Transformers-源码解析-四十四-

Transformers 源码解析(四十四)

.\models\electra\modeling_flax_electra.py

# 引入必要的库和模块
from typing import Callable, Optional, Tuple  # 导入类型提示相关的模块

import flax  # 导入 Flax 深度学习库
import flax.linen as nn  # 导入 Flax 中的线性模块
import jax  # 导入 JAX,用于自动求导和并行计算
import jax.numpy as jnp  # 导入 JAX 的 NumPy 接口,用于数组操作
import numpy as np  # 导入 NumPy,用于常规的数学运算
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze  # 导入 Flax 中的冻结字典相关模块
from flax.linen import combine_masks, make_causal_mask  # 导入 Flax 中的掩码组合和因果掩码生成函数
from flax.linen import partitioning as nn_partitioning  # 导入 Flax 中的模块分割工具
from flax.linen.attention import dot_product_attention_weights  # 导入 Flax 中的点积注意力权重计算函数
from flax.traverse_util import flatten_dict, unflatten_dict  # 导入 Flax 中的字典扁平化和反扁平化工具
from jax import lax  # 导入 JAX 的低级别 API

# 导入输出类和实用函数
from ...modeling_flax_outputs import (
    FlaxBaseModelOutput,
    FlaxBaseModelOutputWithPastAndCrossAttentions,
    FlaxCausalLMOutputWithCrossAttentions,
    FlaxMaskedLMOutput,
    FlaxMultipleChoiceModelOutput,
    FlaxQuestionAnsweringModelOutput,
    FlaxSequenceClassifierOutput,
    FlaxTokenClassifierOutput,
)
from ...modeling_flax_utils import (
    ACT2FN,  # 导入激活函数映射
    FlaxPreTrainedModel,  # 导入 Flax 预训练模型基类
    append_call_sample_docstring,  # 导入用于追加调用示例文档字符串的函数
    append_replace_return_docstrings,  # 导入用于追加替换返回值文档字符串的函数
    overwrite_call_docstring,  # 导入用于覆盖调用文档字符串的函数
)
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging  # 导入模型输出、文档字符串添加和日志记录工具
from .configuration_electra import ElectraConfig  # 导入 Electra 配置类

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器实例

_CHECKPOINT_FOR_DOC = "google/electra-small-discriminator"  # 预训练模型的检查点路径
_CONFIG_FOR_DOC = "ElectraConfig"  # Electra 模型的配置类名称

remat = nn_partitioning.remat  # 设置 remat 变量为模块分割的重组矩阵操作

@flax.struct.dataclass
class FlaxElectraForPreTrainingOutput(ModelOutput):
    """
    [`ElectraForPreTraining`] 的输出类型。
    """
    # 此类用于定义 Electra 模型预训练的输出结构
    Args:
        logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        
        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        
        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    # 定义 logits 变量,类型为 jnp.ndarray,形状为 (batch_size, sequence_length, config.vocab_size)
    logits: jnp.ndarray = None
    
    # 定义 hidden_states 变量,类型为 Optional[Tuple[jnp.ndarray]],可选参数,当 `output_hidden_states=True` 时返回
    # 返回一个元组,包含 jnp.ndarray 类型的张量,形状为 (batch_size, sequence_length, hidden_size)
    hidden_states: Optional[Tuple[jnp.ndarray]] = None
    
    # 定义 attentions 变量,类型为 Optional[Tuple[jnp.ndarray]],可选参数,当 `output_attentions=True` 时返回
    # 返回一个元组,包含 jnp.ndarray 类型的张量,形状为 (batch_size, num_heads, sequence_length, sequence_length)
    # 表示注意力权重经过 softmax 后的结果,用于计算自注意力头部中的加权平均值。
    attentions: Optional[Tuple[jnp.ndarray]] = None
# 定义模型的文档字符串,描述该模型从 FlaxPreTrainedModel 继承,并列出了库为所有模型实现的通用方法(如下载、保存和从 PyTorch 模型转换权重)
ELECTRA_START_DOCSTRING = r"""

    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading, saving and converting weights from PyTorch models)

    This model is also a Flax Linen
    [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
    regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)

    Parameters:
        config ([`ElectraConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

# 定义模型输入的文档字符串,目前为空白
ELECTRA_INPUTS_DOCSTRING = r"""
    # 将输入的各项参数打包成一个参数字典,用于传递给模型的前向推断函数
    Args:
        input_ids (`numpy.ndarray` of shape `({0})`):
            输入序列标记在词汇表中的索引。
    
            可以使用 [`AutoTokenizer`] 获取这些索引。详情见 [`PreTrainedTokenizer.encode`] 和
            [`PreTrainedTokenizer.__call__`]。
    
            [什么是输入 ID?](../glossary#input-ids)
        attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
            避免对填充标记索引执行注意力的掩码。掩码值为 `[0, 1]`:
    
            - 1 表示**不屏蔽**的标记,
            - 0 表示**屏蔽**的标记。
    
            [什么是注意力掩码?](../glossary#attention-mask)
        token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
            段标记索引,指示输入的第一部分和第二部分。索引值为 `[0, 1]`:
    
            - 0 对应*句子 A* 的标记,
            - 1 对应*句子 B* 的标记。
    
            [什么是标记类型 ID?](../glossary#token-type-ids)
        position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
            每个输入序列标记在位置嵌入中的位置索引。选择范围为 `[0, config.max_position_embeddings - 1]`。
        head_mask (`numpy.ndarray` of shape `({0})`, `optional):
            选择性屏蔽注意力模块中的头部的掩码。掩码值为 `[0, 1]`:
    
            - 1 表示**不屏蔽**的头部,
            - 0 表示**屏蔽**的头部。
        return_dict (`bool`, *optional*):
            是否返回一个 [`~utils.ModelOutput`] 而不是普通的元组。
"""
定义一个名为 FlaxElectraEmbeddings 的 nn.Module 类,用于构建包括单词、位置和标记类型嵌入的 embeddings。

config: ElectraConfig
    # 保存了 Electra 模型的配置信息,如词汇大小、嵌入维度等

dtype: jnp.dtype = jnp.float32
    # 计算时使用的数据类型,默认为 jnp.float32

setup(self):
    # 初始化模型的各个组件

    self.word_embeddings = nn.Embed(
        self.config.vocab_size,
        self.config.embedding_size,
        embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
    )
    # 创建单词嵌入层,根据词汇大小和嵌入维度进行初始化

    self.position_embeddings = nn.Embed(
        self.config.max_position_embeddings,
        self.config.embedding_size,
        embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
    )
    # 创建位置嵌入层,根据最大位置嵌入数和嵌入维度进行初始化

    self.token_type_embeddings = nn.Embed(
        self.config.type_vocab_size,
        self.config.embedding_size,
        embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
    )
    # 创建标记类型嵌入层,根据标记类型的数量和嵌入维度进行初始化

    self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
    # 创建 Layer Normalization 层,使用给定的 epsilon 参数进行初始化

    self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
    # 创建 Dropout 层,使用给定的 dropout 概率进行初始化

__call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
    # 定义 __call__ 方法,实现模块的调用功能,接受输入参数并进行处理

    inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
    # 将输入的词汇 ID 转换为单词嵌入

    position_embeds = self.position_embeddings(position_ids.astype("i4"))
    # 将位置 ID 转换为位置嵌入

    token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
    # 将标记类型 ID 转换为标记类型嵌入

    hidden_states = inputs_embeds + token_type_embeddings + position_embeds
    # 将单词、位置和标记类型嵌入求和,形成最终的隐藏状态表示

    hidden_states = self.LayerNorm(hidden_states)
    # 对隐藏状态进行 Layer Normalization 处理

    hidden_states = self.dropout(hidden_states, deterministic=deterministic)
    # 对处理后的隐藏状态进行 Dropout 操作

    return hidden_states
    # 返回处理后的最终隐藏状态
"""

# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Electra
class FlaxElectraSelfAttention(nn.Module):
    config: ElectraConfig
    causal: bool = False
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    # 设置函数用于初始化模型的配置
    def setup(self):
        # 计算每个注意力头的维度
        self.head_dim = self.config.hidden_size // self.config.num_attention_heads
        # 如果隐藏层大小不能被注意力头数整除,抛出数值错误异常
        if self.config.hidden_size % self.config.num_attention_heads != 0:
            raise ValueError(
                "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
                "                   : {self.config.num_attention_heads}"
            )

        # 初始化查询(query)的全连接层
        self.query = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )
        # 初始化键(key)的全连接层
        self.key = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )
        # 初始化值(value)的全连接层
        self.value = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )

        # 如果是因果注意力模型,创建一个因果掩码
        if self.causal:
            self.causal_mask = make_causal_mask(
                jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
            )

    # 函数用于将隐藏状态按照注意力头的数量进行分割
    def _split_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))

    # 函数用于将按注意力头分割后的隐藏状态重新合并
    def _merge_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))

    @nn.compact
    # 从 transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache 复制而来的函数装饰器
    def _concatenate_to_cache(self, key, value, query, attention_mask):
        """
        This function takes projected key, value states from a single input token and concatenates the states to cached
        states from previous steps. This function is slightly adapted from the official Flax repository:
        https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
        """
        # 检测是否通过缺少现有缓存数据来初始化。
        is_initialized = self.has_variable("cache", "cached_key")
        # 获取或创建缓存的键,若不存在则初始化为零张量,形状与传入的键相同
        cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
        # 获取或创建缓存的值,若不存在则初始化为零张量,形状与传入的值相同
        cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
        # 获取或创建缓存索引,若不存在则初始化为整数零
        cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))

        if is_initialized:
            # 获取当前缓存的维度信息,即批次维度、最大长度、注意力头数和每个头的深度
            *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
            # 使用新的一维空间片段更新键、值缓存
            cur_index = cache_index.value
            indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
            key = lax.dynamic_update_slice(cached_key.value, key, indices)
            value = lax.dynamic_update_slice(cached_value.value, value, indices)
            # 更新缓存的键和值
            cached_key.value = key
            cached_value.value = value
            # 更新缓存索引,增加已更新的缓存向量数目
            num_updated_cache_vectors = query.shape[1]
            cache_index.value = cache_index.value + num_updated_cache_vectors
            # 生成因果掩码用于缓存的解码器自注意力:
            # 我们的单个查询位置只应该参与到已经生成和缓存的键位置,而不是剩余的零元素。
            pad_mask = jnp.broadcast_to(
                jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
                tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
            )
            # 合并掩码,结合当前的注意力掩码
            attention_mask = combine_masks(pad_mask, attention_mask)
        # 返回更新后的键、值和注意力掩码
        return key, value, attention_mask
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Electra
class FlaxElectraSelfOutput(nn.Module):
    config: ElectraConfig  # Electra模型的配置对象
    dtype: jnp.dtype = jnp.float32  # 计算时的数据类型

    def setup(self):
        # 创建一个全连接层,输出维度为配置中的hidden_size,权重初始化为正态分布
        self.dense = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        # 创建一个LayerNorm层,epsilon值为配置中的layer_norm_eps
        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        # 创建一个Dropout层,dropout率为配置中的hidden_dropout_prob
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

    def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
        # 输入经过全连接层
        hidden_states = self.dense(hidden_states)
        # 经过Dropout层,用于随机置零一部分神经元的输出
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        # 经过LayerNorm层,并加上输入张量,实现残差连接和层归一化
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Electra
class FlaxElectraAttention(nn.Module):
    config: ElectraConfig  # Electra模型的配置对象
    causal: bool = False  # 是否是因果注意力
    dtype: jnp.dtype = jnp.float32  # 计算时的数据类型

    def setup(self):
        # 创建一个FlaxElectraSelfAttention对象,用于自注意力计算
        self.self = FlaxElectraSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
        # 创建一个FlaxElectraSelfOutput对象,用于自注意力的输出处理

    def __call__(
        self,
        hidden_states,
        attention_mask,
        layer_head_mask,
        key_value_states=None,
        init_cache=False,
        deterministic=True,
        output_attentions: bool = False,
    ):
        # 注意力掩码的形状应为(*batch_sizes, kv_length)
        # FLAX期望的形状为(*batch_sizes, 1, 1, kv_length),使其能广播
        # 注意力计算的输出包含在attn_outputs中
        attn_outputs = self.self(
            hidden_states,
            attention_mask,
            layer_head_mask=layer_head_mask,
            key_value_states=key_value_states,
            init_cache=init_cache,
            deterministic=deterministic,
            output_attentions=output_attentions,
        )
        attn_output = attn_outputs[0]
        # 使用self.output对自注意力的输出进行处理,实现残差连接和层归一化
        hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (attn_outputs[1],)

        return outputs


# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Electra
class FlaxElectraIntermediate(nn.Module):
    config: ElectraConfig  # Electra模型的配置对象
    dtype: jnp.dtype = jnp.float32  # 计算时的数据类型

    def setup(self):
        # 创建一个全连接层,输出维度为配置中的intermediate_size,权重初始化为正态分布
        self.dense = nn.Dense(
            self.config.intermediate_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        # 选择一个激活函数,根据配置中的hidden_act确定
        self.activation = ACT2FN[self.config.hidden_act]
    # 定义一个特殊方法 __call__(),用于将对象实例像函数一样调用
    def __call__(self, hidden_states):
        # 使用 self.dense 对象处理 hidden_states,进行线性变换
        hidden_states = self.dense(hidden_states)
        # 使用 self.activation 对象处理 hidden_states,应用激活函数
        hidden_states = self.activation(hidden_states)
        # 返回处理后的 hidden_states
        return hidden_states
# Copied from transformers.models.electra.modeling_flax_electra.FlaxElectraOutput with Bert->Electra
class FlaxElectraOutput(nn.Module):
    config: ElectraConfig
    dtype: jnp.dtype = jnp.float32  # 计算时使用的数据类型

    def setup(self):
        # 密集连接层,用于将隐藏状态转换到指定大小的输出空间
        self.dense = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        # 随机失活层,以一定的概率丢弃隐藏状态中的部分数据,防止过拟合
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
        # 层归一化,用于对输入数据进行归一化处理
        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)

    def __call__(self, hidden_states, attention_output, deterministic: bool = True):
        # 密集连接层计算,将隐藏状态映射到指定大小的输出空间
        hidden_states = self.dense(hidden_states)
        # 随机失活操作,根据设定的概率丢弃部分数据,用于防止过拟合
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        # 层归一化操作,将处理后的数据进行归一化
        hidden_states = self.LayerNorm(hidden_states + attention_output)
        return hidden_states


# Copied from transformers.models.electra.modeling_flax_electra.FlaxElectraLayer with Bert->Electra
class FlaxElectraLayer(nn.Module):
    config: ElectraConfig
    dtype: jnp.dtype = jnp.float32  # 计算时使用的数据类型

    def setup(self):
        # Electra 自注意力层,根据配置初始化
        self.attention = FlaxElectraAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
        # Electra 中间层,根据配置初始化
        self.intermediate = FlaxElectraIntermediate(self.config, dtype=self.dtype)
        # Electra 输出层,根据配置初始化
        self.output = FlaxElectraOutput(self.config, dtype=self.dtype)
        # 如果配置中包含跨注意力机制,初始化跨注意力层
        if self.config.add_cross_attention:
            self.crossattention = FlaxElectraAttention(self.config, causal=False, dtype=self.dtype)

    def __call__(
        self,
        hidden_states,
        attention_mask,
        layer_head_mask,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
        output_attentions: bool = False,
        ):
        # Electra 自注意力计算,处理隐藏状态和注意力掩码
        attention_output = self.attention(
            hidden_states,
            attention_mask,
            layer_head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            init_cache=init_cache,
            deterministic=deterministic,
            output_attentions=output_attentions,
        )
        # Electra 中间层计算,处理注意力输出
        intermediate_output = self.intermediate(attention_output)
        # Electra 输出层计算,处理中间层输出和注意力输出
        layer_output = self.output(intermediate_output, attention_output, deterministic=deterministic)
        
        if self.config.add_cross_attention:
            # 如果配置中包含跨注意力机制,计算跨注意力
            attention_output = self.crossattention(
                layer_output,
                encoder_attention_mask,
                encoder_hidden_states,
                layer_head_mask=None,
                init_cache=init_cache,
                deterministic=deterministic,
                output_attentions=output_attentions,
            )
        
        return (layer_output, attention_output) if output_attentions else layer_output
        # Self Attention
        # 使用 self.attention 方法对输入的 hidden_states 进行自注意力计算
        attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            layer_head_mask=layer_head_mask,
            init_cache=init_cache,
            deterministic=deterministic,
            output_attentions=output_attentions,
        )
        # 获取自注意力计算后的输出
        attention_output = attention_outputs[0]

        # Cross-Attention Block
        # 如果 encoder_hidden_states 不为空,则进行交叉注意力计算
        if encoder_hidden_states is not None:
            # 使用 self.crossattention 方法进行交叉注意力计算
            cross_attention_outputs = self.crossattention(
                attention_output,
                attention_mask=encoder_attention_mask,
                layer_head_mask=layer_head_mask,
                key_value_states=encoder_hidden_states,
                deterministic=deterministic,
                output_attentions=output_attentions,
            )
            # 获取交叉注意力计算后的输出
            attention_output = cross_attention_outputs[0]

        # 经过注意力计算后,通过 self.intermediate 方法处理中间隐藏层输出
        hidden_states = self.intermediate(attention_output)
        # 使用 self.output 方法根据注意力输出和隐藏层输出生成最终输出
        hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)

        # 将最终输出存入元组 outputs 中
        outputs = (hidden_states,)

        # 如果需要输出注意力权重信息
        if output_attentions:
            # 将自注意力的注意力权重信息添加到 outputs 中
            outputs += (attention_outputs[1],)
            # 如果存在 encoder_hidden_states,则将交叉注意力的注意力权重信息也添加到 outputs 中
            if encoder_hidden_states is not None:
                outputs += (cross_attention_outputs[1],)
        
        # 返回最终的输出元组
        return outputs
# 从 transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection 复制代码并将 Bert 替换为 Electra
class FlaxElectraLayerCollection(nn.Module):
    config: ElectraConfig  # 类属性,指定模型配置为 ElectraConfig 类型
    dtype: jnp.dtype = jnp.float32  # 计算过程中使用的数据类型,默认为 jnp.float32
    gradient_checkpointing: bool = False  # 是否开启梯度检查点,默认为 False

    def setup(self):
        # 如果开启了梯度检查点
        if self.gradient_checkpointing:
            # 使用 remat 函数包装 FlaxElectraLayer 类,指定静态参数索引为 (5, 6, 7)
            FlaxElectraCheckpointLayer = remat(FlaxElectraLayer, static_argnums=(5, 6, 7))
            # 创建 self.layers 列表,包含 num_hidden_layers 个 FlaxElectraCheckpointLayer 实例
            self.layers = [
                FlaxElectraCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
                for i in range(self.config.num_hidden_layers)
            ]
        else:
            # 创建 self.layers 列表,包含 num_hidden_layers 个 FlaxElectraLayer 实例
            self.layers = [
                FlaxElectraLayer(self.config, name=str(i), dtype=self.dtype)
                for i in range(self.config.num_hidden_layers)
            ]

    def __call__(
        self,
        hidden_states,
        attention_mask,
        head_mask,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        ):
            # 如果不需要输出注意力权重信息,则初始化为空元组;否则置为 None
            all_attentions = () if output_attentions else None
            # 如果不需要输出隐藏状态信息,则初始化为空元组;否则置为 None
            all_hidden_states = () if output_hidden_states else None
            # 如果不需要输出交叉注意力权重信息,则初始化为空元组;否则置为 None
            all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None

            # 检查头部掩码的层数是否正确
            if head_mask is not None:
                if head_mask.shape[0] != (len(self.layers)):
                    raise ValueError(
                        f"The head_mask should be specified for {len(self.layers)} layers, but it is for                  "
                        f"       {head_mask.shape[0]}."
                    )

            # 遍历每一层神经网络层
            for i, layer in enumerate(self.layers):
                # 如果需要输出隐藏状态信息,则记录当前层的隐藏状态
                if output_hidden_states:
                    all_hidden_states += (hidden_states,)

                # 调用当前层的前向传播方法
                layer_outputs = layer(
                    hidden_states,
                    attention_mask,
                    head_mask[i] if head_mask is not None else None,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    init_cache,
                    deterministic,
                    output_attentions,
                )

                # 更新隐藏状态为当前层的输出
                hidden_states = layer_outputs[0]

                # 如果需要输出注意力权重信息,则记录当前层的注意力权重
                if output_attentions:
                    all_attentions += (layer_outputs[1],)

                    # 如果有编码器隐藏状态,则记录当前层的交叉注意力权重
                    if encoder_hidden_states is not None:
                        all_cross_attentions += (layer_outputs[2],)

            # 如果需要输出隐藏状态信息,则记录最后一层的隐藏状态
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            # 组装最终的输出元组
            outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)

            # 如果不需要返回字典格式的输出,则返回元组中非空的部分
            if not return_dict:
                return tuple(v for v in outputs if v is not None)

            # 返回字典格式的输出,使用特定的输出类
            return FlaxBaseModelOutputWithPastAndCrossAttentions(
                last_hidden_state=hidden_states,
                hidden_states=all_hidden_states,
                attentions=all_attentions,
                cross_attentions=all_cross_attentions,
            )
# 从 transformers.models.bert.modeling_flax_bert.FlaxBertEncoder 复制并修改为使用 Electra 模型
class FlaxElectraEncoder(nn.Module):
    config: ElectraConfig
    dtype: jnp.dtype = jnp.float32  # 计算过程中使用的数据类型
    gradient_checkpointing: bool = False  # 是否使用梯度检查点技术

    def setup(self):
        # 初始化 Electra 编码器层集合
        self.layer = FlaxElectraLayerCollection(
            self.config,
            dtype=self.dtype,
            gradient_checkpointing=self.gradient_checkpointing,
        )

    def __call__(
        self,
        hidden_states,
        attention_mask,
        head_mask,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 调用 Electra 编码器层集合来处理输入
        return self.layer(
            hidden_states,
            attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            init_cache=init_cache,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )


class FlaxElectraGeneratorPredictions(nn.Module):
    config: ElectraConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # 初始化 Electra 生成器预测层的 LayerNorm 和 Dense 层
        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        self.dense = nn.Dense(self.config.embedding_size, dtype=self.dtype)

    def __call__(self, hidden_states):
        # 执行 Electra 生成器预测过程:Dense -> 激活函数 -> LayerNorm
        hidden_states = self.dense(hidden_states)
        hidden_states = ACT2FN[self.config.hidden_act](hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


class FlaxElectraDiscriminatorPredictions(nn.Module):
    """用于鉴别器的预测模块,由两个密集层组成。"""

    config: ElectraConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # 初始化 Electra 鉴别器预测层的 Dense 层和 Dense 预测层
        self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
        self.dense_prediction = nn.Dense(1, dtype=self.dtype)

    def __call__(self, hidden_states):
        # 执行 Electra 鉴别器预测过程:Dense -> 激活函数 -> Dense 预测层
        hidden_states = self.dense(hidden_states)
        hidden_states = ACT2FN[self.config.hidden_act](hidden_states)
        hidden_states = self.dense_prediction(hidden_states).squeeze(-1)
        return hidden_states


class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
    """
    处理权重初始化和一个简单接口以下载和加载预训练模型的抽象类。
    """

    config_class = ElectraConfig
    base_model_prefix = "electra"
    module_class: nn.Module = None
    # 初始化方法,用于实例化一个新的对象
    def __init__(
        self,
        config: ElectraConfig,
        input_shape: Tuple = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        gradient_checkpointing: bool = False,
        **kwargs,
    ):
        # 使用传入的配置和参数初始化模块对象
        module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
        # 调用父类的初始化方法,传入配置、模块对象以及其他参数
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

    # 从transformers库中复制而来,用于启用梯度检查点
    def enable_gradient_checkpointing(self):
        # 根据当前对象的配置和数据类型,启用模块的梯度检查点功能
        self._module = self.module_class(
            config=self.config,
            dtype=self.dtype,
            gradient_checkpointing=True,
        )

    # 从transformers库中复制而来,用于初始化模型的权重
    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # 初始化输入的张量
        input_ids = jnp.zeros(input_shape, dtype="i4")
        token_type_ids = jnp.zeros_like(input_ids)
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
        attention_mask = jnp.ones_like(input_ids)
        head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))

        # 拆分随机数生成器用于参数和dropout
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        if self.config.add_cross_attention:
            encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
            encoder_attention_mask = attention_mask
            # 使用模块的初始化方法初始化模型的各种参数
            module_init_outputs = self.module.init(
                rngs,
                input_ids,
                attention_mask,
                token_type_ids,
                position_ids,
                head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                return_dict=False,
            )
        else:
            # 使用模块的初始化方法初始化模型的各种参数(无交叉注意力的情况)
            module_init_outputs = self.module.init(
                rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
            )

        # 获取随机初始化的参数
        random_params = module_init_outputs["params"]

        # 如果提供了预定义的参数,将缺失的参数补充进去
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params
    # 初始化缓存的方法,用于快速自回归解码
    def init_cache(self, batch_size, max_length):
        r"""
        Args:
            batch_size (`int`):
                batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
            max_length (`int`):
                maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
                cache.
        """
        # 初始化输入变量以检索缓存
        input_ids = jnp.ones((batch_size, max_length), dtype="i4")
        attention_mask = jnp.ones_like(input_ids, dtype="i4")
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        # 使用模型的初始化方法初始化变量,包括输入的 ID、注意力掩码、位置 ID,同时设置返回字典为 False 并初始化缓存
        init_variables = self.module.init(
            jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
        )
        # 返回解冻后的缓存部分
        return unfreeze(init_variables["cache"])

    @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        params: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        past_key_values: dict = None,
# 定义一个 FlaxElectraModule 类,继承自 nn.Module
class FlaxElectraModule(nn.Module):
    # 配置属性,使用 ElectraConfig 类型
    config: ElectraConfig
    # 计算中使用的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32
    # 是否使用梯度检查点
    gradient_checkpointing: bool = False

    # 模块的初始化方法
    def setup(self):
        # 创建 FlaxElectraEmbeddings 实例,使用给定的配置和数据类型
        self.embeddings = FlaxElectraEmbeddings(self.config, dtype=self.dtype)
        # 如果嵌入维度不等于隐藏层维度,创建 Dense 层进行投影
        if self.config.embedding_size != self.config.hidden_size:
            self.embeddings_project = nn.Dense(self.config.hidden_size, dtype=self.dtype)
        # 创建 FlaxElectraEncoder 实例,使用给定的配置、数据类型和梯度检查点标志
        self.encoder = FlaxElectraEncoder(
            self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
        )

    # 实现调用模块时的行为
    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        head_mask: Optional[np.ndarray] = None,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 调用 embeddings 方法生成嵌入向量
        embeddings = self.embeddings(
            input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
        )
        # 如果存在 embeddings_project 属性,对 embeddings 进行投影
        if hasattr(self, "embeddings_project"):
            embeddings = self.embeddings_project(embeddings)

        # 调用 encoder 方法对 embeddings 进行编码
        return self.encoder(
            embeddings,
            attention_mask,
            head_mask=head_mask,
            deterministic=deterministic,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )


# 添加文档字符串说明的装饰器,说明 FlaxElectraModel 是基于 FlaxElectraPreTrainedModel 的模型
@add_start_docstrings(
    "The bare Electra Model transformer outputting raw hidden-states without any specific head on top.",
    ELECTRA_START_DOCSTRING,
)
# 定义 FlaxElectraModel 类,继承自 FlaxElectraPreTrainedModel
class FlaxElectraModel(FlaxElectraPreTrainedModel):
    # 模块类设置为 FlaxElectraModule
    module_class = FlaxElectraModule


# 向 FlaxElectraModel 添加调用样本文档字符串的函数说明
append_call_sample_docstring(FlaxElectraModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC)


# 定义 FlaxElectraTiedDense 类,继承自 nn.Module
class FlaxElectraTiedDense(nn.Module):
    # 嵌入大小属性
    embedding_size: int
    # 数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32
    # 精度设置,默认为 None
    precision = None
    # 偏置初始化函数,默认为全零初始化
    bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros

    # 模块的初始化方法
    def setup(self):
        # 创建偏置参数,形状为 (embedding_size,)
        self.bias = self.param("bias", self.bias_init, (self.embedding_size,))

    # 实现调用模块时的行为
    def __call__(self, x, kernel):
        # 将输入 x 和 kernel 转换为指定数据类型的 jnp 数组
        x = jnp.asarray(x, self.dtype)
        kernel = jnp.asarray(kernel, self.dtype)
        # 使用 dot_general 函数进行矩阵乘法运算,加上偏置项
        y = lax.dot_general(
            x,
            kernel,
            (((x.ndim - 1,), (0,)), ((), ())),
            precision=self.precision,
        )
        # 将偏置转换为指定数据类型的 jnp 数组后,返回 y 加上 bias 的结果
        bias = jnp.asarray(self.bias, self.dtype)
        return y + bias


# 定义 FlaxElectraForMaskedLMModule 类,继承自 nn.Module
class FlaxElectraForMaskedLMModule(nn.Module):
    # 配置属性,使用 ElectraConfig 类型
    config: ElectraConfig
    # 数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32
    # 是否使用梯度检查点
    gradient_checkpointing: bool = False
    # 初始化模型设置,在对象的实例化过程中被调用
    def setup(self):
        # 初始化 Electra 模型模块,使用给定的配置、数据类型和梯度检查点设置
        self.electra = FlaxElectraModule(
            config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
        )
        # 初始化生成器预测模块,使用给定的配置和数据类型
        self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
        # 如果配置要求共享词嵌入
        if self.config.tie_word_embeddings:
            # 使用 Electra 模型的共享词嵌入初始化生成器 LM 头
            self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
        else:
            # 否则,初始化一个普通的全连接层作为生成器 LM 头
            self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype)

    # 在对象被调用时执行的方法,处理输入并生成预测结果
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 将输入传递给 Electra 模型,获取模型输出
        outputs = self.electra(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 从模型输出中提取隐藏状态
        hidden_states = outputs[0]
        # 使用生成器预测模块生成预测分数
        prediction_scores = self.generator_predictions(hidden_states)

        # 如果配置要求共享词嵌入
        if self.config.tie_word_embeddings:
            # 获取 Electra 模型的共享词嵌入
            shared_embedding = self.electra.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
            # 使用共享词嵌入调整生成器 LM 头的预测分数
            prediction_scores = self.generator_lm_head(prediction_scores, shared_embedding.T)
        else:
            # 否则,直接使用生成器 LM 头生成预测分数
            prediction_scores = self.generator_lm_head(prediction_scores)

        # 如果不需要返回字典
        if not return_dict:
            # 返回预测分数和其它输出
            return (prediction_scores,) + outputs[1:]

        # 返回封装了预测分数、隐藏状态和注意力的 MaskedLMOutput 对象
        return FlaxMaskedLMOutput(
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
@add_start_docstrings("""Electra Model with a `language modeling` head on top.""", ELECTRA_START_DOCSTRING)
# 使用装饰器添加模型的文档字符串,指明这是一个在语言建模头部的Electra模型

class FlaxElectraForMaskedLM(FlaxElectraPreTrainedModel):
    module_class = FlaxElectraForMaskedLMModule

# 定义一个FlaxElectraForMaskedLM类,继承自FlaxElectraPreTrainedModel,并指定其模块类为FlaxElectraForMaskedLMModule

append_call_sample_docstring(FlaxElectraForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC)

# 向FlaxElectraForMaskedLM类的__call__方法添加示例的文档字符串,展示了如何调用该模型的示例用法

class FlaxElectraForPreTrainingModule(nn.Module):
    config: ElectraConfig
    dtype: jnp.dtype = jnp.float32
    gradient_checkpointing: bool = False

    def setup(self):
        self.electra = FlaxElectraModule(
            config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
        )
        self.discriminator_predictions = FlaxElectraDiscriminatorPredictions(config=self.config, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # Model
        # 调用self.electra模块进行模型计算
        outputs = self.electra(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = outputs[0]

        # 使用self.discriminator_predictions预测生成的token
        logits = self.discriminator_predictions(hidden_states)

        if not return_dict:
            return (logits,) + outputs[1:]

        # 返回预训练输出对象FlaxElectraForPreTrainingOutput
        return FlaxElectraForPreTrainingOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


@add_start_docstrings(
    """
    Electra model with a binary classification head on top as used during pretraining for identifying generated tokens.

    It is recommended to load the discriminator checkpoint into that model.
    """,
    ELECTRA_START_DOCSTRING,
)
# 使用装饰器添加文档字符串,描述这是一个在预训练过程中用于识别生成token的Electra模型

class FlaxElectraForPreTraining(FlaxElectraPreTrainedModel):
    module_class = FlaxElectraForPreTrainingModule

# 定义一个FlaxElectraForPreTraining类,继承自FlaxElectraPreTrainedModel,并指定其模块类为FlaxElectraForPreTrainingModule

FLAX_ELECTRA_FOR_PRETRAINING_DOCSTRING = """
    Returns:

    Example:

    ```
    >>> from transformers import AutoTokenizer, FlaxElectraForPreTraining

    >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-small-discriminator")
    >>> model = FlaxElectraForPreTraining.from_pretrained("google/electra-small-discriminator")

    >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
    >>> outputs = model(**inputs)

    >>> prediction_logits = outputs.logits
    ```
"""

overwrite_call_docstring(
    FlaxElectraForPreTraining,
    ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_ELECTRA_FOR_PRETRAINING_DOCSTRING,
)
# 覆盖FlaxElectraForPreTraining类的__call__方法的文档字符串,展示模型的输入和输出示例用法
    # 导入FlaxElectraForPreTraining类和FlaxElectraForPreTrainingOutput类型
    # 使用_CONFIG_FOR_DOC指定的配置类
    FlaxElectraForPreTraining, output_type=FlaxElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
)


class FlaxElectraForTokenClassificationModule(nn.Module):
    config: ElectraConfig  # 类型注解,指定 config 属性的类型为 ElectraConfig
    dtype: jnp.dtype = jnp.float32  # 设置 dtype 属性,默认为 jnp.float32 类型
    gradient_checkpointing: bool = False  # 设置 gradient_checkpointing 属性,默认为 False

    def setup(self):
        self.electra = FlaxElectraModule(  # 初始化 electra 属性为 FlaxElectraModule 实例
            config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
        )
        classifier_dropout = (
            self.config.classifier_dropout  # 获取 config 对象的 classifier_dropout 属性
            if self.config.classifier_dropout is not None  # 如果其不为 None,则使用该值
            else self.config.hidden_dropout_prob  # 否则使用 hidden_dropout_prob 属性的值
        )
        self.dropout = nn.Dropout(classifier_dropout)  # 初始化 dropout 属性为 nn.Dropout 实例
        self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)  # 初始化 classifier 属性为 nn.Dense 实例

    def __call__(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # Model
        outputs = self.electra(  # 调用 self.electra 进行模型计算
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = outputs[0]  # 获取模型输出的第一个元素,即隐藏状态

        hidden_states = self.dropout(hidden_states, deterministic=deterministic)  # 对隐藏状态应用 dropout 操作
        logits = self.classifier(hidden_states)  # 将隐藏状态传递给分类器生成 logits

        if not return_dict:
            return (logits,) + outputs[1:]  # 如果 return_dict 为 False,则返回元组形式的结果

        return FlaxTokenClassifierOutput(  # 返回 FlaxTokenClassifierOutput 对象
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


@add_start_docstrings(
    """
    Electra model with a token classification head on top.

    Both the discriminator and generator may be loaded into this model.
    """,
    ELECTRA_START_DOCSTRING,  # 添加文档字符串,结合 ELECTRA_START_DOCSTRING 定义
)
class FlaxElectraForTokenClassification(FlaxElectraPreTrainedModel):
    module_class = FlaxElectraForTokenClassificationModule  # 设置模块类为 FlaxElectraForTokenClassificationModule


append_call_sample_docstring(
    FlaxElectraForTokenClassification,
    _CHECKPOINT_FOR_DOC,  # 添加函数调用示例的文档字符串,使用 _CHECKPOINT_FOR_DOC 参数
    FlaxTokenClassifierOutput,
    _CONFIG_FOR_DOC,  # 结合 _CONFIG_FOR_DOC 参数
)


def identity(x, **kwargs):
    return x  # 定义一个简单的函数 identity,返回其输入参数 x


class FlaxElectraSequenceSummary(nn.Module):
    r"""
    Compute a single vector summary of a sequence hidden states.
    """
    Args:
        config ([`PretrainedConfig`]):
            The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
            config class of your model for the default values it uses):

            - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
            - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
              (otherwise to `config.hidden_size`).
            - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
              another string or `None` will add no activation.
            - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
            - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
    """
    # 定义一个类变量config,它是一个ElectraConfig对象
    config: ElectraConfig
    # 定义一个数据类型变量dtype,默认为jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 类的初始化方法
    def setup(self):
        # 设置summary初始值为identity函数
        self.summary = identity
        # 检查config对象是否有summary_use_proj属性,并且它为True
        if hasattr(self.config, "summary_use_proj") and self.config.summary_use_proj:
            # 检查config对象是否有summary_proj_to_labels属性,并且它为True,并且config.num_labels大于0
            if (
                hasattr(self.config, "summary_proj_to_labels")
                and self.config.summary_proj_to_labels
                and self.config.num_labels > 0
            ):
                # 设置num_classes为config.num_labels
                num_classes = self.config.num_labels
            else:
                # 否则设置num_classes为config.hidden_size
                num_classes = self.config.hidden_size
            # 将summary设置为一个全连接层nn.Dense,输出维度为num_classes,数据类型为self.dtype
            self.summary = nn.Dense(num_classes, dtype=self.dtype)

        # 获取summary_activation字符串属性值
        activation_string = getattr(self.config, "summary_activation", None)
        # 根据activation_string获取对应的激活函数,如果为None则使用恒等函数lambda x: x
        self.activation = ACT2FN[activation_string] if activation_string else lambda x: x  # noqa F407

        # 设置first_dropout初始值为identity函数
        self.first_dropout = identity
        # 检查config对象是否有summary_first_dropout属性,并且其值大于0
        if hasattr(self.config, "summary_first_dropout") and self.config.summary_first_dropout > 0:
            # 将first_dropout设置为一个Dropout层,丢弃概率为config.summary_first_dropout
            self.first_dropout = nn.Dropout(self.config.summary_first_dropout)

        # 设置last_dropout初始值为identity函数
        self.last_dropout = identity
        # 检查config对象是否有summary_last_dropout属性,并且其值大于0
        if hasattr(self.config, "summary_last_dropout") and self.config.summary_last_dropout > 0:
            # 将last_dropout设置为一个Dropout层,丢弃概率为config.summary_last_dropout
            self.last_dropout = nn.Dropout(self.config.summary_last_dropout)
    def __call__(self, hidden_states, cls_index=None, deterministic: bool = True):
        """
        Compute a single vector summary of a sequence hidden states.

        Args:
            hidden_states (`jnp.ndarray` of shape `[batch_size, seq_len, hidden_size]`):
                The hidden states of the last layer.
            cls_index (`jnp.ndarray` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
                Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.

        Returns:
            `jnp.ndarray`: The summary of the sequence hidden states.
        """
        # NOTE: This function computes a summary vector of the sequence hidden states.

        # Extract the first token's hidden state from each sequence in the batch
        output = hidden_states[:, 0]

        # Apply dropout to the extracted hidden state
        output = self.first_dropout(output, deterministic=deterministic)

        # Compute the summary vector using a predefined method
        output = self.summary(output)

        # Apply an activation function to the computed summary vector
        output = self.activation(output)

        # Apply dropout to the final output vector before returning
        output = self.last_dropout(output, deterministic=deterministic)

        # Return the final summary vector
        return output
# 定义一个基于 Flax 的 Electra 多选题模型的模块类
class FlaxElectraForMultipleChoiceModule(nn.Module):
    # 指定配置对象为 ElectraConfig
    config: ElectraConfig
    # 指定数据类型为 jnp.float32 的浮点数
    dtype: jnp.dtype = jnp.float32
    # 梯度检查点,默认为关闭状态
    gradient_checkpointing: bool = False

    # 模块初始化方法
    def setup(self):
        # 创建 Electra 模型对象
        self.electra = FlaxElectraModule(
            config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
        )
        # 创建序列摘要对象
        self.sequence_summary = FlaxElectraSequenceSummary(config=self.config, dtype=self.dtype)
        # 创建分类器对象,使用 Dense 层,输出维度为 1
        self.classifier = nn.Dense(1, dtype=self.dtype)

    # 对象调用方法,处理输入并返回输出
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 获取选择题的数量
        num_choices = input_ids.shape[1]
        # 若输入不为 None,则重塑输入的形状以便处理
        input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
        attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
        token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
        position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None

        # 使用 Electra 模型进行前向传播
        outputs = self.electra(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 提取隐藏状态
        hidden_states = outputs[0]
        # 对隐藏状态进行序列摘要
        pooled_output = self.sequence_summary(hidden_states, deterministic=deterministic)
        # 使用分类器进行分类,生成逻辑回归结果
        logits = self.classifier(pooled_output)

        # 重塑 logits 的形状以匹配输入的多选题数量
        reshaped_logits = logits.reshape(-1, num_choices)

        # 如果不返回字典,则返回元组形式的结果
        if not return_dict:
            return (reshaped_logits,) + outputs[1:]

        # 返回多选题模型的输出,包括重塑后的 logits,隐藏状态和注意力
        return FlaxMultipleChoiceModelOutput(
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


# 为 FlaxElectraForMultipleChoice 类添加文档字符串,描述其功能和用途
@add_start_docstrings(
    """
    ELECTRA Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
    softmax) e.g. for RocStories/SWAG tasks.
    """,
    ELECTRA_START_DOCSTRING,
)
class FlaxElectraForMultipleChoice(FlaxElectraPreTrainedModel):
    module_class = FlaxElectraForMultipleChoiceModule


# 为 FlaxElectraForMultipleChoice 类的调用方法添加文档字符串示例
overwrite_call_docstring(
    FlaxElectraForMultipleChoice, ELECTRA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
)
# 为 FlaxElectraForMultipleChoice 类添加调用方法的样例文档字符串
append_call_sample_docstring(
    FlaxElectraForMultipleChoice,
    _CHECKPOINT_FOR_DOC,
    FlaxMultipleChoiceModelOutput,
    _CONFIG_FOR_DOC,
)


# 定义一个基于 Flax 的 Electra 问答模型的模块类
class FlaxElectraForQuestionAnsweringModule(nn.Module):
    # 指定配置对象为 ElectraConfig
    config: ElectraConfig
    # 指定数据类型为 jnp.float32 的浮点数
    dtype: jnp.dtype = jnp.float32
    # 设置类中的梯度检查点标志,默认为 False
    gradient_checkpointing: bool = False
    
    # 初始化模型设置
    def setup(self):
        # 使用给定的配置、数据类型和梯度检查点设置创建 FlaxElectraModule 实例
        self.electra = FlaxElectraModule(
            config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
        )
        # 创建输出层,用于问题回答任务,输出维度为 self.config.num_labels
        self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
    
    # 定义对象的调用方法,处理输入并返回预测结果
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 调用 Electra 模型进行前向传播
        outputs = self.electra(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 获取隐藏状态
        hidden_states = outputs[0]
        # 使用输出层计算起始和结束位置的 logits
        logits = self.qa_outputs(hidden_states)
        # 按输出类别数将 logits 分割为起始位置和结束位置的 logits
        start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
        # 去除最后一个维度上的冗余维度
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
    
        # 如果不返回字典,则返回元组形式的结果
        if not return_dict:
            return (start_logits, end_logits) + outputs[1:]
    
        # 返回 FlaxQuestionAnsweringModelOutput 对象,包含起始和结束 logits、隐藏状态和注意力
        return FlaxQuestionAnsweringModelOutput(
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
@add_start_docstrings(
    """
    ELECTRA Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
    ELECTRA_START_DOCSTRING,
)
class FlaxElectraForQuestionAnswering(FlaxElectraPreTrainedModel):
    module_class = FlaxElectraForQuestionAnsweringModule

append_call_sample_docstring(
    FlaxElectraForQuestionAnswering,
    _CHECKPOINT_FOR_DOC,
    FlaxQuestionAnsweringModelOutput,
    _CONFIG_FOR_DOC,
)


class FlaxElectraClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    config: ElectraConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # Initialize a fully connected layer with hidden_size neurons
        self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
        
        # Determine dropout rate based on config values
        classifier_dropout = (
            self.config.classifier_dropout
            if self.config.classifier_dropout is not None
            else self.config.hidden_dropout_prob
        )
        # Apply dropout with computed rate
        self.dropout = nn.Dropout(classifier_dropout)
        
        # Final output layer with num_labels neurons
        self.out_proj = nn.Dense(self.config.num_labels, dtype=self.dtype)

    def __call__(self, hidden_states, deterministic: bool = True):
        # Extract the representation of the first token (<s>) from hidden_states
        x = hidden_states[:, 0, :]
        
        # Apply dropout to the extracted token representation
        x = self.dropout(x, deterministic=deterministic)
        
        # Pass through the fully connected layer
        x = self.dense(x)
        
        # Apply GELU activation function (similar to BERT's tanh)
        x = ACT2FN["gelu"](x)
        
        # Apply dropout again
        x = self.dropout(x, deterministic=deterministic)
        
        # Pass through the output layer
        x = self.out_proj(x)
        
        # Return the logits for sequence classification
        return x


class FlaxElectraForSequenceClassificationModule(nn.Module):
    config: ElectraConfig
    dtype: jnp.dtype = jnp.float32
    gradient_checkpointing: bool = False

    def setup(self):
        # Initialize Electra module with specified configuration
        self.electra = FlaxElectraModule(
            config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
        )
        
        # Initialize classification head using the same configuration
        self.classifier = FlaxElectraClassificationHead(config=self.config, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        # 如果 `return_dict` 为 True,则返回一个命名元组对象 FlaxSequenceClassifierOutput
        # 包含 logits, hidden_states 和 attentions 这些字段
        if not return_dict:
            # 如果 `return_dict` 为 False,返回一个元组,包含 logits 和 outputs 的其余部分
            return (logits,) + outputs[1:]

        # 如果 `return_dict` 为 True,返回一个 FlaxSequenceClassifierOutput 对象
        return FlaxSequenceClassifierOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
@add_start_docstrings(
    """
    Electra Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    """,
    ELECTRA_START_DOCSTRING,
)



class FlaxElectraForSequenceClassification(FlaxElectraPreTrainedModel):
    module_class = FlaxElectraForSequenceClassificationModule



append_call_sample_docstring(
    FlaxElectraForSequenceClassification,
    _CHECKPOINT_FOR_DOC,
    FlaxSequenceClassifierOutput,
    _CONFIG_FOR_DOC,
)



class FlaxElectraForCausalLMModule(nn.Module):
    config: ElectraConfig
    dtype: jnp.dtype = jnp.float32
    gradient_checkpointing: bool = False

    def setup(self):
        self.electra = FlaxElectraModule(
            config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
        )
        self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
        if self.config.tie_word_embeddings:
            self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
        else:
            self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask: Optional[jnp.ndarray] = None,
        token_type_ids: Optional[jnp.ndarray] = None,
        position_ids: Optional[jnp.ndarray] = None,
        head_mask: Optional[jnp.ndarray] = None,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,


**注释:**


# 添加起始文档字符串,描述此模型是基于Electra模型的序列分类/回归头(线性层叠加在汇总输出之上),例如用于GLUE任务。
@add_start_docstrings(
    """
    Electra Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    """,
    ELECTRA_START_DOCSTRING,
)

# 定义用于序列分类的FlaxElectraForSequenceClassification类,继承自FlaxElectraPreTrainedModel类。
class FlaxElectraForSequenceClassification(FlaxElectraPreTrainedModel):
    module_class = FlaxElectraForSequenceClassificationModule

# 向FlaxElectraForSequenceClassification类添加调用示例文档字符串。
append_call_sample_docstring(
    FlaxElectraForSequenceClassification,
    _CHECKPOINT_FOR_DOC,
    FlaxSequenceClassifierOutput,
    _CONFIG_FOR_DOC,
)

# 定义用于因果语言模型的FlaxElectraForCausalLMModule类,继承自nn.Module。
class FlaxElectraForCausalLMModule(nn.Module):
    config: ElectraConfig  # 类型注解,指定config属性的类型为ElectraConfig。
    dtype: jnp.dtype = jnp.float32  # 类型注解,指定dtype属性的类型,默认为jnp.float32。
    gradient_checkpointing: bool = False  # 类型注解,指定gradient_checkpointing属性的类型,默认为False。

    # 模块的设置方法
    def setup(self):
        # 创建Electra模块并赋值给self.electra属性,根据配置、数据类型和梯度检查点设置。
        self.electra = FlaxElectraModule(
            config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
        )
        # 创建生成器预测模块并赋值给self.generator_predictions属性,根据配置和数据类型设置。
        self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
        # 如果配置要求共享词嵌入,则创建FlaxElectraTiedDense类型的生成器语言模型头部,否则创建普通的nn.Dense。
        if self.config.tie_word_embeddings:
            self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
        else:
            self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype)

    # 模块的调用方法,接收多个输入参数,执行因果语言模型的计算。
    def __call__(
        self,
        input_ids,
        attention_mask: Optional[jnp.ndarray] = None,
        token_type_ids: Optional[jnp.ndarray] = None,
        position_ids: Optional[jnp.ndarray] = None,
        head_mask: Optional[jnp.ndarray] = None,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        ):
            # 调用 ELECTRA 模型进行推理,获取输出结果
            outputs = self.electra(
                input_ids,
                attention_mask,
                token_type_ids,
                position_ids,
                head_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                init_cache=init_cache,
                deterministic=deterministic,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            # 从模型输出中获取隐藏状态
            hidden_states = outputs[0]
            # 使用生成器生成预测分数
            prediction_scores = self.generator_predictions(hidden_states)

            # 如果配置指定词嵌入共享
            if self.config.tie_word_embeddings:
                # 获取 ELECTRA 模型中的共享词嵌入参数
                shared_embedding = self.electra.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
                # 使用共享词嵌入进行生成器的 LM 头部预测
                prediction_scores = self.generator_lm_head(prediction_scores, shared_embedding.T)
            else:
                # 否则,直接使用生成器的 LM 头部进行预测
                prediction_scores = self.generator_lm_head(prediction_scores)

            # 如果不返回字典形式的输出
            if not return_dict:
                # 返回包含预测分数和额外输出的元组
                return (prediction_scores,) + outputs[1:]

            # 返回 FlaxCausalLMOutputWithCrossAttentions 类的对象,其中包含预测分数、隐藏状态、注意力权重及交叉注意力权重
            return FlaxCausalLMOutputWithCrossAttentions(
                logits=prediction_scores,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
                cross_attentions=outputs.cross_attentions,
            )
@add_start_docstrings(
    """
    Electra Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
    autoregressive tasks.
    """,
    ELECTRA_START_DOCSTRING,
)
# 基于 transformers.models.bert.modeling_flax_bert.FlaxBertForCausalLM 中的代码,将 Bert 替换为 Electra
class FlaxElectraForCausalLM(FlaxElectraPreTrainedModel):
    module_class = FlaxElectraForCausalLMModule

    def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
        # 初始化缓存
        batch_size, seq_length = input_ids.shape

        # 使用模型的初始化缓存方法创建过去键值对
        past_key_values = self.init_cache(batch_size, max_length)
        
        # 注意:通常情况下,需要在 attention_mask 中对超出 input_ids.shape[-1] 和小于 cache_length 的位置填充 0
        # 但由于解码器使用因果遮蔽,这些位置已经被遮蔽了
        # 因此,我们可以在这里创建一个静态的 attention_mask,这对于编译来说更有效
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
        if attention_mask is not None:
            # 计算位置 ID,根据 attention_mask 累积和减去 1
            position_ids = attention_mask.cumsum(axis=-1) - 1
            # 更新 extended_attention_mask,使用 attention_mask 进行动态更新切片
            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
        else:
            # 如果没有提供 attention_mask,则广播生成位置 ID
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))

        return {
            "past_key_values": past_key_values,
            "attention_mask": extended_attention_mask,
            "position_ids": position_ids,
        }

    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        # 更新生成过程中的模型参数
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
        return model_kwargs


# 将样例调用的文档字符串附加到类 FlaxElectraForCausalLM 上,用于文档化
append_call_sample_docstring(
    FlaxElectraForCausalLM,
    _CHECKPOINT_FOR_DOC,
    FlaxCausalLMOutputWithCrossAttentions,
    _CONFIG_FOR_DOC,
)

.\models\electra\modeling_tf_electra.py

# coding=utf-8
# Copyright 2019 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF Electra model."""


from __future__ import annotations

import math  # 导入数学库
import warnings  # 导入警告模块
from dataclasses import dataclass  # 导入 dataclass 用于创建结构化的类
from typing import Optional, Tuple, Union  # 导入类型提示相关库

import numpy as np  # 导入 numpy 库
import tensorflow as tf  # 导入 TensorFlow 库

from ...activations_tf import get_tf_activation  # 导入获取 TensorFlow 激活函数的函数
from ...modeling_tf_outputs import (  # 导入 TensorFlow 模型输出相关类
    TFBaseModelOutputWithPastAndCrossAttentions,
    TFMaskedLMOutput,
    TFMultipleChoiceModelOutput,
    TFQuestionAnsweringModelOutput,
    TFSequenceClassifierOutput,
    TFTokenClassifierOutput,
)
from ...modeling_tf_utils import (  # 导入 TensorFlow 模型工具函数和类
    TFMaskedLanguageModelingLoss,
    TFModelInputType,
    TFMultipleChoiceLoss,
    TFPreTrainedModel,
    TFQuestionAnsweringLoss,
    TFSequenceClassificationLoss,
    TFSequenceSummary,
    TFTokenClassificationLoss,
    get_initializer,
    keras,
    keras_serializable,
    unpack_inputs,
)
from ...tf_utils import (  # 导入 TensorFlow 工具函数
    check_embeddings_within_bounds,
    shape_list,
    stable_softmax,
)
from ...utils import (  # 导入通用工具函数和类
    ModelOutput,
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from .configuration_electra import ElectraConfig  # 导入 Electra 的配置类


logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器对象

_CHECKPOINT_FOR_DOC = "google/electra-small-discriminator"  # 用于文档的预训练模型检查点
_CONFIG_FOR_DOC = "ElectraConfig"  # 用于文档的 Electra 配置


TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [  # 预训练模型的存档列表
    "google/electra-small-generator",
    "google/electra-base-generator",
    "google/electra-large-generator",
    "google/electra-small-discriminator",
    "google/electra-base-discriminator",
    "google/electra-large-discriminator",
    # 查看所有 ELECTRA 模型:https://huggingface.co/models?filter=electra
]


# 从 transformers.models.bert.modeling_tf_bert.TFBertSelfAttention 复制并修改为 Electra 模型
class TFElectraSelfAttention(keras.layers.Layer):
    def __init__(self, config: ElectraConfig, **kwargs):
        # 调用父类初始化方法,传入额外的关键字参数
        super().__init__(**kwargs)

        # 检查隐藏层大小是否能被注意力头数整除
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number "
                f"of attention heads ({config.num_attention_heads})"
            )

        # 设置注意力头数和每个注意力头的大小
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.sqrt_att_head_size = math.sqrt(self.attention_head_size)

        # 创建查询、键、值的全连接层,并初始化
        self.query = keras.layers.Dense(
            units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
        )
        self.key = keras.layers.Dense(
            units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
        )
        self.value = keras.layers.Dense(
            units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
        )
        # 设置 dropout 层,用于注意力概率的 dropout
        self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)

        # 标记是否为解码器
        self.is_decoder = config.is_decoder
        # 保存配置对象
        self.config = config

    def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
        # 将输入张量从 [batch_size, seq_length, all_head_size] 重塑为 [batch_size, seq_length, num_attention_heads, attention_head_size]
        tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))

        # 将张量进行转置,从 [batch_size, seq_length, num_attention_heads, attention_head_size] 变为 [batch_size, num_attention_heads, seq_length, attention_head_size]
        return tf.transpose(tensor, perm=[0, 2, 1, 3])

    def call(
        self,
        hidden_states: tf.Tensor,
        attention_mask: tf.Tensor,
        head_mask: tf.Tensor,
        encoder_hidden_states: tf.Tensor,
        encoder_attention_mask: tf.Tensor,
        past_key_value: Tuple[tf.Tensor],
        output_attentions: bool,
        training: bool = False,
    ):
        # 在这里实现模型的前向传播
        # (这部分根据代码的未提供,无法详细解释其具体功能)
        pass

    def build(self, input_shape=None):
        # 如果已经构建过,则直接返回
        if self.built:
            return
        self.built = True
        # 检查并构建查询、键、值的全连接层
        if getattr(self, "query", None) is not None:
            with tf.name_scope(self.query.name):
                self.query.build([None, None, self.config.hidden_size])
        if getattr(self, "key", None) is not None:
            with tf.name_scope(self.key.name):
                self.key.build([None, None, self.config.hidden_size])
        if getattr(self, "value", None) is not None:
            with tf.name_scope(self.value.name):
                self.value.build([None, None, self.config.hidden_size])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Electra
class TFElectraSelfOutput(keras.layers.Layer):
    def __init__(self, config: ElectraConfig, **kwargs):
        super().__init__(**kwargs)

        # 定义一个全连接层,用于映射输入到隐藏状态大小的输出
        self.dense = keras.layers.Dense(
            units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        # 定义LayerNormalization层,用于归一化隐藏状态
        self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        # 定义Dropout层,用于随机丢弃部分隐藏状态,防止过拟合
        self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
        self.config = config

    def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
        # 全连接层处理隐藏状态
        hidden_states = self.dense(inputs=hidden_states)
        # Dropout层应用于全连接层的输出
        hidden_states = self.dropout(inputs=hidden_states, training=training)
        # LayerNormalization层应用于处理后的隐藏状态和输入张量的残差连接
        hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)

        return hidden_states

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果dense层已定义,则根据输入形状构建dense层
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                self.dense.build([None, None, self.config.hidden_size])
        # 如果LayerNorm层已定义,则根据输入形状构建LayerNorm层
        if getattr(self, "LayerNorm", None) is not None:
            with tf.name_scope(self.LayerNorm.name):
                self.LayerNorm.build([None, None, self.config.hidden_size])


# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->Electra
class TFElectraAttention(keras.layers.Layer):
    def __init__(self, config: ElectraConfig, **kwargs):
        super().__init__(**kwargs)

        # 定义Electra自注意力层
        self.self_attention = TFElectraSelfAttention(config, name="self")
        # 定义Electra自输出层
        self.dense_output = TFElectraSelfOutput(config, name="output")

    def prune_heads(self, heads):
        raise NotImplementedError

    def call(
        self,
        input_tensor: tf.Tensor,
        attention_mask: tf.Tensor,
        head_mask: tf.Tensor,
        encoder_hidden_states: tf.Tensor,
        encoder_attention_mask: tf.Tensor,
        past_key_value: Tuple[tf.Tensor],
        output_attentions: bool,
        training: bool = False,
    ) -> Tuple[tf.Tensor]:
        # 调用自注意力层处理输入Tensor
        self_outputs = self.self_attention(
            hidden_states=input_tensor,
            attention_mask=attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            training=training,
        )
        # 调用自输出层处理自注意力层的输出和原始输入Tensor
        attention_output = self.dense_output(
            hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
        )
        # 如果需要输出注意力值,将它们与输出合并
        outputs = (attention_output,) + self_outputs[1:]

        return outputs
    # 定义一个方法 `build`,用于构建神经网络层
    def build(self, input_shape=None):
        # 如果已经构建过,直接返回,避免重复构建
        if self.built:
            return
        # 标记为已构建
        self.built = True
        # 如果存在 self_attention 属性
        if getattr(self, "self_attention", None) is not None:
            # 在命名作用域内,构建 self_attention 层
            with tf.name_scope(self.self_attention.name):
                self.self_attention.build(None)
        # 如果存在 dense_output 属性
        if getattr(self, "dense_output", None) is not None:
            # 在命名作用域内,构建 dense_output 层
            with tf.name_scope(self.dense_output.name):
                self.dense_output.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Electra
class TFElectraIntermediate(keras.layers.Layer):
    def __init__(self, config: ElectraConfig, **kwargs):
        super().__init__(**kwargs)

        # 创建一个全连接层,输出单元数为 config.intermediate_size,权重初始化方式为 config.initializer_range
        self.dense = keras.layers.Dense(
            units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )

        # 根据配置选择激活函数,如果是字符串形式,则通过工具函数获取对应的 TensorFlow 激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = get_tf_activation(config.hidden_act)
        else:
            self.intermediate_act_fn = config.hidden_act
        self.config = config

    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        # 输入 hidden_states 经过全连接层 dense 处理
        hidden_states = self.dense(inputs=hidden_states)
        # 经过选定的激活函数处理后返回
        hidden_states = self.intermediate_act_fn(hidden_states)

        return hidden_states

    def build(self, input_shape=None):
        # 如果已经构建过,直接返回
        if self.built:
            return
        self.built = True
        # 如果存在 dense 层,则按照给定的形状构建 dense 层
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                self.dense.build([None, None, self.config.hidden_size])


# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Electra
class TFElectraOutput(keras.layers.Layer):
    def __init__(self, config: ElectraConfig, **kwargs):
        super().__init__(**kwargs)

        # 创建一个全连接层,输出单元数为 config.hidden_size,权重初始化方式为 config.initializer_range
        self.dense = keras.layers.Dense(
            units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        # 创建 LayerNormalization 层,epsilon 为 config.layer_norm_eps
        self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        # 创建 Dropout 层,dropout rate 为 config.hidden_dropout_prob
        self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
        self.config = config

    def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
        # 输入 hidden_states 经过全连接层 dense 处理
        hidden_states = self.dense(inputs=hidden_states)
        # 根据训练状态应用 Dropout
        hidden_states = self.dropout(inputs=hidden_states, training=training)
        # 输入 hidden_states 与 input_tensor 相加后,经过 LayerNormalization 处理
        hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)

        return hidden_states

    def build(self, input_shape=None):
        # 如果已经构建过,直接返回
        if self.built:
            return
        self.built = True
        # 如果存在 dense 层,则按照给定的形状构建 dense 层
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                self.dense.build([None, None, self.config.intermediate_size])
        # 如果存在 LayerNorm 层,则按照给定的形状构建 LayerNorm 层
        if getattr(self, "LayerNorm", None) is not None:
            with tf.name_scope(self.LayerNorm.name):
                self.LayerNorm.build([None, None, self.config.hidden_size])


# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->Electra
class TFElectraLayer(keras.layers.Layer):
    # 这部分未提供完整代码,暂无法添加注释
    # 初始化 ElectraModel 类的实例
    def __init__(self, config: ElectraConfig, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)

        # 创建注意力机制对象,并命名为 'attention'
        self.attention = TFElectraAttention(config, name="attention")
        
        # 检查当前模型是否为解码器
        self.is_decoder = config.is_decoder
        
        # 检查是否需要添加跨注意力机制
        self.add_cross_attention = config.add_cross_attention
        
        # 如果需要添加跨注意力机制且当前模型不是解码器,则抛出值错误异常
        if self.add_cross_attention:
            if not self.is_decoder:
                raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
            
            # 创建跨注意力机制对象,并命名为 'crossattention'
            self.crossattention = TFElectraAttention(config, name="crossattention")
        
        # 创建 Electra 中间层对象,并命名为 'intermediate'
        self.intermediate = TFElectraIntermediate(config, name="intermediate")
        
        # 创建 Electra 输出层对象,并命名为 'output'
        self.bert_output = TFElectraOutput(config, name="output")

    # 定义调用方法,用于前向传播计算
    def call(
        self,
        hidden_states: tf.Tensor,
        attention_mask: tf.Tensor,
        head_mask: tf.Tensor,
        encoder_hidden_states: tf.Tensor | None,
        encoder_attention_mask: tf.Tensor | None,
        past_key_value: Tuple[tf.Tensor] | None,
        output_attentions: bool,
        training: bool = False,
        # 定义函数的输入和输出类型,这里返回一个包含 Tensor 元组的 Tuple
        ) -> Tuple[tf.Tensor]:
        # 如果过去的键/值缓存不为空,提取自注意力的缓存键/值元组的前两个位置
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
        # 调用自注意力层处理隐藏状态,生成自注意力输出
        self_attention_outputs = self.attention(
            input_tensor=hidden_states,
            attention_mask=attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            past_key_value=self_attn_past_key_value,
            output_attentions=output_attentions,
            training=training,
        )
        # 提取自注意力的输出
        attention_output = self_attention_outputs[0]

        # 如果是解码器,最后一个输出是自注意力缓存的元组
        if self.is_decoder:
            # 提取除了自注意力输出之外的所有输出作为结果
            outputs = self_attention_outputs[1:-1]
            # 提取自注意力的当前键/值缓存
            present_key_value = self_attention_outputs[-1]
        else:
            # 如果不是解码器,添加自注意力权重输出到结果中
            outputs = self_attention_outputs[1:]

        # 初始化交叉注意力的当前键/值缓存为 None
        cross_attn_present_key_value = None
        # 如果是解码器且存在编码器的隐藏状态
        if self.is_decoder and encoder_hidden_states is not None:
            # 如果模型没有交叉注意力层,抛出数值错误
            if not hasattr(self, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
                    " by setting `config.add_cross_attention=True`"
                )

            # 如果过去的键/值缓存不为空,提取交叉注意力的缓存键/值元组的后两个位置
            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
            # 调用交叉注意力层处理自注意力输出,生成交叉注意力输出
            cross_attention_outputs = self.crossattention(
                input_tensor=attention_output,
                attention_mask=attention_mask,
                head_mask=head_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                past_key_value=cross_attn_past_key_value,
                output_attentions=output_attentions,
                training=training,
            )
            # 提取交叉注意力的输出
            attention_output = cross_attention_outputs[0]
            # 将交叉注意力的输出添加到结果中
            outputs = outputs + cross_attention_outputs[1:-1]

            # 将交叉注意力的当前键/值缓存添加到当前键/值缓存中
            cross_attn_present_key_value = cross_attention_outputs[-1]
            present_key_value = present_key_value + cross_attn_present_key_value

        # 通过中间层处理注意力输出,生成中间输出
        intermediate_output = self.intermediate(hidden_states=attention_output)
        # 通过 BERT 输出层处理中间输出和注意力输出,生成层输出
        layer_output = self.bert_output(
            hidden_states=intermediate_output, input_tensor=attention_output, training=training
        )
        # 将注意力输出添加到结果中
        outputs = (layer_output,) + outputs

        # 如果是解码器,将注意力的键/值作为最后一个输出添加到结果中
        if self.is_decoder:
            outputs = outputs + (present_key_value,)

        # 返回所有输出作为函数的结果
        return outputs
    # 构建方法,用于建立模型的网络结构
    def build(self, input_shape=None):
        # 如果已经建立过,直接返回,避免重复建立
        if self.built:
            return
        # 设置标记,表示模型已经建立
        self.built = True
        
        # 如果存在注意力模型,建立其网络结构
        if getattr(self, "attention", None) is not None:
            with tf.name_scope(self.attention.name):
                self.attention.build(None)
        
        # 如果存在中间层模型,建立其网络结构
        if getattr(self, "intermediate", None) is not None:
            with tf.name_scope(self.intermediate.name):
                self.intermediate.build(None)
        
        # 如果存在BERT输出模型,建立其网络结构
        if getattr(self, "bert_output", None) is not None:
            with tf.name_scope(self.bert_output.name):
                self.bert_output.build(None)
        
        # 如果存在交叉注意力模型,建立其网络结构
        if getattr(self, "crossattention", None) is not None:
            with tf.name_scope(self.crossattention.name):
                self.crossattention.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->Electra
class TFElectraEncoder(keras.layers.Layer):
    def __init__(self, config: ElectraConfig, **kwargs):
        super().__init__(**kwargs)
        self.config = config
        # 创建多个 Electra 层组成的列表,每层命名为 "layer_._{i}"
        self.layer = [TFElectraLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]

    def call(
        self,
        hidden_states: tf.Tensor,
        attention_mask: tf.Tensor,
        head_mask: tf.Tensor,
        encoder_hidden_states: tf.Tensor | None,
        encoder_attention_mask: tf.Tensor | None,
        past_key_values: Tuple[Tuple[tf.Tensor]] | None,
        use_cache: Optional[bool],
        output_attentions: bool,
        output_hidden_states: bool,
        return_dict: bool,
        training: bool = False,
    ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
        # 初始化存储所有隐藏状态、注意力等的空元组,如果不需要输出则为 None
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

        # 如果需要缓存下一层的输出,则初始化为空元组
        next_decoder_cache = () if use_cache else None
        for i, layer_module in enumerate(self.layer):
            # 如果需要输出隐藏状态,则将当前层的隐藏状态添加到 all_hidden_states 中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 如果有过去的键值对,则获取当前层的过去键值对
            past_key_value = past_key_values[i] if past_key_values is not None else None

            # 调用当前层的前向传播方法,获取当前层的输出
            layer_outputs = layer_module(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                head_mask=head_mask[i],
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                training=training,
            )
            hidden_states = layer_outputs[0]  # 更新当前隐藏状态为当前层的输出的第一个元素

            # 如果需要缓存,将当前层的输出最后一个元素添加到下一层的缓存中
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)

            # 如果需要输出注意力,将当前层的注意力添加到 all_attentions 中
            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)
                # 如果配置要求添加跨层注意力,并且有编码器隐藏状态,则添加跨层注意力
                if self.config.add_cross_attention and encoder_hidden_states is not None:
                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

        # 添加最后一层的隐藏状态到 all_hidden_states 中
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 如果不需要返回字典,则返回非空的元组
        if not return_dict:
            return tuple(
                v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
            )

        # 返回带有过去键值对和跨层注意力的 Electra 模型输出对象
        return TFBaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
            cross_attentions=all_cross_attentions,
        )
    # 定义一个方法 `build`,用于构建神经网络层
    def build(self, input_shape=None):
        # 如果已经构建过网络,则直接返回
        if self.built:
            return
        # 将标志位 `built` 设为 True,表示网络已构建
        self.built = True
        # 如果属性 `layer` 存在
        if getattr(self, "layer", None) is not None:
            # 遍历每个层对象
            for layer in self.layer:
                # 在 TensorFlow 的命名空间中,按层的名称设置命名空间
                with tf.name_scope(layer.name):
                    # 调用每个层对象的 `build` 方法来构建层,参数为 None
                    layer.build(None)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Electra
class TFElectraPooler(keras.layers.Layer):
    def __init__(self, config: ElectraConfig, **kwargs):
        super().__init__(**kwargs)

        # Initialize a dense layer for pooling with specified hidden size, tanh activation, and initializer
        self.dense = keras.layers.Dense(
            units=config.hidden_size,
            kernel_initializer=get_initializer(config.initializer_range),
            activation="tanh",
            name="dense",
        )
        self.config = config

    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        # Extract the hidden state of the first token for pooling
        first_token_tensor = hidden_states[:, 0]
        # Apply the dense layer to the first token's hidden state for pooling
        pooled_output = self.dense(inputs=first_token_tensor)

        return pooled_output

    def build(self, input_shape=None):
        # Check if layer is already built
        if self.built:
            return
        self.built = True
        # Build the dense layer with specified input shape and hidden size from config
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                self.dense.build([None, None, self.config.hidden_size])


# Copied from transformers.models.albert.modeling_tf_albert.TFAlbertEmbeddings with Albert->Electra
class TFElectraEmbeddings(keras.layers.Layer):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, config: ElectraConfig, **kwargs):
        super().__init__(**kwargs)

        self.config = config
        self.embedding_size = config.embedding_size
        self.max_position_embeddings = config.max_position_embeddings
        self.initializer_range = config.initializer_range
        # Layer normalization for embeddings with specified epsilon
        self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        # Dropout layer for embeddings with specified dropout rate
        self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)

    def build(self, input_shape=None):
        # Build word embeddings with vocab size and embedding size
        with tf.name_scope("word_embeddings"):
            self.weight = self.add_weight(
                name="weight",
                shape=[self.config.vocab_size, self.embedding_size],
                initializer=get_initializer(self.initializer_range),
            )

        # Build token type embeddings with type vocab size and embedding size
        with tf.name_scope("token_type_embeddings"):
            self.token_type_embeddings = self.add_weight(
                name="embeddings",
                shape=[self.config.type_vocab_size, self.embedding_size],
                initializer=get_initializer(self.initializer_range),
            )

        # Build position embeddings with max position and embedding size
        with tf.name_scope("position_embeddings"):
            self.position_embeddings = self.add_weight(
                name="embeddings",
                shape=[self.max_position_embeddings, self.embedding_size],
                initializer=get_initializer(self.initializer_range),
            )

        # Check if layer is already built
        if self.built:
            return
        self.built = True
        # Build layer normalization for embeddings with specified input shape and embedding size
        if getattr(self, "LayerNorm", None) is not None:
            with tf.name_scope(self.LayerNorm.name):
                self.LayerNorm.build([None, None, self.config.embedding_size])
    # Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call
    def call(
        self,
        input_ids: tf.Tensor = None,                   # 输入的 token ids 张量
        position_ids: tf.Tensor = None,                # 位置 ids 张量
        token_type_ids: tf.Tensor = None,              # token 类型 ids 张量
        inputs_embeds: tf.Tensor = None,               # 嵌入的输入张量
        past_key_values_length=0,                      # 过去的键值对长度,默认为0
        training: bool = False,                        # 是否处于训练模式的布尔值
    ) -> tf.Tensor:
        """
        Applies embedding based on inputs tensor.
    
        Returns:
            final_embeddings (`tf.Tensor`): output embedding tensor.
        """
        if input_ids is None and inputs_embeds is None:
            raise ValueError("Need to provide either `input_ids` or `input_embeds`.")
    
        if input_ids is not None:
            check_embeddings_within_bounds(input_ids, self.config.vocab_size)
            inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
    
        input_shape = shape_list(inputs_embeds)[:-1]   # 获取输入嵌入张量的形状
    
        if token_type_ids is None:
            token_type_ids = tf.fill(dims=input_shape, value=0)   # 如果没有指定 token 类型 ids,默认填充为0
    
        if position_ids is None:
            position_ids = tf.expand_dims(
                tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
            )
            # 如果没有指定位置 ids,生成一个范围为 [past_key_values_length, input_shape[1] + past_key_values_length) 的张量
    
        position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
        # 根据位置 ids 从位置嵌入参数中获取位置嵌入张量
    
        token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
        # 根据 token 类型 ids 从 token 类型嵌入参数中获取 token 类型嵌入张量
    
        final_embeddings = inputs_embeds + position_embeds + token_type_embeds
        # 计算最终的嵌入张量,将输入嵌入、位置嵌入和 token 类型嵌入相加
    
        final_embeddings = self.LayerNorm(inputs=final_embeddings)
        # 使用层归一化处理最终的嵌入张量
    
        final_embeddings = self.dropout(inputs=final_embeddings, training=training)
        # 使用 dropout 进行训练时的正则化处理
    
        return final_embeddings
class TFElectraDiscriminatorPredictions(keras.layers.Layer):
    # Electra 判别器预测层,继承自 Keras 层
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        
        # 创建一个全连接层,输出维度为 config.hidden_size,命名为 "dense"
        self.dense = keras.layers.Dense(config.hidden_size, name="dense")
        
        # 创建一个全连接层,输出维度为 1,命名为 "dense_prediction"
        self.dense_prediction = keras.layers.Dense(1, name="dense_prediction")
        
        # 保存配置信息
        self.config = config

    def call(self, discriminator_hidden_states, training=False):
        # 将判别器隐藏状态输入到全连接层中
        hidden_states = self.dense(discriminator_hidden_states)
        
        # 根据配置中的激活函数,对隐藏状态进行激活
        hidden_states = get_tf_activation(self.config.hidden_act)(hidden_states)
        
        # 压缩预测结果的维度,去除最后一个维度
        logits = tf.squeeze(self.dense_prediction(hidden_states), -1)

        return logits

    def build(self, input_shape=None):
        if self.built:
            return
        
        self.built = True
        
        # 如果 dense 层已经存在,则构建该层
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                self.dense.build([None, None, self.config.hidden_size])
        
        # 如果 dense_prediction 层已经存在,则构建该层
        if getattr(self, "dense_prediction", None) is not None:
            with tf.name_scope(self.dense_prediction.name):
                self.dense_prediction.build([None, None, self.config.hidden_size])


class TFElectraGeneratorPredictions(keras.layers.Layer):
    # Electra 生成器预测层,继承自 Keras 层
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        
        # 创建 LayerNormalization 层,epsilon 设置为 config.layer_norm_eps,命名为 "LayerNorm"
        self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        
        # 创建全连接层,输出维度为 config.embedding_size,命名为 "dense"
        self.dense = keras.layers.Dense(config.embedding_size, name="dense")
        
        # 保存配置信息
        self.config = config

    def call(self, generator_hidden_states, training=False):
        # 将生成器隐藏状态输入到全连接层中
        hidden_states = self.dense(generator_hidden_states)
        
        # 使用 GELU 激活函数对隐藏状态进行激活
        hidden_states = get_tf_activation("gelu")(hidden_states)
        
        # 对激活后的隐藏状态进行 LayerNormalization 处理
        hidden_states = self.LayerNorm(hidden_states)

        return hidden_states

    def build(self, input_shape=None):
        if self.built:
            return
        
        self.built = True
        
        # 如果 LayerNorm 层已经存在,则构建该层
        if getattr(self, "LayerNorm", None) is not None:
            with tf.name_scope(self.LayerNorm.name):
                self.LayerNorm.build([None, None, self.config.embedding_size])
        
        # 如果 dense 层已经存在,则构建该层
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                self.dense.build([None, None, self.config.hidden_size])


class TFElectraPreTrainedModel(TFPreTrainedModel):
    """
    一个抽象类,用于处理权重初始化以及下载和加载预训练模型的简单接口。
    """
    
    # 默认配置类为 ElectraConfig
    config_class = ElectraConfig
    
    # 基础模型前缀为 "electra"
    base_model_prefix = "electra"
    
    # 从 PT 模型加载时忽略的键
    _keys_to_ignore_on_load_unexpected = [r"generator_lm_head.weight"]
    
    # 加载时缺失的键
    _keys_to_ignore_on_load_missing = [r"dropout"]


@keras_serializable
class TFElectraMainLayer(keras.layers.Layer):
    # Electra 主层,继承自 Keras 层
    config_class = ElectraConfig
    # 初始化方法,接受配置和可选参数
    def __init__(self, config, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)

        # 将配置保存到实例变量中
        self.config = config
        # 根据配置设置是否为解码器的标志
        self.is_decoder = config.is_decoder

        # 创建电力特拉嵌入层对象,并命名为"embeddings"
        self.embeddings = TFElectraEmbeddings(config, name="embeddings")

        # 如果嵌入层的嵌入大小不等于隐藏大小,则创建一个全连接层用于投影
        if config.embedding_size != config.hidden_size:
            self.embeddings_project = keras.layers.Dense(config.hidden_size, name="embeddings_project")

        # 创建电力特拉编码器对象,并命名为"encoder"
        self.encoder = TFElectraEncoder(config, name="encoder")

    # 获取输入嵌入的方法
    def get_input_embeddings(self):
        # 返回嵌入层对象
        return self.embeddings

    # 设置输入嵌入的方法
    def set_input_embeddings(self, value):
        # 设置嵌入层的权重
        self.embeddings.weight = value
        # 设置嵌入层的词汇表大小
        self.embeddings.vocab_size = shape_list(value)[0]

    # 剪枝模型头部的方法,抛出未实现错误
    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        raise NotImplementedError
    def get_extended_attention_mask(self, attention_mask, input_shape, dtype, past_key_values_length=0):
        # 获取输入的批量大小和序列长度
        batch_size, seq_length = input_shape

        # 如果没有提供注意力掩码,则创建一个全为1的注意力掩码
        if attention_mask is None:
            attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)

        # 创建一个3D注意力掩码,从一个2D张量掩码中生成
        # 大小为 [batch_size, 1, 1, to_seq_length]
        # 这样可以广播到 [batch_size, num_heads, from_seq_length, to_seq_length]
        # 这个注意力掩码比在OpenAI GPT中使用的三角形遮盖更简单,我们只需要准备广播维度。
        attention_mask_shape = shape_list(attention_mask)

        mask_seq_length = seq_length + past_key_values_length
        # 从 `modeling_tf_t5.py` 复制而来
        # 提供一个维度为 [batch_size, mask_seq_length] 的填充掩码
        # - 如果模型是解码器,除了填充掩码外还应用因果掩码
        # - 如果模型是编码器,使掩码可广播到 [batch_size, num_heads, mask_seq_length, mask_seq_length]
        if self.is_decoder:
            seq_ids = tf.range(mask_seq_length)
            causal_mask = tf.less_equal(
                tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
                seq_ids[None, :, None],
            )
            causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype)
            extended_attention_mask = causal_mask * attention_mask[:, None, :]
            attention_mask_shape = shape_list(extended_attention_mask)
            extended_attention_mask = tf.reshape(
                extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
            )
            # 如果存在过去的键值长度大于0,则修剪注意力掩码
            if past_key_values_length > 0:
                extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
        else:
            # 对于编码器,将注意力掩码重塑为 [batch_size, 1, 1, attention_mask_shape[1]]
            extended_attention_mask = tf.reshape(
                attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
            )

        # 将注意力掩码转换为指定的数据类型
        extended_attention_mask = tf.cast(extended_attention_mask, dtype=dtype)
        one_cst = tf.constant(1.0, dtype=dtype)
        ten_thousand_cst = tf.constant(-10000.0, dtype=dtype)
        # 将掩码中的1.0变为0.0,0.0变为-10000.0,以便在softmax之前抑制掉未关注的位置
        extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)

        return extended_attention_mask
    # 如果头部遮罩(head_mask)不为None,则抛出未实现的错误,暂不支持头部遮罩
    def get_head_mask(self, head_mask):
        if head_mask is not None:
            raise NotImplementedError
        else:
            # 如果头部遮罩为None,则创建一个长度为self.config.num_hidden_layers的空遮罩列表
            head_mask = [None] * self.config.num_hidden_layers

        # 返回头部遮罩
        return head_mask

    # 使用装饰器unpack_inputs解包输入参数
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        token_type_ids: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
        encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
        past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: Optional[bool] = False,
    ):
        # 在构建模型时,如果已经构建过则直接返回
        if self.built:
            return
        self.built = True
        # 如果存在self.embeddings,则构建它
        if getattr(self, "embeddings", None) is not None:
            with tf.name_scope(self.embeddings.name):
                self.embeddings.build(None)
        # 如果存在self.encoder,则构建它
        if getattr(self, "encoder", None) is not None:
            with tf.name_scope(self.encoder.name):
                self.encoder.build(None)
        # 如果存在self.embeddings_project,则根据指定的形状构建它
        if getattr(self, "embeddings_project", None) is not None:
            with tf.name_scope(self.embeddings_project.name):
                self.embeddings_project.build([None, None, self.config.embedding_size])
@dataclass
class TFElectraForPreTrainingOutput(ModelOutput):
    """
    [`TFElectraForPreTraining`]的输出类型。

    Args:
        loss (*可选*, 当提供 `labels` 时返回, `tf.Tensor` 形状为 `(1,)`):
            ELECTRA 目标的总损失。
        logits (`tf.Tensor` 形状为 `(batch_size, sequence_length)`):
            头部的预测分数(SoftMax 前每个标记的分数)。
        hidden_states (`tuple(tf.Tensor)`, *可选*, 当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
            元组的 `tf.Tensor`(一个用于嵌入输出 + 每个层的输出)形状为 `(batch_size, sequence_length, hidden_size)`。

            模型在每层输出的隐藏状态以及初始嵌入输出。
        attentions (`tuple(tf.Tensor)`, *可选*, 当 `output_attentions=True` 或 `config.output_attentions=True` 时返回):
            元组的 `tf.Tensor`(每个层一个)形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。

            经过注意力 softmax 后的注意力权重,用于计算自注意力头部的加权平均。

    """

    logits: tf.Tensor = None
    hidden_states: Tuple[tf.Tensor] | None = None
    attentions: Tuple[tf.Tensor] | None = None


ELECTRA_START_DOCSTRING = r"""

    此模型继承自 [`TFPreTrainedModel`]。查看超类文档以获取库实现的所有模型的通用方法(如下载或保存、调整输入嵌入、修剪头部等)。

    此模型还是 [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) 的子类。将其视为常规的 TF 2.0 Keras 模型,并参考 TF 2.0 文档,了解有关一般用法和行为的所有内容。

    <Tip>

    `transformers` 中的 TensorFlow 模型和层接受两种输入格式:

    - 将所有输入作为关键字参数(类似于 PyTorch 模型);
    - 将所有输入作为列表、元组或字典的第一个位置参数。

    支持第二种格式的原因是,当传递输入给模型和层时,Keras 方法更喜欢此格式。由于这种支持,在使用诸如 `model.fit()` 等方法时,您应该能够“只需传递”您的输入和标签 - 只需使用 `model.fit()` 支持的任何格式!但是,如果您想在 Keras 方法如 `fit()` 和 `predict()` 之外使用第二种格式,比如在使用 Keras `Functional` API 创建自己的层或模型时,有三种可能性可以用于在第一个位置参数中收集所有输入张量:

    - 只有 `input_ids` 的单个张量:`model(input_ids)`
    - 可变长度列表,其中按文档字符串中给出的顺序包含一个或多个输入张量:

"""
    `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
    - 当使用模型对象 `model` 时,可以传入一个包含输入张量的字典,键名需与文档字符串中给出的输入名称对应:
    `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`

    Note that when creating models and layers with
    [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
    about any of this, as you can just pass inputs like you would to any other Python function!
    - 当使用子类化创建模型和层时,您无需担心这些细节,可以像传递任何其他 Python 函数的输入一样操作!

    Parameters:
        config ([`ElectraConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
        - config ([`ElectraConfig`]): 包含模型所有参数的配置类。
          使用配置文件初始化模型时,并不会加载与模型关联的权重,只加载配置信息。
          查看 [`~PreTrainedModel.from_pretrained`] 方法以加载模型的权重。
"""

ELECTRA_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
            [`PreTrainedTokenizer.encode`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
            config will be used instead.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
            used instead.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
            eager mode, in graph mode the value will always be set to True.
        training (`bool`, *optional*, defaults to `False`):
            Whether or not to use the model in training mode (some modules like dropout modules have different
            behaviors between training and evaluation).
"""


@add_start_docstrings(
    # 添加文档字符串前缀,将其应用于下方的函数装饰器
    """
    生成器模型和判别器模型的检查点可以加载到此模型中。

    这是一个裸的 Electra 模型变压器,输出未经任何特定头部处理的原始隐藏状态。与 BERT 模型相似,但如果隐藏大小和嵌入大小不同,则在嵌入层和编码器之间使用额外的线性层。

    ELECTRA_START_DOCSTRING 标识符,指示这是 Electra 模型的文档字符串的起始部分。
    """
    )
    # 结束类定义的括号

class TFElectraModel(TFElectraPreTrainedModel):
    # TFElectraModel 类继承自 TFElectraPreTrainedModel 类

    def __init__(self, config, *inputs, **kwargs):
        # 初始化方法,接受 config 对象和任意其他输入参数

        # 调用父类的初始化方法
        super().__init__(config, *inputs, **kwargs)

        # 创建 TFElectraMainLayer 实例并赋值给 self.electra
        self.electra = TFElectraMainLayer(config, name="electra")

    @unpack_inputs
    @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TFBaseModelOutputWithPastAndCrossAttentions,
        config_class=_CONFIG_FOR_DOC,
    )
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        token_type_ids: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
        encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
        past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: Optional[bool] = False,
    ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
        r"""
        encoder_hidden_states  (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

        past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
            contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
        use_cache (`bool`, *optional*, defaults to `True`):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`). Set to `False` during training, `True` during generation
        """
        # 调用 Electra 模型进行前向传播,接受多个输入参数
        outputs = self.electra(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        # 返回 Electra 模型的输出结果
        return outputs

    def build(self, input_shape=None):
        # 如果模型已经构建过,则直接返回,避免重复构建
        if self.built:
            return
        # 标记模型已经构建
        self.built = True
        # 如果 self.electra 存在,则在对应的命名空间下构建 Electra 模型
        if getattr(self, "electra", None) is not None:
            with tf.name_scope(self.electra.name):
                # 构建 Electra 模型,传入 None 作为输入形状
                self.electra.build(None)
# 使用装饰器为类添加文档字符串,描述该类的作用和功能
@add_start_docstrings(
    """
    Electra model with a binary classification head on top as used during pretraining for identifying generated tokens.

    Even though both the discriminator and generator may be loaded into this model, the discriminator is the only model
    of the two to have the correct classification head to be used for this model.
    """,
    ELECTRA_START_DOCSTRING,
)
class TFElectraForPreTraining(TFElectraPreTrainedModel):
    
    # 初始化方法,接收配置和其他关键字参数
    def __init__(self, config, **kwargs):
        # 调用父类的初始化方法
        super().__init__(config, **kwargs)
        
        # 创建 Electra 主层,并命名为 "electra"
        self.electra = TFElectraMainLayer(config, name="electra")
        
        # 创建 Electra 鉴别器预测层,并命名为 "discriminator_predictions"
        self.discriminator_predictions = TFElectraDiscriminatorPredictions(config, name="discriminator_predictions")

    # 调用方法,接收多个输入参数,执行模型的前向传播
    @unpack_inputs
    # 使用装饰器添加模型前向传播的文档字符串,描述输入参数的格式和作用
    @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    # 使用装饰器替换返回值的文档字符串,指定返回结果的类型为 TFElectraForPreTrainingOutput
    @replace_return_docstrings(output_type=TFElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        token_type_ids: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: Optional[bool] = False,
        # 最后一个参数没有被完全列出

        # 表示是否返回字典形式的结果
        return_dict: Optional[bool] = None,
        # 是否在训练模式下运行模型
        training: Optional[bool] = False,
        discriminator_hidden_states = self.electra(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

# 调用 self.electra 模型进行前向传播,传入各种输入参数,获取鉴别器模型的隐藏状态。


        discriminator_sequence_output = discriminator_hidden_states[0]

# 从鉴别器模型的隐藏状态中提取序列输出,即第一个元素。


        logits = self.discriminator_predictions(discriminator_sequence_output)

# 使用 self.discriminator_predictions 模型预测鉴别器输出的 logits(对数概率)。


        if not return_dict:
            return (logits,) + discriminator_hidden_states[1:]

# 如果 return_dict 参数为 False,则返回 logits 和鉴别器模型的其他隐藏状态。


        return TFElectraForPreTrainingOutput(
            logits=logits,
            hidden_states=discriminator_hidden_states.hidden_states,
            attentions=discriminator_hidden_states.attentions,
        )

# 如果 return_dict 参数为 True,则返回 TFElectraForPreTrainingOutput 对象,包含 logits、隐藏状态和注意力权重。



    def build(self, input_shape=None):
        if self.built:
            return

# 如果模型已经构建过,直接返回,避免重复构建。


        self.built = True

# 将模型标记为已构建状态。


        if getattr(self, "electra", None) is not None:
            with tf.name_scope(self.electra.name):
                self.electra.build(None)

# 如果 self.electra 存在,使用其名称作为命名空间,在该命名空间下构建 self.electra 模型。


        if getattr(self, "discriminator_predictions", None) is not None:
            with tf.name_scope(self.discriminator_predictions.name):
                self.discriminator_predictions.build(None)

# 如果 self.discriminator_predictions 存在,使用其名称作为命名空间,在该命名空间下构建 self.discriminator_predictions 模型。
class TFElectraMaskedLMHead(keras.layers.Layer):
    # 定义 Electra 模型的 Masked Language Modeling 头部的层
    def __init__(self, config, input_embeddings, **kwargs):
        super().__init__(**kwargs)

        self.config = config
        self.embedding_size = config.embedding_size
        self.input_embeddings = input_embeddings

    def build(self, input_shape):
        # 添加权重,初始化偏置向量为全零向量
        self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")

        super().build(input_shape)

    def get_output_embeddings(self):
        # 返回输入的嵌入层对象
        return self.input_embeddings

    def set_output_embeddings(self, value):
        # 设置输入的嵌入层的权重和词汇大小
        self.input_embeddings.weight = value
        self.input_embeddings.vocab_size = shape_list(value)[0]

    def get_bias(self):
        # 返回偏置向量字典
        return {"bias": self.bias}

    def set_bias(self, value):
        # 设置偏置向量
        self.bias = value["bias"]
        self.config.vocab_size = shape_list(value["bias"])[0]

    def call(self, hidden_states):
        # 计算 Masked Language Modeling 的输出
        seq_length = shape_list(tensor=hidden_states)[1]
        hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size])
        hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
        hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
        hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)

        return hidden_states


@add_start_docstrings(
    """
    Electra model with a language modeling head on top.

    Even though both the discriminator and generator may be loaded into this model, the generator is the only model of
    the two to have been trained for the masked language modeling task.
    """,
    ELECTRA_START_DOCSTRING,
)
class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLoss):
    # Electra 模型加上顶部的语言建模头部
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)

        self.config = config
        # Electra 主层
        self.electra = TFElectraMainLayer(config, name="electra")
        # Electra 生成器预测
        self.generator_predictions = TFElectraGeneratorPredictions(config, name="generator_predictions")

        if isinstance(config.hidden_act, str):
            self.activation = get_tf_activation(config.hidden_act)
        else:
            self.activation = config.hidden_act

        # Electra 的 Masked Language Modeling 头部
        self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head")

    def get_lm_head(self):
        # 返回 Masked Language Modeling 头部
        return self.generator_lm_head

    def get_prefix_bias_name(self):
        # 警告:方法已弃用,请使用 `get_bias` 替代
        warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
        return self.name + "/" + self.generator_lm_head.name

    @unpack_inputs
    @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint="google/electra-small-generator",
        output_type=TFMaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
        mask="[MASK]",
        expected_output="'paris'",
        expected_loss=1.22,
    )
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        token_type_ids: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: np.ndarray | tf.Tensor | None = None,
        training: Optional[bool] = False,
    ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
        r"""
        Define the call function for the Electra generator model.
    
        labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        """
        # Generate hidden states using the Electra model with provided inputs
        generator_hidden_states = self.electra(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )
        # Extract sequence output from generator hidden states
        generator_sequence_output = generator_hidden_states[0]
        # Generate prediction scores using the generator predictions function
        prediction_scores = self.generator_predictions(generator_sequence_output, training=training)
        # Apply language modeling head to the generator prediction scores
        prediction_scores = self.generator_lm_head(prediction_scores, training=training)
        # Compute loss only if labels are provided using the provided loss computation function
        loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores)
    
        # Prepare output based on whether return_dict is False or True
        if not return_dict:
            output = (prediction_scores,) + generator_hidden_states[1:]
            return ((loss,) + output) if loss is not None else output
    
        # Return TFMaskedLMOutput with detailed components if return_dict is True
        return TFMaskedLMOutput(
            loss=loss,
            logits=prediction_scores,
            hidden_states=generator_hidden_states.hidden_states,
            attentions=generator_hidden_states.attentions,
        )
    # 构建模型的方法,用于设置模型结构和参数
    def build(self, input_shape=None):
        # 如果模型已经构建过,则直接返回,不再重复构建
        if self.built:
            return
        # 将模型标记为已构建状态
        self.built = True
        
        # 如果存在名为 "electra" 的子模型,进行其构建
        if getattr(self, "electra", None) is not None:
            # 使用电力转换模型的名字作为命名空间
            with tf.name_scope(self.electra.name):
                # 调用电力转换模型的构建方法,输入形状为 None 表示使用默认形状
                self.electra.build(None)
        
        # 如果存在名为 "generator_predictions" 的子模型,进行其构建
        if getattr(self, "generator_predictions", None) is not None:
            # 使用生成器预测模型的名字作为命名空间
            with tf.name_scope(self.generator_predictions.name):
                # 调用生成器预测模型的构建方法,输入形状为 None 表示使用默认形状
                self.generator_predictions.build(None)
        
        # 如果存在名为 "generator_lm_head" 的子模型,进行其构建
        if getattr(self, "generator_lm_head", None) is not None:
            # 使用生成器语言模型头部的名字作为命名空间
            with tf.name_scope(self.generator_lm_head.name):
                # 调用生成器语言模型头部的构建方法,输入形状为 None 表示使用默认形状
                self.generator_lm_head.build(None)
    """
    ELECTRA Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    """

@add_start_docstrings(
    """
    ELECTRA Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    """,
    ELECTRA_START_DOCSTRING,
)
class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceClassificationLoss):
    """
    ELECTRA模型的转换器,顶部带有序列分类/回归头(在汇聚输出顶部的线性层),例如用于GLUE任务。
    """

    def __init__(self, config, *inputs, **kwargs):
        """
        初始化方法。

        Args:
            config (ElectraConfig): 模型的配置对象,包含模型的超参数。
            *inputs: 可变长度的输入参数。
            **kwargs: 其他关键字参数。
        """
        super().__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels  # 设置模型的标签数
        self.electra = TFElectraMainLayer(config, name="electra")  # ELECTRA主层对象
        self.classifier = TFElectraClassificationHead(config, name="classifier")  # 分类头部对象

    @unpack_inputs
    @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint="bhadresh-savani/electra-base-emotion",
        output_type=TFSequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
        expected_output="'joy'",
        expected_loss=0.06,
    )
    def forward(self, *model_args, **model_kwargs):
        """
        正向传播方法,根据输入计算模型输出。

        Args:
            *model_args: 可变长度的模型输入参数。
            **model_kwargs: 模型输入的关键字参数。

        Returns:
            TFSequenceClassifierOutput: 序列分类器的输出对象。
        """
        pass  # 这里的方法体未提供,仅有注释和装饰器的声明
    def call(
        self,
        input_ids: TFModelInputType | None = None,  # 接收输入的文本序列的 ID,可以为空
        attention_mask: np.ndarray | tf.Tensor | None = None,  # 注意力掩码,用于指示模型在处理输入时哪些部分需要注意
        token_type_ids: np.ndarray | tf.Tensor | None = None,  # 用于区分不同文本序列的 token 类型 ID
        position_ids: np.ndarray | tf.Tensor | None = None,  # 表示输入中每个 token 的位置 ID
        head_mask: np.ndarray | tf.Tensor | None = None,  # 头部掩码,用于指示模型在自注意力机制中哪些头部需要被屏蔽
        inputs_embeds: np.ndarray | tf.Tensor | None = None,  # 可选的嵌入输入,可以直接提供输入的嵌入表示
        output_attentions: Optional[bool] = None,  # 是否输出注意力权重
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态
        return_dict: Optional[bool] = None,  # 是否返回结果字典
        labels: np.ndarray | tf.Tensor | None = None,  # 用于计算序列分类/回归损失的标签
        training: Optional[bool] = False,  # 是否处于训练模式
    ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
        r"""
        labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        # 调用 Electra 模型进行前向传播
        outputs = self.electra(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )
        # 将 Electra 输出传递给分类器
        logits = self.classifier(outputs[0])
        # 如果提供了标签,则计算损失
        loss = None if labels is None else self.hf_compute_loss(labels, logits)

        # 根据 return_dict 参数决定返回结果的格式
        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output
        else:
            # 如果 return_dict 为 True,则返回 TFSequenceClassifierOutput 对象
            return TFSequenceClassifierOutput(
                loss=loss,
                logits=logits,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )

    def build(self, input_shape=None):
        # 如果已经构建过模型,则直接返回
        if self.built:
            return
        self.built = True
        # 如果存在 Electra 模型,则构建其内部结构
        if getattr(self, "electra", None) is not None:
            with tf.name_scope(self.electra.name):
                self.electra.build(None)
        # 如果存在分类器模型,则构建其内部结构
        if getattr(self, "classifier", None) is not None:
            with tf.name_scope(self.classifier.name):
                self.classifier.build(None)
"""
ELECTRA 模型,顶部带有多选分类头部(在池化输出的基础上是一个线性层和一个 softmax),例如用于 RocStories/SWAG 任务。

继承自 TFElectraPreTrainedModel 和 TFMultipleChoiceLoss。
"""
@add_start_docstrings(
    """
    ELECTRA Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
    softmax) e.g. for RocStories/SWAG tasks.
    """,
    ELECTRA_START_DOCSTRING,
)
class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss):
    def __init__(self, config, *inputs, **kwargs):
        """
        初始化方法,设置模型的各个组件。

        Parameters:
        - config: ELECTRA 模型的配置对象。
        - *inputs: 可变长度的输入。
        - **kwargs: 其他关键字参数。
        """
        super().__init__(config, *inputs, **kwargs)

        # ELECTRA 主体层
        self.electra = TFElectraMainLayer(config, name="electra")
        # 序列汇总层
        self.sequence_summary = TFSequenceSummary(
            config, initializer_range=config.initializer_range, name="sequence_summary"
        )
        # 分类器层
        self.classifier = keras.layers.Dense(
            1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )
        # 保存配置对象
        self.config = config

    @unpack_inputs
    @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TFMultipleChoiceModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        token_type_ids: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: np.ndarray | tf.Tensor | None = None,
        training: Optional[bool] = False,
        """
        调用方法,执行 ELECTRA 模型的前向传播。

        Parameters:
        - input_ids: 输入的 token IDs。
        - attention_mask: 注意力掩码。
        - token_type_ids: token 类型 IDs。
        - position_ids: 位置 IDs。
        - head_mask: 头部掩码。
        - inputs_embeds: 输入的嵌入。
        - output_attentions: 是否输出注意力。
        - output_hidden_states: 是否输出隐藏状态。
        - return_dict: 是否返回字典形式结果。
        - labels: 标签数据。
        - training: 是否处于训练模式。

        Returns:
        ELECTRA 模型的输出对象。
        """
        ...
    ) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
        r"""
        labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
            where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
        """

        # 如果给定了 input_ids,则获取其第二和第三维的大小
        if input_ids is not None:
            num_choices = shape_list(input_ids)[1]
            seq_length = shape_list(input_ids)[2]
        else:
            # 如果没有 input_ids,则获取 inputs_embeds 的第二和第三维的大小
            num_choices = shape_list(inputs_embeds)[1]
            seq_length = shape_list(inputs_embeds)[2]

        # 将输入张量展平为二维张量,如果相应输入不为 None
        flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
        flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
        flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
        flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
        flat_inputs_embeds = (
            tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
            if inputs_embeds is not None
            else None
        )
        
        # 调用 Electra 模型进行前向传播,传入展平后的张量及其他参数
        outputs = self.electra(
            input_ids=flat_input_ids,
            attention_mask=flat_attention_mask,
            token_type_ids=flat_token_type_ids,
            position_ids=flat_position_ids,
            head_mask=head_mask,
            inputs_embeds=flat_inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        # 对 Electra 模型的输出进行序列汇总
        logits = self.sequence_summary(outputs[0])
        # 将汇总后的序列 logits 输入分类器进行分类预测
        logits = self.classifier(logits)
        # 重新整形 logits 张量为形状为 (-1, num_choices)
        reshaped_logits = tf.reshape(logits, (-1, num_choices))
        # 如果提供了 labels,则计算损失
        loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)

        # 如果 return_dict=False,则按指定格式返回结果
        if not return_dict:
            output = (reshaped_logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        # 如果 return_dict=True,则返回带有多选模型输出的对象
        return TFMultipleChoiceModelOutput(
            loss=loss,
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def build(self, input_shape=None):
        # 如果已经构建,则直接返回
        if self.built:
            return
        # 标记已构建
        self.built = True
        
        # 如果存在 self.electra 属性,则构建 Electra 模型
        if getattr(self, "electra", None) is not None:
            with tf.name_scope(self.electra.name):
                self.electra.build(None)
        
        # 如果存在 self.sequence_summary 属性,则构建序列汇总层
        if getattr(self, "sequence_summary", None) is not None:
            with tf.name_scope(self.sequence_summary.name):
                self.sequence_summary.build(None)
        
        # 如果存在 self.classifier 属性,则构建分类器层
        if getattr(self, "classifier", None) is not None:
            with tf.name_scope(self.classifier.name):
                self.classifier.build([None, None, self.config.hidden_size])
@add_start_docstrings(
    """
    Electra model with a token classification head on top.

    Both the discriminator and generator may be loaded into this model.
    """,
    ELECTRA_START_DOCSTRING,
)
class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassificationLoss):
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)

        # 初始化 Electra 主模型层,命名为 "electra"
        self.electra = TFElectraMainLayer(config, name="electra")

        # 根据配置中的 dropout 概率设置分类器的 dropout 层
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = keras.layers.Dropout(classifier_dropout)

        # 定义一个全连接层作为分类器,输出维度为类别数目
        self.classifier = keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )

        # 将配置保存在对象中
        self.config = config

    @unpack_inputs
    @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint="bhadresh-savani/electra-base-discriminator-finetuned-conll03-english",
        output_type=TFTokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
        expected_output="['B-LOC', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'B-LOC', 'I-LOC']",
        expected_loss=0.11,
    )
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        token_type_ids: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: np.ndarray | tf.Tensor | None = None,
        training: Optional[bool] = False,
    ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
        r"""
        labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        # 调用 ELECTRA 模型进行预测,获取鉴别器的隐藏状态
        discriminator_hidden_states = self.electra(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )
        # 从鉴别器的隐藏状态中取出序列输出
        discriminator_sequence_output = discriminator_hidden_states[0]
        # 对鉴别器的序列输出应用 dropout 操作
        discriminator_sequence_output = self.dropout(discriminator_sequence_output)
        # 将 dropout 后的输出传递给分类器,得到预测的 logits
        logits = self.classifier(discriminator_sequence_output)
        # 如果提供了标签,则计算损失
        loss = None if labels is None else self.hf_compute_loss(labels, logits)

        # 根据 return_dict 的值决定返回的结果格式
        if not return_dict:
            # 如果不要求返回字典,则输出 logits 和其它隐藏状态
            output = (logits,) + discriminator_hidden_states[1:]
            return ((loss,) + output) if loss is not None else output

        # 如果要求返回字典格式的结果,则返回 TFTokenClassifierOutput 对象
        return TFTokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=discriminator_hidden_states.hidden_states,
            attentions=discriminator_hidden_states.attentions,
        )

    def build(self, input_shape=None):
        # 如果已经构建过模型,则直接返回
        if self.built:
            return
        # 标记模型已经构建
        self.built = True
        # 如果存在 ELECTRA 模型,建立其内部结构
        if getattr(self, "electra", None) is not None:
            with tf.name_scope(self.electra.name):
                self.electra.build(None)
        # 如果存在分类器模型,建立其内部结构
        if getattr(self, "classifier", None) is not None:
            with tf.name_scope(self.classifier.name):
                self.classifier.build([None, None, self.config.hidden_size])
# 使用装饰器添加模型文档字符串,描述了 Electra 模型在提取式问答任务(如 SQuAD)中的应用,包括在隐藏状态输出之上的线性层,用于计算“span start logits”和“span end logits”。
@add_start_docstrings(
    """
    Electra Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
    ELECTRA_START_DOCSTRING,
)
# 定义 TFElectraForQuestionAnswering 类,继承自 TFElectraPreTrainedModel 和 TFQuestionAnsweringLoss
class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnsweringLoss):
    
    # 初始化方法,接受配置 config 和其他输入参数
    def __init__(self, config, *inputs, **kwargs):
        # 调用父类的初始化方法
        super().__init__(config, *inputs, **kwargs)

        # 设置模型的标签数量
        self.num_labels = config.num_labels
        # 创建 Electra 主层对象,命名为 "electra"
        self.electra = TFElectraMainLayer(config, name="electra")
        # 创建输出层,使用 Dense 层,输出大小为 config.num_labels,使用指定的初始化器初始化权重,命名为 "qa_outputs"
        self.qa_outputs = keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
        )
        # 保存配置对象
        self.config = config

    # 使用装饰器来包装 call 方法,添加模型前向传播的文档字符串
    @unpack_inputs
    @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint="bhadresh-savani/electra-base-squad2",
        output_type=TFQuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
        qa_target_start_index=11,
        qa_target_end_index=12,
        expected_output="'a nice puppet'",
        expected_loss=2.64,
    )
    # 定义模型的前向传播方法,接受多个输入参数和一些控制参数
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        token_type_ids: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        start_positions: np.ndarray | tf.Tensor | None = None,
        end_positions: np.ndarray | tf.Tensor | None = None,
        training: Optional[bool] = False,
        # 结尾未完,继续下一页 。
    ) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
        r"""
        start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        """
        discriminator_hidden_states = self.electra(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )
        # 获取鉴别器模型的隐藏状态
        discriminator_sequence_output = discriminator_hidden_states[0]
        # 通过输出序列计算问题回答的逻辑张量
        logits = self.qa_outputs(discriminator_sequence_output)
        # 将逻辑张量沿最后一个维度分割为起始和结束的逻辑张量
        start_logits, end_logits = tf.split(logits, 2, axis=-1)
        # 压缩起始和结束的逻辑张量的最后一个维度
        start_logits = tf.squeeze(start_logits, axis=-1)
        end_logits = tf.squeeze(end_logits, axis=-1)
        # 初始化损失变量
        loss = None

        # 如果提供了起始和结束的位置信息,则计算损失
        if start_positions is not None and end_positions is not None:
            # 准备标签,用于计算损失
            labels = {"start_position": start_positions}
            labels["end_position"] = end_positions
            # 使用标签和预测的逻辑张量计算损失
            loss = self.hf_compute_loss(labels, (start_logits, end_logits))

        # 如果不要求返回字典,则组装输出
        if not return_dict:
            output = (
                start_logits,
                end_logits,
            ) + discriminator_hidden_states[1:]
            # 返回损失和输出,如果损失不为None
            return ((loss,) + output) if loss is not None else output

        # 返回 TFQuestionAnsweringModelOutput 对象,包含损失和其他输出信息
        return TFQuestionAnsweringModelOutput(
            loss=loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=discriminator_hidden_states.hidden_states,
            attentions=discriminator_hidden_states.attentions,
        )

    def build(self, input_shape=None):
        # 如果模型已经构建,则直接返回
        if self.built:
            return
        # 标记模型已经构建
        self.built = True
        # 如果存在 electra 层,则构建它
        if getattr(self, "electra", None) is not None:
            with tf.name_scope(self.electra.name):
                self.electra.build(None)
        # 如果存在 qa_outputs 层,则构建它
        if getattr(self, "qa_outputs", None) is not None:
            with tf.name_scope(self.qa_outputs.name):
                self.qa_outputs.build([None, None, self.config.hidden_size])

.\models\electra\tokenization_electra.py

# 以 UTF-8 编码声明文件编码方式
# 版权声明及许可信息
# 该代码基于 Apache License, Version 2.0 开源许可证发布,详情请访问指定网址获取完整许可信息
# 导入所需的标准库模块和函数
# collections 模块提供了额外的数据类型供 Python 内置数据类型的扩展
# os 模块提供了与操作系统交互的功能
# unicodedata 模块包含用于 Unicode 数据库的访问功能
# 从 typing 模块导入 List, Optional, Tuple,用于类型提示
# 从 tokenization_utils 模块中导入 PreTrainedTokenizer 类和一些辅助函数
# 从 utils 模块导入 logging 函数
from typing import List, Optional, Tuple
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
from ...utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义一个字典,指定每个文件的名称及其对应的默认文件名
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}

# 定义一个嵌套字典,指定预训练模型和其对应的词汇文件 URL
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "google/electra-small-generator": (
            "https://huggingface.co/google/electra-small-generator/resolve/main/vocab.txt"
        ),
        "google/electra-base-generator": "https://huggingface.co/google/electra-base-generator/resolve/main/vocab.txt",
        "google/electra-large-generator": (
            "https://huggingface.co/google/electra-large-generator/resolve/main/vocab.txt"
        ),
        "google/electra-small-discriminator": (
            "https://huggingface.co/google/electra-small-discriminator/resolve/main/vocab.txt"
        ),
        "google/electra-base-discriminator": (
            "https://huggingface.co/google/electra-base-discriminator/resolve/main/vocab.txt"
        ),
        "google/electra-large-discriminator": (
            "https://huggingface.co/google/electra-large-discriminator/resolve/main/vocab.txt"
        ),
    }
}

# 定义一个字典,指定每个预训练模型和其对应的位置嵌入大小
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "google/electra-small-generator": 512,
    "google/electra-base-generator": 512,
    "google/electra-large-generator": 512,
    "google/electra-small-discriminator": 512,
    "google/electra-base-discriminator": 512,
    "google/electra-large-discriminator": 512,
}

# 定义一个字典,指定每个预训练模型的初始配置
PRETRAINED_INIT_CONFIGURATION = {
    "google/electra-small-generator": {"do_lower_case": True},
    "google/electra-base-generator": {"do_lower_case": True},
    "google/electra-large-generator": {"do_lower_case": True},
    "google/electra-small-discriminator": {"do_lower_case": True},
    "google/electra-base-discriminator": {"do_lower_case": True},
    "google/electra-large-discriminator": {"do_lower_case": True},
}

# 从 transformers.models.bert.tokenization_bert.load_vocab 函数复制过来的加载词汇表的函数定义
def load_vocab(vocab_file):
    """Loads a vocabulary file into a dictionary."""
    # 创建一个有序字典用于存储词汇表
    vocab = collections.OrderedDict()
    # 使用 UTF-8 编码打开词汇文件
    with open(vocab_file, "r", encoding="utf-8") as reader:
        # 逐行读取词汇文件内容
        tokens = reader.readlines()
    # 对 tokens 列表进行遍历,同时获取索引和每个元素 token
    for index, token in enumerate(tokens):
        # 去除 token 字符串末尾的换行符 "\n"
        token = token.rstrip("\n")
        # 将 token 添加到 vocab 字典中,键为 token,值为 index
        vocab[token] = index
    
    # 返回填充完毕的 vocab 字典作为结果
    return vocab
# 从transformers.models.bert.tokenization_bert.whitespace_tokenize复制而来,定义了一个函数用于基本的空白符号分割和清理文本。
def whitespace_tokenize(text):
    """Runs basic whitespace cleaning and splitting on a piece of text."""
    # 清除文本两侧的空白符号
    text = text.strip()
    # 如果清理后的文本为空,则返回空列表
    if not text:
        return []
    # 使用空白符号分割文本,得到token列表
    tokens = text.split()
    # 返回分割后的token列表
    return tokens


# 从transformers.models.bert.tokenization_bert.BertTokenizer复制而来,修改为支持Electra,构建Electra分词器。
class ElectraTokenizer(PreTrainedTokenizer):
    r"""
    Construct a Electra tokenizer. Based on WordPiece.

    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
    this superclass for more information regarding those methods.
    """
    # 定义一个类,用于处理预训练模型的词汇表和相关配置信息
    
    vocab_files_names = VOCAB_FILES_NAMES
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    
    # 初始化方法,用于创建一个新的Tokenizer实例
    def __init__(
        self,
        vocab_file,
        do_lower_case=True,
        do_basic_tokenize=True,
        never_split=None,
        unk_token="[UNK]",
        sep_token="[SEP]",
        pad_token="[PAD]",
        cls_token="[CLS]",
        mask_token="[MASK]",
        tokenize_chinese_chars=True,
        strip_accents=None,
        **kwargs,
    ):
        """
        Args:
            vocab_file (`str`):
                包含词汇表的文件。
            do_lower_case (`bool`, *optional*, defaults to `True`):
                是否在进行分词时将输入转换为小写。
            do_basic_tokenize (`bool`, *optional*, defaults to `True`):
                是否在WordPiece分词前进行基本分词。
            never_split (`Iterable`, *optional*):
                在分词过程中不应拆分的标记集合。仅在 `do_basic_tokenize=True` 时有效。
            unk_token (`str`, *optional*, defaults to `"[UNK]"`):
                未知标记。当输入中的标记不在词汇表中时,将其替换为此标记。
            sep_token (`str`, *optional*, defaults to `"[SEP]"`):
                分隔符标记,在构建多个序列的序列时使用,例如序列分类或问答问题时使用。也用作构建带有特殊标记的序列的最后一个标记。
            pad_token (`str`, *optional*, defaults to `"[PAD]"`):
                用于填充的标记,例如在批处理不同长度的序列时使用。
            cls_token (`str`, *optional*, defaults to `"[CLS]"`):
                分类器标记,在进行序列分类时使用(整个序列的分类而不是每个标记的分类)。它是构建带有特殊标记的序列的第一个标记。
            mask_token (`str`, *optional*, defaults to `"[MASK]"`):
                用于屏蔽值的标记。这是在进行遮蔽语言建模训练时使用的标记。模型将尝试预测此标记。
            tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
                是否对中文字符进行分词。
    
                对于日语可能需要禁用此选项(参见此问题: https://github.com/huggingface/transformers/issues/328)。
            strip_accents (`bool`, *optional*):
                是否删除所有重音符号。如果未指定此选项,则将根据 `lowercase` 的值来确定(与原始Electra一样)。
        """
    ):
        # 如果给定的词汇文件不存在,抛出数值错误异常,提示找不到指定路径的词汇文件
        if not os.path.isfile(vocab_file):
            raise ValueError(
                f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
                " model use `tokenizer = ElectraTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
            )
        # 加载词汇表文件到 self.vocab 中
        self.vocab = load_vocab(vocab_file)
        # 使用 collections.OrderedDict 创建 ids 到 tokens 的有序映射
        self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
        # 设置是否进行基本标记化处理的标志
        self.do_basic_tokenize = do_basic_tokenize
        # 如果需要进行基本标记化处理
        if do_basic_tokenize:
            # 初始化 BasicTokenizer 对象,设置参数
            self.basic_tokenizer = BasicTokenizer(
                do_lower_case=do_lower_case,
                never_split=never_split,
                tokenize_chinese_chars=tokenize_chinese_chars,
                strip_accents=strip_accents,
            )

        # 初始化 WordpieceTokenizer 对象,传入词汇表和未知标记
        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))

        # 调用父类的初始化方法,设置参数
        super().__init__(
            do_lower_case=do_lower_case,
            do_basic_tokenize=do_basic_tokenize,
            never_split=never_split,
            unk_token=unk_token,
            sep_token=sep_token,
            pad_token=pad_token,
            cls_token=cls_token,
            mask_token=mask_token,
            tokenize_chinese_chars=tokenize_chinese_chars,
            strip_accents=strip_accents,
            **kwargs,
        )

    @property
    def do_lower_case(self):
        # 返回基本标记化器的小写标志
        return self.basic_tokenizer.do_lower_case

    @property
    def vocab_size(self):
        # 返回词汇表的大小(词汇数)
        return len(self.vocab)

    def get_vocab(self):
        # 返回词汇表和已添加标记编码器的结合
        return dict(self.vocab, **self.added_tokens_encoder)

    def _tokenize(self, text, split_special_tokens=False):
        # 分词结果列表
        split_tokens = []
        # 如果需要进行基本标记化处理
        if self.do_basic_tokenize:
            # 使用 BasicTokenizer 对象进行标记化处理
            for token in self.basic_tokenizer.tokenize(
                text, never_split=self.all_special_tokens if not split_special_tokens else None
            ):
                # 如果标记在不分割集合中
                if token in self.basic_tokenizer.never_split:
                    split_tokens.append(token)
                else:
                    # 使用 WordpieceTokenizer 对标记进一步分词
                    split_tokens += self.wordpiece_tokenizer.tokenize(token)
        else:
            # 使用 WordpieceTokenizer 对整个文本进行分词
            split_tokens = self.wordpiece_tokenizer.tokenize(text)
        return split_tokens

    def _convert_token_to_id(self, token):
        """Converts a token (str) in an id using the vocab."""
        # 根据词汇表将标记转换为其对应的 id
        return self.vocab.get(token, self.vocab.get(self.unk_token))

    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        # 根据 id 将其转换为对应的标记
        return self.ids_to_tokens.get(index, self.unk_token)

    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (string) in a single string."""
        # 将标记序列合并为单个字符串,去除 "##" 符号
        out_string = " ".join(tokens).replace(" ##", "").strip()
        return out_string

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
        adding special tokens. A Electra sequence has the following format:

        - single sequence: `[CLS] X [SEP]`
        - pair of sequences: `[CLS] A [SEP] B [SEP]`

        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
        """
        # If token_ids_1 is not provided, return the single-sequence format
        if token_ids_1 is None:
            return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
        
        # For pair of sequences, concatenate tokens with special tokens separating them
        cls = [self.cls_token_id]
        sep = [self.sep_token_id]
        return cls + token_ids_0 + sep + token_ids_1 + sep

    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
        """
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` method.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """
        # If the tokens already have special tokens, delegate to the superclass method
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        # Calculate the mask for tokens with special tokens added
        if token_ids_1 is not None:
            return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
        return [1] + ([0] * len(token_ids_0)) + [1]

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Create token type IDs tensor from token id tensors. `0` for the first sentence tokens, `1` for the second sentence
        tokens.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of token type IDs according to the sequences provided.
        """
    ) -> List[int]:
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. A Electra sequence
        pair mask has the following format:

        ```
        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
        | first sequence    | second sequence |
        ```

        If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
        """
        # Define the separator and classification tokens
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]
        
        # If only one sequence is provided (token_ids_1 is None), return a mask with all zeros
        if token_ids_1 is None:
            return len(cls + token_ids_0 + sep) * [0]
        
        # If both sequences are provided, return a mask with zeros for the first sequence and ones for the second
        return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        index = 0
        
        # Determine the vocabulary file path based on the provided save_directory
        if os.path.isdir(save_directory):
            vocab_file = os.path.join(
                save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
            )
        else:
            vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
        
        # Write the vocabulary to the specified file
        with open(vocab_file, "w", encoding="utf-8") as writer:
            for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
                # Check if vocabulary indices are consecutive and warn if not
                if index != token_index:
                    logger.warning(
                        f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
                        " Please check that the vocabulary is not corrupted!"
                    )
                    index = token_index
                writer.write(token + "\n")
                index += 1
        
        # Return the path to the saved vocabulary file
        return (vocab_file,)
# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer
# 基本分词器类,用于执行基本的分词操作(如分割标点符号、转换为小写等)。
class BasicTokenizer(object):
    """
    Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).

    Args:
        do_lower_case (`bool`, *optional*, defaults to `True`):
            Whether or not to lowercase the input when tokenizing.
            是否在分词时将输入转换为小写。

        never_split (`Iterable`, *optional*):
            Collection of tokens which will never be split during tokenization. Only has an effect when
            `do_basic_tokenize=True`
            在分词过程中永远不会被分开的标记集合。仅在 `do_basic_tokenize=True` 时有效。

        tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
            Whether or not to tokenize Chinese characters.
            是否对中文字符进行分词。建议对日文关闭此选项(参见这个问题链接)。

        strip_accents (`bool`, *optional*):
            Whether or not to strip all accents. If this option is not specified, then it will be determined by the
            value for `lowercase` (as in the original BERT).
            是否去除所有重音符号。如果未指定此选项,则由 `lowercase` 的值来确定(与原始的BERT一致)。

        do_split_on_punc (`bool`, *optional*, defaults to `True`):
            In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
            the full context of the words, such as contractions.
            在某些情况下,我们希望跳过基本的标点符号分割,以便后续的分词可以捕获单词的完整上下文,比如缩写词。
    """

    def __init__(
        self,
        do_lower_case=True,
        never_split=None,
        tokenize_chinese_chars=True,
        strip_accents=None,
        do_split_on_punc=True,
    ):
        # 如果 `never_split` 为 `None`,则初始化为一个空列表
        if never_split is None:
            never_split = []
        # 设置是否将输入转换为小写
        self.do_lower_case = do_lower_case
        # 将 `never_split` 转换为集合,用于存储永不分割的标记集合
        self.never_split = set(never_split)
        # 设置是否对中文字符进行分词
        self.tokenize_chinese_chars = tokenize_chinese_chars
        # 设置是否去除所有重音符号,如果未指定则根据 `lowercase` 的值确定
        self.strip_accents = strip_accents
        # 设置是否执行基本的标点符号分割
        self.do_split_on_punc = do_split_on_punc
    def tokenize`
    def tokenize(self, text, never_split=None):
        """
        Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.

        Args:
            never_split (`List[str]`, *optional*)
                Kept for backward compatibility purposes. Now implemented directly at the base class level (see
                [`PreTrainedTokenizer.tokenize`]) List of token not to split.
        """
        # 使用 union() 方法将 self.never_split 和给定的 never_split 合并成一个新的集合
        never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
        # 清理文本,去除不必要的字符或格式
        text = self._clean_text(text)

        # 2018 年 11 月 1 日添加此功能,用于多语言和中文模型。现在也适用于英语模型,但这并不重要,
        # 因为英语模型未经过任何中文数据的训练,通常在其中也没有中文数据(英语维基百科中有些中文词汇,
        # 因此词汇表中有一些中文字符)。
        if self.tokenize_chinese_chars:
            text = self._tokenize_chinese_chars(text)
        # 使用 NFC 标准规范化文本中的 Unicode 字符
        unicode_normalized_text = unicodedata.normalize("NFC", text)
        # 使用 whitespace_tokenize() 函数将文本分割成原始的单词列表
        orig_tokens = whitespace_tokenize(unicode_normalized_text)
        split_tokens = []
        for token in orig_tokens:
            if token not in never_split:
                if self.do_lower_case:
                    # 如果 do_lower_case 为 True,则将 token 转换为小写
                    token = token.lower()
                    # 如果 strip_accents 不为 False,则运行 _run_strip_accents() 方法去除重音符号
                    if self.strip_accents is not False:
                        token = self._run_strip_accents(token)
                elif self.strip_accents:
                    # 如果 strip_accents 为 True,则运行 _run_strip_accents() 方法去除重音符号
                    token = self._run_strip_accents(token)
            # 将 token 拆分并扩展到 split_tokens 列表中
            split_tokens.extend(self._run_split_on_punc(token, never_split))

        # 使用 whitespace_tokenize() 函数将 split_tokens 列表重新组合成字符串,并再次进行分词
        output_tokens = whitespace_tokenize(" ".join(split_tokens))
        # 返回最终的 token 列表
        return output_tokens

    def _run_strip_accents(self, text):
        """Strips accents from a piece of text."""
        # 使用 NFD 标准规范化文本中的 Unicode 字符
        text = unicodedata.normalize("NFD", text)
        output = []
        for char in text:
            # 获取字符的 Unicode 类别
            cat = unicodedata.category(char)
            # 如果 Unicode 类别为 Mn(Nonspacing_Mark),则跳过该字符
            if cat == "Mn":
                continue
            # 否则将字符添加到输出列表中
            output.append(char)
        # 将输出列表中的字符组合成字符串并返回
        return "".join(output)
    def _run_split_on_punc(self, text, never_split=None):
        """Splits punctuation on a piece of text."""
        # 如果不需要根据标点符号分割文本,或者文本在never_split列表中,则直接返回包含整个文本的列表
        if not self.do_split_on_punc or (never_split is not None and text in never_split):
            return [text]
        # 将文本转换为字符列表
        chars = list(text)
        i = 0
        start_new_word = True
        output = []
        # 遍历字符列表
        while i < len(chars):
            char = chars[i]
            # 如果当前字符是标点符号,则将其作为单独的列表项添加到输出列表中,并设置开始一个新单词的标志为True
            if _is_punctuation(char):
                output.append([char])
                start_new_word = True
            else:
                # 如果不是标点符号,检查是否应该开始一个新单词
                if start_new_word:
                    output.append([])
                start_new_word = False
                # 将当前字符添加到最后一个列表项中
                output[-1].append(char)
            i += 1

        # 将列表中的子列表连接成字符串并返回
        return ["".join(x) for x in output]

    def _tokenize_chinese_chars(self, text):
        """Adds whitespace around any CJK character."""
        output = []
        # 遍历文本中的每个字符
        for char in text:
            cp = ord(char)
            # 如果字符是CJK字符,将其前后加上空格并添加到输出列表中
            if self._is_chinese_char(cp):
                output.append(" ")
                output.append(char)
                output.append(" ")
            else:
                # 如果不是CJK字符,直接添加到输出列表中
                output.append(char)
        # 将列表转换为字符串并返回
        return "".join(output)

    def _is_chinese_char(self, cp):
        """Checks whether CP is the codepoint of a CJK character."""
        # 判断字符编码点是否在CJK Unicode块内
        if (
            (cp >= 0x4E00 and cp <= 0x9FFF)
            or (cp >= 0x3400 and cp <= 0x4DBF)
            or (cp >= 0x20000 and cp <= 0x2A6DF)
            or (cp >= 0x2A700 and cp <= 0x2B73F)
            or (cp >= 0x2B740 and cp <= 0x2B81F)
            or (cp >= 0x2B820 and cp <= 0x2CEAF)
            or (cp >= 0xF900 and cp <= 0xFAFF)
            or (cp >= 0x2F800 and cp <= 0x2FA1F)
        ):
            return True

        return False

    def _clean_text(self, text):
        """Performs invalid character removal and whitespace cleanup on text."""
        output = []
        # 遍历文本中的每个字符
        for char in text:
            cp = ord(char)
            # 如果字符是空白字符或控制字符,将其替换为空格
            if cp == 0 or cp == 0xFFFD or _is_control(char):
                continue
            if _is_whitespace(char):
                output.append(" ")
            else:
                output.append(char)
        # 将列表转换为字符串并返回
        return "".join(output)
# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer
class WordpieceTokenizer(object):
    """Runs WordPiece tokenization."""

    def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
        # 初始化 WordpieceTokenizer 类,设置词汇表、未知 token 和最大输入字符数
        self.vocab = vocab
        self.unk_token = unk_token
        self.max_input_chars_per_word = max_input_chars_per_word

    def tokenize(self, text):
        """
        Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
        tokenization using the given vocabulary.

        For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.

        Args:
            text: A single token or whitespace separated tokens. This should have
                already been passed through *BasicTokenizer*.

        Returns:
            A list of wordpiece tokens.
        """
        # 初始化输出 token 列表
        output_tokens = []
        # 对文本进行分词,使用空白字符进行分隔
        for token in whitespace_tokenize(text):
            chars = list(token)
            # 如果 token 长度超过最大输入字符数,则添加未知 token
            if len(chars) > self.max_input_chars_per_word:
                output_tokens.append(self.unk_token)
                continue

            is_bad = False
            start = 0
            sub_tokens = []
            # 使用贪婪最长匹配算法进行 tokenization
            while start < len(chars):
                end = len(chars)
                cur_substr = None
                while start < end:
                    substr = "".join(chars[start:end])
                    # 对于非首字符的 substr,添加 '##' 前缀
                    if start > 0:
                        substr = "##" + substr
                    # 如果 substr 在词汇表中,则作为当前 token
                    if substr in self.vocab:
                        cur_substr = substr
                        break
                    end -= 1
                # 如果未找到合适的 token,则标记为 bad token
                if cur_substr is None:
                    is_bad = True
                    break
                sub_tokens.append(cur_substr)
                start = end

            # 如果存在 bad token,则添加未知 token,否则添加所有子 token
            if is_bad:
                output_tokens.append(self.unk_token)
            else:
                output_tokens.extend(sub_tokens)
        return output_tokens

.\models\electra\tokenization_electra_fast.py

# 导入必要的模块
import json  # 导入用于处理 JSON 数据的模块
from typing import List, Optional, Tuple  # 导入类型提示模块

from tokenizers import normalizers  # 从 tokenizers 模块导入 normalizers 功能

from ...tokenization_utils_fast import PreTrainedTokenizerFast  # 导入预训练分词器
from .tokenization_electra import ElectraTokenizer  # 从当前目录下的 tokenization_electra 模块导入 ElectraTokenizer 类

# 定义文件名与文件路径映射关系的常量字典
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}

# 定义预训练模型与其词汇文件和分词器文件映射关系的常量字典
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "google/electra-small-generator": (
            "https://huggingface.co/google/electra-small-generator/resolve/main/vocab.txt"
        ),
        "google/electra-base-generator": "https://huggingface.co/google/electra-base-generator/resolve/main/vocab.txt",
        "google/electra-large-generator": (
            "https://huggingface.co/google/electra-large-generator/resolve/main/vocab.txt"
        ),
        "google/electra-small-discriminator": (
            "https://huggingface.co/google/electra-small-discriminator/resolve/main/vocab.txt"
        ),
        "google/electra-base-discriminator": (
            "https://huggingface.co/google/electra-base-discriminator/resolve/main/vocab.txt"
        ),
        "google/electra-large-discriminator": (
            "https://huggingface.co/google/electra-large-discriminator/resolve/main/vocab.txt"
        ),
    },
    "tokenizer_file": {
        "google/electra-small-generator": (
            "https://huggingface.co/google/electra-small-generator/resolve/main/tokenizer.json"
        ),
        "google/electra-base-generator": (
            "https://huggingface.co/google/electra-base-generator/resolve/main/tokenizer.json"
        ),
        "google/electra-large-generator": (
            "https://huggingface.co/google/electra-large-generator/resolve/main/tokenizer.json"
        ),
        "google/electra-small-discriminator": (
            "https://huggingface.co/google/electra-small-discriminator/resolve/main/tokenizer.json"
        ),
        "google/electra-base-discriminator": (
            "https://huggingface.co/google/electra-base-discriminator/resolve/main/tokenizer.json"
        ),
        "google/electra-large-discriminator": (
            "https://huggingface.co/google/electra-large-discriminator/resolve/main/tokenizer.json"
        ),
    },
}

# 定义预训练模型与其位置嵌入大小的映射关系的常量字典
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "google/electra-small-generator": 512,
    "google/electra-base-generator": 512,
    # 定义一个字典,包含四个条目,每个条目的键是一个字符串表示的模型名称,值是一个整数表示的模型大小(512表示模型大小为512字节)
    "google/electra-large-generator": 512,
    "google/electra-small-discriminator": 512,
    "google/electra-base-discriminator": 512,
    "google/electra-large-discriminator": 512,
}

# 预定义的预训练配置字典,包含了Electra模型的不同预训练变体及其配置信息
PRETRAINED_INIT_CONFIGURATION = {
    "google/electra-small-generator": {"do_lower_case": True},
    "google/electra-base-generator": {"do_lower_case": True},
    "google/electra-large-generator": {"do_lower_case": True},
    "google/electra-small-discriminator": {"do_lower_case": True},
    "google/electra-base-discriminator": {"do_lower_case": True},
    "google/electra-large-discriminator": {"do_lower_case": True},
}

# 从transformers.models.bert.tokenization_bert_fast.BertTokenizerFast复制而来,修改为支持Electra模型的快速分词器
class ElectraTokenizerFast(PreTrainedTokenizerFast):
    r"""
    构建一个“快速”的ELECTRA分词器(基于HuggingFace的*tokenizers*库),基于WordPiece。

    此分词器继承自[`PreTrainedTokenizerFast`],其中包含大多数主要方法。用户应参考该超类获取更多关于这些方法的信息。
    ```
    # 定义一个类,实现ElectraTokenizer的功能
    class ElectraTokenizer:
        # 默认的词汇文件名列表
        vocab_files_names = VOCAB_FILES_NAMES
        # 预训练模型的词汇文件映射
        pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
        # 预训练模型的初始化配置
        pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
        # 预训练位置嵌入的最大模型输入大小
        max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
        # ElectraTokenizer 的慢速实现类
        slow_tokenizer_class = ElectraTokenizer
    
        # 初始化方法,用于创建一个 ElectraTokenizer 对象
        def __init__(
            self,
            vocab_file=None,
            tokenizer_file=None,
            do_lower_case=True,
            unk_token="[UNK]",
            sep_token="[SEP]",
            pad_token="[PAD]",
            cls_token="[CLS]",
            mask_token="[MASK]",
            tokenize_chinese_chars=True,
            strip_accents=None,
            **kwargs,
        ):
    ):
        super().__init__(
            vocab_file,
            tokenizer_file=tokenizer_file,
            do_lower_case=do_lower_case,
            unk_token=unk_token,
            sep_token=sep_token,
            pad_token=pad_token,
            cls_token=cls_token,
            mask_token=mask_token,
            tokenize_chinese_chars=tokenize_chinese_chars,
            strip_accents=strip_accents,
            **kwargs,
        )

# 调用父类的初始化方法,传入必要的参数和关键字参数来初始化对象。


        normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())

# 从后端的分词器对象中获取标准化器的状态,将其反序列化为Python对象。


        if (
            normalizer_state.get("lowercase", do_lower_case) != do_lower_case
            or normalizer_state.get("strip_accents", strip_accents) != strip_accents
            or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars
        ):

# 检查标准化器的状态是否与当前对象的参数匹配,如果不匹配则需要更新标准化器。


            normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
            normalizer_state["lowercase"] = do_lower_case
            normalizer_state["strip_accents"] = strip_accents
            normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars
            self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)

# 如果有不匹配的参数,根据标准化器的类型更新标准化器对象,确保与当前对象的参数一致。


        self.do_lower_case = do_lower_case

# 更新当前对象的小写参数为传入的do_lower_case值。


    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
        adding special tokens. A ELECTRA sequence has the following format:

        - single sequence: `[CLS] X [SEP]`
        - pair of sequences: `[CLS] A [SEP] B [SEP]`

        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
        """
        output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]

# 构建模型输入,根据输入的序列或序列对进行连接并添加特殊标记,用于序列分类任务。ELECTRA序列的格式包括单一序列和序列对,对应不同的特殊标记。


        if token_ids_1 is not None:
            output += token_ids_1 + [self.sep_token_id]

        return output

# 如果提供了第二个序列token_ids_1,则将其连接到output中并添加特殊分隔标记,最后返回构建好的输入列表。


    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None

# 根据给定的序列创建token type IDs,用于区分不同序列的类型。
    def create_electra_mask(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. A ELECTRA sequence
        pair mask has the following format:

        ```
        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
        | first sequence    | second sequence |
        ```

        If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).

        Args:
            token_ids_0 (`List[int]`):
                List of IDs for the first sequence.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of token type IDs according to the given sequence(s).
        """
        # Define the separation and classification tokens
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]

        # If token_ids_1 is not provided, return a mask with zeros for the first sequence only
        if token_ids_1 is None:
            return len(cls + token_ids_0 + sep) * [0]

        # Return a mask with zeros for the first sequence and ones for the second sequence
        return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        """
        Save the tokenizer's vocabulary files to the specified directory.

        Args:
            save_directory (str):
                Directory where the vocabulary files will be saved.
            filename_prefix (str, *optional*):
                Optional prefix for the saved files.

        Returns:
            `Tuple[str]`: Tuple containing the filenames of the saved vocabulary files.
        """
        # Save the model's vocabulary files using the tokenizer's internal method
        files = self._tokenizer.model.save(save_directory, name=filename_prefix)
        return tuple(files)

.\models\electra\__init__.py

# 引入类型检查模块,用于静态类型检查
from typing import TYPE_CHECKING

# 从工具模块中引入所需函数和异常类
from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_flax_available,
    is_tf_available,
    is_tokenizers_available,
    is_torch_available,
)

# 定义导入结构,包含不同模块及其对应的导入内容列表
_import_structure = {
    "configuration_electra": ["ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "ElectraConfig", "ElectraOnnxConfig"],
    "tokenization_electra": ["ElectraTokenizer"],
}

# 检查是否存在 tokenizers 库,若不存在则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_tokenizers_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果存在 tokenizers 库,则将 ElectraTokenizerFast 添加到导入结构中
    _import_structure["tokenization_electra_fast"] = ["ElectraTokenizerFast"]

# 检查是否存在 torch 库,若不存在则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果存在 torch 库,则将相关的 Electra 模型导入添加到导入结构中
    _import_structure["modeling_electra"] = [
        "ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST",
        "ElectraForCausalLM",
        "ElectraForMaskedLM",
        "ElectraForMultipleChoice",
        "ElectraForPreTraining",
        "ElectraForQuestionAnswering",
        "ElectraForSequenceClassification",
        "ElectraForTokenClassification",
        "ElectraModel",
        "ElectraPreTrainedModel",
        "load_tf_weights_in_electra",
    ]

# 检查是否存在 tensorflow 库,若不存在则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果存在 tensorflow 库,则将相关的 TFElectra 模型导入添加到导入结构中
    _import_structure["modeling_tf_electra"] = [
        "TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST",
        "TFElectraForMaskedLM",
        "TFElectraForMultipleChoice",
        "TFElectraForPreTraining",
        "TFElectraForQuestionAnswering",
        "TFElectraForSequenceClassification",
        "TFElectraForTokenClassification",
        "TFElectraModel",
        "TFElectraPreTrainedModel",
    ]

# 检查是否存在 flax 库,若不存在则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_flax_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果存在 flax 库,则将相关的 FlaxElectra 模型导入添加到导入结构中
    _import_structure["modeling_flax_electra"] = [
        "FlaxElectraForCausalLM",
        "FlaxElectraForMaskedLM",
        "FlaxElectraForMultipleChoice",
        "FlaxElectraForPreTraining",
        "FlaxElectraForQuestionAnswering",
        "FlaxElectraForSequenceClassification",
        "FlaxElectraForTokenClassification",
        "FlaxElectraModel",
        "FlaxElectraPreTrainedModel",
    ]

# 如果在类型检查环境下
if TYPE_CHECKING:
    # 空语句,因为在类型检查环境下不需要执行额外的代码
    pass
    # 从当前目录中导入以下模块和变量,分别是预训练配置映射、ElectraConfig 类和 ElectraOnnxConfig 类
    from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig, ElectraOnnxConfig
    # 导入 ElectraTokenizer 类,用于处理 Electra 模型的分词器
    
    # 检查是否安装了 tokenizers 库,若未安装则抛出 OptionalDependencyNotAvailable 异常
    try:
        if not is_tokenizers_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 若 tokenizers 可用,则从当前目录中导入 ElectraTokenizerFast 类,用于更快速的分词操作
        from .tokenization_electra_fast import ElectraTokenizerFast
    
    # 检查是否安装了 PyTorch 库,若未安装则抛出 OptionalDependencyNotAvailable 异常
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 若 PyTorch 可用,则从当前目录中导入以下 Electra 相关类和函数
        from .modeling_electra import (
            ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
            ElectraForCausalLM,
            ElectraForMaskedLM,
            ElectraForMultipleChoice,
            ElectraForPreTraining,
            ElectraForQuestionAnswering,
            ElectraForSequenceClassification,
            ElectraForTokenClassification,
            ElectraModel,
            ElectraPreTrainedModel,
            load_tf_weights_in_electra,
        )
    
    # 检查是否安装了 TensorFlow 库,若未安装则抛出 OptionalDependencyNotAvailable 异常
    try:
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 若 TensorFlow 可用,则从当前目录中导入以下 TF-Electra 相关类和函数
        from .modeling_tf_electra import (
            TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
            TFElectraForMaskedLM,
            TFElectraForMultipleChoice,
            TFElectraForPreTraining,
            TFElectraForQuestionAnswering,
            TFElectraForSequenceClassification,
            TFElectraForTokenClassification,
            TFElectraModel,
            TFElectraPreTrainedModel,
        )
    
    # 检查是否安装了 Flax 库,若未安装则抛出 OptionalDependencyNotAvailable 异常
    try:
        if not is_flax_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 若 Flax 可用,则从当前目录中导入以下 Flax-Electra 相关类和函数
        from .modeling_flax_electra import (
            FlaxElectraForCausalLM,
            FlaxElectraForMaskedLM,
            FlaxElectraForMultipleChoice,
            FlaxElectraForPreTraining,
            FlaxElectraForQuestionAnswering,
            FlaxElectraForSequenceClassification,
            FlaxElectraForTokenClassification,
            FlaxElectraModel,
            FlaxElectraPreTrainedModel,
        )
else:
    # 导入 sys 模块,用于动态设置当前模块为懒加载模块
    import sys

    # 使用 sys.modules 来将当前模块设置为懒加载模块的实例
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\encodec\configuration_encodec.py

# 设置编码格式为 UTF-8,确保代码可以正确处理各种字符
# 版权声明,指出版权归 Meta Platforms, Inc. 及其关联公司和 HuggingFace Inc. 团队所有
# 根据 Apache 许可证 2.0 版本授权,只有在符合许可证的情况下才能使用此文件
# 可以通过链接获取许可证的副本
# 根据适用法律或书面同意,软件根据“原样”分发,无任何明示或暗示的保证或条件
# 请参阅许可证了解具体语言的规定,以及许可证下的限制
""" EnCodec model configuration"""

# 导入数学库
import math
# 导入类型提示模块,用于类型注解
from typing import Optional

# 导入 numpy 库,用于数值操作
import numpy as np

# 导入配置工具中的预训练配置类
from ...configuration_utils import PretrainedConfig
# 导入日志工具
from ...utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 预训练模型的配置文件映射,将模型名称映射到其配置文件的 URL
ENCODEC_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "facebook/encodec_24khz": "https://huggingface.co/facebook/encodec_24khz/resolve/main/config.json",
    "facebook/encodec_48khz": "https://huggingface.co/facebook/encodec_48khz/resolve/main/config.json",
}

# EncodecConfig 类,用于存储 Encodec 模型的配置信息
class EncodecConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of an [`EncodecModel`]. It is used to instantiate a
    Encodec model according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the
    [facebook/encodec_24khz](https://huggingface.co/facebook/encodec_24khz) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Example:

    ```
    >>> from transformers import EncodecModel, EncodecConfig

    >>> # Initializing a "facebook/encodec_24khz" style configuration
    >>> configuration = EncodecConfig()

    >>> # Initializing a model (with random weights) from the "facebook/encodec_24khz" style configuration
    >>> model = EncodecModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    """

    # 模型类型为 encodec
    model_type = "encodec"

    # 构造方法,初始化 EncodecConfig 实例
    def __init__(
        self,
        target_bandwidths=[1.5, 3.0, 6.0, 12.0, 24.0],
        sampling_rate=24_000,
        audio_channels=1,
        normalize=False,
        chunk_length_s=None,
        overlap=None,
        hidden_size=128,
        num_filters=32,
        num_residual_layers=1,
        upsampling_ratios=[8, 5, 4, 2],
        norm_type="weight_norm",
        kernel_size=7,
        last_kernel_size=7,
        residual_kernel_size=3,
        dilation_growth_rate=2,
        use_causal_conv=True,
        pad_mode="reflect",
        compress=2,
        num_lstm_layers=2,
        trim_right_ratio=1.0,
        codebook_size=1024,
        codebook_dim=None,
        use_conv_shortcut=True,
        **kwargs,
    ):
        self.target_bandwidths = target_bandwidths
        self.sampling_rate = sampling_rate
        self.audio_channels = audio_channels
        self.normalize = normalize
        self.chunk_length_s = chunk_length_s
        self.overlap = overlap
        self.hidden_size = hidden_size
        self.num_filters = num_filters
        self.num_residual_layers = num_residual_layers
        self.upsampling_ratios = upsampling_ratios
        self.norm_type = norm_type
        self.kernel_size = kernel_size
        self.last_kernel_size = last_kernel_size
        self.residual_kernel_size = residual_kernel_size
        self.dilation_growth_rate = dilation_growth_rate
        self.use_causal_conv = use_causal_conv
        self.pad_mode = pad_mode
        self.compress = compress
        self.num_lstm_layers = num_lstm_layers
        self.trim_right_ratio = trim_right_ratio
        self.codebook_size = codebook_size
        # 设置 codebook_dim,如果未指定,则使用 hidden_size
        self.codebook_dim = codebook_dim if codebook_dim is not None else hidden_size
        self.use_conv_shortcut = use_conv_shortcut

        # 检查 norm_type 是否为支持的类型,否则抛出 ValueError 异常
        if self.norm_type not in ["weight_norm", "time_group_norm"]:
            raise ValueError(
                f'self.norm_type must be one of `"weight_norm"`, `"time_group_norm"`), got {self.norm_type}'
            )

        # 调用父类的构造方法,传入其他未明确列出的关键字参数
        super().__init__(**kwargs)

    # 由于 chunk_length_s 可能会在运行时更改,所以这是一个属性
    @property
    def chunk_length(self) -> Optional[int]:
        # 如果 chunk_length_s 为 None,则返回 None
        if self.chunk_length_s is None:
            return None
        else:
            # 否则返回计算得到的 chunk_length
            return int(self.chunk_length_s * self.sampling_rate)

    # 由于 chunk_length_s 和 overlap 可能会在运行时更改,所以这是一个属性
    @property
    def chunk_stride(self) -> Optional[int]:
        # 如果 chunk_length_s 或 overlap 为 None,则返回 None
        if self.chunk_length_s is None or self.overlap is None:
            return None
        else:
            # 否则返回计算得到的 chunk_stride
            return max(1, int((1.0 - self.overlap) * self.chunk_length))

    # 计算并返回帧率,这是一个属性
    @property
    def frame_rate(self) -> int:
        hop_length = np.prod(self.upsampling_ratios)  # 计算 upsampling_ratios 的乘积
        return math.ceil(self.sampling_rate / hop_length)  # 计算并返回帧率

    # 返回 quantizer 的数量,这是一个属性
    @property
    def num_quantizers(self) -> int:
        return int(1000 * self.target_bandwidths[-1] // (self.frame_rate * 10))

.\models\encodec\convert_encodec_checkpoint_to_pytorch.py

# 设置编码方式为 UTF-8
# 版权声明,指出版权属于 2023 年的 HuggingFace Inc. 团队所有
# 根据 Apache 许可证版本 2.0 使用本文件,详细信息可以访问指定网址获取
# 除非法律要求或书面同意,否则不得使用本文件
# 根据 Apache 许可证版本 2.0,本软件基于“原样”分发,不提供任何形式的担保或条件
# 请查看许可证,了解具体语言版本的细节

"""Convert EnCodec checkpoints."""

# 导入必要的库
import argparse  # 用于解析命令行参数

import torch  # PyTorch 库

from transformers import (  # 导入 transformers 库中的相关模块
    EncodecConfig,  # EnCodec 的配置类
    EncodecFeatureExtractor,  # EnCodec 的特征提取器类
    EncodecModel,  # EnCodec 的模型类
    logging,  # 日志记录模块
)

# 设置日志记录的详细程度为 info 级别
logging.set_verbosity_info()
# 获取名为 "transformers.models.encodec" 的日志记录器
logger = logging.get_logger("transformers.models.encodec")

# 定义映射字典,用于重命名量化器(quantizer)中的模型参数
MAPPING_QUANTIZER = {
    "quantizer.vq.layers.*._codebook.inited": "quantizer.layers.*.codebook.inited",
    "quantizer.vq.layers.*._codebook.cluster_size": "quantizer.layers.*.codebook.cluster_size",
    "quantizer.vq.layers.*._codebook.embed": "quantizer.layers.*.codebook.embed",
    "quantizer.vq.layers.*._codebook.embed_avg": "quantizer.layers.*.codebook.embed_avg",
}

# 定义映射字典,用于重命名编码器(encoder)中的模型参数
MAPPING_ENCODER = {
    "encoder.model.0.conv.conv": "encoder.layers.0.conv",
    "encoder.model.1.block.1.conv.conv": "encoder.layers.1.block.1.conv",
    "encoder.model.1.block.3.conv.conv": "encoder.layers.1.block.3.conv",
    "encoder.model.1.shortcut.conv.conv": "encoder.layers.1.shortcut.conv",
    "encoder.model.3.conv.conv": "encoder.layers.3.conv",
    "encoder.model.4.block.1.conv.conv": "encoder.layers.4.block.1.conv",
    "encoder.model.4.block.3.conv.conv": "encoder.layers.4.block.3.conv",
    "encoder.model.4.shortcut.conv.conv": "encoder.layers.4.shortcut.conv",
    "encoder.model.6.conv.conv": "encoder.layers.6.conv",
    "encoder.model.7.block.1.conv.conv": "encoder.layers.7.block.1.conv",
    "encoder.model.7.block.3.conv.conv": "encoder.layers.7.block.3.conv",
    "encoder.model.7.shortcut.conv.conv": "encoder.layers.7.shortcut.conv",
    "encoder.model.9.conv.conv": "encoder.layers.9.conv",
    "encoder.model.10.block.1.conv.conv": "encoder.layers.10.block.1.conv",
    "encoder.model.10.block.3.conv.conv": "encoder.layers.10.block.3.conv",
    "encoder.model.10.shortcut.conv.conv": "encoder.layers.10.shortcut.conv",
    "encoder.model.12.conv.conv": "encoder.layers.12.conv",
    "encoder.model.13.lstm": "encoder.layers.13.lstm",
    "encoder.model.15.conv.conv": "encoder.layers.15.conv",
}

# 定义映射字典,用于重命名 48kHz 编码器(encoder)中的模型参数
MAPPING_ENCODER_48K = {
    "encoder.model.0.conv.norm": "encoder.layers.0.norm",
    # 这里可以继续添加其他的映射关系
}
    # 定义一个字典,映射旧模型中的层标准化层到新模型中对应的标准化层
    {
        "encoder.model.1.block.1.conv.norm": "encoder.layers.1.block.1.norm",
        "encoder.model.1.block.3.conv.norm": "encoder.layers.1.block.3.norm",
        "encoder.model.1.shortcut.conv.norm": "encoder.layers.1.shortcut.norm",
        "encoder.model.3.conv.norm": "encoder.layers.3.norm",
        "encoder.model.4.block.1.conv.norm": "encoder.layers.4.block.1.norm",
        "encoder.model.4.block.3.conv.norm": "encoder.layers.4.block.3.norm",
        "encoder.model.4.shortcut.conv.norm": "encoder.layers.4.shortcut.norm",
        "encoder.model.6.conv.norm": "encoder.layers.6.norm",
        "encoder.model.7.block.1.conv.norm": "encoder.layers.7.block.1.norm",
        "encoder.model.7.block.3.conv.norm": "encoder.layers.7.block.3.norm",
        "encoder.model.7.shortcut.conv.norm": "encoder.layers.7.shortcut.norm",
        "encoder.model.9.conv.norm": "encoder.layers.9.norm",
        "encoder.model.10.block.1.conv.norm": "encoder.layers.10.block.1.norm",
        "encoder.model.10.block.3.conv.norm": "encoder.layers.10.block.3.norm",
        "encoder.model.10.shortcut.conv.norm": "encoder.layers.10.shortcut.norm",
        "encoder.model.12.conv.norm": "encoder.layers.12.norm",
        "encoder.model.15.conv.norm": "encoder.layers.15.norm",
    }
}
# 闭合上一个字典的定义,表示字典定义的结束

MAPPING_DECODER = {
    "decoder.model.0.conv.conv": "decoder.layers.0.conv",
    "decoder.model.1.lstm": "decoder.layers.1.lstm",
    "decoder.model.3.convtr.convtr": "decoder.layers.3.conv",
    "decoder.model.4.block.1.conv.conv": "decoder.layers.4.block.1.conv",
    "decoder.model.4.block.3.conv.conv": "decoder.layers.4.block.3.conv",
    "decoder.model.4.shortcut.conv.conv": "decoder.layers.4.shortcut.conv",
    "decoder.model.6.convtr.convtr": "decoder.layers.6.conv",
    "decoder.model.7.block.1.conv.conv": "decoder.layers.7.block.1.conv",
    "decoder.model.7.block.3.conv.conv": "decoder.layers.7.block.3.conv",
    "decoder.model.7.shortcut.conv.conv": "decoder.layers.7.shortcut.conv",
    "decoder.model.9.convtr.convtr": "decoder.layers.9.conv",
    "decoder.model.10.block.1.conv.conv": "decoder.layers.10.block.1.conv",
    "decoder.model.10.block.3.conv.conv": "decoder.layers.10.block.3.conv",
    "decoder.model.10.shortcut.conv.conv": "decoder.layers.10.shortcut.conv",
    "decoder.model.12.convtr.convtr": "decoder.layers.12.conv",
    "decoder.model.13.block.1.conv.conv": "decoder.layers.13.block.1.conv",
    "decoder.model.13.block.3.conv.conv": "decoder.layers.13.block.3.conv",
    "decoder.model.13.shortcut.conv.conv": "decoder.layers.13.shortcut.conv",
    "decoder.model.15.conv.conv": "decoder.layers.15.conv",
}
# 映射字典,将模型中的编码器层命名映射到解码器层命名,用于对模型进行结构映射

MAPPING_DECODER_48K = {
    "decoder.model.0.conv.norm": "decoder.layers.0.norm",
    "decoder.model.3.convtr.norm": "decoder.layers.3.norm",
    "decoder.model.4.block.1.conv.norm": "decoder.layers.4.block.1.norm",
    "decoder.model.4.block.3.conv.norm": "decoder.layers.4.block.3.norm",
    "decoder.model.4.shortcut.conv.norm": "decoder.layers.4.shortcut.norm",
    "decoder.model.6.convtr.norm": "decoder.layers.6.norm",
    "decoder.model.7.block.1.conv.norm": "decoder.layers.7.block.1.norm",
    "decoder.model.7.block.3.conv.norm": "decoder.layers.7.block.3.norm",
    "decoder.model.7.shortcut.conv.norm": "decoder.layers.7.shortcut.norm",
    "decoder.model.9.convtr.norm": "decoder.layers.9.norm",
    "decoder.model.10.block.1.conv.norm": "decoder.layers.10.block.1.norm",
    "decoder.model.10.block.3.conv.norm": "decoder.layers.10.block.3.norm",
    "decoder.model.10.shortcut.conv.norm": "decoder.layers.10.shortcut.norm",
    "decoder.model.12.convtr.norm": "decoder.layers.12.norm",
    "decoder.model.13.block.1.conv.norm": "decoder.layers.13.block.1.norm",
    "decoder.model.13.block.3.conv.norm": "decoder.layers.13.block.3.norm",
    "decoder.model.13.shortcut.conv.norm": "decoder.layers.13.shortcut.norm",
    "decoder.model.15.conv.norm": "decoder.layers.15.norm",
}
# 映射字典,将模型中的编码器层的归一化命名映射到解码器层的归一化命名

MAPPING_24K = {
    **MAPPING_QUANTIZER,
    **MAPPING_ENCODER,
    **MAPPING_DECODER,
}
# 将量化器、编码器和解码器的映射合并到一个字典中,用于24K配置

MAPPING_48K = {
    **MAPPING_QUANTIZER,
    **MAPPING_ENCODER,
    **MAPPING_ENCODER_48K,
    **MAPPING_DECODER,
    **MAPPING_DECODER_48K,
}
# 将量化器、编码器、解码器48K配置的映射合并到一个字典中,用于48K配置

TOP_LEVEL_KEYS = []
# 初始化一个空列表,用于存储顶层键

IGNORE_KEYS = []
# 初始化一个空列表,用于存储需要忽略的键

def set_recursively(hf_pointer, key, value, full_name, weight_type):
    # 将 key 按 "." 分割成属性列表,逐级获取 hf_pointer 的属性值
    for attribute in key.split("."):
        hf_pointer = getattr(hf_pointer, attribute)

    # 如果指定了 weight_type,则获取 hf_pointer 对应属性的形状
    if weight_type is not None:
        hf_shape = getattr(hf_pointer, weight_type).shape
    else:
        # 否则获取 hf_pointer 自身的形状
        hf_shape = hf_pointer.shape

    # 检查获取的形状是否与 value 的形状相匹配,如果不匹配则抛出 ValueError 异常
    if hf_shape != value.shape:
        raise ValueError(
            f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
            f" {value.shape} for {full_name}"
        )

    # 根据 weight_type 类型设置 hf_pointer 对应的数据值
    if weight_type == "weight":
        hf_pointer.weight.data = value
    elif weight_type == "weight_g":
        hf_pointer.weight_g.data = value
    elif weight_type == "weight_v":
        hf_pointer.weight_v.data = value
    elif weight_type == "bias":
        hf_pointer.bias.data = value
    elif weight_type == "running_mean":
        hf_pointer.running_mean.data = value
    elif weight_type == "running_var":
        hf_pointer.running_var.data = value
    elif weight_type == "num_batches_tracked":
        hf_pointer.num_batches_tracked.data = value
    elif weight_type == "weight_ih_l0":
        hf_pointer.weight_ih_l0.data = value
    elif weight_type == "weight_hh_l0":
        hf_pointer.weight_hh_l0.data = value
    elif weight_type == "bias_ih_l0":
        hf_pointer.bias_ih_l0.data = value
    elif weight_type == "bias_hh_l0":
        hf_pointer.bias_hh_l0.data = value
    elif weight_type == "weight_ih_l1":
        hf_pointer.weight_ih_l1.data = value
    elif weight_type == "weight_hh_l1":
        hf_pointer.weight_hh_l1.data = value
    elif weight_type == "bias_ih_l1":
        hf_pointer.bias_ih_l1.data = value
    elif weight_type == "bias_hh_l1":
        hf_pointer.bias_hh_l1.data = value
    else:
        # 如果 weight_type 未指定或未匹配到特定类型,直接设置 hf_pointer 的数据值
        hf_pointer.data = value

    # 记录日志,指示成功初始化的属性和其来源
    logger.info(f"{key + ('.' + weight_type if weight_type is not None else '')} was initialized from {full_name}.")
# 判断给定的文件名是否应该被忽略,根据 ignore_keys 中的规则进行匹配
def should_ignore(name, ignore_keys):
    # 遍历 ignore_keys 列表中的每一个关键字
    for key in ignore_keys:
        # 如果关键字以 ".*" 结尾,检查 name 是否以 key[:-1] 开头,如果是则返回 True
        if key.endswith(".*"):
            if name.startswith(key[:-1]):
                return True
        # 如果关键字包含 ".*.",则将 key 拆分成前缀 prefix 和后缀 suffix,如果 name 同时包含这两部分则返回 True
        elif ".*." in key:
            prefix, suffix = key.split(".*.")
            if prefix in name and suffix in name:
                return True
        # 否则,如果关键字 key 直接在 name 中出现则返回 True
        elif key in name:
            return True
    # 如果都没有匹配成功,则返回 False,表示不忽略该文件名
    return False


# 根据给定的模型名和原始字典 orig_dict,加载对应模型的权重到 hf_model 中,并返回未使用的权重列表
def recursively_load_weights(orig_dict, hf_model, model_name):
    # 初始化未使用的权重列表
    unused_weights = []

    # 根据不同的模型名选择相应的映射关系
    if model_name == "encodec_24khz" or "encodec_32khz":
        MAPPING = MAPPING_24K
    elif model_name == "encodec_48khz":
        MAPPING = MAPPING_48K
    else:
        # 如果模型名不在支持列表中,抛出 ValueError 异常
        raise ValueError(f"Unsupported model: {model_name}")
    # 遍历原始字典的键值对
    for name, value in orig_dict.items():
        # 如果应该忽略该键名,则记录日志并跳过当前循环
        if should_ignore(name, IGNORE_KEYS):
            logger.info(f"{name} was ignored")
            continue

        # 标志:用于检查是否在后续处理中使用了该键名对应的数值
        is_used = False

        # 遍历映射字典中的键值对
        for key, mapped_key in MAPPING.items():
            # 如果当前映射键包含通配符"*"
            if "*" in key:
                # 拆分通配符前缀和后缀
                prefix, suffix = key.split(".*.")
                # 如果键名同时包含前缀和后缀,则使用后缀作为新的键名
                if prefix in name and suffix in name:
                    key = suffix

            # 如果当前映射键在键名中找到匹配
            if key in name:
                # 特定情况下的处理:防止 ".embed_avg" 初始化为 ".embed"
                if key.endswith("embed") and name.endswith("embed_avg"):
                    continue

                # 设置标志表明该键名已被使用
                is_used = True

                # 如果映射值中存在通配符"*",则根据层索引替换通配符
                if "*" in mapped_key:
                    layer_index = name.split(key)[0].split(".")[-2]
                    mapped_key = mapped_key.replace("*", layer_index)

                # 根据特定的权重类型为权重键赋值
                if "weight_g" in name:
                    weight_type = "weight_g"
                elif "weight_v" in name:
                    weight_type = "weight_v"
                elif "weight_ih_l0" in name:
                    weight_type = "weight_ih_l0"
                elif "weight_hh_l0" in name:
                    weight_type = "weight_hh_l0"
                elif "bias_ih_l0" in name:
                    weight_type = "bias_ih_l0"
                elif "bias_hh_l0" in name:
                    weight_type = "bias_hh_l0"
                elif "weight_ih_l1" in name:
                    weight_type = "weight_ih_l1"
                elif "weight_hh_l1" in name:
                    weight_type = "weight_hh_l1"
                elif "bias_ih_l1" in name:
                    weight_type = "bias_ih_l1"
                elif "bias_hh_l1" in name:
                    weight_type = "bias_hh_l1"
                elif "bias" in name:
                    weight_type = "bias"
                elif "weight" in name:
                    weight_type = "weight"
                elif "running_mean" in name:
                    weight_type = "running_mean"
                elif "running_var" in name:
                    weight_type = "running_var"
                elif "num_batches_tracked" in name:
                    weight_type = "num_batches_tracked"
                else:
                    weight_type = None

                # 递归地设置新模型的映射键对应的值
                set_recursively(hf_model, mapped_key, value, name, weight_type)

            # 继续下一个映射键的处理
            continue
        
        # 如果没有任何映射键被使用,则将该键名添加到未使用的权重列表中
        if not is_used:
            unused_weights.append(name)

    # 记录未使用的权重列表到警告日志中
    logger.warning(f"Unused weights: {unused_weights}")
# 用装饰器 @torch.no_grad() 标记该函数,禁止在函数内部进行梯度计算
def convert_checkpoint(
    model_name,
    checkpoint_path,
    pytorch_dump_folder_path,
    config_path=None,
    repo_id=None,
):
    """
    Copy/paste/tweak model's weights to transformers design.
    """
    # 如果提供了配置文件路径,则从预训练模型加载配置
    if config_path is not None:
        config = EncodecConfig.from_pretrained(config_path)
    else:
        # 否则创建一个新的配置对象
        config = EncodecConfig()

    # 根据模型名称设置配置对象的参数
    if model_name == "encodec_24khz":
        pass  # 对于 "encodec_24khz" 模型,配置已经是正确的
    elif model_name == "encodec_32khz":
        # 根据模型名称调整配置对象的参数
        config.upsampling_ratios = [8, 5, 4, 4]
        config.target_bandwidths = [2.2]
        config.num_filters = 64
        config.sampling_rate = 32_000
        config.codebook_size = 2048
        config.use_causal_conv = False
        config.normalize = False
        config.use_conv_shortcut = False
    elif model_name == "encodec_48khz":
        # 根据模型名称调整配置对象的参数
        config.upsampling_ratios = [8, 5, 4, 2]
        config.target_bandwidths = [3.0, 6.0, 12.0, 24.0]
        config.sampling_rate = 48_000
        config.audio_channels = 2
        config.use_causal_conv = False
        config.norm_type = "time_group_norm"
        config.normalize = True
        config.chunk_length_s = 1.0
        config.overlap = 0.01
    else:
        # 如果模型名称不在已知列表中,抛出异常
        raise ValueError(f"Unknown model name: {model_name}")

    # 根据配置对象创建模型
    model = EncodecModel(config)

    # 根据配置对象创建特征提取器
    feature_extractor = EncodecFeatureExtractor(
        feature_size=config.audio_channels,
        sampling_rate=config.sampling_rate,
        chunk_length_s=config.chunk_length_s,
        overlap=config.overlap,
    )

    # 将特征提取器保存到指定路径
    feature_extractor.save_pretrained(pytorch_dump_folder_path)

    # 加载原始 PyTorch 检查点
    original_checkpoint = torch.load(checkpoint_path)
    
    # 如果原始检查点中包含 "best_state" 键,只保留权重信息
    if "best_state" in original_checkpoint:
        original_checkpoint = original_checkpoint["best_state"]

    # 递归加载权重到模型中
    recursively_load_weights(original_checkpoint, model, model_name)

    # 将模型保存到指定路径
    model.save_pretrained(pytorch_dump_folder_path)

    # 如果提供了 repo_id,将特征提取器和模型推送到指定的 hub
    if repo_id:
        print("Pushing to the hub...")
        feature_extractor.push_to_hub(repo_id)
        model.push_to_hub(repo_id)


if __name__ == "__main__":
    # 解析命令行参数
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        default="encodec_24khz",
        type=str,
        help="The model to convert. Should be one of 'encodec_24khz', 'encodec_32khz', 'encodec_48khz'.",
    )
    parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
    parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
    parser.add_argument(
        "--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model."
    )
    parser.add_argument(
        "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
    )

    # 解析参数
    args = parser.parse_args()
    # 调用函数 convert_checkpoint,用于转换模型的检查点文件格式
    convert_checkpoint(
        args.model,                     # 指定模型名称参数
        args.checkpoint_path,           # 指定检查点文件路径参数
        args.pytorch_dump_folder_path,  # 指定转换后的 PyTorch 模型输出文件夹路径参数
        args.config_path,               # 指定模型配置文件路径参数
        args.push_to_hub,               # 指定是否将转换后的模型推送到 Hub 的参数
    )
posted @ 2024-06-30 15:36  绝不原创的飞龙  阅读(5)  评论(0编辑  收藏  举报