Transformers-源码解析-七十五-

Transformers 源码解析(七十五)

.\models\mistral\modeling_mistral.py

# 设置文件编码为UTF-8
# 版权声明和许可信息,基于Mistral AI和HuggingFace Inc.团队的代码
# 本代码基于EleutherAI的GPT-NeoX库和GPT-NeoX和OPT实现进行了修改,以适应Meta AI团队训练模型时的架构差异
# 根据Apache许可证2.0版授权使用本文件,除非符合许可证要求,否则不得使用此文件

""" PyTorch Mistral model. """
# 导入Python标准库和第三方库
import inspect
import math
import warnings
from typing import List, Optional, Tuple, Union

# 导入PyTorch相关库
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

# 导入本地模块
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
    logging,
    replace_return_docstrings,
)
from .configuration_mistral import MistralConfig

# 检查是否支持Flash Attention 2
if is_flash_attn_2_available():
    from flash_attn import flash_attn_func, flash_attn_varlen_func
    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa

    _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)

# 获取日志记录器
logger = logging.get_logger(__name__)

# 文档字符串中的配置信息
_CONFIG_FOR_DOC = "MistralConfig"

# 从transformers.models.llama.modeling_llama._get_unpad_data复制而来
# 根据注意力掩码获取未填充数据的函数
def _get_unpad_data(attention_mask):
    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
    max_seqlen_in_batch = seqlens_in_batch.max().item()
    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
    return (
        indices,
        cu_seqlens,
        max_seqlen_in_batch,
    )

# 从transformers.models.llama.modeling_llama.LlamaRMSNorm中复制而来,将Llama->Mistral
# MistralRMSNorm类,继承自nn.Module,用于实现Mistral模型的RMS归一化
class MistralRMSNorm(nn.Module):
    # 初始化函数,定义了一个自定义的归一化层 MistralRMSNorm,功能类似于 T5 模型的 LayerNorm
    def __init__(self, hidden_size, eps=1e-6):
        """
        MistralRMSNorm is equivalent to T5LayerNorm
        """
        # 调用父类的初始化方法
        super().__init__()
        # 初始化权重参数,这些参数将被优化
        self.weight = nn.Parameter(torch.ones(hidden_size))
        # 定义方差中添加的小常数值
        self.variance_epsilon = eps

    # 前向传播函数,计算归一化后的隐藏状态
    def forward(self, hidden_states):
        # 记录输入的数据类型
        input_dtype = hidden_states.dtype
        # 将输入的隐藏状态转换为 float32 类型
        hidden_states = hidden_states.to(torch.float32)
        # 计算隐藏状态的方差
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        # 对隐藏状态进行归一化处理,通过除以标准差加上一个小常数来实现数值稳定性
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        # 返回经过权重调节后的归一化隐藏状态
        return self.weight * hidden_states.to(input_dtype)
# 从transformers.models.llama.modeling_llama.LlamaRotaryEmbedding复制并修改为MistralRotaryEmbedding
# TODO @Arthur 在静态缓存后不再从LLama复制
class MistralRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        # 计算频率倒数,用于正弦和余弦计算
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # 为了使`torch.jit.trace`正常工作,在这里构建缓存
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

        # 计算正弦和余弦缓存
        freqs = torch.outer(t, self.inv_freq)
        # 与论文中不同,但使用不同的排列顺序以获得相同的计算结果
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )


# 从transformers.models.llama.modeling_llama.rotate_half复制的函数
def rotate_half(x):
    """对输入的隐藏维度的一半进行旋转。"""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# 从transformers.models.llama.modeling_llama.apply_rotary_pos_emb复制并修改
# TODO @Arthur 在静态缓存后不再从LLama复制
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    """将Rotary位置嵌入应用到查询和键张量中。"""
    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    # Unsqueezing cos and sin tensors along the specified dimension to match q and k tensor shapes
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    # Applying rotary position embedding to q and k tensors
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
class MistralMLP(nn.Module):
    # MistralMLP 类,用于定义一个 MLP 模型
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size  # 从配置中获取隐藏层大小
        self.intermediate_size = config.intermediate_size  # 从配置中获取中间层大小
        # 创建一个线性层,用于门控投影,输入大小为 hidden_size,输出大小为 intermediate_size,无偏置
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        # 创建一个线性层,用于上投影,输入大小为 hidden_size,输出大小为 intermediate_size,无偏置
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        # 创建一个线性层,用于下投影,输入大小为 intermediate_size,输出大小为 hidden_size,无偏置
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        # 根据配置中的隐藏激活函数选择对应的激活函数
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        # 前向传播函数,利用门控投影、激活函数、上投影计算最终输出
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    # 将 hidden_states 在维度 1 上重复 n_rep 次,实现扩展
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    # 扩展 hidden_states 维度
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class MistralAttention(nn.Module):
    """
    Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
    and "Generating Long Sequences with Sparse Transformers".
    """
    # MistralAttention 类,实现多头注意力机制,基于 'Attention Is All You Need' 的方法,并支持滑动窗口注意力
    # 初始化函数,用于创建一个新的Mistral注意力层对象
    def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None):
        # 调用父类的初始化方法
        super().__init__()
        # 将传入的配置对象保存到实例变量中
        self.config = config
        # 保存传入的层索引到实例变量中
        self.layer_idx = layer_idx
        # 如果未传入层索引,发出警告,并说明在使用缓存时可能会导致错误
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )
    
        # 从配置中获取隐藏单元大小并保存到实例变量中
        self.hidden_size = config.hidden_size
        # 从配置中获取注意力头的数量并保存到实例变量中
        self.num_heads = config.num_attention_heads
        # 计算每个注意力头的维度并保存到实例变量中
        self.head_dim = self.hidden_size // self.num_heads
        # 从配置中获取键值头的数量并保存到实例变量中
        self.num_key_value_heads = config.num_key_value_heads
        # 计算每个键值头的组数并保存到实例变量中
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        # 从配置中获取最大位置嵌入数并保存到实例变量中
        self.max_position_embeddings = config.max_position_embeddings
        # 从配置中获取Rope Theta并保存到实例变量中
        self.rope_theta = config.rope_theta
        # 设置是否因果化为True,并保存到实例变量中
        self.is_causal = True
        # 从配置中获取注意力丢弃率并保存到实例变量中
        self.attention_dropout = config.attention_dropout
    
        # 检查隐藏单元大小是否能被注意力头的数量整除,否则抛出值错误
        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        
        # 创建查询投影矩阵,将隐藏状态映射到注意力头维度的空间
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        # 创建键投影矩阵,将隐藏状态映射到键值头维度的空间
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        # 创建值投影矩阵,将隐藏状态映射到键值头维度的空间
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        # 创建输出投影矩阵,将注意力头的结果映射回隐藏状态的空间
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
    
        # 创建旋转嵌入对象,用于引入循环旋转机制以捕捉序列位置信息
        self.rotary_emb = MistralRotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )
    
    # 定义形状函数,用于调整张量的形状以适应注意力计算的需要
    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
    
    # 前向传播函数,执行Mistral注意力层的计算过程
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs,
# 定义一个名为 MistralFlashAttention2 的类,继承自 MistralAttention 类。
# 这个类是 Mistral flash attention 模块,其权重继承自 MistralAttention,没有进行修改。
class MistralFlashAttention2(MistralAttention):
    """
    Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays
    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
    flash attention and deal with padding tokens in case the input contains any of them.
    """

    # 从 transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ 中复制而来
    # 初始化函数,接受任意参数并传递给父类的初始化函数
    def __init__(self, *args, **kwargs):
        # 调用父类的初始化函数
        super().__init__(*args, **kwargs)

        # TODO: 在 Flash Attention for RoCm 更新到 2.1 后应移除这段代码。
        # flash_attn<2.1 生成左上角对齐的因果蒙版,而这里需要的是右下角对齐,默认情况下 flash_attn>=2.1 已经实现了这个变更。这个属性用于处理这个差异。
        # 参考链接:https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0
        # 需要注意的是,在 flash_attn<2.1 中,当 q_seqlen != k_seqlen(除了 q_seqlen == 1 的情况)时会生成错误的蒙版(左上角)。
        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

    # 正向传播函数
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs,
    ):
        # 该函数定义了模块的正向传播逻辑,接受多个参数,其中 hidden_states 是必传的 Tensor 类型参数。
        # attention_mask, position_ids, past_key_value, output_attentions, use_cache 等参数是可选的。
        # **kwargs 允许传递任意额外的关键字参数。

    # 私有方法 _flash_attention_forward 的定义
    def _flash_attention_forward(
        self,
        query_states,
        key_states,
        value_states,
        attention_mask,
        query_length,
        dropout=0.0,
        softmax_scale=None,
        use_sliding_windows=False,
    def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
        batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape

        # 如果键值序列长度与注意力掩码长度不一致,需要调整注意力掩码
        if kv_seq_len != attention_mask.shape[-1]:
            attention_mask_num_tokens = attention_mask.shape[-1]
            attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]

        # 获取未填充数据的索引和相关的序列长度信息
        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)

        # 根据索引重新组织键和值的层,以便与查询层对齐
        key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
        value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)

        if query_length == kv_seq_len:
            # 如果查询长度与键值序列长度相同,则直接使用相同的索引和序列长度信息
            query_layer = index_first_axis(
                query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
            )
            cu_seqlens_q = cu_seqlens_k
            max_seqlen_in_batch_q = max_seqlen_in_batch_k
            indices_q = indices_k
        elif query_length == 1:
            # 如果查询长度为1,特殊处理序列长度信息和查询层
            max_seqlen_in_batch_q = 1
            cu_seqlens_q = torch.arange(
                batch_size + 1, dtype=torch.int32, device=query_layer.device
            )  # 这里有一个内存复制操作,效率较低。
            indices_q = cu_seqlens_q[:-1]
            query_layer = query_layer.squeeze(1)
        else:
            # 对于其他情况,假设左填充,调整注意力掩码,然后调用unpad_input函数处理查询层
            attention_mask = attention_mask[:, -query_length:]
            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

        # 返回更新后的查询层、键层、值层,以及相关的索引和序列长度信息
        return (
            query_layer,
            key_layer,
            value_layer,
            indices_q,
            (cu_seqlens_q, cu_seqlens_k),
            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
        )
# 从 transformers.models.llama.modeling_llama.LlamaSdpaAttention 复制代码并将 LLama 改为 Mistral
# TODO @Arthur 在静态缓存后不再从 LLama 复制代码
class MistralSdpaAttention(MistralAttention):
    """
    Mistral 注意力模块使用 torch.nn.functional.scaled_dot_product_attention。该模块继承自
    `MistralAttention`,模块的权重保持不变。唯一的改动在于前向传播部分以适应 SDPA API。
    """

    # 改编自 MistralAttention.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
    ):
        """
        前向传播方法用于执行注意力计算。

        Args:
            hidden_states (torch.Tensor): 输入的隐藏状态张量。
            attention_mask (Optional[torch.Tensor], optional): 注意力掩码张量,默认为None。
            position_ids (Optional[torch.LongTensor], optional): 位置标识符张量,默认为None。
            past_key_value (Optional[Cache], optional): 过去的键值对缓存,默认为None。
            output_attentions (bool, optional): 是否输出注意力权重,默认为False。
            use_cache (bool, optional): 是否使用缓存,默认为False。

        Returns:
            根据模块的具体实现不同,返回不同的结果。
        """
        # 实现具体的注意力计算逻辑
        # (具体实现部分可能包括 scaled_dot_product_attention 的调用或其它实现方式)

MISTRAL_ATTENTION_CLASSES = {
    "eager": MistralAttention,
    "flash_attention_2": MistralFlashAttention2,
    "sdpa": MistralSdpaAttention,
}

class MistralDecoderLayer(nn.Module):
    def __init__(self, config: MistralConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size

        # 初始化自注意力机制,根据配置选择不同的实现类
        self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)

        # MLP 部分的初始化
        self.mlp = MistralMLP(config)

        # 输入层归一化,使用 MistralRMSNorm 类进行初始化
        self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        # 注意力后归一化,同样使用 MistralRMSNorm 类进行初始化
        self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        **kwargs,
    ):
        """
        Mistral 解码器层的前向传播方法。

        Args:
            hidden_states (torch.Tensor): 输入的隐藏状态张量。
            attention_mask (Optional[torch.Tensor], optional): 注意力掩码张量,默认为None。
            position_ids (Optional[torch.LongTensor], optional): 位置标识符张量,默认为None。
            past_key_value (Optional[Tuple[torch.Tensor]], optional): 过去的键值对缓存,默认为None。
            output_attentions (Optional[bool], optional): 是否输出注意力权重,默认为False。
            use_cache (Optional[bool], optional): 是否使用缓存,默认为False。
            **kwargs: 其他可选参数。

        Returns:
            根据模块的具体实现不同,返回不同的结果。
        """
        # 实现具体的前向传播逻辑
        # (具体实现部分包括自注意力、MLP处理和归一化处理等步骤)
        ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        # 如果传入了 `padding_mask` 参数,发出警告,提示在 v4.37 版本中将移除,请使用 `attention_mask` 替代
        if "padding_mask" in kwargs:
            warnings.warn(
                "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
            )
        
        """
        Args:
            hidden_states (`torch.FloatTensor`): 输入到层的张量,形状为 `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*): 注意力掩码,形状为 `(batch, sequence_length)`,
                其中填充元素由0表示。
            output_attentions (`bool`, *optional*):
                是否返回所有注意力层的注意力张量。查看返回张量中的 `attentions` 以获取更多详细信息。
            use_cache (`bool`, *optional*):
                如果设置为 `True`,将返回 `past_key_values` 键值状态,可用于加速解码(参见 `past_key_values`)。
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): 缓存的过去键和值投影状态
        """

        residual = hidden_states

        # 输入层归一化
        hidden_states = self.input_layernorm(hidden_states)

        # 自注意力
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
        )
        # 残差连接
        hidden_states = residual + hidden_states

        # 全连接层归一化
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        # 残差连接
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs
# 定义一个长文档字符串,描述了 MistralPreTrainedModel 类的继承关系和使用方法
MISTRAL_START_DOCSTRING = r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`MistralConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

# 为 MistralPreTrainedModel 类添加文档注释,指明它是一个输出原始隐藏状态的模型,没有特定的输出层
@add_start_docstrings(
    "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
    MISTRAL_START_DOCSTRING,
)
class MistralPreTrainedModel(PreTrainedModel):
    # 指定 MistralConfig 作为配置类
    config_class = MistralConfig
    # 基础模型前缀名称为 "model"
    base_model_prefix = "model"
    # 支持梯度检查点
    supports_gradient_checkpointing = True
    # 不进行模块拆分的模块列表
    _no_split_modules = ["MistralDecoderLayer"]
    # 跳过设备放置的键名 "past_key_values"
    _skip_keys_device_placement = "past_key_values"
    # 支持 Flash Attention 2
    _supports_flash_attn_2 = True
    # 支持 SDPA
    _supports_sdpa = True
    # 支持缓存类
    _supports_cache_class = True

    # 初始化权重的方法,根据模块类型设置权重
    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()


MISTRAL_INPUTS_DOCSTRING = r"""
"""


# 为 MistralModel 类添加文档注释,描述它是一个 Transformer 解码器模型,由多个 MistralDecoderLayer 组成
@add_start_docstrings(
    "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
    MISTRAL_START_DOCSTRING,
)
class MistralModel(MistralPreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`]

    Args:
        config: MistralConfig
    """

    # 初始化方法,接受一个 MistralConfig 的配置对象
    def __init__(self, config: MistralConfig):
        super().__init__(config)
        # 设置填充索引为配置中的 pad_token_id
        self.padding_idx = config.pad_token_id
        # 设置词汇表大小为配置中的 vocab_size
        self.vocab_size = config.vocab_size

        # 初始化词嵌入层,指定词汇表大小、隐藏大小和填充索引
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        # 初始化多个 MistralDecoderLayer 层,根据 num_hidden_layers 参数
        self.layers = nn.ModuleList(
            [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        # 设置注意力实现类型为配置中的 _attn_implementation
        self._attn_implementation = config._attn_implementation
        # 初始化 RMS 归一化层,指定隐藏大小和 RMS 归一化的 epsilon 值
        self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        # 默认关闭梯度检查点
        self.gradient_checkpointing = False
        # 初始化权重并进行最终处理
        self.post_init()
    # 返回当前模型的输入嵌入(embedding)
    def get_input_embeddings(self):
        return self.embed_tokens

    # 设置当前模型的输入嵌入(embedding)
    def set_input_embeddings(self, value):
        self.embed_tokens = value

    # 使用 MISTRAL_INPUTS_DOCSTRING 将文档字符串添加到模型前向传播方法上
    def forward(
        self,
        input_ids: torch.LongTensor = None,  # 输入的 token IDs,数据类型为 LongTensor
        attention_mask: Optional[torch.Tensor] = None,  # 注意力遮罩,可选的 Torch 张量
        position_ids: Optional[torch.LongTensor] = None,  # 位置 IDs,可选的 LongTensor
        past_key_values: Optional[List[torch.FloatTensor]] = None,  # 过去的键值对列表,可选的 FloatTensor 列表
        inputs_embeds: Optional[torch.FloatTensor] = None,  # 输入的嵌入张量,可选的 FloatTensor
        use_cache: Optional[bool] = None,  # 是否使用缓存,可选的布尔值
        output_attentions: Optional[bool] = None,  # 是否输出注意力权重,可选的布尔值
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态,可选的布尔值
        return_dict: Optional[bool] = None,  # 是否以字典形式返回结果,可选的布尔值
class MistralForCausalLM(MistralPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        # 使用MistralModel构建模型
        self.model = MistralModel(config)
        # 设置词汇表大小
        self.vocab_size = config.vocab_size
        # 线性层,将隐藏状态映射到词汇表大小的空间,无偏置项
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        # 返回模型的输入嵌入层
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        # 设置模型的输入嵌入层
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        # 返回语言模型头部的输出嵌入层
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        # 设置语言模型头部的输出嵌入层
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        # 设置解码器
        self.model = decoder

    def get_decoder(self):
        # 获取解码器
        return self.model

    @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 模型前向传播函数,详细说明见函数装饰器的注释
        pass

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):
        # 为生成准备输入的函数,包括输入ID、过去键值、注意力掩码和输入嵌入
        pass
        # 检查是否提供了 past_key_values 参数,如果是则根据其内容进行处理
        if past_key_values is not None:
            # 如果 past_key_values 是 Cache 类型,则获取其相关属性
            if isinstance(past_key_values, Cache):
                cache_length = past_key_values.get_seq_length()  # 获取缓存序列的长度
                past_length = past_key_values.seen_tokens  # 获取已处理的标记数
                max_cache_length = past_key_values.get_max_length()  # 获取最大缓存长度
            else:
                # 否则假设 past_key_values 是一个元组,获取其第一个元素的第三维长度作为 cache_length 和 past_length
                cache_length = past_length = past_key_values[0][0].shape[2]
                max_cache_length = None

            # 保留未处理的标记:
            # 1 - 如果 attention_mask 的长度超过 input_ids 的长度,则表明一些输入是作为缓存的一部分传递的
            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
            # 2 - 如果 past_length 小于 input_ids 的长度,则 input_ids 包含所有输入标记。根据 past_length 截断 input_ids。
            elif past_length < input_ids.shape[1]:
                input_ids = input_ids[:, past_length:]
            # 3 - 否则(past_length >= input_ids.shape[1]),假设 input_ids 只包含未处理的标记。

            # 如果即将超过最大缓存长度,则需要裁剪输入的 attention_mask。
            if (
                max_cache_length is not None
                and attention_mask is not None
                and cache_length + input_ids.shape[1] > max_cache_length
            ):
                attention_mask = attention_mask[:, -max_cache_length:]

        # 获取可选的 position_ids 参数,如果 attention_mask 存在且 position_ids 为 None,则动态生成 position_ids 以用于批次生成
        position_ids = kwargs.get("position_ids", None)
        if attention_mask is not None and position_ids is None:
            position_ids = attention_mask.long().cumsum(-1) - 1  # 在 attention_mask 上累积和计算 position_ids
            position_ids.masked_fill_(attention_mask == 0, 1)  # 将 attention_mask 为 0 的位置填充为 1
            if past_key_values:
                position_ids = position_ids[:, -input_ids.shape[1] :]  # 如果 past_key_values 存在,只保留与 input_ids 相关的部分

        # 如果传入了 inputs_embeds 参数,并且 past_key_values 为 None,则只在第一代中使用它们
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}  # 使用 inputs_embeds 作为模型输入
        else:
            model_inputs = {"input_ids": input_ids}  # 否则使用 input_ids 作为模型输入

        # 更新 model_inputs 字典,添加 position_ids、past_key_values、use_cache 和 attention_mask
        model_inputs.update(
            {
                "position_ids": position_ids,
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
            }
        )
        return model_inputs  # 返回最终的模型输入字典

    @staticmethod
    # 定义一个函数 `_reorder_cache`,用于重新排序缓存 `past_key_values` 中的数据
    def _reorder_cache(past_key_values, beam_idx):
        # 初始化一个空元组,用于存储重新排序后的缓存数据
        reordered_past = ()
        # 遍历 past_key_values 中的每一层的缓存数据
        for layer_past in past_key_values:
            # 对每层的缓存数据进行重新排序,并将重新排序后的结果添加到 reordered_past 中
            reordered_past += (
                # 对每个 past_state 执行索引选择操作,使用 beam_idx 作为索引,转移到 past_state 的设备上
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        # 返回重新排序后的缓存数据 reordered_past
        return reordered_past
# 定义了一个用于序列分类的 Mistral 模型,其顶部有一个线性层用于分类。
# 该模型使用最后一个 token 进行分类,类似于其他因果模型(如 GPT-2)的做法。
# 如果配置中定义了 `pad_token_id`,则找到每行中不是填充 token 的最后一个 token 进行分类。
# 如果没有定义 `pad_token_id`,则直接取每个批次中每行的最后一个值作为分类的 token。
# 当传入 `inputs_embeds` 而不是 `input_ids` 时,由于无法猜测填充 token,也采用相同的策略(取每行的最后一个值)。
@add_start_docstrings(
    """
    The Mistral Model transformer with a sequence classification head on top (linear layer).

    [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
    (e.g. GPT-2) do.

    Since it does classification on the last token, it requires to know the position of the last token. If a
    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
    each row of the batch).
    """,
    MISTRAL_START_DOCSTRING,
)
# 从 transformers.models.llama.modeling_llama.LlamaForSequenceClassification 复制并修改为使用 Mistral 模型
class MistralForSequenceClassification(MistralPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.model = MistralModel(config)
        # 使用线性层进行分类,输出维度为类别数,没有偏置项
        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

        # 初始化权重并进行最终处理
        self.post_init()

    def get_input_embeddings(self):
        # 获取模型的输入嵌入层
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        # 设置模型的输入嵌入层
        self.model.embed_tokens = value

    @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,

.\models\mistral\__init__.py

# 引入需要的模块和函数
from typing import TYPE_CHECKING

# 引入自定义异常和延迟加载模块
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_torch_available

# 定义模块导入结构
_import_structure = {
    "configuration_mistral": ["MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP", "MistralConfig"],
}

# 检查是否可用 Torch,若不可用则引发自定义异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,添加 Mistral 模型相关的导入结构
    _import_structure["modeling_mistral"] = [
        "MistralForCausalLM",
        "MistralModel",
        "MistralPreTrainedModel",
        "MistralForSequenceClassification",
    ]

# 检查是否可用 Flax,若不可用则引发自定义异常
try:
    if not is_flax_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,添加 FlaxMistral 模型相关的导入结构
    _import_structure["modeling_flax_mistral"] = [
        "FlaxMistralForCausalLM",
        "FlaxMistralModel",
        "FlaxMistralPreTrainedModel",
    ]

# 如果在类型检查模式下
if TYPE_CHECKING:
    # 导入 Mistral 配置相关的类和函数
    from .configuration_mistral import MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP, MistralConfig

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入 Mistral 模型相关的类和函数
        from .modeling_mistral import (
            MistralForCausalLM,
            MistralForSequenceClassification,
            MistralModel,
            MistralPreTrainedModel,
        )

    try:
        if not is_flax_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入 FlaxMistral 模型相关的类和函数
        from .modeling_flax_mistral import (
            FlaxMistralForCausalLM,
            FlaxMistralModel,
            FlaxMistralPreTrainedModel,
        )

# 如果不是类型检查模式,则进行延迟加载模块的设置
else:
    import sys

    # 将当前模块设置为延迟加载模块
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\mixtral\configuration_mixtral.py

# 设置编码格式为 UTF-8
# 版权声明,包括公司和团队
# 在 Apache License 2.0 下授权使用该文件
# 可以在指定许可证下使用此文件,详见链接
# 如果不符合条件,则不能使用此文件
# 根据法律要求或书面同意,分发的软件以“原样”分发
# 没有任何明示或暗示的保证或条件
# 详见许可证以了解特定的语言权限

""" Mixtral model configuration"""

# 从 transformers 库中导入预训练配置类 PretrainedConfig
from ...configuration_utils import PretrainedConfig
# 导入 logging 模块
from ...utils import logging

# 获取 logger 对象
logger = logging.get_logger(__name__)

# 预训练配置文件映射字典,指定 Mixtral 预训练模型和其配置文件的链接
MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "mistral-ai/Mixtral-8x7B": "https://huggingface.co/mistral-ai/Mixtral-8x7B/resolve/main/config.json",
}

# MixtralConfig 类,继承自 PretrainedConfig 类
class MixtralConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`MixtralModel`]. It is used to instantiate an
    Mixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the Mixtral-7B-v0.1 or Mixtral-7B-Instruct-v0.1.

    [mixtralai/Mixtral-8x7B](https://huggingface.co/mixtralai/Mixtral-8x7B)
    [mixtralai/Mixtral-7B-Instruct-v0.1](https://huggingface.co/mixtralai/Mixtral-7B-Instruct-v0.1)

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    ```
    >>> from transformers import MixtralModel, MixtralConfig

    >>> # Initializing a Mixtral 7B style configuration
    >>> configuration = MixtralConfig()

    >>> # Initializing a model from the Mixtral 7B style configuration
    >>> model = MixtralModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    """

    # 模型类型为 mixtral
    model_type = "mixtral"
    # 推断时忽略的键
    keys_to_ignore_at_inference = ["past_key_values"]

    # 构造函数,定义 MixtralConfig 的各项配置参数
    def __init__(
        self,
        vocab_size=32000,
        hidden_size=4096,
        intermediate_size=14336,
        num_hidden_layers=32,
        num_attention_heads=32,
        num_key_value_heads=8,
        hidden_act="silu",
        max_position_embeddings=4096 * 32,
        initializer_range=0.02,
        rms_norm_eps=1e-5,
        use_cache=True,
        pad_token_id=None,
        bos_token_id=1,
        eos_token_id=2,
        tie_word_embeddings=False,
        rope_theta=1e6,
        sliding_window=None,
        attention_dropout=0.0,
        num_experts_per_tok=2,
        num_local_experts=8,
        output_router_logits=False,
        router_aux_loss_coef=0.001,
        **kwargs,
    ):
        # 调用父类 PretrainedConfig 的构造函数
        super().__init__(
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            num_hidden_layers=num_hidden_layers,
            num_attention_heads=num_attention_heads,
            num_key_value_heads=num_key_value_heads,
            hidden_act=hidden_act,
            max_position_embeddings=max_position_embeddings,
            initializer_range=initializer_range,
            rms_norm_eps=rms_norm_eps,
            use_cache=use_cache,
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            tie_word_embeddings=tie_word_embeddings,
            rope_theta=rope_theta,
            sliding_window=sliding_window,
            attention_dropout=attention_dropout,
            num_experts_per_tok=num_experts_per_tok,
            num_local_experts=num_local_experts,
            output_router_logits=output_router_logits,
            router_aux_loss_coef=router_aux_loss_coef,
            **kwargs,
        )
        ):
            # 初始化模型参数:词汇表大小、最大位置嵌入、隐藏层大小、中间层大小、隐藏层数量、注意力头数量、滑动窗口大小
            self.vocab_size = vocab_size
            self.max_position_embeddings = max_position_embeddings
            self.hidden_size = hidden_size
            self.intermediate_size = intermediate_size
            self.num_hidden_layers = num_hidden_layers
            self.num_attention_heads = num_attention_heads
            self.sliding_window = sliding_window

            # 为了向后兼容性
            # 如果未提供 num_key_value_heads,则将其设置为 num_attention_heads
            if num_key_value_heads is None:
                num_key_value_heads = num_attention_heads

            # 设置 key-value 头的数量
            self.num_key_value_heads = num_key_value_heads
            # 设置隐藏层激活函数
            self.hidden_act = hidden_act
            # 设置初始化范围
            self.initializer_range = initializer_range
            # RMS 归一化的 epsilon 值
            self.rms_norm_eps = rms_norm_eps
            # 是否使用缓存
            self.use_cache = use_cache
            # ROPE 损失函数参数
            self.rope_theta = rope_theta
            # 注意力机制的 dropout 概率
            self.attention_dropout = attention_dropout

            # 每个 token 的专家数量
            self.num_experts_per_tok = num_experts_per_tok
            # 本地专家的数量
            self.num_local_experts = num_local_experts
            # 是否输出路由器的 logits
            self.output_router_logits = output_router_logits
            # 路由器辅助损失系数
            self.router_aux_loss_coef = router_aux_loss_coef

            # 调用父类的初始化方法,设置模型的特殊标记 ID 和其他参数
            super().__init__(
                pad_token_id=pad_token_id,
                bos_token_id=bos_token_id,
                eos_token_id=eos_token_id,
                tie_word_embeddings=tie_word_embeddings,
                **kwargs,
            )

.\models\mixtral\convert_mixtral_weights_to_hf.py

# 引入必要的库和模块
import argparse  # 用于处理命令行参数解析
import json  # 用于处理 JSON 格式的数据
import os  # 用于操作系统相关的功能

import torch  # PyTorch 深度学习库

from transformers import (  # 从 transformers 库中导入指定模块和类
    MixtralConfig,  # Mixtral 模型的配置类
    MixtralForCausalLM,  # Mixtral 的条件语言模型类
)

"""
示例用法:


python src/transformers/models/mixtral/convert_mixtral_weights_to_hf.py \
    --input_dir /path/to/downloaded/mixtral/weights --model_size 7B --output_dir /output/path


之后,可以通过以下方式加载模型:


from transformers import MixtralForCausalLM

model = MixtralForCausalLM.from_pretrained("/output/path")


重要说明:你需要能够将整个模型加载到内存中以执行此脚本(即使最大版本被分成多个检查点,每个检查点都包含模型权重的一部分,因此我们需要将它们全部加载到内存中)。
"""


def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
    # 计算中间层的尺寸,确保是指定倍数的整数
    return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)


def read_json(path):
    # 读取 JSON 文件并返回其内容
    with open(path, "r") as f:
        return json.load(f)


def write_json(text, path):
    # 将文本内容以 JSON 格式写入到指定路径的文件中
    with open(path, "w") as f:
        json.dump(text, f)


def write_model(model_path, input_base_path, model_size, safe_serialization=True):
    # 创建模型路径,如果不存在则创建
    os.makedirs(model_path, exist_ok=True)

    # 读取模型参数的 JSON 文件
    params = read_json(os.path.join(input_base_path, "params.json"))
    num_shards = 1

    # 从 params.json 中读取滑动窗口大小(如果有的话)
    sliding_window = int(params["sliding_window"]) if "sliding_window" in params else None
    n_layers = params["num_hidden_layers"]  # 隐藏层的数量
    n_heads = params["num_attention_heads"]  # 注意力头的数量
    n_heads_per_shard = n_heads // num_shards  # 每个分片的注意力头数量
    dim = params["hidden_size"]  # 隐藏层的尺寸
    dims_per_head = dim // n_heads  # 每个注意力头的尺寸
    base = params.get("rope_theta", 10000.0)  # 获取 rope_theta 参数,默认为 10000.0
    max_position_embeddings = 4096 * 8  # 最大位置嵌入的数量
    num_local_experts = params["num_local_experts"]  # 本地专家的数量
    ffn_dim = params["intermediate_size"]  # 中间层的尺寸

    vocab_size = params["vocab_size"]  # 词汇表的大小

    if "num_key_value_heads" in params:
        num_key_value_heads = params["num_key_value_heads"]  # 键值头的数量(适用于 GQA / MQA)
        num_local_key_value_heads = num_key_value_heads // num_shards  # 每个分片的键值头的数量
        key_value_dim = dims_per_head * num_local_key_value_heads  # 键值维度
    else:  # 兼容其他检查点
        num_key_value_heads = n_heads
        num_local_key_value_heads = n_heads_per_shard
        key_value_dim = dim

    # 对于切片旋转,重新排列
    def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
        # 重新排列张量 `w`,以便于后续处理
        return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)

    print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
    # 打印消息,指示正在从指定路径加载所有参数

    # 加载权重文件列表
    loaded = [
        torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pt"), map_location="cpu") for i in range(8)
    ]

    # 初始化合并后的状态字典
    merged_state_dict = {}
    # 合并所有加载的状态字典
    for state_dict in loaded:
        merged_state_dict.update(state_dict)

    # 初始化状态字典
    state_dict = {}

    # 更新状态字典的特定部分,包括模型的权重
    state_dict.update(
        {
            "model.norm.weight": merged_state_dict["norm.weight"],
            "model.embed_tokens.weight": merged_state_dict["tok_embeddings.weight"],
            "lm_head.weight": merged_state_dict["output.weight"],
        }
    )

    # 初始化 Mixtral 模型的配置
    config = MixtralConfig(
        hidden_size=dim,
        intermediate_size=ffn_dim,
        num_attention_heads=params["num_attention_heads"],
        num_hidden_layers=params["num_hidden_layers"],
        rms_norm_eps=params["rms_norm_eps"],
        num_key_value_heads=num_key_value_heads,
        vocab_size=vocab_size,
        rope_theta=base,
        max_position_embeddings=max_position_embeddings,
        sliding_window=sliding_window,
        num_local_experts=num_local_experts,
    )

    # 打印消息,指示正在加载 Mixtral 模型的检查点
    print("Loading the checkpoint in a Mixtral model.")
    # 在指定设备上初始化 Mixtral 模型
    with torch.device("meta"):
        model = MixtralForCausalLM(config)
    # 从配置中删除保存的路径信息,以避免泄露
    del model.config._name_or_path
    # 设置模型配置的 Torch 数据类型为 float16
    model.config.torch_dtype = torch.float16
    # 打印消息,指示正在以 Transformers 格式保存模型
    print("Saving in the Transformers format.")

    # 加载模型的状态字典
    model.load_state_dict(state_dict, strict=True, assign=True)

    # 检查所有模型参数,确保没有参数保存在 `meta` 设备上
    for n, p in model.named_parameters():
        assert p.device.type != "meta", f"{n} has not been loaded!"

    # 将模型保存为预训练文件格式到指定路径
    model.save_pretrained(model_path, safe_serialization=safe_serialization)
# 定义程序的主函数入口点
def main():
    # 创建命令行参数解析器
    parser = argparse.ArgumentParser()
    
    # 添加命令行参数 --input_dir,用于指定Mixtral权重的位置,包含tokenizer.model和model文件夹
    parser.add_argument(
        "--input_dir",
        help="Location of Mixtral weights, which contains tokenizer.model and model folders",
        required=True,
    )
    
    # 添加命令行参数 --model_size,用于选择模型大小,默认为"7B",与Mixtral官方发布版本对应
    parser.add_argument(
        "--model_size",
        choices=["7B"],
        help="'f' models correspond to the finetuned versions, and are specific to the Mixtral official release. For more details on Mixtral, checkout the original repo: https://huggingface.co/mistral-ai",
        default="7B",
    )
    
    # 添加命令行参数 --output_dir,用于指定写入HF模型的位置
    parser.add_argument("--output_dir", help="Location to write HF model", required=True)
    
    # 添加命令行参数 --safe_serialization,用于指定是否使用安全张量进行保存
    parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
    
    # 解析命令行参数
    args = parser.parse_args()
    
    # 调用write_model函数,传入命令行参数来写入模型
    write_model(
        model_path=args.output_dir,
        input_base_path=args.input_dir,
        model_size=args.model_size,
        safe_serialization=args.safe_serialization,
    )


# 程序的入口点,如果直接运行当前脚本,则调用main函数
if __name__ == "__main__":
    main()

.\models\mixtral\modeling_mixtral.py

# 设置编码格式为 UTF-8
# 版权声明和许可信息,基于 Apache License, Version 2.0
# 本代码基于 EleutherAI 的 GPT-NeoX 库,包括 GPT-NeoX 和 OPT 实现的修改,以适应与 Meta AI 团队训练的模型的架构差异。
# 导入 Python 标准库中的模块和函数
""" PyTorch Mixtral model."""

# 导入 inspect 模块,用于获取对象的信息
import inspect
# 导入 math 模块,提供数学函数
import math
# 导入 warnings 模块,用于发出警告信息
import warnings
# 导入类型提示相关模块
from typing import List, Optional, Tuple, Union

# 导入 PyTorch 模块
import torch
# 导入 PyTorch 中的函数库和功能模块
import torch.nn.functional as F
import torch.utils.checkpoint
# 导入 PyTorch 中的 nn 模块
from torch import nn
# 导入 PyTorch 中的损失函数
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

# 导入混合精度训练相关模块
from ...activations import ACT2FN
# 导入缓存相关模块
from ...cache_utils import Cache, DynamicCache
# 导入模型中的注意力掩码工具函数
from ...modeling_attn_mask_utils import (
    _prepare_4d_causal_attention_mask,
    _prepare_4d_causal_attention_mask_for_sdpa,
)
# 导入模型输出相关类
from ...modeling_outputs import (
    MoeCausalLMOutputWithPast,
    MoeModelOutputWithPast,
    SequenceClassifierOutputWithPast,
)
# 导入模型工具函数
from ...modeling_utils import PreTrainedModel
# 导入 PyTorch 实用工具函数
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
# 导入工具函数
from ...utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
    logging,
    replace_return_docstrings,
)
# 导入导入相关工具函数
from ...utils.import_utils import is_torch_fx_available
# 导入 Mixtral 模型配置类
from .configuration_mixtral import MixtralConfig

# 检查是否支持 Flash Attention 2 版本,根据情况导入相应的模块和函数
if is_flash_attn_2_available():
    from flash_attn import flash_attn_func, flash_attn_varlen_func
    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa

    # 检查 Flash Attention 函数是否支持窗口大小参数
    _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)

# 如果支持 Torch FX,将 _prepare_4d_causal_attention_mask 函数包装为 FX 图中的叶节点函数
if is_torch_fx_available():
    if not is_torch_greater_or_equal_than_1_13:
        import torch.fx

    _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)

# 获取当前模块的日志记录器对象
logger = logging.get_logger(__name__)

# 文档配置信息
_CONFIG_FOR_DOC = "MixtralConfig"


def load_balancing_loss_func(
    gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
) -> float:
    r"""
    Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
    """
    See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
    function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
    experts is too unbalanced.

    Args:
        gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
            Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
            shape [batch_size X sequence_length, num_experts].
        attention_mask (`torch.Tensor`, None):
            The attention_mask used in forward function
            shape [batch_size X sequence_length] if not None.
        num_experts (`int`, *optional*):
            Number of experts

    Returns:
        The auxiliary loss.
    """
    # 如果 gate_logits 为空或者不是元组,则返回 0
    if gate_logits is None or not isinstance(gate_logits, tuple):
        return 0

    # 如果 gate_logits 是元组,则计算设备并将各层的门控 logits 拼接起来
    if isinstance(gate_logits, tuple):
        compute_device = gate_logits[0].device
        concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)

    # 计算 routing weights,即经过 softmax 处理后的权重
    routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)

    # 根据 routing weights 获取 top_k 个专家的索引
    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)

    # 创建一个 one-hot 编码的专家 mask
    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)

    if attention_mask is None:
        # 如果没有 attention_mask,则计算每个专家被路由到的 token 的百分比
        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)

        # 计算路由到每个专家的平均概率
        router_prob_per_expert = torch.mean(routing_weights, dim=0)
    else:
        batch_size, sequence_length = attention_mask.shape
        num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)

        # 创建专家注意力 mask,用于处理 padding token
        expert_attention_mask = (
            attention_mask[None, :, :, None, None]
            .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
            .reshape(-1, top_k, num_experts)
            .to(compute_device)
        )

        # 计算每个专家被路由到的 token 的百分比,考虑了 attention_mask
        tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
            expert_attention_mask, dim=0
        )

        # 创建路由概率专家注意力 mask,用于处理 padding token
        router_per_expert_attention_mask = (
            attention_mask[None, :, :, None]
            .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
            .reshape(-1, num_experts)
            .to(compute_device)
        )

        # 计算路由到每个专家的平均概率,考虑了 attention_mask
        router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
            router_per_expert_attention_mask, dim=0
        )
    # 计算每个专家的损失乘以路由器概率,并对所有专家求和得到总损失
    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
    # 将总损失乘以专家的数量,得到最终的整体损失
    return overall_loss * num_experts
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
# 计算非填充数据的索引、累计序列长度和批次中最大序列长度
def _get_unpad_data(attention_mask):
    # 计算每个批次中的序列长度总和
    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
    # 找到所有非填充位置的索引
    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
    # 获取批次中最大的序列长度
    max_seqlen_in_batch = seqlens_in_batch.max().item()
    # 计算累计序列长度并进行填充
    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
    return (
        indices,
        cu_seqlens,
        max_seqlen_in_batch,
    )


# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral
# MixtralRMSNorm类,用于模仿T5LayerNorm,实现均值归一化
class MixtralRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        MixtralRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        # 初始化权重参数
        self.weight = nn.Parameter(torch.ones(hidden_size))
        # 设置方差的小值 epsilon
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        # 获取输入的数据类型
        input_dtype = hidden_states.dtype
        # 将输入转换为 float32 类型
        hidden_states = hidden_states.to(torch.float32)
        # 计算输入张量的方差
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        # 应用均值归一化
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        # 返回加权后的归一化结果
        return self.weight * hidden_states.to(input_dtype)


# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral
# MixtralRotaryEmbedding类,用于生成旋转嵌入矩阵,实现Self-Attention操作
class MixtralRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        # 初始化维度、最大位置嵌入和基数
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        # 计算频率的倒数,用于生成正弦和余弦值
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        # 将频率作为缓冲区注册,以便后续使用
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # 构建旋转嵌入的正弦和余弦缓存
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        # 设置缓存的最大序列长度
        self.max_seq_len_cached = seq_len
        # 生成等间距的整数张量
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

        # 计算正弦和余弦值的缓存
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # 如果当前序列长度超过缓存的最大序列长度,重新设置正弦和余弦缓存
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        # 返回旋转嵌入的正弦和余弦值
        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )


# Copied from transformers.models.llama.modeling_llama.rotate_half
# 实现输入张量的上下半部分交换
def rotate_half(x):
    # 对输入张量的一半隐藏维度进行旋转操作
    """Rotates half the hidden dims of the input."""
    
    # 将输入张量 x 的前半部分进行切片,保留其隐藏维度的前一半数据
    x1 = x[..., : x.shape[-1] // 2]
    
    # 将输入张量 x 的后半部分进行切片,保留其隐藏维度的后一半数据
    x2 = x[..., x.shape[-1] // 2 :]
    
    # 将 x 的后半部分取负值,并与 x 的前半部分连接在一起,以实现旋转操作
    return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    # Unsqueezes cos and sin tensors along unsqueeze_dim to match dimensions of q and k
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    # Apply rotary position embedding to q and k tensors
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    # Extract dimensions from hidden_states tensor
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    # If n_rep is 1, return the original hidden_states tensor
    if n_rep == 1:
        return hidden_states
    # Expand hidden_states tensor to repeat along the specified dimension
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    # Reshape expanded tensor to the desired shape
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
class MixtralAttention(nn.Module):
    """
    Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
    and "Generating Long Sequences with Sparse Transformers".
    """
    # 初始化函数,接受配置参数和可选的层索引
    def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):
        # 调用父类的初始化方法
        super().__init__()
        # 保存传入的配置参数
        self.config = config
        # 保存传入的层索引
        self.layer_idx = layer_idx
        # 如果未提供层索引,发出警告,因为在使用缓存时可能会导致前向调用中的错误
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        # 从配置中获取隐藏层大小
        self.hidden_size = config.hidden_size
        # 从配置中获取注意力头数
        self.num_heads = config.num_attention_heads
        # 计算每个注意力头的维度
        self.head_dim = self.hidden_size // self.num_heads
        # 从配置中获取键值头数
        self.num_key_value_heads = config.num_key_value_heads
        # 计算每组键值头的数量
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        # 从配置中获取最大位置嵌入数
        self.max_position_embeddings = config.max_position_embeddings
        # 从配置中获取旋转嵌入的基础值
        self.rope_theta = config.rope_theta
        # 设定是否是因果注意力
        self.is_causal = True
        # 从配置中获取注意力丢弃率
        self.attention_dropout = config.attention_dropout

        # 检查隐藏层大小是否能被注意力头数整除,否则抛出数值错误
        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )

        # 初始化查询投影层
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        # 初始化键投影层
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        # 初始化值投影层
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        # 初始化输出投影层
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

        # 初始化旋转嵌入层
        self.rotary_emb = MixtralRotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )

    # 根据给定的张量形状,调整其形状以适应注意力头数和头维度的结构
    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    # 前向传播函数,接收隐藏状态、注意力掩码、位置ID、过去的键值对缓存等参数
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs,
# 从 `transformers.models.mistral.modeling_mistral.MistralFlashAttention2` 复制的 `MixtralFlashAttention2` 类,将 Mistral 更名为 Mixtral
class MixtralFlashAttention2(MixtralAttention):
    """
    Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays
    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
    flash attention and deal with padding tokens in case the input contains any of them.
    """

    # 从 `transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__` 复制的构造函数
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # TODO: 在 Flash Attention for RoCm 版本升级到 2.1 之后应该移除这段注释。
        # flash_attn<2.1 生成左上对齐的因果蒙版,而这里需要右下对齐的默认效果。此属性用于处理这种差异。参考:https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0。
        # 注意,对于 flash_attn<2.1,除了 q_seqlen == 1 的情况外,使用 q_seqlen != k_seqlen 会产生错误的蒙版(左上对齐)。
        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs,
    ):
        """
        Override of the forward method to integrate Mixtral flash attention with handling of padding tokens.
        """
        # 真正的前向传播方法,集成了 Mixtral flash attention 并处理填充标记
        pass

    def _flash_attention_forward(
        self,
        query_states,
        key_states,
        value_states,
        attention_mask,
        query_length,
        dropout=0.0,
        softmax_scale=None,
        use_sliding_windows=False,
    # 定义一个方法 `_upad_input`,该方法接受多个输入参数:query_layer, key_layer, value_layer, attention_mask, query_length
    def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
        # 获取 key_layer 的形状信息,分别为 batch_size, kv_seq_len, num_heads, head_dim
        batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape

        # 如果 kv_seq_len 不等于 attention_mask 的最后一个维度长度,需要重新创建 padding mask
        if kv_seq_len != attention_mask.shape[-1]:
            # 获取 attention_mask 的最后一个维度长度
            attention_mask_num_tokens = attention_mask.shape[-1]
            # 更新 attention_mask,保留 kv_seq_len 长度的部分
            attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]

        # 调用 _get_unpad_data 函数,获取解压后的数据 indices_k, cu_seqlens_k, max_seqlen_in_batch_k
        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)

        # 通过索引操作,对 key_layer 进行重新组织,形状变为 (batch_size * kv_seq_len, num_heads, head_dim)
        key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
        # 对 value_layer 进行类似的重新组织
        value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)

        # 根据 query_length 的不同情况进行不同的处理
        if query_length == kv_seq_len:
            # 如果 query_length 等于 kv_seq_len,则对 query_layer 进行索引操作
            query_layer = index_first_axis(
                query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
            )
            cu_seqlens_q = cu_seqlens_k
            max_seqlen_in_batch_q = max_seqlen_in_batch_k
            indices_q = indices_k
        elif query_length == 1:
            # 如果 query_length 等于 1,则将 query_layer 的形状调整,并生成相应的索引和长度信息
            max_seqlen_in_batch_q = 1
            cu_seqlens_q = torch.arange(
                batch_size + 1, dtype=torch.int32, device=query_layer.device
            )  # 这里有一个 memcpy 操作,非常不好。
            indices_q = cu_seqlens_q[:-1]
            query_layer = query_layer.squeeze(1)
        else:
            # 否则,根据 -query_length: 切片假设左填充,更新 attention_mask,并调用 unpad_input 函数处理 query_layer
            attention_mask = attention_mask[:, -query_length:]
            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

        # 返回处理后的结果,包括 query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k)
        return (
            query_layer,
            key_layer,
            value_layer,
            indices_q,
            (cu_seqlens_q, cu_seqlens_k),
            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
        )
# 从`transformers.models.mistral.modeling_mistral.MistralSdpaAttention`复制而来,将"Mistral"改为"Mixtral"
class MixtralSdpaAttention(MixtralAttention):
    """
    使用`torch.nn.functional.scaled_dot_product_attention`的Mixtral注意力模块。此模块继承自`MixtralAttention`,
    其权重保持不变。唯一的更改在于前向传递,以适应SDPA API。
    """

    # 从MixtralAttention.forward进行调整
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
    ):
        pass  # 这里的实际实现在SDPA API中进行了调整,但在注释中未提供具体的实现细节

# 定义了Mixtral注意力类别的映射字典
MIXTRAL_ATTENTION_CLASSES = {
    "eager": MixtralAttention,
    "flash_attention_2": MixtralFlashAttention2,
    "sdpa": MixtralSdpaAttention,  # 将sdpa映射到MixtralSdpaAttention类
}


class MixtralBlockSparseTop2MLP(nn.Module):
    def __init__(self, config: MixtralConfig):
        super().__init__()
        self.ffn_dim = config.intermediate_size
        self.hidden_dim = config.hidden_size

        # 线性层定义
        self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
        self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
        self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)

        # 激活函数从ACT2FN字典中选择
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, hidden_states):
        # 前向传递计算
        current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
        current_hidden_states = self.w2(current_hidden_states)
        return current_hidden_states


# MixtralBLockSparseTop2MLP被废弃,用MixtralBlockSparseTop2MLP代替,发出一次警告
class MixtralBLockSparseTop2MLP(MixtralBlockSparseTop2MLP):
    def __init__(self, *args, **kwargs):
        logger.warning_once(
            "MixtralBLockSparseTop2MLP is deprecated by MixtralBlockSparseTop2MLP and will be removed in v4.40."
        )
        super().__init__(*args, **kwargs)


class MixtralSparseMoeBlock(nn.Module):
    """
    这个实现严格等同于标准的MoE,具有全容量(没有丢弃标记的令牌)。它更快,因为它将MoE操作
    形式化为块稀疏操作,以适应对专家的不平衡分配,而标准MoE要么(1)丢弃标记,以降低性能,要么(2)
    将容量因子设置为专家数量,从而浪费填充的计算和内存。
    """

    def __init__(self, config):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.intermediate_size
        self.num_experts = config.num_local_experts
        self.top_k = config.num_experts_per_tok

        # gating
        # gating机制的线性层
        self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)

        # 创建MixtralBlockSparseTop2MLP模块列表,用于每个专家
        self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        接收隐藏状态作为输入,返回处理后的隐藏状态和路由器的logits值。

        Args:
            hidden_states (torch.Tensor): 输入的隐藏状态张量,形状为(batch_size, sequence_length, hidden_dim)

        Returns:
            torch.Tensor: 处理后的最终隐藏状态张量,形状为(batch_size, sequence_length, hidden_dim)
            torch.Tensor: 路由器的logits张量,形状为(batch * sequence_length, n_experts)
        """

        # 获取输入张量的维度信息
        batch_size, sequence_length, hidden_dim = hidden_states.shape

        # 将输入的三维张量重塑为二维张量,以便进行路由器的计算
        hidden_states = hidden_states.view(-1, hidden_dim)

        # 使用路由器模型计算路由器的logits
        router_logits = self.gate(hidden_states)

        # 使用softmax函数对logits进行归一化处理,得到路由权重
        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)

        # 从每个路由权重中选择top-k的值,并重新归一化
        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

        # 将归一化后的路由权重转换为输入张量的数据类型
        routing_weights = routing_weights.to(hidden_states.dtype)

        # 初始化一个全零张量,用于存储最终的隐藏状态
        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )

        # 使用one-hot编码创建选定专家的专家掩码
        # 这将用于轻松地索引哪个专家将被调用
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

        # 遍历模型中所有可用的专家,并在每个专家上执行计算
        for expert_idx in range(self.num_experts):
            expert_layer = self.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])

            if top_x.shape[0] == 0:
                continue

            # 将top_x张量转换为Python列表,以便在PyTorch中更快地索引
            top_x_list = top_x.tolist()
            idx_list = idx.tolist()

            # 根据索引从隐藏状态中获取正确的隐藏状态,并计算当前专家的隐藏状态
            current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]

            # 使用index_add_方法将当前专家的隐藏状态加到最终隐藏状态中
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

        # 将最终隐藏状态张量重塑回原始形状(batch_size, sequence_length, hidden_dim)
        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)

        # 返回最终的隐藏状态张量和路由器的logits张量
        return final_hidden_states, router_logits
# 定义 MixtralDecoderLayer 类,继承自 nn.Module,用于实现 Mixtral 模型的解码器层
class MixtralDecoderLayer(nn.Module):
    # 初始化方法,接受 MixtralConfig 和层索引作为参数
    def __init__(self, config: MixtralConfig, layer_idx: int):
        super().__init__()
        # 设置隐藏层大小
        self.hidden_size = config.hidden_size

        # 初始化自注意力机制,根据配置选择不同的注意力实现类进行初始化
        self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)

        # 初始化块稀疏多路注意力模块
        self.block_sparse_moe = MixtralSparseMoeBlock(config)

        # 初始化输入层归一化模块
        self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        # 初始化注意力后归一化模块
        self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    # 前向传播方法,接受隐藏状态、注意力掩码、位置 ID、过去的键值对等参数
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        output_router_logits: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        **kwargs,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        if "padding_mask" in kwargs:
            # 如果传入了 `padding_mask` 参数,发出警告提示
            warnings.warn(
                "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
            )
        
        """
        Args:
            hidden_states (`torch.FloatTensor`): 输入到层的张量,形状为 `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *可选*): 注意力掩码张量,形状为 `(batch, sequence_length)`,其中填充元素为0
            past_key_value (`Tuple(torch.FloatTensor)`, *可选*): 缓存的过去键值投影状态
            output_attentions (`bool`, *可选*):
                是否返回所有注意力层的注意力张量。详见返回的张量中的 `attentions` 了解更多细节。
            output_router_logits (`bool`, *可选*):
                是否返回所有路由器的logits。这对计算路由器损失很有用,在推理时不应返回。
            use_cache (`bool`, *可选*):
                如果设置为 `True`,则返回 `past_key_values` 键值状态,可用于加速解码 (参见 `past_key_values`).
        """

        residual = hidden_states

        # 输入层归一化
        hidden_states = self.input_layernorm(hidden_states)

        # 自注意力层
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
        )
        hidden_states = residual + hidden_states

        # 全连接层
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states, router_logits = self.block_sparse_moe(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        if output_router_logits:
            outputs += (router_logits,)

        return outputs
# MIXTRAL_START_DOCSTRING 是一个多行原始字符串,用于描述 MixtralPreTrainedModel 类的文档字符串。
# 它包含了关于模型继承自 PreTrainedModel 的信息,以及如何使用 PyTorch 的说明和参数列表。
MIXTRAL_START_DOCSTRING = r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`MixtralConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

# add_start_docstrings 是一个装饰器,用于为 MixtralPreTrainedModel 类添加文档字符串。
# 第一个参数是描述该模型输出原始隐藏状态的概述性文本。
# 第二个参数是 MIXTRAL_START_DOCSTRING,用于详细描述该类的配置和参数信息。
@add_start_docstrings(
    "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
    MIXTRAL_START_DOCSTRING,
)
# MixtralPreTrainedModel 类继承自 PreTrainedModel,用于 Mixtral 模型的预训练和初始化。
class MixtralPreTrainedModel(PreTrainedModel):
    # 配置类,指定了 Mixtral 模型的配置信息。
    config_class = MixtralConfig
    # 基础模型的前缀,通常用于命名前缀。
    base_model_prefix = "model"
    # 是否支持梯度检查点。
    supports_gradient_checkpointing = True
    # 不需要拆分的模块列表。
    _no_split_modules = ["MixtralDecoderLayer"]
    # 跳过键的设备放置。
    _skip_keys_device_placement = "past_key_values"
    # 是否支持 Flash Attention 2。
    _supports_flash_attn_2 = True
    # 是否支持 SDPA(Scaled Dot-Product Attention)。
    _supports_sdpa = True
    # 是否支持缓存类。
    _supports_cache_class = True

    # 初始化权重的函数。
    def _init_weights(self, module):
        std = self.config.initializer_range
        # 如果是线性层,初始化权重和偏置。
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        # 如果是嵌入层,初始化权重。
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()


# MIXTRAL_INPUTS_DOCSTRING 是一个未填充的多行原始字符串,可能用于描述 MixtralModel 类的输入信息。
MIXTRAL_INPUTS_DOCSTRING = r"""
"""


# add_start_docstrings 是一个装饰器,用于为 MixtralModel 类添加文档字符串。
# 第一个参数是描述该模型输出原始隐藏状态的概述性文本。
# 第二个参数是 MIXTRAL_START_DOCSTRING,用于详细描述该类的配置和参数信息。
@add_start_docstrings(
    "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
    MIXTRAL_START_DOCSTRING,
)
# MixtralModel 类继承自 MixtralPreTrainedModel,代表了 Mixtral 模型的具体实现。
class MixtralModel(MixtralPreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]

    Args:
        config: MixtralConfig
    """
    # 初始化函数,接受一个 MixtralConfig 类型的参数 config
    def __init__(self, config: MixtralConfig):
        # 调用父类的初始化函数,传入 config 参数
        super().__init__(config)
        # 设置 padding_idx 属性为 config 的 pad_token_id
        self.padding_idx = config.pad_token_id
        # 设置 vocab_size 属性为 config 的 vocab_size
        self.vocab_size = config.vocab_size

        # 创建一个嵌入层对象 embed_tokens,用于将输入的 token 转换为向量表示
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        
        # 创建一个由多个 MixtralDecoderLayer 组成的层列表,每层通过不同的 layer_idx 构建
        self.layers = nn.ModuleList(
            [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        
        # 设置 _attn_implementation 属性为 config 的 _attn_implementation
        self._attn_implementation = config._attn_implementation
        
        # 创建一个 MixtralRMSNorm 对象 norm,用于进行归一化处理
        self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        # 初始化梯度检查点标志为 False
        self.gradient_checkpointing = False
        
        # 调用 post_init 函数,完成权重初始化和最终处理
        self.post_init()

    # 返回 embed_tokens 属性,即输入嵌入层对象
    def get_input_embeddings(self):
        return self.embed_tokens

    # 设置 embed_tokens 属性为 value
    def set_input_embeddings(self, value):
        self.embed_tokens = value

    # 忽略复制操作,用于 forward 函数的装饰器
    @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_router_logits: Optional[bool] = None,
        return_dict: Optional[bool] = None,
# MixtralForCausalLM 类,继承自 MixtralPreTrainedModel 类,用于混合专家模型的因果语言建模任务

class MixtralForCausalLM(MixtralPreTrainedModel):
    # 定义被绑定权重的键值,用于共享权重
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        # 调用父类的初始化方法,传入配置对象 config
        super().__init__(config)
        # 初始化 MixtralModel 模型,根据传入的配置对象 config
        self.model = MixtralModel(config)
        # 设置词汇表大小为配置对象中的词汇表大小
        self.vocab_size = config.vocab_size
        # 初始化 lm_head,使用线性层将隐藏状态映射到词汇表大小,无偏置
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        # 设置路由辅助损失系数为配置对象中的路由辅助损失系数
        self.router_aux_loss_coef = config.router_aux_loss_coef
        # 设置本地专家的数量为配置对象中的本地专家数量
        self.num_experts = config.num_local_experts
        # 设置每个令牌的专家数量为配置对象中的每个令牌专家数量
        self.num_experts_per_tok = config.num_experts_per_tok
        # 调用后处理初始化方法,用于初始化权重并应用最终处理
        self.post_init()

    # 获取输入嵌入层,返回 MixtralModel 模型的嵌入 tokens
    def get_input_embeddings(self):
        return self.model.embed_tokens

    # 设置输入嵌入层的值
    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    # 获取输出嵌入层,返回 lm_head 线性层
    def get_output_embeddings(self):
        return self.lm_head

    # 设置输出嵌入层的值
    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    # 设置解码器,用于设置 MixtralModel 模型的 decoder
    def set_decoder(self, decoder):
        self.model = decoder

    # 获取解码器,返回当前 MixtralModel 模型
    def get_decoder(self):
        return self.model

    # 前向传播函数,接受多种输入参数,返回 MoeCausalLMOutputWithPast 类型的输出
    @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_router_logits: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 前向传播函数,详细参数含义见上方修饰器的文档注释
        # 本函数无具体实现,仅用于说明接口,实际实现需在派生类中完成
        pass

    # 为生成准备输入的函数
    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        inputs_embeds=None,
        output_router_logits=False,
        **kwargs,
    ):
        # 为生成任务准备输入的函数,详细参数含义见上方函数签名
        # 本函数无具体实现,仅用于说明接口,实际实现需在派生类中完成
        pass
        # Omit tokens covered by past_key_values
        # 如果 past_key_values 不为空,则跳过已被处理的 token

        if past_key_values is not None:
            # Check if past_key_values is an instance of Cache
            # 检查 past_key_values 是否为 Cache 类的实例
            if isinstance(past_key_values, Cache):
                # Get sequence length from past_key_values
                # 从 past_key_values 中获取序列长度
                cache_length = past_key_values.get_seq_length()
                # Get seen tokens count from past_key_values
                # 从 past_key_values 中获取已看到的 token 数量
                past_length = past_key_values.seen_tokens
                # Get maximum cache length from past_key_values
                # 从 past_key_values 中获取最大缓存长度
                max_cache_length = past_key_values.get_max_length()
            else:
                # Assume past_key_values is a tuple and get dimensions from it
                # 假设 past_key_values 是一个元组,并从中获取维度信息
                cache_length = past_length = past_key_values[0][0].shape[2]
                max_cache_length = None

            # Keep only the unprocessed tokens:
            # 保留未处理的 token:

            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
            # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
            # input)
            # 如果 attention_mask 的长度超过 input_ids 的长度,则说明部分输入作为缓存的一部分传递(例如将 input_embeds 作为输入)

            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
            
            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
            # input_ids based on the past_length.
            # 如果 past_length 小于 input_ids 的长度,则 input_ids 包含所有的输入 token。根据 past_length 可以丢弃 input_ids 的部分 token。

            elif past_length < input_ids.shape[1]:
                input_ids = input_ids[:, past_length:]

            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
            # 否则(past_length >= input_ids.shape[1]),假设 input_ids 只包含未处理的 token。

            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
            # 如果即将超出最大缓存长度,我们需要裁剪输入的 attention mask。

            if (
                max_cache_length is not None
                and attention_mask is not None
                and cache_length + input_ids.shape[1] > max_cache_length
            ):
                attention_mask = attention_mask[:, -max_cache_length:]

        # Get position_ids from kwargs if not provided
        # 如果 attention_mask 不为空且 position_ids 为空,则动态创建 position_ids 以进行批量生成

        position_ids = kwargs.get("position_ids", None)
        if attention_mask is not None and position_ids is None:
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -input_ids.shape[1] :]

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        # 如果传递了 `inputs_embeds`,我们只想在第一代步骤中使用它们

        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        # Update model_inputs with various parameters
        # 使用各种参数更新 model_inputs

        model_inputs.update(
            {
                "position_ids": position_ids,
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
                "output_router_logits": output_router_logits,
            }
        )
        # Return the constructed model_inputs dictionary
        # 返回构建的 model_inputs 字典

        return model_inputs
    # 定义一个函数 `_reorder_cache`,用于重新排序缓存数据
    def _reorder_cache(past_key_values, beam_idx):
        # 初始化一个空的元组,用于存储重新排序后的缓存数据
        reordered_past = ()
        # 遍历每层的缓存数据
        for layer_past in past_key_values:
            # 对每层的缓存数据进行重新排序,并将结果作为元组加入到 `reordered_past` 中
            reordered_past += (
                # 对每个 `past_state` 根据 `beam_idx` 进行索引选择,并放到对应设备上
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        # 返回重新排序后的缓存数据
        return reordered_past
"""
The Mixtral Model transformer with a sequence classification head on top (linear layer).

[`MixtralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.

Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
"""
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL
class MixtralForSequenceClassification(MixtralPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.model = MixtralModel(config)  # 初始化 Mixtral 模型
        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)  # 线性层用于分类得分

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens  # 返回输入嵌入的模型

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value  # 设置输入嵌入的模型

    @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        """
        Forward pass for MixtralForSequenceClassification.

        Args:
            input_ids (torch.LongTensor, optional): Input token IDs.
            attention_mask (torch.Tensor, optional): Mask to avoid performing attention on padding tokens.
            position_ids (torch.LongTensor, optional): IDs to mark each token's position.
            past_key_values (List[torch.FloatTensor], optional): Cached key/value states for faster decoding.
            inputs_embeds (torch.FloatTensor, optional): Precomputed embeddings for the input tokens.
            labels (torch.LongTensor, optional): Labels for computing the sequence classification loss.
            use_cache (bool, optional): Whether or not to use cached key/value states.
            output_attentions (bool, optional): Whether or not to output attentions weights.
            output_hidden_states (bool, optional): Whether or not to output hidden states.
            return_dict (bool, optional): Whether or not to return a dictionary as the output.

        Returns:
            Depending on `return_dict`, either a model output dictionary or a tuple of logits and loss.

        Notes:
            This method defines how inputs are processed through the Mixtral model for sequence classification.
        """
        # 实现 MixtralForSequenceClassification 的前向传播
        # 具体实现根据参数的不同选择执行不同的操作,最终返回结果
        pass

.\models\mixtral\__init__.py

# 导入需要的模块和函数
from typing import TYPE_CHECKING
# 从项目内部工具中导入必要的异常和工具函数
from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_torch_available,
)

# 定义模块的导入结构,指定哪些类和函数可以被外部导入
_import_structure = {
    "configuration_mixtral": ["MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP", "MixtralConfig"],
}

# 检查是否 Torch 可用,如果不可用则抛出异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果 Torch 可用,则添加额外的模型类到导入结构中
    _import_structure["modeling_mixtral"] = [
        "MixtralForCausalLM",
        "MixtralModel",
        "MixtralPreTrainedModel",
        "MixtralForSequenceClassification",
    ]

# 如果是类型检查阶段,则从配置和模型模块导入特定类和常量
if TYPE_CHECKING:
    from .configuration_mixtral import MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP, MixtralConfig

    # 再次检查 Torch 是否可用,如果不可用则忽略
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果 Torch 可用,则从模型模块导入额外的模型类
        from .modeling_mixtral import (
            MixtralForCausalLM,
            MixtralForSequenceClassification,
            MixtralModel,
            MixtralPreTrainedModel,
        )

# 如果不是类型检查阶段,则将当前模块注册为一个 LazyModule
else:
    import sys

    # 动态将当前模块替换为 LazyModule 对象,这样在导入时会延迟加载模块内容
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\mluke\convert_mluke_original_pytorch_checkpoint_to_pytorch.py

# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert mLUKE checkpoint."""

import argparse
import json
import os
from collections import OrderedDict

import torch

from transformers import LukeConfig, LukeForMaskedLM, MLukeTokenizer, XLMRobertaTokenizer
from transformers.tokenization_utils_base import AddedToken


@torch.no_grad()
def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, pytorch_dump_folder_path, model_size):
    # 从元数据文件中加载配置信息
    with open(metadata_path) as metadata_file:
        metadata = json.load(metadata_file)
    # 根据元数据配置创建 LukeConfig 对象
    config = LukeConfig(use_entity_aware_attention=True, **metadata["model_config"])

    # 加载来自 checkpoint_path 的模型权重
    state_dict = torch.load(checkpoint_path, map_location="cpu")["module"]

    # 加载实体词汇表文件
    entity_vocab = load_original_entity_vocab(entity_vocab_path)
    # 添加一个新条目用于 [MASK2]
    entity_vocab["[MASK2]"] = max(entity_vocab.values()) + 1
    config.entity_vocab_size += 1

    # 根据元数据中指定的 BERT 模型名称加载 tokenizer
    tokenizer = XLMRobertaTokenizer.from_pretrained(metadata["model_config"]["bert_model_name"])

    # 为下游任务向 token 词汇表添加特殊 token
    entity_token_1 = AddedToken("<ent>", lstrip=False, rstrip=False)
    entity_token_2 = AddedToken("<ent2>", lstrip=False, rstrip=False)
    tokenizer.add_special_tokens({"additional_special_tokens": [entity_token_1, entity_token_2]})
    config.vocab_size += 2

    # 打印信息,保存 tokenizer 到指定路径
    print(f"Saving tokenizer to {pytorch_dump_folder_path}")
    tokenizer.save_pretrained(pytorch_dump_folder_path)

    # 更新 tokenizer 配置文件
    with open(os.path.join(pytorch_dump_folder_path, "tokenizer_config.json"), "r") as f:
        tokenizer_config = json.load(f)
    tokenizer_config["tokenizer_class"] = "MLukeTokenizer"
    with open(os.path.join(pytorch_dump_folder_path, "tokenizer_config.json"), "w") as f:
        json.dump(tokenizer_config, f)

    # 将实体词汇表写入指定路径
    with open(os.path.join(pytorch_dump_folder_path, MLukeTokenizer.vocab_files_names["entity_vocab_file"]), "w") as f:
        json.dump(entity_vocab, f)

    # 从保存路径加载 MLukeTokenizer
    tokenizer = MLukeTokenizer.from_pretrained(pytorch_dump_folder_path)

    # 初始化特殊 token 的嵌入向量
    ent_init_index = tokenizer.convert_tokens_to_ids(["@"])[0]
    ent2_init_index = tokenizer.convert_tokens_to_ids(["#"])[0]

    # 获取词嵌入权重
    word_emb = state_dict["embeddings.word_embeddings.weight"]
    # 提取第一个特殊 token 的嵌入向量并扩展维度
    ent_emb = word_emb[ent_init_index].unsqueeze(0)
    # 获取第二个实体的嵌入向量,并添加一个维度使其成为二维张量
    ent2_emb = word_emb[ent2_init_index].unsqueeze(0)
    
    # 将词嵌入、第一个实体嵌入和第二个实体嵌入连接起来,更新到模型的权重中
    state_dict["embeddings.word_embeddings.weight"] = torch.cat([word_emb, ent_emb, ent2_emb])
    
    # 为 'entity_predictions.bias' 添加特殊的标记
    for bias_name in ["lm_head.decoder.bias", "lm_head.bias"]:
        # 获取当前偏置的张量
        decoder_bias = state_dict[bias_name]
        # 获取第一个实体的偏置并添加一个维度使其成为二维张量
        ent_decoder_bias = decoder_bias[ent_init_index].unsqueeze(0)
        # 获取第二个实体的偏置并添加一个维度使其成为二维张量
        ent2_decoder_bias = decoder_bias[ent2_init_index].unsqueeze(0)
        # 将三个偏置连接起来,更新到模型的偏置中
        state_dict[bias_name] = torch.cat([decoder_bias, ent_decoder_bias, ent2_decoder_bias])

    # 初始化实体感知自注意力机制中查询层的权重和偏置
    for layer_index in range(config.num_hidden_layers):
        for matrix_name in ["query.weight", "query.bias"]:
            prefix = f"encoder.layer.{layer_index}.attention.self."
            # 复制查询层权重和偏置到不同的实体组合中
            state_dict[prefix + "w2e_" + matrix_name] = state_dict[prefix + matrix_name]
            state_dict[prefix + "e2w_" + matrix_name] = state_dict[prefix + matrix_name]
            state_dict[prefix + "e2e_" + matrix_name] = state_dict[prefix + matrix_name]

    # 使用 '[MASK]' 实体的嵌入来初始化 '[MASK2]' 实体的嵌入,用于下游任务
    entity_emb = state_dict["entity_embeddings.entity_embeddings.weight"]
    entity_mask_emb = entity_emb[entity_vocab["[MASK]"]].unsqueeze(0)
    state_dict["entity_embeddings.entity_embeddings.weight"] = torch.cat([entity_emb, entity_mask_emb])
    
    # 为 'entity_predictions.bias' 添加 '[MASK2]' 实体的偏置
    entity_prediction_bias = state_dict["entity_predictions.bias"]
    entity_mask_bias = entity_prediction_bias[entity_vocab["[MASK]"]].unsqueeze(0)
    state_dict["entity_predictions.bias"] = torch.cat([entity_prediction_bias, entity_mask_bias])

    # 初始化 Luke 模型作为一个评估模型
    model = LukeForMaskedLM(config=config).eval()

    # 移除不需要的权重
    state_dict.pop("entity_predictions.decoder.weight")
    state_dict.pop("lm_head.decoder.weight")
    state_dict.pop("lm_head.decoder.bias")
    
    # 创建一个有序字典,以适应 Hugging Face 模型的加载要求
    state_dict_for_hugging_face = OrderedDict()
    for key, value in state_dict.items():
        if not (key.startswith("lm_head") or key.startswith("entity_predictions")):
            state_dict_for_hugging_face[f"luke.{key}"] = state_dict[key]
        else:
            state_dict_for_hugging_face[key] = state_dict[key]

    # 使用加载字典更新模型的权重,并忽略严格检查模式
    missing_keys, unexpected_keys = model.load_state_dict(state_dict_for_hugging_face, strict=False)

    # 检查是否存在不期望的键
    if set(unexpected_keys) != {"luke.embeddings.position_ids"}:
        raise ValueError(f"Unexpected unexpected_keys: {unexpected_keys}")
    
    # 检查是否存在缺失的键
    if set(missing_keys) != {
        "lm_head.decoder.weight",
        "lm_head.decoder.bias",
        "entity_predictions.decoder.weight",
    }:
        raise ValueError(f"Unexpected missing_keys: {missing_keys}")

    # 对模型的权重进行绑定
    model.tie_weights()
    
    # 断言 Luke 模型的词嵌入与 lm_head 解码器的权重完全相等
    assert (model.luke.embeddings.word_embeddings.weight == model.lm_head.decoder.weight).all()
    
    # 断言 Luke 模型的实体嵌入与 entity_predictions 解码器的权重完全相等
    assert (model.luke.entity_embeddings.entity_embeddings.weight == model.entity_predictions.decoder.weight).all()

    # 检查输出
    # 从预训练模型文件夹路径加载 MLukeTokenizer,用于实体分类任务的标记化
    tokenizer = MLukeTokenizer.from_pretrained(pytorch_dump_folder_path, task="entity_classification")
    
    # 定义要输入的文本字符串及其实体范围
    text = "ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン (Afghanistan)."
    span = (0, 9)
    # 使用加载的 tokenizer 对文本进行编码,指定实体范围,并返回 PyTorch 张量
    encoding = tokenizer(text, entity_spans=[span], return_tensors="pt")
    
    # 使用模型进行推理,传入编码后的文本
    outputs = model(**encoding)
    
    # 验证词级别的隐藏状态
    if model_size == "large":
        raise NotImplementedError
    else:  # base
        expected_shape = torch.Size((1, 33, 768))  # 预期的隐藏状态张量形状
        expected_slice = torch.tensor([[0.0892, 0.0596, -0.2819], [0.0134, 0.1199, 0.0573], [-0.0169, 0.0927, 0.0644]])  # 预期的部分张量切片
    
    # 检查模型输出的最后隐藏状态张量的形状是否符合预期
    if not (outputs.last_hidden_state.shape == expected_shape):
        raise ValueError(
            f"Outputs.last_hidden_state.shape is {outputs.last_hidden_state.shape}, Expected shape is {expected_shape}"
        )
    # 检查模型输出的最后隐藏状态张量的部分切片是否与预期的张量切片在指定容差下相似
    if not torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4):
        raise ValueError
    
    # 验证实体级别的隐藏状态
    if model_size == "large":
        raise NotImplementedError
    else:  # base
        expected_shape = torch.Size((1, 1, 768))  # 预期的实体级别隐藏状态张量形状
        expected_slice = torch.tensor([[-0.1482, 0.0609, 0.0322]])  # 预期的实体级别隐藏状态张量切片
    
    # 检查模型输出的实体级别隐藏状态张量的形状是否符合预期
    if not (outputs.entity_last_hidden_state.shape == expected_shape):
        raise ValueError(
            f"Outputs.entity_last_hidden_state.shape is {outputs.entity_last_hidden_state.shape}, Expected shape is"
            f" {expected_shape}"
        )
    # 检查模型输出的实体级别隐藏状态张量的部分切片是否与预期的张量切片在指定容差下相似
    if not torch.allclose(outputs.entity_last_hidden_state[0, :3, :3], expected_slice, atol=1e-4):
        raise ValueError
    
    # 验证掩码词/实体预测
    # 重新加载 tokenizer(可能是为了覆盖先前的实体分类配置)
    tokenizer = MLukeTokenizer.from_pretrained(pytorch_dump_folder_path)
    text = "Tokyo is the capital of <mask>."
    span = (24, 30)
    # 使用重新加载的 tokenizer 对新文本进行编码,指定实体范围,并返回 PyTorch 张量
    encoding = tokenizer(text, entity_spans=[span], return_tensors="pt")
    
    # 使用模型进行推理,传入编码后的文本
    outputs = model(**encoding)
    
    # 获取输入的 token_ids 并找到 <mask> 的位置
    input_ids = encoding["input_ids"][0].tolist()
    mask_position_id = input_ids.index(tokenizer.convert_tokens_to_ids("<mask>"))
    # 在模型输出的 logits 中找到预测的 token_id
    predicted_id = outputs.logits[0][mask_position_id].argmax(dim=-1)
    assert "Japan" == tokenizer.decode(predicted_id)  # 断言预测的实体是 "Japan"
    
    # 在实体 logits 中找到预测的实体 ID,并根据 tokenizer 的实体词汇表找到对应的多语言实体
    predicted_entity_id = outputs.entity_logits[0][0].argmax().item()
    multilingual_predicted_entities = [
        entity for entity, entity_id in tokenizer.entity_vocab.items() if entity_id == predicted_entity_id
    ]
    assert [e for e in multilingual_predicted_entities if e.startswith("en:")][0] == "en:Japan"  # 断言多语言实体是 "en:Japan"
    
    # 最后,保存 PyTorch 模型和 tokenizer 到指定路径
    print("Saving PyTorch model to {}".format(pytorch_dump_folder_path))
    model.save_pretrained(pytorch_dump_folder_path)
# 加载原始实体词汇表的函数
def load_original_entity_vocab(entity_vocab_path):
    # 定义特殊的标记列表
    SPECIAL_TOKENS = ["[MASK]", "[PAD]", "[UNK]"]

    # 打开实体词汇表文件,逐行加载JSON数据
    data = [json.loads(line) for line in open(entity_vocab_path)]

    # 创建一个新的映射字典
    new_mapping = {}
    # 遍历加载的每个实体词汇表条目
    for entry in data:
        # 获取实体的唯一标识符
        entity_id = entry["id"]
        # 遍历每个实体的名称和语言信息
        for entity_name, language in entry["entities"]:
            # 如果实体名称在特殊标记列表中,则将其映射到对应的实体ID
            if entity_name in SPECIAL_TOKENS:
                new_mapping[entity_name] = entity_id
                break
            # 否则,将实体名称和语言组合成新的实体名称
            new_entity_name = f"{language}:{entity_name}"
            # 将新的实体名称映射到对应的实体ID
            new_mapping[new_entity_name] = entity_id

    # 返回创建的新映射字典
    return new_mapping


if __name__ == "__main__":
    # 创建参数解析器
    parser = argparse.ArgumentParser()
    # 添加必需的参数
    parser.add_argument("--checkpoint_path", type=str, help="Path to a pytorch_model.bin file.")
    parser.add_argument(
        "--metadata_path", default=None, type=str, help="Path to a metadata.json file, defining the configuration."
    )
    parser.add_argument(
        "--entity_vocab_path",
        default=None,
        type=str,
        help="Path to an entity_vocab.tsv file, containing the entity vocabulary.",
    )
    parser.add_argument(
        "--pytorch_dump_folder_path", default=None, type=str, help="Path to where to dump the output PyTorch model."
    )
    parser.add_argument(
        "--model_size", default="base", type=str, choices=["base", "large"], help="Size of the model to be converted."
    )
    # 解析命令行参数
    args = parser.parse_args()

    # 调用函数,转换LUKE模型的检查点
    convert_luke_checkpoint(
        args.checkpoint_path,
        args.metadata_path,
        args.entity_vocab_path,
        args.pytorch_dump_folder_path,
        args.model_size,
    )

.\models\mluke\tokenization_mluke.py

# 导入所需的模块和库
import itertools  # 导入itertools模块,用于高效循环操作
import json  # 导入json模块,用于处理JSON格式的数据
import os  # 导入os模块,用于与操作系统进行交互
from collections.abc import Mapping  # 从collections.abc模块导入Mapping抽象基类
from shutil import copyfile  # 从shutil模块导入copyfile函数,用于复制文件
from typing import Any, Dict, List, Optional, Tuple, Union  # 导入类型提示所需的类和函数

import numpy as np  # 导入numpy库,用于数值计算
import sentencepiece as spm  # 导入sentencepiece库,用于处理文本分词

# 导入transformers库中的相关模块和函数
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import (
    ENCODE_KWARGS_DOCSTRING,
    AddedToken,
    BatchEncoding,
    EncodedInput,
    PaddingStrategy,
    TensorType,
    TextInput,
    TextInputPair,
    TruncationStrategy,
    to_py_obj,
)
from ...utils import add_end_docstrings, is_tf_tensor, is_torch_tensor, logging  # 导入一些工具函数和类


logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器对象

EntitySpan = Tuple[int, int]  # 定义EntitySpan类型为(int, int)元组
EntitySpanInput = List[EntitySpan]  # 定义EntitySpanInput类型为元素为EntitySpan的列表
Entity = str  # 定义Entity类型为字符串
EntityInput = List[Entity]  # 定义EntityInput类型为元素为Entity的列表

SPIECE_UNDERLINE = "▁"  # 定义SPIECE_UNDERLINE为"▁"

VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "entity_vocab_file": "entity_vocab.json"}  # 定义VOCAB_FILES_NAMES字典

# 定义PRETRAINED_VOCAB_FILES_MAP字典,包含预训练模型和其对应的文件路径
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "studio-ousia/mluke-base": "https://huggingface.co/studio-ousia/mluke-base/resolve/main/vocab.json",
    },
    "merges_file": {
        "studio-ousia/mluke-base": "https://huggingface.co/studio-ousia/mluke-base/resolve/main/merges.txt",
    },
    "entity_vocab_file": {
        "studio-ousia/mluke-base": "https://huggingface.co/studio-ousia/mluke-base/resolve/main/entity_vocab.json",
    },
}

# 定义PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES字典,包含预训练模型和其对应的位置嵌入大小
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "studio-ousia/mluke-base": 512,
}


class MLukeTokenizer(PreTrainedTokenizer):
    """
    Adapted from [`XLMRobertaTokenizer`] and [`LukeTokenizer`]. Based on
    [SentencePiece](https://github.com/google/sentencepiece).

    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
    this superclass for more information regarding those methods.

    Attributes:
        sp_model (`SentencePieceProcessor`):
            The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
    """

    # 定义MLukeTokenizer类,继承自PreTrainedTokenizer类

    vocab_files_names = VOCAB_FILES_NAMES  # 设置类属性vocab_files_names为VOCAB_FILES_NAMES
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP  # 设置类属性pretrained_vocab_files_map为PRETRAINED_VOCAB_FILES_MAP
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES  # 设置类属性max_model_input_sizes为PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    model_input_names = ["input_ids", "attention_mask"]  # 设置类属性model_input_names为包含"input_ids"和"attention_mask"的列表
    # 初始化方法,用于创建一个新的 tokenizer 对象
    def __init__(
        self,
        vocab_file,
        entity_vocab_file,
        bos_token="<s>",
        eos_token="</s>",
        sep_token="</s>",
        cls_token="<s>",
        unk_token="<unk>",
        pad_token="<pad>",
        mask_token="<mask>",
        task=None,
        max_entity_length=32,
        max_mention_length=30,
        entity_token_1="<ent>",
        entity_token_2="<ent2>",
        entity_unk_token="[UNK]",
        entity_pad_token="[PAD]",
        entity_mask_token="[MASK]",
        entity_mask2_token="[MASK2]",
        sp_model_kwargs: Optional[Dict[str, Any]] = None,
        **kwargs,
    ):
        # 继承父类构造方法
        super().__init__(**kwargs)

    @property
    # 计算属性:返回 tokenizer 的词汇量大小
    # 从 transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizer.vocab_size 复制而来
    def vocab_size(self):
        return len(self.sp_model) + self.fairseq_offset + 1  # 添加 <mask> token 的数量到词汇量中

    # 获取词汇表的方法
    # 从 transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizer.get_vocab 复制而来
    def get_vocab(self):
        # 创建词汇表字典,包括转换后的 token 到 id 的映射
        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
        vocab.update(self.added_tokens_encoder)
        return vocab

    # 分词方法
    # 从 transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizer._tokenize 复制而来
    def _tokenize(self, text: str) -> List[str]:
        # 使用 SentencePiece 模型进行文本编码
        # TODO 检查是否适用于 t5/llama PR
        return self.sp_model.encode(text, out_type=str)

    # 将 token 转换为 id 的方法
    # 从 transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizer._convert_token_to_id 复制而来
    def _convert_token_to_id(self, token):
        """Converts a token (str) in an id using the vocab."""
        # 如果 token 存在于 fairseq_tokens_to_ids 中,直接返回对应的 id
        if token in self.fairseq_tokens_to_ids:
            return self.fairseq_tokens_to_ids[token]
        # 否则,使用 SentencePiece 模型将 token 转换为 id
        spm_id = self.sp_model.PieceToId(token)

        # 如果 spm_id 为 0,返回未知 token 的 id
        return spm_id + self.fairseq_offset if spm_id else self.unk_token_id

    # 将 id 转换为 token 的方法
    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        # 如果 index 存在于 fairseq_ids_to_tokens 中,直接返回对应的 token
        if index in self.fairseq_ids_to_tokens:
            return self.fairseq_ids_to_tokens[index]
        # 否则,使用 SentencePiece 模型将 id 转换为 token
        return self.sp_model.IdToPiece(index - self.fairseq_offset)

    # 将 token 序列转换为字符串的方法
    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (strings for sub-words) in a single string."""
        # 将 token 序列连接成字符串,并替换 SPIECE_UNDERLINE 为空格,然后去除首尾空格
        out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
        return out_string

    # 序列化对象状态的方法
    def __getstate__(self):
        state = self.__dict__.copy()
        # 清空 sp_model,因为其不可序列化
        state["sp_model"] = None
        # 将 sp_model_proto 序列化并保存在状态中
        state["sp_model_proto"] = self.sp_model.serialized_model_proto()
        return state

    # 反序列化对象状态的方法
    def __setstate__(self, d):
        self.__dict__ = d

        # 为了向后兼容,如果不存在 sp_model_kwargs,设置为空字典
        if not hasattr(self, "sp_model_kwargs"):
            self.sp_model_kwargs = {}

        # 根据 sp_model_kwargs 创建 SentencePieceProcessor 对象,并从序列化的 proto 中加载模型
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
    # 将两个参数字典化,用于增强函数文档的功能
    @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
    # 从 transformers.models.luke.tokenization_luke.LukeTokenizer.__call__ 复制而来的函数定义
    def __call__(
        self,
        text: Union[TextInput, List[TextInput]],
        text_pair: Optional[Union[TextInput, List[TextInput]]] = None,
        entity_spans: Optional[Union[EntitySpanInput, List[EntitySpanInput]]] = None,
        entity_spans_pair: Optional[Union[EntitySpanInput, List[EntitySpanInput]]] = None,
        entities: Optional[Union[EntityInput, List[EntityInput]]] = None,
        entities_pair: Optional[Union[EntityInput, List[EntityInput]]] = None,
        add_special_tokens: bool = True,
        padding: Union[bool, str, PaddingStrategy] = False,
        truncation: Union[bool, str, TruncationStrategy] = None,
        max_length: Optional[int] = None,
        max_entity_length: Optional[int] = None,
        stride: int = 0,
        is_split_into_words: Optional[bool] = False,
        pad_to_multiple_of: Optional[int] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        return_token_type_ids: Optional[bool] = None,
        return_attention_mask: Optional[bool] = None,
        return_overflowing_tokens: bool = False,
        return_special_tokens_mask: bool = False,
        return_offsets_mapping: bool = False,
        return_length: bool = False,
        verbose: bool = True,
        **kwargs,
    ):
        # 从 transformers.models.luke.tokenization_luke.LukeTokenizer._encode_plus 复制而来的函数定义
        def _encode_plus(
            self,
            text: Union[TextInput],
            text_pair: Optional[Union[TextInput]] = None,
            entity_spans: Optional[EntitySpanInput] = None,
            entity_spans_pair: Optional[EntitySpanInput] = None,
            entities: Optional[EntityInput] = None,
            entities_pair: Optional[EntityInput] = None,
            add_special_tokens: bool = True,
            padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
            truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
            max_length: Optional[int] = None,
            max_entity_length: Optional[int] = None,
            stride: int = 0,
            is_split_into_words: Optional[bool] = False,
            pad_to_multiple_of: Optional[int] = None,
            return_tensors: Optional[Union[str, TensorType]] = None,
            return_token_type_ids: Optional[bool] = None,
            return_attention_mask: Optional[bool] = None,
            return_overflowing_tokens: bool = False,
            return_special_tokens_mask: bool = False,
            return_offsets_mapping: bool = False,
            return_length: bool = False,
            verbose: bool = True,
            **kwargs,
        ):
    # 定义函数签名,声明返回类型为 BatchEncoding
    ) -> BatchEncoding:
        # 如果要求返回偏移映射,则抛出 NotImplementedError 异常
        if return_offsets_mapping:
            raise NotImplementedError(
                "return_offset_mapping is not available when using Python tokenizers. "
                "To use this feature, change your tokenizer to one deriving from "
                "transformers.PreTrainedTokenizerFast. "
                "More information on available tokenizers at "
                "https://github.com/huggingface/transformers/pull/2674"
            )

        # 如果指定 is_split_into_words 参数为 True,则抛出 NotImplementedError 异常
        if is_split_into_words:
            raise NotImplementedError("is_split_into_words is not supported in this tokenizer.")

        (
            # 调用内部方法 _create_input_sequence,生成输入序列所需的各个 ID 和实体标记跨度
            first_ids,
            second_ids,
            first_entity_ids,
            second_entity_ids,
            first_entity_token_spans,
            second_entity_token_spans,
        ) = self._create_input_sequence(
            # 传入文本和其配对文本、实体和其配对实体、实体标记跨度等参数
            text=text,
            text_pair=text_pair,
            entities=entities,
            entities_pair=entities_pair,
            entity_spans=entity_spans,
            entity_spans_pair=entity_spans_pair,
            **kwargs,  # 接受其他可能的关键字参数
        )

        # 调用 prepare_for_model 方法,生成模型输入所需的 attention_mask 和 token_type_ids
        # 返回结果作为函数结果
        return self.prepare_for_model(
            first_ids,
            pair_ids=second_ids,
            entity_ids=first_entity_ids,
            pair_entity_ids=second_entity_ids,
            entity_token_spans=first_entity_token_spans,
            pair_entity_token_spans=second_entity_token_spans,
            add_special_tokens=add_special_tokens,
            padding=padding_strategy.value,
            truncation=truncation_strategy.value,
            max_length=max_length,
            max_entity_length=max_entity_length,
            stride=stride,
            pad_to_multiple_of=pad_to_multiple_of,
            return_tensors=return_tensors,
            prepend_batch_axis=True,
            return_attention_mask=return_attention_mask,
            return_token_type_ids=return_token_type_ids,
            return_overflowing_tokens=return_overflowing_tokens,
            return_special_tokens_mask=return_special_tokens_mask,
            return_length=return_length,
            verbose=verbose,
        )

    # 从 transformers.models.luke.tokenization_luke.LukeTokenizer._batch_encode_plus 复制而来的代码
    # 定义一个方法用于批量编码文本或文本对,同时处理实体和实体跨度输入
    def _batch_encode_plus(
        self,
        batch_text_or_text_pairs: Union[List[TextInput], List[TextInputPair]],
        batch_entity_spans_or_entity_spans_pairs: Optional[
            Union[List[EntitySpanInput], List[Tuple[EntitySpanInput, EntitySpanInput]]]
        ] = None,
        batch_entities_or_entities_pairs: Optional[
            Union[List[EntityInput], List[Tuple[EntityInput, EntityInput]]]
        ] = None,
        add_special_tokens: bool = True,
        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
        truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
        max_length: Optional[int] = None,
        max_entity_length: Optional[int] = None,
        stride: int = 0,
        is_split_into_words: Optional[bool] = False,
        pad_to_multiple_of: Optional[int] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        return_token_type_ids: Optional[bool] = None,
        return_attention_mask: Optional[bool] = None,
        return_overflowing_tokens: bool = False,
        return_special_tokens_mask: bool = False,
        return_offsets_mapping: bool = False,
        return_length: bool = False,
        verbose: bool = True,
        **kwargs,
    ):
        # 检查实体输入格式的方法,确保实体跨度是列表且长度匹配
        # 如果 entity_spans 不是列表,抛出 ValueError
        if not isinstance(entity_spans, list):
            raise ValueError("entity_spans should be given as a list")
        # 如果 entity_spans 长度大于 0 且第一个元素不是元组,抛出 ValueError
        elif len(entity_spans) > 0 and not isinstance(entity_spans[0], tuple):
            raise ValueError(
                "entity_spans should be given as a list of tuples containing the start and end character indices"
            )

        # 如果指定了 entities,则需检查其格式是否正确
        if entities is not None:
            # 如果 entities 不是列表,抛出 ValueError
            if not isinstance(entities, list):
                raise ValueError("If you specify entities, they should be given as a list")
            # 如果 entities 长度大于 0 且第一个元素不是字符串,抛出 ValueError
            if len(entities) > 0 and not isinstance(entities[0], str):
                raise ValueError("If you specify entities, they should be given as a list of entity names")
            # 如果 entities 和 entity_spans 长度不一致,抛出 ValueError
            if len(entities) != len(entity_spans):
                raise ValueError("If you specify entities, entities and entity_spans must be the same length")

    # 创建输入序列的方法,用于构建模型输入
    def _create_input_sequence(
        self,
        text: Union[TextInput],
        text_pair: Optional[Union[TextInput]] = None,
        entities: Optional[EntityInput] = None,
        entities_pair: Optional[EntityInput] = None,
        entity_spans: Optional[EntitySpanInput] = None,
        entity_spans_pair: Optional[EntitySpanInput] = None,
        **kwargs,
    ):
        pass  # 此方法的实现可能包括文本编码、实体标记等操作,这里未提供具体实现

    # 批量为模型准备输入的方法,扩展了 _batch_encode_plus 方法的功能
    @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
    # 该装饰器通常用于增强方法的文档说明,添加了关于编码的额外参数的文档说明
    # 实际功能由所在类中的 _batch_encode_plus 方法实现
    # 批量准备模型输入数据,为每个批次准备模型所需的输入
    def _batch_prepare_for_model(
        # 批次中每个样本的ID与空值对,即[(ids, None), ...]
        batch_ids_pairs: List[Tuple[List[int], None]],
        # 批次中每个样本的实体ID对,即[(entity_ids1, entity_ids2), ...]
        batch_entity_ids_pairs: List[Tuple[Optional[List[int]], Optional[List[int]]]],
        # 批次中每个样本的实体token span对,即[(token_spans1, token_spans2), ...]
        batch_entity_token_spans_pairs: List[Tuple[Optional[List[Tuple[int, int]]], Optional[List[Tuple[int, int]]]]],
        # 是否添加特殊token,如[CLS]和[SEP]
        add_special_tokens: bool = True,
        # 填充策略,默认不填充
        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
        # 截断策略,默认不截断
        truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
        # 最大长度限制,可以为None
        max_length: Optional[int] = None,
        # 实体最大长度限制,可以为None
        max_entity_length: Optional[int] = None,
        # 步进值,默认为0
        stride: int = 0,
        # 填充到某个倍数,默认为None
        pad_to_multiple_of: Optional[int] = None,
        # 返回的张量类型,可以为None
        return_tensors: Optional[str] = None,
        # 是否返回token类型ID,可以为None
        return_token_type_ids: Optional[bool] = None,
        # 是否返回注意力掩码,可以为None
        return_attention_mask: Optional[bool] = None,
        # 是否返回溢出的token,默认为False
        return_overflowing_tokens: bool = False,
        # 是否返回特殊token掩码,默认为False
        return_special_tokens_mask: bool = False,
        # 是否返回长度信息,默认为False
        return_length: bool = False,
        # 是否输出详细信息,即使默认为True
        verbose: bool = True,
    ) -> BatchEncoding:
        """
        Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
        adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
        manages a moving window (with user defined stride) for overflowing tokens

        Args:
            batch_ids_pairs: list of tokenized input ids or input ids pairs
            batch_entity_ids_pairs: list of entity ids or entity ids pairs
            batch_entity_token_spans_pairs: list of entity spans or entity spans pairs
            max_entity_length: The maximum length of the entity sequence.
        """

        batch_outputs = {}  # 初始化一个空字典用于存储每个批次的输出结果
        for input_ids, entity_ids, entity_token_span_pairs in zip(
            batch_ids_pairs, batch_entity_ids_pairs, batch_entity_token_spans_pairs
        ):
            first_ids, second_ids = input_ids  # 将输入的 ids 对分解为第一个和第二个序列
            first_entity_ids, second_entity_ids = entity_ids  # 将实体 ids 对分解为第一个和第二个实体序列
            first_entity_token_spans, second_entity_token_spans = entity_token_span_pairs  # 将实体 token spans 对分解为第一个和第二个实体 token spans

            # 调用 self.prepare_for_model 方法处理输入,准备模型输入
            outputs = self.prepare_for_model(
                first_ids,
                second_ids,
                entity_ids=first_entity_ids,
                pair_entity_ids=second_entity_ids,
                entity_token_spans=first_entity_token_spans,
                pair_entity_token_spans=second_entity_token_spans,
                add_special_tokens=add_special_tokens,
                padding=PaddingStrategy.DO_NOT_PAD.value,  # 在之后的批次中进行填充
                truncation=truncation_strategy.value,
                max_length=max_length,
                max_entity_length=max_entity_length,
                stride=stride,
                pad_to_multiple_of=None,  # 在之后的批次中进行填充
                return_attention_mask=False,  # 在之后的批次中返回注意力掩码
                return_token_type_ids=return_token_type_ids,
                return_overflowing_tokens=return_overflowing_tokens,
                return_special_tokens_mask=return_special_tokens_mask,
                return_length=return_length,
                return_tensors=None,  # 最终将整个批次转换为张量
                prepend_batch_axis=False,
                verbose=verbose,
            )

            # 将每个输出添加到 batch_outputs 字典中对应的列表中
            for key, value in outputs.items():
                if key not in batch_outputs:
                    batch_outputs[key] = []
                batch_outputs[key].append(value)

        # 调用 self.pad 方法对批次进行填充处理
        batch_outputs = self.pad(
            batch_outputs,
            padding=padding_strategy.value,
            max_length=max_length,
            pad_to_multiple_of=pad_to_multiple_of,
            return_attention_mask=return_attention_mask,
        )

        # 将处理后的 batch_outputs 转换为 BatchEncoding 对象
        batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)

        # 返回最终的 batch_outputs
        return batch_outputs

    @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
    # Copied from transformers.models.luke.tokenization_luke.LukeTokenizer.prepare_for_model
    def prepare_for_model(
        self,
        ids: List[int],
        pair_ids: Optional[List[int]] = None,
        entity_ids: Optional[List[int]] = None,
        pair_entity_ids: Optional[List[int]] = None,
        entity_token_spans: Optional[List[Tuple[int, int]]] = None,
        pair_entity_token_spans: Optional[List[Tuple[int, int]]] = None,
        add_special_tokens: bool = True,
        padding: Union[bool, str, PaddingStrategy] = False,
        truncation: Union[bool, str, TruncationStrategy] = None,
        max_length: Optional[int] = None,
        max_entity_length: Optional[int] = None,
        stride: int = 0,
        pad_to_multiple_of: Optional[int] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        return_token_type_ids: Optional[bool] = None,
        return_attention_mask: Optional[bool] = None,
        return_overflowing_tokens: bool = False,
        return_special_tokens_mask: bool = False,
        return_offsets_mapping: bool = False,
        return_length: bool = False,
        verbose: bool = True,
        prepend_batch_axis: bool = False,
        **kwargs,
    ):
        # 准备输入数据以供模型使用,可以设置添加特殊标记、填充、截断等策略
        pass
    
    # Copied from transformers.models.luke.tokenization_luke.LukeTokenizer.pad
    def pad(
        self,
        encoded_inputs: Union[
            BatchEncoding,
            List[BatchEncoding],
            Dict[str, EncodedInput],
            Dict[str, List[EncodedInput]],
            List[Dict[str, EncodedInput]],
        ],
        padding: Union[bool, str, PaddingStrategy] = True,
        max_length: Optional[int] = None,
        max_entity_length: Optional[int] = None,
        pad_to_multiple_of: Optional[int] = None,
        return_attention_mask: Optional[bool] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        verbose: bool = True,
    ):
        # 对输入进行填充处理,支持不同的填充策略和最大长度设置
        pass
    
    # Copied from transformers.models.luke.tokenization_luke.LukeTokenizer._pad
    def _pad(
        self,
        encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
        max_length: Optional[int] = None,
        max_entity_length: Optional[int] = None,
        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
        pad_to_multiple_of: Optional[int] = None,
        return_attention_mask: Optional[bool] = None,
    ):
        # 内部方法:根据指定的填充策略对输入进行填充,支持最大长度和多样化填充倍数
        pass
    # 将词汇表保存到指定目录下的文件中,并返回保存的文件路径
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str, str]:
        # 检查保存目录是否存在,如果不存在则记录错误并返回
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return
        
        # 构建输出词汇表文件路径,根据可选的前缀和文件名构造
        out_vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )

        # 如果当前词汇表文件路径与输出路径不同且当前词汇表文件存在,则复制当前词汇表文件到输出路径
        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
            copyfile(self.vocab_file, out_vocab_file)
        # 如果当前词汇表文件不存在,则将模型序列化后的内容写入输出文件
        elif not os.path.isfile(self.vocab_file):
            with open(out_vocab_file, "wb") as fi:
                content_spiece_model = self.sp_model.serialized_model_proto()
                fi.write(content_spiece_model)

        # 构建实体词汇表文件路径
        entity_vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["entity_vocab_file"]
        )

        # 将实体词汇表以 JSON 格式写入文件
        with open(entity_vocab_file, "w", encoding="utf-8") as f:
            f.write(json.dumps(self.entity_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")

        # 返回保存的词汇表文件路径和实体词汇表文件路径
        return out_vocab_file, entity_vocab_file

    # 从 XLM-RoBERTa Tokenizer 类中复制的方法:构建带有特殊标记的输入序列
    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
        adding special tokens. An XLM-RoBERTa sequence has the following format:

        - single sequence: `<s> X </s>`
        - pair of sequences: `<s> A </s></s> B </s>`

        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
        """

        # 如果只有一个输入序列,添加起始和结束标记
        if token_ids_1 is None:
            return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
        # 如果有两个输入序列,添加起始标记、两个结束标记和分隔标记
        cls = [self.cls_token_id]
        sep = [self.sep_token_id]
        return cls + token_ids_0 + sep + sep + token_ids_1 + sep

    # 从 XLM-RoBERTa Tokenizer 类中复制的方法:获取特殊标记的掩码
    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ):
    # 从没有添加特殊标记的标记列表中提取序列 ID。当使用分词器的 `prepare_for_model` 方法添加特殊标记时调用此方法。
    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
        """
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` method.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """

        if already_has_special_tokens:
            # 如果已经存在特殊标记,直接调用父类的方法获取特殊标记掩码
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        if token_ids_1 is None:
            # 如果没有第二个序列,返回一个列表:开始标记、token_ids_0 的长度个零、结束标记
            return [1] + ([0] * len(token_ids_0)) + [1]
        # 如果有第二个序列,返回一个列表:开始标记、token_ids_0 的长度个零、结束标记、两个分隔符、token_ids_1 的长度个零、结束标记
        return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]

    # 从 transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizer.create_token_type_ids_from_sequences 复制
    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLM-RoBERTa does
        not make use of token type ids, therefore a list of zeros is returned.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of zeros.

        """

        sep = [self.sep_token_id]  # 分隔符的 token ID
        cls = [self.cls_token_id]  # 开始标记的 token ID

        if token_ids_1 is None:
            # 如果没有第二个序列,返回一个全零列表,长度为开始标记、token_ids_0、分隔符的总长度
            return len(cls + token_ids_0 + sep) * [0]
        # 如果有第二个序列,返回一个全零列表,长度为开始标记、token_ids_0、两个分隔符、token_ids_1、分隔符的总长度
        return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]

.\models\mluke\__init__.py

# 引入类型检查模块
from typing import TYPE_CHECKING

# 引入自定义的异常类和延迟加载模块的工具函数
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_sentencepiece_available

# 定义一个空的导入结构字典
_import_structure = {}

# 尝试检测是否存在 SentencePiece 库,如果不存在则引发自定义的异常
try:
    if not is_sentencepiece_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    # 如果依赖项不可用,则忽略异常继续执行
    pass
else:
    # 如果依赖项可用,则将 MLukeTokenizer 添加到导入结构中
    _import_structure["tokenization_mluke"] = ["MLukeTokenizer"]

# 如果类型检查开启
if TYPE_CHECKING:
    try:
        # 再次检测是否存在 SentencePiece 库,如果不存在则引发自定义的异常
        if not is_sentencepiece_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        # 如果依赖项不可用,则忽略异常继续执行
        pass
    else:
        # 如果依赖项可用,则从 tokenization_mluke 模块导入 MLukeTokenizer 类
        from .tokenization_mluke import MLukeTokenizer

# 如果不是类型检查模式
else:
    # 导入 sys 模块
    import sys

    # 将当前模块替换为一个延迟加载模块,使用 _LazyModule 进行延迟加载
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\mobilebert\configuration_mobilebert.py

# coding=utf-8
# 指定文件编码为 UTF-8

# Copyright 2020 The HuggingFace Team. All rights reserved.
# 版权声明,保留所有权利

# Licensed under the Apache License, Version 2.0 (the "License");
# 根据 Apache License, Version 2.0 进行许可,允许在特定条件下使用、复制、修改和分发本软件
# you may not use this file except in compliance with the License.
# 除非符合许可协议,否则不能使用此文件

# You may obtain a copy of the License at
# 您可以在以下网址获取许可协议的副本
#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 除非适用法律要求或书面同意,否则依据 "原样" 分发本软件,不提供任何明示或暗示的保证或条件

# See the License for the specific language governing permissions and
# limitations under the License.
# 查阅许可协议,了解权限和限制

""" MobileBERT model configuration"""
# MobileBERT 模型配置

from collections import OrderedDict
# 导入 OrderedDict 类,用于有序字典的支持
from typing import Mapping
# 导入 Mapping 类型提示,用于支持映射类型的提示

from ...configuration_utils import PretrainedConfig
# 从配置工具中导入预训练配置类 PretrainedConfig
from ...onnx import OnnxConfig
# 从 onnx 模块导入 OnnxConfig
from ...utils import logging
# 从 utils 中导入 logging 模块

logger = logging.get_logger(__name__)
# 获取当前模块的日志记录器

MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "google/mobilebert-uncased": "https://huggingface.co/google/mobilebert-uncased/resolve/main/config.json"
}
# 预训练模型配置存档映射,提供模型名称到预训练配置文件的 URL 映射

class MobileBertConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`MobileBertModel`] or a [`TFMobileBertModel`]. It
    is used to instantiate a MobileBERT model according to the specified arguments, defining the model architecture.
    Instantiating a configuration with the defaults will yield a similar configuration to that of the MobileBERT
    [google/mobilebert-uncased](https://huggingface.co/google/mobilebert-uncased) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Examples:

    ```
    >>> from transformers import MobileBertConfig, MobileBertModel

    >>> # Initializing a MobileBERT configuration
    >>> configuration = MobileBertConfig()

    >>> # Initializing a model (with random weights) from the configuration above
    >>> model = MobileBertModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```

    Attributes: pretrained_config_archive_map (Dict[str, str]): A dictionary containing all the available pre-trained
    checkpoints.
    """
    # MobileBERT 配置类,用于存储 [`MobileBertModel`] 或 [`TFMobileBertModel`] 的配置。
    # 根据指定参数实例化 MobileBERT 模型,定义模型架构。
    # 使用默认配置实例化将产生与 MobileBERT [google/mobilebert-uncased](https://huggingface.co/google/mobilebert-uncased) 架构类似的配置。

    # 配置对象继承自 [`PretrainedConfig`],可用于控制模型输出。详细信息请阅读 [`PretrainedConfig`] 的文档。

    pretrained_config_archive_map = MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
    # 预训练模型配置存档映射,存储所有可用的预训练检查点

    model_type = "mobilebert"
    # 模型类型设定为 "mobilebert"
    # 初始化函数,用于初始化一个多头注意力模型的参数
    def __init__(
        self,
        vocab_size=30522,  # 词汇表大小,默认为30522
        hidden_size=512,  # 隐藏层大小,默认为512
        num_hidden_layers=24,  # 隐藏层的数量,默认为24层
        num_attention_heads=4,  # 注意力头的数量,默认为4个
        intermediate_size=512,  # 中间层大小,默认为512
        hidden_act="relu",  # 隐藏层激活函数,默认为ReLU
        hidden_dropout_prob=0.0,  # 隐藏层的dropout概率,默认为0.0(不使用)
        attention_probs_dropout_prob=0.1,  # 注意力机制的dropout概率,默认为0.1
        max_position_embeddings=512,  # 最大位置嵌入大小,默认为512
        type_vocab_size=2,  # 类型词汇表大小,默认为2
        initializer_range=0.02,  # 初始化范围,默认为0.02
        layer_norm_eps=1e-12,  # 层归一化的epsilon值,默认为1e-12
        pad_token_id=0,  # 填充token的ID,默认为0
        embedding_size=128,  # 嵌入大小,默认为128
        trigram_input=True,  # 是否使用trigram输入,默认为True
        use_bottleneck=True,  # 是否使用瓶颈结构,默认为True
        intra_bottleneck_size=128,  # 瓶颈内部大小,默认为128
        use_bottleneck_attention=False,  # 是否使用瓶颈的注意力,默认为False
        key_query_shared_bottleneck=True,  # 键和查询是否共享瓶颈,默认为True
        num_feedforward_networks=4,  # 前馈网络的数量,默认为4
        normalization_type="no_norm",  # 归一化类型,默认为"no_norm"
        classifier_activation=True,  # 分类器是否激活,默认为True
        classifier_dropout=None,  # 分类器的dropout概率,默认为None
        **kwargs,
    ):
        # 调用父类的初始化方法,传递填充token的ID和其他关键字参数
        super().__init__(pad_token_id=pad_token_id, **kwargs)

        # 初始化模型的各种参数
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.hidden_act = hidden_act
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps
        self.embedding_size = embedding_size
        self.trigram_input = trigram_input
        self.use_bottleneck = use_bottleneck
        self.intra_bottleneck_size = intra_bottleneck_size
        self.use_bottleneck_attention = use_bottleneck_attention
        self.key_query_shared_bottleneck = key_query_shared_bottleneck
        self.num_feedforward_networks = num_feedforward_networks
        self.normalization_type = normalization_type
        self.classifier_activation = classifier_activation

        # 根据是否使用瓶颈结构来确定真实的隐藏层大小
        if self.use_bottleneck:
            self.true_hidden_size = intra_bottleneck_size
        else:
            self.true_hidden_size = hidden_size

        # 分类器的dropout概率
        self.classifier_dropout = classifier_dropout
# 从 transformers.models.bert.configuration_bert.BertOnnxConfig 复制的代码,创建了 MobileBertOnnxConfig 类,用于配置 MobileBert 模型的 ONNX 格式设置。
class MobileBertOnnxConfig(OnnxConfig):
    # 定义 inputs 属性,返回一个映射,其中键为字符串,值为映射,映射的键为整数,值为字符串。
    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        # 如果任务是多选项问题 ("multiple-choice"),则定义动态轴为 {0: "batch", 1: "choice", 2: "sequence"}。
        if self.task == "multiple-choice":
            dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
        # 否则,定义动态轴为 {0: "batch", 1: "sequence"}。
        else:
            dynamic_axis = {0: "batch", 1: "sequence"}
        # 返回一个有序字典,包含输入名称到动态轴的映射。
        return OrderedDict(
            [
                ("input_ids", dynamic_axis),  # 输入名称 "input_ids" 映射到 dynamic_axis 中的值。
                ("attention_mask", dynamic_axis),  # 输入名称 "attention_mask" 映射到 dynamic_axis 中的值。
                ("token_type_ids", dynamic_axis),  # 输入名称 "token_type_ids" 映射到 dynamic_axis 中的值。
            ]
        )

.\models\mobilebert\convert_mobilebert_original_tf_checkpoint_to_pytorch.py

# 导入必要的模块和库
import argparse  # 用于解析命令行参数

import torch  # 导入PyTorch库

# 从transformers库中导入MobileBertConfig、MobileBertForPreTraining和load_tf_weights_in_mobilebert函数
from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert

# 从transformers.utils中导入logging模块
from transformers.utils import logging

# 设置日志输出级别为info
logging.set_verbosity_info()

# 定义函数:将TensorFlow的checkpoint转换为PyTorch的模型
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, mobilebert_config_file, pytorch_dump_path):
    # 从配置文件中加载MobileBERT模型的配置
    config = MobileBertConfig.from_json_file(mobilebert_config_file)
    # 打印配置信息
    print(f"Building PyTorch model from configuration: {config}")
    # 根据配置创建MobileBERT的预训练模型
    model = MobileBertForPreTraining(config)
    # 加载TensorFlow的checkpoint中的权重到PyTorch模型中
    model = load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path)
    # 打印保存PyTorch模型的路径
    print(f"Save PyTorch model to {pytorch_dump_path}")
    # 将PyTorch模型的状态字典保存到指定路径
    torch.save(model.state_dict(), pytorch_dump_path)


# 主程序入口
if __name__ == "__main__":
    # 创建命令行参数解析器
    parser = argparse.ArgumentParser()
    # 添加必选参数:TensorFlow的checkpoint路径
    parser.add_argument(
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
    )
    # 添加必选参数:MobileBERT模型配置文件的路径
    parser.add_argument(
        "--mobilebert_config_file",
        default=None,
        type=str,
        required=True,
        help=(
            "The config json file corresponding to the pre-trained MobileBERT model. \n"
            "This specifies the model architecture."
        ),
    )
    # 添加必选参数:输出的PyTorch模型路径
    parser.add_argument(
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
    )
    # 解析命令行参数
    args = parser.parse_args()
    # 调用转换函数,将TensorFlow的checkpoint转换为PyTorch模型
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.mobilebert_config_file, args.pytorch_dump_path)

.\models\mobilebert\modeling_mobilebert.py

# 导入必要的库和模块
import math  # 导入数学库,用于数学运算
import os  # 导入操作系统库,用于操作系统相关功能
import warnings  # 导入警告模块,用于处理警告信息
from dataclasses import dataclass  # 导入 dataclass 模块,用于创建数据类
from typing import Optional, Tuple, Union  # 导入类型提示相关模块

import torch  # 导入 PyTorch 模块
from torch import nn  # 导入神经网络模块
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss  # 导入损失函数

# 导入相关的模型输出类和工具函数
from ...activations import ACT2FN
from ...modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPooling,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    NextSentencePredictorOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
    ModelOutput,
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from .configuration_mobilebert import MobileBertConfig  # 导入 MobileBert 配置类

logger = logging.get_logger(__name__)  # 获取日志记录器

_CHECKPOINT_FOR_DOC = "google/mobilebert-uncased"  # 预训练模型的文档说明
_CONFIG_FOR_DOC = "MobileBertConfig"  # MobileBert 配置文档说明

# TokenClassification 文档字符串和期望输出
_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "mrm8488/mobilebert-finetuned-ner"
_TOKEN_CLASS_EXPECTED_OUTPUT = "['I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC']"
_TOKEN_CLASS_EXPECTED_LOSS = 0.03

# QuestionAnswering 文档字符串和期望输出
_CHECKPOINT_FOR_QA = "csarron/mobilebert-uncased-squad-v2"
_QA_EXPECTED_OUTPUT = "'a nice puppet'"
_QA_EXPECTED_LOSS = 3.98
_QA_TARGET_START_INDEX = 12
_QA_TARGET_END_INDEX = 13

# SequenceClassification 文档字符串和期望输出
_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "lordtt13/emo-mobilebert"
_SEQ_CLASS_EXPECTED_OUTPUT = "'others'"
_SEQ_CLASS_EXPECTED_LOSS = "4.72"

MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = ["google/mobilebert-uncased"]  # 预训练模型的存档列表


def load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path):
    ```
    加载 MobileBERT 模型的 TensorFlow 权重,并将它们转换为 PyTorch 模型权重。

    Args:
        model (PreTrainedModel): 要加载权重的 MobileBERT 模型实例。
        config (MobileBertConfig): MobileBERT 模型的配置对象。
        tf_checkpoint_path (str): TensorFlow 权重的路径。

    Returns:
        None

    Raises:
        ImportError: 如果导入 TensorFlow 失败。
        RuntimeError: 如果无法从 tf_checkpoint_path 加载权重。

    Example usage:
        ```
        model = MobileBertModel.from_pretrained('google/mobilebert-uncased')
        config = MobileBertConfig.from_pretrained('google/mobilebert-uncased')
        load_tf_weights_in_mobilebert(model, config, 'path/to/tf_checkpoint')
        ```
    ```
    ```
    """Load tf checkpoints in a pytorch model."""
    # 加载 TensorFlow 检查点到 PyTorch 模型中

    try:
        import re  # 导入正则表达式模块
        import numpy as np  # 导入 NumPy 库
        import tensorflow as tf  # 导入 TensorFlow 库
    except ImportError:
        # 如果导入失败,记录错误信息并抛出异常
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
        raise

    tf_path = os.path.abspath(tf_checkpoint_path)
    # 获取 TensorFlow 检查点文件的绝对路径

    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
    # 记录日志信息,指示正在从 TensorFlow 检查点文件 {tf_path} 转换

    # Load weights from TF model
    # 从 TensorFlow 模型中加载权重
    init_vars = tf.train.list_variables(tf_path)
    # 获取 TensorFlow 模型中所有变量列表

    names = []
    arrays = []
    for name, shape in init_vars:
        # 遍历每个变量名和其形状
        logger.info(f"Loading TF weight {name} with shape {shape}")
        # 记录日志信息,指示正在加载 TensorFlow 权重 {name},形状为 {shape}
        array = tf.train.load_variable(tf_path, name)
        # 加载 TensorFlow 模型中的变量数据
        names.append(name)
        arrays.append(array)
    # 遍历names和arrays,每次迭代处理一个名字和对应的数组
    for name, array in zip(names, arrays):
        # 替换name中的特定字符串以简化模型参数名
        name = name.replace("ffn_layer", "ffn")
        name = name.replace("FakeLayerNorm", "LayerNorm")
        name = name.replace("extra_output_weights", "dense/kernel")
        name = name.replace("bert", "mobilebert")
        # 将name按"/"分割成列表
        name = name.split("/")
        
        # 检查name中是否包含不需要的变量名,若包含则跳过此次迭代
        if any(
            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
            for n in name
        ):
            logger.info(f"Skipping {'/'.join(name)}")
            continue
        
        pointer = model
        
        # 遍历name中的每个部分,逐级访问model的属性
        for m_name in name:
            # 如果m_name匹配形如"A-Za-z+_\d+"的字符串,则按下划线分割为多个部分
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
                scope_names = re.split(r"_(\d+)", m_name)
            else:
                scope_names = [m_name]
            
            # 根据scope_names的第一个部分选择指针位置
            if scope_names[0] == "kernel" or scope_names[0] == "gamma":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
                pointer = getattr(pointer, "bias")
            elif scope_names[0] == "output_weights":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "squad":
                pointer = getattr(pointer, "classifier")
            else:
                try:
                    pointer = getattr(pointer, scope_names[0])
                except AttributeError:
                    # 如果属性不存在,则记录日志并跳过当前迭代
                    logger.info(f"Skipping {'/'.join(name)}")
                    continue
            
            # 如果scope_names有多个部分,则进一步访问指定位置的属性
            if len(scope_names) >= 2:
                num = int(scope_names[1])
                pointer = pointer[num]
        
        # 如果m_name的结尾是"_embeddings",则访问pointer的"weight"属性
        if m_name[-11:] == "_embeddings":
            pointer = getattr(pointer, "weight")
        elif m_name == "kernel":
            # 如果m_name为"kernel",则对array进行转置操作
            array = np.transpose(array)
        
        # 检查pointer的形状和array的形状是否相匹配,若不匹配则抛出异常
        try:
            assert (
                pointer.shape == array.shape
            ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
        
        # 记录初始化操作的日志信息
        logger.info(f"Initialize PyTorch weight {name}")
        
        # 将array转换为torch.Tensor,并赋值给pointer的data属性
        pointer.data = torch.from_numpy(array)
    
    # 返回处理后的模型
    return model
        self,
        input_ids: Optional[torch.LongTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,


        # 初始化 MobileBertEmbeddings 类的实例
        super().__init__()
        # 设置是否使用三元输入(trigram_input)和嵌入大小(embedding_size)
        self.trigram_input = config.trigram_input
        self.embedding_size = config.embedding_size
        self.hidden_size = config.hidden_size

        # 初始化词嵌入(word_embeddings),位置嵌入(position_embeddings),和类型嵌入(token_type_embeddings)
        self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # 如果使用三元输入,嵌入维度乘以3,否则为1
        embed_dim_multiplier = 3 if self.trigram_input else 1
        embedded_input_size = self.embedding_size * embed_dim_multiplier

        # 定义嵌入转换层,将输入嵌入映射到隐藏大小
        self.embedding_transformation = nn.Linear(embedded_input_size, config.hidden_size)

        # 初始化归一化层(LayerNorm)和 dropout 层
        self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # 注册位置 ID 张量,用于序列化时持久化存储
        self.register_buffer(
            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
        )
    ) -> torch.Tensor:
        # 如果输入的 input_ids 不为 None,则获取其形状作为 input_shape
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            # 否则,获取 inputs_embeds 的形状去掉最后一个维度作为 input_shape
            input_shape = inputs_embeds.size()[:-1]

        # 获取序列长度,即 input_shape 的第二个维度
        seq_length = input_shape[1]

        # 如果未提供 position_ids,则使用预设的 position_ids 切片,长度为 seq_length
        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]

        # 如果未提供 token_type_ids,则创建一个与 input_shape 相同的零张量作为 token_type_ids
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

        # 如果未提供 inputs_embeds,则使用 input_ids 通过 word_embeddings 获取其嵌入表示
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

        # 如果启用了 trigram_input
        if self.trigram_input:
            # 根据 MobileBERT 论文中的描述,对输入的嵌入进行扩展处理
            inputs_embeds = torch.cat(
                [
                    nn.functional.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0.0),
                    inputs_embeds,
                    nn.functional.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0.0),
                ],
                dim=2,
            )

        # 如果启用了 trigram_input 或者嵌入维度不等于隐藏层维度
        if self.trigram_input or self.embedding_size != self.hidden_size:
            # 对输入的嵌入进行额外的变换处理
            inputs_embeds = self.embedding_transformation(inputs_embeds)

        # 添加位置嵌入和 token 类型嵌入到输入嵌入中
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        embeddings = inputs_embeds + position_embeddings + token_type_embeddings

        # 对嵌入结果进行层归一化
        embeddings = self.LayerNorm(embeddings)

        # 对归一化后的嵌入结果进行 Dropout 处理
        embeddings = self.dropout(embeddings)

        # 返回处理后的嵌入结果
        return embeddings
class MobileBertSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_attention_heads = config.num_attention_heads  # 设置注意力头的数量
        self.attention_head_size = int(config.true_hidden_size / config.num_attention_heads)  # 计算每个注意力头的大小
        self.all_head_size = self.num_attention_heads * self.attention_head_size  # 计算所有注意力头的总大小

        self.query = nn.Linear(config.true_hidden_size, self.all_head_size)  # 创建查询线性层
        self.key = nn.Linear(config.true_hidden_size, self.all_head_size)  # 创建键线性层
        self.value = nn.Linear(
            config.true_hidden_size if config.use_bottleneck_attention else config.hidden_size, self.all_head_size
        )  # 创建值线性层,根据是否使用瓶颈注意力选择隐藏大小

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)  # 创建Dropout层用于注意力概率

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)  # 调整张量形状以适应多头注意力计算
        return x.permute(0, 2, 1, 3)  # 转置张量以便进行注意力得分计算

    def forward(
        self,
        query_tensor: torch.Tensor,
        key_tensor: torch.Tensor,
        value_tensor: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
    ) -> Tuple[torch.Tensor]:
        mixed_query_layer = self.query(query_tensor)  # 计算混合查询层
        mixed_key_layer = self.key(key_tensor)  # 计算混合键层
        mixed_value_layer = self.value(value_tensor)  # 计算混合值层

        query_layer = self.transpose_for_scores(mixed_query_layer)  # 转置并准备查询张量
        key_layer = self.transpose_for_scores(mixed_key_layer)  # 转置并准备键张量
        value_layer = self.transpose_for_scores(mixed_value_layer)  # 转置并准备值张量

        # 计算原始的注意力分数,即查询与键的点积
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)  # 缩放注意力分数的平方根

        if attention_mask is not None:
            # 应用预计算的注意力掩码(适用于BertModel的所有层)
            attention_scores = attention_scores + attention_mask

        # 将注意力分数规范化为概率
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # 通过Dropout层来实现随机遮盖整个待注意的标记,这在原始Transformer论文中也有提到
        attention_probs = self.dropout(attention_probs)

        # 如果需要,掩盖特定的注意力头
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)  # 计算上下文张量
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()  # 调整上下文张量的维度顺序

        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)  # 调整上下文张量的形状
        context_layer = context_layer.view(new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)  # 准备输出结果

        return outputs
    # 初始化方法,接受一个配置对象作为参数
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__()
        # 根据配置决定是否使用瓶颈层
        self.use_bottleneck = config.use_bottleneck
        # 创建一个线性层,输入和输出大小都是 true_hidden_size
        self.dense = nn.Linear(config.true_hidden_size, config.true_hidden_size)
        # 根据配置选择合适的归一化方法,并初始化
        self.LayerNorm = NORM2FN[config.normalization_type](config.true_hidden_size, eps=config.layer_norm_eps)
        # 如果不使用瓶颈层,则初始化一个丢弃层
        if not self.use_bottleneck:
            self.dropout = nn.Dropout(config.hidden_dropout_prob)

    # 前向传播方法,接收隐藏状态和残差张量作为输入,返回一个张量
    def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor:
        # 将隐藏状态输入到线性层中进行计算
        layer_outputs = self.dense(hidden_states)
        # 如果没有使用瓶颈层,则对输出进行丢弃操作
        if not self.use_bottleneck:
            layer_outputs = self.dropout(layer_outputs)
        # 将丢弃后的输出与残差张量相加,并通过 LayerNorm 进行归一化处理
        layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
        # 返回处理后的输出张量
        return layer_outputs
# MobileBertAttention 类定义
class MobileBertAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化 self 层,即 MobileBertSelfAttention 类的实例化
        self.self = MobileBertSelfAttention(config)
        # 初始化 output 层,即 MobileBertSelfOutput 类的实例化
        self.output = MobileBertSelfOutput(config)
        # 初始化一个空集合,用于存储已剪枝的注意力头的索引
        self.pruned_heads = set()

    # 剪枝注意力头的方法
    def prune_heads(self, heads):
        # 如果 heads 列表为空,则直接返回
        if len(heads) == 0:
            return
        # 调用 find_pruneable_heads_and_indices 函数找到可剪枝的注意力头及其索引
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )

        # 剪枝线性层
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # 更新超参数并存储剪枝后的头部索引
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    # 前向传播方法
    def forward(
        self,
        query_tensor: torch.Tensor,
        key_tensor: torch.Tensor,
        value_tensor: torch.Tensor,
        layer_input: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
    ) -> Tuple[torch.Tensor]:
        # 使用 self 层进行自注意力计算
        self_outputs = self.self(
            query_tensor,
            key_tensor,
            value_tensor,
            attention_mask,
            head_mask,
            output_attentions,
        )
        # 将 self 层的输出经过 output 层的线性投影并添加残差连接
        attention_output = self.output(self_outputs[0], layer_input)
        # 如果需要输出注意力权重,将它们添加到输出中
        outputs = (attention_output,) + self_outputs[1:]  # 如果需要输出注意力权重,则将它们添加到输出中
        return outputs


# MobileBertIntermediate 类定义
class MobileBertIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 线性层,用于转换隐藏状态的维度
        self.dense = nn.Linear(config.true_hidden_size, config.intermediate_size)
        # 判断 config.hidden_act 是否是字符串,选择对应的激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    # 前向传播方法
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 经过线性层进行维度转换
        hidden_states = self.dense(hidden_states)
        # 经过激活函数
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


# OutputBottleneck 类定义
class OutputBottleneck(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 线性层,用于输出的维度转换
        self.dense = nn.Linear(config.true_hidden_size, config.hidden_size)
        # 归一化层,根据 config.normalization_type 选择对应的归一化函数
        self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size, eps=config.layer_norm_eps)
        # Dropout 层,用于防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
    # 定义前向传播方法,接受隐藏状态和残差张量作为输入,返回处理后的张量
    def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor:
        # 使用全连接层对隐藏状态进行线性变换
        layer_outputs = self.dense(hidden_states)
        # 对线性变换后的结果应用丢弃部分神经元的dropout操作
        layer_outputs = self.dropout(layer_outputs)
        # 将dropout后的结果与残差张量相加,并通过层归一化处理
        layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
        # 返回处理后的层输出张量
        return layer_outputs
# 定义 MobileBertOutput 类,继承自 nn.Module,用于处理 MobileBERT 模型的输出
class MobileBertOutput(nn.Module):
    # 初始化方法,接收一个配置对象 config
    def __init__(self, config):
        super().__init__()
        # 根据配置确定是否使用瓶颈层
        self.use_bottleneck = config.use_bottleneck
        # 创建一个线性层,将 intermediate_size 映射到 true_hidden_size
        self.dense = nn.Linear(config.intermediate_size, config.true_hidden_size)
        # 根据配置选择规范化层类型并初始化
        self.LayerNorm = NORM2FN[config.normalization_type](config.true_hidden_size)
        # 如果不使用瓶颈层,则创建一个丢弃层,用于随机丢弃节点以防止过拟合
        if not self.use_bottleneck:
            self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # 如果使用瓶颈层,则创建一个输出瓶颈对象
        else:
            self.bottleneck = OutputBottleneck(config)

    # 前向传播方法,接收三个张量作为输入,返回一个张量
    def forward(
        self, intermediate_states: torch.Tensor, residual_tensor_1: torch.Tensor, residual_tensor_2: torch.Tensor
    ) -> torch.Tensor:
        # 将 intermediate_states 输入到线性层中得到 layer_output
        layer_output = self.dense(intermediate_states)
        # 如果不使用瓶颈层,则对 layer_output 进行丢弃操作
        if not self.use_bottleneck:
            layer_output = self.dropout(layer_output)
            # 将丢弃后的输出与 residual_tensor_1 相加,并通过规范化层 LayerNorm 处理
            layer_output = self.LayerNorm(layer_output + residual_tensor_1)
        # 如果使用瓶颈层,则直接使用瓶颈层处理 layer_output 和 residual_tensor_2
        else:
            layer_output = self.LayerNorm(layer_output + residual_tensor_1)
            layer_output = self.bottleneck(layer_output, residual_tensor_2)
        # 返回处理后的输出张量
        return layer_output


# 定义 BottleneckLayer 类,继承自 nn.Module,用于 MobileBERT 中的瓶颈层处理
class BottleneckLayer(nn.Module):
    # 初始化方法,接收一个配置对象 config
    def __init__(self, config):
        super().__init__()
        # 创建一个线性层,将 hidden_size 映射到 intra_bottleneck_size
        self.dense = nn.Linear(config.hidden_size, config.intra_bottleneck_size)
        # 根据配置选择规范化层类型并初始化
        self.LayerNorm = NORM2FN[config.normalization_type](config.intra_bottleneck_size, eps=config.layer_norm_eps)

    # 前向传播方法,接收一个张量作为输入,返回一个张量
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 将 hidden_states 输入到线性层中得到 layer_input
        layer_input = self.dense(hidden_states)
        # 通过规范化层 LayerNorm 处理 layer_input
        layer_input = self.LayerNorm(layer_input)
        # 返回处理后的输出张量
        return layer_input


# 定义 Bottleneck 类,继承自 nn.Module,用于 MobileBERT 中的瓶颈处理
class Bottleneck(nn.Module):
    # 初始化方法,接收一个配置对象 config
    def __init__(self, config):
        super().__init__()
        # 根据配置确定是否共享键值查询瓶颈
        self.key_query_shared_bottleneck = config.key_query_shared_bottleneck
        # 根据配置确定是否使用瓶颈注意力
        self.use_bottleneck_attention = config.use_bottleneck_attention
        # 创建一个输入瓶颈层对象
        self.input = BottleneckLayer(config)
        # 如果共享键值查询瓶颈,则创建一个瓶颈注意力层对象
        if self.key_query_shared_bottleneck:
            self.attention = BottleneckLayer(config)
    # 定义一个方法 `forward`,接收一个名为 `hidden_states` 的张量参数,并返回一个元组
    def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
        # 该方法可以返回三种不同的元组值。这些不同的值利用了瓶颈层,
        # 这些线性层用于将隐藏状态投影到低维向量,从而降低内存使用。
        # 这些线性层在训练过程中学习其权重。

        # 如果 `config.use_bottleneck_attention` 为真,则返回经过瓶颈层处理的四个值,
        # 分别用于注意力层中的键、查询、值以及“层输入”。
        # 这个瓶颈层用于投影隐藏状态。层输入将作为注意力自我输出的残差张量使用,
        # 在计算注意力分数后添加到输出中。

        # 如果不使用 `config.use_bottleneck_attention`,但使用了 `config.key_query_shared_bottleneck`,
        # 则返回四个值,其中三个值经过了瓶颈层处理:查询和键通过相同的瓶颈层处理,
        # 而残差层则通过另一个瓶颈层处理,将应用于注意力自我输出。

        # 最后,如果都不满足,则查询、键和值的值为未经过瓶颈处理的隐藏状态,
        # 而残差层为经过瓶颈处理的隐藏状态。

        # 使用 `self.input` 方法对隐藏状态进行瓶颈处理
        bottlenecked_hidden_states = self.input(hidden_states)
        # 根据条件返回相应的元组值
        if self.use_bottleneck_attention:
            return (bottlenecked_hidden_states,) * 4
        elif self.key_query_shared_bottleneck:
            shared_attention_input = self.attention(hidden_states)
            return (shared_attention_input, shared_attention_input, hidden_states, bottlenecked_hidden_states)
        else:
            return (hidden_states, hidden_states, hidden_states, bottlenecked_hidden_states)
# MobileBertLayer 类定义,继承自 nn.Module
class MobileBertLayer(nn.Module):
    def __init__(self, config):
        # 调用父类构造函数进行初始化
        super().__init__()
        
        # 根据配置文件初始化各种属性
        self.use_bottleneck = config.use_bottleneck
        self.num_feedforward_networks = config.num_feedforward_networks
        
        # 创建 MobileBertAttention 对象,用于处理注意力机制
        self.attention = MobileBertAttention(config)
        
        # 创建 MobileBertIntermediate 对象,用于处理中间层输出
        self.intermediate = MobileBertIntermediate(config)
        
        # 创建 MobileBertOutput 对象,用于处理最终输出
        self.output = MobileBertOutput(config)
        
        # 如果配置中设置了使用瓶颈层,则创建 Bottleneck 对象
        if self.use_bottleneck:
            self.bottleneck = Bottleneck(config)
        
        # 如果配置中指定了多个前馈网络,则创建对应数量的 FFNLayer 对象组成列表
        if config.num_feedforward_networks > 1:
            self.ffn = nn.ModuleList([FFNLayer(config) for _ in range(config.num_feedforward_networks - 1)])

    # 前向传播函数定义,接收输入 hidden_states 和可选的各种掩码、标志
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        ```
    # 定义一个方法,接收隐藏状态、注意力掩码、头部掩码、是否输出注意力权重等参数,返回一个元组类型的 torch.Tensor
        ) -> Tuple[torch.Tensor]:
            # 如果使用瓶颈模块
            if self.use_bottleneck:
                # 使用瓶颈模块处理隐藏状态,返回查询张量、键张量、值张量和层输入
                query_tensor, key_tensor, value_tensor, layer_input = self.bottleneck(hidden_states)
            else:
                # 否则直接复制隐藏状态到查询张量、键张量、值张量、层输入
                query_tensor, key_tensor, value_tensor, layer_input = [hidden_states] * 4
    
            # 调用 self.attention 方法,处理查询张量、键张量、值张量、层输入、注意力掩码、头部掩码等参数,获取自注意力输出
            self_attention_outputs = self.attention(
                query_tensor,
                key_tensor,
                value_tensor,
                layer_input,
                attention_mask,
                head_mask,
                output_attentions=output_attentions,
            )
            # 获取自注意力输出的第一个元素作为 attention_output
            attention_output = self_attention_outputs[0]
            # 创建一个元组 s,包含 attention_output
            s = (attention_output,)
            # 如果输出注意力权重,则将其添加到 outputs 中
            outputs = self_attention_outputs[1:]  # 如果我们输出注意力权重,添加自注意力权重
    
            # 如果存在多个前馈网络
            if self.num_feedforward_networks != 1:
                # 对于每个前馈网络模块 ffn_module 在 self.ffn 中的枚举 i
                for i, ffn_module in enumerate(self.ffn):
                    # 使用前馈网络模块处理 attention_output
                    attention_output = ffn_module(attention_output)
                    # 将处理后的 attention_output 添加到元组 s 中
                    s += (attention_output,)
    
            # 使用 intermediate 方法处理 attention_output,得到 intermediate_output
            intermediate_output = self.intermediate(attention_output)
            # 使用 output 方法处理 intermediate_output、attention_output 和 hidden_states,得到 layer_output
            layer_output = self.output(intermediate_output, attention_output, hidden_states)
            # 构建 outputs 元组,包含 layer_output 和之前的 outputs、固定的一些张量数据和 s 元组
            outputs = (
                (layer_output,)
                + outputs
                + (
                    torch.tensor(1000),  # 固定值 1000
                    query_tensor,  # 查询张量
                    key_tensor,  # 键张量
                    value_tensor,  # 值张量
                    layer_input,  # 层输入
                    attention_output,  # 自注意力输出
                    intermediate_output,  # intermediate 输出
                )
                + s
            )
            # 返回 outputs 结果
            return outputs
class MobileBertEncoder(nn.Module):
    # MobileBERT 编码器模型,继承自 nn.Module 类
    def __init__(self, config):
        super().__init__()
        # 初始化 MobileBERT 编码器的层列表,每层由 MobileBertLayer 组成
        self.layer = nn.ModuleList([MobileBertLayer(config) for _ in range(config.num_hidden_layers)])

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = True,
    ) -> Union[Tuple, BaseModelOutput]:
        # 如果需要输出隐藏状态,则初始化空元组 all_hidden_states
        all_hidden_states = () if output_hidden_states else None
        # 如果需要输出注意力权重,则初始化空元组 all_attentions
        all_attentions = () if output_attentions else None
        # 遍历每一层 MobileBERT 编码器
        for i, layer_module in enumerate(self.layer):
            # 如果需要输出隐藏状态,将当前隐藏状态添加到 all_hidden_states 中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 调用当前层的 forward 方法,计算层的输出
            layer_outputs = layer_module(
                hidden_states,
                attention_mask,
                head_mask[i],
                output_attentions,
            )
            # 更新隐藏状态为当前层的输出的第一个元素
            hidden_states = layer_outputs[0]

            # 如果需要输出注意力权重,将当前层的注意力权重添加到 all_attentions 中
            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        # 添加最后一层的隐藏状态到 all_hidden_states 中
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 如果不返回字典形式的输出,则根据需要返回不同的元组或 BaseModelOutput 对象
        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )


class MobileBertPooler(nn.Module):
    # MobileBERT 池化层模型,继承自 nn.Module 类
    def __init__(self, config):
        super().__init__()
        # 根据配置文件选择是否激活分类器激活函数
        self.do_activate = config.classifier_activation
        if self.do_activate:
            # 如果需要激活,初始化一个线性层 dense
            self.dense = nn.Linear(config.hidden_size, config.hidden_size)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 池化模型通过简单地选择第一个 token 对应的隐藏状态来实现
        first_token_tensor = hidden_states[:, 0]
        # 如果不需要激活,则直接返回第一个 token 的隐藏状态
        if not self.do_activate:
            return first_token_tensor
        else:
            # 否则,通过线性层和 tanh 激活函数计算池化输出
            pooled_output = self.dense(first_token_tensor)
            pooled_output = torch.tanh(pooled_output)
            return pooled_output


class MobileBertPredictionHeadTransform(nn.Module):
    # MobileBERT 预测头转换层模型,继承自 nn.Module 类
    def __init__(self, config):
        super().__init__()
        # 初始化线性层 dense,用于特征变换
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 根据配置选择激活函数,支持字符串形式和函数形式
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        # 初始化 LayerNorm 层,用于归一化
        self.LayerNorm = NORM2FN["layer_norm"](config.hidden_size, eps=config.layer_norm_eps)
    # 定义一个方法 `forward`,用于前向传播计算
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 使用全连接层 `self.dense` 对输入的隐藏状态进行线性变换
        hidden_states = self.dense(hidden_states)
        # 对变换后的隐藏状态应用激活函数 `self.transform_act_fn`
        hidden_states = self.transform_act_fn(hidden_states)
        # 对激活后的隐藏状态进行层归一化 `self.LayerNorm`
        hidden_states = self.LayerNorm(hidden_states)
        # 返回处理后的隐藏状态作为输出
        return hidden_states
# MobileBertLMPredictionHead 类定义,继承自 nn.Module
class MobileBertLMPredictionHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 使用 MobileBertPredictionHeadTransform 类对隐藏状态进行转换
        self.transform = MobileBertPredictionHeadTransform(config)
        # 创建一个线性层,用于预测输出权重,输入维度为词汇表大小,输出维度为隐藏大小减去嵌入大小,无偏置
        self.dense = nn.Linear(config.vocab_size, config.hidden_size - config.embedding_size, bias=False)
        # 创建一个线性层,用于预测输出偏置,输入维度为嵌入大小,输出维度为词汇表大小,无偏置
        self.decoder = nn.Linear(config.embedding_size, config.vocab_size, bias=False)
        # 创建一个参数化的偏置向量,维度为词汇表大小
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        # 确保输出偏置与预测层的偏置相连接,以便与 `resize_token_embeddings` 方法正确调整大小
        self.decoder.bias = self.bias

    # 前向传播函数,接受隐藏状态输入,返回预测分数张量
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 对隐藏状态进行转换
        hidden_states = self.transform(hidden_states)
        # 计算预测分数
        hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0))
        # 加上预测偏置
        hidden_states += self.decoder.bias
        return hidden_states


# MobileBertOnlyMLMHead 类定义,继承自 nn.Module
class MobileBertOnlyMLMHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 创建 MobileBertLMPredictionHead 实例作为预测
        self.predictions = MobileBertLMPredictionHead(config)

    # 前向传播函数,接受序列输出输入,返回预测分数
    def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
        # 调用预测头进行预测
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


# MobileBertPreTrainingHeads 类定义,继承自 nn.Module
class MobileBertPreTrainingHeads(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 创建 MobileBertLMPredictionHead 实例作为预测
        self.predictions = MobileBertLMPredictionHead(config)
        # 创建线性层,用于序列关系分类,输入维度为隐藏大小,输出维度为2(二元分类)
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    # 前向传播函数,接受序列输出和池化输出输入,返回预测分数和序列关系分数的元组
    def forward(self, sequence_output: torch.Tensor, pooled_output: torch.Tensor) -> Tuple[torch.Tensor]:
        # 调用预测头进行预测
        prediction_scores = self.predictions(sequence_output)
        # 计算序列关系分数
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score


# MobileBertPreTrainedModel 类定义,继承自 PreTrainedModel
class MobileBertPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 指定配置类为 MobileBertConfig
    config_class = MobileBertConfig
    # 预训练模型归档映射
    pretrained_model_archive_map = MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST
    # 加载 TensorFlow 权重的方法
    load_tf_weights = load_tf_weights_in_mobilebert
    # 基础模型前缀名
    base_model_prefix = "mobilebert"
    def _init_weights(self, module):
        """Initialize the weights"""
        # 如果 module 是 nn.Linear 类型
        if isinstance(module, nn.Linear):
            # 使用正态分布初始化权重,均值为 0.0,标准差为模型配置中的 initializer_range
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 如果存在偏置项,将其初始化为零
            if module.bias is not None:
                module.bias.data.zero_()
        # 如果 module 是 nn.Embedding 类型
        elif isinstance(module, nn.Embedding):
            # 使用正态分布初始化权重,均值为 0.0,标准差为模型配置中的 initializer_range
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 如果指定了 padding_idx,则将对应位置的权重初始化为零
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        # 如果 module 是 nn.LayerNorm 或者 NoNorm 类型
        elif isinstance(module, (nn.LayerNorm, NoNorm)):
            # 将偏置项初始化为零
            module.bias.data.zero_()
            # 将权重初始化为全 1.0
            module.weight.data.fill_(1.0)
@dataclass
class MobileBertForPreTrainingOutput(ModelOutput):
    """
    Output type of [`MobileBertForPreTraining`].

    Args:
        loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
            Total loss as the sum of the masked language modeling loss and the next sequence prediction
            (classification) loss.
        prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
            before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    loss: Optional[torch.FloatTensor] = None
    prediction_logits: torch.FloatTensor = None
    seq_relationship_logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


MOBILEBERT_START_DOCSTRING = r"""
    Docstring for `MobileBertForPreTrainingOutput`.

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`MobileBertConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

MOBILEBERT_INPUTS_DOCSTRING = r"""
    Docstring for `MOBILEBERT_INPUTS_DOCSTRING`.

    """
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            # 输入序列的标记索引,在词汇表中找到对应的标记
            Indices of input sequence tokens in the vocabulary.

            # 可以使用 `AutoTokenizer` 获取这些索引。详见 `PreTrainedTokenizer.encode` 和 `PreTrainedTokenizer.__call__`。
            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            # 用来避免在填充的标记索引上执行注意力机制,值为 `[0, 1]`:

            - 1 表示 **未被屏蔽** 的标记,
            - 0 表示 **被屏蔽** 的标记。
            
            # 注意屏蔽令牌的作用是什么?
            [What are attention masks?](../glossary#attention-mask)
        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            # 指示输入的第一部分和第二部分的分段标记索引。索引选在 `[0, 1]`:

            - 0 对应 *句子 A* 的标记,
            - 1 对应 *句子 B* 的标记。
            
            # 什么是分段标记 ID?
            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            # 每个输入序列标记在位置嵌入中的位置索引。选在范围 `[0, config.max_position_embeddings - 1]`。

            # 什么是位置 ID?
            [What are position IDs?](../glossary#position-ids)
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            # 用于屏蔽自注意力模块中的特定头部的掩码。掩码值选在 `[0, 1]`:

            - 1 表示头部 **未被屏蔽**,
            - 0 表示头部 **被屏蔽**。
            
            # 控制自注意力头部屏蔽的作用是什么?
        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
            # 可选,可以直接传递嵌入表示,而不是传递 `input_ids`。如果需要更多对转换 `input_ids` 到相关向量的控制权,则很有用。
            # 这对于控制模型内部嵌入查找矩阵的转换方式很有用。
        output_attentions (`bool`, *optional*):
            # 是否返回所有注意力层的注意力张量。详见返回张量中的 `attentions` 以获取更多细节。
        output_hidden_states (`bool`, *optional*):
            # 是否返回所有层的隐藏状态。详见返回张量中的 `hidden_states` 以获取更多细节。
        return_dict (`bool`, *optional*):
            # 是否返回 [`~utils.ModelOutput`] 而不是普通元组。
"""
The bare MobileBert Model transformer outputting raw hidden-states without any specific head on top.
"""
@add_start_docstrings(
    "The bare MobileBert Model transformer outputting raw hidden-states without any specific head on top.",
    MOBILEBERT_START_DOCSTRING,
)
class MobileBertModel(MobileBertPreTrainedModel):
    """
    MobileBertModel class implementing the MobileBERT architecture.

    https://arxiv.org/pdf/2004.02984.pdf
    """

    def __init__(self, config, add_pooling_layer=True):
        """
        Initializes a MobileBertModel instance.

        Args:
            config (MobileBertConfig): Configuration class for MobileBERT.
            add_pooling_layer (bool): Whether to add a pooling layer. Defaults to True.
        """
        super().__init__(config)
        self.config = config
        self.embeddings = MobileBertEmbeddings(config)  # Initialize MobileBertEmbeddings layer
        self.encoder = MobileBertEncoder(config)        # Initialize MobileBertEncoder layer

        self.pooler = MobileBertPooler(config) if add_pooling_layer else None  # Initialize MobileBertPooler if add_pooling_layer is True

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        """
        Retrieves the input word embeddings from MobileBertEmbeddings.

        Returns:
            torch.nn.Embedding: The word embedding layer.
        """
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        """
        Sets the input word embeddings in MobileBertEmbeddings.

        Args:
            value (torch.Tensor): New tensor for word embeddings.
        """
        self.embeddings.word_embeddings = value

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model.

        Args:
            heads_to_prune (dict): Dictionary of {layer_num: list of heads to prune in this layer}.

        See base class PreTrainedModel.
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    @add_start_docstrings_to_model_forward(
        MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
    )
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=BaseModelOutputWithPooling,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        output_hidden_states: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        ):
        """
        Forward pass for the MobileBertModel.

        Args:
            input_ids (Optional[torch.LongTensor]): Input ids of shape (batch_size, sequence_length).
            attention_mask (Optional[torch.FloatTensor]): Attention mask of shape (batch_size, sequence_length).
            token_type_ids (Optional[torch.LongTensor]): Token type ids of shape (batch_size, sequence_length).
            position_ids (Optional[torch.LongTensor]): Position ids of shape (batch_size, sequence_length).
            head_mask (Optional[torch.FloatTensor]): Mask to nullify selected heads of shape (num_heads,).
            inputs_embeds (Optional[torch.FloatTensor]): Embedded inputs of shape (batch_size, sequence_length, embedding_size).
            output_hidden_states (Optional[bool]): Whether to return hidden states.
            output_attentions (Optional[bool]): Whether to return attentions.
            return_dict (Optional[bool]): Whether to return a dictionary.

        Returns:
            BaseModelOutputWithPooling or tuple:
                BaseModelOutputWithPooling if output_hidden_states=False and output_attentions=False
                tuple (torch.FloatTensor, ...) otherwise

        """
        # Forward pass logic goes here
        pass
        # 如果未指定输出注意力,使用配置中的默认值
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 如果未指定输出隐藏状态,使用配置中的默认值
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 如果未指定返回字典,使用配置中的默认值
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 如果同时指定了 input_ids 和 inputs_embeds,则抛出 ValueError
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            # 如果指定了 input_ids,则检查是否存在填充并没有注意力掩码
            self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
            # 获取 input_ids 的形状
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            # 如果指定了 inputs_embeds,则获取其形状的前几维
            input_shape = inputs_embeds.size()[:-1]
        else:
            # 如果既未指定 input_ids 也未指定 inputs_embeds,则抛出 ValueError
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        # 确定设备是 input_ids 的设备还是 inputs_embeds 的设备
        device = input_ids.device if input_ids is not None else inputs_embeds.device

        # 如果 attention_mask 未指定,则创建全为 1 的注意力掩码张量
        if attention_mask is None:
            attention_mask = torch.ones(input_shape, device=device)
        # 如果 token_type_ids 未指定,则创建全为 0 的张量作为 token 类型 IDs
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        # 生成扩展的注意力掩码张量,以匹配多头注意力的维度需求
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

        # 准备头部掩码(如果需要)
        # 在头部掩码中,1.0 表示保留该头部
        # attention_probs 的形状为 bsz x n_heads x N x N
        # input head_mask 的形状为 [num_heads] 或 [num_hidden_layers x num_heads]
        # head_mask 被转换为形状 [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        # 生成嵌入输出
        embedding_output = self.embeddings(
            input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
        )
        # 编码器的输出
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 获取序列输出
        sequence_output = encoder_outputs[0]
        # 如果定义了池化器,生成池化输出;否则为 None
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

        # 如果不要求返回字典,则返回一个元组
        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        # 如果要求返回字典,则返回一个 BaseModelOutputWithPooling 对象
        return BaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
@add_start_docstrings(
    """
    MobileBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
    `next sentence prediction (classification)` head.
    """,
    MOBILEBERT_START_DOCSTRING,
)
class MobileBertForPreTraining(MobileBertPreTrainedModel):
    _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]

    def __init__(self, config):
        super().__init__(config)
        # 初始化 MobileBert 模型
        self.mobilebert = MobileBertModel(config)
        # 初始化 MobileBert 的预训练头部
        self.cls = MobileBertPreTrainingHeads(config)

        # 初始化权重并应用最终处理
        self.post_init()

    def get_output_embeddings(self):
        # 返回预测头部的解码器,用于输出嵌入
        return self.cls.predictions.decoder

    def set_output_embeddings(self, new_embeddings):
        # 设置新的输出嵌入到预测头部的解码器中
        self.cls.predictions.decoder = new_embeddings

    def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
        # 调整标记嵌入的大小,首先调整密集输出嵌入
        self.cls.predictions.dense = self._get_resized_lm_head(
            self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True
        )

        # 调用父类的方法来调整标记嵌入的大小
        return super().resize_token_embeddings(new_num_tokens=new_num_tokens)

    @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=MobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        next_sentence_label: Optional[torch.LongTensor] = None,
        output_attentions: Optional[torch.FloatTensor] = None,
        output_hidden_states: Optional[torch.FloatTensor] = None,
        return_dict: Optional[torch.FloatTensor] = None,
):
    def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
        # 调整模型中的token嵌入大小,首先调整密集输出层的嵌入
        self.cls.predictions.dense = self._get_resized_lm_head(
            self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True
        )
        # 调用父类方法以完成token嵌入的调整,并返回调整后的嵌入层
        return super().resize_token_embeddings(new_num_tokens=new_num_tokens)

    @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=MaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
        expected_output="'paris'",
        expected_loss=0.57,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, MaskedLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        """
        # 如果return_dict为None,则使用配置文件中的默认设置来确定是否返回字典形式的输出
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用MobileBERT模型,传入各种输入参数,并获取输出结果
        outputs = self.mobilebert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从模型的输出中提取序列输出
        sequence_output = outputs[0]
        # 使用预测头部对序列输出进行预测得分计算
        prediction_scores = self.cls(sequence_output)

        masked_lm_loss = None
        # 如果提供了标签,则计算掩码语言建模损失
        if labels is not None:
            loss_fct = CrossEntropyLoss()  # 使用交叉熵损失函数,-100索引表示填充标记
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

        # 如果不返回字典形式的输出,则组装输出结果并返回
        if not return_dict:
            output = (prediction_scores,) + outputs[2:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        # 返回MaskedLMOutput对象,包括损失、预测logits、隐藏状态和注意力分布
        return MaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
# 创建一个新的类 MobileBertOnlyNSPHead,继承自 nn.Module
class MobileBertOnlyNSPHead(nn.Module):
    # 初始化方法,接受一个配置参数 config
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__()
        # 定义一个线性层,用于预测下一个句子的关系,输入大小为 config.hidden_size,输出大小为 2
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    # 前向传播方法,接受一个 Tensor 参数 pooled_output,返回预测的下一个句子关系的分数 Tensor
    def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
        # 计算下一个句子关系的分数,使用 seq_relationship 线性层
        seq_relationship_score = self.seq_relationship(pooled_output)
        # 返回计算得到的分数 Tensor
        return seq_relationship_score


# 添加文档字符串和注解到 MobileBertForNextSentencePrediction 类
@add_start_docstrings(
    """MobileBert Model with a `next sentence prediction (classification)` head on top.""",
    MOBILEBERT_START_DOCSTRING,
)
# 定义 MobileBertForNextSentencePrediction 类,继承自 MobileBertPreTrainedModel
class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
    # 初始化方法,接受一个配置参数 config
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)

        # 创建一个 MobileBertModel 对象,传入配置参数 config
        self.mobilebert = MobileBertModel(config)
        # 创建一个 MobileBertOnlyNSPHead 对象,传入配置参数 config
        self.cls = MobileBertOnlyNSPHead(config)

        # 调用额外的初始化方法来初始化权重并应用最终处理
        self.post_init()

    # 添加文档字符串和注解到 forward 方法,描述输入参数和返回输出
    @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    # 替换返回文档字符串中的输出类型为 NextSentencePredictorOutput,使用指定的配置类 _CONFIG_FOR_DOC
    @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
    # 前向传播方法,接受多个可选的 Tensor 输入参数和其他关键字参数
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs,
        ) -> Union[Tuple, NextSentencePredictorOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
            (see `input_ids` docstring) Indices should be in `[0, 1]`.

            - 0 indicates sequence B is a continuation of sequence A,
            - 1 indicates sequence B is a random sequence.

        Returns:
            Depending on `return_dict`:
            - If `return_dict` is `False`, returns a tuple with `seq_relationship_score` and additional outputs.
            - If `return_dict` is `True`, returns a `NextSentencePredictorOutput` object.

        Examples:
        Example usage of the `MobileBertForNextSentencePrediction` model.

        ```
        >>> from transformers import AutoTokenizer, MobileBertForNextSentencePrediction
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("google/mobilebert-uncased")
        >>> model = MobileBertForNextSentencePrediction.from_pretrained("google/mobilebert-uncased")

        >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
        >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
        >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")

        >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
        >>> loss = outputs.loss
        >>> logits = outputs.logits
        ```"""

        if "next_sentence_label" in kwargs:
            # Issue a warning that the argument `next_sentence_label` is deprecated and suggest using `labels` instead
            warnings.warn(
                "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
                " `labels` instead.",
                FutureWarning,
            )
            # Replace `next_sentence_label` with `labels` if found in kwargs
            labels = kwargs.pop("next_sentence_label")

        # Determine whether to return a dictionary based on the provided argument or the default setting
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Pass inputs through the MobileBERT model for next sentence prediction
        outputs = self.mobilebert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # Extract the pooled output from MobileBERT's outputs
        pooled_output = outputs[1]
        # Compute scores for next sentence prediction using a classification layer
        seq_relationship_score = self.cls(pooled_output)

        next_sentence_loss = None
        # Compute loss if labels are provided
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), labels.view(-1))

        # Decide the output format based on `return_dict`
        if not return_dict:
            # Return a tuple with `seq_relationship_score` and optionally other outputs
            output = (seq_relationship_score,) + outputs[2:]
            return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output

        # Return a `NextSentencePredictorOutput` object containing loss, logits, hidden states, and attentions
        return NextSentencePredictorOutput(
            loss=next_sentence_loss,
            logits=seq_relationship_score,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
# MobileBert 模型变换器,顶部带有序列分类/回归头部(线性层在池化输出之上),例如用于 GLUE 任务。
# 此类从 transformers.models.bert.modeling_bert.BertForSequenceClassification 复制,将 Bert 替换为 MobileBert 并全小写处理。
class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels  # 设置标签数量
        self.config = config

        self.mobilebert = MobileBertModel(config)  # 初始化 MobileBert 模型
        # 根据配置初始化分类器的丢弃率,若未指定,则使用隐藏层丢弃率
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)  # 定义丢弃层
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)  # 定义分类器线性层

        # 初始化权重并应用最终处理
        self.post_init()

    # 添加输入文档字符串,描述模型前向传播的输入参数
    # 从 MOBILEBERT_INPUTS_DOCSTRING 格式化得到输入参数的说明
    # 添加代码示例文档字符串,包含序列分类检查点、输出类型、配置类和预期输出/损失
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        # 设置返回字典,如果未指定则使用模型配置中的默认值
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用MobileBERT模型进行前向传播
        outputs = self.mobilebert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从模型输出中获取汇聚输出(pooled output)
        pooled_output = outputs[1]

        # 对汇聚输出进行dropout处理
        pooled_output = self.dropout(pooled_output)
        # 使用分类器进行分类任务的预测
        logits = self.classifier(pooled_output)

        # 初始化损失为None
        loss = None
        # 如果提供了标签,计算损失
        if labels is not None:
            # 如果问题类型未指定,则根据标签类型和类别数确定问题类型
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            # 根据问题类型选择相应的损失函数
            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        # 如果不要求返回字典,则组织输出结果为元组
        if not return_dict:
            output = (logits,) + outputs[2:]  # 包括额外的hidden_states
            return ((loss,) + output) if loss is not None else output

        # 返回包含损失和其他输出的SequenceClassifierOutput对象
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
@add_start_docstrings(
    """
    MobileBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
    linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
    MOBILEBERT_START_DOCSTRING,
)
# 从transformers.models.bert.modeling_bert.BertForQuestionAnswering复制过来,将Bert改为MobileBert全大写
class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        # 设置分类任务的标签数目
        self.num_labels = config.num_labels

        # 创建MobileBert模型,不包含池化层
        self.mobilebert = MobileBertModel(config, add_pooling_layer=False)
        # 创建线性层,用于输出分类标签
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        # 初始化权重并应用最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_QA,
        output_type=QuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
        qa_target_start_index=_QA_TARGET_START_INDEX,
        qa_target_end_index=_QA_TARGET_END_INDEX,
        expected_output=_QA_EXPECTED_OUTPUT,
        expected_loss=_QA_EXPECTED_LOSS,
    )
    # 定义前向传播函数,处理输入并返回模型输出
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        start_positions: Optional[torch.Tensor] = None,
        end_positions: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
        r"""
        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        """
        # 如果 return_dict 参数为 None,则使用 self.config.use_return_dict 决定是否返回字典形式的输出
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用 MobileBERT 模型进行前向传播
        outputs = self.mobilebert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取模型输出的序列输出
        sequence_output = outputs[0]

        # 将序列输出传入 QA 输出层,得到起始和结束位置的 logits
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()  # 去除多余的维度并保持连续性
        end_logits = end_logits.squeeze(-1).contiguous()  # 去除多余的维度并保持连续性

        total_loss = None
        if start_positions is not None and end_positions is not None:
            # 如果在多 GPU 环境下,对 start_positions 和 end_positions 添加一维
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # 将超出模型输入的起始/结束位置限制在有效范围内
            ignored_index = start_logits.size(1)
            start_positions = start_positions.clamp(0, ignored_index)
            end_positions = end_positions.clamp(0, ignored_index)

            # 定义交叉熵损失函数,忽略 ignored_index 处的预测
            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

        # 如果不需要返回字典形式的输出,按元组形式返回结果
        if not return_dict:
            output = (start_logits, end_logits) + outputs[2:]  # 包括额外的 hidden_states 和 attentions
            return ((total_loss,) + output) if total_loss is not None else output

        # 返回 QuestionAnsweringModelOutput 类型的对象,包含损失、起始和结束 logits、隐藏状态和注意力权重
        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
# 定义一个带有多选分类头的 MobileBert 模型,用于例如 RocStories/SWAG 任务。
# 这个类继承自 MobileBertPreTrainedModel。
class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        # 初始化 MobileBert 模型
        self.mobilebert = MobileBertModel(config)
        
        # 确定分类器的 dropout 比率,如果未指定则使用隐藏层 dropout 比率
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        # 定义一个 dropout 层,用于分类器
        self.dropout = nn.Dropout(classifier_dropout)
        
        # 定义一个线性层作为分类器,输入大小为隐藏层大小,输出大小为1
        self.classifier = nn.Linear(config.hidden_size, 1)

        # 初始化权重并进行最终处理
        self.post_init()

    # 添加文档字符串到 forward 方法,描述输入参数
    @add_start_docstrings_to_model_forward(
        MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
    )
    # 添加代码示例文档字符串到 forward 方法,描述模型的输出类型和相关配置
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=MultipleChoiceModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
            `input_ids` above)
        """
        # 初始化返回字典,如果未提供则使用配置中的默认设置
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # 计算选择项的数量,根据 input_ids 的第二维度确定
        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

        # 重新塑形输入张量,以便适应模型要求
        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
        inputs_embeds = (
            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
            if inputs_embeds is not None
            else None
        )

        # 将输入传递给 MobileBERT 模型进行处理
        outputs = self.mobilebert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 提取池化后的输出
        pooled_output = outputs[1]

        # 应用 dropout 正则化
        pooled_output = self.dropout(pooled_output)
        # 将池化后的输出传递给分类器得到 logits
        logits = self.classifier(pooled_output)
        # 重新塑形 logits,以匹配选择项的形状
        reshaped_logits = logits.view(-1, num_choices)

        # 计算损失
        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)

        # 根据 return_dict 决定输出结果的格式
        if not return_dict:
            output = (reshaped_logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 构造 MultipleChoiceModelOutput 对象作为返回结果
        return MultipleChoiceModelOutput(
            loss=loss,
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
# 添加文档字符串描述 MobileBert 模型与顶部的标记分类头部(线性层),例如用于命名实体识别(NER)任务
@add_start_docstrings(
    """
    MobileBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
    for Named-Entity-Recognition (NER) tasks.
    """,
    MOBILEBERT_START_DOCSTRING,
)
# 从 transformers.models.bert.modeling_bert.BertForTokenClassification 复制并修改为 MobileBert,保持所有大小写一致
class MobileBertForTokenClassification(MobileBertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        # 使用 MobileBertModel 初始化
        self.mobilebert = MobileBertModel(config, add_pooling_layer=False)
        # 如果配置中指定了 classifier_dropout,则使用该值;否则使用 hidden_dropout_prob
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        # 定义一个 Dropout 层,用于分类器
        self.dropout = nn.Dropout(classifier_dropout)
        # 定义一个线性层作为分类器,输入大小为 hidden_size,输出大小为 num_labels
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        # 初始化权重并进行最终处理
        self.post_init()

    # 添加文档字符串描述模型的 forward 方法
    @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    # 添加示例代码文档字符串,展示输入输出的示例
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION,
        output_type=TokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
        expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT,
        expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
    )
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        # 输入参数说明:input_ids:输入的 token IDs;attention_mask:注意力掩码;token_type_ids:token 类型 IDs;...
    ):
        ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        # 根据输入的 return_dict 参数确定是否返回字典形式的输出
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用 MobileBERT 模型进行前向传播
        outputs = self.mobilebert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取模型输出的序列特征向量
        sequence_output = outputs[0]

        # 对序列特征向量应用 dropout 操作
        sequence_output = self.dropout(sequence_output)
        
        # 对 dropout 后的特征向量进行分类预测
        logits = self.classifier(sequence_output)

        # 初始化损失为 None
        loss = None
        # 如果提供了标签,则计算交叉熵损失
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        # 如果 return_dict 为 False,则返回 logits 和额外的 hidden_states
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 如果 return_dict 为 True,则返回 TokenClassifierOutput 对象
        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
posted @ 2024-06-30 15:32  绝不原创的飞龙  阅读(90)  评论(0编辑  收藏  举报