Transformers-源码解析-八-

Transformers 源码解析(八)

.\modeling_tf_outputs.py

# 导入警告模块,用于处理警告信息
import warnings
# 导入数据类装饰器,用于定义数据类
from dataclasses import dataclass
# 导入类型提示,用于类型注解
from typing import List, Optional, Tuple

# 导入 TensorFlow 库
import tensorflow as tf

# 从当前目录下的 utils 模块中导入 ModelOutput 类
from .utils import ModelOutput


@dataclass
class TFBaseModelOutput(ModelOutput):
    """
    模型输出的基类,包含可能的隐藏状态和注意力信息。

    Args:
        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
            模型最后一层的隐藏状态序列。
        hidden_states (`tuple(tf.FloatTensor)`, *optional*, 当 `output_hidden_states=True` 时返回或者 `config.output_hidden_states=True`):
            元组,包含每一层的隐藏状态 `tf.Tensor`(一个用于嵌入输出,一个用于每一层的输出)的形状为 `(batch_size, sequence_length, hidden_size)`。

            模型在每一层的隐藏状态,包括初始嵌入层的输出。
        attentions (`tuple(tf.Tensor)`, *optional*, 当 `output_attentions=True` 时返回或者 `config.output_attentions=True`):
            元组,包含每一层的注意力权重 `tf.Tensor` 的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。

            注意力 softmax 后的注意力权重,用于计算自注意力头的加权平均值。
    """

    last_hidden_state: tf.Tensor = None
    hidden_states: Tuple[tf.Tensor] | None = None
    attentions: Tuple[tf.Tensor] | None = None


@dataclass
class TFBaseModelOutputWithNoAttention(ModelOutput):
    """
    模型输出的基类,包含可能的隐藏状态,但不包含注意力信息。

    Args:
        last_hidden_state (`tf.Tensor` shape `(batch_size, num_channels, height, width)`):
            模型最后一层的隐藏状态序列。
        hidden_states (`tuple(tf.Tensor)`, *optional*, 当 `output_hidden_states=True` 时返回或者 `config.output_hidden_states=True`):
            元组,包含每一层的隐藏状态 `tf.Tensor`(一个用于嵌入层的输出,如果模型有嵌入层,一个用于每一层的输出)的形状为 `(batch_size, num_channels, height, width)`。

            模型在每一层的隐藏状态,包括可选的初始嵌入层的输出。
    """

    last_hidden_state: tf.Tensor = None
    # 声明一个可选类型的变量hidden_states,默认为None
    hidden_states: Optional[Tuple[tf.Tensor, ...]] = None
@dataclass
class TFBaseModelOutputWithPooling(ModelOutput):
    """
    Base class for model's outputs that also contains a pooling of the last hidden states.

    Args:
        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
            Last layer hidden-state of the first token of the sequence (classification token) further processed by a
            Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
            prediction (classification) objective during pretraining.

            This output is usually *not* a good summary of the semantic content of the input, you're often better with
            averaging or pooling the sequence of hidden-states for the whole input sequence.
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """
    
    last_hidden_state: tf.Tensor = None  # 最后一层模型输出的隐藏状态张量
    pooler_output: tf.Tensor = None  # 经过线性层和Tanh激活函数处理后的第一个标记的隐藏状态张量
    hidden_states: Tuple[tf.Tensor] | None = None  # 每层输出的隐藏状态张量的元组,包括初始嵌入层输出
    attentions: Tuple[tf.Tensor] | None = None  # 注意力权重的元组,用于计算自注意力头中的加权平均值


@dataclass
class TFBaseModelOutputWithPoolingAndNoAttention(ModelOutput):
    """
    Base class for model's outputs that also contains a pooling of the last hidden states.

    Args:
        last_hidden_state (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
            Last layer hidden-state after a pooling operation on the spatial dimensions.
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
            the output of each layer) of shape `(batch_size, num_channels, height, width)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
    """

    last_hidden_state: tf.Tensor = None  # 最后一层模型输出的隐藏状态张量
    pooler_output: tf.Tensor = None  # 在空间维度进行池化操作后的最后一层隐藏状态张量
    hidden_states: Tuple[tf.Tensor] | None = None  # 每层输出的隐藏状态张量的元组,包括可选的初始嵌入层输出
    # 定义变量 `last_hidden_state`,类型为 `tf.Tensor`,初始值为 None
    last_hidden_state: tf.Tensor = None
    # 定义变量 `pooler_output`,类型为 `tf.Tensor`,初始值为 None
    pooler_output: tf.Tensor = None
    # 定义变量 `hidden_states`,类型为 `Optional[Tuple[tf.Tensor, ...]]`,初始值为 None
    hidden_states: Optional[Tuple[tf.Tensor, ...]] = None
@dataclass
class TFBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
    """
    Base class for model's outputs that also contains a pooling of the last hidden states.

    Args:
        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
            Last layer hidden-state of the first token of the sequence (classification token) further processed by a
            Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
            prediction (classification) objective during pretraining.

            This output is usually *not* a good summary of the semantic content of the input, you're often better with
            averaging or pooling the sequence of hidden-states for the whole input sequence.
        past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
            sequence_length, embed_size_per_head)`).

            Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
            `past_key_values` input) to speed up sequential decoding.
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
            weighted average in the cross-attention heads.
    """

    # 最后一个隐藏状态,形状为(batch_size, sequence_length, hidden_size),表示模型最后一层的隐藏状态序列
    last_hidden_state: tf.Tensor = None

    # 汇聚输出,形状为(batch_size, hidden_size),表示经过线性层和Tanh激活函数处理的分类标记的最后一层隐藏状态
    # 在预训练期间,线性层的权重由下一个句子预测(分类)目标进行训练
    pooler_output: tf.Tensor = None

    # 历史关键值,形状为List[tf.Tensor],长度为config.n_layers,每个张量形状为(2, batch_size, num_heads, sequence_length, embed_size_per_head)
    # 当传递use_cache=True或config.use_cache=True时返回,包含预计算的隐藏状态(注意力块中的键和值),可用于加速序列解码
    past_key_values: List[tf.Tensor] | None = None

    # 隐藏状态,形状为tuple(tf.Tensor),当传递output_hidden_states=True或config.output_hidden_states=True时返回
    # 包含每一层输出的隐藏状态的元组,以及初始嵌入输出
    hidden_states: Tuple[tf.Tensor] | None = None

    # 注意力权重,形状为tuple(tf.Tensor),当传递output_attentions=True或config.output_attentions=True时返回
    # 包含每一层的注意力权重,形状为(batch_size, num_heads, sequence_length, sequence_length),用于计算自注意力头部的加权平均值
    attentions: Tuple[tf.Tensor] | None = None

    # 交叉注意力权重,形状为tuple(tf.Tensor),当传递output_attentions=True或config.output_attentions=True时返回
    # 解码器的交叉注意力层的注意力权重,经过注意力softmax后,用于计算交叉注意力头部的加权平均值
    cross_attentions: Tuple[tf.Tensor] | None = None
    cross_attentions: Tuple[tf.Tensor] | None = None
# 定义一个带有过去键/值的模型输出类,继承自`ModelOutput`
@dataclass
class TFBaseModelOutputWithPast(ModelOutput):
    """
    Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).

    Args:
        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.

            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
            hidden_size)` is output.
        past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
            sequence_length, embed_size_per_head)`).

            Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
            `past_key_values` input) to speed up sequential decoding.
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    # 定义类的属性:最后一个隐藏状态
    last_hidden_state: tf.Tensor = None
    # 定义类的属性:过去键/值的列表,用于加速顺序解码
    past_key_values: List[tf.Tensor] | None = None
    # 定义类的属性:包含每层隐藏状态的元组
    hidden_states: Tuple[tf.Tensor] | None = None
    # 定义类的属性:每层注意力权重的元组
    attentions: Tuple[tf.Tensor] | None = None


@dataclass
class TFBaseModelOutputWithCrossAttentions(ModelOutput):
    """
    Base class for model's outputs, with potential hidden states and attentions.
    """
    """
    Args:
        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
            模型最后一层的输出隐藏状态序列。
        hidden_states (`tuple(tf.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            元组的形式,包含每层模型的隐藏状态,以及初始嵌入输出。
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            元组的形式,包含每层注意力权重,用于计算自注意力中加权平均值。
        cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            元组的形式,包含解码器跨注意力层的注意力权重,用于计算跨注意力中加权平均值。
    """
    
    last_hidden_state: tf.Tensor = None
    hidden_states: Tuple[tf.Tensor] | None = None
    attentions: Tuple[tf.Tensor] | None = None
    cross_attentions: Tuple[tf.Tensor] | None = None
@dataclass
class TFBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
    """
    Model output class for Transformer-based models that includes past key/values and cross-attentions.

    Args:
        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Final layer hidden-states of the model.
            
            If `past_key_values` is used, only the last hidden-state of shape `(batch_size, 1, hidden_size)` is output.
        past_key_values (`List[tf.Tensor]`, *optional*):
            List of tensors, each of shape `(2, batch_size, num_heads, sequence_length, embed_size_per_head)`.

            Pre-computed hidden-states (key and values) for sequential decoding speed-up.
        hidden_states (`Tuple[tf.Tensor]`, *optional*):
            Tuple of tensors, one for embeddings and one for each layer's hidden-states, each of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`Tuple[tf.Tensor]`, *optional*):
            Tuple of tensors, one for each layer, each of shape `(batch_size, num_heads, sequence_length, sequence_length)`.

            Attention weights after softmax, used for self-attention heads.
        cross_attentions (`Tuple[tf.Tensor]`, *optional*):
            Tuple of tensors, one for each layer, each of shape `(batch_size, num_heads, sequence_length, sequence_length)`.

            Attention weights of the decoder's cross-attention layer after softmax.
    """

    last_hidden_state: tf.Tensor = None
    past_key_values: List[tf.Tensor] | None = None
    hidden_states: Tuple[tf.Tensor] | None = None
    attentions: Tuple[tf.Tensor] | None = None
    cross_attentions: Tuple[tf.Tensor] | None = None


@dataclass
class TFSeq2SeqModelOutput(ModelOutput):
    """
    Model output class for Seq2Seq Transformer-based models.

    Args:
        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Final layer hidden-states of the encoder.
        past_key_values (`List[tf.Tensor]`, *optional*):
            List of tensors, each of shape `(2, batch_size, num_heads, sequence_length, embed_size_per_head)`.

            Pre-computed hidden-states (key and values) for decoder's sequential decoding speed-up.
        decoder_hidden_states (`Tuple[tf.Tensor]`, *optional*):
            Tuple of tensors for decoder's hidden-states, each of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer for the decoder.
        decoder_attentions (`Tuple[tf.Tensor]`, *optional*):
            Tuple of tensors for decoder's attentions, each of shape `(batch_size, num_heads, sequence_length, sequence_length)`.

            Attention weights after softmax for the decoder.
    """

    last_hidden_state: tf.Tensor = None
    past_key_values: List[tf.Tensor] | None = None
    decoder_hidden_states: Tuple[tf.Tensor] | None = None
    decoder_attentions: Tuple[tf.Tensor] | None = None
    # 定义交叉注意力的张量元组或空值,初始值为 None
    cross_attentions: Tuple[tf.Tensor] | None = None
    # 定义编码器最后一个隐藏状态的张量或空值,初始值为 None
    encoder_last_hidden_state: tf.Tensor | None = None
    # 定义编码器隐藏状态的张量元组或空值,初始值为 None
    encoder_hidden_states: Tuple[tf.Tensor] | None = None
    # 定义编码器注意力的张量元组或空值,初始值为 None
    encoder_attentions: Tuple[tf.Tensor] | None = None
# 基于 ModelOutput 的数据类,表示因果语言模型(或自回归模型)的输出。
@dataclass
class TFCausalLMOutput(ModelOutput):
    """
    Base class for causal language model (or autoregressive) outputs.

    Args:
        loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
        logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    # 损失值张量,形状为 `(n,)`,当提供 `labels` 时返回,用于语言模型的损失计算(用于预测下一个标记)。
    loss: tf.Tensor | None = None
    # 预测分数张量,形状为 `(batch_size, sequence_length, config.vocab_size)`,在 SoftMax 之前的每个词汇标记的预测分数。
    logits: tf.Tensor = None
    # 隐藏状态元组,包含每层输出的张量(嵌入输出和每个层的输出),形状为 `(batch_size, sequence_length, hidden_size)`。
    hidden_states: Tuple[tf.Tensor] | None = None
    # 注意力张量元组,包含每层的注意力权重张量,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
    attentions: Tuple[tf.Tensor] | None = None


@dataclass
class TFCausalLMOutputWithPast(ModelOutput):
    """
    Base class for causal language model (or autoregressive) outputs.
    """
    # 定义函数参数和返回类型的注释
    Args:
        loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided):
            语言建模损失(用于下一个标记预测)。
        logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            语言建模头的预测分数(在 SoftMax 之前每个词汇标记的分数)。
        past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            包含预先计算的隐藏状态(注意力块中的键和值)的列表,可用于加速顺序解码。
            长度为 `config.n_layers` 的 `tf.Tensor` 列表,每个张量的形状为 `(2, batch_size, num_heads, sequence_length, embed_size_per_head)`。
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            模型在每一层输出的隐藏状态加上初始嵌入输出的元组。
            包含 `tf.Tensor`(嵌入输出的一个 + 每层输出的一个),形状为 `(batch_size, sequence_length, hidden_size)`。
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            self-attention 头部中的加权平均计算所使用的注意力 softmax 后的注意力权重。
            包含每一层的 `tf.Tensor`,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
@dataclass
class TFCausalLMOutputWithCrossAttentions(ModelOutput):
    """
    Base class for causal language model (or autoregressive) outputs.

    Args:
        loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
        logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
            weighted average in the cross-attention heads.
        past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
            sequence_length, embed_size_per_head)`).

            Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
            `past_key_values` input) to speed up sequential decoding.
    """

    loss: tf.Tensor | None = None  # Language modeling loss tensor, optional
    logits: tf.Tensor = None  # Prediction scores before SoftMax for each token
    past_key_values: List[tf.Tensor] | None = None  # Pre-computed hidden states for sequential decoding
    hidden_states: Tuple[tf.Tensor] | None = None  # Hidden states of the model at each layer output
    attentions: Tuple[tf.Tensor] | None = None  # Attention weights for self-attention heads
    cross_attentions: Tuple[tf.Tensor] | None = None  # Attention weights for cross-attention heads


@dataclass
class TFMaskedLMOutput(ModelOutput):
    """
    Base class for masked language models outputs.
    """
    # 定义 loss 变量,表示掩码语言建模的损失,形状为 (n,),当提供 labels 参数时返回
    loss: tf.Tensor | None = None
    # 定义 logits 变量,表示语言建模头部的预测分数,形状为 (batch_size, sequence_length, config.vocab_size)
    logits: tf.Tensor = None
    # 定义 hidden_states 变量,表示模型每层的隐藏状态的元组,形状为 (batch_size, sequence_length, hidden_size)
    # 当 output_hidden_states=True 或 config.output_hidden_states=True 时返回
    hidden_states: Tuple[tf.Tensor] | None = None
    # 定义 attentions 变量,表示自注意力头部的注意力权重的元组,形状为 (batch_size, num_heads, sequence_length, sequence_length)
    # 当 output_attentions=True 或 config.output_attentions=True 时返回
    attentions: Tuple[tf.Tensor] | None = None
@dataclass
class TFSeq2SeqLMOutput(ModelOutput):
    """
    Base class for sequence-to-sequence language models outputs.
    """

    # Optional: Loss tensor representing the model's computed loss
    loss: tf.Tensor | None = None

    # Optional: Logits tensor containing the model's predictions
    logits: tf.Tensor = None

    # Optional: List of past key values for attention mechanisms
    past_key_values: List[tf.Tensor] | None = None

    # Optional: Tuple of tensors for hidden states of the decoder
    decoder_hidden_states: Tuple[tf.Tensor] | None = None

    # Optional: Tuple of tensors for attention weights of the decoder
    decoder_attentions: Tuple[tf.Tensor] | None = None

    # Optional: Tuple of tensors for cross-attention weights
    cross_attentions: Tuple[tf.Tensor] | None = None

    # Optional: Tensor representing the last hidden state of the encoder
    encoder_last_hidden_state: tf.Tensor | None = None

    # Optional: Tuple of tensors for hidden states of the encoder
    encoder_hidden_states: Tuple[tf.Tensor] | None = None

    # Optional: Tuple of tensors for attention weights of the encoder
    encoder_attentions: Tuple[tf.Tensor] | None = None


@dataclass
class TFNextSentencePredictorOutput(ModelOutput):
    """
    Base class for outputs of models predicting if two sentences are consecutive or not.
    """

    # Optional: Loss tensor representing the next sentence prediction loss
    loss: tf.Tensor | None = None

    # Required: Logits tensor for the next sentence prediction
    logits: tf.Tensor = None

    # Optional: Tuple of tensors for hidden states of the model
    hidden_states: Tuple[tf.Tensor] | None = None

    # Optional: Tuple of tensors for attention weights of the model
    attentions: Tuple[tf.Tensor] | None = None


@dataclass
class TFSequenceClassifierOutput(ModelOutput):
    """
    Base class for outputs of sentence classification models.
    """
    Args:
        loss (`tf.Tensor` of shape `(batch_size, )`, *optional*, returned when `labels` is provided):
            分类(或回归,如果 `config.num_labels==1`)的损失。
            当提供 `labels` 参数时返回。
        logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
            分类(或回归,如果 `config.num_labels==1`)的分数(SoftMax 之前的值)。
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            由 `tf.Tensor` 组成的元组(当传递 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回)。
            形状为 `(batch_size, sequence_length, hidden_size)`。

            模型在每个层输出的隐藏状态,以及初始嵌入输出。
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            由 `tf.Tensor` 组成的元组(当传递 `output_attentions=True` 或 `config.output_attentions=True` 时返回)。
            形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。

            注意力权重经过注意力 SoftMax 后的结果,用于计算自注意力头部的加权平均值。
    """

    loss: tf.Tensor | None = None  # 初始化为 None,表示损失值尚未设置
    logits: tf.Tensor = None  # 初始化为 None,表示逻辑分数尚未设置
    hidden_states: Tuple[tf.Tensor] | None = None  # 初始化为 None,表示隐藏状态尚未设置
    attentions: Tuple[tf.Tensor] | None = None  # 初始化为 None,表示注意力权重尚未设置
# 使用 `dataclass` 装饰器声明一个类,表示一个序列到序列句子分类模型的输出。
@dataclass
class TFSeq2SeqSequenceClassifierOutput(ModelOutput):
    """
    序列到序列句子分类模型输出的基础类。

    """

    # 表示损失值的张量,可以为 None
    loss: tf.Tensor | None = None
    # 表示逻辑回归输出的张量
    logits: tf.Tensor = None
    # 表示过去键值的列表,可以为 None
    past_key_values: List[tf.Tensor] | None = None
    # 表示解码器隐藏状态的元组,可以为 None
    decoder_hidden_states: Tuple[tf.Tensor] | None = None
    # 表示解码器注意力的元组,可以为 None
    decoder_attentions: Tuple[tf.Tensor] | None = None
    # 表示交叉注意力的元组,可以为 None
    cross_attentions: Tuple[tf.Tensor] | None = None
    # 表示编码器最后隐藏状态的张量,可以为 None
    encoder_last_hidden_state: tf.Tensor | None = None
    # 表示编码器隐藏状态的元组,可以为 None
    encoder_hidden_states: Tuple[tf.Tensor] | None = None
    # 表示编码器注意力的元组,可以为 None
    encoder_attentions: Tuple[tf.Tensor] | None = None


# 使用 `dataclass` 装饰器声明一个类,表示语义分割模型的输出。
@dataclass
class TFSemanticSegmenterOutput(ModelOutput):
    """
    语义分割模型输出的基础类。

    Args:
        loss (`tf.Tensor` of shape `(1,)`, *optional*, 当提供 `labels` 时返回):
            分类(或回归,如果 `config.num_labels==1`)损失。
        logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
            每个像素的分类分数。

            <Tip warning={true}>

            返回的 logits 不一定与作为输入传递的 `pixel_values` 的大小相同。这是为了避免进行两次插值并在将 logits 调整回原始图像大小时失去一些质量。
            您应该始终检查 logits 的形状并根据需要进行调整大小。

            </Tip>

        hidden_states (`tuple(tf.Tensor)`, *optional*, 当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
            `tf.Tensor` 的元组(如果模型具有嵌入层,则为一个用于每层输出的隐藏状态的输出 + 一个用于每个层输出的初始嵌入输出),
            形状为 `(batch_size, patch_size, hidden_size)`。

            模型在每层输出之后的隐藏状态,加上可选的初始嵌入输出。
        attentions (`tuple(tf.Tensor)`, *optional*, 当 `output_attentions=True` 或 `config.output_attentions=True` 时返回):
            `tf.Tensor` 的元组(每层一个),
            形状为 `(batch_size, num_heads, patch_size, sequence_length)`。

            经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
    """

    # 表示损失值的张量,可以为 None
    loss: tf.Tensor | None = None
    # 表示逻辑回归输出的张量
    logits: tf.Tensor = None
    # 表示隐藏状态的元组,可以为 None
    hidden_states: Tuple[tf.Tensor] | None = None
    # 表示注意力的元组,可以为 None
    attentions: Tuple[tf.Tensor] | None = None


# 使用 `dataclass` 装饰器声明一个类,表示不输出注意力分数的语义分割模型的输出。
@dataclass
class TFSemanticSegmenterOutputWithNoAttention(ModelOutput):
    """
    不输出注意力分数的语义分割模型输出的基础类。

    """
    # 定义函数的参数和返回值类型注释
    Args:
        loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            分类(或回归,如果 `config.num_labels==1`)损失。
        logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
            每个像素的分类得分。

            <Tip warning={true}>

            返回的 logits 不一定与输入的 `pixel_values` 具有相同的大小。这是为了避免在用户需要将 logits 调整回原始图像大小时进行两次插值并丢失一些质量。您应始终检查 logits 的形状并根据需要调整大小。

            </Tip>

        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            形状为 `(batch_size, patch_size, hidden_size)` 的 `tf.Tensor` 元组(当传递 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回)。

            模型在每个层的输出隐藏状态加上可选的初始嵌入输出。

    loss: tf.Tensor | None = None
    logits: tf.Tensor = None
    hidden_states: Tuple[tf.Tensor] | None = None
# 定义 TFImageClassifierOutput 类,用于表示图像分类模型的输出结果
@dataclass
class TFImageClassifierOutput(ModelOutput):
    """
    Base class for outputs of image classification models.

    Args:
        loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            分类损失(如果提供 `labels` 参数)。
        logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
            分类得分(如果 `config.num_labels==1` 则是回归分数),未经 SoftMax 处理。
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            一个元组,包含 `tf.Tensor`(用于嵌入层的输出,如果模型有嵌入层,+ 每个阶段的输出),形状为 `(batch_size, sequence_length, hidden_size)`。
            模型在每个阶段输出的隐藏状态(也称为特征图)。
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            一个元组,包含 `tf.Tensor`(每个层的注意力权重),形状为 `(batch_size, num_heads, patch_size, sequence_length)`。

            注意力 softmax 后的注意力权重,用于在自注意力头中计算加权平均值。
    """

    loss: tf.Tensor | None = None
    logits: tf.Tensor = None
    hidden_states: Tuple[tf.Tensor] | None = None
    attentions: Tuple[tf.Tensor] | None = None


# 定义 TFMultipleChoiceModelOutput 类,用于表示多项选择模型的输出结果
@dataclass
class TFMultipleChoiceModelOutput(ModelOutput):
    """
    Base class for outputs of multiple choice models.

    Args:
        loss (`tf.Tensor` of shape *(batch_size, )*, *optional*, returned when `labels` is provided):
            分类损失。
        logits (`tf.Tensor` of shape `(batch_size, num_choices)`):
            `num_choices` 是输入张量的第二维。参见上面的 `input_ids`。

            分类得分(未经 SoftMax 处理)。
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            一个元组,包含 `tf.Tensor`(用于嵌入层的输出 + 每层的输出),形状为 `(batch_size, sequence_length, hidden_size)`。

            模型在每层输出的隐藏状态加上初始嵌入输出。
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            一个元组,包含 `tf.Tensor`(每个层的注意力权重),形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。

            注意力 softmax 后的注意力权重,用于在自注意力头中计算加权平均值。
    """

    loss: tf.Tensor | None = None
    logits: tf.Tensor = None
    hidden_states: Tuple[tf.Tensor] | None = None
    attentions: Tuple[tf.Tensor] | None = None
@dataclass
class TFTokenClassifierOutput(ModelOutput):
    """
    Token 分类模型输出的基类。

    Args:
        loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of unmasked labels, returned when `labels` is provided):
            分类损失。
        logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.num_labels)`):
            分类分数(SoftMax 之前)。
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            元组,包含 `tf.Tensor`(一个用于嵌入输出 + 每层输出的 `tf.Tensor`),形状为 `(batch_size, sequence_length, hidden_size)`。

            每层模型的隐藏状态,加上初始嵌入输出。
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            元组,包含 `tf.Tensor`(每个层的注意力权重),形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。

            注意力权重经过 SoftMax 后的结果,用于计算自注意力头部的加权平均值。
    """

    loss: tf.Tensor | None = None
    logits: tf.Tensor = None
    hidden_states: Tuple[tf.Tensor] | None = None
    attentions: Tuple[tf.Tensor] | None = None


@dataclass
class TFQuestionAnsweringModelOutput(ModelOutput):
    """
    问答模型输出的基类。

    Args:
        loss (`tf.Tensor` of shape `(batch_size, )`, *optional*, returned when `start_positions` and `end_positions` are provided):
            总的跨度提取损失,为开始和结束位置的交叉熵之和。
        start_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
            起始位置的分数(SoftMax 之前)。
        end_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
            结束位置的分数(SoftMax 之前)。
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            元组,包含 `tf.Tensor`(一个用于嵌入输出 + 每层输出的 `tf.Tensor`),形状为 `(batch_size, sequence_length, hidden_size)`。

            每层模型的隐藏状态,加上初始嵌入输出。
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            元组,包含 `tf.Tensor`(每个层的注意力权重),形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。

            注意力权重经过 SoftMax 后的结果,用于计算自注意力头部的加权平均值。
    """

    loss: tf.Tensor | None = None
    # 定义变量 start_logits,用于存储开始位置的预测张量,初始值为 None
    start_logits: tf.Tensor = None
    # 定义变量 end_logits,用于存储结束位置的预测张量,初始值为 None
    end_logits: tf.Tensor = None
    # 定义变量 hidden_states,用于存储隐藏状态的元组张量,初始值为 None
    hidden_states: Tuple[tf.Tensor] | None = None
    # 定义变量 attentions,用于存储注意力张量的元组张量,初始值为 None
    attentions: Tuple[tf.Tensor] | None = None
@dataclass
class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
    """
    Base class for outputs of sequence-to-sequence question answering models.
    """
    Args:
        loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
        start_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
            Span-start scores (before SoftMax).
        end_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
            Span-end scores (before SoftMax).
        past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
            sequence_length, embed_size_per_head)`).

            Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
            used (see `past_key_values` input) to speed up sequential decoding.
        decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
        decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
            self-attention heads.
        encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder of the model.
        encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
        encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
            self-attention heads.
    """

    # 定义一个 loss 变量,默认为 None,用于存储总的 span 抽取损失
    loss: tf.Tensor | None = None
    # 初始化变量 `start_logits`,用于存储模型预测的起始位置的 logits
    start_logits: tf.Tensor = None
    # 初始化变量 `end_logits`,用于存储模型预测的结束位置的 logits
    end_logits: tf.Tensor = None
    # 初始化变量 `past_key_values`,用于存储模型解码器过去的键值张量列表,初始为 None
    past_key_values: List[tf.Tensor] | None = None
    # 初始化变量 `decoder_hidden_states`,用于存储解码器的隐藏状态的元组,初始为 None
    decoder_hidden_states: Tuple[tf.Tensor] | None = None
    # 初始化变量 `decoder_attentions`,用于存储解码器的注意力张量的元组,初始为 None
    decoder_attentions: Tuple[tf.Tensor] | None = None
    # 初始化变量 `encoder_last_hidden_state`,用于存储编码器的最后隐藏状态的张量,初始为 None
    encoder_last_hidden_state: tf.Tensor | None = None
    # 初始化变量 `encoder_hidden_states`,用于存储编码器的隐藏状态的元组,初始为 None
    encoder_hidden_states: Tuple[tf.Tensor] | None = None
    # 初始化变量 `encoder_attentions`,用于存储编码器的注意力张量的元组,初始为 None
    encoder_attentions: Tuple[tf.Tensor] | None = None
@dataclass
class TFSequenceClassifierOutputWithPast(ModelOutput):
    """
    用于句子分类模型输出的基础类。

    Args:
        loss (`tf.Tensor` of shape `(batch_size, )`, *optional*, returned when `labels` is provided):
            分类(或回归,如果 config.num_labels==1)的损失值。
        logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
            分类(或回归,如果 config.num_labels==1)的分数(SoftMax 之前)。
        past_key_values (`List[tf.Tensor]`, *optional*, 当传递 `use_cache=True` 或 `config.use_cache=True` 时返回):
            长度为 `config.n_layers` 的 `tf.Tensor` 列表,每个张量的形状为 `(2, batch_size, num_heads, sequence_length, embed_size_per_head)`。

            包含预先计算的隐藏状态(注意力块中的键和值),可用于加速序列解码。
        hidden_states (`tuple(tf.Tensor)`, *optional*, 当传递 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
            形状为 `(batch_size, sequence_length, hidden_size)` 的 `tf.Tensor` 元组。

            模型在每一层输出的隐藏状态,以及初始嵌入输出。
        attentions (`tuple(tf.Tensor)`, *optional*, 当传递 `output_attentions=True` 或 `config.output_attentions=True` 时返回):
            形状为 `(batch_size, num_heads, sequence_length, sequence_length)` 的 `tf.Tensor` 元组。

            注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
    """

    loss: tf.Tensor | None = None
    logits: tf.Tensor = None
    past_key_values: List[tf.Tensor] | None = None
    hidden_states: Tuple[tf.Tensor] | None = None
    attentions: Tuple[tf.Tensor] | None = None


@dataclass
class TFImageClassifierOutputWithNoAttention(ModelOutput):
    """
    用于图像分类模型输出的基础类。

    Args:
        loss (`tf.Tensor` of shape `(1,)`, *optional*, 当提供 `labels` 时返回):
            分类(或回归,如果 config.num_labels==1)的损失值。
        logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
            分类(或回归,如果 config.num_labels==1)的分数(SoftMax 之前)。
        hidden_states (`tuple(tf.Tensor)`, *optional*, 当传递 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
            形状为 `(batch_size, num_channels, height, width)` 的 `tf.Tensor` 元组。

            模型在每个阶段输出的隐藏状态(也称为特征图)。
    """
    # 定义变量 `loss`,其类型为 `tf.Tensor` 或者 `None`,初始值为 `None`
    loss: tf.Tensor | None = None
    # 定义变量 `logits`,其类型为 `tf.Tensor`,初始值为 `None`
    logits: tf.Tensor = None
    # 定义变量 `hidden_states`,其类型为 `Optional`,包含一个元组,元组中的每个元素为 `tf.Tensor` 对象,初始值为 `None`
    hidden_states: Optional[Tuple[tf.Tensor, ...]] = None
@dataclass
class TFMaskedImageModelingOutput(ModelOutput):
    """
    Base class for outputs of masked image completion / in-painting models.

    Args:
        loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
            Reconstruction loss.
        reconstruction (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
           Reconstructed / completed images.
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when
        `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
            the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called
            feature maps) of the model at the output of each stage.
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when
        `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`.
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    # loss 属性,表示重建损失(如果提供了 bool_masked_pos 参数)
    loss: tf.Tensor | None = None
    # reconstruction 属性,表示重建或完成的图像数据
    reconstruction: tf.Tensor = None
    # hidden_states 属性,表示隐藏状态,如果模型有嵌入层,则为嵌入输出以及每个阶段的输出
    hidden_states: Tuple[tf.Tensor] | None = None
    # attentions 属性,表示注意力权重,用于计算自注意力头中的加权平均值
    attentions: Tuple[tf.Tensor] | None = None

    @property
    def logits(self):
        # 提醒警告:logits 属性在 Transformers 版本 5 中将被移除,请使用 reconstruction 属性获取最终输出
        warnings.warn(
            "logits attribute is deprecated and will be removed in version 5 of Transformers."
            " Please use the reconstruction attribute to retrieve the final output instead.",
            FutureWarning,
        )
        # 返回 reconstruction 属性作为输出
        return self.reconstruction

.\modeling_tf_pytorch_utils.py

# 设置文件编码为 UTF-8
# 版权声明,分别归属于 Google AI Language Team 和 HuggingFace Inc. 团队以及 NVIDIA 公司
#
# 根据 Apache 许可证 2.0 版本,除非符合许可证要求,否则禁止使用此文件
# 可以在以下链接找到完整的许可证文本:http://www.apache.org/licenses/LICENSE-2.0
#
# 如果不符合适用法律或未经书面同意,软件将按“原样”分发,没有任何形式的担保或条件
# 详见许可证以了解更多信息

""" PyTorch - TF 2.0 通用实用工具 """

# 导入必要的库
import os
import re

import numpy  # 导入 numpy 库

# 从本地模块中导入以下工具函数
from .utils import ExplicitEnum, expand_dims, is_numpy_array, is_torch_tensor, logging, reshape, squeeze, tensor_size
from .utils import transpose as transpose_func

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义枚举类 TransposeType,表示转置类型的枚举值
class TransposeType(ExplicitEnum):
    """
    可能的...
    """
    NO = "no"
    SIMPLE = "simple"
    CONV1D = "conv1d"
    CONV2D = "conv2d"

# 定义函数 convert_tf_weight_name_to_pt_weight_name,将 TF 2.0 模型变量名转换为 PyTorch 模型权重名
def convert_tf_weight_name_to_pt_weight_name(
    tf_name, start_prefix_to_remove="", tf_weight_shape=None, name_scope=None
):
    """
    将 TF 2.0 模型变量名转换为 PyTorch 模型权重名。

    TF2.0 范围 -> PyTorch 属性名转换的约定:

        - '$1___$2' 被 $2 替换(可用于在 TF2.0 vs PyTorch 中复制或删除层)
        - '_._' 被新的级别分隔替换(可用于在 PyTorch nn.ModulesList 中转换 TF2.0 列表)

    返回一个元组,包含:

        - PyTorch 模型权重名
        - transpose:表示 TF2.0 和 PyTorch 权重矩阵之间是否以及如何进行转置的 `TransposeType` 成员
    """
    if name_scope is not None:
        if not tf_name.startswith(name_scope) and "final_logits_bias" not in tf_name:
            raise ValueError(
                f"Weight name {tf_name} does not start with name_scope {name_scope}. This is an internal error "
                "in Transformers, so (unless you were doing something really evil) please open an issue to report it!"
            )
        tf_name = tf_name[len(name_scope) :]
        tf_name = tf_name.lstrip("/")
    tf_name = tf_name.replace(":0", "")  # 移除设备 ID
    tf_name = re.sub(
        r"/[^/]*___([^/]*)/", r"/\1/", tf_name
    )  # '$1___$2' 被 $2 替换(可用于在 TF2.0 vs PyTorch 中复制或删除层)
    tf_name = tf_name.replace(
        "_._", "/"
    )  # '_._' 被级别分隔符替换(可用于在 PyTorch nn.ModulesList 中转换 TF2.0 列表)
    tf_name = re.sub(r"//+", "/", tf_name)  # 移除末尾的空级别
    tf_name = tf_name.split("/")  # 从 TF2.0 '/' 分隔符转换为 PyTorch '.' 分隔符
    # 检查 TensorFlow 权重名是否为多层级结构,如 BART 中的 final_logits_bias
    if len(tf_name) > 1:
        # 如果是多层级结构,移除第一层级的名称
        tf_name = tf_name[1:]  # Remove level zero
    
    # 将 TensorFlow 权重形状转换为列表形式
    tf_weight_shape = list(tf_weight_shape)
    
    # 判断是否需要转置权重
    if tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 4:
        # 如果权重名称以 "kernel" 结尾且形状为四维,则选择转置类型为 CONV2D
        transpose = TransposeType.CONV2D
    elif tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 3:
        # 如果权重名称以 "kernel" 结尾且形状为三维,则选择转置类型为 CONV1D
        transpose = TransposeType.CONV1D
    elif bool(
        tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"]
        or "emb_projs" in tf_name
        or "out_projs" in tf_name
    ):
        # 如果权重名称以 "kernel", "pointwise_kernel", "depthwise_kernel" 结尾,或者包含 "emb_projs" 或 "out_projs",
        # 则选择转置类型为 SIMPLE
        transpose = TransposeType.SIMPLE
    else:
        # 否则,选择不进行转置
        transpose = TransposeType.NO
    
    # 将标准的 TensorFlow 2.0 权重名称转换为 PyTorch 权重名称
    if tf_name[-1] == "kernel" or tf_name[-1] == "embeddings" or tf_name[-1] == "gamma":
        tf_name[-1] = "weight"
    if tf_name[-1] == "beta":
        tf_name[-1] = "bias"
    
    # 对于 SeparableConv1D TF 层,将两个权重转换为 PyTorch Conv1D 的形式
    if tf_name[-1] == "pointwise_kernel" or tf_name[-1] == "depthwise_kernel":
        tf_name[-1] = tf_name[-1].replace("_kernel", ".weight")
    
    # 将列表形式的名称拼接为字符串形式
    tf_name = ".".join(tf_name)
    
    # 如果需要移除前缀,则移除指定的前缀
    if start_prefix_to_remove:
        tf_name = tf_name.replace(start_prefix_to_remove, "", 1)
    
    # 返回转换后的 PyTorch 权重名称和转置类型
    return tf_name, transpose
def apply_transpose(transpose: TransposeType, weight, match_shape=None, pt_to_tf=True):
    """
    Apply a transpose operation to a weight tensor and optionally reshape it to match a target shape, in a framework-agnostic manner.
    """
    # 根据 transpose 类型选择不同的转置方式
    if transpose is TransposeType.CONV2D:
        # Conv2D 权重转置说明:
        #    PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
        # -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
        axes = (2, 3, 1, 0) if pt_to_tf else (3, 2, 0, 1)
        weight = transpose_func(weight, axes=axes)
    elif transpose is TransposeType.CONV1D:
        # Conv1D 权重转置说明:
        #    PT: (num_out_channel, num_in_channel, kernel)
        # -> TF: (kernel, num_in_channel, num_out_channel)
        weight = transpose_func(weight, axes=(2, 1, 0))
    elif transpose is TransposeType.SIMPLE:
        # 简单转置操作
        weight = transpose_func(weight)

    # 如果没有指定匹配的形状,直接返回转置后的权重
    if match_shape is None:
        return weight

    # 调整权重的形状以匹配目标形状
    if len(match_shape) < len(weight.shape):
        weight = squeeze(weight)  # 如果目标形状的维度少于当前权重的维度,则进行压缩操作
    elif len(match_shape) > len(weight.shape):
        weight = expand_dims(weight, axis=0)  # 如果目标形状的维度多于当前权重的维度,则在指定轴上扩展维度

    # 如果权重的形状与目标形状不匹配,则尝试重新调整形状
    if list(match_shape) != list(weight.shape):
        try:
            weight = reshape(weight, match_shape)  # 重新调整权重的形状为目标形状
        except AssertionError as e:
            e.args += (match_shape, match_shape)
            raise e  # 抛出异常

    return weight


#####################
# PyTorch => TF 2.0 #
#####################


def load_pytorch_checkpoint_in_tf2_model(
    tf_model,
    pytorch_checkpoint_path,
    tf_inputs=None,
    allow_missing_keys=False,
    output_loading_info=False,
    _prefix=None,
    tf_to_pt_weight_rename=None,
):
    """Load pytorch checkpoints into a TF 2.0 model"""
    try:
        import tensorflow as tf  # noqa: F401
        import torch  # noqa: F401
        from safetensors.torch import load_file as safe_load_file  # noqa: F401

        from .pytorch_utils import is_torch_greater_or_equal_than_1_13  # noqa: F401
    except ImportError:
        logger.error(
            "Loading a PyTorch model in TensorFlow requires both PyTorch and TensorFlow to be installed. Please see "
            "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
        )
        raise

    # 将单个文件路径处理为一个包含单个片段的集合
    if isinstance(pytorch_checkpoint_path, str):
        pytorch_checkpoint_path = [pytorch_checkpoint_path]

    # 将所有片段加载到单个状态字典中
    pt_state_dict = {}
    for path in pytorch_checkpoint_path:
        pt_path = os.path.abspath(path)
        logger.info(f"Loading PyTorch weights from {pt_path}")
        # 根据文件后缀选择加载方式
        if pt_path.endswith(".safetensors"):
            state_dict = safe_load_file(pt_path)
        else:
            weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
            state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)

        pt_state_dict.update(state_dict)
    # 使用日志记录器输出 PyTorch 检查点中包含的参数总数,格式化为千位分隔的字符串
    logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters")
    
    # 调用函数,将 PyTorch 模型权重加载到 TensorFlow 2 模型中
    return load_pytorch_weights_in_tf2_model(
        tf_model,                     # TensorFlow 2 模型对象
        pt_state_dict,                # PyTorch 模型的状态字典
        tf_inputs=tf_inputs,          # 可选参数:传递给 TensorFlow 加载函数的输入
        allow_missing_keys=allow_missing_keys,  # 可选参数:允许缺失的键
        output_loading_info=output_loading_info,  # 可选参数:控制加载过程中的信息输出
        _prefix=_prefix,              # 可选参数:加载时的前缀
        tf_to_pt_weight_rename=tf_to_pt_weight_rename,  # 可选参数:重命名 TensorFlow 到 PyTorch 权重的映射
    )
# 载入 PyTorch 模型权重到 TensorFlow 2.0 模型
def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_missing_keys=False):
    """Load pytorch checkpoints in a TF 2.0 model"""
    # 获取 PyTorch 模型的状态字典
    pt_state_dict = pt_model.state_dict()

    # 调用函数载入 PyTorch 权重到 TensorFlow 模型
    return load_pytorch_weights_in_tf2_model(
        tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys
    )


# 载入 PyTorch 状态字典到 TensorFlow 2.0 模型
def load_pytorch_weights_in_tf2_model(
    tf_model,
    pt_state_dict,
    tf_inputs=None,
    allow_missing_keys=False,
    output_loading_info=False,
    _prefix=None,
    tf_to_pt_weight_rename=None,
):
    """Load pytorch state_dict in a TF 2.0 model."""
    try:
        import tensorflow as tf  # 导入 TensorFlow 库
        import torch  # 导入 PyTorch 库
    except ImportError:
        # 若导入失败,输出错误信息并抛出异常
        logger.error(
            "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
            "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
        )
        raise

    # 将 PyTorch 状态字典中的张量转换为 NumPy 数组
    pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
    # 调用函数加载 PyTorch 状态字典到 TensorFlow 模型
    return load_pytorch_state_dict_in_tf2_model(
        tf_model,
        pt_state_dict,
        tf_inputs=tf_inputs,
        allow_missing_keys=allow_missing_keys,
        output_loading_info=output_loading_info,
        _prefix=_prefix,
        tf_to_pt_weight_rename=tf_to_pt_weight_rename,
    )


# 加载 PyTorch 状态字典到 TensorFlow 2.0 模型
def load_pytorch_state_dict_in_tf2_model(
    tf_model,
    pt_state_dict,
    tf_inputs=None,
    allow_missing_keys=False,
    output_loading_info=False,
    _prefix=None,
    tf_to_pt_weight_rename=None,
    ignore_mismatched_sizes=False,
):
    """Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading
    safetensors archive created with the safe_open() function."""
    import tensorflow as tf

    # 如果未指定输入数据,使用模型的虚拟输入
    if tf_inputs is None:
        tf_inputs = tf_model.dummy_inputs

    # 如果未指定前缀,设为空字符串
    if _prefix is None:
        _prefix = ""

    # 如果有输入数据,确保模型已构建
    if tf_inputs:
        with tf.name_scope(_prefix):
            tf_model(tf_inputs, training=False)  # 确保模型已构建

    # 转换从 TensorFlow 键到 PyTorch 键的映射
    tf_keys_to_pt_keys = {}
    # 遍历输入字典的键
    for key in pt_state_dict.keys():
        new_key = None
        # 如果键名中包含 "gamma",替换为 "weight"
        if "gamma" in key:
            new_key = key.replace("gamma", "weight")
        # 如果键名中包含 "beta",替换为 "bias"
        if "beta" in key:
            new_key = key.replace("beta", "bias")
        # 如果键名中包含 "running_var",替换为 "moving_variance"
        if "running_var" in key:
            new_key = key.replace("running_var", "moving_variance")
        # 如果键名中包含 "running_mean",替换为 "moving_mean"
        if "running_mean" in key:
            new_key = key.replace("running_mean", "moving_mean")

        # 处理新的 `weight_norm` 命名,来源于 https://github.com/huggingface/transformers/pull/24030
        key_components = key.split(".")
        name = None
        # 检查键名的特定模式,根据模式生成新的命名
        if key_components[-3::2] == ["parametrizations", "original0"]:
            name = key_components[-2] + "_g"
        elif key_components[-3::2] == ["parametrizations", "original1"]:
            name = key_components[-2] + "_v"
        if name is not None:
            key_components = key_components[:-3] + [name]
            new_key = ".".join(key_components)

        # 如果没有匹配到任何替换规则,保持原来的键名不变
        if new_key is None:
            new_key = key
        # 将新旧键名的对应关系存入字典
        tf_keys_to_pt_keys[new_key] = key

    # Matt: 所有 TF 模型都在 MainLayer 类中存储实际模型,包括基础模型。
    # 在 PT 中,派生模型(带头部的模型)使用基础模型类作为主干,没有 MainLayer 类。
    # 这意味着 TF 基础模型的权重名中有一个额外的层级,对应于 MainLayer 类。
    # 以下代码块用于补偿这一差异。
    
    # 如果没有任何 TF 键名以 tf_model.base_model_prefix 开头,则需要移除的前缀为 tf_model.base_model_prefix + "."
    start_prefix_to_remove = ""
    if not any(s.startswith(tf_model.base_model_prefix) for s in tf_keys_to_pt_keys.keys()):
        start_prefix_to_remove = tf_model.base_model_prefix + "."

    # 获取 TF 模型的所有符号权重(可训练和不可训练的)
    symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
    # 初始化 TF 加载的权重数目
    tf_loaded_numel = 0
    # 获取所有 PyTorch 键名的集合
    all_pytorch_weights = set(tf_keys_to_pt_keys.keys())
    # 存储缺失的键名列表
    missing_keys = []
    # 存储不匹配的键名列表
    mismatched_keys = []
    # 检查 pt_state_dict 是否具有 "get_tensor" 方法,用于确定是否为 SafeTensor 存档
    is_safetensor_archive = hasattr(pt_state_dict, "get_tensor")
    # 遍历符号权重列表中的每个符号权重对象
    for symbolic_weight in symbolic_weights:
        # 获取当前符号权重的名称
        sw_name = symbolic_weight.name
        # 将 TensorFlow 的权重名称转换为 PyTorch 的权重名称,并获取转换后的名称及是否需要转置的信息
        name, transpose = convert_tf_weight_name_to_pt_weight_name(
            sw_name,
            start_prefix_to_remove=start_prefix_to_remove,
            tf_weight_shape=symbolic_weight.shape,
            name_scope=_prefix,
        )
        
        # 如果指定了 TensorFlow 到 PyTorch 权重重命名函数,则使用它来获取可能的别名
        if tf_to_pt_weight_rename is not None:
            aliases = tf_to_pt_weight_rename(name)  # 返回一个元组以处理可能的名称别名
            # 遍历别名列表,按优先顺序使用第一个匹配的别名
            for alias in aliases:
                if alias in tf_keys_to_pt_keys:
                    name = alias
                    break
            else:
                # 如果没有别名匹配,使用列表中的第一个名称(将被报告为缺失)
                name = aliases[0]

        # 在 PyTorch 模型状态字典中查找对应的 NumPy 数组
        if name not in tf_keys_to_pt_keys:
            # 如果允许缺失键,则将名称添加到缺失键列表中并继续下一个符号权重
            if allow_missing_keys:
                missing_keys.append(name)
                continue
            # 如果定义了可以在加载时忽略的键列表,则根据列表判断是否需要忽略当前键
            elif tf_model._keys_to_ignore_on_load_missing is not None:
                if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load_missing):
                    continue
                # 如果不符合忽略条件,则抛出异常,指出在 PyTorch 模型中找不到该键
            raise AttributeError(f"{name} not found in PyTorch model")
        
        # 获取 PyTorch 模型状态字典中对应键的数组
        state_dict_name = tf_keys_to_pt_keys[name]
        # 如果是安全张量归档模式,则从 PyTorch 状态字典中获取张量
        if is_safetensor_archive:
            array = pt_state_dict.get_tensor(state_dict_name)
        else:
            array = pt_state_dict[state_dict_name]
        
        # 尝试将数组按照转置信息应用到符号权重的形状上
        try:
            array = apply_transpose(transpose, array, symbolic_weight.shape)
        except tf.errors.InvalidArgumentError as e:
            # 如果出现尺寸不匹配的错误,并且不忽略尺寸不匹配,则抛出异常
            if not ignore_mismatched_sizes:
                error_msg = str(e)
                error_msg += (
                    "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
                )
                raise tf.errors.InvalidArgumentError(error_msg)
            else:
                # 否则将不匹配的键和形状添加到不匹配键列表中并继续下一个符号权重
                mismatched_keys.append((name, array.shape, symbolic_weight.shape))
                continue
        
        # 计算加载的 TensorFlow 权重的元素数量并累加到 tf_loaded_numel 中
        tf_loaded_numel += tensor_size(array)
        
        # 将 PyTorch 数组转换为符号权重的数据类型,并分配给符号权重对象
        symbolic_weight.assign(tf.cast(array, symbolic_weight.dtype))
        # 立即释放数组以尽可能保持内存使用低峰
        del array
        # 从所有 PyTorch 权重集合中移除当前处理的键
        all_pytorch_weights.discard(name)

    # 记录加载了多少个参数到 TF 2.0 模型中
    logger.info(f"Loaded {tf_loaded_numel:,} parameters in the TF 2.0 model.")

    # 将未预期的键列表转换为列表形式
    unexpected_keys = list(all_pytorch_weights)

    # 如果定义了在加载时忽略的缺失键列表,则根据列表中的模式匹配规则进行过滤
    if tf_model._keys_to_ignore_on_load_missing is not None:
        for pat in tf_model._keys_to_ignore_on_load_missing:
            missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
    
    # 如果定义了在加载时忽略的未预期键列表,则根据列表中的模式匹配规则进行过滤
    if tf_model._keys_to_ignore_on_load_unexpected is not None:
        for pat in tf_model._keys_to_ignore_on_load_unexpected:
            unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
    # 如果存在未预期的键(权重),记录警告信息到日志
    if len(unexpected_keys) > 0:
        logger.warning(
            "Some weights of the PyTorch model were not used when initializing the TF 2.0 model"
            f" {tf_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
            f" {tf_model.__class__.__name__} from a PyTorch model trained on another task or with another architecture"
            " (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS"
            f" NOT expected if you are initializing {tf_model.__class__.__name__} from a PyTorch model that you expect"
            " to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a"
            " BertForSequenceClassification model)."
        )
    else:
        # 如果所有 PyTorch 模型的权重都被使用,记录相应信息到日志
        logger.warning(f"All PyTorch model weights were used when initializing {tf_model.__class__.__name__}.\n")
    
    # 如果存在未初始化的键(权重或缓冲区),记录警告信息到日志
    if len(missing_keys) > 0:
        logger.warning(
            f"Some weights or buffers of the TF 2.0 model {tf_model.__class__.__name__} were not initialized from the"
            f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a"
            " down-stream task to be able to use it for predictions and inference."
        )
    else:
        # 如果所有权重都从 PyTorch 模型初始化,记录相应信息到日志
        logger.warning(
            f"All the weights of {tf_model.__class__.__name__} were initialized from the PyTorch model.\n"
            "If your task is similar to the task the model of the checkpoint was trained on, "
            f"you can already use {tf_model.__class__.__name__} for predictions without further training."
        )
    
    # 如果存在形状不匹配的键,生成对应的警告信息,并记录到日志
    if len(mismatched_keys) > 0:
        mismatched_warning = "\n".join(
            [
                f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
                for key, shape1, shape2 in mismatched_keys
            ]
        )
        logger.warning(
            f"Some weights of {tf_model.__class__.__name__} were not initialized from the model checkpoint"
            f" are newly initialized because the shapes did not"
            f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
            " to use it for predictions and inference."
        )
    
    # 如果需要输出加载信息,返回 TensorFlow 模型及加载信息
    if output_loading_info:
        loading_info = {
            "missing_keys": missing_keys,
            "unexpected_keys": unexpected_keys,
            "mismatched_keys": mismatched_keys,
        }
        return tf_model, loading_info
    
    # 返回加载后的 TensorFlow 模型
    return tf_model
#####################
# TF 2.0 => PyTorch #
#####################

# 在 PyTorch 模型中加载 TF 2.0 的检查点
def load_tf2_checkpoint_in_pytorch_model(
    pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
):
    """
    Load TF 2.0 HDF5 checkpoint in a PyTorch model We use HDF5 to easily do transfer learning (see
    https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
    """
    try:
        import tensorflow as tf  # noqa: F401
        import torch  # noqa: F401
    except ImportError:
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
            "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
        )
        raise

    import transformers

    from .modeling_tf_utils import load_tf_weights

    logger.info(f"Loading TensorFlow weights from {tf_checkpoint_path}")

    # 实例化并加载相关的 TF 2.0 模型
    tf_model_class_name = "TF" + pt_model.__class__.__name__  # 在类名前加上 "TF"
    tf_model_class = getattr(transformers, tf_model_class_name)
    tf_model = tf_model_class(pt_model.config)

    if tf_inputs is None:
        tf_inputs = tf_model.dummy_inputs

    if tf_inputs is not None:
        tf_model(tf_inputs, training=False)  # 确保模型已构建

    load_tf_weights(tf_model, tf_checkpoint_path)

    return load_tf2_model_in_pytorch_model(
        pt_model, tf_model, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
    )


# 在 PyTorch 模型中加载 TF 2.0 模型
def load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=False, output_loading_info=False):
    """Load TF 2.0 model in a pytorch model"""
    weights = tf_model.weights

    return load_tf2_weights_in_pytorch_model(
        pt_model, weights, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
    )


# 在 PyTorch 模型中加载 TF 2.0 权重
def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=False, output_loading_info=False):
    """Load TF2.0 symbolic weights in a PyTorch model"""
    try:
        import tensorflow as tf  # noqa: F401
        import torch  # noqa: F401
    except ImportError:
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
            "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
        )
        raise

    # 将 TF 2.0 的权重转换为字典形式
    tf_state_dict = {tf_weight.name: tf_weight.numpy() for tf_weight in tf_weights}
    return load_tf2_state_dict_in_pytorch_model(
        pt_model, tf_state_dict, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
    )


# 在 PyTorch 模型中加载 TF 2.0 的状态字典
def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_keys=False, output_loading_info=False):
    import torch

    new_pt_params_dict = {}
    # 获取当前 PyTorch 模型的所有命名参数,并转换成字典形式
    current_pt_params_dict = dict(pt_model.named_parameters())

    # 确保能够加载 PyTorch 基础模型和派生模型(带有头部)
    # TF 模型总是有一个前缀,而一些 PyTorch 基础模型则没有
    start_prefix_to_remove = ""
    if not any(s.startswith(pt_model.base_model_prefix) for s in current_pt_params_dict.keys()):
        start_prefix_to_remove = pt_model.base_model_prefix + "."

    # 构建一个从潜在的 PyTorch 权重名称到 TF 2.0 变量的映射
    tf_weights_map = {}
    for name, tf_weight in tf_state_dict.items():
        # 转换 TF 的权重名称到 PyTorch 的权重名称
        pt_name, transpose = convert_tf_weight_name_to_pt_weight_name(
            name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape
        )
        tf_weights_map[pt_name] = (tf_weight, transpose)

    # 获取所有 TF 权重名称的集合
    all_tf_weights = set(tf_weights_map.keys())
    
    # 用于存储已加载的 PyTorch 权重数据指针的字典
    loaded_pt_weights_data_ptr = {}
    
    # 存储缺失的 PyTorch 键列表
    missing_keys_pt = []

    # 遍历当前 PyTorch 模型的所有参数
    for pt_weight_name, pt_weight in current_pt_params_dict.items():
        # 处理 PyTorch 共享权重(在 TF 2.0 中不重复)
        if pt_weight.data_ptr() in loaded_pt_weights_data_ptr:
            new_pt_params_dict[pt_weight_name] = loaded_pt_weights_data_ptr[pt_weight.data_ptr()]
            continue

        # 准备用于检查的 PyTorch 权重名称
        pt_weight_name_to_check = pt_weight_name
        
        # 处理新的 `weight_norm`(来自 https://github.com/huggingface/transformers/pull/24030)
        key_components = pt_weight_name.split(".")
        name = None
        if key_components[-3::2] == ["parametrizations", "original0"]:
            name = key_components[-2] + "_g"
        elif key_components[-3::2] == ["parametrizations", "original1"]:
            name = key_components[-2] + "_v"
        if name is not None:
            key_components = key_components[:-3] + [name]
            pt_weight_name_to_check = ".".join(key_components)

        # 检查 PyTorch 权重名称是否在 TF 2.0 权重映射中
        if pt_weight_name_to_check not in tf_weights_map:
            # 如果允许缺失的键,则将其添加到缺失的 PyTorch 键列表中
            if allow_missing_keys:
                missing_keys_pt.append(pt_weight_name)
                continue

            # 否则,抛出属性错误,指明找不到对应的 TF 2.0 模型的键
            raise AttributeError(f"{pt_weight_name} not found in TF 2.0 model")

        # 获取对应的 numpy 数组和转置信息
        array, transpose = tf_weights_map[pt_weight_name_to_check]

        # 应用转置(如果需要),将 TF 数组转换为 PyTorch 数组
        array = apply_transpose(transpose, array, pt_weight.shape, pt_to_tf=False)

        # 如果数组是标量,转换为 numpy 数组
        if numpy.isscalar(array):
            array = numpy.array(array)
        # 如果不是 torch 张量也不是 numpy 数组,则假定为 numpy 数组并转换为 torch 张量
        if not is_torch_tensor(array) and not is_numpy_array(array):
            array = array.numpy()
        if is_numpy_array(array):
            # 转换为 torch 张量
            array = torch.from_numpy(array)

        # 将转换后的数组存储到新的 PyTorch 参数字典中
        new_pt_params_dict[pt_weight_name] = array
        # 将已加载的 PyTorch 权重数据指针存储到字典中,以避免重复加载
        loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = array
        # 从所有 TF 权重集合中移除当前处理的 PyTorch 权重名称
        all_tf_weights.discard(pt_weight_name)

    # 使用新的 PyTorch 参数字典加载模型状态,允许缺失的键
    missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False)
    # 将缺失的 PyTorch 键列表添加到总的缺失键列表中
    missing_keys += missing_keys_pt
    # 如果模型定义了要在加载时忽略的键,将这些键从缺失键列表中移除,避免不必要地向用户发出警告。
    if pt_model._keys_to_ignore_on_load_missing is not None:
        for pat in pt_model._keys_to_ignore_on_load_missing:
            # 使用正则表达式模式匹配并移除缺失键列表中与模式匹配的键
            missing_keys = [k for k in missing_keys if re.search(pat, k) is None]

    # 如果模型定义了要在加载时忽略的意外键,将这些键从意外键列表中移除,同样避免不必要的警告。
    if pt_model._keys_to_ignore_on_load_unexpected is not None:
        for pat in pt_model._keys_to_ignore_on_load_unexpected:
            # 使用正则表达式模式匹配并移除意外键列表中与模式匹配的键
            unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

    # 如果存在未使用的权重(意外键),向日志记录警告信息,说明这在某些情况下是预期的,比如模型从不同任务或架构的 TF 2.0 模型初始化时。
    if len(unexpected_keys) > 0:
        logger.warning(
            "Some weights of the TF 2.0 model were not used when initializing the PyTorch model"
            f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
            f" {pt_model.__class__.__name__} from a TF 2.0 model trained on another task or with another architecture"
            " (e.g. initializing a BertForSequenceClassification model from a TFBertForPreTraining model).\n- This IS"
            f" NOT expected if you are initializing {pt_model.__class__.__name__} from a TF 2.0 model that you expect"
            " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
            " TFBertForSequenceClassification model)."
        )
    else:
        # 如果没有未使用的权重,向日志记录警告信息,说明所有 TF 2.0 模型权重都已使用。
        logger.warning(f"All TF 2.0 model weights were used when initializing {pt_model.__class__.__name__}.\n")

    # 如果存在未初始化的权重(缺失键),向日志记录警告信息,建议用户在下游任务上训练模型以便进行预测和推断。
    if len(missing_keys) > 0:
        logger.warning(
            f"Some weights of {pt_model.__class__.__name__} were not initialized from the TF 2.0 model and are newly"
            f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
            " use it for predictions and inference."
        )
    else:
        # 如果没有未初始化的权重,向日志记录警告信息,说明所有权重都已从 TF 2.0 模型初始化。
        logger.warning(
            f"All the weights of {pt_model.__class__.__name__} were initialized from the TF 2.0 model.\n"
            "If your task is similar to the task the model of the checkpoint was trained on, "
            f"you can already use {pt_model.__class__.__name__} for predictions without further training."
        )

    # 向日志记录加载信息,显示哪些 TF 2.0 模型的权重或缓冲区未加载。
    logger.info(f"Weights or buffers not loaded from TF 2.0 model: {all_tf_weights}")

    # 如果需要输出加载信息,返回模型及加载信息的字典。
    if output_loading_info:
        loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
        return pt_model, loading_info

    # 否则,只返回加载后的 PyTorch 模型。
    return pt_model

.\modeling_tf_utils.py

# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""TF general model utils."""

from __future__ import annotations  # Future import to allow forward references

import functools  # Importing functools module for higher-order functions
import gc  # Importing gc module for garbage collection utilities
import inspect  # Importing inspect module for examining live objects
import json  # Importing json module for JSON encoding and decoding
import os  # Importing os module for operating system functionalities
import pickle  # Importing pickle module for object serialization
import re  # Importing re module for regular expressions
import warnings  # Importing warnings module for issuing warnings

from collections.abc import Mapping  # Importing Mapping from collections.abc for ABCs of collections
from pathlib import Path  # Importing Path from pathlib for object-oriented filesystem paths
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union  # Importing typing modules for type hints

import h5py  # Importing h5py for HDF5 file support
import numpy as np  # Importing numpy for numerical computing
import tensorflow as tf  # Importing tensorflow library

from packaging.version import parse  # Importing parse from packaging.version for version parsing

from . import DataCollatorWithPadding, DefaultDataCollator  # Importing local modules
from .activations_tf import get_tf_activation  # Importing get_tf_activation from activations_tf module
from .configuration_utils import PretrainedConfig  # Importing PretrainedConfig from configuration_utils module
from .dynamic_module_utils import custom_object_save  # Importing custom_object_save from dynamic_module_utils module
from .generation import GenerationConfig, TFGenerationMixin  # Importing GenerationConfig and TFGenerationMixin
from .tf_utils import (
    convert_batch_encoding,  # Importing convert_batch_encoding function
    expand_1d,  # Importing expand_1d function
    load_attributes_from_hdf5_group,  # Importing load_attributes_from_hdf5_group function
    save_attributes_to_hdf5_group,  # Importing save_attributes_to_hdf5_group function
    shape_list,  # Importing shape_list function
)

from .utils import (
    SAFE_WEIGHTS_INDEX_NAME,  # Importing constants from utils module
    SAFE_WEIGHTS_NAME,
    TF2_WEIGHTS_INDEX_NAME,
    TF2_WEIGHTS_NAME,
    TF_WEIGHTS_NAME,
    WEIGHTS_INDEX_NAME,
    WEIGHTS_NAME,
    ModelOutput,  # Importing ModelOutput class
    PushToHubMixin,  # Importing PushToHubMixin class
    cached_file,  # Importing cached_file function
    download_url,  # Importing download_url function
    find_labels,  # Importing find_labels function
    has_file,  # Importing has_file function
    is_offline_mode,  # Importing is_offline_mode function
    is_remote_url,  # Importing is_remote_url function
    is_safetensors_available,  # Importing is_safetensors_available function
    is_tf_symbolic_tensor,  # Importing is_tf_symbolic_tensor function
    logging,  # Importing logging utilities
    requires_backends,  # Importing requires_backends decorator
    working_or_temp_dir,  # Importing working_or_temp_dir function
)
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files  # Importing hub-related utilities

# Checking if safetensors library is available and importing related functions if so
if is_safetensors_available():
    from safetensors import safe_open
    from safetensors.tensorflow import save_file as safe_save_file

# Checking if TYPE_CHECKING is True, then importing PreTrainedTokenizerBase from local module
if TYPE_CHECKING:
    from . import PreTrainedTokenizerBase

# Getting logger from logging utilities
logger = logging.get_logger(__name__)

# Setting TF_USE_LEGACY_KERAS environment variable to '1' for compatibility with Keras 2
if "TF_USE_LEGACY_KERAS" not in os.environ:
    os.environ["TF_USE_LEGACY_KERAS"] = "1"
elif os.environ["TF_USE_LEGACY_KERAS"] != "1":
    # Warning if TF_USE_LEGACY_KERAS is set to '0' explicitly, which may cause issues with Transformers models
    logger.warning(
        "Transformers is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. "
        "This may result in unexpected behaviour or errors if Keras 3 objects are passed to Transformers models."
    )

# Attempting to import tf_keras as keras and backend as K, falling back to keras and keras.backend if not available
try:
    import tf_keras as keras
    from tf_keras import backend as K
except (ModuleNotFoundError, ImportError):
    import keras
    from keras import backend as K
    # 检查导入的 Keras 版本是否大于 2
    if parse(keras.__version__).major > 2:
        # 如果版本大于 2,则抛出值错误异常
        raise ValueError(
            "Your currently installed version of Keras is Keras 3, but this is not yet supported in "
            "Transformers. Please install the backwards-compatible tf-keras package with "
            "`pip install tf-keras`."
        )
# 获取 TensorFlow 的日志记录器对象
tf_logger = tf.get_logger()

# 定义一个类型别名,表示可以作为 TF 模型的输入的多种可能类型
TFModelInputType = Union[
    List[tf.Tensor],         # 列表中包含 TensorFlow 张量
    List[np.ndarray],        # 列表中包含 NumPy 数组
    Dict[str, tf.Tensor],    # 字典,键是字符串,值是 TensorFlow 张量
    Dict[str, np.ndarray],   # 字典,键是字符串,值是 NumPy 数组
    tf.Tensor,               # 单个 TensorFlow 张量
    np.ndarray,              # 单个 NumPy 数组
]

# 定义一个简单的损失函数,如果预测值的维度小于等于 1,则直接返回预测值,否则返回沿指定轴的均值
def dummy_loss(y_true, y_pred):
    if y_pred.shape.rank <= 1:
        return y_pred
    else:
        reduction_axes = list(range(1, y_pred.shape.rank))
        return tf.reduce_mean(y_pred, axis=reduction_axes)


class TFModelUtilsMixin:
    """
    `keras.Model` 的几个实用工具方法,作为 Mixin 使用。
    """

    def num_parameters(self, only_trainable: bool = False) -> int:
        """
        获取模型中参数的数量(可选只计算可训练的参数)。

        Args:
            only_trainable (`bool`, *optional*, 默认为 `False`):
                是否只返回可训练参数的数量。

        Returns:
            `int`: 参数的数量。
        """
        if only_trainable:
            return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
        else:
            return self.count_params()


def keras_serializable(cls):
    """
    装饰一个 Keras 层类,以支持 Keras 序列化。

    这是通过以下方式实现的:

    1. 在 `get_config` 中为 Keras 配置字典添加 `transformers_config` 字典(在序列化时由 Keras 调用)。
    2. 包装 `__init__` 方法以接受 `transformers_config` 字典(在反序列化时由 Keras 传递)并将其转换为实际层初始化器的配置对象。
    3. 在 Keras 中注册该类作为自定义对象(如果 Tensorflow 版本支持),因此在调用 `keras.models.load_model` 时不需要在 `custom_objects` 中提供它。

    Args:
        cls (a `keras.layers.Layers subclass`):
            通常是项目中的 `TF.MainLayer` 类,一般必须接受 `config` 参数作为其初始化器。

    Returns:
        经过修改以支持 Keras 反序列化的同一类对象。
    """
    initializer = cls.__init__

    config_class = getattr(cls, "config_class", None)
    if config_class is None:
        raise AttributeError("Must set `config_class` to use @keras_serializable")

    @functools.wraps(initializer)
    def wrapped_init(self, *args, **kwargs):
        config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.pop("config", None)

        if isinstance(config, dict):
            config = config_class.from_dict(config)
            initializer(self, config, *args, **kwargs)
        elif isinstance(config, PretrainedConfig):
            if len(args) > 0:
                initializer(self, *args, **kwargs)
            else:
                initializer(self, config, *args, **kwargs)
        else:
            raise ValueError("Must pass either `config` (PretrainedConfig) or `config` (dict)")

        self._config = config
        self._kwargs = kwargs

    cls.__init__ = wrapped_init
    # 检查类 cls 是否具有 get_config 方法,如果没有,则抛出 TypeError 异常
    if not hasattr(cls, "get_config"):
        raise TypeError("Only use @keras_serializable on keras.layers.Layer subclasses")
    
    # 检查 cls 的 get_config 方法是否具有 "_is_default" 属性
    if hasattr(cls.get_config, "_is_default"):
        
        # 定义新的 get_config 方法,用于序列化对象的配置信息
        def get_config(self):
            # 调用父类的 get_config 方法,获取默认配置
            cfg = super(cls, self).get_config()
            # 将当前对象的配置转换为字典,并存储在 cfg["config"] 中
            cfg["config"] = self._config.to_dict()
            # 将对象的关键字参数更新到 cfg 中
            cfg.update(self._kwargs)
            return cfg
        
        # 将新定义的 get_config 方法赋值给 cls 的 get_config 属性
        cls.get_config = get_config
    
    # 将 _keras_serializable 标记设置为 True,表示对象已经被序列化
    cls._keras_serializable = True
    
    # 如果 keras.utils 中存在 register_keras_serializable 方法,则注册 cls
    if hasattr(keras.utils, "register_keras_serializable"):
        cls = keras.utils.register_keras_serializable()(cls)
    
    # 返回经过处理的 cls 对象
    return cls
# 定义一个适用于因果语言建模(CLM)的损失函数类,即猜测下一个标记的任务。
class TFCausalLanguageModelingLoss:
    """
    Loss function suitable for causal language modeling (CLM), that is, the task of guessing the next token.

    <Tip>

    Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.

    </Tip>
    """

    # 使用标签和logits计算损失的方法
    def hf_compute_loss(self, labels, logits):
        # 定义稀疏分类交叉熵损失函数,from_logits=True 表示输入为 logits
        loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
        
        # 如果配置为 tf_legacy_loss,则仅仅处理不等于 -100 的标签
        if self.config.tf_legacy_loss:
            # 创建一个布尔掩码,标记所有不等于 -100 的位置
            active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
            # 使用布尔掩码过滤 logits,并降维处理
            reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
            # 使用布尔掩码过滤标签,并降维处理
            labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
            return loss_fn(labels, reduced_logits)
        
        # 将负标签裁剪为零,以避免 NaN 和错误,这些位置将在后续被掩码
        unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
        # 创建一个损失掩码,确保仅处理不等于 -100 的标签
        loss_mask = tf.cast(labels != -100, dtype=unmasked_loss.dtype)
        # 应用损失掩码到未掩码的损失
        masked_loss = unmasked_loss * loss_mask
        # 计算平均掩码后的损失
        reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask)
        return tf.reshape(reduced_masked_loss, (1,))


class TFQuestionAnsweringLoss:
    """
    Loss function suitable for question answering.
    """

    # 使用标签和logits计算损失的方法
    def hf_compute_loss(self, labels, logits):
        # 定义稀疏分类交叉熵损失函数,from_logits=True 表示输入为 logits
        loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
        # 计算起始位置的损失
        start_loss = loss_fn(labels["start_position"], logits[0])
        # 计算结束位置的损失
        end_loss = loss_fn(labels["end_position"], logits[1])
        # 返回起始和结束位置损失的平均值
        return (start_loss + end_loss) / 2.0


class TFTokenClassificationLoss:
    """
    Loss function suitable for token classification.

    <Tip>

    Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.

    </Tip>
    """
    # 定义一个方法用于计算损失,需要传入标签和对数概率
    def hf_compute_loss(self, labels, logits):
        # 使用稀疏分类交叉熵损失函数,设置为从对数概率计算,不进行损失的汇总
        loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
        
        # 如果当前是即时执行模式(eager execution),则执行以下条件判断
        if tf.executing_eagerly():  # Data-dependent conditionals are forbidden in XLA
            # 如果标签中存在值为 -1 的情况,打印警告信息,建议使用 -100 替代 -1 来屏蔽损失
            if tf.math.reduce_any(labels == -1):
                tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
        
        # 如果配置中指定使用传统的 TensorFlow 损失计算方法
        if self.config.tf_legacy_loss:
            # 如果标签中存在值为 -1 的情况,打印警告信息,建议使用 -100 替代 -1 来屏蔽损失
            if tf.math.reduce_any(labels == -1):
                tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
                # 将标签中不等于 -1 的位置筛选出来,作为有效的损失位置
                active_loss = tf.reshape(labels, (-1,)) != -1
            else:
                # 将标签中不等于 -100 的位置筛选出来,作为有效的损失位置
                active_loss = tf.reshape(labels, (-1,)) != -100
            
            # 从 logits 中筛选出有效的预测值,并且展平为一维数组
            reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
            # 从标签中筛选出有效的标签值,并且展平为一维数组
            labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
            
            # 返回计算后的损失值
            return loss_fn(labels, reduced_logits)
        
        # 对负数标签进行裁剪,转换为零,避免出现 NaN 和错误,这些位置之后会被屏蔽掉
        unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
        
        # 确保只有标签不等于 -100 或 -1 的位置被计入损失计算
        loss_mask = tf.cast(labels >= 0, dtype=unmasked_loss.dtype)
        
        # 避免之后可能出现的除以零错误
        # 屏蔽掉的位置将因为 -100 和 -1 不是有效标签而导致损失为 NaN
        masked_loss = unmasked_loss * loss_mask
        
        # 计算屏蔽后的损失总和,并除以有效损失位置的总数来得到平均损失
        reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask)
        
        # 将结果重新整形为长度为 1 的张量,并返回
        return tf.reshape(reduced_masked_loss, (1,))
class TFSequenceClassificationLoss:
    """
    Loss function suitable for sequence classification.
    """

    def hf_compute_loss(self, labels, logits):
        # 如果 logits 的形状是 1 维或者第二维是 1,使用均方误差损失函数
        if logits.shape.rank == 1 or logits.shape[1] == 1:
            loss_fn = keras.losses.MeanSquaredError(reduction=keras.losses.Reduction.NONE)
            if labels.shape.rank == 1:
                # 如果 labels 是 1 维的,则将其扩展为二维
                labels = tf.expand_dims(labels, axis=-1)
        else:
            # 否则使用稀疏分类交叉熵损失函数
            loss_fn = keras.losses.SparseCategoricalCrossentropy(
                from_logits=True, reduction=keras.losses.Reduction.NONE
            )

        return loss_fn(labels, logits)


class TFMultipleChoiceLoss:
    """Loss function suitable for multiple choice tasks."""

    def hf_compute_loss(self, labels, logits):
        # 使用稀疏分类交叉熵损失函数
        loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
        return loss_fn(labels, logits)


class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss):
    """
    Loss function suitable for masked language modeling (MLM), that is, the task of guessing the masked tokens.

    <Tip>

    Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.

    </Tip>
    """


class TFNextSentencePredictionLoss:
    """
    Loss function suitable for next sentence prediction (NSP), that is, the task of guessing the next sentence.

    <Tip>

    Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.

    </Tip>
    """

    def hf_compute_loss(self, labels, logits):
        # 使用稀疏分类交叉熵损失函数
        loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
        if self.config.tf_legacy_loss:
            # 确保仅计算不等于 -100 的标签作为损失
            next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
            next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss)
            next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss)

            return loss_fn(next_sentence_label, next_sentence_reduced_logits)

        # 确保仅计算不等于 -100 的标签作为损失

        # 在这里将负标签剪切为零,以避免 NaN 和错误 - 这些位置后续将被屏蔽
        unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels), y_pred=logits)
        ns_loss_mask = tf.cast(labels != -100, dtype=unmasked_ns_loss.dtype)
        # 将标签为 -100 的样本归零,不进行减少
        masked_ns_loss = unmasked_ns_loss * ns_loss_mask

        return masked_ns_loss


def booleans_processing(config, **kwargs):
    """
    Process the input booleans of each model.
    """
    # 创建一个空字典,用于存储最终的布尔值选项
    final_booleans = {}
    
    # 如果在传入的参数 kwargs 中存在 "output_attentions",则处理其布尔值设定:
    # 如果 kwargs["output_attentions"] 不为 None,则使用它;否则使用 config.output_attentions 的值
    if "output_attentions" in kwargs:
        final_booleans["output_attentions"] = (
            kwargs["output_attentions"] if kwargs["output_attentions"] is not None else config.output_attentions
        )
    
    # 处理 "output_hidden_states" 的布尔值设定:
    # 如果 kwargs["output_hidden_states"] 不为 None,则使用它;否则使用 config.output_hidden_states 的值
    final_booleans["output_hidden_states"] = (
        kwargs["output_hidden_states"] if kwargs["output_hidden_states"] is not None else config.output_hidden_states
    )
    
    # 处理 "return_dict" 的布尔值设定:
    # 如果 kwargs["return_dict"] 不为 None,则使用它;否则使用 config.return_dict 的值
    final_booleans["return_dict"] = kwargs["return_dict"] if kwargs["return_dict"] is not None else config.return_dict
    
    # 如果在 kwargs 中有 "use_cache" 参数,则处理其布尔值设定:
    # 如果 kwargs["use_cache"] 不为 None,则使用它;否则尝试使用 config.use_cache 的值,如果 config 没有 use_cache 属性则为 None
    if "use_cache" in kwargs:
        final_booleans["use_cache"] = (
            kwargs["use_cache"] if kwargs["use_cache"] is not None else getattr(config, "use_cache", None)
        )
    
    # 返回存储了所有布尔选项的字典
    return final_booleans
# 定义一个装饰器函数,用于处理传递给 Keras 层的输入参数,将它们作为关键字参数传递给层。这样可以通过它们的变量名在下游使用这些输入,即使它们作为字典打包在第一个输入中(在 Keras 中很常见)。

original_signature = inspect.signature(func)
# 获取传入函数的原始签名信息

@functools.wraps(func)
def run_call_with_unpacked_inputs(self, *args, **kwargs):
    # 从装饰函数的 kwargs 中隔离出实际的 `**kwargs`
    kwargs_call = {key: val for key, val in kwargs.items() if key not in dict(original_signature.parameters)}
    # 从 kwargs 中分离出用于函数调用的参数和关键字参数
    fn_args_and_kwargs = {key: val for key, val in kwargs.items() if key not in kwargs_call}
    fn_args_and_kwargs.update({"kwargs_call": kwargs_call})

    # 如果存在任何参数,将其移动到 kwargs 中
    fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args)))

    # 对于 EncoderDecoder 模型,将配置选项应用于其内部模型。
    if "EncoderDecoder" in self.__class__.__name__:
        config = None
    else:
        config = self.config

    # 调用 input_processing 函数处理输入
    unpacked_inputs = input_processing(func, config, **fn_args_and_kwargs)
    # 调用原始函数并传递解包后的输入
    return func(self, **unpacked_inputs)

# Keras 要求传递第一个层参数,并通过 `inspect.getfullargspec()` 进行检查。这个函数不遵循装饰器链(即不考虑 `functools.wraps()`),因此必须使用以下行以确保 Keras 检查第一个参数与原始签名匹配。
run_call_with_unpacked_inputs.__signature__ = original_signature

return run_call_with_unpacked_inputs
    # 定义允许的数据类型元组,包括 TensorFlow 张量、布尔值、整数、模型输出、元组、列表、字典和 NumPy 数组
    allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
    
    # 如果 kwargs 字典中包含键 "kwargs_call" 中的 "inputs",发出警告并将其替换为 "input_ids"
    if "inputs" in kwargs["kwargs_call"]:
        warnings.warn(
            "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
            FutureWarning,
        )
        output["input_ids"] = kwargs["kwargs_call"].pop("inputs")
    
    # 如果 kwargs 字典中包含键 "kwargs_call" 中的 "decoder_cached_states",发出警告并将其替换为 "past_key_values"
    if "decoder_cached_states" in kwargs["kwargs_call"]:
        warnings.warn(
            "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
            " `past_key_values` instead.",
            FutureWarning,
        )
        output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states")
    
    # 如果 kwargs 字典中同时包含 "past" 和 "past_key_values",根据参数名称列表作相应处理
    if "past" in kwargs["kwargs_call"] and "past_key_values" in parameter_names:
        warnings.warn(
            "The `past` argument is deprecated and will be removed in a future version, use `past_key_values`"
            " instead.",
            FutureWarning,
        )
        kwargs["past_key_values"] = kwargs["kwargs_call"].pop("past")
    elif "past_key_values" in kwargs["kwargs_call"] and "past" in parameter_names:
        kwargs["past"] = kwargs["kwargs_call"].pop("past_key_values")
    
    # 如果存在额外的关键字参数(kwargs_call),将其从 kwargs 中弹出并存储在 output 字典中的 "kwargs" 键下
    if has_kwargs:
        output["kwargs"] = kwargs.pop("kwargs_call", {})
    else:
        # 如果 kwargs_call 不为空,则引发 ValueError 异常,指示模型不支持这些关键字参数
        if len(kwargs["kwargs_call"]) > 0:
            raise ValueError(
                "The following keyword arguments are not supported by this model:"
                f" {list(kwargs['kwargs_call'].keys())}."
            )
        kwargs.pop("kwargs_call")
    
    # 遍历 kwargs 字典,检查每个键值对的值是否是允许的数据类型之一,如果是则存储在 output 字典中对应的键下
    for k, v in kwargs.items():
        if isinstance(v, allowed_types) or tf.is_tensor(v) or v is None:
            output[k] = v
        else:
            # 如果值的类型不允许,则引发 ValueError 异常,指出具体类型和不允许的数据类型列表
            raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
    
    # 如果 main_input 是元组或列表,则遍历其中的每个输入
    if isinstance(main_input, (tuple, list)):
        for i, input in enumerate(main_input):
            # 如果输入是 TensorFlow 符号张量,并且输入的名称在 parameter_names 中,则存储在 output 中对应的键下
            if is_tf_symbolic_tensor(input):
                # TensorFlow 张量的名称通常是 `name:id` 格式,这里只提取 `name` 部分
                tensor_name = input.name.split(":")[0]
    
                if tensor_name in parameter_names:
                    output[tensor_name] = input
                else:
                    output[parameter_names[i]] = input
            # 如果输入是允许的数据类型之一或为 None,则存储在 output 中对应的键下
            elif isinstance(input, allowed_types) or input is None:
                output[parameter_names[i]] = input
            else:
                # 如果输入的类型不允许,则引发 ValueError 异常,指出具体类型和不允许的数据类型列表
                raise ValueError(
                    f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for"
                    f" {parameter_names[i]}."
                )
    # 如果 main_input 是一个 Mapping 类型(如字典),则执行以下操作
    elif isinstance(main_input, Mapping):
        # 如果 main_input 中包含键 "inputs"
        if "inputs" in main_input:
            # 发出警告,说明 `inputs` 参数已废弃,并在将来的版本中会移除,建议使用 `input_ids` 替代
            warnings.warn(
                "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids`"
                " instead.",
                FutureWarning,
            )
            # 将 main_input 中的 "inputs" 弹出并放入 output 的 "input_ids" 中
            output["input_ids"] = main_input.pop("inputs")

        # 如果 main_input 中包含键 "decoder_cached_states"
        if "decoder_cached_states" in main_input:
            # 发出警告,说明 `decoder_cached_states` 参数已废弃,并在将来的版本中会移除,建议使用 `past_key_values` 替代
            warnings.warn(
                "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
                " `past_key_values` instead.",
                FutureWarning,
            )
            # 将 main_input 中的 "decoder_cached_states" 弹出并放入 output 的 "past_key_values" 中
            output["past_key_values"] = main_input.pop("decoder_cached_states")

        # 遍历 main_input 中的键值对
        for k, v in dict(main_input).items():
            # 如果值 v 的类型属于允许的类型 allowed_types 或者为 None
            if isinstance(v, allowed_types) or v is None:
                # 将键值对放入 output 中
                output[k] = v
            # 如果键 k 不在参数名称列表 parameter_names 中,且 "args" 不在参数名称列表中
            elif k not in parameter_names and "args" not in parameter_names:
                # 记录警告日志,说明参数 k 不属于参数列表 parameter_names 中,并将被忽略
                logger.warning(
                    f"The parameter {k} does not belongs to the parameter list {parameter_names} and will be ignored."
                )
                continue
            else:
                # 抛出数值错误,说明类型为 type(v) 的数据不允许,只有 allowed_types 类型允许传递给参数 k
                raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
    
    # 如果 main_input 不是 Mapping 类型,则执行以下操作
    else:
        # 如果 main_input 是 TensorFlow 的张量或者为 None
        if tf.is_tensor(main_input) or main_input is None:
            # 将 main_input 放入 output 中,键为 main_input_name
            output[main_input_name] = main_input
        else:
            # 抛出数值错误,说明类型为 type(main_input) 的数据不允许,只有 allowed_types 类型允许传递给 main_input_name
            raise ValueError(
                f"Data of type {type(main_input)} is not allowed only {allowed_types} is accepted for {main_input_name}."
            )

    # 将未指定的参数按照签名的默认值填充到 output 中
    for name in parameter_names:
        # 如果参数名称 name 不在 output 的键列表中,且不为 "args"
        if name not in list(output.keys()) and name != "args":
            # 将参数名称 name 的默认值(来自 kwargs 或者签名中)填充到 output 中
            output[name] = kwargs.pop(name, signature[name].default)

    # 当创建 SavedModel 时,TF 会通过 LayerCall.__call__(args, **kwargs) 调用方法
    # 因此为了正确输出,需要处理此异常情况
    if "args" in output:
        # 如果 output 中的 "args" 不为 None,并且是 TensorFlow 符号张量
        if output["args"] is not None and is_tf_symbolic_tensor(output["args"]):
            # 获取张量的名称
            tensor_name = output["args"].name.split(":")[0]
            # 将 output 中的 "args" 放入 output 中,键为张量的名称
            output[tensor_name] = output["args"]
        else:
            # 在这种情况下,"args" 总是第一个参数,然后是 "input_ids"
            output["input_ids"] = output["args"]

        # 从 output 中删除 "args"
        del output["args"]

    # 如果 output 中存在 "kwargs",从 output 中删除 "kwargs"
    if "kwargs" in output:
        del output["kwargs"]

    # 创建一个新的字典 cast_output
    cast_output = {}
    # 遍历 output 中的键值对
    for key, val in output.items():
        # 如果值 val 是 TensorFlow 的张量且数据类型为 tf.int64
        if isinstance(val, tf.Tensor) and val.dtype == tf.int64:
            # 将 val 转换为 tf.int32 类型,并放入 cast_output 中
            cast_output[key] = tf.cast(val, tf.int32)
        # 如果值 val 是 NumPy 的数组且数据类型为 np.int64
        elif isinstance(val, np.ndarray) and val.dtype == np.int64:
            # 将 val 转换为 np.int32 类型,并放入 cast_output 中
            cast_output[key] = val.astype(np.int32)
        else:
            # 否则直接将 val 放入 cast_output 中
            cast_output[key] = val

    # 将 cast_output 赋值给 output
    output = cast_output
    # 删除 cast_output
    del cast_output
    # 如果配置对象不为空,则从输出字典中提取指定键的键值对,形成布尔类型的字典
    boolean_dict = {
        k: v
        for k, v in output.items()
        if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
    }

    # 调用 booleans_processing 函数处理布尔类型的配置,更新输出字典
    output.update(
        booleans_processing(
            config=config,
            **boolean_dict,
        )
    )

    # 返回更新后的输出字典
    return output
def tf_shard_checkpoint(weights, max_shard_size="10GB"):
    """
    Splits a model state dictionary into sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
    given size.

    The sub-checkpoints are determined by iterating through the `weights` in the order of its keys, ensuring that each
    sub-checkpoint does not exceed `max_shard_size`.

    Args:
        weights (`Dict[str, tf.ResourceVariable]`): The dictionary of tf.ResourceVariable objects representing weights
            of a model.
        max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
            The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
            (like `"5MB"`).
    
    Returns:
        Tuple[Dict[str, List[tf.ResourceVariable]], Optional[Dict[str, List[int]]]]:
            A tuple containing:
                - A dictionary mapping from a checkpoint name (e.g., `"TF2_WEIGHTS_NAME"`) to a list of tf.ResourceVariable objects,
                  representing each sub-checkpoint.
                - Optionally, a dictionary mapping from each checkpoint name to a list of sizes (in bytes) of the corresponding
                  sub-checkpoints.
    """
    max_shard_size = convert_file_size_to_int(max_shard_size)  # Convert `max_shard_size` string to integer bytes

    sharded_state_dicts = []  # Initialize list to hold sub-checkpoints
    current_block = []  # Initialize current sub-checkpoint
    current_block_size = 0  # Initialize current sub-checkpoint size
    total_size = 0  # Initialize total size accumulator

    for item in weights:  # Iterate through each weight item
        weight_size = item.numpy().size * dtype_byte_size(item.dtype)  # Calculate size of current weight in bytes

        # Check if adding current weight would exceed `max_shard_size`, if so, start a new sub-checkpoint
        if current_block_size + weight_size > max_shard_size:
            sharded_state_dicts.append(current_block)  # Append current sub-checkpoint to list
            current_block = []  # Reset current sub-checkpoint
            current_block_size = 0  # Reset current sub-checkpoint size

        current_block.append(item)  # Add current weight to current sub-checkpoint
        current_block_size += weight_size  # Update current sub-checkpoint size
        total_size += weight_size  # Update total size accumulator

    sharded_state_dicts.append(current_block)  # Append the last sub-checkpoint

    # If only one sub-checkpoint exists, return it directly
    if len(sharded_state_dicts) == 1:
        return {TF2_WEIGHTS_NAME: sharded_state_dicts[0]}, None

    # Otherwise, prepare and return a dictionary mapping each checkpoint name to its corresponding list of weights
    weight_map = {}
    shards = {}
    # 遍历分片状态字典列表,同时追踪索引号和每个状态字典
    for idx, shard in enumerate(sharded_state_dicts):
        # 根据索引号生成分片文件名,将 ".h5" 替换为格式化的编号
        shard_file = TF2_WEIGHTS_NAME.replace(".h5", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.h5")
        # 将当前分片存入分片字典,以生成的文件名作为键,分片数据作为值
        shards[shard_file] = shard
        # 遍历当前分片中的每个权重,并将权重名映射到对应的分片文件名
        for weight in shard:
            weight_name = weight.name
            weight_map[weight_name] = shard_file

    # 创建元数据字典,包含总大小信息
    metadata = {"total_size": total_size}
    # 创建索引字典,包含元数据和权重映射信息
    index = {"metadata": metadata, "weight_map": weight_map}
    # 返回分片字典和索引字典作为结果
    return shards, index
# 加载 TensorFlow 分片权重的函数,用于从分片检查点中加载模型的权重。检测缺失和意外的层,并根据它们的名称和形状从分片文件中加载 TensorFlow 权重。
def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None):
    """
    This is the same as `load_tf_weights` but for a sharded checkpoint. Detect missing and unexpected layers and load
    the TF weights from the shard file accordingly to their names and shapes.

    This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
    loaded in the model.

    Args:
        model (`keras.models.Model`): The model in which to load the checkpoint.
        shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names.
        ignore_mismatched_sizes (`bool`, *optional*, defaults to `True`):
            Whether or not to ignore the mismatch between the sizes.
        strict (`bool`, *optional*, defaults to `True`):
            Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.

    Returns:
        Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
        mismatched layers.
    """

    # 创建空集合来存储意外的键、保存的键和不匹配的键
    unexpected_keys = set()
    saved_keys = set()
    mismatched_keys = set()

    # 由于 TensorFlow 将其权重的类名添加到权重中,并使用索引而不是层名称加载权重,因此我们必须去掉层名称的第一个前缀。
    # 创建模型键集合和映射字典
    model_keys = set()
    model_layer_map = {}
    for i, k in enumerate(model.weights):
        layer_name = k.name
        # 如果有前缀,并且层名称以前缀开头,则去除前缀和斜杠
        if _prefix is not None and layer_name.startswith(_prefix):
            layer_name = layer_name[len(_prefix):]
            layer_name = layer_name.lstrip("/")
        # 如果层名称中包含 "model." 或只有一个部分,则保持不变;否则,只保留第二部分作为层名称
        if not ("model." in layer_name or len(layer_name.split("/")) == 1):
            layer_name = "/".join(layer_name.split("/")[1:])
        # 将处理后的层名称添加到模型键集合和映射字典中
        model_keys.add(layer_name)
        model_layer_map[layer_name] = i

    # 遍历每个分片文件,并加载权重
    for shard_file in shard_files:
        # 调用 load_tf_shard 函数加载分片文件中的权重
        saved_weight_names_set, unexpected_keys_set, mismatched_keys_set = load_tf_shard(
            model,
            model_layer_map,
            shard_file,
            ignore_mismatched_sizes=ignore_mismatched_sizes,
            _prefix=_prefix,
        )
        # 更新保存的键、意外的键和不匹配的键集合
        saved_keys.update(saved_weight_names_set)
        unexpected_keys.update(unexpected_keys_set)
        mismatched_keys.update(mismatched_keys_set)
        # 手动进行垃圾回收
        gc.collect()

    # 计算缺失的键集合
    missing_keys = model_keys - saved_keys
    # 如果 strict 为 True 并且存在缺失的键或意外的键,则抛出运行时错误
    if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
        error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
        if len(missing_keys) > 0:
            str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
            error_message += f"\nMissing key(s): {str_missing_keys}."
        if len(unexpected_keys) > 0:
            str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
            error_message += f"\nUnexpected key(s): {str_unexpected_keys}."
        raise RuntimeError(error_message)
    # 返回三个变量:missing_keys(缺失的键列表)、unexpected_keys(意外的键列表)、mismatched_keys(不匹配的键列表)
    return missing_keys, unexpected_keys, mismatched_keys
# 从分片的检查点文件中加载一个分片。处理缺失的键和意外的键。

def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
    """
    Loads a shard from a sharded checkpoint file. Handles the missing keys and unexpected keys.

    Args:
        model (`keras.models.Model`): Model in which the weights are loaded
        model_layer_map (`Dict`): A dictionary mapping the layer name to the index of the layer in the model.
        resolved_archive_file (`str`): Path to the checkpoint file from which the weights will be loaded
        ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): Whether to ignore the mismatched keys

    Returns:
        `keras.models.Model`: Three lists, one for the layers that were found and succesfully restored (from the
        shard file), one for the mismatched layers, and another one for the unexpected layers.
    """

    # 保存已读取的权重名称的集合
    saved_weight_names_set = set()
    # 存储已加载的权重数据的字典
    saved_weights = {}
    # 存储不匹配的键的集合
    mismatched_keys = set()
    # 存储意外的键的集合
    unexpected_keys = set()

    # 读取 H5 文件
    try:
        # 使用 "r" 模式打开 H5 文件作为 sharded_checkpoint_file,使用 with 语句确保文件操作后自动关闭
        with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
            # 从 H5 文件中加载每个层的名称,并存储为集合 saved_h5_model_layers_name
            saved_h5_model_layers_name = set(load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names"))
            # 初始化空列表,用于存储权重的元组 [(权重对象, 权重值), ...]
            weight_value_tuples = []

            # 遍历每个保存的层名称
            for layer_name in saved_h5_model_layers_name:
                # 获取 H5 文件中的层对象
                h5_layer_object = sharded_checkpoint_file[layer_name]
                # 将 H5 文件中的权重转换为 NumPy 数组,并存储在 saved_weights 字典中
                saved_weights[layer_name] = np.asarray(h5_layer_object)

                # 将当前层名称添加到 saved_weight_names_set 集合中
                saved_weight_names_set.add(layer_name)

                # 如果层名称不在 model_layer_map 中,将其添加到 unexpected_keys 集合中
                if layer_name not in model_layer_map:
                    unexpected_keys.add(layer_name)
                else:
                    # 从 model_layer_map 中获取符号权重并赋值给 symbolic_weight
                    symbolic_weight = model.weights[model_layer_map[layer_name]]

                    # 获取保存的权重值
                    saved_weight_value = saved_weights[layer_name]
                    # 如果保存的权重值不为空
                    if saved_weight_value is not None:
                        # 检查当前权重的形状与 H5 文件中的形状是否不同
                        if K.int_shape(symbolic_weight) != saved_weight_value.shape:
                            # 如果形状不兼容,尝试重新调整保存的权重值的形状以匹配当前权重
                            try:
                                array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
                            except ValueError as e:
                                # 如果 ignore_mismatched_sizes 为 True,则将不兼容的形状添加到 mismatched_keys 中
                                if ignore_mismatched_sizes:
                                    mismatched_keys.add(
                                        (layer_name, saved_weight_value.shape, K.int_shape(symbolic_weight))
                                    )
                                    continue
                                else:
                                    raise e
                        else:
                            array = saved_weight_value

                    # 创建权重元组 (symbolic_weight, array),并添加到 weight_value_tuples 列表中
                    weight_value_tuples.append((symbolic_weight, array))

        # 使用 K.batch_set_value 批量设置模型权重
        K.batch_set_value(weight_value_tuples)

        # 返回结果:保存的权重名称集合、未预期的键集合和不匹配的键集合
        return saved_weight_names_set, unexpected_keys, mismatched_keys
    # 捕获任何异常,并尝试处理
    except Exception as e:
        # 尝试打开已解析的归档文件
        try:
            # 使用上下文管理器打开文件
            with open(resolved_archive_file) as f:
                # 如果文件内容以 "version" 开头
                if f.read().startswith("version"):
                    # 抛出 OSError,提示缺少 git-lfs
                    raise OSError(
                        "You seem to have cloned a repository without having git-lfs installed. Please install "
                        "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
                        "you cloned."
                    )
                else:
                    # 否则,抛出 ValueError,提示无法找到必要的预训练模型文件
                    raise ValueError(
                        f"Unable to locate the file {resolved_archive_file} which is necessary to load this pretrained"
                        " model. Make sure you have saved the model properly."
                    ) from e
        except (UnicodeDecodeError, ValueError):
            # 捕获 UnicodeDecodeError 或 ValueError 异常,抛出 OSError
            raise OSError(
                f"Unable to load weights from TF checkpoint file for '{resolved_archive_file}' "
                f"at '{resolved_archive_file}'. "
                "If you tried to load a TF model from a sharded checkpoint, you should try converting the model "
                "by loading it in pytorch and saving it localy. A convertion script should be realeased soon."
            )
# 根据文件后缀判断使用哪种函数加载 TF 权重:如果是 ".safetensors" 后缀,则使用安全张量的加载函数,否则使用 H5 文件的加载函数
if resolved_archive_file.endswith(".safetensors"):
    load_function = load_tf_weights_from_safetensors
else:
    load_function = load_tf_weights_from_h5

# 调用相应的加载函数,加载模型的权重并返回结果
return load_function(
    model, resolved_archive_file, ignore_mismatched_sizes=ignore_mismatched_sizes, _prefix=_prefix
)



def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
    # 初始化一个空列表来存放形状不匹配的层
    mismatched_layers = []

    # 从 H5 文件中读取权重值,并批量设置到模型中
    K.batch_set_value(weight_value_tuples)

    # 计算缺失的和意外的层
    missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
    unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))

    # 返回缺失的层、意外的层和形状不匹配的层
    return missing_layers, unexpected_layers, mismatched_layers



def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
    # 从安全张量文件中读取权重
    # 使用安全的方式打开解析后的存档文件,支持 TensorFlow 框架
    with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
        # 初始化一个空列表,用于存储不匹配的层信息
        mismatched_layers = []
        
        # 获取模型所有权重的名称列表(去除模型名称和前缀)
        weight_names = [strip_model_name_and_prefix(w.name, _prefix=_prefix) for w in model.weights]
        
        # 获取加载的权重文件中所有的键(即权重名称)
        loaded_weight_names = list(safetensors_archive.keys())
        
        # 找出在高级层列表中存在但在加载的权重中不存在的层
        missing_layers = list(set(weight_names) - set(loaded_weight_names))
        
        # 找出在加载的权重中存在但在高级层列表中不存在的层
        unexpected_layers = list(set(loaded_weight_names) - set(weight_names))
        
        # 遍历模型的每一个权重
        for weight in model.weights:
            # 获取去除模型名称和前缀后的权重名称
            weight_name = strip_model_name_and_prefix(weight.name, _prefix=_prefix)
            
            # 如果该权重在加载的权重名称列表中
            if weight_name in loaded_weight_names:
                # 从安全存档中获取该权重的值
                weight_value = safetensors_archive.get_tensor(weight_name)
                
                # 检查当前权重和从H5文件中读取的权重形状是否不同
                if K.int_shape(weight) != weight_value.shape:
                    # 如果形状不同,尝试将从文件中读取的权重值重塑为当前权重的形状
                    try:
                        weight_value = tf.reshape(weight_value, K.int_shape(weight))
                    except (ValueError, tf.errors.InvalidArgumentError) as e:
                        # 如果无法重塑且不忽略形状不匹配,则抛出异常
                        if ignore_mismatched_sizes:
                            # 如果忽略形状不匹配,则将当前权重和文件中权重的不匹配信息添加到列表中
                            mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight)))
                            continue
                        else:
                            raise e
                
                # 将重新形状后的权重值赋值给当前权重
                K.set_value(weight, weight_value)  # weight.assign() might break if weight is a DTensor
    
    # 返回缺失的层列表、意外的层列表和不匹配的层列表
    return missing_layers, unexpected_layers, mismatched_layers
def init_copy_embeddings(old_embeddings, new_num_tokens):
    r"""
    This function aims to reduce the embeddings in case new_num_tokens < old_num_tokens or to pad with -1 in case
    new_num_tokens > old_num_tokens. A mask is also computed in order to know which weight in the embeddings should be
    kept or not. Example:

        - if new_num_tokens=5 and old_num_tokens=4 and old_embeddings=[w1,w2,w3,w4]

            -  mask=[True,True,True,True,False] and current_weights=[w1,w2,w3,w4,-1]
        - if new_num_tokens=4 and old_num_tokens=5 and old_embeddings=[w1,w2,w3,w4,w5]

            - mask=[True,True,True,True] and current_weights=[w1,w2,w3,w4]
    """
    # Get the number of tokens and embedding dimension from the old embeddings
    old_num_tokens, old_embedding_dim = shape_list(old_embeddings)
    
    # Calculate the difference in size between old and new embeddings
    size_diff = new_num_tokens - old_num_tokens

    # initialize new embeddings
    # Copy token embeddings from the previous ones
    if tf.math.greater(size_diff, 0):
        # if the new size is greater than the old one, we extend the current embeddings with a padding until getting new size
        # and we create a mask to properly identify the padded values and be replaced by the values of the newly created
        # embeddings
        
        # Pad the old embeddings with -1 to extend to the new size
        current_weights = tf.pad(
            old_embeddings.value(), tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=-1
        )
        
        # Determine how many tokens to copy and create a mask to identify them
        num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
        mask = tf.fill(tf.convert_to_tensor([num_tokens_to_copy, 1]), True)
        
        # Pad the mask to match the extended embeddings size
        mask = tf.pad(mask, tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=False)
    else:
        # if the new size if lower than the old one, we take the current embeddings until the new size
        
        # Slice the old embeddings to match the new size
        current_weights = tf.slice(
            old_embeddings.value(),
            tf.convert_to_tensor([0, 0]),
            tf.convert_to_tensor([new_num_tokens, old_embedding_dim]),
        )
        
        # Create a mask for the entire new size
        mask = tf.fill(tf.convert_to_tensor([new_num_tokens, 1]), True)

    # Return the mask and the current weights
    return mask, current_weights
    """
    Class attributes (overridden by derived classes):

        - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
          for this model architecture.
        - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
          classes of the same architecture adding modules on top of the base model.
        - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
          models, `pixel_values` for vision models and `input_values` for speech models).
    """

    # 配置类,用作该模型架构的配置类,应该是 PretrainedConfig 的子类
    config_class = None

    # 基础模型前缀,表示在相同架构的派生类中与基础模型相关联的属性字符串
    base_model_prefix = ""

    # 主要输入名称,模型的主要输入名称,通常为 `input_ids`(用于 NLP 模型)、`pixel_values`(用于视觉模型)和 `input_values`(用于语音模型)
    main_input_name = "input_ids"

    # 自动分类,未指定
    _auto_class = None

    # 使用虚拟损失,未指定
    _using_dummy_loss = None

    # 标签到输出映射,未指定
    _label_to_output_map = None

    # 在加载模型权重时要忽略的张量名称的正则表达式列表,避免不必要的警告
    _keys_to_ignore_on_load_missing = None

    # 在加载模型权重时要忽略的权重中张量名称的正则表达式列表,避免不必要的警告
    _keys_to_ignore_on_load_unexpected = None

    # 是否需要加载权重前缀,默认为 False
    _requires_load_weight_prefix = False

    @property
    def dummy_inputs(self) -> Dict[str, tf.Tensor]:
        """
        Dummy inputs to build the network.

        Returns:
            `Dict[str, tf.Tensor]`: The dummy inputs.
        """
        dummies = {}
        for key, spec in self.input_signature.items():
            # 2 是最正确的任意大小。我不会回答这个问题
            dummy_shape = [dim if dim is not None else 2 for dim in spec.shape]
            if spec.shape[0] is None:
                # 但是,为了节省内存,让批量大小为 1
                dummy_shape[0] = 1
            dummies[key] = tf.ones(shape=dummy_shape, dtype=spec.dtype)
            if key == "token_type_ids":
                # 一些模型具有 token_type_ids,但 vocab_size 为 1
                dummies[key] = tf.zeros_like(dummies[key])
        if self.config.add_cross_attention and "encoder_hidden_states" in inspect.signature(self.call).parameters:
            if "encoder_hidden_states" not in dummies:
                if self.main_input_name == "input_ids":
                    dummies["encoder_hidden_states"] = tf.ones(
                        shape=(1, 2, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states"
                    )
                else:
                    raise NotImplementedError(
                        "Model has cross-attention but we couldn't infer the shape for the encoder hidden states. Please manually override dummy_inputs!"
                    )
        return dummies

    def build_in_name_scope(self):
        with tf.name_scope(self.name):
            self.build(input_shape=None)

    @property
    def framework(self) -> str:
        """
        :str: Identifies that this is a TensorFlow model.
        """
        return "tf"
    # 定义一个方法 `build`,用于构建模型,接受一个可选的输入形状参数 `input_shape`
    def build(self, input_shape=None):
        pass  # 这里只是为了确保不调用父类的 `build()`

    # 初始化方法 `__init__`,接受一个配置参数 `config` 和可变数量的位置参数 `inputs` 和关键字参数 `kwargs`
    def __init__(self, config, *inputs, **kwargs):
        # 调用父类的初始化方法
        super().__init__(*inputs, **kwargs)
        # 如果 `config` 不是 `PretrainedConfig` 类的实例,则抛出异常
        if not isinstance(config, PretrainedConfig):
            raise ValueError(
                f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
                "`PretrainedConfig`. To create a model from a pretrained model use "
                f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
            )
        # 将 `config` 和预训练权重的原始来源(如果在模型中给出)保存在实例中
        self.config = config
        self.name_or_path = config.name_or_path
        # 如果模型可以生成文本,则根据 `config` 创建 `GenerationConfig` 实例并保存在 `generation_config` 中,否则设为 `None`
        self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
        # 设置保存规范为输入签名 `input_signature` 的保存规范
        self._set_save_spec(self.input_signature)

    # 获取模型配置的方法,返回配置的字典表示
    def get_config(self):
        return self.config.to_dict()

    # 使用 `convert_batch_encoding` 转换参数,然后调用父类的 `fit` 方法
    @functools.wraps(keras.Model.fit)
    def fit(self, *args, **kwargs):
        args, kwargs = convert_batch_encoding(*args, **kwargs)
        return super().fit(*args, **kwargs)

    # 使用 `convert_batch_encoding` 转换参数,然后调用父类的 `train_on_batch` 方法
    @functools.wraps(keras.Model.train_on_batch)
    def train_on_batch(self, *args, **kwargs):
        args, kwargs = convert_batch_encoding(*args, **kwargs)
        return super().train_on_batch(*args, **kwargs)

    # 使用 `convert_batch_encoding` 转换参数,然后调用父类的 `test_on_batch` 方法
    @functools.wraps(keras.Model.test_on_batch)
    def test_on_batch(self, *args, **kwargs):
        args, kwargs = convert_batch_encoding(*args, **kwargs)
        return super().test_on_batch(*args, **kwargs)

    # 使用 `convert_batch_encoding` 转换参数,然后调用父类的 `predict_on_batch` 方法
    @functools.wraps(keras.Model.predict_on_batch)
    def predict_on_batch(self, *args, **kwargs):
        args, kwargs = convert_batch_encoding(*args, **kwargs)
        return super().predict_on_batch(*args, **kwargs)

    # 使用 `convert_batch_encoding` 转换参数,然后调用父类的 `predict` 方法
    @functools.wraps(keras.Model.predict)
    def predict(self, *args, **kwargs):
        args, kwargs = convert_batch_encoding(*args, **kwargs)
        return super().predict(*args, **kwargs)

    # 使用 `convert_batch_encoding` 转换参数,然后调用父类的 `evaluate` 方法
    @functools.wraps(keras.Model.evaluate)
    def evaluate(self, *args, **kwargs):
        args, kwargs = convert_batch_encoding(*args, **kwargs)
        return super().evaluate(*args, **kwargs)

    # 类方法 `from_config`,接受 `config` 和其他关键字参数 `kwargs`
    @classmethod
    def from_config(cls, config, **kwargs):
        # 如果 `config` 是 `PretrainedConfig` 类的实例,则调用 `_from_config` 方法
        if isinstance(config, PretrainedConfig):
            return cls._from_config(config, **kwargs)
        # 否则,根据 `config` 字典创建 `config_class` 实例,并调用 `_from_config` 方法
        return cls._from_config(cls.config_class.from_dict(config, **kwargs))

    # 类方法 `_from_config`,接受 `config` 和其他关键字参数 `kwargs`
    @classmethod
    def _from_config(cls, config, **kwargs):
        """
        所有模型初始化时应置于其下的上下文管理器都在这里。
        """
        # 使用 `config` 和其他关键字参数初始化类 `cls` 的实例
        return cls(config, **kwargs)
    def get_head_mask(self, head_mask: tf.Tensor | None, num_hidden_layers: int) -> tf.Tensor:
        """
        Prepare the head mask if needed.

        Args:
            head_mask (`tf.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
                The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
            num_hidden_layers (`int`):
                The number of hidden layers in the model.

        Returns:
            `tf.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
            `[None]` for each layer.
        """
        # 如果传入的头部掩码不为 None,则调用 _convert_head_mask_to_5d 方法将其转换为 5 维张量
        if head_mask is not None:
            head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
        else:
            # 如果头部掩码为 None,则创建一个列表,包含 num_hidden_layers 个 None 元素
            head_mask = [None] * num_hidden_layers

        return head_mask

    def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
        """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
        # 如果头部掩码的维度为 1,将其扩展为 [1 x 1 x num_heads x 1 x 1] 的形式,并复制为 num_hidden_layers 个
        if head_mask.shape.rank == 1:
            head_mask = head_mask[None, None, :, None, None]
            head_mask = tf.repeat(head_mask, repeats=num_hidden_layers, axis=0)
        # 如果头部掩码的维度为 2,将其扩展为 [num_hidden_layers x 1 x num_heads x 1 x 1] 的形式
        elif head_mask.shape.rank == 2:
            head_mask = head_mask[:, None, :, None, None]
        # 断言头部掩码的维度必须为 5,否则抛出异常
        assert head_mask.shape.rank == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
        # 将头部掩码转换为 float32 类型,以支持 float16 兼容性
        head_mask = tf.cast(head_mask, tf.float32)
        return head_mask

    @tf.function
    def serving(self, inputs):
        """
        Args:
        Method used for serving the model. Does not have a specific signature, but will be specialized as concrete
        functions when saving with `save_pretrained`.
            inputs (`Dict[str, tf.Tensor]`):
                The input of the saved model as a dictionary of tensors.
        """
        # 调用模型的 call 方法进行推理,获取输出
        output = self.call(inputs)

        # 返回推理输出的 serving_output 结果
        return self.serving_output(output)

    @property
    # 定义一个方法,返回一个字典,将模型输入的名称映射到 tf.TensorSpec 对象,用于描述模型输入的预期形状和数据类型。
    def input_signature(self) -> Dict[str, tf.TensorSpec]:
        """
        This property should return a dict mapping input names to tf.TensorSpec objects, representing the expected
        shape and dtype for model inputs. It is used for both serving and for generating dummy inputs.
        """
        # 获取调用方法 self.call 的参数列表
        model_inputs = list(inspect.signature(self.call).parameters)
        # 初始化一个空字典用于存储输入签名
        sig = {}
        
        # 检查是否存在 "input_ids" 作为模型输入的一部分
        if "input_ids" in model_inputs:
            # 如果模型类名以 "ForMultipleChoice" 结尾,则文本维度为 3
            if self.__class__.__name__.endswith("ForMultipleChoice"):
                text_dims = 3
            else:
                text_dims = 2
            # 遍历预定义的输入名称列表
            for input_name in (
                "input_ids",
                "attention_mask",
                "token_type_ids",
                "decoder_input_ids",
                "decoder_attention_mask",
            ):
                # 如果当前遍历的输入名称存在于模型输入中
                if input_name in model_inputs:
                    # 将输入名称作为键,创建对应的 tf.TensorSpec 对象,指定形状和数据类型
                    sig[input_name] = tf.TensorSpec([None] * text_dims, tf.int32, name=input_name)
        
        # 检查是否存在 "pixel_values" 作为模型输入的一部分
        if "pixel_values" in model_inputs:
            # 初始化像素值的形状,None 表示任意长度或尺寸
            pixel_values_shape = [None, None, None, None]
            # 根据配置获取视觉输入的配置信息
            if hasattr(self.config, "vision_config"):
                vision_config = self.config.vision_config
            else:
                vision_config = self.config
            # 如果配置中包含 num_channels 属性,则将其设置为像素值形状的第二维度
            if hasattr(vision_config, "num_channels"):
                pixel_values_shape[1] = vision_config.num_channels
            else:
                # 如果无法从配置中推断出通道数,则抛出未实现错误
                raise NotImplementedError(
                    "Could not infer number of channels from config, please override input_signature to specify input shapes."
                )
            # 根据配置中的图像大小信息设置像素值的高度和宽度
            if hasattr(vision_config, "image_size"):
                pixel_values_shape[2] = pixel_values_shape[3] = vision_config.image_size
            elif hasattr(vision_config, "input_size"):
                pixel_values_shape[2] = pixel_values_shape[3] = vision_config.input_size
            else:
                # 如果无法推断输入图像的形状,则抛出未实现错误
                raise NotImplementedError(
                    "Could not infer input image shape from config, please override input_signature to specify input shapes."
                )
            # 将 "pixel_values" 添加到输入签名字典中,创建对应的 tf.TensorSpec 对象
            sig["pixel_values"] = tf.TensorSpec(pixel_values_shape, tf.float32, name="pixel_values")
        
        # 如果模型需要 "input_features" 作为输入,则抛出未实现错误,要求手动定义输入签名
        if "input_features" in model_inputs:
            raise NotImplementedError("Audio models need a manually defined input_signature")
        
        # 返回构建好的输入签名字典
        return sig
    def serving_output(self, output):
        """
        Prepare the output of the saved model. Can be overridden if specific serving modifications are required.
        """
        # 检查输出是否为ModelOutput类型,如果不是,则直接返回输出
        if not isinstance(output, ModelOutput):
            return output
        # 遍历输出的键
        for key in output:
            # 如果键以"hidden_states"结尾且配置中未设置输出隐藏状态,则将对应值设为None
            if key.endswith("hidden_states") and not getattr(self.config, "output_hidden_states", False):
                output[key] = None
            # 如果键以"attentions"结尾且配置中未设置输出注意力权重,则将对应值设为None
            elif key.endswith("attentions") and not getattr(self.config, "output_attentions", False):
                output[key] = None
            # 如果键为"past_key_values"且配置中未设置使用缓存,则将对应值设为None
            elif key == "past_key_values" and not getattr(self.config, "use_cache", False):
                output[key] = None
            # 如果键为"cross_attentions"且配置中未同时设置输出注意力权重和使用交叉注意力,则将对应值设为None
            elif key == "cross_attentions" and not (
                getattr(self.config, "output_attentions", False) and getattr(self.config, "add_cross_attention", False)
            ):
                output[key] = None
            # 如果值为tuple或list类型,尝试将其转换为TensorFlow张量
            if isinstance(output[key], (tuple, list)):
                try:
                    output[key] = tf.convert_to_tensor(output[key])
                except (ValueError, tf.errors.InvalidArgumentError):
                    pass  # 可能由于层的维度不同而无法转换
        return output

    @classmethod
    def can_generate(cls) -> bool:
        """
        Returns whether this model can generate sequences with `.generate()`.

        Returns:
            `bool`: Whether this model can generate sequences with `.generate()`.
        """
        # 检测是否已覆盖了`prepare_inputs_for_generation`方法,这是生成序列的要求之一
        # 或者模型可能有自定义的`generate`函数
        if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
            return False
        return True

    def get_input_embeddings(self) -> keras.layers.Layer:
        """
        Returns the model's input embeddings layer.

        Returns:
            `tf.Variable`: The embeddings layer mapping vocabulary to hidden states.
        """
        # 获取模型的输入嵌入层
        main_layer = getattr(self, self.base_model_prefix, self)

        # 如果main_layer不是self,即存在基础模型前缀,则返回其输入嵌入层
        if main_layer is not self:
            return main_layer.get_input_embeddings()
        else:
            # 否则抛出未实现错误,要求子类实现该方法
            raise NotImplementedError
    # 定义一个方法用于保存模型检查点,将模型参数保存到指定的目录中
    def _save_checkpoint(self, checkpoint_dir, epoch):
        # 如果指定的检查点目录不存在,则创建该目录
        if not os.path.isdir(checkpoint_dir):
            os.mkdir(checkpoint_dir)
        
        # 定义权重文件的保存路径为指定目录下的"weights.h5"
        weights_path = os.path.join(checkpoint_dir, "weights.h5")
        # 调用模型的保存权重方法,将模型的权重保存到weights_path中
        self.save_weights(weights_path)
        
        # 准备额外的数据,包括当前的训练轮数(epoch)和优化器的状态
        extra_data = {"epoch": epoch, "optimizer_state": self.optimizer.get_weights()}
        # 定义额外数据文件的保存路径为指定目录下的"extra_data.pickle"
        extra_data_path = os.path.join(checkpoint_dir, "extra_data.pickle")
        
        # 使用 pickle 序列化额外数据,并保存到extra_data_path中
        with open(extra_data_path, "wb") as f:
            pickle.dump(extra_data, f)

    # 定义一个方法用于准备 TensorFlow 数据集
    def prepare_tf_dataset(
        self,
        dataset: "datasets.Dataset",  # noqa:F821
        batch_size: int = 8,
        shuffle: bool = True,
        tokenizer: Optional["PreTrainedTokenizerBase"] = None,
        collate_fn: Optional[Callable] = None,
        collate_fn_args: Optional[Dict[str, Any]] = None,
        drop_remainder: Optional[bool] = None,
        prefetch: bool = True,
    ):
    
    # 定义一个方法用于编译模型,设置优化器、损失函数、评估指标等
    def compile(
        self,
        optimizer="rmsprop",
        loss="auto_with_warning",
        metrics=None,
        loss_weights=None,
        weighted_metrics=None,
        run_eagerly=None,
        steps_per_execution=None,
        **kwargs,
    ):
    ):
        """
        This is a thin wrapper that sets the model's loss output head as the loss if the user does not specify a loss
        function themselves.
        """
        # 如果用户没有指定损失函数,则将模型的损失输出头部设置为损失函数
        if loss in ("auto_with_warning", "passthrough"):  # "passthrough" for workflow backward compatibility
            # 如果在compile()中没有指定损失函数,将使用模型的内部损失计算作为损失
            logger.info(
                "No loss specified in compile() - the model's internal loss computation will be used as the "
                "loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
                "To disable this behaviour please pass a loss argument, or explicitly pass "
                "`loss=None` if you do not want your model to compute a loss. You can also specify `loss='auto'` to "
                "get the internal loss without printing this info string."
            )
            # 设置损失为"auto",表示使用默认的虚拟损失函数
            loss = "auto"
        if loss == "auto":
            # 如果损失为"auto",则将损失设置为虚拟损失函数dummy_loss,并标记为使用了虚拟损失函数
            loss = dummy_loss
            self._using_dummy_loss = True
        else:
            # 否则,标记为没有使用虚拟损失函数
            self._using_dummy_loss = False
        # 获取父类方法compile()的参数列表
        parent_args = list(inspect.signature(keras.Model.compile).parameters.keys())
        # 检查是否支持参数"steps_per_execution"
        if "steps_per_execution" in parent_args:
            # 如果支持,调用父类方法compile(),使用参数"steps_per_execution"
            super().compile(
                optimizer=optimizer,
                loss=loss,
                metrics=metrics,
                loss_weights=loss_weights,
                weighted_metrics=weighted_metrics,
                run_eagerly=run_eagerly,
                steps_per_execution=steps_per_execution,
                **kwargs,
            )
        else:
            # 否则,调用父类方法compile(),使用参数"experimental_steps_per_execution"(兼容旧版本命名)
            super().compile(
                optimizer=optimizer,
                loss=loss,
                metrics=metrics,
                loss_weights=loss_weights,
                weighted_metrics=weighted_metrics,
                run_eagerly=run_eagerly,
                experimental_steps_per_execution=steps_per_execution,
                **kwargs,
            )

    def compute_loss(self, *args, **kwargs):
        # 检查是否有方法"compute_loss"存在于keras.Model中
        if hasattr(keras.Model, "compute_loss"):
            # 如果是True(TF 2.8或更高版本),调用父类方法compute_loss()
            return super().compute_loss(*args, **kwargs)
        else:
            # 否则,发出警告,指出旧版本的compute_loss方法已弃用,建议使用hf_compute_loss()方法
            warnings.warn(
                "The old compute_loss method is deprecated as it conflicts with the Keras compute_loss "
                "method added in TF 2.8. If you want the original HF compute_loss, please call "
                "hf_compute_loss() instead. From TF versions >= 2.8, or Transformers versions >= 5, "
                "calling compute_loss() will get the Keras method instead.",
                FutureWarning,
            )
            # 返回使用hf_compute_loss()方法计算的损失值
            return self.hf_compute_loss(*args, **kwargs)
    # 获取标签到输出名称的映射关系函数
    def get_label_to_output_name_mapping(self):
        # 使用 Python inspect 模块获取当前函数调用的参数名列表
        arg_names = list(inspect.signature(self.call).parameters)
        # 如果已经存在标签到输出映射关系,直接返回
        if self._label_to_output_map is not None:
            return self._label_to_output_map
        # 根据不同的参数名情况,返回对应的映射关系字典
        elif "start_positions" in arg_names:
            return {"start_positions": "start_logits", "end_positions": "end_logits"}
        elif "sentence_order_label" in arg_names:
            return {"labels": "prediction_logits", "sentence_order_label": "sop_logits"}
        elif "next_sentence_label" in arg_names:
            return {"labels": "prediction_logits", "next_sentence_label": "seq_relationship_logits"}
        elif "mc_labels" in arg_names:
            return {"labels": "logits", "mc_labels": "mc_logits"}
        else:
            # 默认情况下,返回空的映射关系字典
            return {}

    # 创建模型卡函数,用于生成模型卡片的描述
    def create_model_card(
        self,
        output_dir,
        model_name: str,
        language: Optional[str] = None,
        license: Optional[str] = None,
        tags: Optional[str] = None,
        finetuned_from: Optional[str] = None,
        tasks: Optional[str] = None,
        dataset_tags: Optional[Union[str, List[str]]] = None,
        dataset: Optional[Union[str, List[str]]] = None,
        dataset_args: Optional[Union[str, List[str]]] = None,
        ):
            """
            Creates a draft of a model card using the information available to the `Trainer`.

            Args:
                output_dir (`str` or `os.PathLike`):
                    The folder in which to create the model card.
                model_name (`str`, *optional*):
                    The name of the model.
                language (`str`, *optional*):
                    The language of the model (if applicable)
                license (`str`, *optional*):
                    The license of the model. Will default to the license of the pretrained model used, if the original
                    model given to the `Trainer` comes from a repo on the Hub.
                tags (`str` or `List[str]`, *optional*):
                    Some tags to be included in the metadata of the model card.
                finetuned_from (`str`, *optional*):
                    The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
                    of the original model given to the `Trainer` (if it comes from the Hub).
                tasks (`str` or `List[str]`, *optional*):
                    One or several task identifiers, to be included in the metadata of the model card.
                dataset_tags (`str` or `List[str]`, *optional*):
                    One or several dataset tags, to be included in the metadata of the model card.
                dataset (`str` or `List[str]`, *optional*):
                    One or several dataset identifiers, to be included in the metadata of the model card.
                dataset_args (`str` or `List[str]`, *optional*):
                   One or several dataset arguments, to be included in the metadata of the model card.
            """
            # Avoids a circular import by doing this when necessary.
            from .modelcard import TrainingSummary  # tests_ignore

            # 使用 TrainingSummary 类的静态方法 from_keras 创建训练摘要
            training_summary = TrainingSummary.from_keras(
                self,
                keras_history=self.history,
                language=language,
                license=license,
                tags=tags,
                model_name=model_name,
                finetuned_from=finetuned_from,
                tasks=tasks,
                dataset_tags=dataset_tags,
                dataset=dataset,
                dataset_args=dataset_args,
            )
            # 将训练摘要转换为模型卡
            model_card = training_summary.to_model_card()
            # 打开指定路径下的 README.md 文件,以写入模型卡内容
            with open(os.path.join(output_dir, "README.md"), "w") as f:
                f.write(model_card)
    # 设置模型的输入嵌入
    def set_input_embeddings(self, value):
        """
        Set model's input embeddings

        Args:
            value (`tf.Variable`):
                The new weights mapping hidden states to vocabulary.
        """
        # 获取主要的模型层
        main_layer = getattr(self, self.base_model_prefix)

        # 如果主模型层为空,抛出未实现错误
        if main_layer is None:
            raise NotImplementedError("The model does not implements the base_model_prefix attribute.")

        try:
            # 尝试设置输入嵌入到主模型层
            main_layer.set_input_embeddings(value)
        except AttributeError:
            # 如果出现属性错误,记录日志并构建模型
            logger.info("Building the model")
            self.build_in_name_scope()
            # 再次尝试设置输入嵌入到主模型层
            main_layer.set_input_embeddings(value)

    # 获取模型的输出嵌入
    def get_output_embeddings(self) -> Union[None, keras.layers.Layer]:
        """
        Returns the model's output embeddings

        Returns:
            `tf.Variable`: The new weights mapping vocabulary to hidden states.
        """
        # 如果模型有语言模型头部
        if self.get_lm_head() is not None:
            lm_head = self.get_lm_head()

            try:
                # 尝试获取输出嵌入层
                return lm_head.get_output_embeddings()
            except AttributeError:
                # 如果出现属性错误,记录日志并构建模型
                logger.info("Building the model")
                self.build_in_name_scope()

                # 再次尝试获取输出嵌入层
                return lm_head().get_output_embeddings()

        # 如果没有语言模型头部,返回None(适用于没有输出嵌入的模型)
        return None  # Overwrite for models with output embeddings

    # 设置模型的输出嵌入
    def set_output_embeddings(self, value):
        """
        Set model's output embeddings

        Args:
            value (`tf.Variable`):
                The new weights mapping hidden states to vocabulary.
        """
        # 如果模型有语言模型头部
        if self.get_lm_head() is not None:
            lm_head = self.get_lm_head()
            try:
                # 尝试设置输出嵌入到语言模型头部
                lm_head.set_output_embeddings(value)
            except AttributeError:
                # 如果出现属性错误,记录日志并构建模型,然后再次尝试设置输出嵌入
                logger.info("Building the model")
                self.build_in_name_scope()
                lm_head.set_output_embeddings(value)

    # 获取带有偏置的输出层,用于处理模型带有与嵌入权重绑定的偏置属性
    def get_output_layer_with_bias(self) -> Union[None, keras.layers.Layer]:
        """
        Get the layer that handles a bias attribute in case the model has an LM head with weights tied to the
        embeddings

        Return:
            `keras.layers.Layer`: The layer that handles the bias, None if not an LM model.
        """
        warnings.warn(
            "The method get_output_layer_with_bias is deprecated. Please use `get_lm_head` instead.", FutureWarning
        )
        # 返回语言模型头部(如果有)
        return self.get_lm_head()

    # 获取模型名称到父层的前缀偏置名称
    def get_prefix_bias_name(self) -> Union[None, str]:
        """
        Get the concatenated _prefix name of the bias from the model name to the parent layer

        Return:
            `str`: The _prefix name of the bias.
        """
        warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
        # 返回None,因为这个方法已经被废弃
        return None
    def get_bias(self) -> Union[None, Dict[str, tf.Variable]]:
        """
        获取 LM 头部的偏置字典。键表示偏置属性的名称。

        Return:
            `tf.Variable`: 表示偏置的权重,如果不是 LM 模型则返回 None。
        """
        if self.get_lm_head() is not None:
            # 获取 LM 头部的引用
            lm_head = self.get_lm_head()
            try:
                # 尝试获取 LM 头部的偏置
                return lm_head.get_bias()
            except AttributeError:
                # 如果 LM 头部没有 get_bias 方法,则建立名称作用域并尝试再次获取偏置
                self.build_in_name_scope()
                return lm_head.get_bias()
        return None

    def set_bias(self, value):
        """
        设置 LM 头部所有的偏置。

        Args:
            value (`Dict[tf.Variable]`):
                LM 头部新的偏置字典。
        """
        if self.get_lm_head() is not None:
            # 获取 LM 头部的引用
            lm_head = self.get_lm_head()
            try:
                # 尝试设置 LM 头部的偏置
                lm_head.set_bias(value)
            except AttributeError:
                # 如果 LM 头部没有 set_bias 方法,则建立名称作用域并尝试再次设置偏置
                self.build_in_name_scope()
                lm_head.set_bias(value)

    def get_lm_head(self) -> keras.layers.Layer:
        """
        LM 头部层。所有包含 LM 头部的模型必须重写此方法。

        Return:
            `keras.layers.Layer`: 如果模型有 LM 头部则返回该层,否则返回 None。
        """
        return None

    def resize_token_embeddings(
        self, new_num_tokens: Optional[int] = None
    ) -> Union[keras.layers.Embedding, tf.Variable]:
        """
        调整模型输入标记嵌入矩阵的大小,如果 `new_num_tokens != config.vocab_size`。

        在之后处理权重嵌入时要注意是否模型类有 `tie_weights()` 方法。

        Arguments:
            new_num_tokens (`int`, *optional*):
                嵌入矩阵中的新标记数量。增加大小将在末尾添加新初始化的向量,减小大小将从末尾删除向量。如果未提供或为 `None`,则仅返回输入标记的指针而不执行任何操作。

        Return:
            `tf.Variable` 或 `keras.layers.Embedding`: 模型输入标记的指针。
        """
        # TODO (joao): 因嵌入重构标记为替换标记(由 `_v2_resized_token_embeddings`)

        # 如果模型具有 keras 嵌入层,则运行新代码路径
        if isinstance(self.get_input_embeddings(), keras.layers.Embedding):
            return self._v2_resized_token_embeddings(new_num_tokens)

        # 如果 new_num_tokens 为 None 或等于 config.vocab_size,则返回当前输入标记的权重
        if new_num_tokens is None or new_num_tokens == self.config.vocab_size:
            return self._get_word_embedding_weight(self.get_input_embeddings())

        # 否则调整标记嵌入大小并返回模型嵌入
        model_embeds = self._resize_token_embeddings(new_num_tokens)

        # 更新基础模型和当前模型配置的词汇大小
        self.config.vocab_size = new_num_tokens

        return model_embeds
    # 调整模型的输入标记嵌入矩阵大小,如果 `new_num_tokens != config.vocab_size`。
    # 如果 `new_num_tokens` 为 `None` 或者与当前配置中的词汇表大小相同,则返回模型的输入标记嵌入指针。
    def _v2_resized_token_embeddings(self, new_num_tokens: Optional[int] = None) -> keras.layers.Embedding:
        if new_num_tokens is None or new_num_tokens == self.config.vocab_size:
            return self.get_input_embeddings()

        # 调整标记嵌入矩阵的大小,并获取调整后的模型嵌入层
        model_embeds = self._v2_resize_token_embeddings(new_num_tokens)

        # 更新基础模型和当前模型配置中的词汇表大小
        self.config.vocab_size = new_num_tokens

        # 返回调整后的模型嵌入层
        return model_embeds

    # 获取词嵌入权重的函数
    def _get_word_embedding_weight(model, embedding_layer):
        # TODO (joao): 根据嵌入重构的需求标记为删除

        # 如果 `embedding_layer` 是 `tf.Tensor` 类型,则返回它本身
        if isinstance(embedding_layer, tf.Tensor):
            return embedding_layer

        # 否则,尝试从层的属性中获取权重
        embeds = getattr(embedding_layer, "weight", None)
        if embeds is not None:
            return embeds

        # 尝试从层的 `decoder` 属性获取权重
        embeds = getattr(embedding_layer, "decoder", None)
        if embeds is not None:
            return embeds

        # 如果属性不存在可能是因为模型尚未构建,因此尝试在构建模型后再次获取
        model.build_in_name_scope()

        # 再次尝试从层的 `weight` 属性获取权重
        embeds = getattr(embedding_layer, "weight", None)
        if embeds is not None:
            return embeds

        # 再次尝试从层的 `decoder` 属性获取权重
        embeds = getattr(embedding_layer, "decoder", None)
        if embeds is not None:
            return embeds

        # 如果无法获取权重,则返回 `None`
        return None
    def _resize_token_embeddings(self, new_num_tokens):
        # TODO (joao): flagged for replacement (by `_v2_resize_token_embeddings`) due to embeddings refactor
        # 获取当前模型的词嵌入权重
        old_embeddings = self._get_word_embedding_weight(self.get_input_embeddings())
        # 调用私有方法,根据新的词汇量大小调整词嵌入权重
        new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)

        # 如果词嵌入没有被绑定,确保语言模型头部偏置也被调整大小
        if self.get_bias() is not None:
            # 获取当前的语言模型头部偏置
            old_lm_head_bias = self.get_bias()
            # 根据新的词汇量大小调整语言模型头部偏置
            new_lm_head_bias = self._get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens)
            # 设置调整后的语言模型头部偏置
            self.set_bias(new_lm_head_bias)

        # 如果词嵌入没有被绑定,确保语言模型头部解码器也被调整大小
        if self.get_output_embeddings() is not None:
            # 获取当前语言模型头部解码器的词嵌入权重
            old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings())
            # 根据新的词汇量大小调整语言模型头部解码器
            new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens)
            # 设置调整后的语言模型头部解码器
            self.set_output_embeddings(new_lm_head_decoder)

        # 设置调整后的输入词嵌入
        self.set_input_embeddings(new_embeddings)

        # 返回调整后的输入词嵌入
        return self.get_input_embeddings()

    def _v2_resize_token_embeddings(self, new_num_tokens):
        # 获取当前模型的输入词嵌入权重
        old_embeddings = self.get_input_embeddings()
        # 根据新的词汇量大小调整输入词嵌入权重
        new_embeddings = self._v2_get_resized_embeddings(old_embeddings, new_num_tokens)
        # 设置调整后的输入词嵌入权重
        self.set_input_embeddings(new_embeddings)

        # 如果词嵌入没有被绑定,确保语言模型头部偏置也被调整大小
        if self.get_bias() is not None:
            # 获取当前的语言模型头部偏置
            old_lm_head_bias = self.get_bias()
            # 根据新的词汇量大小调整语言模型头部偏置
            new_lm_head_bias = self._v2_get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens)
            # 设置调整后的语言模型头部偏置
            self.set_bias(new_lm_head_bias)

        # 如果词嵌入没有被绑定,确保语言模型头部解码器也被调整大小
        tied_weights = self.get_input_embeddings() == self.get_output_embeddings()
        if self.get_output_embeddings() is not None and not tied_weights:
            # 获取当前语言模型头部解码器的词嵌入权重
            old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings())
            # TODO (joao): this one probably needs a v2 version with other models
            # 根据新的词汇量大小调整语言模型头部解码器
            new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens)
            # 设置调整后的语言模型头部解码器
            self.set_output_embeddings(new_lm_head_decoder)

        # 返回调整后的输入词嵌入权重
        return self.get_input_embeddings()
   `
    def _get_resized_lm_head_bias(self, old_lm_head_bias, new_num_tokens):
        """
        Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end.
        Reducing the size will remove vectors from the end

        Args:
            old_lm_head_bias (`tf.Variable`):
                Old lm head bias to be resized.
            new_num_tokens (`int`, *optional*):
                New number of tokens in the linear matrix.

                Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
                vectors from the end. If not provided or `None`, just returns None

        Return:
            `tf.Variable`: Pointer to the resized bias.
        """
        # TODO (joao): flagged for replacement (by `_v2_get_resized_lm_head_bias`) due to embeddings refactor
        # Initialize an empty dictionary to store new biases
        new_lm_head_bias = {}

        # Iterate through each attribute and its corresponding weight in old_lm_head_bias
        for attr, weight in old_lm_head_bias.items():
            # Determine the shape of the weight tensor
            first_dim, old_num_tokens = (None, shape_list(weight)[0]) if tf.rank(weight) == 1 else shape_list(weight)
            # Calculate the difference in size between old and new number of tokens
            size_diff = new_num_tokens - old_num_tokens
            # Define the final shape of the bias tensor after resizing
            final_shape = [new_num_tokens] if first_dim is None else [first_dim, new_num_tokens]

            # Initialize or slice the current bias based on size difference
            if tf.math.greater(size_diff, 0):
                padding_shape = [[0, size_diff]] if first_dim is None else [[0, 0], [0, size_diff]]
                current_bias = tf.pad(weight.value(), tf.convert_to_tensor(padding_shape), constant_values=-1)
                num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
                mask_shape = [num_tokens_to_copy] if first_dim is None else [1, num_tokens_to_copy]
                bias_mask = tf.fill(tf.convert_to_tensor(mask_shape), True)
                bias_mask = tf.pad(bias_mask, tf.convert_to_tensor(padding_shape), constant_values=False)
            else:
                slice_from = [0] if first_dim is None else [0, 0]
                current_bias = tf.slice(
                    weight.value(), tf.convert_to_tensor(slice_from), tf.convert_to_tensor(final_shape)
                )
                bias_mask = tf.fill(tf.convert_to_tensor(final_shape), True)

            # Create a new bias variable and initialize it
            new_bias = self.add_weight(
                shape=final_shape,
                initializer="zeros",
                trainable=True,
                name=weight.name.split(":")[0],
            )
            init_bias = tf.where(bias_mask, current_bias, new_bias.value())

            # Assign the initialized bias to the new_bias variable and store it in new_lm_head_bias
            new_bias.assign(init_bias)
            new_lm_head_bias[attr] = new_bias

        # Return the dictionary containing resized biases
        return new_lm_head_bias
    ) -> Dict[str, tf.Tensor]:
        """
        Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end.
        Reducing the size will remove vectors from the end

        Args:
            old_lm_head_bias (`Dict[str, tf.Variable]`):
                Old lm head bias to be resized.
            new_num_tokens (`int`):
                New number of tokens in the linear matrix. Increasing the size will add newly initialized vectors at
                the end. Reducing the size will remove vectors from the end.

        Return:
            `tf.Tensor`: Values for the resized bias.
        """
        # Initialize an empty dictionary to store resized biases
        new_lm_head_bias = {}

        # Iterate over each attribute and its corresponding weight in the old_lm_head_bias dictionary
        for attr, weight in old_lm_head_bias.items():
            # Determine the shape of the weight tensor and calculate the size difference
            first_dim, old_num_tokens = (None, shape_list(weight)[0]) if tf.rank(weight) == 1 else shape_list(weight)
            size_diff = new_num_tokens - old_num_tokens

            # Copy the old bias values to the new bias tensor
            if old_num_tokens > new_num_tokens:
                # Trim the weight tensor if reducing size
                new_bias = weight.value()[..., :new_num_tokens]
            else:
                # Pad the weight tensor with zeros if increasing size
                padding_shape = [[0, size_diff]] if first_dim is None else [[0, 0], [0, size_diff]]
                new_bias = tf.pad(weight.value(), tf.convert_to_tensor(padding_shape))

            # Store the resized bias tensor in the new_lm_head_bias dictionary
            new_lm_head_bias[attr] = new_bias

        # Return the dictionary containing resized bias tensors
        return new_lm_head_bias
    def _get_resized_lm_head_decoder(self, old_lm_head_decoder, new_num_tokens):
        """
        Build a resized decoder from the old ones. Increasing the size will add newly initialized vectors at the end.
        Reducing the size will remove vectors from the end

        Args:
            old_lm_head_decoder (`tf.Variable`):
                Old lm head decoder to be resized.
            new_num_tokens (`int`, *optional*):
                New number of tokens in the linear matrix.

                Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
                vectors from the end. If not provided or `None`, just returns None

        Return:
            `tf.Variable`: Pointer to the resized decoder or None if the output embeddings are different from the input
            ones.
        """
        # 将新的 lm head 解码器初始化为旧的 lm head 解码器
        new_lm_head_decoder = old_lm_head_decoder

        # 检查输入嵌入矩阵和旧 lm head 解码器是否相同
        is_input_output_equals = tf.reduce_any(
            self._get_word_embedding_weight(self.get_input_embeddings()) == old_lm_head_decoder
        )

        # 如果旧 lm head 解码器不为 None 并且输入输出不相同
        if old_lm_head_decoder is not None and not is_input_output_equals:
            # 获取旧 lm head 解码器的维度
            old_embedding_dim = shape_list(old_lm_head_decoder)[1]

            # 初始化复制嵌入和解码器掩码
            decoder_mask, current_decoder = init_copy_embeddings(old_lm_head_decoder, new_num_tokens)

            # 创建新的 lm head 解码器,使用零初始化,可训练
            new_lm_head_decoder = self.add_weight(
                shape=(new_num_tokens, old_embedding_dim),
                initializer="zeros",
                trainable=True,
                name=old_lm_head_decoder.name.split(":")[0],
            )

            # 根据解码器掩码选择初始化策略
            init_decoder = tf.where(decoder_mask, current_decoder, new_lm_head_decoder.value())

            # 将初始化的解码器赋给新的 lm head 解码器
            new_lm_head_decoder.assign(init_decoder)

        # 返回新的 lm head 解码器
        return new_lm_head_decoder
    def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable:
        """
        Build a resized Embedding weights from a provided token Embedding weights. Increasing the size will add newly
        initialized vectors at the end. Reducing the size will remove vectors from the end

        Args:
            old_embeddings (`tf.Variable`):
                Old embeddings to be resized.
            new_num_tokens (`int`, *optional*):
                New number of tokens in the embedding matrix.

                Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
                vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
                `tf.Variable` module of the model without doing anything.

        Return:
            `tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if `new_num_tokens` is
            `None`
        """
        # TODO (joao): flagged for replacement (by `_v2_get_resized_embeddings`) due to embeddings refactor
        # 获取旧嵌入维度
        old_embedding_dim = shape_list(old_embeddings)[1]
        # 从配置中获取初始化范围
        init_range = getattr(self.config, "initializer_range", 0.02)
        # 初始化嵌入层并生成嵌入掩码和当前嵌入
        embeddings_mask, current_embeddings = init_copy_embeddings(old_embeddings, new_num_tokens)
        # 添加新权重,根据指定的形状和初始化器
        new_embeddings = self.add_weight(
            name=old_embeddings.name.split(":")[0],
            shape=[new_num_tokens, old_embedding_dim],
            initializer=get_initializer(init_range),
            dtype=tf.float32,
        )
        # 根据嵌入掩码选择初始化的嵌入值
        init_embeddings = tf.where(embeddings_mask, current_embeddings, new_embeddings.value())

        # 将初始化的嵌入值赋给新的嵌入层
        new_embeddings.assign(init_embeddings)

        # 返回新的嵌入层
        return new_embeddings
    ) -> keras.layers.Embedding:
        """
        Build a resized Embedding layer from a provided Embedding layer. Increasing the size will add newly initialized
        vectors at the end. Reducing the size will remove vectors from the end.

        Args:
            old_embeddings (`keras.layers.Embedding`):
                Old embeddings to be resized.
            new_num_tokens (`int`, *optional*):
                New number of tokens in the embedding matrix.

        Return:
            `keras.layers.Embedding`: Resized Embedding layer.
        """

        # Get the initialization range for the embeddings
        init_range = 0.02  # default value

        # Define potential variable names for initialization range
        potential_initialization_variable_names = [
            "initializer_range",  # most common
            "initializer_factor",  # e.g. T5
            "init_std",  # e.g BART
        ]

        # Iterate through potential variable names to find the correct initialization range
        for var_name in potential_initialization_variable_names:
            if hasattr(self.config, var_name):
                init_range = getattr(self.config, var_name)

        # Create a new Embedding layer with the specified parameters
        new_embeddings = keras.layers.Embedding(
            input_dim=new_num_tokens,
            output_dim=old_embeddings.output_dim,
            embeddings_initializer=keras.initializers.TruncatedNormal(stddev=init_range),
            name=old_embeddings.embeddings.name[:-13],  # exact same scoped name except "/embeddings:0"
        )
        
        # Initialize the new Embedding layer with a dummy input
        new_embeddings(tf.constant([[0]]))

        # Copy the old embeddings to the new embeddings
        if old_embeddings.input_dim >= new_num_tokens:
            init_embeddings = old_embeddings.embeddings[:new_num_tokens]
        else:
            init_embeddings = tf.concat(
                [old_embeddings.embeddings, new_embeddings.embeddings[old_embeddings.input_dim :]], axis=0
            )
        # Assign the initialized embeddings to the new embeddings layer
        new_embeddings.embeddings.assign(init_embeddings)
        
        # Return the resized Embedding layer
        return new_embeddings

    def prune_heads(self, heads_to_prune):
        """
        Prunes heads of the base model.

        Arguments:
            heads_to_prune (`Dict[int, List[int]]`):
                Dictionary with keys being selected layer indices (`int`) and associated values being the list of heads
                to prune in said layer (list of `int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on
                layer 1 and heads 2 and 3 on layer 2.
        """
        raise NotImplementedError

    def save_pretrained(
        self,
        save_directory,
        saved_model=False,
        version=1,
        push_to_hub=False,
        signatures=None,
        max_shard_size: Union[int, str] = "10GB",
        create_pr: bool = False,
        safe_serialization: bool = False,
        token: Optional[Union[str, bool]] = None,
        **kwargs,
    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
        *model_args,
        config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
        cache_dir: Optional[Union[str, os.PathLike]] = None,
        ignore_mismatched_sizes: bool = False,
        force_download: bool = False,
        local_files_only: bool = False,
        token: Optional[Union[str, bool]] = None,
        revision: str = "main",
        use_safetensors: bool = None,
        **kwargs,
    ):
        """
        从预训练模型加载模型或者模型配置。

        Args:
            pretrained_model_name_or_path (str or os.PathLike):
                预训练模型名称或路径。
            *model_args:
                模型特定的额外参数。
            config (PretrainedConfig, str, os.PathLike, optional):
                可选的模型配置。
            cache_dir (str or os.PathLike, optional):
                可选的缓存目录。
            ignore_mismatched_sizes (bool):
                是否忽略大小不匹配的警告,默认为 False。
            force_download (bool):
                是否强制下载模型,默认为 False。
            local_files_only (bool):
                是否只使用本地文件,默认为 False。
            token (str or bool, optional):
                可选的身份验证令牌。
            revision (str):
                模型版本,默认为 "main"。
            use_safetensors (bool, optional):
                是否使用安全张量,默认为 None。
            **kwargs:
                其他未指定的关键字参数。
        """

    def push_to_hub(
        self,
        repo_id: str,
        use_temp_dir: Optional[bool] = None,
        commit_message: Optional[str] = None,
        private: Optional[bool] = None,
        max_shard_size: Optional[Union[int, str]] = "10GB",
        token: Optional[Union[bool, str]] = None,
        # (`use_auth_token` is deprecated: we have to keep it here as we don't have **kwargs)
        use_auth_token: Optional[Union[bool, str]] = None,
        create_pr: bool = False,
        **base_model_card_args,
    ):
        """
        将模型推送到模型中心(Hub)的指定仓库。

        Args:
            repo_id (str):
                仓库的唯一标识符。
            use_temp_dir (bool, optional):
                是否使用临时目录,默认为 None。
            commit_message (str, optional):
                提交消息,用于版本控制。
            private (bool, optional):
                是否将仓库设置为私有,默认为 None。
            max_shard_size (int or str, optional):
                最大的分片大小限制,默认为 "10GB"。
            token (bool or str, optional):
                身份验证令牌。
            use_auth_token (bool or str, optional):
                (已弃用)身份验证令牌,用于兼容目的。
            create_pr (bool):
                是否创建 Pull Request,默认为 False。
            **base_model_card_args:
                其他基本模型卡片参数。
    @classmethod
    def register_for_auto_class(cls, auto_class="TFAutoModel"):
        """
        注册当前类到给定的自动加载类中。这仅用于自定义模型,因为库中的模型已经与自动加载类映射。

        <Tip warning={true}>
        该 API 是实验性的,可能在未来的发布中有些许更改。
        </Tip>

        Args:
            auto_class (str or type, optional, defaults to "TFAutoModel"):
                要注册新模型的自动加载类。
        """
        if not isinstance(auto_class, str):
            auto_class = auto_class.__name__

        import transformers.models.auto as auto_module

        if not hasattr(auto_module, auto_class):
            raise ValueError(f"{auto_class} 不是有效的自动加载类名。")

        cls._auto_class = auto_class
# 定义一个自定义的 1D 卷积层,按照 Radford 等人在 OpenAI GPT 中定义的方式(也用于 GPT-2)。

class TFConv1D(keras.layers.Layer):
    """
    1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).

    Basically works like a linear layer but the weights are transposed.

    Args:
        nf (`int`):
            The number of output features.
        nx (`int`):
            The number of input features.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation to use to initialize the weights.
        kwargs (`Dict[str, Any]`, *optional*):
            Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`.
    """

    def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
        super().__init__(**kwargs)
        self.nf = nf  # 输出特征的数量
        self.nx = nx  # 输入特征的数量
        self.initializer_range = initializer_range  # 初始化权重时的标准差

    def build(self, input_shape):
        if self.built:
            return
        self.built = True
        # 添加权重变量:weight 的形状为 [nx, nf],使用指定标准差的初始化器初始化
        self.weight = self.add_weight(
            "weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
        )
        # 添加偏置变量:bias 的形状为 [1, nf],使用零初始化器初始化
        self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer())

    def call(self, x):
        bz, sl = shape_list(x)[:2]  # 获取输入张量 x 的批量大小和序列长度

        x = tf.reshape(x, [-1, self.nx])  # 将输入张量 x 重塑为二维张量
        x = tf.matmul(x, self.weight) + self.bias  # 执行矩阵乘法和偏置加法操作

        x = tf.reshape(x, [bz, sl, self.nf])  # 将结果重新塑造为原始序列张量的形状

        return x


class TFSharedEmbeddings(keras.layers.Layer):
    r"""
    Construct shared token embeddings.

    The weights of the embedding layer is usually shared with the weights of the linear decoder when doing language
    modeling.

    Args:
        vocab_size (`int`):
            The size of the vocabulary, e.g., the number of unique tokens.
        hidden_size (`int`):
            The size of the embedding vectors.
        initializer_range (`float`, *optional*):
            The standard deviation to use when initializing the weights. If no value is provided, it will default to
            \\(1/\sqrt{hidden\_size}\\).
        kwargs (`Dict[str, Any]`, *optional*):
            Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`.
    """

    # TODO (joao): flagged for delection due to embeddings refactor

    def __init__(self, vocab_size: int, hidden_size: int, initializer_range: Optional[float] = None, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size  # 词汇表大小,即唯一标记的数量
        self.hidden_size = hidden_size  # 嵌入向量的大小
        self.initializer_range = hidden_size**-0.5 if initializer_range is None else initializer_range
        # 如果未提供初始化标准差,则默认为 1/√hidden_size
        warnings.warn(
            "`TFSharedEmbeddings` is scheduled for deletion in v4.32, use `keras.layers.Embedding` instead.",
            DeprecationWarning,
        )
    def build(self, input_shape):
        """
        Build shared token embedding layer.

        This method initializes the layer's weight matrix based on the specified vocabulary size and hidden size.
        The weight matrix is initialized using a custom initializer within the specified range.

        Args:
            input_shape (tuple): Shape tuple describing the input shape.

        Returns:
            None
        """
        self.weight = self.add_weight(
            "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range)
        )
        super().build(input_shape)

    def get_config(self):
        """
        Returns the configuration of the layer.

        Returns:
            dict: Configuration dictionary containing vocab_size, hidden_size, and initializer_range.
        """
        config = {
            "vocab_size": self.vocab_size,
            "hidden_size": self.hidden_size,
            "initializer_range": self.initializer_range,
        }
        base_config = super().get_config()

        return dict(list(base_config.items()) + list(config.items()))

    def call(self, inputs: tf.Tensor, mode: str = "embedding") -> tf.Tensor:
        """
        Get token embeddings of inputs or decode final hidden state.

        Args:
            inputs (`tf.Tensor`):
                In embedding mode, should be an int64 tensor with shape `[batch_size, length]`.
                In linear mode, should be a float tensor with shape `[batch_size, length, hidden_size]`.
            mode (`str`, defaults to `"embedding"`):
                A valid value is either `"embedding"` or `"linear"`, indicating the layer's usage mode.

        Returns:
            `tf.Tensor`: Depending on mode,
                - In embedding mode: float32 embedding tensor, shape `[batch_size, length, embedding_size]`.
                - In linear mode: float32 tensor, shape `[batch_size, length, vocab_size]`.

        Raises:
            ValueError: if `mode` is not valid.
        """
        if mode == "embedding":
            return self._embedding(inputs)
        elif mode == "linear":
            return self._linear(inputs)
        else:
            raise ValueError(f"mode {mode} is not valid.")

    def _embedding(self, input_ids):
        """
        Applies embedding based on input_ids tensor.

        Args:
            input_ids: Tensor containing token indices.

        Returns:
            `tf.Tensor`: Float32 embedding tensor.
        """
        return tf.gather(self.weight, input_ids)

    def _linear(self, inputs):
        """
        Computes logits by running inputs through a linear layer.

        Args:
            inputs: A float32 tensor with shape [..., hidden_size]

        Returns:
            `tf.Tensor`: Float32 tensor with shape [..., vocab_size].
        """
        first_dims = shape_list(inputs)[:-1]
        x = tf.reshape(inputs, [-1, self.hidden_size])
        logits = tf.matmul(x, self.weight, transpose_b=True)

        return tf.reshape(logits, first_dims + [self.vocab_size])
class TFSequenceSummary(keras.layers.Layer):
    """
    Compute a single vector summary of a sequence hidden states.

    Args:
        config ([`PretrainedConfig`]):
            The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
            config class of your model for the default values it uses):

            - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:

                - `"last"` -- Take the last token hidden state (like XLNet)
                - `"first"` -- Take the first token hidden state (like Bert)
                - `"mean"` -- Take the mean of all tokens hidden states
                - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
                - `"attn"` -- Not implemented now, use multi-head attention

            - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
            - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
              (otherwise to `config.hidden_size`).
            - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
              another string or `None` will add no activation.
            - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
            - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.

        initializer_range (`float`, defaults to 0.02): The standard deviation to use to initialize the weights.
        kwargs (`Dict[str, Any]`, *optional*):
            Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`.
    """
    # 初始化函数,接受预训练配置和其他可选参数
    def __init__(self, config: PretrainedConfig, initializer_range: float = 0.02, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)

        # 根据配置确定摘要类型,如果配置中有 summary_use_proj 属性则使用其值,否则默认为 "last"
        self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last"
        
        # 如果摘要类型为 "attn",抛出未实现错误,建议使用标准的多头注意力模块
        if self.summary_type == "attn":
            raise NotImplementedError
       
        # 判断配置中是否有 summary_use_proj 属性并且其值为 True,表示需要进行投影操作
        self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj
        if self.has_summary:
            # 如果配置中定义了 summary_proj_to_labels 并且其值为 True,并且 num_labels 大于 0
            if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
                num_classes = config.num_labels
            else:
                num_classes = config.hidden_size
            # 创建一个全连接层,用于摘要投影,输出维度为 num_classes
            self.summary = keras.layers.Dense(
                num_classes, kernel_initializer=get_initializer(initializer_range), name="summary"
            )

        # 判断配置中是否定义了 summary_activation 属性,如果有则设置相应的激活函数
        self.has_activation = False
        activation_string = getattr(config, "summary_activation", None)
        if activation_string is not None:
            self.has_activation = True
            self.activation = get_tf_activation(activation_string)

        # 判断配置中是否定义了 summary_first_dropout 属性并且其值大于 0,如果是则创建首层 Dropout
        self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0
        if self.has_first_dropout:
            self.first_dropout = keras.layers.Dropout(config.summary_first_dropout)

        # 判断配置中是否定义了 summary_last_dropout 属性并且其值大于 0,如果是则创建末层 Dropout
        self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
        if self.has_last_dropout:
            self.last_dropout = keras.layers.Dropout(config.summary_last_dropout)
        
        # 将隐藏大小设置为配置中定义的 hidden_size
        self.hidden_size = config.hidden_size
    # 定义一个方法 `call`,用于执行模型的前向传播
    def call(self, inputs, cls_index=None, training=False):
        # 检查输入是否为字典、元组或列表,若不是则直接使用 `inputs` 作为隐藏状态
        if not isinstance(inputs, (dict, tuple, list)):
            hidden_states = inputs
        elif isinstance(inputs, (tuple, list)):
            # 若输入为元组或列表,则将第一个元素作为隐藏状态,第二个元素作为 `cls_index`(若有的话)
            hidden_states = inputs[0]
            cls_index = inputs[1] if len(inputs) > 1 else None
            assert len(inputs) <= 2, "Too many inputs."  # 断言输入的长度不超过2,否则报错
        else:
            # 若输入为字典,则从中获取 `hidden_states` 和 `cls_index`(默认为 None)
            hidden_states = inputs.get("hidden_states")
            cls_index = inputs.get("cls_index", None)

        # 根据 `summary_type` 选择如何汇总隐藏状态
        if self.summary_type == "last":
            output = hidden_states[:, -1]  # 取最后一个时间步的隐藏状态作为输出
        elif self.summary_type == "first":
            output = hidden_states[:, 0]  # 取第一个时间步的隐藏状态作为输出
        elif self.summary_type == "mean":
            output = tf.reduce_mean(hidden_states, axis=1)  # 对隐藏状态在第一维(batch 维度)上取平均
        elif self.summary_type == "cls_index":
            # 根据给定的 `cls_index` 从隐藏状态中取出对应位置的向量
            hidden_shape = shape_list(hidden_states)  # 获取隐藏状态的形状信息
            if cls_index is None:
                # 若 `cls_index` 为 None,则默认选择每个样本序列的最后一个位置
                cls_index = tf.fill(
                    hidden_shape[:-2], hidden_shape[-2] - 1
                )  # 创建一个张量,形状为 [batch] 或 [batch, num choices],填充为序列长度
            cls_shape = shape_list(cls_index)
            if len(cls_shape) <= len(hidden_shape) - 2:
                cls_index = tf.expand_dims(cls_index, axis=-1)  # 在最后一维上扩展 `cls_index`
            # output shape: (batch, num choices, hidden_size)
            output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2)
            output = tf.squeeze(
                output, axis=len(hidden_shape) - 2
            )  # 压缩维度,输出形状为 (batch, num choices, hidden_size)
        elif self.summary_type == "attn":
            raise NotImplementedError  # 如果 `summary_type` 是 "attn",则抛出未实现错误

        # 若模型具有第一个 dropout 层,则在输出上应用该 dropout
        if self.has_first_dropout:
            output = self.first_dropout(output, training=training)

        # 若模型具有汇总方法,则将输出传递给汇总方法
        if self.has_summary:
            output = self.summary(output)

        # 若模型具有激活函数,则将输出传递给激活函数
        if self.has_activation:
            output = self.activation(output)

        # 若模型具有最后一个 dropout 层,则在输出上应用该 dropout
        if self.has_last_dropout:
            output = self.last_dropout(output, training=training)

        return output

    # 构建模型,在输入形状已知的情况下进行构建
    def build(self, input_shape):
        if self.built:
            return  # 如果模型已经构建过,则直接返回
        self.built = True  # 标记模型已经构建
        if getattr(self, "summary", None) is not None:
            with tf.name_scope("summary"):
                self.summary.build(self.hidden_size)  # 使用汇总方法构建汇总层
# 定义一个函数,用于创建具有指定范围的截断正态分布初始化器
def get_initializer(initializer_range: float = 0.02) -> keras.initializers.TruncatedNormal:
    """
    Creates a `keras.initializers.TruncatedNormal` with the given range.

    Args:
        initializer_range (*float*, defaults to 0.02): Standard deviation of the initializer range.

    Returns:
        `keras.initializers.TruncatedNormal`: The truncated normal initializer.
    """
    # 返回一个截断正态分布初始化器对象,其标准差由参数 initializer_range 指定
    return keras.initializers.TruncatedNormal(stddev=initializer_range)

.\modeling_utils.py

# 导入 Python 内置和第三方库
import collections  # 导入 collections 模块,用于扩展内置容器数据类型
import copy  # 导入 copy 模块,用于对象复制操作
import functools  # 导入 functools 模块,用于高阶函数操作
import gc  # 导入 gc 模块,Python 的垃圾回收模块
import importlib.metadata  # 导入 importlib.metadata 模块,用于元数据获取
import inspect  # 导入 inspect 模块,用于解析源码
import itertools  # 导入 itertools 模块,用于创建和操作迭代器的函数
import json  # 导入 json 模块,用于 JSON 数据的编解码
import os  # 导入 os 模块,用于与操作系统交互
import re  # 导入 re 模块,用于正则表达式操作
import shutil  # 导入 shutil 模块,用于文件操作的高级函数
import tempfile  # 导入 tempfile 模块,用于创建临时文件和目录
import warnings  # 导入 warnings 模块,用于警告控制

# 导入 typing 模块中的类型
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

# 导入第三方库 torch
import torch  # 导入 PyTorch 深度学习库
from packaging import version  # 从 packaging 模块导入 version 子模块
from torch import Tensor, nn  # 从 torch 模块导入 Tensor 和 nn(神经网络)子模块
from torch.nn import CrossEntropyLoss, Identity  # 从 torch.nn 模块导入 CrossEntropyLoss 和 Identity 类
from torch.utils.checkpoint import checkpoint  # 从 torch.utils.checkpoint 模块导入 checkpoint 函数

# 导入本地的模块和函数
from .activations import get_activation  # 从当前目录的 activiations 模块导入 get_activation 函数
from .configuration_utils import PretrainedConfig  # 从当前目录的 configuration_utils 模块导入 PretrainedConfig 类
from .dynamic_module_utils import custom_object_save  # 从当前目录的 dynamic_module_utils 模块导入 custom_object_save 函数
from .generation import GenerationConfig, GenerationMixin  # 从当前目录的 generation 模块导入 GenerationConfig 和 GenerationMixin 类
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled  # 从当前目录的 integrations 模块导入若干函数和类
from .pytorch_utils import (  # 从当前目录的 pytorch_utils 模块导入若干函数和类,忽略 F401 错误
    Conv1D,
    apply_chunking_to_forward,
    find_pruneable_heads_and_indices,
    id_tensor_storage,
    is_torch_greater_or_equal_than_1_13,
    prune_conv1d_layer,
    prune_layer,
    prune_linear_layer,
)
from .quantizers import AutoHfQuantizer, HfQuantizer  # 从当前目录的 quantizers 模块导入 AutoHfQuantizer 和 HfQuantizer 类
from .quantizers.quantizers_utils import get_module_from_name  # 从当前目录的 quantizers.quantizers_utils 模块导入 get_module_from_name 函数
from .safetensors_conversion import auto_conversion  # 从当前目录的 safetensors_conversion 模块导入 auto_conversion 函数
from .utils import (  # 从当前目录的 utils 模块导入若干函数、常量和类
    ADAPTER_SAFE_WEIGHTS_NAME,
    ADAPTER_WEIGHTS_NAME,
    CONFIG_NAME,
    DUMMY_INPUTS,
    FLAX_WEIGHTS_NAME,
    SAFE_WEIGHTS_INDEX_NAME,
    SAFE_WEIGHTS_NAME,
    TF2_WEIGHTS_NAME,
    TF_WEIGHTS_NAME,
    WEIGHTS_INDEX_NAME,
    WEIGHTS_NAME,
    ContextManagers,
    ModelOutput,
    PushToHubMixin,
    cached_file,
    copy_func,
    download_url,
    extract_commit_hash,
    has_file,
    is_accelerate_available,
    is_bitsandbytes_available,
    is_flash_attn_2_available,
    is_offline_mode,
    is_optimum_available,
    is_peft_available,
    is_remote_url,
    is_safetensors_available,
    is_torch_sdpa_available,
    is_torch_xla_available,
    logging,
    replace_return_docstrings,
    strtobool,
)
from .utils.hub import convert_file_size_to_int, create_and_tag_model_card, get_checkpoint_shard_files  # 从当前目录的 utils.hub 模块导入若干函数
from .utils.import_utils import (  # 从当前目录的 utils.import_utils 模块导入若干函数和常量
    ENV_VARS_TRUE_VALUES,
    is_sagemaker_mp_enabled,
    is_torch_fx_proxy,
)
    is_torchdynamo_compiling,
# 导入所需模块和变量
from .utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
# 设置环境变量 XLA_USE_BF16,指定默认值为 "0" 并转换为大写
XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
# 设置环境变量 XLA_DOWNCAST_BF16,指定默认值为 "0" 并转换为大写
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()

# 如果加速库可用
if is_accelerate_available():
    # 导入加速库相关模块和函数
    from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
    from accelerate.hooks import add_hook_to_module
    from accelerate.utils import (
        check_tied_parameters_on_same_device,
        find_tied_parameters,
        get_balanced_memory,
        get_max_memory,
        load_offloaded_weights,
        offload_weight,
        save_offload_index,
        set_module_tensor_to_device,
    )

# 如果 SafeTensors 库可用
if is_safetensors_available():
    # 导入 SafeTensors 库相关函数
    from safetensors import safe_open
    from safetensors.torch import load_file as safe_load_file
    from safetensors.torch import save_file as safe_save_file

# 获取日志记录器
logger = logging.get_logger(__name__)

# 初始化权重标记
_init_weights = True

# 检查是否启用了 FSDP(Fully Sharded Data Parallelism)
def is_fsdp_enabled():
    return (
        torch.distributed.is_available()  # 检查是否支持分布式训练
        and torch.distributed.is_initialized()  # 检查是否已初始化分布式环境
        and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1  # 检查是否启用 FSDP
        and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1  # 检查是否启用 FSDP CPU 和 RAM 的高效加载
    )

# 检查当前进程是否是本地分布式训练的主进程(rank 0)
def is_local_dist_rank_0():
    return (
        torch.distributed.is_available()  # 检查是否支持分布式训练
        and torch.distributed.is_initialized()  # 检查是否已初始化分布式环境
        and int(os.environ.get("LOCAL_RANK", -1)) == 0  # 检查本地进程的分布式训练排名是否为 0
    )

# 如果 SageMaker Model Parallelism 可用
if is_sagemaker_mp_enabled():
    # 导入 SageMaker Model Parallelism 相关模块和函数
    import smdistributed.modelparallel.torch as smp
    from smdistributed.modelparallel import __version__ as SMP_VERSION

    # 检查是否为 SageMaker MP 1.10 版本之后
    IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
else:
    IS_SAGEMAKER_MP_POST_1_10 = False

# 如果 PEFT 可用
if is_peft_available():
    # 从 utils 模块中导入 find_adapter_config_file 函数
    from .utils import find_adapter_config_file

# 定义 Torch 初始化函数字典
TORCH_INIT_FUNCTIONS = {
    "uniform_": nn.init.uniform_,
    "normal_": nn.init.normal_,
    "trunc_normal_": nn.init.trunc_normal_,
    "constant_": nn.init.constant_,
    "xavier_uniform_": nn.init.xavier_uniform_,
    "xavier_normal_": nn.init.xavier_normal_,
    "kaiming_uniform_": nn.init.kaiming_uniform_,
    "kaiming_normal_": nn.init.kaiming_normal_,
    "uniform": nn.init.uniform,
    "normal": nn.init.normal,
    "xavier_uniform": nn.init.xavier_uniform,
    "xavier_normal": nn.init.xavier_normal,
    "kaiming_uniform": nn.init.kaiming_uniform,
    "kaiming_normal": nn.init.kaiming_normal,
}

# 上下文管理器,用于全局禁用模型初始化权重以加快大模型加载速度
@contextmanager
def no_init_weights(_enable=True):
    """
    Context manager to globally disable weight initialization to speed up loading large models.

    TODO(Patrick): Delete safety argument `_enable=True` at next major version. .
    """
    global _init_weights
    old_init_weights = _init_weights

    if _enable:
        _init_weights = False

        def _skip_init(*args, **kwargs):
            pass

        # 临时替换 Torch 初始化函数为 _skip_init 函数
        for name, init_func in TORCH_INIT_FUNCTIONS.items():
            setattr(torch.nn.init, name, _skip_init)
    try:
        yield
    finally:
        # 恢复原始的初始化权重函数
        _init_weights = old_init_weights
        if _enable:
            # 如果启用了初始化函数替换
            # 遍历 TORCH_INIT_FUNCTIONS 字典中的每一项
            for name, init_func in TORCH_INIT_FUNCTIONS.items():
                # 将 torch.nn.init 中的初始化函数名 name 恢复为原始函数 init_func
                setattr(torch.nn.init, name, init_func)
def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
    try:
        # 尝试获取参数的第一个参数并返回其设备信息
        return next(parameter.parameters()).device
    except StopIteration:
        # 对于 nn.DataParallel 在 PyTorch 1.5 及以上版本的兼容性处理

        def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
            tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
            return tuples

        # 从参数中获取命名成员的生成器
        gen = parameter._named_members(get_members_fn=find_tensor_attributes)
        # 获取第一个生成器产生的元组,并返回其设备信息
        first_tuple = next(gen)
        return first_tuple[1].device


def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
    """
    Returns the first parameter dtype (can be non-floating) or asserts if none were found.
    """
    try:
        # 尝试获取参数的第一个参数并返回其数据类型
        return next(parameter.parameters()).dtype
    except StopIteration:
        # 对于 nn.DataParallel 在 PyTorch 大于 1.5 版本的兼容性处理

        def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
            tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
            return tuples

        # 从参数中获取命名成员的生成器
        gen = parameter._named_members(get_members_fn=find_tensor_attributes)
        # 获取第一个生成器产生的元组,并返回其数据类型
        first_tuple = next(gen)
        return first_tuple[1].dtype


def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
    """
    Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
    """
    last_dtype = None
    # 遍历参数的所有参数
    for t in parameter.parameters():
        last_dtype = t.dtype
        if t.is_floating_point():
            # 添加修复 https://github.com/pytorch/xla/issues/4152
            # 修复模型代码传递的数值超出 XLA_USE_BF16=1 和 XLA_DOWNCAST_BF16=1 的范围,导致转换为 -inf 的问题
            # 注意: `is_torch_xla_available()` 是最后检查的,因为它会在 torch dynamo 中引入图形断裂
            if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
                return torch.bfloat16
            if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
                if t.dtype == torch.float:
                    return torch.bfloat16
                if t.dtype == torch.double:
                    return torch.float32
            return t.dtype

    # 如果找不到浮点数据类型,则返回最后一个找到的数据类型
    if last_dtype is not None:
        return last_dtype

    # 对于 nn.DataParallel 在 PyTorch 大于 1.5 版本的兼容性处理
    def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
        tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
        return tuples

    # 从参数中获取命名成员的生成器
    gen = parameter._named_members(get_members_fn=find_tensor_attributes)
    last_tuple = None
    # 遍历生成器中的元组
    for tuple in gen:
        last_tuple = tuple
        if tuple[1].is_floating_point():
            return tuple[1].dtype
    # 如果 last_tuple 不是 None,则返回 last_tuple 中第二个元素的数据类型作为结果
    if last_tuple is not None:
        return last_tuple[1].dtype
    
    # 如果 last_tuple 是 None,则尝试使用 parameter 中的缓冲区的数据类型作为结果
    for t in parameter.buffers():
        # 记录每次迭代中 t 的数据类型到 last_dtype
        last_dtype = t.dtype
        # 如果 t 是浮点数类型,则返回 t 的数据类型作为结果
        if t.is_floating_point():
            return t.dtype
    
    # 如果所有缓冲区都不是浮点数类型,则返回最后一次迭代中记录的数据类型作为结果
    return last_dtype
# 返回 `state_dict` 中第一个浮点数据类型,如果没有则抛出异常
def get_state_dict_float_dtype(state_dict):
    for t in state_dict.values():  # 遍历 `state_dict` 中的每个值
        if t.is_floating_point():  # 检查当前值是否为浮点数类型
            return t.dtype  # 返回该值的数据类型

    raise ValueError("couldn't find any floating point dtypes in state_dict")  # 如果没有找到浮点数据类型则抛出异常


# 返回 `state_dict` 中第一个浮点数据类型,如果没有则返回第一个数据类型
def get_state_dict_dtype(state_dict):
    for t in state_dict.values():  # 遍历 `state_dict` 中的每个值
        if t.is_floating_point():  # 检查当前值是否为浮点数类型
            return t.dtype  # 返回该值的数据类型

    # 如果没有找到浮点数据类型,则返回 `state_dict` 中第一个值的数据类型
    else:
        return next(state_dict.values()).dtype


# 返回指定数据类型 `dtype` 的参数占据的字节数
def dtype_byte_size(dtype):
    if dtype == torch.bool:  # 如果数据类型是布尔类型
        return 1 / 8  # 返回布尔类型参数占据的字节数
    bit_search = re.search(r"[^\d](\d+)$", str(dtype))  # 从数据类型字符串中搜索位数信息
    if bit_search is None:  # 如果未找到有效的数据类型
        raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")  # 抛出数据类型无效的异常
    bit_size = int(bit_search.groups()[0])  # 提取数据类型的位数
    return bit_size // 8  # 返回数据类型参数占据的字节数


# 将模型状态字典 `state_dict` 分割为多个子检查点,使每个子检查点的最终大小不超过指定大小
def shard_checkpoint(
    state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME
):
    max_shard_size = convert_file_size_to_int(max_shard_size)  # 将最大分片大小转换为整数形式

    sharded_state_dicts = [{}]  # 初始化一个空的分片状态字典列表
    last_block_size = 0  # 初始化最后一个分片的大小
    total_size = 0  # 初始化总大小
    storage_id_to_block = {}  # 初始化存储 ID 到分片索引的映射表
    # 遍历状态字典中的每个键值对,其中键为参数名,值为参数的权重
    for key, weight in state_dict.items():
        # 如果权重是字符串类型,跳过当前循环,因为在序列化时使用了 BNB,可能出现这种情况
        # 可参考:https://github.com/huggingface/transformers/pull/24416 获取更多细节
        if isinstance(weight, str):
            continue
        else:
            # 获取权重张量的存储 ID
            storage_id = id_tensor_storage(weight)

        # 如果某个权重共享相同的底层存储,则将该权重放入相同的“块”中
        if storage_id in storage_id_to_block:
            block_id = storage_id_to_block[storage_id]
            sharded_state_dicts[block_id][key] = weight
            continue

        # 计算当前权重的字节大小
        weight_size = weight.numel() * dtype_byte_size(weight.dtype)

        # 如果当前块的总大小加上当前权重的大小超过了最大分片大小,并且当前块中至少有一个权重,
        # 则将当前块分片,创建一个新的空字典作为新的块,并重置当前块的大小
        if last_block_size + weight_size > max_shard_size and len(sharded_state_dicts[-1]) > 0:
            sharded_state_dicts.append({})
            last_block_size = 0

        # 将当前权重添加到当前块中
        sharded_state_dicts[-1][key] = weight
        # 更新当前块的总大小
        last_block_size += weight_size
        # 将当前权重的存储 ID 映射到对应的块索引
        storage_id_to_block[storage_id] = len(sharded_state_dicts) - 1

    # 如果只有一个分片,直接返回该分片
    if len(sharded_state_dicts) == 1:
        return {weights_name: sharded_state_dicts[0]}, None

    # 否则,构建索引
    weight_map = {}
    shards = {}
    # 遍历所有分片,为每个分片创建一个文件名,并将分片及其对应的键添加到 shards 和 weight_map 中
    for idx, shard in enumerate(sharded_state_dicts):
        shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
        shard_file = shard_file.replace(
            ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
        )
        shards[shard_file] = shard
        for key in shard.keys():
            weight_map[key] = shard_file

    # 添加元数据
    metadata = {"total_size": total_size}
    # 构建索引结构,包括元数据和权重映射
    index = {"metadata": metadata, "weight_map": weight_map}
    return shards, index
# 加载分片检查点的函数,用于从文件夹中加载模型的状态字典
def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
    """
    This is the same as
    [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
    but for a sharded checkpoint.

    This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
    loaded in the model.

    Args:
        model (`torch.nn.Module`): The model in which to load the checkpoint.
        folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
        strict (`bool`, *optional`, defaults to `True`):
            Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
        prefer_safe (`bool`, *optional*, defaults to `False`)
            If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the
            safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible.

    Returns:
        `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields
            - `missing_keys` is a list of str containing the missing keys
            - `unexpected_keys` is a list of str containing the unexpected keys
    """
    # 拼接索引文件的路径
    index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
    # 拼接安全索引文件的路径
    safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME)

    # 检查索引文件和安全索引文件是否存在
    index_present = os.path.isfile(index_file)
    safe_index_present = os.path.isfile(safe_index_file)

    # 如果既没有索引文件也没有安全索引文件,则抛出错误
    if not index_present and not (safe_index_present and is_safetensors_available()):
        filenames = (
            (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME) if is_safetensors_available() else (WEIGHTS_INDEX_NAME,)
        )
        raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.")

    # 根据 prefer_safe 的设置确定加载哪种索引文件
    load_safe = False
    if safe_index_present:
        if prefer_safe:
            if is_safetensors_available():
                load_safe = True  # 根据偏好加载安全索引文件
            else:
                logger.warning(
                    f"Cannot load sharded checkpoint at {folder} safely since safetensors is not installed!"
                )
        elif not index_present:
            load_safe = True  # 因为没有其他选择,所以加载安全索引文件

    load_index = safe_index_file if load_safe else index_file

    # 使用 utf-8 编码打开加载索引文件,并解析为 JSON 格式
    with open(load_index, "r", encoding="utf-8") as f:
        index = json.load(f)

    # 获取所有分片文件的路径
    shard_files = list(set(index["weight_map"].values()))

    # 如果 strict=True,则在加载任何状态字典之前检查错误
    loaded_keys = index["weight_map"].keys()
    model_keys = model.state_dict().keys()

    # 查找模型中缺失的键和索引中未预料到的键
    missing_keys = [key for key in model_keys if key not in loaded_keys]
    unexpected_keys = [key for key in loaded_keys if key not in model_keys]
    # 如果 strict 为 True 并且存在缺失的键或者不期望的键
    if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
        # 构建错误信息,指明加载 state_dict 时出错的模型类名
        error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
        
        # 如果存在缺失的键
        if len(missing_keys) > 0:
            # 构建缺失键的字符串表示,用逗号分隔
            str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
            error_message += f"\nMissing key(s): {str_missing_keys}."
        
        # 如果存在不期望的键
        if len(unexpected_keys) > 0:
            # 构建不期望键的字符串表示,用逗号分隔
            str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
            error_message += f"\nMissing key(s): {str_unexpected_keys}."
        
        # 抛出运行时异常,显示错误信息
        raise RuntimeError(error_message)

    # 根据 torch 版本创建用于加载文件的 loader 函数
    weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
    loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg)

    # 遍历每个分片文件
    for shard_file in shard_files:
        # 使用 loader 加载分片文件的 state_dict
        state_dict = loader(os.path.join(folder, shard_file))
        
        # 将加载的 state_dict 应用到模型中,strict 设置为 False
        model.load_state_dict(state_dict, strict=False)

        # 在加载下一个 state_dict 之前确保释放内存
        del state_dict
        gc.collect()

    # 返回与 PyTorch load_state_dict 函数相同的对象,用于处理不兼容键
    return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
def load_state_dict(checkpoint_file: Union[str, os.PathLike], is_quantized: bool = False):
    """
    Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
    """

    # 如果检查点文件以 ".safetensors" 结尾且安全张量可用
    if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
        # 检查归档格式
        with safe_open(checkpoint_file, framework="pt") as f:
            metadata = f.metadata()
        # 如果归档中的元数据格式不在有效列表 ["pt", "tf", "flax", "mlx"] 中,则抛出异常
        if metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
            raise OSError(
                f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
                "you save your model with the `save_pretrained` method."
            )
        # 加载安全张量文件
        return safe_load_file(checkpoint_file)

    try:
        # 处理特定条件下的 `map_location`
        if (
            (is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0)
            or (is_fsdp_enabled() and not is_local_dist_rank_0())
        ) and not is_quantized:
            map_location = "meta"
        else:
            map_location = "cpu"

        extra_args = {}
        # 如果 `checkpoint_file` 是字符串,并且不是 `meta` `map_location`,且 PyTorch 版本 >= 2.1.0,并且是 Zip 格式文件,则启用 `mmap`
        if (
            isinstance(checkpoint_file, str)
            and map_location != "meta"
            and version.parse(torch.__version__) >= version.parse("2.1.0")
            and is_zipfile(checkpoint_file)
        ):
            extra_args = {"mmap": True}

        weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
        
        # 使用 PyTorch 加载检查点文件
        return torch.load(
            checkpoint_file,
            map_location=map_location,
            **weights_only_kwarg,
            **extra_args,
        )

    except Exception as e:
        try:
            with open(checkpoint_file) as f:
                # 检查文件是否以 "version" 开头,如果是,则可能是未安装 git-lfs 的情况
                if f.read(7) == "version":
                    raise OSError(
                        "You seem to have cloned a repository without having git-lfs installed. Please install "
                        "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
                        "you cloned."
                    )
                else:
                    raise ValueError(
                        f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
                        "model. Make sure you have saved the model properly."
                    ) from e
        except (UnicodeDecodeError, ValueError):
            # 如果无法读取文件内容,抛出加载异常
            raise OSError(
                f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
                f"at '{checkpoint_file}'. "
                "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
            )


def set_initialized_submodules(model, state_dict_keys):
    """
    Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state
    dict.
    """
    # 创建一个空字典,用于存储未初始化的子模块
    not_initialized_submodules = {}
    # 遍历模型中所有命名的模块及其对应的名称
    for module_name, module in model.named_modules():
        # 从状态字典键集合中提取加载的键集合,去除模块名称前缀
        loaded_keys = {k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")}
        # 检查加载的键集合是否完全包含模块的状态字典的所有键
        if loaded_keys.issuperset(module.state_dict()):
            # 如果是,则标记模块为已由Hugging Face初始化
            module._is_hf_initialized = True
        else:
            # 否则将未初始化的模块添加到未初始化子模块字典中
            not_initialized_submodules[module_name] = module
    # 返回所有未初始化的子模块字典
    return not_initialized_submodules
# 将给定的模型加载状态字典到模型中,修改旧格式为新格式(如果需要)
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
    # 查找所有含有特定关键词的键,将其转换为新的键名
    old_keys = []
    new_keys = []
    for key in state_dict.keys():
        new_key = None
        if "gamma" in key:
            new_key = key.replace("gamma", "weight")
        if "beta" in key:
            new_key = key.replace("beta", "bias")
        if new_key:
            old_keys.append(key)
            new_keys.append(new_key)
    # 替换旧键为新键
    for old_key, new_key in zip(old_keys, new_keys):
        state_dict[new_key] = state_dict.pop(old_key)

    # 复制状态字典以便 _load_from_state_dict 可以修改它
    metadata = getattr(state_dict, "_metadata", None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    error_msgs = []

    # PyTorch 的 `_load_from_state_dict` 不会复制模块子类中的参数,
    # 所以需要递归应用该函数
    def load(module: nn.Module, state_dict, prefix=""):
        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
        args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
        # 模块及其子模块的参数将以给定的前缀开头,如果在状态字典中不存在这些参数,则可以提前退出
        if len([key for key in state_dict if key.startswith(prefix)]) > 0:
            if is_deepspeed_zero3_enabled():
                import deepspeed

                # 在分片模型中,每个分片只有部分完整状态字典,因此只收集当前状态字典中存在的参数
                named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
                params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
                if len(params_to_gather) > 0:
                    # 因为 zero3 在模型参数中放置占位符,所以这个上下文管理器会收集(取消分片)当前层的参数,
                    # 然后从状态字典中加载,再重新分片
                    with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
                        if torch.distributed.get_rank() == 0:
                            module._load_from_state_dict(*args)
            else:
                module._load_from_state_dict(*args)

        # 递归加载子模块的参数
        for name, child in module._modules.items():
            if child is not None:
                load(child, state_dict, prefix + name + ".")

    # 开始递归加载模型
    load(model_to_load, state_dict, prefix=start_prefix)
    # 删除 `state_dict`,以便更早地由 GC 回收。注意 `state_dict` 是参数的副本,因此可以安全删除它。
    del state_dict

    return error_msgs
    # 辅助函数:查找最后一个子模块及其参数/缓冲区名称。如果提供了 `start_prefix`,则将其从键的开头移除。
    if len(start_prefix) > 0 and long_key.startswith(start_prefix):
        # 如果 `start_prefix` 长度大于零且 `long_key` 以 `start_prefix` 开头,则移除 `start_prefix`
        long_key = ".".join(long_key.split(".")[1:])
    
    # 按照点号分割长键名
    split_key = long_key.split(".")
    # 从模型开始查找子模块
    submodule = model
    while len(split_key) > 1:
        if hasattr(submodule, split_key[0]):
            # 如果模块具有当前分割键名对应的属性,则获取该属性作为下一级子模块
            submodule = getattr(submodule, split_key[0])
            # 删除已处理的键名
            del split_key[0]
        else:
            # 如果模块不具有当前分割键名对应的属性,则子模块置为 None,跳出循环
            submodule = None
            break
    
    # 如果最终找到的子模块仍然是初始的模型,说明未找到匹配的子模块
    if submodule == model:
        submodule = None
    # 返回最后找到的子模块及剩余的键名
    return submodule, split_key[0]
    # 将 `loaded_state_dict_keys` 中的参数移到模型的元设备上,从而释放这些参数占用的内存空间。
    # `start_prefix` 用于包含模型名称的模型键,例如在 `bert.pooler.dense.weight` 中的 `bert`。

    # 初始化错误信息列表
    error_msgs = []

    # 初始化旧键和新键列表,用于处理特定的参数重命名情况
    old_keys = []
    new_keys = []

    # 检查是否进行了量化操作
    is_quantized = hf_quantizer is not None

    # 遍历 `state_dict` 中的所有键
    for key in state_dict.keys():
        new_key = None

        # 替换特定键名中的 "gamma" 为 "weight"
        if "gamma" in key:
            new_key = key.replace("gamma", "weight")

        # 替换特定键名中的 "beta" 为 "bias"
        if "beta" in key:
            new_key = key.replace("beta", "bias")

        # 如果有新的键名生成,则将原键名添加到旧键列表,将新键名添加到新键列表
        if new_key:
            old_keys.append(key)
            new_keys.append(new_key)
    # 遍历两个列表 old_keys 和 new_keys,依次将 state_dict 中 old_key 对应的值替换为 new_key,并更新 state_dict。
    for old_key, new_key in zip(old_keys, new_keys):
        state_dict[new_key] = state_dict.pop(old_key)
    
    # 返回三个变量作为结果:error_msgs(错误消息列表)、offload_index(卸载索引)、state_dict_index(状态字典索引)。
    return error_msgs, offload_index, state_dict_index
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
    # 如果 variant 参数不为 None,则修改 weights_name 中的文件扩展名
    if variant is not None:
        # 将 weights_name 按照 '.' 分割成列表
        splits = weights_name.split(".")
        # 替换列表中倒数第二项为 variant
        splits = splits[:-1] + [variant] + splits[-1:]
        # 将列表重新组合成字符串形式的 weights_name
        weights_name = ".".join(splits)

    # 返回修改后的 weights_name
    return weights_name


class ModuleUtilsMixin:
    """
    A few utilities for `torch.nn.Modules`, to be used as a mixin.
    """

    @staticmethod
    def _hook_rss_memory_pre_forward(module, *args, **kwargs):
        try:
            import psutil
        except ImportError:
            # 如果导入 psutil 失败,则抛出 ImportError
            raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")

        # 获取当前进程的 psutil.Process 对象
        process = psutil.Process(os.getpid())
        # 获取当前进程的内存信息
        mem = process.memory_info()
        # 将当前进程的内存占用 RSS 存储到 module 对象的 mem_rss_pre_forward 属性中
        module.mem_rss_pre_forward = mem.rss
        # 返回 None
        return None

    @staticmethod
    def _hook_rss_memory_post_forward(module, *args, **kwargs):
        try:
            import psutil
        except ImportError:
            # 如果导入 psutil 失败,则抛出 ImportError
            raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")

        # 获取当前进程的 psutil.Process 对象
        process = psutil.Process(os.getpid())
        # 获取当前进程的内存信息
        mem = process.memory_info()
        # 将当前进程的内存占用 RSS 存储到 module 对象的 mem_rss_post_forward 属性中
        module.mem_rss_post_forward = mem.rss
        # 计算前后两次内存占用的差值
        mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
        # 将差值累加到 module 对象的 mem_rss_diff 属性中
        module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0)
        # 返回 None
        return None

    def add_memory_hooks(self):
        """
        Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.

        Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero
        with `model.reset_memory_hooks_state()`.
        """
        # 遍历当前对象的所有子模块
        for module in self.modules():
            # 注册前向传播前的钩子函数 _hook_rss_memory_pre_forward
            module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
            # 注册前向传播后的钩子函数 _hook_rss_memory_post_forward
            module.register_forward_hook(self._hook_rss_memory_post_forward)
        # 调用 reset_memory_hooks_state 方法,重置所有模块的内存钩子状态
        self.reset_memory_hooks_state()

    def reset_memory_hooks_state(self):
        """
        Reset the `mem_rss_diff` attribute of each module (see [`~modeling_utils.ModuleUtilsMixin.add_memory_hooks`]).
        """
        # 遍历当前对象的所有子模块
        for module in self.modules():
            # 将每个模块的 mem_rss_diff、mem_rss_post_forward 和 mem_rss_pre_forward 属性重置为 0
            module.mem_rss_diff = 0
            module.mem_rss_post_forward = 0
            module.mem_rss_pre_forward = 0

    @property
    def device(self) -> torch.device:
        """
        `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
        device).
        """
        # 调用 get_parameter_device 函数获取当前模块所在的设备,并返回设备对象
        return get_parameter_device(self)

    @property
    def dtype(self) -> torch.dtype:
        """
        `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
        """
        # 调用 get_parameter_dtype 函数获取当前模块的数据类型,并返回数据类型对象
        return get_parameter_dtype(self)
    def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
        """
        Invert an attention mask (e.g., switches 0. and 1.).

        Args:
            encoder_attention_mask (`torch.Tensor`): An attention mask.

        Returns:
            `torch.Tensor`: The inverted attention mask.
        """
        # 如果注意力遮罩是三维的,则在第二个维度上扩展为四维
        if encoder_attention_mask.dim() == 3:
            encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
        # 如果注意力遮罩是二维的,则在第二个和第三个维度上扩展为四维
        if encoder_attention_mask.dim() == 2:
            encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
        
        # T5有一个可以比较序列ID的遮罩,这里通过转置来模拟
        # 参考:https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
        # /transformer/transformer_layers.py#L270
        # 将注意力遮罩转换为模型数据类型,以支持fp16(半精度浮点数)计算
        encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype)
        # 计算反转的注意力遮罩,将0变为最小的负浮点数
        encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(self.dtype).min

        return encoder_extended_attention_mask

    @staticmethod
    def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device=None):
        if device is not None:
            warnings.warn(
                "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
            )
        else:
            device = attention_mask.device
        
        batch_size, seq_length = input_shape
        # 创建一个序列ID张量,长度为seq_length,设备为指定的设备
        seq_ids = torch.arange(seq_length, device=device)
        # 创建一个因果遮罩,用于decoder,形状为[batch_size, seq_length, seq_length]
        causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
        # 将因果遮罩转换为与注意力遮罩相同的数据类型
        causal_mask = causal_mask.to(attention_mask.dtype)

        # 如果因果遮罩的长度小于注意力遮罩的长度,则需要在因果遮罩前添加一个全1的遮罩
        if causal_mask.shape[1] < attention_mask.shape[1]:
            prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
            causal_mask = torch.cat(
                [
                    torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
                    causal_mask,
                ],
                axis=-1,
            )

        # 创建扩展的注意力遮罩,是因果遮罩和输入的注意力遮罩的点积
        extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
        return extended_attention_mask

    def get_extended_attention_mask(
        self, attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None
    ):
        # 略过此方法的注释,因为未提供代码块
    ) -> Tensor:
        """
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

        Arguments:
            attention_mask (`torch.Tensor`):
                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
            input_shape (`Tuple[int]`):
                The shape of the input to the model.

        Returns:
            `torch.Tensor` The extended attention mask, with the same dtype as `attention_mask.dtype`.
        """
        if dtype is None:
            dtype = self.dtype  # 如果未指定 dtype,则使用对象自身的 dtype

        if not (attention_mask.dim() == 2 and self.config.is_decoder):
            # 如果 attention_mask 的维度不是二维或模型不是解码器,发出警告
            # 仅在不在 `create_extended_attention_mask_for_decoder` 中显示时才显示此警告
            if device is not None:
                warnings.warn(
                    "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
                )
        
        # 如果 attention_mask 的维度是三维,则扩展为 [batch_size, 1, from_seq_length, to_seq_length]
        if attention_mask.dim() == 3:
            extended_attention_mask = attention_mask[:, None, :, :]
        elif attention_mask.dim() == 2:
            # 如果提供了维度为 [batch_size, seq_length] 的填充 mask
            # - 如果模型是解码器,则除了填充 mask 外还应用因果 mask
            # - 如果模型是编码器,则将 mask 扩展为 [batch_size, num_heads, seq_length, seq_length]
            if self.config.is_decoder:
                extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
                    input_shape, attention_mask, device
                )
            else:
                extended_attention_mask = attention_mask[:, None, None, :]
        else:
            # 如果 attention_mask 维度不符合要求,抛出 ValueError
            raise ValueError(
                f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
            )

        # 将 extended_attention_mask 转换为指定的 dtype,用于 fp16 兼容性
        extended_attention_mask = extended_attention_mask.to(dtype=dtype)
        # 将所有值为 1.0 的位置变为 0.0,所有值为 0.0 的位置变为 dtype 的最小值
        extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
        return extended_attention_mask
    def prepare_head_mask(self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False) -> Tensor:
        """
        Prepare the head mask if needed.

        Args:
            head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
                The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
            num_hidden_layers (`int`):
                The number of hidden layers in the model.
            is_attention_chunked (`bool`, *optional*, defaults to `False`):
                Whether or not the attention scores are computed by chunks or not.

        Returns:
            `torch.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
            `[None]` for each layer.
        """
        if head_mask is not None:
            # Convert head_mask to a 5-dimensional tensor if it's 1-dimensional
            head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
            # Modify head_mask shape if attention scores are computed by chunks
            if is_attention_chunked is True:
                head_mask = head_mask.unsqueeze(-1)
        else:
            # Set head_mask to a list of None for each layer if head_mask is None
            head_mask = [None] * num_hidden_layers

        return head_mask

    def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
        """
        Convert `head_mask` to a 5-dimensional tensor `[num_hidden_layers x batch x num_heads x seq_length x seq_length]`.

        Args:
            head_mask (`torch.Tensor`):
                The input head_mask tensor with shape `[num_heads]` or `[num_hidden_layers x num_heads]`.
            num_hidden_layers (`int`):
                The number of hidden layers in the model.

        Returns:
            `torch.Tensor`:
                The converted head_mask tensor with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]`.
        """
        if head_mask.dim() == 1:
            # Expand the head_mask tensor to match the desired shape
            head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
        elif head_mask.dim() == 2:
            # Expand the head_mask tensor to include each layer if it's 2-dimensional
            head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
        assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
        head_mask = head_mask.to(dtype=self.dtype)  # Convert to specified dtype for compatibility
        return head_mask
    def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
        """
        Get number of (optionally, trainable or non-embeddings) parameters in the module.

        Args:
            only_trainable (`bool`, *optional*, defaults to `False`):
                Whether or not to return only the number of trainable parameters

            exclude_embeddings (`bool`, *optional*, defaults to `False`):
                Whether or not to return only the number of non-embeddings parameters

        Returns:
            `int`: The number of parameters.
        """

        # Check if embeddings should be excluded from the parameter count
        if exclude_embeddings:
            # Generate a list of parameter names that belong to embedding layers
            embedding_param_names = [
                f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
            ]
            # Filter out embedding parameters from the total parameters
            total_parameters = [
                parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
            ]
        else:
            # If not excluding embeddings, include all parameters of the module
            total_parameters = list(self.parameters())

        # Initialize an empty list to store the number of elements (numel) in each parameter tensor
        total_numel = []
        
        # Check if the model has been loaded in 4bit precision
        is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False)

        # If loaded in 4bit precision, additional considerations are needed
        if is_loaded_in_4bit:
            # Check if the bitsandbytes library is available
            if is_bitsandbytes_available():
                import bitsandbytes as bnb
            else:
                # Raise an error if bitsandbytes is not installed but 4bit precision is indicated
                raise ValueError(
                    "bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong"
                    " make sure to install bitsandbytes with `pip install bitsandbytes`. You also need a GPU. "
                )

        # Iterate through each parameter to calculate the number of elements (numel)
        for param in total_parameters:
            # Check if the parameter requires gradient or if only trainable parameters are considered
            if param.requires_grad or not only_trainable:
                # For 4bit models, adjust the numel calculation due to storage considerations
                if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit):
                    total_numel.append(
                        param.numel() * 2 * self.hf_quantizer.quantization_config.bnb_4bit_quant_storage.itemsize
                    )
                else:
                    # Standard numel calculation for regular tensors
                    total_numel.append(param.numel())

        # Return the sum of all calculated numels, representing the total number of parameters
        return sum(total_numel)
    # Helper function to estimate the total number of tokens from the model inputs.
    def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int:
        """
        Helper function to estimate the total number of tokens from the model inputs.

        Args:
            inputs (`dict`): The model inputs.

        Returns:
            `int`: The total number of tokens.
        """
        # Initialize a dictionary to track warnings if not already initialized
        if not hasattr(self, "warnings_issued"):
            self.warnings_issued = {}
        
        # Check if the main input name exists in the input dictionary
        if self.main_input_name in input_dict:
            # Return the number of elements in the tensor corresponding to the main input
            return input_dict[self.main_input_name].numel()
        # If main input name does not exist, issue a warning
        elif "estimate_tokens" not in self.warnings_issued:
            logger.warning(
                "Could not estimate the number of tokens of the input, floating-point operations will not be computed"
            )
            # Mark that a warning for 'estimate_tokens' has been issued
            self.warnings_issued["estimate_tokens"] = True
        
        # Return 0 if unable to estimate tokens
        return 0

    # Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a
    # batch with this transformer model.
    def floating_point_ops(
        self, input_dict: Dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True
    ) -> int:
        """
        Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a
        batch with this transformer model. Default approximation neglects the quadratic dependency on the number of
        tokens (valid if `12 * d_model << sequence_length`) as laid out in [this
        paper](https://arxiv.org/pdf/2001.08361.pdf) section 2.1. Should be overridden for transformers with parameter
        re-use e.g. Albert or Universal Transformers, or if doing long-range modeling with very high sequence lengths.

        Args:
            batch_size (`int`):
                The batch size for the forward pass.

            sequence_length (`int`):
                The number of tokens in each line of the batch.

            exclude_embeddings (`bool`, *optional*, defaults to `True`):
                Whether or not to count embedding and softmax operations.

        Returns:
            `int`: The number of floating-point operations.
        """

        # Calculate the number of floating-point operations based on an approximation
        # 6 operations per token times the estimated number of tokens times the number of model parameters
        return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
# 定义一个继承自多个Mixin类的模型基类,用于所有模型的基础功能实现
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):
    r"""
    Base class for all models.

    [`PreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
    downloading and saving models as well as a few methods common to all models to:

        - resize the input embeddings,
        - prune heads in the self-attention heads.

    Class attributes (overridden by derived classes):

        - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
          for this model architecture.
        - **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
          taking as arguments:

            - **model** ([`PreTrainedModel`]) -- An instance of the model on which to load the TensorFlow checkpoint.
            - **config** ([`PreTrainedConfig`]) -- An instance of the configuration associated to the model.
            - **path** (`str`) -- A path to the TensorFlow checkpoint.

        - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
          classes of the same architecture adding modules on top of the base model.
        - **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization.
        - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
          models, `pixel_values` for vision models and `input_values` for speech models).
    """

    # 配置类,派生类需覆盖
    config_class = None
    # 基础模型前缀,派生类需覆盖
    base_model_prefix = ""
    # 主要输入名称,默认为 `input_ids`
    main_input_name = "input_ids"
    # 模型标签,初始化为 None
    model_tags = None

    # 内部使用的属性,以下几个属性初始化为 None
    _auto_class = None
    _no_split_modules = None
    _skip_keys_device_placement = None
    _keep_in_fp32_modules = None

    # 用于加载时忽略的 `state_dict` 键的模式列表,初始化为 None
    _keys_to_ignore_on_load_missing = None
    # 用于加载时忽略的 `state_dict` 键的模式列表,初始化为 None
    _keys_to_ignore_on_load_unexpected = None
    # 用于保存模型时忽略的 `state_dict` 键的列表,初始化为 None
    _keys_to_ignore_on_save = None
    # 可能与另一个键绑定的 `state_dict` 键的列表,初始化为 None
    _tied_weights_keys = None

    # 是否支持模型并行化,默认为 False
    is_parallelizable = False
    # 是否支持梯度检查点,默认为 False
    supports_gradient_checkpointing = False

    # 是否支持 Flash Attention 2,默认为 False
    _supports_flash_attn_2 = False

    # 是否支持 SDPA,默认为 False
    _supports_sdpa = False

    # 是否支持将 `Cache` 实例用作 `past_key_values`,默认为 False
    _supports_cache_class = False

    @property
    def dummy_inputs(self) -> Dict[str, torch.Tensor]:
        """
        `Dict[str, torch.Tensor]`: 返回用于网络前向传播的虚拟输入数据字典。
        """
        return {"input_ids": torch.tensor(DUMMY_INPUTS)}

    @property
    def framework(self) -> str:
        """
        :str: 标识这是一个基于 PyTorch 的模型。
        """
        return "pt"

    def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
        super().__init__()
        if not isinstance(config, PretrainedConfig):
            raise ValueError(
                f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
                "`PretrainedConfig`. To create a model from a pretrained model use "
                f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
            )
        # 保存配置和预训练权重的来源,如果在模型中给出的话
        config = self._autoset_attn_implementation(
            config, torch_dtype=torch.get_default_dtype(), check_device_map=False
        )
        self.config = config

        self.name_or_path = config.name_or_path
        self.warnings_issued = {}
        # 如果模型支持生成,将生成配置从模型配置中创建
        self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
        # 重写类属性以将其变为实例属性,这样像 `InstructBlipForConditionalGeneration` 这样的模型可以动态更新它,
        # 而不需要修改类属性,当使用不同的组件(例如语言模型)时。
        self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)

    def post_init(self):
        """
        在每次 Transformer 模型初始化结束时执行的方法,用于执行需要模型模块正确初始化的代码(例如权重初始化)。
        """
        self.init_weights()
        self._backward_compatibility_gradient_checkpointing()

    def _backward_compatibility_gradient_checkpointing(self):
        if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
            self.gradient_checkpointing_enable()
            # 现在已经使用了该属性,从配置中删除它,这样它就不会被保存在配置中。
            delattr(self.config, "gradient_checkpointing")
    def add_model_tags(self, tags: Union[List[str], str]) -> None:
        r"""
        Add custom tags into the model that gets pushed to the Hugging Face Hub. Will
        not overwrite existing tags in the model.

        Args:
            tags (`Union[List[str], str]`):
                The desired tags to inject in the model

        Examples:

        ```
        from transformers import AutoModel

        model = AutoModel.from_pretrained("google-bert/bert-base-cased")

        model.add_model_tags(["custom", "custom-bert"])

        # Push the model to your namespace with the name "my-custom-bert".
        model.push_to_hub("my-custom-bert")
        ```
        """
        if isinstance(tags, str):
            tags = [tags]  # 如果tags是字符串,转换为单元素列表

        if self.model_tags is None:
            self.model_tags = []  # 如果当前模型标签为空,初始化为空列表

        for tag in tags:
            if tag not in self.model_tags:
                self.model_tags.append(tag)  # 添加不重复的标签到模型标签列表

    @classmethod
    def _from_config(cls, config, **kwargs):
        """
        All context managers that the model should be initialized under go here.

        Args:
            torch_dtype (`torch.dtype`, *optional*):
                Override the default `torch.dtype` and load the model under this dtype.
        """
        torch_dtype = kwargs.pop("torch_dtype", None)
        use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)

        # override default dtype if needed
        dtype_orig = None
        if torch_dtype is not None:
            dtype_orig = cls._set_default_torch_dtype(torch_dtype)  # 如果指定了torch_dtype,则设置默认dtype为指定的dtype

        config = copy.deepcopy(config)  # 创建配置的深拷贝,避免在_from_config中直接修改原始配置
        config._attn_implementation = kwargs.pop("attn_implementation", None)  # 设置配置中的注意力实现方式

        config = cls._autoset_attn_implementation(
            config,
            use_flash_attention_2=use_flash_attention_2,
            check_device_map=False,
            torch_dtype=torch_dtype,
        )  # 调用自动设置注意力实现的方法,根据参数设置config的相关属性

        if is_deepspeed_zero3_enabled():
            import deepspeed

            logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
            # this immediately partitions the model across all gpus, to avoid the overhead in time
            # and memory copying it on CPU or each GPU first
            with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
                model = cls(config, **kwargs)  # 在DeepSpeed ZeRO-3环境下使用deepseed.zero.Init初始化模型
        else:
            model = cls(config, **kwargs)  # 在非DeepSpeed ZeRO-3环境下常规初始化模型

        # restore default dtype if it was modified
        if dtype_orig is not None:
            torch.set_default_dtype(dtype_orig)  # 如果修改了默认dtype,则恢复为修改前的dtype

        return model

    @classmethod
    def _autoset_attn_implementation(
        cls,
        config,
        use_flash_attention_2: bool = False,
        torch_dtype: Optional[torch.dtype] = None,
        device_map: Optional[Union[str, Dict[str, int]]] = None,
        check_device_map: bool = True,
    ):
        """
        Automatically sets the attention implementation in the provided config.

        Args:
            config: The model configuration to modify.
            use_flash_attention_2: Whether to use the Flash Attention 2 implementation.
            torch_dtype: Optional, override the default torch.dtype for initialization.
            device_map: Optional device mapping.
            check_device_map: Whether to check device map validity.

        Returns:
            The modified config with the attention implementation set.
        """
        # Set attention implementation based on parameters
        if use_flash_attention_2:
            config.attention_type = "flash_attention_2"
        elif config._attn_implementation is not None:
            config.attention_type = config._attn_implementation

        if device_map is not None and check_device_map:
            cls._validate_device_map(device_map)

        return config
    def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
        """
        Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
        under specific dtype.

        Args:
            dtype (`torch.dtype`):
                a floating dtype to set to.

        Returns:
            `torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was
            modified. If it wasn't, returns `None`.

        Note `set_default_dtype` currently only works with floating-point types and asserts if for example,
        `torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception.
        """
        if not dtype.is_floating_point:
            raise ValueError(
                f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
            )

        logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
        # 获取当前的默认 dtype
        dtype_orig = torch.get_default_dtype()
        # 设置新的默认 dtype
        torch.set_default_dtype(dtype)
        return dtype_orig

    @property
    def base_model(self) -> nn.Module:
        """
        `torch.nn.Module`: The main body of the model.
        """
        # 返回当前实例的 `base_model_prefix` 属性,如果不存在则返回自身
        return getattr(self, self.base_model_prefix, self)

    @classmethod
    def can_generate(cls) -> bool:
        """
        Returns whether this model can generate sequences with `.generate()`.

        Returns:
            `bool`: Whether this model can generate sequences with `.generate()`.
        """
        # 检查是否定义了 `prepare_inputs_for_generation` 或 `generate` 函数
        # 如果没有定义 `prepare_inputs_for_generation` 或 `generate`,则返回 True
        if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
            return False
        return True

    @classmethod
    def _check_and_enable_flash_attn_2(
        cls,
        config,
        torch_dtype: Optional[torch.dtype] = None,
        device_map: Optional[Union[str, Dict[str, int]]] = None,
        check_device_map: bool = True,
        hard_check_only: bool = False,
    ):
        """
        Check and potentially enable the Flash Attention 2 features based on the provided configuration.

        Args:
            config: The configuration object for the model.
            torch_dtype (Optional[torch.dtype]): The desired dtype to set as default.
            device_map (Optional[Union[str, Dict[str, int]]]): Device mapping information.
            check_device_map (bool): Whether to check device map.
            hard_check_only (bool): Whether to perform a hard check only.

        This function checks if certain conditions are met in the provided configuration to enable Flash Attention 2.
        """
        # 此处应该有代码实现,用于检查和启用 Flash Attention 2 的相关特性
        pass
    # 检查并启用 SDPA(Scaled Dot-Product Attention)功能的静态方法。如果所有检查通过且 `hard_check_only` 为 False,
    # 则设置配置属性 `_attn_implementation` 为 "flash_attention_2",以便模型可以正确初始化相应的注意力模块。
    def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
        if hard_check_only:
            # 如果仅进行严格检查并且当前类不支持 SDPA,则抛出值错误
            if not cls._supports_sdpa:
                raise ValueError(
                    f"{cls.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet."
                    " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe"
                    ' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
                )
            # 如果未安装 PyTorch SDPA,则抛出导入错误
            if not is_torch_sdpa_available():
                raise ImportError(
                    "PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1."
                )

        # 如果未安装 PyTorch SDPA 或当前类不支持 SDPA,则直接返回配置
        if not is_torch_sdpa_available() or not cls._supports_sdpa:
            return config

        # 获取类属性 `_is_bettertransformer`,判断是否使用 BetterTransformer 模式
        _is_bettertransformer = getattr(cls, "use_bettertransformer", False)
        # 如果是 BetterTransformer 模式,则返回配置
        if _is_bettertransformer:
            return config

        # 如果不是严格检查模式,将配置的 `_attn_implementation` 设置为 "sdpa"
        if not hard_check_only:
            config._attn_implementation = "sdpa"
        # 返回更新后的配置
        return config

    # 启用输入嵌入的梯度计算的方法。用于在固定模型权重的同时微调适配器权重。
    def enable_input_require_grads(self):
        # 定义一个函数 `make_inputs_require_grads`,用于设置输出的梯度要求为 True
        def make_inputs_require_grads(module, input, output):
            output.requires_grad_(True)

        # 注册前向钩子 `_require_grads_hook` 到输入嵌入模块上
        self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)

    # 移除输入嵌入梯度计算的方法。
    def disable_input_require_grads(self):
        # 移除前向钩子 `_require_grads_hook`
        self._require_grads_hook.remove()

    # 获取模型的输入嵌入的方法,返回一个 `nn.Module` 模块,将词汇映射到隐藏状态。
    def get_input_embeddings(self) -> nn.Module:
        # 获取基础模型,若存在,则递归调用其 `get_input_embeddings` 方法
        base_model = getattr(self, self.base_model_prefix, self)
        # 若 `base_model` 不是当前对象本身,则调用其 `get_input_embeddings` 方法
        if base_model is not self:
            return base_model.get_input_embeddings()
        else:
            # 否则抛出未实现错误
            raise NotImplementedError
    def set_input_embeddings(self, value: nn.Module):
        """
        Set model's input embeddings.

        Args:
            value (`nn.Module`): A module mapping vocabulary to hidden states.
        """
        # 获取当前模型的基础模型(可能是自身或者其它模型)
        base_model = getattr(self, self.base_model_prefix, self)
        # 如果基础模型不是当前对象本身,则递归调用基础模型的设置输入嵌入方法
        if base_model is not self:
            base_model.set_input_embeddings(value)
        else:
            # 如果基础模型是当前对象本身,则抛出未实现的错误
            raise NotImplementedError

    def get_output_embeddings(self) -> nn.Module:
        """
        Returns the model's output embeddings.

        Returns:
            `nn.Module`: A torch module mapping hidden states to vocabulary.
        """
        # 对于没有输出嵌入的模型,返回空值
        return None  # Overwrite for models with output embeddings

    def _init_weights(self, module):
        """
        Initialize the weights. This method should be overridden by derived class and is
        the only initialization method that will be called when loading a checkpoint
        using `from_pretrained`. Any attempt to initialize outside of this function
        will be useless as the torch.nn.init function are all replaced with skip.
        """
        # 初始化权重的方法,应当由派生类重写。在使用 `from_pretrained` 加载检查点时,这是唯一会被调用的初始化方法。

    def _initialize_weights(self, module):
        """
        Initialize the weights if they are not already initialized.
        """
        # 如果模块已经被初始化,则直接返回
        if getattr(module, "_is_hf_initialized", False):
            return
        # 否则调用初始化权重的具体方法
        self._init_weights(module)
        # 标记模块已经被初始化
        module._is_hf_initialized = True

    def tie_weights(self):
        """
        Tie the weights between the input embeddings and the output embeddings.

        If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
        weights instead.
        """
        # 如果配置中设置了 `tie_word_embeddings`,则尝试绑定输入嵌入和输出嵌入的权重
        if getattr(self.config, "tie_word_embeddings", True):
            # 获取输出嵌入
            output_embeddings = self.get_output_embeddings()
            # 如果输出嵌入不为空,则尝试绑定或克隆权重
            if output_embeddings is not None:
                self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())

        # 如果配置中设置了 `is_encoder_decoder` 和 `tie_encoder_decoder`,则尝试绑定编码器-解码器的权重
        if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
            # 如果存在基础模型前缀,则将当前对象替换为基础模型
            if hasattr(self, self.base_model_prefix):
                self = getattr(self, self.base_model_prefix)
            # 调用内部方法绑定编码器-解码器权重
            self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)

        # 对于模型中的每一个模块,如果模块具有 `_tie_weights` 属性,则调用其绑定权重方法
        for module in self.modules():
            if hasattr(module, "_tie_weights"):
                module._tie_weights()
    def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
        """根据是否使用 TorchScript 来共享或克隆模块的权重"""
        if self.config.torchscript:
            # 如果使用 TorchScript,则克隆输入 embeddings 的权重到输出 embeddings
            output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
        else:
            # 否则,直接共享输入 embeddings 的权重给输出 embeddings
            output_embeddings.weight = input_embeddings.weight

        # 如果输出 embeddings 存在偏置项
        if getattr(output_embeddings, "bias", None) is not None:
            # 对输出 embeddings 的偏置进行填充,以匹配权重的形状
            output_embeddings.bias.data = nn.functional.pad(
                output_embeddings.bias.data,
                (
                    0,
                    output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],
                ),
                "constant",
                0,
            )
        # 如果输出 embeddings 具有 'out_features' 属性,并且输入 embeddings 具有 'num_embeddings' 属性
        if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
            # 设置输出 embeddings 的 out_features 属性为输入 embeddings 的 num_embeddings
            output_embeddings.out_features = input_embeddings.num_embeddings

    def _get_no_split_modules(self, device_map: str):
        """
        获取在使用 device_map 时不应分割的模块。我们遍历模块以获取底层的 `_no_split_modules`。

        Args:
            device_map (`str`):
                设备映射值。选项有 ["auto", "balanced", "balanced_low_0", "sequential"]

        Returns:
            `List[str]`: 不应分割的模块列表
        """
        _no_split_modules = set()
        modules_to_check = [self]
        while len(modules_to_check) > 0:
            module = modules_to_check.pop(-1)
            # 如果模块不在 _no_split_modules 中,则继续检查其子模块
            if module.__class__.__name__ not in _no_split_modules:
                if isinstance(module, PreTrainedModel):
                    if module._no_split_modules is None:
                        raise ValueError(
                            f"{module.__class__.__name__} 不支持 `device_map='{device_map}'`。要实现支持,模型类需要实现 `_no_split_modules` 属性。"
                        )
                    else:
                        _no_split_modules = _no_split_modules | set(module._no_split_modules)
                # 将当前模块的子模块加入待检查列表
                modules_to_check += list(module.children())
        return list(_no_split_modules)

    def resize_token_embeddings(
        self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
        ):
    def resize_token_embeddings(
        self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
    ) -> nn.Embedding:
        """
        调整模型的输入 token embeddings 矩阵大小,如果 `new_num_tokens != config.vocab_size` 的话。

        调整后负责在需要时绑定权重 embeddings,如果模型类有 `tie_weights()` 方法的话。

        参数:
            new_num_tokens (`int`, *可选*):
                embedding 矩阵中的新 token 数量。增加大小会在末尾添加新初始化的向量。减少大小会从末尾移除向量。
                如果未提供或为 `None`,仅返回指向模型输入 token 的 `torch.nn.Embedding` 模块的指针,不执行任何操作。
            pad_to_multiple_of (`int`, *可选*):
                如果设置,将填充 embedding 矩阵至提供的值的倍数。如果 `new_num_tokens` 设置为 `None`,则仅将 embedding
                填充至 `pad_to_multiple_of` 的倍数。

                这对于启用 NVIDIA 硬件的 Tensor Cores(计算能力 `>= 7.5`,Volta)或者利用 TPUs 时特别有用,这些硬件
                在序列长度为 128 的倍数时效果最佳。有关更多详细信息或调整大小的正确值的帮助,请参考此指南:
                https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc

        返回:
            `torch.nn.Embedding`: 指向模型输入 tokens Embedding 模块的指针。
        """
        # 调整 token embeddings 大小并返回
        model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)

        # 如果 new_num_tokens 和 pad_to_multiple_of 都为 None,直接返回调整后的模型 embeddings
        if new_num_tokens is None and pad_to_multiple_of is None:
            return model_embeds

        # 更新基础模型和当前模型配置中的词汇大小
        self.config.vocab_size = model_embeds.weight.shape[0]
        self.vocab_size = model_embeds.weight.shape[0]

        # 如果需要,重新绑定权重
        self.tie_weights()

        # 返回调整后的模型 embeddings
        return model_embeds
    # 调整模型的 token embeddings 的大小
    def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
        # 获取当前的输入 embeddings
        old_embeddings = self.get_input_embeddings()
        # 调整 embeddings 的大小
        new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
        
        # 如果旧的 embeddings 带有 _hf_hook 属性,将其挂钩移到新的 embeddings 上
        if hasattr(old_embeddings, "_hf_hook"):
            hook = old_embeddings._hf_hook
            add_hook_to_module(new_embeddings, hook)
        
        # 复制旧的 embeddings 是否需要梯度到新的 embeddings
        old_embeddings_requires_grad = old_embeddings.weight.requires_grad
        new_embeddings.requires_grad_(old_embeddings_requires_grad)
        
        # 设置模型的输入 embeddings 为新调整大小后的 embeddings
        self.set_input_embeddings(new_embeddings)
        
        # 检查是否量化了模型
        is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
        
        # 更新 new_num_tokens,确保其与新 embeddings 的实际大小一致
        if pad_to_multiple_of is not None:
            # 如果使用了 deepspeed.zero3 并且未量化,则使用 deepspeed.zero.GatheredParameters 调整大小
            if is_deepspeed_zero3_enabled() and not is_quantized:
                import deepspeed

                with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
                    new_num_tokens = new_embeddings.weight.shape[0]
            else:
                # 否则,直接使用新 embeddings 的大小
                new_num_tokens = new_embeddings.weight.shape[0]
        
        # 如果输出 embeddings 存在且未绑定 word embeddings,调整 lm head 的大小
        if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
            # 获取旧的 lm head
            old_lm_head = self.get_output_embeddings()
            # 调整 lm head 的大小
            new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
            
            # 如果旧的 lm head 带有 _hf_hook 属性,将其挂钩移到新的 lm head 上
            if hasattr(old_lm_head, "_hf_hook"):
                hook = old_lm_head._hf_hook
                add_hook_to_module(new_lm_head, hook)
            
            # 复制旧的 lm head 是否需要梯度到新的 lm head
            old_lm_head_requires_grad = old_lm_head.weight.requires_grad
            new_lm_head.requires_grad_(old_lm_head_requires_grad)
            
            # 设置模型的输出 embeddings 为新调整大小后的 lm head
            self.set_output_embeddings(new_lm_head)
        
        # 返回调整后的输入 embeddings
        return self.get_input_embeddings()

    # 获取调整大小后的 embeddings
    def _get_resized_embeddings(
        self,
        old_embeddings: nn.Embedding,
        new_num_tokens: Optional[int] = None,
        pad_to_multiple_of: Optional[int] = None,
    ):
        ...

    # 获取调整大小后的 lm head
    def _get_resized_lm_head(
        self, old_lm_head: nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False
    ):
        ...

    # 将原始 lm head 复制到调整大小后的 lm head
    def _copy_lm_head_original_to_resized(
        self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
    ):
        # 将旧的 lm head 权重复制到新的 lm head
        if not transposed:
            new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
        else:
            new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]

        # 如果新的 lm head 存在偏置,将旧的 lm head 偏置复制到新的 lm head
        if has_new_lm_head_bias:
            new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
    def resize_position_embeddings(self, new_num_position_embeddings: int):
        # 抛出未实现错误,提示用户在子类中实现这个方法来调整位置嵌入
        raise NotImplementedError(
            f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
            f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
        )

    def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
        # 抛出未实现错误,提示用户在子类中实现这个方法来获取位置嵌入
        raise NotImplementedError(
            f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
            f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
        )

    def init_weights(self):
        """
        If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
        initialization logic in `_init_weights`.
        """
        # 如果需要修剪头部,则调用修剪方法
        if self.config.pruned_heads:
            self.prune_heads(self.config.pruned_heads)

        # 如果定义了初始化权重的方法,则执行权重初始化
        if _init_weights:
            # 调用_apply方法来初始化权重
            self.apply(self._initialize_weights)

            # 如果不是初始化所有权重,则不应该绑定权重
            # 因为from_pretrained(...)方法会自动绑定权重
            self.tie_weights()

    def prune_heads(self, heads_to_prune: Dict[int, List[int]]):
        """
        Prunes heads of the base model.

        Arguments:
            heads_to_prune (`Dict[int, List[int]]`):
                Dictionary with keys being selected layer indices (`int`) and associated values being the list of heads
                to prune in said layer (list of `int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on
                layer 1 and heads 2 and 3 on layer 2.
        """
        # 将新修剪的头部集合保存为先前存储的修剪头部集合与新修剪头部集合的并集
        for layer, heads in heads_to_prune.items():
            union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
            self.config.pruned_heads[layer] = list(union_heads)  # 不幸的是,我们必须将其存储为列表以便进行JSON序列化

        # 调用基础模型的内部方法来修剪头部
        self.base_model._prune_heads(heads_to_prune)
    # 激活当前模型的梯度检查点功能。
    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
        """
        Activates gradient checkpointing for the current model.

        Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
        activations".

        We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
        the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2

        Args:
            gradient_checkpointing_kwargs (dict, *optional*):
                Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
        """
        # 如果当前模型不支持梯度检查点,则抛出异常。
        if not self.supports_gradient_checkpointing:
            raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")

        # 如果未提供梯度检查点参数,则使用默认值 {"use_reentrant": True}。
        if gradient_checkpointing_kwargs is None:
            gradient_checkpointing_kwargs = {"use_reentrant": True}

        # 创建一个偏函数,用于调用 `torch.utils.checkpoint.checkpoint` 函数,并传入梯度检查点参数。
        gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)

        # 对于旧的梯度检查点格式(transformers < 4.35.0),对于在Hub上存在的模型,我们将回退到重写的 `_set_gradient_checkpointing` 方法。
        _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters

        # 如果不是使用旧格式,则调用 `self._set_gradient_checkpointing` 方法启用梯度检查点。
        if not _is_using_old_format:
            self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
        # 否则,应用部分应用 `self._set_gradient_checkpointing` 方法,传入参数 `value=True`。
        else:
            self.apply(partial(self._set_gradient_checkpointing, value=True))
            # 记录警告信息,提示使用了已废弃的梯度检查点格式。
            logger.warn(
                "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
                "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
            )

        # 如果存在 `_hf_peft_config_loaded` 属性,则需要确保输入的 `requires_grad` 为 True。
        if getattr(self, "_hf_peft_config_loaded", False):
            # 当使用 PEFT + 梯度检查点 + Trainer 时,需要确保输入的 `requires_grad` 为 True。
            # 这也适用于 PEFT:https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
            # 在使用 PEFT 进行训练时,只有 LoRA 层的 `requires_grad` 被设置为 True,但冻结层的输出需要传播梯度,以确保梯度的流动。
            self.enable_input_require_grads()
    def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint):
        is_gradient_checkpointing_set = False

        # Apply gradient checkpointing setting to the top-level module if supported,
        # such as LongT5Stack inheriting from `PreTrainedModel`.
        if hasattr(self, "gradient_checkpointing"):
            # Set the checkpointing function for the top-level module
            self._gradient_checkpointing_func = gradient_checkpointing_func
            # Enable or disable gradient checkpointing
            self.gradient_checkpointing = enable
            is_gradient_checkpointing_set = True

        # Apply gradient checkpointing setting to all modules recursively
        for module in self.modules():
            if hasattr(module, "gradient_checkpointing"):
                # Set the checkpointing function for the current module
                module._gradient_checkpointing_func = gradient_checkpointing_func
                # Enable or disable gradient checkpointing for the current module
                module.gradient_checkpointing = enable
                is_gradient_checkpointing_set = True

        # If no module supports gradient checkpointing, raise an error
        if not is_gradient_checkpointing_set:
            raise ValueError(
                f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute"
                " `gradient_checkpointing` to modules of the model that uses checkpointing."
            )

    def gradient_checkpointing_disable(self):
        """
        Deactivates gradient checkpointing for the current model.

        Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
        activations".
        """
        # Check if gradient checkpointing is supported
        if self.supports_gradient_checkpointing:
            # For older format (transformers < 4.35.0) or models on the Hub,
            # fall back to the deprecated `_set_gradient_checkpointing` method
            _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
            if not _is_using_old_format:
                # Disable gradient checkpointing using the modern method
                self._set_gradient_checkpointing(enable=False)
            else:
                # Warn about using deprecated checkpointing format
                logger.warn(
                    "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
                    "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
                )
                # Apply partial method to disable gradient checkpointing
                self.apply(partial(self._set_gradient_checkpointing, value=False))

        # Disable input require gradients if Half precision config loaded
        if getattr(self, "_hf_peft_config_loaded", False):
            self.disable_input_require_grads()

    @property
    def is_gradient_checkpointing(self) -> bool:
        """
        Whether gradient checkpointing is activated for this model or not.

        Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
        activations".
        """
        # Check if any module in the model has gradient checkpointing enabled
        return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
    def save_pretrained(
        self,
        save_directory: Union[str, os.PathLike],
        is_main_process: bool = True,
        state_dict: Optional[dict] = None,
        save_function: Callable = torch.save,
        push_to_hub: bool = False,
        max_shard_size: Union[int, str] = "5GB",
        safe_serialization: bool = True,
        variant: Optional[str] = None,
        token: Optional[Union[str, bool]] = None,
        save_peft_format: bool = True,
        **kwargs,
    ):
        """
        Save the model to the specified directory.

        Arguments:
            save_directory (`Union[str, os.PathLike]`):
                Directory where the model should be saved.
            is_main_process (`bool`, *optional*, defaults to `True`):
                Flag indicating if the current process is the main one.
            state_dict (`Optional[dict]`, *optional*):
                Optional dictionary containing the state of the model.
            save_function (`Callable`, *optional*):
                Function used for saving the model (default is `torch.save`).
            push_to_hub (`bool`, *optional*, defaults to `False`):
                Whether to push the saved model to a model hub (if supported).
            max_shard_size (`Union[int, str]`, *optional*, defaults to `"5GB"`):
                Maximum size of each shard when saving large models.
            safe_serialization (`bool`, *optional*, defaults to `True`):
                Whether to ensure safe serialization of the model.
            variant (`Optional[str]`, *optional*):
                Variant of the model being saved (if applicable).
            token (`Optional[Union[str, bool]]`, *optional*):
                Token used for authentication or authorization.
            save_peft_format (`bool`, *optional*, defaults to `True`):
                Whether to save the model in PEFT format.
            **kwargs:
                Additional keyword arguments for customizing the saving process.
        """
        @wraps(PushToHubMixin.push_to_hub)
        def push_to_hub(self, *args, **kwargs):
            """
            Push the model to a model hub with specified tags.

            Arguments:
                *args:
                    Positional arguments for the push operation.
                **kwargs:
                    Keyword arguments for customizing the push operation.

            Returns:
                Result of the super class's `push_to_hub` method.
            """
            tags = self.model_tags if self.model_tags is not None else []

            tags_kwargs = kwargs.get("tags", [])
            if isinstance(tags_kwargs, str):
                tags_kwargs = [tags_kwargs]

            for tag in tags_kwargs:
                if tag not in tags:
                    tags.append(tag)

            if tags:
                kwargs["tags"] = tags
            return super().push_to_hub(*args, **kwargs)

        def get_memory_footprint(self, return_buffers=True):
            """
            Get the memory footprint of the model.

            Arguments:
                return_buffers (`bool`, *optional*, defaults to `True`):
                    Whether to include buffer tensors in the memory footprint calculation.

            Returns:
                Memory footprint of the model in bytes.
            """
            mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
            if return_buffers:
                mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
                mem = mem + mem_bufs
            return mem

        @wraps(torch.nn.Module.cuda)
        def cuda(self, *args, **kwargs):
            """
            Move the model to CUDA device, if not quantized.

            Arguments:
                *args:
                    Positional arguments for the CUDA operation.
                **kwargs:
                    Keyword arguments for customizing the CUDA operation.

            Returns:
                Result of the super class's `cuda` method.
            
            Raises:
                ValueError: If the model is 4-bit or 8-bit quantized.
            """
            if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
                raise ValueError(
                    "Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. "
                    "Please use the model as it is, since the model has already been set to the "
                    "correct devices and casted to the correct `dtype`."
                )
            else:
                return super().cuda(*args, **kwargs)

        @wraps(torch.nn.Module.to)
    # 定义一个类方法,用于从预训练模型加载模型实例
    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
        *model_args,
        config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
        cache_dir: Optional[Union[str, os.PathLike]] = None,
        ignore_mismatched_sizes: bool = False,
        force_download: bool = False,
        local_files_only: bool = False,
        token: Optional[Union[str, bool]] = None,
        revision: str = "main",
        use_safetensors: bool = None,
        **kwargs,
    ):
    def _load_pretrained_model(
        cls,
        model,
        state_dict,
        loaded_keys,
        resolved_archive_file,
        pretrained_model_name_or_path,
        ignore_mismatched_sizes=False,
        sharded_metadata=None,
        _fast_init=True,
        low_cpu_mem_usage=False,
        device_map=None,
        offload_folder=None,
        offload_state_dict=None,
        dtype=None,
        hf_quantizer=None,
        keep_in_fp32_modules=None,
    ):
        """
        Load a pretrained model using the provided state_dict and configuration.

        Args:
            model: The model to load the pretrained weights into.
            state_dict: The pretrained weights as a state dictionary.
            loaded_keys: Keys of the loaded state_dict.
            resolved_archive_file: Path to the resolved archive file.
            pretrained_model_name_or_path: Name or path of the pretrained model.
            ignore_mismatched_sizes: If True, ignore mismatched tensor sizes.
            sharded_metadata: Metadata related to sharding.
            _fast_init: Whether to perform fast initialization.
            low_cpu_mem_usage: If True, use low CPU memory mode.
            device_map: Mapping of devices.
            offload_folder: Folder for offloading.
            offload_state_dict: State dictionary for offloading.
            dtype: Data type of the model weights.
            hf_quantizer: Quantizer for Hugging Face models.
            keep_in_fp32_modules: Modules to keep in FP32 format.

        Returns:
            None
        """

        # Implementation of pretrained model loading logic
        _move_model_to_meta(model, loaded_keys, "")  # Move model to meta device

        # Load state_dict from resolved archive file
        state_dict = load_state_dict(resolved_archive_file)

        # Placeholder for expected keys handling
        expected_keys = loaded_keys  # TODO: Replace with proper expected keys handling

        # Load state_dict into meta model and retrieve error messages if any
        error_msgs = _load_state_dict_into_meta_model(
            model,
            state_dict,
            loaded_keys,
            "",
            expected_keys=expected_keys,
            hf_quantizer=hf_quantizer,
        )

        return error_msgs

    def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
        """
        Retrieve modules from the model based on provided module names.

        Args:
            names: List of module names to retrieve.
            add_prefix: Whether to add a prefix to retrieved module names.
            remove_prefix: Whether to remove a prefix from retrieved module names.

        Returns:
            List: Retrieved modules based on the provided names.
        """

        # Create a set of module keys from the provided names
        module_keys = {".".join(key.split(".")[:-1]) for key in names}

        # Special case handling for torch.nn.ParameterList
        module_keys = module_keys.union(
            {".".join(key.split(".")[:-2]) for key in names if len(key) > 0 and key[-1].isdigit()}
        )

        retrieved_modules = []

        # Retrieve modules that match the module keys
        for name, module in self.named_modules():
            if remove_prefix:
                _prefix = f"{self.base_model_prefix}."
                name = name[len(_prefix) :] if name.startswith(_prefix) else name
            elif add_prefix:
                name = ".".join([self.base_model_prefix, name]) if len(name) > 0 else self.base_model_prefix

            if name in module_keys:
                retrieved_modules.append(module)

        return retrieved_modules

    @staticmethod
    def _load_pretrained_model_low_mem(
        model, loaded_state_dict_keys, resolved_archive_file, start_prefix="", hf_quantizer=None
    ):
        """
        This is an experimental function that loads the model using ~1.x model size CPU memory

        Before you call it do:

        1. save which state_dict keys are available
        2. drop state_dict before model is created, since the latter takes 1x model size memory

        Here then we continue:

        3. switch to the meta device all params/buffers that are going to be replaced from the loaded state_dict
        4. load state_dict 2nd time
        5. replace the params/buffers from the state_dict

        Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed. To
        handle bitsandbytes, needs non-empty hf_quantizer argument.
        """
        _move_model_to_meta(model, loaded_state_dict_keys, start_prefix)  # Move model to meta device
        state_dict = load_state_dict(resolved_archive_file)  # Load state_dict from archive file
        expected_keys = loaded_state_dict_keys  # Placeholder for expected keys
        error_msgs = _load_state_dict_into_meta_model(
            model,
            state_dict,
            loaded_state_dict_keys,
            start_prefix,
            expected_keys=expected_keys,
            hf_quantizer=hf_quantizer,
        )
        return error_msgs
    # 注册自定义模型类到指定的自动模型类中
    def register_for_auto_class(cls, auto_class="AutoModel"):
        """
        Register this class with a given auto class. This should only be used for custom models as the ones in the
        library are already mapped with an auto class.

        <Tip warning={true}>

        This API is experimental and may have some slight breaking changes in the next releases.

        </Tip>

        Args:
            auto_class (`str` or `type`, *optional*, defaults to `"AutoModel"`):
                The auto class to register this new model with.
        """
        # 如果 `auto_class` 不是字符串,则将其转换为类名字符串
        if not isinstance(auto_class, str):
            auto_class = auto_class.__name__

        # 导入自动模型模块
        import transformers.models.auto as auto_module

        # 检查是否存在给定名称的自动模型类
        if not hasattr(auto_module, auto_class):
            raise ValueError(f"{auto_class} is not a valid auto class.")

        # 将自动模型类名赋值给当前类的 `_auto_class` 属性
        cls._auto_class = auto_class

    # 将模型转换为 BetterTransformer
    def to_bettertransformer(self) -> "PreTrainedModel":
        """
        Converts the model to use [PyTorch's native attention
        implementation](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html), integrated to
        Transformers through [Optimum library](https://huggingface.co/docs/optimum/bettertransformer/overview). Only a
        subset of all Transformers models are supported.

        PyTorch's attention fastpath allows to speed up inference through kernel fusions and the use of [nested
        tensors](https://pytorch.org/docs/stable/nested.html). Detailed benchmarks can be found in [this blog
        post](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2).

        Returns:
            [`PreTrainedModel`]: The model converted to BetterTransformer.
        """
        # 检查是否安装了 Optimum 库,如果没有则抛出 ImportError
        if not is_optimum_available():
            raise ImportError("The package `optimum` is required to use Better Transformer.")

        # 导入 Optimum 库的版本信息
        from optimum.version import __version__ as optimum_version

        # 检查 Optimum 库的版本是否满足要求,如果不满足则抛出 ImportError
        if version.parse(optimum_version) < version.parse("1.7.0"):
            raise ImportError(
                f"Please install optimum>=1.7.0 to use Better Transformer. The version {optimum_version} was found."
            )

        # 导入 BetterTransformer 类
        from optimum.bettertransformer import BetterTransformer

        # 使用 BetterTransformer 类将当前模型转换为 BetterTransformer
        return BetterTransformer.transform(self)
    def reverse_bettertransformer(self):
        """
        Reverts the transformation from [`~PreTrainedModel.to_bettertransformer`] so that the original modeling is
        used, for example in order to save the model.

        Returns:
            [`PreTrainedModel`]: The model converted back to the original modeling.
        """
        # 检查是否已安装 optimum 包,否则抛出 ImportError
        if not is_optimum_available():
            raise ImportError("The package `optimum` is required to use Better Transformer.")

        # 导入 optimum 版本信息,并检查是否符合最低要求版本 1.7.0
        from optimum.version import __version__ as optimum_version

        if version.parse(optimum_version) < version.parse("1.7.0"):
            raise ImportError(
                f"Please install optimum>=1.7.0 to use Better Transformer. The version {optimum_version} was found."
            )

        # 导入 BetterTransformer 类并调用其 reverse 方法,将模型转换回原始建模
        from optimum.bettertransformer import BetterTransformer

        return BetterTransformer.reverse(self)

    def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask):
        """
        Shows a one-time warning if the input_ids appear to contain padding and no attention mask was given.
        """

        # 在 TorchFX 代理或 Torch 脚本跟踪时跳过检查
        if is_torch_fx_proxy(input_ids) or torch.jit.is_tracing() or is_torchdynamo_compiling():
            return

        # 如果 attention_mask 不为 None 或者模型配置中 pad_token_id 为 None,则跳过警告
        if (attention_mask is not None) or (self.config.pad_token_id is None):
            return

        # 仅检查输入中的第一个和最后一个 token 是否包含 pad_token_id,以减少开销
        if self.config.pad_token_id in input_ids[:, [-1, 0]]:
            warn_string = (
                "We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See "
                "https://huggingface.co/docs/transformers/troubleshooting"
                "#incorrect-output-when-padding-tokens-arent-masked."
            )

            # 如果 pad_token_id 等于 BOS、EOS 或 SEP 中的任何一个,显示额外警告信息
            if (
                (self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id)
                or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id)
                or (self.config.sep_token_id is not None and self.config.sep_token_id == self.config.pad_token_id)
            ):
                warn_string += (
                    f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical "
                    f"to the `bos_token_id` ({self.config.bos_token_id}), `eos_token_id` ({self.config.eos_token_id}), "
                    f"or the `sep_token_id` ({self.config.sep_token_id}), and your input is not padded."
                )

            # 发出一次性的警告,用 logger 记录
            logger.warning_once(warn_string)

    @property
    # 发出警告,提醒用户 `_is_quantized_training_enabled` 函数将在 transformers 4.39.0 版本中弃用,建议使用 `model.hf_quantizer.is_trainable` 替代
    warnings.warn(
        "`_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead",
        FutureWarning,
    )

    # 检查当前对象是否具有属性 `hf_quantizer`
    if not hasattr(self, "hf_quantizer"):
        # 如果没有 `hf_quantizer` 属性,则返回 False
        return False

    # 返回 `hf_quantizer` 对象的 `is_trainable` 属性值
    return self.hf_quantizer.is_trainable
# 将 PreTrainedModel 类的 push_to_hub 方法复制一份,赋值给自身,以备后续修改
PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)

# 如果 push_to_hub 方法有文档字符串,则格式化文档字符串,插入模型、AutoModel 和模型文件的相关信息
if PreTrainedModel.push_to_hub.__doc__ is not None:
    PreTrainedModel.push_to_hub.__doc__ = PreTrainedModel.push_to_hub.__doc__.format(
        object="model", object_class="AutoModel", object_files="model file"
    )

# 定义一个计算 SQuAD 起始位置 logit 的神经网络模块
class PoolerStartLogits(nn.Module):
    """
    Compute SQuAD start logits from sequence hidden states.

    Args:
        config ([`PretrainedConfig`]):
            The config used by the model, will be used to grab the `hidden_size` of the model.
    """

    def __init__(self, config: PretrainedConfig):
        super().__init__()
        # 使用全连接层将隐藏状态映射到一个数值
        self.dense = nn.Linear(config.hidden_size, 1)

    def forward(
        self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None
    ) -> torch.FloatTensor:
        """
        Args:
            hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
                The final hidden states of the model.
            p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
                Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
                should be masked.

        Returns:
            `torch.FloatTensor`: The start logits for SQuAD.
        """
        # 使用全连接层计算起始位置的 logit,并将结果压缩维度
        x = self.dense(hidden_states).squeeze(-1)

        if p_mask is not None:
            # 根据模型参数的数据类型,对无效位置的 logit 进行处理,使用不同的填充值
            if get_parameter_dtype(self) == torch.float16:
                x = x * (1 - p_mask) - 65500 * p_mask
            else:
                x = x * (1 - p_mask) - 1e30 * p_mask

        return x


# 定义一个计算 SQuAD 结束位置 logit 的神经网络模块
class PoolerEndLogits(nn.Module):
    """
    Compute SQuAD end logits from sequence hidden states.

    Args:
        config ([`PretrainedConfig`]):
            The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
            to use.
    """

    def __init__(self, config: PretrainedConfig):
        super().__init__()
        # 第一个全连接层将两倍的隐藏状态映射到隐藏大小
        self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
        # 激活函数为双曲正切函数
        self.activation = nn.Tanh()
        # 使用 LayerNorm 对隐藏大小进行归一化
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 第二个全连接层将隐藏状态映射到一个数值
        self.dense_1 = nn.Linear(config.hidden_size, 1)

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        start_states: Optional[torch.FloatTensor] = None,
        start_positions: Optional[torch.LongTensor] = None,
        p_mask: Optional[torch.FloatTensor] = None,
    ) -> torch.FloatTensor:
        """
        Args:
            hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
                模型的最终隐藏状态。
            start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
                标记范围内第一个标记的隐藏状态。
            start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
                标记范围内第一个标记的位置。
            p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
                用于无效位置的掩码,如查询和特殊符号(PAD、SEP、CLS)。1.0 表示该标记应被屏蔽。

        <Tip>

        `start_states` 或 `start_positions` 中的一个必须不为 `None`。如果两者都设置了,`start_positions` 会覆盖 `start_states`。

        </Tip>

        Returns:
            `torch.FloatTensor`: SQuAD 任务的结束位置logits。
        """
        assert (
            start_states is not None or start_positions is not None
        ), "One of start_states, start_positions should be not None"
        if start_positions is not None:
            slen, hsz = hidden_states.shape[-2:]
            start_positions = start_positions[:, None, None].expand(-1, -1, hsz)  # shape (bsz, 1, hsz)
            start_states = hidden_states.gather(-2, start_positions)  # shape (bsz, 1, hsz)
            start_states = start_states.expand(-1, slen, -1)  # shape (bsz, slen, hsz)

        x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
        x = self.activation(x)
        x = self.LayerNorm(x)
        x = self.dense_1(x).squeeze(-1)

        if p_mask is not None:
            if get_parameter_dtype(self) == torch.float16:
                x = x * (1 - p_mask) - 65500 * p_mask
            else:
                x = x * (1 - p_mask) - 1e30 * p_mask

        return x
class PoolerAnswerClass(nn.Module):
    """
    Compute SQuAD 2.0 answer class from classification and start tokens hidden states.

    Args:
        config ([`PretrainedConfig`]):
            The config used by the model, will be used to grab the `hidden_size` of the model.
    """

    def __init__(self, config):
        super().__init__()
        # Initialize a linear layer that maps concatenated hidden states to the hidden size
        self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
        # Activation function for the dense layer
        self.activation = nn.Tanh()
        # Final linear layer to compute logits for SQuAD answer class
        self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        start_states: Optional[torch.FloatTensor] = None,
        start_positions: Optional[torch.LongTensor] = None,
        cls_index: Optional[torch.LongTensor] = None,
    ) -> torch.FloatTensor:
        """
        Args:
            hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
                The final hidden states of the model.
            start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
                The hidden states of the first tokens for the labeled span.
            start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
                The position of the first token for the labeled span.
            cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
                Position of the CLS token for each sentence in the batch. If `None`, takes the last token.

        <Tip>

        One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
        `start_states`.

        </Tip>

        Returns:
            `torch.FloatTensor`: The SQuAD 2.0 answer class.
        """
        # Ensure the hidden state size is retrieved correctly
        hsz = hidden_states.shape[-1]
        # Ensure that either start_states or start_positions is provided
        assert (
            start_states is not None or start_positions is not None
        ), "One of start_states, start_positions should be not None"

        # If start_positions is provided, derive start_states from hidden_states
        if start_positions is not None:
            start_positions = start_positions[:, None, None].expand(-1, -1, hsz)  # shape (bsz, 1, hsz)
            start_states = hidden_states.gather(-2, start_positions).squeeze(-2)  # shape (bsz, hsz)

        # If cls_index is provided, derive cls_token_state from hidden_states
        if cls_index is not None:
            cls_index = cls_index[:, None, None].expand(-1, -1, hsz)  # shape (bsz, 1, hsz)
            cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2)  # shape (bsz, hsz)
        else:
            # Otherwise, take the last token's hidden state as cls_token_state
            cls_token_state = hidden_states[:, -1, :]  # shape (bsz, hsz)

        # Concatenate start_states and cls_token_state, apply dense layers and activation
        x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
        x = self.activation(x)
        # Apply final linear layer and squeeze to get SQuAD answer class logits
        x = self.dense_1(x).squeeze(-1)

        return x


@dataclass
class SquadHeadOutput(ModelOutput):
    """
    Base class for outputs of question answering models using a [`~modeling_utils.SQuADHead`].
    """
    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
            Classification loss as the sum of start token, end token (and is_impossible if provided) classification
            losses.
        start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
            Log probabilities for the top config.start_n_top start token possibilities (beam-search).
        start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
            Indices for the top config.start_n_top start token possibilities (beam-search).
        end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
            Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities
            (beam-search).
        end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
            Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search).
        cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
            Log probabilities for the `is_impossible` label of the answers.

    """

    # Optional: 可选参数,以下各变量用于存储模型的不同输出结果,如果未提供`start_positions`或`end_positions`,则可能为空
    loss: Optional[torch.FloatTensor] = None
    start_top_log_probs: Optional[torch.FloatTensor] = None
    start_top_index: Optional[torch.LongTensor] = None
    end_top_log_probs: Optional[torch.FloatTensor] = None
    end_top_index: Optional[torch.LongTensor] = None
    cls_logits: Optional[torch.FloatTensor] = None
class SQuADHead(nn.Module):
    r"""
    A SQuAD head inspired by XLNet.

    Args:
        config ([`PretrainedConfig`]):
            The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
            to use.
    """

    def __init__(self, config):
        super().__init__()
        # 初始化 SQuAD 头部模块,设置起始和结束位置的 top k 值
        self.start_n_top = config.start_n_top
        self.end_n_top = config.end_n_top

        # 初始化起始位置的 logits 池化层
        self.start_logits = PoolerStartLogits(config)
        # 初始化结束位置的 logits 池化层
        self.end_logits = PoolerEndLogits(config)
        # 初始化答案分类的池化层
        self.answer_class = PoolerAnswerClass(config)

    @replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig)
    def forward(
        self,
        hidden_states: torch.FloatTensor,
        start_positions: Optional[torch.LongTensor] = None,
        end_positions: Optional[torch.LongTensor] = None,
        cls_index: Optional[torch.LongTensor] = None,
        is_impossible: Optional[torch.LongTensor] = None,
        p_mask: Optional[torch.FloatTensor] = None,
        return_dict: bool = False,
    ):
        """
        Perform forward pass of the SQuAD head module.

        Args:
            hidden_states (torch.FloatTensor): Sequence of hidden states.
            start_positions (Optional[torch.LongTensor]): Tensor of start positions for the answer spans.
            end_positions (Optional[torch.LongTensor]): Tensor of end positions for the answer spans.
            cls_index (Optional[torch.LongTensor]): Index of the classification token if used.
            is_impossible (Optional[torch.LongTensor]): Tensor indicating if the question is unanswerable.
            p_mask (Optional[torch.FloatTensor]): Mask indicating which elements in the input sequence should not be attended to.
            return_dict (bool): Whether to return a dictionary.

        Returns:
            SquadHeadOutput: Output of the SQuAD head module.
        """
        # 实现 SQuAD 头部的前向传播逻辑
        # 这里应该包含具体的模型逻辑,根据输入参数计算输出
        pass


class SequenceSummary(nn.Module):
    r"""
    Compute a single vector summary of a sequence hidden states.

    Args:
        config ([`PretrainedConfig`]):
            The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
            config class of your model for the default values it uses):

            - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:

                - `"last"` -- Take the last token hidden state (like XLNet)
                - `"first"` -- Take the first token hidden state (like Bert)
                - `"mean"` -- Take the mean of all tokens hidden states
                - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
                - `"attn"` -- Not implemented now, use multi-head attention

            - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
            - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
              (otherwise to `config.hidden_size`).
            - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
              another string or `None` will add no activation.
            - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
            - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
    """
    # 初始化函数,接受一个预训练配置对象作为参数
    def __init__(self, config: PretrainedConfig):
        # 调用父类的初始化方法
        super().__init__()

        # 从配置对象中获取摘要类型,如果未指定则默认为"last"
        self.summary_type = getattr(config, "summary_type", "last")
        
        # 如果摘要类型为"attn",则抛出未实现错误,建议使用标准的多头注意力模块
        if self.summary_type == "attn":
            raise NotImplementedError

        # 初始化摘要为一个Identity对象,这个对象在前向传播中不做任何操作
        self.summary = Identity()

        # 如果配置中指定了使用投影进行摘要操作
        if hasattr(config, "summary_use_proj") and config.summary_use_proj:
            # 如果配置中指定了将投影映射到标签并且标签数大于0,则num_classes为标签数
            if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
                num_classes = config.num_labels
            # 否则num_classes为隐藏大小
            else:
                num_classes = config.hidden_size
            # 使用线性层将隐藏状态映射到num_classes维度
            self.summary = nn.Linear(config.hidden_size, num_classes)

        # 根据配置中指定的激活函数字符串,获取对应的激活函数或者使用Identity作为激活函数
        activation_string = getattr(config, "summary_activation", None)
        self.activation: Callable = get_activation(activation_string) if activation_string else Identity()

        # 初始化第一个dropout层为Identity对象,如果配置中指定了第一个dropout的概率,则使用nn.Dropout进行初始化
        self.first_dropout = Identity()
        if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
            self.first_dropout = nn.Dropout(config.summary_first_dropout)

        # 初始化最后一个dropout层为Identity对象,如果配置中指定了最后一个dropout的概率,则使用nn.Dropout进行初始化
        self.last_dropout = Identity()
        if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
            self.last_dropout = nn.Dropout(config.summary_last_dropout)
    ) -> torch.FloatTensor:
        """
        Compute a single vector summary of a sequence hidden states.

        Args:
            hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
                The hidden states of the last layer.
            cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
                Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.

        Returns:
            `torch.FloatTensor`: The summary of the sequence hidden states.
        """
        # 根据选择的汇总类型进行汇总操作
        if self.summary_type == "last":
            # 取每个序列的最后一个隐藏状态
            output = hidden_states[:, -1]
        elif self.summary_type == "first":
            # 取每个序列的第一个隐藏状态
            output = hidden_states[:, 0]
        elif self.summary_type == "mean":
            # 对整个序列的隐藏状态进行平均
            output = hidden_states.mean(dim=1)
        elif self.summary_type == "cls_index":
            if cls_index is None:
                # 如果没有提供 cls_index,则默认选择每个序列的最后一个 token 作为分类 token
                cls_index = torch.full_like(
                    hidden_states[..., :1, :],
                    hidden_states.shape[-2] - 1,
                    dtype=torch.long,
                )
            else:
                # 将 cls_index 扩展为与 hidden_states 相同的维度
                cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
                cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
            # 从 hidden_states 中根据 cls_index 提取对应的隐藏状态
            output = hidden_states.gather(-2, cls_index).squeeze(-2)  # shape (bsz, XX, hidden_size)
        elif self.summary_type == "attn":
            # 如果选择了注意力汇总类型,目前尚未实现此功能,抛出未实现错误
            raise NotImplementedError

        # 对输出进行第一个 dropout 操作
        output = self.first_dropout(output)
        # 将汇总后的向量传递给汇总层
        output = self.summary(output)
        # 对汇总后的向量应用激活函数
        output = self.activation(output)
        # 对最终输出进行最后一个 dropout 操作
        output = self.last_dropout(output)

        return output
# 递归地解包模型,从可能的容器中解开(如在分布式训练中使用的容器)。
def unwrap_model(model: nn.Module) -> nn.Module:
    """
    Recursively unwraps a model from potential containers (as used in distributed training).

    Args:
        model (`torch.nn.Module`): The model to unwrap.
    """
    # 如果模型具有 `module` 属性,说明模型被包装,需要递归解包
    if hasattr(model, "module"):
        return unwrap_model(model.module)
    else:
        return model


# 展开设备映射,返回对应参数名到设备的映射。
def expand_device_map(device_map, param_names, start_prefix):
    """
    Expand a device map to return the correspondance parameter name to device.
    """
    # 创建新的设备映射字典
    new_device_map = {}
    # 过滤参数名列表,仅保留以给定前缀开头的参数名,并去除前缀
    param_names = [p[len(start_prefix) :] for p in param_names if p.startswith(start_prefix)]
    # 遍历设备映射,更新新的设备映射字典
    for module, device in device_map.items():
        new_device_map.update(
            # 对于每个参数名,如果与模块名匹配,或者以模块名加点开头,或者模块名为空,则更新映射
            {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
        )
    return new_device_map


# 获取仅包含已转移到磁盘的权重的碎片文件列表。
def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
    """
    Returns the list of shard files containing only weights offloaded to disk.
    """
    # 从权重映射中提取与给定前缀匹配的权重名称及其对应的文件名
    weight_map = {
        p[len(start_prefix) :]: v for p, v in sharded_metadata["weight_map"].items() if p.startswith(start_prefix)
    }
    # 创建一个默认值为列表的字典,用于存储每个文件的设备列表
    files_content = collections.defaultdict(list)
    # 遍历权重映射,为每个权重名称找到对应的设备列表并存储到 files_content 中
    for weight_name, filename in weight_map.items():
        while len(weight_name) > 0 and weight_name not in device_map:
            weight_name = ".".join(weight_name.split(".")[:-1])
        files_content[filename].append(device_map[weight_name])

    # 返回仅包含磁盘设备的文件列表
    return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
posted @ 2024-06-29 15:48  绝不原创的飞龙  阅读(77)  评论(0编辑  收藏  举报