Transformers-源码解析-七十二-

Transformers 源码解析(七十二)

.\models\maskformer\modeling_maskformer.py

# coding=utf-8
# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

### PyTorch MaskFormer模型的基础元类
""" 
import math
from dataclasses import dataclass  
from numbers import Number
from typing import Dict, List, Optional, Tuple

import numpy as np  
import torch  
from torch import Tensor, nn  

from ...activations import ACT2FN                 
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask  
from ...modeling_outputs import BaseModelOutputWithCrossAttentions  
from ...modeling_utils import PreTrainedModel  
from ...pytorch_utils import is_torch_greater_or_equal_than_2_1  
from ...utils import (
    ModelOutput,                                   # 基本模型输出类,用于封装模型输出项
    add_start_docstrings,                         # 用于添加模型的开始说明文档
    add_start_docstrings_to_model_forward,        # 用于添加模型输入参数的文档
    is_accelerate_available,                      # 检查加速器模块是否可用
    is_scipy_available,                           # 检查科学计算库是否可用
    logging,                                     # 日志记录模块
    replace_return_docstrings,                    # 替换返回文档说明的函数
    requires_backends,                           # 要求特定后端支持的装饰器
)    
from ...utils.backbone_utils import load_backbone  
from ..detr import DetrConfig  
from .configuration_maskformer import MaskFormerConfig  
from .configuration_maskformer_swin import MaskFormerSwinConfig  

if is_accelerate_available():                      # 检查加速器模块是否存在
    from accelerate import PartialState     
    from accelerate.utils import reduce  

if is_scipy_available():                           # 检查科学计算库存在且可用
    from scipy.optimize import linear_sum_assignment  

logger = logging.get_logger(__name__)               # 创建日志记录器

# "MaskFormerConfig"类实例,用于指定模型配置
_CONFIG_FOR_DOC = "MaskFormerConfig"
# "facebook/maskformer-swin-base-ade"模型的预训练模型地址
_CHECKPOINT_FOR_DOC = "facebook/maskformer-swin-base-ade"

# MaskFormer预训练模型列表
MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = ["facebook/maskformer-swin-base-ade",]

@dataclass                
# 定义"DetrDecoderOutput"类,扩展了"BaseModelOutputWithCrossAttentions"类,用于处理"DETR"解码器的输出项
class DetrDecoderOutput(BaseModelOutputWithCrossAttentions):
    """
    "DetrDecoderOutput"类继承自"BaseModelOutputWithCrossAttentions"类,用于封装"DETREncoder"模块的输出。
    接收一个"CrossAttentions"对象作为属性,并在此基础上添加了一个可选的解码器中间层激活堆栈。
    用于单辅助解码器损失训练时提供额外的特征信息。
    """
    
    # 定义函数的参数列表,包括最后一个隐藏层的输出状态
    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            模型最后一层的隐藏状态的序列输出。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            可选参数,当传递 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回,
            包含元组中的 `torch.FloatTensor`(一个用于嵌入层输出,每层输出一个)的形状为 `(batch_size, sequence_length, hidden_size)`。
            模型每一层的隐藏状态,以及初始嵌入层的输出。
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            可选参数,当传递 `output_attentions=True` 或 `config.output_attentions=True` 时返回,
            包含元组中的 `torch.FloatTensor`(每一层一个)的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
            经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
            可选参数,当同时传递 `output_attentions=True` 和 `config.add_cross_attention=True` 或 `config.output_attentions=True` 时返回,
            包含元组中的 `torch.FloatTensor`(每一层一个)的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
            解码器交叉注意力层的注意力权重,经过注意力 softmax 后,用于计算交叉注意力头中的加权平均值。
        intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
            可选参数,当传递 `config.auxiliary_loss=True` 时返回,
            形状为 `(config.decoder_layers, batch_size, num_queries, hidden_size)` 的中间解码器激活状态。
            每个解码器层的中间激活状态,每个状态经过了层归一化。
    """
    
    # intermediate_hidden_states 变量定义为可选的 `torch.FloatTensor` 类型,表示中间隐藏状态,默认为 None
    intermediate_hidden_states: Optional[torch.FloatTensor] = None
# 定义一个数据类 `MaskFormerPixelLevelModuleOutput`,继承自 `ModelOutput`,表示 MaskFormer 的像素级模块的输出
@dataclass
class MaskFormerPixelLevelModuleOutput(ModelOutput):
    """
    MaskFormer's pixel level module output. It returns both the last and (optionally) the hidden states from the
    `encoder` and `decoder`. By default, the `encoder` is a MaskFormerSwin Transformer and the `decoder` is a Feature
    Pyramid Network (FPN).

    The `encoder_last_hidden_state` are referred on the paper as **images features**, while `decoder_last_hidden_state`
    as **pixel embeddings**

    Args:
        encoder_last_hidden_state (`torch.FloatTensor` of shape`(batch_size, num_channels, height, width)`):
            Last hidden states (final feature map) of the last stage of the encoder.
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the model at
            the output of each stage.
        decoder_last_hidden_state (`torch.FloatTensor` of shape`(batch_size, num_channels, height, width)`):
            Last hidden states (final feature map) of the last stage of the decoder.
        decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the model at
            the output of each stage.
    """

    # 定义属性 `encoder_last_hidden_state`,表示编码器最后一个隐藏状态的张量,形状为 `(batch_size, num_channels, height, width)`
    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
    # 定义属性 `decoder_last_hidden_state`,表示解码器最后一个隐藏状态的张量,形状为 `(batch_size, num_channels, height, width)`
    decoder_last_hidden_state: Optional[torch.FloatTensor] = None
    # 定义属性 `encoder_hidden_states`,表示编码器的隐藏状态的元组,每个元素是一个形状为 `(batch_size, num_channels, height, width)` 的张量
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # 定义属性 `decoder_hidden_states`,表示解码器的隐藏状态的元组,每个元素是一个形状为 `(batch_size, num_channels, height, width)` 的张量
    decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None


@dataclass
class MaskFormerPixelDecoderOutput(ModelOutput):
    """
    MaskFormer's pixel decoder module output, practically a Feature Pyramid Network. It returns the last hidden state
    and (optionally) the hidden states.
    """
    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            模型最后阶段的最后隐藏状态(最终特征图)。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, 当 `output_hidden_states=True` 时返回或当 `config.output_hidden_states=True` 时返回):
            包含多个元素的元组,每个元素是 `torch.FloatTensor`,形状为 `(batch_size, num_channels, height, width)`。
            模型在每一层输出的隐藏状态,还包括初始嵌入的输出。
        attentions (`tuple(torch.FloatTensor)`, *optional*, 当 `output_attentions=True` 时返回或当 `config.output_attentions=True` 时返回):
            包含多个元素的元组,每个元素是 `torch.FloatTensor`,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
            Detr 解码器中注意力权重经过 attention softmax 后的输出,用于计算自注意力头中的加权平均值。
    """

    last_hidden_state: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
# 定义一个数据类,用于存储 [`MaskFormerModel`] 的输出。这个类返回计算 logits 所需的所有隐藏状态。

@dataclass
class MaskFormerModelOutput(ModelOutput):
    """
    Class for outputs of [`MaskFormerModel`]. This class returns all the needed hidden states to compute the logits.

    Args:
        encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Last hidden states (final feature map) of the last stage of the encoder model (backbone).
        pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Last hidden states (final feature map) of the last stage of the pixel decoder model (FPN).
        transformer_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Last hidden states (final feature map) of the last stage of the transformer decoder model.
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder
            model at the output of each stage.
        pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel
            decoder model at the output of each stage.
        transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the
            transformer decoder at the output of each stage.
        hidden_states `tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` containing `encoder_hidden_states`, `pixel_decoder_hidden_states` and
            `decoder_hidden_states`
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights from Detr's decoder after the attention softmax, used to compute the
            weighted average in the self-attention heads.
    """
    # 定义可选的 torch.FloatTensor 类型变量,用于存储编码器的最后隐藏状态
    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
    # 定义可选的 torch.FloatTensor 类型变量,用于存储像素解码器的最后隐藏状态
    pixel_decoder_last_hidden_state: Optional[torch.FloatTensor] = None
    # 定义可选的 torch.FloatTensor 类型变量,用于存储变换器解码器的最后隐藏状态
    transformer_decoder_last_hidden_state: Optional[torch.FloatTensor] = None
    # 定义可选的 Tuple[torch.FloatTensor] 类型变量,用于存储编码器的隐藏状态序列
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # 定义可选的 Tuple[torch.FloatTensor] 类型变量,用于存储像素解码器的隐藏状态序列
    pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # 定义可选的 Tuple[torch.FloatTensor] 类型变量,用于存储变换器解码器的隐藏状态序列
    transformer_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # 定义可选的 Tuple[torch.FloatTensor] 类型变量,用于存储隐藏状态序列
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # 定义可选的 Tuple[torch.FloatTensor] 类型变量,用于存储注意力分布序列
    attentions: Optional[Tuple[torch.FloatTensor]] = None
# 数据类装饰器,用于定义实例分割输出的数据结构,继承自ModelOutput类
@dataclass
class MaskFormerForInstanceSegmentationOutput(ModelOutput):
    """
    Class for outputs of [`MaskFormerForInstanceSegmentation`].

    This output can be directly passed to [`~MaskFormerImageProcessor.post_process_semantic_segmentation`] or or
    [`~MaskFormerImageProcessor.post_process_instance_segmentation`] or
    [`~MaskFormerImageProcessor.post_process_panoptic_segmentation`] depending on the task. Please, see
    [`~MaskFormerImageProcessor] for details regarding usage.

    """

    # 损失值,可选的浮点张量
    loss: Optional[torch.FloatTensor] = None
    # 类别查询的逻辑张量
    class_queries_logits: torch.FloatTensor = None
    # 掩码查询的逻辑张量
    masks_queries_logits: torch.FloatTensor = None
    # 辅助逻辑张量
    auxiliary_logits: torch.FloatTensor = None
    # 编码器最后隐藏状态,可选的浮点张量
    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
    # 像素解码器最后隐藏状态,可选的浮点张量
    pixel_decoder_last_hidden_state: Optional[torch.FloatTensor] = None
    # 变换器解码器最后隐藏状态,可选的浮点张量
    transformer_decoder_last_hidden_state: Optional[torch.FloatTensor] = None
    # 编码器隐藏状态,可选的浮点张量元组
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # 像素解码器隐藏状态,可选的浮点张量元组
    pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # 变换器解码器隐藏状态,可选的浮点张量元组
    transformer_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # 隐藏状态,可选的浮点张量元组
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # 注意力分数,可选的浮点张量元组
    attentions: Optional[Tuple[torch.FloatTensor]] = None


# 重新实现自原始实现的函数
def upsample_like(pixel_values: Tensor, like: Tensor, mode: str = "bilinear") -> Tensor:
    """
    An utility function that upsamples `pixel_values` to match the dimension of `like`.

    Args:
        pixel_values (`torch.Tensor`):
            The tensor we wish to upsample.
        like (`torch.Tensor`):
            The tensor we wish to use as size target.
        mode (str, *optional*, defaults to `"bilinear"`):
            The interpolation mode.

    Returns:
        `torch.Tensor`: The upsampled tensor
    """
    # 获取`like`张量的高度和宽度维度
    _, _, height, width = like.shape
    # 使用双线性插值法对`pixel_values`进行上采样,使其大小与`like`相匹配
    upsampled = nn.functional.interpolate(pixel_values, size=(height, width), mode=mode, align_corners=False)
    return upsampled


# 计算DICE损失的函数
def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor:
    r"""
    Compute the DICE loss, similar to generalized IOU for masks as follows:

    $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$

    In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow

    $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$

    Args:
        inputs (`torch.Tensor`):
            A tensor representing a mask.
        labels (`torch.Tensor`):
            A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
            (0 for the negative class and 1 for the positive class).
        num_masks (`int`):
            The number of masks present in the current batch, used for normalization.

    Returns:
        `torch.Tensor`: The computed loss.
    """
    # 对输入进行sigmoid操作并展平为一维张量,得到预测概率
    probs = inputs.sigmoid().flatten(1)
    # 计算DICE损失的分子部分:2 * 预测概率 * 真实标签的交集
    numerator = 2 * (probs * labels).sum(-1)
    # 计算概率和标签在最后一个维度上的和,分别求和
    denominator = probs.sum(-1) + labels.sum(-1)
    # 计算损失值,使用给定的数值计算公式
    loss = 1 - (numerator + 1) / (denominator + 1)
    # 将所有损失值求和并除以遮罩数量,得到平均损失
    loss = loss.sum() / num_masks
    # 返回计算得到的平均损失值
    return loss
# 从原始实现重构而来的函数,计算逐对的 Sigmoid Focal Loss
def sigmoid_focal_loss(
    inputs: Tensor, labels: Tensor, num_masks: int, alpha: float = 0.25, gamma: float = 2
) -> Tensor:
    r"""
    Focal loss,最初在 [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) 中提出,最初用于 RetinaNet。该损失计算如下:

    $$ \mathcal{L}_{\text{focal loss}} = -(1 - p_t)^{\gamma}\log{(p_t)} $$

    其中 \\(CE(p_t) = -\log{(p_t)}}\\),CE 是标准交叉熵损失。

    请参考论文中的方程式 (1,2,3) 以获得更好的理解。

    Args:
        inputs (`torch.Tensor`):
            任意形状的浮点张量。
        labels (`torch.Tensor`):
            与 inputs 相同形状的张量。存储每个元素的二元分类标签 (0 表示负类,1 表示正类)。
        num_masks (`int`):
            当前批次中存在的掩码数量,用于归一化。
        alpha (float, *可选*, 默认为 0.25):
            在范围 (0,1) 内的加权因子,用于平衡正负例。
        gamma (float, *可选*, 默认为 2.0):
            调整因子 \\(1 - p_t\\) 的指数,用于平衡简单与困难的例子。

    Returns:
        `torch.Tensor`: 计算得到的损失。
    """
    # 使用带 logits 的二元交叉熵损失,不进行归一化
    criterion = nn.BCEWithLogitsLoss(reduction="none")
    # 对输入进行 sigmoid 操作得到概率
    probs = inputs.sigmoid()
    # 计算标准交叉熵损失
    cross_entropy_loss = criterion(inputs, labels)
    # 计算 p_t
    p_t = probs * labels + (1 - probs) * (1 - labels)
    # 计算 focal loss
    loss = cross_entropy_loss * ((1 - p_t) ** gamma)

    # 如果 alpha 大于等于 0,计算 alpha_t
    if alpha >= 0:
        alpha_t = alpha * labels + (1 - alpha) * (1 - labels)
        loss = alpha_t * loss

    # 计算平均损失并进行归一化
    loss = loss.mean(1).sum() / num_masks
    return loss


# 从原始实现重构而来的函数,计算逐对的 Dice Loss
def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
    """
    Dice Loss 的逐对版本,参见 `dice_loss` 以了解用法。

    Args:
        inputs (`torch.Tensor`):
            表示掩码的张量
        labels (`torch.Tensor`):
            与 inputs 相同形状的张量。存储每个元素的二元分类标签 (0 表示负类,1 表示正类)。

    Returns:
        `torch.Tensor`: 每对之间计算得到的损失。
    """
    # 对输入进行 sigmoid 操作并展平为一维
    inputs = inputs.sigmoid().flatten(1)
    # 计算分子,使用矩阵乘法
    numerator = 2 * torch.matmul(inputs, labels.T)
    # 使用广播获取 [num_queries, NUM_CLASSES] 矩阵
    # 计算分母
    denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
    # 计算 Dice Loss
    loss = 1 - (numerator + 1) / (denominator + 1)
    return loss


# 从原始实现重构而来的函数,计算逐对的 Sigmoid Focal Loss
def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float = 0.25, gamma: float = 2.0) -> Tensor:
    r"""
    Sigmoid Focal Loss 的逐对版本,参见 `sigmoid_focal_loss` 以了解用法。
    ```
    # 如果alpha小于0,则引发值错误异常
    if alpha < 0:
        raise ValueError("alpha must be positive")

    # 获取输入张量的高度和宽度(假设输入是一个二维张量)
    height_and_width = inputs.shape[1]

    # 使用二元交叉熵损失函数,但是禁止自动平均(即不对每个样本的损失求平均)
    criterion = nn.BCEWithLogitsLoss(reduction="none")

    # 计算输入张量的sigmoid函数值,即转换为概率
    prob = inputs.sigmoid()

    # 计算正样本的交叉熵损失
    cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs))

    # 计算焦点损失的正样本部分,用于聚焦于困难的样本
    focal_pos = ((1 - prob) ** gamma) * cross_entropy_loss_pos
    focal_pos *= alpha

    # 计算负样本的交叉熵损失
    cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs))

    # 计算焦点损失的负样本部分,用于聚焦于容易的样本
    focal_neg = (prob**gamma) * cross_entropy_loss_neg
    focal_neg *= 1 - alpha

    # 计算最终的损失值,分别乘以标签的转置以加权正负样本
    loss = torch.matmul(focal_pos, labels.T) + torch.matmul(focal_neg, (1 - labels).T)

    # 返回归一化后的损失,即平均每个元素的损失
    return loss / height_and_width
# Copied from transformers.models.detr.modeling_detr.DetrAttention
class DetrAttention(nn.Module):
    """
    Multi-headed attention from 'Attention Is All You Need' paper.

    Here, we add position embeddings to the queries and keys (as explained in the DETR paper).
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        bias: bool = True,
    ):
        super().__init__()
        self.embed_dim = embed_dim  # 初始化注意力机制的嵌入维度
        self.num_heads = num_heads  # 初始化注意力头的数量
        self.dropout = dropout  # 初始化dropout率
        self.head_dim = embed_dim // num_heads  # 计算每个注意力头的维度
        if self.head_dim * num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {num_heads})."
            )
        self.scaling = self.head_dim**-0.5  # 缩放因子,用于注意力计算

        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)  # 用于投影键的线性层
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)  # 用于投影值的线性层
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)  # 用于投影查询的线性层
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)  # 输出投影的线性层

    def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
        return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
        # 重塑张量形状,以便进行多头注意力操作

    def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor], **kwargs):
        position_embeddings = kwargs.pop("position_embeddings", None)

        if kwargs:
            raise ValueError(f"Unexpected arguments {kwargs.keys()}")

        if position_embeddings is not None and object_queries is not None:
            raise ValueError(
                "Cannot specify both position_embeddings and object_queries. Please use just object_queries"
            )

        if position_embeddings is not None:
            logger.warning_once(
                "position_embeddings has been deprecated and will be removed in v4.34. Please use object_queries instead"
            )
            object_queries = position_embeddings

        return tensor if object_queries is None else tensor + object_queries
        # 添加位置嵌入到输入张量中的查询,支持使用对象查询

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        object_queries: Optional[torch.Tensor] = None,
        key_value_states: Optional[torch.Tensor] = None,
        spatial_position_embeddings: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        **kwargs,
    ):
        # 前向传播函数,实现注意力机制的计算过程
    # 初始化方法,用于初始化一个DetrDecoderLayer对象
    def __init__(self, config: DetrConfig):
        # 调用父类的初始化方法
        super().__init__()
        # 设置嵌入维度等于配置文件中的d_model值
        self.embed_dim = config.d_model

        # 创建一个自注意力层对象,使用DetrAttention类实现
        self.self_attn = DetrAttention(
            embed_dim=self.embed_dim,
            num_heads=config.decoder_attention_heads,
            dropout=config.attention_dropout,
        )
        # 设置Dropout层的概率为配置文件中的dropout值
        self.dropout = config.dropout
        # 设置激活函数为配置文件中指定的激活函数
        self.activation_fn = ACT2FN[config.activation_function]
        # 设置激活函数后的Dropout概率为配置文件中的activation_dropout值
        self.activation_dropout = config.activation_dropout

        # 对自注意力层输出进行LayerNorm归一化处理
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        
        # 创建一个编码器注意力层对象,使用DetrAttention类实现
        self.encoder_attn = DetrAttention(
            self.embed_dim,
            config.decoder_attention_heads,
            dropout=config.attention_dropout,
        )
        # 对编码器注意力层输出进行LayerNorm归一化处理
        self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        
        # 使用线性层进行特征变换,输入维度为embed_dim,输出维度为配置文件中的decoder_ffn_dim
        self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
        # 对fc1层输出进行线性变换,输出维度为embed_dim
        self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
        # 对最终输出进行LayerNorm归一化处理
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    # 前向传播方法,定义了如何处理输入数据的流程
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        object_queries: Optional[torch.Tensor] = None,
        query_position_embeddings: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = False,
        **kwargs,
class DetrDecoder(nn.Module):
    """
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetrDecoderLayer`].

    The decoder updates the query embeddings through multiple self-attention and cross-attention layers.

    Some small tweaks for DETR:

    - object_queries and query_position_embeddings are added to the forward pass.
    - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.

    Args:
        config: DetrConfig
    """

    def __init__(self, config: DetrConfig):
        super().__init__()
        self.config = config
        self.dropout = config.dropout
        self.layerdrop = config.decoder_layerdrop

        # Initialize layers as a list of DetrDecoderLayer modules based on config.decoder_layers
        self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)])
        
        # Apply LayerNorm to the output of the last decoder layer
        self.layernorm = nn.LayerNorm(config.d_model)

        # Gradient checkpointing is disabled by default
        self.gradient_checkpointing = False

    def forward(
        self,
        inputs_embeds=None,
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        object_queries=None,
        query_position_embeddings=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        # Forward pass through the decoder layers
        # Each layer updates the query embeddings using self-attention and cross-attention mechanisms
        # object_queries and query_position_embeddings are incorporated if provided
        # If auxiliary_loss is True, also returns hidden states from all decoding layers
        # The method returns a dictionary of output values
        pass  # Placeholder for actual implementation


# refactored from original implementation
class MaskFormerHungarianMatcher(nn.Module):
    """This class computes an assignment between the labels and the predictions of the network.

    For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more
    predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are
    un-matched (and thus treated as non-objects).
    """

    def __init__(self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0):
        """Creates the matcher

        Params:
            cost_class (float, *optional*, defaults to 1.0):
                This is the relative weight of the classification error in the matching cost.
            cost_mask (float, *optional*,  defaults to 1.0):
                This is the relative weight of the focal loss of the binary mask in the matching cost.
            cost_dice (float, *optional*, defaults to 1.0):
                This is the relative weight of the dice loss of the binary mask in the matching cost
        """
        super().__init__()
        if cost_class == 0 and cost_mask == 0 and cost_dice == 0:
            raise ValueError("All costs cant be 0")
        
        # Initialize the relative weights for classification, mask focal loss, and dice loss
        self.cost_class = cost_class
        self.cost_mask = cost_mask
        self.cost_dice = cost_dice

    @torch.no_grad()
    def forward(self):
        pass  # Placeholder for actual implementation
    # 返回对象的字符串表示形式,用于调试和显示
    def __repr__(self):
        # 构建字符串的头部,表示对象的类和名称
        head = "Matcher " + self.__class__.__name__
        # 构建字符串的主体部分,包括成本类、掩码和Dice成本的信息
        body = [
            f"cost_class: {self.cost_class}",  # 显示成本类的数值
            f"cost_mask: {self.cost_mask}",    # 显示成本掩码的数值
            f"cost_dice: {self.cost_dice}",    # 显示Dice成本的数值
        ]
        _repr_indent = 4  # 设置缩进量
        # 将头部和主体部分结合起来,每一行前面加上指定的缩进量
        lines = [head] + [" " * _repr_indent + line for line in body]
        # 将所有行连接成一个多行字符串并返回
        return "\n".join(lines)
# 从原始实现中复制并调整
class MaskFormerLoss(nn.Module):
    def __init__(
        self,
        num_labels: int,
        matcher: MaskFormerHungarianMatcher,
        weight_dict: Dict[str, float],
        eos_coef: float,
    ):
        """
        MaskFormer Loss类。损失计算与DETR非常类似。过程分为两步:
        1) 计算真实标签掩码与模型输出之间的匈牙利分配
        2) 监督每对匹配的真实标签/预测(监督类别和掩码)

        Args:
            num_labels (`int`):
                类别数量。
            matcher (`MaskFormerHungarianMatcher`):
                计算预测和标签之间分配的Torch模块。
            weight_dict (`Dict[str, float]`):
                不同损失要应用的权重字典。
            eos_coef (`float`):
                应用于空类别的权重。
        """

        super().__init__()
        requires_backends(self, ["scipy"])
        self.num_labels = num_labels
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.eos_coef = eos_coef
        # 创建一个权重张量,包含所有类别和一个额外的EOS类别
        empty_weight = torch.ones(self.num_labels + 1)
        empty_weight[-1] = self.eos_coef
        self.register_buffer("empty_weight", empty_weight)

    def _max_by_axis(self, the_list: List[List[int]]) -> List[int]:
        # 按轴找到列表中的最大值
        maxes = the_list[0]
        for sublist in the_list[1:]:
            for index, item in enumerate(sublist):
                maxes[index] = max(maxes[index], item)
        return maxes

    def _pad_images_to_max_in_batch(self, tensors: List[Tensor]) -> Tuple[Tensor, Tensor]:
        # 获取批次中的最大尺寸
        max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors])
        batch_size = len(tensors)
        # 计算最终形状
        batch_shape = [batch_size] + max_size
        b, _, h, w = batch_shape
        # 获取元数据
        dtype = tensors[0].dtype
        device = tensors[0].device
        # 创建零填充的张量和填充掩码
        padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device)
        padding_masks = torch.ones((b, h, w), dtype=torch.bool, device=device)
        # 将张量填充到最大尺寸
        for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks):
            padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor)
            padding_mask[: tensor.shape[1], : tensor.shape[2]] = False

        return padded_tensors, padding_masks

    def loss_labels(
        self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array]
    ) -> Dict[str, Tensor]:
        """Compute the losses related to the labels using cross entropy.

        Args:
            class_queries_logits (`torch.Tensor`):
                A tensor of shape `batch_size, num_queries, num_labels`
            class_labels (`List[torch.Tensor]`):
                List of class labels of shape `(labels)`.
            indices (`Tuple[np.array])`:
                The indices computed by the Hungarian matcher.

        Returns:
            `Dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key:
            - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
        """

        pred_logits = class_queries_logits
        batch_size, num_queries, _ = pred_logits.shape
        # Define CrossEntropyLoss criterion with empty_weight
        criterion = nn.CrossEntropyLoss(weight=self.empty_weight)
        # Obtain indices for permutation based on the Hungarian matcher output
        idx = self._get_predictions_permutation_indices(indices)
        # Concatenate target classes for each query in the batch
        # Shape after concatenation: (batch_size, num_queries)
        target_classes_o = torch.cat([target[j] for target, (_, j) in zip(class_labels, indices)])
        # Initialize target_classes tensor with default values
        # Shape: (batch_size, num_queries)
        target_classes = torch.full(
            (batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64, device=pred_logits.device
        )
        # Update target_classes tensor using the permutation indices
        target_classes[idx] = target_classes_o
        # Transpose pred_logits from "batch_size x num_queries x num_labels" to "batch_size x num_labels x num_queries"
        pred_logits_transposed = pred_logits.transpose(1, 2)
        # Compute cross entropy loss between transposed pred_logits and target_classes
        loss_ce = criterion(pred_logits_transposed, target_classes)
        # Prepare losses dictionary with cross entropy loss
        losses = {"loss_cross_entropy": loss_ce}
        return losses
    ) -> Dict[str, Tensor]:
        """Compute the losses related to the masks using focal and dice loss.

        Args:
            masks_queries_logits (`torch.Tensor`):
                A tensor of shape `batch_size, num_queries, height, width`
            mask_labels (`torch.Tensor`):
                List of mask labels of shape `(labels, height, width)`.
            indices (`Tuple[np.array])`:
                The indices computed by the Hungarian matcher.
            num_masks (`int)`:
                The number of masks, used for normalization.

        Returns:
            `Dict[str, Tensor]`: A dict of `torch.Tensor` containing two keys:
            - **loss_mask** -- The loss computed using sigmoid focal loss on the predicted and ground truth masks.
            - **loss_dice** -- The loss computed using dice loss on the predicted and ground truth
              masks.
        """
        # Get permutation indices for predictions based on Hungarian matcher results
        src_idx = self._get_predictions_permutation_indices(indices)
        # Get permutation indices for targets based on Hungarian matcher results
        tgt_idx = self._get_targets_permutation_indices(indices)

        # Select predicted masks using the permutation indices
        pred_masks = masks_queries_logits[src_idx]

        # Pad and stack target masks to match the shape of predictions
        target_masks, _ = self._pad_images_to_max_in_batch(mask_labels)
        target_masks = target_masks[tgt_idx]

        # Upsample predicted masks to match the size of target masks
        pred_masks = nn.functional.interpolate(
            pred_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
        )
        # Flatten the predictions and targets for loss computation
        pred_masks = pred_masks[:, 0].flatten(1)
        target_masks = target_masks.flatten(1)

        # Compute losses using sigmoid focal loss and dice loss
        losses = {
            "loss_mask": sigmoid_focal_loss(pred_masks, target_masks, num_masks),
            "loss_dice": dice_loss(pred_masks, target_masks, num_masks),
        }
        return losses

    def _get_predictions_permutation_indices(self, indices):
        # Concatenate batch indices for predictions
        batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        # Concatenate prediction indices based on permutation results
        predictions_indices = torch.cat([src for (src, _) in indices])
        return batch_indices, predictions_indices

    def _get_targets_permutation_indices(self, indices):
        # Concatenate batch indices for targets
        batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
        # Concatenate target indices based on permutation results
        target_indices = torch.cat([tgt for (_, tgt) in indices])
        return batch_indices, target_indices

    def forward(
        self,
        masks_queries_logits: Tensor,
        class_queries_logits: Tensor,
        mask_labels: List[Tensor],
        class_labels: List[Tensor],
        auxiliary_predictions: Optional[Dict[str, Tensor]] = None,
        """
        This performs the loss computation.

        Args:
            masks_queries_logits (`torch.Tensor`):
                A tensor of shape `batch_size, num_queries, height, width`
                表示查询掩码的logits张量,形状为 `batch_size, num_queries, height, width`
            class_queries_logits (`torch.Tensor`):
                A tensor of shape `batch_size, num_queries, num_labels`
                表示查询类别的logits张量,形状为 `batch_size, num_queries, num_labels`
            mask_labels (`torch.Tensor`):
                List of mask labels of shape `(labels, height, width)`.
                掩码标签列表,形状为 `(labels, height, width)`
            class_labels (`List[torch.Tensor]`):
                List of class labels of shape `(labels)`.
                类别标签列表,形状为 `(labels)`
            auxiliary_predictions (`Dict[str, torch.Tensor]`, *optional*):
                if `use_auxiliary_loss` was set to `true` in [`MaskFormerConfig`], then it contains the logits from the
                inner layers of the Detr's Decoder.
                可选参数,如果在 `MaskFormerConfig` 中设置了 `use_auxiliary_loss` 为 `true`,则包含来自 Detr 解码器内部层的logits。

        Returns:
            `Dict[str, Tensor]`: A dict of `torch.Tensor` containing two keys:
            - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
              使用交叉熵计算预测标签和真实标签之间的损失。
            - **loss_mask** -- The loss computed using sigmoid focal loss on the predicted and ground truth masks.
              使用sigmoid focal loss计算预测掩码和真实掩码之间的损失。
            - **loss_dice** -- The loss computed using dice loss on the predicted and ground truth masks.
              使用dice loss计算预测掩码和真实掩码之间的损失。
            if `use_auxiliary_loss` was set to `true` in [`MaskFormerConfig`], the dictionary contains additional losses
            for each auxiliary predictions.
            如果在 [`MaskFormerConfig`] 中设置了 `use_auxiliary_loss` 为 `true`,则字典包含每个辅助预测的额外损失。
        """

        # retrieve the matching between the outputs of the last layer and the labels
        # 获取最后一层输出与标签之间的匹配
        indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels)

        # compute the average number of target masks for normalization purposes
        # 计算平均目标掩码数量,用于归一化
        num_masks: Number = self.get_num_masks(class_labels, device=class_labels[0].device)

        # get all the losses
        # 获取所有的损失
        losses: Dict[str, Tensor] = {
            **self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks),
            **self.loss_labels(class_queries_logits, class_labels, indices),
        }

        # in case of auxiliary losses, we repeat this process with the output of each intermediate layer.
        # 如果存在辅助损失,则对每个中间层的输出重复此过程。
        if auxiliary_predictions is not None:
            for idx, aux_outputs in enumerate(auxiliary_predictions):
                masks_queries_logits = aux_outputs["masks_queries_logits"]
                class_queries_logits = aux_outputs["class_queries_logits"]
                loss_dict = self.forward(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
                loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()}
                losses.update(loss_dict)

        return losses
    # 定义一个方法,计算批次中目标掩码的平均数量,用于归一化目的。
    def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor:
        """
        Computes the average number of target masks across the batch, for normalization purposes.
        计算批次中目标掩码的平均数量,用于归一化目的。
        """
        # 计算所有类别标签中的掩码总数
        num_masks = sum([len(classes) for classes in class_labels])
        
        # 将掩码总数转换为张量,并指定数据类型和设备
        num_masks = torch.as_tensor(num_masks, dtype=torch.float, device=device)
        
        # 默认单进程世界大小
        world_size = 1
        
        # 如果加速库可用
        if is_accelerate_available():
            # 如果共享状态非空
            if PartialState._shared_state != {}:
                # 使用共享状态中的减少功能处理掩码总数
                num_masks = reduce(num_masks)
                # 获取部分状态对象的进程数量
                world_size = PartialState().num_processes
        
        # 将掩码总数除以进程数量进行归一化,并确保至少为1
        num_masks = torch.clamp(num_masks / world_size, min=1)
        
        # 返回归一化后的掩码数量
        return num_masks
class MaskFormerFPNConvLayer(nn.Module):
    def __init__(self, in_features: int, out_features: int, kernel_size: int = 3, padding: int = 1):
        """
        A basic module that executes conv - norm - in sequence used in MaskFormer.

        Args:
            in_features (`int`):
                The number of input features (channels).
            out_features (`int`):
                The number of outputs features (channels).
        """
        super().__init__()
        # Define layers for convolution, group normalization, and ReLU activation
        self.layers = [
            nn.Conv2d(in_features, out_features, kernel_size=kernel_size, padding=padding, bias=False),
            nn.GroupNorm(32, out_features),
            nn.ReLU(inplace=True),
        ]
        # Add each layer to the module and name them with their index
        for i, layer in enumerate(self.layers):
            self.add_module(str(i), layer)

    def forward(self, input: Tensor) -> Tensor:
        # Apply each layer sequentially to the input tensor
        hidden_state = input
        for layer in self.layers:
            hidden_state = layer(hidden_state)
        return hidden_state


class MaskFormerFPNLayer(nn.Module):
    def __init__(self, in_features: int, lateral_features: int):
        """
        A Feature Pyramid Network Layer (FPN) layer. It creates a feature map by aggregating features from the previous
        and backbone layer. Due to the spatial mismatch, the tensor coming from the previous layer is upsampled.

        Args:
            in_features (`int`):
                The number of input features (channels).
            lateral_features (`int`):
                The number of lateral features (channels).
        """
        super().__init__()
        # Project features from the lateral connection to match in_features using 1x1 convolution and group normalization
        self.proj = nn.Sequential(
            nn.Conv2d(lateral_features, in_features, kernel_size=1, padding=0, bias=False),
            nn.GroupNorm(32, in_features),
        )
        # Create a convolutional block for further processing of features
        self.block = MaskFormerFPNConvLayer(in_features, in_features)

    def forward(self, down: Tensor, left: Tensor) -> Tensor:
        # Project features from the lateral connection
        left = self.proj(left)
        # Upsample the downsampled features to match the size of the lateral features
        down = nn.functional.interpolate(down, size=left.shape[-2:], mode="nearest")
        # Aggregate features by element-wise addition
        down += left
        # Process the aggregated features using the convolutional block
        down = self.block(down)
        return down


class MaskFormerFPNModel(nn.Module):
    # This class definition continues in the actual code and is incomplete here.
    pass
    # 初始化方法,定义特征金字塔网络的结构
    def __init__(self, in_features: int, lateral_widths: List[int], feature_size: int = 256):
        """
        Feature Pyramid Network, given an input tensor and a set of feature maps of different feature/spatial sizes,
        it creates a list of feature maps with the same feature size.

        Args:
            in_features (`int`):
                The number of input features (channels).
            lateral_widths (`List[int]`):
                A list with the feature (channel) sizes of each lateral connection.
            feature_size (int, *optional*, defaults to 256):
                The feature (channel) size of the resulting feature maps.
        """
        # 调用父类的初始化方法
        super().__init__()
        # 定义特征金字塔网络的起始卷积层
        self.stem = MaskFormerFPNConvLayer(in_features, feature_size)
        # 定义特征金字塔网络的中间层序列,每层是一个MaskFormerFPNLayer对象
        self.layers = nn.Sequential(
            *[MaskFormerFPNLayer(feature_size, lateral_width) for lateral_width in lateral_widths[::-1]]
        )

    # 前向传播方法,计算特征金字塔网络的输出特征图列表
    def forward(self, features: List[Tensor]) -> List[Tensor]:
        # 初始化一个空列表,用于存储特征金字塔网络的输出特征图
        fpn_features = []
        # 获取最后一个特征图
        last_feature = features[-1]
        # 获取除了最后一个特征图外的其他特征图列表
        other_features = features[:-1]
        # 将最后一个特征图送入起始卷积层stem计算
        output = self.stem(last_feature)
        # 逐层处理特征金字塔网络的每一层
        for layer, left in zip(self.layers, other_features[::-1]):
            # 使用当前层处理输出特征图和对应的左侧特征图,得到新的输出特征图
            output = layer(output, left)
            # 将处理后的特征图加入到特征金字塔网络输出列表中
            fpn_features.append(output)
        # 返回特征金字塔网络的所有输出特征图列表
        return fpn_features
# 定义了一个名为 MaskFormerPixelDecoder 的神经网络模块类
class MaskFormerPixelDecoder(nn.Module):
    # 初始化方法,设置模块的参数和属性
    def __init__(self, *args, feature_size: int = 256, mask_feature_size: int = 256, **kwargs):
        r"""
        Pixel Decoder Module proposed in [Per-Pixel Classification is Not All You Need for Semantic
        Segmentation](https://arxiv.org/abs/2107.06278). It first runs the backbone's features into a Feature Pyramid
        Network creating a list of feature maps. Then, it projects the last one to the correct `mask_size`.

        Args:
            feature_size (`int`, *optional*, defaults to 256):
                The feature size (channel dimension) of the FPN feature maps.
            mask_feature_size (`int`, *optional*, defaults to 256):
                The features (channels) of the target masks size \\(C_{\epsilon}\\) in the paper.
        """
        super().__init__()

        # 创建 MaskFormerFPNModel 实例,用于生成特征金字塔网络的特征图列表
        self.fpn = MaskFormerFPNModel(*args, feature_size=feature_size, **kwargs)
        # 使用卷积层将最后一个特征图投影到正确的 mask 尺寸
        self.mask_projection = nn.Conv2d(feature_size, mask_feature_size, kernel_size=3, padding=1)

    # 前向传播方法,处理输入数据并返回输出
    def forward(
        self, features: List[Tensor], output_hidden_states: bool = False, return_dict: bool = True
    ) -> MaskFormerPixelDecoderOutput:
        # 使用特征金字塔网络处理输入特征列表,生成特征金字塔的特征图列表
        fpn_features = self.fpn(features)
        # 获取最后一个特征图并进行投影
        last_feature_projected = self.mask_projection(fpn_features[-1])

        # 根据 return_dict 参数返回不同形式的输出
        if not return_dict:
            return (last_feature_projected, tuple(fpn_features)) if output_hidden_states else (last_feature_projected,)

        # 如果 return_dict 为 True,则返回 MaskFormerPixelDecoderOutput 对象
        return MaskFormerPixelDecoderOutput(
            last_hidden_state=last_feature_projected, hidden_states=tuple(fpn_features) if output_hidden_states else ()
        )


# 复制并改编自原始实现,与 DetrSinePositionEmbedding 实现几乎相同
class MaskFormerSinePositionEmbedding(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
    need paper, generalized to work on images.
    """

    # 初始化方法,设置位置嵌入的参数和属性
    def __init__(
        self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None
    ):
        super().__init__()
        # 如果指定了 scale 参数但未设置 normalize 参数,则抛出异常
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        # 初始化位置特征数量、温度参数、标准化标志和缩放比例
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        self.scale = 2 * math.pi if scale is None else scale
    # 实现 Transformer 模型中的位置编码生成函数
    def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
        # 如果没有给定掩码,创建一个全零的掩码张量,与输入张量的维度匹配
        if mask is None:
            mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
        
        # 计算反掩码,将掩码取反,转换为输入张量的数据类型
        not_mask = (~mask).to(x.dtype)
        
        # 在不被掩码遮挡的区域上计算累积和,作为位置编码的一部分
        y_embed = not_mask.cumsum(1)  # 在第二个维度上进行累积和
        x_embed = not_mask.cumsum(2)  # 在第三个维度上进行累积和
        
        # 如果需要归一化位置编码
        if self.normalize:
            eps = 1e-6
            # 对 y 轴和 x 轴的位置编码进行归一化处理,并乘以缩放因子 self.scale
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        # 生成维度张量,用于计算位置编码
        dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=x.device).type_as(x)
        dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)

        # 根据维度张量计算 x 和 y 的位置编码
        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t

        # 使用正弦和余弦函数堆叠 x 和 y 的位置编码
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)

        # 将 x 和 y 的位置编码连接起来,并将维度顺序转换为 (batch, channels, height, width)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        
        # 返回位置编码张量
        return pos
class PredictionBlock(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, activation: nn.Module) -> None:
        super().__init__()
        # 创建一个包含线性层和激活函数的层列表
        self.layers = [nn.Linear(in_dim, out_dim), activation]
        # 将每个层作为子模块添加到当前模块中,以便在 forward 方法中能够正确调用
        for i, layer in enumerate(self.layers):
            self.add_module(str(i), layer)

    def forward(self, input: Tensor) -> Tensor:
        hidden_state = input
        # 逐层应用网络层和激活函数
        for layer in self.layers:
            hidden_state = layer(hidden_state)
        return hidden_state


class MaskformerMLPPredictionHead(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 3):
        """
        A classic Multi Layer Perceptron (MLP).

        Args:
            input_dim (`int`):
                输入维度。
            hidden_dim (`int`):
                隐藏层维度。
            output_dim (`int`):
                输出维度。
            num_layers (int, *optional*, defaults to 3):
                层数。
        """
        super().__init__()
        # 构建输入和输出维度的列表,用于每个预测块的创建
        in_dims = [input_dim] + [hidden_dim] * (num_layers - 1)
        out_dims = [hidden_dim] * (num_layers - 1) + [output_dim]

        self.layers = []
        # 根据给定维度创建预测块列表
        for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)):
            # 对于除了最后一层外的每一层使用ReLU激活函数,最后一层使用恒等激活函数
            activation = nn.ReLU() if i < num_layers - 1 else nn.Identity()
            # 创建预测块对象
            layer = PredictionBlock(in_dim, out_dim, activation=activation)
            self.layers.append(layer)
            # 将预测块作为子模块添加到当前模块中,使用索引作为名称
            self.add_module(str(i), layer)

    def forward(self, input: Tensor) -> Tensor:
        hidden_state = input
        # 逐层应用预测块
        for layer in self.layers:
            hidden_state = layer(hidden_state)
        return hidden_state


class MaskFormerPixelLevelModule(nn.Module):
    pass  # 空模块,暂无具体实现
    def __init__(self, config: MaskFormerConfig):
        """
        Pixel Level Module proposed in [Per-Pixel Classification is Not All You Need for Semantic
        Segmentation](https://arxiv.org/abs/2107.06278). It runs the input image through a backbone and a pixel
        decoder, generating an image feature map and pixel embeddings.

        Args:
            config ([`MaskFormerConfig`]):
                The configuration used to instantiate this model.
        """
        super().__init__()  # 调用父类的初始化方法

        # 检查配置中是否有`backbone_config`属性,并且其`model_type`为"swin"时
        if getattr(config, "backbone_config") is not None and config.backbone_config.model_type == "swin":
            # 为了向后兼容,创建一个新的`backbone_config`,并从字典形式转换而来
            backbone_config = config.backbone_config
            backbone_config = MaskFormerSwinConfig.from_dict(backbone_config.to_dict())
            # 设置新的`out_features`,这里设置为固定的阶段名称列表
            backbone_config.out_features = ["stage1", "stage2", "stage3", "stage4"]
            config.backbone_config = backbone_config

        # 加载指定配置的背骨网络
        self.encoder = load_backbone(config)

        # 获取背骨网络最后一层的特征通道数
        feature_channels = self.encoder.channels

        # 初始化像素级解码器,传入参数为最后一个特征层的通道数、FPN特征大小、掩码特征大小、以及其它特征层的宽度
        self.decoder = MaskFormerPixelDecoder(
            in_features=feature_channels[-1],
            feature_size=config.fpn_feature_size,
            mask_feature_size=config.mask_feature_size,
            lateral_widths=feature_channels[:-1],
        )

    def forward(
        self, pixel_values: Tensor, output_hidden_states: bool = False, return_dict: bool = True
    ) -> MaskFormerPixelLevelModuleOutput:
        # 将像素值传入编码器,获取特征映射
        features = self.encoder(pixel_values).feature_maps

        # 将特征映射传入解码器,获取解码器的输出
        decoder_output = self.decoder(features, output_hidden_states, return_dict=return_dict)

        # 如果`return_dict`为False,返回特定格式的输出元组
        if not return_dict:
            last_hidden_state = decoder_output[0]  # 解码器输出的最后隐藏状态
            outputs = (features[-1], last_hidden_state)  # 输出包括编码器的最后一个特征映射和解码器的最后隐藏状态
            if output_hidden_states:
                hidden_states = decoder_output[1]  # 解码器的所有隐藏状态
                outputs = outputs + (tuple(features),) + (hidden_states,)  # 输出扩展为包括所有特征映射和隐藏状态
            return outputs

        # 如果`return_dict`为True,构造并返回`MaskFormerPixelLevelModuleOutput`对象
        return MaskFormerPixelLevelModuleOutput(
            encoder_last_hidden_state=features[-1],  # 编码器的最后一个特征映射
            decoder_last_hidden_state=decoder_output.last_hidden_state,  # 解码器的最后隐藏状态
            encoder_hidden_states=tuple(features) if output_hidden_states else (),  # 所有编码器特征映射
            decoder_hidden_states=decoder_output.hidden_states if output_hidden_states else (),  # 所有解码器隐藏状态
        )
class MaskFormerTransformerModule(nn.Module):
    """
    The MaskFormer's transformer module.
    """

    def __init__(self, in_features: int, config: MaskFormerConfig):
        super().__init__()
        hidden_size = config.decoder_config.hidden_size
        should_project = in_features != hidden_size
        # 初始化位置编码器,用于对象查询的位置信息嵌入
        self.position_embedder = MaskFormerSinePositionEmbedding(num_pos_feats=hidden_size // 2, normalize=True)
        # 初始化查询的嵌入层,根据配置的查询数量和隐藏大小
        self.queries_embedder = nn.Embedding(config.decoder_config.num_queries, hidden_size)
        # 如果输入特征与隐藏大小不同,进行卷积投影
        self.input_projection = nn.Conv2d(in_features, hidden_size, kernel_size=1) if should_project else None
        # 初始化解码器
        self.decoder = DetrDecoder(config=config.decoder_config)

    def forward(
        self,
        image_features: Tensor,
        output_hidden_states: bool = False,
        output_attentions: bool = False,
        return_dict: Optional[bool] = None,
    ) -> DetrDecoderOutput:
        if self.input_projection is not None:
            # 如果存在输入投影层,对图像特征进行投影
            image_features = self.input_projection(image_features)
        # 生成对象查询的位置嵌入
        object_queries = self.position_embedder(image_features)
        # 重复查询嵌入以匹配批次大小
        batch_size = image_features.shape[0]
        queries_embeddings = self.queries_embedder.weight.unsqueeze(0).repeat(batch_size, 1, 1)
        # 初始化输入嵌入(用零填充),将会被模型修改
        inputs_embeds = torch.zeros_like(queries_embeddings, requires_grad=True)

        batch_size, num_channels, height, width = image_features.shape
        # 重新排列图像特征和对象查询的维度以便匹配解码器的输入格式
        image_features = image_features.view(batch_size, num_channels, height * width).permute(0, 2, 1)
        object_queries = object_queries.view(batch_size, num_channels, height * width).permute(0, 2, 1)

        # 调用解码器进行前向传播
        decoder_output: DetrDecoderOutput = self.decoder(
            inputs_embeds=inputs_embeds,
            attention_mask=None,
            encoder_hidden_states=image_features,
            encoder_attention_mask=None,
            object_queries=object_queries,
            query_position_embeddings=queries_embeddings,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 返回解码器的输出结果
        return decoder_output


注释:
    # Args 定义了此函数的输入参数
    Args:
        # `pixel_values` 是一个 FloatTensor,表示像素值,形状为 `(batch_size, num_channels, height, width)`
        # 像素值可以通过 `AutoImageProcessor` 获得。详见 `MaskFormerImageProcessor.__call__` 的说明。
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
        
        # `pixel_mask` 是一个 LongTensor,形状为 `(batch_size, height, width)`,可选参数
        # 用于避免在填充像素值上执行注意力操作。掩码的取值范围为 `[0, 1]`:
        #
        # - 1 表示真实像素(即 **未掩码**),
        # - 0 表示填充像素(即 **已掩码**)。
        #
        # [什么是注意力掩码?](../glossary#attention-mask)
        pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
        
        # `output_hidden_states` 是一个布尔值,可选参数
        # 是否返回所有层的隐藏状态。更多细节请参见返回的张量中的 `hidden_states`。
        output_hidden_states (`bool`, *optional*):
        
        # `output_attentions` 是一个布尔值,可选参数
        # 是否返回 Detr 解码器注意力层的注意力张量。
        output_attentions (`bool`, *optional*):
        
        # `return_dict` 是一个布尔值,可选参数
        # 是否返回 `~MaskFormerModelOutput` 而不是普通的元组。
        return_dict (`bool`, *optional*):
"""
Defines the MaskFormerModel class which extends MaskFormerPreTrainedModel.

@add_start_docstrings(
    "The bare MaskFormer Model outputting raw hidden-states without any specific head on top.",
    MASKFORMER_START_DOCSTRING,
)
"""
class MaskFormerModel(MaskFormerPreTrainedModel):
    """
    Initializes a MaskFormerModel instance.

    Args:
        config (MaskFormerConfig): Configuration object specifying model parameters.

    Inherits:
        MaskFormerPreTrainedModel: Base class for MaskFormerModel, pre-trained model.

    Attributes:
        pixel_level_module (MaskFormerPixelLevelModule): Pixel-level module for MaskFormer.
        transformer_module (MaskFormerTransformerModule): Transformer module for MaskFormer.
    """

    def __init__(self, config: MaskFormerConfig):
        """
        Constructor for MaskFormerModel.

        Args:
            config (MaskFormerConfig): Configuration object specifying model parameters.

        Calls super() to initialize from MaskFormerPreTrainedModel, initializes:
            - pixel_level_module (MaskFormerPixelLevelModule): Module for pixel-level operations.
            - transformer_module (MaskFormerTransformerModule): Transformer module for MaskFormer.

        Post-initialization handled by self.post_init().
        """
        super().__init__(config)
        self.pixel_level_module = MaskFormerPixelLevelModule(config)
        self.transformer_module = MaskFormerTransformerModule(
            in_features=self.pixel_level_module.encoder.channels[-1], config=config
        )

        self.post_init()
    # 定义一个方法 `forward`,用于模型的前向传播
    def forward(
        # 输入参数 `pixel_values`,类型为 Tensor,表示输入的像素值
        self,
        # 输入参数 `pixel_mask`,可选的 Tensor 类型,表示像素的掩码,用于指示哪些像素是有效的
        pixel_values: Tensor,
        # 输入参数 `output_hidden_states`,可选的布尔值,控制是否输出隐藏状态
        output_hidden_states: Optional[bool] = None,
        # 输入参数 `output_attentions`,可选的布尔值,控制是否输出注意力权重
        output_attentions: Optional[bool] = None,
        # 输入参数 `return_dict`,可选的布尔值,控制是否返回字典形式的输出
        return_dict: Optional[bool] = None,
class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
    def __init__(self, config: MaskFormerConfig):
        super().__init__(config)
        # 初始化 MaskFormerModel 模型
        self.model = MaskFormerModel(config)
        # 从配置中获取隐藏层大小
        hidden_size = config.decoder_config.hidden_size
        # 创建一个线性层用于类别预测,输出维度为 num_labels + 1(增加一个“空”类别)
        self.class_predictor = nn.Linear(hidden_size, config.num_labels + 1)
        # 创建 MaskformerMLPPredictionHead 实例,用于掩码嵌入
        self.mask_embedder = MaskformerMLPPredictionHead(hidden_size, hidden_size, config.mask_feature_size)

        # 创建 MaskFormerHungarianMatcher 实例,用于匹配器
        self.matcher = MaskFormerHungarianMatcher(
            cost_class=1.0, cost_dice=config.dice_weight, cost_mask=config.mask_weight
        )

        # 设置损失权重字典,用于损失函数 MaskFormerLoss
        self.weight_dict: Dict[str, float] = {
            "loss_cross_entropy": config.cross_entropy_weight,
            "loss_mask": config.mask_weight,
            "loss_dice": config.dice_weight,
        }

        # 创建 MaskFormerLoss 损失函数实例
        self.criterion = MaskFormerLoss(
            config.num_labels,
            matcher=self.matcher,
            weight_dict=self.weight_dict,
            eos_coef=config.no_object_weight,
        )

        # 运行初始化后处理方法
        self.post_init()

    # 计算并返回损失字典
    def get_loss_dict(
        self,
        masks_queries_logits: Tensor,
        class_queries_logits: Tensor,
        mask_labels: Tensor,
        class_labels: Tensor,
        auxiliary_logits: Dict[str, Tensor],
    ) -> Dict[str, Tensor]:
        loss_dict: Dict[str, Tensor] = self.criterion(
            masks_queries_logits, class_queries_logits, mask_labels, class_labels, auxiliary_logits
        )
        # 根据权重字典调整每个损失值
        for key, weight in self.weight_dict.items():
            for loss_key, loss in loss_dict.items():
                if key in loss_key:
                    loss *= weight

        return loss_dict

    # 计算并返回总损失值
    def get_loss(self, loss_dict: Dict[str, Tensor]) -> Tensor:
        return sum(loss_dict.values())

    # 前向传播函数,接受多个输入和输出参数,包括像素值、掩码和类别标签等
    @add_start_docstrings_to_model_forward(MASKFORMER_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=MaskFormerForInstanceSegmentationOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        pixel_values: Tensor,
        mask_labels: Optional[List[Tensor]] = None,
        class_labels: Optional[List[Tensor]] = None,
        pixel_mask: Optional[Tensor] = None,
        output_auxiliary_logits: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 省略了前向传播函数的其余部分,因为没有要注释的代码
        pass

.\models\maskformer\modeling_maskformer_swin.py

# coding=utf-8
# 声明代码文件使用 UTF-8 编码

# 版权声明及许可协议,这里使用 Apache License 2.0
# 详细说明了使用条件,允许了如何使用和分发代码
# 可以在 http://www.apache.org/licenses/LICENSE-2.0 获取许可协议的副本

"""MaskFormer Swin Transformer. The reason Swin Transformer is implemented here is because MaskFormer uses the hidden
states before downsampling, which is different from the default Swin Transformer."""

import collections.abc
import math
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
from torch import Tensor, nn

# 导入自定义的激活函数映射表
from ...activations import ACT2FN
# 导入文件工具函数
from ...file_utils import ModelOutput
# 导入模型输出类,用于承载模型输出结果
from ...modeling_outputs import BackboneOutput
# 导入预训练模型基类
from ...modeling_utils import PreTrainedModel
# 导入 PyTorch 工具函数,如头部剪枝,网格操作等
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
# 导入支撑函数,BackboneMixin 类,支持 Swin Transformer 模型
from ...utils.backbone_utils import BackboneMixin
# 导入 MaskFormer Swin 的配置类
from .configuration_maskformer_swin import MaskFormerSwinConfig


@dataclass
# 继承自 ModelOutput 类,增加了包含隐藏状态空间维度的输出类
class MaskFormerSwinModelOutputWithPooling(ModelOutput):
    """
    Class for MaskFormerSwinModel's outputs that also contains the spatial dimensions of the hidden states.
    """
    """
    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            模型最后一层的隐藏状态序列。
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
            经过平均池化操作后的最后一层隐藏状态。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            一个元组,包含每一层的隐藏状态,形状为 `(batch_size, sequence_length, hidden_size)`。

            模型在每一层的输出隐藏状态,以及初始嵌入输出。
        hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*):
            包含每个隐藏状态的空间维度元组,用于将 `hidden_states` 重塑为 `batch, channels, height, width` 的形式。
            由于填充存在,无法在 `forward` 方法之前推断它们的空间大小。
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            一个元组,包含每一层的注意力权重,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。

            注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
    """

    last_hidden_state: torch.FloatTensor = None
    pooler_output: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
# 数据类装饰器,定义了一个输出模型的基类
@dataclass
class MaskFormerSwinBaseModelOutput(ModelOutput):
    """
    SwinEncoder模型输出的类。

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            模型最后一层的隐藏状态序列。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, 当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
            一个元组的 `torch.FloatTensor`(对应每层的输出和初始嵌入输出),
            形状为 `(batch_size, sequence_length, hidden_size)`。

            模型每一层的隐藏状态加上初始嵌入输出。
        hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*):
            包含每个 `hidden_state` 的空间维度的元组,用于将 `hidden_states` 重塑为 `batch, channels, height, width`。
            由于填充,它们的空间大小在 `forward` 方法之前无法推断。
        attentions (`tuple(torch.FloatTensor)`, *optional*, 当 `output_attentions=True` 或 `config.output_attentions=True` 时返回):
            一个元组的 `torch.FloatTensor`(每层一个),形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。

            经过注意力 softmax 后的注意力权重,用于在自注意力头中计算加权平均值。
    """

    last_hidden_state: torch.FloatTensor = None  # 最后一层的隐藏状态
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None  # 每层的隐藏状态
    hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None  # 隐藏状态的空间维度
    attentions: Optional[Tuple[torch.FloatTensor]] = None  # 注意力权重


# 从transformers.models.swin.modeling_swin.window_partition复制过来的函数
def window_partition(input_feature, window_size):
    """
    将给定输入分割为窗口。
    """
    batch_size, height, width, num_channels = input_feature.shape
    input_feature = input_feature.view(
        batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
    )
    windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
    return windows


# 从transformers.models.swin.modeling_swin.window_reverse复制过来的函数
def window_reverse(windows, window_size, height, width):
    """
    合并窗口以产生更高分辨率的特征。
    """
    num_channels = windows.shape[-1]
    windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
    windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
    return windows


# 从transformers.models.swin.modeling_swin.drop_path复制过来的函数
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
    """
    实现丢弃路径(drop path)操作。

    Args:
        input (torch.Tensor): 输入张量。
        drop_prob (float, optional): 丢弃概率。默认为0.0。
        training (bool, optional): 是否处于训练模式。默认为False。

    Returns:
        torch.Tensor: 处理后的张量。
    """
    # 略
    # 如果 drop_prob 等于 0.0 或者不处于训练状态,直接返回输入,不进行 Drop Path 操作
    if drop_prob == 0.0 or not training:
        return input
    # 计算保留概率
    keep_prob = 1 - drop_prob
    # 计算输出张量的形状,适用于各种维度的张量,而不仅仅是二维卷积神经网络
    shape = (input.shape[0],) + (1,) * (input.ndim - 1)
    # 生成随机张量,与输入张量相同形状,用于二值化
    random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
    random_tensor.floor_()  # 对随机张量进行二值化处理
    # 计算输出,将输入张量除以保留概率,再乘以二值化后的随机张量
    output = input.div(keep_prob) * random_tensor
    # 返回处理后的输出张量
    return output
# 定义一个名为 MaskFormerSwinEmbeddings 的 PyTorch 模块,用于构建补丁和位置嵌入。
class MaskFormerSwinEmbeddings(nn.Module):
    """
    Construct the patch and position embeddings.
    """

    # 初始化方法,接收一个 config 对象作为参数。
    def __init__(self, config):
        super().__init__()

        # 使用 MaskFormerSwinPatchEmbeddings 类创建补丁嵌入对象。
        self.patch_embeddings = MaskFormerSwinPatchEmbeddings(config)
        # 获取补丁数量
        num_patches = self.patch_embeddings.num_patches
        # 获取补丁网格大小
        self.patch_grid = self.patch_embeddings.grid_size

        # 根据配置选择是否使用绝对位置嵌入
        if config.use_absolute_embeddings:
            # 如果使用绝对位置嵌入,则创建一个全零的可学习参数张量
            self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
        else:
            # 否则位置嵌入设为 None
            self.position_embeddings = None

        # LayerNorm 层,用于归一化嵌入向量
        self.norm = nn.LayerNorm(config.embed_dim)
        # Dropout 层,用于随机失活以防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    # 前向传播方法,接收像素值作为输入
    def forward(self, pixel_values):
        # 使用补丁嵌入对象处理像素值,得到嵌入张量和输出维度信息
        embeddings, output_dimensions = self.patch_embeddings(pixel_values)
        # 对嵌入张量进行归一化
        embeddings = self.norm(embeddings)

        # 如果位置嵌入不为 None,则将位置嵌入加到嵌入张量上
        if self.position_embeddings is not None:
            embeddings = embeddings + self.position_embeddings

        # 对嵌入张量进行随机失活
        embeddings = self.dropout(embeddings)

        # 返回处理后的嵌入张量和输出维度信息
        return embeddings, output_dimensions


# 从 transformers.models.swin.modeling_swin.SwinPatchEmbeddings 复制而来的类
class MaskFormerSwinPatchEmbeddings(nn.Module):
    """
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    """

    # 初始化方法,接收一个 config 对象作为参数
    def __init__(self, config):
        super().__init__()
        # 从配置中获取图像大小和补丁大小
        image_size, patch_size = config.image_size, config.patch_size
        # 从配置中获取通道数和嵌入维度大小
        num_channels, hidden_size = config.num_channels, config.embed_dim
        # 如果图像大小和补丁大小不是可迭代对象,则转换为元组形式
        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
        # 计算补丁数量
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])

        # 将初始化的图像大小、补丁大小、通道数、嵌入维度等保存为类属性
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_patches = num_patches
        self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])

        # 使用卷积层将输入的像素值转换为补丁嵌入的隐藏状态
        self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)

    # 可能的填充方法,用于在图像尺寸不是补丁的整数倍时进行填充
    def maybe_pad(self, pixel_values, height, width):
        # 如果宽度不是补丁大小的整数倍,则在宽度方向进行填充
        if width % self.patch_size[1] != 0:
            pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
            pixel_values = nn.functional.pad(pixel_values, pad_values)
        # 如果高度不是补丁大小的整数倍,则在高度方向进行填充
        if height % self.patch_size[0] != 0:
            pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
            pixel_values = nn.functional.pad(pixel_values, pad_values)
        # 返回填充后的像素值张量
        return pixel_values
    # 定义前向传播函数,接受像素值作为输入,返回嵌入向量和输出尺寸元组
    def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
        # 获取像素值张量的形状信息,包括通道数、高度和宽度
        _, num_channels, height, width = pixel_values.shape
        
        # 检查通道数是否与配置中设置的通道数相匹配,如果不匹配则抛出错误
        if num_channels != self.num_channels:
            raise ValueError(
                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
            )
        
        # 如果需要,对输入进行填充,使其能够被 self.patch_size 整除
        pixel_values = self.maybe_pad(pixel_values, height, width)
        
        # 将像素值投影到嵌入空间
        embeddings = self.projection(pixel_values)
        
        # 获取投影后嵌入张量的形状信息,包括通道数、高度和宽度
        _, _, height, width = embeddings.shape
        
        # 计算最终输出的高度和宽度,并存储为元组
        output_dimensions = (height, width)
        
        # 将嵌入张量按第二维展平,并交换第一和第二维的顺序
        embeddings = embeddings.flatten(2).transpose(1, 2)

        # 返回处理后的嵌入向量和输出尺寸元组
        return embeddings, output_dimensions
# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging
class MaskFormerSwinPatchMerging(nn.Module):
    """
    Patch Merging Layer.

    Args:
        input_resolution (`Tuple[int]`):
            Resolution of input feature.
        dim (`int`):
            Number of input channels.
        norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
            Normalization layer class.
    """

    def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
        super().__init__()
        self.input_resolution = input_resolution  # 保存输入特征的分辨率信息
        self.dim = dim  # 输入通道数
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)  # 线性变换层,用于特征维度的变换
        self.norm = norm_layer(4 * dim)  # 标准化层,对输入特征进行标准化处理

    def maybe_pad(self, input_feature, height, width):
        should_pad = (height % 2 == 1) or (width % 2 == 1)
        if should_pad:
            pad_values = (0, 0, 0, width % 2, 0, height % 2)
            input_feature = nn.functional.pad(input_feature, pad_values)
        return input_feature  # 可能对输入特征进行填充操作,使其尺寸符合要求

    def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor:
        height, width = input_dimensions
        batch_size, dim, num_channels = input_feature.shape

        input_feature = input_feature.view(batch_size, height, width, num_channels)  # 将输入特征重塑为四维张量
        input_feature = self.maybe_pad(input_feature, height, width)  # 可能对输入特征进行填充操作
        input_feature_0 = input_feature[:, 0::2, 0::2, :]  # 提取输入特征的子区块
        input_feature_1 = input_feature[:, 1::2, 0::2, :]  # 提取输入特征的子区块
        input_feature_2 = input_feature[:, 0::2, 1::2, :]  # 提取输入特征的子区块
        input_feature_3 = input_feature[:, 1::2, 1::2, :]  # 提取输入特征的子区块
        input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)  # 按最后一个维度拼接特征
        input_feature = input_feature.view(batch_size, -1, 4 * num_channels)  # 将特征重塑为三维张量

        input_feature = self.norm(input_feature)  # 对特征进行标准化处理
        input_feature = self.reduction(input_feature)  # 对特征进行线性变换

        return input_feature


# Copied from transformers.models.swin.modeling_swin.SwinDropPath with Swin->MaskFormerSwin
class MaskFormerSwinDropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""

    def __init__(self, drop_prob: Optional[float] = None) -> None:
        super().__init__()
        self.drop_prob = drop_prob  # 初始化丢弃概率

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return drop_path(hidden_states, self.drop_prob, self.training)  # 调用外部函数 drop_path 进行随机深度丢弃操作

    def extra_repr(self) -> str:
        return "p={}".format(self.drop_prob)  # 返回描述实例状态的字符串
# 从 transformers.models.swin.modeling_swin.SwinSelfAttention 复制而来,修改为 MaskFormerSwinSelfAttention
class MaskFormerSwinSelfAttention(nn.Module):
    def __init__(self, config, dim, num_heads, window_size):
        super().__init__()
        if dim % num_heads != 0:
            raise ValueError(
                f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
            )

        # 设置注意力头数和每个头的大小
        self.num_attention_heads = num_heads
        self.attention_head_size = int(dim / num_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.window_size = (
            window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
        )

        # 创建相对位置偏置表的可学习参数
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
        )

        # 计算窗口内每个位置对之间的相对位置索引
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index)

        # 定义查询、键、值的线性变换层
        self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
        self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
        self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)

        # 定义 dropout 层
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    # 将输入张量重塑为注意力分数的形状
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    # 前向传播函数,接受隐藏状态、注意力掩码、头部掩码和是否输出注意力分数作为输入
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
        ) -> Tuple[torch.Tensor]:  # 定义函数签名,指定返回类型为包含单个张量的元组
        batch_size, dim, num_channels = hidden_states.shape  # 获取隐藏状态的形状信息
        mixed_query_layer = self.query(hidden_states)  # 使用查询函数处理隐藏状态

        key_layer = self.transpose_for_scores(self.key(hidden_states))  # 使用键函数处理隐藏状态并转置
        value_layer = self.transpose_for_scores(self.value(hidden_states))  # 使用值函数处理隐藏状态并转置
        query_layer = self.transpose_for_scores(mixed_query_layer)  # 处理混合查询层并转置

        # 计算注意力分数,即查询和键的点积
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)  # 对注意力分数进行缩放

        # 获取相对位置偏置并调整形状
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
        relative_position_bias = relative_position_bias.view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
        )

        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # 调整相对位置偏置的维度顺序
        attention_scores = attention_scores + relative_position_bias.unsqueeze(0)  # 添加相对位置偏置到注意力分数中

        if attention_mask is not None:
            # 应用预先计算好的注意力掩码(适用于MaskFormerSwinModel forward()函数的所有层)
            mask_shape = attention_mask.shape[0]
            attention_scores = attention_scores.view(
                batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
            )
            attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
            attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)

        # 将注意力分数归一化为注意力概率
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # 使用dropout进行注意力概率的处理
        attention_probs = self.dropout(attention_probs)

        # 如果指定了头部掩码,则应用头部掩码
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)  # 使用注意力概率加权值层得到上下文层
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()  # 调整上下文层的维度顺序
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)  # 调整上下文层的形状

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)  # 返回模型输出

        return outputs  # 返回上下文层和注意力概率的元组
# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->MaskFormerSwin
class MaskFormerSwinSelfOutput(nn.Module):
    def __init__(self, config, dim):
        super().__init__()
        # 创建一个全连接层,输入维度为dim,输出维度为dim
        self.dense = nn.Linear(dim, dim)
        # 创建一个dropout层,使用config中指定的dropout概率
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 对输入的hidden_states进行线性变换
        hidden_states = self.dense(hidden_states)
        # 对线性变换后的结果进行dropout
        hidden_states = self.dropout(hidden_states)
        # 返回处理后的hidden_states
        return hidden_states


# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->MaskFormerSwin
class MaskFormerSwinAttention(nn.Module):
    def __init__(self, config, dim, num_heads, window_size):
        super().__init__()
        # 创建MaskFormerSwinSelfAttention对象
        self.self = MaskFormerSwinSelfAttention(config, dim, num_heads, window_size)
        # 创建MaskFormerSwinSelfOutput对象
        self.output = MaskFormerSwinSelfOutput(config, dim)
        # 初始化一个空集合,用于存储被剪枝的注意力头索引
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        # 查找可剪枝的注意力头及其索引
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )

        # 剪枝线性层
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # 更新超参数并存储被剪枝的头索引
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # 执行自注意力机制,并返回self_outputs
        self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
        # 使用self_outputs[0]和hidden_states作为输入,执行输出层操作
        attention_output = self.output(self_outputs[0], hidden_states)
        # 如果需要输出注意力信息,则将其添加到outputs中
        outputs = (attention_output,) + self_outputs[1:]  # 如果输出注意力信息,将其添加到outputs中
        return outputs


# Copied from transformers.models.swin.modeling_swin.SwinIntermediate with Swin->MaskFormerSwin
class MaskFormerSwinIntermediate(nn.Module):
    def __init__(self, config, dim):
        super().__init__()
        # 创建一个全连接层,输入维度为dim,输出维度为config.mlp_ratio * dim
        self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
        # 如果config.hidden_act是字符串类型,则使用ACT2FN字典中对应的激活函数,否则直接使用config.hidden_act作为激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act
    # 定义一个方法 `forward`,接受一个名为 `hidden_states` 的张量作为输入,并返回一个张量作为输出
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 将输入张量 `hidden_states` 传递给 `self.dense` 层,执行线性变换
        hidden_states = self.dense(hidden_states)
        # 将经过线性变换后的张量 `hidden_states` 应用激活函数 `self.intermediate_act_fn`
        hidden_states = self.intermediate_act_fn(hidden_states)
        # 返回经过激活函数处理后的张量 `hidden_states`
        return hidden_states
# 从 transformers.models.swin.modeling_swin.SwinOutput 复制的类,将 Swin 替换为 MaskFormerSwin
class MaskFormerSwinOutput(nn.Module):
    def __init__(self, config, dim):
        super().__init__()
        # 创建一个线性层,输入维度为 config.mlp_ratio * dim,输出维度为 dim
        self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
        # 定义一个 Dropout 层,使用 config.hidden_dropout_prob 作为丢弃概率
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 前向传播函数,先通过线性层处理 hidden_states
        hidden_states = self.dense(hidden_states)
        # 然后对处理后的结果进行 Dropout
        hidden_states = self.dropout(hidden_states)
        return hidden_states


class MaskFormerSwinLayer(nn.Module):
    def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
        super().__init__()
        # 初始化 MaskFormerSwinLayer 类,设置一些初始参数
        self.shift_size = shift_size
        self.window_size = config.window_size
        self.input_resolution = input_resolution
        # 添加 LayerNorm 层,对输入进行归一化,eps 参数为 config.layer_norm_eps
        self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
        # 定义 MaskFormerSwinAttention 层,处理注意力相关计算
        self.attention = MaskFormerSwinAttention(config, dim, num_heads, self.window_size)
        # 如果 config.drop_path_rate 大于 0.0,则添加 MaskFormerSwinDropPath 层,否则添加一个恒等映射
        self.drop_path = (
            MaskFormerSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
        )
        # 添加 LayerNorm 层,对输入进行归一化,eps 参数为 config.layer_norm_eps
        self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
        # 定义 MaskFormerSwinIntermediate 层,处理中间过渡层的计算
        self.intermediate = MaskFormerSwinIntermediate(config, dim)
        # 定义 MaskFormerSwinOutput 层,处理最终输出层的计算
        self.output = MaskFormerSwinOutput(config, dim)

    def get_attn_mask(self, input_resolution):
        if self.shift_size > 0:
            # 如果 shift_size 大于 0,则计算用于 SW-MSA 的注意力掩码
            height, width = input_resolution
            # 创建一个全零张量作为图像掩码,维度为 (1, height, width, 1)
            img_mask = torch.zeros((1, height, width, 1))
            # 定义高度和宽度的切片区域,根据 window_size 和 shift_size 的值生成不同的切片
            height_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None),
            )
            width_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None),
            )
            count = 0
            # 填充图像掩码张量的不同区域
            for height_slice in height_slices:
                for width_slice in width_slices:
                    img_mask[:, height_slice, width_slice, :] = count
                    count += 1

            # 将图像掩码切分成窗口,并展平为二维张量
            mask_windows = window_partition(img_mask, self.window_size)
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            # 计算注意力掩码,使对角线元素为 0,其余元素分别填充为 -100.0 或 0.0
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None
        return attn_mask
    # 定义一个方法用于可能的填充操作,用于保证输入张量的高度和宽度能被窗口大小整除
    def maybe_pad(self, hidden_states, height, width):
        # 计算左边和顶部需要填充的像素数,默认为0
        pad_left = pad_top = 0
        # 计算右边需要填充的像素数,确保能够被窗口大小整除
        pad_right = (self.window_size - width % self.window_size) % self.window_size
        # 计算底部需要填充的像素数,确保能够被窗口大小整除
        pad_bottom = (self.window_size - height % self.window_size) % self.window_size
        # 组装填充的数值,顺序为 (前填充高度, 后填充高度, 左填充宽度, 右填充宽度, 顶部填充高度, 底部填充高度)
        pad_values = (0, 0, pad_left, pad_right, pad_top, pad_bottom)
        # 对隐藏状态张量进行填充操作,使用给定的填充数值
        hidden_states = nn.functional.pad(hidden_states, pad_values)
        # 返回填充后的隐藏状态张量以及填充数值,用于后续可能的反填充操作
        return hidden_states, pad_values
    def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False):
        # 解构输入维度元组
        height, width = input_dimensions
        # 获取隐藏状态张量的批大小、维度和通道数
        batch_size, dim, channels = hidden_states.size()
        # 保存原始隐藏状态张量
        shortcut = hidden_states

        # Layer normalization 在注意力机制之前应用于隐藏状态张量
        hidden_states = self.layernorm_before(hidden_states)
        # 将隐藏状态张量重新排列为四维张量(批大小、高度、宽度、通道)
        hidden_states = hidden_states.view(batch_size, height, width, channels)
        # 可能需要对隐藏状态张量进行填充,使其大小成为窗口大小的倍数
        hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)

        # 获取填充后张量的维度信息
        _, height_pad, width_pad, _ = hidden_states.shape
        # 如果设置了 cyclic shift
        if self.shift_size > 0:
            # 在指定维度上对隐藏状态进行循环移位
            shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_hidden_states = hidden_states

        # 将隐藏状态分割成窗口
        hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
        # 将分割后的窗口张量重新视图为二维张量
        hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
        # 获取注意力掩码
        attn_mask = self.get_attn_mask((height_pad, width_pad))
        # 如果存在注意力掩码,则将其转移到与隐藏状态窗口相同的设备上
        if attn_mask is not None:
            attn_mask = attn_mask.to(hidden_states_windows.device)

        # 执行自注意力机制
        self_attention_outputs = self.attention(
            hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
        )

        # 获取自注意力机制的输出
        attention_output = self_attention_outputs[0]

        # 如果需要输出注意力权重,则将其添加到输出中
        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights

        # 将注意力输出视图为四维张量
        attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
        # 反转窗口分割,恢复到原始大小
        shifted_windows = window_reverse(
            attention_windows, self.window_size, height_pad, width_pad
        )  # B height' width' C

        # 如果设置了 cyclic shift,将注意力窗口进行反向循环移位
        if self.shift_size > 0:
            attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            attention_windows = shifted_windows

        # 如果存在填充,截取注意力窗口以移除填充部分
        was_padded = pad_values[3] > 0 or pad_values[5] > 0
        if was_padded:
            attention_windows = attention_windows[:, :height, :width, :].contiguous()

        # 将注意力窗口重新视图为三维张量
        attention_windows = attention_windows.view(batch_size, height * width, channels)

        # 将原始隐藏状态张量与注意力窗口加上 drop path 结果相加
        hidden_states = shortcut + self.drop_path(attention_windows)

        # 在注意力机制之后应用 layer normalization
        layer_output = self.layernorm_after(hidden_states)
        # 中间层处理
        layer_output = self.intermediate(layer_output)
        # 在隐藏状态上添加输出层结果
        layer_output = hidden_states + self.output(layer_output)

        # 将层输出添加到总体输出中
        outputs = (layer_output,) + outputs

        # 返回所有输出
        return outputs
# 基于 transformers.models.swin.modeling_swin.SwinStage.__init__ 复制而来的 MaskFormerSwinStage 类
class MaskFormerSwinStage(nn.Module):
    # 初始化函数,接收配置参数 config,特征维度 dim,输入分辨率 input_resolution,深度 depth,注意力头数目 num_heads,丢弃路径 drop_path,降采样 downsample
    def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
        super().__init__()
        # 保存传入的配置参数
        self.config = config
        # 保存特征维度
        self.dim = dim
        # 创建包含多个 MaskFormerSwinLayer 模块的模块列表 blocks
        self.blocks = nn.ModuleList(
            [
                MaskFormerSwinLayer(
                    config=config,
                    dim=dim,
                    input_resolution=input_resolution,
                    num_heads=num_heads,
                    shift_size=0 if (i % 2 == 0) else config.window_size // 2,
                )
                for i in range(depth)
            ]
        )

        # 如果有降采样函数,创建降采样层 self.downsample,否则设为 None
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
        else:
            self.downsample = None

        # 初始时设置 pointing 属性为 False
        self.pointing = False

    # 前向传播函数,接收隐藏状态 hidden_states,输入维度 input_dimensions,头部掩码 head_mask,是否输出注意力 output_attentions,是否输出隐藏状态 output_hidden_states
    def forward(
        self, hidden_states, input_dimensions, head_mask=None, output_attentions=False, output_hidden_states=False
    ):
        # 如果需要输出隐藏状态,则初始化空的元组 all_hidden_states 用于存储所有隐藏状态
        all_hidden_states = () if output_hidden_states else None

        # 获取输入维度的高度和宽度
        height, width = input_dimensions
        # 遍历所有 blocks 中的模块
        for i, block_module in enumerate(self.blocks):
            # 如果需要输出隐藏状态,则将当前隐藏状态添加到 all_hidden_states 中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 获取当前层的头部掩码,如果没有传入头部掩码则为 None
            layer_head_mask = head_mask[i] if head_mask is not None else None

            # 调用当前 block_module 的前向传播函数,计算该模块的隐藏状态
            block_hidden_states = block_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)

            # 更新隐藏状态为当前模块计算得到的隐藏状态的第一个元素
            hidden_states = block_hidden_states[0]

            # 如果需要输出隐藏状态,则将当前隐藏状态添加到 all_hidden_states 中
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

        # 如果存在降采样层 self.downsample
        if self.downsample is not None:
            # 计算降采样后的高度和宽度
            height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
            # 计算输出维度,包括原始和降采样后的尺寸
            output_dimensions = (height, width, height_downsampled, width_downsampled)
            # 调用降采样层的前向传播函数,对隐藏状态进行降采样处理
            hidden_states = self.downsample(hidden_states, input_dimensions)
        else:
            # 如果不存在降采样层,则输出维度与输入维度相同
            output_dimensions = (height, width, height, width)

        # 返回最终的隐藏状态、输出维度以及所有的隐藏状态(如果需要输出)
        return hidden_states, output_dimensions, all_hidden_states


# 基于 transformers.models.swin.modeling_swin.SwinEncoder.__init__ 复制而来的 MaskFormerSwinEncoder 类
class MaskFormerSwinEncoder(nn.Module):
    pass  # 这里暂时没有任何代码,仅为占位符,具体实现可能会在后续添加
    # 初始化函数,接受配置和网格大小作为参数
    def __init__(self, config, grid_size):
        # 调用父类的初始化方法
        super().__init__()
        # 计算网络层数
        self.num_layers = len(config.depths)
        # 保存配置信息
        self.config = config
        # 根据配置中的 drop_path_rate 参数生成一个线性空间的列表,转换为 Python 列表类型
        dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
        # 创建一个 nn.ModuleList,包含多个 MaskFormerSwinStage 模块
        self.layers = nn.ModuleList(
            [
                MaskFormerSwinStage(
                    config=config,
                    # 计算当前层的嵌入维度
                    dim=int(config.embed_dim * 2**i_layer),
                    # 计算当前层的输入分辨率
                    input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
                    # 当前层的深度
                    depth=config.depths[i_layer],
                    # 当前层的注意力头数
                    num_heads=config.num_heads[i_layer],
                    # 当前层的 drop_path 参数,根据当前层的深度切片生成
                    drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
                    # 是否进行下采样,最后一层不进行下采样
                    downsample=MaskFormerSwinPatchMerging if (i_layer < self.num_layers - 1) else None,
                )
                for i_layer in range(self.num_layers)  # 循环创建每一层的 MaskFormerSwinStage 模块
            ]
        )

        # 梯度检查点设置为 False
        self.gradient_checkpointing = False

    # 前向传播函数
    def forward(
        self,
        hidden_states,
        input_dimensions,
        head_mask=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True,
        ):
            # 如果不输出隐藏状态,则初始化为空元组;否则设为 None
            all_hidden_states = () if output_hidden_states else None
            # 初始化所有输入维度为空元组
            all_input_dimensions = ()
            # 如果不输出注意力,则初始化为 None
            all_self_attentions = () if output_attentions else None

            # 如果需要输出隐藏状态
            if output_hidden_states:
                # 将当前层的隐藏状态添加到 all_hidden_states 中
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 遍历所有的层,并获取每层的模块和屏蔽头掩码
            for i, layer_module in enumerate(self.layers):
                layer_head_mask = head_mask[i] if head_mask is not None else None

                # 如果启用了梯度检查点且处于训练模式
                if self.gradient_checkpointing and self.training:
                    # 使用梯度检查点函数执行当前层的调用,并获取隐藏状态、输出维度和所有隐藏状态
                    layer_hidden_states, output_dimensions, layer_all_hidden_states = self._gradient_checkpointing_func(
                        layer_module.__call__,
                        hidden_states,
                        layer_head_mask,
                        output_attentions,
                    )
                else:
                    # 否则,直接调用当前层模块,并获取隐藏状态、输出维度和所有隐藏状态
                    layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module(
                        hidden_states,
                        input_dimensions,
                        layer_head_mask,
                        output_attentions,
                        output_hidden_states,
                    )

                # 更新输入维度为当前输出维度的最后两个维度
                input_dimensions = (output_dimensions[-2], output_dimensions[-1])
                # 将当前输入维度添加到 all_input_dimensions 中
                all_input_dimensions += (input_dimensions,)
                # 如果需要输出隐藏状态,则将当前层的所有隐藏状态添加到 all_hidden_states 中
                if output_hidden_states:
                    all_hidden_states += (layer_all_hidden_states,)

                # 更新隐藏状态为当前层的隐藏状态
                hidden_states = layer_hidden_states

                # 如果需要输出注意力,则将当前层的第二个隐藏状态添加到 all_self_attentions 中
                if output_attentions:
                    all_self_attentions = all_self_attentions + (layer_all_hidden_states[1],)

            # 如果不返回字典,则返回所有非空的结果元组
            if not return_dict:
                return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)

            # 否则,返回 MaskFormerSwinBaseModelOutput 对象,包含最后的隐藏状态、所有隐藏状态、空间维度和注意力
            return MaskFormerSwinBaseModelOutput(
                last_hidden_state=hidden_states,
                hidden_states=all_hidden_states,
                hidden_states_spatial_dimensions=all_input_dimensions,
                attentions=all_self_attentions,
            )
# 从 transformers.models.swin.modeling_swin.SwinPreTrainedModel 复制代码,修改为 MaskFormerSwinPreTrainedModel,类用于 MaskFormerSwin 模型
class MaskFormerSwinPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 使用 MaskFormerSwinConfig 作为配置类
    config_class = MaskFormerSwinConfig
    # 基础模型的前缀为 "model"
    base_model_prefix = "model"
    # 主输入名称为 "pixel_values"
    main_input_name = "pixel_values"
    # 支持梯度检查点
    supports_gradient_checkpointing = True

    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            # 对于线性层和卷积层,使用正态分布初始化权重,标准差为 self.config.initializer_range
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                # 如果存在偏置项,则将其初始化为零
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            # 对于 LayerNorm 层,将偏置项初始化为零,权重初始化为 1.0
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
    def __init__(self, config, add_pooling_layer=True):
        super().__init__(config)
        self.config = config
        self.num_layers = len(config.depths)
        self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))

        # 初始化 MaskFormerSwin 模型的嵌入层
        self.embeddings = MaskFormerSwinEmbeddings(config)
        # 初始化 MaskFormerSwin 模型的编码器
        self.encoder = MaskFormerSwinEncoder(config, self.embeddings.patch_grid)

        # 初始化层归一化层,输入特征数为 self.num_features,epsilon 为 config.layer_norm_eps
        self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
        # 如果设置了 add_pooling_layer 为 True,则初始化自适应平均池化层
        self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None

    def get_input_embeddings(self):
        # 返回输入嵌入层的 patch_embeddings
        return self.embeddings.patch_embeddings

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        # 遍历要修剪的头部 heads_to_prune 字典
        for layer, heads in heads_to_prune.items():
            # 在编码器的每一层中修剪指定的注意力头部
            self.encoder.layer[layer].attention.prune_heads(heads)

    def forward(
        self,
        pixel_values=None,
        head_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        ):
        # 设置输出注意力矩阵选项,默认使用配置文件中的设置
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 设置输出隐藏状态选项,默认使用配置文件中的设置
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 设置是否返回字典的选项,默认使用配置文件中的设置
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 如果未提供像素值,则抛出数值错误
        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        # 准备头部掩码(如果需要)
        # 在 head_mask 中为 1.0 表示保留对应的注意力头
        # attention_probs 的形状为 bsz x n_heads x N x N
        # 输入的 head_mask 形状为 [num_heads] 或者 [num_hidden_layers x num_heads]
        # 将 head_mask 转换为形状 [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, len(self.config.depths))

        # 对像素值进行嵌入操作
        embedding_output, input_dimensions = self.embeddings(pixel_values)

        # 编码器处理阶段
        encoder_outputs = self.encoder(
            embedding_output,
            input_dimensions,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 如果 return_dict 为 True,则使用字典返回
        sequence_output = encoder_outputs.last_hidden_state if return_dict else encoder_outputs[0]
        sequence_output = self.layernorm(sequence_output)

        # 如果存在池化器,则进行池化操作
        pooled_output = None
        if self.pooler is not None:
            pooled_output = self.pooler(sequence_output.transpose(1, 2))
            pooled_output = torch.flatten(pooled_output, 1)

        # 如果 return_dict 为 False,则返回元组形式的输出
        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        # 计算隐藏状态的空间维度
        hidden_states_spatial_dimensions = (input_dimensions,) + encoder_outputs.hidden_states_spatial_dimensions

        # 使用 MaskFormerSwinModelOutputWithPooling 类封装返回结果
        return MaskFormerSwinModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            hidden_states_spatial_dimensions=hidden_states_spatial_dimensions,
            attentions=encoder_outputs.attentions,
        )
        # MaskFormerSwinBackbone 类定义,继承自 MaskFormerSwinPreTrainedModel 和 BackboneMixin
        """
        MaskFormerSwin backbone, designed especially for the MaskFormer framework.

        This classes reshapes `hidden_states` from (`batch_size, sequence_length, hidden_size)` to (`batch_size,
        num_channels, height, width)`). It also adds additional layernorms after each stage.

        Args:
            config (`MaskFormerSwinConfig`):
                The configuration used by [`MaskFormerSwinModel`].
        """
        # 初始化方法,接收 MaskFormerSwinConfig 类型的参数 config
        def __init__(self, config: MaskFormerSwinConfig):
            # 调用父类 MaskFormerSwinPreTrainedModel 的初始化方法
            super().__init__(config)
            # 调用父类 BackboneMixin 的初始化方法
            super()._init_backbone(config)

            # 创建 MaskFormerSwinModel 的实例,并赋值给 self.model
            self.model = MaskFormerSwinModel(config)
            # 检查是否在 out_features 中包含 'stem',若包含则抛出 ValueError 异常
            if "stem" in self.out_features:
                raise ValueError("This backbone does not support 'stem' in the `out_features`.")
            
            # 计算特征图的通道数列表,根据 config 中的 embed_dim 和 depths 参数计算
            self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
            
            # 创建包含各层规范化操作的 nn.ModuleList,每层规范化操作的输入通道数对应 num_features 中的后续元素
            self.hidden_states_norms = nn.ModuleList(
                [nn.LayerNorm(num_channels) for num_channels in self.num_features[1:]]
            )

            # 调用 post_init 方法进行权重初始化和最终处理
            self.post_init()

        # 前向传播方法,接收输入 pixel_values 和可选的输出控制参数
        def forward(
            self,
            pixel_values: Tensor,
            output_hidden_states: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            return_dict: Optional[bool] = None,
        ) -> BackboneOutput:
        # 确定是否返回字典类型的结果,若未指定则使用配置中的默认设置
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # 确定是否输出隐藏状态,若未指定则使用配置中的默认设置
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 确定是否输出注意力权重,若未指定则使用配置中的默认设置
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions

        # 使用模型进行前向传播,指定输出隐藏状态和注意力权重,并以字典类型返回结果
        outputs = self.model(
            pixel_values, output_hidden_states=True, output_attentions=output_attentions, return_dict=True
        )

        # 跳过模型的stem部分,即第一个隐藏状态
        hidden_states = outputs.hidden_states[1:]

        # 将隐藏状态重塑回原始的空间维度
        # 空间维度包含每个阶段的所有高度和宽度,包括嵌入后的维度
        spatial_dimensions: Tuple[Tuple[int, int]] = outputs.hidden_states_spatial_dimensions
        feature_maps = ()
        for i, (hidden_state, stage, (height, width)) in enumerate(
            zip(hidden_states, self.stage_names[1:], spatial_dimensions)
        ):
            norm = self.hidden_states_norms[i]
            # 获取经过最后一个块输出但未经过补丁合并的隐藏状态
            hidden_state_unpolled = hidden_state[-1]
            # 对隐藏状态进行归一化处理
            hidden_state_norm = norm(hidden_state_unpolled)
            # 像素解码器(FPN)需要3D张量(特征)
            batch_size, _, hidden_size = hidden_state_norm.shape
            # 重塑张量形状为 "b (h w) d -> b d h w"
            hidden_state_permuted = (
                hidden_state_norm.permute(0, 2, 1).view((batch_size, hidden_size, height, width)).contiguous()
            )
            if stage in self.out_features:
                feature_maps += (hidden_state_permuted,)

        # 如果不返回字典类型的结果,则构造输出元组
        if not return_dict:
            output = (feature_maps,)
            if output_hidden_states:
                output += (outputs.hidden_states,)
            if output_attentions:
                output += (outputs.attentions,)
            return output

        # 返回BackboneOutput对象,包含特征图、隐藏状态(如果输出)、注意力权重(如果输出)
        return BackboneOutput(
            feature_maps=feature_maps,
            hidden_states=outputs.hidden_states if output_hidden_states else None,
            attentions=outputs.attentions,
        )

.\models\maskformer\__init__.py

# 版权声明和许可信息
#
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入类型检查相关的模块
from typing import TYPE_CHECKING

# 导入相关的工具和模块
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available

# 定义导入结构,指定各模块的导入内容
_import_structure = {
    "configuration_maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig"],
    "configuration_maskformer_swin": ["MaskFormerSwinConfig"],
}

# 检查视觉相关的依赖是否可用,若不可用则抛出OptionalDependencyNotAvailable异常
try:
    if not is_vision_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若可用,则导入特征提取和图像处理相关模块
    _import_structure["feature_extraction_maskformer"] = ["MaskFormerFeatureExtractor"]
    _import_structure["image_processing_maskformer"] = ["MaskFormerImageProcessor"]

# 检查Torch相关的依赖是否可用,若不可用则抛出OptionalDependencyNotAvailable异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若可用,则导入模型相关的模块
    _import_structure["modeling_maskformer"] = [
        "MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
        "MaskFormerForInstanceSegmentation",
        "MaskFormerModel",
        "MaskFormerPreTrainedModel",
    ]
    _import_structure["modeling_maskformer_swin"] = [
        "MaskFormerSwinBackbone",
        "MaskFormerSwinModel",
        "MaskFormerSwinPreTrainedModel",
    ]

# 如果在类型检查模式下
if TYPE_CHECKING:
    # 导入具体的配置和模型相关内容
    from .configuration_maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig
    from .configuration_maskformer_swin import MaskFormerSwinConfig

    try:
        # 再次检查视觉相关依赖是否可用,若不可用则忽略
        if not is_vision_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 若可用,则导入特征提取和图像处理相关模块
        from .feature_extraction_maskformer import MaskFormerFeatureExtractor
        from .image_processing_maskformer import MaskFormerImageProcessor

    try:
        # 再次检查Torch相关依赖是否可用,若不可用则忽略
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 若可用,则导入模型相关的模块
        from .modeling_maskformer import (
            MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
            MaskFormerForInstanceSegmentation,
            MaskFormerModel,
            MaskFormerPreTrainedModel,
        )
        from .modeling_maskformer_swin import (
            MaskFormerSwinBackbone,
            MaskFormerSwinModel,
            MaskFormerSwinPreTrainedModel,
        )

# 如果不是类型检查模式,则设置模块为LazyModule
else:
    import sys

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)

.\models\mbart\configuration_mbart.py

# coding=utf-8
# 上面是指定文件编码为 UTF-8,确保支持多语言字符集
# Copyright 2021, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
# 版权声明,指出代码版权归 Facebook AI Research Team 和 HuggingFace Inc. 团队所有
#
# Licensed under the Apache License, Version 2.0 (the "License");
# 指明采用 Apache 许可证 2.0 版本
# you may not use this file except in compliance with the License.
# 在符合许可证条件的情况下才能使用该文件
# You may obtain a copy of the License at
# 可以在以下网址获取许可证副本
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 软件在适用法律要求或书面同意的情况下按“原样”分发
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 没有任何明示或暗示的担保或条件
# See the License for the specific language governing permissions and
# limitations under the License.
# 请查看许可证,了解特定语言管理权限和限制

""" MBART model configuration"""
# 导入必要的模块和类
from collections import OrderedDict  # 导入 OrderedDict 类,用于有序字典
from typing import Any, Mapping, Optional  # 导入必要的类型声明,如 Any、Mapping、Optional

from ... import PreTrainedTokenizer  # 导入预训练 Tokenizer
from ...configuration_utils import PretrainedConfig  # 导入预训练配置类
from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast  # 导入 ONNX 相关配置类
from ...onnx.utils import compute_effective_axis_dimension  # 导入计算有效轴维度的函数
from ...utils import TensorType, is_torch_available, logging  # 导入 TensorType、is_torch_available 和 logging

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器对象

MBART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "facebook/mbart-large-cc25": "https://huggingface.co/facebook/mbart-large-cc25/resolve/main/config.json",
    # 预训练模型映射字典,指定 MBART 大型模型的配置文件地址
    # 查看所有 MBART 模型地址 https://huggingface.co/models?filter=mbart
}


class MBartConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`MBartModel`]. It is used to instantiate an MBART
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the MBART
    [facebook/mbart-large-cc25](https://huggingface.co/facebook/mbart-large-cc25) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Example:

    ```
    >>> from transformers import MBartConfig, MBartModel

    >>> # Initializing a MBART facebook/mbart-large-cc25 style configuration
    >>> configuration = MBartConfig()

    >>> # Initializing a model (with random weights) from the facebook/mbart-large-cc25 style configuration
    >>> model = MBartModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""
    # MBART 配置类,用于存储 MBART 模型的配置信息

    model_type = "mbart"  # 模型类型为 mbart
    keys_to_ignore_at_inference = ["past_key_values"]  # 推断时忽略的键名列表,包含 "past_key_values"
    attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
    # 属性映射字典,将 num_attention_heads 映射为 encoder_attention_heads,hidden_size 映射为 d_model
    # 定义一个初始化方法,用于初始化Transformer模型的参数和配置
    def __init__(
        self,
        vocab_size=50265,                          # 词汇表大小,默认为50265
        max_position_embeddings=1024,              # 最大位置嵌入长度,默认为1024
        encoder_layers=12,                         # 编码器层数,默认为12层
        encoder_ffn_dim=4096,                      # 编码器中FFN层的维度,默认为4096
        encoder_attention_heads=16,                # 编码器中注意力头的数量,默认为16个
        decoder_layers=12,                         # 解码器层数,默认为12层
        decoder_ffn_dim=4096,                      # 解码器中FFN层的维度,默认为4096
        decoder_attention_heads=16,                # 解码器中注意力头的数量,默认为16个
        encoder_layerdrop=0.0,                     # 编码器层的层丢弃率,默认为0.0(不丢弃)
        decoder_layerdrop=0.0,                     # 解码器层的层丢弃率,默认为0.0(不丢弃)
        use_cache=True,                            # 是否使用缓存,默认为True
        is_encoder_decoder=True,                   # 是否是编码-解码结构,默认为True
        activation_function="gelu",                # 激活函数,默认为GELU
        d_model=1024,                              # 模型维度,默认为1024
        dropout=0.1,                               # 全局Dropout率,默认为0.1
        attention_dropout=0.0,                     # 注意力Dropout率,默认为0.0
        activation_dropout=0.0,                    # 激活函数Dropout率,默认为0.0
        init_std=0.02,                             # 权重初始化标准差,默认为0.02
        classifier_dropout=0.0,                    # 分类器Dropout率,默认为0.0
        scale_embedding=False,                     # 是否缩放嵌入,默认为False
        pad_token_id=1,                            # 填充token的ID,默认为1
        bos_token_id=0,                            # 起始token的ID,默认为0
        eos_token_id=2,                            # 终止token的ID,默认为2
        forced_eos_token_id=2,                     # 强制终止token的ID,默认为2
        **kwargs,                                  # 其他参数,作为关键字参数传递
    ):
        self.vocab_size = vocab_size                # 初始化词汇表大小
        self.max_position_embeddings = max_position_embeddings  # 初始化最大位置嵌入长度
        self.d_model = d_model                      # 初始化模型维度
        self.encoder_ffn_dim = encoder_ffn_dim      # 初始化编码器中FFN层的维度
        self.encoder_layers = encoder_layers        # 初始化编码器层数
        self.encoder_attention_heads = encoder_attention_heads  # 初始化编码器中注意力头的数量
        self.decoder_ffn_dim = decoder_ffn_dim      # 初始化解码器中FFN层的维度
        self.decoder_layers = decoder_layers        # 初始化解码器层数
        self.decoder_attention_heads = decoder_attention_heads  # 初始化解码器中注意力头的数量
        self.dropout = dropout                      # 初始化全局Dropout率
        self.attention_dropout = attention_dropout  # 初始化注意力Dropout率
        self.activation_dropout = activation_dropout  # 初始化激活函数Dropout率
        self.activation_function = activation_function  # 初始化激活函数类型
        self.init_std = init_std                    # 初始化权重初始化标准差
        self.encoder_layerdrop = encoder_layerdrop  # 初始化编码器层的层丢弃率
        self.decoder_layerdrop = decoder_layerdrop  # 初始化解码器层的层丢弃率
        self.classifier_dropout = classifier_dropout  # 初始化分类器Dropout率
        self.use_cache = use_cache                  # 初始化是否使用缓存
        self.num_hidden_layers = encoder_layers     # 初始化隐藏层的数量,与编码器层数相同
        self.scale_embedding = scale_embedding      # 初始化是否缩放嵌入的标志
        super().__init__(                            # 调用父类的初始化方法
            pad_token_id=pad_token_id,               # 传递填充token的ID
            bos_token_id=bos_token_id,               # 传递起始token的ID
            eos_token_id=eos_token_id,               # 传递终止token的ID
            is_encoder_decoder=is_encoder_decoder,   # 传递是否是编码-解码结构的标志
            forced_eos_token_id=forced_eos_token_id, # 传递强制终止token的ID
            **kwargs                                 # 传递其他关键字参数
        )
# 从Bart配置类BartOnnxConfig复制,并将Bart改为MBart
class MBartOnnxConfig(OnnxSeq2SeqConfigWithPast):
    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        # 如果任务是"default"或"seq2seq-lm",则设置通用输入字典
        if self.task in ["default", "seq2seq-lm"]:
            common_inputs = OrderedDict(
                [
                    ("input_ids", {0: "batch", 1: "encoder_sequence"}),
                    ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
                ]
            )

            # 如果使用过去信息,则设置解码器的输入ID和注意力掩码
            if self.use_past:
                common_inputs["decoder_input_ids"] = {0: "batch"}
                common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
            else:
                common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
                common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}

            # 如果使用过去信息,则填充输入中的过去键值对
            if self.use_past:
                self.fill_with_past_key_values_(common_inputs, direction="inputs")
        elif self.task == "causal-lm":
            # TODO: 需要处理这种情况。
            common_inputs = OrderedDict(
                [
                    ("input_ids", {0: "batch", 1: "encoder_sequence"}),
                    ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
                ]
            )
            # 如果使用过去信息,则为每个编码器层设置过去键和值
            if self.use_past:
                num_encoder_layers, _ = self.num_layers
                for i in range(num_encoder_layers):
                    common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
                    common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
        else:
            # 否则设置完整的输入字典,包括解码器相关信息
            common_inputs = OrderedDict(
                [
                    ("input_ids", {0: "batch", 1: "encoder_sequence"}),
                    ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
                    ("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
                    ("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
                ]
            )

        return common_inputs

    @property
    def outputs(self) -> Mapping[str, Mapping[int, str]]:
        # 如果任务是"default"或"seq2seq-lm",则获取默认的输出字典
        if self.task in ["default", "seq2seq-lm"]:
            common_outputs = super().outputs
        else:
            # 否则调用父类的输出方法获取输出字典,并为每个编码器层设置当前键和值
            common_outputs = super(OnnxConfigWithPast, self).outputs
            if self.use_past:
                num_encoder_layers, _ = self.num_layers
                for i in range(num_encoder_layers):
                    common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
                    common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
        return common_outputs
    # 定义一个方法 `_generate_dummy_inputs_for_default_and_seq2seq_lm`,用于生成默认和序列到序列语言模型的虚拟输入数据
    def _generate_dummy_inputs_for_default_and_seq2seq_lm(
        self,
        tokenizer: PreTrainedTokenizer,
        batch_size: int = -1,
        seq_length: int = -1,
        is_pair: bool = False,
        framework: Optional[TensorType] = None,
    ) -> Mapping[str, Any]:
        encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
            tokenizer, batch_size, seq_length, is_pair, framework
        )
        # 生成编码器的输入数据,用于序列分类和问答任务的虚拟输入
        # 根据参数生成编码器的输入数据,包括tokenization对象、批量大小、序列长度、是否为成对输入、框架类型

        # Generate decoder inputs
        decoder_seq_length = seq_length if not self.use_past else 1
        # 计算解码器的序列长度,若使用过去状态则设为1,否则设为与编码器相同的序列长度
        decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
            tokenizer, batch_size, decoder_seq_length, is_pair, framework
        )
        # 生成解码器的输入数据,用于序列分类和问答任务的虚拟输入,根据参数生成解码器的输入数据

        decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
        # 将解码器输入数据的键名修改为带有前缀"decoder_"的形式

        common_inputs = dict(**encoder_inputs, **decoder_inputs)
        # 将编码器和解码器的输入数据合并成一个字典,作为公共输入数据

        if self.use_past:
            if not is_torch_available():
                raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
            else:
                import torch
            # 检查是否使用过去状态,并验证是否安装了PyTorch

            batch, encoder_seq_length = common_inputs["input_ids"].shape
            # 获取批量大小和编码器序列长度

            decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
            # 获取解码器的序列长度

            num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
            # 获取编码器和解码器的注意力头数目

            encoder_shape = (
                batch,
                num_encoder_attention_heads,
                encoder_seq_length,
                self._config.hidden_size // num_encoder_attention_heads,
            )
            # 计算编码器的形状

            decoder_past_length = decoder_seq_length + 3
            # 计算解码器的过去状态长度

            decoder_shape = (
                batch,
                num_decoder_attention_heads,
                decoder_past_length,
                self._config.hidden_size // num_decoder_attention_heads,
            )
            # 计算解码器的形状

            common_inputs["decoder_attention_mask"] = torch.cat(
                [common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
            )
            # 将解码器的注意力遮罩扩展到包括过去状态长度的维度

            common_inputs["past_key_values"] = []
            # 初始化过去键值列表

            # If the number of encoder and decoder layers are present in the model configuration, both are considered
            num_encoder_layers, num_decoder_layers = self.num_layers
            # 获取编码器和解码器的层数

            min_num_layers = min(num_encoder_layers, num_decoder_layers)
            # 计算最小层数

            max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
            # 计算最大层数

            remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
            # 根据层数的差异确定剩余的一方是编码器还是解码器

            for _ in range(min_num_layers):
                common_inputs["past_key_values"].append(
                    (
                        torch.zeros(decoder_shape),
                        torch.zeros(decoder_shape),
                        torch.zeros(encoder_shape),
                        torch.zeros(encoder_shape),
                    )
                )
            # 为每一层编码器和解码器生成零张量,并添加到过去键值列表中

            # TODO: test this.
            shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
            # 根据剩余一方的名称确定形状

            for _ in range(min_num_layers, max_num_layers):
                common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
            # 为剩余层数生成零张量,并添加到过去键值列表中

        return common_inputs
    # 生成用于因果语言模型的虚拟输入数据集,返回一个映射字典
    def _generate_dummy_inputs_for_causal_lm(
        self,
        tokenizer: PreTrainedTokenizer,
        batch_size: int = -1,
        seq_length: int = -1,
        is_pair: bool = False,
        framework: Optional[TensorType] = None,
    ) -> Mapping[str, Any]:
        # 调用另一个方法生成用于序列分类和问答的虚拟输入数据集
        common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
            tokenizer, batch_size, seq_length, is_pair, framework
        )

        if self.use_past:
            # 检查是否需要使用过去键值(past_key_values)
            if not is_torch_available():
                # 如果没有安装 PyTorch,抛出异常
                raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
            else:
                import torch
            batch, seqlen = common_inputs["input_ids"].shape
            # 计算过去键值的长度,不使用与输入相同的长度
            past_key_values_length = seqlen + 2
            # 获取编码器层和注意力头的数量
            num_encoder_layers, _ = self.num_layers
            num_encoder_attention_heads, _ = self.num_attention_heads
            # 定义过去键值的形状
            past_shape = (
                batch,
                num_encoder_attention_heads,
                past_key_values_length,
                self._config.hidden_size // num_encoder_attention_heads,
            )

            # 获取注意力掩码的数据类型
            mask_dtype = common_inputs["attention_mask"].dtype
            # 扩展现有的注意力掩码,增加过去键值的长度
            common_inputs["attention_mask"] = torch.cat(
                [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
            )
            # 初始化过去键值列表
            common_inputs["past_key_values"] = [
                (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
            ]
        # 返回生成的输入数据集字典
        return common_inputs

    # 生成用于序列分类和问答的虚拟输入数据集
    def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
        self,
        tokenizer: PreTrainedTokenizer,
        batch_size: int = -1,
        seq_length: int = -1,
        is_pair: bool = False,
        framework: Optional[TensorType] = None,
    ) -> Mapping[str, Any]:
        # 从 OnnxConfig.generate_dummy_inputs 复制此方法
        # 为了代码清晰性,没有使用 super(OnnxConfigWithPast, self).generate_dummy_inputs
        # 计算有效的轴维度,以避免 ONNX 的优化影响,固定样本维度为2个样本
        batch_size = compute_effective_axis_dimension(
            batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
        )

        # 计算要添加的特殊标记的数量,并计算有效的序列维度,固定令牌维度为8个令牌
        token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
        seq_length = compute_effective_axis_dimension(
            seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
        )

        # 根据计算的批次和序列生成虚拟输入数据
        dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
        # 使用 tokenizer 将虚拟输入转换为张量并返回作为字典
        common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
        return common_inputs
    # 生成虚拟输入数据的方法,返回一个包含各种任务通用输入的字典
    def generate_dummy_inputs(
        self,
        tokenizer: PreTrainedTokenizer,
        batch_size: int = -1,
        seq_length: int = -1,
        is_pair: bool = False,
        framework: Optional[TensorType] = None,
    ) -> Mapping[str, Any]:
        # 如果任务是"default"或"seq2seq-lm",调用特定方法生成对应任务的虚拟输入
        if self.task in ["default", "seq2seq-lm"]:
            common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm(
                tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
            )
        # 如果任务是"causal-lm",调用特定方法生成对应任务的虚拟输入
        elif self.task == "causal-lm":
            common_inputs = self._generate_dummy_inputs_for_causal_lm(
                tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
            )
        # 对于其他任务,调用特定方法生成适用于序列分类和问答的虚拟输入
        else:
            common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
                tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
            )

        # 返回生成的通用输入字典
        return common_inputs

    # 根据任务类型调用不同的方法来扁平化过去的键值对
    def _flatten_past_key_values_(self, flattened_output, name, idx, t):
        # 如果任务是"default"或"seq2seq-lm",调用父类方法来处理
        if self.task in ["default", "seq2seq-lm"]:
            flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
        # 对于其他任务,调用继承类"OnnxSeq2SeqConfigWithPast"的方法来处理
        else:
            flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
                flattened_output, name, idx, t
            )

.\models\mbart\convert_mbart_original_checkpoint_to_pytorch.py

# 导入必要的库和模块
import argparse  # 用于解析命令行参数

import torch  # PyTorch库,用于机器学习和深度学习任务
from torch import nn  # PyTorch的神经网络模块

from transformers import MBartConfig, MBartForConditionalGeneration  # Hugging Face Transformers库中的MBart配置和生成模型


def remove_ignore_keys_(state_dict):
    # 定义要从状态字典中移除的键列表
    ignore_keys = [
        "encoder.version",
        "decoder.version",
        "model.encoder.version",
        "model.decoder.version",
        "_float_tensor",
        "decoder.output_projection.weight",
    ]
    # 遍历并从状态字典中移除指定的键
    for k in ignore_keys:
        state_dict.pop(k, None)


def make_linear_from_emb(emb):
    # 从嵌入层创建线性层
    vocab_size, emb_size = emb.weight.shape  # 获取嵌入层的词汇量和嵌入维度大小
    lin_layer = nn.Linear(vocab_size, emb_size, bias=False)  # 创建一个线性层,没有偏置项
    lin_layer.weight.data = emb.weight.data  # 将嵌入层的权重数据复制到线性层的权重中
    return lin_layer  # 返回创建的线性层


def convert_fairseq_mbart_checkpoint_from_disk(
    checkpoint_path, hf_config_path="facebook/mbart-large-en-ro", finetuned=False, mbart_50=False
):
    # 从磁盘加载Fairseq MBart模型的检查点并转换为Hugging Face MBart模型
    state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]  # 加载检查点的状态字典
    remove_ignore_keys_(state_dict)  # 移除状态字典中指定的键

    vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]  # 获取编码器嵌入词汇表大小

    # 根据预训练配置路径加载MBart配置
    mbart_config = MBartConfig.from_pretrained(hf_config_path, vocab_size=vocab_size)
    if mbart_50 and finetuned:
        mbart_config.activation_function = "relu"  # 如果是50层MBart并且是微调模型,则设置激活函数为ReLU

    # 将decoder的嵌入权重设为shared.weight
    state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]

    # 创建MBart条件生成模型
    model = MBartForConditionalGeneration(mbart_config)
    model.model.load_state_dict(state_dict)  # 加载状态字典到模型中

    if finetuned:
        model.lm_head = make_linear_from_emb(model.model.shared)  # 如果是微调模型,则使用make_linear_from_emb创建lm_head

    return model  # 返回转换后的Hugging Face MBart模型


if __name__ == "__main__":
    parser = argparse.ArgumentParser()  # 创建参数解析器

    # 添加必需的参数
    parser.add_argument(
        "fairseq_path", type=str, help="bart.large, bart.large.cnn or a path to a model.pt on local filesystem."
    )
    parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")

    # 添加可选参数
    parser.add_argument(
        "--hf_config",
        default="facebook/mbart-large-cc25",
        type=str,
        help="Which huggingface architecture to use: mbart-large",
    )
    parser.add_argument("--mbart_50", action="store_true", help="whether the model is mMART-50 checkpoint")
    parser.add_argument("--finetuned", action="store_true", help="whether the model is a fine-tuned checkpoint")

    args = parser.parse_args()  # 解析命令行参数
    model = convert_fairseq_mbart_checkpoint_from_disk(
        args.fairseq_path, hf_config_path=args.hf_config, finetuned=args.finetuned, mbart_50=args.mbart_50
    )
    # 调用模型对象的save_pretrained方法,将模型保存到指定的路径args.pytorch_dump_folder_path中
    model.save_pretrained(args.pytorch_dump_folder_path)

.\models\mbart\modeling_flax_mbart.py

# coding=utf-8
# Copyright 2021, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Flax MBart model."""

# 导入必要的库和模块
import math
import random
from functools import partial
from typing import Callable, Optional, Tuple

import flax.linen as nn  # 导入 Flax 的 linen 模块作为 nn 别名
import jax  # 导入 JAX 库
import jax.numpy as jnp  # 导入 JAX 的 NumPy 实现作为 jnp 别名
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze  # 导入 Flax 的 FrozenDict 和相关函数
from flax.linen import combine_masks, make_causal_mask  # 导入 Flax 的 combine_masks 和 make_causal_mask 函数
from flax.linen.attention import dot_product_attention_weights  # 导入 Flax 的 dot_product_attention_weights 函数
from flax.traverse_util import flatten_dict, unflatten_dict  # 导入 Flax 的 flatten_dict 和 unflatten_dict 函数
from jax import lax  # 导入 JAX 的 lax 模块
from jax.random import PRNGKey  # 导入 JAX 的 PRNGKey 类

# 导入模型输出相关的类
from ...modeling_flax_outputs import (
    FlaxBaseModelOutput,
    FlaxBaseModelOutputWithPastAndCrossAttentions,
    FlaxCausalLMOutputWithCrossAttentions,
    FlaxSeq2SeqLMOutput,
    FlaxSeq2SeqModelOutput,
    FlaxSeq2SeqQuestionAnsweringModelOutput,
    FlaxSeq2SeqSequenceClassifierOutput,
)

# 导入模型工具函数和类
from ...modeling_flax_utils import (
    ACT2FN,
    FlaxPreTrainedModel,
    append_call_sample_docstring,
    append_replace_return_docstrings,
    overwrite_call_docstring,
)

# 导入工具函数和类
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings

# 导入 MBart 的配置类
from .configuration_mbart import MBartConfig

# 获取 logger 对象
logger = logging.get_logger(__name__)

# 用于文档的预训练模型检查点和配置信息
_CHECKPOINT_FOR_DOC = "facebook/mbart-large-cc25"
_CONFIG_FOR_DOC = "MBartConfig"

# MBart 模型的起始文档字符串
MBART_START_DOCSTRING = r"""
    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a Flax Linen
    [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
    regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
    Parameters:
        config ([`MBartConfig`]): Model configuration class with all the parameters of the model.
            初始化模型配置类,包含模型的所有参数。
            使用配置文件初始化不会加载模型的权重,只加载配置。可以查看 [`~FlaxPreTrainedModel.from_pretrained`] 方法加载模型权重。

        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
            计算的数据类型。可以是 `jax.numpy.float32`、`jax.numpy.float16`(在GPU上)、`jax.numpy.bfloat16`(在TPU上)之一。

            可用于在GPU或TPU上启用混合精度训练或半精度推断。如果指定了dtype,则所有计算将使用给定的 `dtype`。

            **注意,这仅指定计算的dtype,并不影响模型参数的dtype。**

            如果要更改模型参数的dtype,请参阅 [`~FlaxPreTrainedModel.to_fp16`] 和 [`~FlaxPreTrainedModel.to_bf16`]。
"""
MBART_INPUTS_DOCSTRING = r"""
"""


MBART_ENCODE_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
            输入序列标记在词汇表中的索引。默认情况下,将忽略填充部分。

            可以使用 [`AutoTokenizer`] 获取这些索引。详见 [`PreTrainedTokenizer.encode`] 和 [`PreTrainedTokenizer.__call__`]。

            [什么是输入 ID?](../glossary#input-ids)
        attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            避免在填充标记索引上执行注意力操作的掩码。掩码值选择在 `[0, 1]`:

            - 1 表示 **未被掩码** 的标记,
            - 0 表示 **被掩码** 的标记。

            [什么是注意力掩码?](../glossary#attention-mask)
        position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            每个输入序列标记在位置嵌入中的位置索引。选择范围是 `[0, config.max_position_embeddings - 1]`。
        output_attentions (`bool`, *optional*):
            是否返回所有注意力层的注意力张量。查看返回的张量下的 `attentions` 获取更多细节。
        output_hidden_states (`bool`, *optional*):
            是否返回所有层的隐藏状态。查看返回的张量下的 `hidden_states` 获取更多细节。
        return_dict (`bool`, *optional*):
            是否返回 [`~utils.ModelOutput`] 而不是普通的元组。
"""

MBART_DECODE_INPUTS_DOCSTRING = r"""
"""


def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int) -> jnp.ndarray:
    """
    将输入 ID 向右移动一个标记,并包装最后一个非填充标记(<LID> 标记)。注意,与其他类似 Bart 模型不同,MBart 没有单一的 `decoder_start_token_id`。
    """
    prev_output_tokens = jnp.array(input_ids).copy()

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")

    # 用 `pad_token_id` 替换标签中可能的 -100 值
    prev_output_tokens = jnp.where(prev_output_tokens == -100, pad_token_id, input_ids)
    index_of_eos = (jnp.where(prev_output_tokens != pad_token_id, 1, 0).sum(axis=-1) - 1).reshape(-1, 1)
    decoder_start_tokens = jnp.array(
        [prev_output_tokens[i, eos_idx] for i, eos_idx in enumerate(index_of_eos)], dtype=jnp.int32
    ).squeeze()

    prev_output_tokens = prev_output_tokens.at[:, 1:].set(prev_output_tokens[:, :-1])
    prev_output_tokens = prev_output_tokens.at[:, 0].set(decoder_start_tokens)

    return prev_output_tokens


# 从 transformers.models.bart.modeling_flax_bart.FlaxBartAttention 复制,将 Bart 改为 MBart
class FlaxMBartAttention(nn.Module):
    config: MBartConfig
    embed_dim: int
    num_heads: int
    # 定义 dropout 参数,默认为 0.0
    dropout: float = 0.0
    # 定义 causal 参数,默认为 False,表示是否使用因果注意力
    causal: bool = False
    # 定义 bias 参数,默认为 True,表示是否使用偏置
    bias: bool = True
    # 定义 dtype 参数,默认为 jnp.float32,表示计算中使用的数据类型

    def setup(self) -> None:
        # 计算每个注意力头的维度
        self.head_dim = self.embed_dim // self.num_heads
        # 检查 embed_dim 必须能被 num_heads 整除,否则抛出 ValueError 异常
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {self.num_heads})."
            )

        # 部分应用 nn.Dense 函数,创建对应的全连接层,并使用指定的参数
        dense = partial(
            nn.Dense,
            self.embed_dim,
            use_bias=self.bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )

        # 创建 q_proj, k_proj, v_proj 和 out_proj 四个全连接层
        self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
        self.out_proj = dense()

        # 创建 dropout 层,用于随机丢弃输入元素以防止过拟合
        self.dropout_layer = nn.Dropout(rate=self.dropout)

        # 如果设置了 causal=True,创建因果注意力掩码,限制注意力只能关注之前的位置
        if self.causal:
            self.causal_mask = make_causal_mask(
                jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
            )

    # 将隐藏状态按注意力头拆分
    def _split_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))

    # 将拆分后的注意力头合并回原始维度
    def _merge_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))

    # 使用 nn.compact 装饰器,标志着这是一个 JAX 用来定义层的函数
    @nn.compact
    def _concatenate_to_cache(self, key, value, query, attention_mask):
        """
        This function takes projected key, value states from a single input token and concatenates the states to cached
        states from previous steps. This function is slighly adapted from the official Flax repository:
        https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
        """
        # 检测是否通过变量"cache"中的"cached_key"来初始化缓存数据。
        is_initialized = self.has_variable("cache", "cached_key")
        # 获取或初始化缓存的键和值,若不存在则创建并初始化为零矩阵,形状和数据类型与输入的key和value相同。
        cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
        cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
        # 获取或初始化缓存的索引,若不存在则创建并初始化为0。
        cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))

        if is_initialized:
            # 获取缓存的维度信息,包括批处理维度、最大长度、注意力头数和每头深度。
            *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
            # 使用新的1D空间切片更新键和值缓存。
            cur_index = cache_index.value
            indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
            key = lax.dynamic_update_slice(cached_key.value, key, indices)
            value = lax.dynamic_update_slice(cached_value.value, value, indices)
            # 更新缓存中的键和值。
            cached_key.value = key
            cached_value.value = value
            # 更新缓存索引,增加已更新的缓存向量数。
            num_updated_cache_vectors = query.shape[1]
            cache_index.value = cache_index.value + num_updated_cache_vectors
            # 生成用于缓存的因果掩码:单个查询位置仅应关注已生成并缓存的键位置,而不是剩余的零元素。
            pad_mask = jnp.broadcast_to(
                jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
                tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
            )
            # 结合输入的注意力掩码和生成的填充掩码。
            attention_mask = combine_masks(pad_mask, attention_mask)
        # 返回更新后的键、值和注意力掩码。
        return key, value, attention_mask
# MBart 编码器层定义,继承自 nn.Module 类
class FlaxMBartEncoderLayer(nn.Module):
    # MBart 配置参数
    config: MBartConfig
    # 计算时的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 初始化方法
    def setup(self) -> None:
        # 设置嵌入维度为配置中的模型维度
        self.embed_dim = self.config.d_model
        # 创建 MBart 自注意力层对象
        self.self_attn = FlaxMBartAttention(
            config=self.config,
            embed_dim=self.embed_dim,
            num_heads=self.config.encoder_attention_heads,
            dropout=self.config.attention_dropout,
            dtype=self.dtype,
        )
        # 自注意力层后的 LayerNorm 层
        self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
        # 随机失活层
        self.dropout_layer = nn.Dropout(rate=self.config.dropout)
        # 激活函数
        self.activation_fn = ACT2FN[self.config.activation_function]
        # 激活函数后的随机失活层
        self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
        # 第一个全连接层
        self.fc1 = nn.Dense(
            self.config.encoder_ffn_dim,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )
        # 第二个全连接层,最终输出维度与嵌入维度相同
        self.fc2 = nn.Dense(
            self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
        )
        # 最终的 LayerNorm 层
        self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)

    # 对象调用方法,执行编码器层的前向传播
    def __call__(
        self,
        hidden_states: jnp.ndarray,
        attention_mask: jnp.ndarray,
        output_attentions: bool = True,
        deterministic: bool = True,
    ) -> Tuple[jnp.ndarray]:
        # 保存残差连接
        residual = hidden_states
        # 应用自注意力层的 LayerNorm
        hidden_states = self.self_attn_layer_norm(hidden_states)
        # 执行自注意力计算,返回计算结果和注意力权重
        hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask)
        # 应用随机失活层
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
        # 添加残差连接
        hidden_states = residual + hidden_states

        # 保存残差连接
        residual = hidden_states
        # 应用最终的 LayerNorm
        hidden_states = self.final_layer_norm(hidden_states)
        # 应用激活函数和第一个全连接层
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        # 应用激活函数后的随机失活层
        hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
        # 应用第二个全连接层
        hidden_states = self.fc2(hidden_states)
        # 应用最终的随机失活层
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
        # 添加残差连接
        hidden_states = residual + hidden_states

        # 构建输出元组,包含最终的隐藏状态
        outputs = (hidden_states,)

        # 如果需要输出注意力权重,添加到输出元组中
        if output_attentions:
            outputs += (attn_weights,)

        return outputs


# 从 transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayerCollection 复制并修改为 MBart
class FlaxMBartEncoderLayerCollection(nn.Module):
    # MBart 配置参数
    config: MBartConfig
    # 计算时的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32  # 计算的数据类型

    # 初始化方法
    def setup(self):
        # 创建多层 MBart 编码器层列表
        self.layers = [
            FlaxMBartEncoderLayer(self.config, name=str(i), dtype=self.dtype)
            for i in range(self.config.encoder_layers)
        ]
        # 编码器层的层级丢弃率
        self.layerdrop = self.config.encoder_layerdrop
    # 定义一个调用方法,用于执行模型的前向传播
    def __call__(
        self,
        hidden_states,
        attention_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 初始化存储所有注意力权重的变量,如果不需要输出注意力权重则置为 None
        all_attentions = () if output_attentions else None
        # 初始化存储所有隐藏状态的变量,如果不需要输出隐藏状态则置为 None
        all_hidden_states = () if output_hidden_states else None

        # 遍历每个编码器层
        for encoder_layer in self.layers:
            # 如果需要输出隐藏状态,则将当前隐藏状态添加到 all_hidden_states 中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)
            # 添加层级丢弃功能,参见论文 https://arxiv.org/abs/1909.11556
            dropout_probability = random.uniform(0, 1)
            # 如果不是确定性执行且随机数小于层级丢弃率,则跳过当前层
            if not deterministic and (dropout_probability < self.layerdrop):
                layer_outputs = (None, None)  # 表示跳过当前层,输出为空
            else:
                # 否则,执行当前编码器层的前向传播,获取输出
                layer_outputs = encoder_layer(
                    hidden_states,
                    attention_mask,
                    output_attentions,
                    deterministic,
                )
            # 更新隐藏状态为当前层的输出的第一个元素
            hidden_states = layer_outputs[0]
            # 如果需要输出注意力权重,则将当前层的注意力权重添加到 all_attentions 中
            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        # 如果需要输出隐藏状态,则将最终的隐藏状态添加到 all_hidden_states 中
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        # 汇总模型输出,包括最终隐藏状态、所有隐藏状态和所有注意力权重
        outputs = (hidden_states, all_hidden_states, all_attentions)

        # 如果不需要以字典形式返回结果,则返回一个元组,排除 None 的部分
        if not return_dict:
            return tuple(v for v in outputs if v is not None)

        # 否则,以 FlaxBaseModelOutput 对象形式返回结果,包括最终隐藏状态、所有隐藏状态和所有注意力权重
        return FlaxBaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )
class FlaxMBartDecoderLayer(nn.Module):
    # 定义类变量 config,类型为 MBartConfig,用于存储配置信息
    config: MBartConfig
    # 定义类变量 dtype,默认为 jnp.float32,指定数据类型为 32 位浮点数

    def setup(self) -> None:
        # 初始化函数,设置层的参数和模型结构

        # 将 self.embed_dim 设置为 self.config.d_model,表示嵌入维度等于模型维度
        self.embed_dim = self.config.d_model

        # 初始化 self.self_attn 为 FlaxMBartAttention 类实例,用于自注意力机制
        self.self_attn = FlaxMBartAttention(
            config=self.config,
            embed_dim=self.embed_dim,
            num_heads=self.config.decoder_attention_heads,
            dropout=self.config.attention_dropout,
            causal=True,
            dtype=self.dtype,
        )

        # 初始化 self.dropout_layer 为 Dropout 层,用于随机失活
        self.dropout_layer = nn.Dropout(rate=self.config.dropout)

        # 根据配置选择激活函数,并初始化 self.activation_fn
        self.activation_fn = ACT2FN[self.config.activation_function]

        # 初始化 self.activation_dropout_layer 为 Dropout 层,用于激活函数的随机失活
        self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)

        # 初始化 self.self_attn_layer_norm 为 LayerNorm 层,用于自注意力机制后的归一化
        self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)

        # 初始化 self.encoder_attn 为 FlaxMBartAttention 类实例,用于编码器-解码器注意力机制
        self.encoder_attn = FlaxMBartAttention(
            config=self.config,
            embed_dim=self.embed_dim,
            num_heads=self.config.decoder_attention_heads,
            dropout=self.config.attention_dropout,
            dtype=self.dtype,
        )

        # 初始化 self.encoder_attn_layer_norm 为 LayerNorm 层,用于编码器-解码器注意力机制后的归一化
        self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)

        # 初始化 self.fc1 为 Dense(全连接)层,用于第一个前馈神经网络(FFN)层
        self.fc1 = nn.Dense(
            self.config.decoder_ffn_dim,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )

        # 初始化 self.fc2 为 Dense 层,用于第二个前馈神经网络(FFN)层
        self.fc2 = nn.Dense(
            self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
        )

        # 初始化 self.final_layer_norm 为 LayerNorm 层,用于最终输出的归一化
        self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)

    def __call__(
        self,
        hidden_states: jnp.ndarray,
        attention_mask: jnp.ndarray,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        output_attentions: bool = True,
        deterministic: bool = True,
        # 定义 __call__ 方法,用于模型调用时的前向传播
    ) -> Tuple[jnp.ndarray]:  
        # 将输入的隐藏状态作为残差保存
        residual = hidden_states  
        # 对当前的隐藏状态进行自注意力层归一化处理
        hidden_states = self.self_attn_layer_norm(hidden_states)

        # 自注意力机制
        # 调用自注意力层处理隐藏状态,返回处理后的新隐藏状态和自注意力权重
        hidden_states, self_attn_weights = self.self_attn(
            hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
        )
        # 对自注意力层输出的隐藏状态进行 dropout 处理
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
        # 将残差与处理后的隐藏状态相加,形成新的隐藏状态
        hidden_states = residual + hidden_states

        # 跨注意力块
        cross_attn_weights = None
        # 如果存在编码器的隐藏状态
        if encoder_hidden_states is not None:
            # 将当前的隐藏状态作为残差保存
            residual = hidden_states

            # 对当前的隐藏状态进行编码器注意力层归一化处理
            hidden_states = self.encoder_attn_layer_norm(hidden_states)
            # 调用编码器注意力层处理隐藏状态,返回处理后的新隐藏状态和编码器注意力权重
            hidden_states, cross_attn_weights = self.encoder_attn(
                hidden_states=hidden_states,
                key_value_states=encoder_hidden_states,
                attention_mask=encoder_attention_mask,
            )
            # 对编码器注意力层输出的隐藏状态进行 dropout 处理
            hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
            # 将残差与处理后的隐藏状态相加,形成新的隐藏状态
            hidden_states = residual + hidden_states

        # 全连接层处理
        # 将当前的隐藏状态作为残差保存
        residual = hidden_states
        # 对当前的隐藏状态进行最终归一化处理
        hidden_states = self.final_layer_norm(hidden_states)
        # 应用激活函数到全连接层的第一层
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        # 对全连接层的第一层输出进行 dropout 处理
        hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
        # 继续传播到全连接层的第二层
        hidden_states = self.fc2(hidden_states)
        # 对全连接层的第二层输出进行 dropout 处理
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
        # 将残差与处理后的隐藏状态相加,形成新的隐藏状态
        hidden_states = residual + hidden_states

        # 准备输出
        outputs = (hidden_states,)

        # 如果需要输出注意力权重信息,则将自注意力权重和编码器注意力权重添加到输出中
        if output_attentions:
            outputs += (self_attn_weights, cross_attn_weights)

        # 返回输出结果
        return outputs
# 从 transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayerCollection 复制代码并将 Bart 改为 MBart
class FlaxMBartDecoderLayerCollection(nn.Module):
    config: MBartConfig  # 类型注解,指定 MBartConfig 类型的 config 变量
    dtype: jnp.dtype = jnp.float32  # 计算过程中使用的数据类型,默认为 jnp.float32

    def setup(self):
        # 创建 self.layers 列表,包含 self.config.decoder_layers 个 FlaxMBartDecoderLayer 对象
        self.layers = [
            FlaxMBartDecoderLayer(self.config, name=str(i), dtype=self.dtype)
            for i in range(self.config.decoder_layers)
        ]
        self.layerdrop = self.config.decoder_layerdrop  # 设置 layerdrop 参数为 config 中的 decoder_layerdrop 值

    def __call__(
        self,
        hidden_states,
        attention_mask,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # decoder layers
        all_hidden_states = () if output_hidden_states else None  # 如果不输出 hidden_states,则设置为 None
        all_self_attns = () if output_attentions else None  # 如果不输出 self-attention,则设置为 None
        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None  # 如果不输出 cross-attention 或者没有 encoder_hidden_states,则设置为 None

        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)  # 将当前 hidden_states 添加到 all_hidden_states 中
                # 添加 LayerDrop(参见 https://arxiv.org/abs/1909.11556 进行描述)

            dropout_probability = random.uniform(0, 1)  # 随机生成一个 dropout 概率
            if not deterministic and (dropout_probability < self.layerdrop):
                layer_outputs = (None, None, None)  # 如果未指定 deterministic 或者 dropout 概率小于 layerdrop,则输出为 None
            else:
                # 调用 decoder_layer 进行解码层计算
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_attention_mask,
                    init_cache=init_cache,
                    output_attentions=output_attentions,
                    deterministic=deterministic,
                )

            hidden_states = layer_outputs[0]  # 更新 hidden_states 为当前层的输出第一个元素
            if output_attentions:
                all_self_attns += (layer_outputs[1],)  # 将当前层的 self-attention 输出添加到 all_self_attns 中

                if encoder_hidden_states is not None:
                    all_cross_attentions += (layer_outputs[2],)  # 如果存在 encoder_hidden_states,则将 cross-attention 输出添加到 all_cross_attentions 中

        # 添加来自最后解码层的 hidden states
        if output_hidden_states:
            all_hidden_states += (hidden_states,)  # 将最后一个 hidden_states 添加到 all_hidden_states 中

        # 构建输出列表
        outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]

        if not return_dict:
            return tuple(v for v in outputs if v is not None)  # 如果不返回字典,则返回非 None 的元组

        return FlaxBaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
            cross_attentions=all_cross_attentions,
        )
    """Head for sentence-level classification tasks."""
    
    # 定义一个类用于处理句子级别的分类任务,以下是其成员变量和方法的定义和说明

    config: MBartConfig
    # 用于存储配置信息的变量,类型为 MBartConfig 类型的对象

    inner_dim: int
    # 用于存储中间维度大小的整数变量,表示神经网络中间层的维度

    num_classes: int
    # 用于存储分类类别数量的整数变量,表示分类任务的输出类别数目

    pooler_dropout: float
    # 用于存储池化层的dropout率的浮点数变量,控制神经网络在训练中的丢弃比例

    dtype: jnp.dtype = jnp.float32
    # 数据类型,默认为 jax 的浮点数类型 jnp.float32

    def setup(self):
        # 初始化方法,用于设置类的各个成员变量

        self.dense = nn.Dense(
            self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
        )
        # 创建一个全连接层对象 self.dense,设置输入维度为 inner_dim,数据类型为 dtype,并使用正态分布初始化权重

        self.dropout = nn.Dropout(rate=self.pooler_dropout)
        # 创建一个 dropout 层对象 self.dropout,设置丢弃率为 pooler_dropout

        self.out_proj = nn.Dense(
            self.num_classes,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )
        # 创建一个输出投影层对象 self.out_proj,设置输出维度为 num_classes,数据类型为 dtype,并使用正态分布初始化权重

    def __call__(self, hidden_states: jnp.ndarray, deterministic: bool):
        # 类的调用方法,用于实现类的前向传播过程

        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        # 对输入的 hidden_states 应用 dropout 操作,根据 deterministic 参数决定是否使用确定性丢弃

        hidden_states = self.dense(hidden_states)
        # 将经过 dropout 后的 hidden_states 输入到全连接层 self.dense 中进行线性变换

        hidden_states = jnp.tanh(hidden_states)
        # 对全连接层的输出应用双曲正切激活函数

        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        # 对激活后的 hidden_states 再次应用 dropout 操作

        hidden_states = self.out_proj(hidden_states)
        # 将经过 dropout 的 hidden_states 输入到输出投影层 self.out_proj 中进行线性变换,得到最终的分类结果

        return hidden_states
        # 返回处理后的结果 hidden_states
# 定义一个名为 FlaxMBartEncoder 的类,继承自 nn.Module,用于 MBart 编码器模型
class FlaxMBartEncoder(nn.Module):
    # 类属性:MBart 的配置对象
    config: MBartConfig
    # 类属性:嵌入层对象,用于输入的词嵌入
    embed_tokens: nn.Embed
    # 类属性:计算中使用的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32  # 计算时使用的数据类型

    # 定义初始化方法 setup()
    def setup(self):
        # 初始化 dropout 层,根据配置中的 dropout 率
        self.dropout_layer = nn.Dropout(rate=self.config.dropout)

        # 获取嵌入维度大小
        embed_dim = self.config.d_model
        # 获取填充符号的索引
        self.padding_idx = self.config.pad_token_id
        # 获取源序列的最大位置编码长度
        self.max_source_positions = self.config.max_position_embeddings
        # 根据配置设置嵌入的缩放因子
        self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0

        # MBart 的特殊设置:如果指定了 padding_idx,则将嵌入的 id 偏移 2,并相应地调整 num_embeddings。其他模型不需要这个处理
        self.offset = 2
        # 初始化位置嵌入层
        self.embed_positions = nn.Embed(
            self.config.max_position_embeddings + self.offset,  # 位置嵌入层的大小,考虑了偏移量
            embed_dim,  # 嵌入的维度大小
            embedding_init=jax.nn.initializers.normal(self.config.init_std),  # 初始化方法,使用正态分布
        )
        # 初始化编码器层集合
        self.layers = FlaxMBartEncoderLayerCollection(self.config, self.dtype)
        # 初始化嵌入层的 LayerNorm
        self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
        # 初始化输出层的 LayerNorm
        self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)

    # 定义调用方法 __call__(),用于执行编码器的前向传播
    def __call__(
        self,
        input_ids,  # 输入的 token ids
        attention_mask,  # 注意力掩码,用于指示哪些位置是填充的
        position_ids,  # 位置 ids
        output_attentions: bool = False,  # 是否输出注意力权重
        output_hidden_states: bool = False,  # 是否输出隐藏状态
        return_dict: bool = True,  # 是否以字典形式返回结果
        deterministic: bool = True,  # 是否使用确定性计算
    ):
        # 获取输入的形状信息
        input_shape = input_ids.shape
        # 将输入 ids 展平为二维张量
        input_ids = input_ids.reshape(-1, input_shape[-1])

        # 使用嵌入 tokens 和缩放因子来嵌入输入的 ids
        inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

        # 根据位置 ids 和偏移量获取位置嵌入
        embed_pos = self.embed_positions(position_ids + self.offset)

        # 将输入嵌入和位置嵌入相加得到隐藏状态
        hidden_states = inputs_embeds + embed_pos
        # 对输入嵌入的 LayerNorm 处理
        hidden_states = self.layernorm_embedding(hidden_states)
        # 应用 dropout
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)

        # 调用编码器层集合的前向传播
        outputs = self.layers(
            hidden_states,
            attention_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取编码器层集合的最后隐藏状态
        last_hidden_states = outputs[0]
        # 对最后隐藏状态进行 LayerNorm 处理
        last_hidden_states = self.layer_norm(last_hidden_states)

        # 如果需要输出隐藏状态,则更新 `hidden_states` 中的最后一个元素,应用上面的 `layernorm`
        hidden_states = None
        if output_hidden_states:
            hidden_states = outputs[1]
            hidden_states = hidden_states[:-1] + (last_hidden_states,)

        # 如果不以字典形式返回结果,则将结果组合成元组返回
        if not return_dict:
            outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
            return tuple(v for v in outputs if v is not None)

        # 以 FlaxBaseModelOutput 对象的形式返回结果
        return FlaxBaseModelOutput(
            last_hidden_state=last_hidden_states,
            hidden_states=hidden_states,
            attentions=outputs.attentions,
        )


# 定义一个名为 FlaxMBartDecoder 的类,目前还未完整给出,继承自 nn.Module
class FlaxMBartDecoder(nn.Module):
    # 类属性:MBart 的配置对象
    config: MBartConfig
    # 类属性:嵌入层对象,用于输入的词嵌入
    embed_tokens: nn.Embed
    # 设置默认数据类型为 jnp.float32,用于计算过程中的数据类型
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    # 初始化函数,在对象创建时调用,用于设置各种属性和参数
    def setup(self):
        # 初始化一个丢弃层,根据配置中的丢弃率设置
        self.dropout_layer = nn.Dropout(rate=self.config.dropout)

        # 获取嵌入维度,从配置中获取填充索引、最大目标位置和嵌入缩放因子
        embed_dim = self.config.d_model
        self.padding_idx = self.config.pad_token_id
        self.max_target_positions = self.config.max_position_embeddings
        self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0

        # 如果是 MBart 模型,根据填充索引偏移量设置嵌入位置
        # 其他模型不需要此偏移量的调整
        self.offset = 2
        self.embed_positions = nn.Embed(
            self.config.max_position_embeddings + self.offset,  # 设置嵌入的最大位置数量
            embed_dim,  # 嵌入的维度
            embedding_init=jax.nn.initializers.normal(self.config.init_std),  # 使用正态分布初始化嵌入
        )

        # 初始化多层解码器层集合
        self.layers = FlaxMBartDecoderLayerCollection(self.config, self.dtype)
        # 初始化嵌入层的 LayerNorm
        self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
        # 初始化通用的 LayerNorm
        self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)

    # 对象调用函数,实现模型的前向传播
    def __call__(
        self,
        input_ids,  # 输入的 token id
        attention_mask,  # 注意力掩码
        position_ids,  # 位置 id
        encoder_hidden_states: Optional[jnp.ndarray] = None,  # 编码器的隐藏状态(可选)
        encoder_attention_mask: Optional[jnp.ndarray] = None,  # 编码器的注意力掩码(可选)
        init_cache: bool = False,  # 是否初始化缓存(默认为 False)
        output_attentions: bool = False,  # 是否输出注意力权重(默认为 False)
        output_hidden_states: bool = False,  # 是否输出隐藏状态(默认为 False)
        return_dict: bool = True,  # 是否以字典形式返回结果(默认为 True)
        deterministic: bool = True,  # 是否确定性计算(默认为 True)
        ):
            # 获取输入张量的形状
            input_shape = input_ids.shape
            # 将输入张量重塑为二维张量
            input_ids = input_ids.reshape(-1, input_shape[-1])

            # 使用模型的嵌入层对输入张量进行嵌入,并乘以嵌入缩放因子
            inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

            # 嵌入位置信息
            positions = self.embed_positions(position_ids + self.offset)

            # 将输入嵌入和位置嵌入相加
            hidden_states = inputs_embeds + positions
            # 应用嵌入层归一化
            hidden_states = self.layernorm_embedding(hidden_states)

            # 对隐藏状态应用 dropout
            hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)

            # 将隐藏状态传递给层堆栈进行处理
            outputs = self.layers(
                hidden_states,
                attention_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                deterministic=deterministic,
                init_cache=init_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )

            # 获取输出中的最后隐藏状态
            last_hidden_states = outputs[0]
            # 对最后隐藏状态应用层归一化
            last_hidden_states = self.layer_norm(last_hidden_states)

            # 如果需要输出隐藏状态,更新 `hidden_states` 中的最后一个元素
            hidden_states = None
            if output_hidden_states:
                hidden_states = outputs[1]
                hidden_states = hidden_states[:-1] + (last_hidden_states,)

            # 如果不返回字典形式的结果,构建输出元组
            if not return_dict:
                outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
                return tuple(v for v in outputs if v is not None)

            # 返回带有过去和交叉注意力的 FlaxBaseModelOutputWithPastAndCrossAttentions 对象
            return FlaxBaseModelOutputWithPastAndCrossAttentions(
                last_hidden_state=last_hidden_states,
                hidden_states=hidden_states,
                attentions=outputs.attentions,
                cross_attentions=outputs.cross_attentions,
            )
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartModule with Bart->MBart
# 定义了 FlaxMBartModule 类,继承自 nn.Module
class FlaxMBartModule(nn.Module):
    # 类属性 config,指定为 MBartConfig 类型
    config: MBartConfig
    # 类属性 dtype,指定为 jnp.float32,用于计算的数据类型

    # 初始化方法 setup,用于设置模块内部的各个组件
    def setup(self):
        # 创建一个共享的嵌入层 nn.Embed 对象
        self.shared = nn.Embed(
            self.config.vocab_size,  # 嵌入层的词汇表大小
            self.config.d_model,      # 嵌入的维度大小
            embedding_init=jax.nn.initializers.normal(self.config.init_std),  # 嵌入层参数初始化方法
            dtype=self.dtype,         # 嵌入层的数据类型
        )

        # 创建 MBartEncoder 对象,用于编码输入数据
        self.encoder = FlaxMBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
        # 创建 MBartDecoder 对象,用于解码器的解码过程
        self.decoder = FlaxMBartDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared)

    # 获取编码器模块的方法
    def _get_encoder_module(self):
        return self.encoder

    # 获取解码器模块的方法
    def _get_decoder_module(self):
        return self.decoder

    # 实现 __call__ 方法,定义了模型的调用过程
    def __call__(
        self,
        input_ids,                 # 输入的编码器输入 ID
        attention_mask,            # 编码器的注意力掩码
        decoder_input_ids,         # 解码器的输入 ID
        decoder_attention_mask,    # 解码器的注意力掩码
        position_ids,              # 位置 ID
        decoder_position_ids,      # 解码器的位置 ID
        output_attentions: bool = False,         # 是否输出注意力权重
        output_hidden_states: bool = False,      # 是否输出隐藏状态
        return_dict: bool = True,                # 是否以字典形式返回结果
        deterministic: bool = True,              # 是否确定性计算结果
    ):
        # 编码器的前向传播过程,返回编码器的输出结果
        encoder_outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )

        # 解码器的前向传播过程,返回解码器的输出结果
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            position_ids=decoder_position_ids,
            encoder_hidden_states=encoder_outputs[0],  # 使用编码器的隐藏状态作为输入
            encoder_attention_mask=attention_mask,     # 使用编码器的注意力掩码
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )

        # 如果 return_dict 为 False,则将解码器和编码器的输出结果连接起来返回
        if not return_dict:
            return decoder_outputs + encoder_outputs

        # 如果 return_dict 为 True,则返回 FlaxSeq2SeqModelOutput 对象,包含了完整的模型输出信息
        return FlaxSeq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )


# 定义 FlaxMBartPreTrainedModel 类,继承自 FlaxPreTrainedModel
class FlaxMBartPreTrainedModel(FlaxPreTrainedModel):
    # 类属性 config_class,指定为 MBartConfig 类
    config_class = MBartConfig
    # 类属性 base_model_prefix,指定为字符串 "model"
    base_model_prefix: str = "model"
    # 类属性 module_class,默认为 None,用于指定模型的主模块

    # 初始化方法,用于创建 FlaxMBartPreTrainedModel 对象
    def __init__(
        self,
        config: MBartConfig,                  # MBart 模型的配置
        input_shape: Tuple[int] = (1, 1),     # 输入形状,默认为 (1, 1)
        seed: int = 0,                        # 随机种子,默认为 0
        dtype: jnp.dtype = jnp.float32,       # 计算数据类型,默认为 jnp.float32
        _do_init: bool = True,                # 是否初始化,默认为 True
        **kwargs,                             # 其他关键字参数
    ):
        # 调用父类的初始化方法,初始化模型的基本配置
        super().__init__(config=config, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init, **kwargs)
    ):
        # 使用给定的配置和参数实例化模块对象
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        # 调用父类初始化方法,传递配置、模块对象、输入形状、种子、数据类型以及是否初始化的标志
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

    # 初始化模型权重
    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # 初始化输入张量,全零张量,数据类型为'i4'
        input_ids = jnp.zeros(input_shape, dtype="i4")
        # 确保初始化可以为FlaxMBartForSequenceClassificationModule正常工作
        input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
        # 创建注意力遮罩,与input_ids形状相同,全为1
        attention_mask = jnp.ones_like(input_ids)
        # 初始化decoder输入为input_ids
        decoder_input_ids = input_ids
        # decoder的注意力遮罩与input_ids形状相同,全为1
        decoder_attention_mask = jnp.ones_like(input_ids)

        # 获取批次大小和序列长度
        batch_size, sequence_length = input_ids.shape
        # 生成位置编码,广播形状为(batch_size, sequence_length)
        position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
        # decoder的位置编码与position_ids相同
        decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

        # 划分随机数生成器,用于参数和dropout
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        # 使用初始化方法初始化模型,返回参数字典
        random_params = self.module.init(
            rngs,
            input_ids,
            attention_mask,
            decoder_input_ids,
            decoder_attention_mask,
            position_ids,
            decoder_position_ids,
        )["params"]

        # 如果提供了params,则用随机参数填充缺失的键
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            # 冻结并返回填充后的参数字典
            return freeze(unflatten_dict(params))
        else:
            # 否则,返回随机生成的参数字典
            return random_params

    # 从transformers.models.bart.modeling_flax_bart.FlaxBartPreTrainedModel.init_cache复制,替换Bart为MBart
    # 初始化缓存以支持快速自回归解码。
    def init_cache(self, batch_size, max_length, encoder_outputs):
        r"""
        Args:
            batch_size (`int`):
                用于快速自回归解码的批大小。定义初始化缓存时的批大小。
            max_length (`int`):
                自回归解码的最大可能长度。定义初始化缓存时的序列长度。
            encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
                `encoder_outputs` 包括 (`last_hidden_state`, *可选*: `hidden_states`, *可选*: `attentions`)。
                `last_hidden_state` 的形状为 `(batch_size, sequence_length, hidden_size)`,*可选* 是编码器最后一层的隐藏状态的序列,
                用于解码器的交叉注意力。
        """
        # 初始化用于检索缓存的输入变量
        decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
        decoder_attention_mask = jnp.ones_like(decoder_input_ids)
        decoder_position_ids = jnp.broadcast_to(
            jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
        )

        def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
            # 获取解码器模块
            decoder_module = module._get_decoder_module()
            return decoder_module(
                decoder_input_ids,
                decoder_attention_mask,
                decoder_position_ids,
                **kwargs,
            )

        # 使用给定的输入初始化模型变量,其中 `method` 指定了仅需调用解码器来初始化缓存
        init_variables = self.module.init(
            jax.random.PRNGKey(0),
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            decoder_position_ids=decoder_position_ids,
            encoder_hidden_states=encoder_outputs[0],
            init_cache=True,
            method=_decoder_forward,
        )
        # 返回冻结的缓存部分
        return unfreeze(init_variables["cache"])

    @add_start_docstrings(MBART_ENCODE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=MBartConfig)
    # 编码方法,根据输入的参数编码输入序列
    def encode(
        self,
        input_ids: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        position_ids: Optional[jnp.ndarray] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        r"""
        Returns:

        Example:

        ```
        >>> from transformers import AutoTokenizer, FlaxMBartForConditionalGeneration

        >>> model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25")
        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")

        >>> text = "My friends are cool but they eat too many carbs."
        >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax")
        >>> encoder_outputs = model.encode(**inputs)
        ```
        """
        # Determine whether to output attentions based on input or default configuration
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # Determine whether to output hidden states based on input or default configuration
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # Determine whether to return a dictionary of outputs based on input or default configuration
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # If attention mask is not provided, create a mask where all elements are 1
        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)
        # If position IDs are not provided, create a broadcasted array from 0 to sequence length
        if position_ids is None:
            batch_size, sequence_length = input_ids.shape
            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

        # Handle any pseudo-random number generators needed for dropout
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        # Define a nested function to forward input through the encoder module
        def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
            encode_module = module._get_encoder_module()
            return encode_module(input_ids, attention_mask, position_ids, **kwargs)

        # Apply the Flax module with specified parameters and inputs
        return self.module.apply(
            {"params": params or self.params},
            input_ids=jnp.array(input_ids, dtype="i4"),
            attention_mask=jnp.array(attention_mask, dtype="i4"),
            position_ids=jnp.array(position_ids, dtype="i4"),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs,
            method=_encoder_forward,
        )

    @add_start_docstrings(MBART_DECODE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=MBartConfig)
    def decode(
        self,
        decoder_input_ids,
        encoder_outputs,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        past_key_values: dict = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        r"""
        """
        # Function definition continues in the next segment
    # 定义一个特殊方法,使得对象可以像函数一样被调用
    def __call__(
        self,
        input_ids: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        decoder_input_ids: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        position_ids: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        # 如果未指定输出注意力机制,则使用配置中的默认设置
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 如果未指定输出隐藏状态,则使用配置中的默认设置
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 如果未指定返回字典形式,则使用配置中的默认设置
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # 准备编码器的输入
        if attention_mask is None:
            # 如果未提供注意力遮罩,则创建一个全为1的遮罩,形状与输入的input_ids相同
            attention_mask = jnp.ones_like(input_ids)
        if position_ids is None:
            # 如果未提供位置编码,则根据输入的input_ids的形状自动创建位置编码
            batch_size, sequence_length = input_ids.shape
            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

        # 准备解码器的输入
        if decoder_input_ids is None:
            # 如果未提供解码器输入的token_ids,则根据编码器的输入右移一位,同时使用配置中的pad_token_id填充
            decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id)
        if decoder_attention_mask is None:
            # 如果未提供解码器的注意力遮罩,则创建一个全为1的遮罩,形状与decoder_input_ids相同
            decoder_attention_mask = jnp.ones_like(decoder_input_ids)
        if decoder_position_ids is None:
            # 如果未提供解码器的位置编码,则根据decoder_input_ids的形状自动创建位置编码
            batch_size, sequence_length = decoder_input_ids.shape
            decoder_position_ids = jnp.broadcast_to(
                jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
            )

        # 处理可能需要的任何伪随机数生成器
        rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}

        # 调用self.module的apply方法,传递参数和输入,执行模型计算
        return self.module.apply(
            {"params": params or self.params},  # 参数字典,如果params为None则使用self.params
            input_ids=jnp.array(input_ids, dtype="i4"),  # 将input_ids转换为jnp的整型数组
            attention_mask=jnp.array(attention_mask, dtype="i4"),  # 将attention_mask转换为jnp的整型数组
            position_ids=jnp.array(position_ids, dtype="i4"),  # 将position_ids转换为jnp的整型数组
            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),  # 将decoder_input_ids转换为jnp的整型数组
            decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),  # 将decoder_attention_mask转换为jnp的整型数组
            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),  # 将decoder_position_ids转换为jnp的整型数组
            output_attentions=output_attentions,  # 输出注意力权重的标志
            output_hidden_states=output_hidden_states,  # 输出隐藏状态的标志
            return_dict=return_dict,  # 返回字典的标志
            deterministic=not train,  # 是否确定性计算,如果为False则进行随机dropout
            rngs=rngs,  # 伪随机数生成器字典
        )
# 为 FlaxMBartModel 类添加文档字符串,描述其作用为在 MBart 模型上输出原始隐藏状态而无需特定的顶部头部。
@add_start_docstrings(
    "The bare MBart Model transformer outputting raw hidden-states without any specific head on top.",
    MBART_START_DOCSTRING,
)
# 声明 FlaxMBartModel 类,继承自 FlaxMBartPreTrainedModel
class FlaxMBartModel(FlaxMBartPreTrainedModel):
    # 使用 MBartConfig 作为配置
    config: MBartConfig
    # 计算中使用的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32
    # 模块的类别设定为 FlaxMBartModule
    module_class = FlaxMBartModule

# 调用函数 append_call_sample_docstring,为 FlaxMBartModel 类添加调用示例的文档字符串
append_call_sample_docstring(FlaxMBartModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)

# 从 transformers.models.bart.modeling_flax_bart.FlaxBartForConditionalGenerationModule 复制并修改为 FlaxMBartForConditionalGenerationModule 类
class FlaxMBartForConditionalGenerationModule(nn.Module):
    # 使用 MBartConfig 作为配置
    config: MBartConfig
    # 计算中使用的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32
    # 偏置初始化函数,使用 jax.nn.initializers.zeros
    bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros

    # 模块的设置方法
    def setup(self):
        # 创建 FlaxMBartModule 模块,使用给定的配置和数据类型
        self.model = FlaxMBartModule(config=self.config, dtype=self.dtype)
        # 创建 lm_head 密集层,输出维度为 self.model.shared.num_embeddings,不使用偏置,使用指定的初始化器
        self.lm_head = nn.Dense(
            self.model.shared.num_embeddings,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )
        # 创建 final_logits_bias 参数,维度为 (1, self.model.shared.num_embeddings),使用偏置初始化器
        self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))

    # 获取编码器模块的方法
    def _get_encoder_module(self):
        return self.model.encoder

    # 获取解码器模块的方法
    def _get_decoder_module(self):
        return self.model.decoder

    # 类的调用方法,定义了模型的前向传播逻辑和相关参数
    def __call__(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask,
        position_ids,
        decoder_position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        # 调用模型进行前向推断,获取模型输出
        outputs = self.model(
            input_ids=input_ids,  # 输入的编码器输入 ID
            attention_mask=attention_mask,  # 编码器的注意力遮罩
            decoder_input_ids=decoder_input_ids,  # 解码器的输入 ID
            decoder_attention_mask=decoder_attention_mask,  # 解码器的注意力遮罩
            position_ids=position_ids,  # 位置 ID,用于编码器
            decoder_position_ids=decoder_position_ids,  # 解码器的位置 ID
            output_attentions=output_attentions,  # 是否输出注意力权重
            output_hidden_states=output_hidden_states,  # 是否输出隐藏状态
            return_dict=return_dict,  # 是否返回字典格式的输出
            deterministic=deterministic,  # 是否确定性运行(不随机性质)
        )

        # 获取模型的隐藏状态作为 LM 的 logits
        hidden_states = outputs[0]

        # 如果配置要求共享词嵌入
        if self.config.tie_word_embeddings:
            # 获取共享的嵌入矩阵
            shared_embedding = self.model.variables["params"]["shared"]["embedding"]
            # 应用共享的嵌入矩阵作为 LM 头的核心参数
            lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
        else:
            # 使用原始的 LM 头计算 logits
            lm_logits = self.lm_head(hidden_states)

        # 将最终 logits 加上偏置项(停止梯度传播)
        lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype))

        # 如果不要求返回字典格式的输出
        if not return_dict:
            # 组装输出,包括 logits 和其他模型输出
            output = (lm_logits,) + outputs[1:]
            return output

        # 返回 FlaxSeq2SeqLMOutput 类型的输出
        return FlaxSeq2SeqLMOutput(
            logits=lm_logits,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )
@add_start_docstrings(
    "The MMBart Model with a language modeling head. Can be used for summarization.", MBART_START_DOCSTRING
)
class FlaxMBartForConditionalGeneration(FlaxMBartPreTrainedModel):
    module_class = FlaxMBartForConditionalGenerationModule
    dtype: jnp.dtype = jnp.float32

    @add_start_docstrings(MBART_DECODE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=MBartConfig)
    def decode(
        self,
        decoder_input_ids,
        encoder_outputs,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        past_key_values: dict = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        """
        Performs decoding with the model.

        Args:
            decoder_input_ids: Input IDs for the decoder.
            encoder_outputs: Outputs from the encoder.
            encoder_attention_mask: Optional attention mask for the encoder outputs.
            decoder_attention_mask: Optional attention mask for the decoder inputs.
            decoder_position_ids: Optional position IDs for the decoder inputs.
            past_key_values: Cached key values for efficient generation.
            output_attentions: Whether to output attentions.
            output_hidden_states: Whether to output hidden states.
            return_dict: Whether to return a dictionary or a tuple.
            train: Whether in training mode.
            params: Optional parameters.
            dropout_rng: Dropout random number generator key.

        Returns:
            Model output with cross attentions.

        """
        # Function body omitted for brevity as it is straightforward with provided docstrings.

    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        max_length,
        attention_mask: Optional[jax.Array] = None,
        decoder_attention_mask: Optional[jax.Array] = None,
        encoder_outputs=None,
        **kwargs,
    ):
        """
        Prepares inputs for generation.

        Args:
            decoder_input_ids: Input IDs for the decoder.
            max_length: Maximum length for generation.
            attention_mask: Optional attention mask for the encoder outputs.
            decoder_attention_mask: Optional attention mask for the decoder inputs.
            encoder_outputs: Outputs from the encoder.
            **kwargs: Additional keyword arguments.

        Returns:
            Dictionary with prepared inputs for generation.

        """
        # initializing the cache
        batch_size, seq_length = decoder_input_ids.shape

        # Initialize past key values for efficient generation
        past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)

        # Create an extended attention mask
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
        if decoder_attention_mask is not None:
            # Adjust position IDs based on decoder_attention_mask
            position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
        else:
            # Use default position IDs if decoder_attention_mask is not provided
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))

        return {
            "past_key_values": past_key_values,
            "encoder_outputs": encoder_outputs,
            "encoder_attention_mask": attention_mask,
            "decoder_attention_mask": extended_attention_mask,
            "decoder_position_ids": position_ids,
        }

    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        """
        Updates inputs for generation.

        Args:
            model_outputs: Model outputs from the generation.
            model_kwargs: Original model keyword arguments.

        Returns:
            Updated model keyword arguments.

        """
        # Update past_key_values and decoder_position_ids for next generation step
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
        return model_kwargs


FLAX_MBART_CONDITIONAL_GENERATION_DOCSTRING = r"""
    Returns:

    Summarization example:

    ```
    >>> from transformers import AutoTokenizer, FlaxMBartForConditionalGeneration, MBartConfig
    >>> model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25")
    >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
    
    从预训练的MBart模型和tokenizer中加载Facebook的mbart-large-cc25模型和标记器。
    
    
    >>> ARTICLE_TO_SUMMARIZE = "Meine Freunde sind cool, aber sie essen zu viel Kuchen."
    >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="np")
    
    定义要进行摘要的文章,并使用tokenizer将其转换为模型所需的输入格式。
    
    
    >>> # Generate Summary
    >>> summary_ids = model.generate(inputs["input_ids"], num_beams=4, max_length=5).sequences
    >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
    
    使用模型生成文章的摘要,指定生成4个束(beam),最大长度为5,然后解码生成的摘要并打印。
    
    
    >>> from transformers import AutoTokenizer, FlaxMBartForConditionalGeneration
    
    >>> model = FlaxMBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25")
    >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
    
    再次加载MBart模型和标记器,确保环境准备好用于示例。
    
    
    >>> # de_DE is the language symbol id <LID> for German
    >>> TXT = "</s> Meine Freunde sind <mask> nett aber sie essen zu viel Kuchen. </s> de_DE"
    >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="np")["input_ids"]
    
    定义一个包含掩码填充的例子,`TXT`包含一个掩码标记`<mask>`,表示需要填充的位置。将`TXT`编码为模型可接受的输入格式。
    
    
    >>> logits = model(input_ids).logits
    >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero()[0].item()
    >>> probs = logits[0, masked_index].softmax(dim=0)
    >>> values, predictions = probs.topk(5)
    
    使用模型预测掩码位置的概率分布,并选择最高的五个概率值。
    
    
    >>> tokenizer.decode(predictions).split()
    
    将预测的结果解码为文本序列,并分割为单词列表。
"""

# 调用函数`overwrite_call_docstring`,用于重写模型类的文档字符串
overwrite_call_docstring(
    FlaxMBartForConditionalGeneration, MBART_INPUTS_DOCSTRING + FLAX_MBART_CONDITIONAL_GENERATION_DOCSTRING
)
# 调用函数`append_replace_return_docstrings`,用于追加或替换模型类的返回值文档字符串
append_replace_return_docstrings(
    FlaxMBartForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
)


# 从`transformers.models.bart.modeling_flax_bart.FlaxBartForSequenceClassificationModule`复制代码,将Bart改为MBart
class FlaxMBartForSequenceClassificationModule(nn.Module):
    config: MBartConfig  # 定义MBart配置
    dtype: jnp.dtype = jnp.float32  # 设置数据类型为32位浮点数
    num_labels: Optional[int] = None  # 可选的标签数量

    def setup(self):
        # 初始化MBart模型和分类头部
        self.model = FlaxMBartModule(config=self.config, dtype=self.dtype)
        self.classification_head = FlaxMBartClassificationHead(
            config=self.config,
            inner_dim=self.config.d_model,
            num_classes=self.num_labels if self.num_labels is not None else self.config.num_labels,
            pooler_dropout=self.config.classifier_dropout,
        )

    def _get_encoder_module(self):
        # 获取编码器模块
        return self.model.encoder

    def _get_decoder_module(self):
        # 获取解码器模块
        return self.model.decoder

    def __call__(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask,
        position_ids,
        decoder_position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
        # 定义模型调用接口,接受多个输入参数和控制参数

        # 返回字典格式的结果,控制是否返回注意力权重和隐藏状态
        return self.model(
            input_ids,
            attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            position_ids=position_ids,
            decoder_position_ids=decoder_position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )
        ):
            # 调用模型进行推理
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
                position_ids=position_ids,
                decoder_position_ids=decoder_position_ids,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                deterministic=deterministic,
            )

            # 获取模型输出的最后一个隐藏状态
            hidden_states = outputs[0]  # 最后一个隐藏状态

            # 创建一个掩码,用于标记输入中的结束符(<eos>)
            eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0)

            # 处理 JAX 编译时的类型错误
            if type(eos_mask) != jax.interpreters.partial_eval.DynamicJaxprTracer:
                # 检查所有样本是否具有相同数量的 <eos> 标记
                if len(jnp.unique(eos_mask.sum(1))) > 1:
                    raise ValueError("所有示例必须具有相同数量的 <eos> 标记。")

                # 检查输入中是否有缺失的 <eos> 标记
                if any(eos_mask.sum(1) == 0):
                    raise ValueError("输入中缺少 <eos> 标记。")

                # 确保每个示例只保留最后一个 <eos> 标记
                eos_mask_noised = eos_mask + jnp.arange(eos_mask.shape[1]) * 1e-6
                eos_mask = jnp.where(eos_mask_noised == eos_mask_noised.max(1).reshape(-1, 1), 1, 0)

            # 使用 <eos> 标记计算句子表示
            sentence_representation = jnp.einsum("ijk, ij -> ijk", hidden_states, eos_mask).sum(1)

            # 使用分类头部对句子表示进行分类预测
            logits = self.classification_head(sentence_representation, deterministic=deterministic)

            # 如果不返回字典,则返回元组
            if not return_dict:
                output = (logits,) + outputs[1:]
                return output

            # 如果返回字典,则返回 FlaxSeq2SeqSequenceClassifierOutput 类的实例
            return FlaxSeq2SeqSequenceClassifierOutput(
                logits=logits,
                decoder_hidden_states=outputs.decoder_hidden_states,
                decoder_attentions=outputs.decoder_attentions,
                cross_attentions=outputs.cross_attentions,
                encoder_last_hidden_state=outputs.encoder_last_hidden_state,
                encoder_hidden_states=outputs.encoder_hidden_states,
                encoder_attentions=outputs.encoder_attentions,
            )
@add_start_docstrings(
    """
    MBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
    tasks.
    """,
    MBART_START_DOCSTRING,
)
"""
使用`add_start_docstrings`装饰器为`FlaxMBartForSequenceClassification`类添加文档字符串,描述其作为带有顶部序列分类/头部的MBart模型。
"""

class FlaxMBartForSequenceClassification(FlaxMBartPreTrainedModel):
    """
    MBart序列分类模型,继承自`FlaxMBartPreTrainedModel`。
    """
    module_class = FlaxMBartForSequenceClassificationModule
    dtype = jnp.float32

append_call_sample_docstring(
    FlaxMBartForSequenceClassification,
    _CHECKPOINT_FOR_DOC,
    FlaxSeq2SeqSequenceClassifierOutput,
    _CONFIG_FOR_DOC,
)
"""
使用`append_call_sample_docstring`函数为`FlaxMBartForSequenceClassification`类添加示例调用文档字符串。
"""

# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartForQuestionAnsweringModule with Bart->MBart
"""
从`transformers.models.bart.modeling_flax_bart.FlaxBartForQuestionAnsweringModule`复制代码,并将Bart替换为MBart。
"""

class FlaxMBartForQuestionAnsweringModule(nn.Module):
    """
    MBart问答模块定义,继承自`nn.Module`。
    """
    config: MBartConfig
    dtype: jnp.dtype = jnp.float32
    num_labels = 2

    def setup(self):
        """
        设置方法,初始化模型和输出层。
        """
        self.model = FlaxMBartModule(config=self.config, dtype=self.dtype)
        self.qa_outputs = nn.Dense(
            self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
        )

    def _get_encoder_module(self):
        """
        获取编码器模块的私有方法。
        """
        return self.model.encoder

    def _get_decoder_module(self):
        """
        获取解码器模块的私有方法。
        """
        return self.model.decoder

    def __call__(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask,
        position_ids,
        decoder_position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        """
        模型调用方法,接受多个输入和参数,返回包含多个输出的字典或元组。

        Args:
            input_ids: 输入的编码器输入id。
            attention_mask: 编码器的注意力掩码。
            decoder_input_ids: 解码器输入id。
            decoder_attention_mask: 解码器的注意力掩码。
            position_ids: 输入的位置id。
            decoder_position_ids: 解码器的位置id。
            output_attentions: 是否输出注意力权重。
            output_hidden_states: 是否输出隐藏状态。
            return_dict: 是否以字典形式返回输出。
            deterministic: 是否确定性计算。

        Returns:
            根据return_dict返回不同结构的输出。
        """
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            position_ids=position_ids,
            decoder_position_ids=decoder_position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = jnp.split(logits, logits.shape[-1], axis=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        if not return_dict:
            output = (start_logits, end_logits) + outputs[1:]
            return output

        return FlaxSeq2SeqQuestionAnsweringModelOutput(
            start_logits=start_logits,
            end_logits=end_logits,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )
    MBart Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
    MBART_START_DOCSTRING,



# MBart 模型,使用顶部的跨度分类头部用于抽取式问答任务,如 SQuAD(在隐藏状态输出之上的线性层,用于计算“起始跨度对数”和“结束跨度对数”)。
MBart Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
"""
# 导入 MBART_START_DOCSTRING,可能是一个模型文档字符串的起始标记或常量
MBART_START_DOCSTRING,
)
# 在此处代码似乎存在语法错误,可能是由于括号未正确闭合引起的问题
class FlaxMBartForQuestionAnswering(FlaxMBartPreTrainedModel):
    # 将模块类指定为 FlaxMBartForQuestionAnsweringModule
    module_class = FlaxMBartForQuestionAnsweringModule
    # 指定数据类型为 jnp.float32
    dtype = jnp.float32


# 向 FlaxMBartForQuestionAnswering 类附加一个调用样本文档字符串的函数
append_call_sample_docstring(
    FlaxMBartForQuestionAnswering,
    _CHECKPOINT_FOR_DOC,
    FlaxSeq2SeqQuestionAnsweringModelOutput,
    _CONFIG_FOR_DOC,
)

.\models\mbart\modeling_mbart.py

# 设置文件编码为 UTF-8
# 版权声明和许可信息,指明代码版权归 Facebook AI Research Team 和 HuggingFace Inc. 团队所有,使用 Apache License, Version 2.0 授权
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非法律要求或书面同意,否则不得使用此文件中的代码
""" PyTorch MBART 模型定义 """
# 导入必要的库和模块
import copy  # 导入深拷贝功能
import math  # 导入数学函数
from typing import List, Optional, Tuple, Union  # 引入类型提示

import torch  # 导入 PyTorch 库
import torch.nn.functional as F  # 导入 PyTorch 的函数库
import torch.utils.checkpoint  # 导入 PyTorch 的检查点功能
from torch import nn  # 从 PyTorch 中导入神经网络模块
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss  # 导入损失函数

# 从本地或者上层模块导入所需的工具函数和类
from ...activations import ACT2FN  # 导入激活函数映射表
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask  # 导入注意力掩码处理工具函数
from ...modeling_outputs import (  # 导入模型输出类
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
    Seq2SeqQuestionAnsweringModelOutput,
    Seq2SeqSequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel  # 导入预训练模型基类
from ...utils import (  # 导入工具函数和类
    add_code_sample_docstrings,
    add_end_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
    logging,
    replace_return_docstrings,
)
from .configuration_mbart import MBartConfig  # 从当前模块导入 MBART 配置类

# 如果支持 Flash Attention 2.0,导入相关函数和模块
if is_flash_attn_2_available():
    from flash_attn import flash_attn_func, flash_attn_varlen_func
    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa

# 获取日志记录器
logger = logging.get_logger(__name__)

# 模型文档中使用的检查点名称
_CHECKPOINT_FOR_DOC = "facebook/mbart-large-cc25"
# 模型文档中使用的配置名称
_CONFIG_FOR_DOC = "MBartConfig"

# 期望的输出形状
_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]

# MBART 预训练模型的存档列表
MBART_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "facebook/mbart-large-cc25",
    # 查看所有 MBART 模型列表:https://huggingface.co/models?filter=mbart
]

# 从 transformers.models.llama.modeling_llama._get_unpad_data 复制的函数,用于获取未填充数据
def _get_unpad_data(attention_mask):
    # 计算每个样本序列的长度
    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
    # 找到非填充位置的索引
    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
    # 获取批次中最大的序列长度
    max_seqlen_in_batch = seqlens_in_batch.max().item()
    # 计算累积序列长度
    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
    return (
        indices,
        cu_seqlens,
        max_seqlen_in_batch,
    )

# 将输入的 ID 向右移动一个位置,用于生成输入序列的右移版本,用于 MBart 模型的输入处理
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
    """
    将输入的 ID 向右移动一个位置,并包装最后一个非填充标记(即 <LID> 标记)。需要注意的是,与其他类似 Bart 的模型不同,MBart 没有单一的 `decoder_start_token_id`。
    """
    # 复制输入的 token 序列作为输出的初始 token 序列
    prev_output_tokens = input_ids.clone()

    # 如果未定义 pad_token_id,则抛出数值错误异常
    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    
    # 将 labels 中可能存在的值为 -100 的部分替换为 pad_token_id
    prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)

    # 找到每个样本中最后一个非 pad_token_id 的位置,形成一个索引张量
    index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)

    # 根据 index_of_eos 中的索引,获取每个样本中的 decoder 起始 token
    decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()

    # 将 prev_output_tokens 中每个样本的 token 序列整体左移一位
    prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
    
    # 将 prev_output_tokens 中每个样本的第一个 token 替换为 decoder_start_tokens
    prev_output_tokens[:, 0] = decoder_start_tokens

    # 返回处理后的 prev_output_tokens,即新的输出 token 序列
    return prev_output_tokens
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MBart
class MBartLearnedPositionalEmbedding(nn.Embedding):
    """
    This module learns positional embeddings up to a fixed maximum size.
    """

    def __init__(self, num_embeddings: int, embedding_dim: int):
        # MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
        # and adjust num_embeddings appropriately. Other models don't have this hack
        self.offset = 2
        # Call the constructor of nn.Embedding with adjusted num_embeddings
        super().__init__(num_embeddings + self.offset, embedding_dim)

    def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
        """`input_ids' shape is expected to be [bsz x seqlen]."""

        # Extract batch size and sequence length from input_ids tensor
        bsz, seq_len = input_ids.shape[:2]
        # Generate positions tensor starting from past_key_values_length up to past_key_values_length + seq_len
        positions = torch.arange(
            past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
        ).expand(bsz, -1)

        # Return the positional embeddings by adding self.offset to positions
        return super().forward(positions + self.offset)


# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->MBart
class MBartAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        bias: bool = True,
        is_causal: bool = False,
        config: Optional[MBartConfig] = None,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        self.config = config

        # Check if embed_dim is divisible by num_heads
        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {num_heads})."
            )
        # Scaling factor for dot product attention
        self.scaling = self.head_dim**-0.5
        self.is_decoder = is_decoder
        self.is_causal = is_causal

        # Linear projections for key, value, query and output
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        # Reshape tensor to [batch_size, num_heads, seq_len, head_dim]
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        key_value_states: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        ):
        # Forward pass through the multi-headed attention mechanism
        # ...
        pass

# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->MBart
class MBartFlashAttention2(MBartAttention):
    """
    Placeholder class for future extension or modification.
    """
    pass
    # MBart flash attention module. This module inherits from `MBartAttention` as the weights of the module stays
    # untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
    # flash attention and deal with padding tokens in case the input contains any of them.

    # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
    # 初始化函数,继承自父类,初始化模块。
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
        # 控制属性,用于处理 Flash Attention 不同版本之间的差异。
        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

    # Reshape 操作,将张量重新排列成指定形状。
    def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)

    # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
    # Flash Attention 的前向传播函数,处理查询、键、值、注意力掩码等参数。
    def _flash_attention_forward(
        self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
        """
        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
        first unpad the input, then computes the attention scores and pad the final attention scores.

        Args:
            query_states (`torch.Tensor`):
                Input query states to be passed to Flash Attention API
            key_states (`torch.Tensor`):
                Input key states to be passed to Flash Attention API
            value_states (`torch.Tensor`):
                Input value states to be passed to Flash Attention API
            attention_mask (`torch.Tensor`):
                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
                position of padding tokens and 1 for the position of non-padding tokens.
            dropout (`float`):
                Attention dropout
            softmax_scale (`float`, *optional*):
                The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
        """
        # Determine if causal attention is needed based on current settings and conditions
        if not self._flash_attn_uses_top_left_mask:
            causal = self.is_causal
        else:
            # Temporary workaround for Flash Attention on RoCm platform; check not needed after version 2.1
            causal = self.is_causal and query_length != 1

        # Check if there are any padding tokens in the input sequence
        if attention_mask is not None:
            # Get batch size from query_states tensor
            batch_size = query_states.shape[0]
            
            # Unpad input tensors based on attention_mask
            query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
                query_states, key_states, value_states, attention_mask, query_length
            )

            # Unpacked variables from cu_seq_lens and max_seq_lens
            cu_seqlens_q, cu_seqlens_k = cu_seq_lens
            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

            # Perform Flash Attention with variable length support
            attn_output_unpad = flash_attn_varlen_func(
                query_states,
                key_states,
                value_states,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=max_seqlen_in_batch_q,
                max_seqlen_k=max_seqlen_in_batch_k,
                dropout_p=dropout,
                softmax_scale=softmax_scale,
                causal=causal,
            )

            # Pad the attention output to match the original input sequence length
            attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
        else:
            # Perform regular Flash Attention without padding
            attn_output = flash_attn_func(
                query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
            )

        # Return the computed attention output
        return attn_output
    # 定义一个私有方法 `_upad_input`,用于处理注意力机制的输入数据
    def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
        # 获取非填充数据的索引、当前序列长度及批次中的最大序列长度
        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
        
        # 获取 key_layer 的形状信息
        batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

        # 根据 indices_k 重新索引并重新组织 key_layer 和 value_layer
        key_layer = index_first_axis(
            key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
        )
        value_layer = index_first_axis(
            value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
        )

        # 根据 query_length 的不同情况处理 query_layer
        if query_length == kv_seq_len:
            # 如果 query_length 等于 kv_seq_len,则按 indices_k 重新索引 query_layer
            query_layer = index_first_axis(
                query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
            )
            cu_seqlens_q = cu_seqlens_k
            max_seqlen_in_batch_q = max_seqlen_in_batch_k
            indices_q = indices_k
        elif query_length == 1:
            # 如果 query_length 等于 1,则处理单个 query 的情况
            max_seqlen_in_batch_q = 1
            cu_seqlens_q = torch.arange(
                batch_size + 1, dtype=torch.int32, device=query_layer.device
            )  # 这里有一个 memcpy,非常糟糕。
            indices_q = cu_seqlens_q[:-1]
            query_layer = query_layer.squeeze(1)
        else:
            # 否则,假设存在左填充,根据 query_length 和 attention_mask 进行处理
            # 注意:这里的 -query_length 切片假设存在左填充。
            attention_mask = attention_mask[:, -query_length:]
            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

        # 返回处理后的 query_layer、key_layer、value_layer,以及相关的索引和长度信息
        return (
            query_layer,
            key_layer,
            value_layer,
            indices_q,
            (cu_seqlens_q, cu_seqlens_k),
            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
        )
MBART_ATTENTION_CLASSES = {
    "eager": MBartAttention,  # 定义一个字典,将字符串映射到对应的注意力机制类
    "flash_attention_2": MBartFlashAttention2,
}

class MBartEncoderLayer(nn.Module):
    def __init__(self, config: MBartConfig):
        super().__init__()
        self.embed_dim = config.d_model  # 从配置中获取嵌入维度大小

        # 初始化自注意力层,根据配置选择不同的注意力机制类
        self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
            embed_dim=self.embed_dim,
            num_heads=config.encoder_attention_heads,
            dropout=config.attention_dropout,
            config=config,
        )

        # 初始化自注意力层的 LayerNorm
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)

        # 获取配置中的 dropout 概率
        self.dropout = config.dropout

        # 根据配置选择激活函数
        self.activation_fn = ACT2FN[config.activation_function]

        # 获取配置中的激活函数 dropout 概率
        self.activation_dropout = config.activation_dropout

        # 第一个线性层,将嵌入维度映射到编码器前馈网络维度
        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)

        # 第二个线性层,将编码器前馈网络维度映射回嵌入维度
        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)

        # 最终的 LayerNorm 层,对输出进行归一化
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        layer_head_mask: torch.Tensor,
        output_attentions: bool = False,
    ) -> torch.Tensor:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        """
        # 保存输入的隐藏状态作为残差连接的基础
        residual = hidden_states
        # 对隐藏状态进行 LayerNorm 处理
        hidden_states = self.self_attn_layer_norm(hidden_states)
        # 使用自注意力机制进行计算,得到新的隐藏状态、注意力权重和注意力概率(此处第三个返回值用下划线 `_` 表示)
        hidden_states, attn_weights, _ = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            layer_head_mask=layer_head_mask,
            output_attentions=output_attentions,
        )
        # 对输出的隐藏状态进行 dropout 处理
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        # 将残差与处理后的隐藏状态相加,实现残差连接
        hidden_states = residual + hidden_states

        # 保存当前状态作为下一步的残差连接基础
        residual = hidden_states
        # 对最终输出的隐藏状态进行 LayerNorm 处理
        hidden_states = self.final_layer_norm(hidden_states)
        # 使用激活函数处理第一个全连接层的输出
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        # 对第一个全连接层的输出进行 dropout 处理
        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
        # 经过第二个全连接层
        hidden_states = self.fc2(hidden_states)
        # 对第二个全连接层的输出进行 dropout 处理
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        # 将残差与处理后的隐藏状态相加,实现残差连接
        hidden_states = residual + hidden_states

        # 如果隐藏状态的数据类型是 torch.float16,并且存在无穷大或 NaN 的情况,则进行 clamp 操作
        if hidden_states.dtype == torch.float16 and (
            torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
        ):
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

        # 将最终的隐藏状态作为输出
        outputs = (hidden_states,)

        # 如果需要输出注意力权重,则将注意力权重也加入到输出中
        if output_attentions:
            outputs += (attn_weights,)

        return outputs
# 定义 MBartDecoderLayer 类,继承自 nn.Module,用于 MBart 解码器层的实现
class MBartDecoderLayer(nn.Module):
    def __init__(self, config: MBartConfig):
        super().__init__()
        self.embed_dim = config.d_model

        # 初始化自注意力机制,根据配置选择实现类,并设置相关参数
        self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
            embed_dim=self.embed_dim,
            num_heads=config.decoder_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
            is_causal=True,
            config=config,
        )
        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.activation_function]
        self.activation_dropout = config.activation_dropout

        # 初始化自注意力层规范化器
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)

        # 初始化编码器注意力机制,根据配置选择实现类,并设置相关参数
        self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
            self.embed_dim,
            config.decoder_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
            config=config,
        )
        # 初始化编码器注意力层规范化器
        self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)

        # 初始化第一个线性层(前馈神经网络的第一层)
        self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)

        # 初始化第二个线性层(前馈神经网络的第二层)
        self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)

        # 初始化最终层规范化器
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    # 前向传播函数,定义了层的数据流向
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = True,
    ):
        # 以下为具体的层操作实现
        # 注意力机制和规范化
        ...


# 定义 MBartClassificationHead 类,用于 MBart 模型的分类任务头部
class MBartClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(
        self,
        input_dim: int,
        inner_dim: int,
        num_classes: int,
        pooler_dropout: float,
    ):
        super().__init__()
        # 初始化线性层,用于将输入维度映射到内部维度
        self.dense = nn.Linear(input_dim, inner_dim)
        self.dropout = nn.Dropout(p=pooler_dropout)
        # 输出投影层,将内部维度映射到类别数量
        self.out_proj = nn.Linear(inner_dim, num_classes)

    # 前向传播函数,定义了头部的数据流向
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 应用 dropout 操作
        hidden_states = self.dropout(hidden_states)
        # 通过线性层进行映射和激活函数
        hidden_states = self.dense(hidden_states)
        hidden_states = torch.tanh(hidden_states)
        hidden_states = self.dropout(hidden_states)
        # 最终通过输出投影层得到分类结果
        hidden_states = self.out_proj(hidden_states)
        return hidden_states


# 定义 MBartPreTrainedModel 类,继承自 PreTrainedModel,作为 MBart 模型的基类
class MBartPreTrainedModel(PreTrainedModel):
    config_class = MBartConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["MBartDecoderLayer", "MBartAttention"]
    _supports_flash_attn_2 = True
    # 初始化神经网络模块的权重
    def _init_weights(self, module):
        std = self.config.init_std  # 获取初始化标准差
        if isinstance(module, nn.Linear):  # 如果当前模块是线性层
            module.weight.data.normal_(mean=0.0, std=std)  # 初始化权重为正态分布
            if module.bias is not None:  # 如果存在偏置项
                module.bias.data.zero_()  # 将偏置项初始化为零
        elif isinstance(module, nn.Embedding):  # 如果当前模块是嵌入层
            module.weight.data.normal_(mean=0.0, std=std)  # 初始化嵌入矩阵的权重为正态分布
            if module.padding_idx is not None:  # 如果指定了填充索引
                module.weight.data[module.padding_idx].zero_()  # 将填充索引位置的权重初始化为零

    @property
    # 获取一个虚拟输入示例的属性方法
    def dummy_inputs(self):
        pad_token = self.config.pad_token_id  # 获取填充标记的 ID
        input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)  # 创建输入 ID 张量
        dummy_inputs = {
            "attention_mask": input_ids.ne(pad_token),  # 生成注意力掩码,排除填充标记
            "input_ids": input_ids,  # 将输入 ID 放入字典
        }
        return dummy_inputs  # 返回虚拟输入字典
MBART_START_DOCSTRING = r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`MBartConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

MBART_GENERATION_EXAMPLE = r"""
    Translation example:

    ```
    >>> from transformers import AutoTokenizer, MBartForConditionalGeneration

    >>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
    >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-en-ro")

    >>> example_english_phrase = "42 is the answer"
    >>> inputs = tokenizer(example_english_phrase, return_tensors="pt")

    >>> # Translate
    >>> generated_ids = model.generate(**inputs, num_beams=4, max_length=5)
    >>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    '42 este răspuns'
    ```

    Mask filling example:

    ```
    >>> from transformers import AutoTokenizer, MBartForConditionalGeneration

    >>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25")
    >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")

    >>> # de_DE is the language symbol id <LID> for German
    >>> TXT = "</s> Meine Freunde sind <mask> nett aber sie essen zu viel Kuchen. </s> de_DE"

    >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="pt")["input_ids"]
    >>> logits = model(input_ids).logits

    >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
    >>> probs = logits[0, masked_index].softmax(dim=0)
    >>> values, predictions = probs.topk(5)

    >>> tokenizer.decode(predictions).split()
    ['nett', 'sehr', 'ganz', 'nicht', 'so']
    ```
"""

MBART_INPUTS_DOCSTRING = r"""
"""


class MBartEncoder(MBartPreTrainedModel):
    """
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    [`MBartEncoderLayer`].

    Args:
        config: MBartConfig
            Model configuration class with all the parameters of the model.
        embed_tokens (nn.Embedding): output embedding
            The output embedding for the model.
    """
    # 初始化函数,接受一个 MBartConfig 对象和一个可选的嵌入词向量对象作为参数
    def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None):
        # 调用父类的初始化方法
        super().__init__(config)

        # 从配置中获取 dropout 和 encoder_layerdrop 的数值
        self.dropout = config.dropout
        self.layerdrop = config.encoder_layerdrop

        # 从配置中获取嵌入的维度,并设置填充 token 的索引
        embed_dim = config.d_model
        self.padding_idx = config.pad_token_id
        # 设置最大的源序列长度
        self.max_source_positions = config.max_position_embeddings
        # 根据配置选择是否对嵌入进行缩放
        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

        # 创建嵌入层,vocab_size 是词汇表的大小,embed_dim 是嵌入的维度,padding_idx 是填充 token 的索引
        self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)

        # 如果提供了外部的嵌入词向量,则使用它来初始化嵌入层的权重
        if embed_tokens is not None:
            self.embed_tokens.weight = embed_tokens.weight

        # 创建学习的位置嵌入对象,max_position_embeddings 是位置的最大数量,embed_dim 是嵌入的维度
        self.embed_positions = MBartLearnedPositionalEmbedding(
            config.max_position_embeddings,
            embed_dim,
        )
        
        # 创建一系列 MBartEncoderLayer 层,并存储在 layers 中,数量由 encoder_layers 决定
        self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)])
        
        # 根据配置决定是否使用 flash_attention_2
        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
        
        # 对嵌入进行 layernorm 处理,embed_dim 是嵌入的维度
        self.layernorm_embedding = nn.LayerNorm(embed_dim)
        
        # 对编码器的输出进行 layernorm 处理,config.d_model 是模型的维度
        self.layer_norm = nn.LayerNorm(config.d_model)

        # 是否启用梯度检查点,默认为 False
        self.gradient_checkpointing = False

        # 调用 post_init 方法来初始化权重并进行最终的处理
        self.post_init()

    # 用于向后兼容梯度检查点,如果配置中设置了 gradient_checkpointing,则启用梯度检查点
    def _backward_compatibility_gradient_checkpointing(self):
        # 不删除配置中的梯度检查点属性
        if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
            self.gradient_checkpointing_enable()

    # 前向传播函数,接受多个输入参数,包括 input_ids、attention_mask 等
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
class MBartDecoder(MBartPreTrainedModel):
    """
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MBartDecoderLayer`]

    Args:
        config: MBartConfig
        embed_tokens (nn.Embedding): output embedding
    """

    def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None):
        super().__init__(config)
        self.dropout = config.dropout  # 从配置中获取dropout率
        self.layerdrop = config.decoder_layerdrop  # 从配置中获取层dropout率
        self.padding_idx = config.pad_token_id  # 获取填充token的索引
        self.max_target_positions = config.max_position_embeddings  # 获取目标位置的最大数目
        self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0  # 计算嵌入尺度

        self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)  # 初始化嵌入层

        if embed_tokens is not None:
            self.embed_tokens.weight = embed_tokens.weight  # 如果提供了预训练的嵌入,使用它们

        self.embed_positions = MBartLearnedPositionalEmbedding(
            config.max_position_embeddings,
            config.d_model,
        )  # 初始化位置编码器

        self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)])  # 创建多层解码器层
        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"  # 检查是否使用了Flash Attention 2
        self.layernorm_embedding = nn.LayerNorm(config.d_model)  # 初始化嵌入层的LayerNorm
        self.layer_norm = nn.LayerNorm(config.d_model)  # 初始化层的LayerNorm

        self.gradient_checkpointing = False  # 梯度检查点设为False
        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens  # 返回输入的嵌入层

    def set_input_embeddings(self, value):
        self.embed_tokens = value  # 设置输入的嵌入层

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 函数签名和参数说明
        """
        MBartDecoder的前向传播函数

        Args:
            input_ids (torch.LongTensor, optional): 输入的token IDs
            attention_mask (torch.Tensor, optional): 注意力掩码
            encoder_hidden_states (torch.FloatTensor, optional): 编码器的隐藏状态
            encoder_attention_mask (torch.LongTensor, optional): 编码器的注意力掩码
            head_mask (torch.Tensor, optional): 多头注意力的头部掩码
            cross_attn_head_mask (torch.Tensor, optional): 跨注意力头部的掩码
            past_key_values (Tuple[Tuple[torch.FloatTensor]], optional): 缓存的键值对
            inputs_embeds (torch.FloatTensor, optional): 输入的嵌入表示
            use_cache (bool, optional): 是否使用缓存
            output_attentions (bool, optional): 是否输出注意力
            output_hidden_states (bool, optional): 是否输出隐藏状态
            return_dict (bool, optional): 是否返回字典

        Returns:
            根据配置返回不同的输出
        """
        pass  # 此处省略了实际的前向传播逻辑,需要补充完整
    # 返回共享的输入嵌入
    def get_input_embeddings(self):
        return self.shared

    # 设置输入嵌入,并将其分配给编码器和解码器的嵌入
    def set_input_embeddings(self, value):
        self.shared = value
        self.encoder.embed_tokens = self.shared
        self.decoder.embed_tokens = self.shared

    # 返回编码器对象
    def get_encoder(self):
        return self.encoder

    # 返回解码器对象
    def get_decoder(self):
        return self.decoder

    # 如果配置要求,绑定编码器和解码器的词嵌入权重
    def _tie_weights(self):
        if self.config.tie_word_embeddings:
            self._tie_or_clone_weights(self.encoder.embed_tokens, self.get_input_embeddings())
            self._tie_or_clone_weights(self.decoder.embed_tokens, self.get_input_embeddings())

    # 前向传播函数,接收多个输入和控制参数,输出Seq2Seq模型的结果
    @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=Seq2SeqModelOutput,
        config_class=_CONFIG_FOR_DOC,
        expected_output=_EXPECTED_OUTPUT_SHAPE,
    )
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
# 为 MBART 模型添加文档字符串,描述其具有语言建模头部的功能,适用于摘要生成,需要在预训练模型微调后使用。
@add_start_docstrings(
    "The MBART Model with a language modeling head. Can be used for summarization, after fine-tuning the pretrained models.",
    MBART_START_DOCSTRING,
)
class MBartForConditionalGeneration(MBartPreTrainedModel):
    # 基础模型的前缀,用于加载模型时忽略的键
    base_model_prefix = "model"
    # 在加载模型时忽略的缺失键
    _keys_to_ignore_on_load_missing = ["final_logits_bias"]
    # 共享权重的键列表
    _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"]

    # 初始化方法,接受一个 MBART 配置对象
    def __init__(self, config: MBartConfig):
        super().__init__(config)
        # 使用给定的配置创建 MBartModel 模型
        self.model = MBartModel(config)
        # 注册一个缓冲区,用于存储最终对数偏置,维度为 (1, num_embeddings)
        self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
        # 创建一个线性层 lm_head,用于语言建模,输入维度为 config.d_model,输出维度为 num_embeddings
        self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)

        # 初始化权重并进行最终处理
        self.post_init()

    # 获取编码器模型
    def get_encoder(self):
        return self.model.get_encoder()

    # 获取解码器模型
    def get_decoder(self):
        return self.model.get_decoder()

    # 调整 token embeddings 的大小
    def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
        # 调用父类方法,调整 token embeddings 大小
        new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
        # 调整最终对数偏置的大小以匹配新的 token embeddings
        self._resize_final_logits_bias(new_embeddings.weight.shape[0])
        return new_embeddings

    # 调整最终对数偏置的大小
    def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
        old_num_tokens = self.final_logits_bias.shape[-1]
        if new_num_tokens <= old_num_tokens:
            new_bias = self.final_logits_bias[:, :new_num_tokens]
        else:
            extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
            new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
        # 注册调整后的最终对数偏置
        self.register_buffer("final_logits_bias", new_bias)

    # 获取输出 embeddings
    def get_output_embeddings(self):
        return self.lm_head

    # 设置输出 embeddings
    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    # 添加模型前向方法的文档字符串,包括输入的详细说明
    @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
    # 替换返回值的文档字符串,指定输出类型为 Seq2SeqLMOutput
    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
    # 添加模型前向方法的结尾文档字符串,包括生成示例
    @add_end_docstrings(MBART_GENERATION_EXAMPLE)
    # 定义模型的前向传播函数,接收多个输入参数
    def forward(
        self,
        input_ids: torch.LongTensor = None,  # 输入序列的token IDs
        attention_mask: Optional[torch.Tensor] = None,  # 输入序列的注意力掩码
        decoder_input_ids: Optional[torch.LongTensor] = None,  # 解码器的输入token IDs
        decoder_attention_mask: Optional[torch.LongTensor] = None,  # 解码器的注意力掩码
        head_mask: Optional[torch.Tensor] = None,  # 多头注意力机制的掩码
        decoder_head_mask: Optional[torch.Tensor] = None,  # 解码器多头注意力机制的掩码
        cross_attn_head_mask: Optional[torch.Tensor] = None,  # 跨注意力机制的多头掩码
        encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,  # 编码器的输出
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,  # 缓存的键值对
        inputs_embeds: Optional[torch.FloatTensor] = None,  # 输入的嵌入表示
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,  # 解码器输入的嵌入表示
        labels: Optional[torch.LongTensor] = None,  # 模型的标签
        use_cache: Optional[bool] = None,  # 是否使用缓存
        output_attentions: Optional[bool] = None,  # 是否输出注意力权重
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态
        return_dict: Optional[bool] = None,  # 是否返回字典格式的结果
    ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:
            Either a Seq2SeqLMOutput containing loss, logits, and other optional outputs, or a tuple of
            torch.FloatTensor containing logits and optional outputs.

        """
        # Determine whether to return results in a dictionary format or not
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Handle special case where labels are provided
        if labels is not None:
            # Adjust `use_cache` to False when `labels` are provided, with a warning message
            if use_cache:
                logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
            use_cache = False
            # If `decoder_input_ids` or `decoder_inputs_embeds` are not provided, generate `decoder_input_ids`
            if decoder_input_ids is None and decoder_inputs_embeds is None:
                decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)

        # Pass inputs to the model for forward computation
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # Calculate logits for language modeling head and apply bias
        lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias

        masked_lm_loss = None
        # Calculate masked language modeling loss if `labels` are provided
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))

        # Depending on `return_dict`, construct and return output tuple or Seq2SeqLMOutput
        if not return_dict:
            output = (lm_logits,) + outputs[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        # Return results as Seq2SeqLMOutput with specified outputs
        return Seq2SeqLMOutput(
            loss=masked_lm_loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )

    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        past_key_values=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
        # 如果使用了过去的键值对(past_key_values),则根据其长度截断 decoder_input_ids
        if past_key_values is not None:
            past_length = past_key_values[0][0].shape[2]

            # 一些生成方法可能只传递最后一个输入 ID
            if decoder_input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                # 默认行为:只保留最后一个 ID
                remove_prefix_length = decoder_input_ids.shape[1] - 1

            # 从 decoder_input_ids 中截取需要的部分
            decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]

        # 返回一个包含各种模型输入和设置的字典
        return {
            "input_ids": None,  # encoder_outputs 已定义,不需要 input_ids
            "encoder_outputs": encoder_outputs,
            "past_key_values": past_key_values,
            "decoder_input_ids": decoder_input_ids,
            "attention_mask": attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,  # 将此项更改以避免缓存(可能用于调试目的)
        }

    # 从标签(labels)准备解码器输入的静态方法
    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        return shift_tokens_right(labels, self.config.pad_token_id)

    @staticmethod
    # 重新排序缓存中的 past_key_values,以匹配给定的 beam_idx
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            # 缓存的交叉注意力状态无需重新排序 -> 它们始终保持不变
            reordered_past += (
                # 对每个层的过去状态执行索引选择,以匹配给定的 beam_idx
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
                + layer_past[2:],  # 添加未修改的剩余部分
            )
        return reordered_past
@add_start_docstrings(
    """
    MBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
    tasks.
    """,
    MBART_START_DOCSTRING,
)
# 定义 MBart 序列分类模型,建立在 MBartPreTrainedModel 基础上
class MBartForSequenceClassification(MBartPreTrainedModel):
    # 需要共享权重的键列表,用于 tied weights 功能
    _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]

    def __init__(self, config: MBartConfig, **kwargs):
        # 调用父类构造函数初始化模型
        super().__init__(config, **kwargs)
        # 初始化 MBart 模型
        self.model = MBartModel(config)
        # 初始化序列分类头部
        self.classification_head = MBartClassificationHead(
            config.d_model,
            config.d_model,
            config.num_labels,
            config.classifier_dropout,
        )

        # 初始化权重并进行最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=Seq2SeqSequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    # 从 transformers.models.bart.modeling_bart.BartForSequenceClassification.forward 复制而来
    # 前向传播函数定义
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,



@add_start_docstrings(
    """
    MBART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
    MBART_START_DOCSTRING,
)
# 定义 MBart 问答模型,用于类似 SQuAD 的抽取式问答任务
class MBartForQuestionAnswering(MBartPreTrainedModel):
    # 需要共享权重的键列表,用于 tied weights 功能
    _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]

    def __init__(self, config):
        # 调用父类构造函数初始化模型
        super().__init__(config)

        # 设定分类标签数目为 2
        config.num_labels = 2
        self.num_labels = config.num_labels

        # 初始化 MBart 模型
        self.model = MBartModel(config)
        # 初始化问答输出层
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        # 初始化权重并进行最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=Seq2SeqQuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    # Copied from transformers.models.bart.modeling_bart.BartForQuestionAnswering.forward
    # 定义 BartForQuestionAnswering 模型的前向传播方法,接受多个输入参数
    
    def forward(
        self,
        input_ids: torch.Tensor = None,  # 输入的 token IDs 张量
        attention_mask: Optional[torch.Tensor] = None,  # 注意力遮罩张量,用于指示哪些 token 是需要注意的
        decoder_input_ids: Optional[torch.LongTensor] = None,  # 解码器的输入 token IDs 张量
        decoder_attention_mask: Optional[torch.LongTensor] = None,  # 解码器的注意力遮罩张量
        head_mask: Optional[torch.Tensor] = None,  # 多头注意力的掩码
        decoder_head_mask: Optional[torch.Tensor] = None,  # 解码器的多头注意力掩码
        cross_attn_head_mask: Optional[torch.Tensor] = None,  # 跨层注意力的掩码
        encoder_outputs: Optional[List[torch.FloatTensor]] = None,  # 编码器的输出列表
        start_positions: Optional[torch.LongTensor] = None,  # 答案开始位置的张量
        end_positions: Optional[torch.LongTensor] = None,  # 答案结束位置的张量
        inputs_embeds: Optional[torch.FloatTensor] = None,  # 输入的嵌入向量
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,  # 解码器输入的嵌入向量
        use_cache: Optional[bool] = None,  # 是否使用缓存
        output_attentions: Optional[bool] = None,  # 是否输出注意力权重
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态
        return_dict: Optional[bool] = None,  # 是否以字典形式返回结果
# 从 transformers.models.bart.modeling_bart.BartDecoderWrapper 复制而来,将 Bart 替换为 MBart
class MBartDecoderWrapper(MBartPreTrainedModel):
    """
    这个包装类是一个辅助类,用于在因果语言模型与 [`EncoderDecoderModel`] 框架结合使用时正确加载预训练检查点。
    """

    def __init__(self, config):
        # 调用父类构造函数,初始化 MBartDecoderWrapper 对象
        super().__init__(config)
        # 创建 MBartDecoder 对象
        self.decoder = MBartDecoder(config)

    def forward(self, *args, **kwargs):
        # 前向传播函数,调用 MBartDecoder 的前向传播方法
        return self.decoder(*args, **kwargs)


# 从 transformers.models.bart.modeling_bart.BartForCausalLM 复制而来,将 Bart 替换为 MBart,facebook/bart-base 替换为 facebook/mbart-large-cc25
class MBartForCausalLM(MBartPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        # 深拷贝配置对象,配置为解码器,不是编码-解码器结构
        config = copy.deepcopy(config)
        config.is_decoder = True
        config.is_encoder_decoder = False
        # 调用父类构造函数,初始化 MBartForCausalLM 对象
        super().__init__(config)
        # 创建 MBartDecoderWrapper 对象作为模型的核心部分
        self.model = MBartDecoderWrapper(config)

        # 初始化线性层,用于语言模型的输出
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        # 获取输入嵌入层,即 MBartDecoder 的嵌入标记
        return self.model.decoder.embed_tokens

    def set_input_embeddings(self, value):
        # 设置输入嵌入层,即 MBartDecoder 的嵌入标记
        self.model.decoder.embed_tokens = value

    def get_output_embeddings(self):
        # 获取输出嵌入层,即语言模型头部
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        # 设置输出嵌入层,即语言模型头部
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        # 设置解码器,用于动态设置 MBartDecoderWrapper 的解码器
        self.model.decoder = decoder

    def get_decoder(self):
        # 获取解码器,即 MBartDecoderWrapper 的解码器
        return self.model.decoder

    @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 前向传播函数,详细参数见函数声明
        pass

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
    ):
        # 为生成准备输入的函数,详细参数见函数声明
        pass
    ):
        # 如果模型用作编码器-解码器模型中的解码器,解码器注意力遮罩在需要时动态创建
        if attention_mask is None:
            # 如果注意力遮罩为空,则创建一个与输入形状相同的全1张量作为注意力遮罩
            attention_mask = input_ids.new_ones(input_ids.shape)

        if past_key_values:
            # 获取过去关键值的长度
            past_length = past_key_values[0][0].shape[2]

            # 有些生成方法已经只传递了最后一个输入 ID
            if input_ids.shape[1] > past_length:
                # 如果当前输入长度大于过去长度,则移除前缀的长度为过去长度
                remove_prefix_length = past_length
            else:
                # 否则,默认行为是只保留最后一个输入 ID
                remove_prefix_length = input_ids.shape[1] - 1

            # 截取输入序列,移除前缀长度
            input_ids = input_ids[:, remove_prefix_length:]
        # 第一步,解码器缓存状态为空
        return {
            "input_ids": input_ids,  # encoder_outputs is defined. input_ids not needed
            "attention_mask": attention_mask,
            "past_key_values": past_key_values,
            "use_cache": use_cache,
        }

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            # 重新排序过去的关键值,根据 beam_idx 对每一层的 past_state 进行索引选择
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        return reordered_past
posted @ 2024-06-29 17:01  绝不原创的飞龙  阅读(7)  评论(0编辑  收藏  举报