Transformers-源码解析-十五-

Transformers 源码解析(十五)

.\models\bert\modeling_flax_bert.py

# 导入所需的模块和类
from typing import Callable, Optional, Tuple  # 导入类型注解相关的类和方法

import flax  # 导入 Flax 深度学习框架
import flax.linen as nn  # 导入 Flax 提供的线性 API 模块
import jax  # 导入 JAX,用于定义和执行计算
import jax.numpy as jnp  # 导入 JAX 提供的 NumPy 兼容的数组处理工具
import numpy as np  # 导入 NumPy 数组处理库
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze  # 导入 Flax 提供的冻结字典相关方法
from flax.linen import combine_masks, make_causal_mask  # 导入 Flax 提供的掩码组合和因果掩码生成方法
from flax.linen import partitioning as nn_partitioning  # 导入 Flax 提供的分区模块
from flax.linen.attention import dot_product_attention_weights  # 导入 Flax 提供的点积注意力权重计算方法
from flax.traverse_util import flatten_dict, unflatten_dict  # 导入 Flax 提供的字典扁平化和反扁平化方法
from jax import lax  # 导入 JAX 提供的线性代数加速模块

# 从外部模块导入不同输出类型的模型结果类
from ...modeling_flax_outputs import (
    FlaxBaseModelOutputWithPastAndCrossAttentions,
    FlaxBaseModelOutputWithPooling,
    FlaxBaseModelOutputWithPoolingAndCrossAttentions,
    FlaxCausalLMOutputWithCrossAttentions,
    FlaxMaskedLMOutput,
    FlaxMultipleChoiceModelOutput,
    FlaxNextSentencePredictorOutput,
    FlaxQuestionAnsweringModelOutput,
    FlaxSequenceClassifierOutput,
    FlaxTokenClassifierOutput,
)
# 从外部模块导入不同工具类和方法
from ...modeling_flax_utils import (
    ACT2FN,
    FlaxPreTrainedModel,
    append_call_sample_docstring,
    append_replace_return_docstrings,
    overwrite_call_docstring,
)
# 从外部模块导入模型输出类和配置相关的方法
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging

# 从本地模块导入 BERT 模型配置类
from .configuration_bert import BertConfig

# 获取日志记录器实例
logger = logging.get_logger(__name__)

# 定义用于文档的检查点名称和配置名称
_CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased"
_CONFIG_FOR_DOC = "BertConfig"

# 定义 Flax 提供的重要函数 remat
remat = nn_partitioning.remat

# 定义 FlaxBertForPreTrainingOutput 类,继承自 ModelOutput,用于表示预训练过程的输出类型
@flax.struct.dataclass
class FlaxBertForPreTrainingOutput(ModelOutput):
    """
    Output type of [`BertForPreTraining`].
    """
    # 预测语言模型头部的预测得分(每个词汇标记的得分,未经过 SoftMax 处理)
    prediction_logits: jnp.ndarray = None

    # 下一序列预测(分类)头部的预测得分(True/False 继续的得分,未经过 SoftMax 处理)
    seq_relationship_logits: jnp.ndarray = None

    # 模型隐藏层的隐藏状态元组(当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回)
    # 包含了每一层的输出(除了嵌入层外)的 `jnp.ndarray` 数组,形状为 `(batch_size, sequence_length, hidden_size)`
    hidden_states: Optional[Tuple[jnp.ndarray]] = None

    # 自注意力头部的注意力权重元组(当 `output_attentions=True` 或 `config.output_attentions=True` 时返回)
    # 包含了每一层的注意力权重 `jnp.ndarray` 数组,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`
    attentions: Optional[Tuple[jnp.ndarray]] = None
# BERT_START_DOCSTRING 是一个原始字符串文档,包含了有关模型继承关系和Flax模块的详细说明。
# 这个模型继承自 FlaxPreTrainedModel,并且可以作为 flax.linen.Module 使用。
# 支持 JAX 的 JIT 编译、自动微分、向量化和并行化特性。
# 
# Parameters:
#     config ([BertConfig]): 包含模型所有参数的配置类。
#         使用配置文件初始化时,不会加载与模型关联的权重,只加载配置。
#         可以使用 FlaxPreTrainedModel.from_pretrained 方法加载模型权重。
#     dtype (jax.numpy.dtype, 可选,默认为 jax.numpy.float32):
#         计算的数据类型。可以是 jax.numpy.float32、jax.numpy.float16(在GPU上)和 jax.numpy.bfloat16(在TPU上)。
#         可用于在GPU或TPU上启用混合精度训练或半精度推断。如果指定,则所有计算都将使用给定的 dtype。
#         
#         注意,这只指定计算的数据类型,不影响模型参数的数据类型。
#         
#         如果要更改模型参数的数据类型,请参阅 FlaxPreTrainedModel.to_fp16 和 FlaxPreTrainedModel.to_bf16。
BERT_START_DOCSTRING = r"""

    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading, saving and converting weights from PyTorch models)

    This model is also a
    [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
    a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
    behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)

    Parameters:
        config ([`BertConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
            The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
            `jax.numpy.bfloat16` (on TPUs).

            This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
            specified all the computation will be performed with the given `dtype`.

            **Note that this only specifies the dtype of the computation and does not influence the dtype of model
            parameters.**

            If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
            [`~FlaxPreTrainedModel.to_bf16`].
        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
            The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
            `jax.numpy.bfloat16` (on TPUs).

            This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
            specified all the computation will be performed with the given `dtype`.

            **Note that this only specifies the dtype of the computation and does not influence the dtype of model
            parameters.**

            If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
            [`~FlaxPreTrainedModel.to_bf16`].

"""

# BERT_INPUTS_DOCSTRING 是一个原始字符串文档,目前为空,用于稍后添加BERT模型输入的说明文档。
BERT_INPUTS_DOCSTRING = r"""
    # Args 是一个 docstring(文档字符串)的一部分,用于描述函数的参数信息
    Args:
        input_ids (`numpy.ndarray` of shape `({0})`):
            # 输入序列中的标记索引,在词汇表中表示每个标记
            Indices of input sequence tokens in the vocabulary.
    
            # 可以使用 AutoTokenizer 来获取这些索引,详见 PreTrainedTokenizer.encode 和 PreTrainedTokenizer.__call__ 的详情
    
            [What are input IDs?](../glossary#input-ids)
            # 查看更多关于 input IDs 的信息,链接到 glossary 的相应条目
    
        attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
            # 用于避免在填充的标记索引上进行注意力计算的掩码
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
    
            - 1 表示**不被遮蔽**的标记,
            - 0 表示**被遮蔽**的标记。
    
            [What are attention masks?](../glossary#attention-mask)
            # 查看更多关于 attention masks 的信息,链接到 glossary 的相应条目
    
        token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
            # 段标记索引,指示输入的第一部分和第二部分。索引在 `[0, 1]` 中选择:
    
            - 0 对应于*句子 A* 的标记,
            - 1 对应于*句子 B* 的标记。
    
            [What are token type IDs?](../glossary#token-type-ids)
            # 查看更多关于 token type IDs 的信息,链接到 glossary 的相应条目
    
        position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
            # 每个输入序列标记在位置嵌入中的位置索引。选择在范围 `[0, config.max_position_embeddings - 1]` 中。
    
        head_mask (`numpy.ndarray` of shape `({0})`, `optional):
            # 用于使注意力模块中的特定 head 失效的掩码。掩码值选择在 `[0, 1]` 中:
    
            - 1 表示该 head **未被遮蔽**,
            - 0 表示该 head **被遮蔽**。
    
        return_dict (`bool`, *optional*):
            # 是否返回 `~utils.ModelOutput` 而不是普通元组。
    
    # 上述内容描述了函数的各个参数及其作用,帮助用户理解如何使用这些参数调用函数。
"""
class FlaxBertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    config: BertConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):
        # 初始化词嵌入层,输入词汇表大小和隐藏层大小,并使用正态分布初始化
        self.word_embeddings = nn.Embed(
            self.config.vocab_size,
            self.config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            dtype=self.dtype,
        )
        # 初始化位置嵌入层,输入最大位置嵌入数量和隐藏层大小,并使用正态分布初始化
        self.position_embeddings = nn.Embed(
            self.config.max_position_embeddings,
            self.config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            dtype=self.dtype,
        )
        # 初始化类型嵌入层,输入类型词汇表大小和隐藏层大小,并使用正态分布初始化
        self.token_type_embeddings = nn.Embed(
            self.config.type_vocab_size,
            self.config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            dtype=self.dtype,
        )
        # 初始化 Layer Normalization 层,设置 epsilon 为配置中的值
        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        # 初始化 Dropout 层,设置丢弃率为配置中的值
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

    def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
        # 嵌入输入的词嵌入向量,将输入类型转换为整型
        inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
        # 嵌入位置向量,将位置编码转换为整型
        position_embeds = self.position_embeddings(position_ids.astype("i4"))
        # 嵌入类型向量,将类型编码转换为整型
        token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))

        # 汇总所有嵌入向量
        hidden_states = inputs_embeds + token_type_embeddings + position_embeds

        # 进行 Layer Normalization
        hidden_states = self.LayerNorm(hidden_states)
        # 应用 Dropout,根据 deterministic 参数决定是否使用确定性 Dropout
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        return hidden_states


class FlaxBertSelfAttention(nn.Module):
    config: BertConfig
    causal: bool = False
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
"""
    def setup(self):
        # 将隐藏层大小分成多个注意力头的维度
        self.head_dim = self.config.hidden_size // self.config.num_attention_heads
        # 如果隐藏层大小不能被注意力头数整除,抛出数值错误
        if self.config.hidden_size % self.config.num_attention_heads != 0:
            raise ValueError(
                "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
                "                   : {self.config.num_attention_heads}"
            )

        # 初始化查询权重的全连接层
        self.query = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )
        # 初始化键权重的全连接层
        self.key = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )
        # 初始化值权重的全连接层
        self.value = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )

        # 如果启用因果注意力机制,创建因果掩码
        if self.causal:
            self.causal_mask = make_causal_mask(
                jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
            )

    def _split_heads(self, hidden_states):
        # 将隐藏状态张量分割成多个注意力头
        return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))

    def _merge_heads(self, hidden_states):
        # 将分割的注意力头合并回原始隐藏状态张量
        return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))

    @nn.compact
    # 从 transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache 复制而来
    def _concatenate_to_cache(self, key, value, query, attention_mask):
        """
        This function takes projected key, value states from a single input token and concatenates the states to cached
        states from previous steps. This function is slightly adapted from the official Flax repository:
        https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
        """
        # 检测是否通过缺失现有缓存数据进行初始化。
        is_initialized = self.has_variable("cache", "cached_key")
        # 获取或创建缓存的键值状态
        cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
        cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
        # 获取或创建缓存索引
        cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))

        if is_initialized:
            *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
            # 使用新的一维空间片段更新键、值缓存
            cur_index = cache_index.value
            indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
            key = lax.dynamic_update_slice(cached_key.value, key, indices)
            value = lax.dynamic_update_slice(cached_value.value, value, indices)
            cached_key.value = key
            cached_value.value = value
            # 更新缓存索引以反映新加入的缓存向量数
            num_updated_cache_vectors = query.shape[1]
            cache_index.value = cache_index.value + num_updated_cache_vectors
            # 对于缓存的解码器自注意力,使用因果掩码:我们的单个查询位置只应与已生成和缓存的键位置相关联,而不是剩余的零元素。
            pad_mask = jnp.broadcast_to(
                jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
                tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
            )
            # 合并因果掩码和给定的注意力掩码
            attention_mask = combine_masks(pad_mask, attention_mask)
        
        # 返回更新后的键、值和注意力掩码
        return key, value, attention_mask
# 定义 FlaxBertSelfOutput 类,继承自 nn.Module
class FlaxBertSelfOutput(nn.Module):
    config: BertConfig  # 类型注解,指定配置为 BertConfig 类型
    dtype: jnp.dtype = jnp.float32  # 计算中使用的数据类型

    # 设置方法,初始化网络层
    def setup(self):
        # 创建全连接层,输入大小为 hidden_size,使用正态分布初始化权重
        self.dense = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        # 创建 LayerNorm 层,使用配置中的 epsilon 参数
        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        # 创建 Dropout 层,使用配置中的 hidden_dropout_prob 参数
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

    # 调用方法,定义网络层间的传递逻辑
    def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
        # 全连接层处理隐藏状态
        hidden_states = self.dense(hidden_states)
        # 使用 Dropout 层处理隐藏状态,根据 deterministic 参数确定是否采用确定性方式
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        # 使用 LayerNorm 层处理隐藏状态和输入张量的加和结果
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        # 返回处理后的隐藏状态
        return hidden_states


# 定义 FlaxBertAttention 类,继承自 nn.Module
class FlaxBertAttention(nn.Module):
    config: BertConfig  # 类型注解,指定配置为 BertConfig 类型
    causal: bool = False  # 是否使用因果注意力的标志,默认为 False
    dtype: jnp.dtype = jnp.float32  # 计算中使用的数据类型

    # 设置方法,初始化网络层
    def setup(self):
        # 创建自注意力层,使用配置和 causal 参数
        self.self = FlaxBertSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
        # 创建自输出层,使用配置和数据类型参数
        self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype)

    # 调用方法,定义网络层间的传递逻辑
    def __call__(
        self,
        hidden_states,
        attention_mask,
        layer_head_mask,
        key_value_states=None,
        init_cache=False,
        deterministic=True,
        output_attentions: bool = False,
    ):
        # 注意力掩码的形状为 (*batch_sizes, kv_length)
        # FLAX 需要形状为 (*batch_sizes, 1, 1, kv_length),以便与注意力权重形状匹配
        attn_outputs = self.self(
            hidden_states,
            attention_mask,
            layer_head_mask=layer_head_mask,
            key_value_states=key_value_states,
            init_cache=init_cache,
            deterministic=deterministic,
            output_attentions=output_attentions,
        )
        # 获取自注意力层的输出
        attn_output = attn_outputs[0]
        # 使用自输出层处理自注意力层的输出和隐藏状态
        hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)

        # 构建输出元组
        outputs = (hidden_states,)

        # 如果需要输出注意力权重,则添加到输出元组中
        if output_attentions:
            outputs += (attn_outputs[1],)

        # 返回最终输出
        return outputs


# 定义 FlaxBertIntermediate 类,继承自 nn.Module
class FlaxBertIntermediate(nn.Module):
    config: BertConfig  # 类型注解,指定配置为 BertConfig 类型
    dtype: jnp.dtype = jnp.float32  # 计算中使用的数据类型

    # 设置方法,初始化网络层
    def setup(self):
        # 创建全连接层,输入大小为 intermediate_size,使用正态分布初始化权重
        self.dense = nn.Dense(
            self.config.intermediate_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        # 使用配置中的隐藏激活函数名称,创建激活函数层
        self.activation = ACT2FN[self.config.hidden_act]

    # 调用方法,定义网络层间的传递逻辑
    def __call__(self, hidden_states):
        # 全连接层处理隐藏状态
        hidden_states = self.dense(hidden_states)
        # 使用激活函数处理全连接层的输出
        hidden_states = self.activation(hidden_states)
        # 返回处理后的隐藏状态
        return hidden_states


# 定义 FlaxBertOutput 类,继承自 nn.Module
class FlaxBertOutput(nn.Module):
    config: BertConfig  # 类型注解,指定配置为 BertConfig 类型
    dtype: jnp.dtype = jnp.float32  # 计算中使用的数据类型
    # 初始化模型中的层和参数
    def setup(self):
        # 创建一个全连接层,输出维度为 self.config.hidden_size
        self.dense = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        # 创建一个 Dropout 层,用于随机失活以防止过拟合
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
        # 创建一个 LayerNorm 层,用于层标准化,epsilon 为 self.config.layer_norm_eps
        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)

    # 定义模型的调用方法,实现前向传播
    def __call__(self, hidden_states, attention_output, deterministic: bool = True):
        # 使用全连接层进行线性变换
        hidden_states = self.dense(hidden_states)
        # 对变换后的结果进行 Dropout 操作,以减少过拟合风险
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        # 对结果进行层标准化,并与 attention_output 相加作为最终输出
        hidden_states = self.LayerNorm(hidden_states + attention_output)
        # 返回最终的隐藏状态表示
        return hidden_states
class FlaxBertLayer(nn.Module):
    config: BertConfig
    dtype: jnp.dtype = jnp.float32  # 计算的数据类型

    def setup(self):
        # 初始化注意力层、中间层和输出层
        self.attention = FlaxBertAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
        self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype)
        self.output = FlaxBertOutput(self.config, dtype=self.dtype)
        # 如果配置中包含交叉注意力,初始化交叉注意力层
        if self.config.add_cross_attention:
            self.crossattention = FlaxBertAttention(self.config, causal=False, dtype=self.dtype)

    def __call__(
        self,
        hidden_states,
        attention_mask,
        layer_head_mask,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
        output_attentions: bool = False,
    ):
        # 自注意力机制
        attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            layer_head_mask=layer_head_mask,
            init_cache=init_cache,
            deterministic=deterministic,
            output_attentions=output_attentions,
        )
        attention_output = attention_outputs[0]

        # 交叉注意力块
        if encoder_hidden_states is not None:
            cross_attention_outputs = self.crossattention(
                attention_output,
                attention_mask=encoder_attention_mask,
                layer_head_mask=layer_head_mask,
                key_value_states=encoder_hidden_states,
                deterministic=deterministic,
                output_attentions=output_attentions,
            )
            attention_output = cross_attention_outputs[0]

        hidden_states = self.intermediate(attention_output)
        hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (attention_outputs[1],)
            if encoder_hidden_states is not None:
                outputs += (cross_attention_outputs[1],)
        return outputs


class FlaxBertLayerCollection(nn.Module):
    config: BertConfig
    dtype: jnp.dtype = jnp.float32  # 计算的数据类型
    gradient_checkpointing: bool = False

    def setup(self):
        if self.gradient_checkpointing:
            # 如果梯度检查点开启,使用 remat 函数包装 FlaxBertLayer 并创建层集合
            FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7))
            self.layers = [
                FlaxBertCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
                for i in range(self.config.num_hidden_layers)
            ]
        else:
            # 否则,直接创建 FlaxBertLayer 的层集合
            self.layers = [
                FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
            ]
    # 定义 __call__ 方法,用于将对象实例作为可调用函数使用
    def __call__(
        self,
        hidden_states,
        attention_mask,
        head_mask,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 如果需要输出注意力权重,则初始化空元组,否则设为 None
        all_attentions = () if output_attentions else None
        # 如果需要输出隐藏状态,则初始化空元组,否则设为 None
        all_hidden_states = () if output_hidden_states else None
        # 如果同时需要输出注意力权重且有编码器隐藏状态,则初始化空元组,否则设为 None
        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None

        # 检查头部掩码的层数是否正确
        if head_mask is not None:
            if head_mask.shape[0] != (len(self.layers)):
                # 抛出异常,指出头部掩码应指定为与层数相同的层数
                raise ValueError(
                    f"The head_mask should be specified for {len(self.layers)} layers, but it is for "
                    f"{head_mask.shape[0]}."
                )

        # 遍历每一层并执行前向传播
        for i, layer in enumerate(self.layers):
            # 如果需要输出隐藏状态,则记录当前层的隐藏状态
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            # 调用当前层的前向传播方法
            layer_outputs = layer(
                hidden_states,
                attention_mask,
                head_mask[i] if head_mask is not None else None,
                encoder_hidden_states,
                encoder_attention_mask,
                init_cache,
                deterministic,
                output_attentions,
            )

            # 提取当前层的输出隐藏状态
            hidden_states = layer_outputs[0]

            # 如果需要输出注意力权重,则记录当前层的注意力权重
            if output_attentions:
                all_attentions += (layer_outputs[1],)

                # 如果有编码器隐藏状态,则记录当前层的交叉注意力权重
                if encoder_hidden_states is not None:
                    all_cross_attentions += (layer_outputs[2],)

        # 如果需要输出隐藏状态,则记录最终的隐藏状态
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        # 整理所有输出,并根据 return_dict 决定输出格式
        outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)

        if not return_dict:
            # 如果不需要以字典形式返回,则返回一个包含非 None 值的元组
            return tuple(v for v in outputs if v is not None)

        # 否则,以指定的输出格式返回结果
        return FlaxBaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
            cross_attentions=all_cross_attentions,
        )
# 定义一个 FlaxBertEncoder 类,继承自 nn.Module,用于BERT编码器的实现
class FlaxBertEncoder(nn.Module):
    # 类属性:BERT 的配置信息
    config: BertConfig
    # 类属性:计算过程中使用的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32
    # 类属性:是否使用梯度检查点,默认为 False
    gradient_checkpointing: bool = False

    # 初始化方法,设置编码器的层集合
    def setup(self):
        self.layer = FlaxBertLayerCollection(
            self.config,
            dtype=self.dtype,
            gradient_checkpointing=self.gradient_checkpointing,
        )

    # 调用方法,用于执行编码器的前向计算
    def __call__(
        self,
        hidden_states,
        attention_mask,
        head_mask,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        return self.layer(
            hidden_states,
            attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            init_cache=init_cache,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )


# 定义一个 FlaxBertPooler 类,继承自 nn.Module,用于BERT的池化器
class FlaxBertPooler(nn.Module):
    # 类属性:BERT 的配置信息
    config: BertConfig
    # 类属性:计算过程中使用的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 初始化方法,设置池化器的全连接层
    def setup(self):
        self.dense = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )

    # 调用方法,用于执行池化器的前向计算
    def __call__(self, hidden_states):
        # 取第一个位置的隐藏状态作为池化器输入
        cls_hidden_state = hidden_states[:, 0]
        # 经过全连接层变换
        cls_hidden_state = self.dense(cls_hidden_state)
        # 使用双曲正切函数进行激活
        return nn.tanh(cls_hidden_state)


# 定义一个 FlaxBertPredictionHeadTransform 类,继承自 nn.Module,用于BERT预测头的变换
class FlaxBertPredictionHeadTransform(nn.Module):
    # 类属性:BERT 的配置信息
    config: BertConfig
    # 类属性:计算过程中使用的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 初始化方法,设置预测头变换的全连接层、激活函数和 LayerNorm 层
    def setup(self):
        self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
        self.activation = ACT2FN[self.config.hidden_act]
        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)

    # 调用方法,用于执行预测头变换的前向计算
    def __call__(self, hidden_states):
        # 经过全连接层变换
        hidden_states = self.dense(hidden_states)
        # 经过激活函数变换
        hidden_states = self.activation(hidden_states)
        # 经过 LayerNorm 层变换
        return self.LayerNorm(hidden_states)


# 定义一个 FlaxBertLMPredictionHead 类,继承自 nn.Module,用于BERT的语言模型预测头
class FlaxBertLMPredictionHead(nn.Module):
    # 类属性:BERT 的配置信息
    config: BertConfig
    # 类属性:计算过程中使用的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32
    # 类属性:偏置初始化函数,默认为全零
    bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros

    # 初始化方法,设置预测头的变换和全连接层
    def setup(self):
        self.transform = FlaxBertPredictionHeadTransform(self.config, dtype=self.dtype)
        self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False)
        self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
    # 定义一个特殊方法 __call__,用于将对象实例像函数一样调用
    def __call__(self, hidden_states, shared_embedding=None):
        # 调用 transform 方法,对输入的 hidden_states 进行转换处理
        hidden_states = self.transform(hidden_states)

        # 如果传入了共享的嵌入 shared_embedding,则使用 decoder 对象的 apply 方法
        if shared_embedding is not None:
            # 通过 decoder 对象的 apply 方法应用共享嵌入的转置 kernel 参数到 hidden_states
            hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
        else:
            # 如果没有共享嵌入,则直接使用 decoder 对象处理 hidden_states
            hidden_states = self.decoder(hidden_states)

        # 将对象的 bias 属性转换为 JAX 的数组,并将其加到 hidden_states 上
        bias = jnp.asarray(self.bias, self.dtype)
        hidden_states += bias
        
        # 返回经过处理后的 hidden_states
        return hidden_states
class FlaxBertOnlyMLMHead(nn.Module):
    config: BertConfig
    dtype: jnp.dtype = jnp.float32

    # 初始化模块,设置预测头部
    def setup(self):
        self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)

    # 调用模块,生成预测结果
    def __call__(self, hidden_states, shared_embedding=None):
        hidden_states = self.predictions(hidden_states, shared_embedding=shared_embedding)
        return hidden_states


class FlaxBertOnlyNSPHead(nn.Module):
    dtype: jnp.dtype = jnp.float32

    # 初始化模块,设置序列关系预测头部
    def setup(self):
        self.seq_relationship = nn.Dense(2, dtype=self.dtype)

    # 调用模块,生成序列关系预测结果
    def __call__(self, pooled_output):
        return self.seq_relationship(pooled_output)


class FlaxBertPreTrainingHeads(nn.Module):
    config: BertConfig
    dtype: jnp.dtype = jnp.float32

    # 初始化模块,设置预测头部和序列关系预测头部
    def setup(self):
        self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
        self.seq_relationship = nn.Dense(2, dtype=self.dtype)

    # 调用模块,生成预测头部和序列关系预测结果
    def __call__(self, hidden_states, pooled_output, shared_embedding=None):
        prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding)
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score


class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = BertConfig
    base_model_prefix = "bert"
    module_class: nn.Module = None

    # 初始化预训练模型,设置模块类和参数
    def __init__(
        self,
        config: BertConfig,
        input_shape: Tuple = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        gradient_checkpointing: bool = False,
        **kwargs,
    ):
        module = self.module_class(
            config=config,
            dtype=dtype,
            gradient_checkpointing=gradient_checkpointing,
            **kwargs,
        )
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

    # 启用梯度检查点
    def enable_gradient_checkpointing(self):
        self._module = self.module_class(
            config=self.config,
            dtype=self.dtype,
            gradient_checkpointing=True,
        )
    # 初始化模型权重的函数,接受随机数生成器rng、输入形状input_shape和可选参数params,并返回参数字典
    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # 初始化输入张量
        input_ids = jnp.zeros(input_shape, dtype="i4")  # 创建全零的输入张量
        token_type_ids = jnp.zeros_like(input_ids)  # 创建与input_ids相同形状的全零张量
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)  # 根据input_ids形状广播生成位置张量
        attention_mask = jnp.ones_like(input_ids)  # 创建与input_ids相同形状的全一张量作为注意力掩码
        head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))  # 创建全一头掩码张量

        params_rng, dropout_rng = jax.random.split(rng)  # 分割随机数生成器rng,用于参数和dropout

        rngs = {"params": params_rng, "dropout": dropout_rng}  # 创建包含params_rng和dropout_rng的随机数生成器字典

        if self.config.add_cross_attention:
            # 如果配置中包含跨注意力,则初始化编码器隐藏状态和注意力掩码
            encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
            encoder_attention_mask = attention_mask
            module_init_outputs = self.module.init(
                rngs,
                input_ids,
                attention_mask,
                token_type_ids,
                position_ids,
                head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                return_dict=False,
            )  # 调用模块的初始化函数,传入相应参数,返回初始化输出
        else:
            module_init_outputs = self.module.init(
                rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
            )  # 调用模块的初始化函数,传入相应参数,返回初始化输出

        random_params = module_init_outputs["params"]  # 从初始化输出中获取随机参数

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))  # 展开并解冻随机参数
            params = flatten_dict(unfreeze(params))  # 展开并解冻输入参数
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]  # 将缺失的键从随机参数复制到输入参数中
            self._missing_keys = set()  # 清空缺失键集合
            return freeze(unflatten_dict(params))  # 冻结和恢复输入参数字典结构并返回
        else:
            return random_params  # 如果没有输入参数,则直接返回随机参数

    # 从transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache复制的函数
    def init_cache(self, batch_size, max_length):
        r"""
        Args:
            batch_size (`int`):
                fast auto-regressive decoding使用的批量大小,定义初始化缓存的批量大小。
            max_length (`int`):
                auto-regressive decoding的最大可能长度,定义初始化缓存的序列长度。
        """
        # 初始化用于检索缓存的输入变量
        input_ids = jnp.ones((batch_size, max_length), dtype="i4")  # 创建全一的输入张量
        attention_mask = jnp.ones_like(input_ids, dtype="i4")  # 创建与input_ids相同形状的全一张量作为注意力掩码
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)  # 根据input_ids形状广播生成位置张量

        init_variables = self.module.init(
            jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
        )  # 调用模块的初始化函数,传入相应参数并初始化缓存,返回初始化变量
        return unfreeze(init_variables["cache"])  # 解冻并返回初始化变量中的缓存

    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    # 定义一个特殊方法 __call__,使得对象可以像函数一样被调用
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        params: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        past_key_values: dict = None,
# 定义一个名为FlaxBertModule的类,并继承自nn.Module
class FlaxBertModule(nn.Module):
    # 声明一个类型为BertConfig的config变量
    config: BertConfig
    # 声明一个名为dtype的变量,类型为jnp.dtype,默认值为jnp.float32,表示计算的数据类型
    dtype: jnp.dtype = jnp.float32
    # 声明一个名为add_pooling_layer的变量,类型为bool,默认值为True
    add_pooling_layer: bool = True
    # 声明一个名为gradient_checkpointing的变量,类型为bool,默认值为False
    gradient_checkpointing: bool = False

    # 定义一个setup方法
    def setup(self):
        # 初始化self.embeddings为FlaxBertEmbeddings类的实例对象
        self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype)
        # 初始化self.encoder为FlaxBertEncoder类的实例对象
        self.encoder = FlaxBertEncoder(
            self.config,
            dtype=self.dtype,
            gradient_checkpointing=self.gradient_checkpointing,
        )
        # 初始化self.pooler为FlaxBertPooler类的实例对象
        self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)

    # 定义一个__call__方法
    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids: Optional[jnp.ndarray] = None,
        position_ids: Optional[jnp.ndarray] = None,
        head_mask: Optional[jnp.ndarray] = None,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 如果token_type_ids为None,则初始化为和input_ids形状相同的全0数组
        if token_type_ids is None:
            token_type_ids = jnp.zeros_like(input_ids)

        # 如果position_ids为None,则初始化为将一维数组变成二维数组后进行广播扩展
        if position_ids is None:
            position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        # 将输入数据传递给self.embeddings进行处理,得到hidden_states
        hidden_states = self.embeddings(
            input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
        )
        # 将hidden_states传递给self.encoder进行处理,得到outputs
        outputs = self.encoder(
            hidden_states,
            attention_mask,
            head_mask=head_mask,
            deterministic=deterministic,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 从outputs中获取第一个元素赋值给hidden_states
        hidden_states = outputs[0]
        # 如果add_pooling_layer为True,则将hidden_states传递给self.pooler进行处理得到pooled,否则pooled为None
        pooled = self.pooler(hidden_states) if self.add_pooling_layer else None

        # 如果return_dict为False
        if not return_dict:
            # 如果pooled为None,不返回pooled
            if pooled is None:
                return (hidden_states,) + outputs[1:]
            # 返回包含hidden_states、pooled和outputs[1:]的元组
            return (hidden_states, pooled) + outputs[1:]

        # 返回FlaxBaseModelOutputWithPoolingAndCrossAttentions的实例对象,包含指定的属性和值
        return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=hidden_states,
            pooler_output=pooled,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )

# 使用add_start_docstrings函数装饰FlaxBertModel类
@add_start_docstrings(
    "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
    BERT_START_DOCSTRING,
)
class FlaxBertModel(FlaxBertPreTrainedModel):
    # 设置module_class属性为FlaxBertModule类
    module_class = FlaxBertModule
# 调用函数 `overwrite_call_docstring`,用于覆盖指定类的调用方法的文档字符串
overwrite_call_docstring(
    FlaxBertForPreTraining,



@add_start_docstrings(
    """
    Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
    sentence prediction (classification)` head.
    """,
    BERT_START_DOCSTRING,
)



# 创建一个自定义的文档字符串注解,描述了 Bert 模型在预训练期间的结构,包括了 `masked language modeling` 和 `next sentence prediction (classification)` 两个头部任务
@add_start_docstrings(
    """
    Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
    sentence prediction (classification)` head.
    """,
    BERT_START_DOCSTRING,
)



class FlaxBertForPreTraining(FlaxBertPreTrainedModel):
    # 将模型类设置为 FlaxBertForPreTrainingModule
    module_class = FlaxBertForPreTrainingModule



FLAX_BERT_FOR_PRETRAINING_DOCSTRING = """
    Returns:

    Example:

    ```
    >>> from transformers import AutoTokenizer, FlaxBertForPreTraining

    >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
    >>> model = FlaxBertForPreTraining.from_pretrained("google-bert/bert-base-uncased")

    >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
    >>> outputs = model(**inputs)

    >>> prediction_logits = outputs.prediction_logits
    >>> seq_relationship_logits = outputs.seq_relationship_logits
    ```



# 定义 FLAX_BERT_FOR_PRETRAINING_DOCSTRING,包含函数的返回值和使用示例
FLAX_BERT_FOR_PRETRAINING_DOCSTRING = """
    Returns:

    Example:

    ```
    >>> from transformers import AutoTokenizer, FlaxBertForPreTraining

    >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
    >>> model = FlaxBertForPreTraining.from_pretrained("google-bert/bert-base-uncased")

    >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
    >>> outputs = model(**inputs)

    >>> prediction_logits = outputs.prediction_logits
    >>> seq_relationship_logits = outputs.seq_relationship_logits
    ```
    # 使用字符串格式化函数 BERT_INPUTS_DOCSTRING 格式化输入的参数 "batch_size, sequence_length",并加上 FLAX_BERT_FOR_PRETRAINING_DOCSTRING 的内容
    BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BERT_FOR_PRETRAINING_DOCSTRING,
# 导入所需模块和函数
append_replace_return_docstrings(
    FlaxBertForPreTraining, output_type=FlaxBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
)

# 定义一个自定义的 nn.Module 类 FlaxBertForMaskedLMModule,用于处理 BERT 模型的 masked language modeling 任务
class FlaxBertForMaskedLMModule(nn.Module):
    config: BertConfig
    dtype: jnp.dtype = jnp.float32
    gradient_checkpointing: bool = False

    # 初始化函数,设置模型的各种参数和组件
    def setup(self):
        # 初始化一个 FlaxBertModule 实例,作为主要的 BERT 模型
        self.bert = FlaxBertModule(
            config=self.config,
            add_pooling_layer=False,
            dtype=self.dtype,
            gradient_checkpointing=self.gradient_checkpointing,
        )
        # 初始化一个 FlaxBertOnlyMLMHead 实例,用于预测 masked token
        self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)

    # 调用函数,定义模型的前向传播过程
    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        head_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 调用 bert 模型进行前向传播,得到输出
        outputs = self.bert(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取隐藏状态作为模型预测的输入
        hidden_states = outputs[0]
        
        # 根据配置判断是否共享词嵌入
        if self.config.tie_word_embeddings:
            shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
        else:
            shared_embedding = None

        # 计算预测的 logits
        logits = self.cls(hidden_states, shared_embedding=shared_embedding)

        # 根据 return_dict 决定返回的格式
        if not return_dict:
            return (logits,) + outputs[1:]

        # 返回 FlaxMaskedLMOutput 类型的结果,包括 logits、隐藏状态和注意力分布
        return FlaxMaskedLMOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


# 将自动生成的文档字符串添加到 FlaxBertForMaskedLM 类上,用于描述其语言建模头部的 BERT 模型
@add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING)
class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
    module_class = FlaxBertForMaskedLMModule


# 添加示例调用文档字符串到 FlaxBertForMaskedLM 类上,用于指定检查点、输出类型和配置的示例调用
append_call_sample_docstring(FlaxBertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC)


# 定义一个自定义的 nn.Module 类 FlaxBertForNextSentencePredictionModule,用于处理 BERT 模型的下一句预测任务
class FlaxBertForNextSentencePredictionModule(nn.Module):
    config: BertConfig
    dtype: jnp.dtype = jnp.float32
    gradient_checkpointing: bool = False

    # 初始化函数,设置模型的各种参数和组件
    def setup(self):
        # 初始化一个 FlaxBertModule 实例,作为主要的 BERT 模型
        self.bert = FlaxBertModule(
            config=self.config,
            dtype=self.dtype,
            gradient_checkpointing=self.gradient_checkpointing,
        )
        # 初始化一个 FlaxBertOnlyNSPHead 实例,用于预测下一句
        self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype)

    # 调用函数,定义模型的前向传播过程
    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        head_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        ):
        # 如果 return_dict 不为 None,则使用指定的 return_dict;否则使用类的默认配置值
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # 调用 BERT 模型进行推断
        outputs = self.bert(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从 BERT 输出中获取池化后的特征
        pooled_output = outputs[1]
        
        # 将池化后的特征输入到分类层,得到句子关系的预测分数
        seq_relationship_scores = self.cls(pooled_output)

        # 如果 return_dict 为 False,则返回结果元组,包含预测分数和额外的隐藏状态列表
        if not return_dict:
            return (seq_relationship_scores,) + outputs[2:]

        # 如果 return_dict 为 True,则返回 FlaxNextSentencePredictorOutput 对象,包含预测 logits、隐藏状态和注意力
        return FlaxNextSentencePredictorOutput(
            logits=seq_relationship_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
# 给 FlaxBertForNextSentencePrediction 类添加文档字符串,描述其包含“下一句预测(分类)”头部的 BERT 模型
@add_start_docstrings(
    """Bert Model with a `next sentence prediction (classification)` head on top.""",
    BERT_START_DOCSTRING,
)
class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel):
    # 模块类指向 FlaxBertForNextSentencePredictionModule
    module_class = FlaxBertForNextSentencePredictionModule


# FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING 包含详细的文档字符串,说明返回值和使用示例
FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING = """
    Returns:

    Example:

    ```
    >>> from transformers import AutoTokenizer, FlaxBertForNextSentencePrediction

    >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
    >>> model = FlaxBertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased")

    >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
    >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
    >>> encoding = tokenizer(prompt, next_sentence, return_tensors="jax")

    >>> outputs = model(**encoding)
    >>> logits = outputs.logits
    >>> assert logits[0, 0] < logits[0, 1]  # next sentence was random
    ```
"""

# 将 BERT_INPUTS_DOCSTRING 格式化后的字符串和 FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING 添加到 FlaxBertForNextSentencePrediction 类的文档字符串中
overwrite_call_docstring(
    FlaxBertForNextSentencePrediction,
    BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING,
)

# 附加和替换 FlaxBertForNextSentencePrediction 类的返回文档字符串,输出类型为 FlaxNextSentencePredictorOutput,配置类为 _CONFIG_FOR_DOC
append_replace_return_docstrings(
    FlaxBertForNextSentencePrediction, output_type=FlaxNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC
)


class FlaxBertForSequenceClassificationModule(nn.Module):
    config: BertConfig
    dtype: jnp.dtype = jnp.float32
    gradient_checkpointing: bool = False

    def setup(self):
        # 设置 BERT 模块,根据配置选择是否使用梯度检查点
        self.bert = FlaxBertModule(
            config=self.config,
            dtype=self.dtype,
            gradient_checkpointing=self.gradient_checkpointing,
        )
        # 根据配置设置分类器的 dropout 率,若未指定,则使用隐藏层的 dropout 率
        classifier_dropout = (
            self.config.classifier_dropout
            if self.config.classifier_dropout is not None
            else self.config.hidden_dropout_prob
        )
        # 设置 dropout 层
        self.dropout = nn.Dropout(rate=classifier_dropout)
        # 设置分类器,输出维度为配置中定义的标签数
        self.classifier = nn.Dense(
            self.config.num_labels,
            dtype=self.dtype,
        )

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        head_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 调用 BERT 模型进行前向传播,获取输出结果
        outputs = self.bert(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从 BERT 输出中获取池化后的特征表示
        pooled_output = outputs[1]
        # 对池化后的特征表示应用 dropout,以防止过拟合
        pooled_output = self.dropout(pooled_output, deterministic=deterministic)
        # 将池化后的特征表示输入分类器,得到 logits
        logits = self.classifier(pooled_output)

        # 如果不要求返回一个字典,则返回 logits 和额外的隐藏状态
        if not return_dict:
            return (logits,) + outputs[2:]

        # 如果要求返回一个字典,则封装输出为 FlaxSequenceClassifierOutput 类型
        return FlaxSequenceClassifierOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
@add_start_docstrings(
    """
    Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    """,
    BERT_START_DOCSTRING,
)
class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel):
    module_class = FlaxBertForSequenceClassificationModule



append_call_sample_docstring(
    FlaxBertForSequenceClassification,
    _CHECKPOINT_FOR_DOC,
    FlaxSequenceClassifierOutput,
    _CONFIG_FOR_DOC,
)



class FlaxBertForMultipleChoiceModule(nn.Module):
    config: BertConfig
    dtype: jnp.dtype = jnp.float32
    gradient_checkpointing: bool = False

    def setup(self):
        # 初始化 Bert 模型
        self.bert = FlaxBertModule(
            config=self.config,
            dtype=self.dtype,
            gradient_checkpointing=self.gradient_checkpointing,
        )
        # Dropout 层,用于随机丢弃输入特征
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
        # 分类器,全连接层,用于多项选择任务
        self.classifier = nn.Dense(1, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        head_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        num_choices = input_ids.shape[1]
        # 重塑输入以适应 Bert 模型的期望形状
        input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
        attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
        token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
        position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None

        # 调用 Bert 模型获取输出
        outputs = self.bert(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取池化后的输出
        pooled_output = outputs[1]
        # 应用 Dropout 层
        pooled_output = self.dropout(pooled_output, deterministic=deterministic)
        # 应用分类器获取最终 logits
        logits = self.classifier(pooled_output)

        # 重塑 logits 以适应多项选择任务的形状
        reshaped_logits = logits.reshape(-1, num_choices)

        if not return_dict:
            # 如果不返回字典,则返回 logits 和额外的隐藏状态
            return (reshaped_logits,) + outputs[2:]

        # 如果返回字典,则构造输出对象
        return FlaxMultipleChoiceModelOutput(
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )



@add_start_docstrings(
    """
    Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
    softmax) e.g. for RocStories/SWAG tasks.
    """,
    BERT_START_DOCSTRING,
)
class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel):
    module_class = FlaxBertForMultipleChoiceModule



overwrite_call_docstring(
    # 导入 FlaxBertForMultipleChoice 类
    FlaxBertForMultipleChoice, 
    # 使用 BERT_INPUTS_DOCSTRING 格式化字符串,描述输入参数的文档字符串
    BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
# 调用append_call_sample_docstring函数,向FlaxBertForMultipleChoice类中添加示例文档字符串
append_call_sample_docstring(
    FlaxBertForMultipleChoice, _CHECKPOINT_FOR_DOC, FlaxMultipleChoiceModelOutput, _CONFIG_FOR_DOC
)


class FlaxBertForTokenClassificationModule(nn.Module):
    config: BertConfig
    dtype: jnp.dtype = jnp.float32
    gradient_checkpointing: bool = False

    def setup(self):
        # 初始化self.bert作为FlaxBertModule实例,配置参数来自self.config,并设置相关参数
        self.bert = FlaxBertModule(
            config=self.config,
            dtype=self.dtype,
            add_pooling_layer=False,
            gradient_checkpointing=self.gradient_checkpointing,
        )
        # 设置分类器的dropout率,若未指定则使用self.config.hidden_dropout_prob的值
        classifier_dropout = (
            self.config.classifier_dropout
            if self.config.classifier_dropout is not None
            else self.config.hidden_dropout_prob
        )
        # 初始化self.dropout作为nn.Dropout实例,设定dropout率为classifier_dropout
        self.dropout = nn.Dropout(rate=classifier_dropout)
        # 初始化self.classifier作为nn.Dense实例,设定输出维度为self.config.num_labels
        self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        head_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 调用self.bert,传入参数并返回输出
        outputs = self.bert(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取BERT模型的隐藏状态
        hidden_states = outputs[0]
        # 对隐藏状态应用dropout操作,根据deterministic参数决定是否使用确定性dropout
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        # 对dropout后的隐藏状态进行分类预测,生成logits
        logits = self.classifier(hidden_states)

        # 如果return_dict为False,返回(logits,) + outputs[1:]
        if not return_dict:
            return (logits,) + outputs[1:]

        # 如果return_dict为True,返回FlaxTokenClassifierOutput对象,包括logits、隐藏状态和注意力机制
        return FlaxTokenClassifierOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


@add_start_docstrings(
    """
    Bert模型,在隐藏状态输出上增加了一个token分类头部(即隐藏状态输出之上的线性层),用于例如命名实体识别(NER)任务。
    """,
    BERT_START_DOCSTRING,
)
class FlaxBertForTokenClassification(FlaxBertPreTrainedModel):
    # 指定模块类为FlaxBertForTokenClassificationModule
    module_class = FlaxBertForTokenClassificationModule


# 调用append_call_sample_docstring函数,向FlaxBertForTokenClassification类中添加示例文档字符串
append_call_sample_docstring(
    FlaxBertForTokenClassification, _CHECKPOINT_FOR_DOC, FlaxTokenClassifierOutput, _CONFIG_FOR_DOC
)


class FlaxBertForQuestionAnsweringModule(nn.Module):
    config: BertConfig
    dtype: jnp.dtype = jnp.float32
    gradient_checkpointing: bool = False

    def setup(self):
        # 初始化self.bert作为FlaxBertModule实例,配置参数来自self.config,并设置相关参数
        self.bert = FlaxBertModule(
            config=self.config,
            dtype=self.dtype,
            add_pooling_layer=False,
            gradient_checkpointing=self.gradient_checkpointing,
        )
        # 初始化self.qa_outputs作为nn.Dense实例,设定输出维度为self.config.num_labels
        self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        head_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        """
        调用模型的方法,用于执行前向推断。

        Args:
            input_ids: 输入的token ID序列
            attention_mask: 注意力掩码,标识每个token的重要性
            token_type_ids: token类型ID,用于区分句子A和句子B等信息
            position_ids: 位置ID,指示每个token在输入序列中的位置
            head_mask: 头部掩码,控制每个注意力头的重要性
            deterministic: 是否以确定性方式运行(默认为True)
            output_attentions: 是否输出注意力权重(默认为False)
            output_hidden_states: 是否输出所有隐藏状态(默认为False)
            return_dict: 是否返回字典形式的输出(默认为True)

        Returns:
            FlaxQuestionAnsweringModelOutput 或 tuple:
                如果return_dict为True,则返回FlaxQuestionAnsweringModelOutput对象,
                包含起始和结束logits、隐藏状态和注意力权重等信息;
                如果return_dict为False,则返回元组,包含起始和结束logits以及额外的输出。
        """
        # 调用BERT模型的前向传播,获取模型输出
        outputs = self.bert(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从模型输出中提取隐藏状态
        hidden_states = outputs[0]

        # 将隐藏状态传入问答输出层,获取起始和结束logits
        logits = self.qa_outputs(hidden_states)
        
        # 根据问题答案的数量将logits进行分割
        start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
        
        # 去除最后一维的数据
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        # 如果不需要以字典的形式返回结果,则返回元组
        if not return_dict:
            return (start_logits, end_logits) + outputs[1:]

        # 以FlaxQuestionAnsweringModelOutput对象的形式返回结果
        return FlaxQuestionAnsweringModelOutput(
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
@add_start_docstrings(
    """
    Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
    BERT_START_DOCSTRING,
)
class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel):
    # 将 Bert 模型与用于抽取式问答任务的跨度分类头部结合在一起,例如在 SQuAD 上操作(在隐藏状态输出之上的线性层,
    # 用于计算 `span start logits` 和 `span end logits`)。
    module_class = FlaxBertForQuestionAnsweringModule


append_call_sample_docstring(
    FlaxBertForQuestionAnswering,
    _CHECKPOINT_FOR_DOC,
    FlaxQuestionAnsweringModelOutput,
    _CONFIG_FOR_DOC,
)


class FlaxBertForCausalLMModule(nn.Module):
    config: BertConfig
    dtype: jnp.dtype = jnp.float32
    gradient_checkpointing: bool = False

    def setup(self):
        # 初始化方法,设置模块中的组件
        self.bert = FlaxBertModule(
            config=self.config,
            add_pooling_layer=False,
            dtype=self.dtype,
            gradient_checkpointing=self.gradient_checkpointing,
        )
        self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        position_ids,
        token_type_ids: Optional[jnp.ndarray] = None,
        head_mask: Optional[jnp.ndarray] = None,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 模型调用方法
        # 调用内部的 Bert 模块进行前向传播
        outputs = self.bert(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            init_cache=init_cache,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]
        if self.config.tie_word_embeddings:
            # 如果配置要求共享词嵌入,获取共享的嵌入参数
            shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
        else:
            shared_embedding = None

        # 计算预测分数
        logits = self.cls(hidden_states, shared_embedding=shared_embedding)

        if not return_dict:
            return (logits,) + outputs[1:]

        # 返回带有交叉注意力的 FlaxCausalLMOutputWithCrossAttentions 对象
        return FlaxCausalLMOutputWithCrossAttentions(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )


@add_start_docstrings(
    """
    Bert Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
    autoregressive tasks.
    """,
    BERT_START_DOCSTRING,
)
class FlaxBertForCausalLM(FlaxBertPreTrainedModel):
    # 将 Bert 模型与用于语言建模任务的头部结合在一起(在隐藏状态输出之上的线性层),例如自回归任务。
    # 设置模块类为 FlaxBertForCausalLMModule
    module_class = FlaxBertForCausalLMModule
    
    # 为生成器准备输入的函数定义,接受输入的token ids、最大长度、可选的注意力掩码
    def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
        # 初始化缓存
        batch_size, seq_length = input_ids.shape
    
        # 使用模型自定义方法初始化缓存,返回过去的键值对
        past_key_values = self.init_cache(batch_size, max_length)
    
        # 注意,在通常情况下需要在 attention_mask 的 x > input_ids.shape[-1] 和 x < cache_length 的位置放置 0。
        # 但由于解码器使用因果注意力掩码,这些位置已经被掩盖了。
        # 因此,我们可以在这里创建一个静态的注意力掩码,这样更有效地进行编译
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
    
        # 如果有传入注意力掩码,则根据它计算位置 ids,并动态更新静态的注意力掩码
        if attention_mask is not None:
            position_ids = attention_mask.cumsum(axis=-1) - 1
            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
        else:
            # 如果没有传入注意力掩码,则生成位置 ids,用于模型输入
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
    
        # 返回包含 past_key_values、extended_attention_mask 和 position_ids 的字典
        return {
            "past_key_values": past_key_values,
            "attention_mask": extended_attention_mask,
            "position_ids": position_ids,
        }
    
    # 更新生成器输入的函数定义,接受模型输出和模型参数关键字作为输入
    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        # 更新模型参数关键字中的 past_key_values 和 position_ids
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
        return model_kwargs
# 调用函数 append_call_sample_docstring,添加示例文档字符串到 FlaxBertForCausalLM 类中
append_call_sample_docstring(
    FlaxBertForCausalLM,
    # 示例文档字符串的检查点
    _CHECKPOINT_FOR_DOC,
    # 使用交叉注意力的 FlaxCausalLMOutputWithCrossAttentions 类
    FlaxCausalLMOutputWithCrossAttentions,
    # 示例文档字符串的配置
    _CONFIG_FOR_DOC,
)

.\models\bert\modeling_tf_bert.py

# coding=utf-8
# 版权声明:2018 年由 Google AI 语言团队和 HuggingFace Inc. 团队所有。
# 版权声明:2018 年,NVIDIA CORPORATION 版权所有。
#
# 根据 Apache 许可证 2.0 版本("许可证")获得许可;
# 除非符合许可证要求或书面同意,否则不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件按"原样"分发,
# 没有任何形式的明示或暗示的担保或条件。
# 有关详细信息,请参阅许可证。
""" TF 2.0 BERT 模型。"""


from __future__ import annotations

import math  # 导入数学函数库
import warnings  # 导入警告模块
from dataclasses import dataclass  # 导入 dataclass 用于定义数据类
from typing import Dict, Optional, Tuple, Union  # 导入类型提示工具

import numpy as np  # 导入 NumPy 库
import tensorflow as tf  # 导入 TensorFlow 库

from ...activations_tf import get_tf_activation  # 从本地包中导入 TensorFlow 激活函数
from ...modeling_tf_outputs import (
    TFBaseModelOutputWithPastAndCrossAttentions,  # 导入 TFBaseModelOutputWithPastAndCrossAttentions 输出类
    TFBaseModelOutputWithPoolingAndCrossAttentions,  # 导入 TFBaseModelOutputWithPoolingAndCrossAttentions 输出类
    TFCausalLMOutputWithCrossAttentions,  # 导入 TFCausalLMOutputWithCrossAttentions 输出类
    TFMaskedLMOutput,  # 导入 TFMaskedLMOutput 输出类
    TFMultipleChoiceModelOutput,  # 导入 TFMultipleChoiceModelOutput 输出类
    TFNextSentencePredictorOutput,  # 导入 TFNextSentencePredictorOutput 输出类
    TFQuestionAnsweringModelOutput,  # 导入 TFQuestionAnsweringModelOutput 输出类
    TFSequenceClassifierOutput,  # 导入 TFSequenceClassifierOutput 输出类
    TFTokenClassifierOutput,  # 导入 TFTokenClassifierOutput 输出类
)
from ...modeling_tf_utils import (
    TFCausalLanguageModelingLoss,  # 导入 TFCausalLanguageModelingLoss 损失类
    TFMaskedLanguageModelingLoss,  # 导入 TFMaskedLanguageModelingLoss 损失类
    TFModelInputType,  # 导入 TFModelInputType 输入类型
    TFMultipleChoiceLoss,  # 导入 TFMultipleChoiceLoss 损失类
    TFNextSentencePredictionLoss,  # 导入 TFNextSentencePredictionLoss 损失类
    TFPreTrainedModel,  # 导入 TFPreTrainedModel 预训练模型类
    TFQuestionAnsweringLoss,  # 导入 TFQuestionAnsweringLoss 损失类
    TFSequenceClassificationLoss,  # 导入 TFSequenceClassificationLoss 损失类
    TFTokenClassificationLoss,  # 导入 TFTokenClassificationLoss 损失类
    get_initializer,  # 导入获取初始化器函数
    keras,  # 导入 Keras 库
    keras_serializable,  # 导入 Keras 序列化功能
    unpack_inputs,  # 导入解包输入函数
)
from ...tf_utils import (
    check_embeddings_within_bounds,  # 导入检查嵌入范围的函数
    shape_list,  # 导入获取张量形状的函数
    stable_softmax,  # 导入稳定 Softmax 函数
)
from ...utils import (
    ModelOutput,  # 导入模型输出类
    add_code_sample_docstrings,  # 导入添加代码示例文档字符串函数
    add_start_docstrings,  # 导入添加起始文档字符串函数
    add_start_docstrings_to_model_forward,  # 导入向前模型添加起始文档字符串函数
    logging,  # 导入日志模块
    replace_return_docstrings,  # 导入替换返回文档字符串函数
)
from .configuration_bert import BertConfig  # 从本地配置文件导入 BertConfig 类


logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

_CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased"  # 预训练模型的文档检查点
_CONFIG_FOR_DOC = "BertConfig"  # BertConfig 的文档配置

# TokenClassification 文档字符串
_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english"  # 标记分类预训练模型检查点
_TOKEN_CLASS_EXPECTED_OUTPUT = (
    "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] "
)  # 标记分类预期输出
_TOKEN_CLASS_EXPECTED_LOSS = 0.01  # 标记分类预期损失

# QuestionAnswering 文档字符串
_CHECKPOINT_FOR_QA = "ydshieh/bert-base-cased-squad2"  # 问答预训练模型检查点
_QA_EXPECTED_OUTPUT = "'a nice puppet'"  # 问答预期输出
_QA_EXPECTED_LOSS = 7.41  # 问答预期损失
_QA_TARGET_START_INDEX = 14  # 问答目标起始索引
_QA_TARGET_END_INDEX = 15  # 问答目标结束索引

# SequenceClassification 文档字符串
_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ydshieh/bert-base-uncased-yelp-polarity"  # 序列分类预训练模型检查点
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"  # 序列分类预期输出
_SEQ_CLASS_EXPECTED_LOSS = 0.01  # 序列分类预期损失

TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "google-bert/bert-base-uncased",  # 预训练模型存档列表
    "google-bert/bert-large-uncased",  # 预训练模型存档列表
    # 列出了多个预训练的BERT模型的名称,每个名称代表一个特定配置和语言的BERT模型
    [
        "google-bert/bert-base-cased",  # 谷歌的BERT基础模型,大小写敏感
        "google-bert/bert-large-cased",  # 谷歌的BERT大型模型,大小写敏感
        "google-bert/bert-base-multilingual-uncased",  # 谷歌的多语言BERT基础模型,大小写不敏感
        "google-bert/bert-base-multilingual-cased",  # 谷歌的多语言BERT基础模型,大小写敏感
        "google-bert/bert-base-chinese",  # 谷歌的中文BERT基础模型
        "google-bert/bert-base-german-cased",  # 谷歌的德语BERT基础模型,大小写敏感
        "google-bert/bert-large-uncased-whole-word-masking",  # 谷歌的大型BERT模型,全词遮盖,大小写不敏感
        "google-bert/bert-large-cased-whole-word-masking",  # 谷歌的大型BERT模型,全词遮盖,大小写敏感
        "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad",  # 谷歌的在SQuAD上微调的大型BERT模型,全词遮盖,大小写不敏感
        "google-bert/bert-large-cased-whole-word-masking-finetuned-squad",  # 谷歌的在SQuAD上微调的大型BERT模型,全词遮盖,大小写敏感
        "google-bert/bert-base-cased-finetuned-mrpc",  # 谷歌的在MRPC任务上微调的BERT基础模型,大小写敏感
        "cl-tohoku/bert-base-japanese",  # 东北大学的日语BERT基础模型
        "cl-tohoku/bert-base-japanese-whole-word-masking",  # 东北大学的日语BERT基础模型,全词遮盖
        "cl-tohoku/bert-base-japanese-char",  # 东北大学的日语BERT基础模型,字符级别
        "cl-tohoku/bert-base-japanese-char-whole-word-masking",  # 东北大学的日语BERT基础模型,字符级别,全词遮盖
        "TurkuNLP/bert-base-finnish-cased-v1",  # TurkuNLP的芬兰语BERT基础模型,大小写敏感
        "TurkuNLP/bert-base-finnish-uncased-v1",  # TurkuNLP的芬兰语BERT基础模型,大小写不敏感
        "wietsedv/bert-base-dutch-cased",  # Wietsedv的荷兰语BERT基础模型,大小写敏感
        # 查看所有BERT模型,请访问 https://huggingface.co/models?filter=bert
    ]
        super().__init__(**kwargs)

        # 初始化层参数,保存BERT配置
        self.config = config
        # 获取BERT模型隐藏层大小
        self.hidden_size = config.hidden_size
        # 获取最大位置嵌入数
        self.max_position_embeddings = config.max_position_embeddings
        # 获取初始化范围
        self.initializer_range = config.initializer_range
        # 创建LayerNorm层,并设置epsilon值
        self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        # 创建Dropout层,并设置丢弃率
        self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
    # 定义 build 方法,用于构建模型结构
    def build(self, input_shape=None):
        # 在 "word_embeddings" 命名空间下创建权重矩阵,用于词嵌入
        self.weight = self.add_weight(
            name="weight",
            shape=[self.config.vocab_size, self.hidden_size],
            initializer=get_initializer(self.initializer_range),
        )

        # 在 "token_type_embeddings" 命名空间下创建权重矩阵,用于标记类型嵌入
        self.token_type_embeddings = self.add_weight(
            name="embeddings",
            shape=[self.config.type_vocab_size, self.hidden_size],
            initializer=get_initializer(self.initializer_range),
        )

        # 在 "position_embeddings" 命名空间下创建权重矩阵,用于位置嵌入
        self.position_embeddings = self.add_weight(
            name="embeddings",
            shape=[self.max_position_embeddings, self.hidden_size],
            initializer=get_initializer(self.initializer_range),
        )

        # 如果模型已构建,则直接返回,避免重复构建
        if self.built:
            return
        self.built = True
        
        # 如果存在 LayerNorm 层,则构建 LayerNorm 层,输入形状为 [None, None, self.config.hidden_size]
        if getattr(self, "LayerNorm", None) is not None:
            with tf.name_scope(self.LayerNorm.name):
                self.LayerNorm.build([None, None, self.config.hidden_size])

    # 定义 call 方法,用于执行模型前向传播
    def call(
        self,
        input_ids: tf.Tensor = None,
        position_ids: tf.Tensor = None,
        token_type_ids: tf.Tensor = None,
        inputs_embeds: tf.Tensor = None,
        past_key_values_length=0,
        training: bool = False,
    ) -> tf.Tensor:
        """
        Applies embedding based on inputs tensor.

        Returns:
            final_embeddings (`tf.Tensor`): output embedding tensor.
        """
        # 如果没有提供 input_ids 或 inputs_embeds,则抛出 ValueError
        if input_ids is None and inputs_embeds is None:
            raise ValueError("Need to provide either `input_ids` or `input_embeds`.")

        # 如果提供了 input_ids,则从权重矩阵中获取对应的嵌入向量
        if input_ids is not None:
            check_embeddings_within_bounds(input_ids, self.config.vocab_size)
            inputs_embeds = tf.gather(params=self.weight, indices=input_ids)

        input_shape = shape_list(inputs_embeds)[:-1]

        # 如果未提供 token_type_ids,则创建一个形状与 inputs_embeds 相同的全 0 张量
        if token_type_ids is None:
            token_type_ids = tf.fill(dims=input_shape, value=0)

        # 如果未提供 position_ids,则创建一个序列张量,范围从 past_key_values_length 到 input_shape[1] + past_key_values_length
        if position_ids is None:
            position_ids = tf.expand_dims(
                tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
            )

        # 从 position_embeddings 中根据 position_ids 获取位置嵌入向量
        position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
        # 从 token_type_embeddings 中根据 token_type_ids 获取标记类型嵌入向量
        token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
        # 计算最终的嵌入向量,包括输入嵌入、位置嵌入和标记类型嵌入
        final_embeddings = inputs_embeds + position_embeds + token_type_embeds
        # 对最终的嵌入向量应用 LayerNorm 层
        final_embeddings = self.LayerNorm(inputs=final_embeddings)
        # 对最终的嵌入向量应用 dropout,用于训练时防止过拟合
        final_embeddings = self.dropout(inputs=final_embeddings, training=training)

        return final_embeddings
class TFBertSelfAttention(keras.layers.Layer):
    def __init__(self, config: BertConfig, **kwargs):
        super().__init__(**kwargs)

        # 检查隐藏大小是否是注意力头数的整数倍
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number "
                f"of attention heads ({config.num_attention_heads})"
            )

        # 初始化注意力头数和每个注意力头的大小
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.sqrt_att_head_size = math.sqrt(self.attention_head_size)

        # 创建用于计算查询、键和值的全连接层
        self.query = keras.layers.Dense(
            units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
        )
        self.key = keras.layers.Dense(
            units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
        )
        self.value = keras.layers.Dense(
            units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
        )
        # 配置注意力概率的丢弃层
        self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)

        # 是否作为解码器使用和配置信息
        self.is_decoder = config.is_decoder
        self.config = config

    def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
        # 重塑张量形状,从 [batch_size, seq_length, all_head_size] 到 [batch_size, seq_length, num_attention_heads, attention_head_size]
        tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))

        # 转置张量,从 [batch_size, seq_length, num_attention_heads, attention_head_size] 到 [batch_size, num_attention_heads, seq_length, attention_head_size]
        return tf.transpose(tensor, perm=[0, 2, 1, 3])

    def call(
        self,
        hidden_states: tf.Tensor,
        attention_mask: tf.Tensor,
        head_mask: tf.Tensor,
        encoder_hidden_states: tf.Tensor,
        encoder_attention_mask: tf.Tensor,
        past_key_value: Tuple[tf.Tensor],
        output_attentions: bool,
        training: bool = False,
    ):
        # 本层的调用方法,将在实际使用时详细处理各种输入和输出逻辑
        pass

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 构建查询、键和值层,设置它们的输入形状
        if getattr(self, "query", None) is not None:
            with tf.name_scope(self.query.name):
                self.query.build([None, None, self.config.hidden_size])
        if getattr(self, "key", None) is not None:
            with tf.name_scope(self.key.name):
                self.key.build([None, None, self.config.hidden_size])
        if getattr(self, "value", None) is not None:
            with tf.name_scope(self.value.name):
                self.value.build([None, None, self.config.hidden_size])
    # 初始化函数,接受一个BertConfig对象和其他关键字参数
    def __init__(self, config: BertConfig, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)

        # 创建一个全连接层,单元数为config.hidden_size,使用给定的初始化器初始化权重,命名为"dense"
        self.dense = keras.layers.Dense(
            units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        
        # 创建一个LayerNormalization层,使用给定的epsilon参数,命名为"LayerNorm"
        self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        
        # 创建一个Dropout层,使用给定的dropout率
        self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
        
        # 保存传入的BertConfig对象
        self.config = config

    # call方法,接受hidden_states(隐藏状态)、input_tensor(输入张量)、training(是否在训练模式下)参数,
    # 返回处理后的隐藏状态张量
    def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
        # 将隐藏状态传入全连接层进行线性变换
        hidden_states = self.dense(inputs=hidden_states)
        
        # 根据训练模式应用Dropout操作
        hidden_states = self.dropout(inputs=hidden_states, training=training)
        
        # 将Dropout后的隐藏状态与输入张量相加,并通过LayerNormalization进行归一化处理
        hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)

        # 返回处理后的隐藏状态张量
        return hidden_states

    # build方法,用于构建层的权重(如果尚未构建)
    def build(self, input_shape=None):
        # 如果已经构建过,则直接返回
        if self.built:
            return
        
        # 标记为已构建
        self.built = True
        
        # 如果dense层已经定义,则使用dense层的名称作为作用域
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                # 构建dense层的权重,输入形状为[None, None, config.hidden_size]
                self.dense.build([None, None, self.config.hidden_size])
        
        # 如果LayerNorm层已经定义,则使用LayerNorm层的名称作为作用域
        if getattr(self, "LayerNorm", None) is not None:
            with tf.name_scope(self.LayerNorm.name):
                # 构建LayerNorm层的权重,输入形状为[None, None, config.hidden_size]
                self.LayerNorm.build([None, None, self.config.hidden_size])
# 定义一个基于 Keras 的自定义层 TFBertAttention,用于 BERT 模型的自注意力机制
class TFBertAttention(keras.layers.Layer):
    def __init__(self, config: BertConfig, **kwargs):
        super().__init__(**kwargs)

        # 创建自注意力层对象,使用给定的 BertConfig 进行配置
        self.self_attention = TFBertSelfAttention(config, name="self")
        # 创建自注意力输出层对象,使用给定的 BertConfig 进行配置
        self.dense_output = TFBertSelfOutput(config, name="output")

    # 未实现的方法,用于裁剪注意力机制中的某些头部
    def prune_heads(self, heads):
        raise NotImplementedError

    # 定义调用方法,实现自注意力机制的前向传播
    def call(
        self,
        input_tensor: tf.Tensor,
        attention_mask: tf.Tensor,
        head_mask: tf.Tensor,
        encoder_hidden_states: tf.Tensor,
        encoder_attention_mask: tf.Tensor,
        past_key_value: Tuple[tf.Tensor],
        output_attentions: bool,
        training: bool = False,
    ) -> Tuple[tf.Tensor]:
        # 使用 self_attention 对输入张量进行自注意力计算
        self_outputs = self.self_attention(
            hidden_states=input_tensor,
            attention_mask=attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            training=training,
        )
        # 使用 dense_output 对自注意力输出进行处理
        attention_output = self.dense_output(
            hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
        )
        # 如果需要输出注意力权重,将其加入到输出元组中
        outputs = (attention_output,) + self_outputs[1:]

        return outputs

    # 构建层结构,在第一次调用时构建子层的图结构
    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 构建 self_attention 子层的图结构
        if getattr(self, "self_attention", None) is not None:
            with tf.name_scope(self.self_attention.name):
                self.self_attention.build(None)
        # 构建 dense_output 子层的图结构
        if getattr(self, "dense_output", None) is not None:
            with tf.name_scope(self.dense_output.name):
                self.dense_output.build(None)


# 定义一个基于 Keras 的自定义层 TFBertIntermediate,用于 BERT 模型的中间层处理
class TFBertIntermediate(keras.layers.Layer):
    def __init__(self, config: BertConfig, **kwargs):
        super().__init__(**kwargs)

        # 创建全连接层对象,设置神经元数和初始化方式
        self.dense = keras.layers.Dense(
            units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )

        # 根据配置获取中间激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = get_tf_activation(config.hidden_act)
        else:
            self.intermediate_act_fn = config.hidden_act
        self.config = config

    # 定义调用方法,实现中间层的前向传播
    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        # 使用全连接层对输入张量进行线性变换
        hidden_states = self.dense(inputs=hidden_states)
        # 使用中间激活函数对线性变换结果进行非线性变换
        hidden_states = self.intermediate_act_fn(hidden_states)

        return hidden_states

    # 构建层结构,在第一次调用时构建图结构
    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 构建 dense 子层的图结构
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                self.dense.build([None, None, self.config.hidden_size])


class TFBertOutput(keras.layers.Layer):
    # 这里继续补充 TFBertOutput 类的注释
    # 初始化函数,用于创建一个新的实例
    def __init__(self, config: BertConfig, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)

        # 创建一个全连接层,用于线性变换
        self.dense = keras.layers.Dense(
            units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        # 创建一个层归一化层,用于归一化输入数据
        self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        # 创建一个丢弃层,用于在训练时随机丢弃部分数据,防止过拟合
        self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
        # 存储配置对象,方便在调用中使用
        self.config = config

    # 调用函数,定义了实例的前向传播逻辑
    def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
        # 将输入数据通过全连接层进行线性变换
        hidden_states = self.dense(inputs=hidden_states)
        # 在训练时,随机丢弃部分数据以防止过拟合
        hidden_states = self.dropout(inputs=hidden_states, training=training)
        # 对线性变换后的数据进行层归一化,并与原始输入数据相加
        hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)

        # 返回处理后的数据作为输出
        return hidden_states

    # 构建函数,用于在首次调用时构建层的内部结构
    def build(self, input_shape=None):
        # 如果已经构建过,直接返回
        if self.built:
            return
        # 标记为已构建
        self.built = True
        # 如果存在全连接层,根据配置参数构建其内部结构
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                self.dense.build([None, None, self.config.intermediate_size])
        # 如果存在层归一化层,根据配置参数构建其内部结构
        if getattr(self, "LayerNorm", None) is not None:
            with tf.name_scope(self.LayerNorm.name):
                self.LayerNorm.build([None, None, self.config.hidden_size])
# 定义一个自定义层 TFBertLayer,继承自 keras 的 Layer 类
class TFBertLayer(keras.layers.Layer):
    # 初始化方法,接受一个 BertConfig 类型的 config 参数和其他关键字参数
    def __init__(self, config: BertConfig, **kwargs):
        # 调用父类 Layer 的初始化方法
        super().__init__(**kwargs)

        # 创建一个 TFBertAttention 层实例,用给定的 config 参数和名称 "attention"
        self.attention = TFBertAttention(config, name="attention")
        
        # 根据 config 中的 is_decoder 属性设置当前层是否为解码器
        self.is_decoder = config.is_decoder
        
        # 根据 config 中的 add_cross_attention 属性设置是否添加跨注意力机制
        self.add_cross_attention = config.add_cross_attention
        
        # 如果要添加跨注意力机制,且当前层不是解码器,则抛出 ValueError 异常
        if self.add_cross_attention:
            if not self.is_decoder:
                raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
            
            # 创建一个 TFBertAttention 层实例,用给定的 config 参数和名称 "crossattention"
            self.crossattention = TFBertAttention(config, name="crossattention")
        
        # 创建一个 TFBertIntermediate 层实例,用给定的 config 参数和名称 "intermediate"
        self.intermediate = TFBertIntermediate(config, name="intermediate")
        
        # 创建一个 TFBertOutput 层实例,用给定的 config 参数和名称 "output"
        self.bert_output = TFBertOutput(config, name="output")

    # 定义层的调用方法,接受多个输入参数,包括隐藏状态、注意力掩码等
    def call(
        self,
        hidden_states: tf.Tensor,
        attention_mask: tf.Tensor,
        head_mask: tf.Tensor,
        encoder_hidden_states: tf.Tensor | None,
        encoder_attention_mask: tf.Tensor | None,
        past_key_value: Tuple[tf.Tensor] | None,
        output_attentions: bool,
        training: bool = False,
        # 函数定义未完全展示,缺少返回类型注释
    ) -> Tuple[tf.Tensor]:
        # 如果存在过去的键/值缓存,取前两个元素作为自注意力的过去键/值
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
        # 使用自注意力层处理隐藏状态,计算自注意力输出
        self_attention_outputs = self.attention(
            input_tensor=hidden_states,
            attention_mask=attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            past_key_value=self_attn_past_key_value,
            output_attentions=output_attentions,
            training=training,
        )
        # 获取自注意力的输出
        attention_output = self_attention_outputs[0]

        # 如果模型为解码器,最后一个输出是自注意力缓存的元组
        if self.is_decoder:
            outputs = self_attention_outputs[1:-1]
            # 获取当前的键/值缓存
            present_key_value = self_attention_outputs[-1]
        else:
            # 否则,输出除了第一个元素外的所有元素(即自注意力权重)
            outputs = self_attention_outputs[1:]  # 如果需要输出注意力权重,添加自注意力权重

        cross_attn_present_key_value = None
        # 如果是解码器并且有编码器的隐藏状态输入
        if self.is_decoder and encoder_hidden_states is not None:
            # 如果未定义交叉注意力层,则引发错误
            if not hasattr(self, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
                    " by setting `config.add_cross_attention=True`"
                )

            # 如果存在过去的键/值缓存,取倒数第二个和最后一个元素作为交叉注意力的过去键/值
            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
            # 使用交叉注意力层处理自注意力输出,计算交叉注意力输出
            cross_attention_outputs = self.crossattention(
                input_tensor=attention_output,
                attention_mask=attention_mask,
                head_mask=head_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                past_key_value=cross_attn_past_key_value,
                output_attentions=output_attentions,
                training=training,
            )
            # 获取交叉注意力的输出
            attention_output = cross_attention_outputs[0]
            # 将交叉注意力权重添加到输出中
            outputs = outputs + cross_attention_outputs[1:-1]

            # 将交叉注意力缓存添加到当前的键/值缓存中的倒数第二个和最后一个位置
            cross_attn_present_key_value = cross_attention_outputs[-1]
            present_key_value = present_key_value + cross_attn_present_key_value

        # 使用中间层处理注意力输出,得到中间层输出
        intermediate_output = self.intermediate(hidden_states=attention_output)
        # 使用BERT输出层处理中间层输出和注意力输出,得到层输出
        layer_output = self.bert_output(
            hidden_states=intermediate_output, input_tensor=attention_output, training=training
        )
        # 将层输出与注意力权重(如果存在)合并到输出中
        outputs = (layer_output,) + outputs

        # 如果是解码器,将注意力的键/值作为最后一个输出返回
        if self.is_decoder:
            outputs = outputs + (present_key_value,)

        # 返回最终的输出
        return outputs
    # 构建模型的方法,用于在给定输入形状的情况下构建模型的各个部分
    def build(self, input_shape=None):
        # 如果模型已经构建过,则直接返回,避免重复构建
        if self.built:
            return
        # 将模型标记为已构建状态
        self.built = True
        
        # 如果模型具有 attention 属性,则构建 attention 部分
        if getattr(self, "attention", None) is not None:
            # 使用 attention 的名称作为命名空间,构建 attention 层
            with tf.name_scope(self.attention.name):
                self.attention.build(None)
        
        # 如果模型具有 intermediate 属性,则构建 intermediate 部分
        if getattr(self, "intermediate", None) is not None:
            # 使用 intermediate 的名称作为命名空间,构建 intermediate 层
            with tf.name_scope(self.intermediate.name):
                self.intermediate.build(None)
        
        # 如果模型具有 bert_output 属性,则构建 bert_output 部分
        if getattr(self, "bert_output", None) is not None:
            # 使用 bert_output 的名称作为命名空间,构建 bert_output 层
            with tf.name_scope(self.bert_output.name):
                self.bert_output.build(None)
        
        # 如果模型具有 crossattention 属性,则构建 crossattention 部分
        if getattr(self, "crossattention", None) is not None:
            # 使用 crossattention 的名称作为命名空间,构建 crossattention 层
            with tf.name_scope(self.crossattention.name):
                self.crossattention.build(None)
# 定义一个基于Keras层的TFBertEncoder类,用于BERT模型的编码器部分
class TFBertEncoder(keras.layers.Layer):
    def __init__(self, config: BertConfig, **kwargs):
        super().__init__(**kwargs)
        # 初始化时保存BERT配置信息
        self.config = config
        # 创建多个TFBertLayer实例作为编码器的层,并命名每一层
        self.layer = [TFBertLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]

    # 定义调用方法,实现编码器的前向传播
    def call(
        self,
        hidden_states: tf.Tensor,                 # 输入的隐藏状态张量
        attention_mask: tf.Tensor,                # 自注意力机制的掩码张量
        head_mask: tf.Tensor,                     # 头部掩码张量,用于控制多头注意力中的哪些头参与计算
        encoder_hidden_states: tf.Tensor | None,  # 编码器的隐藏状态张量,如果存在的话
        encoder_attention_mask: tf.Tensor | None, # 编码器的注意力掩码张量,如果存在的话
        past_key_values: Tuple[Tuple[tf.Tensor]] | None,  # 过去的键值对,用于缓存
        use_cache: Optional[bool],                # 是否使用缓存
        output_attentions: bool,                  # 是否输出注意力权重
        output_hidden_states: bool,               # 是否输出所有隐藏状态
        return_dict: bool,                        # 是否返回字典形式的结果
        training: bool = False,                   # 是否处于训练模式
    ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
        # 初始化存储所有隐藏状态、注意力权重和交叉注意力权重的空元组
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

        # 初始化下一个解码器缓存的空元组,如果使用缓存的话
        next_decoder_cache = () if use_cache else None

        # 遍历每一层编码器
        for i, layer_module in enumerate(self.layer):
            # 如果需要输出隐藏状态,则将当前层的隐藏状态添加到all_hidden_states中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 获取当前层的过去键值对,用于当前层的注意力机制
            past_key_value = past_key_values[i] if past_key_values is not None else None

            # 调用当前层的前向传播,得到当前层的输出
            layer_outputs = layer_module(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                head_mask=head_mask[i],
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                training=training,
            )
            # 更新当前隐藏状态为当前层的输出的第一个元素
            hidden_states = layer_outputs[0]

            # 如果使用缓存,则将当前层的输出的最后一个元素添加到下一个解码器缓存中
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)

            # 如果需要输出注意力权重,则将当前层的注意力权重添加到all_attentions中
            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)
                # 如果配置中添加了交叉注意力并且编码器隐藏状态不为空,则将当前层的交叉注意力添加到all_cross_attentions中
                if self.config.add_cross_attention and encoder_hidden_states is not None:
                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

        # 添加最后一层编码器的隐藏状态,如果需要输出隐藏状态的话
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 如果不需要返回字典形式的结果,则返回非None的所有元组元素
        if not return_dict:
            return tuple(
                v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
            )

        # 返回TFBaseModelOutputWithPastAndCrossAttentions类型的结果字典
        return TFBaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
            cross_attentions=all_cross_attentions,
        )
    # 定义一个构建模型的方法,该方法可以接受输入形状作为参数
    def build(self, input_shape=None):
        # 如果模型已经构建过,则直接返回,不进行重复构建
        if self.built:
            return
        # 将标志位设置为已构建
        self.built = True
        # 检查是否存在self.layer属性,即模型是否包含层
        if getattr(self, "layer", None) is not None:
            # 遍历模型中的每一层
            for layer in self.layer:
                # 使用层的名称作为命名空间
                with tf.name_scope(layer.name):
                    # 调用每一层的build方法,传入None作为输入形状
                    layer.build(None)
class TFBertPooler(keras.layers.Layer):
    def __init__(self, config: BertConfig, **kwargs):
        super().__init__(**kwargs)

        # 定义一个全连接层,用于池化隐藏状态
        self.dense = keras.layers.Dense(
            units=config.hidden_size,
            kernel_initializer=get_initializer(config.initializer_range),
            activation="tanh",
            name="dense",
        )
        self.config = config

    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        # 池化模型的输出,简单地选择第一个 token 对应的隐藏状态
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(inputs=first_token_tensor)

        return pooled_output

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                # 构建全连接层
                self.dense.build([None, None, self.config.hidden_size])


class TFBertPredictionHeadTransform(keras.layers.Layer):
    def __init__(self, config: BertConfig, **kwargs):
        super().__init__(**kwargs)

        # 定义一个全连接层,用于转换隐藏状态
        self.dense = keras.layers.Dense(
            units=config.hidden_size,
            kernel_initializer=get_initializer(config.initializer_range),
            name="dense",
        )

        # 根据配置选择激活函数
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = get_tf_activation(config.hidden_act)
        else:
            self.transform_act_fn = config.hidden_act

        # 使用 LayerNormalization 层对隐藏状态进行规范化
        self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        self.config = config

    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        # 应用全连接层
        hidden_states = self.dense(inputs=hidden_states)
        # 应用激活函数
        hidden_states = self.transform_act_fn(hidden_states)
        # 应用 LayerNormalization
        hidden_states = self.LayerNorm(inputs=hidden_states)

        return hidden_states

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                # 构建全连接层
                self.dense.build([None, None, self.config.hidden_size])
        if getattr(self, "LayerNorm", None) is not None:
            with tf.name_scope(self.LayerNorm.name):
                # 构建 LayerNormalization 层
                self.LayerNorm.build([None, None, self.config.hidden_size])


class TFBertLMPredictionHead(keras.layers.Layer):
    def __init__(self, config: BertConfig, input_embeddings: keras.layers.Layer, **kwargs):
        super().__init__(**kwargs)

        self.config = config
        self.hidden_size = config.hidden_size

        # 创建一个预测头的转换层
        self.transform = TFBertPredictionHeadTransform(config, name="transform")

        # 输出权重与输入嵌入层相同,但每个 token 有一个仅输出的偏置
        self.input_embeddings = input_embeddings
    # 定义一个方法用于构建模型层,接受输入形状参数,默认为None
    def build(self, input_shape=None):
        # 初始化偏置项为零向量,形状与词汇表大小相同,可训练
        self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")

        # 如果模型已构建,则直接返回,避免重复构建
        if self.built:
            return
        self.built = True  # 标记模型已构建

        # 如果有transform属性,使用其名字空间构建transform层
        if getattr(self, "transform", None) is not None:
            with tf.name_scope(self.transform.name):
                self.transform.build(None)

    # 返回输入嵌入层
    def get_output_embeddings(self) -> keras.layers.Layer:
        return self.input_embeddings

    # 设置输出嵌入层,更新权重和词汇表大小
    def set_output_embeddings(self, value: tf.Variable):
        self.input_embeddings.weight = value
        self.input_embeddings.vocab_size = shape_list(value)[0]

    # 返回偏置项作为字典
    def get_bias(self) -> Dict[str, tf.Variable]:
        return {"bias": self.bias}

    # 设置偏置项,更新偏置和词汇表大小
    def set_bias(self, value: tf.Variable):
        self.bias = value["bias"]
        self.config.vocab_size = shape_list(value["bias"])[0]

    # 模型调用函数,接受隐藏状态张量作为输入,返回处理后的张量
    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        # 使用transform层处理隐藏状态
        hidden_states = self.transform(hidden_states=hidden_states)
        seq_length = shape_list(hidden_states)[1]  # 获取序列长度
        hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size])  # 重塑张量形状
        hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)  # 执行矩阵乘法
        hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])  # 再次重塑张量形状
        hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)  # 添加偏置项到张量

        return hidden_states
class TFBertMLMHead(keras.layers.Layer):
    def __init__(self, config: BertConfig, input_embeddings: keras.layers.Layer, **kwargs):
        super().__init__(**kwargs)

        # 使用给定的配置和输入嵌入层创建预测头部对象
        self.predictions = TFBertLMPredictionHead(config, input_embeddings, name="predictions")

    def call(self, sequence_output: tf.Tensor) -> tf.Tensor:
        # 调用预测头部对象来计算序列输出的预测分数
        prediction_scores = self.predictions(hidden_states=sequence_output)

        return prediction_scores

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "predictions", None) is not None:
            with tf.name_scope(self.predictions.name):
                # 构建预测头部对象
                self.predictions.build(None)


class TFBertNSPHead(keras.layers.Layer):
    def __init__(self, config: BertConfig, **kwargs):
        super().__init__(**kwargs)

        # 创建一个密集层来处理序列关系分数的预测
        self.seq_relationship = keras.layers.Dense(
            units=2,
            kernel_initializer=get_initializer(config.initializer_range),
            name="seq_relationship",
        )
        self.config = config

    def call(self, pooled_output: tf.Tensor) -> tf.Tensor:
        # 使用密集层计算池化输出的序列关系分数
        seq_relationship_score = self.seq_relationship(inputs=pooled_output)

        return seq_relationship_score

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "seq_relationship", None) is not None:
            with tf.name_scope(self.seq_relationship.name):
                # 构建密集层,指定输入形状为 [None, None, 隐藏大小]
                self.seq_relationship.build([None, None, self.config.hidden_size])


@keras_serializable
class TFBertMainLayer(keras.layers.Layer):
    config_class = BertConfig

    def __init__(self, config: BertConfig, add_pooling_layer: bool = True, **kwargs):
        super().__init__(**kwargs)

        # 初始化BERT主层对象,配置及是否添加池化层
        self.config = config
        self.is_decoder = config.is_decoder

        # 创建BERT的嵌入层、编码器层和池化层(如果需要的话)
        self.embeddings = TFBertEmbeddings(config, name="embeddings")
        self.encoder = TFBertEncoder(config, name="encoder")
        self.pooler = TFBertPooler(config, name="pooler") if add_pooling_layer else None

    def get_input_embeddings(self) -> keras.layers.Layer:
        # 返回嵌入层对象
        return self.embeddings

    def set_input_embeddings(self, value: tf.Variable):
        # 设置嵌入层的权重和词汇大小
        self.embeddings.weight = value
        self.embeddings.vocab_size = shape_list(value)[0]

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        raise NotImplementedError

    @unpack_inputs
    # 定义一个类方法,用于调用模型。接受多个输入参数,都有默认值为None或False。
    def call(
        self,
        input_ids: TFModelInputType | None = None,  # 输入的token IDs,类型为TFModelInputType或None
        attention_mask: np.ndarray | tf.Tensor | None = None,  # 注意力遮罩,类型为numpy数组、Tensor或None
        token_type_ids: np.ndarray | tf.Tensor | None = None,  # token类型IDs,类型为numpy数组、Tensor或None
        position_ids: np.ndarray | tf.Tensor | None = None,  # 位置IDs,类型为numpy数组、Tensor或None
        head_mask: np.ndarray | tf.Tensor | None = None,  # 头部遮罩,类型为numpy数组、Tensor或None
        inputs_embeds: np.ndarray | tf.Tensor | None = None,  # 输入的嵌入,类型为numpy数组、Tensor或None
        encoder_hidden_states: np.ndarray | tf.Tensor | None = None,  # 编码器隐藏状态,类型为numpy数组、Tensor或None
        encoder_attention_mask: np.ndarray | tf.Tensor | None = None,  # 编码器注意力遮罩,类型为numpy数组、Tensor或None
        past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,  # 过去的键-值对,可选的类型为嵌套元组
        use_cache: Optional[bool] = None,  # 是否使用缓存,可选的布尔值
        output_attentions: Optional[bool] = None,  # 是否输出注意力权重,可选的布尔值
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态,可选的布尔值
        return_dict: Optional[bool] = None,  # 是否返回字典形式结果,可选的布尔值
        training: bool = False,  # 是否处于训练模式,默认为False


    # 构建模型的方法,用于建立模型的各个组件
    def build(self, input_shape=None):
        # 如果模型已经构建完毕,则直接返回
        if self.built:
            return
        # 标记模型为已构建状态
        self.built = True
        # 如果模型具有嵌入层(embeddings),则构建嵌入层
        if getattr(self, "embeddings", None) is not None:
            with tf.name_scope(self.embeddings.name):  # 使用嵌入层名称作为命名空间
                self.embeddings.build(None)  # 构建嵌入层,输入形状为None
        # 如果模型具有编码器(encoder),则构建编码器
        if getattr(self, "encoder", None) is not None:
            with tf.name_scope(self.encoder.name):  # 使用编码器名称作为命名空间
                self.encoder.build(None)  # 构建编码器,输入形状为None
        # 如果模型具有池化器(pooler),则构建池化器
        if getattr(self, "pooler", None) is not None:
            with tf.name_scope(self.pooler.name):  # 使用池化器名称作为命名空间
                self.pooler.build(None)  # 构建池化器,输入形状为None
class TFBertPreTrainedModel(TFPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 设置配置类为BertConfig,用于模型配置
    config_class = BertConfig
    # 指定基础模型的前缀为"bert"
    base_model_prefix = "bert"


@dataclass
class TFBertForPreTrainingOutput(ModelOutput):
    """
    Output type of [`TFBertForPreTraining`].

    Args:
        prediction_logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        seq_relationship_logits (`tf.Tensor` of shape `(batch_size, 2)`):
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
            before SoftMax).
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    # 定义输出类,包括预训练过程中的损失、预测logits、序列关系logits、隐藏状态和注意力
    loss: tf.Tensor | None = None
    prediction_logits: tf.Tensor = None
    seq_relationship_logits: tf.Tensor = None
    hidden_states: Optional[Union[Tuple[tf.Tensor], tf.Tensor]] = None
    attentions: Optional[Union[Tuple[tf.Tensor], tf.Tensor]] = None


BERT_START_DOCSTRING = r"""

    This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
    as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
    behavior.

    <Tip>

    TensorFlow models and layers in `transformers` accept two formats as input:

    - having all inputs as keyword arguments (like PyTorch models), or
    - having all inputs as a list, tuple or dict in the first positional argument.

    The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
    and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
    pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
    format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
    the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
    positional argument:

    - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
    - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
    `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
    - a dictionary with one or several input Tensors associated to the input names given in the docstring:
    `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`

    Note that when creating models and layers with
    [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
    about any of this, as you can just pass inputs like you would to any other Python function!

    </Tip>


Args:
    config ([`BertConfig`]): Model configuration class with all the parameters of the model.
        Initializing with a config file does not load the weights associated with the model, only the
        configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.

"""

BERT_INPUTS_DOCSTRING = r"""
"""

@add_start_docstrings(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
BERT_START_DOCSTRING,
)
class TFBertModel(TFBertPreTrainedModel):
def init(self, config: BertConfig, add_pooling_layer: bool = True, *inputs, **kwargs):
super().init(config, *inputs, **kwargs)

    # 初始化 Bert 主模型层,并设置是否添加池化层
    self.bert = TFBertMainLayer(config, add_pooling_layer, name="bert")

@unpack_inputs
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
    checkpoint=_CHECKPOINT_FOR_DOC,
    output_type=TFBaseModelOutputWithPoolingAndCrossAttentions,
    config_class=_CONFIG_FOR_DOC,
)
def call(
    self,
    input_ids: TFModelInputType | None = None,
    attention_mask: np.ndarray | tf.Tensor | None = None,
    token_type_ids: np.ndarray | tf.Tensor | None = None,
    position_ids: np.ndarray | tf.Tensor | None = None,
    head_mask: np.ndarray | tf.Tensor | None = None,
    inputs_embeds: np.ndarray | tf.Tensor | None = None,
    encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
    encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
    past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    training: Optional[bool] = False,
    **kwargs
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
    """
    Perform the forward pass of the TFBertModel.

    This method overrides the call function in TFBertPreTrainedModel
    to allow for flexible input handling and model output specification.
    """
    # 以下代码为注释部分,解释了每个参数的作用和期望的输入输出类型
    # 参数解释和类型注释由 add_start_docstrings_to_model_forward 和 add_code_sample_docstrings 提供
    pass
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
    r"""
    encoder_hidden_states  (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
        Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
        the model is configured as a decoder.
    encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
        Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
        the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:

        - 1 for tokens that are **not masked**,
        - 0 for tokens that are **masked**.

    past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
        contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
        If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
        don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
        `decoder_input_ids` of shape `(batch_size, sequence_length)`.
    use_cache (`bool`, *optional*, defaults to `True`):
        If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
        `past_key_values`). Set to `False` during training, `True` during generation
    """
    outputs = self.bert(
        input_ids=input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        encoder_hidden_states=encoder_hidden_states,
        encoder_attention_mask=encoder_attention_mask,
        past_key_values=past_key_values,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        training=training,
    )
    return outputs

@add_start_docstrings(
"""
Bert Model with two heads on top as done during the pretraining:
a masked language modeling head and a next sentence prediction (classification) head.
""",
BERT_START_DOCSTRING,
)
class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [
r"position_ids",
r"cls.predictions.decoder.weight",
r"cls.predictions.decoder.bias",
]

def __init__(self, config: BertConfig, *inputs, **kwargs):
    super().__init__(config, *inputs, **kwargs)

    # Initialize the BERT main layer with the provided configuration
    self.bert = TFBertMainLayer(config, name="bert")
    
    # Initialize the Next Sentence Prediction (NSP) head with the provided configuration
    self.nsp = TFBertNSPHead(config, name="nsp___cls")
    
    # Initialize the Masked Language Modeling (MLM) head with the provided configuration,
    # using embeddings from the BERT main layer
    self.mlm = TFBertMLMHead(config, input_embeddings=self.bert.embeddings, name="mlm___cls")

def get_lm_head(self) -> keras.layers.Layer:
    # Return the predictions layer from the MLM head
    return self.mlm.predictions

def get_prefix_bias_name(self) -> str:
    # Deprecated method warning for obtaining the bias name
    warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
    return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name

@unpack_inputs
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call(
    self,
    input_ids: TFModelInputType | None = None,
    attention_mask: np.ndarray | tf.Tensor | None = None,
    token_type_ids: np.ndarray | tf.Tensor | None = None,
    position_ids: np.ndarray | tf.Tensor | None = None,
    head_mask: np.ndarray | tf.Tensor | None = None,
    inputs_embeds: np.ndarray | tf.Tensor | None = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    labels: np.ndarray | tf.Tensor | None = None,
    next_sentence_label: np.ndarray | tf.Tensor | None = None,
    training: Optional[bool] = False,
):
    # Method defining the forward pass of the model
    ...

def build(self, input_shape=None):
    if self.built:
        return
    self.built = True
    if getattr(self, "bert", None) is not None:
        # Build the BERT main layer within its name scope
        with tf.name_scope(self.bert.name):
            self.bert.build(None)
    if getattr(self, "nsp", None) is not None:
        # Build the NSP head within its name scope
        with tf.name_scope(self.nsp.name):
            self.nsp.build(None)
    if getattr(self, "mlm", None) is not None:
        # Build the MLM head within its name scope
        with tf.name_scope(self.mlm.name):
            self.mlm.build(None)

@add_start_docstrings("""Bert Model with a language modeling head on top.""", BERT_START_DOCSTRING)
class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
...
# 在加载时忽略的键列表,这些键是在加载模型时不期望出现的
_keys_to_ignore_on_load_unexpected = [
r"pooler", # 忽略名为"pooler"的键
r"cls.seq_relationship", # 忽略名为"cls.seq_relationship"的键
r"cls.predictions.decoder.weight", # 忽略名为"cls.predictions.decoder.weight"的键
r"nsp___cls", # 忽略名为"nsp___cls"的键
]

# 初始化方法,接受一个BertConfig对象作为参数,以及其他可能的输入和关键字参数
def __init__(self, config: BertConfig, *inputs, **kwargs):
    # 调用父类的初始化方法
    super().__init__(config, *inputs, **kwargs)

    # 如果配置指定为decoder,则发出警告
    if config.is_decoder:
        logger.warning(
            "If you want to use `TFBertForMaskedLM` make sure `config.is_decoder=False` for "
            "bi-directional self-attention."
        )

    # 创建TFBertMainLayer实例,用给定的配置和名称"bert",不添加池化层
    self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
    # 创建TFBertMLMHead实例,用给定的配置和输入嵌入self.bert.embeddings,名称为"mlm___cls"
    self.mlm = TFBertMLMHead(config, input_embeddings=self.bert.embeddings, name="mlm___cls")

# 返回语言模型头部(MLM头部)的Keras层对象
def get_lm_head(self) -> keras.layers.Layer:
    return self.mlm.predictions

# 获取前缀偏置名称,已弃用,发出警告
def get_prefix_bias_name(self) -> str:
    warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
    # 返回包含self.name、self.mlm.name和self.mlm.predictions.name的字符串
    return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name

# 调用方法,接收多个输入参数和关键字参数,包括输入ID、注意力掩码等
@unpack_inputs
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
    checkpoint=_CHECKPOINT_FOR_DOC,
    output_type=TFMaskedLMOutput,
    config_class=_CONFIG_FOR_DOC,
    expected_output="'paris'",
    expected_loss=0.88,
)
def call(
    self,
    input_ids: TFModelInputType | None = None,
    attention_mask: np.ndarray | tf.Tensor | None = None,
    token_type_ids: np.ndarray | tf.Tensor | None = None,
    position_ids: np.ndarray | tf.Tensor | None = None,
    head_mask: np.ndarray | tf.Tensor | None = None,
    inputs_embeds: np.ndarray | tf.Tensor | None = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    labels: np.ndarray | tf.Tensor | None = None,
    training: Optional[bool] = False,
) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
    r"""
    labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
        Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
        config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
        loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
    """
    # 使用 TensorFlow 注解语法指定函数的返回类型,可以是 TFMaskedLMOutput 或包含 tf.Tensor 的元组
    outputs = self.bert(
        input_ids=input_ids,  # 输入的 token IDs
        attention_mask=attention_mask,  # 注意力遮罩,指示哪些 token 是真实的(1)和哪些是填充的(0)
        token_type_ids=token_type_ids,  # 用于区分两个句子的 token 类型 IDs
        position_ids=position_ids,  # 位置编码 IDs,用于指定 token 在序列中的位置
        head_mask=head_mask,  # 多头注意力机制中屏蔽的头部掩码
        inputs_embeds=inputs_embeds,  # 可选的输入嵌入,用于代替输入的 token IDs
        output_attentions=output_attentions,  # 是否返回注意力权重
        output_hidden_states=output_hidden_states,  # 是否返回隐藏状态
        return_dict=return_dict,  # 是否以字典形式返回输出
        training=training,  # 是否处于训练模式
    )
    # 获取 BERT 输出的序列输出
    sequence_output = outputs[0]
    # 将序列输出传递给 MLM 模型进行预测
    prediction_scores = self.mlm(sequence_output=sequence_output, training=training)
    # 如果提供了标签,则计算 MLM 损失
    loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)

    # 如果不返回字典,则按照顺序构造输出元组
    if not return_dict:
        output = (prediction_scores,) + outputs[2:]  # 包含预测分数和额外的输出状态
        return ((loss,) + output) if loss is not None else output

    # 返回 TFMaskedLMOutput 类型的对象,包括损失、预测分数、隐藏状态和注意力权重
    return TFMaskedLMOutput(
        loss=loss,
        logits=prediction_scores,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

def build(self, input_shape=None):
    if self.built:
        return
    self.built = True
    # 如果已经构建过,则直接返回
    if getattr(self, "bert", None) is not None:
        with tf.name_scope(self.bert.name):  # 在 TensorFlow 中使用指定的命名空间构建 BERT 模型
            self.bert.build(None)  # 构建 BERT 模型
    if getattr(self, "mlm", None) is not None:
        with tf.name_scope(self.mlm.name):  # 在 TensorFlow 中使用指定的命名空间构建 MLM 模型
            self.mlm.build(None)  # 构建 MLM 模型

继承自 TFBertPreTrainedModel 和 TFCausalLanguageModelingLoss,实现了 BERT 语言模型的头部部分

class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
# 在从 PyTorch 模型加载到 TensorFlow 模型时,指定的可以忽略的层名称列表
_keys_to_ignore_on_load_unexpected = [
r"pooler", # 忽略名为 "pooler" 的层
r"cls.seq_relationship", # 忽略名为 "cls.seq_relationship" 的层
r"cls.predictions.decoder.weight", # 忽略名为 "cls.predictions.decoder.weight" 的层
r"nsp___cls", # 忽略名为 "nsp___cls" 的层
]

def __init__(self, config: BertConfig, *inputs, **kwargs):
    super().__init__(config, *inputs, **kwargs)

    # 如果配置不是解码器,则发出警告
    if not config.is_decoder:
        logger.warning("If you want to use `TFBertLMHeadModel` as a standalone, add `is_decoder=True.`")

    # 创建 BERT 主层,不添加池化层,命名为 "bert"
    self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
    # 创建 BERT MLM 头部,使用 BERT embeddings 作为输入嵌入,命名为 "mlm___cls"
    self.mlm = TFBertMLMHead(config, input_embeddings=self.bert.embeddings, name="mlm___cls")

# 返回 MLM 头部的预测层
def get_lm_head(self) -> keras.layers.Layer:
    return self.mlm.predictions

# 返回前缀偏置名称的字符串表示(已弃用)
def get_prefix_bias_name(self) -> str:
    warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
    return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name

# 准备生成时的输入,处理输入的形状和注意力掩码
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
    input_shape = input_ids.shape
    # 如果没有提供注意力掩码,则创建全 1 的掩码
    if attention_mask is None:
        attention_mask = tf.ones(input_shape)

    # 如果有过去的键值对被使用,则截取输入的最后一个 token
    if past_key_values is not None:
        input_ids = input_ids[:, -1:]

    # 返回包含输入 ids、注意力掩码和过去键值对的字典
    return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}

# 调用模型时的装饰器,展开输入参数并添加代码示例的文档字符串
@unpack_inputs
@add_code_sample_docstrings(
    checkpoint=_CHECKPOINT_FOR_DOC,
    output_type=TFCausalLMOutputWithCrossAttentions,
    config_class=_CONFIG_FOR_DOC,
)
def call(
    self,
    input_ids: TFModelInputType | None = None,
    attention_mask: np.ndarray | tf.Tensor | None = None,
    token_type_ids: np.ndarray | tf.Tensor | None = None,
    position_ids: np.ndarray | tf.Tensor | None = None,
    head_mask: np.ndarray | tf.Tensor | None = None,
    inputs_embeds: np.ndarray | tf.Tensor | None = None,
    encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
    encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
    past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    labels: np.ndarray | tf.Tensor | None = None,
    training: Optional[bool] = False,
    **kwargs,
# 如果模型已经构建完成,则直接返回,不进行重复构建
if self.built:
    return
# 将标志位设置为已构建状态
self.built = True

# 如果存在属性self.bert且不为None,则构建self.bert模型
if getattr(self, "bert", None) is not None:
    # 使用self.bert的名字作为命名空间,构建self.bert模型
    with tf.name_scope(self.bert.name):
        self.bert.build(None)

# 如果存在属性self.mlm且不为None,则构建self.mlm模型
if getattr(self, "mlm", None) is not None:
    # 使用self.mlm的名字作为命名空间,构建self.mlm模型
    with tf.name_scope(self.mlm.name):
        self.mlm.build(None)

使用装饰器添加模型文档字符串,描述带有顶部“下一个句子预测(分类)”头的Bert模型

@add_start_docstrings(
"""Bert Model with a next sentence prediction (classification) head on top.""",
BERT_START_DOCSTRING,
)
class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredictionLoss):
# 在从PT模型加载TF模型时,指定要忽略的未预期/丢失的层名称列表
_keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"cls.predictions"]

def __init__(self, config: BertConfig, *inputs, **kwargs):
    # 调用父类构造函数初始化模型配置
    super().__init__(config, *inputs, **kwargs)

    # 初始化BERT主层,并命名为“bert”
    self.bert = TFBertMainLayer(config, name="bert")
    # 初始化下一个句子预测头部,并命名为“nsp___cls”
    self.nsp = TFBertNSPHead(config, name="nsp___cls")

# 使用装饰器解包输入和添加模型前向传播的文档字符串
@unpack_inputs
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
# 替换返回值文档字符串,指定输出类型和配置类别
@replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def call(
    self,
    input_ids: TFModelInputType | None = None,
    attention_mask: np.ndarray | tf.Tensor | None = None,
    token_type_ids: np.ndarray | tf.Tensor | None = None,
    position_ids: np.ndarray | tf.Tensor | None = None,
    head_mask: np.ndarray | tf.Tensor | None = None,
    inputs_embeds: np.ndarray | tf.Tensor | None = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    next_sentence_label: np.ndarray | tf.Tensor | None = None,
    training: Optional[bool] = False,
    **kwargs
) -> Union[TFNextSentencePredictorOutput, Tuple[tf.Tensor]]:
    # 实现模型的前向传播逻辑,接受多个输入参数和训练标志
    pass  # 实际实现在此处被省略
# 定义一个函数,用于进行下一句预测。函数返回类型为TFNextSentencePredictorOutput或Tuple[tf.Tensor]
def __call__(
        self,
        input_ids: tf.Tensor,
        attention_mask: Optional[tf.Tensor] = None,
        token_type_ids: Optional[tf.Tensor] = None,
        position_ids: Optional[tf.Tensor] = None,
        head_mask: Optional[tf.Tensor] = None,
        inputs_embeds: Optional[tf.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = True,
        training: Optional[bool] = False,
        next_sentence_label: Optional[tf.Tensor] = None,
    ) -> Union[TFNextSentencePredictorOutput, Tuple[tf.Tensor]]:
        r"""
        返回函数的说明
        示例代码
        """

        # 使用BERT模型进行预测,输出为outputs
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )
        
        # 得到池化后的输出
        pooled_output = outputs[1]
        # 使用NSP模型对池化后的输出进行预测,输出为seq_relationship_scores
        seq_relationship_scores = self.nsp(pooled_output=pooled_output)
        # 如果有下一个句子的标签,计算下一个句子的损失
        next_sentence_loss = (
            None
            if next_sentence_label is None
            else self.hf_compute_loss(labels=next_sentence_label, logits=seq_relationship_scores)
        )
        
        # 如果不返回字典形式的结果,返回序列的得分和其他输出
        if not return_dict:
            output = (seq_relationship_scores,) + outputs[2:]
            return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
        
        # 如果返回字典形式的结果,返回字典类型的输出
        return TFNextSentencePredictorOutput(
            loss=next_sentence_loss,
            logits=seq_relationship_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    # 构建模型
    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果已经建立BERT模型,继续构建BERT模型
        if getattr(self, "bert", None) is not None:
            with tf.name_scope(self.bert.name):
                self.bert.build(None)
        # 如果已经建立NSP模型,继续构建NSP模型
        if getattr(self, "nsp", None) is not None:
            with tf.name_scope(self.nsp.name):
                self.nsp.build(None)

定义一个带有顶部序列分类/回归头的 BERT 模型转换器(在汇总输出之上有一个线性层),例如用于 GLUE 任务

@add_start_docstrings(
"""
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
output) e.g. for GLUE tasks.
""",
BERT_START_DOCSTRING,
)
class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassificationLoss):
# 当从 PT 模型加载 TF 模型时,带 '.' 的名称表示授权的意外/丢失的层
_keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"]
# 当从 PT 模型加载 TF 模型时,缺失的层的名称
_keys_to_ignore_on_load_missing = [r"dropout"]

def __init__(self, config: BertConfig, *inputs, **kwargs):
    super().__init__(config, *inputs, **kwargs)

    # 设置模型标签数
    self.num_labels = config.num_labels

    # 使用 TF 的 Bert 主层初始化 BERT 模型
    self.bert = TFBertMainLayer(config, name="bert")

    # 设置分类器的丢弃率
    classifier_dropout = (
        config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
    )
    # 添加丢弃层
    self.dropout = keras.layers.Dropout(rate=classifier_dropout)
    # 定义分类器层
    self.classifier = keras.layers.Dense(
        units=config.num_labels,
        kernel_initializer=get_initializer(config.initializer_range),
        name="classifier",
    )
    # 保存配置
    self.config = config

# 调用模型的前向传播方法,用于处理输入并返回相应的输出和损失
@unpack_inputs
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
    checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
    output_type=TFSequenceClassifierOutput,
    config_class=_CONFIG_FOR_DOC,
    expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
    expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
)
def call(
    self,
    input_ids: TFModelInputType | None = None,
    attention_mask: np.ndarray | tf.Tensor | None = None,
    token_type_ids: np.ndarray | tf.Tensor | None = None,
    position_ids: np.ndarray | tf.Tensor | None = None,
    head_mask: np.ndarray | tf.Tensor | None = None,
    inputs_embeds: np.ndarray | tf.Tensor | None = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    labels: np.ndarray | tf.Tensor | None = None,
    training: Optional[bool] = False,
    **kwargs,
):
    """
    BERT 模型的前向传播方法,处理输入并返回相应的输出和损失。

    Args:
        input_ids (TFModelInputType | None, optional): 输入的 token IDs. Defaults to None.
        attention_mask (np.ndarray | tf.Tensor | None, optional): 注意力遮罩. Defaults to None.
        token_type_ids (np.ndarray | tf.Tensor | None, optional): token 类型 IDs. Defaults to None.
        position_ids (np.ndarray | tf.Tensor | None, optional): 位置 IDs. Defaults to None.
        head_mask (np.ndarray | tf.Tensor | None, optional): 头部遮罩. Defaults to None.
        inputs_embeds (np.ndarray | tf.Tensor | None, optional): 输入的嵌入. Defaults to None.
        output_attentions (Optional[bool], optional): 是否输出注意力. Defaults to None.
        output_hidden_states (Optional[bool], optional): 是否输出隐藏状态. Defaults to None.
        return_dict (Optional[bool], optional): 是否返回字典. Defaults to None.
        labels (np.ndarray | tf.Tensor | None, optional): 标签. Defaults to None.
        training (Optional[bool], optional): 是否训练模式. Defaults to False.
        **kwargs: 其他关键字参数.

    Returns:
        TFSequenceClassifierOutput: 序列分类器输出对象.
    """
    # 处理输入以获取模型的输出
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    # 调用 BERT 模型的前向传播
    outputs = self.bert(
        input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        training=training,
        **kwargs,
    )

    # 对 BERT 输出进行丢弃操作
    pooled_output = outputs[1]  # 汇总输出
    pooled_output = self.dropout(pooled_output, training=training)
    # 经过分类器层得到最终输出
    logits = self.classifier(pooled_output)

    # 准备模型的输出
    loss = None
    if labels is not None:
        if self.num_labels == 1:
            # 单标签分类任务
            loss_fn = tf.keras.losses.MeanSquaredError()
            loss = loss_fn(labels, logits)
        else:
            # 多标签分类任务
            loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
            loss = loss_fn(labels, logits)

    if not return_dict:
        # 返回不同的输出对象
        output = (logits,) + outputs[2:]
        return ((loss,) + output) if loss is not None else output

    # 返回序列分类器输出对象
    return TFSequenceClassifierOutput(
        loss=loss,
        logits=logits,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )
) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
    r"""
    labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
        Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
        config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
        `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
    """
    # 调用BERT模型进行前向传播,获取输出
    outputs = self.bert(
        input_ids=input_ids,  # 输入的token IDs
        attention_mask=attention_mask,  # 注意力掩码,用于指示哪些token需要注意,哪些不需要
        token_type_ids=token_type_ids,  # token类型IDs,用于区分segment A和segment B
        position_ids=position_ids,  # token的位置IDs,指示每个token在序列中的位置
        head_mask=head_mask,  # 头部掩码,用于控制哪些attention头是有效的
        inputs_embeds=inputs_embeds,  # 输入的嵌入表示,代替输入的token IDs
        output_attentions=output_attentions,  # 是否输出attention权重
        output_hidden_states=output_hidden_states,  # 是否输出所有层的隐藏状态
        return_dict=return_dict,  # 返回类型,是否以字典形式返回输出
        training=training,  # 是否处于训练模式
    )
    pooled_output = outputs[1]  # 获取汇聚输出,通常是CLS token的表示
    pooled_output = self.dropout(inputs=pooled_output, training=training)  # 对汇聚输出进行dropout处理
    logits = self.classifier(inputs=pooled_output)  # 使用分类器对汇聚输出进行分类预测
    loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)  # 如果有标签,则计算损失

    if not return_dict:
        output = (logits,) + outputs[2:]  # 构造输出元组,包括logits和可能的额外输出
        return ((loss,) + output) if loss is not None else output  # 如果有损失,返回损失和输出,否则只返回输出

    # 如果return_dict为True,以TFSequenceClassifierOutput形式返回输出
    return TFSequenceClassifierOutput(
        loss=loss,  # 损失
        logits=logits,  # 预测的logits
        hidden_states=outputs.hidden_states,  # 所有层的隐藏状态
        attentions=outputs.attentions,  # 所有层的注意力权重
    )

def build(self, input_shape=None):
    if self.built:
        return
    self.built = True
    if getattr(self, "bert", None) is not None:
        with tf.name_scope(self.bert.name):
            self.bert.build(None)  # 构建BERT模型
    if getattr(self, "classifier", None) is not None:
        with tf.name_scope(self.classifier.name):
            self.classifier.build([None, None, self.config.hidden_size])  # 构建分类器模型

"""
Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
softmax) e.g. for RocStories/SWAG tasks.
"""

继承自TFBertPreTrainedModel和TFMultipleChoiceLoss,实现Bert模型添加多选分类头部

class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):

# 当从PyTorch模型加载TF模型时,忽略的预期未知/丢失的层名称列表
_keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"]
# 当从PyTorch模型加载TF模型时,忽略的缺失层名称列表
_keys_to_ignore_on_load_missing = [r"dropout"]

def __init__(self, config: BertConfig, *inputs, **kwargs):
    super().__init__(config, *inputs, **kwargs)

    # 初始化Bert主层,并命名为"bert"
    self.bert = TFBertMainLayer(config, name="bert")
    # 初始化Dropout层,使用配置中的隐藏层Dropout概率
    self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
    # 初始化分类器Dense层,单元数为1,使用给定的初始化器范围初始化权重,并命名为"classifier"
    self.classifier = keras.layers.Dense(
        units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
    )
    self.config = config

@unpack_inputs
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
    checkpoint=_CHECKPOINT_FOR_DOC,
    output_type=TFMultipleChoiceModelOutput,
    config_class=_CONFIG_FOR_DOC,
)
# 定义模型前向传播方法,接受一系列输入参数
def call(
    self,
    input_ids: TFModelInputType | None = None,
    attention_mask: np.ndarray | tf.Tensor | None = None,
    token_type_ids: np.ndarray | tf.Tensor | None = None,
    position_ids: np.ndarray | tf.Tensor | None = None,
    head_mask: np.ndarray | tf.Tensor | None = None,
    inputs_embeds: np.ndarray | tf.Tensor | None = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    labels: np.ndarray | tf.Tensor | None = None,
    training: Optional[bool] = False,
    # 参数类型和默认值的注释
) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
    r"""
    定义函数的返回类型,可以是 TFMultipleChoiceModelOutput 或包含 tf.Tensor 的元组
    """
    if input_ids is not None:
        # 获取 input_ids 的形状列表,并取第二个维度的大小作为 num_choices
        num_choices = shape_list(input_ids)[1]
        # 获取 input_ids 的第三个维度的大小作为 seq_length
        seq_length = shape_list(input_ids)[2]
    else:
        # 如果 input_ids 为 None,则使用 inputs_embeds 的形状列表中的第二个维度作为 num_choices
        num_choices = shape_list(inputs_embeds)[1]
        # 使用 inputs_embeds 的第三个维度的大小作为 seq_length
        seq_length = shape_list(inputs_embeds)[2]

    # 如果 input_ids 不为 None,则将其重塑为 (-1, seq_length) 的形状,否则为 None
    flat_input_ids = tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None
    # 如果 attention_mask 不为 None,则将其重塑为 (-1, seq_length) 的形状,否则为 None
    flat_attention_mask = tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None
    # 如果 token_type_ids 不为 None,则将其重塑为 (-1, seq_length) 的形状,否则为 None
    flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None
    # 如果 position_ids 不为 None,则将其重塑为 (-1, seq_length) 的形状,否则为 None
    flat_position_ids = tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None
    # 如果 inputs_embeds 不为 None,则将其重塑为 (-1, seq_length, inputs_embeds 的第四个维度大小) 的形状,否则为 None
    flat_inputs_embeds = tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3])) if inputs_embeds is not None else None
    # 调用 BERT 模型,传递平铺后的输入及其他参数,并获取输出结果
    outputs = self.bert(
        input_ids=flat_input_ids,
        attention_mask=flat_attention_mask,
        token_type_ids=flat_token_type_ids,
        position_ids=flat_position_ids,
        head_mask=head_mask,
        inputs_embeds=flat_inputs_embeds,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        training=training,
    )
    # 从 BERT 输出中获取池化后的输出
    pooled_output = outputs[1]
    # 使用 dropout 方法对池化输出进行处理,根据 training 参数决定是否训练
    pooled_output = self.dropout(inputs=pooled_output, training=training)
    # 使用分类器对处理后的 pooled_output 进行分类预测
    logits = self.classifier(inputs=pooled_output)
    # 将 logits 重塑为 (-1, num_choices) 的形状
    reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
    # 如果 labels 不为 None,则计算损失,否则损失为 None
    loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits)

    # 如果 return_dict 为 False,则返回格式化后的输出
    if not return_dict:
        output = (reshaped_logits,) + outputs[2:]
        return ((loss,) + output) if loss is not None else output

    # 如果 return_dict 为 True,则返回 TFMultipleChoiceModelOutput 对象,包含损失、logits、隐藏状态和注意力分布
    return TFMultipleChoiceModelOutput(
        loss=loss,
        logits=reshaped_logits,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )
# 定义神经网络模型的构建方法,如果已经构建过,则直接返回
def build(self, input_shape=None):
    if self.built:
        return
    # 标记模型为已构建状态
    self.built = True
    # 如果存在名为"bert"的属性,则构建BERT模型
    if getattr(self, "bert", None) is not None:
        # 在命名空间中构建BERT模型
        with tf.name_scope(self.bert.name):
            self.bert.build(None)
    # 如果存在名为"classifier"的属性,则构建分类器模型
    if getattr(self, "classifier", None) is not None:
        # 在命名空间中构建分类器模型,期望输入形状为[None, None, self.config.hidden_size]
        with tf.name_scope(self.classifier.name):
            self.classifier.build([None, None, self.config.hidden_size])

使用装饰器为类添加文档字符串,描述其作为一个在 Bert 模型上加入了一个标记分类头部的模型,用于命名实体识别 (NER) 等任务

class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
# 当从 PyTorch 模型加载到 TF 模型时,忽略的意外/缺失的层的名称列表,包含不匹配的层名
_keys_to_ignore_on_load_unexpected = [
r"pooler",
r"mlm___cls",
r"nsp___cls",
r"cls.predictions",
r"cls.seq_relationship",
]
# 当从 PyTorch 模型加载到 TF 模型时,忽略的缺失的层的名称列表,包含缺少的层名
_keys_to_ignore_on_load_missing = [r"dropout"]

def __init__(self, config: BertConfig, *inputs, **kwargs):
    # 调用父类构造函数,传递配置和其他输入参数
    super().__init__(config, *inputs, **kwargs)

    # 记录标签的数量
    self.num_labels = config.num_labels

    # 初始化 BERT 主层,禁用添加池化层,命名为 "bert"
    self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")

    # 设置分类器的丢弃率为配置中的分类器丢弃率或者隐藏层丢弃率,如果配置中未指定分类器丢弃率,则使用隐藏层丢弃率
    classifier_dropout = (
        config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
    )
    # 添加一个丢弃层
    self.dropout = keras.layers.Dropout(rate=classifier_dropout)
    # 添加一个全连接层作为分类器,单元数为配置中的标签数量,初始化器使用配置中的初始化范围
    self.classifier = keras.layers.Dense(
        units=config.num_labels,
        kernel_initializer=get_initializer(config.initializer_range),
        name="classifier",
    )
    # 记录配置对象
    self.config = config

# 使用装饰器为模型的前向传播方法添加文档字符串,描述其输入和输出,以及模型的用法示例和预期输出
@unpack_inputs
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
    checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION,
    output_type=TFTokenClassifierOutput,
    config_class=_CONFIG_FOR_DOC,
    expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT,
    expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
)
def call(
    self,
    input_ids: TFModelInputType | None = None,
    attention_mask: np.ndarray | tf.Tensor | None = None,
    token_type_ids: np.ndarray | tf.Tensor | None = None,
    position_ids: np.ndarray | tf.Tensor | None = None,
    head_mask: np.ndarray | tf.Tensor | None = None,
    inputs_embeds: np.ndarray | tf.Tensor | None = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    labels: np.ndarray | tf.Tensor | None = None,
    training: Optional[bool] = False,
) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
    r"""
    定义函数签名和返回类型注解,此函数可以返回 TFTokenClassifierOutput 或包含 tf.Tensor 的元组。
    labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
        用于计算标记分类损失的标签。索引应在 `[0, ..., config.num_labels - 1]` 范围内。
    """
    # 使用 BERT 模型处理输入数据,并获取输出结果
    outputs = self.bert(
        input_ids=input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        training=training,
    )
    # 从 BERT 输出中获取序列输出
    sequence_output = outputs[0]
    # 根据训练状态应用 dropout 操作,用于防止过拟合
    sequence_output = self.dropout(inputs=sequence_output, training=training)
    # 使用分类器模型对序列输出进行分类,生成 logits
    logits = self.classifier(inputs=sequence_output)
    # 如果有标签,则计算损失,否则损失为 None
    loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)

    # 如果 return_dict 为 False,则按照非字典格式返回结果
    if not return_dict:
        output = (logits,) + outputs[2:]
        return ((loss,) + output) if loss is not None else output

    # 如果 return_dict 为 True,则以 TFTokenClassifierOutput 格式返回结果
    return TFTokenClassifierOutput(
        loss=loss,
        logits=logits,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

# 构建模型的方法
def build(self, input_shape=None):
    # 如果模型已经构建完成,则直接返回
    if self.built:
        return
    # 将模型标记为已构建
    self.built = True
    # 如果存在名为 bert 的模型,则在 bert 命名空间下构建它
    if getattr(self, "bert", None) is not None:
        with tf.name_scope(self.bert.name):
            self.bert.build(None)
    # 如果存在名为 classifier 的模型,则在 classifier 命名空间下构建它
    if getattr(self, "classifier", None) is not None:
        with tf.name_scope(self.classifier.name):
            self.classifier.build([None, None, self.config.hidden_size])

"""
Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layer on top of the hidden-states output to compute span start logits and span end logits).
"""

引入函数装饰器,用于向模型添加文档字符串

@add_start_docstrings(
"""
Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layer on top of the hidden-states output to compute span start logits and span end logits).
""",
BERT_START_DOCSTRING,
)

声明 TF 模型类 TFBertForQuestionAnswering,继承自 TFBertPreTrainedModel 和 TFQuestionAnsweringLoss

class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
# 在从 PT 模型加载 TF 模型时,指定忽略的层名称正则表达式列表
_keys_to_ignore_on_load_unexpected = [
r"pooler",
r"mlm___cls",
r"nsp___cls",
r"cls.predictions",
r"cls.seq_relationship",
]

# 初始化方法,接受一个 BertConfig 对象和其他可选输入参数
def __init__(self, config: BertConfig, *inputs, **kwargs):
    super().__init__(config, *inputs, **kwargs)

    # 将配置中的标签数量赋值给实例变量 num_labels
    self.num_labels = config.num_labels

    # 创建一个 TFBertMainLayer 实例,用于 BERT 的主要层,不包含池化层,命名为 "bert"
    self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
    
    # 创建一个全连接层 Dense 实例,用于 QA 输出,指定单元数为配置中的标签数量,
    # 使用指定范围内的初始化器来初始化权重,命名为 "qa_outputs"
    self.qa_outputs = keras.layers.Dense(
        units=config.num_labels,
        kernel_initializer=get_initializer(config.initializer_range),
        name="qa_outputs",
    )
    
    # 将配置对象保存为实例变量
    self.config = config

# 使用装饰器声明 call 方法,定义模型的前向传播逻辑
@unpack_inputs
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
    checkpoint=_CHECKPOINT_FOR_QA,
    output_type=TFQuestionAnsweringModelOutput,
    config_class=_CONFIG_FOR_DOC,
    qa_target_start_index=_QA_TARGET_START_INDEX,
    qa_target_end_index=_QA_TARGET_END_INDEX,
    expected_output=_QA_EXPECTED_OUTPUT,
    expected_loss=_QA_EXPECTED_LOSS,
)
def call(
    self,
    input_ids: TFModelInputType | None = None,
    attention_mask: np.ndarray | tf.Tensor | None = None,
    token_type_ids: np.ndarray | tf.Tensor | None = None,
    position_ids: np.ndarray | tf.Tensor | None = None,
    head_mask: np.ndarray | tf.Tensor | None = None,
    inputs_embeds: np.ndarray | tf.Tensor | None = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    start_positions: np.ndarray | tf.Tensor | None = None,
    end_positions: np.ndarray | tf.Tensor | None = None,
    training: Optional[bool] = False,
) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
    r"""
    start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
        Labels for position (index) of the start of the labelled span for computing the token classification loss.
        Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
        are not taken into account for computing the loss.
    end_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
        Labels for position (index) of the end of the labelled span for computing the token classification loss.
        Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
        are not taken into account for computing the loss.
    """
    outputs = self.bert(
        input_ids=input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        training=training,
    )
    # 获取BERT模型的输出,包括序列输出
    sequence_output = outputs[0]
    # 将序列输出传递给QA输出层进行预测
    logits = self.qa_outputs(inputs=sequence_output)
    # 将预测的logits张量按照最后一个维度分割成起始和结束位置的预测
    start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1)
    # 去除多余的维度,使得张量的维度降低为2
    start_logits = tf.squeeze(input=start_logits, axis=-1)
    end_logits = tf.squeeze(input=end_logits, axis=-1)
    # 初始化损失变量为None
    loss = None

    # 如果提供了起始位置和结束位置的标签,则计算损失
    if start_positions is not None and end_positions is not None:
        labels = {"start_position": start_positions}
        labels["end_position"] = end_positions
        # 使用Hugging Face的损失计算函数计算损失
        loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))

    # 如果不要求返回字典格式的输出,则根据条件返回输出结果
    if not return_dict:
        output = (start_logits, end_logits) + outputs[2:]
        return ((loss,) + output) if loss is not None else output

    # 返回TFQuestionAnsweringModelOutput类型的输出,包括损失、起始和结束位置的logits以及隐藏状态和注意力
    return TFQuestionAnsweringModelOutput(
        loss=loss,
        start_logits=start_logits,
        end_logits=end_logits,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

def build(self, input_shape=None):
    # 如果模型已经构建完毕,则直接返回
    if self.built:
        return
    # 将模型标记为已构建状态
    self.built = True
    # 如果BERT模型存在,则构建BERT模型
    if getattr(self, "bert", None) is not None:
        with tf.name_scope(self.bert.name):
            self.bert.build(None)
    # 如果QA输出层存在,则构建QA输出层
    if getattr(self, "qa_outputs", None) is not None:
        with tf.name_scope(self.qa_outputs.name):
            self.qa_outputs.build([None, None, self.config.hidden_size])

# `.\models\bert\tokenization_bert.py`

```py
# 指定编码为 UTF-8

# 版权声明,版权归Google AI Language Team和HuggingFace Inc.团队所有,使用Apache License 2.0授权
# 只有在遵守许可证的情况下才能使用此文件
# 可以在以下网址获取许可证的副本:http://www.apache.org/licenses/LICENSE-2.0

# 如果适用法律要求或书面同意,软件将按“原样”分发,不提供任何明示或暗示的保证或条件
# 请参阅许可证以了解详细信息

"""Bert的标记化类。"""

# 导入所需模块
import collections  # 导入collections模块
import os  # 导入os模块
import unicodedata  # 导入unicodedata模块
from typing import List, Optional, Tuple  # 导入类型提示所需的模块

# 从tokenization_utils.py中导入预训练的标记器和一些辅助函数
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace

# 导入日志记录功能
from ...utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义词汇文件的名称,这里是一个包含词汇的文本文件
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}

# 预训练词汇文件的映射,这里假设只有一个vocab_file键,对应的值是vocab.txt文件名
PRETRAINED_VOCAB_FILES_MAP = {
    {
        "vocab_file": {
            "google-bert/bert-base-uncased": "https://huggingface.co/google-bert/bert-base-uncased/resolve/main/vocab.txt",
            "google-bert/bert-large-uncased": "https://huggingface.co/google-bert/bert-large-uncased/resolve/main/vocab.txt",
            "google-bert/bert-base-cased": "https://huggingface.co/google-bert/bert-base-cased/resolve/main/vocab.txt",
            "google-bert/bert-large-cased": "https://huggingface.co/google-bert/bert-large-cased/resolve/main/vocab.txt",
            "google-bert/bert-base-multilingual-uncased": (
                "https://huggingface.co/google-bert/bert-base-multilingual-uncased/resolve/main/vocab.txt"
            ),
            "google-bert/bert-base-multilingual-cased": "https://huggingface.co/google-bert/bert-base-multilingual-cased/resolve/main/vocab.txt",
            "google-bert/bert-base-chinese": "https://huggingface.co/google-bert/bert-base-chinese/resolve/main/vocab.txt",
            "google-bert/bert-base-german-cased": "https://huggingface.co/google-bert/bert-base-german-cased/resolve/main/vocab.txt",
            "google-bert/bert-large-uncased-whole-word-masking": (
                "https://huggingface.co/google-bert/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt"
            ),
            "google-bert/bert-large-cased-whole-word-masking": (
                "https://huggingface.co/google-bert/bert-large-cased-whole-word-masking/resolve/main/vocab.txt"
            ),
            "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad": (
                "https://huggingface.co/google-bert/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt"
            ),
            "google-bert/bert-large-cased-whole-word-masking-finetuned-squad": (
                "https://huggingface.co/google-bert/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt"
            ),
            "google-bert/bert-base-cased-finetuned-mrpc": (
                "https://huggingface.co/google-bert/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt"
            ),
            "google-bert/bert-base-german-dbmdz-cased": "https://huggingface.co/google-bert/bert-base-german-dbmdz-cased/resolve/main/vocab.txt",
            "google-bert/bert-base-german-dbmdz-uncased": (
                "https://huggingface.co/google-bert/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt"
            ),
            "TurkuNLP/bert-base-finnish-cased-v1": (
                "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt"
            ),
            "TurkuNLP/bert-base-finnish-uncased-v1": (
                "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt"
            ),
            "wietsedv/bert-base-dutch-cased": (
                "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt"
            ),
        }
    }
    
    
    注释:
    
    # vocab_file 是一个包含不同 BERT 模型及其对应词汇表 URL 的字典
    {
        "google-bert/bert-base-uncased": "https://huggingface.co/google-bert/bert-base-uncased/resolve/main/vocab.txt",  # Google BERT base uncased 模型的词汇表 URL
        "google-bert/bert-large-uncased": "https://huggingface.co/google-bert/bert-large-uncased/resolve/main/vocab.txt",  # Google BERT large uncased 模型的词汇表 URL
        "google-bert/bert-base-cased": "https://huggingface.co/google-bert/bert-base-cased/resolve/main/vocab.txt",  # Google BERT base cased 模型的词汇表 URL
        "google-bert/bert-large-cased": "https://huggingface.co/google-bert/bert-large-cased/resolve/main/vocab.txt",  # Google BERT large cased 模型的词汇表 URL
        "google-bert/bert-base-multilingual-uncased": (
            "https://huggingface.co/google-bert/bert-base-multilingual-uncased/resolve/main/vocab.txt"  # Google BERT base 多语言 uncased 模型的词汇表 URL
        ),
        "google-bert/bert-base-multilingual-cased": "https://huggingface.co/google-bert/bert-base-multilingual-cased/resolve/main/vocab.txt",  # Google BERT base 多语言 cased 模型的词汇表 URL
        "google-bert/bert-base-chinese": "https://huggingface.co/google-bert/bert-base-chinese/resolve/main/vocab.txt",  # Google BERT base 中文模型的词汇表 URL
        "google-bert/bert-base-german-cased": "https://huggingface.co/google-bert/bert-base-german-cased/resolve/main/vocab.txt",  # Google BERT base 德语 cased 模型的词汇表 URL
        "google-bert/bert-large-uncased-whole-word-masking": (
            "https://huggingface.co/google-bert/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt"  # Google BERT large uncased 整词屏蔽模型的词汇表 URL
        ),
        "google-bert/bert-large-cased-whole-word-masking": (
            "https://huggingface.co/google-bert/bert-large-cased-whole-word-masking/resolve/main/vocab.txt"  # Google BERT large cased 整词屏蔽模型的词汇表 URL
        ),
        "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad": (
            "https://huggingface.co/google-bert/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt"  # Google BERT large uncased 整词屏蔽模型(在 SQuAD 上微调)的词汇表 URL
        ),
        "google-bert/bert-large-cased-whole-word-masking-finetuned-squad": (
            "https://huggingface.co/google-bert/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt"  # Google BERT large cased 整词屏蔽模型(在 SQuAD 上微调)的词汇表 URL
        ),
        "google-bert/bert-base-cased-finetuned-mrpc": (
            "https://huggingface.co/google-bert/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt"  # Google BERT base cased 模型(在 MRPC 数据集上微调)的词汇表 URL
        ),
        "google-bert/bert-base-german-dbmdz-cased": "https://huggingface.co/google-bert/bert-base-german-dbmdz-cased/resolve/main/vocab.txt",  # Google BERT base 德语(由 DBMDZ 组织提供,cased)模型的词汇表 URL
        "google-bert/bert-base-german-dbmdz-uncased": (
            "https://huggingface.co/google-bert/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt"  # Google BERT base 德语(由 DBMDZ 组织提供,uncased)模型的词汇表 URL
        ),
        "TurkuNLP/bert-base-finnish-cased-v1": (
            "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt"  # TurkuNLP 提供的芬兰语 cased BERT base v1 模型的词汇表 URL
        ),
        "TurkuNLP/bert-base-finnish-uncased-v1": (
            "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt"  # TurkuNLP 提供的芬兰语 uncased BERT base v1 模型的词汇表 URL
        ),
        "wietsedv/bert-base-dutch-cased": (
            "
}

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "google-bert/bert-base-uncased": 512,  # 设置预训练模型的位置嵌入尺寸
    "google-bert/bert-large-uncased": 512,
    "google-bert/bert-base-cased": 512,
    "google-bert/bert-large-cased": 512,
    "google-bert/bert-base-multilingual-uncased": 512,
    "google-bert/bert-base-multilingual-cased": 512,
    "google-bert/bert-base-chinese": 512,
    "google-bert/bert-base-german-cased": 512,
    "google-bert/bert-large-uncased-whole-word-masking": 512,
    "google-bert/bert-large-cased-whole-word-masking": 512,
    "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad": 512,
    "google-bert/bert-large-cased-whole-word-masking-finetuned-squad": 512,
    "google-bert/bert-base-cased-finetuned-mrpc": 512,
    "google-bert/bert-base-german-dbmdz-cased": 512,
    "google-bert/bert-base-german-dbmdz-uncased": 512,
    "TurkuNLP/bert-base-finnish-cased-v1": 512,
    "TurkuNLP/bert-base-finnish-uncased-v1": 512,
    "wietsedv/bert-base-dutch-cased": 512,
}

PRETRAINED_INIT_CONFIGURATION = {
    "google-bert/bert-base-uncased": {"do_lower_case": True},  # 配置预训练模型初始化参数
    "google-bert/bert-large-uncased": {"do_lower_case": True},
    "google-bert/bert-base-cased": {"do_lower_case": False},
    "google-bert/bert-large-cased": {"do_lower_case": False},
    "google-bert/bert-base-multilingual-uncased": {"do_lower_case": True},
    "google-bert/bert-base-multilingual-cased": {"do_lower_case": False},
    "google-bert/bert-base-chinese": {"do_lower_case": False},
    "google-bert/bert-base-german-cased": {"do_lower_case": False},
    "google-bert/bert-large-uncased-whole-word-masking": {"do_lower_case": True},
    "google-bert/bert-large-cased-whole-word-masking": {"do_lower_case": False},
    "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
    "google-bert/bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
    "google-bert/bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
    "google-bert/bert-base-german-dbmdz-cased": {"do_lower_case": False},
    "google-bert/bert-base-german-dbmdz-uncased": {"do_lower_case": True},
    "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
    "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
    "wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
}


def load_vocab(vocab_file):
    """Loads a vocabulary file into a dictionary."""
    vocab = collections.OrderedDict()  # 创建一个有序字典用于存储词汇表
    with open(vocab_file, "r", encoding="utf-8") as reader:
        tokens = reader.readlines()  # 读取词汇文件中的所有行
    for index, token in enumerate(tokens):
        token = token.rstrip("\n")  # 去除每个词汇的换行符
        vocab[token] = index  # 将词汇添加到字典中,键为词汇,值为索引
    return vocab  # 返回加载后的词汇表字典


def whitespace_tokenize(text):
    """Runs basic whitespace cleaning and splitting on a piece of text."""
    text = text.strip()  # 去除文本首尾空白字符
    if not text:
        return []  # 如果文本为空,则返回空列表
    tokens = text.split()  # 使用空格分割文本生成词汇列表
    return tokens  # 返回分割后的词汇列表


class BertTokenizer(PreTrainedTokenizer):
    r"""
    Construct a BERT tokenizer. Based on WordPiece.
    """
    # 从`PreTrainedTokenizer`继承,该类包含大多数主要方法。用户应参考这个超类以获取关于这些方法的更多信息。

    # 参数:
    # vocab_file (`str`):
    #     包含词汇表的文件。
    # do_lower_case (`bool`, *可选*, 默认为 `True`):
    #     在标记化时是否将输入转换为小写。
    # do_basic_tokenize (`bool`, *可选*, 默认为 `True`):
    #     是否在使用WordPiece之前进行基本的标记化。
    # never_split (`Iterable`, *可选*):
    #     在标记化时永远不会分割的一组标记。仅在 `do_basic_tokenize=True` 时有效。
    # unk_token (`str`, *可选*, 默认为 `"[UNK]"`):
    #     未知标记。词汇表中不存在的标记无法转换为ID,并将被设置为此标记。
    # sep_token (`str`, *可选*, 默认为 `"[SEP]"`):
    #     分隔符标记,在构建来自多个序列的序列时使用,例如用于序列分类或用于文本和问题的问题回答。在使用特殊标记构建的序列的最后一个标记也会使用此标记。
    # pad_token (`str`, *可选*, 默认为 `"[PAD]"`):
    #     用于填充的标记,例如在批处理不同长度的序列时使用。
    # cls_token (`str`, *可选*, 默认为 `"[CLS]"`):
    #     分类器标记,在进行序列分类(整个序列的分类而不是每个标记的分类)时使用。在使用特殊标记构建的序列的第一个标记。
    # mask_token (`str`, *可选*, 默认为 `"[MASK]"`):
    #     用于屏蔽值的标记。这是在使用掩蔽语言建模训练模型时使用的标记。模型将尝试预测此标记。
    # tokenize_chinese_chars (`bool`, *可选*, 默认为 `True`):
    #     是否标记化中文字符。
    #     对于日文,这可能应该停用(参见此[问题](https://github.com/huggingface/transformers/issues/328))。
    # strip_accents (`bool`, *可选*):
    #     是否删除所有重音符号。如果未指定此选项,则将根据 `lowercase` 的值确定(与原始BERT相同)。
    vocab_files_names = VOCAB_FILES_NAMES
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    # 初始化方法,用于初始化一个Tokenizer对象
    def __init__(
        self,
        vocab_file,
        do_lower_case=True,
        do_basic_tokenize=True,
        never_split=None,
        unk_token="[UNK]",
        sep_token="[SEP]",
        pad_token="[PAD]",
        cls_token="[CLS]",
        mask_token="[MASK]",
        tokenize_chinese_chars=True,
        strip_accents=None,
        **kwargs,
    ):
        # 检查给定的词汇文件是否存在,如果不存在则抛出异常
        if not os.path.isfile(vocab_file):
            raise ValueError(
                f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
                " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
            )
        # 加载词汇表文件到self.vocab中
        self.vocab = load_vocab(vocab_file)
        # 根据加载的词汇表构建从id到token的有序字典
        self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
        # 是否进行基本的tokenize操作
        self.do_basic_tokenize = do_basic_tokenize
        # 如果需要进行基本tokenize,则初始化BasicTokenizer对象
        if do_basic_tokenize:
            self.basic_tokenizer = BasicTokenizer(
                do_lower_case=do_lower_case,
                never_split=never_split,
                tokenize_chinese_chars=tokenize_chinese_chars,
                strip_accents=strip_accents,
            )

        # 初始化WordpieceTokenizer对象,使用加载的词汇表和未知token
        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))

        # 调用父类的初始化方法,传递相同的参数和额外的参数
        super().__init__(
            do_lower_case=do_lower_case,
            do_basic_tokenize=do_basic_tokenize,
            never_split=never_split,
            unk_token=unk_token,
            sep_token=sep_token,
            pad_token=pad_token,
            cls_token=cls_token,
            mask_token=mask_token,
            tokenize_chinese_chars=tokenize_chinese_chars,
            strip_accents=strip_accents,
            **kwargs,
        )

    # 属性方法,返回是否进行小写处理的标志位
    @property
    def do_lower_case(self):
        return self.basic_tokenizer.do_lower_case

    # 属性方法,返回词汇表的大小
    @property
    def vocab_size(self):
        return len(self.vocab)

    # 方法,返回包含所有词汇和特殊token编码的字典
    def get_vocab(self):
        return dict(self.vocab, **self.added_tokens_encoder)

    # 方法,对文本进行tokenize操作,返回token列表
    def _tokenize(self, text, split_special_tokens=False):
        split_tokens = []
        # 如果需要进行基本tokenize操作
        if self.do_basic_tokenize:
            # 遍历基本tokenizer的tokenize结果
            for token in self.basic_tokenizer.tokenize(
                text, never_split=self.all_special_tokens if not split_special_tokens else None
            ):
                # 如果token在不分割集合中,则直接加入split_tokens列表
                if token in self.basic_tokenizer.never_split:
                    split_tokens.append(token)
                else:
                    # 否则,使用WordpieceTokenizer对token进行进一步的分词处理,并加入split_tokens列表
                    split_tokens += self.wordpiece_tokenizer.tokenize(token)
        else:
            # 否则,直接使用WordpieceTokenizer对整个text进行tokenize操作
            split_tokens = self.wordpiece_tokenizer.tokenize(text)
        return split_tokens

    # 方法,根据token获取对应的id
    def _convert_token_to_id(self, token):
        """Converts a token (str) in an id using the vocab."""
        return self.vocab.get(token, self.vocab.get(self.unk_token))

    # 方法,根据id获取对应的token
    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        return self.ids_to_tokens.get(index, self.unk_token)
    def convert_tokens_to_string(self, tokens):
        """
        Converts a sequence of tokens (string) into a single string by joining them,
        removing '##' and stripping leading/trailing whitespace.

        Args:
            tokens (List[str]): List of tokens to be converted.

        Returns:
            str: The concatenated string of tokens.
        """
        out_string = " ".join(tokens).replace(" ##", "").strip()
        return out_string

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Builds model inputs from a sequence or a pair of sequences for sequence classification tasks
        by adding special tokens. A BERT sequence has the following format:

        - single sequence: `[CLS] X [SEP]`
        - pair of sequences: `[CLS] A [SEP] B [SEP]`

        Args:
            token_ids_0 (List[int]): List of token IDs for the first sequence.
            token_ids_1 (Optional[List[int]]): Optional list of token IDs for the second sequence.

        Returns:
            List[int]: List of input IDs with the appropriate special tokens added.
        """
        if token_ids_1 is None:
            return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
        cls = [self.cls_token_id]
        sep = [self.sep_token_id]
        return cls + token_ids_0 + sep + token_ids_1 + sep

    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
        """
        Retrieves a mask indicating whether each token in the input list is a special token
        (1 for special token, 0 for sequence token). This is used when preparing tokens for a model.

        Args:
            token_ids_0 (List[int]): List of token IDs for the first sequence.
            token_ids_1 (Optional[List[int]]): Optional list of token IDs for the second sequence.
            already_has_special_tokens (bool, optional): Whether the input token lists already include special tokens.

        Returns:
            List[int]: A list of integers representing the mask.
        """
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        if token_ids_1 is not None:
            return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
        return [1] + ([0] * len(token_ids_0)) + [1]

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Creates token type IDs from token lists representing sequences or pairs of sequences.

        Args:
            token_ids_0 (List[int]): List of token IDs for the first sequence.
            token_ids_1 (Optional[List[int]]): Optional list of token IDs for the second sequence.

        Returns:
            List[int]: List of token type IDs.
        """
    ) -> List[int]:
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
        pair mask has the following format:

        ```
        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
        | first sequence    | second sequence |
        ```

        If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).

        Args:
            token_ids_0 (`List[int]`):
                List of IDs representing the first sequence.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List representing the token type IDs for the given sequence(s).
        """
        # Define separator and classification tokens
        sep = [self.sep_token_id]  # Separator token ID
        cls = [self.cls_token_id]  # Classification token ID
        
        # If token_ids_1 is None, return a mask with zeros corresponding to the first sequence only
        if token_ids_1 is None:
            return len(cls + token_ids_0 + sep) * [0]  # Create and return mask with zeros
        
        # If token_ids_1 is provided, return a mask with zeros for the first sequence and ones for the second sequence
        return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # Initialize index counter
        index = 0
        
        # Determine vocabulary file path
        if os.path.isdir(save_directory):
            # If save_directory is a directory, construct file path inside the directory
            vocab_file = os.path.join(
                save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
            )
        else:
            # Otherwise, treat save_directory as the full file path
            vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
        
        # Write vocabulary to the specified file
        with open(vocab_file, "w", encoding="utf-8") as writer:
            # Iterate through vocabulary items sorted by index
            for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
                # Check for non-consecutive indices in the vocabulary
                if index != token_index:
                    logger.warning(
                        f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
                        " Please check that the vocabulary is not corrupted!"
                    )
                    index = token_index  # Update index to current token's index
                writer.write(token + "\n")  # Write token to file
                index += 1  # Increment index for the next token
        
        # Return the path to the saved vocabulary file
        return (vocab_file,)
# 定义一个名为 BasicTokenizer 的类,用于执行基本的分词(如分割标点符号、转换为小写等)。
class BasicTokenizer(object):
    """
    Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).

    Args:
        do_lower_case (`bool`, *optional*, defaults to `True`):
            Whether or not to lowercase the input when tokenizing.
        never_split (`Iterable`, *optional*):
            Collection of tokens which will never be split during tokenization. Only has an effect when
            `do_basic_tokenize=True`
        tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
            Whether or not to tokenize Chinese characters.

            This should likely be deactivated for Japanese (see this
            [issue](https://github.com/huggingface/transformers/issues/328)).
        strip_accents (`bool`, *optional*):
            Whether or not to strip all accents. If this option is not specified, then it will be determined by the
            value for `lowercase` (as in the original BERT).
        do_split_on_punc (`bool`, *optional*, defaults to `True`):
            In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
            the full context of the words, such as contractions.
    """

    # 初始化方法,设置类的属性
    def __init__(
        self,
        do_lower_case=True,          # 是否将输入转换为小写,默认为True
        never_split=None,            # 永远不分割的 token 集合,默认为 None
        tokenize_chinese_chars=True, # 是否分割中文字符,默认为 True
        strip_accents=None,          # 是否去除所有重音符号,默认根据 lowercase 决定
        do_split_on_punc=True,       # 是否在基本标点符号处分割,默认为 True
    ):
        # 如果 never_split 为 None,则设为一个空列表
        if never_split is None:
            never_split = []
        # 设置实例的属性值
        self.do_lower_case = do_lower_case                  # 是否小写化输入
        self.never_split = set(never_split)                 # 永远不分割的 token 集合,转为集合类型
        self.tokenize_chinese_chars = tokenize_chinese_chars # 是否分割中文字符
        self.strip_accents = strip_accents                  # 是否去除重音符号
        self.do_split_on_punc = do_split_on_punc            # 是否在基本标点符号处分割
    def tokenize(self, text, never_split=None):
        """
        Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.

        Args:
            never_split (`List[str]`, *optional*)
                Kept for backward compatibility purposes. Now implemented directly at the base class level (see
                [`PreTrainedTokenizer.tokenize`]) List of token not to split.
        """
        # 使用 never_split 参数更新当前对象的 never_split 集合(若提供的话)
        never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
        # 清理文本,如去除无用空白等
        text = self._clean_text(text)

        # 以下部分是为了支持多语言和中文模型而添加的代码(2018 年 11 月 1 日起)
        # 现在英语模型也应用了这一代码,但由于英语模型未经过中文数据的训练,
        # 这段代码对英语模型基本没有影响(尽管英语词汇表中包含了一些中文单词,
        # 这是因为英语维基百科中包含了一些中文词汇)。
        if self.tokenize_chinese_chars:
            # 对包含中文字符的文本进行特殊处理,分词
            text = self._tokenize_chinese_chars(text)
        # 将文本中的 Unicode 标准化为 NFC 格式(避免同一字符的不同 Unicode 编码被视为不同字符)
        unicode_normalized_text = unicodedata.normalize("NFC", text)
        # 使用空白符分割文本,得到原始 token 列表
        orig_tokens = whitespace_tokenize(unicode_normalized_text)
        split_tokens = []
        # 遍历每个原始 token
        for token in orig_tokens:
            # 如果 token 不在 never_split 集合中
            if token not in never_split:
                # 如果设置为小写处理,则将 token 转换为小写
                if self.do_lower_case:
                    token = token.lower()
                    # 如果需要去除重音符号,则执行去除重音符号的操作
                    if self.strip_accents is not False:
                        token = self._run_strip_accents(token)
                # 如果需要去除重音符号,则执行去除重音符号的操作
                elif self.strip_accents:
                    token = self._run_strip_accents(token)
            # 将处理后的 token 通过标点符号分割函数进一步分割
            split_tokens.extend(self._run_split_on_punc(token, never_split))

        # 使用空白符重新组合处理后的 token,并分割为最终的输出 token 列表
        output_tokens = whitespace_tokenize(" ".join(split_tokens))
        # 返回最终的输出 token 列表
        return output_tokens

    def _run_strip_accents(self, text):
        """Strips accents from a piece of text."""
        # 将文本中的字符标准化为 NFD 格式
        text = unicodedata.normalize("NFD", text)
        output = []
        # 遍历文本中的每个字符
        for char in text:
            # 获取字符的 Unicode 分类
            cat = unicodedata.category(char)
            # 如果字符是非组合型记号(Mn),则跳过
            if cat == "Mn":
                continue
            # 否则将字符添加到输出列表中
            output.append(char)
        # 将输出列表中的字符连接成字符串并返回
        return "".join(output)
    def _run_split_on_punc(self, text, never_split=None):
        """按照标点符号分割文本。

        Args:
            text (str): 要分割的文本。
            never_split (set): 不应该被分割的文本集合。

        Returns:
            list: 分割后的文本列表。

        """
        # 如果不需要按标点符号分割,或者文本在不分割的集合中,则直接返回原文本列表
        if not self.do_split_on_punc or (never_split is not None and text in never_split):
            return [text]
        # 将文本转换为字符列表
        chars = list(text)
        i = 0
        start_new_word = True
        output = []
        while i < len(chars):
            char = chars[i]
            # 如果是标点符号,则作为新词开始
            if _is_punctuation(char):
                output.append([char])
                start_new_word = True
            else:
                # 如果不是标点符号,根据start_new_word标记将字符添加到当前词列表中
                if start_new_word:
                    output.append([])
                start_new_word = False
                output[-1].append(char)
            i += 1

        # 将列表中的字符列表连接为字符串,并返回分割后的文本列表
        return ["".join(x) for x in output]

    def _tokenize_chinese_chars(self, text):
        """在每个CJK字符周围添加空格。

        Args:
            text (str): 要处理的文本。

        Returns:
            str: 处理后的文本。

        """
        output = []
        for char in text:
            cp = ord(char)
            # 如果是CJK字符,添加空格前后包裹该字符
            if self._is_chinese_char(cp):
                output.append(" ")
                output.append(char)
                output.append(" ")
            else:
                output.append(char)
        # 将列表中的字符连接为一个字符串,并返回处理后的文本
        return "".join(output)

    def _is_chinese_char(self, cp):
        """检查CP是否是CJK字符的码点。

        Args:
            cp (int): 要检查的字符的Unicode码点。

        Returns:
            bool: 如果是CJK字符则返回True,否则返回False。

        """
        # 这里的CJK字符定义来自于CJK统一表意文字块的Unicode范围
        if (
            (cp >= 0x4E00 and cp <= 0x9FFF)
            or (cp >= 0x3400 and cp <= 0x4DBF)
            or (cp >= 0x20000 and cp <= 0x2A6DF)
            or (cp >= 0x2A700 and cp <= 0x2B73F)
            or (cp >= 0x2B740 and cp <= 0x2B81F)
            or (cp >= 0x2B820 and cp <= 0x2CEAF)
            or (cp >= 0xF900 and cp <= 0xFAFF)
            or (cp >= 0x2F800 and cp <= 0x2FA1F)
        ):
            return True

        return False

    def _clean_text(self, text):
        """对文本进行无效字符移除和空白字符清理。

        Args:
            text (str): 要清理的文本。

        Returns:
            str: 清理后的文本。

        """
        output = []
        for char in text:
            cp = ord(char)
            # 移除无效字符和控制字符,以及替换空白字符为单个空格
            if cp == 0 or cp == 0xFFFD or _is_control(char):
                continue
            if _is_whitespace(char):
                output.append(" ")
            else:
                output.append(char)
        # 将列表中的字符连接为一个字符串,并返回清理后的文本
        return "".join(output)
class WordpieceTokenizer(object):
    """Runs WordPiece tokenization."""

    def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
        # 初始化WordpieceTokenizer对象,设置词汇表、未知标记和单词的最大字符数
        self.vocab = vocab
        self.unk_token = unk_token
        self.max_input_chars_per_word = max_input_chars_per_word

    def tokenize(self, text):
        """
        Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
        tokenization using the given vocabulary.

        For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.

        Args:
            text: A single token or whitespace separated tokens. This should have
                already been passed through *BasicTokenizer*.

        Returns:
            A list of wordpiece tokens.
        """
        # 初始化输出token列表
        output_tokens = []
        # 使用whitespace_tokenize函数将文本分割成单词或标记
        for token in whitespace_tokenize(text):
            # 将token转换为字符列表
            chars = list(token)
            # 如果token的长度超过最大输入字符数,则将未知标记添加到输出token列表中
            if len(chars) > self.max_input_chars_per_word:
                output_tokens.append(self.unk_token)
                continue

            # 初始化标志变量和起始位置
            is_bad = False
            start = 0
            sub_tokens = []
            # 循环直到处理完所有字符
            while start < len(chars):
                end = len(chars)
                cur_substr = None
                # 使用最长匹配算法找到合适的子串
                while start < end:
                    substr = "".join(chars[start:end])
                    if start > 0:
                        substr = "##" + substr
                    # 如果找到了匹配词汇表的子串,则更新当前子串并跳出循环
                    if substr in self.vocab:
                        cur_substr = substr
                        break
                    end -= 1
                # 如果未找到合适的子串,则标记为无效
                if cur_substr is None:
                    is_bad = True
                    break
                # 将找到的子串添加到sub_tokens列表中
                sub_tokens.append(cur_substr)
                start = end

            # 如果标记为无效,则将未知标记添加到输出token列表中;否则将sub_tokens列表中的token添加到输出token列表中
            if is_bad:
                output_tokens.append(self.unk_token)
            else:
                output_tokens.extend(sub_tokens)
        # 返回最终的token列表
        return output_tokens

.\models\bert\tokenization_bert_fast.py

# coding=utf-8
# 上面是指定脚本的编码格式为 UTF-8

# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# 版权声明,指明了代码的版权归属

#
# Licensed under the Apache License, Version 2.0 (the "License");
# 根据 Apache License, Version 2.0 许可证,可以自由使用本代码
# you may not use this file except in compliance with the License.
# 除非遵循许可证规定,否则不能使用该文件

# You may obtain a copy of the License at
# 可以在以下链接获取许可证的副本
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 除非适用法律要求或书面同意,否则根据许可证分发的软件是基于“原样”分发的,没有任何形式的保证或条件。
# 请参阅许可证以获取详细的权限和限制信息。

"""Fast Tokenization classes for Bert."""
# 用于 Bert 的快速标记化类

import json
from typing import List, Optional, Tuple

from tokenizers import normalizers

from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
from .tokenization_bert import BertTokenizer

# 获取日志记录器对象
logger = logging.get_logger(__name__)

# 定义词汇文件的名称映射
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}

# 预训练模型所需的词汇文件映射
PRETRAINED_VOCAB_FILES_MAP = {
    {
        "vocab_file": {
            "google-bert/bert-base-uncased": "https://huggingface.co/google-bert/bert-base-uncased/resolve/main/vocab.txt",
            "google-bert/bert-large-uncased": "https://huggingface.co/google-bert/bert-large-uncased/resolve/main/vocab.txt",
            "google-bert/bert-base-cased": "https://huggingface.co/google-bert/bert-base-cased/resolve/main/vocab.txt",
            "google-bert/bert-large-cased": "https://huggingface.co/google-bert/bert-large-cased/resolve/main/vocab.txt",
            "google-bert/bert-base-multilingual-uncased": (
                "https://huggingface.co/google-bert/bert-base-multilingual-uncased/resolve/main/vocab.txt"
            ),
            "google-bert/bert-base-multilingual-cased": "https://huggingface.co/google-bert/bert-base-multilingual-cased/resolve/main/vocab.txt",
            "google-bert/bert-base-chinese": "https://huggingface.co/google-bert/bert-base-chinese/resolve/main/vocab.txt",
            "google-bert/bert-base-german-cased": "https://huggingface.co/google-bert/bert-base-german-cased/resolve/main/vocab.txt",
            "google-bert/bert-large-uncased-whole-word-masking": (
                "https://huggingface.co/google-bert/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt"
            ),
            "google-bert/bert-large-cased-whole-word-masking": (
                "https://huggingface.co/google-bert/bert-large-cased-whole-word-masking/resolve/main/vocab.txt"
            ),
            "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad": (
                "https://huggingface.co/google-bert/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt"
            ),
            "google-bert/bert-large-cased-whole-word-masking-finetuned-squad": (
                "https://huggingface.co/google-bert/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt"
            ),
            "google-bert/bert-base-cased-finetuned-mrpc": (
                "https://huggingface.co/google-bert/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt"
            ),
            "google-bert/bert-base-german-dbmdz-cased": "https://huggingface.co/google-bert/bert-base-german-dbmdz-cased/resolve/main/vocab.txt",
            "google-bert/bert-base-german-dbmdz-uncased": (
                "https://huggingface.co/google-bert/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt"
            ),
            "TurkuNLP/bert-base-finnish-cased-v1": (
                "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt"
            ),
            "TurkuNLP/bert-base-finnish-uncased-v1": (
                "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt"
            ),
            "wietsedv/bert-base-dutch-cased": (
                "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt"
            )
        }
    }
    
    
    
    # 注释:
    "vocab_file" 字典包含了多个键值对,每个键代表一个预训练的BERT模型,对应的值是该模型的词汇表(vocab.txt)的下载链接。
    这些链接可以通过Hugging Face模型中心获取,用于获取BERT模型的词汇表数据。
    {
        // Tokenizer文件的映射,键是模型名称,值是对应的Tokenizer.json文件的URL
        "tokenizer_file": {
            "google-bert/bert-base-uncased": "https://huggingface.co/google-bert/bert-base-uncased/resolve/main/tokenizer.json",
            "google-bert/bert-large-uncased": "https://huggingface.co/google-bert/bert-large-uncased/resolve/main/tokenizer.json",
            "google-bert/bert-base-cased": "https://huggingface.co/google-bert/bert-base-cased/resolve/main/tokenizer.json",
            "google-bert/bert-large-cased": "https://huggingface.co/google-bert/bert-large-cased/resolve/main/tokenizer.json",
            "google-bert/bert-base-multilingual-uncased": (
                "https://huggingface.co/google-bert/bert-base-multilingual-uncased/resolve/main/tokenizer.json"
            ),
            "google-bert/bert-base-multilingual-cased": (
                "https://huggingface.co/google-bert/bert-base-multilingual-cased/resolve/main/tokenizer.json"
            ),
            "google-bert/bert-base-chinese": "https://huggingface.co/google-bert/bert-base-chinese/resolve/main/tokenizer.json",
            "google-bert/bert-base-german-cased": "https://huggingface.co/google-bert/bert-base-german-cased/resolve/main/tokenizer.json",
            "google-bert/bert-large-uncased-whole-word-masking": (
                "https://huggingface.co/google-bert/bert-large-uncased-whole-word-masking/resolve/main/tokenizer.json"
            ),
            "google-bert/bert-large-cased-whole-word-masking": (
                "https://huggingface.co/google-bert/bert-large-cased-whole-word-masking/resolve/main/tokenizer.json"
            ),
            "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad": (
                "https://huggingface.co/google-bert/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/tokenizer.json"
            ),
            "google-bert/bert-large-cased-whole-word-masking-finetuned-squad": (
                "https://huggingface.co/google-bert/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/tokenizer.json"
            ),
            "google-bert/bert-base-cased-finetuned-mrpc": (
                "https://huggingface.co/google-bert/bert-base-cased-finetuned-mrpc/resolve/main/tokenizer.json"
            ),
            "google-bert/bert-base-german-dbmdz-cased": (
                "https://huggingface.co/google-bert/bert-base-german-dbmdz-cased/resolve/main/tokenizer.json"
            ),
            "google-bert/bert-base-german-dbmdz-uncased": (
                "https://huggingface.co/google-bert/bert-base-german-dbmdz-uncased/resolve/main/tokenizer.json"
            ),
            "TurkuNLP/bert-base-finnish-cased-v1": (
                "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/tokenizer.json"
            ),
            "TurkuNLP/bert-base-finnish-uncased-v1": (
                "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/tokenizer.json"
            ),
            "wietsedv/bert-base-dutch-cased": (
                "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/tokenizer.json"
            )
        }
    }
}

# 首先定义了一个空的类 BertTokenizerFast,该类继承自 PreTrainedTokenizerFast
class BertTokenizerFast(PreTrainedTokenizerFast):
    # docstring: 构建一个“快速”BERT tokenizer,使用 HuggingFace 的 tokenizers 库支持,基于 WordPiece
    r"""
    Construct a "fast" BERT tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece.

    This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
    refer to this superclass for more information regarding those methods.
    """
    # 定义类 BertTokenizer,用于处理 BERT 模型的分词器功能
    class BertTokenizer:
    
        # 类的初始化方法,用于设置分词器的各种参数和选项
        def __init__(
            self,
            vocab_file=None,  # 词汇表文件路径,用于加载模型的词汇表
            tokenizer_file=None,  # 分词器文件路径,可选,用于加载预训练的分词器模型
            do_lower_case=True,  # 是否将输入转换为小写
            unk_token="[UNK]",  # 未知标记,当词汇表中不存在某个词时使用
            sep_token="[SEP]",  # 分隔符标记,在构建多序列时使用
            pad_token="[PAD]",  # 填充标记,在对不同长度的序列进行批处理时使用
            cls_token="[CLS]",  # 分类器标记,用于序列分类任务中
            mask_token="[MASK]",  # 掩码标记,用于掩码语言模型任务中
            tokenize_chinese_chars=True,  # 是否分词中文字符
            strip_accents=None,  # 是否去除所有重音符号
            **kwargs,  # 其他参数,用于兼容未来可能添加的参数
    ):
        # 调用父类的构造函数,初始化模型的tokenizer
        super().__init__(
            vocab_file,
            tokenizer_file=tokenizer_file,
            do_lower_case=do_lower_case,
            unk_token=unk_token,
            sep_token=sep_token,
            pad_token=pad_token,
            cls_token=cls_token,
            mask_token=mask_token,
            tokenize_chinese_chars=tokenize_chinese_chars,
            strip_accents=strip_accents,
            **kwargs,
        )

        # 获取当前tokenizer的规范化器状态并转换为JSON格式
        normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
        # 检查是否有用户设置的规范化器状态与当前初始化参数不匹配,如果不匹配则进行更新
        if (
            normalizer_state.get("lowercase", do_lower_case) != do_lower_case
            or normalizer_state.get("strip_accents", strip_accents) != strip_accents
            or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars
        ):
            # 获取当前规范化器的类并进行实例化
            normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
            # 更新规范化器的参数
            normalizer_state["lowercase"] = do_lower_case
            normalizer_state["strip_accents"] = strip_accents
            normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars
            # 将更新后的规范化器应用于当前的tokenizer对象
            self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)

        # 更新当前对象的小写处理标志
        self.do_lower_case = do_lower_case

    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
        adding special tokens. A BERT sequence has the following format:

        - single sequence: `[CLS] X [SEP]`
        - pair of sequences: `[CLS] A [SEP] B [SEP]`

        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
        """
        # 构建带有特殊标记的模型输入序列,用于序列分类任务
        output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]

        # 如果存在第二个序列token_ids_1,则连接第二个序列的特殊标记
        if token_ids_1 is not None:
            output += token_ids_1 + [self.sep_token_id]

        return output

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    def create_token_type_ids_from_sequences(self,
                                            token_ids_0: List[int],
                                            token_ids_1: Optional[List[int]] = None) -> List[int]:
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
        pair mask has the following format:

        ```
        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
        | first sequence    | second sequence |
        ```

        If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).

        Args:
            token_ids_0 (`List[int]`):
                List of token IDs representing the first sequence.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of token IDs representing the second sequence in sequence-pair tasks.

        Returns:
            `List[int]`: List of token type IDs according to the given sequence(s).
        """
        # Define the separator token ID and the classification token ID
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]

        # If only one sequence is provided, return a mask with 0s for the first sequence
        if token_ids_1 is None:
            return len(cls + token_ids_0 + sep) * [0]

        # If both sequences are provided, concatenate their lengths with separator and classification tokens
        # Return a mask with 0s for the first sequence and 1s for the second sequence
        return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]


    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        """
        Save the vocabulary files associated with the tokenizer's model to a specified directory.

        Args:
            save_directory (str):
                Directory where the vocabulary files will be saved.
            filename_prefix (Optional[str]):
                Optional prefix to prepend to the saved vocabulary file names.

        Returns:
            Tuple[str]: Tuple containing the filenames of the saved vocabulary files.
        """
        # Call the model's save method to save the vocabulary files to the specified directory
        files = self._tokenizer.model.save(save_directory, name=filename_prefix)
        
        # Return the filenames as a tuple
        return tuple(files)

.\models\bert\tokenization_bert_tf.py

    # 导入所需的标准库和模块
    import os
    from typing import List, Union

    # 导入 TensorFlow 库
    import tensorflow as tf
    # 导入 TensorFlow Text 库中的 BERT 分词器
    from tensorflow_text import BertTokenizer as BertTokenizerLayer
    from tensorflow_text import FastBertTokenizer, ShrinkLongestTrimmer, case_fold_utf8, combine_segments, pad_model_inputs

    # 导入自定义的 Keras 辅助函数
    from ...modeling_tf_utils import keras
    # 导入自定义的 BERT 分词器
    from .tokenization_bert import BertTokenizer

    # 定义一个 Keras 层,用于在图中进行 BERT 分词
    class TFBertTokenizer(keras.layers.Layer):
        """
        This is an in-graph tokenizer for BERT. It should be initialized similarly to other tokenizers, using the
        `from_pretrained()` method. It can also be initialized with the `from_tokenizer()` method, which imports settings
        from an existing standard tokenizer object.

        In-graph tokenizers, unlike other Hugging Face tokenizers, are actually Keras layers and are designed to be run
        when the model is called, rather than during preprocessing. As a result, they have somewhat more limited options
        than standard tokenizer classes. They are most useful when you want to create an end-to-end model that goes
        straight from `tf.string` inputs to outputs.
        """
    # 初始化函数,用于创建一个 Tokenizer 对象
    def __init__(
        self,
        vocab_list: List,                   # 词汇表列表,包含了 Tokenizer 所需的词汇
        do_lower_case: bool,                # 是否将输入文本转换为小写进行分词
        cls_token_id: int = None,           # 分类器标记的 ID,在序列分类中用作序列的第一个标记
        sep_token_id: int = None,           # 分隔符标记的 ID,在构建序列时用于多序列的分隔
        pad_token_id: int = None,           # 填充标记的 ID,在批处理不同长度的序列时使用
        padding: str = "longest",           # 填充类型,可以是"longest"或"max_length"
        truncation: bool = True,            # 是否对序列进行截断,使其不超过最大长度
        max_length: int = 512,              # 序列的最大长度,用于填充和截断
        pad_to_multiple_of: int = None,     # 如果设置,序列将填充到此值的倍数
        return_token_type_ids: bool = True, # 是否返回 token_type_ids
        return_attention_mask: bool = True, # 是否返回 attention_mask
        use_fast_bert_tokenizer: bool = True,  # 是否使用 FastBertTokenizer 类(Tensorflow Text)进行分词
        **tokenizer_kwargs,                 # 其他可能传递给 tokenizer 的参数
        ):
            super().__init__()
            # 调用父类的初始化方法

            if use_fast_bert_tokenizer:
                # 如果使用快速的 BERT 分词器
                self.tf_tokenizer = FastBertTokenizer(
                    vocab_list, token_out_type=tf.int64, lower_case_nfd_strip_accents=do_lower_case, **tokenizer_kwargs
                )
            else:
                # 否则使用静态词汇表创建查找表
                lookup_table = tf.lookup.StaticVocabularyTable(
                    tf.lookup.KeyValueTensorInitializer(
                        keys=vocab_list,
                        key_dtype=tf.string,
                        values=tf.range(tf.size(vocab_list, out_type=tf.int64), dtype=tf.int64),
                        value_dtype=tf.int64,
                    ),
                    num_oov_buckets=1,
                )
                # 使用查找表创建 BERT 分词器层
                self.tf_tokenizer = BertTokenizerLayer(
                    lookup_table, token_out_type=tf.int64, lower_case=do_lower_case, **tokenizer_kwargs
                )

            self.vocab_list = vocab_list
            self.do_lower_case = do_lower_case
            # 设置特殊 token 的索引,如果未提供则从 vocab_list 中获取
            self.cls_token_id = vocab_list.index("[CLS]") if cls_token_id is None else cls_token_id
            self.sep_token_id = vocab_list.index("[SEP]") if sep_token_id is None else sep_token_id
            self.pad_token_id = vocab_list.index("[PAD]") if pad_token_id is None else pad_token_id
            # 初始化用于截断最长序列的 paired_trimmer
            self.paired_trimmer = ShrinkLongestTrimmer(max_length - 3, axis=1)  # Allow room for special tokens
            self.max_length = max_length
            self.padding = padding
            self.truncation = truncation
            self.pad_to_multiple_of = pad_to_multiple_of
            self.return_token_type_ids = return_token_type_ids
            self.return_attention_mask = return_attention_mask
    def from_tokenizer(cls, tokenizer: "PreTrainedTokenizerBase", **kwargs):  # noqa: F821
        """
        Initialize a `TFBertTokenizer` from an existing `Tokenizer`.

        Args:
            tokenizer (`PreTrainedTokenizerBase`):
                The tokenizer to use to initialize the `TFBertTokenizer`.

        Examples:

        ```
        from transformers import AutoTokenizer, TFBertTokenizer

        tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
        tf_tokenizer = TFBertTokenizer.from_tokenizer(tokenizer)
        ```
        """
        # Retrieve the 'do_lower_case' parameter from kwargs; if not provided, use tokenizer's setting
        do_lower_case = kwargs.pop("do_lower_case", None)
        do_lower_case = tokenizer.do_lower_case if do_lower_case is None else do_lower_case
        # Retrieve the 'cls_token_id' parameter from kwargs; if not provided, use tokenizer's setting
        cls_token_id = kwargs.pop("cls_token_id", None)
        cls_token_id = tokenizer.cls_token_id if cls_token_id is None else cls_token_id
        # Retrieve the 'sep_token_id' parameter from kwargs; if not provided, use tokenizer's setting
        sep_token_id = kwargs.pop("sep_token_id", None)
        sep_token_id = tokenizer.sep_token_id if sep_token_id is None else sep_token_id
        # Retrieve the 'pad_token_id' parameter from kwargs; if not provided, use tokenizer's setting
        pad_token_id = kwargs.pop("pad_token_id", None)
        pad_token_id = tokenizer.pad_token_id if pad_token_id is None else pad_token_id

        # Get the vocabulary dictionary from the tokenizer and sort it by indices
        vocab = tokenizer.get_vocab()
        vocab = sorted(vocab.items(), key=lambda x: x[1])
        # Extract just the vocabulary tokens into a list
        vocab_list = [entry[0] for entry in vocab]
        # Instantiate a new TFBertTokenizer using the retrieved parameters and vocab_list
        return cls(
            vocab_list=vocab_list,
            do_lower_case=do_lower_case,
            cls_token_id=cls_token_id,
            sep_token_id=sep_token_id,
            pad_token_id=pad_token_id,
            **kwargs,
        )

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs):
        """
        Instantiate a `TFBertTokenizer` from a pre-trained tokenizer.

        Args:
            pretrained_model_name_or_path (`str` or `os.PathLike`):
                The name or path to the pre-trained tokenizer.

        Examples:

        ```
        from transformers import TFBertTokenizer

        tf_tokenizer = TFBertTokenizer.from_pretrained("google-bert/bert-base-uncased")
        ```
        """
        try:
            # Attempt to create a BertTokenizer instance from the provided pretrained_model_name_or_path
            tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
        except:  # noqa: E722
            # If the above fails, fall back to using BertTokenizerFast
            from .tokenization_bert_fast import BertTokenizerFast

            tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
        # Call from_tokenizer to create a TFBertTokenizer instance using the obtained tokenizer
        return cls.from_tokenizer(tokenizer, **kwargs)

    def unpaired_tokenize(self, texts):
        # If do_lower_case is True, convert texts to lowercase using case_fold_utf8
        if self.do_lower_case:
            texts = case_fold_utf8(texts)
        # Tokenize texts using tf_tokenizer's tokenize method
        tokens = self.tf_tokenizer.tokenize(texts)
        # Merge dimensions from 1 to -1 in tokens
        return tokens.merge_dims(1, -1)

    def call(
        self,
        text,
        text_pair=None,
        padding=None,
        truncation=None,
        max_length=None,
        pad_to_multiple_of=None,
        return_token_type_ids=None,
        return_attention_mask=None,
    # 定义一个方法,用于获取配置信息的字典
    def get_config(self):
        # 返回包含各种配置项的字典
        return {
            "vocab_list": self.vocab_list,       # 返回实例的词汇表列表
            "do_lower_case": self.do_lower_case, # 返回是否执行小写转换的布尔值
            "cls_token_id": self.cls_token_id,   # 返回类别标记的 ID
            "sep_token_id": self.sep_token_id,   # 返回分隔标记的 ID
            "pad_token_id": self.pad_token_id,   # 返回填充标记的 ID
        }
posted @ 2024-06-30 15:34  绝不原创的飞龙  阅读(6)  评论(0编辑  收藏  举报