Transformers-源码解析-一百二十六-

Transformers 源码解析(一百二十六)

.\models\xlm\modeling_xlm.py

# coding=utf-8
# Copyright 2019-present, Facebook, Inc and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
 PyTorch XLM model.
"""

import itertools
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union

import numpy as np
import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import gelu
from ...modeling_outputs import (
    BaseModelOutput,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, SequenceSummary, SQuADHead
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
    ModelOutput,
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from .configuration_xlm import XLMConfig

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "FacebookAI/xlm-mlm-en-2048"
_CONFIG_FOR_DOC = "XLMConfig"

XLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "FacebookAI/xlm-mlm-en-2048",
    "FacebookAI/xlm-mlm-ende-1024",
    "FacebookAI/xlm-mlm-enfr-1024",
    "FacebookAI/xlm-mlm-enro-1024",
    "FacebookAI/xlm-mlm-tlm-xnli15-1024",
    "FacebookAI/xlm-mlm-xnli15-1024",
    "FacebookAI/xlm-clm-enfr-1024",
    "FacebookAI/xlm-clm-ende-1024",
    "FacebookAI/xlm-mlm-17-1280",
    "FacebookAI/xlm-mlm-100-1280",
    # See all XLM models at https://huggingface.co/models?filter=xlm
]


def create_sinusoidal_embeddings(n_pos, dim, out):
    """
    Create sinusoidal positional embeddings.
    
    Args:
    - n_pos (int): Number of positions.
    - dim (int): Dimension of embeddings.
    - out (Tensor): Output tensor to store the embeddings.
    
    This function computes sinusoidal embeddings based on position and dimension,
    storing them in the provided output tensor.
    """
    position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
    out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
    out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
    out.detach_()
    out.requires_grad = False


def get_masks(slen, lengths, causal, padding_mask=None):
    """
    Generate masks for hidden states and optionally an attention mask.
    
    Args:
    - slen (int): Sequence length.
    - lengths (Tensor): Lengths of each sequence in a batch.
    - causal (bool): If True, generate a causal (triangular) attention mask.
    - padding_mask (Tensor, optional): Mask indicating padded elements.
    
    Returns:
    - Tensor: Mask for hidden states.
    
    This function generates a mask to hide elements beyond the actual length
    of each sequence, and optionally a causal attention mask if specified.
    """
    alen = torch.arange(slen, dtype=torch.long, device=lengths.device)
    if padding_mask is not None:
        mask = padding_mask
    else:
        assert lengths.max().item() <= slen
        mask = alen < lengths[:, None]

    # attention mask is the same as mask, or triangular inferior attention (causal)
    bs = lengths.size(0)
    # 如果 causal 变量为真,创建一个注意力掩码,基于 alen 的长度重复创建一个 (bs, slen, slen) 的张量,
    # 并检查每个位置上的长度是否小于等于相应的 alen 值。
    # 如果 causal 变量为假,则直接使用 mask 作为注意力掩码。

    if causal:
        attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None]
    else:
        attn_mask = mask

    # 执行一些基本的健全性检查,确保 mask 的形状为 (bs, slen)
    assert mask.size() == (bs, slen)
    # 如果 causal 为真,则检查 attn_mask 的形状为 (bs, slen, slen),否则不需要此检查
    assert causal is False or attn_mask.size() == (bs, slen, slen)

    # 返回最终的 mask 和 attn_mask
    return mask, attn_mask
# 定义多头注意力机制的类
class MultiHeadAttention(nn.Module):
    # 类变量,用于生成唯一的层 ID
    NEW_ID = itertools.count()

    # 初始化方法
    def __init__(self, n_heads, dim, config):
        super().__init__()
        # 分配新的层 ID 给当前实例
        self.layer_id = next(MultiHeadAttention.NEW_ID)
        self.dim = dim  # 注意力机制的维度
        self.n_heads = n_heads  # 头的数量
        self.dropout = config.attention_dropout  # 注意力机制的 dropout 概率
        assert self.dim % self.n_heads == 0  # 确保维度可以整除头的数量

        # 定义线性层,用于计算查询(Q)、键(K)、值(V)和输出
        self.q_lin = nn.Linear(dim, dim)
        self.k_lin = nn.Linear(dim, dim)
        self.v_lin = nn.Linear(dim, dim)
        self.out_lin = nn.Linear(dim, dim)
        
        # 存储被修剪的注意力头的索引
        self.pruned_heads = set()

    # 方法:修剪不需要的注意力头
    def prune_heads(self, heads):
        attention_head_size = self.dim // self.n_heads  # 每个头的注意力大小
        if len(heads) == 0:
            return
        
        # 查找可修剪的注意力头及其索引
        heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads)
        
        # 对线性层进行修剪
        self.q_lin = prune_linear_layer(self.q_lin, index)
        self.k_lin = prune_linear_layer(self.k_lin, index)
        self.v_lin = prune_linear_layer(self.v_lin, index)
        self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
        
        # 更新超参数:头的数量和注意力机制的维度
        self.n_heads = self.n_heads - len(heads)
        self.dim = attention_head_size * self.n_heads
        
        # 更新已修剪的头的集合
        self.pruned_heads = self.pruned_heads.union(heads)
    def forward(self, input, mask, kv=None, cache=None, head_mask=None, output_attentions=False):
        """
        Self-attention (if kv is None) or attention over source sentence (provided by kv).
        """
        # Input is (bs, qlen, dim)
        # Mask is (bs, klen) (non-causal) or (bs, klen, klen)
        bs, qlen, dim = input.size()  # 获取输入张量的形状信息,bs为batch size,qlen为序列长度,dim为特征维度
        if kv is None:
            klen = qlen if cache is None else cache["slen"] + qlen  # 如果kv为None,计算klen为当前序列长度或加上缓存序列长度
        else:
            klen = kv.size(1)  # 如果kv不为None,计算klen为kv张量的第二维长度
        # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
        n_heads = self.n_heads  # 获取注意力头的数量
        dim_per_head = self.dim // n_heads  # 计算每个注意力头的特征维度
        mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen)  # 根据mask张量的维度,确定其重塑形状

        def shape(x):
            """projection"""
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)  # 对输入张量x进行投影操作,变换其形状和维度顺序

        def unshape(x):
            """compute context"""
            return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)  # 对输入张量x进行反投影操作,计算上下文信息

        q = shape(self.q_lin(input))  # 对输入input进行线性变换后,再进行投影操作,得到查询向量q
                                      # 形状为(bs, n_heads, qlen, dim_per_head)
        if kv is None:
            k = shape(self.k_lin(input))  # 对输入input进行线性变换后,再进行投影操作,得到键向量k
                                          # 形状为(bs, n_heads, qlen, dim_per_head)
            v = shape(self.v_lin(input))  # 对输入input进行线性变换后,再进行投影操作,得到值向量v
                                          # 形状为(bs, n_heads, qlen, dim_per_head)
        elif cache is None or self.layer_id not in cache:
            k = v = kv
            k = shape(self.k_lin(k))  # 对输入kv进行线性变换后,再进行投影操作,得到键向量k
                                      # 形状为(bs, n_heads, klen, dim_per_head)
            v = shape(self.v_lin(v))  # 对输入kv进行线性变换后,再进行投影操作,得到值向量v
                                      # 形状为(bs, n_heads, klen, dim_per_head)

        if cache is not None:
            if self.layer_id in cache:
                if kv is None:
                    k_, v_ = cache[self.layer_id]
                    k = torch.cat([k_, k], dim=2)  # 将缓存中的键向量k_和当前计算得到的k拼接在一起
                                                  # 形状为(bs, n_heads, klen, dim_per_head)
                    v = torch.cat([v_, v], dim=2)  # 将缓存中的值向量v_和当前计算得到的v拼接在一起
                                                  # 形状为(bs, n_heads, klen, dim_per_head)
                else:
                    k, v = cache[self.layer_id]  # 直接从缓存中获取键向量k和值向量v

            cache[self.layer_id] = (k, v)  # 更新缓存中当前层的键值对

        q = q / math.sqrt(dim_per_head)  # 对查询向量q进行缩放操作,以确保在计算注意力分数时的数值稳定性
                                         # 形状为(bs, n_heads, qlen, dim_per_head)
        scores = torch.matmul(q, k.transpose(2, 3))  # 计算查询向量q和键向量k的注意力分数
                                                    # 形状为(bs, n_heads, qlen, klen)
        mask = (mask == 0).view(mask_reshape).expand_as(scores)  # 根据mask张量将无效位置的注意力分数置为极小值
                                                                 # 形状为(bs, n_heads, qlen, klen)
        scores.masked_fill_(mask, torch.finfo(scores.dtype).min)  # 使用极小值填充无效位置的注意力分数
                                                                  # 形状为(bs, n_heads, qlen, klen)

        weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)  # 计算注意力权重,对注意力分数进行softmax操作
                                                                                 # 形状为(bs, n_heads, qlen, klen)
        weights = nn.functional.dropout(weights, p=self.dropout, training=self.training)  # 对注意力权重进行dropout操作,用于模型训练防止过拟合
                                                                                           # 形状为(bs, n_heads, qlen, klen)

        # Mask heads if we want to
        if head_mask is not None:
            weights = weights * head_mask  # 如果指定了头部掩码,则对注意力权重进行头部掩码操作

        context = torch.matmul(weights, v)  # 使用注意力权重对值向量v进行加权求和,得到上下文张量
                                            # 形状为(bs, n_heads, qlen, dim_per_head)
        context = unshape(context)  # 将加权求和得到的上下文张量进行反投影操作,得到最终的上下文表示
                                    # 形状为(bs, qlen, dim)

        outputs = (self.out_lin(context),)  # 将上下文张量传入输出层进行线性变换,得到最终的输出
                                           # 形状为(bs, qlen, dim)
        if output_attentions:
            outputs = outputs + (weights,)  # 如果需要输出注意力权重,则将注意力权重作为额外输出

        return outputs  # 返回模型的输出结果
class TransformerFFN(nn.Module):
    def __init__(self, in_dim, dim_hidden, out_dim, config):
        super().__init__()
        self.dropout = config.dropout  # 从配置中获取 dropout 率
        self.lin1 = nn.Linear(in_dim, dim_hidden)  # 创建一个线性层,输入维度为 in_dim,输出维度为 dim_hidden
        self.lin2 = nn.Linear(dim_hidden, out_dim)  # 创建另一个线性层,输入维度为 dim_hidden,输出维度为 out_dim
        self.act = gelu if config.gelu_activation else nn.functional.relu  # 根据配置选择激活函数为 GELU 或 ReLU
        self.chunk_size_feed_forward = config.chunk_size_feed_forward  # 从配置中获取前向传播的分块大小
        self.seq_len_dim = 1  # 序列长度的维度设为 1

    def forward(self, input):
        return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)

    def ff_chunk(self, input):
        x = self.lin1(input)  # 应用第一个线性层
        x = self.act(x)  # 应用激活函数
        x = self.lin2(x)  # 应用第二个线性层
        x = nn.functional.dropout(x, p=self.dropout, training=self.training)  # 应用 dropout
        return x


class XLMPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = XLMConfig  # 设置配置类为 XLMConfig
    load_tf_weights = None  # 不使用 TensorFlow 权重加载
    base_model_prefix = "transformer"  # 基础模型前缀为 "transformer"

    def __init__(self, *inputs, **kwargs):
        super().__init__(*inputs, **kwargs)  # 调用父类的构造方法

    @property
    def dummy_inputs(self):
        inputs_list = torch.tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])  # 创建模型的虚拟输入张量
        attns_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])  # 创建模型的虚拟注意力张量
        if self.config.use_lang_emb and self.config.n_langs > 1:
            langs_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])  # 创建虚拟语言嵌入张量
        else:
            langs_list = None
        return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list}  # 返回虚拟输入的字典形式

    def _init_weights(self, module):
        """Initialize the weights."""
        if isinstance(module, nn.Embedding):  # 如果是嵌入层
            if self.config is not None and self.config.embed_init_std is not None:
                nn.init.normal_(module.weight, mean=0, std=self.config.embed_init_std)  # 使用正态分布初始化权重
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()  # 如果有 padding_idx,则将对应位置的权重置零
        if isinstance(module, nn.Linear):  # 如果是线性层
            if self.config is not None and self.config.init_std is not None:
                nn.init.normal_(module.weight, mean=0, std=self.config.init_std)  # 使用正态分布初始化权重
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0.0)  # 将偏置项初始化为常数 0
        if isinstance(module, nn.LayerNorm):  # 如果是 LayerNorm 层
            module.bias.data.zero_()  # 将偏置项置零
            module.weight.data.fill_(1.0)  # 将权重项填充为 1.0


@dataclass
class XLMForQuestionAnsweringOutput(ModelOutput):
    """
    Base class for outputs of question answering models using a `SquadHead`.
    """
    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
            分类损失,作为起始标记和结束标记分类损失的总和(如果提供了 `start_positions` 和 `end_positions`)。
        start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
            开始标记可能性的对数概率,对应于前 `config.start_n_top` 个可能性(使用 Beam Search)。
        start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
            开始标记可能性的索引,对应于前 `config.start_n_top` 个可能性(使用 Beam Search)。
        end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
            结束标记可能性的对数概率,对应于前 `config.start_n_top * config.end_n_top` 个可能性(使用 Beam Search)。
        end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
            结束标记可能性的索引,对应于前 `config.start_n_top * config.end_n_top` 个可能性(使用 Beam Search)。
        cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
            答案是否不可能的标签的对数概率。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            模型每层的隐藏状态,包括初始嵌入输出,形状为 `(batch_size, sequence_length, hidden_size)`。
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            自注意力机制注意力权重,用于计算自注意力头中的加权平均值,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)

This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.

Parameters:
    config ([`XLMConfig`]): Model configuration class with all the parameters of the model.
        Initializing with a config file does not load the weights associated with the model, only the
        configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

XLM_INPUTS_DOCSTRING = r"""
"""

@add_start_docstrings(
    "The bare XLM Model transformer outputting raw hidden-states without any specific head on top.",
    XLM_START_DOCSTRING,
)
class XLMModel(XLMPreTrainedModel):
    """
    XLM Model class inheriting from XLMPreTrainedModel.
    """

    def get_input_embeddings(self):
        """
        Returns the input embeddings of the model.
        """
        return self.embeddings

    def set_input_embeddings(self, new_embeddings):
        """
        Set the input embeddings of the model to new_embeddings.
        """
        self.embeddings = new_embeddings

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model.
        
        Args:
            heads_to_prune (dict): Dictionary of {layer_num: list of heads to prune in this layer}.
                See base class PreTrainedModel.
        """
        for layer, heads in heads_to_prune.items():
            self.attentions[layer].prune_heads(heads)

    @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=BaseModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        langs: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        lengths: Optional[torch.Tensor] = None,
        cache: Optional[Dict[str, torch.Tensor]] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        """
        Forward pass of the XLM model.

        Args:
            input_ids (torch.Tensor, optional): Indices of input sequence tokens in the vocabulary.
            attention_mask (torch.Tensor, optional): Mask to avoid performing attention on padding token indices.
            langs (torch.Tensor, optional): Language IDs for multilingual models (not used here).
            token_type_ids (torch.Tensor, optional): Segment token indices to indicate first and second portions of the inputs.
            position_ids (torch.Tensor, optional): Indices of positions of each input sequence tokens in the position embeddings.
            lengths (torch.Tensor, optional): Lengths of each sequence to avoid masking beyond the sequence length.
            cache (Dict[str, torch.Tensor], optional): Dictionary with precomputed hidden-states.
            head_mask (torch.Tensor, optional): Mask to nullify selected heads of the self-attention modules.
            inputs_embeds (torch.Tensor, optional): External embeddings for the input tokens.
            output_attentions (bool, optional): Whether to output the attentions weights.
            output_hidden_states (bool, optional): Whether to output the hidden states.
            return_dict (bool, optional): Whether to return a dictionary instead of a tuple of outputs.

        Returns:
            BaseModelOutput: Model output that contains various elements depending on the configuration.
        """
        # Implementation of the forward pass is omitted here as it's a part of the model's internal details.

class XLMPredLayer(nn.Module):
    """
    Prediction layer (cross_entropy or adaptive_softmax).
    """
    def __init__(self, config):
        super().__init__()
        self.asm = config.asm  # 从配置中获取是否使用自适应softmax的标志
        self.n_words = config.n_words  # 从配置中获取词汇表大小
        self.pad_index = config.pad_index  # 从配置中获取填充索引
        dim = config.emb_dim  # 从配置中获取词嵌入维度

        if config.asm is False:
            # 如果不使用自适应softmax,则创建一个线性投影层
            self.proj = nn.Linear(dim, config.n_words, bias=True)
        else:
            # 如果使用自适应softmax,则创建一个自适应softmax损失层
            self.proj = nn.AdaptiveLogSoftmaxWithLoss(
                in_features=dim,
                n_classes=config.n_words,
                cutoffs=config.asm_cutoffs,
                div_value=config.asm_div_value,
                head_bias=True,  # 默认为False,这里设置为True
            )

    def forward(self, x, y=None):
        """计算损失,并可选地计算分数。"""
        outputs = ()  # 初始化一个空的元组用于存储输出

        if self.asm is False:
            # 如果不使用自适应softmax,则计算投影层的分数
            scores = self.proj(x)
            outputs = (scores,) + outputs  # 将分数添加到输出元组中
            if y is not None:
                # 如果标签不为空,则计算交叉熵损失
                loss = nn.functional.cross_entropy(scores.view(-1, self.n_words), y.view(-1), reduction="mean")
                outputs = (loss,) + outputs  # 将损失添加到输出元组中
        else:
            # 如果使用自适应softmax,则计算log_prob方法得到的分数
            scores = self.proj.log_prob(x)
            outputs = (scores,) + outputs  # 将分数添加到输出元组中
            if y is not None:
                # 如果标签不为空,则调用自适应softmax的forward方法计算损失
                _, loss = self.proj(x, y)
                outputs = (loss,) + outputs  # 将损失添加到输出元组中

        return outputs
"""
The XLM Model transformer with a language modeling head on top (linear layer with weights tied to the input
embeddings).
"""
# 继承自预训练模型基类 XLMPreTrainedModel 的 XLM Model,增加了语言建模头部
class XLMWithLMHeadModel(XLMPreTrainedModel):
    # 定义需要共享权重的层
    _tied_weights_keys = ["pred_layer.proj.weight"]

    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)
        # 创建 XLMModel 实例,用于进行主要的 Transformer 编码
        self.transformer = XLMModel(config)
        # 创建 XLMPredLayer 实例,用于语言模型头部预测
        self.pred_layer = XLMPredLayer(config)

        # 初始化权重并应用最终处理
        self.post_init()

    # 返回输出嵌入的方法
    def get_output_embeddings(self):
        return self.pred_layer.proj

    # 设置输出嵌入的方法
    def set_output_embeddings(self, new_embeddings):
        self.pred_layer.proj = new_embeddings

    # 为生成准备输入的方法,处理输入数据和语言 ID
    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        mask_token_id = self.config.mask_token_id
        lang_id = self.config.lang_id

        effective_batch_size = input_ids.shape[0]
        # 创建与输入形状相同的掩码张量,填充特殊标记 ID
        mask_token = torch.full((effective_batch_size, 1), mask_token_id, dtype=torch.long, device=input_ids.device)
        input_ids = torch.cat([input_ids, mask_token], dim=1)
        # 如果存在语言 ID,则创建相同形状的语言 ID 张量;否则为 None
        if lang_id is not None:
            langs = torch.full_like(input_ids, lang_id)
        else:
            langs = None
        # 返回处理后的输入字典
        return {"input_ids": input_ids, "langs": langs}

    """
    Forward 方法的函数签名注释,描述了输入参数和输出的相关文档字符串。

    Parameters:
        input_ids (Optional[torch.Tensor]): 输入的 token IDs 张量,默认为 None。
        attention_mask (Optional[torch.Tensor]): 注意力掩码张量,默认为 None。
        langs (Optional[torch.Tensor]): 语言 ID 张量,默认为 None。
        token_type_ids (Optional[torch.Tensor]): token 类型 ID 张量,默认为 None。
        position_ids (Optional[torch.Tensor]): 位置 ID 张量,默认为 None。
        lengths (Optional[torch.Tensor]): 长度张量,默认为 None。
        cache (Optional[Dict[str, torch.Tensor]]): 缓存字典,默认为 None。
        head_mask (Optional[torch.Tensor]): 头部掩码张量,默认为 None。
        inputs_embeds (Optional[torch.Tensor]): 输入嵌入张量,默认为 None。
        labels (Optional[torch.Tensor]): 标签张量,默认为 None。
        output_attentions (Optional[bool]): 是否输出注意力,默认为 None。
        output_hidden_states (Optional[bool]): 是否输出隐藏状态,默认为 None。
        return_dict (Optional[bool]): 是否返回字典,默认为 None。
    """
    @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=MaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
        mask="<special1>",
    )
    # 模型前向传播方法,接受多个输入参数,并按照预期的格式进行文档化
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        langs: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        lengths: Optional[torch.Tensor] = None,
        cache: Optional[Dict[str, torch.Tensor]] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        # 增加了多个文档字符串的装饰器,描述了该方法的使用情况和示例

        # 省略部分参数文档
        ...
        ):
        # 实际方法的具体实现在模型类的实际应用中完成,不在这里具体展示
        pass
        ) -> Union[Tuple, MaskedLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        """
        # 如果 return_dict 不为 None,则使用传入的值,否则使用 self.config.use_return_dict 的值
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 使用 Transformer 模型处理输入数据
        transformer_outputs = self.transformer(
            input_ids,
            attention_mask=attention_mask,
            langs=langs,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            lengths=lengths,
            cache=cache,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取 Transformer 模型的输出
        output = transformer_outputs[0]

        # 使用预测层处理输出和标签,返回结果为损失和对数概率或仅为对数概率,取决于是否提供了标签
        outputs = self.pred_layer(output, labels)  # (loss, logits) or (logits,) depending on if labels are provided.

        # 如果 return_dict 为 False,则返回除了第一个元素(损失)外的所有元素
        if not return_dict:
            return outputs + transformer_outputs[1:]

        # 如果 return_dict 为 True,则返回 MaskedLMOutput 对象,包括损失、对数概率、隐藏状态和注意力权重
        return MaskedLMOutput(
            loss=outputs[0] if labels is not None else None,
            logits=outputs[0] if labels is None else outputs[1],
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )
# 使用装饰器为类添加文档字符串,描述这是一个 XLM 模型,用于序列分类/回归任务,例如 GLUE 任务
# 通过继承 XLMPreTrainedModel 类来定义 XLM 序列分类模型
class XLMForSequenceClassification(XLMPreTrainedModel):
    def __init__(self, config):
        # 调用父类构造函数初始化模型参数
        super().__init__(config)
        # 设置模型的类别数目
        self.num_labels = config.num_labels
        # 保存配置信息
        self.config = config

        # 初始化 XLM 模型和序列摘要处理器
        self.transformer = XLMModel(config)
        self.sequence_summary = SequenceSummary(config)

        # 执行后期初始化,包括权重初始化和最终处理
        self.post_init()

    # 使用装饰器为 forward 方法添加文档字符串,描述该方法的输入
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        langs: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        lengths: Optional[torch.Tensor] = None,
        cache: Optional[Dict[str, torch.Tensor]] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, SequenceClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        # 设置返回字典,如果未提供则使用配置中的返回字典设置
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用transformer模型处理输入数据,获取transformer的输出结果
        transformer_outputs = self.transformer(
            input_ids,
            attention_mask=attention_mask,
            langs=langs,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            lengths=lengths,
            cache=cache,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从transformer的输出中获取主要的输出结果
        output = transformer_outputs[0]
        # 对transformer输出进行汇总处理,得到logits
        logits = self.sequence_summary(output)

        # 初始化损失值为None
        loss = None
        # 如果提供了标签数据
        if labels is not None:
            # 确定问题类型,如果未指定,则根据标签数据类型和标签数量设置问题类型
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            # 根据问题类型选择相应的损失函数和计算损失值
            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        # 如果不要求返回字典形式的输出,重新组织输出格式
        if not return_dict:
            output = (logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        # 返回SequenceClassifierOutput对象,包含损失值、logits、隐藏状态和注意力权重
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )
@add_start_docstrings(
    """
    XLM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
    XLM_START_DOCSTRING,
)


这段代码定义了一个XLM模型,该模型在其顶部具有一个用于抽取式问答任务(如SQuAD)的跨度分类头部,这个注释说明了模型的整体功能和用途。


class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):


定义了一个名为`XLMForQuestionAnsweringSimple`的类,它继承自`XLMPreTrainedModel`类,用于执行简单的问答任务。


def __init__(self, config):
    super().__init__(config)

    self.transformer = XLMModel(config)
    self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

    # Initialize weights and apply final processing
    self.post_init()


初始化方法定义了模型的构造函数。它首先调用父类的构造函数来初始化模型配置。然后创建了一个`XLMModel`实例作为`transformer`,并创建了一个线性层`qa_outputs`,用于预测答案的开始和结束位置。最后调用`post_init()`方法来初始化权重并应用最终处理。


@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
    checkpoint=_CHECKPOINT_FOR_DOC,
    output_type=QuestionAnsweringModelOutput,
    config_class=_CONFIG_FOR_DOC,
)


这些装饰器为`forward`方法添加了文档字符串,描述了模型前向传播的输入和输出格式,以及提供了示例代码和模型配置信息的链接。


def forward(
    self,
    input_ids: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    langs: Optional[torch.Tensor] = None,
    token_type_ids: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.Tensor] = None,
    lengths: Optional[torch.Tensor] = None,
    cache: Optional[Dict[str, torch.Tensor]] = None,
    head_mask: Optional[torch.Tensor] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    start_positions: Optional[torch.Tensor] = None,
    end_positions: Optional[torch.Tensor] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,


`forward`方法定义了模型的前向传播逻辑,接受多个输入参数,包括`input_ids`、`attention_mask`等,用于执行模型的计算和推理过程。
    # 此方法用于模型的前向传播,接受多个可选参数来控制输入和输出的细节

        input_ids: Optional[torch.Tensor] = None,
        # 输入的 token IDs,类型为 Torch 张量,默认为 None

        attention_mask: Optional[torch.Tensor] = None,
        # 注意力掩码,用于指示哪些位置是需要注意的,默认为 None

        langs: Optional[torch.Tensor] = None,
        # 输入序列的语言 ID,类型为 Torch 张量,默认为 None

        token_type_ids: Optional[torch.Tensor] = None,
        # 用于区分不同句子或序列的 token 类型 ID,默认为 None

        position_ids: Optional[torch.Tensor] = None,
        # 位置 ID,用于指示每个 token 在序列中的位置,默认为 None

        lengths: Optional[torch.Tensor] = None,
        # 输入序列的长度信息,类型为 Torch 张量,默认为 None

        cache: Optional[Dict[str, torch.Tensor]] = None,
        # 缓存字典,用于存储中间计算结果以加速后续计算,默认为 None

        head_mask: Optional[torch.Tensor] = None,
        # 多头注意力机制中的头部掩码,用于控制哪些注意力头部被屏蔽,默认为 None

        inputs_embeds: Optional[torch.Tensor] = None,
        # 输入的嵌入表示,类型为 Torch 张量,默认为 None

        start_positions: Optional[torch.Tensor] = None,
        # 开始位置的标签,用于答案抽取任务,默认为 None

        end_positions: Optional[torch.Tensor] = None,
        # 结束位置的标签,用于答案抽取任务,默认为 None

        is_impossible: Optional[torch.Tensor] = None,
        # 标记答案是否不可能存在的标签,默认为 None

        cls_index: Optional[torch.Tensor] = None,
        # CLS 标记的位置索引,默认为 None

        p_mask: Optional[torch.Tensor] = None,
        # 用于标记不需要参与损失计算的位置的掩码,默认为 None

        output_attentions: Optional[bool] = None,
        # 是否输出注意力权重,默认为 None

        output_hidden_states: Optional[bool] = None,
        # 是否输出隐藏状态,默认为 None

        return_dict: Optional[bool] = None,
        # 是否返回字典格式的输出,默认为 None
# 基于 XLM 模型,在其上面添加了一个用于标记分类(如命名实体识别)任务的线性层的模型定义
@add_start_docstrings(
    """
    XLM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
    Named-Entity-Recognition (NER) tasks.
    """,
    XLM_START_DOCSTRING,
)
class XLMForTokenClassification(XLMPreTrainedModel):
    def __init__(self, config):
        # 调用父类初始化函数
        super().__init__(config)
        # 设置标签数量
        self.num_labels = config.num_labels

        # XLM 模型的主体部分
        self.transformer = XLMModel(config)
        # Dropout 层
        self.dropout = nn.Dropout(config.dropout)
        # 标记分类器的线性层,输入大小为隐藏状态的大小,输出大小为标签数量
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        # 初始化权重并应用最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        langs: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        lengths: Optional[torch.Tensor] = None,
        cache: Optional[Dict[str, torch.Tensor]] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        # 输入参数的文档字符串
        **kwargs,
    ):
    ) -> Union[Tuple, TokenClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        # 确定是否应该返回字典格式的输出
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 将输入传递给transformer模型进行处理
        outputs = self.transformer(
            input_ids,
            attention_mask=attention_mask,
            langs=langs,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            lengths=lengths,
            cache=cache,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从transformer模型输出中提取序列输出
        sequence_output = outputs[0]

        # 应用dropout层到序列输出
        sequence_output = self.dropout(sequence_output)

        # 通过分类器获取logits(预测分数)
        logits = self.classifier(sequence_output)

        # 初始化损失值为None
        loss = None

        # 如果存在标签,则计算损失值
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        # 如果不需要返回字典格式的输出,则返回元组格式的结果
        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        # 如果需要返回字典格式的输出,则使用TokenClassifierOutput封装结果
        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
# 定义 XLM 多选分类模型,包含线性层和 softmax 在 transformer 的池化输出之上
@add_start_docstrings(
    """
    XLM Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
    softmax) e.g. for RocStories/SWAG tasks.
    """,
    XLM_START_DOCSTRING,
)
class XLMForMultipleChoice(XLMPreTrainedModel):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        # 初始化 XLMModel,用于处理输入序列
        self.transformer = XLMModel(config)
        
        # 初始化 SequenceSummary,用于生成池化的输出
        self.sequence_summary = SequenceSummary(config)
        
        # 初始化 logits_proj 线性层,用于多选分类的最终输出
        self.logits_proj = nn.Linear(config.num_labels, 1)

        # 初始化权重并应用最终处理
        self.post_init()

    # 多选分类模型的前向传播方法
    @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=MultipleChoiceModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        langs: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        lengths: Optional[torch.Tensor] = None,
        cache: Optional[Dict[str, torch.Tensor]] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        ) -> Union[Tuple, MultipleChoiceModelOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
            `input_ids` above)
        """
        # 如果 return_dict 不为 None,则使用传入的值,否则使用配置中的默认值
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # 计算选择题的个数,即 input_ids 的第二维的大小,如果 input_ids 为 None 则为 inputs_embeds 的第二维大小
        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

        # 将 input_ids 重塑为二维张量的形式,如果 input_ids 为 None 则为 None
        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
        # 将 attention_mask 重塑为二维张量的形式,如果 attention_mask 为 None 则为 None
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        # 将 token_type_ids 重塑为二维张量的形式,如果 token_type_ids 为 None 则为 None
        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
        # 将 position_ids 重塑为二维张量的形式,如果 position_ids 为 None 则为 None
        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
        # 将 langs 重塑为二维张量的形式,如果 langs 为 None 则为 None
        langs = langs.view(-1, langs.size(-1)) if langs is not None else None
        # 将 inputs_embeds 重塑为三维张量的形式,如果 inputs_embeds 为 None 则为 None
        inputs_embeds = (
            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
            if inputs_embeds is not None
            else None
        )

        # 如果使用 lengths 参数,发出警告并将其设置为 None,XLM 多选模型不支持 lengths 参数
        if lengths is not None:
            logger.warning(
                "The `lengths` parameter cannot be used with the XLM multiple choice models. Please use the "
                "attention mask instead."
            )
            lengths = None

        # 调用 Transformer 模型,传入各种参数进行计算
        transformer_outputs = self.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            langs=langs,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            lengths=lengths,
            cache=cache,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 从 Transformer 输出中获取最终的隐藏状态
        output = transformer_outputs[0]
        # 对输出进行序列摘要,得到 logits
        logits = self.sequence_summary(output)
        # 将 logits 投影到最终的结果空间
        logits = self.logits_proj(logits)
        # 将 logits 重塑为二维张量,形状为 (-1, num_choices)
        reshaped_logits = logits.view(-1, num_choices)

        # 初始化损失为 None
        loss = None
        # 如果提供了标签,则计算交叉熵损失
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)

        # 如果 return_dict 为 False,则返回非字典格式的输出
        if not return_dict:
            output = (reshaped_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        # 返回 MultipleChoiceModelOutput 对象,包括损失、重塑后的 logits,以及可能的额外信息
        return MultipleChoiceModelOutput(
            loss=loss,
            logits=reshaped_logits,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

.\models\xlm\tokenization_xlm.py

# coding=utf-8
# Copyright 2019 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for XLM."""


import json  # 导入处理 JSON 格式的库
import os    # 导入操作系统相关功能的库
import re    # 导入正则表达式的库
import sys   # 导入系统相关功能的库
import unicodedata  # 导入 Unicode 数据库
from typing import List, Optional, Tuple  # 导入类型提示相关功能

from ...tokenization_utils import PreTrainedTokenizer  # 导入预训练 Tokenizer 的工具类
from ...utils import logging  # 导入日志记录功能


logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

VOCAB_FILES_NAMES = {
    "vocab_file": "vocab.json",   # 词汇表文件名
    "merges_file": "merges.txt",  # 合并文件名
}

PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "FacebookAI/xlm-mlm-en-2048": "https://huggingface.co/FacebookAI/xlm-mlm-en-2048/resolve/main/vocab.json",
        "FacebookAI/xlm-mlm-ende-1024": "https://huggingface.co/FacebookAI/xlm-mlm-ende-1024/resolve/main/vocab.json",
        "FacebookAI/xlm-mlm-enfr-1024": "https://huggingface.co/FacebookAI/xlm-mlm-enfr-1024/resolve/main/vocab.json",
        "FacebookAI/xlm-mlm-enro-1024": "https://huggingface.co/FacebookAI/xlm-mlm-enro-1024/resolve/main/vocab.json",
        "FacebookAI/xlm-mlm-tlm-xnli15-1024": "https://huggingface.co/FacebookAI/xlm-mlm-tlm-xnli15-1024/resolve/main/vocab.json",
        "FacebookAI/xlm-mlm-xnli15-1024": "https://huggingface.co/FacebookAI/xlm-mlm-xnli15-1024/resolve/main/vocab.json",
        "FacebookAI/xlm-clm-enfr-1024": "https://huggingface.co/FacebookAI/xlm-clm-enfr-1024/resolve/main/vocab.json",
        "FacebookAI/xlm-clm-ende-1024": "https://huggingface.co/FacebookAI/xlm-clm-ende-1024/resolve/main/vocab.json",
        "FacebookAI/xlm-mlm-17-1280": "https://huggingface.co/FacebookAI/xlm-mlm-17-1280/resolve/main/vocab.json",
        "FacebookAI/xlm-mlm-100-1280": "https://huggingface.co/FacebookAI/xlm-mlm-100-1280/resolve/main/vocab.json",
    },
    # merges_file 字典,包含多个键值对,每个键值对表示一个模型名称和其对应的 merges.txt 文件链接
    "merges_file": {
        "FacebookAI/xlm-mlm-en-2048": "https://huggingface.co/FacebookAI/xlm-mlm-en-2048/resolve/main/merges.txt",
        "FacebookAI/xlm-mlm-ende-1024": "https://huggingface.co/FacebookAI/xlm-mlm-ende-1024/resolve/main/merges.txt",
        "FacebookAI/xlm-mlm-enfr-1024": "https://huggingface.co/FacebookAI/xlm-mlm-enfr-1024/resolve/main/merges.txt",
        "FacebookAI/xlm-mlm-enro-1024": "https://huggingface.co/FacebookAI/xlm-mlm-enro-1024/resolve/main/merges.txt",
        "FacebookAI/xlm-mlm-tlm-xnli15-1024": "https://huggingface.co/FacebookAI/xlm-mlm-tlm-xnli15-1024/resolve/main/merges.txt",
        "FacebookAI/xlm-mlm-xnli15-1024": "https://huggingface.co/FacebookAI/xlm-mlm-xnli15-1024/resolve/main/merges.txt",
        "FacebookAI/xlm-clm-enfr-1024": "https://huggingface.co/FacebookAI/xlm-clm-enfr-1024/resolve/main/merges.txt",
        "FacebookAI/xlm-clm-ende-1024": "https://huggingface.co/FacebookAI/xlm-clm-ende-1024/resolve/main/merges.txt",
        "FacebookAI/xlm-mlm-17-1280": "https://huggingface.co/FacebookAI/xlm-mlm-17-1280/resolve/main/merges.txt",
        "FacebookAI/xlm-mlm-100-1280": "https://huggingface.co/FacebookAI/xlm-mlm-100-1280/resolve/main/merges.txt",
    },
}

# 预训练位置嵌入大小的字典,每个模型名称对应其位置嵌入的大小
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "FacebookAI/xlm-mlm-en-2048": 512,
    "FacebookAI/xlm-mlm-ende-1024": 512,
    "FacebookAI/xlm-mlm-enfr-1024": 512,
    "FacebookAI/xlm-mlm-enro-1024": 512,
    "FacebookAI/xlm-mlm-tlm-xnli15-1024": 512,
    "FacebookAI/xlm-mlm-xnli15-1024": 512,
    "FacebookAI/xlm-clm-enfr-1024": 512,
    "FacebookAI/xlm-clm-ende-1024": 512,
    "FacebookAI/xlm-mlm-17-1280": 512,
    "FacebookAI/xlm-mlm-100-1280": 512,
}

# 预训练模型初始化配置的字典,每个模型名称对应其特定的初始化配置
PRETRAINED_INIT_CONFIGURATION = {
    "FacebookAI/xlm-mlm-en-2048": {"do_lowercase_and_remove_accent": True},
    "FacebookAI/xlm-mlm-ende-1024": {
        "do_lowercase_and_remove_accent": True,
        "id2lang": {0: "de", 1: "en"},
        "lang2id": {"de": 0, "en": 1},
    },
    "FacebookAI/xlm-mlm-enfr-1024": {
        "do_lowercase_and_remove_accent": True,
        "id2lang": {0: "en", 1: "fr"},
        "lang2id": {"en": 0, "fr": 1},
    },
    "FacebookAI/xlm-mlm-enro-1024": {
        "do_lowercase_and_remove_accent": True,
        "id2lang": {0: "en", 1: "ro"},
        "lang2id": {"en": 0, "ro": 1},
    },
    "FacebookAI/xlm-mlm-tlm-xnli15-1024": {
        "do_lowercase_and_remove_accent": True,
        "id2lang": {
            0: "ar",
            1: "bg",
            2: "de",
            3: "el",
            4: "en",
            5: "es",
            6: "fr",
            7: "hi",
            8: "ru",
            9: "sw",
            10: "th",
            11: "tr",
            12: "ur",
            13: "vi",
            14: "zh",
        },
        "lang2id": {
            "ar": 0,
            "bg": 1,
            "de": 2,
            "el": 3,
            "en": 4,
            "es": 5,
            "fr": 6,
            "hi": 7,
            "ru": 8,
            "sw": 9,
            "th": 10,
            "tr": 11,
            "ur": 12,
            "vi": 13,
            "zh": 14,
        },
    },
    "FacebookAI/xlm-mlm-xnli15-1024": {
        "do_lowercase_and_remove_accent": True,
        "id2lang": {
            0: "ar",
            1: "bg",
            2: "de",
            3: "el",
            4: "en",
            5: "es",
            6: "fr",
            7: "hi",
            8: "ru",
            9: "sw",
            10: "th",
            11: "tr",
            12: "ur",
            13: "vi",
            14: "zh",
        },
        "lang2id": {
            "ar": 0,
            "bg": 1,
            "de": 2,
            "el": 3,
            "en": 4,
            "es": 5,
            "fr": 6,
            "hi": 7,
            "ru": 8,
            "sw": 9,
            "th": 10,
            "tr": 11,
            "ur": 12,
            "vi": 13,
            "zh": 14,
        },
    },
    "FacebookAI/xlm-clm-enfr-1024": {
        "do_lowercase_and_remove_accent": True,
        "id2lang": {0: "en", 1: "fr"},
        "lang2id": {"en": 0, "fr": 1},
    },
    "FacebookAI/xlm-clm-ende-1024": {
        # 执行小写化和去除重音符号的操作,设为 True
        "do_lowercase_and_remove_accent": True,
        # ID 到语言的映射字典
        "id2lang": {0: "de", 1: "en"},
        # 语言到ID的映射字典
        "lang2id": {"de": 0, "en": 1},
    },
    "FacebookAI/xlm-mlm-17-1280": {
        # 执行小写化和去除重音符号的操作,设为 False
        "do_lowercase_and_remove_accent": False,
        # ID 到语言的映射字典,包含17种语言
        "id2lang": {
            0: "ar",
            1: "de",
            2: "en",
            3: "es",
            4: "fr",
            5: "hi",
            6: "it",
            7: "ja",
            8: "ko",
            9: "nl",
            10: "pl",
            11: "pt",
            12: "ru",
            13: "sv",
            14: "tr",
            15: "vi",
            16: "zh",
        },
        # 语言到ID的映射字典,与上面的ID到语言对应
        "lang2id": {
            "ar": 0,
            "de": 1,
            "en": 2,
            "es": 3,
            "fr": 4,
            "hi": 5,
            "it": 6,
            "ja": 7,
            "ko": 8,
            "nl": 9,
            "pl": 10,
            "pt": 11,
            "ru": 12,
            "sv": 13,
            "tr": 14,
            "vi": 15,
            "zh": 16,
        },
    },
}

# 定义函数结束,这是一个空的函数定义,没有具体的实现内容

def get_pairs(word):
    """
    Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length
    strings)
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        # 将当前字符与前一个字符作为一个符号对加入到集合中
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


def lowercase_and_remove_accent(text):
    """
    Lowercase and strips accents from a piece of text based on
    https://github.com/facebookresearch/XLM/blob/master/tools/lowercase_and_remove_accent.py
    """
    # 将文本以空格连接,然后转换为小写
    text = " ".join(text)
    text = text.lower()
    # 使用NFD规范将文本进行Unicode标准化
    text = unicodedata.normalize("NFD", text)
    output = []
    for char in text:
        # 获取Unicode字符的分类
        cat = unicodedata.category(char)
        # 如果字符是非spacing mark,则加入到输出列表中
        if cat == "Mn":
            continue
        output.append(char)
    # 将输出列表连接成字符串并按空格分割后返回
    return "".join(output).lower().split(" ")


def replace_unicode_punct(text):
    """
    Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl
    """
    # 替换文本中的Unicode标点符号为ASCII符号
    text = text.replace(",", ",")
    text = re.sub(r"。\s*", ". ", text)
    text = text.replace("、", ",")
    text = text.replace("”", '"')
    text = text.replace("“", '"')
    text = text.replace("∶", ":")
    text = text.replace(":", ":")
    text = text.replace("?", "?")
    text = text.replace("《", '"')
    text = text.replace("》", '"')
    text = text.replace(")", ")")
    text = text.replace("!", "!")
    text = text.replace("(", "(")
    text = text.replace(";", ";")
    text = text.replace("1", "1")
    text = text.replace("」", '"')
    text = text.replace("「", '"')
    text = text.replace("0", "0")
    text = text.replace("3", "3")
    text = text.replace("2", "2")
    text = text.replace("5", "5")
    text = text.replace("6", "6")
    text = text.replace("9", "9")
    text = text.replace("7", "7")
    text = text.replace("8", "8")
    text = text.replace("4", "4")
    text = re.sub(r".\s*", ". ", text)
    text = text.replace("~", "~")
    text = text.replace("’", "'")
    text = text.replace("…", "...")
    text = text.replace("━", "-")
    text = text.replace("〈", "<")
    text = text.replace("〉", ">")
    text = text.replace("【", "[")
    text = text.replace("】", "]")
    text = text.replace("%", "%")
    return text


def remove_non_printing_char(text):
    """
    Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl
    """
    output = []
    for char in text:
        # 获取Unicode字符的分类
        cat = unicodedata.category(char)
        # 如果字符以C开头,表示是不可打印字符,跳过
        if cat.startswith("C"):
            continue
        output.append(char)
    # 将输出列表连接成字符串后返回
    return "".join(output)


def romanian_preprocessing(text):
    """Sennrich's WMT16 scripts for Romanian preprocessing, used by model `FacebookAI/xlm-mlm-enro-1024`"""
    # https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/normalise-romanian.py
    # 替换文本中的特定Unicode字符为另一组Unicode字符
    text = text.replace("\u015e", "\u0218").replace("\u015f", "\u0219")
    text = text.replace("\u0162", "\u021a").replace("\u0163", "\u021b")
    # 替换文本中的特定 Unicode 字符为对应的 ASCII 字符
    text = text.replace("\u0218", "S").replace("\u0219", "s")  # 将 '\u0218' 替换为 'S','\u0219' 替换为 's'(s-comma)
    text = text.replace("\u021a", "T").replace("\u021b", "t")  # 将 '\u021a' 替换为 'T','\u021b' 替换为 't'(t-comma)
    text = text.replace("\u0102", "A").replace("\u0103", "a")  # 将 '\u0102' 替换为 'A','\u0103' 替换为 'a'
    text = text.replace("\u00C2", "A").replace("\u00E2", "a")  # 将 '\u00C2' 替换为 'A','\u00E2' 替换为 'a'
    text = text.replace("\u00CE", "I").replace("\u00EE", "i")  # 将 '\u00CE' 替换为 'I','\u00EE' 替换为 'i'
    # 返回替换后的文本
    return text
class XLMTokenizer(PreTrainedTokenizer):
    """
    Construct an XLM tokenizer. Based on Byte-Pair Encoding. The tokenization process is the following:

    - Moses preprocessing and tokenization for most supported languages.
    - Language specific tokenization for Chinese (Jieba), Japanese (KyTea) and Thai (PyThaiNLP).
    - Optionally lowercases and normalizes all inputs text.
    - The arguments `special_tokens` and the function `set_special_tokens`, can be used to add additional symbols (like
      "__classify__") to a vocabulary.
    - The `lang2id` attribute maps the languages supported by the model with their IDs if provided (automatically set
      for pretrained vocabularies).
    - The `id2lang` attributes does reverse mapping if provided (automatically set for pretrained vocabularies).

    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
    this superclass for more information regarding those methods.
    """
    # 定义函数的参数说明文档字符串,描述了每个参数的含义和默认值
    Args:
        vocab_file (`str`):
            Vocabulary file.
        merges_file (`str`):
            Merges file.
        unk_token (`str`, *optional*, defaults to `"<unk>"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        bos_token (`str`, *optional*, defaults to `"<s>"`):
            The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.

            <Tip>

            When building a sequence using special tokens, this is not the token that is used for the beginning of
            sequence. The token used is the `cls_token`.

            </Tip>

        sep_token (`str`, *optional*, defaults to `"</s>"`):
            The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
            sequence classification or for a text and a question for question answering. It is also used as the last
            token of a sequence built with special tokens.
        pad_token (`str`, *optional*, defaults to `"<pad>"`):
            The token used for padding, for example when batching sequences of different lengths.
        cls_token (`str`, *optional*, defaults to `"</s>"`):
            The classifier token which is used when doing sequence classification (classification of the whole sequence
            instead of per-token classification). It is the first token of the sequence when built with special tokens.
        mask_token (`str`, *optional*, defaults to `"<special1>"`):
            The token used for masking values. This is the token used when training this model with masked language
            modeling. This is the token which the model will try to predict.
        additional_special_tokens (`List[str]`, *optional*, defaults to `['<special0>', '<special1>', '<special2>', '<special3>', '<special4>', '<special5>', '<special6>', '<special7>', '<special8>', '<special9>']`):
            List of additional special tokens.
        lang2id (`Dict[str, int]`, *optional*):
            Dictionary mapping languages string identifiers to their IDs.
        id2lang (`Dict[int, str]`, *optional*):
            Dictionary mapping language IDs to their string identifiers.
        do_lowercase_and_remove_accent (`bool`, *optional*, defaults to `True`):
            Whether to lowercase and remove accents when tokenizing.
    ```

    # 初始化一些预定义的常量和映射,用于模型预训练时使用
    vocab_files_names = VOCAB_FILES_NAMES
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    # 初始化函数,用于实例化一个 XLMTokenizer 对象
    def __init__(
        self,
        vocab_file,
        merges_file,
        unk_token="<unk>",
        bos_token="<s>",
        sep_token="</s>",
        pad_token="<pad>",
        cls_token="</s>",
        mask_token="<special1>",
        additional_special_tokens=[
            "<special0>",
            "<special1>",
            "<special2>",
            "<special3>",
            "<special4>",
            "<special5>",
            "<special6>",
            "<special7>",
            "<special8>",
            "<special9>",
        ],
        lang2id=None,
        id2lang=None,
        do_lowercase_and_remove_accent=True,
        **kwargs,
    ):
        # 尝试导入 sacremoses 库,如果导入失败则抛出 ImportError
        try:
            import sacremoses
        except ImportError:
            raise ImportError(
                "You need to install sacremoses to use XLMTokenizer. "
                "See https://pypi.org/project/sacremoses/ for installation."
            )

        # 将 sacremoses 模块赋值给 self.sm
        self.sm = sacremoses

        # 缓存 sm.MosesPunctNormalizer 实例的字典
        self.cache_moses_punct_normalizer = {}
        # 缓存 sm.MosesTokenizer 实例的字典
        self.cache_moses_tokenizer = {}

        # 支持自定义分词器的语言集合,包括中文、泰语和日语
        self.lang_with_custom_tokenizer = {"zh", "th", "ja"}

        # 是否执行小写化和去除重音,用于当前支持的模型(v1.2.0)和 XLM-17 & 100 模型的区分
        self.do_lowercase_and_remove_accent = do_lowercase_and_remove_accent
        self.lang2id = lang2id
        self.id2lang = id2lang

        # 如果 lang2id 和 id2lang 都不为 None,则断言它们的长度相等
        if lang2id is not None and id2lang is not None:
            assert len(lang2id) == len(id2lang)

        # 日语分词器和中文分词器初始化为 None
        self.ja_word_tokenizer = None
        self.zh_word_tokenizer = None

        # 从 vocab_file 中读取编码器(encoder)的 JSON 格式数据
        with open(vocab_file, encoding="utf-8") as vocab_handle:
            self.encoder = json.load(vocab_handle)

        # 构建解码器(decoder),将编码器的键值对反转
        self.decoder = {v: k for k, v in self.encoder.items()}

        # 从 merges_file 中读取 BPE merges 数据并处理成字典形式的 bpe_ranks
        with open(merges_file, encoding="utf-8") as merges_handle:
            merges = merges_handle.read().split("\n")[:-1]
        merges = [tuple(merge.split()[:2]) for merge in merges]
        self.bpe_ranks = dict(zip(merges, range(len(merges))))

        # 缓存对象
        self.cache = {}

        # 调用父类的初始化方法,传递各种参数和关键字参数
        super().__init__(
            unk_token=unk_token,
            bos_token=bos_token,
            sep_token=sep_token,
            pad_token=pad_token,
            cls_token=cls_token,
            mask_token=mask_token,
            additional_special_tokens=additional_special_tokens,
            lang2id=lang2id,
            id2lang=id2lang,
            do_lowercase_and_remove_accent=do_lowercase_and_remove_accent,
            **kwargs,
        )

    # do_lower_case 属性的 getter 方法,返回 do_lowercase_and_remove_accent 的值
    @property
    def do_lower_case(self):
        return self.do_lowercase_and_remove_accent

    # 使用 sacremoses 库的 MosesPunctNormalizer 进行标点符号规范化处理
    def moses_punct_norm(self, text, lang):
        # 如果 lang 不在 cache_moses_punct_normalizer 的键中,则创建一个新的 MosesPunctNormalizer 实例
        if lang not in self.cache_moses_punct_normalizer:
            punct_normalizer = self.sm.MosesPunctNormalizer(lang=lang)
            self.cache_moses_punct_normalizer[lang] = punct_normalizer
        else:
            # 否则从缓存中获取现有的 MosesPunctNormalizer 实例
            punct_normalizer = self.cache_moses_punct_normalizer[lang]
        # 调用 normalize 方法对文本进行标点符号规范化处理并返回结果
        return punct_normalizer.normalize(text)
    # 使用 Moses 分词器对文本进行分词处理,根据语言选择缓存的分词器实例或创建新的实例
    def moses_tokenize(self, text, lang):
        # 如果指定语言的分词器不在缓存中,则创建并存储
        if lang not in self.cache_moses_tokenizer:
            moses_tokenizer = self.sm.MosesTokenizer(lang=lang)
            self.cache_moses_tokenizer[lang] = moses_tokenizer
        else:
            # 否则,从缓存中获取已存储的分词器实例
            moses_tokenizer = self.cache_moses_tokenizer[lang]
        # 使用选定的分词器对文本进行分词处理,返回分词结果
        return moses_tokenizer.tokenize(text, return_str=False, escape=False)

    # 执行一系列预处理步骤对输入文本进行规范化处理,不返回字符串格式的文本
    def moses_pipeline(self, text, lang):
        # 替换文本中的 Unicode 标点符号
        text = replace_unicode_punct(text)
        # 使用指定语言的 Moses 标点规范化函数处理文本
        text = self.moses_punct_norm(text, lang)
        # 移除文本中的非打印字符
        text = remove_non_printing_char(text)
        # 返回处理后的文本
        return text

    # 使用 Mykytea 进行日语文本的分词处理,若实例未初始化,则进行初始化
    def ja_tokenize(self, text):
        # 如果日语分词器尚未初始化
        if self.ja_word_tokenizer is None:
            try:
                # 尝试导入 Mykytea 库进行初始化
                import Mykytea
                # 使用 Mykytea 初始化日语分词器
                self.ja_word_tokenizer = Mykytea.Mykytea(
                    f"-model {os.path.expanduser('~')}/local/share/kytea/model.bin"
                )
            except (AttributeError, ImportError):
                # 若导入失败,则记录错误信息并引发异常
                logger.error(
                    "Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper"
                    " (https://github.com/chezou/Mykytea-python) with the following steps"
                )
                logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea")
                logger.error("2. autoreconf -i")
                logger.error("3. ./configure --prefix=$HOME/local")
                logger.error("4. make && make install")
                logger.error("5. pip install kytea")
                raise
        # 使用日语分词器对文本进行分词处理,返回分词结果列表
        return list(self.ja_word_tokenizer.getWS(text))

    # 返回当前词汇表的大小,即编码器中条目的数量
    @property
    def vocab_size(self):
        return len(self.encoder)

    # 返回词汇表的字典表示,包括编码器和添加的特殊标记编码器
    def get_vocab(self):
        return dict(self.encoder, **self.added_tokens_encoder)
    def bpe(self, token):
        # 将输入的 token 转换为特定格式的元组 word,以便后续处理
        word = tuple(token[:-1]) + (token[-1] + "</w>",)
        # 如果 token 已经在缓存中,则直接返回缓存中的结果
        if token in self.cache:
            return self.cache[token]
        # 获取 token 中所有可能的 bigram 对
        pairs = get_pairs(word)

        # 如果没有找到任何 bigram 对,则在 token 后面加上结束符 "</w>" 并返回
        if not pairs:
            return token + "</w>"

        # 开始迭代处理 bigram 对,直到无法再合并
        while True:
            # 找到当前 word 中频率最小的 bigram
            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
            # 如果该 bigram 不在预先计算的频率表中,则停止合并
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            # 遍历 word 中的每个字符,根据找到的 bigram 进行合并或保留
            while i < len(word):
                try:
                    j = word.index(first, i)
                except ValueError:
                    new_word.extend(word[i:])
                    break
                else:
                    new_word.extend(word[i:j])
                    i = j

                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
                    new_word.append(first + second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            # 更新 word 为新的合并结果,并转换为元组
            new_word = tuple(new_word)
            word = new_word
            # 如果已经无法继续合并,则停止循环
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        
        # 将处理后的 word 转换为字符串形式
        word = " ".join(word)
        # 如果转换后的 word 是特定格式,则做相应替换处理
        if word == "\n  </w>":
            word = "\n</w>"
        # 将处理结果缓存起来,并返回
        self.cache[token] = word
        return word

    def _convert_token_to_id(self, token):
        """Converts a token (str) in an id using the vocab."""
        # 使用词汇表将 token 转换为对应的 ID,如果 token 不存在则使用未知词符号
        return self.encoder.get(token, self.encoder.get(self.unk_token))

    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        # 使用词汇表将 ID 转换为对应的 token,如果 ID 不存在则使用未知词符号
        return self.decoder.get(index, self.unk_token)

    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (string) in a single string."""
        # 将一系列 token 组合成一个字符串,替换特定结束符后返回
        out_string = "".join(tokens).replace("</w>", " ").strip()
        return out_string

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
        adding special tokens. An XLM sequence has the following format:

        - single sequence: `<s> X </s>`
        - pair of sequences: `<s> A </s> B </s>`

        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.

        """
        # 定义起始符 `<s>` 和分隔符 `</s>`
        bos = [self.bos_token_id]
        sep = [self.sep_token_id]

        # 如果没有提供 token_ids_1,则返回单个序列的输入 ID 列表
        if token_ids_1 is None:
            return bos + token_ids_0 + sep
        
        # 如果提供了 token_ids_1,则返回双序列的输入 ID 列表
        return bos + token_ids_0 + sep + token_ids_1 + sep
    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
        """
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` method.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """

        # If tokens already have special tokens, delegate to superclass method
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        # Calculate special tokens mask based on whether there is a second sequence
        if token_ids_1 is not None:
            # Case for sequence pair: [CLS] token_ids_0 [SEP] token_ids_1 [SEP]
            return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
        else:
            # Case for single sequence: [CLS] token_ids_0 [SEP]
            return [1] + ([0] * len(token_ids_0)) + [1]

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. An XLM sequence
        pair mask has the following format:

        ```
        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
        | first sequence    | second sequence |
        ```

        If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
        """
        # Define [SEP] and [CLS] tokens
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]

        # If only one sequence is provided
        if token_ids_1 is None:
            # Return token type IDs for single sequence: [CLS] token_ids_0 [SEP]
            return len(cls + token_ids_0 + sep) * [0]
        else:
            # Return token type IDs for sequence pair: [CLS] token_ids_0 [SEP] token_ids_1 [SEP]
            return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
    # 将词汇表保存到指定目录下的文件中
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # 检查保存目录是否存在,若不存在则记录错误并返回
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return
        
        # 构建词汇表文件路径
        vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )
        # 构建合并文件路径
        merge_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
        )

        # 将编码器(encoder)中的内容以 JSON 格式写入词汇表文件
        with open(vocab_file, "w", encoding="utf-8") as f:
            f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")

        # 初始化索引
        index = 0
        # 打开合并文件,以 UTF-8 编码写入
        with open(merge_file, "w", encoding="utf-8") as writer:
            # 遍历并排序 self.bpe_ranks 中的 BPE 标记及其索引,按索引排序
            for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
                # 若索引不连续,则记录警告信息
                if index != token_index:
                    logger.warning(
                        f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
                        " Please check that the tokenizer is not corrupted!"
                    )
                    index = token_index
                # 将 BPE 标记写入文件,并以换行符结尾
                writer.write(" ".join(bpe_tokens) + "\n")
                index += 1

        # 返回词汇表文件路径和合并文件路径
        return vocab_file, merge_file

    # 返回对象的序列化状态,用于 pickle 保存
    def __getstate__(self):
        # 复制对象的字典属性
        state = self.__dict__.copy()
        # 将 sm 属性设为 None,避免 pickle 时出现不必要的引用
        state["sm"] = None
        return state

    # 恢复对象的状态,用于 pickle 加载
    def __setstate__(self, d):
        # 将对象的字典属性恢复为给定的状态
        self.__dict__ = d

        # 尝试导入 sacremoses 库,如果失败则抛出 ImportError
        try:
            import sacremoses
        except ImportError:
            raise ImportError(
                "You need to install sacremoses to use XLMTokenizer. "
                "See https://pypi.org/project/sacremoses/ for installation."
            )

        # 将 sacremoses 库赋给对象的 sm 属性
        self.sm = sacremoses

.\models\xlm\__init__.py

# 导入类型检查模块
from typing import TYPE_CHECKING

# 导入自定义模块和异常
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available

# 定义导入结构字典,包含不同模块和对应的类/函数列表
_import_structure = {
    "configuration_xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMOnnxConfig"],
    "tokenization_xlm": ["XLMTokenizer"],
}

# 检查是否 Torch 可用,若不可用则引发 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果 Torch 可用,则添加 Torch 版本的 XLM 模型相关类到导入结构中
    _import_structure["modeling_xlm"] = [
        "XLM_PRETRAINED_MODEL_ARCHIVE_LIST",
        "XLMForMultipleChoice",
        "XLMForQuestionAnswering",
        "XLMForQuestionAnsweringSimple",
        "XLMForSequenceClassification",
        "XLMForTokenClassification",
        "XLMModel",
        "XLMPreTrainedModel",
        "XLMWithLMHeadModel",
    ]

# 检查是否 TensorFlow 可用,若不可用则引发 OptionalDependencyNotAvailable 异常
try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果 TensorFlow 可用,则添加 TensorFlow 版本的 XLM 模型相关类到导入结构中
    _import_structure["modeling_tf_xlm"] = [
        "TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST",
        "TFXLMForMultipleChoice",
        "TFXLMForQuestionAnsweringSimple",
        "TFXLMForSequenceClassification",
        "TFXLMForTokenClassification",
        "TFXLMMainLayer",
        "TFXLMModel",
        "TFXLMPreTrainedModel",
        "TFXLMWithLMHeadModel",
    ]

# 如果是类型检查模式,导入相应模块的类型和类
if TYPE_CHECKING:
    from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMOnnxConfig
    from .tokenization_xlm import XLMTokenizer

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入 Torch 版本的 XLM 模型相关类
        from .modeling_xlm import (
            XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
            XLMForMultipleChoice,
            XLMForQuestionAnswering,
            XLMForQuestionAnsweringSimple,
            XLMForSequenceClassification,
            XLMForTokenClassification,
            XLMModel,
            XLMPreTrainedModel,
            XLMWithLMHeadModel,
        )

    try:
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入相对当前目录下的 .modeling_tf_xlm 模块中的特定内容
        from .modeling_tf_xlm import (
            TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST,  # 导入 TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST 变量
            TFXLMForMultipleChoice,  # 导入 TFXLMForMultipleChoice 类
            TFXLMForQuestionAnsweringSimple,  # 导入 TFXLMForQuestionAnsweringSimple 类
            TFXLMForSequenceClassification,  # 导入 TFXLMForSequenceClassification 类
            TFXLMForTokenClassification,  # 导入 TFXLMForTokenClassification 类
            TFXLMMainLayer,  # 导入 TFXLMMainLayer 类
            TFXLMModel,  # 导入 TFXLMModel 类
            TFXLMPreTrainedModel,  # 导入 TFXLMPreTrainedModel 类
            TFXLMWithLMHeadModel,  # 导入 TFXLMWithLMHeadModel 类
        )
else:
    # 导入系统模块 sys
    import sys
    # 将当前模块(__name__)的映射关系指向一个 LazyModule 实例,以延迟加载模块
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\xlm_prophetnet\configuration_xlm_prophetnet.py

# 设置文件编码为UTF-8,确保代码中的中文和特殊字符能正确处理
# 版权声明和许可条款,指明代码的使用和分发规则
# 引入所需的模块和类,包括预训练配置和日志记录工具
from typing import Callable, Optional, Union

# 从配置工具中导入预训练配置类
from ...configuration_utils import PretrainedConfig
# 从工具模块中导入日志记录功能
from ...utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 预训练模型及其配置文件的映射字典,指定了模型名称和对应的配置文件URL
XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "microsoft/xprophetnet-large-wiki100-cased": (
        "https://huggingface.co/microsoft/xprophetnet-large-wiki100-cased/resolve/main/config.json"
    ),
}

# XLM-ProphetNet模型的配置类,继承自PretrainedConfig
class XLMProphetNetConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`XLMProphetNetModel`]. It is used to instantiate a
    XLMProphetNet model according to the specified arguments, defining the model architecture. Instantiating a
    configuration with the defaults will yield a similar configuration to that of the XLMProphetNet
    [microsoft/xprophetnet-large-wiki100-cased](https://huggingface.co/microsoft/xprophetnet-large-wiki100-cased)
    architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    """

    # 指定模型类型为"xlm-prophetnet"
    model_type = "xlm-prophetnet"
    # 在推理过程中需要忽略的键名列表
    keys_to_ignore_at_inference = ["past_key_values"]
    # 属性映射字典,将配置属性映射到模型中的具体参数
    attribute_map = {
        "num_attention_heads": "num_encoder_attention_heads",
    }
    # 初始化方法,用于设置模型的各种参数和选项
    def __init__(
        self,
        activation_dropout: Optional[float] = 0.1,  # 激活函数的dropout比例,默认为0.1
        activation_function: Optional[Union[str, Callable]] = "gelu",  # 激活函数的类型,默认为gelu
        vocab_size: Optional[int] = 30522,  # 词汇表大小,默认为30522
        hidden_size: Optional[int] = 1024,  # 隐藏层的尺寸,默认为1024
        encoder_ffn_dim: Optional[int] = 4096,  # 编码器中FFN层的维度,默认为4096
        num_encoder_layers: Optional[int] = 12,  # 编码器的层数,默认为12
        num_encoder_attention_heads: Optional[int] = 16,  # 编码器注意力头的数量,默认为16
        decoder_ffn_dim: Optional[int] = 4096,  # 解码器中FFN层的维度,默认为4096
        num_decoder_layers: Optional[int] = 12,  # 解码器的层数,默认为12
        num_decoder_attention_heads: Optional[int] = 16,  # 解码器注意力头的数量,默认为16
        attention_dropout: Optional[float] = 0.1,  # 注意力机制中的dropout比例,默认为0.1
        dropout: Optional[float] = 0.1,  # 全连接层的dropout比例,默认为0.1
        max_position_embeddings: Optional[int] = 512,  # 最大位置编码数,默认为512
        init_std: Optional[float] = 0.02,  # 初始化的标准差,默认为0.02
        is_encoder_decoder: Optional[bool] = True,  # 是否为编码-解码模型,默认为True
        add_cross_attention: Optional[bool] = True,  # 是否添加交叉注意力机制,默认为True
        decoder_start_token_id: Optional[int] = 0,  # 解码器起始token的ID,默认为0
        ngram: Optional[int] = 2,  # n-gram大小,默认为2
        num_buckets: Optional[int] = 32,  # 桶的数量,默认为32
        relative_max_distance: Optional[int] = 128,  # 相对最大距离,默认为128
        disable_ngram_loss: Optional[bool] = False,  # 是否禁用n-gram损失,默认为False
        eps: Optional[float] = 0.0,  # 用于数值稳定性的小常数,默认为0.0
        use_cache: Optional[bool] = True,  # 是否使用缓存,默认为True
        pad_token_id: Optional[int] = 0,  # 填充token的ID,默认为0
        bos_token_id: Optional[int] = 1,  # 开始token的ID,默认为1
        eos_token_id: Optional[int] = 2,  # 结束token的ID,默认为2
        **kwargs,
    ):
        self.vocab_size = vocab_size  # 设置词汇表大小
        self.hidden_size = hidden_size  # 设置隐藏层大小
        self.encoder_ffn_dim = encoder_ffn_dim  # 设置编码器中FFN层的维度
        self.num_encoder_layers = num_encoder_layers  # 设置编码器层数
        self.num_encoder_attention_heads = num_encoder_attention_heads  # 设置编码器注意力头数
        self.decoder_ffn_dim = decoder_ffn_dim  # 设置解码器中FFN层的维度
        self.num_decoder_layers = num_decoder_layers  # 设置解码器层数
        self.num_decoder_attention_heads = num_decoder_attention_heads  # 设置解码器注意力头数
        self.max_position_embeddings = max_position_embeddings  # 设置最大位置编码数
        self.init_std = init_std  # 设置初始化标准差
        self.activation_function = activation_function  # 设置激活函数类型

        # 用于XLMProphetNet的特定参数
        self.ngram = ngram  # 设置n-gram大小
        self.num_buckets = num_buckets  # 设置桶的数量
        self.relative_max_distance = relative_max_distance  # 设置相对最大距离
        self.disable_ngram_loss = disable_ngram_loss  # 设置是否禁用n-gram损失
        self.eps = eps  # 设置数值稳定性的小常数

        # 三种类型的dropout
        self.attention_dropout = attention_dropout  # 设置注意力机制的dropout比例
        self.activation_dropout = activation_dropout  # 设置激活函数的dropout比例
        self.dropout = dropout  # 设置全连接层的dropout比例

        self.use_cache = use_cache  # 设置是否使用缓存

        # 调用父类的初始化方法,设置其他参数
        super().__init__(
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            is_encoder_decoder=is_encoder_decoder,
            add_cross_attention=add_cross_attention,
            decoder_start_token_id=decoder_start_token_id,
            **kwargs,
        )

    @property
    def num_hidden_layers(self) -> int:
        return self.num_encoder_layers + self.num_decoder_layers

    @num_hidden_layers.setter
    # 定义一个方法 `num_hidden_layers`,用于设置隐藏层数量,这里抛出未实现错误
    def num_hidden_layers(self, value):
        # 抛出未实现错误,指示该模型不支持设置隐藏层数量
        raise NotImplementedError(
            "This model does not support the setting of `num_hidden_layers`. Please set `num_encoder_layers` and"
            " `num_decoder_layers`."
        )

.\models\xlm_prophetnet\modeling_xlm_prophetnet.py

# coding=utf-8
# Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch XLM-ProphetNet model."""

# 引入必要的库和模块
import copy
import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union

# 引入 PyTorch 库
import torch
import torch.utils.checkpoint
from torch import Tensor, nn
from torch.nn import LayerNorm

# 引入激活函数映射和模型输出
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
    ModelOutput,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
# 引入 XLM-ProphetNet 的配置文件
from .configuration_xlm_prophetnet import XLMProphetNetConfig

# 获取日志记录器
logger = logging.get_logger(__name__)

# 用于文档的配置信息
_CONFIG_FOR_DOC = "XLMProphetNetConfig"

# XLM-ProphetNet 的预训练模型列表
XLM_PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "microsoft/xprophetnet-large-wiki100-cased",
    # 查看所有 XLMProphetNet 模型的链接
    # 在 https://huggingface.co/models?filter=xprophetnet
]

# 从 src.transformers.models.prophetnet.modeling_prophetnet.PROPHETNET_START_DOCSTRING 复制的文档字符串,
# 将 ProphetNetConfig 替换为 XLMProphetNetConfig
XLM_PROPHETNET_START_DOCSTRING = r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    Original ProphetNet code can be found [here](https://github.com/microsoft/ProphetNet). Checkpoints were converted
    from original Fairseq checkpoints. For more information on the checkpoint conversion, please take a look at the
    file `convert_prophetnet_original_pytorch_checkpoint_to_pytorch.py`.

    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matters related to general usage and
    behavior.

    Parameters:
        config ([`XLMProphetNetConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

# 从 src.transformers.models.prophetnet.modeling_prophetnet.PROPHETNET_INPUTS_DOCSTRING 复制的文档字符串,
# 将 ProphetNet 替换为 XLMProphetNet
XLM_PROPHETNET_INPUTS_DOCSTRING = r"""
"""
Copied from src.transformers.models.prophetnet.modeling_prophetnet.PROPHETNET_STANDALONE_INPUTS_DOCSTRING with ProphetNet->XLMProphetNet
"""
XLM_PROPHETNET_STANDALONE_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""


# Copied from transformers.models.prophetnet.modeling_prophetnet.softmax
def softmax(hidden_state, dim, onnx_trace=False):
    """
    Applies softmax function along a specific dimension of the input tensor.

    Args:
        hidden_state (torch.Tensor): Input tensor to apply softmax.
        dim (int): Dimension along which softmax will be computed.
        onnx_trace (bool, optional): Whether to trace the operation for ONNX compatibility.

    Returns:
        torch.Tensor: Tensor after applying softmax along the specified dimension.
    """
    if onnx_trace:
        return nn.functional.softmax(hidden_state.float(), dim=dim)
    else:
        return nn.functional.softmax(hidden_state, dim=dim, dtype=torch.float32)


# Copied from transformers.models.prophetnet.modeling_prophetnet.ngram_attention_bias
def ngram_attention_bias(sequence_length, ngram, device, dtype):
    """
    Compute n-gram attention bias tensor for ProphetNet.

    Args:
        sequence_length (int): Length of the input sequence.
        ngram (int): Size of the n-gram.
        device (torch.device): Device on which to allocate the tensors.
        dtype (torch.dtype): Data type of the tensors.

    Returns:
        torch.Tensor: N-gram attention bias tensor of shape (ngram, sequence_length, 2 * sequence_length).
    """
    left_block = (
        torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * torch.finfo(dtype).min
    )
    right_block = left_block.detach().clone()
    # create bias
    for stream_idx in range(ngram):
        right_block[stream_idx].fill_diagonal_(0, wrap=False)
        left_block[stream_idx].triu_(-stream_idx + 1)

    left_block[:, :, 0] = 0
    return torch.cat([left_block, right_block], dim=2)
# 计算相对位置桶的函数,用于指定数量的桶、最大距离和相对位置列表
def compute_relative_buckets(num_buckets, max_distance, relative_positions, is_bidirectional=False):
    """
    This function computes individual parts of the relative position buckets. For more detail, see paper.
    """
    # 反转相对位置,用负数表示
    inv_relative_positions = -relative_positions
    # 初始化相对位置桶
    rel_positions_bucket = 0

    # 如果是双向的相对位置计算
    if is_bidirectional:
        # 将桶的数量减半
        num_buckets = num_buckets // 2
        # 根据负相对位置是否小于零,确定其所属桶的索引
        rel_positions_bucket = (
            rel_positions_bucket
            + torch.lt(inv_relative_positions, torch.zeros_like(inv_relative_positions)).int() * num_buckets
        )
        # 取相对位置的绝对值
        inv_relative_positions = torch.abs(inv_relative_positions)
    else:
        # 将负相对位置限制在非负数范围内
        inv_relative_positions = torch.max(inv_relative_positions, torch.zeros_like(inv_relative_positions))

    # 计算精确的最大值
    max_exact = num_buckets // 2
    # 判断是否是小距离的情况
    is_small = torch.lt(inv_relative_positions, max_exact)
    # 如果是大距离,使用对数函数计算其桶索引
    val_if_large = max_exact + torch.log(inv_relative_positions.float() / max_exact) / math.log(
        max_distance / max_exact
    ) * (num_buckets - max_exact)
    # 限制桶索引在合理范围内
    val_if_large = torch.min(val_if_large, torch.ones_like(val_if_large) * (num_buckets - 1)).int()
    # 根据距离大小判断采用小距离还是大距离计算的结果
    rel_positions_bucket = rel_positions_bucket + torch.where(is_small, inv_relative_positions.int(), val_if_large)
    # 返回相对位置桶
    return rel_positions_bucket


# 从transformers.models.prophetnet.modeling_prophetnet.compute_all_stream_relative_buckets复制而来
# 计算所有流的相对位置桶
def compute_all_stream_relative_buckets(num_buckets, max_distance, position_ids):
    """
    This function computes both main and predict relative position buckets. For more detail, see paper.
    """
    # 主流相对位置
    main_stream_relative_positions = position_ids.unsqueeze(1).repeat(1, position_ids.size(-1), 1)
    main_stream_relative_positions = main_stream_relative_positions - position_ids.unsqueeze(-1)

    # 预测流相对位置
    predicting_stream_relative_positions = torch.cat((position_ids - 1, position_ids), dim=-1).unsqueeze(1)
    predicting_stream_relative_positions = predicting_stream_relative_positions.repeat(1, position_ids.size(-1), 1)
    predicting_stream_relative_positions = predicting_stream_relative_positions - position_ids.unsqueeze(-1)

    # 获取主要和预测位置桶
    main_relative_position_buckets = compute_relative_buckets(
        num_buckets, max_distance, main_stream_relative_positions, is_bidirectional=False
    )
    predict_relative_position_buckets = compute_relative_buckets(
        num_buckets, max_distance, predicting_stream_relative_positions, is_bidirectional=False
    )
    # 返回主流和预测流的相对位置桶
    return main_relative_position_buckets, predict_relative_position_buckets


# 从transformers.models.prophetnet.modeling_prophetnet.ProphetNetSeq2SeqLMOutput中复制而来,
# 用于XLMProphetNet的序列到序列语言模型输出
@dataclass
class XLMProphetNetSeq2SeqLMOutput(ModelOutput):
    """
    Base class for sequence-to-sequence language models outputs.
    """

    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    logits_ngram: Optional[torch.FloatTensor] = None
    # 定义了多个可选类型的 Torch 张量元组变量,用于存储模型解码器的各种状态和注意力机制的输出
    past_key_values: Optional[Tuple[torch.FloatTensor]] = None
    decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_ngram_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    decoder_ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    
    # 定义了一个属性方法,用于获取解码器的交叉注意力机制,同时发出未来移除警告
    @property
    def decoder_cross_attentions(self):
        warnings.warn(
            "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions`"
            " instead.",
            FutureWarning,
        )
        return self.cross_attentions
@dataclass
# 定义 XLMProphetNetSeq2SeqModelOutput 类,继承自 ModelOutput,用于存储编码器模型的输出结果,包含预先计算的隐藏状态以加速顺序解码。
class XLMProphetNetSeq2SeqModelOutput(ModelOutput):
    """
    Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
    decoding.
    """

    # 最后一个隐藏状态,类型为 torch.FloatTensor
    last_hidden_state: torch.FloatTensor
    # 可选项,最后一个 n-gram 隐藏状态,类型为 torch.FloatTensor
    last_hidden_state_ngram: Optional[torch.FloatTensor] = None
    # 可选项,过去的键/值对,用于加速顺序解码,类型为 Tuple[torch.FloatTensor]
    past_key_values: Optional[Tuple[torch.FloatTensor]] = None
    # 可选项,解码器的隐藏状态序列,类型为 Tuple[torch.FloatTensor]
    decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # 可选项,解码器的 n-gram 隐藏状态序列,类型为 Tuple[torch.FloatTensor]
    decoder_ngram_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # 可选项,解码器的注意力权重序列,类型为 Tuple[torch.FloatTensor]
    decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    # 可选项,解码器的 n-gram 注意力权重序列,类型为 Tuple[torch.FloatTensor]]
    decoder_ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None
    # 可选项,交叉注意力权重序列,类型为 Tuple[torch.FloatTensor]
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
    # 可选项,编码器的最后一个隐藏状态,类型为 torch.FloatTensor
    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
    # 可选项,编码器的隐藏状态序列,类型为 Tuple[torch.FloatTensor]]
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # 可选项,编码器的注意力权重序列,类型为 Tuple[torch.FloatTensor]]
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None

    @property
    def decoder_cross_attentions(self):
        # 发出警告,提示 `decoder_cross_attentions` 将被移除,请使用 `cross_attentions` 替代
        warnings.warn(
            "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions`"
            " instead.",
            FutureWarning,
        )
        # 返回交叉注意力权重序列 cross_attentions
        return self.cross_attentions


@dataclass
# 定义 XLMProphetNetDecoderModelOutput 类,继承自 ModelOutput,用于存储解码器模型的输出结果。
class XLMProphetNetDecoderModelOutput(ModelOutput):
    """
    Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
    """

    # 最后一个隐藏状态,类型为 torch.FloatTensor
    last_hidden_state: torch.FloatTensor
    # 可选项,最后一个 n-gram 隐藏状态,类型为 torch.FloatTensor
    last_hidden_state_ngram: Optional[torch.FloatTensor] = None
    # 可选项,过去的键/值对,用于加速顺序解码,类型为 Tuple[torch.FloatTensor]
    past_key_values: Optional[Tuple[torch.FloatTensor]] = None
    # 可选项,隐藏状态序列,类型为 Tuple[torch.FloatTensor]
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # 可选项,n-gram 隐藏状态序列,类型为 Tuple[torch.FloatTensor]
    hidden_states_ngram: Optional[Tuple[torch.FloatTensor]] = None
    # 可选项,注意力权重序列,类型为 Tuple[torch.FloatTensor]
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    # 可选项,n-gram 注意力权重序列,类型为 Tuple[torch.FloatTensor]
    ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None
    # 可选项,交叉注意力权重序列,类型为 Tuple[torch.FloatTensor]
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None


@dataclass
# 定义 XLMProphetNetDecoderLMOutput 类,继承自 ModelOutput,用于存储解码器语言模型的输出结果。
class XLMProphetNetDecoderLMOutput(ModelOutput):
    """
    Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
    """

    # 可选项,损失值,类型为 torch.FloatTensor
    loss: Optional[torch.FloatTensor] = None
    # 预测的 logits,类型为 torch.FloatTensor
    logits: torch.FloatTensor = None
    # 可选项,预测的 n-gram logits,类型为 torch.FloatTensor
    logits_ngram: Optional[torch.FloatTensor] = None
    # 可选项,过去的键/值对,用于加速顺序解码,类型为 Tuple[torch.FloatTensor]
    past_key_values: Optional[Tuple[torch.FloatTensor]] = None
    # 可选项,隐藏状态序列,类型为 Tuple[torch.FloatTensor]
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # 可选项,n-gram 隐藏状态序列,类型为 Tuple[torch.FloatTensor]
    hidden_states_ngram: Optional[Tuple[torch.FloatTensor]] = None
    # 可选项,注意力权重序列,类型为 Tuple[torch.FloatTensor]
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    # 可选项,n-gram 注意力权重序列,类型为 Tuple[torch.FloatTensor]
    ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
# 从transformers.models.prophetnet.modeling_prophetnet.ProphetNetPreTrainedModel复制而来,将ProphetNet替换为XLMProphetNet
class XLMProphetNetPreTrainedModel(PreTrainedModel):
    # 配置类为XLMProphetNetConfig
    config_class = XLMProphetNetConfig
    # 基础模型前缀为"prophetnet"
    base_model_prefix = "prophetnet"
    # 支持梯度检查点
    supports_gradient_checkpointing = True

    # 初始化模型权重的函数,根据不同类型的module设置不同的初始化方式
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            # 对线性层的权重进行正态分布初始化
            module.weight.data.normal_(mean=0.0, std=self.config.init_std)
            # 如果存在偏置项,则将其初始化为零
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            # 对嵌入层的权重进行正态分布初始化
            module.weight.data.normal_(mean=0.0, std=self.config.init_std)
            # 如果存在padding_idx,则将其对应位置的权重初始化为零
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

    # 将输入向右移动的函数,用于decoder端的输入准备
    def _shift_right(self, input_ids):
        decoder_start_token_id = self.config.decoder_start_token_id
        pad_token_id = self.config.pad_token_id

        # 断言确保decoder_start_token_id已定义,通常设置为pad_token_id
        assert decoder_start_token_id is not None, (
            "self.model.config.decoder_start_token_id has to be defined. In XLMProphetNet it is usually set to the"
            " pad_token_id. See XLMProphetNet docs for more information"
        )

        # 将输入向右移动一位
        shifted_input_ids = input_ids.new_zeros(input_ids.shape)
        shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
        shifted_input_ids[..., 0] = decoder_start_token_id

        # 断言确保pad_token_id已定义,用于替换labels中可能存在的-100值
        assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
        shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

        # 断言确保shifted_input_ids中所有值都为非负数
        assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"

        return shifted_input_ids


# 从transformers.models.prophetnet.modeling_prophetnet.ProphetNetPositionalEmbeddings复制而来,将ProphetNet替换为XLMProphetNet
class XLMProphetNetPositionalEmbeddings(nn.Embedding):
    """
    This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
    based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to
    the forward function.
    """

    def __init__(self, config: XLMProphetNetConfig) -> None:
        # 最大长度为config中的max_position_embeddings
        self.max_length = config.max_position_embeddings
        super().__init__(config.max_position_embeddings, config.hidden_size, config.pad_token_id)
    # 定义一个方法 forward,用于模型的前向传播
    def forward(self, inputs_shape, device, attention_mask=None, past_key_values=None, position_ids=None):
        # 断言语句,确保 position_ids 为 None 或者 self.padding_idx 未设置
        assert (position_ids is None) or (
            self.padding_idx is None
        ), "If position_ids is pre-computed then padding_idx should not be set."

        # 如果 position_ids 为 None
        if position_ids is None:
            # 如果 past_key_values 不为 None,则在解码单步时 position_ids 对每个 token 都相同
            if past_key_values is not None:
                # 获取过去键值中的输入 token 数量
                prev_num_input_ids = past_key_values[0][0].shape[2]
                # 计算新的输入 token 数量
                num_input_ids = inputs_shape[1] + prev_num_input_ids
                # 计算新的 position_ids,并将其设为 padding_idx 加上 num_input_ids
                position_ids = torch.ones((1, 1), dtype=torch.long, device=device) * (
                    int(self.padding_idx + num_input_ids)
                )
            else:
                # 如果 attention_mask 为 None,则初始化 attention_mask 为全 1 的张量
                if attention_mask is None:
                    attention_mask = torch.ones(inputs_shape, dtype=torch.long, device=device)

                # 从 input_ids / attention_mask 中获取 position_ids
                position_ids = (
                    torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask
                ).long() + self.padding_idx

                # 确保 position_ids 不超过 max_length - 1
                position_ids = position_ids.clamp(0, self.max_length - 1)

        # 调用父类的 forward 方法,并返回其结果以及计算得到的 position_ids
        return super().forward(position_ids), position_ids

    # 定义一个私有方法 _forward,用于调用父类的 forward 方法
    def _forward(self, position_ids):
        return super().forward(position_ids)
# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetAttention with ProphetNet->XLMProphetNet
class XLMProphetNetAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        config: XLMProphetNetConfig,
        num_attn_heads: int,
    ):
        super().__init__()
        hidden_size = config.hidden_size

        self.attention_dropout = config.attention_dropout  # 设置注意力(dropout)的概率
        self.dropout = config.dropout  # 设置全连接层(dropout)的概率
        self.num_attn_heads = num_attn_heads  # 设置注意力头的数量
        self.head_dim = hidden_size // num_attn_heads  # 计算每个注意力头的维度

        assert self.head_dim * num_attn_heads == hidden_size, (
            "`config.hidden_size` must be divisible by `config.num_encoder_attention_heads` and"
            " `config.num_decoder_attention_heads`"
        )

        self.key_proj = nn.Linear(hidden_size, hidden_size)  # 创建线性层,用于计算键的投影
        self.value_proj = nn.Linear(hidden_size, hidden_size)  # 创建线性层,用于计算值的投影
        self.query_proj = nn.Linear(hidden_size, hidden_size)  # 创建线性层,用于计算查询的投影

        self.out_proj = nn.Linear(hidden_size, hidden_size)  # 创建线性层,用于输出投影

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous()  # 重新形状张量,以便进行多头注意力计算

    def forward(
        self,
        hidden_states,
        key_value_states: Optional[Tensor] = None,
        attention_mask: Optional[Tensor] = None,
        layer_head_mask: Optional[Tensor] = None,
        past_key_value: Optional[Tuple[Tensor]] = None,
        output_attentions: bool = False,



# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetFeedForward with ProphetNet->XLMProphetNet
class XLMProphetNetFeedForward(nn.Module):
    """
    This is the residual two feed-forward layer block based on the original Transformer implementation.
    """

    def __init__(self, config: XLMProphetNetConfig, ffn_dim: int):
        super().__init__()
        self.activation_fn = ACT2FN[config.activation_function]  # 设置激活函数
        self.intermediate = nn.Linear(config.hidden_size, ffn_dim)  # 创建线性层,用于中间变换
        self.output = nn.Linear(ffn_dim, config.hidden_size)  # 创建线性层,用于输出变换
        self.activation_dropout = config.activation_dropout  # 设置激活(dropout)的概率
        self.dropout = config.dropout  # 设置全连接层(dropout)的概率

    def forward(self, hidden_states):
        hidden_states = self.intermediate(hidden_states)  # 中间变换
        hidden_states = self.activation_fn(hidden_states)  # 激活函数处理

        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)  # 激活(dropout)
        hidden_states = self.output(hidden_states)  # 输出变换
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)  # 全连接层(dropout)
        return hidden_states



# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetNgramSelfAttention with ProphetNet->XLMProphetNet
class XLMProphetNetNgramSelfAttention(nn.Module):
    # 初始化方法,接受一个配置对象 config:XLMProphetNetConfig
    def __init__(self, config: XLMProphetNetConfig):
        # 调用父类的初始化方法
        super().__init__()
        # 设置隐藏层大小为 config 中的 hidden_size
        self.hidden_size = config.hidden_size

        # 设置桶的数量为 config 中的 num_buckets
        self.num_buckets = config.num_buckets
        # 设置相对最大距离为 config 中的 relative_max_distance
        self.relative_max_distance = config.relative_max_distance
        # 设置注意力头的数量为 config 中的 num_decoder_attention_heads
        self.num_attn_heads = config.num_decoder_attention_heads
        # 设置全连接层的 dropout 率为 config 中的 dropout
        self.dropout = config.dropout
        # 设置注意力机制的 dropout 率为 config 中的 attention_dropout
        self.attention_dropout = config.attention_dropout
        # 设置每个注意力头的维度为 hidden_size / num_attn_heads
        self.head_dim = config.hidden_size // self.num_attn_heads
        # 设置 ngram 参数为 config 中的 ngram

        # 断言条件:确保 hidden_size 能够被 num_attn_heads 整除
        assert (
            self.head_dim * self.num_attn_heads == config.hidden_size
        ), "config.hidden_size must be divisible by num_attn_heads"
        
        # key, value, query 的投影层
        self.key_proj = nn.Linear(config.hidden_size, config.hidden_size)
        self.value_proj = nn.Linear(config.hidden_size, config.hidden_size)
        self.query_proj = nn.Linear(config.hidden_size, config.hidden_size)

        # 输出投影层
        self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)

        # 相对位置编码嵌入层
        self.relative_pos_embeddings = nn.Linear(config.hidden_size, self.num_buckets * self.num_attn_heads)

        # 用于 ONNX 运行时的标志,默认为 False
        self.onnx_trace = False

    # 将张量 tensor 重新整形为 (batch_size, seq_len, num_attn_heads, head_dim),并进行转置和连续性处理
    def _shape(self, tensor, seq_len, batch_size):
        return tensor.view(batch_size, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous()

    # 准备模型用于 ONNX 导出时设置 onnx_trace 标志为 True
    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    # 前向传播方法
    def forward(
        self,
        hidden_states,
        past_key_value: Optional[Tuple[Tensor]] = None,
        attention_mask=None,
        layer_head_mask=None,
        extended_predict_attention_mask=None,
        main_relative_position_buckets=None,
        predict_relative_position_buckets=None,
        position_ids=None,
    ):
    # 获取主要相对位置编码嵌入
    def get_main_relative_pos_embeddings(
        self, hidden_states, attn_weights, position_ids, main_relative_position_buckets
        # input hidden_states [batch_size, sequence_length, hidden_size]
        # input attn_weights [batch_size, num_heads, sequence_length, sequence_length]
        # input position_ids [batch_size, sequence_length] or [1,1]
        batch_size, num_attn_heads, tgt_len, src_len = attn_weights.shape
        # 将注意力权重张量重新调整形状为 [batch_size, num_heads, tgt_len, src_len]
        attn_weights = attn_weights.view(batch_size, num_attn_heads, tgt_len, src_len)
        
        # 如果未提供主要相对位置桶,则计算它们
        if main_relative_position_buckets is None:
            batch_size, sequence_length = hidden_states.shape[:2]
            # 生成相对位置张量,维度为 [batch_size, sequence_length, sequence_length+1]
            relative_positions = (
                torch.arange(1, attn_weights.shape[-1] + 1)
                .unsqueeze(0)
                .unsqueeze(0)
                .repeat(batch_size, sequence_length, 1)
                .to(position_ids.device)
            )
            # 计算相对位置差,并减去位置 ID,形成相对位置差矩阵
            relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1)
            # 计算主要相对位置桶,用于后续的注意力计算
            main_relative_position_buckets = compute_relative_buckets(
                self.num_buckets, self.relative_max_distance, relative_positions, False
            )

        # 计算相对位置编码张量,形状为 [batch_size, sequence_length, num_buckets * num_heads]
        rel_pos_embeddings = self.relative_pos_embeddings(hidden_states)
        # 调整相对位置编码张量的形状为 [batch_size, sequence_length, num_buckets, num_heads]
        rel_pos_embeddings = rel_pos_embeddings.view(
            rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads)
        )
        # 将维度重新排列为 [batch_size, num_heads, sequence_length, num_buckets]
        rel_pos_embeddings = rel_pos_embeddings.permute(0, 3, 1, 2)
        # 调整形状为 [batch_size, num_heads, sequence_length, num_buckets * 1]
        rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:3] + (-1,))

        # 将主要相对位置桶扩展到所有头部,形状为 [batch_size * num_heads * sequence_length, sequence_length]
        main_relative_position_buckets = main_relative_position_buckets.repeat(1, self.num_attn_heads, 1)
        # 调整形状为 [batch_size * num_heads * sequence_length, sequence_length],并转换为长整型
        main_relative_position_buckets = main_relative_position_buckets.view(
            -1, main_relative_position_buckets.shape[-1]
        ).long()
        # 调整相对位置编码张量的形状,以匹配相应的主要相对位置桶
        rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1))

        # 使用索引从相对位置编码张量中聚合主要相对位置桶对应的编码
        main_relative_pos_embeddings = torch.gather(rel_pos_embeddings, dim=1, index=main_relative_position_buckets)
        # 调整形状为 [batch_size, num_heads, tgt_len, num_buckets]
        main_relative_pos_embeddings = main_relative_pos_embeddings.view(batch_size, num_attn_heads, tgt_len, -1)
        # 返回主要相对位置编码张量
        return main_relative_pos_embeddings
    # 定义函数 predict_relative_position_embeddings,接受多个输入参数
    def predict_relative_position_embeddings(
        hidden_states, attn_weights, position_ids, predict_relative_position_buckets=None
    ):
        # 获取 hidden_states 的 batch_size 和 sequence_length 维度大小
        # hidden_states 的形状为 [batch_size, sequence_length, ngram, hidden_size]
        batch_size, sequence_length = hidden_states.shape[0:2]
    
        # 如果 predict_relative_position_buckets 为 None,则计算相对位置
        if predict_relative_position_buckets is None:
            # 获取 attn_weights 的 key_sequence_length 维度大小
            key_sequence_length = attn_weights.shape[-1]
            # 断言检查 position_ids 是否正确,应为 1 2 3 4 5 ... (key_sequence_length - 1)
            assert (
                position_ids[0][0] == key_sequence_length - 1
            ), "`position_ids` are incorrect. They should be of the format 1 2 3 4 5 ... (key_sequence_length - 1)"
            
            # 创建相对位置张量 relative_positions
            relative_positions = (
                torch.arange(0, key_sequence_length)
                .unsqueeze(0)
                .unsqueeze(0)
                .repeat(batch_size, sequence_length, 1)
                .to(position_ids.device)
            )
            
            # 计算相对位置偏移量
            relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1)
            
            # 使用 compute_relative_buckets 计算预测相对位置的桶
            predict_relative_position_buckets = compute_relative_buckets(
                self.num_buckets, self.relative_max_distance, relative_positions, False
            )
    
        # 将 hidden_states 的 ngram 维度与 sequence_length 维度互换位置
        hidden_states = hidden_states.transpose(1, 2)
        
        # 计算相对位置嵌入 rel_pos_embeddings
        rel_pos_embeddings = self.relative_pos_embeddings(hidden_states)
    
        # 调整 rel_pos_embeddings 的形状为 [batch_size, ngram, sequence_length, num_buckets, num_heads]
        rel_pos_embeddings = rel_pos_embeddings.view(
            hidden_states.shape[:-1] + (self.num_buckets, self.num_attn_heads)
        )
        
        # 将 rel_pos_embeddings 的维度顺序重新排列为 [batch_size, ngram, sequence_length, num_heads, num_buckets]
        rel_pos_embeddings = rel_pos_embeddings.permute(0, 2, 1, 4, 3)
        
        # 将 rel_pos_embeddings 展开为二维张量 [batch_size * ngram * sequence_length * num_heads, num_buckets]
        rel_pos_embeddings = rel_pos_embeddings.reshape(-1, self.num_buckets)
        
        # 将 predict_relative_position_buckets 在第 0 维度上增加一个维度
        predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0)
        
        # 在第 0 维度上重复 predict_relative_position_buckets self.ngram 次,
        # 在第 1 维度上重复 batch_size 次,第 2 维度上重复 num_attn_heads 次,最后一维度不变
        predict_relative_position_buckets = predict_relative_position_buckets.repeat(
            self.ngram, 1, self.num_attn_heads, 1
        )
        
        # 将 predict_relative_position_buckets 重塑为二维张量 [ngram * batch_size * num_heads * sequence_length, -1]
        predict_relative_position_buckets = predict_relative_position_buckets.view(
            -1, predict_relative_position_buckets.size(-1)
        ).long()
    
        # 使用 torch.gather 根据 predict_relative_position_buckets 从 rel_pos_embeddings 中获取预测的相对位置嵌入
        predict_relative_pos_embeddings = torch.gather(
            rel_pos_embeddings, dim=1, index=predict_relative_position_buckets
        )
    
        # 将预测的相对位置嵌入 predict_relative_pos_embeddings 重新调整为形状 [batch_size, gram, num_heads, sequence_length, -1]
        predict_relative_pos_embeddings = predict_relative_pos_embeddings.view(
            batch_size, self.ngram, self.num_attn_heads, sequence_length, -1
        )
    
        # 返回预测的相对位置嵌入 predict_relative_pos_embeddings
        return predict_relative_pos_embeddings
# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetEncoderLayer with ProphetNet->XLMProphetNet, Prophetnet->XLMProphetnet
class XLMProphetNetEncoderLayer(nn.Module):
    """
    Encoder block for XLMProphetnet
    """

    def __init__(self, config: XLMProphetNetConfig):
        super().__init__()
        # 1st residual block
        # 定义自注意力机制层,使用XLMProphetNetAttention模块,配置头数为config.num_encoder_attention_heads
        self.self_attn = XLMProphetNetAttention(config, config.num_encoder_attention_heads)
        # 定义Layer Normalization层,用于自注意力输出的归一化
        self.self_attn_layer_norm = LayerNorm(config.hidden_size)

        # 2nd residual block
        # 定义前馈神经网络层,使用XLMProphetNetFeedForward模块,配置隐藏层大小为config.encoder_ffn_dim
        self.feed_forward = XLMProphetNetFeedForward(config, config.encoder_ffn_dim)
        # 定义Layer Normalization层,用于前馈神经网络输出的归一化
        self.feed_forward_layer_norm = LayerNorm(config.hidden_size)

    def forward(
        self,
        hidden_states,
        attention_mask,
        layer_head_mask,
        output_attentions: bool = False,
    ):
        # 1st residual block
        # 执行自注意力机制,获取注意力输出、注意力权重和无用信息,更新隐藏状态
        attention_output, attn_weights, _ = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            layer_head_mask=layer_head_mask,
            output_attentions=output_attentions,
        )
        # 对注意力输出和原始隐藏状态进行残差连接后,再进行Layer Normalization
        hidden_states = self.self_attn_layer_norm(attention_output + hidden_states)

        # 2nd residual block
        # 执行前馈神经网络,得到前馈网络的输出
        feed_forward_output = self.feed_forward(hidden_states)
        # 对前馈网络的输出和原始隐藏状态进行残差连接后,再进行Layer Normalization
        hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states)

        # 组装输出结果
        outputs = (hidden_states,)

        if output_attentions:
            outputs += (attn_weights,)  # 如果需要输出注意力权重,则添加到输出结果中

        return outputs


# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetDecoderLayer with Prophetnet->XLMProphetnet, ProphetNet->XLMProphetNet
class XLMProphetNetDecoderLayer(nn.Module):
    """
    Decoder block for XLMProphetnet
    """

    def __init__(self, config: XLMProphetNetConfig):
        super().__init__()
        # 1st residual block
        # 定义N-gram自注意力机制层,使用XLMProphetNetNgramSelfAttention模块
        self.self_attn = XLMProphetNetNgramSelfAttention(config)
        # 定义Layer Normalization层,用于自注意力输出的归一化
        self.self_attn_layer_norm = LayerNorm(config.hidden_size)

        # 2nd residual block
        # 如果配置了交叉注意力,定义交叉注意力机制层,使用XLMProphetNetAttention模块,配置头数为config.num_decoder_attention_heads
        if config.add_cross_attention:
            self.cross_attn = XLMProphetNetAttention(config, config.num_decoder_attention_heads)
            # 定义Layer Normalization层,用于交叉注意力输出的归一化
            self.cross_attn_layer_norm = LayerNorm(config.hidden_size)

        # 3rd residual block
        # 定义前馈神经网络层,使用XLMProphetNetFeedForward模块,配置隐藏层大小为config.decoder_ffn_dim
        self.feed_forward = XLMProphetNetFeedForward(config, config.decoder_ffn_dim)
        # 定义Layer Normalization层,用于前馈神经网络输出的归一化
        self.feed_forward_layer_norm = LayerNorm(config.hidden_size)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attn_mask=None,
        layer_head_mask=None,
        cross_attn_layer_head_mask=None,
        extended_predict_attention_mask=None,
        main_relative_position_buckets=None,
        predict_relative_position_buckets=None,
        position_ids=None,
        past_key_value=None,
        use_cache: bool = True,
        output_attentions: bool = False,
    ):
        # 1st residual block
        # 执行N-gram自注意力机制,更新隐藏状态
        attention_output = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            layer_head_mask=layer_head_mask,
            extended_predict_attention_mask=extended_predict_attention_mask,
            main_relative_position_buckets=main_relative_position_buckets,
            predict_relative_position_buckets=predict_relative_position_buckets,
            position_ids=position_ids,
            past_key_value=past_key_value,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        # 对注意力输出和原始隐藏状态进行残差连接后,再进行Layer Normalization
        hidden_states = self.self_attn_layer_norm(attention_output + hidden_states)

        # 2nd residual block
        if config.add_cross_attention:
            # 执行交叉注意力机制,获取注意力输出,更新隐藏状态
            cross_attention_output = self.cross_attn(
                hidden_states=hidden_states,
                attention_mask=encoder_attn_mask,
                layer_head_mask=cross_attn_layer_head_mask,
                encoder_hidden_states=encoder_hidden_states,
                output_attentions=output_attentions,
            )
            # 对交叉注意力输出和原始隐藏状态进行残差连接后,再进行Layer Normalization
            hidden_states = self.cross_attn_layer_norm(cross_attention_output + hidden_states)

        # 3rd residual block
        # 执行前馈神经网络,得到前馈网络的输出
        feed_forward_output = self.feed_forward(hidden_states)
        # 对前馈网络的输出和原始隐藏状态进行残差连接后,再进行Layer Normalization
        hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states)

        # 组装输出结果
        outputs = (hidden_states,)

        if output_attentions:
            outputs += (attention_output[-1],)  # 如果需要输出注意力权重,则添加到输出结果中

        return outputs
        ):
            # 1st residual block
            # 如果过去的键/值对存在,则从中获取自注意力缓存的键/值对的前两个位置,否则设为 None
            self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
            # 使用自注意力模型处理隐藏状态,生成 ngram_attention_output 是自注意力输出,self_attn_weights 是自注意力权重,self_attn_weights_ngram 是 ngram 注意力权重,present_key_value 是当前的键/值对
            ngram_attention_output, self_attn_weights, self_attn_weights_ngram, present_key_value = self.self_attn(
                hidden_states=hidden_states,
                past_key_value=self_attn_past_key_value,
                attention_mask=attention_mask,
                layer_head_mask=layer_head_mask,
                extended_predict_attention_mask=extended_predict_attention_mask,
                main_relative_position_buckets=main_relative_position_buckets,
                predict_relative_position_buckets=predict_relative_position_buckets,
                position_ids=position_ids,
            )
            # 将自注意力输出与原始隐藏状态相加,并进行 Layer Normalization
            hidden_states = self.self_attn_layer_norm(hidden_states + ngram_attention_output)

            # 如果过去的键/值对存在,则从中获取交叉注意力缓存的键/值对的后两个位置,否则设为 None
            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
            cross_attn_weights = None
            if encoder_hidden_states is not None:
                # 2nd residual block
                # 如果编码器的隐藏状态存在,则使用交叉注意力模型处理隐藏状态与编码器的键/值状态,生成 attention_output 是交叉注意力输出,cross_attn_weights 是交叉注意力权重,cross_attn_present_key_value 是当前的键/值对
                attention_output, cross_attn_weights, cross_attn_present_key_value = self.cross_attn(
                    hidden_states=hidden_states,
                    key_value_states=encoder_hidden_states,
                    attention_mask=encoder_attn_mask,
                    layer_head_mask=cross_attn_layer_head_mask,
                    past_key_value=cross_attn_past_key_value,
                    output_attentions=output_attentions,
                )
                # 将交叉注意力输出与原始隐藏状态相加,并进行 Layer Normalization
                hidden_states = self.cross_attn_layer_norm(attention_output + hidden_states)

                # 将交叉注意力的键/值对添加到 present_key_value 中的后两个位置
                present_key_value = present_key_value + cross_attn_present_key_value

            # 3rd residual block
            # 使用前馈神经网络处理隐藏状态,生成 feed_forward_output 是前馈神经网络输出
            feed_forward_output = self.feed_forward(hidden_states)
            # 将前馈神经网络的输出与原始隐藏状态相加,并进行 Layer Normalization
            hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states)

            # 将最终的隐藏状态作为输出
            outputs = (hidden_states,)

            # 如果需要输出注意力权重,则将自注意力和交叉注意力权重添加到输出中
            if output_attentions:
                outputs += (self_attn_weights, self_attn_weights_ngram, cross_attn_weights)

            # 如果需要使用缓存,则将当前的键/值对添加到输出中
            if use_cache:
                outputs += (present_key_value,)

            # 返回最终的输出元组
            return outputs
# 添加起始文档字符串,描述 XLMProphetNetModel 的独立编码器部分
@add_start_docstrings(
    "The standalone encoder part of the XLMProphetNetModel.",
    XLM_PROPHETNET_START_DOCSTRING,
)
# 从 transformers.models.prophetnet.modeling_prophetnet.ProphetNetEncoder 复制而来,做了如下更改:microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET
class XLMProphetNetEncoder(XLMProphetNetPreTrainedModel):
    r"""
    word_embeddings  (`torch.nn.Embeddings` of shape `(config.vocab_size, config.hidden_size)`, *optional*):
        The word embedding parameters. This can be used to initialize [`XLMProphetNetEncoder`] with pre-defined word
        embeddings instead of randomly initialized word embeddings.
    """

    def __init__(self, config: XLMProphetNetConfig, word_embeddings: nn.Embedding = None):
        super().__init__(config)

        # 初始化词嵌入,如果未提供则随机初始化,并设置填充索引
        self.word_embeddings = (
            word_embeddings
            if word_embeddings is not None
            else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        )
        # 初始化位置嵌入
        self.position_embeddings = XLMProphetNetPositionalEmbeddings(config)
        # 初始化嵌入层的 LayerNorm
        self.embeddings_layer_norm = LayerNorm(config.hidden_size)

        # 创建编码器层列表,每层都是 XLMProphetNetEncoderLayer 类的实例
        self.layers = nn.ModuleList([XLMProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)])

        # 是否使用梯度检查点
        self.gradient_checkpointing = False
        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        # 返回输入的词嵌入
        return self.word_embeddings

    def set_input_embeddings(self, value):
        # 设置输入的词嵌入
        self.word_embeddings = value

    # 添加起始文档字符串到模型的 forward 方法,提供 XLM_PROPHETNET_STANDALONE_INPUTS_DOCSTRING 描述
    @add_start_docstrings_to_model_forward(XLM_PROPHETNET_STANDALONE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    # 使用给定的配置和可选的词嵌入初始化模型
    def __init__(self, config: XLMProphetNetConfig, word_embeddings: Optional[nn.Embedding] = None):
        # 调用父类的初始化方法
        super().__init__(config)

        # 从配置中获取参数并设置为对象属性
        self.ngram = config.ngram  # ngram 参数
        self.num_buckets = config.num_buckets  # 桶的数量
        self.relative_max_distance = config.relative_max_distance  # 相对最大距离
        self.dropout = config.dropout  # dropout 比率
        self.max_target_positions = config.max_position_embeddings  # 最大目标位置数

        # 如果提供了词嵌入,则使用提供的;否则创建一个新的词嵌入对象
        self.word_embeddings = (
            word_embeddings
            if word_embeddings is not None
            else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        )
        # 创建位置嵌入对象
        self.position_embeddings = XLMProphetNetPositionalEmbeddings(config)

        # 创建 ngram 嵌入对象
        self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None)
        # 创建多个解码层,并组成一个模块列表
        self.layers = nn.ModuleList([XLMProphetNetDecoderLayer(config) for _ in range(config.num_decoder_layers)])
        # 创建用于层归一化的对象
        self.embeddings_layer_norm = LayerNorm(config.hidden_size)

        # 初始化梯度检查点标志为 False
        self.gradient_checkpointing = False
        # 执行初始化权重和最终处理步骤
        self.post_init()

    # 返回模型的输入词嵌入
    def get_input_embeddings(self):
        return self.word_embeddings

    # 设置模型的输入词嵌入
    def set_input_embeddings(self, value):
        self.word_embeddings = value

    # 前向传播函数,具有详细的文档字符串和输出文档的替换
    @add_start_docstrings_to_model_forward(XLM_PROPHETNET_STANDALONE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=XLMProphetNetDecoderModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    # 计算带缓冲的相对桶
    def compute_buffered_relative_buckets(self, position_ids):
        # 获取批次大小和序列长度
        batch_size, sequence_length = position_ids.shape

        # 创建从1到self.max_target_positions的整数序列,并移到设备上
        position_ids = torch.arange(1, self.max_target_positions).to(position_ids.device).repeat(batch_size, 1)
        
        # 计算主相对桶和预测相对桶
        main_relative_buckets, predict_relative_buckets = compute_all_stream_relative_buckets(
            self.num_buckets, self.relative_max_distance, position_ids
        )

        # 缓冲主相对桶
        main_relative_buckets = main_relative_buckets[:, :sequence_length, :sequence_length].repeat(batch_size, 1, 1)
        
        # 缓冲预测相对桶,包括当前目标位置和扩展的序列长度部分
        predict_relative_buckets = torch.cat(
            [
                predict_relative_buckets[:, :sequence_length, :sequence_length],
                predict_relative_buckets[
                    :, :sequence_length, self.max_target_positions : self.max_target_positions + sequence_length
                ],
            ],
            2,
        ).repeat(batch_size, 1, 1)

        # 返回主相对桶和预测相对桶
        return main_relative_buckets, predict_relative_buckets

    # 准备注意力掩码
    def prepare_attention_mask(self, hidden_states, attention_mask):
        # 获取批次大小和序列长度
        batch_size, seq_length = hidden_states.shape[:2]

        # 获取因果掩码,用最小值填充
        causal_mask = torch.full(
            (seq_length, seq_length),
            torch.finfo(hidden_states.dtype).min,
            dtype=hidden_states.dtype,
            device=hidden_states.device,
        )
        causal_mask = torch.triu(causal_mask, 1)  # 取上三角部分作为因果掩码

        # 扩展因果掩码以适应批次和注意力头数
        extended_causal_mask = causal_mask[:seq_length, :seq_length][None, None, :, :].expand(
            (batch_size, self.config.num_decoder_attention_heads) + causal_mask.shape
        )

        # 添加常规注意力掩码
        if attention_mask is not None:
            extended_attention_mask = (1.0 - attention_mask[:, None, None, :]) * torch.finfo(hidden_states.dtype).min
            extended_attention_mask = extended_causal_mask + extended_attention_mask
        else:
            extended_attention_mask = extended_causal_mask

        # 将注意力掩码转换为hidden_states的dtype并返回
        return extended_attention_mask.to(hidden_states.dtype)
    # 定义一个方法,准备预测用的注意力掩码
    def prepare_predict_attention_mask(self, hidden_states, attention_mask):
        # 获取批次大小和序列长度
        batch_size, seq_length = hidden_states.shape[:2]

        # 获取预测用因果掩码
        predict_causal_mask = ngram_attention_bias(
            self.max_target_positions, self.ngram, hidden_states.device, hidden_states.dtype
        )
        # 将因果掩码按照特定规则连接起来,以适应预测流的需要
        predict_causal_mask = torch.cat(
            [
                predict_causal_mask[:, :seq_length, :seq_length],
                predict_causal_mask[
                    :, :seq_length, self.max_target_positions : self.max_target_positions + seq_length
                ],
            ],
            dim=-1,
        )
        # 扩展因果掩码以适应批次和注意力头数目
        extended_predict_causal_mask = predict_causal_mask[None, None, :, :, :].expand(
            (batch_size, self.config.num_decoder_attention_heads) + predict_causal_mask.shape
        )

        # 添加常规注意力掩码(如果有)
        if attention_mask is not None:
            # 根据注意力掩码生成扩展的注意力掩码,负无穷处保持不变
            extended_attention_mask = (1.0 - attention_mask[:, None, None, None, :]) * torch.finfo(self.dtype).min
            extended_attention_mask = extended_attention_mask.expand(
                (batch_size, self.config.num_decoder_attention_heads, self.ngram, seq_length, seq_length)
            )
            # 预测流的注意力掩码应始终为0,将其连接到扩展的注意力掩码中
            extended_attention_mask = torch.cat(
                [extended_attention_mask, torch.zeros_like(extended_attention_mask)], dim=-1
            )
            extended_predict_attention_mask = extended_predict_causal_mask + extended_attention_mask
        else:
            extended_predict_attention_mask = extended_predict_causal_mask

        # 返回最终的扩展预测注意力掩码,转换为隐藏状态的数据类型
        return extended_predict_attention_mask.to(hidden_states.dtype)
# 为 XLMProphetNetModel 类添加文档字符串,描述其作为 XLMProphetNetPreTrainedModel 的子类,以及模型输出原始隐藏状态的特性
@add_start_docstrings(
    "The bare XLMProphetNet Model outputting raw hidden-states without any specific head on top.",
    XLM_PROPHETNET_START_DOCSTRING,
)
# 从 transformers.models.prophetnet.modeling_prophetnet.ProphetNetModel 复制并修改的 XLMProphetNetModel 类
# 原始模型地址由 microsoft/prophetnet-large-uncased 更改为 patrickvonplaten/xprophetnet-large-uncased-standalone,
# 类名由 ProphetNetModel 更改为 XLMProphetNetModel,相关常量和字符串也做相应的修改
class XLMProphetNetModel(XLMProphetNetPreTrainedModel):
    # 指定了 encoder 和 decoder 共享权重的键名列表
    _tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight"]

    def __init__(self, config: XLMProphetNetConfig):
        super().__init__(config)
        # 初始化词嵌入层,使用配置中的词汇大小、隐藏层大小和填充标识符
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)

        # 复制配置以初始化编码器和解码器,确保配置的一致性和独立性
        encoder_config = copy.deepcopy(config)
        encoder_config.is_encoder_decoder = False
        encoder_config.use_cache = False
        # 初始化编码器,传入编码器配置和词嵌入层
        self.encoder = XLMProphetNetEncoder(encoder_config, self.word_embeddings)

        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
        decoder_config.is_encoder_decoder = False
        # 初始化解码器,传入解码器配置和词嵌入层
        self.decoder = XLMProphetNetDecoder(decoder_config, self.word_embeddings)

        # 初始化权重并应用最终处理
        self.post_init()

    # 获取输入词嵌入层
    def get_input_embeddings(self):
        return self.word_embeddings

    # 设置输入词嵌入层
    def set_input_embeddings(self, value):
        self.word_embeddings = value
        self.encoder.word_embeddings = self.word_embeddings
        self.decoder.word_embeddings = self.word_embeddings

    # 绑定编码器和解码器的词嵌入权重
    def _tie_weights(self):
        if self.config.tie_word_embeddings:
            self._tie_or_clone_weights(self.encoder.word_embeddings, self.word_embeddings)
            self._tie_or_clone_weights(self.decoder.word_embeddings, self.word_embeddings)

    # 获取编码器对象
    def get_encoder(self):
        return self.encoder

    # 获取解码器对象
    def get_decoder(self):
        return self.decoder

    # 前向传播函数,接受多个输入参数并返回模型输出
    @add_start_docstrings_to_model_forward(XLM_PROPHETNET_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=XLMProphetNetSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.Tensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        decoder_inputs_embeds: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    "The XLMProphetNet Model with a language modeling head. Can be used for sequence generation tasks.",

# 定义了一个字符串,描述了带有语言建模头部的 XLMProphetNet 模型,适用于序列生成任务。

    XLM_PROPHETNET_START_DOCSTRING,

# 引用了常量 XLM_PROPHETNET_START_DOCSTRING,可能是用于生成模型文档字符串的起始标记。
# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetForConditionalGeneration with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET
class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
    _tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight", "lm_head.weight"]

    def __init__(self, config: XLMProphetNetConfig):
        super().__init__(config)
        self.prophetnet = XLMProphetNetModel(config)  # 初始化ProphetNet模型
        self.padding_idx = config.pad_token_id  # 设置填充索引
        self.disable_ngram_loss = config.disable_ngram_loss  # 禁用N-gram损失

        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)  # 初始化线性层用于语言建模头

        # Initialize weights and apply final processing
        self.post_init()  # 调用后续初始化方法

    def get_output_embeddings(self):
        return self.lm_head  # 返回语言建模头的权重

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings  # 设置新的语言建模头权重

    def _tie_weights(self):
        if self.config.tie_word_embeddings:
            self._tie_or_clone_weights(self.prophetnet.word_embeddings, self.lm_head)  # 如果需要,则绑定或克隆词嵌入的权重到语言建模头

    def get_input_embeddings(self):
        return self.prophetnet.word_embeddings  # 返回ProphetNet模型的词嵌入层

    @add_start_docstrings_to_model_forward(XLM_PROPHETNET_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=XLMProphetNetSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.Tensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[torch.Tensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        decoder_inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        ):
        """
        此方法实现了XLMProphetNetForConditionalGeneration的前向传播逻辑,接受多个输入参数,并返回模型输出。
        """
        # 实现详细的前向传播逻辑...
    # 计算损失函数,用于模型训练过程中的损失计算
    def _compute_loss(self, logits, labels, ignore_index=-100):
        # 创建与labels相同维度的零张量,用于存储扩展后的目标标签
        expend_targets = labels.new_zeros(self.config.ngram, labels.size(0), labels.size(1)).fill_(ignore_index)

        # 根据配置参数ngram扩展目标标签,用于计算ngram损失
        for i in range(self.config.ngram):
            if i > 0 and self.disable_ngram_loss:
                break
            expend_targets[i, :, :] = labels

        # 调整logits的维度顺序以便计算损失
        logits = logits.transpose(0, 1).contiguous()
        # 对logits进行log_softmax操作,用于计算负对数似然损失
        lprobs = nn.functional.log_softmax(
            logits.view(-1, logits.size(-1)),
            dim=-1,
            dtype=torch.float32,
        )

        # 计算负对数似然损失
        loss = nn.functional.nll_loss(lprobs, expend_targets.view(-1), reduction="mean")

        # 如果配置中的平滑因子eps大于0,则进行标签平滑处理
        if self.config.eps > 0.0:
            smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
            non_masked_tokens = expend_targets.ne(ignore_index).view(-1)
            smooth_loss = smooth_loss[non_masked_tokens]
            smooth_loss = smooth_loss.mean()

            eps_i = self.config.eps / lprobs.size(-1)
            loss = (1.0 - self.config.eps) * loss + eps_i * smooth_loss

        # 返回计算得到的损失
        return loss

    # 生成过程中准备输入,返回用于生成的输入参数字典
    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        past_key_values=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):
        # 断言encoder_outputs不为None,确保生成过程中有编码器输出
        assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation."

        # 如果有过去的键值,将decoder_input_ids限制为最后一个token
        if past_key_values:
            decoder_input_ids = decoder_input_ids[:, -1:]
        
        # 返回生成过程所需的参数字典
        return {
            "input_ids": None,  # encoder_outputs已定义,不需要input_ids
            "encoder_outputs": encoder_outputs,
            "past_key_values": past_key_values,
            "decoder_input_ids": decoder_input_ids,
            "attention_mask": attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,
        }

    # 根据标签准备解码器输入ids,用于解码器生成过程
    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        return self._shift_right(labels)

    # 重新排序缓存数据,用于生成过程中的beam搜索
    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            # 对每一层的缓存数据进行重新排序,以便与beam搜索结果匹配
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
                + layer_past[2:],
            )
        return reordered_past

    # 获取编码器模型
    def get_encoder(self):
        return self.prophetnet.encoder

    # 获取解码器模型
    def get_decoder(self):
        return self.prophetnet.decoder
@add_start_docstrings(
    "The standalone decoder part of the XLMProphetNetModel with a lm head on top. The model can be used for causal"
    " language modeling.",
    XLM_PROPHETNET_START_DOCSTRING,
)
# 定义 XLMProphetNetForCausalLM 类,继承自 XLMProphetNetPreTrainedModel
# 这个类是 XLMProphetNet 模型的独立解码器部分,顶部带有语言建模头
# 可用于因果语言建模。

class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel):
    # 静态成员变量,用于指定需要共享权重的层
    _tied_weights_keys = [
        "prophetnet.word_embeddings.weight",
        "prophetnet.decoder.word_embeddings.weight",
        "lm_head.weight",
    ]

    def __init__(self, config: XLMProphetNetConfig):
        # 设置用于条件语言建模的配置
        config = copy.deepcopy(config)
        config.is_decoder = True  # 设置为解码器
        config.is_encoder_decoder = False  # 不是编码解码模型
        super().__init__(config)  # 调用父类构造函数,初始化配置
        self.prophetnet = XLMProphetNetDecoderWrapper(config)  # 初始化 XLMProphetNetDecoderWrapper

        self.padding_idx = config.pad_token_id  # 设置填充符索引
        self.disable_ngram_loss = config.disable_ngram_loss  # 是否禁用 ngram 损失

        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        # 初始化语言建模头,线性层映射到词汇表大小,无偏置

        # 初始化权重并进行最终处理
        self.post_init()

    def get_input_embeddings(self):
        # 返回输入嵌入层,即 prophetnet 解码器的词嵌入层
        return self.prophetnet.decoder.word_embeddings

    def set_input_embeddings(self, value):
        # 设置输入嵌入层
        self.prophetnet.decoder.word_embeddings = value

    def get_output_embeddings(self):
        # 返回输出嵌入层,即语言建模头
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        # 设置输出嵌入层
        self.lm_head = new_embeddings

    def _tie_weights(self):
        # 如果配置要求共享词嵌入权重,则共享解码器词嵌入层和语言建模头的权重
        if self.config.tie_word_embeddings:
            self._tie_or_clone_weights(self.prophetnet.decoder.word_embeddings, self.lm_head)

    def set_decoder(self, decoder):
        # 设置解码器
        self.prophetnet.decoder = decoder

    def get_decoder(self):
        # 获取解码器
        return self.prophetnet.decoder

    @add_start_docstrings_to_model_forward(XLM_PROPHETNET_STANDALONE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=XLMProphetNetDecoderLMOutput, config_class=_CONFIG_FOR_DOC)
    # 重写 forward 方法,添加模型输入的文档字符串和输出的类型说明
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        # 模型前向传播方法,包含多个输入和控制参数

        # 输出是否返回字典格式结果
        return_dict: Optional[bool] = None,


注释:
    # 定义一个方法用于计算损失函数,接收模型预测的logits、真实标签、以及一个忽略索引值(默认为-100)
    def _compute_loss(self, logits, labels, ignore_index=-100):
        # 创建一个与labels相同数据类型和形状的全零张量,填充值为ignore_index,形状为(self.config.ngram, labels.size(0), labels.size(1))
        expend_targets = labels.new_zeros(self.config.ngram, labels.size(0), labels.size(1)).fill_(ignore_index)

        # 循环创建ngram个维度的标签张量,用于损失计算
        for i in range(self.config.ngram):
            # 如果i大于0并且self.disable_ngram_loss为True,则退出循环
            if i > 0 and self.disable_ngram_loss:
                break
            # 将labels复制到第i维的标签张量中
            expend_targets[i, :, :] = labels

        # 转置logits张量,使其形状变为(序列长度, 批次大小, 类别数),并保证内存连续性
        logits = logits.transpose(0, 1).contiguous()
        # 对logits进行log_softmax操作,计算对数概率,dim=-1表示沿着最后一个维度进行softmax操作,dtype=torch.float32指定数据类型
        lprobs = nn.functional.log_softmax(
            logits.view(-1, logits.size(-1)),  # 将logits视图展平为二维张量
            dim=-1,
            dtype=torch.float32,
        )

        # 计算负对数似然损失,将lprobs视图展平为一维张量,expend_targets也展平为一维张量,reduction="mean"表示计算均值
        loss = nn.functional.nll_loss(lprobs, expend_targets.view(-1), reduction="mean")

        # 如果配置参数self.config.eps大于0.0
        if self.config.eps > 0.0:
            # 计算平滑损失,对lprobs在最后一个维度求和并保持维度不变
            smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
            # 获取非遮蔽标记的令牌,即expend_targets不等于ignore_index的元素视图
            non_masked_tokens = expend_targets.ne(ignore_index).view(-1)
            # 根据非遮蔽标记的令牌,重新计算smooth_loss的均值
            smooth_loss = smooth_loss[non_masked_tokens].mean()

            # 计算eps_i,即self.config.eps除以lprobs的最后一个维度的长度
            eps_i = self.config.eps / lprobs.size(-1)
            # 计算最终损失,结合平滑损失和eps_i的影响
            loss = (1.0 - self.config.eps) * loss + eps_i * smooth_loss

        # 返回计算得到的损失值
        return loss

    # 定义一个方法,准备生成过程中的输入参数,接收input_ids等参数及其它关键字参数
    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        head_mask=None,
        use_cache=None,
        **kwargs,
    ):
        # 如果注意力掩码为None,则创建一个与input_ids形状相同的全1张量作为注意力掩码
        if attention_mask is None:
            attention_mask = input_ids.new_ones(input_ids.shape)

        # 如果past_key_values不为None,则只保留input_ids的最后一个时间步作为输入
        if past_key_values:
            input_ids = input_ids[:, -1:]

        # 返回准备好的输入参数字典,包括input_ids、attention_mask、head_mask、past_key_values和use_cache
        # input_ids不需要在这里定义,因为encoder_outputs已经定义了
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "head_mask": head_mask,
            "past_key_values": past_key_values,
            "use_cache": use_cache,
        }

    @staticmethod
    # 从transformers.models.bart.modeling_bart.BartForCausalLM._reorder_cache中复制而来的方法
    def _reorder_cache(past_key_values, beam_idx):
        # 重新排序过的过去键值对的元组
        reordered_past = ()
        # 遍历过去的每一层键值对
        for layer_past in past_key_values:
            # 对每个过去状态,根据beam_idx重新排序,并放置到reordered_past中
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        # 返回重新排序后的过去键值对
        return reordered_past
# 从transformers.models.prophetnet.modeling_prophetnet.ProphetNetDecoderWrapper复制而来,将ProphetNet->XLMProphetNet,prophetnet->XLMProphetNet
class XLMProphetNetDecoderWrapper(XLMProphetNetPreTrainedModel):
    """
    这是一个包装类,使得[`XLMProphetNetForCausalLM`]能够从预训练的XLMProphetNet类正确加载。
    """

    def __init__(self, config: XLMProphetNetConfig):
        super().__init__(config)

        # 初始化词嵌入层,使用给定的词汇表大小、隐藏大小和填充标记ID
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        # 初始化解码器,传入配置和词嵌入层
        self.decoder = XLMProphetNetDecoder(config, word_embeddings=self.word_embeddings)

        # 初始化权重并应用最终处理
        self.post_init()

    def _tie_weights(self):
        # 将词嵌入层的权重与解码器的输入嵌入层权重绑定
        self._tie_or_clone_weights(self.word_embeddings, self.decoder.get_input_embeddings())

    def forward(self, *args, **kwargs):
        # 前向传播,调用解码器的前向方法
        return self.decoder(*args, **kwargs)

.\models\xlm_prophetnet\tokenization_xlm_prophetnet.py

# coding=utf-8
# Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections  # 引入collections模块,用于数据结构的操作
import os  # 引入os模块,用于文件系统操作
from shutil import copyfile  # 引入shutil模块的copyfile函数,用于复制文件
from typing import Any, Dict, List, Optional, Tuple  # 引入常见数据类型定义
from ...tokenization_utils import PreTrainedTokenizer  # 引入预训练模型的文本分割工具
from ...utils import logging  # 引入库中通用的日志记录工具

# 初始化日志记录器
logger = logging.get_logger(__name__)

# 句子分片下划线符号(用于分词标识)
SPIECE_UNDERLINE = "▁"

# 预训练模型的各种名字——主要指的是词典文件名
VOCAB_FILES_NAMES = {"vocab_file": "prophetnet.tokenizer"}

# 预训练模型提供者映射词典及其模型位置
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "microsoft/xprophetnet-large-wiki100-cased": (
            "https://huggingface.co/microsoft/xprophetnet-large-wiki100-cased/resolve/main/prophetnet.tokenizer"
        ),
    }
}

# 预训练初始化配置参数
PRETRAINED_INIT_CONFIGURATION = {
    "microsoft/xprophetnet-large-wiki100-cased": {"do_lower_case": False},
}

# 预训练模型中预定义的position嵌入大小
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "microsoft/xprophetnet-large-wiki100-cased": 512,
}

# 加载词典文件中的词汇
def load_vocab(vocab_file):
    """
    将词汇文件加载到字典中

    :param vocab_file: 词汇文件路径和名称
    :type vocab_file: str
    """
    vocab = collections.OrderedDict()  # 初始化排序后的词汇字典
    with open(vocab_file, "r", encoding="utf-8") as reader:  # 打开词汇文件
        tokens = reader.readlines()  # 读取文件所有内容
    for index, token in enumerate(tokens):  # 遍历每一个词汇及其索引
        token = token.rstrip("\n")  # 移除字符串尾部的换行符
        vocab[token] = index  # 将词汇添加到词汇字典中,并指派相应索引
    return vocab

# 定义用于处理各种语言模型的类 - XLMProphetNetTokenizer
# 该类继承自 PreTrainedTokenizer 类,并且:
# - 使用了 SentencePiece 技术解决了分词问题
# - 将文本转化为模型能够处理的序列
# - 含有用于辅助加载预训练模型参数的方法和属性
    # 定义一个函数,用于初始化一个词汇表的配置
    Args:
        vocab_file (`str`):
            # 词汇表文件的路径

        bos_token (`str`, *optional*, defaults to `"[SEP]"`):
            # 序列开始的特殊标记,用于预训练。在构建序列时,实际使用的是 `cls_token`。

        eos_token (`str`, *optional*, defaults to `"[SEP]"`):
            # 序列结束的特殊标记。在构建序列时,实际使用的是 `sep_token`。

        sep_token (`str`, *optional*, defaults to `"[SEP]"`):
            # 分隔标记,在构建多个序列的时候使用,例如序列分类或问答任务中的问题和文本分隔。

        unk_token (`str`, *optional*, defaults to `"[UNK]"`):
            # 未知标记。词汇表中不存在的标记将会被替换为该标记。

        pad_token (`str`, *optional*, defaults to `"[PAD]"`):
            # 填充标记,用于对不同长度的序列进行批处理时进行填充。

        cls_token (`str`, *optional*, defaults to `"[CLS]"`):
            # 分类器标记,在序列分类任务中,是构建序列时的第一个标记。

        mask_token (`str`, *optional*, defaults to `"[MASK]"`):
            # 掩码标记,用于掩码语言建模训练中,模型将尝试预测这些标记。

        sp_model_kwargs (`dict`, *optional*):
            # 传递给 `SentencePieceProcessor.__init__()` 方法的参数字典,用于设置 SentencePiece 模型的初始化参数。
            # 可用的参数包括 `enable_sampling`(启用子词正则化)、`nbest_size`(用于unigram的采样参数,对于BPE-Dropout无效)、
            # `alpha`(unigram采样的平滑参数和BPE-dropout的合并操作的dropout概率)等。
            # 参考 [Python wrapper for SentencePiece](https://github.com/google/sentencepiece/tree/master/python) 获取更多信息。
    Attributes:
        sp_model (`SentencePieceProcessor`):
            The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
    """



    # 定义类变量,指定了模型使用的词汇文件名列表
    vocab_files_names = VOCAB_FILES_NAMES
    # 定义类变量,指定了预训练模型使用的词汇文件映射
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    # 定义类变量,指定了预训练模型的最大输入尺寸
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    # 定义类变量,指定了模型的输入名称列表
    model_input_names = ["input_ids", "attention_mask"]



    def __init__(
        self,
        vocab_file,
        bos_token="[SEP]",
        eos_token="[SEP]",
        sep_token="[SEP]",
        unk_token="[UNK]",
        pad_token="[PAD]",
        cls_token="[CLS]",
        mask_token="[MASK]",
        sp_model_kwargs: Optional[Dict[str, Any]] = None,
        **kwargs,
    ) -> None:
        # 如果没有提供 sp_model_kwargs 则设为空字典
        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs

        try:
            # 尝试导入 sentencepiece 库
            import sentencepiece as spm
        except ImportError:
            # 如果导入失败,给出警告并提示用户安装 SentencePiece 库的链接和安装指令
            logger.warning(
                "You need to install SentencePiece to use XLMRobertaTokenizer: https://github.com/google/sentencepiece"
                " pip install sentencepiece"
            )
            raise

        # 初始化 sp_model 属性,使用给定的 sp_model_kwargs 创建 SentencePieceProcessor 对象
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        # 加载词汇文件到 sp_model
        self.sp_model.Load(str(vocab_file))
        # 保存词汇文件路径到 vocab_file 属性
        self.vocab_file = vocab_file

        # 原始 fairseq 的词汇和 spm 的词汇必须是“对齐”的:
        # Vocab    |    0    |    1    |   2    |    3    |  4  |  5  |  6  |   7   |   8   |  9
        # -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
        # fairseq  | '<s>'   | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's'   | '▁de' | '-'
        # spm      | '<unk>' | '<s>'   | '</s>' | ','     | '.' | '▁' | 's' | '▁de' | '-'   | '▁a'

        # 将特殊的 tokens 和 [unused] tokens 放入词汇表中
        self.fairseq_tokens_to_ids = {"[PAD]": 0, "[CLS]": 1, "[SEP]": 2, "[UNK]": 3, "[MASK]": 4}

        for i in range(10):
            tok = f"[unused{i}]"
            self.fairseq_tokens_to_ids[tok] = 5 + i

        # 第一个“真实”的 token “,” 在嵌入词汇中的位置为 15,在 spm 词汇中的位置为 3
        self.fairseq_offset = 12
        # 创建 fairseq_ids_to_tokens 字典,用于根据 id 查找对应的 token
        self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}

        # TODO ArthurZ fairseq_ids_to_tokens should be removed

        # 调用父类的初始化方法,传入各种特殊 token 和 sp_model_kwargs 等参数
        super().__init__(
            bos_token=bos_token,
            eos_token=eos_token,
            sep_token=sep_token,
            unk_token=unk_token,
            pad_token=pad_token,
            cls_token=cls_token,
            mask_token=mask_token,
            sp_model_kwargs=self.sp_model_kwargs,
            **kwargs,
        )



    @property
    def can_save_slow_tokenizer(self) -> bool:
        # 检查词汇文件是否存在,从而确定是否可以保存慢速的分词器
        return os.path.isfile(self.vocab_file) if self.vocab_file else False



    def __getstate__(self):
        # 返回对象的状态字典,将 sp_model 设为 None,以便对象可以被序列化
        state = self.__dict__.copy()
        state["sp_model"] = None
        return state
    def __setstate__(self, d):
        self.__dict__ = d  # 将对象的属性字典设置为给定的字典 `d`

        try:
            import sentencepiece as spm  # 尝试导入 sentencepiece 库
        except ImportError:
            logger.warning(
                "You need to install SentencePiece to use XLMRobertaTokenizer: https://github.com/google/sentencepiece"
                " pip install sentencepiece"
            )
            raise  # 报错提醒用户需要安装 SentencePiece 库

        # 用于向后兼容性
        if not hasattr(self, "sp_model_kwargs"):
            self.sp_model_kwargs = {}  # 如果对象没有 `sp_model_kwargs` 属性,则设置为空字典

        # 根据 `self.sp_model_kwargs` 参数创建 SentencePieceProcessor 对象,并加载词汇文件
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        self.sp_model.Load(self.vocab_file)

    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
        """
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` method.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """

        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )  # 如果已经包含特殊标记,调用父类的方法获取特殊标记掩码

        if token_ids_1 is None:
            return ([0] * len(token_ids_0)) + [1]  # 返回仅有第一个序列的特殊标记掩码
        return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]  # 返回包含两个序列的特殊标记掩码

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLMProphetNet
        does not make use of token type ids, therefore a list of zeros is returned.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of zeros.

        """

        sep = [self.sep_token_id]  # 获取分隔符的 token id

        if token_ids_1 is None:
            return len(token_ids_0 + sep) * [0]  # 返回仅有第一个序列的 token type ids
        return len(token_ids_0 + sep + sep + token_ids_1 + sep) * [0]  # 返回包含两个序列的 token type ids

    @property
    def vocab_size(self):
        return len(self.sp_model) + self.fairseq_offset  # 返回词汇表大小

    def get_vocab(self):
        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}  # 构建词汇表字典
        vocab.update(self.added_tokens_encoder)  # 添加额外的编码器信息到词汇表
        return vocab  # 返回词汇表字典
    def _tokenize(self, text: str) -> str:
        """Tokenizes a given text using the SentencePiece model and returns it as a string."""
        return self.sp_model.encode(text, out_type=str)

    def _convert_token_to_id(self, token):
        """Converts a token (str) into its corresponding ID using the vocabulary."""
        # Check if the token exists in the predefined Fairseq tokens to IDs mapping
        if token in self.fairseq_tokens_to_ids:
            return self.fairseq_tokens_to_ids[token]
        # Obtain the token's ID from the SentencePiece model
        spm_id = self.sp_model.PieceToId(token)
        # Return the ID with an offset specific to Fairseq or the unknown token ID if not found
        return spm_id + self.fairseq_offset if spm_id else self.unk_token_id

    def _convert_id_to_token(self, index):
        """Converts an index (integer) into its corresponding token (str) using the vocabulary."""
        # Check if the index exists in the predefined Fairseq IDs to tokens mapping
        if index in self.fairseq_ids_to_tokens:
            return self.fairseq_ids_to_tokens[index]
        # Convert the index to a token using the SentencePiece model adjusted by Fairseq offset
        return self.sp_model.IdToPiece(index - self.fairseq_offset)

    def convert_tokens_to_string(self, tokens):
        """
        Converts a sequence of tokens (strings for sub-words) into a single concatenated string,
        replacing special sub-word marker with spaces and stripping leading/trailing spaces.
        """
        out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
        return out_string

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        """
        Saves the current vocabulary to the specified directory.

        Args:
            save_directory (str): Directory path where the vocabulary file should be saved.
            filename_prefix (Optional[str]): Optional prefix for the vocabulary file name.

        Returns:
            Tuple[str]: Tuple containing the path of the saved vocabulary file.
        """
        # Ensure the provided directory path exists; otherwise, log an error
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return
        # Define the output vocabulary file path based on the provided directory and filename prefix
        out_vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )

        # Copy the current vocabulary file if it differs from the destination path and exists
        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
            copyfile(self.vocab_file, out_vocab_file)
        # If the current vocabulary file doesn't exist, write the serialized SentencePiece model to the output file
        elif not os.path.isfile(self.vocab_file):
            with open(out_vocab_file, "wb") as fi:
                content_spiece_model = self.sp_model.serialized_model_proto()
                fi.write(content_spiece_model)

        return (out_vocab_file,)

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Builds model inputs by concatenating a sequence or pair of sequences with special tokens.

        Args:
            token_ids_0 (List[int]): List of token IDs for the first sequence.
            token_ids_1 (Optional[List[int]]): Optional list of token IDs for the second sequence in a pair.

        Returns:
            List[int]: List of input IDs with added special tokens for model input.
        """
        # If only one sequence is provided, concatenate it with the separator token
        if token_ids_1 is None:
            return token_ids_0 + [self.sep_token_id]
        # Concatenate both sequences with separator tokens in between
        sep = [self.sep_token_id]
        return token_ids_0 + sep + token_ids_1 + sep

.\models\xlm_prophetnet\__init__.py

# 版权声明及许可证信息,声明代码版权及授权许可
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入类型检查模块
from typing import TYPE_CHECKING

# 导入自定义的异常和模块延迟加载工具
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_sentencepiece_available, is_torch_available

# 定义模块导入结构字典,包含一些模块及其相关的导入
_import_structure = {
    "configuration_xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
}

# 检查是否存在 sentencepiece 库,若不存在则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_sentencepiece_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若存在 sentencepiece 库,则添加 tokenization_xlm_prophetnet 模块到导入结构中
    _import_structure["tokenization_xlm_prophetnet"] = ["XLMProphetNetTokenizer"]

# 检查是否存在 torch 库,若不存在则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若存在 torch 库,则添加 modeling_xlm_prophetnet 模块到导入结构中
    _import_structure["modeling_xlm_prophetnet"] = [
        "XLM_PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST",
        "XLMProphetNetDecoder",
        "XLMProphetNetEncoder",
        "XLMProphetNetForCausalLM",
        "XLMProphetNetForConditionalGeneration",
        "XLMProphetNetModel",
        "XLMProphetNetPreTrainedModel",
    ]

# 如果是类型检查模式
if TYPE_CHECKING:
    # 从 configuration_xlm_prophetnet 模块导入特定类和变量
    from .configuration_xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig

    try:
        # 再次检查 sentencepiece 库是否存在
        if not is_sentencepiece_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 若存在 sentencepiece 库,则从 tokenization_xlm_prophetnet 模块导入特定类
        from .tokenization_xlm_prophetnet import XLMProphetNetTokenizer

    try:
        # 再次检查 torch 库是否存在
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 若存在 torch 库,则从 modeling_xlm_prophetnet 模块导入特定类和变量
        from .modeling_xlm_prophetnet import (
            XLM_PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
            XLMProphetNetDecoder,
            XLMProphetNetEncoder,
            XLMProphetNetForCausalLM,
            XLMProphetNetForConditionalGeneration,
            XLMProphetNetModel,
            XLMProphetNetPreTrainedModel,
        )

# 如果不是类型检查模式
else:
    # 导入 sys 模块
    import sys

    # 使用延迟加载模块,将当前模块设置为 LazyModule 类型
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\xlm_roberta\configuration_xlm_roberta.py

# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" XLM-RoBERTa configuration"""

# 从 collections 模块导入 OrderedDict 类
from collections import OrderedDict
# 从 typing 模块导入 Mapping 类型
from typing import Mapping

# 从 transformers 的相关模块中导入所需的类和函数
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging

# 获取 logger 对象,用于日志记录
logger = logging.get_logger(__name__)

# XLM-RoBERTa 预训练模型的配置文件映射表,包含不同模型及其配置文件的 URL
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "FacebookAI/xlm-roberta-base": "https://huggingface.co/FacebookAI/xlm-roberta-base/resolve/main/config.json",
    "FacebookAI/xlm-roberta-large": "https://huggingface.co/FacebookAI/xlm-roberta-large/resolve/main/config.json",
    "FacebookAI/xlm-roberta-large-finetuned-conll02-dutch": (
        "https://huggingface.co/FacebookAI/xlm-roberta-large-finetuned-conll02-dutch/resolve/main/config.json"
    ),
    "FacebookAI/xlm-roberta-large-finetuned-conll02-spanish": (
        "https://huggingface.co/FacebookAI/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/config.json"
    ),
    "FacebookAI/xlm-roberta-large-finetuned-conll03-english": (
        "https://huggingface.co/FacebookAI/xlm-roberta-large-finetuned-conll03-english/resolve/main/config.json"
    ),
    "FacebookAI/xlm-roberta-large-finetuned-conll03-german": (
        "https://huggingface.co/FacebookAI/xlm-roberta-large-finetuned-conll03-german/resolve/main/config.json"
    ),
}

# XLMRoBERTaConfig 类,继承自 PretrainedConfig 类,用于存储 XLM-RoBERTa 模型的配置信息
class XLMRobertaConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`XLMRobertaModel`] or a [`TFXLMRobertaModel`]. It
    is used to instantiate a XLM-RoBERTa model according to the specified arguments, defining the model architecture.
    Instantiating a configuration with the defaults will yield a similar configuration to that of the XLMRoBERTa
    [FacebookAI/xlm-roberta-base](https://huggingface.co/FacebookAI/xlm-roberta-base) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Examples:

    ```
    >>> from transformers import XLMRobertaConfig, XLMRobertaModel

    >>> # Initializing a XLM-RoBERTa FacebookAI/xlm-roberta-base style configuration
    >>> configuration = XLMRobertaConfig()

    >>> # Initializing a model (with random weights) from the FacebookAI/xlm-roberta-base style configuration

    """
    >>> model = XLMRobertaModel(configuration)

    >>> # 访问模型配置信息
    >>> configuration = model.config
    ```

    model_type = "xlm-roberta"

    def __init__(
        self,
        vocab_size=30522,
        hidden_size=768,
        num_hidden_layers=12,
        num_attention_heads=12,
        intermediate_size=3072,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=512,
        type_vocab_size=2,
        initializer_range=0.02,
        layer_norm_eps=1e-12,
        pad_token_id=1,
        bos_token_id=0,
        eos_token_id=2,
        position_embedding_type="absolute",
        use_cache=True,
        classifier_dropout=None,
        **kwargs,
    ):
        # 调用父类的构造函数,初始化基类的参数
        super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

        # 设置模型的各种参数
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.hidden_act = hidden_act
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps
        self.position_embedding_type = position_embedding_type
        self.use_cache = use_cache
        self.classifier_dropout = classifier_dropout
# 从 transformers.models.roberta.configuration_roberta.RobertaOnnxConfig 复制而来,修改为 XLMRobertaOnnxConfig
class XLMRobertaOnnxConfig(OnnxConfig):
    # 定义 inputs 属性,返回一个映射,表示输入的结构
    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        # 根据任务类型确定动态轴的设置
        if self.task == "multiple-choice":
            dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
        else:
            dynamic_axis = {0: "batch", 1: "sequence"}
        # 返回有序字典,包含 input_ids 和 attention_mask 的动态轴设置
        return OrderedDict(
            [
                ("input_ids", dynamic_axis),
                ("attention_mask", dynamic_axis),
            ]
        )

.\models\xlm_roberta\modeling_flax_xlm_roberta.py

# coding=utf-8
# Copyright 2022 Facebook AI Research and the HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Flax XLM-RoBERTa model.
"""

from typing import Callable, Optional, Tuple

import flax.linen as nn                    # 导入 Flax 的 linen 模块,用于定义神经网络模型
import jax                                 # 导入 JAX,用于执行自动微分和数组操作
import jax.numpy as jnp                    # 导入 JAX 的 NumPy 接口,用作主要的数值计算库
import numpy as np                         # 导入 NumPy,用于处理数组和数值计算
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze  # 导入 Flax 的冻结字典相关函数
from flax.linen import combine_masks, make_causal_mask           # 导入 Flax linen 的函数,用于掩码操作
from flax.linen import partitioning as nn_partitioning           # 导入 Flax linen 的分区模块,用于模型分区
from flax.linen.attention import dot_product_attention_weights  # 导入 Flax linen 的注意力机制函数
from flax.traverse_util import flatten_dict, unflatten_dict     # 导入 Flax 的工具函数,用于字典扁平化和还原
from jax import lax                       # 导入 JAX 的 lax 模块,用于定义低级操作

from ...modeling_flax_outputs import (     # 导入 Flax 模型输出相关的类和函数
    FlaxBaseModelOutputWithPastAndCrossAttentions,
    FlaxBaseModelOutputWithPooling,
    FlaxBaseModelOutputWithPoolingAndCrossAttentions,
    FlaxCausalLMOutputWithCrossAttentions,
    FlaxMaskedLMOutput,
    FlaxMultipleChoiceModelOutput,
    FlaxQuestionAnsweringModelOutput,
    FlaxSequenceClassifierOutput,
    FlaxTokenClassifierOutput,
)
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring  # 导入 Flax 模型工具函数和类
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging  # 导入相关的工具函数和日志记录模块
from .configuration_xlm_roberta import XLMRobertaConfig  # 导入当前模型配置文件

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

_CHECKPOINT_FOR_DOC = "FacebookAI/xlm-roberta-base"  # 模型文档中使用的检查点名称
_CONFIG_FOR_DOC = "XLMRobertaConfig"  # 模型文档中使用的配置文件名称

remat = nn_partitioning.remat  # 定义重映射函数

FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [  # 预训练模型存档列表
    "FacebookAI/xlm-roberta-base",
    "FacebookAI/xlm-roberta-large",
    # See all XLM-RoBERTa models at https://huggingface.co/models?filter=xlm-roberta
]


# Copied from transformers.models.roberta.modeling_flax_roberta.create_position_ids_from_input_ids
def create_position_ids_from_input_ids(input_ids, padding_idx):
    """
    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
    are ignored. This is modified from fairseq's `utils.make_positions`.

    Args:
        input_ids: jnp.ndarray  # 输入的 ID 数组
        padding_idx: int         # 填充符号的索引

    Returns: jnp.ndarray         # 返回一个新的位置 ID 数组
    """
    # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
    mask = (input_ids != padding_idx).astype("i4")  # 创建一个掩码,标记非填充符号的位置为1,填充符号位置为0
    # 如果 mask 的维度大于2,则进行形状重塑,将其展平为二维数组
    if mask.ndim > 2:
        mask = mask.reshape((-1, mask.shape[-1]))
        # 计算累积和,结果为整数类型,乘以 mask,保留同样的形状
        incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
        # 将累积和的结果重塑为与 input_ids 相同的形状
        incremental_indices = incremental_indices.reshape(input_ids.shape)
    else:
        # 如果 mask 的维度不大于2,则直接计算累积和,结果为整数类型,乘以 mask
        incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask

    # 将最终的累积和数组转换为整数类型,并加上 padding_idx
    return incremental_indices.astype("i4") + padding_idx
# XLM_ROBERTA_START_DOCSTRING 是一个包含多行字符串的文档字符串,描述了该模型的继承关系和基本特性,
# 以及它作为 Flax linen 模块的使用方式和支持的 JAX 特性。
XLM_ROBERTA_START_DOCSTRING = r"""

    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading, saving and converting weights from PyTorch models)

    This model is also a
    [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
    a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
    behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)

    Parameters:
        config ([`XLMRobertaConfig`]): Model configuration class with all the parameters of the
            model. Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
"""

# XLM_ROBERTA_INPUTS_DOCSTRING 是一个单行字符串的文档字符串,目前为空字符串。
XLM_ROBERTA_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`numpy.ndarray` of shape `({0})`):
            # 输入序列的标记索引在词汇表中的位置。

            # 可以使用 [`AutoTokenizer`] 获取这些索引。参见 [`PreTrainedTokenizer.encode`] 和
            # [`PreTrainedTokenizer.__call__`] 获取详细信息。

            # [什么是输入 ID?](../glossary#input-ids)
        attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
            # 避免在填充标记索引上执行注意力操作的掩码。掩码值为 `[0, 1]`:

            # - 1 表示**未被掩码**的标记,
            # - 0 表示**被掩码**的标记。

            # [什么是注意力掩码?](../glossary#attention-mask)
        token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
            # 段标记索引,用于指示输入的第一部分和第二部分。索引值为 `[0, 1]`:

            # - 0 对应于*句子 A* 的标记,
            # - 1 对应于*句子 B* 的标记。

            # [什么是标记类型 ID?](../glossary#token-type-ids)
        position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
            # 每个输入序列标记在位置嵌入中的位置索引。选择范围为 `[0, config.max_position_embeddings - 1]`。

        head_mask (`numpy.ndarray` of shape `({0})`, `optional):
            # 用于将注意力模块中选择的头部置零的掩码。掩码值为 `[0, 1]`:

            # - 1 表示**未被掩码**的头部,
            # - 0 表示**被掩码**的头部。

        return_dict (`bool`, *optional*):
            # 是否返回 [`~utils.ModelOutput`] 而不是普通元组。
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->XLMRoberta
class FlaxXLMRobertaEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    config: XLMRobertaConfig  # 类型提示:XLMRoberta 模型配置对象
    dtype: jnp.dtype = jnp.float32  # 计算使用的数据类型,默认为单精度浮点型

    def setup(self):
        # 初始化词嵌入层,用于将输入的词 ID 映射成对应的词向量
        self.word_embeddings = nn.Embed(
            self.config.vocab_size,
            self.config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            dtype=self.dtype,
        )
        # 初始化位置嵌入层,用于表示词的位置信息
        self.position_embeddings = nn.Embed(
            self.config.max_position_embeddings,
            self.config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            dtype=self.dtype,
        )
        # 初始化类型嵌入层,用于区分不同类型的输入(如句子 A 和句子 B)
        self.token_type_embeddings = nn.Embed(
            self.config.type_vocab_size,
            self.config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            dtype=self.dtype,
        )
        # 初始化 Layer Normalization 层,用于归一化隐藏状态
        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        # 初始化 Dropout 层,用于在训练过程中随机丢弃部分隐藏状态,防止过拟合
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

    def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
        # Embed
        # 将输入的词 ID 转换为词嵌入向量
        inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
        # 将位置 ID 转换为位置嵌入向量
        position_embeds = self.position_embeddings(position_ids.astype("i4"))
        # 将类型 ID 转换为类型嵌入向量
        token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))

        # Sum all embeddings
        # 将词嵌入向量、位置嵌入向量和类型嵌入向量相加得到最终的隐藏状态
        hidden_states = inputs_embeds + token_type_embeddings + position_embeds

        # Layer Norm
        # 对隐藏状态进行 Layer Normalization 处理
        hidden_states = self.LayerNorm(hidden_states)
        # 对归一化后的隐藏状态进行 Dropout 操作,以防止过拟合
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        return hidden_states


# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->XLMRoberta
class FlaxXLMRobertaSelfAttention(nn.Module):
    config: XLMRobertaConfig  # 类型提示:XLMRoberta 模型配置对象
    causal: bool = False  # 是否是因果注意力(自回归/自回归式),默认为否
    dtype: jnp.dtype = jnp.float32  # 计算使用的数据类型,默认为单精度浮点型
    # 在模型设置过程中调用,计算每个注意力头的维度
    def setup(self):
        # 将隐藏层大小除以注意力头的数量,以确定每个头的维度
        self.head_dim = self.config.hidden_size // self.config.num_attention_heads
        # 如果隐藏层大小不能被注意力头的数量整除,抛出数值错误异常
        if self.config.hidden_size % self.config.num_attention_heads != 0:
            raise ValueError(
                "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
                "                   : {self.config.num_attention_heads}"
            )

        # 初始化查询、键、值网络层,用于注意力机制
        self.query = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )
        self.key = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )
        self.value = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )

        # 如果启用因果注意力机制,则创建一个因果掩码
        if self.causal:
            self.causal_mask = make_causal_mask(
                jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
            )

    # 将隐藏状态张量分割为多个注意力头
    def _split_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))

    # 将多个注意力头的张量合并回隐藏状态张量
    def _merge_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))

    # 使用 nn.compact 修饰器,定义一个函数,此处功能与特定的函数一致
    @nn.compact
    # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
    def _concatenate_to_cache(self, key, value, query, attention_mask):
        """
        This function takes projected key, value states from a single input token and concatenates the states to cached
        states from previous steps. This function is slightly adapted from the official Flax repository:
        https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
        """
        # 检测是否通过缺少现有缓存数据来初始化。
        is_initialized = self.has_variable("cache", "cached_key")
        # 获取或初始化缓存的键和值,使用零张量填充,维度和类型与输入的key和value相同。
        cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
        cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
        # 获取或初始化缓存索引,初始化为整数0。
        cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))

        if is_initialized:
            # 提取批处理维度、最大长度、头数和每头深度
            *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
            # 使用新的一维空间片段更新键和值缓存
            cur_index = cache_index.value
            indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
            key = lax.dynamic_update_slice(cached_key.value, key, indices)
            value = lax.dynamic_update_slice(cached_value.value, value, indices)
            # 更新缓存中的键和值
            cached_key.value = key
            cached_value.value = value
            # 更新缓存索引,增加更新的缓存向量数量
            num_updated_cache_vectors = query.shape[1]
            cache_index.value = cache_index.value + num_updated_cache_vectors
            # 生成用于缓存解码器自注意力的因果掩码:我们的单个查询位置只应关注已生成和缓存的键位置,而不是剩余的零元素。
            pad_mask = jnp.broadcast_to(
                jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
                tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
            )
            # 将因果掩码与输入的注意力掩码结合起来
            attention_mask = combine_masks(pad_mask, attention_mask)
        # 返回更新后的键、值和注意力掩码
        return key, value, attention_mask
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->XLMRoberta
# 定义了一个用于 XLMRoberta 模型的自注意力输出层
class FlaxXLMRobertaSelfOutput(nn.Module):
    config: XLMRobertaConfig  # 类型注解,指定配置类 XLMRobertaConfig
    dtype: jnp.dtype = jnp.float32  # 计算中使用的数据类型

    def setup(self):
        # 初始化全连接层,输出维度为配置中指定的隐藏大小,使用正态分布初始化权重
        self.dense = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        # 初始化 LayerNorm 层,epsilon 参数由配置类 XLMRobertaConfig 提供
        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        # 初始化 Dropout 层,dropout 率由配置类 XLMRobertaConfig 提供
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

    def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
        # 前向传播函数,接收隐藏状态、输入张量和一个布尔值作为参数
        # 通过全连接层处理隐藏状态
        hidden_states = self.dense(hidden_states)
        # 使用 Dropout 层对处理后的隐藏状态进行随机失活
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        # 将处理后的隐藏状态与输入张量相加,并通过 LayerNorm 层处理
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->XLMRoberta
# 定义了一个用于 XLMRoberta 模型的注意力机制层
class FlaxXLMRobertaAttention(nn.Module):
    config: XLMRobertaConfig  # 类型注解,指定配置类 XLMRobertaConfig
    causal: bool = False  # 是否启用因果关系的布尔值,默认为 False
    dtype: jnp.dtype = jnp.float32  # 计算中使用的数据类型

    def setup(self):
        # 初始化自注意力层,使用 XLMRobertaSelfAttention 类处理
        self.self = FlaxXLMRobertaSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
        # 初始化自注意力输出层,使用 FlaxXLMRobertaSelfOutput 类处理
        self.output = FlaxXLMRobertaSelfOutput(self.config, dtype=self.dtype)

    def __call__(
        self,
        hidden_states,
        attention_mask,
        layer_head_mask,
        key_value_states=None,
        init_cache=False,
        deterministic=True,
        output_attentions: bool = False,
    ):
        # 前向传播函数,接收多个参数用于处理注意力机制
        # 使用 self.self 处理自注意力计算,得到注意力输出
        attn_outputs = self.self(
            hidden_states,
            attention_mask,
            layer_head_mask=layer_head_mask,
            key_value_states=key_value_states,
            init_cache=init_cache,
            deterministic=deterministic,
            output_attentions=output_attentions,
        )
        attn_output = attn_outputs[0]
        # 使用 self.output 处理注意力输出,得到最终的隐藏状态
        hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (attn_outputs[1],)

        return outputs


# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->XLMRoberta
# 定义了一个用于 XLMRoberta 模型的中间层
class FlaxXLMRobertaIntermediate(nn.Module):
    config: XLMRobertaConfig  # 类型注解,指定配置类 XLMRobertaConfig
    dtype: jnp.dtype = jnp.float32  # 计算中使用的数据类型

    def setup(self):
        # 初始化全连接层,输出维度为配置中指定的中间大小,使用正态分布初始化权重
        self.dense = nn.Dense(
            self.config.intermediate_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        # 初始化激活函数,激活函数类型由配置类 XLMRobertaConfig 提供
        self.activation = ACT2FN[self.config.hidden_act]
    # 定义一个类中的特殊方法 __call__(),用于将对象实例像函数一样调用
    def __call__(self, hidden_states):
        # 将输入的隐藏状态数据通过全连接层进行线性变换
        hidden_states = self.dense(hidden_states)
        # 对变换后的隐藏状态数据应用激活函数
        hidden_states = self.activation(hidden_states)
        # 返回经过线性变换和激活函数处理后的隐藏状态数据
        return hidden_states
# 从 transformers.models.bert.modeling_flax_bert.FlaxBertOutput 复制而来,将 Bert 替换为 XLMRoberta
class FlaxXLMRobertaOutput(nn.Module):
    config: XLMRobertaConfig  # XLMRoberta 模型的配置信息
    dtype: jnp.dtype = jnp.float32  # 计算时使用的数据类型

    def setup(self):
        # 初始化一个全连接层,输出大小为 config.hidden_size
        self.dense = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),  # 使用正态分布初始化权重
            dtype=self.dtype,
        )
        # 初始化一个 Dropout 层,丢弃率为 config.hidden_dropout_prob
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
        # 初始化一个 LayerNorm 层,epsilon 为 config.layer_norm_eps
        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)

    def __call__(self, hidden_states, attention_output, deterministic: bool = True):
        # 通过全连接层处理 hidden_states
        hidden_states = self.dense(hidden_states)
        # 应用 Dropout 处理 hidden_states
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        # 应用 LayerNorm 处理 hidden_states 和 attention_output 的和
        hidden_states = self.LayerNorm(hidden_states + attention_output)
        return hidden_states


# 从 transformers.models.bert.modeling_flax_bert.FlaxBertLayer 复制而来,将 Bert 替换为 XLMRoberta
class FlaxXLMRobertaLayer(nn.Module):
    config: XLMRobertaConfig  # XLMRoberta 模型的配置信息
    dtype: jnp.dtype = jnp.float32  # 计算时使用的数据类型

    def setup(self):
        # 初始化 self.attention 为 FlaxXLMRobertaAttention 实例
        self.attention = FlaxXLMRobertaAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
        # 初始化 self.intermediate 为 FlaxXLMRobertaIntermediate 实例
        self.intermediate = FlaxXLMRobertaIntermediate(self.config, dtype=self.dtype)
        # 初始化 self.output 为 FlaxXLMRobertaOutput 实例
        self.output = FlaxXLMRobertaOutput(self.config, dtype=self.dtype)
        # 如果配置中包含交叉注意力,初始化 self.crossattention 为 FlaxXLMRobertaAttention 实例
        if self.config.add_cross_attention:
            self.crossattention = FlaxXLMRobertaAttention(self.config, causal=False, dtype=self.dtype)

    def __call__(
        self,
        hidden_states,
        attention_mask,
        layer_head_mask,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
        output_attentions: bool = False,
        ):
        # 实现 FlaxXLMRobertaLayer 的调用功能,接收多个参数进行处理
        # (具体处理逻辑在实现该方法的类的调用实现中)
        pass  # 这里是函数体的结尾,没有实际的代码逻辑,因此不需要添加额外的注释
        # Self Attention
        # 使用 self.attention 方法进行自注意力计算,处理隐藏状态和注意力掩码
        attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            layer_head_mask=layer_head_mask,
            init_cache=init_cache,
            deterministic=deterministic,
            output_attentions=output_attentions,
        )
        attention_output = attention_outputs[0]

        # Cross-Attention Block
        # 如果存在编码器的隐藏状态,则进行交叉注意力计算
        if encoder_hidden_states is not None:
            cross_attention_outputs = self.crossattention(
                attention_output,
                attention_mask=encoder_attention_mask,
                layer_head_mask=layer_head_mask,
                key_value_states=encoder_hidden_states,
                deterministic=deterministic,
                output_attentions=output_attentions,
            )
            attention_output = cross_attention_outputs[0]

        # 经过 self.intermediate 层的处理
        hidden_states = self.intermediate(attention_output)
        # 经过 self.output 层的处理,得到最终输出的隐藏状态
        hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)

        # 将隐藏状态打包成输出元组
        outputs = (hidden_states,)

        # 如果需要输出注意力信息
        if output_attentions:
            # 添加自注意力信息到输出元组
            outputs += (attention_outputs[1],)
            # 如果存在编码器的隐藏状态,则添加交叉注意力信息到输出元组
            if encoder_hidden_states is not None:
                outputs += (cross_attention_outputs[1],)

        # 返回最终的输出元组
        return outputs
# 从 transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection 复制并修改为 FlaxXLMRobertaLayerCollection
class FlaxXLMRobertaLayerCollection(nn.Module):
    config: XLMRobertaConfig  # 类型提示,指定配置对象为 XLMRobertaConfig 类型
    dtype: jnp.dtype = jnp.float32  # 计算中使用的数据类型,默认为 jnp.float32
    gradient_checkpointing: bool = False  # 是否使用梯度检查点,默认为 False

    def setup(self):
        if self.gradient_checkpointing:
            # 如果开启梯度检查点,使用 remat 函数对 FlaxXLMRobertaLayer 进行重建
            FlaxXLMRobertaCheckpointLayer = remat(FlaxXLMRobertaLayer, static_argnums=(5, 6, 7))
            # 创建一个包含检查点层的列表,每层的名称为索引号字符串
            self.layers = [
                FlaxXLMRobertaCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
                for i in range(self.config.num_hidden_layers)
            ]
        else:
            # 如果未开启梯度检查点,创建一个 FlaxXLMRobertaLayer 的列表,每层的名称为索引号字符串
            self.layers = [
                FlaxXLMRobertaLayer(self.config, name=str(i), dtype=self.dtype)
                for i in range(self.config.num_hidden_layers)
            ]

    def __call__(
        self,
        hidden_states,
        attention_mask,
        head_mask,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        # 神经网络层的调用方法,接受多个输入参数和一些可选的布尔值参数
        all_attentions = () if output_attentions else None
        all_hidden_states = () if output_hidden_states else None
        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None

        # 检查是否需要创建头部遮罩(head_mask),确保头部遮罩的层数与模型层数一致
        if head_mask is not None:
            if head_mask.shape[0] != (len(self.layers)):
                raise ValueError(
                    f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.shape[0]}."
                )

        # 遍历模型的每一层并进行前向传播
        for i, layer in enumerate(self.layers):
            # 如果需要输出隐藏状态,则将当前层的隐藏状态加入到列表中
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            # 调用当前层的前向传播函数,获取当前层的输出
            layer_outputs = layer(
                hidden_states,
                attention_mask,
                head_mask[i] if head_mask is not None else None,
                encoder_hidden_states,
                encoder_attention_mask,
                init_cache,
                deterministic,
                output_attentions,
            )

            # 更新当前层的隐藏状态
            hidden_states = layer_outputs[0]

            # 如果需要输出注意力权重,将当前层的注意力权重加入到列表中
            if output_attentions:
                all_attentions += (layer_outputs[1],)

                # 如果存在编码器的隐藏状态,将当前层的交叉注意力权重加入到列表中
                if encoder_hidden_states is not None:
                    all_cross_attentions += (layer_outputs[2],)

        # 如果需要输出隐藏状态,将最后一层的隐藏状态加入到列表中
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        # 构建模型的输出
        outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)

        # 如果不需要以字典形式返回结果,则返回元组形式的输出
        if not return_dict:
            return tuple(v for v in outputs if v is not None)

        # 以带过去和交叉注意力的 Flax 模型输出格式返回结果
        return FlaxBaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
            cross_attentions=all_cross_attentions,
        )
# 从transformers.models.bert.modeling_flax_bert.FlaxBertEncoder复制代码,并将Bert->XLMRoberta
class FlaxXLMRobertaEncoder(nn.Module):
    config: XLMRobertaConfig
    dtype: jnp.dtype = jnp.float32  # 计算的数据类型
    gradient_checkpointing: bool = False  # 是否使用梯度检查点

    def setup(self):
        self.layer = FlaxXLMRobertaLayerCollection(
            self.config,
            dtype=self.dtype,
            gradient_checkpointing=self.gradient_checkpointing,
        )

    def __call__(
        self,
        hidden_states,
        attention_mask,
        head_mask,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        return self.layer(
            hidden_states,
            attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            init_cache=init_cache,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )


# 从transformers.models.bert.modeling_flax_bert.FlaxBertPooler复制代码,并将Bert->XLMRoberta
class FlaxXLMRobertaPooler(nn.Module):
    config: XLMRobertaConfig
    dtype: jnp.dtype = jnp.float32  # 计算的数据类型

    def setup(self):
        self.dense = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )

    def __call__(self, hidden_states):
        cls_hidden_state = hidden_states[:, 0]  # 取第一个位置的CLS隐藏状态
        cls_hidden_state = self.dense(cls_hidden_state)  # 通过全连接层进行处理
        return nn.tanh(cls_hidden_state)  # 返回经过tanh激活的CLS隐藏状态


# 从transformers.models.roberta.modeling_flax_roberta.FlaxRobertaLMHead复制代码,并将Roberta->XLMRoberta
class FlaxXLMRobertaLMHead(nn.Module):
    config: XLMRobertaConfig
    dtype: jnp.dtype = jnp.float32  # 计算的数据类型
    bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros  # 偏置初始化函数

    def setup(self):
        self.dense = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )
        self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)  # 层归一化
        self.decoder = nn.Dense(
            self.config.vocab_size,
            dtype=self.dtype,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )
        self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))  # 偏置参数
    # 定义一个对象的调用方法,接受隐藏状态和共享嵌入作为参数
    def __call__(self, hidden_states, shared_embedding=None):
        # 将隐藏状态通过全连接层进行线性变换
        hidden_states = self.dense(hidden_states)
        # 使用 GELU 激活函数处理隐藏状态
        hidden_states = ACT2FN["gelu"](hidden_states)
        # 对处理后的隐藏状态进行 Layer Normalization
        hidden_states = self.layer_norm(hidden_states)

        # 如果提供了共享的嵌入向量,则将其作为参数应用到解码器中
        if shared_embedding is not None:
            hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
        else:
            # 否则,直接使用解码器处理隐藏状态
            hidden_states = self.decoder(hidden_states)

        # 将偏置转换为 JAX 数组,并加到隐藏状态上
        bias = jnp.asarray(self.bias, self.dtype)
        hidden_states += bias
        # 返回处理后的隐藏状态
        return hidden_states
# 从 transformers.models.roberta.modeling_flax_roberta.FlaxRobertaClassificationHead 复制而来,将 Roberta 替换为 XLMRoberta
class FlaxXLMRobertaClassificationHead(nn.Module):
    config: XLMRobertaConfig  # 类的配置信息,使用 XLMRobertaConfig
    dtype: jnp.dtype = jnp.float32  # 数据类型设置为 jnp.float32

    def setup(self):
        # 初始化一个全连接层,输出大小为 config.hidden_size,使用正态分布初始化权重
        self.dense = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )
        # 设置分类器的 dropout 率为 config.classifier_dropout,如果为 None,则使用 config.hidden_dropout_prob
        classifier_dropout = (
            self.config.classifier_dropout
            if self.config.classifier_dropout is not None
            else self.config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(rate=classifier_dropout)  # 设置 dropout 层
        # 初始化一个全连接层,输出大小为 config.num_labels,使用正态分布初始化权重
        self.out_proj = nn.Dense(
            self.config.num_labels,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )

    def __call__(self, hidden_states, deterministic=True):
        hidden_states = hidden_states[:, 0, :]  # 取 <s> 标记对应的隐藏状态 (等同于 [CLS])
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)  # 应用 dropout
        hidden_states = self.dense(hidden_states)  # 应用全连接层
        hidden_states = nn.tanh(hidden_states)  # 应用 tanh 激活函数
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)  # 再次应用 dropout
        hidden_states = self.out_proj(hidden_states)  # 应用输出投影层
        return hidden_states


# 从 transformers.models.roberta.modeling_flax_roberta.FlaxRobertaPreTrainedModel 复制而来,将 Roberta 替换为 XLMRoberta,roberta 替换为 xlm-roberta,ROBERTA 替换为 XLM_ROBERTA
class FlaxXLMRobertaPreTrainedModel(FlaxPreTrainedModel):
    """
    处理权重初始化和简单接口以下载和加载预训练模型的抽象类。
    """

    config_class = XLMRobertaConfig  # 配置类为 XLMRobertaConfig
    base_model_prefix = "xlm-roberta"  # 基础模型前缀为 "xlm-roberta"

    module_class: nn.Module = None  # 模块类设置为 None

    def __init__(
        self,
        config: XLMRobertaConfig,
        input_shape: Tuple = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        gradient_checkpointing: bool = False,
        **kwargs,
    ):
        # 初始化一个模块类对象,使用给定的配置、数据类型和参数
        module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

    # 从 transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing 复制而来
    def enable_gradient_checkpointing(self):
        self._module = self.module_class(
            config=self.config,
            dtype=self.dtype,
            gradient_checkpointing=True,
        )
    # 初始化模型的权重
    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # 初始化输入张量
        input_ids = jnp.zeros(input_shape, dtype="i4")  # 创建全零的输入张量
        token_type_ids = jnp.ones_like(input_ids)  # 创建与输入张量形状相同的全一张量
        position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)  # 根据输入张量创建位置编码
        attention_mask = jnp.ones_like(input_ids)  # 创建与输入张量形状相同的全一注意力掩码
        head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))  # 创建全一的头部掩码

        # 分割随机数生成器
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        if self.config.add_cross_attention:
            # 如果配置要求添加交叉注意力,初始化编码器隐藏状态和注意力掩码
            encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
            encoder_attention_mask = attention_mask
            # 使用模型初始化,并返回模型初始化的输出
            module_init_outputs = self.module.init(
                rngs,
                input_ids,
                attention_mask,
                token_type_ids,
                position_ids,
                head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                return_dict=False,
            )
        else:
            # 否则,使用模型初始化,仅传入基本参数
            module_init_outputs = self.module.init(
                rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
            )

        # 从模型初始化的输出中获取随机参数
        random_params = module_init_outputs["params"]

        if params is not None:
            # 如果提供了预定义参数,则将随机参数展开并填充缺失的键
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))  # 冻结填充后的参数并返回
        else:
            return random_params  # 否则,返回随机初始化的参数

    # 从 transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache 复制过来的方法
    def init_cache(self, batch_size, max_length):
        r"""
        Args:
            batch_size (`int`):
                用于快速自回归解码的批大小。定义初始化缓存时的批大小。
            max_length (`int`):
                自回归解码的最大可能长度。定义初始化缓存的序列长度。
        """
        # 初始化用于检索缓存的输入变量
        input_ids = jnp.ones((batch_size, max_length), dtype="i4")  # 创建全一的输入张量
        attention_mask = jnp.ones_like(input_ids, dtype="i4")  # 创建与输入张量形状相同的全一注意力掩码
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)  # 广播位置编码

        # 使用模型初始化,并返回初始化变量的缓存部分
        init_variables = self.module.init(
            jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
        )
        return unfreeze(init_variables["cache"])

    @add_start_docstrings_to_model_forward(XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        params: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        past_key_values: dict = None,


        # 定义一个调用方法,接收多个输入参数,以下为详细参数解释

        # 必须的输入参数,表示模型的输入 token IDs
        input_ids,

        # 可选的输入参数,表示注意力遮罩,用于指示哪些标记是有效的
        attention_mask=None,

        # 可选的输入参数,表示标记类型的 IDs,通常在多段文本输入时使用
        token_type_ids=None,

        # 可选的输入参数,表示标记在序列中的位置 IDs
        position_ids=None,

        # 可选的输入参数,表示头部遮罩,用于指示哪些注意力头部是有效的
        head_mask=None,

        # 可选的输入参数,表示编码器的隐藏状态
        encoder_hidden_states=None,

        # 可选的输入参数,表示编码器注意力遮罩,用于指示哪些编码器隐藏状态是有效的
        encoder_attention_mask=None,

        # 可选的输入参数,表示额外的参数字典,用于模型配置
        params: dict = None,

        # 可选的输入参数,表示随机数生成器密钥,用于 dropout 操作
        dropout_rng: jax.random.PRNGKey = None,

        # 可选的输入参数,表示是否处于训练模式
        train: bool = False,

        # 可选的输入参数,表示是否输出注意力权重
        output_attentions: Optional[bool] = None,

        # 可选的输入参数,表示是否输出隐藏状态
        output_hidden_states: Optional[bool] = None,

        # 可选的输入参数,表示是否返回一个字典对象
        return_dict: Optional[bool] = None,

        # 可选的输入参数,表示过去的键值状态字典
        past_key_values: dict = None,
# 从transformers.models.bert.modeling_flax_bert.FlaxBertModule复制代码,并将Bert->XLMRoberta
class FlaxXLMRobertaModule(nn.Module):
    # 使用XLMRobertaConfig配置
    config: XLMRobertaConfig
    # 计算时的数据类型
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    # 是否添加池化层,默认为True
    add_pooling_layer: bool = True
    # 是否使用梯度检查点
    gradient_checkpointing: bool = False

    def setup(self):
        # 初始化嵌入层
        self.embeddings = FlaxXLMRobertaEmbeddings(self.config, dtype=self.dtype)
        # 初始化编码器
        self.encoder = FlaxXLMRobertaEncoder(
            self.config,
            dtype=self.dtype,
            gradient_checkpointing=self.gradient_checkpointing,
        )
        # 初始化池化层
        self.pooler = FlaxXLMRobertaPooler(self.config, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids: Optional[jnp.ndarray] = None,
        position_ids: Optional[jnp.ndarray] = None,
        head_mask: Optional[jnp.ndarray] = None,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 确保当token_type_ids未传入时被正确初始化为全零数组
        if token_type_ids is None:
            token_type_ids = jnp.zeros_like(input_ids)

        # 确保当position_ids未传入时被正确初始化
        if position_ids is None:
            position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        # 通过嵌入层计算隐藏状态
        hidden_states = self.embeddings(
            input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
        )
        # 使用编码器计算输出
        outputs = self.encoder(
            hidden_states,
            attention_mask,
            head_mask=head_mask,
            deterministic=deterministic,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 获取编码器的隐藏状态
        hidden_states = outputs[0]
        # 如果需要添加池化层,则计算池化结果
        pooled = self.pooler(hidden_states) if self.add_pooling_layer else None

        # 如果不需要返回字典形式的结果
        if not return_dict:
            # 如果池化结果为None,则不返回它
            if pooled is None:
                return (hidden_states,) + outputs[1:]
            return (hidden_states, pooled) + outputs[1:]

        # 返回包含池化结果和交叉注意力的FlaxBaseModelOutputWithPoolingAndCrossAttentions对象
        return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=hidden_states,
            pooler_output=pooled,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )
    XLM_ROBERTA_START_DOCSTRING,



# 引用预定义的常量 XLM_ROBERTA_START_DOCSTRING
)
class FlaxXLMRobertaModel(FlaxXLMRobertaPreTrainedModel):
    module_class = FlaxXLMRobertaModule


append_call_sample_docstring(FlaxXLMRobertaModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC)


# 从 transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForMaskedLMModule 复制并修改为 XLMRoberta
class FlaxXLMRobertaForMaskedLMModule(nn.Module):
    config: XLMRobertaConfig
    dtype: jnp.dtype = jnp.float32
    gradient_checkpointing: bool = False

    def setup(self):
        # 初始化 XLM-Roberta 模型,配置为不添加池化层,使用指定数据类型和梯度检查点
        self.roberta = FlaxXLMRobertaModule(
            config=self.config,
            add_pooling_layer=False,
            dtype=self.dtype,
            gradient_checkpointing=self.gradient_checkpointing,
        )
        # 初始化 XLM-Roberta 语言模型头部
        self.lm_head = FlaxXLMRobertaLMHead(config=self.config, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        head_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 调用 XLM-Roberta 模型
        outputs = self.roberta(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]
        # 如果配置指定共享词嵌入,则获取共享的词嵌入
        if self.config.tie_word_embeddings:
            shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
        else:
            shared_embedding = None

        # 计算预测分数
        logits = self.lm_head(hidden_states, shared_embedding=shared_embedding)

        if not return_dict:
            return (logits,) + outputs[1:]

        # 返回 XLM-Roberta 遮蔽语言建模的输出
        return FlaxMaskedLMOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


@add_start_docstrings("""XLM RoBERTa Model with a `language modeling` head on top.""", XLM_ROBERTA_START_DOCSTRING)
class FlaxXLMRobertaForMaskedLM(FlaxXLMRobertaPreTrainedModel):
    module_class = FlaxXLMRobertaForMaskedLMModule


append_call_sample_docstring(
    FlaxXLMRobertaForMaskedLM,
    _CHECKPOINT_FOR_DOC,
    FlaxBaseModelOutputWithPooling,
    _CONFIG_FOR_DOC,
    mask="<mask>",
)


# 从 transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForSequenceClassificationModule 复制并修改为 XLMRoberta
class FlaxXLMRobertaForSequenceClassificationModule(nn.Module):
    config: XLMRobertaConfig
    dtype: jnp.dtype = jnp.float32
    gradient_checkpointing: bool = False
    # 在对象初始化时设置模型结构
    def setup(self):
        self.roberta = FlaxXLMRobertaModule(
            config=self.config,                     # 使用给定配置初始化模型
            dtype=self.dtype,                       # 设定数据类型
            add_pooling_layer=False,                # 禁用池化层
            gradient_checkpointing=self.gradient_checkpointing,  # 设置梯度检查点
        )
        self.classifier = FlaxXLMRobertaClassificationHead(config=self.config, dtype=self.dtype)  # 初始化分类头部模块

    # 对象调用时执行的函数,用于模型推断
    def __call__(
        self,
        input_ids,                                  # 输入的token id序列
        attention_mask,                             # 注意力掩码
        token_type_ids,                             # token类型id
        position_ids,                               # 位置id
        head_mask,                                  # 头部掩码
        deterministic: bool = True,                 # 是否使用确定性计算
        output_attentions: bool = False,            # 是否输出注意力权重
        output_hidden_states: bool = False,         # 是否输出隐藏状态
        return_dict: bool = True,                   # 是否返回字典形式的结果
    ):
        # 执行RoBERTa模型的前向传播
        outputs = self.roberta(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]                # 提取序列输出
        logits = self.classifier(sequence_output, deterministic=deterministic)  # 使用分类头部预测logits

        if not return_dict:
            return (logits,) + outputs[1:]          # 返回tuple形式的输出

        return FlaxSequenceClassifierOutput(
            logits=logits,                          # 返回分类的logits
            hidden_states=outputs.hidden_states,    # 返回隐藏状态
            attentions=outputs.attentions,          # 返回注意力权重
        )
@add_start_docstrings(
    """
    XLM Roberta Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    """,
    XLM_ROBERTA_START_DOCSTRING,
)
"""
XLM Roberta模型转换器,顶部带有序列分类/回归头部(即池化输出的顶部线性层),例如用于GLUE任务。
"""

append_call_sample_docstring(
    FlaxXLMRobertaForSequenceClassification,
    _CHECKPOINT_FOR_DOC,
    FlaxSequenceClassifierOutput,
    _CONFIG_FOR_DOC,
)
"""
将示例调用的文档字符串附加到FlaxXLMRobertaForSequenceClassification类的文档中,
包括_CHECKPOINT_FOR_DOC、FlaxSequenceClassifierOutput和_CONFIG_FOR_DOC。
"""

# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMultipleChoiceModule with Bert->XLMRoberta, with self.bert->self.roberta
"""
从transformers.models.bert.modeling_flax_bert.FlaxBertForMultipleChoiceModule复制,
将Bert替换为XLMRoberta,将self.bert替换为self.roberta。
"""
class FlaxXLMRobertaForMultipleChoiceModule(nn.Module):
    config: XLMRobertaConfig
    dtype: jnp.dtype = jnp.float32
    gradient_checkpointing: bool = False

    def setup(self):
        self.roberta = FlaxXLMRobertaModule(
            config=self.config,
            dtype=self.dtype,
            gradient_checkpointing=self.gradient_checkpointing,
        )
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
        self.classifier = nn.Dense(1, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        head_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        num_choices = input_ids.shape[1]
        input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
        attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
        token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
        position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None

        # Model
        outputs = self.roberta(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output, deterministic=deterministic)
        logits = self.classifier(pooled_output)

        reshaped_logits = logits.reshape(-1, num_choices)

        if not return_dict:
            return (reshaped_logits,) + outputs[2:]

        return FlaxMultipleChoiceModelOutput(
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
"""
XLM Roberta模型,带有多选分类头部(即池化输出的顶部线性层和)。
"""

@add_start_docstrings(
    """
    XLM Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
    """,
    XLM_ROBERTA_START_DOCSTRING,
)
"""
XLM Roberta模型,带有多选分类头部(即池化输出的顶部线性层和)
"""
    a softmax) e.g. for RocStories/SWAG tasks.
    """
    XLM-RoBERTa 模型的起始文档字符串,用于生成模型文档说明。
    """
)
class FlaxXLMRobertaForMultipleChoice(FlaxXLMRobertaPreTrainedModel):
    module_class = FlaxXLMRobertaForMultipleChoiceModule


overwrite_call_docstring(
    FlaxXLMRobertaForMultipleChoice, XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
)
append_call_sample_docstring(
    FlaxXLMRobertaForMultipleChoice,
    _CHECKPOINT_FOR_DOC,
    FlaxMultipleChoiceModelOutput,
    _CONFIG_FOR_DOC,
)



# 从FlaxBertForTokenClassificationModule复制并修改为FlaxXLMRobertaForTokenClassificationModule,将self.bert->self.roberta
class FlaxXLMRobertaForTokenClassificationModule(nn.Module):
    config: XLMRobertaConfig
    dtype: jnp.dtype = jnp.float32
    gradient_checkpointing: bool = False

    def setup(self):
        # 初始化时创建FlaxXLMRobertaModule实例,并传入配置、数据类型、是否梯度检查点、不添加池化层
        self.roberta = FlaxXLMRobertaModule(
            config=self.config,
            dtype=self.dtype,
            add_pooling_layer=False,
            gradient_checkpointing=self.gradient_checkpointing,
        )
        # 根据配置设置分类器的dropout率,若未指定则使用隐藏层的dropout率
        classifier_dropout = (
            self.config.classifier_dropout
            if self.config.classifier_dropout is not None
            else self.config.hidden_dropout_prob
        )
        # 创建一个dropout层,用于隐藏状态
        self.dropout = nn.Dropout(rate=classifier_dropout)
        # 创建一个全连接层,输出维度为配置文件中指定的标签数
        self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        head_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 调用self.roberta进行模型推断
        outputs = self.roberta(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从输出中获取隐藏状态,并在推断时使用dropout层
        hidden_states = outputs[0]
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        # 使用分类器预测标签
        logits = self.classifier(hidden_states)

        # 若return_dict为False,则返回元组形式的输出
        if not return_dict:
            return (logits,) + outputs[1:]

        # 否则返回FlaxTokenClassifierOutput对象,包含logits、隐藏状态和注意力机制
        return FlaxTokenClassifierOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


@add_start_docstrings(
    """
    XLM Roberta模型,顶部带有一个标记分类头(即隐藏状态输出的线性层),例如用于命名实体识别(NER)任务。
    """,
    XLM_ROBERTA_START_DOCSTRING,
)
class FlaxXLMRobertaForTokenClassification(FlaxXLMRobertaPreTrainedModel):
    module_class = FlaxXLMRobertaForTokenClassificationModule


append_call_sample_docstring(
    FlaxXLMRobertaForTokenClassification,
    _CHECKPOINT_FOR_DOC,
    FlaxTokenClassifierOutput,
    _CONFIG_FOR_DOC,
)
# 从 transformers.models.bert.modeling_flax_bert.FlaxBertForQuestionAnsweringModule 复制代码到这里,并将 Bert->XLMRoberta,self.bert->self.roberta
class FlaxXLMRobertaForQuestionAnsweringModule(nn.Module):
    # 使用 XLMRobertaConfig 配置类
    config: XLMRobertaConfig
    # 数据类型默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32
    # 是否使用梯度检查点,默认为 False
    gradient_checkpointing: bool = False

    def setup(self):
        # 初始化 self.roberta 为 FlaxXLMRobertaModule
        self.roberta = FlaxXLMRobertaModule(
            config=self.config,
            dtype=self.dtype,
            add_pooling_layer=False,
            gradient_checkpointing=self.gradient_checkpointing,
        )
        # 初始化 self.qa_outputs 为 nn.Dense,输出大小为 self.config.num_labels
        self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        head_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 调用 self.roberta 进行模型计算
        outputs = self.roberta(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从模型输出中获取隐藏状态
        hidden_states = outputs[0]

        # 计算起始和结束位置的 logits
        logits = self.qa_outputs(hidden_states)
        start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        # 如果 return_dict 为 False,则返回 tuple 类型
        if not return_dict:
            return (start_logits, end_logits) + outputs[1:]

        # 如果 return_dict 为 True,则返回 FlaxQuestionAnsweringModelOutput 类型
        return FlaxQuestionAnsweringModelOutput(
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


@add_start_docstrings(
    """
    XLM Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
    linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
    XLM_ROBERTA_START_DOCSTRING,
)
# 继承自 FlaxXLMRobertaPreTrainedModel 的 XLMRoberta 问题回答模型类
class FlaxXLMRobertaForQuestionAnswering(FlaxXLMRobertaPreTrainedModel):
    # 指定模块类为 FlaxXLMRobertaForQuestionAnsweringModule
    module_class = FlaxXLMRobertaForQuestionAnsweringModule


# 从 transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForCausalLMModule 复制代码到这里,并将 Roberta->XLMRoberta
class FlaxXLMRobertaForCausalLMModule(nn.Module):
    # 使用 XLMRobertaConfig 配置类
    config: XLMRobertaConfig
    # 数据类型默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32
    # 是否使用梯度检查点,默认为 False
    gradient_checkpointing: bool = False
    # 在模型设置方法中初始化 RoBERTa 模型和语言模型头部
    def setup(self):
        self.roberta = FlaxXLMRobertaModule(
            config=self.config,
            add_pooling_layer=False,
            dtype=self.dtype,
            gradient_checkpointing=self.gradient_checkpointing,
        )
        self.lm_head = FlaxXLMRobertaLMHead(config=self.config, dtype=self.dtype)

    # 在调用方法中执行模型的前向传播
    def __call__(
        self,
        input_ids,
        attention_mask,
        position_ids,
        token_type_ids: Optional[jnp.ndarray] = None,
        head_mask: Optional[jnp.ndarray] = None,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 调用 RoBERTa 模型的前向传播,并传入所有必要的参数
        outputs = self.roberta(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            init_cache=init_cache,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取模型的隐藏状态作为输入特征
        hidden_states = outputs[0]

        # 根据配置决定是否共享词嵌入层
        if self.config.tie_word_embeddings:
            shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
        else:
            shared_embedding = None

        # 计算语言模型头部的预测分数
        logits = self.lm_head(hidden_states, shared_embedding=shared_embedding)

        # 如果不要求返回字典形式的输出,则返回元组
        if not return_dict:
            return (logits,) + outputs[1:]

        # 返回带有交叉注意力的因果语言建模输出
        return FlaxCausalLMOutputWithCrossAttentions(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )
# 使用装饰器为类添加文档字符串,描述该类是在 XLM Roberta 模型基础上添加了语言建模头部的变体,用于自回归任务
@add_start_docstrings(
    """
    XLM Roberta Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
    autoregressive tasks.
    """,
    XLM_ROBERTA_START_DOCSTRING,
)
# 从 transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForCausalLM 复制过来,并将 Roberta 改为 XLMRoberta
class FlaxXLMRobertaForCausalLM(FlaxXLMRobertaPreTrainedModel):
    # 使用 FlaxXLMRobertaForCausalLMModule 作为模块类
    module_class = FlaxXLMRobertaForCausalLMModule

    # 为生成准备输入的方法,接受输入的 token IDs,生成最大长度的序列
    def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
        # 初始化缓存
        batch_size, seq_length = input_ids.shape

        # 使用 self.init_cache 方法初始化过去的键值对
        past_key_values = self.init_cache(batch_size, max_length)
        
        # 注意:通常需要在 attention_mask 中为 x > input_ids.shape[-1] 和 x < cache_length 的位置放置 0
        # 但由于解码器使用因果蒙版,这些位置已经被蒙版了。因此,我们可以在这里创建一个静态的 attention_mask,这对编译更有效。
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
        
        # 如果提供了 attention_mask,则根据其累积和更新 extended_attention_mask
        if attention_mask is not None:
            position_ids = attention_mask.cumsum(axis=-1) - 1
            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
        else:
            # 否则,广播生成一个 position_ids,形状为 (batch_size, seq_length)
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))

        return {
            "past_key_values": past_key_values,
            "attention_mask": extended_attention_mask,
            "position_ids": position_ids,
        }

    # 更新生成时的输入,将模型输出的 past_key_values 和 position_ids 更新到 model_kwargs 中
    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
        return model_kwargs

# 将样例调用文档字符串附加到 FlaxXLMRobertaForCausalLM 类上,描述如何调用该类以生成样本
append_call_sample_docstring(
    FlaxXLMRobertaForCausalLM,
    _CHECKPOINT_FOR_DOC,
    FlaxCausalLMOutputWithCrossAttentions,
    _CONFIG_FOR_DOC,
)
posted @ 2024-07-01 10:53  绝不原创的飞龙  阅读(11)  评论(0编辑  收藏  举报