Transformers-源码解析-六十六-

Transformers 源码解析(六十六)

.\models\longformer\modeling_tf_longformer.py

# 导入警告模块,用于处理警告信息
import warnings
# 导入数据类装饰器,用于定义数据类
from dataclasses import dataclass
# 导入类型提示模块,用于指定变量类型
from typing import Optional, Tuple, Union

# 导入 NumPy 库
import numpy as np
# 导入 TensorFlow 库
import tensorflow as tf

# 导入自定义模块中的函数和类
from ...activations_tf import get_tf_activation
from ...modeling_tf_utils import (
    TFMaskedLanguageModelingLoss,
    TFModelInputType,
    TFMultipleChoiceLoss,
    TFPreTrainedModel,
    TFQuestionAnsweringLoss,
    TFSequenceClassificationLoss,
    TFTokenClassificationLoss,
    get_initializer,
    keras,
    keras_serializable,
    unpack_inputs,
)
# 导入 TensorFlow 工具函数
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
# 导入通用工具模块中的函数和类
from ...utils import (
    ModelOutput,
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
)
# 导入 Longformer 模型配置类
from .configuration_longformer import LongformerConfig

# 获取日志记录器对象
logger = logging.get_logger(__name__)

# Longformer 模型的预训练模型列表
_CHECKPOINT_FOR_DOC = "allenai/longformer-base-4096"
_CONFIG_FOR_DOC = "LongformerConfig"

# 定义一个大负数常量,用于在 Softmax 操作中抑制无关的信息
LARGE_NEGATIVE = -1e8

# Longformer 模型的预训练模型存档列表
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "allenai/longformer-base-4096",
    "allenai/longformer-large-4096",
    "allenai/longformer-large-4096-finetuned-triviaqa",
    "allenai/longformer-base-4096-extra.pos.embd.only",
    "allenai/longformer-large-4096-extra.pos.embd.only",
    # 更多 Longformer 模型详见:https://huggingface.co/models?filter=longformer
]

# 定义一个数据类,用于存储 Longformer 模型的基础输出
@dataclass
class TFLongformerBaseModelOutput(ModelOutput):
    """
    Longformer 模型的基础输出类,可能包含隐藏状态、本地和全局注意力等信息。
    继承自 ModelOutput 类。
    """
    Args:
        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
            模型最后一层输出的隐藏状态序列。

        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            包含模型每一层隐藏状态的元组,形状为 `(batch_size, sequence_length, hidden_size)`。

            模型在每一层的隐藏状态,以及初始嵌入输出。

        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            包含模型每一层局部注意力权重的元组,形状为 `(batch_size, num_heads, sequence_length, x + attention_window + 1)`,
            其中 `x` 是全局注意力掩码的标记数。

            在注意力 softmax 后的局部注意力权重,用于计算自注意力头中的加权平均值。这些是从序列中每个标记到具有全局注意力的每个标记的注意力权重(前 `x` 个值),
            以及到注意力窗口内每个标记的注意力权重(剩余的 `attention_window + 1` 个值)。
            注意,前 `x` 个值是指文本中固定位置的标记,而剩余的 `attention_window + 1` 个值是指相对位置的标记:
            标记到自身的注意力权重位于索引 `x + attention_window / 2`,而前(后) `attention_window / 2` 个值是到前(后)标记的注意力权重。
            如果注意力窗口包含具有全局注意力的标记,则相应索引处的注意力权重设为 0;该值应从第一个 `x` 个注意力权重中获取。
            如果标记具有全局注意力,则到 `attentions` 中所有其他标记的注意力权重设为 0;该值应从 `global_attentions` 中获取。

        global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            包含模型每一层全局注意力权重的元组,形状为 `(batch_size, num_heads, sequence_length, x)`,
            其中 `x` 是具有全局注意力掩码的标记数。

            在注意力 softmax 后的全局注意力权重,用于计算自注意力头中的加权平均值。这些是从具有全局注意力的每个标记到序列中每个标记的注意力权重。

    Raises:
        None

    Returns:
        None
# 使用 dataclass 装饰器定义 TFLongformerBaseModelOutputWithPooling 类,用于存储 Longformer 模型的输出,并包含最后隐藏状态的汇总信息。
@dataclass
class TFLongformerBaseModelOutputWithPooling(ModelOutput):
    """
    Base class for Longformer's outputs that also contains a pooling of the last hidden states.
    """

    # 最后隐藏状态的张量,通常是模型的最终输出
    last_hidden_state: tf.Tensor = None
    # 汇总器的输出张量,可能用于整合最后隐藏状态
    pooler_output: tf.Tensor = None
    # 隐藏状态的元组,记录模型中间层的隐藏状态,如果有的话
    hidden_states: Tuple[tf.Tensor, ...] | None = None
    # 注意力张量的元组,记录模型的注意力权重,如果有的话
    attentions: Tuple[tf.Tensor, ...] | None = None
    # 全局注意力张量的元组,记录模型的全局注意力权重,如果有的话
    global_attentions: Tuple[tf.Tensor, ...] | None = None


# 使用 dataclass 装饰器定义 TFLongformerMaskedLMOutput 类,用于存储掩码语言模型的输出。
@dataclass
class TFLongformerMaskedLMOutput(ModelOutput):
    """
    Base class for masked language models outputs.
    """
    # 定义 loss 变量,类型为 tf.Tensor,形状为 (1,),当提供 labels 参数时返回,用于掩码语言建模的损失
    loss: tf.Tensor | None = None
    # 定义一个变量 logits,类型为 tf.Tensor,初始值为 None,用于存储模型的输出 logits
    logits: tf.Tensor = None
    
    # 定义一个变量 hidden_states,类型为 Tuple[tf.Tensor, ...] 或者 None,初始值为 None,
    # 用于存储模型的隐藏状态(例如 RNN 或者 Transformer 中的隐藏状态)
    hidden_states: Tuple[tf.Tensor, ...] | None = None
    
    # 定义一个变量 attentions,类型为 Tuple[tf.Tensor, ...] 或者 None,初始值为 None,
    # 用于存储模型中的注意力分数或注意力权重
    attentions: Tuple[tf.Tensor, ...] | None = None
    
    # 定义一个变量 global_attentions,类型为 Tuple[tf.Tensor, ...] 或者 None,初始值为 None,
    # 用于存储模型中的全局注意力分数或全局注意力权重
    global_attentions: Tuple[tf.Tensor, ...] | None = None
# 使用 dataclass 装饰器定义 TFLongformerQuestionAnsweringModelOutput 类,表示 Longformer 问答模型的输出
@dataclass
class TFLongformerQuestionAnsweringModelOutput(ModelOutput):
    """
    Base class for outputs of question answering Longformer models.
    问题回答 Longformer 模型输出的基类。
    """

    # 损失值,是一个 TensorFlow 张量或者 None(表示没有损失)
    loss: tf.Tensor | None = None

    # 起始位置的预测 logits(对数概率)
    start_logits: tf.Tensor = None

    # 结束位置的预测 logits(对数概率)
    end_logits: tf.Tensor = None

    # 隐藏状态的元组,可能为 None
    hidden_states: Tuple[tf.Tensor, ...] | None = None

    # 注意力分布的元组,可能为 None
    attentions: Tuple[tf.Tensor, ...] | None = None

    # 全局注意力的元组,可能为 None
    global_attentions: Tuple[tf.Tensor, ...] | None = None


# 使用 dataclass 装饰器定义 TFLongformerSequenceClassifierOutput 类,表示 Longformer 序列分类模型的输出
@dataclass
class TFLongformerSequenceClassifierOutput(ModelOutput):
    """
    Base class for outputs of sentence classification models.
    句子分类模型输出的基类。
    """
    # 定义变量 loss,用来存储分类(或回归,如果 config.num_labels==1)的损失值张量,可选项
    loss: tf.Tensor | None = None
    # logits: tf.Tensor = None
    # 声明一个变量 logits,类型为 tf.Tensor,初始赋值为 None,用于存储模型的输出 logits。
    
    # hidden_states: Tuple[tf.Tensor, ...] | None = None
    # 声明一个变量 hidden_states,类型为 Tuple[tf.Tensor, ...] 或 None,初始赋值为 None。
    # 这个变量用于存储模型中间层的隐藏状态,可能是一个张量元组或者空值。
    
    # attentions: Tuple[tf.Tensor, ...] | None = None
    # 声明一个变量 attentions,类型为 Tuple[tf.Tensor, ...] 或 None,初始赋值为 None。
    # 这个变量用于存储模型中注意力机制的输出,可能是一个张量元组或者空值。
    
    # global_attentions: Tuple[tf.Tensor, ...] | None = None
    # 声明一个变量 global_attentions,类型为 Tuple[tf.Tensor, ...] 或 None,初始赋值为 None。
    # 这个变量用于存储模型中的全局注意力机制的输出,可能是一个张量元组或者空值。
# 使用 dataclass 装饰器声明一个数据类,用于存储输出结果
@dataclass
# TFLongformerMultipleChoiceModelOutput 类继承自 ModelOutput 类,表示多选模型的输出基类
class TFLongformerMultipleChoiceModelOutput(ModelOutput):
    """
    Base class for outputs of multiple choice models.
    """
        Args:
            loss (`tf.Tensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
                分类损失。
            logits (`tf.Tensor` of shape `(batch_size, num_choices)`):
                *num_choices* 是输入张量的第二维度。(见上述的 *input_ids*)。

                分类得分(SoftMax 之前)。
            hidden_states (`tuple(tf.Tensor)`, *optional*, 当 `output_hidden_states=True` 被传递或者 `config.output_hidden_states=True` 时返回):
                形状为 `(batch_size, sequence_length, hidden_size)` 的 `tf.Tensor` 元组。

                模型在每一层输出的隐藏状态以及初始嵌入输出。
            attentions (`tuple(tf.Tensor)`, *optional*, 当 `output_attentions=True` 被传递或者 `config.output_attentions=True` 时返回):
                形状为 `(batch_size, num_heads, sequence_length, x + attention_window + 1)` 的 `tf.Tensor` 元组,其中 `x` 是具有全局注意力掩码的标记数。

                注意力 softmax 后的局部注意力权重,用于计算自注意力头中的加权平均值。这些是从序列中每个标记到具有全局注意力的每个标记(前 `x` 个值)和到注意力窗口中每个标记(剩余 `attention_window + 1` 个值)的注意力权重。请注意,前 `x` 个值指的是文本中具有固定位置的标记,但剩余的 `attention_window + 1` 个值指的是具有相对位置的标记:一个标记到自身的注意力权重位于索引 `x + attention_window / 2`,前 `attention_window / 2`(后续)的值是到前(后) `attention_window / 2` 个标记的注意力权重。如果注意力窗口包含具有全局注意力的标记,则相应索引处的注意力权重设为 0;其值应从前 `x` 个注意力权重中访问。如果一个标记具有全局注意力,则 `attentions` 中所有其他标记的注意力权重设为 0,其值应从 `global_attentions` 中访问。
            global_attentions (`tuple(tf.Tensor)`, *optional*, 当 `output_attentions=True` 被传递或者 `config.output_attentions=True` 时返回):
                形状为 `(batch_size, num_heads, sequence_length, x)` 的 `tf.Tensor` 元组,其中 `x` 是具有全局注意力掩码的标记数。

                注意力 softmax 后的全局注意力权重,用于计算自注意力头中的加权平均值。这些是每个具有全局注意力的标记到序列中每个标记的注意力权重。
    # 定义一个 TensorFlow 张量 logits,初始化为 None
    logits: tf.Tensor = None
    # 定义一个元组,包含多个 TensorFlow 张量的隐藏状态,初始化为 None 或者空值
    hidden_states: Tuple[tf.Tensor, ...] | None = None
    # 定义一个元组,包含多个 TensorFlow 张量的注意力权重,初始化为 None 或者空值
    attentions: Tuple[tf.Tensor, ...] | None = None
    # 定义一个元组,包含多个 TensorFlow 张量的全局注意力权重,初始化为 None 或者空值
    global_attentions: Tuple[tf.Tensor, ...] | None = None
@dataclass
class TFLongformerTokenClassifierOutput(ModelOutput):
    """
    定义一个数据类 TFLongformerTokenClassifierOutput,用于存储长形式模型的标记分类输出结果。
    继承自 ModelOutput 类。
    Base class for outputs of token classification models.
    """
    """
    Args:
        loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
            分类损失。
            Classification loss.
        logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.num_labels)`):
            分类分数(SoftMax 之前)。
            Classification scores (before SoftMax).
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            模型每一层的隐藏状态,包括初始嵌入输出。
            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            每层的注意力权重。
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
            attention_window + 1)`, where `x` is the number of tokens with global attention mask.

            Local attentions weights after the attention softmax, used to compute the weighted average in the
            self-attention heads. Those are the attention weights from every token in the sequence to every token with
            global attention (first `x` values) and to every token in the attention window (remaining `attention_window
            + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
            remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
            token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
            (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
            If the attention window contains a token with global attention, the attention weight at the corresponding
            index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
            attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
            accessed from `global_attentions`.
        global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            全局注意力权重。
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x`
            is the number of tokens with global attention mask.

            Global attentions weights after the attention softmax, used to compute the weighted average in the
            self-attention heads. Those are the attention weights from every token with global attention to every token
            in the sequence.
    """

    loss: tf.Tensor | None = None  # 分类损失,初始值为 None
    logits: tf.Tensor = None  # 分类分数,初始值为 None
    hidden_states: Tuple[tf.Tensor, ...] | None = None  # 模型每一层的隐藏状态,初始值为 None
    # 定义变量 attentions 和 global_attentions,它们分别是 Tensorflow 的张量元组或者 None 值
    attentions: Tuple[tf.Tensor, ...] | None = None
    global_attentions: Tuple[tf.Tensor, ...] | None = None
# 根据输入的形状和分隔符索引,计算全局注意力掩码,如果 `before_sep_token` 为 True,则将注意力放在分隔符之前的所有标记上,否则放在分隔符之后。
def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_sep_token=True):
    # 确保 `sep_token_indices` 的第二个维度为2,即 `input_ids` 应该有两个维度
    assert shape_list(sep_token_indices)[1] == 2, "`input_ids` should have two dimensions"
    # 从 `sep_token_indices` 中提取问题结束索引,为全局注意力掩码准备形状
    question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1][:, None]
    # 创建布尔类型的注意力掩码,全局注意力位置为 True
    attention_mask = tf.expand_dims(tf.range(input_ids_shape[1], dtype=tf.int64), axis=0)
    attention_mask = tf.tile(attention_mask, (input_ids_shape[0], 1))
    if before_sep_token is True:
        # 如果 `before_sep_token` 为 True,则将问题结束索引扩展到整个序列长度,并生成相应的注意力掩码
        question_end_index = tf.tile(question_end_index, (1, input_ids_shape[1]))
        attention_mask = tf.cast(attention_mask < question_end_index, dtype=question_end_index.dtype)
    else:
        # 否则,将最后一个标记视为分隔符,不计入全局注意力,同时在中间有两个分隔符标记
        question_end_index = tf.tile(question_end_index + 1, (1, input_ids_shape[1]))
        attention_mask = tf.cast(
            (attention_mask > question_end_index) * (attention_mask < input_ids_shape[-1]),
            dtype=question_end_index.dtype,
        )
    return attention_mask


# 从 transformers.models.roberta.modeling_tf_roberta.TFRobertaLMHead 复制并修改为 Longformer 的 LM 头部模型
class TFLongformerLMHead(keras.layers.Layer):
    """Longformer Head for masked language modeling."""

    def __init__(self, config, input_embeddings, **kwargs):
        super().__init__(**kwargs)

        self.config = config
        self.hidden_size = config.hidden_size
        # 定义全连接层,用于预测下一个标记
        self.dense = keras.layers.Dense(
            config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        # 定义层归一化层,用于规范化
        self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
        # 获取 GELU 激活函数
        self.act = get_tf_activation("gelu")

        # 输出权重与输入嵌入相同,但每个标记有一个输出偏置
        self.decoder = input_embeddings

    def build(self, input_shape=None):
        # 定义输出偏置,形状为 (vocab_size,)
        self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")

        if self.built:
            return
        self.built = True
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                self.dense.build([None, None, self.config.hidden_size])
        if getattr(self, "layer_norm", None) is not None:
            with tf.name_scope(self.layer_norm.name):
                self.layer_norm.build([None, None, self.config.hidden_size])

    def get_output_embeddings(self):
        # 获取输出嵌入
        return self.decoder
    # 设置输出的嵌入向量
    def set_output_embeddings(self, value):
        # 更新解码器的权重为给定的值
        self.decoder.weight = value
        # 更新解码器的词汇大小为给定值的第一个维度大小
        self.decoder.vocab_size = shape_list(value)[0]
    
    # 获取偏置项
    def get_bias(self):
        # 返回包含偏置项的字典
        return {"bias": self.bias}
    
    # 设置偏置项
    def set_bias(self, value):
        # 更新对象的偏置项为给定字典中的偏置项值
        self.bias = value["bias"]
        # 更新配置的词汇大小为给定偏置项的第一个维度大小
        self.config.vocab_size = shape_list(value["bias"])[0]
    
    # 模型的调用方法
    def call(self, hidden_states):
        # 全连接层:将隐藏状态映射到更高维度
        hidden_states = self.dense(hidden_states)
        # 激活函数:应用激活函数到全连接层输出
        hidden_states = self.act(hidden_states)
        # 层归一化:对激活函数输出进行层归一化处理
    
        # 投影回词汇大小的向量并加上偏置项
        seq_length = shape_list(tensor=hidden_states)[1]
        hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size])
        hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True)
        hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
        hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
    
        # 返回处理后的隐藏状态
        return hidden_states
class TFLongformerEmbeddings(keras.layers.Layer):
    """
    Same as BertEmbeddings with a tiny tweak for positional embeddings indexing and some extra casting.
    """

    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        # 设定填充符索引为1
        self.padding_idx = 1
        # 保存配置参数
        self.config = config
        # 隐藏层大小
        self.hidden_size = config.hidden_size
        # 最大位置编码数
        self.max_position_embeddings = config.max_position_embeddings
        # 初始化范围
        self.initializer_range = config.initializer_range
        # 层归一化
        self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        # Dropout
        self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)

    def build(self, input_shape=None):
        with tf.name_scope("word_embeddings"):
            # 添加词嵌入权重
            self.weight = self.add_weight(
                name="weight",
                shape=[self.config.vocab_size, self.hidden_size],
                initializer=get_initializer(self.initializer_range),
            )

        with tf.name_scope("token_type_embeddings"):
            # 添加token类型嵌入
            self.token_type_embeddings = self.add_weight(
                name="embeddings",
                shape=[self.config.type_vocab_size, self.hidden_size],
                initializer=get_initializer(self.initializer_range),
            )

        with tf.name_scope("position_embeddings"):
            # 添加位置嵌入
            self.position_embeddings = self.add_weight(
                name="embeddings",
                shape=[self.max_position_embeddings, self.hidden_size],
                initializer=get_initializer(self.initializer_range),
            )

        if self.built:
            return
        self.built = True
        if getattr(self, "LayerNorm", None) is not None:
            with tf.name_scope(self.LayerNorm.name):
                # 构建层归一化
                self.LayerNorm.build([None, None, self.config.hidden_size])

    def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0):
        """
        Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
        symbols are ignored. This is modified from fairseq's `utils.make_positions`.

        Args:
            input_ids: tf.Tensor
        Returns: tf.Tensor
        """
        # 创建掩码,标记非填充符号的位置
        mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype)
        # 累积索引,考虑过去的键值长度,乘以掩码
        incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask

        return incremental_indices + self.padding_idx

    def call(
        self,
        input_ids=None,
        position_ids=None,
        token_type_ids=None,
        inputs_embeds=None,
        past_key_values_length=0,
        training=False,
        """
        Applies embedding based on inputs tensor.

        Returns:
            final_embeddings (`tf.Tensor`): output embedding tensor.
        """
        # 确保 `input_ids` 和 `inputs_embeds` 至少有一个不是 None
        assert not (input_ids is None and inputs_embeds is None)

        if input_ids is not None:
            # 检查 `input_ids` 是否在词汇表大小范围内
            check_embeddings_within_bounds(input_ids, self.config.vocab_size)
            # 从权重矩阵中根据 `input_ids` 提取对应的嵌入向量
            inputs_embeds = tf.gather(params=self.weight, indices=input_ids)

        # 获取输入嵌入张量的形状,去除最后一个维度(通常是序列长度)
        input_shape = shape_list(inputs_embeds)[:-1]

        if token_type_ids is None:
            # 如果 `token_type_ids` 为 None,则创建一个与 `inputs_embeds` 形状相同的全零张量作为 token 类型 id
            token_type_ids = tf.cast(tf.fill(dims=input_shape, value=0), tf.int64)

        if position_ids is None:
            if input_ids is not None:
                # 如果 `position_ids` 为 None 并且 `input_ids` 不为 None,则从 `input_ids` 创建位置 id
                position_ids = self.create_position_ids_from_input_ids(
                    input_ids=input_ids, past_key_values_length=past_key_values_length
                )
            else:
                # 如果 `input_ids` 为 None,则创建一个从 `padding_idx + 1` 开始到 `input_shape[-1] + self.padding_idx` 结束的位置 id
                position_ids = tf.expand_dims(
                    tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1, dtype=tf.int64),
                    axis=0,
                )

        # 根据位置 id 从位置嵌入矩阵中提取位置嵌入向量
        position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
        # 根据 token 类型 id 从 token 类型嵌入矩阵中提取 token 类型嵌入向量
        token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
        # 将输入嵌入、位置嵌入和 token 类型嵌入相加得到最终的嵌入张量
        final_embeddings = inputs_embeds + position_embeds + token_type_embeds
        # 对最终嵌入张量进行 LayerNorm
        final_embeddings = self.LayerNorm(inputs=final_embeddings)
        # 在训练模式下对最终嵌入张量进行 dropout
        final_embeddings = self.dropout(inputs=final_embeddings, training=training)

        # 返回最终的嵌入张量
        return final_embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Longformer
class TFLongformerIntermediate(keras.layers.Layer):
    def __init__(self, config: LongformerConfig, **kwargs):
        super().__init__(**kwargs)

        # 创建一个全连接层,用于中间表示,设置单元数和初始化器
        self.dense = keras.layers.Dense(
            units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )

        # 根据配置获取中间激活函数,若为字符串则转换为对应的 TensorFlow 激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = get_tf_activation(config.hidden_act)
        else:
            self.intermediate_act_fn = config.hidden_act
        self.config = config

    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        # 通过全连接层处理输入的隐藏状态张量
        hidden_states = self.dense(inputs=hidden_states)
        # 应用中间激活函数到处理后的隐藏状态
        hidden_states = self.intermediate_act_fn(hidden_states)

        return hidden_states

    def build(self, input_shape=None):
        # 如果已经构建过,直接返回
        if self.built:
            return
        self.built = True
        # 如果存在全连接层,设置其构建结构
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                self.dense.build([None, None, self.config.hidden_size])


# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Longformer
class TFLongformerOutput(keras.layers.Layer):
    def __init__(self, config: LongformerConfig, **kwargs):
        super().__init__(**kwargs)

        # 创建一个全连接层,用于输出表示,设置单元数和初始化器
        self.dense = keras.layers.Dense(
            units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        # 创建一个 LayerNormalization 层,用于归一化输出表示
        self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        # 创建一个 Dropout 层,用于随机丢弃部分输出表示,以减少过拟合
        self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
        self.config = config

    def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
        # 通过全连接层处理输入的隐藏状态张量
        hidden_states = self.dense(inputs=hidden_states)
        # 在训练阶段随机丢弃部分输出表示,用于正则化
        hidden_states = self.dropout(inputs=hidden_states, training=training)
        # 对处理后的表示进行 LayerNormalization,并加上输入张量,实现残差连接
        hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)

        return hidden_states

    def build(self, input_shape=None):
        # 如果已经构建过,直接返回
        if self.built:
            return
        self.built = True
        # 如果存在全连接层,设置其构建结构
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                self.dense.build([None, None, self.config.intermediate_size])
        # 如果存在 LayerNormalization 层,设置其构建结构
        if getattr(self, "LayerNorm", None) is not None:
            with tf.name_scope(self.LayerNorm.name):
                self.LayerNorm.build([None, None, self.config.hidden_size])


# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Longformer
class TFLongformerPooler(keras.layers.Layer):
    # 此类尚未实现,预留作为 Longformer 池化层的定义
    # 初始化函数,用于创建一个新的Longformer层实例
    def __init__(self, config: LongformerConfig, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)

        # 创建一个全连接层,用于变换隐藏状态的维度
        self.dense = keras.layers.Dense(
            units=config.hidden_size,  # 设置全连接层的输出单元数为隐藏大小
            kernel_initializer=get_initializer(config.initializer_range),  # 初始化权重的方式
            activation="tanh",  # 激活函数为tanh
            name="dense",  # 层的名称
        )
        # 保存Longformer配置信息
        self.config = config

    # 调用函数,定义了如何使用层处理输入张量并返回输出张量
    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        # 简单池化模型,通过取第一个token对应的隐藏状态来“池化”模型
        first_token_tensor = hidden_states[:, 0]  # 取第一个token对应的隐藏状态张量
        pooled_output = self.dense(inputs=first_token_tensor)  # 使用全连接层处理第一个token的隐藏状态

        return pooled_output  # 返回池化后的输出张量

    # 构建函数,在第一次调用call方法前构建层,通常用于初始化参数
    def build(self, input_shape=None):
        if self.built:  # 如果已经构建过,直接返回
            return
        self.built = True  # 设置为已构建状态
        if getattr(self, "dense", None) is not None:  # 如果存在全连接层
            with tf.name_scope(self.dense.name):  # 使用全连接层的名称作为命名空间
                self.dense.build([None, None, self.config.hidden_size])  # 构建全连接层的权重
# 从 transformers.models.bert.modeling_tf_bert.TFBertSelfOutput 复制并修改为 Longformer
class TFLongformerSelfOutput(keras.layers.Layer):
    def __init__(self, config: LongformerConfig, **kwargs):
        super().__init__(**kwargs)

        # 定义一个全连接层,用于将输入向量映射到隐藏大小的空间
        self.dense = keras.layers.Dense(
            units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        # LayerNormalization 层,用于对输入进行归一化处理
        self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        # Dropout 层,用于在训练过程中随机失活一部分神经元,防止过拟合
        self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
        # 存储传入的配置信息
        self.config = config

    def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
        # 线性变换操作,将输入张量映射到隐藏大小的空间
        hidden_states = self.dense(inputs=hidden_states)
        # 在训练过程中,对输出进行 dropout 处理,防止过拟合
        hidden_states = self.dropout(inputs=hidden_states, training=training)
        # 使用 LayerNormalization 对输入向量进行归一化处理,并与原始输入张量相加
        hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)

        return hidden_states

    def build(self, input_shape=None):
        # 如果已经构建过,直接返回
        if self.built:
            return
        self.built = True
        # 如果 dense 层已定义,则构建 dense 层
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                self.dense.build([None, None, self.config.hidden_size])
        # 如果 LayerNorm 层已定义,则构建 LayerNorm 层
        if getattr(self, "LayerNorm", None) is not None:
            with tf.name_scope(self.LayerNorm.name):
                self.LayerNorm.build([None, None, self.config.hidden_size])


class TFLongformerSelfAttention(keras.layers.Layer):
    # 初始化函数,接受配置、层ID以及其他关键字参数
    def __init__(self, config, layer_id, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)
        # 将配置信息保存在实例变量中
        self.config = config

        # 检查隐藏层大小是否能被注意力头数整除
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads}"
            )

        # 设置注意力头数和头维度
        self.num_heads = config.num_attention_heads
        self.head_dim = int(config.hidden_size / config.num_attention_heads)
        self.embed_dim = config.hidden_size

        # 创建查询、键、值的全连接层,用于自注意力机制
        self.query = keras.layers.Dense(
            self.embed_dim,
            kernel_initializer=get_initializer(config.initializer_range),
            name="query",
        )
        self.key = keras.layers.Dense(
            self.embed_dim,
            kernel_initializer=get_initializer(config.initializer_range),
            name="key",
        )
        self.value = keras.layers.Dense(
            self.embed_dim,
            kernel_initializer=get_initializer(config.initializer_range),
            name="value",
        )

        # 创建查询、键、值的全连接层,用于全局注意力机制
        self.query_global = keras.layers.Dense(
            self.embed_dim,
            kernel_initializer=get_initializer(config.initializer_range),
            name="query_global",
        )
        self.key_global = keras.layers.Dense(
            self.embed_dim,
            kernel_initializer=get_initializer(config.initializer_range),
            name="key_global",
        )
        self.value_global = keras.layers.Dense(
            self.embed_dim,
            kernel_initializer=get_initializer(config.initializer_range),
            name="value_global",
        )

        # 创建注意力概率的丢弃层
        self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob)
        self.global_dropout = keras.layers.Dropout(config.attention_probs_dropout_prob)

        # 设置当前层的ID
        self.layer_id = layer_id

        # 检查并设置局部注意力窗口大小
        attention_window = config.attention_window[self.layer_id]
        assert (
            attention_window % 2 == 0
        ), f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}"
        assert (
            attention_window > 0
        ), f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}"

        self.one_sided_attn_window_size = attention_window // 2
    # 定义神经网络层的构建函数,用于构建自定义层对象
    def build(self, input_shape=None):
        # 如果尚未构建过,则开始构建
        if not self.built:
            # 在 "query_global" 命名空间下构建 self.query_global 层
            with tf.name_scope("query_global"):
                self.query_global.build((self.config.hidden_size,))
            # 在 "key_global" 命名空间下构建 self.key_global 层
            with tf.name_scope("key_global"):
                self.key_global.build((self.config.hidden_size,))
            # 在 "value_global" 命名空间下构建 self.value_global 层
            with tf.name_scope("value_global"):
                self.value_global.build((self.config.hidden_size,))

        # 如果已经构建过,则直接返回,不进行重复构建
        if self.built:
            return

        # 标记为已构建
        self.built = True

        # 如果存在 self.query 属性,则构建该属性表示的层
        if getattr(self, "query", None) is not None:
            with tf.name_scope(self.query.name):
                self.query.build([None, None, self.config.hidden_size])

        # 如果存在 self.key 属性,则构建该属性表示的层
        if getattr(self, "key", None) is not None:
            with tf.name_scope(self.key.name):
                self.key.build([None, None, self.config.hidden_size])

        # 如果存在 self.value 属性,则构建该属性表示的层
        if getattr(self, "value", None) is not None:
            with tf.name_scope(self.value.name):
                self.value.build([None, None, self.config.hidden_size])

        # 如果存在 self.query_global 属性,则构建该属性表示的层
        if getattr(self, "query_global", None) is not None:
            with tf.name_scope(self.query_global.name):
                self.query_global.build([None, None, self.config.hidden_size])

        # 如果存在 self.key_global 属性,则构建该属性表示的层
        if getattr(self, "key_global", None) is not None:
            with tf.name_scope(self.key_global.name):
                self.key_global.build([None, None, self.config.hidden_size])

        # 如果存在 self.value_global 属性,则构建该属性表示的层
        if getattr(self, "value_global", None) is not None:
            with tf.name_scope(self.value_global.name):
                self.value_global.build([None, None, self.config.hidden_size])

    # 定义神经网络层的调用函数,用于实现层的前向传播逻辑
    def call(
        self,
        inputs,
        training=False,
    ):
        # 函数体内容省略,需在实际应用中填充具体的前向传播逻辑
        pass

    # 定义静态方法,用于生成用于屏蔽无效位置的张量
    @staticmethod
    def _mask_invalid_locations(input_tensor, window_overlap):
        # 创建正确的上三角布尔掩码
        mask_2d_upper = tf.reverse(
            tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0),
            axis=[0],
        )

        # 填充成完整的矩阵
        padding = tf.convert_to_tensor(
            [[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]]
        )
        mask_2d = tf.pad(mask_2d_upper, padding)  # 创建下三角掩码

        # 与上三角掩码组合
        mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])

        # 广播到完整的矩阵
        mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1))

        # 用于掩码的负无穷张量
        inf_tensor = -float("inf") * tf.ones_like(input_tensor)

        # 执行掩码操作
        input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor, input_tensor)

        return input_tensor
    def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_overlap):
        """
        Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the
        same shape as `attn_probs`
        """

        # 获取 value 张量的形状信息:batch_size, seq_len, num_heads, head_dim
        batch_size, seq_len, num_heads, head_dim = shape_list(value)

        # 断言确保 seq_len 是 2 * window_overlap 的倍数,用于后续分块处理
        tf.debugging.assert_equal(
            seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap"
        )

        # 断言确保 attn_probs 和 value 张量在前三个维度上形状相同(除了 head_dim 维度)
        tf.debugging.assert_equal(
            shape_list(attn_probs)[:3],
            shape_list(value)[:3],
            message="value and attn_probs must have same dims (except head_dim)",
        )

        # 断言确保 attn_probs 的最后一个维度为 2 * window_overlap + 1
        tf.debugging.assert_equal(
            shape_list(attn_probs)[3],
            2 * window_overlap + 1,
            message="attn_probs last dim has to be 2 * window_overlap + 1",
        )

        # 计算分块的数量,每个分块的大小为 window_overlap
        chunks_count = seq_len // window_overlap - 1

        # 将 attn_probs 张量重新排列和分块,以便进行后续的矩阵乘法计算
        chunked_attn_probs = tf.reshape(
            tf.transpose(attn_probs, (0, 2, 1, 3)),
            (
                batch_size * num_heads,
                seq_len // window_overlap,
                window_overlap,
                2 * window_overlap + 1,
            ),
        )

        # 将 value 张量重新排列,以便进行后续的矩阵乘法计算
        value = tf.reshape(
            tf.transpose(value, (0, 2, 1, 3)),
            (batch_size * num_heads, seq_len, head_dim),
        )

        # 在 seq_len 的开头和结尾各填充 window_overlap 大小的值
        paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap], [0, 0]])
        padded_value = tf.pad(value, paddings, constant_values=-1)

        # 将 padded_value 张量分块,每块大小为 3 * window_overlap * head_dim
        frame_size = 3 * window_overlap * head_dim
        frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count
        chunked_value = tf.signal.frame(
            tf.reshape(padded_value, (batch_size * num_heads, -1)),
            frame_size,
            frame_hop_size,
        )
        chunked_value = tf.reshape(
            chunked_value,
            (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
        )

        # 断言确保 chunked_value 的形状正确
        tf.debugging.assert_equal(
            shape_list(chunked_value),
            [batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
            message="Chunked value has the wrong shape",
        )

        # 调用类内部方法 _pad_and_diagonalize 对 chunked_attn_probs 进行填充和对角化处理
        chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)

        # 使用 Einsum 函数进行矩阵乘法计算,得到上下文向量
        context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)

        # 将 context 张量重新排列,以符合标准的张量形状顺序
        context = tf.transpose(
            tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)),
            (0, 2, 1, 3),
        )

        # 返回计算得到的上下文张量
        return context
    def _pad_and_transpose_last_two_dims(hidden_states_padded, paddings):
        """
        Pads the last two dimensions of `hidden_states_padded` tensor and then transposes the last two dimensions.

        Args:
            hidden_states_padded: Input tensor to be padded and transposed.
            paddings: Tensor specifying the padding amounts for each dimension.

        Returns:
            Transposed tensor after padding.
        """
        hidden_states_padded = tf.pad(
            hidden_states_padded, paddings
        )  # padding value is not important because it will be overwritten
        batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded)
        hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length))

        return hidden_states_padded

    @staticmethod
    def _pad_and_diagonalize(chunked_hidden_states):
        """
        Shifts every row 1 step right, converting columns into diagonals.

        Example:

        chunked_hidden_states: A 4-dimensional tensor representing chunked hidden states.
        window_overlap: Integer representing the number of rows/columns to shift.

        Returns:
            Tensor with padded and diagonalized dimensions.
        """
        total_num_heads, num_chunks, window_overlap, hidden_dim = shape_list(chunked_hidden_states)
        paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]])
        chunked_hidden_states = tf.pad(
            chunked_hidden_states, paddings
        )  # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten
        chunked_hidden_states = tf.reshape(
            chunked_hidden_states, (total_num_heads, num_chunks, -1)
        )  # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap
        chunked_hidden_states = chunked_hidden_states[
            :, :, :-window_overlap
        ]  # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap
        chunked_hidden_states = tf.reshape(
            chunked_hidden_states,
            (total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim),
        )  # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap
        chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]

        return chunked_hidden_states
    def _chunk(hidden_states, window_overlap):
        """convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
        # 获取隐藏状态张量的形状信息
        batch_size, seq_length, hidden_dim = shape_list(hidden_states)
        # 计算输出块的数量,每个块大小为2w,重叠大小为w
        num_output_chunks = 2 * (seq_length // (2 * window_overlap)) - 1

        # 定义帧大小和帧步长(类似于卷积)
        frame_hop_size = window_overlap * hidden_dim
        frame_size = 2 * frame_hop_size
        # 将隐藏状态重塑为二维张量以便进行分块操作
        hidden_states = tf.reshape(hidden_states, (batch_size, seq_length * hidden_dim))

        # 使用帧大小和帧步长进行分块操作
        chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)

        # 断言确保分块操作的输出形状正确
        tf.debugging.assert_equal(
            shape_list(chunked_hidden_states),
            [batch_size, num_output_chunks, frame_size],
            message=(
                "Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
                f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}."
            ),
        )

        # 将分块后的隐藏状态重新重塑为所需的形状
        chunked_hidden_states = tf.reshape(
            chunked_hidden_states,
            (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim),
        )

        return chunked_hidden_states

    @staticmethod
    def _get_global_attn_indices(is_index_global_attn):
        """compute global attn indices required throughout forward pass"""
        # 计算每个样本中全局注意力索引的数量
        num_global_attn_indices = tf.math.count_nonzero(is_index_global_attn, axis=1)
        num_global_attn_indices = tf.cast(num_global_attn_indices, dtype=tf.constant(1).dtype)

        # 批次中全局注意力索引的最大数量
        max_num_global_attn_indices = tf.reduce_max(num_global_attn_indices)

        # 提取非零元素的全局注意力索引
        is_index_global_attn_nonzero = tf.where(is_index_global_attn)

        # 计算哪些位置是局部索引中的全局注意力索引
        is_local_index_global_attn = tf.range(max_num_global_attn_indices) < tf.expand_dims(
            num_global_attn_indices, axis=-1
        )

        # 提取局部索引中非零元素的位置
        is_local_index_global_attn_nonzero = tf.where(is_local_index_global_attn)

        # 提取局部索引中零元素的位置
        is_local_index_no_global_attn_nonzero = tf.where(tf.math.logical_not(is_local_index_global_attn))

        return (
            max_num_global_attn_indices,
            is_index_global_attn_nonzero,
            is_local_index_global_attn_nonzero,
            is_local_index_no_global_attn_nonzero,
        )

    def _concat_with_global_key_attn_probs(
        self,
        attn_scores,
        key_vectors,
        query_vectors,
        max_num_global_attn_indices,
        is_index_global_attn_nonzero,
        is_local_index_global_attn_nonzero,
        is_local_index_no_global_attn_nonzero,
        ):
        # 计算批处理大小
        batch_size = shape_list(key_vectors)[0]

        # 选择全局键向量
        global_key_vectors = tf.gather_nd(key_vectors, is_index_global_attn_nonzero)

        # 创建仅包含全局键向量的张量
        key_vectors_only_global = tf.scatter_nd(
            is_local_index_global_attn_nonzero,
            global_key_vectors,
            shape=(
                batch_size,
                max_num_global_attn_indices,
                self.num_heads,
                self.head_dim,
            ),
        )

        # 使用 Einsum 函数计算从全局键向量得到的注意力概率
        # 形状为 (batch_size, seq_len, num_heads, max_num_global_attn_indices)
        attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs", query_vectors, key_vectors_only_global)

        # 转置操作,将形状调整为 (batch_size, max_num_global_attn_indices, seq_len, num_heads)
        attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key, (0, 3, 1, 2))

        # 创建用于掩码的形状
        mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple(
            shape_list(attn_probs_from_global_key_trans)[-2:]
        )

        # 创建掩码张量并转换为与 attn_probs_from_global_key_trans 相同的数据类型
        mask = tf.ones(mask_shape) * -10000.0
        mask = tf.cast(mask, dtype=attn_probs_from_global_key_trans.dtype)

        # 使用 scatter_nd_update 函数对 attn_probs_from_global_key_trans 应用掩码
        attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update(
            attn_probs_from_global_key_trans,
            is_local_index_no_global_attn_nonzero,
            mask,
        )

        # 再次转置得到最终形状 (batch_size, seq_len, num_heads, max_num_global_attn_indices)
        attn_probs_from_global_key = tf.transpose(attn_probs_from_global_key_trans, (0, 2, 3, 1))

        # 将 attn_probs_from_global_key 与 attn_scores 连接起来
        # 形状为 (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
        attn_scores = tf.concat((attn_probs_from_global_key, attn_scores), axis=-1)

        # 返回最终的注意力分数张量
        return attn_scores

    def _compute_attn_output_with_global_indices(
        self,
        value_vectors,
        attn_probs,
        max_num_global_attn_indices,
        is_index_global_attn_nonzero,
        is_local_index_global_attn_nonzero,
        ):
        # 计算批处理大小
        batch_size = shape_list(attn_probs)[0]

        # 仅保留全局注意力概率,截取前 max_num_global_attn_indices 个
        attn_probs_only_global = attn_probs[:, :, :, :max_num_global_attn_indices]

        # 根据非零全局注意力索引,选择全局值向量
        global_value_vectors = tf.gather_nd(value_vectors, is_index_global_attn_nonzero)

        # 创建仅包含全局值向量的张量
        value_vectors_only_global = tf.scatter_nd(
            is_local_index_global_attn_nonzero,
            global_value_vectors,
            shape=(
                batch_size,
                max_num_global_attn_indices,
                self.num_heads,
                self.head_dim,
            ),
        )

        # 计算仅含全局注意力的注意力输出
        attn_output_only_global = tf.einsum("blhs,bshd->blhd", attn_probs_only_global, value_vectors_only_global)

        # 重新整形注意力概率
        attn_probs_without_global = attn_probs[:, :, :, max_num_global_attn_indices:]

        # 计算包含全局和局部注意力的注意力输出
        attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value(
            attn_probs_without_global, value_vectors, self.one_sided_attn_window_size
        )

        # 返回整合的注意力输出
        return attn_output_only_global + attn_output_without_global

    def _compute_global_attn_output_from_hidden(
        self,
        attn_output,
        hidden_states,
        max_num_global_attn_indices,
        layer_head_mask,
        is_local_index_global_attn_nonzero,
        is_index_global_attn_nonzero,
        is_local_index_no_global_attn_nonzero,
        is_index_masked,
        training,
    ):
        # 定义向量重整形和转置函数,用于处理批量数据
        def reshape_and_transpose(self, vector, batch_size):
            return tf.reshape(
                tf.transpose(
                    tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)),
                    (0, 2, 1, 3),
                ),
                (batch_size * self.num_heads, -1, self.head_dim),
            )
class TFLongformerAttention(keras.layers.Layer):
    # TFLongformerAttention 类,继承自 keras.layers.Layer
    def __init__(self, config, layer_id=0, **kwargs):
        super().__init__(**kwargs)
        # 初始化函数
        # 创建 self_attention 层,使用 TFLongformerSelfAttention 类
        self.self_attention = TFLongformerSelfAttention(config, layer_id, name="self")
        # 创建 dense_output 层,使用 TFLongformerSelfOutput 类
        self.dense_output = TFLongformerSelfOutput(config, name="output")

    def prune_heads(self, heads):
        # 修剪头部的函数,抛出未实现错误
        raise NotImplementedError

    def call(self, inputs, training=False):
        (
            hidden_states,
            attention_mask,
            layer_head_mask,
            is_index_masked,
            is_index_global_attn,
            is_global_attn,
        ) = inputs
        # 调用函数,执行注意力计算
        self_outputs = self.self_attention(
            [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn],
            training=training,
        )
        # 使用 dense_output 层对注意力输出进行处理
        attention_output = self.dense_output(self_outputs[0], hidden_states, training=training)
        # 组装输出元组,包括注意力输出和其他可能的输出
        outputs = (attention_output,) + self_outputs[1:]

        return outputs

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果已经构建过,则直接返回
        if getattr(self, "self_attention", None) is not None:
            with tf.name_scope(self.self_attention.name):
                # 使用 self_attention 层构建
                self.self_attention.build(None)
        if getattr(self, "dense_output", None) is not None:
            with tf.name_scope(self.dense_output.name):
                # 使用 dense_output 层构建
                self.dense_output.build(None)


class TFLongformerLayer(keras.layers.Layer):
    # TFLongformerLayer 类,继承自 keras.layers.Layer
    def __init__(self, config, layer_id=0, **kwargs):
        super().__init__(**kwargs)
        # 初始化函数
        # 创建 attention 层,使用 TFLongformerAttention 类
        self.attention = TFLongformerAttention(config, layer_id, name="attention")
        # 创建 intermediate 层,使用 TFLongformerIntermediate 类
        self.intermediate = TFLongformerIntermediate(config, name="intermediate")
        # 创建 longformer_output 层,使用 TFLongformerOutput 类
        self.longformer_output = TFLongformerOutput(config, name="output")

    def call(self, inputs, training=False):
        (
            hidden_states,
            attention_mask,
            layer_head_mask,
            is_index_masked,
            is_index_global_attn,
            is_global_attn,
        ) = inputs
        # 调用函数,执行注意力计算
        attention_outputs = self.attention(
            [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn],
            training=training,
        )
        # 获取注意力输出
        attention_output = attention_outputs[0]
        # 使用 intermediate 层对注意力输出进行处理
        intermediate_output = self.intermediate(attention_output)
        # 使用 longformer_output 层处理 intermediate 输出
        layer_output = self.longformer_output(intermediate_output, attention_output, training=training)
        # 组装输出元组,包括层输出和可能的注意力输出
        outputs = (layer_output,) + attention_outputs[1:]  # add attentions if we output them

        return outputs
    # 定义神经网络层的构建方法,当输入形状为None时,表示可以适应任意输入形状
    def build(self, input_shape=None):
        # 如果已经构建过,直接返回,避免重复构建
        if self.built:
            return
        # 将标志位设置为True,表示已经进行了构建
        self.built = True
        
        # 检查是否存在注意力层,并进行相应的构建
        if getattr(self, "attention", None) is not None:
            # 在命名空间下构建注意力层
            with tf.name_scope(self.attention.name):
                self.attention.build(None)
        
        # 检查是否存在中间层,并进行相应的构建
        if getattr(self, "intermediate", None) is not None:
            # 在命名空间下构建中间层
            with tf.name_scope(self.intermediate.name):
                self.intermediate.build(None)
        
        # 检查是否存在长形式输出层,并进行相应的构建
        if getattr(self, "longformer_output", None) is not None:
            # 在命名空间下构建长形式输出层
            with tf.name_scope(self.longformer_output.name):
                self.longformer_output.build(None)
class TFLongformerEncoder(keras.layers.Layer):
    # 定义 TFLongformerEncoder 类,继承自 keras.layers.Layer
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        # 调用父类构造函数进行初始化

        self.output_hidden_states = config.output_hidden_states
        # 从 config 参数中获取是否输出隐藏状态的设置

        self.output_attentions = config.output_attentions
        # 从 config 参数中获取是否输出注意力权重的设置

        # 创建 Longformer 层的列表,用于处理不同层的输入
        self.layer = [TFLongformerLayer(config, i, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]

    def call(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        padding_len=0,
        is_index_masked=None,
        is_index_global_attn=None,
        is_global_attn=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        training=False,
    ):
        all_hidden_states = () if output_hidden_states else None
        all_attentions = all_global_attentions = () if output_attentions else None

        for idx, layer_module in enumerate(self.layer):
            if output_hidden_states:
                # 如果需要输出隐藏状态,则根据需要去除填充部分并保存
                hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
                all_hidden_states = all_hidden_states + (hidden_states_to_add,)

            # 调用当前层的前向传播
            layer_outputs = layer_module(
                [
                    hidden_states,
                    attention_mask,
                    head_mask[idx] if head_mask is not None else None,
                    is_index_masked,
                    is_index_global_attn,
                    is_global_attn,
                ],
                training=training,
            )
            hidden_states = layer_outputs[0]

            if output_attentions:
                # 如果需要输出注意力权重,调整注意力权重的维度顺序
                # bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1)
                all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)

                # bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn
                all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)),)

        # 添加最后一层的隐藏状态
        if output_hidden_states:
            hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
            all_hidden_states = all_hidden_states + (hidden_states_to_add,)

        # 取消填充部分
        # 对隐藏状态进行去除填充处理,以使其长度与输入的 input_ids.size(1) 一致
        hidden_states = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
        if output_attentions:
            # 如果需要输出注意力权重,对所有注意力权重进行去除填充处理
            all_attentions = (
                tuple([state[:, :, :-padding_len, :] for state in all_attentions])
                if padding_len > 0
                else all_attentions
            )

        if not return_dict:
            # 如果不需要返回字典形式的结果,则返回一个元组,包含非空的值
            return tuple(
                v for v in [hidden_states, all_hidden_states, all_attentions, all_global_attentions] if v is not None
            )

        # 返回一个 TFLongformerBaseModelOutput 对象,包含指定的结果
        return TFLongformerBaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
            global_attentions=all_global_attentions,
        )

    def build(self, input_shape=None):
        # 构建模型,确保每一层已经建立
        if self.built:
            return
        self.built = True
        if getattr(self, "layer", None) is not None:
            for layer in self.layer:
                with tf.name_scope(layer.name):
                    # 使用当前层的名称作为命名空间,构建层
                    layer.build(None)
# 使用 keras_serializable 装饰器将类标记为可序列化的 Keras 层
@keras_serializable
# 定义 TFLongformerMainLayer 类,继承自 keras.layers.Layer
class TFLongformerMainLayer(keras.layers.Layer):
    # 指定配置类为 LongformerConfig
    config_class = LongformerConfig

    # 初始化方法,接受 config 参数和其他关键字参数
    def __init__(self, config, add_pooling_layer=True, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)

        # 如果 attention_window 是整数,则进行如下断言和处理
        if isinstance(config.attention_window, int):
            # 断言 attention_window 必须为偶数
            assert config.attention_window % 2 == 0, "`config.attention_window` has to be an even value"
            # 断言 attention_window 必须为正数
            assert config.attention_window > 0, "`config.attention_window` has to be positive"
            # 将 attention_window 扩展为一个列表,每层一个值
            config.attention_window = [config.attention_window] * config.num_hidden_layers  # one value per layer
        else:
            # 如果 attention_window 是列表,则断言其长度与 num_hidden_layers 相等
            assert len(config.attention_window) == config.num_hidden_layers, (
                "`len(config.attention_window)` should equal `config.num_hidden_layers`. "
                f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}"
            )

        # 将配置参数赋值给对象属性
        self.config = config
        self.num_hidden_layers = config.num_hidden_layers
        self.initializer_range = config.initializer_range
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
        self.return_dict = config.use_return_dict
        self.pad_token_id = config.pad_token_id
        self.attention_window = config.attention_window
        # 创建 TFLongformerEmbeddings 对象,并赋值给 embeddings 属性
        self.embeddings = TFLongformerEmbeddings(config, name="embeddings")
        # 创建 TFLongformerEncoder 对象,并赋值给 encoder 属性
        self.encoder = TFLongformerEncoder(config, name="encoder")
        # 如果 add_pooling_layer 为 True,则创建 TFLongformerPooler 对象,并赋值给 pooler 属性;否则 pooler 属性为 None
        self.pooler = TFLongformerPooler(config, name="pooler") if add_pooling_layer else None

    # 返回 embeddings 属性,用作输入嵌入层的方法
    def get_input_embeddings(self):
        return self.embeddings

    # 设置输入嵌入层的方法,将 value 赋值给 embeddings 的权重,并更新 vocab_size 属性
    def set_input_embeddings(self, value):
        self.embeddings.weight = value
        self.embeddings.vocab_size = shape_list(value)[0]

    # 抽象方法,用于剪枝模型中的注意力头
    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        raise NotImplementedError

    # 使用 unpack_inputs 装饰器定义模型调用方法,接受多个输入参数
    @unpack_inputs
    def call(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        global_attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        training=False,
    ):
        # 方法体部分未提供,通常用于执行模型的前向传播计算

    # _pad_to_window_size 方法定义,用于将输入序列的长度填充到指定的窗口大小
    def _pad_to_window_size(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        inputs_embeds,
        pad_token_id,
    ):
        # 方法体部分未提供,通常用于执行填充操作
    ):
        """A helper function to pad tokens and mask to work with implementation of Longformer selfattention."""
        # padding
        attention_window = (
            self.attention_window if isinstance(self.attention_window, int) else max(self.attention_window)
        )

        assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}"

        # 获取输入数据的形状
        input_shape = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds)
        batch_size, seq_len = input_shape[:2]
        # 计算需要填充的长度,使序列长度能够整除注意力窗口大小
        padding_len = (attention_window - seq_len % attention_window) % attention_window

        # 创建填充张量
        paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]])

        # 如果存在 input_ids,则对其进行填充,使用 pad_token_id 进行填充
        if input_ids is not None:
            input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)

        # 如果存在 position_ids,则对其进行填充,使用 pad_token_id 进行填充
        if position_ids is not None:
            # 使用与 modeling_roberta.RobertaEmbeddings 相同的方式,用 pad_token_id 填充
            position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id)

        # 如果存在 inputs_embeds,则根据 padding_len 对其进行填充
        if inputs_embeds is not None:
            if padding_len > 0:
                # 创建与填充长度相匹配的 input_ids_padding 张量,并利用 embeddings 方法得到 inputs_embeds_padding
                input_ids_padding = tf.cast(tf.fill((batch_size, padding_len), self.pad_token_id), tf.int64)
                inputs_embeds_padding = self.embeddings(input_ids_padding)
                # 将填充后的 inputs_embeds 与 inputs_embeds_padding 进行拼接
                inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)

        # 对 attention_mask 进行填
    # 构建模型的方法,用于设置模型的各层和参数
    def build(self, input_shape=None):
        # 如果模型已经构建完成,则直接返回,避免重复构建
        if self.built:
            return
        # 将模型标记为已构建状态
        self.built = True
        
        # 如果存在嵌入层(embeddings),则构建嵌入层
        if getattr(self, "embeddings", None) is not None:
            # 使用嵌入层的名称作为命名空间
            with tf.name_scope(self.embeddings.name):
                # 调用嵌入层的 build 方法,传入空的输入形状
                self.embeddings.build(None)
        
        # 如果存在编码器(encoder),则构建编码器
        if getattr(self, "encoder", None) is not None:
            # 使用编码器的名称作为命名空间
            with tf.name_scope(self.encoder.name):
                # 调用编码器的 build 方法,传入空的输入形状
                self.encoder.build(None)
        
        # 如果存在池化层(pooler),则构建池化层
        if getattr(self, "pooler", None) is not None:
            # 使用池化层的名称作为命名空间
            with tf.name_scope(self.pooler.name):
                # 调用池化层的 build 方法,传入空的输入形状
                self.pooler.build(None)
    """
    这是一个抽象类,处理权重初始化以及下载和加载预训练模型的简单接口。

    config_class = LongformerConfig
    base_model_prefix = "longformer"

    @property
    def input_signature(self):
        sig = super().input_signature
        sig["global_attention_mask"] = tf.TensorSpec((None, None), tf.int32, name="global_attention_mask")
        return sig
    """
LONGFORMER_START_DOCSTRING = r"""
    这个模型继承自[`TFPreTrainedModel`]。请查看超类文档,了解库实现的所有通用方法(如下载或保存、调整输入嵌入大小、修剪头等)。

    这个模型也是一个 [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) 的子类。可以像使用常规的 TF 2.0 Keras 模型一样使用它,并参考 TF 2.0 的文档了解有关一般使用和行为的所有内容。

    <Tip>

    `transformers` 中的 TensorFlow 模型和层接受两种输入格式:

    - 将所有输入作为关键字参数(类似于 PyTorch 模型)传递,或者
    - 将所有输入作为列表、元组或字典传递给第一个位置参数。

    支持第二种格式的原因是,Keras 方法在传递输入给模型和层时更喜欢这种格式。由于这种支持,当使用 `model.fit()` 等方法时,只需将输入和标签以 `model.fit()` 支持的任何格式传递即可!然而,如果要在 Keras 方法之外(如在使用 Keras `Functional` API 创建自己的层或模型时)使用第二种格式,有三种可能的方法可以使用来收集第一个位置参数中的所有输入张量:

    - 只有 `input_ids` 的单个张量,没有其他内容:`model(input_ids)`
    - 长度可变的列表,按照文档字符串中给定的顺序包含一个或多个输入张量:`model([input_ids, attention_mask])` 或 `model([input_ids, attention_mask, token_type_ids])`
    - 一个字典,其中包含一个或多个输入张量,与文档字符串中给定的输入名称相关联:`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`

    注意,当使用 [子类化](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) 创建模型和层时,不需要担心这些内容,因为可以像将输入传递给任何其他 Python 函数一样传递输入!

    </Tip>

    Parameters:
        config ([`LongformerConfig`]): 包含模型所有参数的模型配置类。
            使用配置文件初始化不会加载与模型相关的权重,仅加载配置。
            查看 [`~PreTrainedModel.from_pretrained`] 方法以加载模型权重。
"""

LONGFORMER_INPUTS_DOCSTRING = r"""
    """
@add_start_docstrings(
    "The bare Longformer Model outputting raw hidden-states without any specific head on top.",
    LONGFORMER_START_DOCSTRING,
)
class TFLongformerModel(TFLongformerPreTrainedModel):
    """
    TFLongformerModel类继承自TFLongformerPreTrainedModel,用于输出不带特定头部的原始隐藏状态。

    This class copies code from [`TFRobertaModel`] and overwrites standard self-attention with longformer
    self-attention to provide the ability to process long sequences following the self-attention approach described in
    [Longformer: the Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, and
    Arman Cohan. Longformer self-attention combines a local (sliding window) and global attention to extend to long
    documents without the O(n^2) increase in memory and compute.

    The self-attention module `TFLongformerSelfAttention` implemented here supports the combination of local and global
    attention but it lacks support for autoregressive attention and dilated attention. Autoregressive and dilated
    attention are more relevant for autoregressive language modeling than finetuning on downstream tasks. Future
    release will add support for autoregressive attention, but the support for dilated attention requires a custom CUDA
    kernel to be memory and compute efficient.
    """

    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        
        # 初始化一个TFLongformerMainLayer实例,命名为longformer,用于长文档处理
        self.longformer = TFLongformerMainLayer(config, name="longformer")

    @unpack_inputs
    @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        global_attention_mask: np.ndarray | tf.Tensor | None = None,
        token_type_ids: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: Optional[bool] = False,
    ) -> Union[TFLongformerBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
        # 调用self.longformer的call方法,传递输入参数,获取输出结果
        outputs = self.longformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            global_attention_mask=global_attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        return outputs
    # 定义一个方法 `build`,用于构建模型的层次结构
    def build(self, input_shape=None):
        # 如果模型已经构建完成,直接返回,避免重复构建
        if self.built:
            return
        # 将模型标记为已构建状态
        self.built = True
        # 如果存在 `longformer` 属性,并且不为 None,则执行以下操作
        if getattr(self, "longformer", None) is not None:
            # 使用 `tf.name_scope` 来命名作用域为 `self.longformer.name`
            with tf.name_scope(self.longformer.name):
                # 调用 `self.longformer` 对象的 `build` 方法,传入 `None` 作为输入形状
                self.longformer.build(None)
@add_start_docstrings(
    """Longformer Model with a `language modeling` head on top.""",
    LONGFORMER_START_DOCSTRING,
)
class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss):
    # 在加载 TF 模型时,以下带 '.' 的名称表示从 PT 模型中加载时允许的意外/缺失的层
    _keys_to_ignore_on_load_unexpected = [r"pooler"]

    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        # 初始化 Longformer 主层,不添加池化层,命名为 "longformer"
        self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer")
        # 初始化 Longformer 语言模型头部,连接到 Longformer 的嵌入层,命名为 "lm_head"
        self.lm_head = TFLongformerLMHead(config, self.longformer.embeddings, name="lm_head")

    def get_lm_head(self):
        # 返回语言模型头部对象
        return self.lm_head

    def get_prefix_bias_name(self):
        # 警告:方法 get_prefix_bias_name 已弃用,请使用 `get_bias` 替代
        warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
        # 返回模型名称加上语言模型头部名称的字符串
        return self.name + "/" + self.lm_head.name

    @unpack_inputs
    @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint="allenai/longformer-base-4096",
        output_type=TFLongformerMaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
        mask="<mask>",
        expected_output="' Paris'",
        expected_loss=0.44,
    )
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        global_attention_mask: np.ndarray | tf.Tensor | None = None,
        token_type_ids: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: np.ndarray | tf.Tensor | None = None,
        training: Optional[bool] = False,
        **kwargs,
    ):
        """
        使用 Longformer 进行前向传播,支持以下输入参数:
        - input_ids: 输入的模型标识符
        - attention_mask: 注意力遮罩,指定哪些元素需要被处理
        - head_mask: 头部遮罩,用于控制多头注意力层的掩码
        - global_attention_mask: 全局注意力遮罩,控制全局注意力机制
        - token_type_ids: 标记类型标识符,用于区分不同文本段落
        - position_ids: 位置标识符,指定输入序列中每个位置的绝对位置
        - inputs_embeds: 输入嵌入,替代输入模型标识符的嵌入表示
        - output_attentions: 是否输出注意力权重
        - output_hidden_states: 是否输出隐藏状态
        - return_dict: 是否返回结果字典
        - labels: 标签,用于模型训练
        - training: 是否为训练模式

        其中,kwargs 包含其他未显式列出的关键字参数。
        """
        pass  # 实际的前向传播逻辑在这里被省略了
    ) -> Union[TFLongformerMaskedLMOutput, Tuple[tf.Tensor]]:
        r"""
        labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        """

        # 调用 Longformer 模型执行前向传播,获取模型输出
        outputs = self.longformer(
            input_ids=input_ids,  # 输入的 token IDs
            attention_mask=attention_mask,  # 注意力掩码,指定哪些 token 是有效的
            head_mask=head_mask,  # 头部掩码,指定哪些头部是有效的
            global_attention_mask=global_attention_mask,  # 全局注意力掩码,指定哪些全局注意力是有效的
            token_type_ids=token_type_ids,  # token 类型 IDs,用于区分不同句子的 token
            position_ids=position_ids,  # token 的位置 IDs
            inputs_embeds=inputs_embeds,  # 输入的嵌入表示
            output_attentions=output_attentions,  # 是否输出注意力权重
            output_hidden_states=output_hidden_states,  # 是否输出隐藏状态
            return_dict=return_dict,  # 是否返回字典形式的输出
            training=training,  # 是否处于训练模式
        )
        sequence_output = outputs[0]  # 获取模型输出的序列表示
        prediction_scores = self.lm_head(sequence_output, training=training)  # 使用 LM 头部进行预测
        loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores)  # 计算损失,如果没有标签则损失为 None

        if not return_dict:
            output = (prediction_scores,) + outputs[2:]  # 如果不返回字典,构造输出元组
            return ((loss,) + output) if loss is not None else output  # 返回损失和输出元组,如果损失为 None 则只返回输出

        # 返回字典形式的 TFLongformerMaskedLMOutput,包括损失、预测 logits、隐藏状态和注意力
        return TFLongformerMaskedLMOutput(
            loss=loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            global_attentions=outputs.global_attentions,
        )

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "longformer", None) is not None:
            with tf.name_scope(self.longformer.name):  # 使用 Longformer 名称创建命名空间
                self.longformer.build(None)  # 构建 Longformer 模型
        if getattr(self, "lm_head", None) is not None:
            with tf.name_scope(self.lm_head.name):  # 使用 LM 头部名称创建命名空间
                self.lm_head.build(None)  # 构建 LM 头部模型
"""
Longformer Model with a span classification head on top for extractive question-answering tasks like SQuAD /
TriviaQA (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
"""
# 引用了 TFLongformerPreTrainedModel 和 TFQuestionAnsweringLoss,构建了一个带有跨度分类头部的 Longformer 模型,用于类似 SQuAD / TriviaQA 的抽取式问答任务。

class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss):
    # 当从 PT 模型加载 TF 模型时,'.' 表示授权的意外/缺失的层
    _keys_to_ignore_on_load_unexpected = [r"pooler"]

    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        self.num_labels = config.num_labels
        # 创建 Longformer 主层,不添加池化层
        self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer")
        # 创建用于输出的 Dense 层,输出大小为 config.num_labels
        self.qa_outputs = keras.layers.Dense(
            config.num_labels,
            kernel_initializer=get_initializer(config.initializer_range),
            name="qa_outputs",
        )
        self.config = config

    @unpack_inputs
    # 将文档字符串添加到模型前向方法,描述输入的格式
    @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    # 添加代码示例文档字符串,描述了如何使用模型和预期输出
    @add_code_sample_docstrings(
        checkpoint="allenai/longformer-large-4096-finetuned-triviaqa",
        output_type=TFLongformerQuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
        expected_output="' puppet'",
        expected_loss=0.96,
    )
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        global_attention_mask: np.ndarray | tf.Tensor | None = None,
        token_type_ids: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        start_positions: np.ndarray | tf.Tensor | None = None,
        end_positions: np.ndarray | tf.Tensor | None = None,
        training: Optional[bool] = False,
    ):
        pass  # 这里是模型的前向计算方法,具体内容未提供,需要根据具体实现补充

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果已经定义了 self.longformer,则构建它
        if getattr(self, "longformer", None) is not None:
            with tf.name_scope(self.longformer.name):
                self.longformer.build(None)
        # 如果已经定义了 self.qa_outputs,则构建它
        if getattr(self, "qa_outputs", None) is not None:
            with tf.name_scope(self.qa_outputs.name):
                # 构建 Dense 层,输出形状为 [None, None, self.config.hidden_size]
                self.qa_outputs.build([None, None, self.config.hidden_size])


class TFLongformerClassificationHead(keras.layers.Layer):
    """Head for sentence-level classification tasks."""
    # 初始化函数,用于创建一个新的模型实例
    def __init__(self, config, **kwargs):
        # 调用父类(可能是神经网络层)的初始化方法
        super().__init__(**kwargs)
        
        # 创建一个全连接层,用于处理输入数据
        self.dense = keras.layers.Dense(
            config.hidden_size,  # 设置隐藏层的大小,从配置中获取
            kernel_initializer=get_initializer(config.initializer_range),  # 使用指定范围的初始化器来初始化权重矩阵
            activation="tanh",  # 激活函数为双曲正切函数
            name="dense",  # 层的名称为 dense
        )
        
        # 创建一个 Dropout 层,用于在训练过程中随机失活输入单元
        self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
        
        # 创建一个全连接层,用于最终输出模型的预测结果
        self.out_proj = keras.layers.Dense(
            config.num_labels,  # 输出层的大小,从配置中获取
            kernel_initializer=get_initializer(config.initializer_range),  # 使用指定范围的初始化器来初始化权重矩阵
            name="out_proj"  # 层的名称为 out_proj
        )
        
        # 将配置信息存储到模型中,以便在需要时进行访问
        self.config = config
    
    # 前向传播函数,用于计算模型的输出结果
    def call(self, hidden_states, training=False):
        # 只保留每个样本的第一个隐藏状态,相当于取 <s> 标记(对应 [CLS])
        hidden_states = hidden_states[:, 0, :]
        
        # 根据训练状态应用 Dropout 层,用于防止过拟合
        hidden_states = self.dropout(hidden_states, training=training)
        
        # 通过全连接层处理隐藏状态,以提取特征
        hidden_states = self.dense(hidden_states)
        
        # 再次应用 Dropout 层,增强模型的泛化能力
        hidden_states = self.dropout(hidden_states, training=training)
        
        # 最终通过输出层得到模型的预测结果
        output = self.out_proj(hidden_states)
        
        # 返回模型的输出结果
        return output
    
    # 模型构建函数,用于构建模型的层次结构
    def build(self, input_shape=None):
        # 如果模型已经构建好了,则直接返回
        if self.built:
            return
        
        # 标记模型已经构建
        self.built = True
        
        # 如果存在全连接层 dense,则构建其层次结构
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                self.dense.build([None, None, self.config.hidden_size])
        
        # 如果存在全连接层 out_proj,则构建其层次结构
        if getattr(self, "out_proj", None) is not None:
            with tf.name_scope(self.out_proj.name):
                self.out_proj.build([None, None, self.config.hidden_size])
# 在 TFLongformerForSequenceClassification 类的开始处添加详细的文档字符串,描述其作为 Longformer 模型转换器的用途,
# 以及其在顶部具有一个序列分类/回归头部的功能(即在汇总输出之上的线性层),例如用于 GLUE 任务。
@add_start_docstrings(
    """
    Longformer Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    """,
    LONGFORMER_START_DOCSTRING,
)
class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSequenceClassificationLoss):
    # 当从 PT 模型加载 TF 模型时,带有 '.' 的名称表示在加载过程中可以忽略的授权的意外/丢失的层。
    _keys_to_ignore_on_load_unexpected = [r"pooler"]

    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        # 设置模型可以处理的标签数目
        self.num_labels = config.num_labels

        # 创建 Longformer 主层,不添加池化层,命名为 "longformer"
        self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer")
        
        # 创建 Longformer 分类头部,命名为 "classifier"
        self.classifier = TFLongformerClassificationHead(config, name="classifier")

    # 使用装饰器为 call 方法添加详细的文档字符串,描述其前向推理的输入和输出,基于 LONGFORMER_INPUTS_DOCSTRING 格式化字符串
    # 添加代码示例的文档字符串,显示如何使用此方法
    @unpack_inputs
    @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TFLongformerSequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        token_type_ids: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        global_attention_mask: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: np.ndarray | tf.Tensor | None = None,
        training: Optional[bool] = False,
    # 定义函数签名,指定输入参数和返回类型
    ) -> Union[TFLongformerSequenceClassifierOutput, Tuple[tf.Tensor]]:
        # 如果 input_ids 存在且不是 TensorFlow 张量,则将其转换为 TensorFlow 张量
        if input_ids is not None and not isinstance(input_ids, tf.Tensor):
            input_ids = tf.convert_to_tensor(input_ids, dtype=tf.int64)
        # 否则,如果 input_ids 存在,则将其强制转换为 tf.int64 类型的张量
        elif input_ids is not None:
            input_ids = tf.cast(input_ids, tf.int64)

        # 如果 attention_mask 存在且不是 TensorFlow 张量,则将其转换为 TensorFlow 张量
        if attention_mask is not None and not isinstance(attention_mask, tf.Tensor):
            attention_mask = tf.convert_to_tensor(attention_mask, dtype=tf.int64)
        # 否则,如果 attention_mask 存在,则将其强制转换为 tf.int64 类型的张量
        elif attention_mask is not None:
            attention_mask = tf.cast(attention_mask, tf.int64)

        # 如果 global_attention_mask 存在且不是 TensorFlow 张量,则将其转换为 TensorFlow 张量
        if global_attention_mask is not None and not isinstance(global_attention_mask, tf.Tensor):
            global_attention_mask = tf.convert_to_tensor(global_attention_mask, dtype=tf.int64)
        # 否则,如果 global_attention_mask 存在,则将其强制转换为 tf.int64 类型的张量
        elif global_attention_mask is not None:
            global_attention_mask = tf.cast(global_attention_mask, tf.int64)

        # 如果 global_attention_mask 为空且 input_ids 存在,则发出警告并初始化全局注意力掩码
        if global_attention_mask is None and input_ids is not None:
            logger.warning_once("Initializing global attention on CLS token...")
            # 在 CLS 标记上的全局注意力
            global_attention_mask = tf.zeros_like(input_ids)
            # 创建一个更新张量,其形状为 input_ids 的第一个维度大小,数据类型为 tf.int64
            updates = tf.ones(shape_list(input_ids)[0], dtype=tf.int64)
            # 创建索引张量,用于更新 global_attention_mask
            indices = tf.pad(
                tensor=tf.expand_dims(tf.range(shape_list(input_ids)[0], dtype=tf.int64), axis=1),
                paddings=[[0, 0], [0, 1]],
                constant_values=0,
            )
            # 使用 tf.tensor_scatter_nd_update 函数更新 global_attention_mask
            global_attention_mask = tf.tensor_scatter_nd_update(
                global_attention_mask,
                indices,
                updates,
            )

        # 调用 self.longformer 进行序列分类器的计算,传入多个参数
        outputs = self.longformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            global_attention_mask=global_attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )
        # 获取模型输出的序列输出
        sequence_output = outputs[0]
        # 将序列输出传入分类器,得到 logits
        logits = self.classifier(sequence_output)

        # 如果 labels 存在,则计算损失,否则设置损失为 None
        loss = None if labels is None else self.hf_compute_loss(labels, logits)

        # 如果不返回字典形式的结果,则组合输出,并根据是否存在损失决定是否包含损失
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 返回 TFLongformerSequenceClassifierOutput 类型的对象,包含损失、logits、隐藏状态和注意力
        return TFLongformerSequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            global_attentions=outputs.global_attentions,
        )
    # 定义神经网络层的构建方法,当输入形状不为None时,指示该方法已经被调用过一次
    def build(self, input_shape=None):
        # 如果已经构建过网络层,则直接返回,避免重复构建
        if self.built:
            return
        # 将网络层标记为已构建状态
        self.built = True
        # 如果存在长形式网络层(longformer),则构建该网络层
        if getattr(self, "longformer", None) is not None:
            # 在命名空间中构建长形式网络层
            with tf.name_scope(self.longformer.name):
                self.longformer.build(None)
        # 如果存在分类器网络层(classifier),则构建该网络层
        if getattr(self, "classifier", None) is not None:
            # 在命名空间中构建分类器网络层
            with tf.name_scope(self.classifier.name):
                self.classifier.build(None)
@add_start_docstrings(
    """
    Longformer Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
    a softmax) e.g. for RocStories/SWAG tasks.
    """,
    LONGFORMER_START_DOCSTRING,
)
class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoiceLoss):
    """
    定义了一个基于Longformer模型的多选题分类器,通过在汇总输出之上添加一个线性层和softmax来实现,
    例如用于RocStories/SWAG任务。
    """

    # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
    _keys_to_ignore_on_load_missing = [r"dropout"]
    """
    在从PT模型加载TF模型时,表示授权的意外/缺失层的名称列表。
    """

    def __init__(self, config, *inputs, **kwargs):
        """
        初始化方法,用于创建模型实例。
        Args:
            config: Longformer模型的配置对象。
            *inputs: 可变长度的输入参数。
            **kwargs: 关键字参数。
        """
        super().__init__(config, *inputs, **kwargs)

        self.longformer = TFLongformerMainLayer(config, name="longformer")
        """
        创建Longformer的主层实例,使用给定的配置和名称。
        """

        self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
        """
        创建一个Dropout层,用于在训练过程中随机丢弃部分神经元,防止过拟合。
        """

        self.classifier = keras.layers.Dense(
            1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )
        """
        创建一个全连接层作为分类器,输出维度为1,使用给定的初始化器范围初始化权重。
        """

        self.config = config
        """
        保存配置对象供后续使用。
        """

    @property
    def input_signature(self):
        """
        定义模型的输入签名,指定了各输入张量的形状和类型。
        """
        return {
            "input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
            "attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
            "global_attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="global_attention_mask"),
        }

    @unpack_inputs
    @add_start_docstrings_to_model_forward(
        LONGFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
    )
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TFLongformerMultipleChoiceModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        token_type_ids: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        global_attention_mask: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: np.ndarray | tf.Tensor | None = None,
        training: Optional[bool] = False,
        ) -> Union[TFLongformerMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
        r"""
        labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
            where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
        """

        # 如果 `input_ids` 不为 None,则获取其第二维度的大小作为 `num_choices`,并获取序列长度 `seq_length`
        if input_ids is not None:
            num_choices = shape_list(input_ids)[1]
            seq_length = shape_list(input_ids)[2]
        else:
            # 否则,使用 `inputs_embeds` 的第二维度大小作为 `num_choices`,并获取序列长度 `seq_length`
            num_choices = shape_list(inputs_embeds)[1]
            seq_length = shape_list(inputs_embeds)[2]

        # 将输入张量展平成二维张量,如果相应的输入不为 None
        flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
        flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
        flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
        flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
        flat_global_attention_mask = (
            tf.reshape(global_attention_mask, (-1, shape_list(global_attention_mask)[-1]))
            if global_attention_mask is not None
            else None
        )
        flat_inputs_embeds = (
            tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
            if inputs_embeds is not None
            else None
        )

        # 调用长形式模型进行处理,传入展平后的输入和其他参数
        outputs = self.longformer(
            flat_input_ids,
            position_ids=flat_position_ids,
            token_type_ids=flat_token_type_ids,
            attention_mask=flat_attention_mask,
            head_mask=head_mask,
            global_attention_mask=flat_global_attention_mask,
            inputs_embeds=flat_inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )
        # 获取汇聚输出(pooled output)
        pooled_output = outputs[1]

        # 对汇聚输出应用 dropout 操作
        pooled_output = self.dropout(pooled_output)
        # 将汇聚输出传入分类器,得到 logits
        logits = self.classifier(pooled_output)
        # 将 logits 重新整形为二维张量,形状为 (-1, num_choices)
        reshaped_logits = tf.reshape(logits, (-1, num_choices))

        # 如果没有提供 labels,则 loss 为 None;否则使用指定方法计算损失
        loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)

        # 如果不要求返回字典,则构造输出元组
        if not return_dict:
            output = (reshaped_logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 如果需要返回字典形式的输出,则构造 TFLongformerMultipleChoiceModelOutput 对象
        return TFLongformerMultipleChoiceModelOutput(
            loss=loss,
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            global_attentions=outputs.global_attentions,
        )
    # 构建函数,用于构建模型的层次结构,如果已经构建过则直接返回
    def build(self, input_shape=None):
        if self.built:
            return
        # 标记模型已构建
        self.built = True
        
        # 如果存在名为"longformer"的属性且不为None,则构建其对应的层次结构
        if getattr(self, "longformer", None) is not None:
            # 在TensorFlow中创建名为self.longformer.name的命名空间
            with tf.name_scope(self.longformer.name):
                # 构建self.longformer层的结构
                self.longformer.build(None)
        
        # 如果存在名为"classifier"的属性且不为None,则构建其对应的层次结构
        if getattr(self, "classifier", None) is not None:
            # 在TensorFlow中创建名为self.classifier.name的命名空间
            with tf.name_scope(self.classifier.name):
                # 构建self.classifier层的结构,输入维度为[None, None, self.config.hidden_size]
                self.classifier.build([None, None, self.config.hidden_size])
# 使用装饰器添加文档字符串,描述了这是一个在Longformer模型基础上增加了标记分类头的类,用于命名实体识别(NER)等任务
@add_start_docstrings(
    """
    Longformer Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
    for Named-Entity-Recognition (NER) tasks.
    """,
    LONGFORMER_START_DOCSTRING,  # 引用了LONGFORMER_START_DOCSTRING作为模型的开始文档字符串
)
class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenClassificationLoss):
    # 在从PyTorch模型加载到TensorFlow模型时,表示可以忽略的预期意外或缺失层的名称列表
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    # 在加载模型时可以忽略的缺失层名称列表
    _keys_to_ignore_on_load_missing = [r"dropout"]

    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        self.num_labels = config.num_labels  # 从配置中获取标签的数量
        self.longformer = TFLongformerMainLayer(config=config, add_pooling_layer=False, name="longformer")
        self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)  # 根据配置添加一个Dropout层
        self.classifier = keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )  # 增加一个全连接层作为分类器,输出维度为标签数量,使用指定的初始化器初始化权重
        self.config = config  # 保存配置信息

    # 使用装饰器定义模型的call方法,并添加多个文档字符串,描述了模型的输入和输出等信息
    @unpack_inputs
    @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,  # 指定了文档中的检查点示例
        output_type=TFLongformerTokenClassifierOutput,  # 指定了输出类型
        config_class=_CONFIG_FOR_DOC,  # 指定了配置类示例
    )
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        token_type_ids: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        global_attention_mask: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: Optional[Union[np.array, tf.Tensor]] = None,
        training: Optional[bool] = False,
    ) -> Union[TFLongformerTokenClassifierOutput, Tuple[tf.Tensor]]:
        r"""
        labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        # 定义函数签名和文档字符串,指定函数返回类型为 TFLongformerTokenClassifierOutput 或包含 tf.Tensor 的元组

        # 调用 Longformer 模型进行前向传播
        outputs = self.longformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            global_attention_mask=global_attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )
        # 从模型输出中获取序列输出
        sequence_output = outputs[0]
        # 对序列输出应用 dropout
        sequence_output = self.dropout(sequence_output)
        # 将 dropout 后的输出传入分类器获取 logits
        logits = self.classifier(sequence_output)
        # 如果提供了标签,则计算损失
        loss = None if labels is None else self.hf_compute_loss(labels, logits)

        # 如果 return_dict=False,则返回不同的输出形式
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 如果 return_dict=True,则返回 TFLongformerTokenClassifierOutput 对象
        return TFLongformerTokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            global_attentions=outputs.global_attentions,
        )

    def build(self, input_shape=None):
        if self.built:
            return
        # 标记模型已构建
        self.built = True
        # 如果存在 Longformer 模型,则构建它
        if getattr(self, "longformer", None) is not None:
            with tf.name_scope(self.longformer.name):
                self.longformer.build(None)
        # 如果存在分类器,则构建它,指定其输入形状
        if getattr(self, "classifier", None) is not None:
            with tf.name_scope(self.classifier.name):
                self.classifier.build([None, None, self.config.hidden_size])

.\models\longformer\tokenization_longformer.py

# 导入所需模块和库
import json  # 导入处理 JSON 格式数据的模块
import os  # 导入操作系统相关功能的模块
from functools import lru_cache  # 导入用于缓存函数调用结果的装饰器
from typing import List, Optional, Tuple  # 导入用于类型提示的模块

import regex as re  # 导入正则表达式库,命名为 re

# 从 tokenization_utils 模块中导入 AddedToken 和 PreTrainedTokenizer 类
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
# 导入日志记录模块中的日志记录器
from ...utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义词汇文件的名称映射字典
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt"}

# 预训练模型的词汇文件映射字典
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "allenai/longformer-base-4096": "https://huggingface.co/allenai/longformer-base-4096/resolve/main/vocab.json",
        "allenai/longformer-large-4096": (
            "https://huggingface.co/allenai/longformer-large-4096/resolve/main/vocab.json"
        ),
        "allenai/longformer-large-4096-finetuned-triviaqa": (
            "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/vocab.json"
        ),
        "allenai/longformer-base-4096-extra.pos.embd.only": (
            "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/vocab.json"
        ),
        "allenai/longformer-large-4096-extra.pos.embd.only": (
            "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/vocab.json"
        ),
    },
    "merges_file": {
        "allenai/longformer-base-4096": "https://huggingface.co/allenai/longformer-base-4096/resolve/main/merges.txt",
        "allenai/longformer-large-4096": (
            "https://huggingface.co/allenai/longformer-large-4096/resolve/main/merges.txt"
        ),
        "allenai/longformer-large-4096-finetuned-triviaqa": (
            "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/merges.txt"
        ),
        "allenai/longformer-base-4096-extra.pos.embd.only": (
            "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/merges.txt"
        ),
        "allenai/longformer-large-4096-extra.pos.embd.only": (
            "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/merges.txt"
        ),
    },
}

# 预训练位置嵌入的尺寸映射字典
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "allenai/longformer-base-4096": 4096,
    "allenai/longformer-large-4096": 4096,
    "allenai/longformer-large-4096-finetuned-triviaqa": 4096,
    "allenai/longformer-base-4096-extra.pos.embd.only": 4096,
}
    # 定义一个字符串键值对,键是文件路径,值是整数 4096
    "allenai/longformer-large-4096-extra.pos.embd.only": 4096,
}


@lru_cache()
# 从transformers.models.roberta.tokenization_roberta.bytes_to_unicode中复制而来
# 返回一个字节到Unicode字符串的映射表
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
    characters the bpe code barfs on.

    The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
    if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
    decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
    tables between utf-8 bytes and unicode strings.
    """
    # 定义一个字节列表bs,包含了utf-8编码中可打印字符和特定范围内的其他字符
    bs = (
        list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
    )
    # 复制一份字节列表到cs
    cs = bs[:]
    # 初始化一个计数器n为0
    n = 0
    # 遍历0到255之间的所有字节
    for b in range(2**8):
        # 如果当前字节b不在bs列表中
        if b not in bs:
            # 将b添加到bs列表中
            bs.append(b)
            # 将2**8 + n添加到cs列表中,并增加计数器n
            cs.append(2**8 + n)
            n += 1
    # 将cs列表中的每个整数转换为对应的Unicode字符,形成一个字节到Unicode字符的映射表并返回
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))


# 从transformers.models.roberta.tokenization_roberta.get_pairs中复制而来
# 返回一个单词中的符号对的集合
def get_pairs(word):
    """
    Return set of symbol pairs in a word.

    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    # 初始化一个空集合pairs
    pairs = set()
    # 从单词的第一个符号开始遍历到倒数第二个符号
    prev_char = word[0]
    for char in word[1:]:
        # 将相邻的符号对(prev_char, char)加入到pairs集合中
        pairs.add((prev_char, char))
        prev_char = char
    # 返回符号对的集合pairs
    return pairs


# 从transformers.models.roberta.tokenization_roberta.RobertaTokenizer中复制而来
class LongformerTokenizer(PreTrainedTokenizer):
    """
    Constructs a Longformer tokenizer, derived from the GPT-2 tokenizer, using byte-level Byte-Pair-Encoding.

    This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
    be encoded differently whether it is at the beginning of the sentence (without space) or not:

    ```
    >>> from transformers import LongformerTokenizer

    >>> tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
    >>> tokenizer("Hello world")["input_ids"]
    [0, 31414, 232, 2]

    >>> tokenizer(" Hello world")["input_ids"]
    [0, 20920, 232, 2]
    ```

    You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
    call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.

    <Tip>

    When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one).

    </Tip>

    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
    this superclass for more information regarding those methods.
    """
    # 构造一个Longformer分词器,继承自PreTrainedTokenizer类

    def __init__(self, *init_inputs, **kwargs):
        # 调用父类的构造函数,传入所有初始化参数和关键字参数
        super().__init__(*init_inputs, **kwargs)

    # LongformerTokenizer类还有其他方法和属性,在此省略...
    # 定义一个名为 vocab_file 的参数,表示词汇表文件的路径
    vocab_file (`str`):
        Path to the vocabulary file.
    # 定义一个名为 merges_file 的参数,表示合并文件的路径
    merges_file (`str`):
        Path to the merges file.
    # 定义一个名为 errors 的参数,表示解码字节为 UTF-8 时的错误处理方式,默认为 "replace"
    errors (`str`, *optional*, defaults to `"replace"`):
        Paradigm to follow when decoding bytes to UTF-8. See
        [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
    # 定义一个名为 bos_token 的参数,表示序列的起始标记,默认为 `"<s>"`
    bos_token (`str`, *optional*, defaults to `"<s>"`):
        The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.

        <Tip>

        When building a sequence using special tokens, this is not the token that is used for the beginning of
        sequence. The token used is the `cls_token`.

        </Tip>

    # 定义一个名为 eos_token 的参数,表示序列的结束标记,默认为 `"</s>"`
    eos_token (`str`, *optional*, defaults to `"</s>"`):
        The end of sequence token.

        <Tip>

        When building a sequence using special tokens, this is not the token that is used for the end of sequence.
        The token used is the `sep_token`.

        </Tip>

    # 定义一个名为 sep_token 的参数,表示序列的分隔标记,默认为 `"</s>"`
    sep_token (`str`, *optional*, defaults to `"</s>"`):
        The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
        sequence classification or for a text and a question for question answering. It is also used as the last
        token of a sequence built with special tokens.
    # 定义一个名为 cls_token 的参数,表示分类器标记,默认为 `"<s>"`
    cls_token (`str`, *optional*, defaults to `"<s>"`):
        The classifier token which is used when doing sequence classification (classification of the whole sequence
        instead of per-token classification). It is the first token of the sequence when built with special tokens.
    # 定义一个名为 unk_token 的参数,表示未知标记,默认为 `"<unk>"`
    unk_token (`str`, *optional*, defaults to `"<unk>"`):
        The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
        token instead.
    # 定义一个名为 pad_token 的参数,表示填充标记,默认为 `"<pad>"`
    pad_token (`str`, *optional*, defaults to `"<pad>"`):
        The token used for padding, for example when batching sequences of different lengths.
    # 定义一个名为 mask_token 的参数,表示掩码标记,默认为 `"<mask>"`
    mask_token (`str`, *optional*, defaults to `"<mask>"`):
        The token used for masking values. This is the token used when training this model with masked language
        modeling. This is the token which the model will try to predict.
    # 定义一个名为 add_prefix_space 的参数,表示是否在输入开头添加空格,默认为 `False`
    add_prefix_space (`bool`, *optional*, defaults to `False`):
        Whether or not to add an initial space to the input. This allows to treat the leading word just as any
        other word. (Longformer tokenizer detect beginning of words by the preceding space).

vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
    def __init__(
        self,
        vocab_file,
        merges_file,
        errors="replace",
        bos_token="<s>",
        eos_token="</s>",
        sep_token="</s>",
        cls_token="<s>",
        unk_token="<unk>",
        pad_token="<pad>",
        mask_token="<mask>",
        add_prefix_space=False,
        **kwargs,
    ):
        # 如果 `bos_token` 是字符串,则创建一个 `AddedToken` 对象,保留其左右空格
        bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
        # 如果 `pad_token` 是字符串,则创建一个 `AddedToken` 对象,保留其左右空格
        pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
        # 如果 `eos_token` 是字符串,则创建一个 `AddedToken` 对象,保留其左右空格
        eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
        # 如果 `unk_token` 是字符串,则创建一个 `AddedToken` 对象,保留其左右空格
        unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
        # 如果 `sep_token` 是字符串,则创建一个 `AddedToken` 对象,保留其左右空格
        sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
        # 如果 `cls_token` 是字符串,则创建一个 `AddedToken` 对象,保留其左右空格
        cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token

        # `mask_token` 被视为普通单词,即在其前面包含空格
        mask_token = (
            AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False)
            if isinstance(mask_token, str)
            else mask_token
        )

        # 这些特殊标记不包含在 `vocab.json` 中,让我们按正确的顺序添加它们
        # 使用 UTF-8 编码打开 `vocab_file`,加载其中的 JSON 数据到 `self.encoder`
        with open(vocab_file, encoding="utf-8") as vocab_handle:
            self.encoder = json.load(vocab_handle)
        # 创建一个反向映射,将 `self.encoder` 的键值对调,存储到 `self.decoder`
        self.decoder = {v: k for k, v in self.encoder.items()}
        # 设置错误处理方式为 `errors`
        self.errors = errors  # how to handle errors in decoding
        # 创建字节到 Unicode 的编码映射
        self.byte_encoder = bytes_to_unicode()
        # 创建一个反向映射,将 `self.byte_encoder` 的键值对调,存储到 `self.byte_decoder`
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        # 使用 UTF-8 编码打开 `merges_file`,读取并分割为 BPE 合并列表 `bpe_merges`
        with open(merges_file, encoding="utf-8") as merges_handle:
            bpe_merges = merges_handle.read().split("\n")[1:-1]
        # 将每个 BPE 合并规则字符串转换为元组,并创建其对应的索引字典 `self.bpe_ranks`
        bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
        # 初始化缓存字典
        self.cache = {}
        # 是否在添加前缀空格
        self.add_prefix_space = add_prefix_space

        # 应该添加 `re.IGNORECASE` 以便对缩写的大写版本进行 BPE 合并
        # 编译正则表达式 `self.pat`,匹配 `'s`、`'t`、`'re`、`'ve`、`'m`、`'ll`、`'d`、`\p{L}+`、`\p{N}+`、`[^\s\p{L}\p{N}]+`、不跟随非空白字符的空格
        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

        # 调用父类的初始化方法,传递参数和关键字参数
        super().__init__(
            errors=errors,
            bos_token=bos_token,
            eos_token=eos_token,
            unk_token=unk_token,
            sep_token=sep_token,
            cls_token=cls_token,
            pad_token=pad_token,
            mask_token=mask_token,
            add_prefix_space=add_prefix_space,
            **kwargs,
        )

    @property
    def vocab_size(self):
        # 返回 `self.encoder` 中键的数量,即词汇表的大小
        return len(self.encoder)

    def get_vocab(self):
        # 创建 `vocab` 字典,复制 `self.encoder` 的内容,然后更新添加的特殊标记编码映射
        vocab = dict(self.encoder).copy()
        vocab.update(self.added_tokens_encoder)
        return vocab
    def bpe(self, token):
        # 如果 token 已经在缓存中,直接返回缓存中的结果
        if token in self.cache:
            return self.cache[token]
        # 将 token 转换为元组形式
        word = tuple(token)
        # 获取 token 的所有字符对
        pairs = get_pairs(word)

        # 如果没有字符对,则直接返回原始 token
        if not pairs:
            return token

        # 反复处理字符对,直到无法继续拆分
        while True:
            # 找到当前权重最小的字符对
            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
            # 如果找到的字符对不在预先计算的排名中,则停止拆分
            if bigram not in self.bpe_ranks:
                break
            # 分离出字符对的两个部分
            first, second = bigram
            new_word = []
            i = 0
            # 遍历当前 word 中的字符
            while i < len(word):
                try:
                    # 查找字符对的第一个字符在 word 中的位置
                    j = word.index(first, i)
                except ValueError:
                    # 如果找不到,则将剩余部分直接加入新单词中
                    new_word.extend(word[i:])
                    break
                else:
                    # 将非字符对部分加入新单词中
                    new_word.extend(word[i:j])
                    i = j

                # 检查当前位置是否匹配字符对的第一个和第二个字符
                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
                    # 如果匹配,将字符对作为一个单元添加到新单词中
                    new_word.append(first + second)
                    i += 2
                else:
                    # 否则将当前字符添加到新单词中,并移动到下一个位置
                    new_word.append(word[i])
                    i += 1
            # 将新单词转换为元组形式,更新 word 变量
            new_word = tuple(new_word)
            word = new_word
            # 如果新单词长度为 1,则停止拆分
            if len(word) == 1:
                break
            else:
                # 继续获取新的字符对
                pairs = get_pairs(word)
        
        # 将拆分后的单词连接成字符串形式
        word = " ".join(word)
        # 将结果缓存起来,避免重复计算
        self.cache[token] = word
        return word

    def _tokenize(self, text):
        """Tokenize a string."""
        # 初始化空的 BPE tokens 列表
        bpe_tokens = []
        # 使用正则表达式找出所有符合条件的 token
        for token in re.findall(self.pat, text):
            # 将每个 token 转换为 BPE token,并加入到 bpe_tokens 中
            token = "".join(
                self.byte_encoder[b] for b in token.encode("utf-8")
            )  # 将所有字节映射为 unicode 字符串,避免 BPE 中的控制符(在我们的情况下是空格)
            bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
        return bpe_tokens

    def _convert_token_to_id(self, token):
        """Converts a token (str) in an id using the vocab."""
        # 根据词汇表将 token 转换为对应的 id
        return self.encoder.get(token, self.encoder.get(self.unk_token))

    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        # 根据词汇表将 id 转换为对应的 token
        return self.decoder.get(index)

    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (string) in a single string."""
        # 将一系列 token 连接成一个字符串
        text = "".join(tokens)
        # 将字符串中的每个字节解码为 unicode 字符串,使用指定的错误处理方法
        text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
        return text
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # 检查保存目录是否存在,如果不存在则记录错误并返回
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return
        
        # 构建词汇文件路径,结合保存目录和文件名前缀(如果有的话)
        vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )
        # 构建合并文件路径,结合保存目录和文件名前缀(如果有的话)
        merge_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
        )

        # 将编码器(self.encoder)的内容以 JSON 格式写入词汇文件
        with open(vocab_file, "w", encoding="utf-8") as f:
            f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")

        index = 0
        # 将 BPE 合并信息写入合并文件
        with open(merge_file, "w", encoding="utf-8") as writer:
            writer.write("#version: 0.2\n")
            # 对 BPE 合并信息按照索引排序并写入文件
            for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
                if index != token_index:
                    # 如果 BPE 合并索引不连续,记录警告信息
                    logger.warning(
                        f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
                        " Please check that the tokenizer is not corrupted!"
                    )
                    index = token_index
                writer.write(" ".join(bpe_tokens) + "\n")
                index += 1

        # 返回保存的词汇文件路径和合并文件路径
        return vocab_file, merge_file

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
        adding special tokens. A Longformer sequence has the following format:

        - single sequence: `<s> X </s>`
        - pair of sequences: `<s> A </s></s> B </s>`

        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
        """
        # 如果只有一个输入序列,添加起始和结束特殊标记,并返回结果
        if token_ids_1 is None:
            return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
        
        # 如果有两个输入序列,添加起始、分隔、分隔以及第二个序列的起始和结束特殊标记,并返回结果
        cls = [self.cls_token_id]
        sep = [self.sep_token_id]
        return cls + token_ids_0 + sep + sep + token_ids_1 + sep

    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ):
        # 确定输入序列是否已经包含了特殊标记
        # 如果已经包含了特殊标记,则创建一个与输入序列长度相同的掩码,所有特殊标记位置为1,其余为0
        if already_has_special_tokens:
            return [1] * len(token_ids_0)
        
        # 初始化一个空列表作为掩码
        special_tokens_mask = []
        # 遍历第一个输入序列的每个元素,将特殊标记位置设为1,其余为0
        for token_id in token_ids_0:
            special_tokens_mask.append(1 if token_id in [self.cls_token_id, self.sep_token_id] else 0)
        
        # 如果有第二个输入序列,同样处理它的特殊标记
        if token_ids_1 is not None:
            for token_id in token_ids_1:
                special_tokens_mask.append(1 if token_id in [self.cls_token_id, self.sep_token_id] else 0)
        
        # 返回最终生成的特殊标记掩码
        return special_tokens_mask
    ) -> List[int]:
        """
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` method.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """
        # 如果已经有特殊的标记,直接调用父类方法获取特殊标记的掩码
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        # 如果没有特殊的标记,根据输入的 token_ids_1 是否为 None,决定返回的特殊标记的掩码列表
        if token_ids_1 is None:
            return [1] + ([0] * len(token_ids_0)) + [1]
        else:
            return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. Longformer does not
        make use of token type ids, therefore a list of zeros is returned.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of zeros.
        """
        # 初始化 SEP 和 CLS 标记的列表
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]

        # 根据 token_ids_1 是否为 None,返回相应长度的全零列表作为 token type ids
        if token_ids_1 is None:
            return len(cls + token_ids_0 + sep) * [0]
        else:
            return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]

    def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
        """
        Prepare text for tokenization by optionally adding a space prefix based on the arguments.

        Args:
            text (str): Input text to be tokenized.
            is_split_into_words (bool, optional): Whether the text is already split into words.
            **kwargs: Additional keyword arguments.

        Returns:
            Tuple[str, dict]: Processed text and any remaining keyword arguments.
        """
        # 获取 add_prefix_space 参数,默认使用对象的设置
        add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
        
        # 如果需要添加前缀空格,并且文本不以空格开头,则在文本前添加空格
        if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()):
            text = " " + text
        
        # 返回处理后的文本和可能修改过的关键字参数
        return (text, kwargs)

.\models\longformer\tokenization_longformer_fast.py

# coding=utf-8
# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Tokenization classes for Longformer."""
# 导入需要的模块
import json
from typing import List, Optional, Tuple

from tokenizers import pre_tokenizers, processors

# 导入基础的 tokenization 类和 fast tokenization 类
from ...tokenization_utils_base import AddedToken, BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging

# 导入 Longformer 的 tokenizer 类
from .tokenization_longformer import LongformerTokenizer

# 获取 logger 对象
logger = logging.get_logger(__name__)

# 定义文件名常量
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}

# 预训练模型的文件映射
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "allenai/longformer-base-4096": "https://huggingface.co/allenai/longformer-base-4096/resolve/main/vocab.json",
        "allenai/longformer-large-4096": (
            "https://huggingface.co/allenai/longformer-large-4096/resolve/main/vocab.json"
        ),
        "allenai/longformer-large-4096-finetuned-triviaqa": (
            "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/vocab.json"
        ),
        "allenai/longformer-base-4096-extra.pos.embd.only": (
            "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/vocab.json"
        ),
        "allenai/longformer-large-4096-extra.pos.embd.only": (
            "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/vocab.json"
        ),
    },
    "merges_file": {
        "allenai/longformer-base-4096": "https://huggingface.co/allenai/longformer-base-4096/resolve/main/merges.txt",
        "allenai/longformer-large-4096": (
            "https://huggingface.co/allenai/longformer-large-4096/resolve/main/merges.txt"
        ),
        "allenai/longformer-large-4096-finetuned-triviaqa": (
            "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/merges.txt"
        ),
        "allenai/longformer-base-4096-extra.pos.embd.only": (
            "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/merges.txt"
        ),
        "allenai/longformer-large-4096-extra.pos.embd.only": (
            "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/merges.txt"
        ),
    },
    # 定义一个字典,存储各种模型的 tokenizer 文件及其对应的 URL
    "tokenizer_file": {
        # AllenAI Longformer Base 4096 模型的 tokenizer 文件及 URL
        "allenai/longformer-base-4096": (
            "https://huggingface.co/allenai/longformer-base-4096/resolve/main/tokenizer.json"
        ),
        # AllenAI Longformer Large 4096 模型的 tokenizer 文件及 URL
        "allenai/longformer-large-4096": (
            "https://huggingface.co/allenai/longformer-large-4096/resolve/main/tokenizer.json"
        ),
        # AllenAI Longformer Large 4096 在 TriviaQA 数据集上微调的 tokenizer 文件及 URL
        "allenai/longformer-large-4096-finetuned-triviaqa": (
            "https://huggingface.co/allenai/longformer-large-4096-finetuned-triviaqa/resolve/main/tokenizer.json"
        ),
        # AllenAI Longformer Base 4096 的额外位置嵌入模型的 tokenizer 文件及 URL
        "allenai/longformer-base-4096-extra.pos.embd.only": (
            "https://huggingface.co/allenai/longformer-base-4096-extra.pos.embd.only/resolve/main/tokenizer.json"
        ),
        # AllenAI Longformer Large 4096 的额外位置嵌入模型的 tokenizer 文件及 URL
        "allenai/longformer-large-4096-extra.pos.embd.only": (
            "https://huggingface.co/allenai/longformer-large-4096-extra.pos.embd.only/resolve/main/tokenizer.json"
        ),
    },
}

# 预训练位置嵌入大小的映射,将模型名称映射到其预训练位置嵌入的长度
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "allenai/longformer-base-4096": 4096,
    "allenai/longformer-large-4096": 4096,
    "allenai/longformer-large-4096-finetuned-triviaqa": 4096,
    "allenai/longformer-base-4096-extra.pos.embd.only": 4096,
    "allenai/longformer-large-4096-extra.pos.embd.only": 4096,
}

# 从transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast中复制而来,
# 用于从FacebookAI/roberta-base转换为allenai/longformer-base-4096,RoBERTa转换为Longformer全大小写,Roberta转换为Longformer
class LongformerTokenizerFast(PreTrainedTokenizerFast):
    """
    Construct a "fast" Longformer tokenizer (backed by HuggingFace's *tokenizers* library), derived from the GPT-2
    tokenizer, using byte-level Byte-Pair-Encoding.

    This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
    be encoded differently whether it is at the beginning of the sentence (without space) or not:

    ```
    >>> from transformers import LongformerTokenizerFast

    >>> tokenizer = LongformerTokenizerFast.from_pretrained("allenai/longformer-base-4096")
    >>> tokenizer("Hello world")["input_ids"]
    [0, 31414, 232, 2]

    >>> tokenizer(" Hello world")["input_ids"]
    [0, 20920, 232, 2]
    ```

    You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
    call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.

    <Tip>

    When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.

    </Tip>

    This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
    refer to this superclass for more information regarding those methods.
    ```
    # 定义函数参数说明
    Args:
        vocab_file (`str`):
            # 词汇表文件的路径。
            Path to the vocabulary file.
        merges_file (`str`):
            # 合并文件的路径。
            Path to the merges file.
        errors (`str`, *optional*, defaults to `"replace"`):
            # 解码字节为 UTF-8 时的错误处理策略。详见 bytes.decode 的说明文档。
            Paradigm to follow when decoding bytes to UTF-8. See
            [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
        bos_token (`str`, *optional*, defaults to `"<s>"`):
            # 预训练过程中用作序列开头的特殊标记。
            The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.

            <Tip>

            When building a sequence using special tokens, this is not the token that is used for the beginning of
            sequence. The token used is the `cls_token`.

            </Tip>

        eos_token (`str`, *optional*, defaults to `"</s>"`):
            # 序列的结束标记。

            <Tip>

            When building a sequence using special tokens, this is not the token that is used for the end of sequence.
            The token used is the `sep_token`.

            </Tip>

        sep_token (`str`, *optional*, defaults to `"</s>"`):
            # 分隔符标记,在构建多序列时使用,例如序列分类或问答时的文本和问题。
            The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
            sequence classification or for a text and a question for question answering. It is also used as the last
            token of a sequence built with special tokens.
        cls_token (`str`, *optional*, defaults to `"<s>"`):
            # 分类器标记,在序列分类时使用(整体序列分类而不是每个标记的分类)。
            The classifier token which is used when doing sequence classification (classification of the whole sequence
            instead of per-token classification). It is the first token of the sequence when built with special tokens.
        unk_token (`str`, *optional*, defaults to `"<unk>"`):
            # 未知标记,词汇表中不存在的标记将被设置为此标记。
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        pad_token (`str`, *optional*, defaults to `"<pad>"`):
            # 填充标记,在批处理不同长度序列时使用。
            The token used for padding, for example when batching sequences of different lengths.
        mask_token (`str`, *optional*, defaults to `"<mask>"`):
            # 掩码标记,用于掩码语言建模训练。
            The token used for masking values. This is the token used when training this model with masked language
            modeling. This is the token which the model will try to predict.
        add_prefix_space (`bool`, *optional*, defaults to `False`):
            # 是否在输入前添加初始空格,用于长序列处理。
            Whether or not to add an initial space to the input. This allows to treat the leading word just as any
            other word. (Longformer tokenizer detect beginning of words by the preceding space).
        trim_offsets (`bool`, *optional*, defaults to `True`):
            # 后处理步骤是否应修剪偏移量以避免包含空格。
            Whether the post processing step should trim offsets to avoid including whitespaces.
    """

    # 从预定义的常量中获取相关文件名、映射和模型输入大小信息
    vocab_files_names = VOCAB_FILES_NAMES
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    # 定义模型输入的名称列表,包括输入的标识和注意力掩码
    model_input_names = ["input_ids", "attention_mask"]
    # 慢速分词器的类,使用 LongformerTokenizer
    slow_tokenizer_class = LongformerTokenizer

    # 初始化函数,用于设置 tokenizer 的各种参数和状态
    def __init__(
        self,
        vocab_file=None,
        merges_file=None,
        tokenizer_file=None,
        errors="replace",
        bos_token="<s>",
        eos_token="</s>",
        sep_token="</s>",
        cls_token="<s>",
        unk_token="<unk>",
        pad_token="<pad>",
        mask_token="<mask>",
        add_prefix_space=False,
        trim_offsets=True,
        **kwargs,
    ):
        # 如果 mask_token 是字符串,则创建一个 AddedToken 对象
        mask_token = (
            AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False)
            if isinstance(mask_token, str)
            else mask_token
        )
        # 调用父类的初始化方法,设置 tokenizer 的基本参数和文件路径
        super().__init__(
            vocab_file,
            merges_file,
            tokenizer_file=tokenizer_file,
            errors=errors,
            bos_token=bos_token,
            eos_token=eos_token,
            sep_token=sep_token,
            cls_token=cls_token,
            unk_token=unk_token,
            pad_token=pad_token,
            mask_token=mask_token,
            add_prefix_space=add_prefix_space,
            trim_offsets=trim_offsets,
            **kwargs,
        )

        # 获取 backend_tokenizer 的预处理器状态,并根据传入的参数进行调整
        pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
        if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
            # 更新预处理器的类型和参数
            pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
            pre_tok_state["add_prefix_space"] = add_prefix_space
            self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)

        # 设置当前对象的 add_prefix_space 属性
        self.add_prefix_space = add_prefix_space

        # 获取 tokenizer 的后处理器组件,并根据状态进行调整
        tokenizer_component = "post_processor"
        tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
        if tokenizer_component_instance:
            state = json.loads(tokenizer_component_instance.__getstate__())

            # 确保 'sep' 和 'cls' 的值为元组,以便于 post_processor_class 的对象处理
            if "sep" in state:
                state["sep"] = tuple(state["sep"])
            if "cls" in state:
                state["cls"] = tuple(state["cls"])

            changes_to_apply = False

            # 检查是否需要更新 add_prefix_space 和 trim_offsets 参数
            if state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
                state["add_prefix_space"] = add_prefix_space
                changes_to_apply = True

            if state.get("trim_offsets", trim_offsets) != trim_offsets:
                state["trim_offsets"] = trim_offsets
                changes_to_apply = True

            # 如果有改变,则更新后处理器的类型和参数
            if changes_to_apply:
                component_class = getattr(processors, state.pop("type"))
                new_value = component_class(**state)
                setattr(self.backend_tokenizer, tokenizer_component, new_value)

    @property
    def mask_token(self) -> str:
        """
        `str`: 返回掩码标记,用于在进行掩码语言建模训练时使用。如果在未设置的情况下使用,则记录错误日志。

        Longformer 分词器有一个特殊的掩码标记,可在填充掩码流程中使用。该掩码标记将贪婪地包括 *<mask>* 前面的空格。
        """
        if self._mask_token is None:
            if self.verbose:
                logger.error("Using mask_token, but it is not set yet.")
            return None
        return str(self._mask_token)

    @mask_token.setter
    def mask_token(self, value):
        """
        重写掩码标记的默认行为,使其能够吸收其前的空格。

        这是为了与基于 Longformer 的所有先前使用的模型保持向后兼容性。
        """
        # 掩码标记表现得像普通单词,即包括前面的空格,因此设置 lstrip 为 True
        value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value
        self._mask_token = value

    def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
        is_split_into_words = kwargs.get("is_split_into_words", False)
        assert self.add_prefix_space or not is_split_into_words, (
            f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
            "to use it with pretokenized inputs."
        )

        return super()._batch_encode_plus(*args, **kwargs)

    def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
        is_split_into_words = kwargs.get("is_split_into_words", False)

        assert self.add_prefix_space or not is_split_into_words, (
            f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
            "to use it with pretokenized inputs."
        )

        return super()._encode_plus(*args, **kwargs)

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        """
        将词汇表保存到指定目录,并返回保存的文件名。

        调用分词器模型的保存方法,将模型保存到指定目录中,并使用指定的文件名前缀。
        """
        files = self._tokenizer.model.save(save_directory, name=filename_prefix)
        return tuple(files)

    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        """
        根据给定的 token_ids 构建包含特殊标记的输入序列。

        将序列开始标记 (bos_token_id)、序列 0 的 token_ids 和序列结束标记 (eos_token_id) 拼接成输出列表。
        如果存在第二个序列 (token_ids_1),则将其与序列 0 后面的 eos_token_id 拼接。
        """
        output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
        if token_ids_1 is None:
            return output

        return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ):
        """
        根据给定的 token_ids 构建 token 类型 ID。

        用于指示每个 token 属于哪个序列的标识符,通常用于区分两个不同序列的 token。
        """
    ) -> List[int]:
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. Longformer does not
        make use of token type ids, therefore a list of zeros is returned.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of zeros.
        """
        # 定义分隔符和分类符的列表
        sep = [self.sep_token_id]  # 包含特殊分隔符的列表
        cls = [self.cls_token_id]  # 包含特殊分类符的列表

        # 如果没有第二个序列,返回第一个序列加上特殊符号的长度的零列表
        if token_ids_1 is None:
            return len(cls + token_ids_0 + sep) * [0]
        
        # 如果有第二个序列,返回两个序列以及特殊符号的长度的零列表
        return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]

.\models\longformer\__init__.py

# 引入类型检查标记,用于在类型检查时导入不同的模块和类
from typing import TYPE_CHECKING

# 从本地包中导入所需的工具和异常类
from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_tf_available,
    is_tokenizers_available,
    is_torch_available,
)

# 定义一个字典,描述了不同模块和类的导入结构
_import_structure = {
    "configuration_longformer": [
        "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
        "LongformerConfig",
        "LongformerOnnxConfig",
    ],
    "tokenization_longformer": ["LongformerTokenizer"],
}

# 检查是否可用 Tokenizers 库,若不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_tokenizers_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若可用,则加入 tokenization_longformer_fast 模块到导入结构中
    _import_structure["tokenization_longformer_fast"] = ["LongformerTokenizerFast"]

# 检查是否可用 Torch 库,若不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若可用,则加入 modeling_longformer 模块到导入结构中
    _import_structure["modeling_longformer"] = [
        "LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
        "LongformerForMaskedLM",
        "LongformerForMultipleChoice",
        "LongformerForQuestionAnswering",
        "LongformerForSequenceClassification",
        "LongformerForTokenClassification",
        "LongformerModel",
        "LongformerPreTrainedModel",
        "LongformerSelfAttention",
    ]

# 检查是否可用 TensorFlow 库,若不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若可用,则加入 modeling_tf_longformer 模块到导入结构中
    _import_structure["modeling_tf_longformer"] = [
        "TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
        "TFLongformerForMaskedLM",
        "TFLongformerForMultipleChoice",
        "TFLongformerForQuestionAnswering",
        "TFLongformerForSequenceClassification",
        "TFLongformerForTokenClassification",
        "TFLongformerModel",
        "TFLongformerPreTrainedModel",
        "TFLongformerSelfAttention",
    ]

# 如果在类型检查模式下,则导入特定的配置和分词器类
if TYPE_CHECKING:
    from .configuration_longformer import (
        LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
        LongformerConfig,
        LongformerOnnxConfig,
    )
    from .tokenization_longformer import LongformerTokenizer

    # 在类型检查模式下,再次检查 Tokenizers 库是否可用,若可用则导入快速分词器类
    try:
        if not is_tokenizers_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .tokenization_longformer_fast import LongformerTokenizerFast

    # 在类型检查模式下,再次检查 Torch 库是否可用,若可用则导入模型相关类
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    # 尝试导入长模型相关的依赖项,如果依赖项不可用则跳过
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入长模型相关的Python文件中的模块
        from .modeling_longformer import (
            LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
            LongformerForMaskedLM,
            LongformerForMultipleChoice,
            LongformerForQuestionAnswering,
            LongformerForSequenceClassification,
            LongformerForTokenClassification,
            LongformerModel,
            LongformerPreTrainedModel,
            LongformerSelfAttention,
        )

    # 尝试检查是否TensorFlow可用,如果不可用则跳过
    try:
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入TensorFlow长模型相关的Python文件中的模块
        from .modeling_tf_longformer import (
            TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
            TFLongformerForMaskedLM,
            TFLongformerForMultipleChoice,
            TFLongformerForQuestionAnswering,
            TFLongformerForSequenceClassification,
            TFLongformerForTokenClassification,
            TFLongformerModel,
            TFLongformerPreTrainedModel,
            TFLongformerSelfAttention,
        )
else:
    # 导入sys模块,用于对当前模块进行操作
    import sys

    # 将当前模块(__name__)的模块对象映射到_LazyModule的实例,实现懒加载模块
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\longt5\configuration_longt5.py

# 引入必要的模块和类,包括预训练配置类、OnnxSeq2SeqConfigWithPast 类和日志记录工具
""" LongT5 model configuration"""
from typing import Mapping  # 导入 Mapping 类型

from ...configuration_utils import PretrainedConfig  # 导入预训练配置类
from ...onnx import OnnxSeq2SeqConfigWithPast  # 导入 OnnxSeq2SeqConfigWithPast 类
from ...utils import logging  # 导入日志记录工具

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义预训练模型配置文件的下载映射字典,每个模型名称对应其配置文件的下载链接
LONGT5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "google/long-t5-local-base": "https://huggingface.co/google/long-t5-local-base/blob/main/config.json",
    "google/long-t5-local-large": "https://huggingface.co/google/long-t5-local-large/blob/main/config.json",
    "google/long-t5-tglobal-base": "https://huggingface.co/google/long-t5-tglobal-base/blob/main/config.json",
    "google/long-t5-tglobal-large": "https://huggingface.co/google/long-t5-tglobal-large/blob/main/config.json",
}


class LongT5Config(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`LongT5Model`] or a [`FlaxLongT5Model`]. It is
    used to instantiate a LongT5 model according to the specified arguments, defining the model architecture.
    Instantiating a configuration with the defaults will yield a similar configuration to that of the LongT5
    [google/long-t5-local-base](https://huggingface.co/google/long-t5-local-base) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    # 设置模型类型为 "longt5"
    model_type = "longt5"
    # 定义一个在推理过程中要忽略的键列表
    keys_to_ignore_at_inference = ["past_key_values"]
    # 定义一个将类属性名称映射到别名的字典
    attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}

    # 初始化方法,用于设置模型的各种参数和属性
    def __init__(
        self,
        vocab_size=32128,  # 词汇表大小,默认为32128
        d_model=512,  # 隐藏层大小,默认为512
        d_kv=64,  # 键值的维度,默认为64
        d_ff=2048,  # 前馈神经网络内部层的维度,默认为2048
        num_layers=6,  # 网络层数,默认为6
        num_decoder_layers=None,  # 解码器层数,默认为None,即与编码器层数相同
        num_heads=8,  # 注意力头的数量,默认为8
        local_radius=127,  # 本地注意力的半径,默认为127
        global_block_size=16,  # 全局块大小,默认为16
        relative_attention_num_buckets=32,  # 相对注意力的桶数,默认为32
        relative_attention_max_distance=128,  # 相对注意力的最大距离,默认为128
        dropout_rate=0.1,  # Dropout率,默认为0.1
        layer_norm_epsilon=1e-6,  # Layer normalization的epsilon值,默认为1e-6
        initializer_factor=1.0,  # 初始化因子,默认为1.0
        feed_forward_proj="relu",  # 前馈网络的激活函数,默认为relu
        is_encoder_decoder=True,  # 是否为编码器-解码器结构,默认为True
        encoder_attention_type="local",  # 编码器注意力的类型,默认为local
        use_cache=True,  # 是否使用缓存,默认为True
        pad_token_id=0,  # 填充标记的ID,默认为0
        eos_token_id=1,  # 结束标记的ID,默认为1
        **kwargs,  # 其他关键字参数,用于传递给父类构造函数
    ):
        # 设置对象的各种属性值
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.d_kv = d_kv
        self.d_ff = d_ff
        self.num_layers = num_layers
        # 如果给定的解码器层数不为None,则使用给定的值,否则使用编码器层数作为解码器层数
        self.num_decoder_layers = num_decoder_layers if num_decoder_layers is not None else self.num_layers
        self.num_heads = num_heads
        self.local_radius = local_radius
        self.global_block_size = global_block_size
        self.relative_attention_num_buckets = relative_attention_num_buckets
        self.relative_attention_max_distance = relative_attention_max_distance
        self.dropout_rate = dropout_rate
        self.layer_norm_epsilon = layer_norm_epsilon
        self.initializer_factor = initializer_factor
        self.feed_forward_proj = feed_forward_proj
        self.encoder_attention_type = encoder_attention_type
        self.use_cache = use_cache

        # 解析前馈网络激活函数的信息,提取激活函数名称和是否为门控激活函数的标志
        act_info = self.feed_forward_proj.split("-")
        self.dense_act_fn = act_info[-1]  # 提取激活函数的名称
        self.is_gated_act = act_info[0] == "gated"  # 判断是否为门控激活函数

        # 如果激活函数信息的长度超出预期或格式不正确,则抛出值错误异常
        if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
            raise ValueError(
                f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. "
                "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
                "'gated-gelu' or 'relu'"
            )

        # 对于向后兼容性,如果前馈网络激活函数设为'gated-gelu',则更新为'gelu_new'
        if feed_forward_proj == "gated-gelu":
            self.dense_act_fn = "gelu_new"

        # 调用父类的初始化方法,传递填充标记ID、结束标记ID、是否为编码器-解码器等参数以及其他关键字参数
        super().__init__(
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            is_encoder_decoder=is_encoder_decoder,
            **kwargs,
        )
# 定义一个名为 LongT5OnnxConfig 的类,继承自 OnnxSeq2SeqConfigWithPast 类
class LongT5OnnxConfig(OnnxSeq2SeqConfigWithPast):
    
    # inputs 属性,返回一个映射,描述了模型的输入结构
    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        # 定义通用的输入格式,包括 input_ids 和 attention_mask
        common_inputs = {
            "input_ids": {0: "batch", 1: "encoder_sequence"},
            "attention_mask": {0: "batch", 1: "encoder_sequence"},
        }
        
        # 如果使用过去信息(use_past 为 True)
        if self.use_past:
            # 调整 attention_mask 的描述以包括过去编码器序列
            common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
            # 添加 decoder_input_ids 的描述
            common_inputs["decoder_input_ids"] = {0: "batch"}
            # 添加 decoder_attention_mask 的描述,包括过去解码器序列
            common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
        else:
            # 如果不使用过去信息,添加普通的 decoder_input_ids 描述
            common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
            # 添加普通的 decoder_attention_mask 描述
            common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}

        # 如果使用过去信息,调用 fill_with_past_key_values_ 方法填充 common_inputs
        if self.use_past:
            self.fill_with_past_key_values_(common_inputs, direction="inputs")

        # 返回描述输入结构的字典 common_inputs
        return common_inputs

    # default_onnx_opset 属性,返回默认的 ONNX 运算集版本号
    @property
    def default_onnx_opset(self) -> int:
        # 返回 ONNX 运算集版本号 13
        return 13

.\models\longt5\convert_longt5x_checkpoint_to_flax.py

# 导入必要的库和模块
import argparse  # 导入命令行参数解析模块

from t5x import checkpoints  # 导入从原始T5X模型检查点加载模块

from transformers import AutoConfig, FlaxAutoModelForSeq2SeqLM  # 导入自动配置模块和FLAX的序列到序列模型


def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path):
    # 使用给定的配置名创建自动配置对象
    config = AutoConfig.from_pretrained(config_name)
    # 根据配置创建FLAX的序列到序列模型
    flax_model = FlaxAutoModelForSeq2SeqLM.from_config(config=config)
    # 加载T5X模型检查点
    t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)

    # 检查是否需要分离MLP的权重
    split_mlp_wi = "wi_0" in t5x_model["target"]["encoder"]["layers_0"]["mlp"]

    # 根据配置类型确定编码器注意力机制的名称
    if config.model_type == "t5":
        encoder_attn_name = "SelfAttention"
    if config.model_type == "longt5" and config.encoder_attention_type == "local":
        encoder_attn_name = "LocalSelfAttention"
    elif config.model_type == "longt5" and config.encoder_attention_type == "transient-global":
        encoder_attn_name = "TransientGlobalSelfAttention"
    else:
        # 如果配置不匹配预期的类型和注意力机制,引发错误
        raise ValueError(
            "Given config is expected to have `model_type='t5'`, or `model_type='longt5` with `encoder_attention_type`"
            " attribute with a value from ['local', 'transient-global]."
        )

    # 编码器部分
    # 仅针对第0层处理:
    # 从T5X模型中提取编码器相对位置嵌入,并将其赋值给FLAX模型的相应部分
    t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T
    flax_model.params["encoder"]["block"]["0"]["layer"]["0"][encoder_attn_name]["relative_attention_bias"][
        "embedding"
    ] = t5x_encoder_rel_embedding

    # 当模型类型为longt5且编码器注意力机制为transient-global时,处理全局相对位置偏差和层归一化
    if config.model_type == "longt5" and config.encoder_attention_type == "transient-global":
        t5x_encoder_global_rel_embedding = t5x_model["target"]["encoder"]["side_relpos_bias"]["rel_embedding"].T
        flax_model.params["encoder"]["block"]["0"]["layer"]["0"][encoder_attn_name]["global_relative_attention_bias"][
            "embedding"
        ] = t5x_encoder_global_rel_embedding

    # 赋值编码器的最终层归一化参数
    t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"]
    flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm

    # 解码器部分
    # 赋值解码器的最终层归一化参数
    tx5_decoder_norm = t5x_model["target"]["decoder"]["decoder_norm"]["scale"]
    flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm
    # 只适用于层级 0:

    # 从 T5X 模型中获取目标部分解码器的相对位置偏置的嵌入矩阵,并进行转置
    t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T
    
    # 将转置后的相对注意力偏置嵌入矩阵赋值给 Flax 模型的对应参数
    flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][
        "embedding"
    ] = t5x_decoder_rel_embedding

    # Token Embeddings

    # 从 T5X 模型中获取目标部分的 token 嵌入(词嵌入)矩阵
    tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"]
    
    # 将获取的 token 嵌入矩阵赋值给 Flax 模型的共享嵌入层参数
    flax_model.params["shared"]["embedding"] = tx5_token_embeddings

    # LM Head (only in v1.1 and LongT5 checkpoints)

    # 检查 T5X 模型中是否存在 logits_dense 属性,通常出现在 v1.1 和 LongT5 检查点中
    if "logits_dense" in t5x_model["target"]["decoder"]:
        # 将 T5X 模型中的 logits_dense 的核(权重矩阵)赋值给 Flax 模型的语言模型头部参数
        flax_model.params["lm_head"]["kernel"] = t5x_model["target"]["decoder"]["logits_dense"]["kernel"]

    # 将转换后的 Flax 模型保存到指定的文件夹路径
    flax_model.save_pretrained(flax_dump_folder_path)
    
    # 打印转换成功的提示信息
    print("T5X Model was sucessfully converted!")
if __name__ == "__main__":
    # 如果当前脚本作为主程序执行,则进入主程序入口

    parser = argparse.ArgumentParser()
    # 创建参数解析器对象

    # 必需的参数
    parser.add_argument(
        "--t5x_checkpoint_path", default=None, type=str, required=True, help="Path the T5X checkpoint."
    )
    # 添加命令行参数:T5X 模型的检查点路径,必需,类型为字符串,帮助信息为路径到 T5X 检查点的路径

    parser.add_argument("--config_name", default=None, type=str, required=True, help="Config name of LongT5/T5 model.")
    # 添加命令行参数:LongT5/T5 模型的配置名称,必需,类型为字符串,帮助信息为模型配置的名称

    parser.add_argument(
        "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model."
    )
    # 添加命令行参数:FLAX 模型的输出文件夹路径,必需,类型为字符串,帮助信息为输出 FLAX 模型的文件夹路径

    args = parser.parse_args()
    # 解析命令行参数并存储到 args 对象中

    convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path)
    # 调用函数 convert_t5x_checkpoint_to_flax,传入命令行参数中的 T5X 模型路径、配置名称和 FLAX 输出文件夹路径作为参数

.\models\longt5\modeling_flax_longt5.py

# 导入所需的模块和类
import copy  # 导入copy模块,用于复制对象
from typing import Any, Callable, List, Optional, Tuple  # 导入类型提示相关的模块

import flax.linen as nn  # 导入Flax的linen模块,并命名为nn
import jax  # 导入JAX库
import jax.numpy as jnp  # 导入JAX中的NumPy模块,并命名为jnp
import numpy as np  # 导入NumPy库,并命名为np
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze  # 从Flax中导入相关类和函数
from flax.linen import combine_masks, make_causal_mask  # 导入Flax的函数和类
from flax.linen import partitioning as nn_partitioning  # 导入Flax的partitioning模块,并命名为nn_partitioning
from flax.linen.attention import dot_product_attention_weights  # 导入注意力机制相关函数
from flax.traverse_util import flatten_dict, unflatten_dict  # 导入Flax的工具函数
from jax.random import PRNGKey  # 从JAX中导入PRNGKey类

from ...modeling_flax_outputs import (  # 导入模型输出相关的类
    FlaxBaseModelOutput,
    FlaxBaseModelOutputWithPastAndCrossAttentions,
    FlaxCausalLMOutputWithCrossAttentions,
    FlaxSeq2SeqLMOutput,
    FlaxSeq2SeqModelOutput,
)
from ...modeling_flax_utils import (  # 导入模型工具函数和类
    ACT2FN,
    FlaxPreTrainedModel,
    append_call_sample_docstring,
    append_replace_return_docstrings,
    overwrite_call_docstring,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings  # 导入工具函数和类
from .configuration_longt5 import LongT5Config  # 导入LongT5Config配置类


logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

_CHECKPOINT_FOR_DOC = "google/long-t5-local-base"  # 预训练模型的检查点名称,用于文档
_CONFIG_FOR_DOC = "LongT5Config"  # 配置文件的名称,用于文档

remat = nn_partitioning.remat  # 将nn_partitioning.remat函数赋值给remat变量


# 从transformers.models.bart.modeling_flax_bart.shift_tokens_right复制而来
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
    """
    将输入的token向右移动一个位置。
    """
    shifted_input_ids = jnp.zeros_like(input_ids)  # 创建一个和input_ids相同形状的零数组
    shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])  # 将input_ids向右移动一个位置
    shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)  # 设置起始位置为decoder_start_token_id

    shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)  # 如果shifted_input_ids等于-100,则设置为pad_token_id
    return shifted_input_ids  # 返回移动后的input_ids


def _pad_to_multiple(x: jnp.ndarray, block_len: int, axis: int, pad_value: int = 0) -> jnp.ndarray:
    """将数组填充到长度为block_len的倍数"""
    pad_len = -x.shape[axis] % block_len  # 计算需要填充的长度
    pad = [(0, 0)] * x.ndim  # 创建填充元组列表,维度与x相同
    pad[axis] = (0, pad_len)  # 设置axis维度的填充范围
    x = jnp.pad(x, pad_width=pad, mode="constant", constant_values=pad_value)  # 使用常数值pad_value进行填充
    return x  # 返回填充后的数组


def _split_into_blocks(x: jnp.ndarray, block_len: int, axis: int) -> jnp.ndarray:
    """沿着指定轴将输入数组分割成指定长度的块"""
    # 如果维度长度
    # 如果张量 x 在指定轴上的长度不是 block_len 的倍数,将使用 pad_value 进行填充
    # pad tensor to multiple of block_len
    if x.shape[axis] % block_len != 0:
        x = _pad_to_multiple(x, block_len, axis, pad_value=0)
    
    # 计算张量 x 在指定轴上被分成的块数
    num_blocks = x.shape[axis] // block_len
    
    # 构建输出张量的形状,保持除了指定轴外的其他维度不变,将指定轴的长度划分为 num_blocks 个块,每块长度为 block_len
    output_shape = x.shape[:axis] + (num_blocks, block_len) + x.shape[(axis + 1):]
    
    # 返回重塑后的张量,以形成指定的 output_shape
    return x.reshape(output_shape)
# 定义一个函数,用于将输入的数组 x 按指定轴 block_axis 进行扩展,使其在该轴上长度增加 2
# 其他轴不变,用常数值 pad_value 进行填充
def _concatenate_3_blocks(x: jnp.ndarray, block_axis: int, sequence_axis: int, pad_value: int = 0) -> jnp.ndarray:
    """Concatenate three consecutive blocks for each input block for local attentiont.
    For more information, see: https://arxiv.org/pdf/2112.07916.pdf.
    """
    num_blocks = x.shape[block_axis]

    pad = [(0, 0)] * x.ndim
    pad[block_axis] = (1, 1)
    # [batch_size, num_blocks, block_len] -> [batch_size, num_blocks + 2, block_len]
    x = jnp.pad(x, pad_width=pad, mode="constant", constant_values=pad_value)

    blocks_list: List[np.array] = []
    for i in range(3):
        # 我们在这里使用索引的方法:
        # https://numpy.org/doc/stable/user/basics.indexing.html#dealing-with-variable-numbers-of-indices-within-programs
        indices = [slice(0, None)] * x.ndim
        indices[block_axis] = slice(i, i + num_blocks)
        indices = tuple(indices)
        blocks_list.append(x[indices])
    # 返回沿着 sequence_axis 轴连接后的数组
    return jnp.concatenate(blocks_list, axis=sequence_axis)  # [batch_size, num_blocks, 3 * block_len, ...]


def _make_3block_relative_position_ids(block_len: int) -> jnp.ndarray:
    """Makes 3-blocked relative position ids for local attention."""
    position_ids = jnp.arange(3 * block_len, dtype=jnp.int32)
    center_position_ids = position_ids[block_len:-block_len]
    relative_position_ids = position_ids[None, :] - center_position_ids[:, None]  # [block_len, 3 * block_len]
    return relative_position_ids


def _mask_local_attention_mask(local_attention_mask: np.ndarray, block_len: int) -> jnp.ndarray:
    """Mask local attention mask to enforce that tokens are not allowed to attend tokens farther than ``local_radius."""
    relative_position_ids = _make_3block_relative_position_ids(block_len)
    locality_mask = jnp.abs(relative_position_ids) < block_len
    locality_mask = locality_mask[None, None, :, :]
    return jnp.logical_and(local_attention_mask, locality_mask)


def _get_local_attention_mask(attention_mask: np.ndarray, block_len: int) -> jnp.ndarray:
    """Prepare attention mask to be applied for a local attention."""
    # [batch_size, num_blocks, block_len]
    _blocked_attention_mask = _split_into_blocks(attention_mask, block_len, axis=1)
    # [batch_size, num_block, 3 * block_len]
    _3blocked_attention_mask = _concatenate_3_blocks(_blocked_attention_mask, block_axis=1, sequence_axis=2)

    _blocked_attention_mask = _blocked_attention_mask[..., None]
    _3blocked_attention_mask = _3blocked_attention_mask[..., None, :]
    # [batch_size, num_block, block_len, 3 * block_len]
    local_attention_mask = jnp.logical_and(_blocked_attention_mask, _3blocked_attention_mask)
    local_attention_mask = _mask_local_attention_mask(local_attention_mask, block_len)
    # [batch_size, 1, num_block, block_len, 3 * block_len]
    return local_attention_mask[:, None, ...]


def _make_global_fixed_block_ids(attention_mask: np.ndarray, global_block_size: int) -> Tuple[jnp.ndarray, np.ndarray]:
    """Make global fixed block ids for global attention."""
    ...
    """Obtain the "fixed block" global id corresponding to each input token.

    This implementation is a simlified version of the original Flaxformr implementation adopted from:
    https://github.com/google/flaxformer/blob/main/flaxformer/architectures/longt5/long_attention.py.

    In our scenario, as we use this strategy only for a decoder, orphan tokens, i.e. those tokens which do not make for
    the whole fixed block, are assigned to the preceding block.

    Padding tokens from the original sequence are represented by -1.
    """
    # 获取注意力掩码的批量大小和序列长度
    batch_size, seq_len = attention_mask.shape[:2]

    # 处理孤立标记的函数,将孤立的标记分配给前一个块
    def handle_orphan_tokens(block_ids: np.ndarray) -> jnp.ndarray:
        # 计算每个块的结束位置
        block_ends = (jnp.arange(seq_len) % global_block_size) == global_block_size - 1
        # 确定真实的块结束位置,同时确保块ID非负数
        true_block_ends = jnp.logical_and(block_ends, block_ids >= 0)
        # 统计完整块的数量
        full_blocks = true_block_ends.sum(-1)[..., None]
        # 将块ID限制在完整块的数量范围内
        block_ids = jnp.minimum(block_ids, full_blocks - 1)
        return block_ids

    # 创建固定块掩码,每个位置上的值为全局块大小的倒数
    fixed_block_mask = jnp.ones_like(attention_mask) / global_block_size
    # 对固定块掩码进行累积求和,并调整每个位置的值
    fixed_block_mask = jnp.cumsum(fixed_block_mask, axis=1) - fixed_block_mask
    # 根据注意力掩码设置掩码数组,非零位置设为1.0,零位置设为-1000.0
    mask = jnp.where(attention_mask != 0.0, 1.0, -1000.0)
    # 计算全局块ID,最大值为累积和减1,至少为-1.0(与注意力掩码数据类型相同)
    global_block_ids = jnp.maximum(
        jnp.floor(mask + fixed_block_mask - 1.0), jnp.array(-1.0, dtype=attention_mask.dtype)
    )
    # 将填充标记设为-1
    global_block_ids = (global_block_ids * attention_mask) + (attention_mask - 1)
    # 对孤立标记进行处理,保证块ID的正确性
    global_block_ids = handle_orphan_tokens(global_block_ids)
    # 计算全局块的数量
    num_globals = seq_len // global_block_size

    # 计算全局段ID,维度为[batch_size, seq_len // global_block_size]
    if num_globals > 0:
        # 如果存在全局块,则将全局块ID的最大值重复到每个全局块的数量上
        _sequence_block_ids_max = jnp.repeat(global_block_ids.max(axis=-1)[:, None], repeats=num_globals, axis=1)
    else:
        # 如果不存在全局块,则创建零数组
        _sequence_block_ids_max = jnp.zeros((batch_size, 0), dtype=global_block_ids.dtype)
    # 计算全局段ID,通过累积求和方法生成,每个全局段ID小于等于对应的块ID设为1,否则设为0
    global_segment_ids = jnp.cumsum(jnp.ones((batch_size, num_globals)), axis=-1) - 1
    global_segment_ids = jnp.where(global_segment_ids <= _sequence_block_ids_max, 1, 0)
    # 返回全局块ID和全局段ID
    return global_block_ids, global_segment_ids
# 创建用于本地到全局注意力的相对位置张量
def _make_side_relative_position_ids(attention_mask: np.ndarray, global_block_size: int) -> np.ndarray:
    # 调用函数生成全局固定块 ID 和全局段 ID
    block_ids, global_segment_ids = _make_global_fixed_block_ids(attention_mask, global_block_size)
    # 获取全局序列长度
    global_seq_len = global_segment_ids.shape[-1]
    # 创建全局位置索引
    global_positions = jnp.arange(global_seq_len)
    # 计算侧向相对位置张量
    side_relative_position = global_positions - block_ids[..., None]
    return side_relative_position


# 计算通过对各个块进行求和得到的各个块聚合
def _create_global_aggregates(hidden_states: np.ndarray, block_ids: np.ndarray, global_seq_len: int) -> np.ndarray:
    """Compute individual block aggregates by summing over individual blocks."""
    # 创建块 ID 的独热编码张量
    one_hot_block_ids = jax.nn.one_hot(block_ids, global_seq_len)
    # 执行张量乘积以计算块聚合
    return jnp.einsum("...nd,...ng->...gd", hidden_states, one_hot_block_ids)


# 从 transformers.models.t5.modeling_flax_t5.FlaxT5LayerNorm 复制并将 T5 更改为 LongT5
class FlaxLongT5LayerNorm(nn.Module):
    hidden_size: int
    dtype: jnp.dtype = jnp.float32
    eps: float = 1e-6
    weight_init: Callable[..., np.ndarray] = jax.nn.initializers.ones

    def setup(self):
        self.weight = self.param("weight", self.weight_init, (self.hidden_size,))

    def __call__(self, hidden_states):
        """
        Construct a layernorm module in the LongT5 style; No bias and no subtraction of mean.
        """
        # 总是使用 float32 计算层归一化
        variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True)
        # 计算标准差并进行归一化
        hidden_states = hidden_states / jnp.sqrt(variance + self.eps)

        return self.weight * hidden_states


# 从 transformers.models.t5.modeling_flax_t5.FlaxT5DenseActDense 复制并将 T5 更改为 LongT5
class FlaxLongT5DenseActDense(nn.Module):
    config: LongT5Config
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # 初始化权重标准差
        wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
        wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)

        # 定义输入密集层(无偏置)
        self.wi = nn.Dense(
            self.config.d_ff,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(wi_init_std),
            dtype=self.dtype,
        )
        # 定义激活函数
        self.act = ACT2FN[self.config.dense_act_fn]
        # 定义丢弃层
        self.dropout = nn.Dropout(self.config.dropout_rate)
        # 定义输出密集层(无偏置)
        self.wo = nn.Dense(
            self.config.d_model,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(wo_init_std),
            dtype=self.dtype,
        )

    def __call__(self, hidden_states, deterministic=True):
        # 输入到输入密集层
        hidden_states = self.wi(hidden_states)
        # 应用激活函数
        hidden_states = self.act(hidden_states)
        # 使用丢弃层(如果不是确定性的)
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        # 输入到输出密集层
        hidden_states = self.wo(hidden_states)
        return hidden_states


# 从 transformers.models.t5.modeling_flax_t5.FlaxT5DenseGatedActDense 复制并将 T5 更改为 LongT5
# 定义一个名为 FlaxLongT5DenseGatedActDense 的类,继承自 nn.Module
class FlaxLongT5DenseGatedActDense(nn.Module):
    # 类变量 config,类型为 LongT5Config,表示配置信息
    config: LongT5Config
    # 类变量 dtype,默认为 jnp.float32,表示计算中使用的数据类型

    # 初始化方法 setup,用于设置网络层
    def setup(self):
        # 初始化权重矩阵的标准差,wi_init_std 和 wo_init_std 分别为 d_model 和 d_ff 的倒数乘以 initializer_factor 的结果
        wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
        wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)

        # 创建第一个全连接层 wi_0
        self.wi_0 = nn.Dense(
            self.config.d_ff,  # 输出维度为 d_ff
            use_bias=False,  # 不使用偏置
            kernel_init=jax.nn.initializers.normal(wi_init_std),  # 使用正态分布初始化权重
            dtype=self.dtype,  # 指定数据类型为 dtype
        )
        # 创建第二个全连接层 wi_1
        self.wi_1 = nn.Dense(
            self.config.d_ff,  # 输出维度为 d_ff
            use_bias=False,  # 不使用偏置
            kernel_init=jax.nn.initializers.normal(wi_init_std),  # 使用正态分布初始化权重
            dtype=self.dtype,  # 指定数据类型为 dtype
        )
        # 创建输出全连接层 wo
        self.wo = nn.Dense(
            self.config.d_model,  # 输出维度为 d_model
            use_bias=False,  # 不使用偏置
            kernel_init=jax.nn.initializers.normal(wo_init_std),  # 使用正态分布初始化权重
            dtype=self.dtype,  # 指定数据类型为 dtype
        )
        # 创建 Dropout 层,使用配置中的 dropout_rate
        self.dropout = nn.Dropout(self.config.dropout_rate)
        # 根据配置中的 dense_act_fn 选择激活函数,并赋值给 act 变量
        self.act = ACT2FN[self.config.dense_act_fn]

    # 实现 __call__ 方法,定义类的可调用行为
    def __call__(self, hidden_states, deterministic):
        # 计算使用激活函数处理后的 hidden_states
        hidden_gelu = self.act(self.wi_0(hidden_states))
        # 计算 hidden_states 经过第二个全连接层的结果
        hidden_linear = self.wi_1(hidden_states)
        # 计算最终的 hidden_states,是经过门控激活函数处理的结果
        hidden_states = hidden_gelu * hidden_linear
        # 对 hidden_states 应用 Dropout,根据 deterministic 参数确定是否确定性执行
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        # 计算输出结果,经过全连接层 wo 处理
        hidden_states = self.wo(hidden_states)
        # 返回处理后的 hidden_states
        return hidden_states


# 从 transformers.models.t5.modeling_flax_t5.FlaxT5LayerFF 复制并修改为使用 LongT5
# 定义一个名为 FlaxLongT5LayerFF 的类,继承自 nn.Module
class FlaxLongT5LayerFF(nn.Module):
    # 类变量 config,类型为 LongT5Config,表示配置信息
    config: LongT5Config
    # 类变量 dtype,默认为 jnp.float32,表示计算中使用的数据类型

    # 初始化方法 setup,用于设置网络层
    def setup(self):
        # 如果配置中指定使用门控激活函数
        if self.config.is_gated_act:
            # 创建使用门控激活函数的 Dense 层对象
            self.DenseReluDense = FlaxLongT5DenseGatedActDense(self.config, dtype=self.dtype)
        else:
            # 创建使用普通激活函数的 Dense 层对象
            self.DenseReluDense = FlaxLongT5DenseActDense(self.config, dtype=self.dtype)

        # 创建 Layer Norm 层对象,使用 LongT5 的配置信息
        self.layer_norm = FlaxLongT5LayerNorm(
            self.config.d_model,  # 归一化的维度为 d_model
            eps=self.config.layer_norm_epsilon,  # 设置 epsilon 参数
            dtype=self.dtype,  # 指定数据类型为 dtype
        )
        # 创建 Dropout 层,使用配置中的 dropout_rate
        self.dropout = nn.Dropout(self.config.dropout_rate)

    # 实现 __call__ 方法,定义类的可调用行为
    def __call__(self, hidden_states, deterministic=True):
        # 对输入的 hidden_states 进行 Layer Norm 处理
        forwarded_states = self.layer_norm(hidden_states)
        # 将处理后的 hidden_states 传入 DenseReluDense 层对象处理
        forwarded_states = self.DenseReluDense(forwarded_states, deterministic=deterministic)
        # 加上 Dropout 处理后的 forwarded_states,并将结果加回到原始的 hidden_states 上
        hidden_states = hidden_states + self.dropout(forwarded_states, deterministic=deterministic)
        # 返回处理后的 hidden_states
        return hidden_states


# 从 transformers.models.t5.modeling_flax_t5.FlaxT5Attention 复制并修改为使用 LongT5
# 定义一个名为 FlaxLongT5Attention 的类,继承自 nn.Module
class FlaxLongT5Attention(nn.Module):
    # 类变量 config,类型为 LongT5Config,表示配置信息
    config: LongT5Config
    # 类变量 has_relative_attention_bias,默认为 False,表示是否有相对位置编码的注意力偏置
    has_relative_attention_bias: bool = False
    # 类变量 causal,默认为 False,表示是否是因果(自回归)注意力机制
    causal: bool = False
    # 类变量 dtype,默认为 jnp.float32,表示计算中使用的数据类型
    # 设置模型的初始化参数和配置
    def setup(self):
        # 设置相对注意力机制的桶数,从配置中获取
        self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
        # 设置相对注意力机制的最大距离,从配置中获取
        self.relative_attention_max_distance = self.config.relative_attention_max_distance
        # 设置模型的维度,从配置中获取
        self.d_model = self.config.d_model
        # 设置键值投影的维度,从配置中获取
        self.key_value_proj_dim = self.config.d_kv
        # 设置注意力头的数量,从配置中获取
        self.n_heads = self.config.num_heads
        # 设置 dropout 率,从配置中获取
        self.dropout = self.config.dropout_rate
        # 计算内部维度,注意力头数量乘以键值投影维度
        self.inner_dim = self.n_heads * self.key_value_proj_dim

        # 初始化查询向量的标准差
        q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
        # 初始化键值向量的标准差
        kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
        # 初始化输出向量的标准差
        o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)

        # 初始化查询向量的 Dense 层
        self.q = nn.Dense(
            self.inner_dim,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(q_init_std),
            dtype=self.dtype,
        )
        # 初始化键向量的 Dense 层
        self.k = nn.Dense(
            self.inner_dim,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(kv_init_std),
            dtype=self.dtype,
        )
        # 初始化值向量的 Dense 层
        self.v = nn.Dense(
            self.inner_dim,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(kv_init_std),
            dtype=self.dtype,
        )
        # 初始化输出向量的 Dense 层
        self.o = nn.Dense(
            self.d_model,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(o_init_std),
            dtype=self.dtype,
        )

        # 如果模型具有相对注意力偏置,则初始化相对注意力偏置的嵌入层
        if self.has_relative_attention_bias:
            self.relative_attention_bias = nn.Embed(
                self.relative_attention_num_buckets,
                self.n_heads,
                embedding_init=jax.nn.initializers.normal(kv_init_std),
                dtype=self.dtype,
            )

    @staticmethod
    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
        """
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

        Translate relative position to a bucket number for relative attention. The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
        This should allow for more graceful generalization to longer sequences than the model has been trained on
        """
        relative_buckets = 0
        if bidirectional:
            # If bidirectional, adjust the number of buckets and determine if relative_position is positive
            num_buckets //= 2
            relative_buckets += (relative_position > 0) * num_buckets
            relative_position = jnp.abs(relative_position)
        else:
            # If not bidirectional, ensure relative_position is non-positive
            relative_position = -jnp.clip(relative_position, a_max=0)
        
        # Ensure relative_position is in the range [0, inf)
        # half of the buckets are for exact increments in positions
        max_exact = num_buckets // 2
        is_small = relative_position < max_exact

        # Compute bucket index for larger positions logarithmically
        relative_position_if_large = max_exact + (
            jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
        )
        relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)

        # Determine final relative_bucket based on whether relative_position is small or large
        relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)

        return relative_buckets.astype("i4")

    def compute_bias(self, query_length, key_length):
        """Compute binned relative position bias"""
        # Create matrices of context and memory positions
        context_position = jnp.arange(query_length, dtype="i4")[:, None]
        memory_position = jnp.arange(key_length, dtype="i4")[None, :]

        # Compute relative position as memory_position - context_position
        relative_position = memory_position - context_position

        # Compute relative_position_bucket using _relative_position_bucket function
        relative_position_bucket = self._relative_position_bucket(
            relative_position,
            bidirectional=(not self.causal),  # Determine bidirectionality based on 'causal' attribute
            num_buckets=self.relative_attention_num_buckets,  # Number of buckets for relative positions
            max_distance=self.relative_attention_max_distance,  # Maximum distance for mapping to buckets
        )

        # Obtain relative_attention_bias values based on computed relative_position_bucket
        values = self.relative_attention_bias(relative_position_bucket)

        # Rearrange values to match expected dimensions
        values = values.transpose((2, 0, 1))[None, :, :, :]

        return values

    def _split_heads(self, hidden_states):
        # Reshape hidden_states to split into heads
        return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
    # 将隐藏状态重塑为指定形状,用于后续操作
    def _merge_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,))

    # 使用Flax的装饰器定义一个紧凑的函数,将投影后的键值状态与查询状态连接到缓存的先前状态
    @nn.compact
    def _concatenate_to_cache(self, key, value, query, attention_mask):
        """
        This function takes projected key, value states from a single input token and concatenates the states to cached
        states from previous steps. This function is slighly adapted from the official Flax repository:
        https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
        """
        # 检测是否通过缺少现有缓存数据来初始化
        is_initialized = self.has_variable("cache", "cached_key")
        # 初始化或获取缓存的键和值,使用零向量填充
        cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
        cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
        # 获取缓存的索引,指示当前缓存的位置
        cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))

        if is_initialized:
            # 解构缓存的形状以便更新
            *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
            # 使用新的1D空间切片更新键和值缓存
            cur_index = cache_index.value
            indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
            key = jax.lax.dynamic_update_slice(cached_key.value, key, indices)
            value = jax.lax.dynamic_update_slice(cached_value.value, value, indices)
            # 更新缓存中的键和值
            cached_key.value = key
            cached_value.value = value
            # 更新缓存索引,表示已更新的缓存向量数量
            num_updated_cache_vectors = query.shape[1]
            cache_index.value = cache_index.value + num_updated_cache_vectors
            # 生成用于缓存的自注意力掩码:单个查询位置只能注意到已生成和缓存的键位置,而非剩余的零元素。
            pad_mask = jnp.broadcast_to(
                jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
                tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
            )
            # 合并填充掩码和输入掩码
            attention_mask = combine_masks(pad_mask, attention_mask)
        
        # 返回更新后的键、值和注意力掩码
        return key, value, attention_mask

    # 创建位置偏置,用于注意力机制中的位置编码
    def _create_position_bias(
        self, key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift
        ):
            # 检查缓存是否已填充,并且当前场景支持因果关注(causal),并且缓存中没有初始化(init_cache 为 False)
            cache_is_filled = self.causal and self.has_variable("cache", "cached_key") and (not init_cache)
            # 计算关键字状态的长度
            key_length = key_states.shape[1]
            # 如果缓存已填充,则查询长度等于关键字状态的长度,否则等于查询状态的长度
            query_length = key_length if cache_is_filled else query_states.shape[1]

            # 如果模型支持相对注意偏置,则计算位置偏置
            if self.has_relative_attention_bias:
                position_bias = self.compute_bias(query_length, key_length)
            # 否则,如果有注意力掩码,则创建与注意力掩码相同形状的全零数组作为位置偏置
            elif attention_mask is not None:
                position_bias = jnp.zeros_like(attention_mask)
            # 否则,创建形状为 (1, self.n_heads, query_length, key_length) 的全零数组作为位置偏置
            else:
                position_bias = jnp.zeros((1, self.n_heads, query_length, key_length), dtype=self.dtype)

            # 如果缓存已填充,则只需取最后一个查询位置的偏置
            if cache_is_filled:
                max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
                position_bias = jax.lax.dynamic_slice(
                    position_bias,
                    (0, 0, causal_attention_mask_shift, 0),
                    (1, self.n_heads, seq_length, max_decoder_length),
                )
            # 返回计算得到的位置偏置
            return position_bias

        # 对象调用函数
        def __call__(
            self,
            hidden_states,
            attention_mask=None,
            key_value_states=None,
            position_bias=None,
            use_cache=False,
            output_attentions=False,
            deterministic=True,
            init_cache=False,
class FlaxLongT5LocalAttention(nn.Module):
    config: LongT5Config
    has_relative_attention_bias: bool = False
    dtype: jnp.dtype = jnp.float32  # 计算中使用的数据类型

    def setup(self):
        self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
        self.relative_attention_max_distance = self.config.relative_attention_max_distance
        self.d_model = self.config.d_model
        self.key_value_proj_dim = self.config.d_kv
        self.n_heads = self.config.num_heads
        self.local_radius = self.config.local_radius
        self.block_len = self.local_radius + 1
        self.dropout = self.config.dropout_rate
        self.inner_dim = self.n_heads * self.key_value_proj_dim

        q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
        kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
        o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)

        # 创建查询权重矩阵,用于计算查询 Q
        self.q = nn.Dense(
            self.inner_dim,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(q_init_std),
            dtype=self.dtype,
        )
        # 创建键权重矩阵,用于计算键 K
        self.k = nn.Dense(
            self.inner_dim,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(kv_init_std),
            dtype=self.dtype,
        )
        # 创建值权重矩阵,用于计算值 V
        self.v = nn.Dense(
            self.inner_dim,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(kv_init_std),
            dtype=self.dtype,
        )
        # 创建输出权重矩阵,用于计算输出 O
        self.o = nn.Dense(
            self.d_model,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(o_init_std),
            dtype=self.dtype,
        )

        # 如果配置中需要相对注意力偏置,则创建相对注意力偏置 Embed 层
        if self.has_relative_attention_bias:
            self.relative_attention_bias = nn.Embed(
                self.relative_attention_num_buckets,
                self.n_heads,
                embedding_init=jax.nn.initializers.normal(kv_init_std),
            )

    @staticmethod
    # 从 transformers.models.t5.modeling_flax_t5.FlaxT5Attention._relative_position_bucket 复制而来
    def _relative_position_bucket(x, max_distance: int, num_buckets: int, bidirectional: bool = True):
        """
        根据相对位置计算桶索引,用于生成相对位置偏置。

        Args:
            x: 相对位置
            max_distance: 最大距离
            num_buckets: 桶的数量
            bidirectional: 是否双向

        Returns:
            相对位置的桶索引
        """
    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
        """
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

        Translate relative position to a bucket number for relative attention. The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
        This should allow for more graceful generalization to longer sequences than the model has been trained on
        """
        relative_buckets = 0
        # 如果是双向注意力机制,则将桶的数量减半,并根据相对位置的正负确定桶的位置
        if bidirectional:
            num_buckets //= 2
            relative_buckets += (relative_position > 0) * num_buckets
            relative_position = jnp.abs(relative_position)
        else:
            # 如果是单向注意力机制,则将相对位置修正为非正数
            relative_position = -jnp.clip(relative_position, a_max=0)
        # 现在 relative_position 范围为 [0, inf)

        # 将较小的相对位置映射到更小的桶,将较大的相对位置映射到更大的桶
        max_exact = num_buckets // 2
        is_small = relative_position < max_exact

        # 较大相对位置映射到更大的桶,使用对数增长来平衡较大的相对位置
        relative_position_if_large = max_exact + (
            jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
        )
        relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)

        # 根据相对位置大小选择相应的桶
        relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)

        return relative_buckets.astype("i4")

    def compute_bias(self, block_length: int):
        """Compute binned relative position bias"""
        # 创建记忆位置和上下文位置数组
        memory_position = jnp.arange(3 * block_length, dtype="i4")
        context_position = memory_position[block_length:-block_length]

        # 计算相对位置并将其转换为相对位置桶
        relative_position = memory_position[None, :] - context_position[:, None]
        relative_position_bucket = self._relative_position_bucket(
            relative_position,
            bidirectional=True,
            num_buckets=self.relative_attention_num_buckets,
            max_distance=self.relative_attention_max_distance,
        )

        # 获取相对注意力偏置的值
        values = self.relative_attention_bias(relative_position_bucket)
        values = values.transpose((2, 0, 1))[None, None, :, :, :]
        return values

    def _split_heads(self, hidden_states):
        # 将隐藏状态张量按头数目拆分
        return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
    # 将隐藏状态重新形状化为 (batch_size, sequence_length, inner_dim)
    def _merge_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[0], -1, self.inner_dim)

    # 创建位置偏置矩阵,用于注意力机制中的位置编码
    def _create_position_bias(self, block_len: int, attention_mask: Optional[np.ndarray]) -> np.ndarray:
        # position_bias 的形状: (1, 1, n_heads, block_len, 3 * block_len)
        if self.has_relative_attention_bias:
            # 如果模型支持相对注意力偏置,则计算相对偏置
            position_bias = self.compute_bias(block_len)
        elif attention_mask is not None:
            # 如果有注意力遮罩,则创建一个与其形状相同的零矩阵作为位置偏置
            position_bias = jnp.zeros_like(attention_mask)
        else:
            # 否则创建一个形状为 (1, 1, self.n_heads, block_len, 3 * block_len) 的零矩阵作为位置偏置
            position_bias = jnp.zeros((1, 1, self.n_heads, block_len, 3 * block_len), dtype=self.dtype)

        return position_bias

    # Transformer 模型的主要调用方法,执行前向传播计算
    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        key_value_states=None,
        position_bias=None,
        output_attentions=False,
        deterministic=True,
class FlaxLongT5TransientGlobalAttention(nn.Module):
    config: LongT5Config
    has_relative_attention_bias: bool = False
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):
        self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
        self.relative_attention_max_distance = self.config.relative_attention_max_distance
        self.d_model = self.config.d_model
        self.key_value_proj_dim = self.config.d_kv
        self.n_heads = self.config.num_heads
        self.local_radius = self.config.local_radius
        self.block_len = self.local_radius + 1
        self.global_block_size = self.config.global_block_size
        self.dropout = self.config.dropout_rate
        self.inner_dim = self.n_heads * self.key_value_proj_dim

        q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
        kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
        o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)

        # Initialize query, key, value, and output dense layers with appropriate parameters
        self.q = nn.Dense(
            self.inner_dim,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(q_init_std),
            dtype=self.dtype,
        )
        self.k = nn.Dense(
            self.inner_dim,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(kv_init_std),
            dtype=self.dtype,
        )
        self.v = nn.Dense(
            self.inner_dim,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(kv_init_std),
            dtype=self.dtype,
        )
        self.o = nn.Dense(
            self.d_model,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(o_init_std),
            dtype=self.dtype,
        )

        if self.has_relative_attention_bias:
            # Initialize relative attention bias if enabled
            self.relative_attention_bias = nn.Embed(
                self.relative_attention_num_buckets,
                self.n_heads,
                embedding_init=jax.nn.initializers.normal(kv_init_std),
            )

        # Initialize global relative attention bias and layer normalization for global attention
        if self.has_relative_attention_bias:
            self.global_relative_attention_bias = nn.Embed(
                self.relative_attention_num_buckets,
                self.n_heads,
                embedding_init=jax.nn.initializers.normal(kv_init_std),
            )
        self.global_input_layer_norm = FlaxLongT5LayerNorm(
            self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
        )

    @staticmethod
    # Static method to compute relative position bucket, adapted from transformers.models.t5.modeling_flax_t5.FlaxT5Attention._relative_position_bucket
    def _relative_position_bucket(
        x: jnp.ndarray,  # Input array for relative positions
        bidirectional: bool = True,  # Flag indicating bidirectional attention
        num_buckets: int = 32,  # Number of buckets for relative position embeddings
        max_distance: int = 128,  # Maximum distance for relative position
    ) -> jnp.ndarray:
    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
        """
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

        Translate relative position to a bucket number for relative attention. The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
        This should allow for more graceful generalization to longer sequences than the model has been trained on
        """
        # 初始化相对位置的桶编号为0
        relative_buckets = 0
        
        # 如果是双向的相对位置注意力机制,调整桶数量,并根据正负性设置相对桶的偏移
        if bidirectional:
            num_buckets //= 2
            relative_buckets += (relative_position > 0) * num_buckets
            relative_position = jnp.abs(relative_position)
        else:
            # 如果是单向的相对位置注意力机制,将相对位置限制为非正数
            relative_position = -jnp.clip(relative_position, a_max=0)
        
        # 现在,relative_position 的范围在 [0, inf)

        # 设置小绝对相对位置的桶数为一半
        max_exact = num_buckets // 2
        is_small = relative_position < max_exact

        # 另一半的桶用于在位置增量上按对数方式增大,直到 max_distance
        relative_position_if_large = max_exact + (
            jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
        )
        relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)

        # 根据相对位置大小选择合适的桶编号
        relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)

        # 将相对桶编号转换为整数类型并返回
        return relative_buckets.astype("i4")

    def compute_bias(self, block_length: int):
        """Compute binned relative position bias"""
        # 创建一个包含特定长度内存位置的数组
        memory_position = jnp.arange(3 * block_length, dtype="i4")
        
        # 从内存位置中选择与上下文相关的位置
        context_position = memory_position[block_length:-block_length]

        # 计算每对内存位置和上下文位置之间的相对位置
        relative_position = memory_position[None, :] - context_position[:, None]
        
        # 根据相对位置计算相对位置的桶编号
        relative_position_bucket = self._relative_position_bucket(
            relative_position,
            bidirectional=True,
            num_buckets=self.relative_attention_num_buckets,
            max_distance=self.relative_attention_max_distance,
        )

        # 使用相对位置桶编号计算相对注意力偏置值
        values = self.relative_attention_bias(relative_position_bucket)
        
        # 调整数据维度以匹配模型要求并返回结果
        values = values.transpose((2, 0, 1))[None, None, :, :, :]
        return values
    def compute_side_bias(self, attention_mask: np.ndarray, global_segment_ids: np.ndarray) -> np.ndarray:
        # (batch_size, 1, 1, seq_len, global_seq_len)
        # 创建一个边缘注意力掩码,比较每个位置的注意力掩码和全局段落 ID 是否相等
        side_attention_mask = jnp.equal(attention_mask[..., None], global_segment_ids[:, None, :])[:, None, ...]
        # 根据边缘注意力掩码,选择性地应用注意力偏置值
        attention_side_bias = jax.lax.select(
            side_attention_mask > 0,
            jnp.full(side_attention_mask.shape, 0.0).astype(self.dtype),
            jnp.full(side_attention_mask.shape, -1e10).astype(self.dtype),
        )
        # (batch_size, seq_len, global_seq_len)
        # 计算侧边相对位置信息
        side_relative_position = _make_side_relative_position_ids(attention_mask, self.global_block_size)
        # 根据相对位置信息创建相对位置桶
        side_relative_position_bucket = self._relative_position_bucket(
            side_relative_position,
            bidirectional=True,
            num_buckets=self.relative_attention_num_buckets,
            max_distance=self.relative_attention_max_distance,
        )
        # (batch_size, seq_len, global_seq_len, num_heads)
        # 计算全局相对注意力偏置
        side_bias = self.global_relative_attention_bias(side_relative_position_bucket)

        # (batch_size, 1, num_heads, seq_len, global_seq_len)
        # 调整维度顺序,以匹配注意力偏置的形状
        side_bias = jnp.transpose(side_bias, (0, 3, 1, 2))
        # (batch_size, num_heads, seq_len, global_seq_len)
        # 结合边缘注意力偏置和全局相对注意力偏置
        attention_side_bias = attention_side_bias + side_bias
        return attention_side_bias

    def _split_heads(self, hidden_states):
        # 将隐藏状态分割成多个头部
        return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))

    def _merge_heads(self, hidden_states):
        # 将多个头部的隐藏状态合并
        return hidden_states.reshape(hidden_states.shape[0], -1, self.inner_dim)

    def _create_position_bias(self, block_len: int, attention_mask: Optional[np.ndarray]) -> np.ndarray:
        # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
        # 根据是否具有相对注意力偏置或注意力掩码,创建位置偏置矩阵
        if self.has_relative_attention_bias:
            position_bias = self.compute_bias(block_len)
        elif attention_mask is not None:
            position_bias = jnp.zeros_like(attention_mask)
        else:
            position_bias = jnp.zeros((1, 1, self.n_heads, block_len, 3 * block_len), dtype=self.dtype)

        return position_bias

    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        key_value_states=None,
        position_bias=None,
        output_attentions=False,
        deterministic=True,
class FlaxLongT5LayerLocalSelfAttention(nn.Module):
    """Local self attention used in encoder"""

    config: LongT5Config  # 类型注解,指定配置类 LongT5Config 的实例变量
    has_relative_attention_bias: bool = False  # 是否使用相对注意力偏置,默认为 False
    dtype: jnp.dtype = jnp.float32  # 计算时使用的数据类型,默认为 jnp.float32

    def setup(self):
        self.LocalSelfAttention = FlaxLongT5LocalAttention(
            self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype
        )
        self.layer_norm = FlaxLongT5LayerNorm(
            self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
        )
        self.dropout = nn.Dropout(self.config.dropout_rate)

    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
        output_attentions=False,
        deterministic=True,
        **kwargs: Any,  # 用于接受 init_cache 的其他参数
    ):
        normed_hidden_states = self.layer_norm(hidden_states)  # 对输入的 hidden_states 进行 layer normalization
        attention_output = self.LocalSelfAttention(
            normed_hidden_states,
            attention_mask=attention_mask,
            position_bias=position_bias,
            output_attentions=output_attentions,
            deterministic=deterministic,
        )
        hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
        outputs = (hidden_states,) + attention_output[1:]  # 如果输出注意力,将注意力添加到输出中
        return outputs


class FlaxLongT5LayerTransientGlobalSelfAttention(nn.Module):
    """Transient-Global self attention used in encoder"""

    config: LongT5Config  # 类型注解,指定配置类 LongT5Config 的实例变量
    has_relative_attention_bias: bool = False  # 是否使用相对注意力偏置,默认为 False
    dtype: jnp.dtype = jnp.float32  # 计算时使用的数据类型,默认为 jnp.float32

    def setup(self):
        self.TransientGlobalSelfAttention = FlaxLongT5TransientGlobalAttention(
            self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype
        )
        self.layer_norm = FlaxLongT5LayerNorm(
            self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
        )
        self.dropout = nn.Dropout(self.config.dropout_rate)

    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
        output_attentions=False,
        deterministic=True,
        **kwargs: Any,  # 用于接受 init_cache 的其他参数
    ):
        normed_hidden_states = self.layer_norm(hidden_states)  # 对输入的 hidden_states 进行 layer normalization
        attention_output = self.TransientGlobalSelfAttention(
            normed_hidden_states,
            attention_mask=attention_mask,
            position_bias=position_bias,
            output_attentions=output_attentions,
            deterministic=deterministic,
        )
        hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
        outputs = (hidden_states,) + attention_output[1:]  # 如果输出注意力,将注意力添加到输出中
        return outputs


# 从 transformers.models.t5.modeling_flax_t5.FlaxT5LayerSelfAttention 复制,将 T5 替换为 LongT5
class FlaxLongT5LayerSelfAttention(nn.Module):
    config: LongT5Config
    has_relative_attention_bias: bool = False
    dtype: jnp.dtype = jnp.float32  # 计算过程中使用的数据类型

    def setup(self):
        # 初始化自注意力层,使用配置信息和设定的参数
        self.SelfAttention = FlaxLongT5Attention(
            self.config,
            has_relative_attention_bias=self.has_relative_attention_bias,
            causal=self.config.causal,
            dtype=self.dtype,
        )
        # 初始化层归一化模块,使用模型配置中的维度和层归一化的 epsilon 参数
        self.layer_norm = FlaxLongT5LayerNorm(
            self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
        )
        # 初始化 Dropout 模块,使用模型配置中的丢弃率参数
        self.dropout = nn.Dropout(self.config.dropout_rate)

    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
        output_attentions=False,
        deterministic=True,
        init_cache=False,
    ):
        # 对输入的隐藏状态进行层归一化
        normed_hidden_states = self.layer_norm(hidden_states)
        # 使用自注意力层处理归一化后的隐藏状态
        attention_output = self.SelfAttention(
            normed_hidden_states,
            attention_mask=attention_mask,
            position_bias=position_bias,
            output_attentions=output_attentions,
            deterministic=deterministic,
            init_cache=init_cache,
        )
        # 将原始隐藏状态与经过 Dropout 处理后的注意力输出相加
        hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
        # 构建输出元组,包括更新后的隐藏状态和可能的注意力输出(如果需要)
        outputs = (hidden_states,) + attention_output[1:]  # 如果输出注意力信息,将其添加到输出中
        return outputs


# 从 transformers.models.t5.modeling_flax_t5.FlaxT5LayerCrossAttention 复制并修改为 LongT5
class FlaxLongT5LayerCrossAttention(nn.Module):
    config: LongT5Config
    dtype: jnp.dtype = jnp.float32  # 计算过程中使用的数据类型

    def setup(self):
        # 初始化编码-解码注意力层,使用配置信息和设定的参数
        self.EncDecAttention = FlaxLongT5Attention(
            self.config, has_relative_attention_bias=False, causal=False, dtype=self.dtype
        )
        # 初始化层归一化模块,使用模型配置中的维度和层归一化的 epsilon 参数
        self.layer_norm = FlaxLongT5LayerNorm(
            self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
        )
        # 初始化 Dropout 模块,使用模型配置中的丢弃率参数
        self.dropout = nn.Dropout(self.config.dropout_rate)

    def __call__(
        self,
        hidden_states,
        key_value_states,
        attention_mask=None,
        position_bias=None,
        output_attentions=False,
        deterministic=True,
    ):
        # 对输入的隐藏状态进行层归一化
        normed_hidden_states = self.layer_norm(hidden_states)
        # 使用编码-解码注意力层处理归一化后的隐藏状态和键值状态
        attention_output = self.EncDecAttention(
            normed_hidden_states,
            attention_mask=attention_mask,
            key_value_states=key_value_states,
            position_bias=position_bias,
            output_attentions=output_attentions,
        )
        # 将原始隐藏状态与经过 Dropout 处理后的注意力输出相加
        hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
        # 构建输出元组,包括更新后的隐藏状态和可能的注意力输出(如果需要)
        outputs = (hidden_states,) + attention_output[1:]  # 如果输出注意力信息,将其添加到输出中
        return outputs


class FlaxLongT5Block(nn.Module):
    config: LongT5Config
    has_relative_attention_bias: bool = False
    dtype: jnp.dtype = jnp.float32  # 计算过程中使用的数据类型
    # 设置函数,用于初始化模型配置
    def setup(self):
        # 从配置中获取是否采用因果(causal)注意力机制
        self.causal = self.config.causal
        # 根据是否采用因果注意力机制选择不同的注意力层类型
        if self.causal:
            attention_layer = FlaxLongT5LayerSelfAttention
        elif self.config.encoder_attention_type == "local":
            attention_layer = FlaxLongT5LayerLocalSelfAttention
        elif self.config.encoder_attention_type == "transient-global":
            attention_layer = FlaxLongT5LayerTransientGlobalSelfAttention
        else:
            # 如果未知的注意力类型,则引发数值错误异常
            raise ValueError(
                "For encoder attention mechanism, either `local` or `transient-global` attention type is expected, "
                f"but got {self.config.encoder_attention_type}."
            )
        # 初始化模型的注意力层
        self.layer = (
            attention_layer(
                self.config,
                has_relative_attention_bias=self.has_relative_attention_bias,
                name=str(0),
                dtype=self.dtype,
            ),
        )
        # 初始化前馈神经网络索引
        feed_forward_index = 1
        # 如果采用因果注意力机制,则添加交叉注意力层
        if self.causal:
            self.layer += (FlaxLongT5LayerCrossAttention(self.config, name=str(1), dtype=self.dtype),)
            feed_forward_index += 1

        # 添加前馈神经网络层
        self.layer += (FlaxLongT5LayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),)

    # 从 transformers.models.t5.modeling_flax_t5.FlaxT5Block.__call__ 复制而来,修改为 LongT5 模型
    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        encoder_decoder_position_bias=None,
        output_attentions=False,
        return_dict=True,
        deterministic=True,
        init_cache=False,
        ):
            # 调用自注意力层处理隐藏状态,传入注意力掩码、位置偏置等参数
            self_attention_outputs = self.layer[0](
                hidden_states,
                attention_mask=attention_mask,
                position_bias=position_bias,
                output_attentions=output_attentions,
                deterministic=deterministic,
                init_cache=init_cache,
            )
            # 更新隐藏状态为自注意力层的输出
            hidden_states = self_attention_outputs[0]
            # 保留自注意力输出和相关位置权重
            attention_outputs = self_attention_outputs[1:]  # 保留自注意力输出和相关位置权重

            # 如果需要执行交叉注意力,且存在编码器的隐藏状态
            do_cross_attention = self.causal and encoder_hidden_states is not None
            if do_cross_attention:
                # 调用交叉注意力层处理隐藏状态,传入编码器的键值状态、注意力掩码等参数
                cross_attention_outputs = self.layer[1](
                    hidden_states,
                    key_value_states=encoder_hidden_states,
                    attention_mask=encoder_attention_mask,
                    position_bias=encoder_decoder_position_bias,
                    output_attentions=output_attentions,
                    deterministic=deterministic,
                )
                # 更新隐藏状态为交叉注意力层的输出
                hidden_states = cross_attention_outputs[0]

                # 保留交叉注意力输出和相关位置权重
                attention_outputs = attention_outputs + cross_attention_outputs[1:]

            # 应用前馈神经网络层处理隐藏状态
            hidden_states = self.layer[-1](hidden_states, deterministic=deterministic)

            # 组装输出元组,包含更新后的隐藏状态
            outputs = (hidden_states,)

            # 将注意力输出追加到输出元组中
            outputs = outputs + attention_outputs

            # 返回隐藏状态和可能的注意力相关数据
            # 返回的元组结构为:hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights),
            #                (cross-attention position bias), (cross-attention weights)
            return outputs
# 从transformers.models.t5.modeling_flax_t5.FlaxT5LayerCollection复制并将T5改为LongT5
class FlaxLongT5LayerCollection(nn.Module):
    config: LongT5Config  # 配置对象,包含LongT5模型的配置信息
    has_relative_attention_bias: bool  # 是否使用相对注意力偏置
    dtype: jnp.dtype = jnp.float32  # 计算时的数据类型

    def setup(self):
        self.layer = FlaxLongT5Block(  # 初始化FlaxLongT5Block层对象
            self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype
        )

    def __call__(
        self,
        hidden_states,  # 输入的隐藏状态
        attention_mask=None,  # 注意力掩码,控制模型关注的位置
        position_bias=None,  # 位置偏置,用于自注意力机制中的位置编码
        encoder_hidden_states=None,  # 编码器隐藏状态,用于编码-解码器注意力
        encoder_attention_mask=None,  # 编码器的注意力掩码
        encoder_decoder_position_bias=None,  # 编码器到解码器的位置偏置
        output_attentions=False,  # 是否输出注意力权重
        deterministic=True,  # 是否使用确定性推断
        init_cache=False,  # 是否初始化缓存
    ):
        return self.layer(  # 调用FlaxLongT5Block层对象进行前向传播
            hidden_states,
            attention_mask=attention_mask,
            position_bias=position_bias,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            encoder_decoder_position_bias=encoder_decoder_position_bias,
            output_attentions=output_attentions,
            deterministic=deterministic,
            init_cache=init_cache,
        )


# 从transformers.models.t5.modeling_flax_t5.FlaxT5BlockCollection复制并将T5改为LongT5
class FlaxLongT5BlockCollection(nn.Module):
    config: LongT5Config  # 配置对象,包含LongT5模型的配置信息
    dtype: jnp.dtype = jnp.float32  # 计算时的数据类型
    gradient_checkpointing: bool = False  # 是否使用梯度检查点技术

    def setup(self):
        self.causal = self.config.causal  # 是否使用因果(自回归)模式
        if self.gradient_checkpointing:
            FlaxLongT5CheckpointLayer = remat(FlaxLongT5LayerCollection, static_argnums=(6, 7, 8))
            self.blocks = [
                FlaxLongT5CheckpointLayer(  # 初始化带梯度检查点的LongT5层对象
                    self.config,
                    has_relative_attention_bias=(i == 0),
                    dtype=self.dtype,
                    name=str(i),
                )
                for i in range(self.config.num_layers)
            ]
        else:
            self.blocks = [
                FlaxLongT5LayerCollection(  # 初始化LongT5层对象列表
                    self.config,
                    has_relative_attention_bias=(i == 0),
                    dtype=self.dtype,
                    name=str(i),
                )
                for i in range(self.config.num_layers)
            ]

    def __call__(
        self,
        hidden_states=None,  # 输入的隐藏状态
        attention_mask=None,  # 注意力掩码,控制模型关注的位置
        encoder_hidden_states=None,  # 编码器隐藏状态,用于编码-解码器注意力
        encoder_attention_mask=None,  # 编码器的注意力掩码
        output_attentions: bool = False,  # 是否输出注意力权重
        output_hidden_states: bool = False,  # 是否输出隐藏状态
        deterministic: bool = True,  # 是否使用确定性推断
        init_cache: bool = False,  # 是否初始化缓存
        # 准备需要的头部掩码
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
        all_cross_attentions = () if (output_attentions and self.causal) else None
        position_bias = None
        encoder_decoder_position_bias = None

        # 遍历每个 Transformer 模块
        for i, layer_module in enumerate(self.blocks):
            if output_hidden_states:
                # 如果需要输出隐藏状态,则将当前隐藏状态添加到 all_hidden_states 中
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 调用当前层的 Transformer 模块进行前向传播
            layer_outputs = layer_module(
                hidden_states,
                attention_mask,
                position_bias,
                encoder_hidden_states,
                encoder_attention_mask,
                encoder_decoder_position_bias,
                output_attentions,
                deterministic,
                init_cache,
            )

            # 更新隐藏状态为当前层输出的隐藏状态
            hidden_states = layer_outputs[0]

            # 更新位置偏置为当前层输出的自注意力位置偏置
            position_bias = layer_outputs[1]

            # 如果是因果的并且有编码器隐藏状态,则更新编码器-解码器位置偏置
            if self.causal and encoder_hidden_states is not None:
                encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]

            # 如果需要输出注意力权重,则将当前层的注意力权重添加到 all_attentions 中
            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[2],)
                # 如果是因果的,则将当前层的交叉注意力权重添加到 all_cross_attentions 中
                if self.causal:
                    all_cross_attentions = all_cross_attentions + (layer_outputs[4],)

        # 返回经过所有 Transformer 层处理后的输出
        return FlaxBaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
            cross_attentions=all_cross_attentions,
        )
# 从transformers.models.t5.modeling_flax_t5.FlaxT5Stack复制并修改为LongT5Stack
class FlaxLongT5Stack(nn.Module):
    # 配置参数
    config: LongT5Config
    # 词嵌入层
    embed_tokens: nn.Embed
    # 计算中使用的数据类型,默认为jnp.float32
    dtype: jnp.dtype = jnp.float32
    # 梯度检查点,默认关闭
    gradient_checkpointing: bool = False

    # 初始化方法
    def setup(self):
        # 是否是因果(causal)模型
        self.causal = self.config.causal
        # 创建LongT5BlockCollection块
        self.block = FlaxLongT5BlockCollection(
            self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
        )
        # 最终层归一化
        self.final_layer_norm = FlaxLongT5LayerNorm(
            self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
        )
        # Dropout层
        self.dropout = nn.Dropout(self.config.dropout_rate)

    # 调用方法,处理输入并生成输出
    def __call__(
        self,
        input_ids=None,
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
        init_cache: bool = False,
    ):
        # 获取词嵌入表示
        hidden_states = self.embed_tokens(input_ids)
        # 应用Dropout
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)

        # 调用块对象处理隐藏状态
        outputs = self.block(
            hidden_states,
            attention_mask=attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            deterministic=deterministic,
            init_cache=init_cache,
        )

        # 从块的输出中获取隐藏状态
        hidden_states = outputs[0]

        # 应用最终层归一化
        hidden_states = self.final_layer_norm(hidden_states)
        # 再次应用Dropout
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)

        # 添加最后一层隐藏状态(用于返回所有隐藏状态时)
        all_hidden_states = None
        if output_hidden_states:
            all_hidden_states = outputs.hidden_states
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 根据返回类型构建输出
        if not return_dict:
            if output_hidden_states:
                return (
                    hidden_states,
                    all_hidden_states,
                ) + outputs[2:]  # 返回隐藏状态及额外输出
            return (hidden_states,) + outputs[1:]  # 仅返回隐藏状态

        # 返回带过去和交叉注意力的基本模型输出
        return FlaxBaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )


# 以下是LongT5编码器输入的文档字符串
LONGT5_ENCODE_INPUTS_DOCSTRING = r"""
    # 接收输入参数的函数定义,用于处理LongT5模型的输入数据
    Args:
        input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
            输入序列标记在词汇表中的索引。LongT5模型具有相对位置嵌入,因此可以在序列的左右两侧进行填充。
    
            可以使用[`AutoTokenizer`]获取索引。详见[`PreTrainedTokenizer.encode`]和[`PreTrainedTokenizer.__call__`]。
    
            想要了解有关如何为预训练准备`input_ids`的更多信息,请查看[长T5训练](./longt5#training)。
        attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            遮罩,用于避免在填充标记索引上执行注意力操作。遮罩值选在 `[0, 1]` 之间:
    
            - 对于**未被遮罩**的标记,值为1,
            - 对于**被遮罩**的标记,值为0。
    
            [什么是注意力遮罩?](../glossary#attention-mask)
        output_attentions (`bool`, *optional*):
            是否返回所有注意力层的注意力张量。查看返回的张量中的`attentions`以获取更多详细信息。
        output_hidden_states (`bool`, *optional*):
            是否返回所有层的隐藏状态。查看返回的张量中的`hidden_states`以获取更多详细信息。
        return_dict (`bool`, *optional*):
            是否返回[`~utils.ModelOutput`]而不是普通元组。
"""
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
"""

class FlaxLongT5PreTrainedModel(FlaxPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 指定配置类为 LongT5Config
    config_class = LongT5Config
    # 基础模型前缀为 "transformer"
    base_model_prefix = "transformer"
    # 模块类默认为空
    module_class: nn.Module = None
    # 初始化方法,用于创建一个 LongT5 模型实例
    def __init__(
        self,
        config: LongT5Config,
        input_shape: Tuple[int] = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        # 使用给定的配置和参数创建模型类实例
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        # 调用父类的初始化方法,传入配置、模型类实例以及其他参数
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

    # 启用梯度检查点功能,重新设置模型实例以支持梯度检查点
    def enable_gradient_checkpointing(self):
        self._module = self.module_class(
            config=self.config,
            dtype=self.dtype,
            gradient_checkpointing=True,
        )

    # 初始化模型权重方法,使用随机数种子 rng 和输入形状 input_shape
    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # 初始化输入张量 input_ids
        input_ids = jnp.zeros(input_shape, dtype="i4")

        # 创建 attention_mask,decoder_input_ids 和 decoder_attention_mask 张量
        attention_mask = jnp.ones_like(input_ids)
        decoder_input_ids = jnp.ones_like(input_ids)
        decoder_attention_mask = jnp.ones_like(input_ids)

        # 分割随机数种子 rng 为 params_rng 和 dropout_rng
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        # 使用初始化方法初始化模型参数 random_params
        random_params = self.module.init(
            rngs,
            input_ids,
            attention_mask,
            decoder_input_ids,
            decoder_attention_mask,
        )["params"]

        # 如果提供了初始参数 params,则将缺失的参数从 random_params 中补充到 params 中
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params

    # 重写 __call__ 方法,并添加了文档字符串装饰器 LONGT5_INPUTS_DOCSTRING
    @add_start_docstrings_to_model_forward(LONGT5_INPUTS_DOCSTRING)
    def __call__(
        self,
        input_ids: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        decoder_input_ids: jnp.ndarray = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
        ):
            # 如果未指定output_attentions,则使用配置中的默认值
            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
            # 如果未指定output_hidden_states,则使用配置中的默认值
            output_hidden_states = (
                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
            )
            # 如果未指定return_dict,则使用配置中的默认值
            return_dict = return_dict if return_dict is not None else self.config.return_dict

            # 如果decoder_input_ids未提供,则抛出值错误
            if decoder_input_ids is None:
                raise ValueError(
                    "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed"
                    " here."
                )

            # 准备编码器输入的注意力掩码
            if attention_mask is None:
                attention_mask = jnp.ones_like(input_ids)

            # 准备解码器输入的注意力掩码
            if decoder_attention_mask is None:
                decoder_attention_mask = jnp.ones_like(decoder_input_ids)

            # 如果需要处理任何伪随机数生成器,则放入rngs字典中
            rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}

            # 调用self.module.apply方法来执行模型
            return self.module.apply(
                {"params": params or self.params},  # 模型参数
                input_ids=jnp.array(input_ids, dtype="i4"),  # 编码器输入的token IDs
                attention_mask=jnp.array(attention_mask, dtype="i4"),  # 编码器输入的注意力掩码
                decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),  # 解码器输入的token IDs
                decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),  # 解码器输入的注意力掩码
                output_attentions=output_attentions,  # 是否返回注意力权重
                output_hidden_states=output_hidden_states,  # 是否返回隐藏状态
                return_dict=return_dict,  # 是否返回字典形式的输出
                deterministic=not train,  # 是否确定性执行(非训练状态)
                rngs=rngs,  # 伪随机数生成器字典
            )
    def init_cache(self, batch_size, max_length, encoder_outputs):
        r"""
        Args:
            batch_size (`int`):
                batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
            max_length (`int`):
                maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
                cache.
            encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
                `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
                `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
                is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
                cross-attention of the decoder.
        """
        # 初始化输入变量以检索缓存
        decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
        decoder_attention_mask = jnp.ones_like(decoder_input_ids)

        def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs):
            decoder_module = module._get_decoder_module()
            return decoder_module(
                decoder_input_ids,
                decoder_attention_mask,
                **kwargs,
            )

        # 使用指定的参数初始化模型,并获取初始化后的变量
        init_variables = self.module.init(
            jax.random.PRNGKey(0),
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            encoder_hidden_states=encoder_outputs[0],
            init_cache=True,
            method=_decoder_forward,  # 只需调用解码器以初始化缓存
        )
        # 返回冻结的缓存变量
        return unfreeze(init_variables["cache"])

    @add_start_docstrings(LONGT5_ENCODE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=LongT5Config)
    def encode(
        self,
        input_ids: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
        ):
        r"""
        Returns:

        Example:

        ```
        >>> from transformers import AutoTokenizer, FlaxLongT5ForConditionalGeneration

        >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
        >>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/long-t5-local-base")

        >>> text = "My friends are cool but they eat too many carbs."
        >>> inputs = tokenizer(text, return_tensors="np")
        >>> encoder_outputs = model.encode(**inputs)
        ```"""
        # 如果没有显式指定,则使用默认配置中的值
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # 如果 attention_mask 为 None,则创建一个全为 1 的掩码
        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)

        # 如果有 dropout_rng,则加入 RNG 字典中
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        # 定义内部函数 _encoder_forward,用于调用编码器模块
        def _encoder_forward(module, input_ids, attention_mask, **kwargs):
            encode_module = module._get_encoder_module()
            return encode_module(input_ids, attention_mask, **kwargs)

        # 调用 Flax 模块的 apply 方法进行前向传播
        return self.module.apply(
            {"params": params or self.params},  # 使用给定的参数或者当前实例的参数
            input_ids=jnp.array(input_ids, dtype="i4"),  # 转换输入 ids 到 JAX 数组格式
            attention_mask=jnp.array(attention_mask, dtype="i4"),  # 转换注意力掩码到 JAX 数组格式
            output_attentions=output_attentions,  # 是否返回注意力权重
            output_hidden_states=output_hidden_states,  # 是否返回隐藏状态
            return_dict=return_dict,  # 是否返回字典形式的输出
            deterministic=not train,  # 是否是确定性运行模式,非训练状态
            rngs=rngs,  # 随机数生成器字典
            method=_encoder_forward,  # 调用的方法,这里是编码器的前向传播函数
        )
    
    @add_start_docstrings(LONGT5_DECODE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=LongT5Config)
    def decode(
        self,
        decoder_input_ids,
        encoder_outputs,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        past_key_values: dict = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
"""
    The LongT5 model was proposed in [LongT5: Efficient Text-To-Text Transformer for Long
    Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo
    Ni, Yun-Hsuan Sung and Yinfei Yang. It's an encoder-decoder transformer pre-trained in a text-to-text denoising
    generative setting. LongT5 model is an extension of T5 model, and it enables using one of the two different
    efficient attention mechanisms - (1) Local attention, or (2) Transient-Global attention.

    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a Flax Linen
    [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
    regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)

    Parameters:
        config ([`LongT5Config`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
            The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
            `jax.numpy.bfloat16` (on TPUs).

            This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
            specified all the computation will be performed with the given `dtype`.

            **Note that this only specifies the dtype of the computation and does not influence the dtype of model
            parameters.**

            If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
            [`~FlaxPreTrainedModel.to_bf16`].
"""

@add_start_docstrings(
    "The bare LONGT5 Model transformer outputting raw hidden-stateswithout any specific head on top.",
    LONGT5_START_DOCSTRING,
)
# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Module with T5->LongT5
class FlaxLongT5Module(nn.Module):
    """
    Flax module for the LongT5 model, extending the T5 architecture to support long sequences and different attention mechanisms.

    Inherits from `nn.Module`, enabling it to be used as a Flax Linen module. Refer to Flax documentation for usage details.

    Attributes:
        config (LongT5Config): Model configuration object containing all parameters.
        dtype (jnp.dtype): Data type for computation, defaulting to jnp.float32.
    """
    config: LongT5Config
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    # 梯度检查点标志,默认为 False
    gradient_checkpointing: bool = False
    
    # 获取编码器模块的方法
    def _get_encoder_module(self):
        return self.encoder
    
    # 获取解码器模块的方法
    def _get_decoder_module(self):
        return self.decoder
    
    # 设置方法,用于初始化和配置模型
    def setup(self):
        # 创建共享的嵌入层,用于输入和输出的词汇表
        self.shared = nn.Embed(
            self.config.vocab_size,
            self.config.d_model,
            embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
            dtype=self.dtype,
        )
    
        # 复制编码器配置,并禁用因果关系,创建编码器对象
        encoder_config = copy.deepcopy(self.config)
        encoder_config.causal = False
        self.encoder = FlaxLongT5Stack(
            encoder_config,
            embed_tokens=self.shared,
            dtype=self.dtype,
            gradient_checkpointing=self.gradient_checkpointing,
        )
    
        # 复制解码器配置,并启用因果关系,根据配置创建解码器对象
        decoder_config = copy.deepcopy(self.config)
        decoder_config.causal = True
        decoder_config.num_layers = self.config.num_decoder_layers
        self.decoder = FlaxLongT5Stack(
            decoder_config,
            embed_tokens=self.shared,
            dtype=self.dtype,
            gradient_checkpointing=self.gradient_checkpointing,
        )
    
    # 调用方法,实现模型的前向传播
    def __call__(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        deterministic: bool = True,
    ):
        # 如果未指定返回字典,则使用配置中的默认设置
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    
        # 编码器的前向传播,生成编码器输出
        encoder_outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )
    
        # 解码器的前向传播,生成解码器输出
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=encoder_outputs[0],  # 使用编码器的隐藏状态作为解码器的输入
            encoder_attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )
    
        # 如果不需要返回字典,则将编码器和解码器的输出合并返回
        if not return_dict:
            return decoder_outputs + encoder_outputs
    
        # 构造并返回序列到序列模型的输出
        return FlaxSeq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
# 从transformers.models.t5.modeling_flax_t5.FlaxT5Model复制代码,将T5替换为LongT5
class FlaxLongT5Model(FlaxLongT5PreTrainedModel):
    # 使用FlaxLongT5Module作为模块类
    module_class = FlaxLongT5Module

# 将FlaxLongT5Model的调用示例文档字符串附加到_CHECKPOINT_FOR_DOC并使用FlaxSeq2SeqModelOutput进行文档化
append_call_sample_docstring(FlaxLongT5Model, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)

# 定义FLAX_LONGT5_MODEL_DOCSTRING,包含函数的返回值和示例用法
FLAX_LONGT5_MODEL_DOCSTRING = """
    Returns:

    Example:

    ```
    >>> from transformers import AutoTokenizer, FlaxLongT5Model

    >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
    >>> model = FlaxLongT5Model.from_pretrained("google/long-t5-local-base")

    >>> input_ids = tokenizer(
    ...     "Studies have been shown that owning a dog is good for you", return_tensors="np"
    ... ).input_ids
    >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids

    >>> # forward pass
    >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
    >>> last_hidden_states = outputs.last_hidden_state
    ```
"""

# 重写FlaxLongT5Model的调用文档字符串,使用LONGT5_INPUTS_DOCSTRING和FLAX_LONGT5_MODEL_DOCSTRING
overwrite_call_docstring(FlaxLongT5Model, LONGT5_INPUTS_DOCSTRING + FLAX_LONGT5_MODEL_DOCSTRING)

# 将FlaxLongT5Model的返回文档字符串替换为FlaxSeq2SeqLMOutput,使用_CONFIG_FOR_DOC作为配置类
append_replace_return_docstrings(FlaxLongT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)

@add_start_docstrings("""LONGT5 Model with a `language modeling` head on top.""", LONGT5_START_DOCSTRING)
# 从transformers.models.t5.modeling_flax_t5.FlaxT5ForConditionalGenerationModule复制代码,将T5替换为LongT5
class FlaxLongT5ForConditionalGenerationModule(nn.Module):
    # 配置为LongT5Config类型
    config: LongT5Config
    # 计算中的数据类型为jnp.float32
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    # 梯度检查点默认为False
    gradient_checkpointing: bool = False

    # 返回编码器模块
    def _get_encoder_module(self):
        return self.encoder

    # 返回解码器模块
    def _get_decoder_module(self):
        return self.decoder

    # 模块设置函数
    def setup(self):
        # 模型维度为配置中的d_model
        self.model_dim = self.config.d_model

        # 创建共享的嵌入层,使用正态分布初始化
        self.shared = nn.Embed(
            self.config.vocab_size,
            self.config.d_model,
            embedding_init=jax.nn.initializers.normal(self.config.initializer_factor),
            dtype=self.dtype,
        )

        # 复制编码器配置,并设定特定参数
        encoder_config = copy.deepcopy(self.config)
        encoder_config.causal = False
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False
        # 创建LongT5编码器栈
        self.encoder = FlaxLongT5Stack(
            encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
        )

        # 复制解码器配置,并设定特定参数
        decoder_config = copy.deepcopy(self.config)
        decoder_config.causal = True
        decoder_config.is_encoder_decoder = False
        decoder_config.num_layers = self.config.num_decoder_layers
        # 创建LongT5解码器栈
        self.decoder = FlaxLongT5Stack(
            decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
        )

        # 创建语言模型头部,全连接层
        self.lm_head = nn.Dense(
            self.config.vocab_size,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_factor),
            dtype=self.dtype,
        )
    def __call__(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        deterministic: bool = True,
    ):
        # 确保返回字典的设置正确
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Encode
        # 调用编码器进行编码
        encoder_outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )

        hidden_states = encoder_outputs[0]

        # Decode
        # 调用解码器进行解码
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )

        sequence_output = decoder_outputs[0]

        if self.config.tie_word_embeddings:
            # 如果需要共享词嵌入,则按比例缩放输出
            # 参考:https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim**-0.5)

        if self.config.tie_word_embeddings:
            # 如果需要共享词嵌入,则应用共享的嵌入层
            shared_embedding = self.shared.variables["params"]["embedding"]
            lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
        else:
            lm_logits = self.lm_head(sequence_output)

        if not return_dict:
            # 如果不需要返回字典,则返回元组形式的结果
            return (lm_logits,) + decoder_outputs[1:] + encoder_outputs

        # 如果需要返回字典,则构造 FlaxSeq2SeqLMOutput 对象
        return FlaxSeq2SeqLMOutput(
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
class FlaxLongT5ForConditionalGeneration(FlaxLongT5PreTrainedModel):
    # 模型类指定为 FlaxLongT5ForConditionalGenerationModule
    module_class = FlaxLongT5ForConditionalGenerationModule

    @add_start_docstrings(LONGT5_DECODE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=LongT5Config)
    def decode(
        self,
        decoder_input_ids,
        encoder_outputs,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        past_key_values: dict = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        """
        解码函数,根据给定的输入和条件生成输出。

        Args:
            decoder_input_ids: 解码器的输入 ID。
            encoder_outputs: 编码器的输出。
            encoder_attention_mask: 可选,编码器的注意力遮罩。
            decoder_attention_mask: 可选,解码器的注意力遮罩。
            past_key_values: 可选,过去的键值对,用于加速生成。
            output_attentions: 可选,是否输出注意力权重。
            output_hidden_states: 可选,是否输出隐藏状态。
            return_dict: 可选,是否以字典形式返回输出。
            train: 是否处于训练模式。
            params: 可选,模型参数。
            dropout_rng: 可选,用于 dropout 的随机数生成器。

        Returns:
            解码后的输出结果。
        """
        ...

    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        max_length,
        attention_mask: Optional[jax.Array] = None,
        decoder_attention_mask: Optional[jax.Array] = None,
        encoder_outputs=None,
        **kwargs,
    ):
        """
        为生成过程准备输入,初始化缓存并生成注意力掩码。

        Args:
            decoder_input_ids: 解码器的输入 ID。
            max_length: 生成的最大长度。
            attention_mask: 可选,编码器的注意力遮罩。
            decoder_attention_mask: 可选,解码器的注意力遮罩。
            encoder_outputs: 可选,编码器的输出。
            **kwargs: 其他关键字参数。

        Returns:
            包含生成过程所需输入的字典。
        """
        # 初始化缓存
        batch_size, seq_length = decoder_input_ids.shape
        past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)

        # 创建扩展的注意力掩码
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
        if decoder_attention_mask is not None:
            extended_attention_mask = jax.lax.dynamic_update_slice(
                extended_attention_mask, decoder_attention_mask, (0, 0)
            )

        return {
            "past_key_values": past_key_values,
            "encoder_outputs": encoder_outputs,
            "encoder_attention_mask": attention_mask,
            "decoder_attention_mask": extended_attention_mask,
        }

    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        """
        更新生成过程的输入,将过去的键值对更新为模型输出的过去键值对。

        Args:
            model_outputs: 模型的输出。
            model_kwargs: 模型的关键字参数。

        Returns:
            更新后的模型关键字参数。
        """
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        return model_kwargs


FLAX_LONGT5_CONDITIONAL_GENERATION_DOCSTRING = """
    Returns:
        生成的结果。

    Example:

    ```
    >>> from transformers import AutoTokenizer, FlaxLongT5ForConditionalGeneration

    >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
    >>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/long-t5-local-base")

    >>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs."
    >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], return_tensors="np")

    >>> # 生成摘要
    >>> summary_ids = model.generate(inputs["input_ids"]).sequences
    >>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False))
    ```
"""

overwrite_call_docstring(
    # 导入 FlaxLongT5ForConditionalGeneration 类以及相关文档字符串
    FlaxLongT5ForConditionalGeneration, LONGT5_INPUTS_DOCSTRING + FLAX_LONGT5_CONDITIONAL_GENERATION_DOCSTRING
# 调用函数 `append_replace_return_docstrings`,传入参数 `FlaxLongT5ForConditionalGeneration` 作为第一个位置参数,
# `output_type=FlaxSeq2SeqLMOutput` 作为关键字参数 `output_type` 的值,
# `_CONFIG_FOR_DOC` 作为关键字参数 `config_class` 的值。
append_replace_return_docstrings(
    FlaxLongT5ForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
)
posted @ 2024-06-29 16:59  绝不原创的飞龙  阅读(14)  评论(0编辑  收藏  举报