Transformers-源码解析-一百一十一-

Transformers 源码解析(一百一十一)

.\models\tapas\modeling_tf_tapas.py

# 设置文件编码为 UTF-8
# 版权声明,指出版权归属于 Google Research 和 HuggingFace Inc. 团队
# 根据 Apache 许可证 2.0 版本授权,除非符合许可证条件,否则不得使用此文件
# 获取许可证的副本,请访问 http://www.apache.org/licenses/LICENSE-2.0
# 除非适用法律要求或书面同意,否则依据 "AS IS" 原则分发软件,无论是明示还是默示的任何保证或条件都不包括在内
# 请参阅许可证,获取特定语言的详细信息和限制

"""TF 2.0 TAPAS 模型。"""

# 引入必要的库和模块
from __future__ import annotations

import enum  # 引入枚举类型
import math  # 引入数学库函数
from dataclasses import dataclass  # 引入数据类装饰器
from typing import Dict, Optional, Tuple, Union  # 引入类型提示

import numpy as np  # 引入 NumPy 库
import tensorflow as tf  # 引入 TensorFlow 库

# 引入相关的自定义库和模块
from ...activations_tf import get_tf_activation  # 从指定路径引入 TensorFlow 激活函数
from ...modeling_tf_outputs import (
    TFBaseModelOutputWithPastAndCrossAttentions,
    TFBaseModelOutputWithPooling,
    TFMaskedLMOutput,
    TFSequenceClassifierOutput,
)  # 从指定路径引入 TensorFlow 模型输出类
from ...modeling_tf_utils import (
    TFMaskedLanguageModelingLoss,
    TFModelInputType,
    TFPreTrainedModel,
    TFSequenceClassificationLoss,
    get_initializer,
    keras,
    keras_serializable,
    unpack_inputs,
)  # 从指定路径引入 TensorFlow 模型相关工具函数
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax  # 从指定路径引入 TensorFlow 相关工具函数
from ...utils import (
    ModelOutput,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_tensorflow_probability_available,
    logging,
    replace_return_docstrings,
    requires_backends,
)  # 从指定路径引入通用工具函数和类
from .configuration_tapas import TapasConfig  # 从指定路径引入 Tapas 配置类

# 获取 logger 对象,用于日志记录
logger = logging.get_logger(__name__)

# 软依赖项
# 检查是否导入了 TensorFlow Probability 库
if is_tensorflow_probability_available():
    try:
        import tensorflow_probability as tfp  # 导入 TensorFlow Probability 库
        # 在第一次调用时,检查安装的 TensorFlow 版本是否兼容
        # TensorFlow Probability 依赖于最新的稳定版本的 TensorFlow
        n = tfp.distributions.Normal(loc=0.0, scale=1.0)
    except ImportError:
        # 如果导入失败,则记录错误信息
        logger.error(
            "TAPAS 模型无法使用,因为无法加载 `tensorflow_probability`。"
            "看起来您安装了与 TensorFlow 版本不匹配的 `tensorflow_probability`。"
            "请尝试按照以下说明重新安装:https://github.com/tensorflow/probability。"
        )

# 用于文档的配置和检查点的字符串常量
_CONFIG_FOR_DOC = "TapasConfig"
_CHECKPOINT_FOR_DOC = "google/tapas-base"

# TF TAPAS 预训练模型的存档列表
TF_TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST = [
    # 大模型
    "google/tapas-large",
    "google/tapas-large-finetuned-sqa",
    "google/tapas-large-finetuned-wtq",
    "google/tapas-large-finetuned-wikisql-supervised",
    "google/tapas-large-finetuned-tabfact",
    # 基础模型
    "google/tapas-base",
    "google/tapas-base-finetuned-sqa",
    "google/tapas-base-finetuned-wtq",
    "google/tapas-base-finetuned-wikisql-supervised",
    "google/tapas-base-finetuned-tabfact",
    # 小模型
    "google/tapas-small",
]
    # 定义一个列表,包含了各种 TAPAS 模型的名称字符串
    models = [
        "google/tapas-small-finetuned-sqa",  # Google TAPAS 小模型,针对 SQA 数据集进行了微调
        "google/tapas-small-finetuned-wtq",  # Google TAPAS 小模型,针对 WTQ 数据集进行了微调
        "google/tapas-small-finetuned-wikisql-supervised",  # Google TAPAS 小模型,针对 Wikisql 数据集进行了监督学习微调
        "google/tapas-small-finetuned-tabfact",  # Google TAPAS 小模型,针对 TabFact 数据集进行了微调
        # 迷你模型
        "google/tapas-mini",  # Google TAPAS 迷你模型
        "google/tapas-mini-finetuned-sqa",  # Google TAPAS 迷你模型,针对 SQA 数据集进行了微调
        "google/tapas-mini-finetuned-wtq",  # Google TAPAS 迷你模型,针对 WTQ 数据集进行了微调
        "google/tapas-mini-finetuned-wikisql-supervised",  # Google TAPAS 迷你模型,针对 Wikisql 数据集进行了监督学习微调
        "google/tapas-mini-finetuned-tabfact",  # Google TAPAS 迷你模型,针对 TabFact 数据集进行了微调
        # 超迷你模型
        "google/tapas-tiny",  # Google TAPAS 超迷你模型
        "google/tapas-tiny-finetuned-sqa",  # Google TAPAS 超迷你模型,针对 SQA 数据集进行了微调
        "google/tapas-tiny-finetuned-wtq",  # Google TAPAS 超迷你模型,针对 WTQ 数据集进行了微调
        "google/tapas-tiny-finetuned-wikisql-supervised",  # Google TAPAS 超迷你模型,针对 Wikisql 数据集进行了监督学习微调
        "google/tapas-tiny-finetuned-tabfact",  # Google TAPAS 超迷你模型,针对 TabFact 数据集进行了微调
        # 查看所有 TAPAS 模型,请访问 https://huggingface.co/models?filter=tapas
    ]
]

# 定义一个全局常量,用于避免零除错误的微小值
EPSILON_ZERO_DIVISION = 1e-10
# 定义一个常量,用于表示接近于负无穷大的值,通常用于表示对数概率为零的情况
CLOSE_ENOUGH_TO_LOG_ZERO = -10000.0


@dataclass
class TFTableQuestionAnsweringOutput(ModelOutput):
    """
    [`TFTapasForQuestionAnswering`]的输出类型。

    Args:
        loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` (and possibly `answer`, `aggregation_labels`, `numeric_values` and `numeric_values_scale` are provided)):
            如果提供了 `labels`(可能还有 `answer`, `aggregation_labels`, `numeric_values` 和 `numeric_values_scale`),则返回总损失,
            包括分层单元选择的对数似然损失,以及(可选的)半监督回归损失和聚合的监督损失。
        logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
            每个标记的单元选择头的预测分数。
        logits_aggregation (`tf.Tensor`, *optional*, of shape `(batch_size, num_aggregation_labels)`):
            每个聚合操作符的聚合头的预测分数。
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            形状为 `(batch_size, sequence_length, hidden_size)` 的 `tf.Tensor` 元组。
            模型在每个层的输出隐藏状态以及初始嵌入输出。
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            形状为 `(batch_size, num_heads, sequence_length, sequence_length)` 的 `tf.Tensor` 元组。
            在注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。

    """

    loss: tf.Tensor | None = None
    logits: tf.Tensor = None
    logits_aggregation: tf.Tensor | None = None
    hidden_states: Tuple[tf.Tensor] | None = None
    attentions: Tuple[tf.Tensor] | None = None


class TFTapasEmbeddings(keras.layers.Layer):
    """
    根据词嵌入、位置嵌入和标记类型嵌入构建嵌入。与 BertEmbeddings 相同,但包含多个用于编码表格结构的标记类型嵌入。
    """

    def __init__(self, config: TapasConfig, **kwargs):
        super().__init__(**kwargs)

        self.config = config
        self.number_of_token_type_embeddings = len(config.type_vocab_sizes)
        self.reset_position_index_per_cell = config.reset_position_index_per_cell
        self.hidden_size = config.hidden_size
        self.max_position_embeddings = config.max_position_embeddings
        self.initializer_range = config.initializer_range
        # 创建一个 LayerNormalization 层,用于规范化输入数据
        self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        # 创建一个 Dropout 层,用于随机失活输入单元,防止过拟合
        self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
    # 在构建模型时,定义输入形状,并添加词嵌入层的权重参数
    def build(self, input_shape=None):
        # 在 "word_embeddings" 命名空间下,创建词嵌入层的权重参数
        with tf.name_scope("word_embeddings"):
            self.weight = self.add_weight(
                name="weight",
                shape=[self.config.vocab_size, self.hidden_size],
                initializer=get_initializer(self.initializer_range),
            )

        # 在 "position_embeddings" 命名空间下,创建位置嵌入层的权重参数
        with tf.name_scope("position_embeddings"):
            self.position_embeddings = self.add_weight(
                name="embeddings",
                shape=[self.max_position_embeddings, self.hidden_size],
                initializer=get_initializer(self.initializer_range),
            )

        # 对于每个类型的词汇表大小,分别创建对应的类型嵌入层的权重参数
        for i, type_vocab_size in enumerate(self.config.type_vocab_sizes):
            with tf.name_scope(f"token_type_embeddings_{i}"):
                setattr(
                    self,
                    f"token_type_embeddings_{i}",
                    self.add_weight(
                        name="embeddings",
                        shape=[type_vocab_size, self.hidden_size],
                        initializer=get_initializer(self.initializer_range),
                    ),
                )

        # 如果模型已经构建,则直接返回,避免重复构建
        if self.built:
            return

        # 标记模型已经构建
        self.built = True

        # 如果存在 LayerNorm 层,则在其命名空间下构建
        if getattr(self, "LayerNorm", None) is not None:
            with tf.name_scope(self.LayerNorm.name):
                # 根据配置的隐藏大小构建 LayerNorm 层
                self.LayerNorm.build([None, None, self.config.hidden_size])

    # 模型调用方法,接受输入张量并进行处理
    def call(
        self,
        input_ids: tf.Tensor = None,
        position_ids: tf.Tensor = None,
        token_type_ids: tf.Tensor = None,
        inputs_embeds: tf.Tensor = None,
        training: bool = False,
    ) -> tf.Tensor:
        """
        Applies embedding based on inputs tensor.

        Returns:
            final_embeddings (`tf.Tensor`): output embedding tensor.
        """
        # Ensure either `input_ids` or `inputs_embeds` is provided
        assert not (input_ids is None and inputs_embeds is None)

        if input_ids is not None:
            # Get the shape of `input_ids`
            input_shape = shape_list(input_ids)
        else:
            # Get the shape of `inputs_embeds` excluding the last dimension
            input_shape = shape_list(inputs_embeds)[:-1]

        # Determine the sequence length from `input_shape`
        seq_length = input_shape[1]

        if token_type_ids is None:
            # If `token_type_ids` is not provided, fill with zeros
            token_type_ids = tf.fill(dims=input_shape + [self.number_of_token_type_embeddings], value=0)

        if position_ids is None:
            # Create absolute position embeddings
            position_ids = tf.expand_dims(tf.range(start=0, limit=seq_length), axis=0)
            position_ids = tf.broadcast_to(position_ids, shape=input_shape)

            # Conditionally create relative position embeddings when `reset_position_index_per_cell` is True
            if self.reset_position_index_per_cell:
                # Calculate column and row indices based on `token_type_ids`
                col_index = IndexMap(token_type_ids[:, :, 1], self.config.type_vocab_sizes[1], batch_dims=1)
                row_index = IndexMap(token_type_ids[:, :, 2], self.config.type_vocab_sizes[2], batch_dims=1)

                # Combine column and row indices to create full index
                full_index = ProductIndexMap(col_index, row_index)

                # Determine the first absolute position for every segment
                first_position_per_segment = reduce_min(position_ids, full_index)[0]

                # Calculate the first absolute position of the cell for every token
                first_position = gather(first_position_per_segment, full_index)

                # Calculate relative positions within the cell, ensuring within bounds
                position = tf.expand_dims(tf.range(start=0, limit=seq_length), axis=0)
                position_ids = tf.math.minimum(self.max_position_embeddings - 1, position - first_position)

        if input_ids is not None:
            # Validate `input_ids` are within bounds of vocabulary size
            check_embeddings_within_bounds(input_ids, self.config.vocab_size)
            # Gather embeddings based on `input_ids`
            inputs_embeds = tf.gather(params=self.weight, indices=input_ids)

        # Gather position embeddings based on `position_ids`
        position_embeddings = tf.gather(self.position_embeddings, indices=position_ids)

        # Combine input embeddings with position embeddings
        final_embeddings = inputs_embeds + position_embeddings

        # Add token type embeddings for each token type
        for i in range(self.number_of_token_type_embeddings):
            name = f"token_type_embeddings_{i}"
            final_embeddings += tf.gather(params=getattr(self, name), indices=token_type_ids[:, :, i])

        # Apply layer normalization
        final_embeddings = self.LayerNorm(inputs=final_embeddings)

        # Apply dropout during training
        final_embeddings = self.dropout(inputs=final_embeddings, training=training)

        # Return the final embedding tensor
        return final_embeddings
# 从 transformers.models.bert.modeling_tf_bert.TFBertSelfAttention 复制并修改为 Tapas
class TFTapasSelfAttention(keras.layers.Layer):
    def __init__(self, config: TapasConfig, **kwargs):
        super().__init__(**kwargs)

        # 检查隐藏层大小是否能被注意力头数整除
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number "
                f"of attention heads ({config.num_attention_heads})"
            )

        # 初始化参数
        self.num_attention_heads = config.num_attention_heads  # 注意力头的数量
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)  # 每个注意力头的大小
        self.all_head_size = self.num_attention_heads * self.attention_head_size  # 所有头部的总大小
        self.sqrt_att_head_size = math.sqrt(self.attention_head_size)  # 注意力头大小的平方根

        # 创建查询、键和值的全连接层
        self.query = keras.layers.Dense(
            units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
        )
        self.key = keras.layers.Dense(
            units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
        )
        self.value = keras.layers.Dense(
            units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
        )
        self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)  # 注意力概率的dropout

        self.is_decoder = config.is_decoder  # 是否为解码器
        self.config = config  # 配置信息

    def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
        # 将张量重塑从 [batch_size, seq_length, all_head_size] 到 [batch_size, seq_length, num_attention_heads, attention_head_size]
        tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))

        # 转置张量从 [batch_size, seq_length, num_attention_heads, attention_head_size] 到 [batch_size, num_attention_heads, seq_length, attention_head_size]
        return tf.transpose(tensor, perm=[0, 2, 1, 3])

    def call(
        self,
        hidden_states: tf.Tensor,
        attention_mask: tf.Tensor,
        head_mask: tf.Tensor,
        encoder_hidden_states: tf.Tensor,
        encoder_attention_mask: tf.Tensor,
        past_key_value: Tuple[tf.Tensor],
        output_attentions: bool,
        training: bool = False,
    ):
        # 该函数定义层的正向传播逻辑,包括自注意力机制和可选的输出注意力权重
        # (此处省略具体实现细节,不在注释内展开)

    def build(self, input_shape=None):
        # 如果已经构建过,则直接返回
        if self.built:
            return
        self.built = True

        # 构建查询、键和值的全连接层
        if getattr(self, "query", None) is not None:
            with tf.name_scope(self.query.name):
                self.query.build([None, None, self.config.hidden_size])
        if getattr(self, "key", None) is not None:
            with tf.name_scope(self.key.name):
                self.key.build([None, None, self.config.hidden_size])
        if getattr(self, "value", None) is not None:
            with tf.name_scope(self.value.name):
                self.value.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Tapas
class TFTapasSelfOutput(keras.layers.Layer):
    def __init__(self, config: TapasConfig, **kwargs):
        super().__init__(**kwargs)

        # 创建一个全连接层,用于变换隐藏状态的维度
        self.dense = keras.layers.Dense(
            units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        # 创建一个 LayerNormalization 层,用于归一化隐藏状态
        self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        # 创建一个 Dropout 层,用于在训练时随机失活部分神经元
        self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
        self.config = config

    def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
        # 将隐藏状态传入全连接层进行维度变换
        hidden_states = self.dense(inputs=hidden_states)
        # 在训练时对输出使用 Dropout
        hidden_states = self.dropout(inputs=hidden_states, training=training)
        # 使用 LayerNormalization 层归一化隐藏状态并与输入张量相加
        hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)

        return hidden_states

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 构建全连接层,设置输入维度为 config.hidden_size
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                self.dense.build([None, None, self.config.hidden_size])
        # 构建 LayerNormalization 层,设置输入维度为 config.hidden_size
        if getattr(self, "LayerNorm", None) is not None:
            with tf.name_scope(self.LayerNorm.name):
                self.LayerNorm.build([None, None, self.config.hidden_size])


# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->Tapas
class TFTapasAttention(keras.layers.Layer):
    def __init__(self, config: TapasConfig, **kwargs):
        super().__init__(**kwargs)

        # 创建自注意力层对象
        self.self_attention = TFTapasSelfAttention(config, name="self")
        # 创建输出层对象
        self.dense_output = TFTapasSelfOutput(config, name="output")

    def prune_heads(self, heads):
        # 暂未实现剪枝功能
        raise NotImplementedError

    def call(
        self,
        input_tensor: tf.Tensor,
        attention_mask: tf.Tensor,
        head_mask: tf.Tensor,
        encoder_hidden_states: tf.Tensor,
        encoder_attention_mask: tf.Tensor,
        past_key_value: Tuple[tf.Tensor],
        output_attentions: bool,
        training: bool = False,
    ) -> Tuple[tf.Tensor]:
        # 调用自注意力层处理输入
        self_outputs = self.self_attention(
            hidden_states=input_tensor,
            attention_mask=attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            training=training,
        )
        # 将自注意力层的输出传递给输出层处理
        attention_output = self.dense_output(
            hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
        )
        # 如果需要输出注意力权重或过去键值对,则添加到输出元组中
        outputs = (attention_output,) + self_outputs[1:]

        return outputs
    # 定义神经网络层的构建方法,当输入形状为None时表示使用默认形状
    def build(self, input_shape=None):
        # 如果已经构建过,直接返回,避免重复构建
        if self.built:
            return
        # 标记该层已经构建
        self.built = True
        
        # 如果存在self_attention属性,执行以下操作
        if getattr(self, "self_attention", None) is not None:
            # 使用self_attention的名字作为命名空间,开始构建self_attention
            with tf.name_scope(self.self_attention.name):
                self.self_attention.build(None)
        
        # 如果存在dense_output属性,执行以下操作
        if getattr(self, "dense_output", None) is not None:
            # 使用dense_output的名字作为命名空间,开始构建dense_output
            with tf.name_scope(self.dense_output.name):
                self.dense_output.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Tapas
class TFTapasIntermediate(keras.layers.Layer):
    def __init__(self, config: TapasConfig, **kwargs):
        super().__init__(**kwargs)

        # 定义一个全连接层,输出大小为 config.intermediate_size,使用指定的初始化器初始化权重
        self.dense = keras.layers.Dense(
            units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )

        # 根据配置中的 hidden_act 参数,获取中间激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = get_tf_activation(config.hidden_act)
        else:
            self.intermediate_act_fn = config.hidden_act
        self.config = config

    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        # 将输入的 hidden_states 输入到全连接层 dense 中
        hidden_states = self.dense(inputs=hidden_states)
        # 使用中间激活函数处理全连接层的输出
        hidden_states = self.intermediate_act_fn(hidden_states)

        return hidden_states

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                # 构建全连接层 dense,输入大小为 [None, None, self.config.hidden_size]
                self.dense.build([None, None, self.config.hidden_size])


# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Tapas
class TFTapasOutput(keras.layers.Layer):
    def __init__(self, config: TapasConfig, **kwargs):
        super().__init__(**kwargs)

        # 定义一个全连接层,输出大小为 config.hidden_size,使用指定的初始化器初始化权重
        self.dense = keras.layers.Dense(
            units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        # 定义 LayerNormalization 层,epsilon 参数为 config.layer_norm_eps
        self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        # 定义 Dropout 层,丢弃率为 config.hidden_dropout_prob
        self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
        self.config = config

    def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
        # 将输入的 hidden_states 输入到全连接层 dense 中
        hidden_states = self.dense(inputs=hidden_states)
        # 在训练时应用 dropout
        hidden_states = self.dropout(inputs=hidden_states, training=training)
        # LayerNormalization 处理全连接层的输出并加上输入 tensor input_tensor
        hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)

        return hidden_states

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                # 构建全连接层 dense,输入大小为 [None, None, self.config.intermediate_size]
                self.dense.build([None, None, self.config.intermediate_size])
        if getattr(self, "LayerNorm", None) is not None:
            with tf.name_scope(self.LayerNorm.name):
                # 构建 LayerNormalization 层,输入大小为 [None, None, self.config.hidden_size]
                self.LayerNorm.build([None, None, self.config.hidden_size])


# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->Tapas
class TFTapasLayer(keras.layers.Layer):
    # 留待实现
    pass
    # 初始化方法,用于创建一个 Tapas 模型对象
    def __init__(self, config: TapasConfig, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)

        # 创建一个 TapasAttention 层对象,并命名为 "attention"
        self.attention = TFTapasAttention(config, name="attention")
        
        # 从配置中获取是否为解码器模型的标志
        self.is_decoder = config.is_decoder
        
        # 从配置中获取是否添加跨注意力的标志
        self.add_cross_attention = config.add_cross_attention
        
        # 如果要添加跨注意力,并且当前模型不是解码器模型,则抛出错误
        if self.add_cross_attention:
            if not self.is_decoder:
                raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
            
            # 创建一个 TapasAttention 层对象用于跨注意力,并命名为 "crossattention"
            self.crossattention = TFTapasAttention(config, name="crossattention")
        
        # 创建一个 TapasIntermediate 层对象,并命名为 "intermediate"
        self.intermediate = TFTapasIntermediate(config, name="intermediate")
        
        # 创建一个 TapasOutput 层对象,并命名为 "output"
        self.bert_output = TFTapasOutput(config, name="output")

    # 模型调用方法,用于执行模型的前向传播
    def call(
        self,
        hidden_states: tf.Tensor,  # 输入的隐藏状态张量
        attention_mask: tf.Tensor,  # 注意力掩码张量
        head_mask: tf.Tensor,  # 头部掩码张量
        encoder_hidden_states: tf.Tensor | None,  # 编码器的隐藏状态张量(可选)
        encoder_attention_mask: tf.Tensor | None,  # 编码器的注意力掩码张量(可选)
        past_key_value: Tuple[tf.Tensor] | None,  # 过去的键值元组(可选)
        output_attentions: bool,  # 是否输出注意力权重
        training: bool = False,  # 是否处于训练模式,默认为 False
    ) -> Tuple[tf.Tensor]:
        # 定义函数的返回类型为包含单个 TensorFlow 张量的元组
        # 如果有过去的键/值信息,则仅保留解码器自注意力部分的前两个位置
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
        # 使用自注意力层处理隐藏状态,生成自注意力的输出
        self_attention_outputs = self.attention(
            input_tensor=hidden_states,
            attention_mask=attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            past_key_value=self_attn_past_key_value,
            output_attentions=output_attentions,
            training=training,
        )
        # 提取自注意力输出中的第一个元素作为注意力输出
        attention_output = self_attention_outputs[0]

        # 如果是解码器模式,则最后一个输出为自注意力缓存的元组
        if self.is_decoder:
            # 输出中排除最后一个元素(自注意力缓存),其余部分为网络层输出
            outputs = self_attention_outputs[1:-1]
            # 提取当前的键/值信息作为解码器的最新键/值信息
            present_key_value = self_attention_outputs[-1]
        else:
            # 输出中排除第一个元素(自注意力输出),保留其余部分(可能包含注意力权重)
            outputs = self_attention_outputs[1:]
        
        cross_attn_present_key_value = None
        # 如果是解码器且有编码器的隐藏状态作为输入
        if self.is_decoder and encoder_hidden_states is not None:
            # 如果模型未包含交叉注意力层,抛出错误
            if not hasattr(self, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
                    " by setting `config.add_cross_attention=True`"
                )
            
            # 从过去的键/值信息中提取交叉注意力层的键/值信息
            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
            # 使用交叉注意力层处理自注意力输出,生成交叉注意力的输出
            cross_attention_outputs = self.crossattention(
                input_tensor=attention_output,
                attention_mask=attention_mask,
                head_mask=head_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                past_key_value=cross_attn_past_key_value,
                output_attentions=output_attentions,
                training=training,
            )
            # 提取交叉注意力输出中的第一个元素作为注意力输出
            attention_output = cross_attention_outputs[0]
            # 将交叉注意力输出中的第二个到倒数第二个元素添加到输出中(可能包含注意力权重)
            outputs = outputs + cross_attention_outputs[1:-1]

            # 将交叉注意力输出中的最后一个元素添加到当前键/值信息中
            cross_attn_present_key_value = cross_attention_outputs[-1]
            present_key_value = present_key_value + cross_attn_present_key_value
        
        # 使用中间层处理注意力输出,生成中间层输出
        intermediate_output = self.intermediate(hidden_states=attention_output)
        # 使用BERT输出层处理中间层输出和注意力输出,生成网络层输出
        layer_output = self.bert_output(
            hidden_states=intermediate_output, input_tensor=attention_output, training=training
        )
        # 将网络层输出添加到输出元组中
        outputs = (layer_output,) + outputs
        
        # 如果是解码器模式,将当前的键/值信息作为最后一个输出添加到输出元组中
        if self.is_decoder:
            outputs = outputs + (present_key_value,)
        
        # 返回最终的输出元组
        return outputs
    # 构建方法,用于构造模型
    def build(self, input_shape=None):
        # 如果模型已经构建完成,则直接返回
        if self.built:
            return
        # 将模型标记为已构建状态
        self.built = True
        
        # 如果存在 self.attention 属性,则构建注意力层
        if getattr(self, "attention", None) is not None:
            # 使用注意力层的名称作为命名空间,构建注意力层
            with tf.name_scope(self.attention.name):
                self.attention.build(None)
        
        # 如果存在 self.intermediate 属性,则构建中间层
        if getattr(self, "intermediate", None) is not None:
            # 使用中间层的名称作为命名空间,构建中间层
            with tf.name_scope(self.intermediate.name):
                self.intermediate.build(None)
        
        # 如果存在 self.bert_output 属性,则构建 BERT 输出层
        if getattr(self, "bert_output", None) is not None:
            # 使用 BERT 输出层的名称作为命名空间,构建 BERT 输出层
            with tf.name_scope(self.bert_output.name):
                self.bert_output.build(None)
        
        # 如果存在 self.crossattention 属性,则构建交叉注意力层
        if getattr(self, "crossattention", None) is not None:
            # 使用交叉注意力层的名称作为命名空间,构建交叉注意力层
            with tf.name_scope(self.crossattention.name):
                self.crossattention.build(None)
# 从 transformers.models.bert.modeling_tf_bert.TFBertEncoder 复制并修改为 Tapas 模型的编码器类 TFTapasEncoder
class TFTapasEncoder(keras.layers.Layer):
    # 初始化方法,接受 TapasConfig 类型的配置参数 config
    def __init__(self, config: TapasConfig, **kwargs):
        super().__init__(**kwargs)
        # 保存传入的配置参数
        self.config = config
        # 创建多个 Tapas 层组成的列表,每一层命名为 "layer_._{i}"
        self.layer = [TFTapasLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]

    # 调用方法,定义了模型的前向传播逻辑
    def call(
        self,
        hidden_states: tf.Tensor,                          # 输入的隐藏状态张量
        attention_mask: tf.Tensor,                         # 注意力掩码张量
        head_mask: tf.Tensor,                              # 头部掩码张量
        encoder_hidden_states: tf.Tensor | None,           # 编码器的隐藏状态张量或空值
        encoder_attention_mask: tf.Tensor | None,           # 编码器的注意力掩码张量或空值
        past_key_values: Tuple[Tuple[tf.Tensor]] | None,    # 历史键值对的元组或空值
        use_cache: Optional[bool],                         # 是否使用缓存的布尔值,可选
        output_attentions: bool,                           # 是否输出注意力张量的布尔值
        output_hidden_states: bool,                        # 是否输出隐藏状态的布尔值
        return_dict: bool,                                 # 是否返回字典类型的布尔值
        training: bool = False,                            # 是否处于训练模式的布尔值,默认为 False
    ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
        # 初始化存储所有隐藏状态、注意力张量和跨层注意力张量的元组
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

        # 如果使用缓存,则初始化下一个解码器缓存的元组
        next_decoder_cache = () if use_cache else None

        # 遍历每一层 Tapas 层
        for i, layer_module in enumerate(self.layer):
            # 如果需要输出隐藏状态,则将当前隐藏状态添加到所有隐藏状态的元组中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 获取当前层的历史键值对
            past_key_value = past_key_values[i] if past_key_values is not None else None

            # 调用当前 Tapas 层的前向传播方法,获取该层的输出
            layer_outputs = layer_module(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                head_mask=head_mask[i],
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                training=training,
            )
            # 更新隐藏状态为当前层的输出的第一个元素
            hidden_states = layer_outputs[0]

            # 如果使用缓存,将当前层的输出的最后一个元素添加到下一个解码器缓存的元组中
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)

            # 如果输出注意力张量,则将当前层的输出的第二个元素添加到所有注意力张量的元组中
            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)
                # 如果配置要求添加跨层注意力,并且编码器的隐藏状态不为空,则将当前层的输出的第三个元素添加到所有跨层注意力张量的元组中
                if self.config.add_cross_attention and encoder_hidden_states is not None:
                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

        # 如果需要输出隐藏状态,则将最后一层的隐藏状态添加到所有隐藏状态的元组中
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 如果不返回字典类型的结果,则返回所有非空的张量组成的元组
        if not return_dict:
            return tuple(
                v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
            )

        # 返回 TFBaseModelOutputWithPastAndCrossAttentions 类型的字典结果
        return TFBaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
            cross_attentions=all_cross_attentions,
        )
    # 定义神经网络模型的构建方法,接受输入形状参数,默认为None
    def build(self, input_shape=None):
        # 如果模型已经构建过,则直接返回,不再重复构建
        if self.built:
            return
        # 将模型标记为已构建状态
        self.built = True
        # 检查是否存在名为"layer"的属性
        if getattr(self, "layer", None) is not None:
            # 遍历模型中的每一层
            for layer in self.layer:
                # 使用层的名称为当前层创建一个命名空间
                with tf.name_scope(layer.name):
                    # 调用每一层的build方法,传入输入形状参数为None,表示根据需要自动确定输入形状
                    layer.build(None)
# 从transformers.models.bert.modeling_tf_bert.TFBertPooler复制并修改为Tapas
class TFTapasPooler(keras.layers.Layer):
    def __init__(self, config: TapasConfig, **kwargs):
        super().__init__(**kwargs)

        # 创建一个全连接层,用于池化模型的隐藏状态,输出维度为config.hidden_size
        self.dense = keras.layers.Dense(
            units=config.hidden_size,
            kernel_initializer=get_initializer(config.initializer_range),
            activation="tanh",
            name="dense",
        )
        self.config = config

    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        # 通过仅仅取第一个标记的隐藏状态来“池化”模型
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(inputs=first_token_tensor)

        return pooled_output

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                # 构建dense层,输入维度为[None, None, self.config.hidden_size]
                self.dense.build([None, None, self.config.hidden_size])


# 从transformers.models.bert.modeling_tf_bert.TFBertPredictionHeadTransform复制并修改为Tapas
class TFTapasPredictionHeadTransform(keras.layers.Layer):
    def __init__(self, config: TapasConfig, **kwargs):
        super().__init__(**kwargs)

        # 创建一个全连接层,输出维度为config.hidden_size
        self.dense = keras.layers.Dense(
            units=config.hidden_size,
            kernel_initializer=get_initializer(config.initializer_range),
            name="dense",
        )

        # 根据config.hidden_act初始化激活函数
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = get_tf_activation(config.hidden_act)
        else:
            self.transform_act_fn = config.hidden_act

        # 创建LayerNormalization层,epsilon值为config.layer_norm_eps
        self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        self.config = config

    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        # 通过dense层处理hidden_states
        hidden_states = self.dense(inputs=hidden_states)
        # 应用激活函数transform_act_fn
        hidden_states = self.transform_act_fn(hidden_states)
        # 应用LayerNormalization层
        hidden_states = self.LayerNorm(inputs=hidden_states)

        return hidden_states

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                # 构建dense层,输入维度为[None, None, self.config.hidden_size]
                self.dense.build([None, None, self.config.hidden_size])
        if getattr(self, "LayerNorm", None) is not None:
            with tf.name_scope(self.LayerNorm.name):
                # 构建LayerNorm层,输入维度为[None, None, self.config.hidden_size]
                self.LayerNorm.build([None, None, self.config.hidden_size])


# 从transformers.models.bert.modeling_tf_bert.TFBertLMPredictionHead复制并修改为Tapas
class TFTapasLMPredictionHead(keras.layers.Layer):
    # 使用 TapasConfig 和输入的嵌入层初始化模型
    def __init__(self, config: TapasConfig, input_embeddings: keras.layers.Layer, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)

        # 保存配置对象和隐藏层大小
        self.config = config
        self.hidden_size = config.hidden_size

        # 创建预测头转换层,用于处理模型的输出
        self.transform = TFTapasPredictionHeadTransform(config, name="transform")

        # 输入嵌入层是模型的输入
        self.input_embeddings = input_embeddings

    # 构建模型
    def build(self, input_shape=None):
        # 添加偏置项,初始化为全零向量
        self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")

        # 如果已经构建过模型,直接返回
        if self.built:
            return
        self.built = True
        
        # 如果存在转换层,构建转换层
        if getattr(self, "transform", None) is not None:
            with tf.name_scope(self.transform.name):
                self.transform.build(None)

    # 获取输出嵌入层
    def get_output_embeddings(self) -> keras.layers.Layer:
        return self.input_embeddings

    # 设置输出嵌入层的权重和词汇表大小
    def set_output_embeddings(self, value: tf.Variable):
        self.input_embeddings.weight = value
        self.input_embeddings.vocab_size = shape_list(value)[0]

    # 获取偏置项
    def get_bias(self) -> Dict[str, tf.Variable]:
        return {"bias": self.bias}

    # 设置偏置项
    def set_bias(self, value: tf.Variable):
        self.bias = value["bias"]
        # 更新配置对象的词汇表大小
        self.config.vocab_size = shape_list(value["bias"])[0]

    # 模型的调用方法,进行前向传播
    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        # 使用转换层处理隐藏状态
        hidden_states = self.transform(hidden_states=hidden_states)
        
        # 获取序列长度
        seq_length = shape_list(hidden_states)[1]
        
        # 将隐藏状态重新形状为二维张量
        hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size])
        
        # 矩阵乘法计算输出
        hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
        
        # 将输出重新形状为三维张量
        hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
        
        # 添加偏置项到输出
        hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)

        # 返回处理后的隐藏状态
        return hidden_states
# 从transformers.models.bert.modeling_tf_bert.TFBertMLMHead复制而来,将Bert替换为Tapas
class TFTapasMLMHead(keras.layers.Layer):
    def __init__(self, config: TapasConfig, input_embeddings: keras.layers.Layer, **kwargs):
        super().__init__(**kwargs)

        # 使用TapasLMPredictionHead类创建predictions对象
        self.predictions = TFTapasLMPredictionHead(config, input_embeddings, name="predictions")

    # 调用函数,根据输入的sequence_output计算预测分数prediction_scores
    def call(self, sequence_output: tf.Tensor) -> tf.Tensor:
        prediction_scores = self.predictions(hidden_states=sequence_output)

        return prediction_scores

    # 构建函数,在第一次调用时建立层的内部结构
    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果存在predictions对象,则在其命名空间下建立结构
        if getattr(self, "predictions", None) is not None:
            with tf.name_scope(self.predictions.name):
                self.predictions.build(None)


# keras_serializable装饰器用于声明TFTapasMainLayer类是可序列化的
@keras_serializable
class TFTapasMainLayer(keras.layers.Layer):
    config_class = TapasConfig

    def __init__(self, config: TapasConfig, add_pooling_layer: bool = True, **kwargs):
        # 调用父类的初始化函数,并添加tensorflow_probability作为后端库的要求
        requires_backends(self, "tensorflow_probability")
        super().__init__(**kwargs)

        self.config = config

        # 创建TFTapasEmbeddings对象作为embeddings
        self.embeddings = TFTapasEmbeddings(config, name="embeddings")
        # 创建TFTapasEncoder对象作为encoder
        self.encoder = TFTapasEncoder(config, name="encoder")
        # 如果add_pooling_layer为True,则创建TFTapasPooler对象作为pooler
        self.pooler = TFTapasPooler(config, name="pooler") if add_pooling_layer else None

    # 返回embeddings对象,用于获取输入的嵌入层
    def get_input_embeddings(self) -> keras.layers.Layer:
        return self.embeddings

    # 设置输入的嵌入层的权重值为value,并更新vocab_size属性
    def set_input_embeddings(self, value: tf.Variable):
        self.embeddings.weight = value
        self.embeddings.vocab_size = shape_list(value)[0]

    # _prune_heads函数用于修剪模型中的注意力头
    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        raise NotImplementedError

    # call函数,接收多个输入参数,并根据配置调用embeddings、encoder和pooler的对应方法
    @unpack_inputs
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        token_type_ids: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: bool = False,
    ):
        # 这里是call函数的具体实现,根据输入参数调用对应的功能

    # 构建函数,在第一次调用时建立层的内部结构
    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果存在embeddings对象,则在其命名空间下建立结构
        if getattr(self, "embeddings", None) is not None:
            with tf.name_scope(self.embeddings.name):
                self.embeddings.build(None)
        # 如果存在encoder对象,则在其命名空间下建立结构
        if getattr(self, "encoder", None) is not None:
            with tf.name_scope(self.encoder.name):
                self.encoder.build(None)
        # 如果存在pooler对象,则在其命名空间下建立结构
        if getattr(self, "pooler", None) is not None:
            with tf.name_scope(self.pooler.name):
                self.pooler.build(None)
"""
    This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
    as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
    behavior.

    <Tip>

    TensorFlow models and layers in `transformers` accept two formats as input:

    - having all inputs as keyword arguments (like PyTorch models), or
    - having all inputs as a list, tuple or dict in the first positional argument.

    The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
    and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
    pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
    format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
    the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
    positional argument:

    - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
    - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
    `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
    - a dictionary with one or several input Tensors associated to the input names given in the docstring:
    `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`

    Note that when creating models and layers with
    [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
    about any of this, as you can just pass inputs like you would to any other Python function!

    </Tip>
"""
    Parameters:
        config ([`TapasConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

TAPAS_INPUTS_DOCSTRING = r"""
"""


@add_start_docstrings(
    "The bare Tapas Model transformer outputting raw hidden-states without any specific head on top.",
    TAPAS_START_DOCSTRING,
)
class TFTapasModel(TFTapasPreTrainedModel):
    def __init__(self, config: TapasConfig, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        # 初始化Tapas主层,使用给定的配置参数
        self.tapas = TFTapasMainLayer(config, name="tapas")

    @unpack_inputs
    @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        token_type_ids: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: Optional[bool] = False,
    ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
        r"""
        前向传播函数,接受多种输入参数,并返回模型的输出。

        Args:
            input_ids: 输入的token IDs。
            attention_mask: 注意力掩码,指示哪些位置是padding的。
            token_type_ids: token类型IDs,用于区分segment。
            position_ids: 位置IDs,用于指定每个token在文本中的位置。
            head_mask: 头部掩码,用于指定哪些注意力头部被屏蔽。
            inputs_embeds: 嵌入的输入张量。
            output_attentions: 是否输出注意力权重。
            output_hidden_states: 是否输出所有隐藏状态。
            return_dict: 是否返回字典格式的输出。
            training: 是否为训练模式。

        Returns:
            模型的输出,可以是包含池化的基础模型输出或者张量的元组。

        Examples:
        
        ```
        >>> from transformers import AutoTokenizer, TapasModel
        >>> import pandas as pd

        >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base")
        >>> model = TapasModel.from_pretrained("google/tapas-base")

        >>> data = {
        ...     "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
        ...     "Age": ["56", "45", "59"],
        ...     "Number of movies": ["87", "53", "69"],
        ... }
        >>> table = pd.DataFrame.from_dict(data)
        >>> queries = ["How many movies has George Clooney played in?", "How old is Brad Pitt?"]

        >>> inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="tf")
        >>> outputs = model(**inputs)

        >>> last_hidden_states = outputs.last_hidden_state
        ```
        """
        # 调用Tapas主层处理输入,返回处理结果
        outputs = self.tapas(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        return outputs

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "tapas", None) is not None:
            with tf.name_scope(self.tapas.name):
                self.tapas.build(None)


@add_start_docstrings("""Tapas Model with a `language modeling` head on top.""", TAPAS_START_DOCSTRING)
# TapasForMaskedLM 类继承自 TFTapasPreTrainedModel 和 TFMaskedLanguageModelingLoss,用于处理 Tapas 模型的 Masked Language Modeling 任务
class TFTapasForMaskedLM(TFTapasPreTrainedModel, TFMaskedLanguageModelingLoss):
    
    # 初始化方法,接受一个 TapasConfig 对象和额外的输入参数
    def __init__(self, config: TapasConfig, *inputs, **kwargs):
        # 调用父类的初始化方法,传入配置和其他输入参数
        super().__init__(config, *inputs, **kwargs)
        
        # 如果配置指定为 decoder,发出警告,建议将 `config.is_decoder` 设为 False,以支持双向自注意力
        if config.is_decoder:
            logger.warning(
                "If you want to use `TFTapasForMaskedLM` make sure `config.is_decoder=False` for "
                "bi-directional self-attention."
            )
        
        # 创建 Tapas 主层对象,关闭添加池化层选项,命名为 "tapas"
        self.tapas = TFTapasMainLayer(config, add_pooling_layer=False, name="tapas")
        
        # 创建 Tapas MLM 头部对象,传入输入嵌入层为 tapas 的嵌入层,命名为 "cls"
        self.lm_head = TFTapasMLMHead(config, input_embeddings=self.tapas.embeddings, name="cls")
    
    # 获取 MLM 头部的方法,返回 lm_head 的 predictions 属性
    def get_lm_head(self) -> keras.layers.Layer:
        return self.lm_head.predictions
    
    # call 方法,处理模型的前向传播
    @unpack_inputs
    @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        token_type_ids: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: np.ndarray | tf.Tensor | None = None,
        training: Optional[bool] = False,
        # 带有注释的参数列表,定义输入和控制模型行为的选项
        r"""
        labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        """
        outputs = self.tapas(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )
        # 获取模型的序列输出
        sequence_output = outputs[0]
        # 将序列输出传递给语言模型头部以预测分数
        prediction_scores = self.lm_head(sequence_output)
        # 如果提供了标签,则计算损失;否则损失为 None
        loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)

        # 如果 return_dict 为 False,则组织输出格式
        if not return_dict:
            output = (prediction_scores,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 构建 TFMaskedLMOutput 对象,封装损失、预测分数、隐藏状态和注意力权重
        return TFMaskedLMOutput(
            loss=loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def build(self, input_shape=None):
        # 如果模型已经构建,则直接返回
        if self.built:
            return
        # 标记模型已经构建
        self.built = True
        # 如果存在 Tapas 模型,则构建 Tapas 模型
        if getattr(self, "tapas", None) is not None:
            with tf.name_scope(self.tapas.name):
                self.tapas.build(None)
        # 如果存在语言模型头部,则构建语言模型头部
        if getattr(self, "lm_head", None) is not None:
            with tf.name_scope(self.lm_head.name):
                self.lm_head.build(None)
# 定义一个 TensorFlow 自定义层 TFTapasComputeTokenLogits,用于计算每个标记的逻辑回归结果
class TFTapasComputeTokenLogits(keras.layers.Layer):
    def __init__(self, config: TapasConfig, **kwargs):
        super().__init__(**kwargs)

        # 从配置中获取温度参数,用于调节逻辑回归的温度
        self.temperature = config.temperature

        # 定义输出层权重和偏置,这些权重用于计算逻辑回归
        with tf.name_scope("output"):
            self.output_weights = self.add_weight(
                name="output_weights",
                shape=(config.hidden_size,),
                dtype=tf.float32,
                trainable=True,
                # 根据配置选择初始化输出权重为零或截断正态分布
                initializer=tf.zeros_initializer()
                if config.init_cell_selection_weights_to_zero
                else keras.initializers.TruncatedNormal(stddev=config.initializer_range),
            )
            self.output_bias = self.add_weight(
                name="output_bias", shape=(), trainable=True, initializer=tf.zeros_initializer()
            )

    # 定义调用函数,输入是序列输出张量,输出是每个标记的逻辑回归结果张量
    def call(self, sequence_output: tf.Tensor) -> tf.Tensor:
        """
        计算每个标记的逻辑回归结果

        Args:
            sequence_output (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
                也称为 last_hidden_state。模型最后一层的隐藏状态序列输出。

        Returns:
            logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): 每个标记的逻辑回归结果。
        """
        # 计算逻辑回归结果,通过张量乘法和偏置,然后除以温度参数
        logits = (tf.einsum("bsj,j->bs", sequence_output, self.output_weights) + self.output_bias) / self.temperature
        return logits


# 定义另一个 TensorFlow 自定义层 TFTapasComputeColumnLogits,用于计算每列的逻辑回归结果
class TFTapasComputeColumnLogits(keras.layers.Layer):
    def __init__(self, config: TapasConfig, **kwargs):
        super().__init__(**kwargs)

        # 定义列输出层的权重和偏置,用于计算列的逻辑回归
        with tf.name_scope("column_output"):
            self.column_output_weights = self.add_weight(
                name="column_output_weights",
                shape=[config.hidden_size],
                dtype=tf.float32,
                trainable=True,
                # 根据配置选择初始化输出权重为零或截断正态分布
                initializer=tf.zeros_initializer()
                if config.init_cell_selection_weights_to_zero
                else keras.initializers.TruncatedNormal(stddev=config.initializer_range),
            )
            self.column_output_bias = self.add_weight(
                name="column_output_bias", shape=(), trainable=True, initializer=tf.zeros_initializer()
            )
    # 计算列的逻辑回归结果

    # 首先,计算没有温度调节的令牌逻辑回归结果 (batch_size, seq_len)
    token_logits = tf.einsum("bsj,j->bs", sequence_output, self.column_output_weights) + self.column_output_bias

    # 接下来,对每个单元格平均逻辑回归结果 (batch_size, max_num_cols*max_num_rows)
    cell_logits, cell_logits_index = reduce_mean(token_logits, cell_index)

    # 最后,对每列平均逻辑回归结果 (batch_size, max_num_cols)
    column_index = cell_index.project_inner(cell_logits_index)
    column_logits, out_index = reduce_sum(cell_logits * cell_mask, column_index)

    # 计算每列的单元格数目,避免零除错误
    cell_count, _ = reduce_sum(cell_mask, column_index)
    column_logits /= cell_count + EPSILON_ZERO_DIVISION

    # 掩盖不出现在示例中的列
    is_padding = tf.logical_and(cell_count < 0.5, tf.not_equal(out_index.indices, 0))
    column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * tf.cast(is_padding, tf.float32)

    # 如果不允许选择空列,进一步掩盖选择了空列的情况
    if not allow_empty_column_selection:
        column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * tf.cast(tf.equal(out_index.indices, 0), tf.float32)

    return column_logits
@add_start_docstrings(
    """
    Tapas Model with a cell selection head and optional aggregation head on top for question-answering tasks on tables
    (linear layers on top of the hidden-states output to compute `logits` and optional `logits_aggregation`), e.g. for
    SQA, WTQ or WikiSQL-supervised tasks.
    """,
    TAPAS_START_DOCSTRING,
)
class TFTapasForQuestionAnswering(TFTapasPreTrainedModel):
    def __init__(self, config: TapasConfig, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        # base model
        self.tapas = TFTapasMainLayer(config, name="tapas")

        # dropout
        self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)

        # initialize layer to compute token-level logits
        self.compute_token_logits = TFTapasComputeTokenLogits(config, name="compute_token_logits")

        # initialize layer to compute column-level logits
        self.compute_column_logits = TFTapasComputeColumnLogits(config, name="compute_column_logits")

        # optional aggregation classifier if specified in the configuration
        if config.num_aggregation_labels > 0:
            self.aggregation_classifier = keras.layers.Dense(
                config.num_aggregation_labels,
                kernel_initializer=get_initializer(config.initializer_range),
                name="aggregation_classifier",
            )
        self.config = config

    @unpack_inputs
    @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=TFTableQuestionAnsweringOutput, config_class=_CONFIG_FOR_DOC)
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        token_type_ids: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        table_mask: np.ndarray | tf.Tensor | None = None,
        aggregation_labels: np.ndarray | tf.Tensor | None = None,
        float_answer: np.ndarray | tf.Tensor | None = None,
        numeric_values: np.ndarray | tf.Tensor | None = None,
        numeric_values_scale: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: np.ndarray | tf.Tensor | None = None,
        training: Optional[bool] = False,
    # 如果已经构建过网络结构,则直接返回,避免重复构建
    if self.built:
        return
    # 设置标志表示网络结构已经构建
    self.built = True

    # 如果存在名为 "tapas" 的属性,则构建其对应的子模型
    if getattr(self, "tapas", None) is not None:
        # 使用 "tapas" 的名称作为命名空间,构建子模型
        with tf.name_scope(self.tapas.name):
            self.tapas.build(None)

    # 如果存在名为 "compute_token_logits" 的属性,则构建其对应的子模型
    if getattr(self, "compute_token_logits", None) is not None:
        # 使用 "compute_token_logits" 的名称作为命名空间,构建子模型
        with tf.name_scope(self.compute_token_logits.name):
            self.compute_token_logits.build(None)

    # 如果存在名为 "compute_column_logits" 的属性,则构建其对应的子模型
    if getattr(self, "compute_column_logits", None) is not None:
        # 使用 "compute_column_logits" 的名称作为命名空间,构建子模型
        with tf.name_scope(self.compute_column_logits.name):
            self.compute_column_logits.build(None)

    # 如果存在名为 "aggregation_classifier" 的属性,则构建其对应的子模型
    if getattr(self, "aggregation_classifier", None) is not None:
        # 使用 "aggregation_classifier" 的名称作为命名空间,构建子模型
        with tf.name_scope(self.aggregation_classifier.name):
            # 构建 "aggregation_classifier" 子模型,输入维度为 [None, None, self.config.hidden_size]
            self.aggregation_classifier.build([None, None, self.config.hidden_size])
# 使用自定义的装饰器添加文档字符串,描述 Tapas 模型用于序列分类任务的结构和功能
@add_start_docstrings(
    """
    Tapas Model with a sequence classification head on top (a linear layer on top of the pooled output), e.g. for table
    entailment tasks, such as TabFact (Chen et al., 2020).
    """,
    TAPAS_START_DOCSTRING,
)
# 定义 TFTapasForSequenceClassification 类,继承自 TFTapasPreTrainedModel 和 TFSequenceClassificationLoss
class TFTapasForSequenceClassification(TFTapasPreTrainedModel, TFSequenceClassificationLoss):
    
    # 初始化方法,接受 TapasConfig 对象和其他输入参数
    def __init__(self, config: TapasConfig, *inputs, **kwargs):
        # 调用父类的初始化方法
        super().__init__(config, *inputs, **kwargs)
        # 设置分类任务的标签数量
        self.num_labels = config.num_labels

        # 创建 Tapas 主层对象,命名为 "tapas"
        self.tapas = TFTapasMainLayer(config, name="tapas")
        # 创建 Dropout 层,使用配置中的隐藏层 Dropout 概率,命名为 "dropout"
        self.dropout = keras.layers.Dropout(config.hidden_dropout_prob, name="dropout")
        # 创建 Dense 层作为分类器,输出大小为标签数量,使用指定的初始化器范围初始化,命名为 "classifier"
        self.classifier = keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )
        # 将配置对象保存在实例中
        self.config = config

    # 使用装饰器将函数添加到模型前向传播路径中,并添加相关文档字符串描述输入格式
    @unpack_inputs
    @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
    @replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
    # 定义模型的前向传播方法,接受多种输入参数
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        token_type_ids: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: np.ndarray | tf.Tensor | None = None,
        training: Optional[bool] = False,
        outputs = self.tapas(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )
        # 调用模型的tapas方法进行前向传播,获取模型输出
        pooled_output = outputs[1]
        # 从模型输出中获取池化后的特征表示
        pooled_output = self.dropout(inputs=pooled_output, training=training)
        # 对池化后的特征表示进行dropout处理
        logits = self.classifier(inputs=pooled_output)
        # 使用分类器对特征表示进行分类得到logits
        loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
        # 如果提供了标签,计算损失函数,否则损失为None

        if not return_dict:
            # 如果不要求返回字典格式的输出
            output = (logits,) + outputs[2:]
            # 组装输出元组,包括logits和可能的其他输出
            return ((loss,) + output) if loss is not None else output
            # 返回包含损失和输出元组的结果,如果损失为None则只返回输出元组

        return TFSequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
        # 返回TFSequenceClassifierOutput对象,包含损失、logits、隐藏状态和注意力信息
    # 构建方法,用于构建模型的各个组件的计算图
    def build(self, input_shape=None):
        # 如果已经构建过,直接返回,避免重复构建
        if self.built:
            return
        # 标记模型已经构建
        self.built = True
        
        # 如果存在 tapas 属性,则构建 tapas 组件的计算图
        if getattr(self, "tapas", None) is not None:
            # 使用 tapas 组件的名称作为命名空间
            with tf.name_scope(self.tapas.name):
                # 调用 tapas 组件的 build 方法,传入 None 作为输入形状
                self.tapas.build(None)
        
        # 如果存在 dropout 属性,则构建 dropout 组件的计算图
        if getattr(self, "dropout", None) is not None:
            # 使用 dropout 组件的名称作为命名空间
            with tf.name_scope(self.dropout.name):
                # 调用 dropout 组件的 build 方法,传入 None 作为输入形状
                self.dropout.build(None)
        
        # 如果存在 classifier 属性,则构建 classifier 组件的计算图
        if getattr(self, "classifier", None) is not None:
            # 使用 classifier 组件的名称作为命名空间
            with tf.name_scope(self.classifier.name):
                # 调用 classifier 组件的 build 方法,传入 [None, None, self.config.hidden_size] 作为输入形状
                self.classifier.build([None, None, self.config.hidden_size])
""" TAPAS utilities."""

# 定义一个枚举类,表示平均近似函数的不同类型
class AverageApproximationFunction(str, enum.Enum):
    RATIO = "ratio"         # 比率
    FIRST_ORDER = "first_order"   # 一阶
    SECOND_ORDER = "second_order"   # 二阶


# 与分段张量相关的所有内容的起点


class IndexMap(object):
    """Index grouping entries within a tensor."""

    def __init__(self, indices, num_segments, batch_dims=0):
        """
        Creates an index.

        Args:
          indices: <int32> Tensor of indices, same shape as `values`.
                  索引的张量,类型为<int32>,形状与`values`相同。
          num_segments: <int32> Scalar tensor, the number of segments. All elements
                        in a batched segmented tensor must have the same number of segments (although many segments can be empty).
                        分段张量的段数,作为一个标量张量。批处理的分段张量中的所有元素必须具有相同数量的段(尽管许多段可以为空)。
          batch_dims: Python integer, the number of batch dimensions. The first
                      `batch_dims` dimensions of a SegmentedTensor are treated as batch dimensions. Segments in different batch
                      elements are always distinct even if they have the same index.
                      批处理维度的数量,作为一个整数。分段张量的前`batch_dims`个维度被视为批处理维度。不同批处理元素中的段始终是不同的,即使它们具有相同的索引。
        """
        self.indices = tf.convert_to_tensor(indices)
        self.num_segments = tf.convert_to_tensor(num_segments)
        self.batch_dims = batch_dims

    def batch_shape(self):
        return tf.shape(self.indices)[: self.batch_dims]


class ProductIndexMap(IndexMap):
    """The product of two indices."""

    def __init__(self, outer_index, inner_index):
        """
        Combines indices i and j into pairs (i, j). The result is an index where each segment (i, j) is the
        intersection of segments i and j. For example if the inputs represent table cells indexed by respectively rows
        and columns the output will be a table indexed by (row, column) pairs, i.e. by cell. The implementation
        combines indices {0, .., n - 1} and {0, .., m - 1} into {0, .., nm - 1}. The output has `num_segments` equal to
        `outer_index.num_segements` * `inner_index.num_segments`.

        Args:
          outer_index: IndexMap.
                      外部索引,类型为IndexMap。
          inner_index: IndexMap, must have the same shape as `outer_index`.
                      内部索引,类型为IndexMap,必须与`outer_index`具有相同的形状。
        """
        if outer_index.batch_dims != inner_index.batch_dims:
            raise ValueError("outer_index.batch_dims and inner_index.batch_dims must be the same.")

        super(ProductIndexMap, self).__init__(
            indices=(
                inner_index.indices
                + outer_index.indices * tf.cast(inner_index.num_segments, inner_index.indices.dtype)
            ),
            num_segments=inner_index.num_segments * outer_index.num_segments,
            batch_dims=inner_index.batch_dims,
        )
        self.outer_index = outer_index
        self.inner_index = inner_index

    def project_outer(self, index):
        """Projects an index with the same index set onto the outer components."""
        return IndexMap(
            indices=tf.math.floordiv(index.indices, self.inner_index.num_segments),
            num_segments=self.outer_index.num_segments,
            batch_dims=index.batch_dims,
        )
    # 定义一个方法 `project_inner`,用于对传入的索引对象进行投影操作
    def project_inner(self, index):
        """Projects an index with the same index set onto the inner components."""
        # 使用 TensorFlow 的数学函数 `floormod` 对索引对象的 indices 属性进行取模运算
        # 以确保索引值不超过内部索引的段数,从而实现投影操作
        return IndexMap(
            indices=tf.math.floormod(index.indices, self.inner_index.num_segments),
            # 设置投影后的索引段数为内部索引的段数
            num_segments=self.inner_index.num_segments,
            # 将索引对象的批次维度(batch_dims)直接传递到新的 IndexMap 对象中
            batch_dims=index.batch_dims,
        )
# 使用 TensorFlow 提供的 tf.gather 函数,根据给定的索引从 values 中收集数据,index.indices 是索引的列表,
# batch_dims 指定批次维度的数量,name 是操作的名称
def gather(values, index, name="segmented_gather"):
    return tf.gather(values, index.indices, batch_dims=index.batch_dims, name=name)


# 将批处理的索引映射压平成一维索引映射。这个操作重新标记段,以保持批处理元素的不同性。
# 第 k 个批处理元素的索引会偏移 `num_segments` * (k - 1)。结果是一个张量,其大小是 `num_segments` 乘以批处理元素的数量。
def flatten(index, name="segmented_flatten"):
    batch_size = tf.reduce_prod(index.batch_shape())
    offset = tf.range(batch_size) * index.num_segments
    offset = tf.reshape(offset, index.batch_shape())
    for _ in range(index.batch_dims, index.indices.shape.rank):
        offset = tf.expand_dims(offset, -1)

    indices = tf.cast(offset, index.indices.dtype) + index.indices
    return IndexMap(indices=tf.reshape(indices, [-1]), num_segments=index.num_segments * batch_size, batch_dims=0)


# 构造一个索引映射,其值等于 range(num_segments)。
def range_index_map(batch_shape, num_segments, name="range_index_map"):
    batch_shape = tf.convert_to_tensor(batch_shape)
    batch_shape.shape.assert_has_rank(1)
    num_segments = tf.convert_to_tensor(num_segments)
    num_segments.shape.assert_has_rank(0)

    indices = tf.range(num_segments)
    shape = tf.concat([tf.ones_like(batch_shape, dtype=tf.int32), tf.expand_dims(num_segments, axis=0)], axis=0)
    indices = tf.reshape(indices, shape)
    multiples = tf.concat([batch_shape, [1]], axis=0)
    indices = tf.tile(indices, multiples)
    return IndexMap(indices=indices, num_segments=num_segments, batch_dims=batch_shape.shape.as_list()[0])


# 应用段内的分段减少功能。
# 此函数尚未完全定义,将在后续代码中继续定义。
def _segment_reduce(values, index, segment_reduce_fn, name):
    """
    Args:
        values (`tf.Tensor`):
            Tensor with segment values.  # 输入参数,包含分段数值的张量
        index (`IndexMap`):
            IndexMap.  # 输入参数,索引映射对象
        segment_reduce_fn (`str`):
            Name for the reduce operation. One of "sum", "mean", "max" or "min".  # 输入参数,指定分段操作的类型,可以是"sum"、"mean"、"max"或"min"
        name (`str`):
            Name for the operation. Currently not used  # 输入参数,操作的名称,目前未使用

    Returns:
        (`IndexMap`): IndexMap of shape batch_shape with elements equal to range(num_segments).
        # 返回值,返回形状为 batch_shape 的 IndexMap 对象,其元素等于范围内的 num_segments。

    """
    # Flatten the batch dimensions, as segments ops do not support batching.
    # However if `values` has extra dimensions to the right keep them
    # unflattened. Segmented ops support vector-valued operations.
    # 将批处理维度展平,因为分段操作不支持批处理。
    # 如果 `values` 右侧有额外的维度,则保持它们不展平。分段操作支持矢量值操作。
    flat_index = flatten(index)
    vector_shape = tf.shape(values)[index.indices.shape.rank :]
    flattened_shape = tf.concat([[-1], vector_shape], axis=0)
    flat_values = tf.reshape(values, flattened_shape)
    segment_means = segment_reduce_fn(
        data=flat_values, segment_ids=flat_index.indices, num_segments=flat_index.num_segments
    )

    # Unflatten the values.
    # 将值重新展开。
    new_shape = tf.concat([index.batch_shape(), [index.num_segments], vector_shape], axis=0)
    output_values = tf.reshape(segment_means, new_shape)
    output_index = range_index_map(index.batch_shape(), index.num_segments)
    return output_values, output_index
def reduce_mean(values, index, name="segmented_reduce_mean"):
    """
    Averages a tensor over its segments. Outputs 0 for empty segments. This operations computes the mean over segments,
    with support for:

      - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices.
      - Vectorization using the last dimension [V1, V2, ...]. If they are present the output will be a mean of vectors
        rather than scalars.
    Only the middle dimensions [I1, ..., Ik] are reduced by the operation.

    Args:
      values: [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..] tensor of values to be
        averaged.
      index: IndexMap [B1, B2, ..., Bn, I1, .., Ik] index defining the segments.
      name: Name for the TensorFlow ops.

    Returns:
      A pair (output_values, output_index) where `output_values` is a tensor of shape [B1, B2, ..., Bn, num_segments,
      V1, V2, ..] and `index` is an IndexMap with shape [B1, B2, ..., Bn, num_segments].
    """
    return _segment_reduce(values, index, tf.math.unsorted_segment_mean, name)



def reduce_sum(values, index, name="segmented_reduce_sum"):
    """
    Sums a tensor over its segments. Outputs 0 for empty segments. This operations computes the sum over segments, with
    support for:

      - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices.
      - Vectorization using the last dimension [V1, V2, ...]. If they are present the output will be a sum of vectors
        rather than scalars.
    Only the middle dimensions [I1, ..., Ik] are reduced by the operation.

    Args:
      values: [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..] tensor of values to be
        averaged.
      index: IndexMap [B1, B2, ..., Bn, I1, .., Ik] index defining the segments.
      name: Name for the TensorFlow ops.

    Returns:
      A pair (output_values, output_index) where `output_values` is a tensor of shape [B1, B2, ..., Bn, num_segments,
      V1, V2, ..] and `index` is an IndexMap with shape [B1, B2, ..., Bn, num_segments].
    """
    return _segment_reduce(values, index, tf.math.unsorted_segment_sum, name)



def reduce_max(values, index, name="segmented_reduce_max"):
    """
    Computes the maximum over segments. This operations computes the maximum over segments, with support for:

      - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices.
      - Vectorization using the last dimension [V1, V2, ...]. If they are present the output will be an element-wise
        maximum of vectors rather than scalars.
    Only the middle dimensions [I1, ..., Ik] are reduced by the operation.

    Args:
      values: [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..] tensor of values to be
        averaged.
      index: IndexMap [B1, B2, ..., Bn, I1, .., Ik] index defining the segments.
      name: Name for the TensorFlow ops.
    """
    # 使用 TensorFlow 的 unsorted_segment_max 函数对给定的 values 和 index 进行分段最大值计算
    return _segment_reduce(values, index, tf.math.unsorted_segment_max, name)
    # 调用私有函数 _segment_reduce,执行分段归约操作,使用 tf.math.unsorted_segment_max 函数进行归约
    # 函数返回一个元组 (output_values, output_index),其中 output_values 是形状为 [B1, B2, ..., Bn, num_segments, V1, V2, ..] 的张量
    # output_index 是形状为 [B1, B2, ..., Bn, num_segments] 的索引映射对象 IndexMap
    return _segment_reduce(values, index, tf.math.unsorted_segment_max, name)
def reduce_min(values, index, name="segmented_reduce_min"):
    """Computes the minimum over segments."""
    # 调用内部函数 _segment_reduce 来实现分段最小值计算,使用 tf.math.unsorted_segment_min 方法
    return _segment_reduce(values, index, tf.math.unsorted_segment_min, name)


def _single_column_cell_selection_loss(token_logits, column_logits, labels, cell_index, col_index, cell_mask):
    """
    Computes the loss for cell selection constrained to a single column. The loss is a hierarchical log-likelihood. The
    model first predicts a column and then selects cells within that column (conditioned on the column). Cells outside
    the selected column are never selected.

    Args:
        token_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
            Tensor containing the logits per token.
        column_logits (`tf.Tensor` of shape `(batch_size, max_num_cols)`):
            Tensor containing the logits per column.
        labels (`tf.Tensor` of shape `(batch_size, sequence_length)`):
            Labels per token.
        cell_index (`ProductIndexMap`):
            Index that groups tokens into cells.
        col_index (`IndexMap`):
            Index that groups tokens into columns.
        cell_mask (`tf.Tensor` of shape `(batch_size, max_num_rows * max_num_cols)`):
            Mask for cells that exist in the table (i.e. that are not padding).

    Returns:
        selection_loss_per_example (`tf.Tensor` of shape `(batch_size,)`): Loss for each example. logits (`tf.Tensor`
        of shape `(batch_size, sequence_length)`): New logits which are only allowed to select cells in a single
        column. Logits outside of the most likely column according to *column_logits* will be set to a very low value
        (such that the probabilities are 0).
    """
    # First find the column we should select. We use the column with maximum
    # number of selected cells.
    labels_per_column, _ = reduce_sum(tf.cast(labels, tf.float32), col_index)
    column_label = tf.argmax(labels_per_column, axis=-1, output_type=tf.int32)
    # Check if there are no selected cells in the column. In that case the model
    # should predict the special column id 0, which means "select nothing".
    no_cell_selected = tf.equal(tf.reduce_max(labels_per_column, axis=-1), 0)
    column_label = tf.where(no_cell_selected, tf.zeros_like(column_label), column_label)

    # Create a categorical distribution based on column logits for loss computation
    column_dist = tfp.distributions.Categorical(logits=column_logits)
    column_loss_per_example = -column_dist.log_prob(column_label)

    # Reduce the labels and logits to per-cell from per-token.
    logits_per_cell, _ = reduce_mean(token_logits, cell_index)
    labels_per_cell, labels_index = reduce_max(tf.cast(labels, tf.int32), cell_index)

    # Mask for the selected column.
    column_id_for_cells = cell_index.project_inner(labels_index).indices
    column_mask = tf.cast(tf.equal(column_id_for_cells, tf.expand_dims(column_label, axis=1)), tf.float32)

    # Compute the log-likelihood for cells, but only for the selected column.
    # 创建一个伯努利分布对象,使用给定的 logits 参数
    cell_dist = tfp.distributions.Bernoulli(logits=logits_per_cell)
    # 计算每个细胞的对数概率,根据标签值
    cell_log_prob = cell_dist.log_prob(labels_per_cell)
    # 计算每个细胞的损失,考虑列掩码和细胞掩码
    cell_loss = -tf.reduce_sum(cell_log_prob * column_mask * cell_mask, axis=1)
    # 将损失标准化为每列中的细胞数量,避免零除错误
    cell_loss /= tf.reduce_sum(column_mask * cell_mask, axis=1) + EPSILON_ZERO_DIVISION

    # 每个样本的选择损失等于每个列的损失
    selection_loss_per_example = column_loss_per_example
    # 添加细胞损失,仅在模型选择了细胞时
    selection_loss_per_example += tf.where(no_cell_selected, tf.zeros_like(selection_loss_per_example), cell_loss)

    # 根据模型选择的列,将选定列以外的概率设置为零
    selected_column_id = tf.argmax(column_logits, axis=-1, output_type=tf.int32)
    selected_column_mask = tf.cast(
        tf.equal(column_id_for_cells, tf.expand_dims(selected_column_id, axis=-1)), tf.float32
    )
    # 永远不要选择具有特殊列标识符 0 的细胞
    selected_column_mask = tf.where(
        tf.equal(column_id_for_cells, 0), tf.zeros_like(selected_column_mask), selected_column_mask
    )
    # 调整细胞的 logits,确保在选择的列之外的细胞概率为零
    logits_per_cell += CLOSE_ENOUGH_TO_LOG_ZERO * (1.0 - cell_mask * selected_column_mask)
    # 从 logits_per_cell 中收集指定的 logits
    logits = gather(logits_per_cell, cell_index)

    # 返回每个示例的选择损失和 logits
    return selection_loss_per_example, logits
# 计算聚合掩码,以确定模型是否应选择表中的单元格而非聚合
def _calculate_aggregate_mask(answer, pooled_output, cell_selection_preference, labels, aggregation_classifier):
    """
    Finds examples where the model should select cells with no aggregation.

    Returns a mask that determines for which examples should the model select answers directly from the table, without
    any aggregation function. If the answer is a piece of text the case is unambiguous as aggregation functions only
    apply to numbers. If the answer is a number but does not appear in the table then we must use some aggregation
    case. The ambiguous case is when the answer is a number that also appears in the table. In this case we use the
    aggregation function probabilities predicted by the model to decide whether to select or aggregate. The threshold
    for this is a hyperparameter *cell_selection_preference*

    Args:
        answer (`tf.Tensor` of shape `(batch_size, )`):
            Answer for every example in the batch. Nan if there is no scalar answer.
        pooled_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
            Output of the pooler (BertPooler) on top of the encoder layer.
        cell_selection_preference (`float`):
            Preference for cell selection in ambiguous cases.
        labels (`tf.Tensor` of shape `(batch_size, sequence_length)`):
            Labels per token.
        aggregation_classifier (`torch.nn.Linear`): Aggregation head

    Returns:
        aggregate_mask (`tf.Tensor` of shape `(batch_size,)`): A mask set to 1 for examples that should use aggregation
        functions.
    """
    # 初始化聚合掩码,判断答案是否为数字而非NaN
    aggregate_mask_init = tf.cast(tf.logical_not(tf.math.is_nan(answer)), tf.float32)
    
    # 计算聚合分类器的逻辑回归结果
    logits_aggregation = aggregation_classifier(pooled_output)
    
    # 创建分类分布对象,用于计算聚合函数的概率分布
    dist_aggregation = tfp.distributions.Categorical(logits=logits_aggregation)
    
    # 计算除“无聚合”外其他聚合操作的总质量
    aggregation_ops_total_mass = tf.reduce_sum(dist_aggregation.probs_parameter()[:, 1:], axis=1)
    
    # 根据当前模型判断是否选择单元格
    is_pred_cell_selection = aggregation_ops_total_mass <= cell_selection_preference
    
    # 判断是否存在单元格选择监督的例子
    is_cell_supervision_available = tf.reduce_sum(labels, axis=1) > 0
    
    # 根据判断结果设置聚合掩码
    aggregate_mask = tf.where(
        tf.logical_and(is_pred_cell_selection, is_cell_supervision_available),
        tf.zeros_like(aggregate_mask_init, dtype=tf.float32),
        aggregate_mask_init,
    )
    
    # 停止梯度在聚合掩码上的传播
    aggregate_mask = tf.stop_gradient(aggregate_mask)
    
    return aggregate_mask


def _calculate_aggregation_loss_known(
    logits_aggregation, aggregate_mask, aggregation_labels, use_answer_as_supervision, num_aggregation_labels
):
    """
    Calculates aggregation loss when its type is known during training.

    In the weakly supervised setting, the only known information is that for cell selection examples, "no aggregation"
    """
    # 计算已知类型聚合损失,用于训练中已知聚合类型的情况
    # 仅在弱监督设置中,已知信息是对于单元格选择的示例,“无聚合”
    """
    Calculate aggregation loss based on logits and supervision signals.

    Args:
        logits_aggregation (`tf.Tensor` of shape `(batch_size, num_aggregation_labels)`):
            Logits per aggregation operation.
        aggregate_mask (`tf.Tensor` of shape `(batch_size, )`):
            A mask set to 1 for examples that should use aggregation functions.
        aggregation_labels (`tf.Tensor` of shape `(batch_size, )`):
            Aggregation function id for every example in the batch.
        use_answer_as_supervision (`bool`, *optional*):
            Whether to use the answer as the only supervision for aggregation examples.
        num_aggregation_labels (`int`, *optional*, defaults to 0):
            The number of aggregation operators to predict.

    Returns:
        aggregation_loss_known (`tf.Tensor` of shape `(batch_size,)`): Aggregation loss (when its type is known during
        training) per example.
    """
    if use_answer_as_supervision:
        # Prepare "no aggregation" targets for cell selection examples.
        target_aggregation = tf.zeros_like(aggregate_mask, dtype=tf.int32)
    else:
        # Use aggregation supervision as the target.
        target_aggregation = aggregation_labels

    # Convert aggregation labels to one-hot encoding.
    one_hot_labels = tf.one_hot(target_aggregation, depth=num_aggregation_labels, dtype=tf.float32)

    # Compute log probabilities of the logits.
    log_probs = tf.nn.log_softmax(logits_aggregation, axis=-1)

    # Calculate cross entropy loss per example.
    per_example_aggregation_intermediate = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)

    if use_answer_as_supervision:
        # Accumulate loss only for examples requiring cell selection
        # (no aggregation).
        return per_example_aggregation_intermediate * (1 - aggregate_mask)
    else:
        # Return aggregation loss for all examples.
        return per_example_aggregation_intermediate
# 计算每个细胞的期望结果,考虑数值分布、数值、缩放因子、输入掩码、聚合逻辑和配置参数
def _calculate_expected_result(
    dist_per_cell, numeric_values, numeric_values_scale, input_mask_float, logits_aggregation, config
):
    Calculates the expected result given cell and aggregation probabilities.

    Args:
        dist_per_cell (`tfp.distributions.Bernoulli`):
            Cell selection distribution for each cell.
        numeric_values (`tf.Tensor` of shape `(batch_size, seq_length)`):
            Numeric values of every token. Nan for tokens which are not numeric values.
        numeric_values_scale (`tf.Tensor` of shape `(batch_size, seq_length)`):
            Scale of the numeric values of every token.
        input_mask_float (`tf.Tensor` of shape `(batch_size, seq_length)`):
            Mask for the table, without question tokens and table headers.
        logits_aggregation (`tf.Tensor` of shape `(batch_size, num_aggregation_labels)`):
            Logits per aggregation operation.
        config ([`TapasConfig`]):
            Model configuration class with all the hyperparameters of the model

    Returns:
        expected_result (`tf.Tensor` of shape `(batch_size,)`): The expected result per example.
    """
    if config.use_gumbel_for_cells:
        # 使用 Gumbel 分布进行采样,用于模拟以伯努利分布为基础的元胞选择
        gumbel_dist = tfp.distributions.RelaxedBernoulli(
            config.temperature,
            logits=dist_per_cell.logits_parameter() * config.temperature,
        )
        scaled_probability_per_cell = gumbel_dist.sample()  # 从 Gumbel 分布中采样元胞选择的概率
    else:
        scaled_probability_per_cell = dist_per_cell.probs_parameter()  # 直接使用伯努利分布的概率参数

    # 对每个元胞选择的概率进行缩放,同时应用数字值的比例和表的掩码
    scaled_probability_per_cell = (scaled_probability_per_cell / numeric_values_scale) * input_mask_float

    # 计算每个示例中选中元胞的数量总和
    count_result = tf.reduce_sum(scaled_probability_per_cell, axis=1)

    # 将非数字表格值的数值设为零,用于遮蔽那些非数值的标记
    numeric_values_masked = tf.where(
        tf.math.is_nan(numeric_values), tf.zeros_like(numeric_values), numeric_values
    )

    # 计算加权平均的结果总和
    sum_result = tf.reduce_sum(scaled_probability_per_cell * numeric_values_masked, axis=1)

    # 根据配置中的平均逼近方法选择相应的方法计算平均结果
    avg_approximation = config.average_approximation_function
    if avg_approximation == AverageApproximationFunction.RATIO:
        # 使用比率逼近方法计算平均结果
        average_result = sum_result / (count_result + EPSILON_ZERO_DIVISION)
    elif avg_approximation == AverageApproximationFunction.FIRST_ORDER:
        # 使用一阶逼近方法计算平均结果,考虑到其他元胞的概率
        ex = tf.reduce_sum(scaled_probability_per_cell, axis=1, keepdims=True) - scaled_probability_per_cell + 1
        average_result = tf.reduce_sum(numeric_values_masked * scaled_probability_per_cell / ex, axis=1)
    elif avg_approximation == AverageApproximationFunction.SECOND_ORDER:
        # 如果平均逼近方法为二阶,执行以下操作
        # 计算每个单元格的调整概率总和,除了当前单元格对应的概率,加上常数1
        ex = tf.reduce_sum(scaled_probability_per_cell, axis=1, keepdims=True) - scaled_probability_per_cell + 1
        # 计算每个单元格的点态方差
        pointwise_var = scaled_probability_per_cell * (1 - scaled_probability_per_cell)
        # 计算总体方差,排除当前单元格的贡献
        var = tf.reduce_sum(pointwise_var, axis=1, keepdims=True) - pointwise_var
        # 计算乘子,用于调整结果
        multiplier = (var / tf.math.square(ex) + 1) / ex
        # 计算加权平均结果
        average_result = tf.reduce_sum(numeric_values_masked * scaled_probability_per_cell * multiplier, axis=1)
    else:
        # 如果平均逼近方法不是二阶,则抛出错误
        raise ValueError("Invalid average_approximation_function: %s", config.average_approximation_function)

    if config.use_gumbel_for_aggregation:
        # 如果配置使用 Gumbel 分布进行聚合操作
        gumbel_dist = tfp.distributions.RelaxedOneHotCategorical(
            config.aggregation_temperature, logits=logits_aggregation[:, 1:]
        )
        # <float32>[batch_size, num_aggregation_labels - 1]
        # 从 Gumbel 分布中抽样,得到聚合操作的概率
        aggregation_op_only_probs = gumbel_dist.sample()
    else:
        # 如果不使用 Gumbel 分布进行聚合操作
        # <float32>[batch_size, num_aggregation_labels - 1]
        # 使用稳定的 softmax 函数计算聚合操作的概率
        aggregation_op_only_probs = stable_softmax(logits_aggregation[:, 1:] / config.aggregation_temperature, axis=-1)
    
    # 将所有结果按行拼接成一个张量
    all_results = tf.concat(
        [
            tf.expand_dims(sum_result, axis=1),
            tf.expand_dims(average_result, axis=1),
            tf.expand_dims(count_result, axis=1),
        ],
        axis=1,
    )
    # 计算期望结果,即所有结果与聚合操作概率的加权和
    expected_result = tf.reduce_sum(all_results * aggregation_op_only_probs, axis=1)
    # 返回期望结果张量
    return expected_result
    # 计算期望结果,根据每个单元格的分布、数值、数值规模、输入掩码、聚合操作的逻辑
    expected_result = _calculate_expected_result(
        dist_per_cell, numeric_values, numeric_values_scale, input_mask_float, logits_aggregation, config
    )

    # 将答案中的 NaN 替换为 0
    answer_masked = tf.where(tf.math.is_nan(answer), tf.zeros_like(answer), answer)

    # 如果配置启用了标准化答案损失
    if config.use_normalized_answer_loss:
        # 计算损失的标准化因子
        normalizer = tf.stop_gradient(
            tf.math.maximum(tf.math.abs(expected_result), tf.math.abs(answer_masked)) + EPSILON_ZERO_DIVISION
        )
        # 标准化答案和期望结果
        normalized_answer_masked = answer_masked / normalizer
        normalized_expected_result = expected_result / normalizer
        # 使用 Huber 损失函数计算每个示例的答案损失
        per_example_answer_loss = tf.compat.v1.losses.huber_loss(
            normalized_answer_masked * aggregate_mask,
            normalized_expected_result * aggregate_mask,
            delta=tf.cast(1.0, tf.float32),
            reduction=tf.losses.Reduction.NONE,
        )
    else:
        # 使用 Huber 损失函数计算每个示例的答案损失,未标准化的情况
        per_example_answer_loss = tf.compat.v1.losses.huber_loss(
            answer_masked * aggregate_mask,
            expected_result * aggregate_mask,
            delta=tf.cast(config.huber_loss_delta, tf.float32),
            reduction=tf.losses.Reduction.NONE,
        )
    # 如果配置中的答案损失截断值为 None,则创建一个全为 1 的张量作为大答案损失掩码
    if config.answer_loss_cutoff is None:
        large_answer_loss_mask = tf.ones_like(per_example_answer_loss, dtype=tf.float32)
    # 否则,根据答案损失是否大于答案损失截断值,生成大答案损失掩码
    else:
        large_answer_loss_mask = tf.where(
            per_example_answer_loss > config.answer_loss_cutoff,
            tf.zeros_like(per_example_answer_loss, dtype=tf.float32),
            tf.ones_like(per_example_answer_loss, dtype=tf.float32),
        )
    # 计算每个示例的答案损失加权,乘以聚合掩码
    per_example_answer_loss_scaled = config.answer_loss_importance * (per_example_answer_loss * aggregate_mask)
    # 返回加权后的每个示例的答案损失以及大答案损失掩码
    return per_example_answer_loss_scaled, large_answer_loss_mask

.\models\tapas\tokenization_tapas.py

# coding=utf-8
# Copyright 2020 Google Research and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tokenization class for TAPAS model."""

import collections  # 导入 collections 模块
import datetime  # 导入 datetime 模块
import enum  # 导入 enum 枚举类型
import itertools  # 导入 itertools 模块
import math  # 导入 math 数学运算模块
import os  # 导入 os 操作系统接口模块
import re  # 导入 re 正则表达式模块
import unicodedata  # 导入 unicodedata Unicode 数据库
from dataclasses import dataclass  # 导入 dataclass 装饰器,用于定义不可变数据类
from typing import Callable, Dict, Generator, List, Optional, Text, Tuple, Union  # 导入类型提示模块

import numpy as np  # 导入 NumPy 数学计算库

from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace  # 导入 tokenization_utils 模块中的相关函数
from ...tokenization_utils_base import (  # 导入 tokenization_utils_base 模块中的函数和类
    ENCODE_KWARGS_DOCSTRING,
    VERY_LARGE_INTEGER,
    BatchEncoding,
    EncodedInput,
    PreTokenizedInput,
    TextInput,
)
from ...utils import ExplicitEnum, PaddingStrategy, TensorType, add_end_docstrings, is_pandas_available, logging  # 导入 utils 模块中的相关功能

if is_pandas_available():
    import pandas as pd  # 如果 pandas 可用,则导入 pandas 模块

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器


VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}  # 定义词汇表文件名

PRETRAINED_VOCAB_FILES_MAP = {  # 预训练模型词汇表文件映射为空字典
    # Map is intentionally left empty
}

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {name: 512 for name in PRETRAINED_VOCAB_FILES_MAP.keys()}  # 预训练位置嵌入大小映射,初始化为512
PRETRAINED_INIT_CONFIGURATION = {name: {"do_lower_case": True} for name in PRETRAINED_VOCAB_FILES_MAP.keys()}  # 预训练模型初始化配置,所有模型均为小写处理


class TapasTruncationStrategy(ExplicitEnum):
    """
    Possible values for the `truncation` argument in [`~TapasTokenizer.__call__`]. Useful for tab-completion in an IDE.
    """
    DROP_ROWS_TO_FIT = "drop_rows_to_fit"  # 截断策略:删除行以适应
    DO_NOT_TRUNCATE = "do_not_truncate"  # 截断策略:不截断


TableValue = collections.namedtuple("TokenValue", ["token", "column_id", "row_id"])  # 命名元组,用于表示表格中的一个单元格值


@dataclass(frozen=True)
class TokenCoordinates:
    column_index: int  # 列索引
    row_index: int  # 行索引
    token_index: int  # 令牌索引


@dataclass
class TokenizedTable:
    rows: List[List[List[Text]]]  # 表格的令牌化行列表
    selected_tokens: List[TokenCoordinates]  # 所选令牌的坐标列表


@dataclass(frozen=True)
class SerializedExample:
    tokens: List[Text]  # 序列化示例的令牌列表
    column_ids: List[int]  # 列标识符列表
    row_ids: List[int]  # 行标识符列表
    segment_ids: List[int]  # 段标识符列表


def _is_inner_wordpiece(token: Text):
    """判断是否为内部词片段"""
    return token.startswith("##")


def load_vocab(vocab_file):
    """加载词汇表文件到字典中"""
    vocab = collections.OrderedDict()  # 使用有序字典存储词汇表
    with open(vocab_file, "r", encoding="utf-8") as reader:  # 打开词汇表文件
        tokens = reader.readlines()  # 读取文件中的所有行
    for index, token in enumerate(tokens):  # 遍历行索引和行内容
        token = token.rstrip("\n")  # 去除行末换行符
        vocab[token] = index  # 将词汇和索引存入字典
    return vocab  # 返回加载后的词汇表字典


def whitespace_tokenize(text):
    """对文本进行基本的空格清理和分割"""
    text = text.strip()  # 去除文本两端空白字符
    # 如果文本为空,则返回空列表
    if not text:
        return []
    # 使用空格分割文本,生成令牌列表
    tokens = text.split()
    # 返回生成的令牌列表
    return tokens
"""
class TapasTokenizer(PreTrainedTokenizer):
    r"""
    Construct a TAPAS tokenizer. Based on WordPiece. Flattens a table and one or more related sentences to be used by
    TAPAS models.

    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
    this superclass for more information regarding those methods. [`TapasTokenizer`] creates several token type ids to
    encode tabular structure. To be more precise, it adds 7 token type ids, in the following order: `segment_ids`,
    `column_ids`, `row_ids`, `prev_labels`, `column_ranks`, `inv_column_ranks` and `numeric_relations`:

    - segment_ids: indicate whether a token belongs to the question (0) or the table (1). 0 for special tokens and
      padding.
    - column_ids: indicate to which column of the table a token belongs (starting from 1). Is 0 for all question
      tokens, special tokens and padding.
    - row_ids: indicate to which row of the table a token belongs (starting from 1). Is 0 for all question tokens,
      special tokens and padding. Tokens of column headers are also 0.
    - prev_labels: indicate whether a token was (part of) an answer to the previous question (1) or not (0). Useful in
      a conversational setup (such as SQA).
    - column_ranks: indicate the rank of a table token relative to a column, if applicable. For example, if you have a
      column "number of movies" with values 87, 53 and 69, then the column ranks of these tokens are 3, 1 and 2
      respectively. 0 for all question tokens, special tokens and padding.
    - inv_column_ranks: indicate the inverse rank of a table token relative to a column, if applicable. For example, if
      you have a column "number of movies" with values 87, 53 and 69, then the inverse column ranks of these tokens are
      1, 3 and 2 respectively. 0 for all question tokens, special tokens and padding.
    - numeric_relations: indicate numeric relations between the question and the tokens of the table. 0 for all
      question tokens, special tokens and padding.

    [`TapasTokenizer`] runs end-to-end tokenization on a table and associated sentences: punctuation splitting and
    wordpiece.

    """

    vocab_files_names = VOCAB_FILES_NAMES  # 词汇文件的名称列表
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP  # 预训练词汇文件映射
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES  # 预训练位置嵌入的最大模型输入大小
    # 初始化函数,用于设置和配置Tokenizer对象的各种参数和选项
    def __init__(
        # 词汇文件路径,用于加载Tokenizer的词汇表
        self,
        vocab_file,
        # 是否将输入文本转换为小写,默认为True
        do_lower_case=True,
        # 是否进行基本的分词,默认为True
        do_basic_tokenize=True,
        # 指定不进行分割的特殊标记列表,如果为None则没有特殊标记
        never_split=None,
        # 未知标记的字符串表示,默认为"[UNK]"
        unk_token="[UNK]",
        # 分隔标记的字符串表示,默认为"[SEP]"
        sep_token="[SEP]",
        # 填充标记的字符串表示,默认为"[PAD]"
        pad_token="[PAD]",
        # 类别标记的字符串表示,默认为"[CLS]"
        cls_token="[CLS]",
        # 掩码标记的字符串表示,默认为"[MASK]"
        mask_token="[MASK]",
        # 空标记的字符串表示,默认为"[EMPTY]"
        empty_token="[EMPTY]",
        # 是否对中文字符进行分词,默认为True
        tokenize_chinese_chars=True,
        # 是否去除字符串中的重音符号,默认为None(不去除)
        strip_accents=None,
        # 单元格修剪长度,指定列名称的最大长度,默认为-1(不限制)
        cell_trim_length: int = -1,
        # 最大列ID,默认为None(不限制)
        max_column_id: int = None,
        # 最大行ID,默认为None(不限制)
        max_row_id: int = None,
        # 是否去除列名的空格,默认为False
        strip_column_names: bool = False,
        # 是否更新答案坐标,默认为False
        update_answer_coordinates: bool = False,
        # 最小问题长度,默认为None(不限制)
        min_question_length=None,
        # 最大问题长度,默认为None(不限制)
        max_question_length=None,
        # 模型的最大长度,默认为512
        model_max_length: int = 512,
        # 额外的特殊标记列表,可以为None
        additional_special_tokens: Optional[List[str]] = None,
        # 其他可选参数,以字典形式接收
        **kwargs,
    ):
        ):
            # 检查是否安装了 Pandas 库,若未安装则抛出 ImportError 异常
            if not is_pandas_available():
                raise ImportError("Pandas is required for the TAPAS tokenizer.")

            # 处理额外的特殊标记,确保空标记在其中
            if additional_special_tokens is not None:
                if empty_token not in additional_special_tokens:
                    additional_special_tokens.append(empty_token)
            else:
                additional_special_tokens = [empty_token]

            # 检查词汇文件是否存在,若不存在则抛出 ValueError 异常
            if not os.path.isfile(vocab_file):
                raise ValueError(
                    f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
                    " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
                )

            # 加载词汇表并创建词汇到 ID 的映射,保持有序字典
            self.vocab = load_vocab(vocab_file)
            self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])

            # 设置是否进行基本的分词处理
            self.do_basic_tokenize = do_basic_tokenize
            if do_basic_tokenize:
                self.basic_tokenizer = BasicTokenizer(
                    do_lower_case=do_lower_case,
                    never_split=never_split,
                    tokenize_chinese_chars=tokenize_chinese_chars,
                    strip_accents=strip_accents,
                )

            # 使用词汇表初始化 WordpieceTokenizer 对象
            self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))

            # 设置额外的属性
            self.cell_trim_length = cell_trim_length
            # 设置列的最大 ID,如果未提供则使用 model_max_length 或设为一个非常大的整数
            self.max_column_id = (
                max_column_id
                if max_column_id is not None
                else model_max_length
                if model_max_length is not None
                else VERY_LARGE_INTEGER
            )
            # 设置行的最大 ID,如果未提供则使用 model_max_length 或设为一个非常大的整数
            self.max_row_id = (
                max_row_id
                if max_row_id is not None
                else model_max_length
                if model_max_length is not None
                else VERY_LARGE_INTEGER
            )
            # 是否去除列名中的空白字符
            self.strip_column_names = strip_column_names
            # 是否更新答案的坐标
            self.update_answer_coordinates = update_answer_coordinates
            # 最小问题长度限制
            self.min_question_length = min_question_length
            # 最大问题长度限制
            self.max_question_length = max_question_length

            # 调用父类的构造方法,初始化基本参数和额外的特殊标记等
            super().__init__(
                do_lower_case=do_lower_case,
                do_basic_tokenize=do_basic_tokenize,
                never_split=never_split,
                unk_token=unk_token,
                sep_token=sep_token,
                pad_token=pad_token,
                cls_token=cls_token,
                mask_token=mask_token,
                empty_token=empty_token,
                tokenize_chinese_chars=tokenize_chinese_chars,
                strip_accents=strip_accents,
                cell_trim_length=cell_trim_length,
                max_column_id=max_column_id,
                max_row_id=max_row_id,
                strip_column_names=strip_column_names,
                update_answer_coordinates=update_answer_coordinates,
                min_question_length=min_question_length,
                max_question_length=max_question_length,
                model_max_length=model_max_length,
                additional_special_tokens=additional_special_tokens,
                **kwargs,
            )

        @property
    # 返回当前实例中的基本分词器的小写设置
    def do_lower_case(self):
        return self.basic_tokenizer.do_lower_case

    # 返回当前词汇表的大小
    @property
    def vocab_size(self):
        return len(self.vocab)

    # 返回词汇表和添加的特殊token编码器组成的字典
    def get_vocab(self):
        return dict(self.vocab, **self.added_tokens_encoder)

    # 将文本进行标记化处理,返回标记列表
    def _tokenize(self, text):
        # 检查格式化后的文本是否为空文本,如果是,则返回一个特殊token的列表
        if format_text(text) == EMPTY_TEXT:
            return [self.additional_special_tokens[0]]
        split_tokens = []
        # 如果设置了基本分词,则使用基本分词器处理文本
        if self.do_basic_tokenize:
            for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
                # 如果token属于不分割的特殊token集合,则直接加入split_tokens
                if token in self.basic_tokenizer.never_split:
                    split_tokens.append(token)
                else:
                    # 否则使用wordpiece_tokenizer进一步分割token,加入split_tokens
                    split_tokens += self.wordpiece_tokenizer.tokenize(token)
        else:
            # 否则直接使用wordpiece_tokenizer处理文本
            split_tokens = self.wordpiece_tokenizer.tokenize(text)
        return split_tokens

    # 根据token返回其在词汇表中的id,如果找不到则返回UNK(未知token)的id
    def _convert_token_to_id(self, token):
        """Converts a token (str) in an id using the vocab."""
        return self.vocab.get(token, self.vocab.get(self.unk_token))

    # 根据id返回词汇表中对应的token,如果找不到则返回UNK(未知token)
    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        return self.ids_to_tokens.get(index, self.unk_token)

    # 将token序列转换为单个字符串,去除"##"并去除两端空格
    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (string) in a single string."""
        out_string = " ".join(tokens).replace(" ##", "").strip()
        return out_string

    # 将词汇表保存到指定目录中的文件中,并返回保存的文件路径
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        index = 0
        # 如果保存目录已存在,则在其下创建词汇表文件
        if os.path.isdir(save_directory):
            vocab_file = os.path.join(
                save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
            )
        else:
            # 否则直接在指定的保存目录或者文件名前缀下创建词汇表文件
            vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
        # 使用utf-8编码打开文件,并逐行写入词汇表中的token
        with open(vocab_file, "w", encoding="utf-8") as writer:
            for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
                if index != token_index:
                    logger.warning(
                        f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
                        " Please check that the vocabulary is not corrupted!"
                    )
                    index = token_index
                writer.write(token + "\n")
                index += 1
        return (vocab_file,)
    def create_attention_mask_from_sequences(self, query_ids: List[int], table_values: List[TableValue]) -> List[int]:
        """
        根据查询的token ID和表格值创建注意力掩码。

        Args:
            query_ids (`List[int]`): 与查询相关的token ID列表。
            table_values (`List[TableValue]`): 表格值的列表,其中包含命名元组,包括token值、列ID和行ID。

        Returns:
            `List[int]`: 包含注意力掩码值的整数列表。
        """
        # 创建一个全为1的列表,长度为查询token数加1再加上表格值数加1
        return [1] * (1 + len(query_ids) + 1 + len(table_values))

    def create_segment_token_type_ids_from_sequences(
        self, query_ids: List[int], table_values: List[TableValue]
    ) -> List[int]:
        """
        根据查询的token ID和表格值创建段落token类型ID。

        Args:
            query_ids (`List[int]`): 与查询相关的token ID列表。
            table_values (`List[TableValue]`): 表格值的列表,其中包含命名元组,包括token值、列ID和行ID。

        Returns:
            `List[int]`: 包含段落token类型ID值的整数列表。
        """
        # 如果有表格值,则提取出所有表格值的第一个元素(token值),否则为空列表
        table_ids = list(zip(*table_values))[0] if table_values else []
        # 返回一个以0填充的列表,长度为查询token数加1再加上1,再加上以1填充的列表,长度为表格值中token值的数量
        return [0] * (1 + len(query_ids) + 1) + [1] * len(table_ids)

    def create_column_token_type_ids_from_sequences(
        self, query_ids: List[int], table_values: List[TableValue]
    ) -> List[int]:
        """
        根据查询的token ID和表格值创建列token类型ID。

        Args:
            query_ids (`List[int]`): 与查询相关的token ID列表。
            table_values (`List[TableValue]`): 表格值的列表,其中包含命名元组,包括token值、列ID和行ID。

        Returns:
            `List[int]`: 包含列token类型ID值的整数列表。
        """
        # 如果有表格值,则提取出所有表格值的第二个元素(列ID),否则为空列表
        table_column_ids = list(zip(*table_values))[1] if table_values else []
        # 返回一个以0填充的列表,长度为查询token数加1再加上1,再加上表格值中列ID数量的列表
        return [0] * (1 + len(query_ids) + 1) + list(table_column_ids)

    def create_row_token_type_ids_from_sequences(
        self, query_ids: List[int], table_values: List[TableValue]
    ) -> List[int]:
        """
        根据查询的token ID和表格值创建行token类型ID。
        
        Args:
            query_ids (`List[int]`): 与查询相关的token ID列表。
            table_values (`List[TableValue]`): 表格值的列表,其中包含命名元组,包括token值、列ID和行ID。

        Returns:
            `List[int]`: 包含行token类型ID值的整数列表。
        """
        # 如果有表格值,则提取出所有表格值的第三个元素(行ID),否则为空列表
        table_row_ids = list(zip(*table_values))[2] if table_values else []
        # 返回一个以0填充的列表,长度为查询token数加1再加上1,再加上表格值中行ID数量的列表
        return [0] * (1 + len(query_ids) + 1) + list(table_row_ids)
    ) -> List[int]:
        """
        Creates the row token type IDs according to the query token IDs and a list of table values.

        Args:
            query_ids (`List[int]`): list of token IDs corresponding to the ID.
            table_values (`List[TableValue]`): lift of table values, which are named tuples containing the
                token value, the column ID and the row ID of said token.

        Returns:
            `List[int]`: List of ints containing the row token type IDs values.
        """
        # Extract row IDs from table_values if it's not empty, otherwise initialize as an empty list
        table_row_ids = list(zip(*table_values))[2] if table_values else []
        # Generate row token type IDs list by concatenating [0], query_ids, [0] (for padding), and table_row_ids
        return [0] * (1 + len(query_ids) + 1) + list(table_row_ids)

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Build model inputs from a question and flattened table for question answering or sequence classification tasks
        by concatenating and adding special tokens.

        Args:
            token_ids_0 (`List[int]`): The ids of the question.
            token_ids_1 (`List[int]`, *optional*): The ids of the flattened table.

        Returns:
            `List[int]`: The model input with special tokens.
        """
        # Check if token_ids_1 is provided; raise error if not provided with TAPAS
        if token_ids_1 is None:
            raise ValueError("With TAPAS, you must provide both question IDs and table IDs.")
        # Concatenate cls_token_id, token_ids_0, sep_token_id, and token_ids_1 to build model input
        return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + token_ids_1

    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
        """
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` method.

        Args:
            token_ids_0 (`List[int]`):
                List of question IDs.
            token_ids_1 (`List[int]`, *optional*):
                List of flattened table IDs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """

        # If already_has_special_tokens is True, delegate to the parent class method
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )
        
        # If token_ids_1 is not None, return a mask indicating special tokens (1) and sequence tokens (0)
        if token_ids_1 is not None:
            return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
        # If token_ids_1 is None, return a mask indicating special tokens (1) and sequence tokens (0)
        return [1] + ([0] * len(token_ids_0)) + [1]

    @add_end_docstrings(TAPAS_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
    # 定义一个方法,使其能像函数一样被调用
    def __call__(
        self,
        # 表格数据,使用 pandas 的 DataFrame 类型
        table: "pd.DataFrame",
        # 查询输入,可以是文本输入、预分词输入、编码输入,或它们的列表
        queries: Optional[
            Union[
                TextInput,
                PreTokenizedInput,
                EncodedInput,
                List[TextInput],
                List[PreTokenizedInput],
                List[EncodedInput],
            ]
        ] = None,
        # 答案的坐标,可以是单个或多个坐标的列表
        answer_coordinates: Optional[Union[List[Tuple], List[List[Tuple]]]] = None,
        # 答案的文本形式,可以是单个或多个文本输入的列表
        answer_text: Optional[Union[List[TextInput], List[List[TextInput]]]] = None,
        # 是否添加特殊标记,默认为 True
        add_special_tokens: bool = True,
        # 填充策略,可以是布尔值、字符串或填充策略对象,默认为 False
        padding: Union[bool, str, PaddingStrategy] = False,
        # 截断策略,可以是布尔值、字符串或截断策略对象,默认为 False
        truncation: Union[bool, str, TapasTruncationStrategy] = False,
        # 最大长度限制,默认为 None
        max_length: Optional[int] = None,
        # 填充到的最接近的倍数,默认为 None
        pad_to_multiple_of: Optional[int] = None,
        # 返回的张量类型,默认为 None
        return_tensors: Optional[Union[str, TensorType]] = None,
        # 是否返回 token 类型 ID,默认为 None
        return_token_type_ids: Optional[bool] = None,
        # 是否返回注意力掩码,默认为 None
        return_attention_mask: Optional[bool] = None,
        # 是否返回溢出的 token,默认为 False
        return_overflowing_tokens: bool = False,
        # 是否返回特殊 token 掩码,默认为 False
        return_special_tokens_mask: bool = False,
        # 是否返回偏移映射,默认为 False
        return_offsets_mapping: bool = False,
        # 是否返回长度,默认为 False
        return_length: bool = False,
        # 是否启用详细输出,默认为 True
        verbose: bool = True,
        **kwargs,
    @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, TAPAS_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
    # 批量编码加方法的定义,具有与 __call__ 相似的参数
    def batch_encode_plus(
        self,
        # 表格数据,使用 pandas 的 DataFrame 类型
        table: "pd.DataFrame",
        # 查询输入的列表,可以是文本输入、预分词输入或编码输入的列表
        queries: Optional[
            Union[
                List[TextInput],
                List[PreTokenizedInput],
                List[EncodedInput],
            ]
        ] = None,
        # 答案的坐标列表的列表形式
        answer_coordinates: Optional[List[List[Tuple]]] = None,
        # 答案的文本列表的列表形式
        answer_text: Optional[List[List[TextInput]]] = None,
        # 是否添加特殊标记,默认为 True
        add_special_tokens: bool = True,
        # 填充策略,可以是布尔值、字符串或填充策略对象,默认为 False
        padding: Union[bool, str, PaddingStrategy] = False,
        # 截断策略,可以是布尔值、字符串或截断策略对象,默认为 False
        truncation: Union[bool, str, TapasTruncationStrategy] = False,
        # 最大长度限制,默认为 None
        max_length: Optional[int] = None,
        # 填充到的最接近的倍数,默认为 None
        pad_to_multiple_of: Optional[int] = None,
        # 返回的张量类型,默认为 None
        return_tensors: Optional[Union[str, TensorType]] = None,
        # 是否返回 token 类型 ID,默认为 None
        return_token_type_ids: Optional[bool] = None,
        # 是否返回注意力掩码,默认为 None
        return_attention_mask: Optional[bool] = None,
        # 是否返回溢出的 token,默认为 False
        return_overflowing_tokens: bool = False,
        # 是否返回特殊 token 掩码,默认为 False
        return_special_tokens_mask: bool = False,
        # 是否返回偏移映射,默认为 False
        return_offsets_mapping: bool = False,
        # 是否返回长度,默认为 False
        return_length: bool = False,
        # 是否启用详细输出,默认为 True
        verbose: bool = True,
        **kwargs,
    # 获取问题 tokens 的方法定义,输入参数是一个查询
    def _get_question_tokens(self, query):
        """Tokenizes the query, taking into account the max and min question length."""
        
        # 使用内部方法 tokenize 对查询进行分词处理,返回分词后的结果
        query_tokens = self.tokenize(query)
        # 如果设定了最大问题长度且查询分词后的长度超过最大问题长度,则记录警告并返回空字符串和空列表
        if self.max_question_length is not None and len(query_tokens) > self.max_question_length:
            logger.warning("Skipping query as its tokens are longer than the max question length")
            return "", []
        # 如果设定了最小问题长度且查询分词后的长度少于最小问题长度,则记录警告并返回空字符串和空列表
        if self.min_question_length is not None and len(query_tokens) < self.min_question_length:
            logger.warning("Skipping query as its tokens are shorter than the min question length")
            return "", []

        # 返回原始查询和其分词后的结果列表
        return query, query_tokens
    # 定义一个方法 `_batch_encode_plus`,用于批量编码输入数据并返回编码后的批处理结果
    def _batch_encode_plus(
        self,
        table,  # 表格数据,待编码的输入表格
        queries: Union[  # 查询数据,可以是文本输入、预分词输入或编码输入的列表
            List[TextInput],
            List[PreTokenizedInput],
            List[EncodedInput],
        ],
        answer_coordinates: Optional[List[List[Tuple]]] = None,  # 答案坐标,可选的二维列表,每个元素是一组坐标元组
        answer_text: Optional[List[List[TextInput]]] = None,  # 答案文本,可选的二维列表,每个元素是一组文本输入
        add_special_tokens: bool = True,  # 是否添加特殊标记,如 [CLS], [SEP]
        padding: Union[bool, str, PaddingStrategy] = False,  # 填充策略,指定填充的方式
        truncation: Union[bool, str, TapasTruncationStrategy] = False,  # 截断策略,指定截断的方式
        max_length: Optional[int] = None,  # 最大长度,限制编码后的最大长度
        pad_to_multiple_of: Optional[int] = None,  # 填充到的倍数
        return_tensors: Optional[Union[str, TensorType]] = None,  # 返回的张量类型
        return_token_type_ids: Optional[bool] = True,  # 是否返回token类型ID
        return_attention_mask: Optional[bool] = None,  # 是否返回注意力掩码
        return_overflowing_tokens: bool = False,  # 是否返回溢出的token
        return_special_tokens_mask: bool = False,  # 是否返回特殊token的掩码
        return_offsets_mapping: bool = False,  # 是否返回偏移映射
        return_length: bool = False,  # 是否返回长度信息
        verbose: bool = True,  # 是否启用详细输出模式
        **kwargs,  # 其他参数,灵活处理额外的关键字参数
    ) -> BatchEncoding:  # 返回类型为 BatchEncoding 对象
        # 对输入的表格数据进行标记化处理,得到表格数据的token表示
        table_tokens = self._tokenize_table(table)

        # 初始化查询数据的token表示列表
        queries_tokens = []
        # 遍历查询数据列表,对每个查询进行处理
        for idx, query in enumerate(queries):
            # 调用内部方法 `_get_question_tokens` 处理查询,获取查询文本和token表示
            query, query_tokens = self._get_question_tokens(query)
            # 更新查询数据列表中的查询文本
            queries[idx] = query
            # 将查询的token表示添加到查询token列表中
            queries_tokens.append(query_tokens)

        # 调用内部方法 `_batch_prepare_for_model` 准备模型输入数据,进行编码和准备
        batch_outputs = self._batch_prepare_for_model(
            table,  # 表格数据
            queries,  # 查询数据
            tokenized_table=table_tokens,  # 表格数据的token表示
            queries_tokens=queries_tokens,  # 查询数据的token表示
            answer_coordinates=answer_coordinates,  # 答案坐标
            padding=padding,  # 填充策略
            truncation=truncation,  # 截断策略
            answer_text=answer_text,  # 答案文本
            add_special_tokens=add_special_tokens,  # 是否添加特殊token
            max_length=max_length,  # 最大长度
            pad_to_multiple_of=pad_to_multiple_of,  # 填充到的倍数
            return_tensors=return_tensors,  # 返回的张量类型
            prepend_batch_axis=True,  # 是否在返回结果中添加批处理维度
            return_attention_mask=return_attention_mask,  # 是否返回注意力掩码
            return_token_type_ids=return_token_type_ids,  # 是否返回token类型ID
            return_overflowing_tokens=return_overflowing_tokens,  # 是否返回溢出的token
            return_special_tokens_mask=return_special_tokens_mask,  # 是否返回特殊token的掩码
            return_length=return_length,  # 是否返回长度信息
            verbose=verbose,  # 是否启用详细输出模式
        )

        # 返回 BatchEncoding 对象,封装了批处理后的编码结果
        return BatchEncoding(batch_outputs)
    # 定义一个方法 `_batch_prepare_for_model`,用于准备数据以供模型处理
    def _batch_prepare_for_model(
        self,
        raw_table: "pd.DataFrame",  # 原始数据表格,类型为 Pandas DataFrame
        raw_queries: Union[  # 原始查询数据的列表,可以是不同类型的输入数据
            List[TextInput],  # 文本输入列表
            List[PreTokenizedInput],  # 预标记化输入列表
            List[EncodedInput],  # 编码输入列表
        ],
        tokenized_table: Optional[TokenizedTable] = None,  # 可选的表格数据经过标记化的形式
        queries_tokens: Optional[List[List[str]]] = None,  # 可选的查询标记化后的词列表
        answer_coordinates: Optional[List[List[Tuple]]] = None,  # 可选的答案坐标列表
        answer_text: Optional[List[List[TextInput]]] = None,  # 可选的答案文本列表
        add_special_tokens: bool = True,  # 是否添加特殊标记,默认为 True
        padding: Union[bool, str, PaddingStrategy] = False,  # 填充策略,默认为 False
        truncation: Union[bool, str, TapasTruncationStrategy] = False,  # 截断策略,默认为 False
        max_length: Optional[int] = None,  # 最大长度限制,可选
        pad_to_multiple_of: Optional[int] = None,  # 填充到指定的倍数
        return_tensors: Optional[Union[str, TensorType]] = None,  # 返回的张量类型,可选
        return_token_type_ids: Optional[bool] = True,  # 是否返回token类型id,默认为 True
        return_attention_mask: Optional[bool] = True,  # 是否返回注意力掩码,默认为 True
        return_special_tokens_mask: bool = False,  # 是否返回特殊标记掩码,默认为 False
        return_offsets_mapping: bool = False,  # 是否返回偏移映射,默认为 False
        return_length: bool = False,  # 是否返回长度信息,默认为 False
        verbose: bool = True,  # 是否打印详细信息,默认为 True
        prepend_batch_axis: bool = False,  # 是否在结果中添加批处理维度,默认为 False
        **kwargs,  # 其它参数,灵活传递
    ):
    ) -> BatchEncoding:
        batch_outputs = {}  # 初始化一个空字典,用于存储批处理的输出结果

        # 遍历输入的四个列表的元素,每次迭代生成一个示例
        for index, example in enumerate(zip(raw_queries, queries_tokens, answer_coordinates, answer_text)):
            raw_query, query_tokens, answer_coords, answer_txt = example  # 解包示例元组到各个变量
            # 调用 self.prepare_for_model 方法准备模型输入,并获取输出结果
            outputs = self.prepare_for_model(
                raw_table,
                raw_query,
                tokenized_table=tokenized_table,
                query_tokens=query_tokens,
                answer_coordinates=answer_coords,
                answer_text=answer_txt,
                add_special_tokens=add_special_tokens,
                padding=PaddingStrategy.DO_NOT_PAD.value,  # 设置不进行单独填充,而是批处理后再进行
                truncation=truncation,
                max_length=max_length,
                pad_to_multiple_of=None,  # 批处理后再进行填充
                return_attention_mask=False,  # 批处理后再进行填充
                return_token_type_ids=return_token_type_ids,
                return_special_tokens_mask=return_special_tokens_mask,
                return_length=return_length,
                return_tensors=None,  # 在最后将整个批次转换为张量
                prepend_batch_axis=False,
                verbose=verbose,
                prev_answer_coordinates=answer_coordinates[index - 1] if index != 0 else None,  # 前一个答案的坐标
                prev_answer_text=answer_text[index - 1] if index != 0 else None,  # 前一个答案的文本
            )

            # 将每个输出项添加到批处理输出字典中
            for key, value in outputs.items():
                if key not in batch_outputs:
                    batch_outputs[key] = []
                batch_outputs[key].append(value)

        # 对批处理输出进行填充处理
        batch_outputs = self.pad(
            batch_outputs,
            padding=padding,
            max_length=max_length,
            pad_to_multiple_of=pad_to_multiple_of,
            return_attention_mask=return_attention_mask,
        )

        # 将填充后的批处理输出转换为 BatchEncoding 类型
        batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)

        return batch_outputs  # 返回填充后的批处理输出对象
    @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, TAPAS_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
    # 使用装饰器添加文档字符串,文档字符串内容包括 ENCODE_KWARGS_DOCSTRING 和 TAPAS_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING
    def encode_plus(
        self,
        table: "pd.DataFrame",
        # 表格数据,必须是一个 Pandas 的 DataFrame,所有单元格的值必须是文本格式。可以使用 *.astype(str)* 转换数据框为字符串格式。
        query: Optional[
            Union[
                TextInput,
                PreTokenizedInput,
                EncodedInput,
            ]
        ] = None,
        # 查询问题,可以是字符串或者字符串列表的形式,用于编码和查询相关的表格信息。
        answer_coordinates: Optional[List[Tuple]] = None,
        # 答案坐标,用于指定表格中答案的坐标位置。
        answer_text: Optional[List[TextInput]] = None,
        # 答案文本,用于指定表格中答案的文本内容。
        add_special_tokens: bool = True,
        # 是否添加特殊标记,通常用于控制是否在编码过程中添加特殊标记,如 [CLS], [SEP] 等。
        padding: Union[bool, str, PaddingStrategy] = False,
        # 填充策略,用于控制输入序列的填充方式,可以是布尔值、字符串或者填充策略对象。
        truncation: Union[bool, str, TapasTruncationStrategy] = False,
        # 截断策略,用于控制输入序列的截断方式,可以是布尔值、字符串或者截断策略对象。
        max_length: Optional[int] = None,
        # 最大长度,用于控制编码后的序列的最大长度。
        pad_to_multiple_of: Optional[int] = None,
        # 填充到的倍数,用于控制序列填充后的长度为指定倍数。
        return_tensors: Optional[Union[str, TensorType]] = None,
        # 返回张量类型,用于指定返回的编码结果的张量类型,如 'pt' 表示返回 PyTorch 张量。
        return_token_type_ids: Optional[bool] = None,
        # 是否返回 token 类型 ID,用于指定是否返回编码后序列的 token 类型 ID。
        return_attention_mask: Optional[bool] = None,
        # 是否返回注意力掩码,用于指定是否返回编码后序列的注意力掩码。
        return_special_tokens_mask: bool = False,
        # 是否返回特殊标记掩码,用于指定是否返回编码后序列的特殊标记掩码。
        return_offsets_mapping: bool = False,
        # 是否返回偏移映射,用于指定是否返回编码后序列的字符偏移映射。
        return_length: bool = False,
        # 是否返回长度,用于指定是否返回编码后序列的长度。
        verbose: bool = True,
        # 是否详细输出,用于控制是否输出详细的编码过程信息。
        **kwargs,
        # 其他参数,用于接收可能存在的其他关键字参数。
    ) -> BatchEncoding:
        """
        Prepare a table and a string for the model.

        Args:
            table (`pd.DataFrame`):
                Table containing tabular data. Note that all cell values must be text. Use *.astype(str)* on a Pandas
                dataframe to convert it to string.
            query (`str` or `List[str]`):
                Question related to a table to be encoded.
            answer_coordinates (`List[Tuple]` or `List[List[Tuple]]`, *optional*):
                Answer coordinates of each table-question pair in the batch. The answer_coordinates must be a single
                list of one or more tuples. Each tuple must be a (row_index, column_index) pair. The first data row
                (not the column header row) has index 0. The first column has index 0.
            answer_text (`List[str]` or `List[List[str]]`, *optional*):
                Answer text of each table-question pair in the batch. The answer_text must be a single list of one or
                more strings. Each string must be the answer text of a corresponding answer coordinate.
        """
        # 检查特殊情况,如果设置了return_token_type_ids但未设置add_special_tokens为True,则引发值错误
        if return_token_type_ids is not None and not add_special_tokens:
            raise ValueError(
                "Asking to return token_type_ids while setting add_special_tokens to False "
                "results in an undefined behavior. Please set add_special_tokens to True or "
                "set return_token_type_ids to None."
            )

        # 检查参数的一致性,如果提供了answer_coordinates但未提供answer_text,或者反之,则引发值错误
        if (answer_coordinates and not answer_text) or (not answer_coordinates and answer_text):
            raise ValueError("In case you provide answers, both answer_coordinates and answer_text should be provided")

        # 检查是否包含不支持的参数,如果kwargs中包含'is_split_into_words',则引发未实现错误
        if "is_split_into_words" in kwargs:
            raise NotImplementedError("Currently TapasTokenizer only supports questions as strings.")

        # 检查是否请求返回偏移映射,由于Python tokenizers不支持该功能,因此引发未实现错误
        if return_offsets_mapping:
            raise NotImplementedError(
                "return_offset_mapping is not available when using Python tokenizers. "
                "To use this feature, change your tokenizer to one deriving from "
                "transformers.PreTrainedTokenizerFast."
            )

        # 调用内部方法_encode_plus,用给定参数编码表和查询,并返回编码结果
        return self._encode_plus(
            table=table,
            query=query,
            answer_coordinates=answer_coordinates,
            answer_text=answer_text,
            add_special_tokens=add_special_tokens,
            truncation=truncation,
            padding=padding,
            max_length=max_length,
            pad_to_multiple_of=pad_to_multiple_of,
            return_tensors=return_tensors,
            return_token_type_ids=return_token_type_ids,
            return_attention_mask=return_attention_mask,
            return_special_tokens_mask=return_special_tokens_mask,
            return_offsets_mapping=return_offsets_mapping,
            return_length=return_length,
            verbose=verbose,
            **kwargs,
        )
    # 定义一个方法 `_encode_plus`,用于对输入进行编码加工,并返回适用于模型输入的格式化数据
    def _encode_plus(
        self,
        table: "pd.DataFrame",  # 输入的表格数据,类型为 Pandas DataFrame
        query: Union[  # 查询文本,可以是文本输入的几种形式之一
            TextInput,  # 文本输入
            PreTokenizedInput,  # 预标记化的输入
            EncodedInput,  # 编码后的输入
        ],
        answer_coordinates: Optional[List[Tuple]] = None,  # 答案的坐标信息(可选)
        answer_text: Optional[List[TextInput]] = None,  # 答案的文本信息(可选)
        add_special_tokens: bool = True,  # 是否添加特殊标记(默认为 True)
        padding: Union[bool, str, PaddingStrategy] = False,  # 是否进行填充,填充策略可以是布尔值、字符串或填充策略对象
        truncation: Union[bool, str, TapasTruncationStrategy] = False,  # 是否截断输入,截断策略可以是布尔值、字符串或截断策略对象
        max_length: Optional[int] = None,  # 最大长度限制(可选)
        pad_to_multiple_of: Optional[int] = None,  # 填充到指定的倍数长度(可选)
        return_tensors: Optional[Union[str, TensorType]] = None,  # 返回的张量类型(可选)
        return_token_type_ids: Optional[bool] = True,  # 是否返回 token 类型 ID(默认为 True)
        return_attention_mask: Optional[bool] = True,  # 是否返回注意力掩码(默认为 True)
        return_special_tokens_mask: bool = False,  # 是否返回特殊标记的掩码(默认为 False)
        return_offsets_mapping: bool = False,  # 是否返回偏移映射(默认为 False)
        return_length: bool = False,  # 是否返回长度信息(默认为 False)
        verbose: bool = True,  # 是否显示详细信息(默认为 True)
        **kwargs,  # 其他关键字参数
    ):
        if query is None:  # 如果查询文本为 None
            query = ""  # 将查询文本设为空字符串
            logger.warning(  # 记录警告日志,提醒用户
                "TAPAS is a question answering model but you have not passed a query. Please be aware that the "
                "model will probably not behave correctly."
            )

        # 对表格进行标记化处理,生成表格 token
        table_tokens = self._tokenize_table(table)
        # 获取查询文本的 token 化结果和原始文本
        query, query_tokens = self._get_question_tokens(query)

        # 调用 self.prepare_for_model 方法,准备模型输入数据
        return self.prepare_for_model(
            table,  # 输入表格数据
            query,  # 查询文本
            tokenized_table=table_tokens,  # 表格的标记化结果
            query_tokens=query_tokens,  # 查询文本的标记化结果
            answer_coordinates=answer_coordinates,  # 答案的坐标信息
            answer_text=answer_text,  # 答案的文本信息
            add_special_tokens=add_special_tokens,  # 是否添加特殊标记
            truncation=truncation,  # 截断策略
            padding=padding,  # 填充策略
            max_length=max_length,  # 最大长度限制
            pad_to_multiple_of=pad_to_multiple_of,  # 填充到倍数长度
            return_tensors=return_tensors,  # 返回的张量类型
            prepend_batch_axis=True,  # 是否在结果中添加批次维度
            return_attention_mask=return_attention_mask,  # 是否返回注意力掩码
            return_token_type_ids=return_token_type_ids,  # 是否返回 token 类型 ID
            return_special_tokens_mask=return_special_tokens_mask,  # 是否返回特殊标记的掩码
            return_length=return_length,  # 是否返回长度信息
            verbose=verbose,  # 是否显示详细信息
        )

    # 将 ENCODE_KWARGS_DOCSTRING 和 TAPAS_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING 添加为文档字符串
    @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, TAPAS_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
    # 定义一个方法,准备数据以供模型使用
    def prepare_for_model(
        self,
        raw_table: "pd.DataFrame",  # 接收原始数据表格,类型为 pandas DataFrame
        raw_query: Union[  # 接收原始查询,可以是多种类型的输入数据
            TextInput,  # 文本输入
            PreTokenizedInput,  # 预分词的输入
            EncodedInput,  # 编码后的输入
        ],
        tokenized_table: Optional[TokenizedTable] = None,  # 可选参数,已经分词的表格数据
        query_tokens: Optional[TokenizedTable] = None,  # 可选参数,查询的分词结果
        answer_coordinates: Optional[List[Tuple]] = None,  # 可选参数,答案的坐标列表
        answer_text: Optional[List[TextInput]] = None,  # 可选参数,答案的文本列表
        add_special_tokens: bool = True,  # 是否添加特殊标记,默认为 True
        padding: Union[bool, str, PaddingStrategy] = False,  # 填充策略,默认为 False
        truncation: Union[bool, str, TapasTruncationStrategy] = False,  # 截断策略,默认为 False
        max_length: Optional[int] = None,  # 可选参数,最大长度限制
        pad_to_multiple_of: Optional[int] = None,  # 可选参数,填充到指定的长度倍数
        return_tensors: Optional[Union[str, TensorType]] = None,  # 可选参数,返回的张量类型
        return_token_type_ids: Optional[bool] = True,  # 是否返回 token 类型 id,默认为 True
        return_attention_mask: Optional[bool] = True,  # 是否返回注意力掩码,默认为 True
        return_special_tokens_mask: bool = False,  # 是否返回特殊标记掩码,默认为 False
        return_offsets_mapping: bool = False,  # 是否返回偏移映射,默认为 False
        return_length: bool = False,  # 是否返回长度信息,默认为 False
        verbose: bool = True,  # 是否显示详细信息,默认为 True
        prepend_batch_axis: bool = False,  # 是否在结果中添加批次轴,默认为 False
        **kwargs,  # 其他关键字参数
    ):
        # 方法的具体实现在这里,根据参数准备数据以供模型使用

    # 定义一个方法,用于获取截断后的表格行
    def _get_truncated_table_rows(
        self,
        query_tokens: List[str],  # 查询的分词结果列表
        tokenized_table: TokenizedTable,  # 已分词的表格数据
        num_rows: int,  # 需要获取的行数
        num_columns: int,  # 表格的列数
        max_length: int,  # 最大长度限制
        truncation_strategy: Union[str, TapasTruncationStrategy],  # 截断策略,可以是字符串或 Tapas 截断策略对象
    ) -> Tuple[int, int]:
        """
        Truncates a sequence pair in-place following the strategy.

        Args:
            query_tokens (`List[str]`):
                List of strings corresponding to the tokenized query.
            tokenized_table (`TokenizedTable`):
                Tokenized table object representing the table.
            num_rows (`int`):
                Total number of rows in the table.
            num_columns (`int`):
                Total number of columns in the table.
            max_length (`int`):
                Maximum length constraint for the sequence pair.
            truncation_strategy (`str` or [`TapasTruncationStrategy`]):
                Truncation strategy to use. Only supports `"drop_rows_to_fit"` strategy.

        Returns:
            `Tuple[int, int]`: Tuple containing the number of rows after truncation and the number of tokens available
            for each table element.
        """
        # Ensure `truncation_strategy` is an instance of `TapasTruncationStrategy`
        if not isinstance(truncation_strategy, TapasTruncationStrategy):
            truncation_strategy = TapasTruncationStrategy(truncation_strategy)

        # Set `max_length` to default `self.model_max_length` if not provided
        if max_length is None:
            max_length = self.model_max_length

        # Implement truncation strategy: 'drop_rows_to_fit'
        if truncation_strategy == TapasTruncationStrategy.DROP_ROWS_TO_FIT:
            while True:
                # Calculate maximum number of tokens that can fit the table
                num_tokens = self._get_max_num_tokens(
                    query_tokens, tokenized_table, num_rows=num_rows, num_columns=num_columns, max_length=max_length
                )

                # If tokens fit the table, exit loop
                if num_tokens is not None:
                    # We could fit the table.
                    break

                # Attempt to drop a row to fit the table within the length constraint
                num_rows -= 1

                # Exit loop if no rows can be dropped further
                if num_rows < 1:
                    break
        elif truncation_strategy != TapasTruncationStrategy.DO_NOT_TRUNCATE:
            # Raise error if an unknown truncation strategy is provided
            raise ValueError(f"Unknown truncation strategy {truncation_strategy}.")

        # Return the number of rows after truncation and the number of tokens available,
        # ensuring at least 1 token is available if `num_tokens` is None
        return num_rows, num_tokens or 1

    def _tokenize_table(
        self,
        table=None,
    ):
        """
        Tokenizes column headers and cell texts of a table.

        Args:
            table (`pd.Dataframe`):
                Table to tokenize. Returns: `TokenizedTable`: TokenizedTable object.
        """
        tokenized_rows = []
        tokenized_row = []
        # tokenize column headers
        for column in table:
            # Check if column names should be stripped before tokenization
            if self.strip_column_names:
                # Tokenize an empty string for stripped column names
                tokenized_row.append(self.tokenize(""))
            else:
                # Tokenize the column name
                tokenized_row.append(self.tokenize(column))
        # Add tokenized column headers to the list of tokenized rows
        tokenized_rows.append(tokenized_row)

        # tokenize cell values
        for idx, row in table.iterrows():
            tokenized_row = []
            for cell in row:
                # Tokenize each cell value
                tokenized_row.append(self.tokenize(cell))
            # Add tokenized row to the list of tokenized rows
            tokenized_rows.append(tokenized_row)

        token_coordinates = []
        # Create token coordinates for each token in the tokenized table
        for row_index, row in enumerate(tokenized_rows):
            for column_index, cell in enumerate(row):
                for token_index, _ in enumerate(cell):
                    # Append token coordinates to the list
                    token_coordinates.append(
                        TokenCoordinates(
                            row_index=row_index,
                            column_index=column_index,
                            token_index=token_index,
                        )
                    )

        # Return a TokenizedTable object containing tokenized rows and token coordinates
        return TokenizedTable(
            rows=tokenized_rows,
            selected_tokens=token_coordinates,
        )

    def _question_encoding_cost(self, question_tokens):
        # Calculate the encoding cost for a question, including two extra tokens for SEP and CLS
        return len(question_tokens) + 2

    def _get_token_budget(self, question_tokens, max_length=None):
        """
        Computes the number of tokens left for the table after tokenizing a question, taking into account the max
        sequence length of the model.

        Args:
            question_tokens (`List[String]`):
                List of tokens representing the question. Returns: `int`: the number of tokens left for the table,
                given the model max length.
        """
        # Determine the remaining token budget for the table after encoding the question
        return (max_length if max_length is not None else self.model_max_length) - self._question_encoding_cost(
            question_tokens
        )
    def _get_table_values(self, table, num_columns, num_rows, num_tokens) -> Generator[TableValue, None, None]:
        """Iterates over partial table and returns token, column and row indexes."""
        # 遍历选定的表格中的令牌
        for tc in table.selected_tokens:
            # 第一行是表头行,跳过
            if tc.row_index >= num_rows + 1:
                continue
            # 如果列索引超过指定的列数,跳过
            if tc.column_index >= num_columns:
                continue
            # 获取表格中指定位置的单元格内容
            cell = table.rows[tc.row_index][tc.column_index]
            # 获取单元格中指定的令牌
            token = cell[tc.token_index]
            word_begin_index = tc.token_index
            # 不添加部分单词。查找起始词片段并检查是否符合令牌预算。
            while word_begin_index >= 0 and _is_inner_wordpiece(cell[word_begin_index]):
                word_begin_index -= 1
            # 如果起始词片段超过指定的令牌数量,跳过
            if word_begin_index >= num_tokens:
                continue
            # 返回表格中的值,包括令牌、列索引加一、行索引
            yield TableValue(token, tc.column_index + 1, tc.row_index)

    def _get_table_boundaries(self, table):
        """Return maximal number of rows, columns and tokens."""
        # 初始化最大的行数、列数和令牌数
        max_num_tokens = 0
        max_num_columns = 0
        max_num_rows = 0
        # 遍历选定的表格中的令牌
        for tc in table.selected_tokens:
            # 更新最大的列数、行数和令牌数
            max_num_columns = max(max_num_columns, tc.column_index + 1)
            max_num_rows = max(max_num_rows, tc.row_index + 1)
            max_num_tokens = max(max_num_tokens, tc.token_index + 1)
            # 确保最大的列数和行数不超过预设的最大值
            max_num_columns = min(self.max_column_id, max_num_columns)
            max_num_rows = min(self.max_row_id, max_num_rows)
        # 返回最大的行数、列数和令牌数
        return max_num_rows, max_num_columns, max_num_tokens

    def _get_table_cost(self, table, num_columns, num_rows, num_tokens):
        # 计算使用指定令牌数量时的表格代价
        return sum(1 for _ in self._get_table_values(table, num_columns, num_rows, num_tokens))

    def _get_max_num_tokens(self, question_tokens, tokenized_table, num_columns, num_rows, max_length):
        """Computes max number of tokens that can be squeezed into the budget."""
        # 获取问题令牌的预算
        token_budget = self._get_token_budget(question_tokens, max_length)
        # 获取表格的行数、列数和最大的令牌数
        _, _, max_num_tokens = self._get_table_boundaries(tokenized_table)
        # 如果单元格修剪长度大于等于零且最大令牌数超过单元格修剪长度,则将最大令牌数设为单元格修剪长度
        if self.cell_trim_length >= 0 and max_num_tokens > self.cell_trim_length:
            max_num_tokens = self.cell_trim_length
        num_tokens = 0
        # 遍历最大令牌数加一的范围
        for num_tokens in range(max_num_tokens + 1):
            # 计算使用指定令牌数量时的表格代价
            cost = self._get_table_cost(tokenized_table, num_columns, num_rows, num_tokens + 1)
            # 如果代价超过了令牌预算,停止遍历
            if cost > token_budget:
                break
        # 如果使用的令牌数小于最大令牌数
        if num_tokens < max_num_tokens:
            # 如果单元格修剪长度大于等于零,则不允许动态修剪
            if self.cell_trim_length >= 0:
                return None
            # 如果使用的令牌数为零,则返回空
            if num_tokens == 0:
                return None
        # 返回可使用的最大令牌数
        return num_tokens

    def _get_num_columns(self, table):
        # 获取表格的列数
        num_columns = table.shape[1]
        # 如果列数超过预设的最大列数,则抛出数值错误异常
        if num_columns >= self.max_column_id:
            raise ValueError("Too many columns")
        # 返回表格的列数
        return num_columns
    def _get_num_rows(self, table, drop_rows_to_fit):
        # 获取表格的行数
        num_rows = table.shape[0]
        # 如果行数超过最大允许的行数
        if num_rows >= self.max_row_id:
            # 如果允许删除超出部分的行
            if drop_rows_to_fit:
                # 将行数调整为最大允许行数减一
                num_rows = self.max_row_id - 1
            else:
                # 否则抛出异常,提示行数过多
                raise ValueError("Too many rows")
        # 返回最终确定的行数
        return num_rows

    def _serialize_text(self, question_tokens):
        """将文本序列化为索引数组。"""
        tokens = []
        segment_ids = []
        column_ids = []
        row_ids = []

        # 在序列化文本开头添加 [CLS] 标记
        tokens.append(self.cls_token)
        segment_ids.append(0)
        column_ids.append(0)
        row_ids.append(0)

        # 遍历问题的每个词汇
        for token in question_tokens:
            tokens.append(token)
            segment_ids.append(0)
            column_ids.append(0)
            row_ids.append(0)

        # 返回序列化后的 tokens, segment_ids, column_ids, row_ids
        return tokens, segment_ids, column_ids, row_ids

    def _serialize(
        self,
        question_tokens,
        table,
        num_columns,
        num_rows,
        num_tokens,
    ):
        """序列化表格和文本。"""
        tokens, segment_ids, column_ids, row_ids = self._serialize_text(question_tokens)

        # 在问题和表格 tokens 之间添加 [SEP] 标记
        tokens.append(self.sep_token)
        segment_ids.append(0)
        column_ids.append(0)
        row_ids.append(0)

        # 获取表格中的每个单元格的值,并添加到序列化结果中
        for token, column_id, row_id in self._get_table_values(table, num_columns, num_rows, num_tokens):
            tokens.append(token)
            segment_ids.append(1)  # 表示这是来自表格的内容
            column_ids.append(column_id)
            row_ids.append(row_id)

        # 返回序列化后的 SerializedExample 对象
        return SerializedExample(
            tokens=tokens,
            segment_ids=segment_ids,
            column_ids=column_ids,
            row_ids=row_ids,
        )

    def _get_column_values(self, table, col_index):
        """获取表格中指定列的数值。"""
        table_numeric_values = {}
        # 遍历表格的每一行
        for row_index, row in table.iterrows():
            cell = row[col_index]
            # 如果单元格的值是数值类型,则加入到结果字典中
            if cell.numeric_value is not None:
                table_numeric_values[row_index] = cell.numeric_value
        # 返回包含数值的字典
        return table_numeric_values

    def _get_cell_token_indexes(self, column_ids, row_ids, column_id, row_id):
        """获取特定列和行索引对应的 token 索引。"""
        # 遍历所有 token 的索引
        for index in range(len(column_ids)):
            # 如果找到与指定列和行索引对应的 token 索引
            if column_ids[index] - 1 == column_id and row_ids[index] - 1 == row_id:
                # 返回该 token 索引
                yield index
    def _get_numeric_column_ranks(self, column_ids, row_ids, table):
        """Returns column ranks for all numeric columns."""

        # 初始化列的排名和反向排名的列表,长度为列的数量
        ranks = [0] * len(column_ids)
        inv_ranks = [0] * len(column_ids)

        # 如果表格对象不为空
        if table is not None:
            # 遍历表格的所有列
            for col_index in range(len(table.columns)):
                # 获取当前列的所有数值
                table_numeric_values = self._get_column_values(table, col_index)

                # 如果当前列没有数值则跳过
                if not table_numeric_values:
                    continue

                try:
                    # 获取用于排序数值的函数
                    key_fn = get_numeric_sort_key_fn(table_numeric_values.values())
                except ValueError:
                    # 如果获取排序函数时发生错误则跳过当前列
                    continue

                # 将当前列的数值转换为排序后的字典形式
                table_numeric_values = {row_index: key_fn(value) for row_index, value in table_numeric_values.items()}

                # 创建一个反向映射字典,将数值映射到行索引的列表
                table_numeric_values_inv = collections.defaultdict(list)
                for row_index, value in table_numeric_values.items():
                    table_numeric_values_inv[value].append(row_index)

                # 对唯一的数值进行排序
                unique_values = sorted(table_numeric_values_inv.keys())

                # 根据数值的排名为每个单元格设置排名和反向排名
                for rank, value in enumerate(unique_values):
                    for row_index in table_numeric_values_inv[value]:
                        for index in self._get_cell_token_indexes(column_ids, row_ids, col_index, row_index):
                            ranks[index] = rank + 1
                            inv_ranks[index] = len(unique_values) - rank

        # 返回列的排名和反向排名列表
        return ranks, inv_ranks

    def _get_numeric_sort_key_fn(self, table_numeric_values, value):
        """
        Returns the sort key function for comparing value to table values. The function returned will be a suitable
        input for the key param of the sort(). See number_annotation_utils._get_numeric_sort_key_fn for details

        Args:
            table_numeric_values: Numeric values of a column
            value: Numeric value in the question

        Returns:
            A function key function to compare column and question values.
        """
        # 如果表格数值为空,则返回 None
        if not table_numeric_values:
            return None
        # 将所有列的数值放入一个列表,并加入当前问题的数值
        all_values = list(table_numeric_values.values())
        all_values.append(value)
        try:
            # 获取所有数值的排序函数
            return get_numeric_sort_key_fn(all_values)
        except ValueError:
            # 如果获取排序函数时发生错误,则返回 None
            return None
    # 返回数值关系的嵌入

    # 创建一个字典,将表格单元格映射到其与问题中任何值的所有关系的集合
    cell_indices_to_relations = collections.defaultdict(set)
    
    # 如果问题和表格都不为空,则处理数值值跨度并添加到问题中
    if question is not None and table is not None:
        for numeric_value_span in question.numeric_spans:
            for value in numeric_value_span.values:
                for column_index in range(len(table.columns)):
                    # 获取该列的所有数值
                    table_numeric_values = self._get_column_values(table, column_index)
                    # 获取排序键函数
                    sort_key_fn = self._get_numeric_sort_key_fn(table_numeric_values, value)
                    if sort_key_fn is None:
                        continue
                    # 遍历每个单元格的数值,并确定数值关系
                    for row_index, cell_value in table_numeric_values.items():
                        relation = get_numeric_relation(value, cell_value, sort_key_fn)
                        if relation is not None:
                            cell_indices_to_relations[column_index, row_index].add(relation)

    # 为每个单元格的所有词片段添加一个特殊特征
    for (column_index, row_index), relations in cell_indices_to_relations.items():
        relation_set_index = 0
        for relation in relations:
            # 确保关系值大于等于Relation.EQ的值
            assert relation.value >= Relation.EQ.value
            relation_set_index += 2 ** (relation.value - Relation.EQ.value)
        # 获取单元格词片段的索引并设置数值关系
        for cell_token_index in self._get_cell_token_indexes(column_ids, row_ids, column_index, row_index):
            numeric_relations[cell_token_index] = relation_set_index

    # 返回计算得到的数值关系列表
    return numeric_relations
    # 返回用于计算答案损失的数值列表
    def _get_numeric_values(self, table, column_ids, row_ids):
        numeric_values = [float("nan")] * len(column_ids)  # 初始化一个长度为列数的数值列表,初始值为 NaN

        if table is not None:
            num_rows = table.shape[0]  # 获取表格的行数
            num_columns = table.shape[1]  # 获取表格的列数

            # 遍历表格的每一列和每一行
            for col_index in range(num_columns):
                for row_index in range(num_rows):
                    numeric_value = table.iloc[row_index, col_index].numeric_value  # 获取指定单元格的数值
                    if numeric_value is not None:
                        if numeric_value.float_value is None:
                            continue
                        float_value = numeric_value.float_value  # 获取数值的浮点值
                        if float_value == float("inf"):  # 如果浮点值为无穷大,则跳过
                            continue
                        # 获取当前单元格对应的 token 索引,并将数值赋给对应索引的数值列表
                        for index in self._get_cell_token_indexes(column_ids, row_ids, col_index, row_index):
                            numeric_values[index] = float_value

        return numeric_values

    # 返回一个用于降低长单词价值的每个 token 的缩放比例列表
    def _get_numeric_values_scale(self, table, column_ids, row_ids):
        numeric_values_scale = [1.0] * len(column_ids)  # 初始化一个长度为列数的缩放比例列表,初始值为 1.0

        if table is None:
            return numeric_values_scale  # 如果表格为空,则直接返回初始的缩放比例列表

        num_rows = table.shape[0]  # 获取表格的行数
        num_columns = table.shape[1]  # 获取表格的列数

        # 遍历表格的每一列和每一行
        for col_index in range(num_columns):
            for row_index in range(num_rows):
                indices = list(self._get_cell_token_indexes(column_ids, row_ids, col_index, row_index))  # 获取单元格对应的 token 索引列表
                num_indices = len(indices)  # 获取 token 索引列表的长度
                if num_indices > 1:
                    # 如果单元格对应的 token 索引数量大于 1,则将缩放比例设置为索引的数量
                    for index in indices:
                        numeric_values_scale[index] = float(num_indices)

        return numeric_values_scale

    # 将输入列表填充到模型最大长度
    def _pad_to_seq_length(self, inputs):
        while len(inputs) > self.model_max_length:  # 当输入列表长度超过模型最大长度时
            inputs.pop()  # 移除末尾的元素
        while len(inputs) < self.model_max_length:  # 当输入列表长度小于模型最大长度时
            inputs.append(0)  # 在末尾添加值为 0 的元素

    # 根据答案坐标获取所有答案的 token 索引列表和缺失答案数量
    def _get_all_answer_ids_from_coordinates(
        self,
        column_ids,
        row_ids,
        answers_list,
    ):
        """Maps lists of answer coordinates to token indexes."""
        answer_ids = [0] * len(column_ids)  # 初始化一个长度为列数的答案 ID 列表,初始值为 0
        found_answers = set()  # 用于存储已找到的答案坐标的集合
        all_answers = set()  # 用于存储所有答案坐标的集合

        for answers in answers_list:  # 遍历答案坐标列表
            column_index, row_index = answers  # 获取列索引和行索引
            all_answers.add((column_index, row_index))  # 将答案坐标添加到所有答案集合中
            for index in self._get_cell_token_indexes(column_ids, row_ids, column_index, row_index):
                # 获取答案坐标对应的 token 索引,并将答案标记为已找到
                found_answers.add((column_index, row_index))
                answer_ids[index] = 1  # 将答案对应的 token 索引位置设置为 1,表示找到了答案

        missing_count = len(all_answers) - len(found_answers)  # 计算未找到的答案数量
        return answer_ids, missing_count  # 返回答案 ID 列表和未找到的答案数量
    def _get_all_answer_ids(self, column_ids, row_ids, answer_coordinates):
        """
        Maps answer coordinates of a question to token indexes.

        In the SQA format (TSV), the coordinates are given as (row, column) tuples. Here, we first swap them to
        (column, row) format before calling _get_all_answer_ids_from_coordinates.
        """

        def _to_coordinates(answer_coordinates_question):
            # 转换答案坐标格式为 (column, row) 形式
            return [(coords[1], coords[0]) for coords in answer_coordinates_question]

        # 调用 _get_all_answer_ids_from_coordinates 方法,传入调整后的答案坐标
        return self._get_all_answer_ids_from_coordinates(
            column_ids, row_ids, answers_list=(_to_coordinates(answer_coordinates))
        )

    def _find_tokens(self, text, segment):
        """Return start index of segment in text or None."""
        # 记录调试信息,输出文本和查找的段落
        logging.info(f"text: {text} {segment}")
        # 在文本中查找段落的起始索引
        for index in range(1 + len(text) - len(segment)):
            for seg_index, seg_token in enumerate(segment):
                # 如果当前位置的字符与段落不匹配,则终止此次匹配
                if text[index + seg_index].piece != seg_token.piece:
                    break
            else:
                # 如果完全匹配,则返回段落在文本中的起始索引
                return index
        # 如果未找到匹配的段落,则返回 None
        return None

    def _find_answer_coordinates_from_answer_text(
        self,
        tokenized_table,
        answer_text,
    ):
        """Returns all occurrences of answer_text in the table."""
        # 记录调试信息,输出答案文本
        logging.info(f"answer text: {answer_text}")
        # 遍历表格的每一行和每一列,寻找答案文本的位置
        for row_index, row in enumerate(tokenized_table.rows):
            if row_index == 0:
                # 跳过表头行,不在表头中搜索答案
                continue
            for col_index, cell in enumerate(row):
                # 在单元格中查找答案文本的 token 索引
                token_index = self._find_tokens(cell, answer_text)
                if token_index is not None:
                    # 如果找到匹配的答案文本,则生成对应的 token 坐标
                    yield TokenCoordinates(
                        row_index=row_index,
                        column_index=col_index,
                        token_index=token_index,
                    )

    def _find_answer_ids_from_answer_texts(
        self,
        column_ids,
        row_ids,
        tokenized_table,
        answer_texts,
    ):
        """
        Returns answer IDs corresponding to given answer texts in a tokenized table.

        This function iterates through provided answer texts, finds their token positions in the tokenized table,
        and yields corresponding answer IDs based on column and row IDs.
        """
        # 循环遍历每个答案文本,查找其在 token 化表格中的位置,并返回对应的答案 ID
        for answer_text in answer_texts:
            for token_coord in self._find_answer_coordinates_from_answer_text(tokenized_table, answer_text):
                yield (column_ids[token_coord.column_index], row_ids[token_coord.row_index])
    ):
        """
        Maps question with answer texts to the first matching token indexes.
        """
        answer_ids = [0] * len(column_ids)
        for answer_text in answer_texts:
            for coordinates in self._find_answer_coordinates_from_answer_text(
                tokenized_table,
                answer_text,
            ):
                # Maps answer coordinates to indexes; this can fail if tokens/rows have
                # been pruned.
                indexes = list(
                    self._get_cell_token_indexes(
                        column_ids,
                        row_ids,
                        column_id=coordinates.column_index,
                        row_id=coordinates.row_index - 1,
                    )
                )
                indexes.sort()
                coordinate_answer_ids = []
                if indexes:
                    begin_index = coordinates.token_index + indexes[0]
                    end_index = begin_index + len(answer_text)
                    for index in indexes:
                        if index >= begin_index and index < end_index:
                            coordinate_answer_ids.append(index)
                if len(coordinate_answer_ids) == len(answer_text):
                    for index in coordinate_answer_ids:
                        answer_ids[index] = 1
                    break
        return answer_ids

    def _get_answer_ids(self, column_ids, row_ids, answer_coordinates):
        """
        Maps answer coordinates of a question to token indexes.
        """
        answer_ids, missing_count = self._get_all_answer_ids(column_ids, row_ids, answer_coordinates)

        if missing_count:
            raise ValueError("Couldn't find all answers")
        return answer_ids

    def get_answer_ids(self, column_ids, row_ids, tokenized_table, answer_texts_question, answer_coordinates_question):
        """
        Retrieves answer IDs based on whether to update answer coordinates or not.
        """
        if self.update_answer_coordinates:
            return self._find_answer_ids_from_answer_texts(
                column_ids,
                row_ids,
                tokenized_table,
                answer_texts=[self.tokenize(at) for at in answer_texts_question],
            )
        return self._get_answer_ids(column_ids, row_ids, answer_coordinates_question)

    def _pad(
        self,
        encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
        max_length: Optional[int] = None,
        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
        pad_to_multiple_of: Optional[int] = None,
        return_attention_mask: Optional[bool] = None,
    ):
        """
        Handles padding of encoded inputs according to specified strategies.
        """
        # Everything related to converting logits to predictions

    def _get_cell_token_probs(self, probabilities, segment_ids, row_ids, column_ids):
        """
        Yields token probabilities for cell tokens based on conditions.
        """
        for i, p in enumerate(probabilities):
            segment_id = segment_ids[i]
            col = column_ids[i] - 1
            row = row_ids[i] - 1
            if col >= 0 and row >= 0 and segment_id == 1:
                yield i, p
    # 计算每个单元格的平均概率,根据标记的概率值进行聚合计算
    def _get_mean_cell_probs(self, probabilities, segment_ids, row_ids, column_ids):
        """Computes average probability per cell, aggregating over tokens."""
        # 使用默认字典存储坐标对应的概率列表
        coords_to_probs = collections.defaultdict(list)
        # 遍历获取每个单元格中的标记概率
        for i, prob in self._get_cell_token_probs(probabilities, segment_ids, row_ids, column_ids):
            # 获取单元格所在列和行,将其从1-based转换为0-based
            col = column_ids[i] - 1
            row = row_ids[i] - 1
            # 将概率添加到坐标对应的概率列表中
            coords_to_probs[(col, row)].append(prob)
        # 计算每个坐标对应的单元格概率的平均值,并返回结果字典
        return {coords: np.array(cell_probs).mean() for coords, cell_probs in coords_to_probs.items()}
    
    # 转换逻辑值到预测结果的所有相关内容结束
# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer
class BasicTokenizer(object):
    """
    Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).

    Args:
        do_lower_case (`bool`, *optional*, defaults to `True`):
            Whether or not to lowercase the input when tokenizing.
        never_split (`Iterable`, *optional*):
            Collection of tokens which will never be split during tokenization. Only has an effect when
            `do_basic_tokenize=True`
        tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
            Whether or not to tokenize Chinese characters.

            This should likely be deactivated for Japanese (see this
            [issue](https://github.com/huggingface/transformers/issues/328)).
        strip_accents (`bool`, *optional*):
            Whether or not to strip all accents. If this option is not specified, then it will be determined by the
            value for `lowercase` (as in the original BERT).
        do_split_on_punc (`bool`, *optional*, defaults to `True`):
            In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
            the full context of the words, such as contractions.
    """

    def __init__(
        self,
        do_lower_case=True,
        never_split=None,
        tokenize_chinese_chars=True,
        strip_accents=None,
        do_split_on_punc=True,
    ):
        # 如果 `never_split` 为 None,则初始化为空列表
        if never_split is None:
            never_split = []
        # 设置是否将输入文本转换为小写
        self.do_lower_case = do_lower_case
        # 将 `never_split` 转换为集合,用于存储不需要分割的特殊标记
        self.never_split = set(never_split)
        # 设置是否对中文字符进行单独的分词处理
        self.tokenize_chinese_chars = tokenize_chinese_chars
        # 设置是否去除所有的重音符号
        self.strip_accents = strip_accents
        # 设置是否进行基本的标点符号分割
        self.do_split_on_punc = do_split_on_punc
    # 对文本进行基本的分词处理。如果需要子词分词,请参考WordPieceTokenizer。
    # 
    # Args:
    #     never_split (`List[str]`, *optional*)
    #         为了向后兼容保留。现在直接在基类级别实现(参见`PreTrainedTokenizer.tokenize`)不分割的标记列表。
    def tokenize(self, text, never_split=None):
        # 如果传入了never_split参数,则将其与self.never_split合并为一个新的集合
        never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
        # 清理文本,去除可能存在的特殊字符
        text = self._clean_text(text)

        # 以下代码块是为了处理多语言和中文模型而添加的,自2018年11月1日起。
        # 现在英语模型也适用,尽管由于英语模型未经过任何中文数据的训练,
        # 并且通常不包含任何中文数据(因为维基百科在英语版本中确实包含一些中文词汇)。
        if self.tokenize_chinese_chars:
            # 如果启用了中文字符分词,则调用内部方法_tokenize_chinese_chars处理文本
            text = self._tokenize_chinese_chars(text)
        
        # 使用Unicode NFC规范化文本,确保统一表示同一字符
        unicode_normalized_text = unicodedata.normalize("NFC", text)
        # 使用whitespace_tokenize对文本进行空白字符分割,获取原始token列表
        orig_tokens = whitespace_tokenize(unicode_normalized_text)
        split_tokens = []
        
        # 遍历原始token列表,处理每个token
        for token in orig_tokens:
            # 如果token不在never_split中,则可能需要进一步处理
            if token not in never_split:
                if self.do_lower_case:
                    # 如果需要小写化处理,则将token转换为小写
                    token = token.lower()
                    # 如果需要去除重音符号,则调用_run_strip_accents方法处理token
                    if self.strip_accents is not False:
                        token = self._run_strip_accents(token)
                elif self.strip_accents:
                    # 如果仅需要去除重音符号,则调用_run_strip_accents方法处理token
                    token = self._run_strip_accents(token)
            # 将处理后的token列表拼接到split_tokens中
            split_tokens.extend(self._run_split_on_punc(token, never_split))

        # 将split_tokens中的token用空白字符连接成字符串,再进行空白字符分割,获取最终输出的token列表
        output_tokens = whitespace_tokenize(" ".join(split_tokens))
        return output_tokens

    # 从文本中去除重音符号
    def _run_strip_accents(self, text):
        text = unicodedata.normalize("NFD", text)
        output = []
        for char in text:
            cat = unicodedata.category(char)
            # 如果字符的Unicode类别为Mn(Mark, Nonspacing),则跳过该字符
            if cat == "Mn":
                continue
            # 将不含重音符号的字符加入到output列表中
            output.append(char)
        # 将output列表中的字符连接成字符串并返回
        return "".join(output)
    def _run_split_on_punc(self, text, never_split=None):
        """Splits punctuation on a piece of text."""
        # 如果不需要在标点符号处分割,或者指定的文本在never_split列表中,则返回原始文本列表
        if not self.do_split_on_punc or (never_split is not None and text in never_split):
            return [text]
        chars = list(text)
        i = 0
        start_new_word = True
        output = []
        while i < len(chars):
            char = chars[i]
            if _is_punctuation(char):
                # 如果是标点符号,创建一个新列表存储当前标点符号
                output.append([char])
                start_new_word = True
            else:
                if start_new_word:
                    # 如果是新词的开始,创建一个空列表
                    output.append([])
                start_new_word = False
                # 将当前字符添加到当前词的列表中
                output[-1].append(char)
            i += 1

        # 将列表中的子列表转换为字符串并返回
        return ["".join(x) for x in output]

    def _tokenize_chinese_chars(self, text):
        """Adds whitespace around any CJK character."""
        output = []
        for char in text:
            cp = ord(char)
            if self._is_chinese_char(cp):
                # 如果是中文字符,则在其前后添加空格
                output.append(" ")
                output.append(char)
                output.append(" ")
            else:
                output.append(char)
        # 将列表中的字符连接成字符串并返回
        return "".join(output)

    def _is_chinese_char(self, cp):
        """Checks whether CP is the codepoint of a CJK character."""
        # 检查给定的码点是否是CJK字符的码点范围内
        if (
            (cp >= 0x4E00 and cp <= 0x9FFF)
            or (cp >= 0x3400 and cp <= 0x4DBF)  #
            or (cp >= 0x20000 and cp <= 0x2A6DF)  #
            or (cp >= 0x2A700 and cp <= 0x2B73F)  #
            or (cp >= 0x2B740 and cp <= 0x2B81F)  #
            or (cp >= 0x2B820 and cp <= 0x2CEAF)  #
            or (cp >= 0xF900 and cp <= 0xFAFF)
            or (cp >= 0x2F800 and cp <= 0x2FA1F)  #
        ):  #
            return True

        return False

    def _clean_text(self, text):
        """Performs invalid character removal and whitespace cleanup on text."""
        output = []
        for char in text:
            cp = ord(char)
            if cp == 0 or cp == 0xFFFD or _is_control(char):
                continue
            if _is_whitespace(char):
                # 如果是空白字符,则替换为单个空格
                output.append(" ")
            else:
                output.append(char)
        # 将列表中的字符连接成字符串并返回
        return "".join(output)
# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer
# WordpieceTokenizer 类,用于运行 WordPiece 分词算法。

class WordpieceTokenizer(object):
    """Runs WordPiece tokenization."""
    # 初始化 WordpieceTokenizer 类
    def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
        # 初始化词汇表、未知标记和每个单词最大输入字符数
        self.vocab = vocab
        self.unk_token = unk_token
        self.max_input_chars_per_word = max_input_chars_per_word

    # 对文本进行 WordPiece 分词处理
    def tokenize(self, text):
        """
        Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
        tokenization using the given vocabulary.

        For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.

        Args:
            text: A single token or whitespace separated tokens. This should have
                already been passed through *BasicTokenizer*.

        Returns:
            A list of wordpiece tokens.
        """
        # 初始化输出的 token 列表
        output_tokens = []
        # 使用 whitespace_tokenize 函数对文本进行分词
        for token in whitespace_tokenize(text):
            chars = list(token)
            # 如果单词长度超过设定的最大字符数,则使用未知标记代替
            if len(chars) > self.max_input_chars_per_word:
                output_tokens.append(self.unk_token)
                continue

            is_bad = False
            start = 0
            sub_tokens = []
            # 贪婪算法,尝试寻找最长匹配的词片段
            while start < len(chars):
                end = len(chars)
                cur_substr = None
                while start < end:
                    substr = "".join(chars[start:end])
                    # 对非首字符的片段加上 '##' 前缀,表示连接词的一部分
                    if start > 0:
                        substr = "##" + substr
                    # 如果片段在词汇表中,则认为是一个有效的词片段
                    if substr in self.vocab:
                        cur_substr = substr
                        break
                    end -= 1
                # 如果没有找到匹配的词片段,则将该 token 标记为未知标记
                if cur_substr is None:
                    is_bad = True
                    break
                sub_tokens.append(cur_substr)
                start = end

            # 如果标记为无效,则使用未知标记代替
            if is_bad:
                output_tokens.append(self.unk_token)
            else:
                output_tokens.extend(sub_tokens)
        # 返回最终的 wordpiece tokens 列表
        return output_tokens


# Below: utilities for TAPAS tokenizer (independent from PyTorch/Tensorflow).
# This includes functions to parse numeric values (dates and numbers) from both the table and questions in order
# to create the column_ranks, inv_column_ranks, numeric_values, numeric values_scale and numeric_relations in
# prepare_for_model of TapasTokenizer.
# These are meant to be used in an academic setup, for production use cases Gold mine or Aqua should be used.


# taken from constants.py of the original implementation
# URL: https://github.com/google-research/tapas/blob/master/tapas/utils/constants.py
# 定义了不同类型的关系,用于在表格处理中连接不同的元素

class Relation(enum.Enum):
    HEADER_TO_CELL = 1  # 连接表头到单元格
    CELL_TO_HEADER = 2  # 连接单元格到表头
    QUERY_TO_HEADER = 3  # 连接查询到表头
    QUERY_TO_CELL = 4  # 连接查询到单元格
    ROW_TO_CELL = 5  # 连接行到单元格
    CELL_TO_ROW = 6  # 连接单元格到行
    EQ = 7  # 标注值等于单元格值
    # 定义常量 LT,表示注释值小于单元格值
    LT = 8  # Annotation value is less than cell value
    
    # 定义常量 GT,表示注释值大于单元格值
    GT = 9  # Annotation value is greater than cell value
# 使用 dataclass 装饰器定义日期类,支持可选的年、月、日属性
@dataclass
class Date:
    year: Optional[int] = None
    month: Optional[int] = None
    day: Optional[int] = None

# 使用 dataclass 装饰器定义数值类,支持可选的浮点值和日期属性
@dataclass
class NumericValue:
    float_value: Optional[float] = None
    date: Optional[Date] = None

# 使用 dataclass 装饰器定义数值区间类,包含开始和结束索引以及数值列表属性
@dataclass
class NumericValueSpan:
    begin_index: int = None
    end_index: int = None
    values: List[NumericValue] = None

# 使用 dataclass 装饰器定义单元格类,包含文本和可选的数值属性
@dataclass
class Cell:
    text: Text
    numeric_value: Optional[NumericValue] = None

# 使用 dataclass 装饰器定义问题类,包含原始文本、归一化后的文本和可选的数值区间列表属性
@dataclass
class Question:
    original_text: Text  # 原始问题字符串
    text: Text  # 归一化后的问题字符串
    numeric_spans: Optional[List[NumericValueSpan]] = None

# 下面是从 number_utils.py 中导入的所有函数以及从 text_utils.py 中导入的两个函数(即 get_all_spans 和 normalize_for_match)
# 原始实现的 URL 可查阅:
# - https://github.com/google-research/tapas/blob/master/tapas/utils/number_utils.py
# - https://github.com/google-research/tapas/blob/master/tapas/utils/text_utils.py

# 用于解析日期表达式的常量
# 命名元组 _DateMask 指定了哪些字段(年、月、日)将被填充
_DateMask = collections.namedtuple("_DateMask", ["year", "month", "day"])

# 常量 _YEAR 表示只填充年份
_YEAR = _DateMask(True, False, False)

# 常量 _YEAR_MONTH 表示填充年份和月份
_YEAR_MONTH = _DateMask(True, True, False)

# 常量 _YEAR_MONTH_DAY 表示填充年份、月份和日期
_YEAR_MONTH_DAY = _DateMask(True, True, True)

# 常量 _MONTH 表示只填充月份
_MONTH = _DateMask(False, True, False)

# 常量 _MONTH_DAY 表示填充月份和日期
_MONTH_DAY = _DateMask(False, True, True)

# _DATE_PATTERNS 是一个元组,每个元素包含一个日期格式和一个对应的 _DateMask,用于 datetime.strptime 的参数
_DATE_PATTERNS = (
    ("%B", _MONTH),
    ("%Y", _YEAR),
    ("%Ys", _YEAR),
    ("%b %Y", _YEAR_MONTH),
    ("%B %Y", _YEAR_MONTH),
    ("%B %d", _MONTH_DAY),
    ("%b %d", _MONTH_DAY),
    ("%d %b", _MONTH_DAY),
    ("%d %B", _MONTH_DAY),
    ("%B %d, %Y", _YEAR_MONTH_DAY),
    ("%d %B %Y", _YEAR_MONTH_DAY),
    ("%m-%d-%Y", _YEAR_MONTH_DAY),
    ("%Y-%m-%d", _YEAR_MONTH_DAY),
    ("%Y-%m", _YEAR_MONTH),
    ("%B %Y", _YEAR_MONTH),
    ("%d %b %Y", _YEAR_MONTH_DAY),
    ("%Y-%m-%d", _YEAR_MONTH_DAY),
    ("%b %d, %Y", _YEAR_MONTH_DAY),
    ("%d.%m.%Y", _YEAR_MONTH_DAY),
    ("%A, %b %d", _MONTH_DAY),
    ("%A, %B %d", _MONTH_DAY),
)

# _FIELD_TO_REGEX 是一个元组,每个元素包含一个日期格式和一个对应的正则表达式,用于将日期格式转换为正则表达式
_FIELD_TO_REGEX = (
    ("%A", r"\w+"),    # 本地化全名的星期几
    ("%B", r"\w+"),    # 本地化全名的月份
    ("%Y", r"\d{4}"),  # 带世纪的年份作为十进制数
    ("%b", r"\w{3}"),  # 本地化缩写的月份
    ("%d", r"\d{1,2}"),  # 月份中的天数,作为零填充的十进制数
    ("%m", r"\d{1,2}"),  # 月份作为零填充的十进制数
)

def _process_date_pattern(dp):
    """为每个日期模式计算一个正则表达式作为预过滤器。"""
    pattern, mask = dp
    regex = pattern
    regex = regex.replace(".", re.escape("."))  # 转义点号
    regex = regex.replace("-", re.escape("-"))  # 转义破折号
    regex = regex.replace(" ", r"\s+")  # 替换空格为匹配任意空白字符的正则表达式
    # 遍历 `_FIELD_TO_REGEX` 列表中的每个元素,元素包含字段名和对应的正则表达式
    for field, field_regex in _FIELD_TO_REGEX:
        # 替换当前正则表达式 `regex` 中的字段名 `field` 为对应的字段正则表达式 `field_regex`
        regex = regex.replace(field, field_regex)
    
    # 断言检查,确保替换后的 `regex` 中不包含 `%` 符号,否则输出当前的 `regex`
    assert "%" not in regex, regex
    
    # 返回编译后的模式 `pattern`、掩码 `mask` 和以 `regex` 开头和结尾的编译后的正则表达式对象
    return pattern, mask, re.compile("^" + regex + "$")
def _process_date_patterns():
    # 调用 _process_date_pattern 函数处理 _DATE_PATTERNS 中的每个模式,并返回处理后的元组
    return tuple(_process_date_pattern(dp) for dp in _DATE_PATTERNS)


_PROCESSED_DATE_PATTERNS = _process_date_patterns()

_MAX_DATE_NGRAM_SIZE = 5

# Following DynSp:
# https://github.com/Microsoft/DynSP/blob/master/util.py#L414.
_NUMBER_WORDS = [
    "zero",     # 数字 0 对应的英文单词
    "one",      # 数字 1 对应的英文单词
    "two",      # 数字 2 对应的英文单词
    "three",    # 数字 3 对应的英文单词
    "four",     # 数字 4 对应的英文单词
    "five",     # 数字 5 对应的英文单词
    "six",      # 数字 6 对应的英文单词
    "seven",    # 数字 7 对应的英文单词
    "eight",    # 数字 8 对应的英文单词
    "nine",     # 数字 9 对应的英文单词
    "ten",      # 数字 10 对应的英文单词
    "eleven",   # 数字 11 对应的英文单词
    "twelve",   # 数字 12 对应的英文单词
]

_ORDINAL_WORDS = [
    "zeroth",    # 序数 0 对应的英文单词
    "first",     # 序数 1 对应的英文单词
    "second",    # 序数 2 对应的英文单词
    "third",     # 序数 3 对应的英文单词
    "fourth",    # 序数 4 对应的英文单词
    "fith",      # 序数 5 对应的英文单词 (可能应为 fifth)
    "sixth",     # 序数 6 对应的英文单词
    "seventh",   # 序数 7 对应的英文单词
    "eighth",    # 序数 8 对应的英文单词
    "ninth",     # 序数 9 对应的英文单词
    "tenth",     # 序数 10 对应的英文单词
    "eleventh",  # 序数 11 对应的英文单词
    "twelfth",   # 序数 12 对应的英文单词
]

_ORDINAL_SUFFIXES = ["st", "nd", "rd", "th"]  # 各种序数的后缀列表

_NUMBER_PATTERN = re.compile(r"((^|\s)[+-])?((\.\d+)|(\d+(,\d\d\d)*(\.\d*)?))")
# 匹配简单的数值表达式的正则表达式模式,包括正负号、逗号分隔的千位数和小数点数值

# Following DynSp:
# https://github.com/Microsoft/DynSP/blob/master/util.py#L293.
_MIN_YEAR = 1700    # 可接受的最小年份
_MAX_YEAR = 2016    # 可接受的最大年份

_INF = float("INF")  # 无穷大的浮点数表示


def _get_numeric_value_from_date(date, mask):
    """Converts date (datetime Python object) to a NumericValue object with a Date object value."""
    if date.year < _MIN_YEAR or date.year > _MAX_YEAR:
        raise ValueError(f"Invalid year: {date.year}")

    new_date = Date()
    if mask.year:
        new_date.year = date.year
    if mask.month:
        new_date.month = date.month
    if mask.day:
        new_date.day = date.day
    return NumericValue(date=new_date)


def _get_span_length_key(span):
    """Sorts span by decreasing length first and increasing first index second."""
    return span[1] - span[0], -span[0]


def _get_numeric_value_from_float(value):
    """Converts float (Python) to a NumericValue object with a float value."""
    return NumericValue(float_value=value)


# Doesn't parse ordinal expressions such as '18th of february 1655'.
def _parse_date(text):
    """Attempts to format a text as a standard date string (yyyy-mm-dd)."""
    text = re.sub(r"Sept\b", "Sep", text)  # 替换文本中的 "Sept" 为 "Sep"
    for in_pattern, mask, regex in _PROCESSED_DATE_PATTERNS:
        if not regex.match(text):
            continue
        try:
            date = datetime.datetime.strptime(text, in_pattern).date()  # 尝试解析文本为日期对象
        except ValueError:
            continue
        try:
            return _get_numeric_value_from_date(date, mask)  # 转换日期为 NumericValue 对象并返回
        except ValueError:
            continue
    return None


def _parse_number(text):
    """Parses simple cardinal and ordinals numbers."""
    for suffix in _ORDINAL_SUFFIXES:
        if text.endswith(suffix):
            text = text[: -len(suffix)]  # 去除文本末尾的序数后缀
            break
    text = text.replace(",", "")  # 去除文本中的逗号
    try:
        value = float(text)  # 尝试将文本转换为浮点数
    except ValueError:
        return None
    if math.isnan(value):
        return None
    if value == _INF:
        return None
    return value


def get_all_spans(text, max_ngram_length):
    """
    Split a text into all possible ngrams up to 'max_ngram_length'. Split points are white space and punctuation.

    Args:
      text: Text to split.
      max_ngram_length: maximal ngram length.
    """
    # 初始化一个空列表,用于存储起始索引
    start_indexes = []
    # 遍历文本中的每个字符及其索引
    for index, char in enumerate(text):
        # 如果当前字符不是字母或数字,则跳过当前循环,继续下一个字符
        if not char.isalnum():
            continue
        # 如果当前字符是字母或数字,并且满足以下条件之一:
        # 1. 是文本的第一个字符
        # 2. 前一个字符不是字母或数字
        # 则将当前索引添加到起始索引列表中
        if index == 0 or not text[index - 1].isalnum():
            start_indexes.append(index)
        # 如果当前字符是字母或数字,并且满足以下条件之一:
        # 1. 是文本的最后一个字符
        # 2. 后一个字符不是字母或数字
        # 针对起始索引列表中的最后几个元素生成 n-gram 的起始索引和结束索引
        if index + 1 == len(text) or not text[index + 1].isalnum():
            for start_index in start_indexes[-max_ngram_length:]:
                # 返回生成器,生成 n-gram 的起始索引和结束索引(不包含结束索引本身)
                yield start_index, index + 1
# 将文本转换为小写,并去除多余的空格
def normalize_for_match(text):
    return " ".join(text.lower().split())


# 将文本转换为小写并去除标点符号
def format_text(text):
    """Lowercases and strips punctuation."""
    text = text.lower().strip()
    # 如果文本是 "n/a"、"?" 或 "nan",则置为空文本
    if text == "n/a" or text == "?" or text == "nan":
        text = EMPTY_TEXT

    # 使用正则表达式替换非字母数字字符为空格,并将下划线替换为空格
    text = re.sub(r"[^\w\d]+", " ", text).replace("_", " ")
    # 去除多余的空格
    text = " ".join(text.split())
    text = text.strip()
    # 如果处理后的文本非空,则返回处理后的文本;否则返回空文本
    if text:
        return text
    return EMPTY_TEXT


# 解析文本,提取最长的数字值和日期跨度
def parse_text(text):
    """
    Extracts longest number and date spans.

    Args:
      text: text to annotate

    Returns:
      List of longest numeric value spans.
    """
    span_dict = collections.defaultdict(list)

    # 提取所有数字模式的匹配项,并解析成数字
    for match in _NUMBER_PATTERN.finditer(text):
        span_text = text[match.start() : match.end()]
        number = _parse_number(span_text)
        if number is not None:
            # 将解析出的数字值添加到对应位置的列表中
            span_dict[match.span()].append(_get_numeric_value_from_float(number))

    # 提取所有单词长度为1的文本片段,并处理其中的数字和序数词
    for begin_index, end_index in get_all_spans(text, max_ngram_length=1):
        if (begin_index, end_index) in span_dict:
            continue
        span_text = text[begin_index:end_index]

        number = _parse_number(span_text)
        if number is not None:
            span_dict[begin_index, end_index].append(_get_numeric_value_from_float(number))
        
        # 检查是否为数字词或序数词,并将其添加到对应位置的列表中
        for number, word in enumerate(_NUMBER_WORDS):
            if span_text == word:
                span_dict[begin_index, end_index].append(_get_numeric_value_from_float(float(number)))
                break
        for number, word in enumerate(_ORDINAL_WORDS):
            if span_text == word:
                span_dict[begin_index, end_index].append(_get_numeric_value_from_float(float(number)))
                break

    # 提取所有长度不超过_MAX_DATE_NGRAM_SIZE的文本片段,并解析日期
    for begin_index, end_index in get_all_spans(text, max_ngram_length=_MAX_DATE_NGRAM_SIZE):
        span_text = text[begin_index:end_index]
        date = _parse_date(span_text)
        if date is not None:
            span_dict[begin_index, end_index].append(date)

    # 根据片段长度对结果进行排序,从长到短
    spans = sorted(span_dict.items(), key=lambda span_value: _get_span_length_key(span_value[0]), reverse=True)
    selected_spans = []

    # 选择不重叠的最长片段
    for span, value in spans:
        for selected_span, _ in selected_spans:
            if selected_span[0] <= span[0] and span[1] <= selected_span[1]:
                break
        else:
            selected_spans.append((span, value))

    # 根据起始索引排序选定的片段
    selected_spans.sort(key=lambda span_value: span_value[0][0])

    numeric_value_spans = []
    # 创建NumericValueSpan对象并添加到列表中
    for span, values in selected_spans:
        numeric_value_spans.append(NumericValueSpan(begin_index=span[0], end_index=span[1], values=values))
    return numeric_value_spans
# - https://github.com/google-research/tapas/blob/master/tapas/utils/text_utils.py

# 定义基本的数值类型,可以是 float 或包含可选的三个 float 的元组
_PrimitiveNumericValue = Union[float, Tuple[Optional[float], Optional[float], Optional[float]]]

# 定义排序键的函数类型,接受一个 NumericValue 参数,返回一个元组和省略号的 float
_SortKeyFn = Callable[[NumericValue], Tuple[float, Ellipsis]]

# 日期元组的大小
_DATE_TUPLE_SIZE = 3

# 表示空文本的常量
EMPTY_TEXT = "EMPTY"

# 表示数值类型的字符串常量
NUMBER_TYPE = "number"
# 表示日期类型的字符串常量
DATE_TYPE = "date"


def _get_value_type(numeric_value):
    # 根据 NumericValue 的内容返回其类型字符串
    if numeric_value.float_value is not None:
        return NUMBER_TYPE
    elif numeric_value.date is not None:
        return DATE_TYPE
    # 如果无法识别类型,则抛出异常
    raise ValueError(f"Unknown type: {numeric_value}")


def _get_value_as_primitive_value(numeric_value):
    """Maps a NumericValue proto to a float or tuple of float."""
    # 根据 NumericValue 返回其对应的 float 或者包含三个 float 的元组
    if numeric_value.float_value is not None:
        return numeric_value.float_value
    if numeric_value.date is not None:
        date = numeric_value.date
        value_tuple = [None, None, None]
        # 将日期的各个字段转换为 float,构成一个简单的基本数值
        if date.year is not None:
            value_tuple[0] = float(date.year)
        if date.month is not None:
            value_tuple[1] = float(date.month)
        if date.day is not None:
            value_tuple[2] = float(date.day)
        return tuple(value_tuple)
    # 如果无法识别类型,则抛出异常
    raise ValueError(f"Unknown type: {numeric_value}")


def _get_all_types(numeric_values):
    # 返回所有 NumericValue 中的类型集合
    return {_get_value_type(value) for value in numeric_values}


def get_numeric_sort_key_fn(numeric_values):
    """
    Creates a function that can be used as a sort key or to compare the values. Maps to primitive types and finds the
    biggest common subset. Consider the values "05/05/2010" and "August 2007". With the corresponding primitive values
    (2010.,5.,5.) and (2007.,8., None). These values can be compared by year and date so we map to the sequence (2010.,
    5.), (2007., 8.). If we added a third value "2006" with primitive value (2006., None, None), we could only compare
    by the year so we would map to (2010.,), (2007.,) and (2006.,).

    Args:
     numeric_values: Values to compare

    Returns:
     A function that can be used as a sort key function (mapping numeric values to a comparable tuple)

    Raises:
      ValueError if values don't have a common type or are not comparable.
    """
    value_types = _get_all_types(numeric_values)
    # 如果数值的类型不唯一,则抛出异常
    if len(value_types) != 1:
        raise ValueError(f"No common value type in {numeric_values}")

    value_type = next(iter(value_types))
    if value_type == NUMBER_TYPE:
        # 数字类型的原始值是简单的 float,此处无需处理
        return _get_value_as_primitive_value

    # 此时类型只能是日期,意味着原始类型是一个三元组的 float
    valid_indexes = set(range(_DATE_TUPLE_SIZE))
    # 遍历传入的 numeric_values 列表中的每个数值
    for numeric_value in numeric_values:
        # 调用函数 _get_value_as_primitive_value,获取 numeric_value 的原始值
        value = _get_value_as_primitive_value(numeric_value)
        # 断言 value 是一个元组
        assert isinstance(value, tuple)
        # 遍历元组 value 中的每个元素及其索引
        for tuple_index, inner_value in enumerate(value):
            # 如果 inner_value 是 None,则从 valid_indexes 中移除该索引
            if inner_value is None:
                valid_indexes.discard(tuple_index)

    # 如果 valid_indexes 集合为空集,表示没有共同的有效索引
    if not valid_indexes:
        # 抛出 ValueError 异常,指示 numeric_values 中没有共同的有效值
        raise ValueError(f"No common value in {numeric_values}")

    # 定义一个排序关键字函数 _sort_key_fn,接受 numeric_value 作为参数
    def _sort_key_fn(numeric_value):
        # 获取 numeric_value 的原始值
        value = _get_value_as_primitive_value(numeric_value)
        # 返回一个元组,包含 valid_indexes 中索引位置的值
        return tuple(value[index] for index in valid_indexes)

    # 返回排序关键字函数 _sort_key_fn
    return _sort_key_fn
# 对给定的行索引到数值列表的映射进行数值合并
def _consolidate_numeric_values(row_index_to_values, min_consolidation_fraction, debug_info):
    """
    Finds the most common numeric values in a column and returns them

    Args:
        row_index_to_values:
            每个行索引对应的数值列表。
        min_consolidation_fraction:
            需要进行合并的最小比例。
        debug_info:
            仅用于调试的额外信息。

    Returns:
        每个行索引对应的最常见数值的第一个匹配值。没有匹配值的行将被丢弃。如果无法合并值,则返回空列表。
    """
    # 统计不同类型出现的次数
    type_counts = collections.Counter()
    for numeric_values in row_index_to_values.values():
        type_counts.update(_get_all_types(numeric_values))
    
    if not type_counts:
        return {}

    # 找到出现次数最多的类型
    max_count = max(type_counts.values())
    if max_count < len(row_index_to_values) * min_consolidation_fraction:
        # logging.log_every_n(logging.INFO, f'Can\'t consolidate types: {debug_info} {row_index_to_values} {max_count}', 100)
        return {}

    valid_types = set()
    for value_type, count in type_counts.items():
        if count == max_count:
            valid_types.add(value_type)
    
    # 如果有多个最常见的类型,确保 DATE_TYPE 在其中
    if len(valid_types) > 1:
        assert DATE_TYPE in valid_types
        max_type = DATE_TYPE
    else:
        max_type = next(iter(valid_types))

    # 创建新的行索引到数值的映射
    new_row_index_to_value = {}
    for index, values in row_index_to_values.items():
        # 提取第一个匹配的值
        for value in values:
            if _get_value_type(value) == max_type:
                new_row_index_to_value[index] = value
                break

    return new_row_index_to_value


def _get_numeric_values(text):
    """解析文本并返回其中的数值。"""
    numeric_spans = parse_text(text)
    return itertools.chain(*(span.values for span in numeric_spans))


def _get_column_values(table, col_index):
    """
    解析表格中指定列的文本,并返回一个字典,将行索引映射到数值列表。
    这是原始实现中 number_annotation_utils.py 中的 _get_column_values 函数。

    Args:
      table: Pandas dataframe
      col_index: 整数,指示要获取数值的列的索引
    """
    index_to_values = {}
    for row_index, row in table.iterrows():
        text = normalize_for_match(row[col_index].text)
        index_to_values[row_index] = list(_get_numeric_values(text))
    return index_to_values


def get_numeric_relation(value, other_value, sort_key_fn):
    """比较两个值并返回它们的关系或 None。"""
    value = sort_key_fn(value)
    other_value = sort_key_fn(other_value)
    if value == other_value:
        return Relation.EQ
    if value < other_value:
        return Relation.LT
    if value > other_value:
        return Relation.GT
    return None


def add_numeric_values_to_question(question):
    """向问题中添加数值范围。"""
    # 将原始问题文本保存在变量 original_text 中
    original_text = question
    # 对问题文本进行规范化处理,使其适合匹配操作
    question = normalize_for_match(question)
    # 解析处理后的问题文本,提取其中的数值范围信息
    numeric_spans = parse_text(question)
    # 返回一个 Question 对象,包含原始文本、规范化后文本和数值范围信息
    return Question(original_text=original_text, text=question, numeric_spans=numeric_spans)
def filter_invalid_unicode(text):
    """
    检查并过滤无效的 Unicode 编码。
    
    Args:
        text: 要检查的文本。

    Returns:
        若 'text' 是无效的 Unicode,则返回空字符串和 True;否则返回原文本和 False。
    """
    return ("", True) if isinstance(text, bytes) else (text, False)


def filter_invalid_unicode_from_table(table):
    """
    从表格中移除无效的 Unicode 编码。检查表格单元格文本是否包含无效的 Unicode 编码,
    如果是,则将单元格文本重置为空字符串,并为每个无效的单元格记录警告日志。

    Args:
        table: 要清理的表格。
    """
    # to do: add table id support
    if not hasattr(table, "table_id"):
        table.table_id = 0

    for row_index, row in table.iterrows():
        for col_index, cell in enumerate(row):
            cell, is_invalid = filter_invalid_unicode(cell)
            if is_invalid:
                logging.warning(
                    f"Scrub an invalid table body @ table_id: {table.table_id}, row_index: {row_index}, "
                    f"col_index: {col_index}",
                )
    for col_index, column in enumerate(table.columns):
        column, is_invalid = filter_invalid_unicode(column)
        if is_invalid:
            logging.warning(f"Scrub an invalid table header @ table_id: {table.table_id}, col_index: {col_index}")


def add_numeric_table_values(table, min_consolidation_fraction=0.7, debug_info=None):
    """
    逐列解析表格中的文本,并添加合并后的数值。合并是指查找具有共同类型(日期或数字)的值。

    Args:
        table: 要注释的表格。
        min_consolidation_fraction: 列中需要具有合并值的单元格的分数。
        debug_info: 用于记录日志的附加信息。
    
    Returns:
        添加了数值属性的表格副本。
    """
    table = table.copy()
    # 首先,过滤掉表格中的无效 Unicode
    filter_invalid_unicode_from_table(table)

    # 其次,将单元格值替换为 Cell 对象
    for row_index, row in table.iterrows():
        for col_index, cell in enumerate(row):
            table.iloc[row_index, col_index] = Cell(text=cell)

    # 第三,为这些 Cell 对象添加 numeric_value 属性
    for col_index, column in enumerate(table.columns):
        column_values = _consolidate_numeric_values(
            _get_column_values(table, col_index),
            min_consolidation_fraction=min_consolidation_fraction,
            debug_info=(debug_info, column),
        )

        for row_index, numeric_value in column_values.items():
            table.iloc[row_index, col_index].numeric_value = numeric_value

    return table

.\models\tapas\__init__.py

# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

# 从 ...utils 中导入 OptionalDependencyNotAvailable、_LazyModule、is_tf_available 和 is_torch_available
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available

# 定义一个字典 _import_structure,用于存储不同模块的导入结构
_import_structure = {
    "configuration_tapas": ["TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP", "TapasConfig"],
    "tokenization_tapas": ["TapasTokenizer"],
}

# 尝试导入 torch 版本的 Tapas 模块,如果不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,则将相关模块添加到 _import_structure 中
    _import_structure["modeling_tapas"] = [
        "TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST",
        "TapasForMaskedLM",
        "TapasForQuestionAnswering",
        "TapasForSequenceClassification",
        "TapasModel",
        "TapasPreTrainedModel",
        "load_tf_weights_in_tapas",
    ]

# 尝试导入 TensorFlow 版本的 Tapas 模块,如果不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,则将相关模块添加到 _import_structure 中
    _import_structure["modeling_tf_tapas"] = [
        "TF_TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST",
        "TFTapasForMaskedLM",
        "TFTapasForQuestionAnswering",
        "TFTapasForSequenceClassification",
        "TFTapasModel",
        "TFTapasPreTrainedModel",
    ]

# 如果是类型检查阶段,导入具体的类型和模块
if TYPE_CHECKING:
    from .configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig
    from .tokenization_tapas import TapasTokenizer

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_tapas import (
            TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST,
            TapasForMaskedLM,
            TapasForQuestionAnswering,
            TapasForSequenceClassification,
            TapasModel,
            TapasPreTrainedModel,
            load_tf_weights_in_tapas,
        )

    try:
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_tf_tapas import (
            TF_TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST,
            TFTapasForMaskedLM,
            TFTapasForQuestionAnswering,
            TFTapasForSequenceClassification,
            TFTapasModel,
            TFTapasPreTrainedModel,
        )

# 如果不是类型检查阶段,则动态加载模块,并将当前模块替换为 _LazyModule 的实例
else:
    import sys

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\timesformer\configuration_timesformer.py

# coding=utf-8
# 定义编码方式为 UTF-8

# 导入预训练配置类 PretrainedConfig 和日志工具 logging
from ...configuration_utils import PretrainedConfig
from ...utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义预训练模型配置文件的映射字典,指定了模型名称和其对应的配置文件链接
TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "facebook/timesformer": "https://huggingface.co/facebook/timesformer/resolve/main/config.json",
}

# TimesformerConfig 类,继承自 PretrainedConfig 类,用于存储 TimeSformer 模型的配置信息
class TimesformerConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`TimesformerModel`]. It is used to instantiate a
    TimeSformer model according to the specified arguments, defining the model architecture. Instantiating a
    configuration with the defaults will yield a similar configuration to that of the TimeSformer
    [facebook/timesformer-base-finetuned-k600](https://huggingface.co/facebook/timesformer-base-finetuned-k600)
    architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    # 设置默认参数:图像尺寸为224像素,每个补丁(patch)的尺寸为16像素,输入通道数为3
    # 每个视频包含8帧,编码器和池化器层的隐藏层大小为768维
    # Transformer编码器中隐藏层的数量为12,每个注意力层中的注意力头数为12
    # Transformer编码器中"中间"(即前馈)层的维度为3072
    # 编码器和池化器中的非线性激活函数为GELU
    # 嵌入层、编码器和池化器中所有全连接层的丢弃概率为0.0
    # 注意力概率的丢弃比例为0.0
    # 初始化所有权重矩阵的截断正态分布标准差为0.02
    # 层归一化层使用的epsilon为1e-06
    # 是否向查询、键和值添加偏置
    # 使用的注意力类型为"divided_space_time"
    # 随机深度的丢弃比率为0

    # 将模型类型设置为"timesformer"
    model_type = "timesformer"
    # 初始化函数,用于初始化一个自定义的神经网络模型
    def __init__(
        self,
        image_size=224,  # 图像输入大小,默认为224
        patch_size=16,  # 每个patch的大小,默认为16
        num_channels=3,  # 输入图像的通道数,默认为3(RGB)
        num_frames=8,  # 输入视频帧数,默认为8
        hidden_size=768,  # Transformer模型中隐藏层的大小,默认为768
        num_hidden_layers=12,  # Transformer模型中隐藏层的数量,默认为12
        num_attention_heads=12,  # Transformer模型中注意力头的数量,默认为12
        intermediate_size=3072,  # Transformer模型中Feedforward层的中间大小,默认为3072
        hidden_act="gelu",  # 隐藏层的激活函数,默认为GELU
        hidden_dropout_prob=0.0,  # 隐藏层的dropout概率,默认为0(无dropout)
        attention_probs_dropout_prob=0.0,  # 注意力层的dropout概率,默认为0(无dropout)
        initializer_range=0.02,  # 初始化权重的范围,默认为0.02
        layer_norm_eps=1e-6,  # LayerNorm层的epsilon值,默认为1e-6
        qkv_bias=True,  # 是否在QKV(查询、键、值)矩阵中使用偏置项,默认为True
        attention_type="divided_space_time",  # 注意力机制的类型,默认为“divided_space_time”
        drop_path_rate=0,  # DropPath层的drop率,默认为0(无drop)
        **kwargs,
    ):
        # 调用父类的初始化方法
        super().__init__(**kwargs)

        # 设置对象的属性值
        self.image_size = image_size  # 图像输入大小
        self.patch_size = patch_size  # 每个patch的大小
        self.num_channels = num_channels  # 输入图像的通道数
        self.num_frames = num_frames  # 输入视频帧数

        self.hidden_size = hidden_size  # Transformer模型中隐藏层的大小
        self.num_hidden_layers = num_hidden_layers  # Transformer模型中隐藏层的数量
        self.num_attention_heads = num_attention_heads  # Transformer模型中注意力头的数量
        self.intermediate_size = intermediate_size  # Transformer模型中Feedforward层的中间大小
        self.hidden_act = hidden_act  # 隐藏层的激活函数
        self.hidden_dropout_prob = hidden_dropout_prob  # 隐藏层的dropout概率
        self.attention_probs_dropout_prob = attention_probs_dropout_prob  # 注意力层的dropout概率
        self.initializer_range = initializer_range  # 初始化权重的范围
        self.layer_norm_eps = layer_norm_eps  # LayerNorm层的epsilon值
        self.qkv_bias = qkv_bias  # 是否在QKV(查询、键、值)矩阵中使用偏置项

        self.attention_type = attention_type  # 注意力机制的类型
        self.drop_path_rate = drop_path_rate  # DropPath层的drop率

.\models\timesformer\convert_timesformer_to_pytorch.py

# 设置编码格式为 UTF-8
# 版权声明及许可信息,指明代码使用的许可协议和版权归属
# 导入转换 TimeSformer 检查点所需的库和模块

import argparse  # 导入用于解析命令行参数的库
import json  # 导入处理 JSON 格式数据的库

import gdown  # 导入用于从 Google Drive 下载文件的库
import numpy as np  # 导入处理数值和数组的库
import torch  # 导入 PyTorch 深度学习框架
from huggingface_hub import hf_hub_download  # 导入从 Hugging Face Hub 下载资源的函数

from transformers import TimesformerConfig, TimesformerForVideoClassification, VideoMAEImageProcessor  # 导入 TimeSformer 模型所需的配置、模型和处理器类


def get_timesformer_config(model_name):
    config = TimesformerConfig()  # 创建一个 TimeSformer 的配置对象

    if "large" in model_name:
        config.num_frames = 96  # 如果模型名包含 'large',设置帧数为 96

    if "hr" in model_name:
        config.num_frames = 16  # 如果模型名包含 'hr',设置帧数为 16
        config.image_size = 448  # 同时设置图像尺寸为 448

    repo_id = "huggingface/label-files"
    if "k400" in model_name:
        config.num_labels = 400  # 如果模型名包含 'k400',设置标签数为 400
        filename = "kinetics400-id2label.json"  # 设置要下载的文件名为 kinetics400-id2label.json
    elif "k600" in model_name:
        config.num_labels = 600  # 如果模型名包含 'k600',设置标签数为 600
        filename = "kinetics600-id2label.json"  # 设置要下载的文件名为 kinetics600-id2label.json
    elif "ssv2" in model_name:
        config.num_labels = 174  # 如果模型名包含 'ssv2',设置标签数为 174
        filename = "something-something-v2-id2label.json"  # 设置要下载的文件名为 something-something-v2-id2label.json
    else:
        raise ValueError("Model name should either contain 'k400', 'k600' or 'ssv2'.")  # 如果模型名不符合预期,则引发错误
    id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))  # 从 Hugging Face Hub 下载并加载 JSON 格式的标签映射数据
    id2label = {int(k): v for k, v in id2label.items()}  # 将标签映射数据中的键转换为整数类型
    config.id2label = id2label  # 将加载的标签映射数据设置为配置对象的 id2label 属性
    config.label2id = {v: k for k, v in id2label.items()}  # 创建反向映射,从标签到 ID 的映射

    return config  # 返回配置对象


def rename_key(name):
    if "encoder." in name:
        name = name.replace("encoder.", "")  # 替换模型参数名中的 'encoder.' 为 ''
    if "cls_token" in name:
        name = name.replace("cls_token", "timesformer.embeddings.cls_token")  # 替换模型参数名中的 'cls_token' 为 'timesformer.embeddings.cls_token'
    if "pos_embed" in name:
        name = name.replace("pos_embed", "timesformer.embeddings.position_embeddings")  # 替换模型参数名中的 'pos_embed' 为 'timesformer.embeddings.position_embeddings'
    if "time_embed" in name:
        name = name.replace("time_embed", "timesformer.embeddings.time_embeddings")  # 替换模型参数名中的 'time_embed' 为 'timesformer.embeddings.time_embeddings'
    if "patch_embed.proj" in name:
        name = name.replace("patch_embed.proj", "timesformer.embeddings.patch_embeddings.projection")  # 替换模型参数名中的 'patch_embed.proj' 为 'timesformer.embeddings.patch_embeddings.projection'
    if "patch_embed.norm" in name:
        name = name.replace("patch_embed.norm", "timesformer.embeddings.norm")  # 替换模型参数名中的 'patch_embed.norm' 为 'timesformer.embeddings.norm'
    if "blocks" in name:
        name = name.replace("blocks", "timesformer.encoder.layer")  # 替换模型参数名中的 'blocks' 为 'timesformer.encoder.layer'
    if "attn.proj" in name:
        name = name.replace("attn.proj", "attention.output.dense")  # 替换模型参数名中的 'attn.proj' 为 'attention.output.dense'
    if "attn" in name and "bias" not in name and "temporal" not in name:
        name = name.replace("attn", "attention.self")  # 替换模型参数名中的 'attn' 为 'attention.self',排除包含 'bias' 和 'temporal' 的情况
    if "attn" in name and "temporal" not in name:
        name = name.replace("attn", "attention.attention")  # 替换模型参数名中的 'attn' 为 'attention.attention',排除包含 'temporal' 的情况
    # 检查字符串 "temporal_norm1" 是否在变量 name 中
    if "temporal_norm1" in name:
        # 如果是,则将字符串 "temporal_norm1" 替换为 "temporal_layernorm"
        name = name.replace("temporal_norm1", "temporal_layernorm")

    # 检查字符串 "temporal_attn.proj" 是否在变量 name 中
    if "temporal_attn.proj" in name:
        # 如果是,则将字符串 "temporal_attn" 替换为 "temporal_attention.output.dense"
        name = name.replace("temporal_attn", "temporal_attention.output.dense")

    # 检查字符串 "temporal_fc" 是否在变量 name 中
    if "temporal_fc" in name:
        # 如果是,则将字符串 "temporal_fc" 替换为 "temporal_dense"
        name = name.replace("temporal_fc", "temporal_dense")

    # 检查字符串 "norm1" 是否在变量 name 中,并且字符串中不包含 "temporal"
    if "norm1" in name and "temporal" not in name:
        # 如果是,则将字符串 "norm1" 替换为 "layernorm_before"
        name = name.replace("norm1", "layernorm_before")

    # 检查字符串 "norm2" 是否在变量 name 中
    if "norm2" in name:
        # 如果是,则将字符串 "norm2" 替换为 "layernorm_after"
        name = name.replace("norm2", "layernorm_after")

    # 检查字符串 "mlp.fc1" 是否在变量 name 中
    if "mlp.fc1" in name:
        # 如果是,则将字符串 "mlp.fc1" 替换为 "intermediate.dense"
        name = name.replace("mlp.fc1", "intermediate.dense")

    # 检查字符串 "mlp.fc2" 是否在变量 name 中
    if "mlp.fc2" in name:
        # 如果是,则将字符串 "mlp.fc2" 替换为 "output.dense"
        name = name.replace("mlp.fc2", "output.dense")

    # 检查字符串 "norm.weight" 是否在变量 name 中,并且字符串中不包含 "fc" 和 "temporal"
    if "norm.weight" in name and "fc" not in name and "temporal" not in name:
        # 如果是,则将字符串 "norm.weight" 替换为 "timesformer.layernorm.weight"
        name = name.replace("norm.weight", "timesformer.layernorm.weight")

    # 检查字符串 "norm.bias" 是否在变量 name 中,并且字符串中不包含 "fc" 和 "temporal"
    if "norm.bias" in name and "fc" not in name and "temporal" not in name:
        # 如果是,则将字符串 "norm.bias" 替换为 "timesformer.layernorm.bias"
        name = name.replace("norm.bias", "timesformer.layernorm.bias")

    # 检查字符串 "head" 是否在变量 name 中
    if "head" in name:
        # 如果是,则将字符串 "head" 替换为 "classifier"
        name = name.replace("head", "classifier")

    # 返回替换后的变量 name
    return name
# 根据给定的原始状态字典和配置,转换模型的状态字典
def convert_state_dict(orig_state_dict, config):
    # 遍历原始状态字典的键(需要复制,因为后续会修改原始字典)
    for key in orig_state_dict.copy().keys():
        # 弹出当前键对应的值
        val = orig_state_dict.pop(key)

        # 如果键以"model."开头,则去除该前缀
        if key.startswith("model."):
            key = key.replace("model.", "")

        # 如果键包含"qkv",则根据不同情况重新命名键
        if "qkv" in key:
            key_split = key.split(".")
            layer_num = int(key_split[1])
            prefix = "timesformer.encoder.layer."
            # 根据键名中是否包含"temporal"决定后缀
            if "temporal" in key:
                postfix = ".temporal_attention.attention.qkv."
            else:
                postfix = ".attention.attention.qkv."
            # 根据键名中是否包含"weight"决定修改状态字典中的键和对应的值
            if "weight" in key:
                orig_state_dict[f"{prefix}{layer_num}{postfix}weight"] = val
            else:
                orig_state_dict[f"{prefix}{layer_num}{postfix}bias"] = val
        else:
            # 否则,对键进行重命名
            orig_state_dict[rename_key(key)] = val

    # 返回转换后的原始状态字典
    return orig_state_dict


# 我们将在一个吃意大利面条的视频上验证我们的结果
# 使用的帧索引: [164 168 172 176 181 185 189 193 198 202 206 210 215 219 223 227]
def prepare_video():
    # 从指定的数据集仓库下载名为"eating_spaghetti.npy"的文件
    file = hf_hub_download(
        repo_id="hf-internal-testing/spaghetti-video", filename="eating_spaghetti.npy", repo_type="dataset"
    )
    # 加载视频数据并转换为列表返回
    video = np.load(file)
    return list(video)


def convert_timesformer_checkpoint(checkpoint_url, pytorch_dump_folder_path, model_name, push_to_hub):
    # 获取特定模型名称的配置信息
    config = get_timesformer_config(model_name)

    # 使用配置创建一个 TimesformerForVideoClassification 模型
    model = TimesformerForVideoClassification(config)

    # 下载托管在 Google Drive 上的原始检查点文件
    output = "pytorch_model.bin"
    gdown.cached_download(checkpoint_url, output, quiet=False)
    # 加载检查点文件,根据文件中的键名不同进行适配
    files = torch.load(output, map_location="cpu")
    if "model" in files:
        state_dict = files["model"]
    elif "module" in files:
        state_dict = files["module"]
    else:
        state_dict = files["model_state"]
    # 转换加载的状态字典到新的状态字典格式
    new_state_dict = convert_state_dict(state_dict, config)

    # 加载模型的新状态字典
    model.load_state_dict(new_state_dict)
    # 设置模型为评估模式
    model.eval()

    # 在基本输入上验证模型
    # 创建一个图像处理器对象,用于视频处理
    image_processor = VideoMAEImageProcessor(image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5])
    # 准备视频数据
    video = prepare_video()
    # 使用图像处理器处理前8帧视频,并返回PyTorch张量格式的输入
    inputs = image_processor(video[:8], return_tensors="pt")

    # 使用模型进行推理,获取输出结果
    outputs = model(**inputs)
    logits = outputs.logits

    # 定义一组模型名称列表,包含不同版本和分辨率的预训练检查点
    model_names = [
        # Kinetics-400 数据集检查点(hr = 使用448px高分辨率输入而非224px)
        "timesformer-base-finetuned-k400",
        "timesformer-large-finetuned-k400",
        "timesformer-hr-finetuned-k400",
        # Kinetics-600 数据集检查点(hr = 使用448px高分辨率输入而非224px)
        "timesformer-base-finetuned-k600",
        "timesformer-large-finetuned-k600",
        "timesformer-hr-finetuned-k600",
        # Something-Something-v2 数据集检查点(hr = 使用448px高分辨率输入而非224px)
        "timesformer-base-finetuned-ssv2",
        "timesformer-large-finetuned-ssv2",
        "timesformer-hr-finetuned-ssv2",
    ]

    # 注意:logits使用了图像均值和标准差 [0.5, 0.5, 0.5] 和 [0.5, 0.5, 0.5] 进行了测试
    # 根据模型名称设置预期的输出形状和预期的输出值
    if model_name == "timesformer-base-finetuned-k400":
        expected_shape = torch.Size([1, 400])
        expected_slice = torch.tensor([-0.3016, -0.7713, -0.4205])
    elif model_name == "timesformer-base-finetuned-k600":
        expected_shape = torch.Size([1, 600])
        expected_slice = torch.tensor([-0.7267, -0.7466, 3.2404])
    elif model_name == "timesformer-base-finetuned-ssv2":
        expected_shape = torch.Size([1, 174])
        expected_slice = torch.tensor([-0.9059, 0.6433, -3.1457])
    elif model_name == "timesformer-large-finetuned-k400":
        expected_shape = torch.Size([1, 400])
        expected_slice = torch.tensor([0, 0, 0])
    elif model_name == "timesformer-large-finetuned-k600":
        expected_shape = torch.Size([1, 600])
        expected_slice = torch.tensor([0, 0, 0])
    elif model_name == "timesformer-large-finetuned-ssv2":
        expected_shape = torch.Size([1, 174])
        expected_slice = torch.tensor([0, 0, 0])
    elif model_name == "timesformer-hr-finetuned-k400":
        expected_shape = torch.Size([1, 400])
        expected_slice = torch.tensor([-0.9617, -3.7311, -3.7708])
    elif model_name == "timesformer-hr-finetuned-k600":
        expected_shape = torch.Size([1, 600])
        expected_slice = torch.tensor([2.5273, 0.7127, 1.8848])
    elif model_name == "timesformer-hr-finetuned-ssv2":
        expected_shape = torch.Size([1, 174])
        expected_slice = torch.tensor([-3.6756, -0.7513, 0.7180])
    else:
        raise ValueError(f"Model name not supported. Should be one of {model_names}")

    # 验证模型输出的形状是否与预期一致
    assert logits.shape == expected_shape
    # 验证模型输出的前三个元素是否与预期的数值接近
    assert torch.allclose(logits[0, :3], expected_slice, atol=1e-4)
    # 打印确认信息
    print("Logits ok!")

    # 如果指定了 PyTorch 模型保存路径,则保存模型和图像处理器
    if pytorch_dump_folder_path is not None:
        print(f"Saving model and image processor to {pytorch_dump_folder_path}")
        image_processor.save_pretrained(pytorch_dump_folder_path)
        model.save_pretrained(pytorch_dump_folder_path)

    # 如果需要推送到 hub
    if push_to_hub:
        # 打印推送到 hub 的消息
        print("Pushing to the hub...")
        # 将模型推送到指定路径下的 hub
        model.push_to_hub(f"fcakyon/{model_name}")
if __name__ == "__main__":
    # 如果作为主程序执行,则开始解析命令行参数
    parser = argparse.ArgumentParser()
    
    # 添加必需的参数
    parser.add_argument(
        "--checkpoint_url",
        default="https://drive.google.com/u/1/uc?id=17yvuYp9L4mn-HpIcK5Zo6K3UoOy1kA5l&export=download",
        type=str,
        help=(
            "URL of the original PyTorch checkpoint (on Google Drive) you'd like to convert. Should be a direct"
            " download link."
        ),
    )
    
    parser.add_argument(
        "--pytorch_dump_folder_path",
        default="",
        type=str,
        help="Path to the output PyTorch model directory.",
    )
    
    parser.add_argument(
        "--model_name", 
        default="timesformer-base-finetuned-k400", 
        type=str, 
        help="Name of the model."
    )
    
    parser.add_argument(
        "--push_to_hub", 
        action="store_true", 
        help="Whether or not to push the converted model to the 🤗 hub."
    )
    
    # 解析命令行参数
    args = parser.parse_args()
    
    # 调用函数 convert_timesformer_checkpoint,传入解析得到的参数
    convert_timesformer_checkpoint(
        args.checkpoint_url, args.pytorch_dump_folder_path, args.model_name, args.push_to_hub
    )
posted @ 2024-07-01 10:58  绝不原创的飞龙  阅读(23)  评论(0编辑  收藏  举报