Transformers-源码解析-八十七-

Transformers 源码解析(八十七)

.\models\pegasus\modeling_pegasus.py

# 设置文件编码格式为 UTF-8
# 版权声明,指出版权归 Google 和 HuggingFace Inc. 团队所有
#
# 根据 Apache 许可证 2.0 版本,除非符合许可证的规定,否则不得使用此文件
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则按“原样”分发软件
# 没有任何形式的明示或暗示保证,包括但不限于适销性或特定用途适用性的保证
# 有关详细信息,请参阅许可证

""" PyTorch PEGASUS model."""

import copy  # 导入深拷贝函数
import math  # 导入数学库中的数学函数
from typing import List, Optional, Tuple, Union  # 导入类型提示支持的数据结构

import numpy as np  # 导入 numpy 库
import torch  # 导入 PyTorch 库
import torch.utils.checkpoint  # 导入 PyTorch 检查点工具
from torch import nn  # 从 PyTorch 中导入神经网络模块
from torch.nn import CrossEntropyLoss  # 导入交叉熵损失函数

from ...activations import ACT2FN  # 导入激活函数映射
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask  # 导入注意力掩码工具函数
from ...modeling_outputs import (  # 导入模型输出类
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel  # 导入预训练模型工具函数
from ...utils import (  # 导入通用工具函数
    add_end_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from .configuration_pegasus import PegasusConfig  # 导入 Pegasus 配置文件

logger = logging.get_logger(__name__)  # 获取日志记录器

_CHECKPOINT_FOR_DOC = "google/pegasus-large"  # 用于文档的检查点模型名称
_CONFIG_FOR_DOC = "PegasusConfig"  # 用于文档的配置文件名称

PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST = [  # 预训练模型的存档列表
    "google/pegasus-large",
    # 可以在 https://huggingface.co/models?filter=pegasus 查看所有 PEGASUS 模型
]


# 从 transformers.models.bart.modeling_bart.shift_tokens_right 复制过来
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    """
    将输入的 token 向右移动一位。
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)  # 创建与 input_ids 形状相同的零张量
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()  # 将 input_ids 向右移动一位
    shifted_input_ids[:, 0] = decoder_start_token_id  # 在首位插入 decoder_start_token_id

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    # 将 labels 中可能的 -100 值替换为 pad_token_id
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

    return shifted_input_ids


# 从 transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding 复制过来,并将 Marian 改为 Pegasus
class PegasusSinusoidalPositionalEmbedding(nn.Embedding):
    """该模块生成任意长度的正弦位置嵌入。"""

    def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
        super().__init__(num_positions, embedding_dim)  # 调用父类的初始化方法
        self.weight = self._init_weight(self.weight)  # 初始化权重

    @staticmethod
    def _init_weight(out: nn.Parameter) -> nn.Parameter:
        """
        Initialize positional embeddings for transformer model.

        Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
        the 2nd half of the vector. [dim // 2:]
        """
        n_pos, dim = out.shape  # 获取输出张量的形状,n_pos 表示位置数,dim 表示维度
        # 创建位置编码矩阵,用于表示不同位置的嵌入向量
        position_enc = np.array(
            [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
        )
        out.requires_grad = False  # 设置张量为不需要梯度,以避免在 PyTorch 1.8+ 版本中出现错误
        sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1  # 计算用于分隔 sin 和 cos 的索引位置
        # 将 sin 和 cos 值填充到输出张量的不同部分
        out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
        out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
        out.detach_()  # 分离张量,防止在后续计算中被修改
        return out

    @torch.no_grad()
    def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
        """
        Perform forward pass of the transformer model.

        `input_ids_shape` is expected to be [bsz x seqlen].
        """
        bsz, seq_len = input_ids_shape[:2]  # 解析输入张量的大小,bsz 表示批量大小,seq_len 表示序列长度
        # 生成位置编码张量,表示每个位置的索引,加上 past_key_values_length 以适应历史键值长度
        positions = torch.arange(
            past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
        )
        return super().forward(positions)
# 从transformers.models.bart.modeling_bart.BartAttention复制并将Bart->Pegasus
class PegasusAttention(nn.Module):
    """来自'Attention Is All You Need'论文的多头注意力机制"""

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        bias: bool = True,
        is_causal: bool = False,
        config: Optional[PegasusConfig] = None,
    ):
        super().__init__()
        self.embed_dim = embed_dim  # 初始化嵌入维度
        self.num_heads = num_heads  # 初始化注意头数
        self.dropout = dropout  # 初始化dropout率
        self.head_dim = embed_dim // num_heads  # 每个注意头的维度
        self.config = config  # Pegasus的配置对象

        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim必须能被num_heads整除 (得到的 `embed_dim`: {self.embed_dim}"
                f" 和 `num_heads`: {num_heads})."
            )
        self.scaling = self.head_dim**-0.5  # 缩放因子
        self.is_decoder = is_decoder  # 是否为解码器的标志
        self.is_causal = is_causal  # 是否为因果注意力的标志

        # 初始化线性变换层
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        # 重新整形张量以适应多头注意力的结构
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        key_value_states: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ):
        # PegasusAttention的前向传播函数
        pass  # 函数体未提供完整,暂时无内容

# 从transformers.models.mbart.modeling_mbart.MBartEncoderLayer复制并将MBart->Pegasus, MBART->PEGASUS
class PegasusEncoderLayer(nn.Module):
    def __init__(self, config: PegasusConfig):
        super().__init__()
        self.embed_dim = config.d_model  # 初始化嵌入维度

        # 使用配置中的注意力实现类构建自注意力层
        self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
            embed_dim=self.embed_dim,
            num_heads=config.encoder_attention_heads,
            dropout=config.attention_dropout,
            config=config,
        )
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)  # 自注意力层的LayerNorm
        self.dropout = config.dropout  # dropout率
        self.activation_fn = ACT2FN[config.activation_function]  # 激活函数
        self.activation_dropout = config.activation_dropout  # 激活函数的dropout率
        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)  # 第一个前馈网络层
        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)  # 第二个前馈网络层
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)  # 最终的LayerNorm

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        layer_head_mask: torch.Tensor,
        output_attentions: bool = False,
    ):
        # PegasusEncoderLayer的前向传播函数
        pass  # 函数体未提供完整,暂时无内容
    ) -> torch.Tensor:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        """
        # 保留输入的残差连接
        residual = hidden_states
        # 对输入的 hidden_states 进行 Layer Normalization
        hidden_states = self.self_attn_layer_norm(hidden_states)
        # 使用 self-attention 机制计算新的 hidden_states,并返回 attention 权重和额外信息
        hidden_states, attn_weights, _ = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            layer_head_mask=layer_head_mask,
            output_attentions=output_attentions,
        )
        # 对计算后的 hidden_states 应用 dropout
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        # 将残差与新计算得到的 hidden_states 相加,形成新的 hidden_states
        hidden_states = residual + hidden_states

        # 再次保留输入的残差连接
        residual = hidden_states
        # 对更新后的 hidden_states 进行 Layer Normalization
        hidden_states = self.final_layer_norm(hidden_states)
        # 应用激活函数和线性变换 fc1 到 hidden_states
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        # 对经过 fc1 的 hidden_states 应用 dropout
        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
        # 经过第二个线性变换 fc2
        hidden_states = self.fc2(hidden_states)
        # 对 fc2 输出的 hidden_states 应用 dropout
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        # 将残差与新计算得到的 hidden_states 相加,形成最终的 hidden_states
        hidden_states = residual + hidden_states

        # 如果 hidden_states 的数据类型是 torch.float16,并且包含无穷大或 NaN 值
        if hidden_states.dtype == torch.float16 and (
            torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
        ):
            # 对 hidden_states 进行截断操作,避免超出数据类型的范围
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

        # 将最终的 hidden_states 存入 outputs
        outputs = (hidden_states,)

        # 如果需要输出 attentions,将 attentions 加入到 outputs 中
        if output_attentions:
            outputs += (attn_weights,)

        # 返回 outputs
        return outputs
# 从transformers.models.mbart.modeling_mbart.MBartDecoderLayer复制过来,将MBart替换为Pegasus,MBART替换为PEGASUS
class PegasusDecoderLayer(nn.Module):
    def __init__(self, config: PegasusConfig):
        super().__init__()
        self.embed_dim = config.d_model  # 设置嵌入维度为配置中的模型维度大小

        # 初始化自注意力层,根据配置选择实现方式,设定头数、dropout等参数
        self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
            embed_dim=self.embed_dim,
            num_heads=config.decoder_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
            is_causal=True,
            config=config,
        )
        self.dropout = config.dropout  # 设置dropout率
        self.activation_fn = ACT2FN[config.activation_function]  # 激活函数设定为配置中指定的函数
        self.activation_dropout = config.activation_dropout  # 激活函数的dropout率

        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)  # 初始化自注意力层的LayerNorm

        # 初始化编码器-解码器注意力层,根据配置选择实现方式,设定头数、dropout等参数
        self.encoder_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
            self.embed_dim,
            config.decoder_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
            config=config,
        )
        self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)  # 初始化编码器-解码器注意力层的LayerNorm

        self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)  # 第一个线性层
        self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)  # 第二个线性层
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)  # 最终输出的LayerNorm

    # 前向传播函数定义
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = True,



# PegasusPreTrainedModel类定义,继承自PreTrainedModel
class PegasusPreTrainedModel(PreTrainedModel):
    config_class = PegasusConfig  # 指定配置类为PegasusConfig
    base_model_prefix = "model"  # 基础模型前缀设定为"model"
    supports_gradient_checkpointing = True  # 支持梯度检查点

    # 初始化权重函数
    def _init_weights(self, module):
        std = self.config.init_std  # 初始化标准差设定为配置中的初始标准差
        if isinstance(module, nn.Linear):  # 如果是线性层
            module.weight.data.normal_(mean=0.0, std=std)  # 权重初始化为正态分布
            if module.bias is not None:
                module.bias.data.zero_()  # 如果有偏置项,初始化为0
        elif isinstance(module, PegasusSinusoidalPositionalEmbedding):
            pass  # 如果是PegasusSinusoidalPositionalEmbedding类型,则不进行任何初始化操作
        elif isinstance(module, nn.Embedding):  # 如果是嵌入层
            module.weight.data.normal_(mean=0.0, std=std)  # 权重初始化为正态分布
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()  # 如果有padding_idx,对应位置初始化为0



# PEGASUS_START_DOCSTRING文档字符串定义
PEGASUS_START_DOCSTRING = r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.

"""
    # 使用它作为普通的 PyTorch 模块,并参考 PyTorch 文档以了解所有与一般用法和行为相关的事项。

    Parameters:
        config ([`PegasusConfig`]):
            模型配置类,包含模型的所有参数。使用配置文件初始化不会加载模型的权重,只加载配置信息。查看
            [`~PreTrainedModel.from_pretrained`] 方法以加载模型权重。
"""

PEGASUS_GENERATION_EXAMPLE = r"""
    Summarization example:

    ```
    >>> from transformers import AutoTokenizer, PegasusForConditionalGeneration

    >>> model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum")
    >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-xsum")

    >>> ARTICLE_TO_SUMMARIZE = (
    ...     "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
    ...     "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
    ...     "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
    ... )
    >>> inputs = tokenizer(ARTICLE_TO_SUMMARIZE, max_length=1024, return_tensors="pt")

    >>> # Generate Summary
    >>> summary_ids = model.generate(inputs["input_ids"])
    >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    "California's largest electricity provider has turned off power to hundreds of thousands of customers."
    ```
"""

PEGASUS_INPUTS_DOCSTRING = r"""
"""


class PegasusEncoder(PegasusPreTrainedModel):
    """
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    [`PegasusEncoderLayer`].

    Args:
        config: PegasusConfig
        embed_tokens (nn.Embedding): output embedding
    """

    def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
        super().__init__(config)

        self.dropout = config.dropout
        self.layerdrop = config.encoder_layerdrop

        embed_dim = config.d_model
        self.padding_idx = config.pad_token_id
        self.max_source_positions = config.max_position_embeddings
        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

        # Initialize embedding tokens with padding index if provided, otherwise default
        if embed_tokens is not None:
            self.embed_tokens = embed_tokens
        else:
            self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)

        # Initialize sinusoidal positional embeddings
        self.embed_positions = PegasusSinusoidalPositionalEmbedding(
            config.max_position_embeddings,
            embed_dim,
            self.padding_idx,
        )

        # Create encoder layers based on config
        self.layers = nn.ModuleList([PegasusEncoderLayer(config) for _ in range(config.encoder_layers)])

        # Layer normalization for encoder output
        self.layer_norm = nn.LayerNorm(config.d_model)

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()
    def resize_position_embeddings(self, new_num_position_embeddings: int):
        """
        Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
        config.max_position_embeddings`.

        Arguments:
            new_num_position_embeddings (`int`):
                The number of new position embeddings. If position embeddings are learned, increasing the size will add
                newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
                position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
                add correct vectors at the end following the position encoding algorithm, whereas reducing the size
                will remove vectors from the end.
        """
        # 记录日志,显示设置最大位置编码数
        logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
        # 更新模型配置中的最大位置编码数
        self.config.max_position_embeddings = new_num_position_embeddings

        # 创建新的位置编码嵌入对象,根据新的最大位置编码数和模型维度创建
        self.embed_positions = PegasusSinusoidalPositionalEmbedding(
            self.config.max_position_embeddings,
            self.config.d_model,
            self.padding_idx,
        )
        # 将位置编码嵌入对象移到指定设备(通常是 GPU)
        self.embed_positions.to(self.device)

    def get_position_embeddings(self) -> nn.Embedding:
        """
        Returns the position embeddings matrix
        """
        # 返回当前位置编码嵌入对象
        return self.embed_positions

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
class PegasusDecoder(PegasusPreTrainedModel):
    """
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PegasusDecoderLayer`]

    Args:
        config: PegasusConfig
        embed_tokens (nn.Embedding): output embedding
    """

    def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
        super().__init__(config)
        self.dropout = config.dropout  # 从配置中获取 dropout 概率
        self.layerdrop = config.decoder_layerdrop  # 从配置中获取层级丢弃率
        self.padding_idx = config.pad_token_id  # 从配置中获取填充符索引
        self.max_target_positions = config.max_position_embeddings  # 从配置中获取最大目标位置数
        self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0  # 根据配置设置嵌入缩放因子

        if embed_tokens is not None:
            self.embed_tokens = embed_tokens  # 如果提供了嵌入词表,直接使用,否则创建新的 nn.Embedding
        else:
            self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)

        # 使用 PegasusSinusoidalPositionalEmbedding 类创建位置嵌入
        self.embed_positions = PegasusSinusoidalPositionalEmbedding(
            config.max_position_embeddings,
            config.d_model,
            self.padding_idx,
        )

        # 使用 PegasusDecoderLayer 类创建多层解码器层
        self.layers = nn.ModuleList([PegasusDecoderLayer(config) for _ in range(config.decoder_layers)])

        # 使用 nn.LayerNorm 创建层归一化模块
        self.layer_norm = nn.LayerNorm(config.d_model)

        self.gradient_checkpointing = False  # 初始化梯度检查点为 False

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens  # 返回输入嵌入

    def set_input_embeddings(self, value):
        self.embed_tokens = value  # 设置输入嵌入

    def resize_position_embeddings(self, new_num_position_embeddings: int):
        """
        Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
        config.max_position_embeddings`.

        Arguments:
            new_num_position_embeddings (`int`):
                The number of new position embeddings. If position embeddings are learned, increasing the size will add
                newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
                position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
                add correct vectors at the end following the position encoding algorithm, whereas reducing the size
                will remove vectors from the end.
        """
        logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
        self.config.max_position_embeddings = new_num_position_embeddings

        # 根据新的位置嵌入数量重新设置位置嵌入矩阵
        self.embed_positions = PegasusSinusoidalPositionalEmbedding(
            self.config.max_position_embeddings,
            self.config.d_model,
            self.padding_idx,
        )
        self.embed_positions.to(self.device)  # 将位置嵌入移到指定设备上

    def get_position_embeddings(self) -> nn.Embedding:
        """
        Returns the position embeddings matrix
        """
        return self.embed_positions  # 返回位置嵌入矩阵
    # 定义神经网络模型的前向传播方法,接受多个输入参数用于模型推理
    def forward(
        self,
        input_ids=None,                    # 输入的 token IDs,用于模型输入
        attention_mask=None,               # 注意力遮罩,指示哪些位置是padding或特殊token
        encoder_hidden_states=None,        # 编码器的隐藏状态,用于某些模型(如BERT)
        encoder_attention_mask=None,       # 编码器的注意力遮罩,指示编码器输入中padding位置
        head_mask=None,                    # 多头注意力机制中的头部掩码,控制哪些头部被屏蔽
        cross_attn_head_mask=None,         # 跨注意力头的掩码,用于控制哪些跨注意力头被屏蔽
        past_key_values=None,              # 用于存储过去的键值对,提高解码效率
        inputs_embeds=None,                # 直接提供的嵌入表示,而不是通过输入ID计算得到
        use_cache=None,                    # 控制是否使用缓存加速解码
        output_attentions=None,            # 控制是否输出注意力权重
        output_hidden_states=None,         # 控制是否输出所有隐藏状态
        return_dict=None,                  # 控制是否以字典形式返回结果
# 使用 `add_start_docstrings` 装饰器为 PegasusModel 类添加文档字符串,描述它是一个不带特定头部的 PEGASUS 模型的裸输出版本。
# 引用 PEGASUS_START_DOCSTRING,补充完整文档字符串内容。

class PegasusModel(PegasusPreTrainedModel):
    # 定义了共享权重的键列表,这些权重在编码器和解码器的嵌入层之间共享
    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

    def __init__(self, config: PegasusConfig):
        super().__init__(config)

        # 初始化填充索引和词汇表大小
        padding_idx, vocab_size = config.pad_token_id, config.vocab_size
        # 创建一个共享的嵌入层对象,用于词嵌入
        self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)

        # 创建编码器和解码器对象,并传入共享的嵌入层对象
        self.encoder = PegasusEncoder(config, self.shared)
        self.decoder = PegasusDecoder(config, self.shared)

        # 初始化模型权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        # 返回共享的嵌入层对象,用于模型输入的词嵌入
        return self.shared

    def set_input_embeddings(self, value):
        # 设置新的共享嵌入层对象,并更新编码器和解码器的嵌入层
        self.shared = value
        self.encoder.embed_tokens = self.shared
        self.decoder.embed_tokens = self.shared

    def get_encoder(self):
        # 返回编码器对象
        return self.encoder

    def get_decoder(self):
        # 返回解码器对象
        return self.decoder

    def resize_position_embeddings(self, new_num_position_embeddings: int):
        """
        调整模型的位置嵌入矩阵大小,如果 `new_num_position_embeddings != config.max_position_embeddings`。

        参数:
            new_num_position_embeddings (`int`):
                新的位置嵌入数量。如果位置嵌入是学习的,则增加大小将在末尾添加新初始化的向量,
                减小大小将从末尾删除向量。如果位置嵌入不是学习的(如正弦位置嵌入),增加大小将
                在末尾添加正确的向量,减小大小将从末尾删除向量。
        """
        self.config.max_position_embeddings = new_num_position_embeddings
        # 调整编码器和解码器的位置嵌入大小
        self.encoder.resize_position_embeddings(new_num_position_embeddings)
        self.decoder.resize_position_embeddings(new_num_position_embeddings)

    def get_position_embeddings(self) -> Tuple[nn.Embedding]:
        """
        返回位置嵌入矩阵
        """
        return (self.encoder.get_position_embeddings(), self.decoder.get_position_embeddings())

    @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
    # 定义名为 forward 的方法,用于进行模型的前向传播
    def forward(
        # 输入序列的标识符张量,可选参数,默认为 None
        input_ids: Optional[torch.Tensor] = None,
        # 注意力遮罩张量,可选参数,默认为 None
        attention_mask: Optional[torch.Tensor] = None,
        # 解码器输入序列的标识符张量,可选参数,默认为 None
        decoder_input_ids: Optional[torch.Tensor] = None,
        # 解码器输入的注意力遮罩张量,可选参数,默认为 None
        decoder_attention_mask: Optional[torch.Tensor] = None,
        # 头遮罩张量,可选参数,默认为 None
        head_mask: Optional[torch.Tensor] = None,
        # 解码器头遮罩张量,可选参数,默认为 None
        decoder_head_mask: Optional[torch.Tensor] = None,
        # 交叉注意力头遮罩张量,可选参数,默认为 None
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        # 编码器输出的元组张量,可选参数,默认为 None
        encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
        # 过去键值的元组张量,可选参数,默认为 None
        past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
        # 输入的嵌入张量,可选参数,默认为 None
        inputs_embeds: Optional[torch.Tensor] = None,
        # 解码器输入的嵌入张量,可选参数,默认为 None
        decoder_inputs_embeds: Optional[torch.Tensor] = None,
        # 是否使用缓存,布尔型可选参数,默认为 None
        use_cache: Optional[bool] = None,
        # 是否输出注意力,布尔型可选参数,默认为 None
        output_attentions: Optional[bool] = None,
        # 是否输出隐藏状态,布尔型可选参数,默认为 None
        output_hidden_states: Optional[bool] = None,
        # 是否返回字典,布尔型可选参数,默认为 None
        return_dict: Optional[bool] = None,
@add_start_docstrings(
    "The PEGASUS Model with a language modeling head. Can be used for summarization.", PEGASUS_START_DOCSTRING
)
class PegasusForConditionalGeneration(PegasusPreTrainedModel):
    # 在模型类上添加文档字符串,说明它是带有语言建模头部的PEGASUS模型,可用于摘要生成

    base_model_prefix = "model"
    # 定义基础模型的前缀为 "model"

    _keys_to_ignore_on_load_missing = ["final_logits_bias"]
    # 定义在加载模型时忽略的键列表,此处为 ["final_logits_bias"]

    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
    # 定义权重共享的键列表,包括编码器和解码器的嵌入权重以及语言模型头部的权重

    def __init__(self, config: PegasusConfig):
        super().__init__(config)
        # 调用父类的初始化方法

        self.model = PegasusModel(config)
        # 实例化一个PEGASUS模型,使用给定的配置参数

        self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
        # 注册一个缓冲区(buffer),用零填充的张量,用于模型的最终对数偏置

        self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
        # 初始化语言建模头部,使用线性层将输入特征维度映射到词汇表大小的输出维度,无偏置项

        # Initialize weights and apply final processing
        self.post_init()
        # 调用后处理函数,初始化权重并进行最终处理

    def get_encoder(self):
        # 返回模型的编码器
        return self.model.get_encoder()

    def get_decoder(self):
        # 返回模型的解码器
        return self.model.get_decoder()

    def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
        # 调整词嵌入矩阵的大小,继承自父类的方法

        new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
        # 调用父类的方法,获取调整后的新词嵌入矩阵

        self._resize_final_logits_bias(new_embeddings.weight.shape[0])
        # 调整最终对数偏置的大小以匹配新的词嵌入矩阵大小

        return new_embeddings
        # 返回调整后的新词嵌入矩阵

    def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
        # 调整最终对数偏置的大小,私有方法

        old_num_tokens = self.final_logits_bias.shape[-1]
        # 获取当前最终对数偏置的维度大小

        if new_num_tokens <= old_num_tokens:
            # 如果新的词嵌入数小于等于当前对数偏置数

            new_bias = self.final_logits_bias[:, :new_num_tokens]
            # 截取当前最终对数偏置,以匹配新的词嵌入数
        else:
            # 如果新的词嵌入数大于当前对数偏置数

            extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
            # 创建额外的对数偏置,用零填充

            new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
            # 将当前对数偏置与额外对数偏置连接起来,以匹配新的词嵌入数

        self.register_buffer("final_logits_bias", new_bias)
        # 注册调整后的最终对数偏置

    def get_output_embeddings(self):
        # 返回语言建模头部
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        # 设置语言建模头部的新词嵌入
        self.lm_head = new_embeddings

    def resize_position_embeddings(self, new_num_position_embeddings: int):
        """
        调整模型的位置嵌入矩阵,如果 `new_num_position_embeddings != config.max_position_embeddings`。

        Arguments:
            new_num_position_embeddings (`int`):
                新的位置嵌入数量。如果位置嵌入是可学习的,增加大小将在末尾添加新的初始化向量,减少大小将从末尾移除向量。
                如果位置嵌入不可学习(如正弦位置嵌入),增加大小将根据位置编码算法在末尾添加正确的向量,减少大小将从末尾移除向量。
        """
        self.config.max_position_embeddings = new_num_position_embeddings
        # 设置配置文件中的最大位置嵌入数量

        self.model.encoder.resize_position_embeddings(new_num_position_embeddings)
        # 调整模型编码器的位置嵌入矩阵大小

        self.model.decoder.resize_position_embeddings(new_num_position_embeddings)
        # 调整模型解码器的位置嵌入矩阵大小
    # 定义一个方法,返回编码器和解码器的位置嵌入矩阵
    def get_position_embeddings(self) -> Tuple[nn.Embedding]:
        """
        Returns the position embeddings matrix
        """
        # 调用模型对象的编码器和解码器的位置嵌入矩阵方法,返回二元组
        return (self.model.encoder.get_position_embeddings(), self.model.decoder.get_position_embeddings())

    # 将模型前向方法添加文档字符串,使用 PEGASUS_INPUTS_DOCSTRING 描述输入
    @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
    # 替换返回值文档字符串为 Seq2SeqLMOutput 类型,并使用 _CONFIG_FOR_DOC 配置类
    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
    # 在前向方法末尾添加 PEGASUS_GENERATION_EXAMPLE 文档字符串
    @add_end_docstrings(PEGASUS_GENERATION_EXAMPLE)
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.Tensor] = None,
        decoder_attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
        past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        decoder_inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: torch.LongTensor = None,
        decoder_input_ids: torch.LongTensor = None,
        encoder_outputs: Optional[ModelOutput] = None,
        decoder_attention_mask: torch.LongTensor = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        decoder_inputs_embeds: Optional[torch.Tensor] = None,
        use_cache: bool = True,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, Seq2SeqLMOutput]:
    
    
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
        """
    
    
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    
    
        if labels is not None:
    
    
            if use_cache:
                logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
            use_cache = False
    
    
            if decoder_input_ids is None and decoder_inputs_embeds is None:
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
                )
    
    
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
    
    
        lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
    
    
        masked_lm_loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
    
    
        if not return_dict:
            output = (lm_logits,) + outputs[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
    
    
        return Seq2SeqLMOutput(
            loss=masked_lm_loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )
    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        past_key_values=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):
        # 如果使用了过去的键值(past_key_values),则截断decoder_input_ids
        if past_key_values is not None:
            past_length = past_key_values[0][0].shape[2]

            # 某些生成方法已经只传递了最后一个输入 ID
            if decoder_input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                # 默认保留旧的行为:只保留最后一个 ID
                remove_prefix_length = decoder_input_ids.shape[1] - 1

            decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]

        # 返回准备好的输入信息作为字典
        return {
            "input_ids": None,  # encoder_outputs 已定义,不需要 input_ids
            "encoder_outputs": encoder_outputs,
            "past_key_values": past_key_values,
            "decoder_input_ids": decoder_input_ids,
            "attention_mask": attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,  # 将此项更改以避免缓存(可能用于调试)
        }

    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        # 将标签右移一位以作为解码器的输入
        return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            # 缓存的交叉注意力状态无需重新排序 -> 它们始终保持不变
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
                + layer_past[2:],
            )
        # 返回重新排序后的过去键值
        return reordered_past
# 从transformers.models.bart.modeling_bart.BartDecoderWrapper复制并修改为PegasusDecoderWrapper
class PegasusDecoderWrapper(PegasusPreTrainedModel):
    """
    这个包装类是一个辅助类,用于在因果语言模型与EncoderDecoderModel框架结合使用时正确加载预训练检查点。
    """

    def __init__(self, config):
        super().__init__(config)
        # 初始化Pegasus解码器
        self.decoder = PegasusDecoder(config)

    def forward(self, *args, **kwargs):
        return self.decoder(*args, **kwargs)


class PegasusForCausalLM(PegasusPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        config = copy.deepcopy(config)
        config.is_decoder = True
        config.is_encoder_decoder = False
        super().__init__(config)
        # 使用PegasusDecoderWrapper来构建模型
        self.model = PegasusDecoderWrapper(config)

        # 定义LM头部线性层
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # 初始化权重并进行最终处理
        self.post_init()

    def get_input_embeddings(self):
        # 获取输入嵌入层
        return self.model.decoder.embed_tokens

    def set_input_embeddings(self, value):
        # 设置输入嵌入层
        self.model.decoder.embed_tokens = value

    def get_output_embeddings(self):
        # 获取输出嵌入层(LM头部)
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        # 设置输出嵌入层(LM头部)
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        # 设置解码器
        self.model.decoder = decoder

    def get_decoder(self):
        # 获取解码器
        return self.model.decoder

    def get_position_embeddings(self) -> nn.Embedding:
        """
        返回位置嵌入矩阵
        """
        return self.model.decoder.get_position_embeddings()

    def resize_position_embeddings(self, new_num_position_embeddings: int):
        """
        如果`new_num_position_embeddings != config.max_position_embeddings`,则调整模型的位置嵌入矩阵大小。

        参数:
            new_num_position_embeddings (`int`):
                新的位置嵌入数量。如果位置嵌入是学习的,则增加大小将在末尾添加新初始化的向量,而减小大小将从末尾删除向量。
                如果位置嵌入不是学习的(如正弦位置嵌入),增加大小将按照位置编码算法在末尾添加正确的向量,而减小大小将从末尾删除向量。
        """
        self.config.max_position_embeddings = new_num_position_embeddings
        self.model.decoder.resize_position_embeddings(new_num_position_embeddings)

    @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
    # 从transformers.models.bart.modeling_bart.BartForCausalLM.forward复制并修改为Pegasus,facebook/bart-base->google/pegasus-large
    # 定义一个方法 `forward`,用于模型的前向传播
    def forward(
        self,
        input_ids: torch.LongTensor = None,  # 输入的token ID序列,默认为None
        attention_mask: Optional[torch.Tensor] = None,  # 注意力遮罩,可选参数,默认为None
        encoder_hidden_states: Optional[torch.FloatTensor] = None,  # 编码器的隐藏状态,可选参数,默认为None
        encoder_attention_mask: Optional[torch.FloatTensor] = None,  # 编码器的注意力遮罩,可选参数,默认为None
        head_mask: Optional[torch.Tensor] = None,  # 头部遮罩,可选参数,默认为None
        cross_attn_head_mask: Optional[torch.Tensor] = None,  # 跨注意力头部遮罩,可选参数,默认为None
        past_key_values: Optional[List[torch.FloatTensor]] = None,  # 过去的键值对,列表,可选参数,默认为None
        inputs_embeds: Optional[torch.FloatTensor] = None,  # 输入的嵌入表示,可选参数,默认为None
        labels: Optional[torch.LongTensor] = None,  # 标签,可选参数,默认为None
        use_cache: Optional[bool] = None,  # 是否使用缓存,可选参数,默认为None
        output_attentions: Optional[bool] = None,  # 是否输出注意力,可选参数,默认为None
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态,可选参数,默认为None
        return_dict: Optional[bool] = None,  # 是否返回字典格式结果,可选参数,默认为None
    ):
        # 定义一个静态方法 `prepare_inputs_for_generation`,用于生成输入准备
        def prepare_inputs_for_generation(
            self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
        ):
            # 如果没有给定注意力遮罩,则创建一个全为1的注意力遮罩,形状与输入的token ID序列相同
            if attention_mask is None:
                attention_mask = input_ids.new_ones(input_ids.shape)

            # 如果有过去的键值对传入
            if past_key_values:
                past_length = past_key_values[0][0].shape[2]

                # 一些生成方法可能已经只传入了最后一个输入ID
                if input_ids.shape[1] > past_length:
                    remove_prefix_length = past_length
                else:
                    # 默认行为:保留最后一个输入ID
                    remove_prefix_length = input_ids.shape[1] - 1

                # 移除前缀长度部分的输入ID序列
                input_ids = input_ids[:, remove_prefix_length:]

            # 返回一个字典,包含处理后的输入参数
            return {
                "input_ids": input_ids,  # 输入的token ID序列,不需要这个参数
                "attention_mask": attention_mask,  # 注意力遮罩
                "past_key_values": past_key_values,  # 过去的键值对
                "use_cache": use_cache,  # 是否使用缓存
            }

        # 静态方法 `_reorder_cache`,用于重新排序缓存的过去的键值对
        @staticmethod
        def _reorder_cache(past_key_values, beam_idx):
            reordered_past = ()
            # 对每一层的过去状态重新排序
            for layer_past in past_key_values:
                reordered_past += (
                    # 对于每个过去状态,在给定的beam索引上进行索引选择
                    tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
                )
            return reordered_past

.\models\pegasus\modeling_tf_pegasus.py

# 定义函数 shift_tokens_right,将输入的 token 序列向右移动一位,用于生成模型的 decoder 输入
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
    # 将 pad_token_id 和 decoder_start_token_id 转换为与 input_ids 相同的数据类型
    pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
    decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
    
    # 创建与 input_ids 相同大小的起始 token 序列
    start_tokens = tf.fill(
        (shape_list(input_ids)[0], 1),  # 使用 shape_list 获取 batch 大小,填充为列向量
        tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype)  # 转换 decoder_start_token_id 的数据类型
    )
    
    # 将 start_tokens 和 input_ids 向右移动一位拼接起来,构成 shifted_input_ids
    shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
    
    # 将 labels 中可能的 -100 值替换为 pad_token_id
    shifted_input_ids = tf.where(
        shifted_input_ids == -100,  # 找到所有值为 -100 的位置
        tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),  # 替换为 pad_token_id
        shifted_input_ids,  # 否则保持原值不变
    )
    
    # 断言 shifted_input_ids 中的值大于等于 0,确保 labels 中的值为正值或 -100
    assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
    
    # 确保断言操作被调用,通过包装结果在 identity 操作中
    with tf.control_dependencies([assert_gte0]):
        shifted_input_ids = tf.identity(shifted_input_ids)
    
    # 返回处理后的 shifted_input_ids
    return shifted_input_ids
# 创建一个用于双向自注意力的因果掩码
def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
    """
    Make causal mask used for bi-directional self-attention.
    创建用于双向自注意力的因果掩码。
    """
    bsz = input_ids_shape[0]  # 获取批量大小
    tgt_len = input_ids_shape[1]  # 获取目标长度
    mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE  # 创建全为大负数的掩码矩阵
    mask_cond = tf.range(shape_list(mask)[-1])  # 生成一个形状与掩码最后一维相同的范围

    # 将掩码中的上三角区域置零,形成因果掩码
    mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)

    if past_key_values_length > 0:
        # 如果过去键值的长度大于零,则在掩码左侧添加零矩阵
        mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)

    return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))


# 从transformers.models.bart.modeling_tf_bart._expand_mask复制过来
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
    """
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    将注意力掩码从 `[bsz, seq_len]` 扩展到 `[bsz, 1, tgt_seq_len, src_seq_len]`。
    """
    src_len = shape_list(mask)[1]  # 获取掩码的源长度
    tgt_len = tgt_len if tgt_len is not None else src_len  # 如果目标长度不为None,则使用目标长度,否则使用源长度
    one_cst = tf.constant(1.0)  # 创建常数值为1.0的张量
    mask = tf.cast(mask, dtype=one_cst.dtype)  # 将掩码转换为与one_cst相同的数据类型
    expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))  # 在掩码的第二、三维度上复制掩码,以扩展维度

    return (one_cst - expanded_mask) * LARGE_NEGATIVE  # 返回取反后乘以大负数的扩展掩码


# 从transformers.models.marian.modeling_tf_marian.TFMarianSinusoidalPositionalEmbedding复制过来,将Marian改为Pegasus
class TFPegasusSinusoidalPositionalEmbedding(keras.layers.Layer):
    """This module produces sinusoidal positional embeddings of any length.
    该模块生成任意长度的正弦位置嵌入。
    """

    def __init__(self, num_positions: int, embedding_dim: int, **kwargs):
        super().__init__(**kwargs)

        if embedding_dim % 2 != 0:
            raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")

        self.embedding_dim = embedding_dim  # 嵌入维度
        self.num_positions = num_positions  # 位置数量

    def build(self, input_shape: tf.TensorShape):
        """
        Build shared token embedding layer Shared weights logic adapted from
        https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
        构建共享的标记嵌入层,权重初始化逻辑参考自上述链接的实现。
        """

        weight = self._init_weight(self.num_positions, self.embedding_dim)  # 初始化权重

        self.weight = self.add_weight(
            name="embeddings",
            shape=[self.num_positions, self.embedding_dim],
        )
        weight = tf.cast(weight, dtype=self.weight.dtype)  # 将权重转换为与self.weight相同的数据类型

        self.weight.assign(weight)  # 分配权重

        super().build(input_shape)  # 调用父类的build方法
    def _init_weight(n_pos: int, dim: int):
        """
        Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
        the 2nd half of the vector. [dim // 2:]
        """
        # 创建一个二维数组,每行代表一个位置编码向量,计算公式与 Transformer 中的位置编码相同
        position_enc = np.array(
            [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
        )
        # 创建一个与 position_enc 相同形状的全零数组
        table = np.zeros_like(position_enc)
        # 将 position_enc 中的 sin 值复制到 table 的前半部分
        table[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2])
        # 将 position_enc 中的 cos 值复制到 table 的后半部分
        table[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
        # 将 table 转换为 TensorFlow 的 tensor 对象
        table = tf.convert_to_tensor(table)
        # 停止梯度在 table 上的传播
        tf.stop_gradient(table)
        return table

    def call(
        self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: tf.Tensor | None = None
    ):
        """Input is expected to be of size [bsz x seqlen]."""
        # 如果未提供位置编码,根据输入的形状创建位置编码的索引
        if position_ids is None:
            seq_len = input_shape[1]
            position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
        # 根据位置编码索引从预先初始化的权重表中获取位置编码向量
        return tf.gather(self.weight, position_ids)
# 从 transformers.models.bart.modeling_tf_bart.TFBartAttention 复制而来,将 Bart 改为 Pegasus
class TFPegasusAttention(keras.layers.Layer):
    """Multi-headed attention from "Attention Is All You Need"""

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        bias: bool = True,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim  # 设置注意力机制的嵌入维度

        self.num_heads = num_heads  # 头数,即注意力头的数量
        self.dropout = keras.layers.Dropout(dropout)  # Dropout 层,用于随机失活
        self.head_dim = embed_dim // num_heads  # 每个注意力头的维度
        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"  # 抛出异常,如果 embed_dim 不能被 num_heads 整除
                f" and `num_heads`: {num_heads})."
            )
        self.scaling = self.head_dim**-0.5  # 缩放因子,用于注意力计算时的数值稳定性
        self.is_decoder = is_decoder  # 是否为解码器

        self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")  # K 矩阵的投影层
        self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")  # Q 矩阵的投影层
        self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")  # V 矩阵的投影层
        self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")  # 输出投影层

    def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
        return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))
        # 重新塑造张量形状以匹配多头注意力机制的需求,返回转置后的张量

    def call(
        self,
        hidden_states: tf.Tensor,
        key_value_states: tf.Tensor | None = None,
        past_key_value: Tuple[Tuple[tf.Tensor]] | None = None,
        attention_mask: tf.Tensor | None = None,
        layer_head_mask: tf.Tensor | None = None,
        training: Optional[bool] = False,
    ):
        # 执行注意力层的前向传播
        ...

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "k_proj", None) is not None:
            with tf.name_scope(self.k_proj.name):
                self.k_proj.build([None, None, self.embed_dim])  # 构建 K 矩阵的投影层
        if getattr(self, "q_proj", None) is not None:
            with tf.name_scope(self.q_proj.name):
                self.q_proj.build([None, None, self.embed_dim])  # 构建 Q 矩阵的投影层
        if getattr(self, "v_proj", None) is not None:
            with tf.name_scope(self.v_proj.name):
                self.v_proj.build([None, None, self.embed_dim])  # 构建 V 矩阵的投影层
        if getattr(self, "out_proj", None) is not None:
            with tf.name_scope(self.out_proj.name):
                self.out_proj.build([None, None, self.embed_dim])  # 构建输出投影层
    # 初始化函数,用于创建一个新的 PegasusEncoderLayer 对象
    def __init__(self, config: PegasusConfig, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)
        # 设置嵌入维度为配置中的模型维度
        self.embed_dim = config.d_model
        # 创建自注意力层对象,指定注意力头数和注意力机制的丢弃率
        self.self_attn = TFPegasusAttention(
            self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn"
        )
        # 创建自注意力层后的层归一化对象
        self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
        # 创建丢弃层对象,使用指定的丢弃率
        self.dropout = keras.layers.Dropout(config.dropout)
        # 获取激活函数对象,根据配置中的激活函数类型
        self.activation_fn = get_tf_activation(config.activation_function)
        # 创建激活层的丢弃层对象,使用指定的丢弃率
        self.activation_dropout = keras.layers.Dropout(config.activation_dropout)
        # 创建全连接层1,指定输出维度为配置中的编码器前馈网络维度
        self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
        # 创建全连接层2,输出维度为之前设置的嵌入维度
        self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
        # 创建最终层的层归一化对象
        self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
        # 保存配置对象,以便在需要时进行访问
        self.config = config

    # 前向传播函数,执行编码器层的前向计算
    def call(
        self,
        hidden_states: tf.Tensor,
        attention_mask: tf.Tensor,
        layer_head_mask: tf.Tensor,
        training: Optional[bool] = False,
    ):
        """
        Args:
            hidden_states (`tf.Tensor`): 输入层的张量,形状为 *(batch, seq_len, embed_dim)*
            attention_mask (`tf.Tensor`): 注意力掩码张量,形状为 *(batch, 1, tgt_len, src_len)*,
                其中填充元素由极大的负值指示。
            layer_head_mask (`tf.Tensor`): 给定层中注意力头的掩码张量,形状为 *(encoder_attention_heads,)*
            training (`bool`, optional): 指示是否处于训练模式的布尔值,默认为 False。
        """
        # 将输入状态保存为残差连接的一部分
        residual = hidden_states
        # 执行自注意力层的层归一化
        hidden_states = self.self_attn_layer_norm(hidden_states)
        # 执行自注意力计算,并获取注意力权重
        hidden_states, self_attn_weights, _ = self.self_attn(
            hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
        )
        
        # 断言确保自注意力未修改查询的形状
        tf.debugging.assert_equal(
            shape_list(hidden_states),
            shape_list(residual),
            message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
        )

        # 使用丢弃层进行输出的丢弃处理
        hidden_states = self.dropout(hidden_states, training=training)
        # 执行残差连接
        hidden_states = residual + hidden_states

        # 将输入状态保存为残差连接的一部分
        residual = hidden_states
        # 执行最终层的层归一化
        hidden_states = self.final_layer_norm(hidden_states)
        # 应用激活函数并执行第一个全连接层的计算
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        # 使用激活层的丢弃层进行输出的丢弃处理
        hidden_states = self.activation_dropout(hidden_states, training=training)
        # 执行第二个全连接层的计算
        hidden_states = self.fc2(hidden_states)
        # 使用丢弃层进行输出的丢弃处理
        hidden_states = self.dropout(hidden_states, training=training)
        # 执行残差连接
        hidden_states = residual + hidden_states

        # 返回编码器层的输出状态和自注意力权重
        return hidden_states, self_attn_weights
    # 构建模型结构,如果已经构建过,则直接返回
    def build(self, input_shape=None):
        if self.built:
            return
        
        # 标记模型为已构建状态
        self.built = True
        
        # 如果存在 self_attn 属性,则构建 self attention 层
        if getattr(self, "self_attn", None) is not None:
            with tf.name_scope(self.self_attn.name):
                self.self_attn.build(None)
        
        # 如果存在 self_attn_layer_norm 属性,则构建 layer normalization 层
        if getattr(self, "self_attn_layer_norm", None) is not None:
            with tf.name_scope(self.self_attn_layer_norm.name):
                self.self_attn_layer_norm.build([None, None, self.embed_dim])
        
        # 如果存在 fc1 属性,则构建第一个全连接层
        if getattr(self, "fc1", None) is not None:
            with tf.name_scope(self.fc1.name):
                self.fc1.build([None, None, self.embed_dim])
        
        # 如果存在 fc2 属性,则构建第二个全连接层
        if getattr(self, "fc2", None) is not None:
            with tf.name_scope(self.fc2.name):
                self.fc2.build([None, None, self.config.encoder_ffn_dim])
        
        # 如果存在 final_layer_norm 属性,则构建最终的 layer normalization 层
        if getattr(self, "final_layer_norm", None) is not None:
            with tf.name_scope(self.final_layer_norm.name):
                self.final_layer_norm.build([None, None, self.embed_dim])
# 从transformers.models.mbart.modeling_tf_mbart.TFMBartDecoderLayer复制到TFPegasusDecoderLayer,用MBart->Pegasus进行替换
class TFPegasusDecoderLayer(keras.layers.Layer):
    # 初始化方法,接受PegasusConfig类型的config对象和其他关键字参数
    def __init__(self, config: PegasusConfig, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)
        # 设定嵌入维度为config.d_model
        self.embed_dim = config.d_model
        # self注意力层,使用TFPegasusAttention,设定嵌入维度、注意力头数、dropout率,用于解码器自注意力
        self.self_attn = TFPegasusAttention(
            embed_dim=self.embed_dim,
            num_heads=config.decoder_attention_heads,
            dropout=config.attention_dropout,
            name="self_attn",
            is_decoder=True,
        )
        # dropout层,使用config.dropout作为dropout率
        self.dropout = keras.layers.Dropout(config.dropout)
        # 激活函数,根据配置获取相应的TensorFlow激活函数
        self.activation_fn = get_tf_activation(config.activation_function)
        # 激活函数的dropout层,使用config.activation_dropout作为dropout率
        self.activation_dropout = keras.layers.Dropout(config.activation_dropout)

        # self注意力层归一化,使用LayerNormalization,epsilon设定为1e-5
        self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
        # encoder注意力层,使用TFPegasusAttention,设定嵌入维度、注意力头数、dropout率,用于编码器-解码器注意力
        self.encoder_attn = TFPegasusAttention(
            self.embed_dim,
            config.decoder_attention_heads,
            dropout=config.attention_dropout,
            name="encoder_attn",
            is_decoder=True,
        )
        # encoder注意力层归一化,使用LayerNormalization,epsilon设定为1e-5
        self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm")
        # 全连接层1,使用Dense层,输出维度为config.decoder_ffn_dim
        self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
        # 全连接层2,使用Dense层,输出维度为self.embed_dim
        self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
        # 最终归一化层,使用LayerNormalization,epsilon设定为1e-5
        self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
        # 存储配置对象
        self.config = config

    # call方法,实现层的调用逻辑,接受多个输入张量和可选的训练标志
    def call(
        self,
        hidden_states: tf.Tensor,  # 隐藏状态张量,输入形状为(batch_size, seq_len, embed_dim)
        attention_mask: tf.Tensor | None = None,  # 注意力掩码张量,用于屏蔽无效位置
        encoder_hidden_states: tf.Tensor | None = None,  # 编码器隐藏状态张量,形状为(batch_size, enc_seq_len, embed_dim)
        encoder_attention_mask: tf.Tensor | None = None,  # 编码器注意力掩码张量,用于编码器-解码器注意力
        layer_head_mask: tf.Tensor | None = None,  # 层级头掩码张量,用于多头注意力机制
        cross_attn_layer_head_mask: tf.Tensor | None = None,  # 交叉注意力层级头掩码张量,用于编码器-解码器注意力的多头机制
        past_key_value: Tuple[tf.Tensor] | None = None,  # 过去的键值元组,用于实现增量解码
        training: Optional[bool] = False,  # 训练标志,控制是否启用训练模式
    # 根据输入形状构建模型,如果已经构建过则直接返回
    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果存在 self_attn 属性,则构建 self attention 层
        if getattr(self, "self_attn", None) is not None:
            with tf.name_scope(self.self_attn.name):
                self.self_attn.build(None)
        # 如果存在 self_attn_layer_norm 属性,则构建 self attention 层的 Layer Normalization
        if getattr(self, "self_attn_layer_norm", None) is not None:
            with tf.name_scope(self.self_attn_layer_norm.name):
                self.self_attn_layer_norm.build([None, None, self.embed_dim])
        # 如果存在 encoder_attn 属性,则构建 encoder attention 层
        if getattr(self, "encoder_attn", None) is not None:
            with tf.name_scope(self.encoder_attn.name):
                self.encoder_attn.build(None)
        # 如果存在 encoder_attn_layer_norm 属性,则构建 encoder attention 层的 Layer Normalization
        if getattr(self, "encoder_attn_layer_norm", None) is not None:
            with tf.name_scope(self.encoder_attn_layer_norm.name):
                self.encoder_attn_layer_norm.build([None, None, self.embed_dim])
        # 如果存在 fc1 属性,则构建第一个全连接层
        if getattr(self, "fc1", None) is not None:
            with tf.name_scope(self.fc1.name):
                self.fc1.build([None, None, self.embed_dim])
        # 如果存在 fc2 属性,则构建第二个全连接层
        if getattr(self, "fc2", None) is not None:
            with tf.name_scope(self.fc2.name):
                self.fc2.build([None, None, self.config.decoder_ffn_dim])
        # 如果存在 final_layer_norm 属性,则构建最终的 Layer Normalization 层
        if getattr(self, "final_layer_norm", None) is not None:
            with tf.name_scope(self.final_layer_norm.name):
                self.final_layer_norm.build([None, None, self.embed_dim])
class TFPegasusPreTrainedModel(TFPreTrainedModel):
    # 设置模型配置类为 PegasusConfig
    config_class = PegasusConfig
    # 模型参数名前缀为 "model"
    base_model_prefix = "model"
    ...     "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
    ...     "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
    ...     "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
    ... )


# 定义一段新闻文章内容,描述了 PG&E 因高风险天气和干燥条件而安排的停电计划,目的是减少火灾风险,影响约 80 万客户,预计持续到明天中午。
ARTICLE_TO_SUMMARIZE = (
    "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
    "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
    "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
)

# 使用分词器对文章进行预处理,设置最大长度为 1024,并返回 TensorFlow 格式的张量
inputs = tokenizer(ARTICLE_TO_SUMMARIZE, max_length=1024, return_tensors="tf")

# 生成摘要
summary_ids = model.generate(input_ids)  # 使用模型生成摘要的输入 ID
print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
"""

PEGASUS_INPUTS_DOCSTRING = r"""
"""

# 使用 keras_serializable 装饰器将类标记为可序列化
@keras_serializable
class TFPegasusEncoder(keras.layers.Layer):
    # 使用 PegasusConfig 类作为配置类
    config_class = PegasusConfig

    """
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    [`TFPegasusEncoderLayer`].

    Args:
        config: PegasusConfig
    """

    def __init__(self, config: PegasusConfig, embed_tokens: Optional[keras.layers.Embedding] = None, **kwargs):
        super().__init__(**kwargs)
        self.config = config
        self.dropout = keras.layers.Dropout(config.dropout)  # 使用指定的 dropout 率创建 Dropout 层
        self.layerdrop = config.encoder_layerdrop  # 从配置中获取层 dropout 率
        self.padding_idx = config.pad_token_id  # 获取配置中的填充索引
        self.max_source_positions = config.max_position_embeddings  # 获取配置中的最大位置嵌入
        self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0  # 计算嵌入比例因子

        self.embed_tokens = embed_tokens  # 设置嵌入 tokens
        self.embed_positions = TFPegasusSinusoidalPositionalEmbedding(
            config.max_position_embeddings,
            config.d_model,
            name="embed_positions",
        )  # 使用 sinusoidal 位置嵌入创建位置嵌入层

        self.layers = [TFPegasusEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]  # 创建多层编码器层
        self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")  # 创建层归一化层

    def get_embed_tokens(self):
        return self.embed_tokens  # 返回嵌入 tokens

    def set_embed_tokens(self, embed_tokens):
        self.embed_tokens = embed_tokens  # 设置嵌入 tokens

    # 使用 unpack_inputs 装饰器来展开输入参数
    @unpack_inputs
    def call(
        self,
        input_ids: tf.Tensor | None = None,
        inputs_embeds: tf.Tensor | None = None,
        attention_mask: tf.Tensor | None = None,
        head_mask: tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: Optional[bool] = False,
    ):
        # 函数实现在 Transformer 编码器层的调用过程中使用

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "embed_positions", None) is not None:
            with tf.name_scope(self.embed_positions.name):
                self.embed_positions.build(None)  # 构建位置嵌入层
        if getattr(self, "layer_norm", None) is not None:
            with tf.name_scope(self.layer_norm.name):
                self.layer_norm.build([None, None, self.config.d_model])  # 构建层归一化层
        if getattr(self, "layers", None) is not None:
            for layer in self.layers:
                with tf.name_scope(layer.name):
                    layer.build(None)  # 构建每一层编码器层


@keras_serializable
class TFPegasusDecoder(keras.layers.Layer):
    config_class = PegasusConfig

    """
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFPegasusDecoderLayer`]

    Args:
        config: PegasusConfig
        embed_tokens: output embedding
    """
    # 初始化方法,接收配置参数 config,嵌入词标记 embed_tokens 和其他关键字参数
    def __init__(self, config: PegasusConfig, embed_tokens: Optional[keras.layers.Embedding] = None, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)
        # 将配置参数 config 保存到实例变量中
        self.config = config
        # 设置填充索引为配置中的 pad_token_id
        self.padding_idx = config.pad_token_id
        # 将嵌入词标记 embed_tokens 保存到实例变量中
        self.embed_tokens = embed_tokens
        # 设置层丢弃率为配置中的 decoder_layerdrop
        self.layerdrop = config.decoder_layerdrop
        # 使用 TF 的 PegasusSinusoidalPositionalEmbedding 创建位置嵌入
        self.embed_positions = TFPegasusSinusoidalPositionalEmbedding(
            config.max_position_embeddings,
            config.d_model,
            name="embed_positions",
        )
        # 如果配置中指定了缩放嵌入,则使用 sqrt(d_model) 缩放因子;否则为 1.0
        self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
        # 创建多层解码器层列表,每层使用配置中的参数和命名
        self.layers = [TFPegasusDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
        # 创建层归一化层,设置 epsilon 为 1e-5
        self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")

        # 创建 Dropout 层,设置丢弃率为配置中的 dropout
        self.dropout = keras.layers.Dropout(config.dropout)

    # 获取嵌入词标记 embed_tokens
    def get_embed_tokens(self):
        return self.embed_tokens

    # 设置嵌入词标记 embed_tokens
    def set_embed_tokens(self, embed_tokens):
        self.embed_tokens = embed_tokens

    # 使用装饰器 unpack_inputs 对输入参数进行解包处理
    def call(
        self,
        input_ids: tf.Tensor | None = None,
        inputs_embeds: tf.Tensor | None = None,
        attention_mask: tf.Tensor | None = None,
        position_ids: tf.Tensor | None = None,
        encoder_hidden_states: tf.Tensor | None = None,
        encoder_attention_mask: tf.Tensor | None = None,
        head_mask: tf.Tensor | None = None,
        cross_attn_head_mask: tf.Tensor | None = None,
        past_key_values: Tuple[Tuple[tf.Tensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: Optional[bool] = False,
    ):
        # 此方法定义了模型的前向传播逻辑,输入和输出的详细说明通常在文档中而不是注释中

    # 构建模型,根据输入形状建立相应的层和嵌入
    def build(self, input_shape=None):
        # 如果模型已经构建过,则直接返回
        if self.built:
            return
        # 标记模型已经构建
        self.built = True
        # 如果存在 embed_positions 属性,则建立位置嵌入
        if getattr(self, "embed_positions", None) is not None:
            with tf.name_scope(self.embed_positions.name):
                self.embed_positions.build(None)
        # 如果存在 layer_norm 属性,则建立层归一化层
        if getattr(self, "layer_norm", None) is not None:
            with tf.name_scope(self.layer_norm.name):
                # 建立层归一化层,输入形状为 [None, None, self.config.d_model]
                self.layer_norm.build([None, None, self.config.d_model])
        # 如果存在 layers 属性,则逐层建立解码器层
        if getattr(self, "layers", None) is not None:
            for layer in self.layers:
                with tf.name_scope(layer.name):
                    layer.build(None)
# 使用装饰器标记这个类是可以被 Keras 序列化的
@keras_serializable
class TFPegasusMainLayer(keras.layers.Layer):
    # 指定配置类
    config_class = PegasusConfig

    # 初始化方法,接受 PegasusConfig 对象作为参数,并调用父类的初始化方法
    def __init__(self, config: PegasusConfig, **kwargs):
        super().__init__(**kwargs)

        # 将传入的配置对象赋值给实例变量 self.config
        self.config = config

        # 创建一个共享的 Embedding 层,用于模型的输入
        self.shared = keras.layers.Embedding(
            input_dim=config.vocab_size,
            output_dim=config.d_model,
            embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std),
            name="model.shared",
        )

        # 设置一个额外的属性,指定加载/存储权重时预期的命名范围
        self.shared.load_weight_prefix = "model.shared"

        # 创建 Pegasus 编码器和解码器对象
        self.encoder = TFPegasusEncoder(config, self.shared, name="encoder")
        self.decoder = TFPegasusDecoder(config, self.shared, name="decoder")

    # 获取输入 Embedding 层的方法
    def get_input_embeddings(self):
        return self.shared

    # 设置输入 Embedding 层的方法,同时更新编码器和解码器的 embed_tokens 属性
    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
        self.encoder.embed_tokens = self.shared
        self.decoder.embed_tokens = self.shared

    # 使用装饰器标记的调用方法,接受多个输入参数和可选的控制参数
    @unpack_inputs
    def call(
        self,
        input_ids: tf.Tensor | None = None,
        attention_mask: tf.Tensor | None = None,
        decoder_input_ids: tf.Tensor | None = None,
        decoder_attention_mask: tf.Tensor | None = None,
        decoder_position_ids: tf.Tensor | None = None,
        head_mask: tf.Tensor | None = None,
        decoder_head_mask: tf.Tensor | None = None,
        cross_attn_head_mask: tf.Tensor | None = None,
        encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
        past_key_values: Tuple[Tuple[tf.Tensor]] = None,
        inputs_embeds: tf.Tensor | None = None,
        decoder_inputs_embeds: tf.Tensor | None = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: Optional[bool] = False,
        **kwargs,
    ):
        # 在这个方法中执行模型的前向传播,处理输入和控制参数,返回相应的输出
        pass
        ):
        # 如果没有传入解码器的输入ID和嵌入向量,则不使用缓存
        if decoder_input_ids is None and decoder_inputs_embeds is None:
            use_cache = False

        # 如果输出隐藏状态为None,则使用模型配置中的默认设置
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )

        # 如果没有传入编码器输出,则调用编码器进行前向传播
        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                training=training,
            )
        # 如果用户传入了一个元组形式的encoder_outputs,在return_dict=True时,将其封装为TFBaseModelOutput
        elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput):
            encoder_outputs = TFBaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )
        # 如果用户传入了TFBaseModelOutput形式的encoder_outputs,在return_dict=False时,将其封装为元组形式
        elif not return_dict and not isinstance(encoder_outputs, tuple):
            encoder_outputs = encoder_outputs.to_tuple()

        # 调用解码器进行前向传播
        decoder_outputs = self.decoder(
            decoder_input_ids,
            attention_mask=decoder_attention_mask,
            position_ids=decoder_position_ids,
            encoder_hidden_states=encoder_outputs[0],
            encoder_attention_mask=attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        # 如果return_dict为False,则返回解码器和编码器的输出作为元组形式
        if not return_dict:
            return decoder_outputs + encoder_outputs

        # 如果return_dict为True,则封装解码器和编码器的输出为TFSeq2SeqModelOutput
        return TFSeq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
    def build(self, input_shape=None):
        # 如果已经构建过模型,则直接返回,避免重复构建
        if self.built:
            return
        # 设置标志位,表示模型已经构建
        self.built = True
        
        # 为了确保共享/绑定的权重位于模型基本命名空间中
        # 在 tf.name_scope 末尾添加 "/"(不是开头!)将其放置在根命名空间而不是当前命名空间
        with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"):
            # 构建共享的权重
            self.shared.build(None)
        
        # 如果存在编码器,则在其命名空间下构建
        if getattr(self, "encoder", None) is not None:
            with tf.name_scope(self.encoder.name):
                self.encoder.build(None)
        
        # 如果存在解码器,则在其命名空间下构建
        if getattr(self, "decoder", None) is not None:
            with tf.name_scope(self.decoder.name):
                self.decoder.build(None)
# 使用装饰器添加模型文档字符串,描述该类是一个不带特定头部的原始 PEGASUS 模型
@add_start_docstrings(
    "The bare PEGASUS Model outputting raw hidden-states without any specific head on top.",
    PEGASUS_START_DOCSTRING,
)
# 定义 TFPegasusModel 类,继承自 TFPegasusPreTrainedModel 类
class TFPegasusModel(TFPegasusPreTrainedModel):
    
    # 初始化方法,接受 PegasusConfig 类型的配置对象及其他输入参数
    def __init__(self, config: PegasusConfig, *inputs, **kwargs):
        # 调用父类的初始化方法,传入配置及其他输入参数
        super().__init__(config, *inputs, **kwargs)
        
        # 创建 TFPegasusMainLayer 对象作为模型的主要层,使用给定的配置对象及名称
        self.model = TFPegasusMainLayer(config, name="model")

    # 返回模型的编码器部分
    def get_encoder(self):
        return self.model.encoder

    # 返回模型的解码器部分
    def get_decoder(self):
        return self.model.decoder

    # 使用装饰器添加模型正向传播的文档字符串,描述输入参数及其作用
    @unpack_inputs
    @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TFSeq2SeqModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    # 定义 call 方法,接受多个输入参数,并返回 TFSeq2SeqModelOutput 类型的输出
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        decoder_input_ids: np.ndarray | tf.Tensor | None = None,
        decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
        decoder_position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        decoder_head_mask: np.ndarray | tf.Tensor | None = None,
        cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,
        encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
        past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: bool = False,
        **kwargs,
    ) -> Union[TFSeq2SeqModelOutput, Tuple[tf.Tensor]]:
        # 调用模型的主要层的 __call__ 方法,传递所有输入参数
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            decoder_position_ids=decoder_position_ids,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            encoder_outputs=encoder_outputs,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        # 返回模型的输出
        return outputs

    # 从 transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output 复制的方法
    # 定义一个方法用于处理模型输出,接受一个输出对象作为参数
    def serving_output(self, output):
        # 如果配置中使用缓存,则提取输出对象中的过去键值对(past_key_values)的第二个元素,否则为 None
        pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
        # 如果配置中输出隐藏状态(output_hidden_states),则将输出对象中的解码器隐藏状态转换为 TensorFlow 张量,否则为 None
        dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
        # 如果配置中输出注意力权重(output_attentions),则将输出对象中的解码器注意力权重转换为 TensorFlow 张量,否则为 None
        dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
        # 如果配置中输出注意力权重(output_attentions),则将输出对象中的交叉注意力权重转换为 TensorFlow 张量,否则为 None
        cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
        # 如果配置中输出隐藏状态(output_hidden_states),则将输出对象中的编码器隐藏状态转换为 TensorFlow 张量,否则为 None
        enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
        # 如果配置中输出注意力权重(output_attentions),则将输出对象中的编码器注意力权重转换为 TensorFlow 张量,否则为 None
        enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None

        # 返回一个 TFSeq2SeqModelOutput 对象,包含了模型输出的相关信息
        return TFSeq2SeqModelOutput(
            last_hidden_state=output.last_hidden_state,
            past_key_values=pkv,
            decoder_hidden_states=dec_hs,
            decoder_attentions=dec_attns,
            cross_attentions=cross_attns,
            encoder_last_hidden_state=output.encoder_last_hidden_state,
            encoder_hidden_states=enc_hs,
            encoder_attentions=enc_attns,
        )

    # 定义一个方法用于构建模型
    def build(self, input_shape=None):
        # 如果已经构建过模型,则直接返回
        if self.built:
            return
        # 标记为已经构建
        self.built = True
        # 如果存在模型属性
        if getattr(self, "model", None) is not None:
            # 使用模型的名称作为命名空间,在该命名空间下构建模型,传入 None 作为输入形状
            with tf.name_scope(self.model.name):
                self.model.build(None)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class BiasLayer(keras.layers.Layer):
    """
    Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis,
    so all weights have to be registered in a layer.
    """

    def __init__(self, shape, initializer, trainable, name, **kwargs):
        super().__init__(name=name, **kwargs)
        # 注:在序列化时,这个变量的名称不会被作用域化,即它不会以"outer_layer/inner_layer/.../name:0"的格式出现。
        # 而是直接是"name:0"。详情请参考:
        # https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214
        # 添加一个可训练的偏置权重,用于模型层的操作
        self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable)

    def call(self, x):
        # 将偏置加到输入张量上,并返回结果
        return x + self.bias


@add_start_docstrings(
    "The PEGASUS Model with a language modeling head. Can be used for summarization.",
    PEGASUS_START_DOCSTRING,
)
class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLanguageModelingLoss):
    _keys_to_ignore_on_load_unexpected = [
        r"model.encoder.embed_tokens.weight",
        r"model.decoder.embed_tokens.weight",
    ]

    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        # 创建 PEGASUS 主模型层,并命名为"model"
        self.model = TFPegasusMainLayer(config, name="model")
        # 根据配置创建一个偏置层用于最终的对数概率偏置,这个层在 pytorch 中被注册为一个缓冲区,为保持一致性设置为不可训练
        self.bias_layer = BiasLayer(
            name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
        )

    def get_decoder(self):
        # 获取模型的解码器
        return self.model.decoder

    def get_encoder(self):
        # 获取模型的编码器
        return self.model.encoder

    def get_output_embeddings(self):
        # 获取输出的嵌入层
        return self.get_input_embeddings()

    def set_output_embeddings(self, value):
        # 设置输出的嵌入层
        self.set_input_embeddings(value)

    def get_bias(self):
        # 返回模型的偏置信息
        return {"final_logits_bias": self.bias_layer.bias}

    def set_bias(self, value):
        # 替换现有的包含偏置的层以进行正确的(反)序列化
        vocab_size = value["final_logits_bias"].shape[-1]
        self.bias_layer = BiasLayer(
            name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False
        )
        # 分配给偏置层新的偏置值
        self.bias_layer.bias.assign(value["final_logits_bias"])

    @unpack_inputs
    @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
    @add_end_docstrings(PEGASUS_GENERATION_EXAMPLE)
    # 定义一个方法 `call`,用于模型的前向传播和推理过程
    def call(
        # 输入序列的 token IDs,可以是 TensorFlow 的输入类型或者 None
        input_ids: TFModelInputType | None = None,
        # 注意力掩码,可以是 NumPy 数组、TensorFlow 张量或者 None
        attention_mask: np.ndarray | tf.Tensor | None = None,
        # 解码器的输入 token IDs,可以是 NumPy 数组、TensorFlow 张量或者 None
        decoder_input_ids: np.ndarray | tf.Tensor | None = None,
        # 解码器的注意力掩码,可以是 NumPy 数组、TensorFlow 张量或者 None
        decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
        # 解码器的位置 IDs,可以是 NumPy 数组、TensorFlow 张量或者 None
        decoder_position_ids: np.ndarray | tf.Tensor | None = None,
        # 头掩码用于屏蔽不同注意力头的特定位置,可以是 NumPy 数组、TensorFlow 张量或者 None
        head_mask: np.ndarray | tf.Tensor | None = None,
        # 解码器头部掩码,可以是 NumPy 数组、TensorFlow 张量或者 None
        decoder_head_mask: np.ndarray | tf.Tensor | None = None,
        # 跨注意力头部掩码,可以是 NumPy 数组、TensorFlow 张量或者 None
        cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,
        # 编码器的输出,类型为 TFBaseModelOutput 或者 None
        encoder_outputs: Optional[TFBaseModelOutput] = None,
        # 用于存储过去的键值对的元组,元素为 NumPy 数组或 TensorFlow 张量的元组的元组
        past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
        # 输入的嵌入向量,可以是 NumPy 数组、TensorFlow 张量或者 None
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        # 解码器的输入嵌入向量,可以是 NumPy 数组、TensorFlow 张量或者 None
        decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None,
        # 是否使用缓存,布尔值或者 None
        use_cache: Optional[bool] = None,
        # 是否输出注意力权重,布尔值或者 None
        output_attentions: Optional[bool] = None,
        # 是否输出隐藏状态,布尔值或者 None
        output_hidden_states: Optional[bool] = None,
        # 是否返回字典格式的输出,布尔值或者 None
        return_dict: Optional[bool] = None,
        # 标签数据,可以是 NumPy 数组、TensorFlow 张量或者 None
        labels: np.ndarray | tf.Tensor | None = None,
        # 是否处于训练模式,布尔值,默认为 False
        training: bool = False,
    ) -> Union[TFSeq2SeqLMOutput, Tuple[tf.Tensor]]:
        """
        labels (`tf.tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:
            Either TFSeq2SeqLMOutput or a tuple containing tf.Tensor outputs.

        """

        if labels is not None:
            # Convert padding tokens in labels to -100 to ignore them during loss computation
            labels = tf.where(
                labels == self.config.pad_token_id,
                tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
                labels,
            )
            use_cache = False
            if decoder_input_ids is None and decoder_inputs_embeds is None:
                # Shift labels to the right for decoder input
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
                )

        # Forward pass through the model
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            decoder_position_ids=decoder_position_ids,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )
        
        # Calculate language modeling logits
        lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True)
        lm_logits = self.bias_layer(lm_logits)
        # Compute masked language modeling loss if labels are provided
        masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)

        if not return_dict:
            # Prepare output tuple without returning a dictionary
            output = (lm_logits,) + outputs[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
        # Return TFSeq2SeqLMOutput object if return_dict is True
        return TFSeq2SeqLMOutput(
            loss=masked_lm_loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,  # index 1 of d outputs
            decoder_hidden_states=outputs.decoder_hidden_states,  # index 2 of d outputs
            decoder_attentions=outputs.decoder_attentions,  # index 3 of d outputs
            cross_attentions=outputs.cross_attentions,  # index 4 of d outputs
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,  # index 0 of encoder outputs
            encoder_hidden_states=outputs.encoder_hidden_states,  # 1 of e out
            encoder_attentions=outputs.encoder_attentions,  # 2 of e out
        )

    # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output
    # 定义一个方法用于处理模型的输出,生成适合序列到序列模型的输出对象
    def serving_output(self, output):
        # 如果配置中启用了缓存,则从输出中提取过去键值对的第二个元素
        pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
        # 如果配置中启用了输出隐藏状态,则将输出的解码器隐藏状态转换为张量
        dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
        # 如果配置中启用了输出注意力权重,则将输出的解码器注意力权重转换为张量
        dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
        # 如果配置中启用了输出注意力权重,则将输出的交叉注意力权重转换为张量
        cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
        # 如果配置中启用了输出隐藏状态,则将输出的编码器隐藏状态转换为张量
        enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
        # 如果配置中启用了输出注意力权重,则将输出的编码器注意力权重转换为张量
        enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None

        # 返回一个 TFSeq2SeqLMOutput 对象,包含输出的各种信息
        return TFSeq2SeqLMOutput(
            logits=output.logits,
            past_key_values=pkv,
            decoder_hidden_states=dec_hs,
            decoder_attentions=dec_attns,
            cross_attentions=cross_attns,
            encoder_last_hidden_state=output.encoder_last_hidden_state,
            encoder_hidden_states=enc_hs,
            encoder_attentions=enc_attns,
        )

    # 从 transformers 库中 TF 版本的 BART 模型类中复制的方法,用于为生成准备输入
    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        past_key_values=None,
        attention_mask=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):
        # 如果使用了过去的键值对,截断 decoder_input_ids
        if past_key_values is not None:
            decoder_input_ids = decoder_input_ids[:, -1:]

        # 如果存在 decoder_attention_mask,则在最后一个位置累加以获取 decoder_position_ids
        if decoder_attention_mask is not None:  # xla
            decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
        # 如果没有使用 XLA 且存在过去的键值对,则从 past_key_values 中获取 decoder_position_ids
        elif past_key_values is not None:  # no xla + past_key_values
            decoder_position_ids = past_key_values[0][0].shape[2]
        # 否则,生成一个范围为 decoder_input_ids.shape[1] 的 decoder_position_ids
        else:  # no xla + no past_key_values
            decoder_position_ids = tf.range(decoder_input_ids.shape[1])

        # 返回一个包含各种生成所需输入的字典
        return {
            "input_ids": None,  # encoder_outputs 已定义,不需要 input_ids
            "encoder_outputs": encoder_outputs,
            "past_key_values": past_key_values,
            "decoder_input_ids": decoder_input_ids,
            "attention_mask": attention_mask,
            "decoder_attention_mask": decoder_attention_mask,
            "decoder_position_ids": decoder_position_ids,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,  # 修改此处以避免缓存(可能用于调试)
        }

    # 定义一个方法用于根据标签从配置中获取的开始和填充令牌 ID 调整 decoder_input_ids
    def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
        return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
    # 定义模型建立函数,指定输入形状(可选)
    def build(self, input_shape=None):
        # 如果模型已经建立过,则直接返回
        if self.built:
            return
        # 将标志位设置为已建立
        self.built = True
        
        # 如果存在模型属性,则按模型名称创建命名空间,并建立模型
        if getattr(self, "model", None) is not None:
            with tf.name_scope(self.model.name):
                self.model.build(None)
        
        # 如果存在偏置层属性,则按偏置层名称创建命名空间,并建立偏置层
        if getattr(self, "bias_layer", None) is not None:
            with tf.name_scope(self.bias_layer.name):
                self.bias_layer.build(None)

.\models\pegasus\tokenization_pegasus.py

# coding=utf-8
# Copyright 2020 Google and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入标准库 os
import os
# 从 shutil 库中导入 copyfile 函数
from shutil import copyfile
# 从 typing 库中导入类型提示 Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

# 导入 sentencepiece 库,作为 spm 的别名
import sentencepiece as spm

# 从 tokenization_utils 中导入 AddedToken, PreTrainedTokenizer 类
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
# 从 utils 中导入 logging 模块
from ...utils import logging

# 定义 SPIECE_UNDERLINE 常量
SPIECE_UNDERLINE = "▁"

# 定义 VOCAB_FILES_NAMES 字典常量
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}

# 定义 PRETRAINED_VOCAB_FILES_MAP 字典常量
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {"google/pegasus-xsum": "https://huggingface.co/google/pegasus-xsum/resolve/main/spiece.model"}
}

# 定义 PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 字典常量
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "google/pegasus-xsum": 512,
}

# 获取 logger 对象
logger = logging.get_logger(__name__)


# TODO ArthurZ refactor this to only use the added_tokens_encoder
# 定义 PegasusTokenizer 类,继承自 PreTrainedTokenizer
class PegasusTokenizer(PreTrainedTokenizer):
    r"""
    Construct a PEGASUS tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).

    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
    this superclass for more information regarding those methods.

    """

    # 设置类变量 vocab_files_names
    vocab_files_names = VOCAB_FILES_NAMES
    # 设置类变量 pretrained_vocab_files_map
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    # 设置类变量 max_model_input_sizes
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    # 设置类变量 model_input_names
    model_input_names = ["input_ids", "attention_mask"]

    # 初始化方法
    def __init__(
        self,
        vocab_file,
        pad_token="<pad>",
        eos_token="</s>",
        unk_token="<unk>",
        mask_token="<mask_2>",
        mask_token_sent="<mask_1>",
        additional_special_tokens=None,
        offset=103,  # entries 2 - 104 are only used for pretraining
        sp_model_kwargs: Optional[Dict[str, Any]] = None,
        **kwargs,
    ):
        # 调用父类的初始化方法
        super().__init__(
            # 传递给父类的参数
            pad_token=pad_token,
            eos_token=eos_token,
            unk_token=unk_token,
            mask_token=mask_token,
            mask_token_sent=mask_token_sent,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )
        # 设置实例变量 offset
        self.offset = offset  # entries 2 - 104 are only used for pretraining
        # 设置实例变量 sp_model_kwargs
        self.sp_model_kwargs = sp_model_kwargs if sp_model_kwargs is not None else {}
        # 加载 SentencePieceProcessor 对象到 self.sp_model
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        # 加载词汇文件到 self.vocab_file
        self.sp_model.Load(self.vocab_file)

    # 定义 vocab_size 属性方法,返回词汇表大小
    @property
    def vocab_size(self) -> int:
        return len(self.sp_model) + self.offset

    # 定义 get_vocab 方法,返回词汇表的字典
    def get_vocab(self) -> Dict[str, int]:
        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
        vocab.update(self.added_tokens_encoder)
        return vocab

    # 定义 __getstate__ 方法,返回对象状态的字典表示,忽略 sp_model
    def __getstate__(self):
        state = self.__dict__.copy()
        state["sp_model"] = None
        return state

    # 定义 __setstate__ 方法,设置对象状态,重新加载 sp_model 和 vocab_file
    def __setstate__(self, d):
        self.__dict__ = d

        # for backward compatibility
        # 向后兼容性处理
        if not hasattr(self, "sp_model_kwargs"):
            self.sp_model_kwargs = {}

        # 加载 SentencePieceProcessor 到 self.sp_model
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        # 加载词汇表文件到 self.vocab_file
        self.sp_model.Load(self.vocab_file)
    def _tokenize(self, text: str) -> List[str]:
        """Take as input a string and return a list of strings (tokens) for words/sub-words"""
        # 使用 sentencepiece 模型对输入文本进行编码,返回字符串的列表(标记)
        return self.sp_model.encode(text, out_type=str)

    def _convert_token_to_id(self, token: str) -> int:
        """Converts a token (str) to an id using the vocab."""
        # 使用 sentencepiece 模型将 token(字符串)转换为对应的 id
        sp_id = self.sp_model.piece_to_id(token)
        return sp_id + self.offset

    def _convert_id_to_token(self, index: int) -> str:
        """Converts an index (integer) to a token (str) using the vocab."""
        # 如果 index 小于 offset,则直接使用 sentencepiece 模型将 index 转换为 token
        if index < self.offset:
            return self.sp_model.IdToPiece(index)
        # 否则,减去 offset 后再转换为 token
        token = self.sp_model.IdToPiece(index - self.offset)
        return token

    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (string) in a single string."""
        current_sub_tokens = []
        out_string = ""
        for token in tokens:
            # 确保特殊的 token 不会被 sentencepiece 模型解码
            if token in self.all_special_tokens:
                out_string += self.sp_model.decode(current_sub_tokens) + token
                current_sub_tokens = []
            else:
                current_sub_tokens.append(token)
        out_string += self.sp_model.decode(current_sub_tokens)
        return out_string.strip()

    def num_special_tokens_to_add(self, pair=False):
        """Just EOS"""
        # 返回要添加的特殊 token 数量(这里只有 EOS)
        return 1

    def _special_token_mask(self, seq):
        all_special_ids = set(self.all_special_ids)  # 一次性创建所有特殊 token 的集合
        all_special_ids.remove(self.unk_token_id)  # <unk> 只有在某些情况下是特殊的

        # 创建一个 mask 列表,标记哪些 token 是特殊 token
        return [1 if x in all_special_ids else 0 for x in seq]

    def get_special_tokens_mask(
        self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
        """Get list where entries are [1] if a token is [eos] or [pad] else 0."""
        if already_has_special_tokens:
            # 如果已经有特殊 token,则直接调用 _special_token_mask 处理 token_ids_0
            return self._special_token_mask(token_ids_0)
        elif token_ids_1 is None:
            # 如果没有第二个 token 列表,对 token_ids_0 进行处理并添加一个额外的特殊 token
            return self._special_token_mask(token_ids_0) + [1]
        else:
            # 否则,合并 token_ids_0 和 token_ids_1 后处理,并添加一个额外的特殊 token
            return self._special_token_mask(token_ids_0 + token_ids_1) + [1]
    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
        """
        Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating
        and adding special tokens. A PEGASUS sequence has the following format, where `X` represents the sequence:

        - single sequence: `X </s>`
        - pair of sequences: `A B </s>` (not intended use)

        BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
        separator.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
        """
        # 如果只有一个序列,直接在末尾添加结束符号的特殊token
        if token_ids_1 is None:
            return token_ids_0 + [self.eos_token_id]
        
        # 如果有两个序列,将两个序列连接起来,并在最后添加结束符号的特殊token
        # 尽管不推荐使用两个序列,但为了API的一致性保留了对两个序列的处理逻辑
        return token_ids_0 + token_ids_1 + [self.eos_token_id]

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        """
        Save the vocabulary to a directory. If the directory does not exist, an error is logged.

        Args:
            save_directory (str):
                Directory path where the vocabulary will be saved.
            filename_prefix (str, *optional*):
                Optional prefix for the saved vocabulary file.

        Returns:
            `Tuple[str]`: Tuple containing the path to the saved vocabulary file.
        """
        # 检查保存目录是否存在,如果不存在则记录错误并返回
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return
        
        # 拼接输出的词汇表文件路径
        out_vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )
        
        # 如果当前词汇表文件路径与目标文件路径不同且当前词汇表文件存在,则复制当前文件到目标路径
        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
            copyfile(self.vocab_file, out_vocab_file)
        # 如果当前词汇表文件不存在,则将序列化后的模型写入目标路径
        elif not os.path.isfile(self.vocab_file):
            with open(out_vocab_file, "wb") as fi:
                content_spiece_model = self.sp_model.serialized_model_proto()
                fi.write(content_spiece_model)
        
        # 返回保存的词汇表文件路径
        return (out_vocab_file,)

.\models\pegasus\tokenization_pegasus_fast.py

# coding=utf-8
# Copyright 2020 Google and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tokenization class for model PEGASUS."""

# 导入标准库和第三方库
import os
from shutil import copyfile
from typing import List, Optional, Tuple

# 导入所需的tokenization工具和logging
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import is_sentencepiece_available, logging

# 如果sentencepiece可用,则导入对应的PEGASUS tokenizer类,否则设为None
if is_sentencepiece_available():
    from .tokenization_pegasus import PegasusTokenizer
else:
    PegasusTokenizer = None

# 获取当前模块的logger
logger = logging.get_logger(__name__)

# 定义一个常量,用于表示词块的前缀
SPIECE_UNDERLINE = "▁"

# 定义词汇文件的名称映射
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"}

# 预训练模型的词汇文件映射
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {"google/pegasus-xsum": "https://huggingface.co/google/pegasus-xsum/resolve/main/spiece.model"},
    "tokenizer_file": {
        "google/pegasus-xsum": "https://huggingface.co/google/pegasus-xsum/resolve/main/tokenizer.json"
    },
}

# 预训练模型的位置编码大小映射
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "google/pegasus-xsum": 512,
}


class PegasusTokenizerFast(PreTrainedTokenizerFast):
    r"""
    Construct a "fast" PEGASUS tokenizer (backed by HuggingFace's *tokenizers* library). Based on
    [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models).

    This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
    refer to this superclass for more information regarding those methods.
    """
    Args:
        vocab_file (`str`):
            [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
            contains the vocabulary necessary to instantiate a tokenizer.
        pad_token (`str`, *optional*, defaults to `"<pad>"`):
            The token used for padding, for example when batching sequences of different lengths.
        eos_token (`str`, *optional*, defaults to `"</s>"`):
            The end of sequence token.

            <Tip>

            When building a sequence using special tokens, this is not the token that is used for the end of sequence.
            The token used is the `sep_token`.

            </Tip>

        unk_token (`str`, *optional*, defaults to `"<unk>"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        mask_token (`str`, *optional*, defaults to `"<mask_2>"`):
            The token used for masking single token values. This is the token used when training this model with masked
            language modeling (MLM). This is the token that the PEGASUS encoder will try to predict during pretraining.
            It corresponds to *[MASK2]* in [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive
            Summarization](https://arxiv.org/pdf/1912.08777.pdf).
        mask_token_sent (`str`, *optional*, defaults to `"<mask_1>"`):
            The token used for masking whole target sentences. This is the token used when training this model with gap
            sentences generation (GSG). This is the sentence that the PEGASUS decoder will try to predict during
            pretraining. It corresponds to *[MASK1]* in [PEGASUS: Pre-training with Extracted Gap-sentences for
            Abstractive Summarization](https://arxiv.org/pdf/1912.08777.pdf).
        additional_special_tokens (`List[str]`, *optional*):
            Additional special tokens used by the tokenizer. If no additional_special_tokens are provided <mask_2> and
            <unk_2, ..., unk_102> are used as additional special tokens corresponding to the [original PEGASUS
            tokenizer](https://github.com/google-research/pegasus/blob/939830367bcf411193d2b5eca2f2f90f3f9260ca/pegasus/ops/pretrain_parsing_ops.cc#L66)
            that uses the tokens 2 - 104 only for pretraining
    ```

    vocab_files_names = VOCAB_FILES_NAMES
    # 获取存储了词汇文件名的常量列表

    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    # 获取存储了预训练模型词汇文件映射的常量字典

    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    # 获取存储了预训练模型输入最大尺寸的常量字典

    slow_tokenizer_class = PegasusTokenizer
    # 获取 PegasusTokenizer 类,用于慢速模式的分词器

    model_input_names = ["input_ids", "attention_mask"]
    # 定义了模型输入的名称列表,包括 input_ids 和 attention_mask
    # 初始化函数,用于创建一个新的对象
    def __init__(
        self,
        vocab_file=None,  # 词汇表文件路径,默认为None
        tokenizer_file=None,  # 分词器文件路径,默认为None
        pad_token="<pad>",  # 填充标记,默认为"<pad>"
        eos_token="</s>",  # 结束标记,默认为"</s>"
        unk_token="<unk>",  # 未知标记,默认为"<unk>"
        mask_token="<mask_2>",  # 掩码标记,默认为"<mask_2>"
        mask_token_sent="<mask_1>",  # 用于句子级别掩码的标记,默认为"<mask_1>"
        additional_special_tokens=None,  # 额外的特殊标记列表,默认为None
        offset=103,  # 前期训练中使用的偏移量,默认为103,索引2到104用于预训练
        **kwargs,  # 其他关键字参数
    ):
        self.offset = offset  # 初始化对象的偏移量属性为给定的offset值

        if additional_special_tokens is not None:
            if not isinstance(additional_special_tokens, list):
                # 如果额外特殊标记不是列表类型,则引发类型错误异常
                raise TypeError(
                    f"additional_special_tokens should be of type {type(list)}, but is"
                    f" {type(additional_special_tokens)}"
                )

            # 如果mask_token_sent不在additional_special_tokens中且不为None,则将其添加到额外特殊标记列表中
            additional_special_tokens_extended = (
                ([mask_token_sent] + additional_special_tokens)
                if mask_token_sent not in additional_special_tokens and mask_token_sent is not None
                else additional_special_tokens
            )

            # 填充额外特殊标记列表直到达到offset - 1的长度,并以"<unk_x>"命名
            additional_special_tokens_extended += [
                f"<unk_{i}>" for i in range(len(additional_special_tokens_extended), self.offset - 1)
            ]

            # 如果额外特殊标记列表中存在重复的标记,则引发值错误异常
            if len(set(additional_special_tokens_extended)) != len(additional_special_tokens_extended):
                raise ValueError(
                    "Please make sure that the provided additional_special_tokens do not contain an incorrectly"
                    f" shifted list of <unk_x> tokens. Found {additional_special_tokens_extended}."
                )
            additional_special_tokens = additional_special_tokens_extended
        else:
            # 如果额外特殊标记为None,则创建一个新的额外特殊标记列表,包含mask_token_sent和"<unk_x>"标记
            additional_special_tokens = [mask_token_sent] if mask_token_sent is not None else []
            additional_special_tokens += [f"<unk_{i}>" for i in range(2, self.offset)]

        # 如果from_slow参数未提供,则从kwargs中获取或初始化为None
        from_slow = kwargs.pop("from_slow", None)
        # 如果pad_token、eos_token、unk_token中有一个与默认值不同,则设置from_slow为True
        from_slow = from_slow or str(pad_token) != "<pad>" or str(eos_token) != "</s>" or str(unk_token) != "<unk>"

        # 从kwargs中移除added_tokens_decoder键的值
        kwargs.pop("added_tokens_decoder", {})

        # 调用父类的初始化方法,传递所需的参数和关键字参数
        super().__init__(
            vocab_file,
            tokenizer_file=tokenizer_file,
            pad_token=pad_token,
            eos_token=eos_token,
            unk_token=unk_token,
            mask_token=mask_token,
            mask_token_sent=mask_token_sent,
            offset=offset,
            additional_special_tokens=additional_special_tokens,
            from_slow=from_slow,
            **kwargs,
        )
        self.vocab_file = vocab_file  # 设置对象的词汇表文件属性为给定的vocab_file路径

    @property
    def can_save_slow_tokenizer(self) -> bool:
        # 检查对象是否能够保存慢速分词器,前提是vocab_file文件存在
        return os.path.isfile(self.vocab_file) if self.vocab_file else False
    def _special_token_mask(self, seq):
        # 将所有特殊标记的 ID 放入集合中,并移除未知标记的 ID
        all_special_ids = set(self.all_special_ids)  # 一次性调用,而不是在列表推导式中调用
        all_special_ids.remove(self.unk_token_id)  # <unk> 只有在某些情况下才是特殊的

        # 检查特殊标记的数量是否正确
        if all_special_ids != set(range(len(self.additional_special_tokens) + 3)):
            raise ValueError(
                "There should be 3 special tokens: mask_token, pad_token, and eos_token +"
                f" {len(self.additional_special_tokens)} additional_special_tokens, but got {all_special_ids}"
            )

        # 返回序列中每个元素是否为特殊标记的掩码列表
        return [1 if x in all_special_ids else 0 for x in seq]

    def get_special_tokens_mask(
        self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
        """
        获取特殊标记掩码列表,如果 token 是 [eos] 或 [pad] 则为 [1],否则为 [0]。
        """
        if already_has_special_tokens:
            return self._special_token_mask(token_ids_0)
        elif token_ids_1 is None:
            return self._special_token_mask(token_ids_0) + [1]
        else:
            return self._special_token_mask(token_ids_0 + token_ids_1) + [1]

    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
        """
        根据输入序列构建模型输入,末尾添加 eos 标记,不添加 bos 标记到开头。

        - 单个序列: `X </s>`
        - 序列对: `A B </s>`(不是预期的用法)

        Args:
            token_ids_0 (`List[int]`):
                要添加特殊标记的 ID 列表
            token_ids_1 (`List[int]`, *可选*):
                第二个序列 ID 列表(如果是序列对)

        Returns:
            `List[int]`: 带有适当特殊标记的输入 ID 列表。
        """
        if token_ids_1 is None:
            return token_ids_0 + [self.eos_token_id]
        # 虽然不期望处理序列对,但为了 API 一致性保留了对序列对的逻辑
        return token_ids_0 + token_ids_1 + [self.eos_token_id]

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        """
        保存词汇表到指定目录下,如果无法保存,则引发异常。

        Args:
            save_directory (str): 保存词汇表的目录路径
            filename_prefix (str, *可选*): 文件名前缀

        Returns:
            Tuple[str]: 保存的词汇表文件路径
        """
        if not self.can_save_slow_tokenizer:
            raise ValueError(
                "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
                "tokenizer."
            )

        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return
        
        # 拼接输出的词汇表文件路径
        out_vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )

        # 如果当前词汇表文件路径与输出路径不同,则复制当前词汇表文件到输出路径
        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
            copyfile(self.vocab_file, out_vocab_file)

        return (out_vocab_file,)

.\models\pegasus\__init__.py

# 导入必要的模块和函数,包括类型检查相关内容
from typing import TYPE_CHECKING

# 从特定路径导入必要的工具和函数
from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_flax_available,
    is_sentencepiece_available,
    is_tf_available,
    is_tokenizers_available,
    is_torch_available,
)

# 定义一个字典,描述导入的结构
_import_structure = {"configuration_pegasus": ["PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusConfig"]}

# 尝试导入 PegasusTokenizer,如果 sentencepiece 不可用则引发异常
try:
    if not is_sentencepiece_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["tokenization_pegasus"] = ["PegasusTokenizer"]

# 尝试导入 PegasusTokenizerFast,如果 tokenizers 不可用则引发异常
try:
    if not is_tokenizers_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["tokenization_pegasus_fast"] = ["PegasusTokenizerFast"]

# 尝试导入 Pegasus 相关的模型,如果 torch 不可用则引发异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_pegasus"] = [
        "PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST",
        "PegasusForCausalLM",
        "PegasusForConditionalGeneration",
        "PegasusModel",
        "PegasusPreTrainedModel",
    ]

# 尝试导入 TF Pegasus 相关的模型,如果 TensorFlow 不可用则引发异常
try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_tf_pegasus"] = [
        "TFPegasusForConditionalGeneration",
        "TFPegasusModel",
        "TFPegasusPreTrainedModel",
    ]

# 尝试导入 Flax Pegasus 相关的模型,如果 Flax 不可用则引发异常
try:
    if not is_flax_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_flax_pegasus"] = [
        "FlaxPegasusForConditionalGeneration",
        "FlaxPegasusModel",
        "FlaxPegasusPreTrainedModel",
    ]

# 如果是类型检查阶段,则进一步导入相应的配置和工具
if TYPE_CHECKING:
    from .configuration_pegasus import PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusConfig

    # 在类型检查阶段,如果 sentencepiece 可用,则导入 PegasusTokenizer
    try:
        if not is_sentencepiece_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .tokenization_pegasus import PegasusTokenizer

    # 在类型检查阶段,如果 tokenizers 可用,则导入 PegasusTokenizerFast
    try:
        if not is_tokenizers_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .tokenization_pegasus_fast import PegasusTokenizerFast
    try:
        # 检查是否安装了 PyTorch 库,如果未安装则引发 OptionalDependencyNotAvailable 异常
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        # 如果异常被引发,不做任何处理,继续执行下面的代码
        pass
    else:
        # 如果未引发异常,则从相应模块导入必要的类和变量
        from .modeling_pegasus import (
            PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST,
            PegasusForCausalLM,
            PegasusForConditionalGeneration,
            PegasusModel,
            PegasusPreTrainedModel,
        )

    try:
        # 检查是否安装了 TensorFlow 库,如果未安装则引发 OptionalDependencyNotAvailable 异常
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        # 如果异常被引发,不做任何处理,继续执行下面的代码
        pass
    else:
        # 如果未引发异常,则从相应模块导入必要的类和变量
        from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel

    try:
        # 检查是否安装了 Flax 库,如果未安装则引发 OptionalDependencyNotAvailable 异常
        if not is_flax_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        # 如果异常被引发,不做任何处理,继续执行下面的代码
        pass
    else:
        # 如果未引发异常,则从相应模块导入必要的类和变量
        from .modeling_flax_pegasus import (
            FlaxPegasusForConditionalGeneration,
            FlaxPegasusModel,
            FlaxPegasusPreTrainedModel,
        )
else:
    # 如果不是以上情况,则执行以下代码
    import sys
    # 导入系统模块 sys

    # 将当前模块注册到 sys.modules 中,使用 _LazyModule 封装
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\pegasus_x\configuration_pegasus_x.py

# coding=utf-8
# Copyright 2022, Google and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PEGASUS-X model configuration"""

# 导入配置基类 PretrainedConfig 和 logging 工具
from ...configuration_utils import PretrainedConfig
from ...utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# PEGASUS-X 预训练模型配置文件映射表,提供预训练模型的配置文件 URL
PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "google/pegasus-x-base": "https://huggingface.co/google/pegasus-x-base/resolve/main/config.json",
    "google/pegasus-x-large": "https://huggingface.co/google/pegasus-x-large/resolve/main/config.json",
    # 查看所有 PEGASUS-X 模型的列表,访问 https://huggingface.co/models?filter=pegasus-x
}

# PegasusXConfig 类,用于存储 PEGASUS-X 模型的配置信息,继承自 PretrainedConfig
class PegasusXConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`PegasusXModel`]. It is used to instantiate a
    PEGASUS-X model according to the specified arguments, defining the model architecture. Instantiating a
    configuration with the defaults will yield a similar configuration to that of the PEGASUS-X
    [google/pegasus-x-large](https://huggingface.co/google/pegasus-x-large) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Example:

    ```
    >>> from transformers import PegasusXConfig, PegasusXModel

    >>> # Initializing a PEGASUS google/pegasus-x-large style configuration
    >>> configuration = PegasusXConfig()

    >>> # Initializing a model (with random weights) from the google/pegasus-x-large style configuration
    >>> model = PegasusXModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    """

    # 模型类型标识为 "pegasus_x"
    model_type = "pegasus_x"
    # 推断时要忽略的键列表,这里忽略 "past_key_values"
    keys_to_ignore_at_inference = ["past_key_values"]
    # 属性映射表,将一些通用名称映射到 PEGASUS-X 模型配置中的特定名称
    attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
    # 初始化函数,用于初始化一个 Transformer 模型的参数和配置
    def __init__(
        self,
        vocab_size=96103,  # 词汇表大小,默认为96103
        max_position_embeddings=16384,  # 最大位置嵌入数,默认为16384
        encoder_layers=16,  # 编码器层数,默认为16层
        encoder_ffn_dim=4096,  # 编码器中 Feed-Forward 层的维度,默认为4096
        encoder_attention_heads=16,  # 编码器中注意力头的数量,默认为16
        decoder_layers=16,  # 解码器层数,默认为16层
        decoder_ffn_dim=4096,  # 解码器中 Feed-Forward 层的维度,默认为4096
        decoder_attention_heads=16,  # 解码器中注意力头的数量,默认为16
        encoder_layerdrop=0.0,  # 编码器层间的随机删除率,默认为0.0
        decoder_layerdrop=0.0,  # 解码器层间的随机删除率,默认为0.0
        use_cache=True,  # 是否使用缓存,默认为True
        is_encoder_decoder=True,  # 是否是编码解码结构,默认为True
        activation_function="gelu",  # 激活函数类型,默认为GELU
        d_model=1024,  # 模型的维度,默认为1024
        dropout=0.1,  # 总体的Dropout概率,默认为0.1
        attention_dropout=0.0,  # 注意力Dropout概率,默认为0.0
        activation_dropout=0.0,  # 激活函数Dropout概率,默认为0.0
        init_std=0.02,  # 初始化参数的标准差,默认为0.02
        decoder_start_token_id=0,  # 解码器的起始标记ID,默认为0
        scale_embedding=True,  # 是否对嵌入进行缩放,默认为True;如果为True,缩放因子为sqrt(d_model)
        pad_token_id=0,  # 填充标记的ID,默认为0
        eos_token_id=1,  # 结束标记的ID,默认为1
        forced_eos_token_id=1,  # 强制结束标记的ID,默认为1
        num_global_tokens=32,  # 全局标记的数量,默认为32
        block_size=512,  # 块大小,默认为512
        stagger_local_blocks=True,  # 是否交错本地块,默认为True
        **kwargs,  # 其他关键字参数,用于传递给父类的初始化函数
    ):
        self.vocab_size = vocab_size  # 设置词汇表大小
        self.max_position_embeddings = max_position_embeddings  # 设置最大位置嵌入数
        self.d_model = d_model  # 设置模型的维度
        self.encoder_ffn_dim = encoder_ffn_dim  # 设置编码器中 Feed-Forward 层的维度
        self.encoder_layers = encoder_layers  # 设置编码器的层数
        self.encoder_attention_heads = encoder_attention_heads  # 设置编码器中注意力头的数量
        self.decoder_ffn_dim = decoder_ffn_dim  # 设置解码器中 Feed-Forward 层的维度
        self.decoder_layers = decoder_layers  # 设置解码器的层数
        self.decoder_attention_heads = decoder_attention_heads  # 设置解码器中注意力头的数量
        self.dropout = dropout  # 设置总体的Dropout概率
        self.attention_dropout = attention_dropout  # 设置注意力Dropout概率
        self.activation_dropout = activation_dropout  # 设置激活函数Dropout概率
        self.activation_function = activation_function  # 设置激活函数类型
        self.init_std = init_std  # 设置初始化参数的标准差
        self.encoder_layerdrop = encoder_layerdrop  # 设置编码器层间的随机删除率
        self.decoder_layerdrop = decoder_layerdrop  # 设置解码器层间的随机删除率
        self.use_cache = use_cache  # 设置是否使用缓存
        self.num_hidden_layers = encoder_layers  # 设置隐藏层的数量(与编码器层数相同)
        self.scale_embedding = scale_embedding  # 设置是否缩放嵌入

        self.num_global_tokens = num_global_tokens  # 设置全局标记的数量
        self.block_size = block_size  # 设置块大小
        self.stagger_local_blocks = stagger_local_blocks  # 设置是否交错本地块

        super().__init__(  # 调用父类的初始化函数
            pad_token_id=pad_token_id,  # 传递填充标记的ID
            eos_token_id=eos_token_id,  # 传递结束标记的ID
            is_encoder_decoder=is_encoder_decoder,  # 传递是否是编码解码结构
            decoder_start_token_id=decoder_start_token_id,  # 传递解码器起始标记ID
            forced_eos_token_id=forced_eos_token_id,  # 传递强制结束标记的ID
            **kwargs,  # 传递其他关键字参数
        )

    @property
    def num_attention_heads(self) -> int:
        return self.encoder_attention_heads  # 返回编码器中注意力头的数量

    @property
    def hidden_size(self) -> int:
        return self.d_model  # 返回模型的维度

.\models\pegasus_x\modeling_pegasus_x.py

# coding=utf-8
# 版权所有 2022 年,Google 和 The HuggingFace Inc. 团队。保留所有权利。
#
# 根据 Apache 许可证 2.0 版本(“许可证”)许可;
# 除非符合许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件是基于“原样”分发的,
# 没有任何明示或暗示的保证或条件。
# 有关详细信息,请参阅许可证。
""" PyTorch PEGASUS-X 模型。"""

import dataclasses  # 导入 dataclasses 模块,用于支持数据类
import math  # 导入 math 模块,提供数学函数
from typing import Optional, Tuple, Union  # 导入类型提示

import numpy as np  # 导入 numpy 库
import torch  # 导入 PyTorch 库
import torch.utils.checkpoint  # 导入 PyTorch 的 checkpoint 模块
from torch import nn  # 从 torch 中导入 nn 模块
from torch.nn import CrossEntropyLoss  # 导入交叉熵损失函数

from ...activations import ACT2FN  # 导入激活函数映射
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask  # 导入注意力掩码处理工具函数
from ...modeling_outputs import (  # 导入模型输出类
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel  # 导入预训练模型基类
from ...utils import (  # 导入工具函数
    add_end_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from .configuration_pegasus_x import PegasusXConfig  # 导入 PEGASUS-X 的配置文件

logger = logging.get_logger(__name__)  # 获取 logger 实例

_CHECKPOINT_FOR_DOC = "google/pegasus-x-base"  # 预训练模型的检查点名称
_CONFIG_FOR_DOC = "PegasusXConfig"  # 用于文档的配置文件名称

PEGASUS_X_PRETRAINED_MODEL_ARCHIVE_LIST = [  # 支持的 PEGASUS-X 预训练模型列表
    "google/pegasus-x-base",
    "google/pegasus-x-large",
    # 查看所有 PEGASUS 模型,请访问 https://huggingface.co/models?filter=pegasus-x
]


@dataclasses.dataclass
class DimensionInfo:
    """维度信息的包装器。"""

    batch_size: int  # 批量大小
    seq_len: int  # 标记长度
    block_size: int  # 块大小
    num_heads: int  # 头的数量
    hidden_dim: int  # 隐藏单元维度
    dim_per_head: int  # 每个头的维度
    num_blocks: int  # 块的数量
    global_len: int  # 全局长度
    padded_seq_len: int  # 填充后的标记序列长度

    # 注意:与原始 Flax 实现相比,在编码器层的开始处,我们将标记表示填充到块大小的倍数,因此始终 T=P。


# 从 transformers.models.bart.modeling_bart.shift_tokens_right 复制过来
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    """
    将输入的标记向右移动一个位置。
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)  # 创建与 input_ids 形状相同的全零张量
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()  # 将 input_ids 向右移动一位
    shifted_input_ids[:, 0] = decoder_start_token_id  # 将起始位置的标记设为 decoder_start_token_id

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id 必须定义。")
    # 用 pad_token_id 替换标签中可能存在的 -100 值
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
    # 返回变换后的输入标识符列表
    return shifted_input_ids
class PegasusXSinusoidalPositionalEmbedding(nn.Module):
    """This module produces sinusoidal positional embeddings of any length."""

    def __init__(self, embed_dim, max_scale: int = 10000.0):
        super().__init__()
        self.embed_dim = embed_dim  # 设置嵌入维度
        self.max_scale = max_scale  # 最大缩放系数

    @torch.no_grad()
    def forward(self, input_embeds: torch.Tensor, past_key_values_length: int = 0) -> torch.Tensor:
        """`input_ids_shape` is expected to be [bsz x seqlen]."""
        batch_size, seq_len = input_embeds.shape[:2]  # 获取输入张量的批量大小和序列长度
        positions = torch.arange(
            past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=input_embeds.device
        )[:, None]  # 创建位置张量,从past_key_values_length到past_key_values_length + seq_len
        pe = torch.zeros((seq_len, self.embed_dim), device=input_embeds.device, dtype=input_embeds.dtype)  # 初始化位置编码张量
        half_d_feature = self.embed_dim // 2  # 特征维度的一半
        div_term = torch.exp(
            torch.arange(half_d_feature, device=input_embeds.device, dtype=torch.int64).type_as(input_embeds)
            * -(np.log(float(self.max_scale)) / (half_d_feature - 1))
        )  # 计算分割项,用于计算正弦和余弦值
        pe[:, :half_d_feature] = torch.sin(positions * div_term)  # 计算正弦位置编码
        pe[:, half_d_feature:] = torch.cos(positions * div_term)  # 计算余弦位置编码
        return pe[None].expand(batch_size, -1, -1)  # 返回位置编码张量,扩展为与输入张量相同的形状


# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PegasusX
class PegasusXAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        bias: bool = True,
        is_causal: bool = False,
        config: Optional[PegasusXConfig] = None,
    ):
        super().__init__()
        self.embed_dim = embed_dim  # 设置嵌入维度
        self.num_heads = num_heads  # 头数
        self.dropout = dropout  # dropout率
        self.head_dim = embed_dim // num_heads  # 每个头的维度
        self.config = config

        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {num_heads})."
            )  # 检查嵌入维度是否能被头数整除

        self.scaling = self.head_dim**-0.5  # 缩放因子
        self.is_decoder = is_decoder  # 是否为解码器
        self.is_causal = is_causal  # 是否因果

        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)  # K矩阵的投影
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)  # V矩阵的投影
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)  # Q矩阵的投影
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)  # 输出矩阵的投影

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
    # 定义一个方法用于前向传播计算
    def forward(
        self,
        # 输入参数:当前隐藏状态,作为Transformer模型的输入
        hidden_states: torch.Tensor,
        # 输入参数:键-值状态,用于注意力机制的计算,可选
        key_value_states: Optional[torch.Tensor] = None,
        # 输入参数:过去的键-值状态元组,用于Transformer解码器,可选
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        # 输入参数:注意力掩码,指定哪些位置需要注意,可选
        attention_mask: Optional[torch.Tensor] = None,
        # 输入参数:层级头部掩码,控制层级上的注意力头部,可选
        layer_head_mask: Optional[torch.Tensor] = None,
        # 输入参数:是否输出注意力信息,默认为False
        output_attentions: bool = False,
    # 定义了一个名为 PegasusXGlobalLocalAttention 的类,继承自 nn.Module 类。
    """Global + Local attention. For use with Encoder only."""
    # 此类实现了全局和局部注意力机制,仅用于编码器。

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        block_size: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
    ):
        # 初始化函数,设置类的参数和模块。
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.block_size = block_size
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads

        # 检查 embed_dim 是否能被 num_heads 整除,如果不能,抛出 ValueError。
        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {num_heads})."
            )
        
        # 缩放因子,用于缩放注意力分数
        self.scaling = self.head_dim**-0.5
        self.is_decoder = is_decoder

        # 线性变换层,用于投影查询、键、值以及输出
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        # 重新整形张量的形状,以适应多头注意力的计算
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        token_hidden_states: torch.Tensor,
        global_hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ):
        # 前向传播函数,实现注意力机制的计算
        ...

    def compute_global_attention_representations(
        self, global_q, global_k, global_v, local_k, local_v, mask, dim: DimensionInfo
    ):
        # 计算全局注意力表示的函数,输入包括全局查询、全局键值对、局部键值对、掩码和维度信息
        ...
    ):
        """Compute attention representations for global tokens.

        Global tokens will attend to both global tokens as well as all input sequence tokens. Because the input
        sequence tokens are arranged in blocks for local attention, we unblock them and compute attention.

        Args:
            global_q (`torch.FloatTensor`) of shape [batch_size, num_heads, global_len, dim_per_head]:
                query vectors from global tokens
            global_k (`torch.FloatTensor`) of shape [batch_size, num_heads, global_len, dim_per_head]:
                key vectors from global tokens
            global_v (`torch.FloatTensor`) of shape [batch_size, num_heads, global_len, dim_per_head]:
                value vectors from global tokens
            local_k (`torch.FloatTensor`) of shape [batch_size, num_heads, padded_seq_len, dim_per_head]:
                key vectors from local tokens
            local_v (`torch.FloatTensor`) of shape [batch_size, num_heads, padded_seq_len, dim_per_head]:
                value vectors from local tokens
            mask (`torch.FloatTensor`) of shape [batch_size, padded_seq_len]: attention mask
            dim (DimensionInfo): DimensionInfo wrapper for dimensions

        Returns:
            output of shape `[batch_sizes, length, features]`. where length will be padded to a multiple of block_size
        """
        # Concatenate global and local key vectors along the sequence dimension
        # Shape: [batch_size, num_heads, global_len+padded_seq_len, dim_per_head]
        global_and_local_k = torch.cat([global_k, local_k], dim=2)
        
        # Concatenate global and local value vectors along the sequence dimension
        # Shape: [batch_size, num_heads, global_len+padded_seq_len, dim_per_head]
        global_and_local_v = torch.cat([global_v, local_v], dim=2)

        # Extend the mask to cover both global and local tokens
        # Shape: [batch_size, global_len+padded_seq_len]
        extended_mask = nn.functional.pad(mask, pad=(dim.global_len, 0), value=0)

        # Compute attention weights between global query and concatenated global/local key vectors
        # Shape: [batch_size, num_heads, global_len, global_len+padded_seq_len]
        attn_weights = torch.einsum("BHGF,BHXF->BHGX", global_q, global_and_local_k)
        attn_weights = attn_weights + extended_mask[:, None, None, :]  # Add extended mask
        
        # Apply softmax to compute attention probabilities and apply dropout
        attn_probs = nn.functional.softmax(attn_weights, dim=-1)
        attn_probs = nn.functional.dropout(attn_probs, p=self.dropout, training=self.training)

        # Compute attention output using attention probabilities and concatenated global/local value vectors
        # Shape: [batch_size, num_heads, global_len, dim_per_head]
        attn_output = torch.einsum("BHGX,BHXF->BHGF", attn_probs, global_and_local_v)
        return attn_output, attn_probs
# 定义 PegasusXEncoderLayer 类,继承自 nn.Module,用于实现 Pegasus X 模型的编码器层
class PegasusXEncoderLayer(nn.Module):
    # 初始化方法,接受两个参数:stagger_blocks_this_layer 表示是否在此层中交错块,config 表示配置信息对象 PegasusXConfig
    def __init__(self, stagger_blocks_this_layer: bool, config: PegasusXConfig):
        super().__init__()
        # 设置编码器层的 embed_dim 属性为配置中的 d_model
        self.embed_dim = config.d_model
        # 使用 PegasusXGlobalLocalAttention 创建自注意力机制对象 self.self_attn
        self.self_attn = PegasusXGlobalLocalAttention(
            embed_dim=self.embed_dim,
            num_heads=config.encoder_attention_heads,
            block_size=config.block_size,
            dropout=config.attention_dropout,
        )
        # 初始化自注意力层归一化层 self.self_attn_layer_norm
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        # 初始化全局自注意力层归一化层 self.global_self_attn_layer_norm
        self.global_self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        # 设置 dropout 概率
        self.dropout = config.dropout
        # 设置激活函数为配置中指定的激活函数
        self.activation_fn = ACT2FN[config.activation_function]
        # 设置激活函数 dropout 概率
        self.activation_dropout = config.activation_dropout
        # 初始化全连接层 fc1,输入维度为 embed_dim,输出维度为配置中的 encoder_ffn_dim
        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
        # 初始化全连接层 fc2,输入维度为配置中的 encoder_ffn_dim,输出维度为 embed_dim
        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
        # 初始化最终的归一化层 self.final_layer_norm
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)
        # 设置是否在此层中交错块的标志
        self.stagger_blocks_this_layer = stagger_blocks_this_layer
        # 设置块大小为配置中的 block_size
        self.block_size = config.block_size

    # 前向传播方法,接受多个参数:hidden_states 表示输入的隐藏状态张量,global_hidden_states 表示全局隐藏状态张量,
    # attention_mask 表示注意力掩码张量,output_attentions 表示是否输出注意力信息,默认为 False
    def forward(
        self,
        hidden_states: torch.Tensor,
        global_hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        output_attentions: bool = False,
    ):
        # 留待具体实现前向传播逻辑

    # 类方法,用于在本地 tokens 上填充隐藏状态和注意力掩码
    @classmethod
    def pad_local_tokens(cls, hidden_states, attention_mask, block_size):
        # hidden_states: [batch_size, seq_len, hidden_dim]
        # 计算需要填充的大小
        pad_size = block_size // 2
        # 获取张量数据类型的最小值
        mask_min_value = torch.finfo(hidden_states.dtype).min
        # 对隐藏状态进行填充,只在序列长度维度上进行填充
        padded_hidden_states = torch.nn.functional.pad(
            hidden_states,
            pad=(0, 0, pad_size, pad_size),
        )
        # 对注意力掩码进行填充,只在序列长度维度上进行填充,并设置填充值为 mask_min_value
        padded_mask = torch.nn.functional.pad(
            attention_mask,
            pad=(pad_size, pad_size),
            value=mask_min_value,
        )
        return padded_hidden_states, padded_mask

    # 类方法,用于在本地 tokens 上取消填充隐藏状态
    @classmethod
    def unpad_local_tokens(cls, padded_hidden_states, block_size):
        # padded_hidden_states: [batch_size, padded seq_len, hidden_dim]
        # 计算填充的大小
        pad_size = block_size // 2
        # 返回去除填充后的隐藏状态,仅保留有效序列长度的部分
        return padded_hidden_states[:, pad_size:-pad_size, :]


class PegasusXDecoderLayer(nn.Module):
    # 留待后续实现
    # 初始化函数,用于初始化一个 PegasusXDecoderLayer 对象
    def __init__(self, config: PegasusXConfig):
        # 调用父类的初始化方法
        super().__init__()
        # 设置嵌入维度为配置中的模型维度
        self.embed_dim = config.d_model

        # 创建自注意力机制对象,用于解码器的自注意力
        self.self_attn = PegasusXAttention(
            embed_dim=self.embed_dim,
            num_heads=config.decoder_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
            bias=False,
        )

        # 设置 dropout 概率
        self.dropout = config.dropout
        # 设置激活函数为配置中指定的激活函数
        self.activation_fn = ACT2FN[config.activation_function]
        # 设置激活函数的 dropout 概率
        self.activation_dropout = config.activation_dropout

        # 创建自注意力层的 LayerNorm 层,用于归一化输入
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)

        # 创建编码器-解码器注意力机制对象,用于解码器的编码器-解码器注意力
        self.encoder_attn = PegasusXAttention(
            self.embed_dim,
            config.decoder_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
            bias=False,
        )

        # 创建编码器-解码器注意力层的 LayerNorm 层,用于归一化输入
        self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)

        # 创建第一个全连接层,将解码器的嵌入维度映射到配置中指定的前馈神经网络维度
        self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
        # 创建第二个全连接层,将前馈神经网络的维度映射回解码器的嵌入维度
        self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)

        # 创建最终的 LayerNorm 层,用于归一化输出
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)
class PegasusXPreTrainedModel(PreTrainedModel):
    # 设置配置类为 PegasusXConfig
    config_class = PegasusXConfig
    # 模型前缀为 "model"
    base_model_prefix = "model"
    # 支持梯度检查点
    supports_gradient_checkpointing = True
    # 不拆分的模块列表,使用正则表达式指定
    _no_split_modules = [r"PegasusXEncoderLayer", r"PegasusXDecoderLayer"]

    def _init_weights(self, module):
        # 初始化权重函数,使用配置中的初始标准差
        std = self.config.init_std
        # 如果是线性层,初始化权重为正态分布,偏置为零
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        # 如果是嵌入层,初始化权重为正态分布
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)


PEGASUS_X_START_DOCSTRING = r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`PegasusXConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

PEGASUS_X_GENERATION_EXAMPLE = r"""
    Summarization example:

    ```
    >>> from transformers import AutoTokenizer, PegasusXForConditionalGeneration

    >>> model = PegasusXForConditionalGeneration.from_pretrained("google/pegasus-x-base")
    >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-x-large")

    >>> ARTICLE_TO_SUMMARIZE = (
    ...     "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
    ...     "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
    ...     "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
    ... )
    >>> inputs = tokenizer(ARTICLE_TO_SUMMARIZE, max_length=1024, return_tensors="pt")

    >>> # Generate Summary
    >>> summary_ids = model.generate(inputs["input_ids"])
    >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    "California's largest electricity provider has turned off power to hundreds of thousands of customers."
    ```
"""

PEGASUS_X_INPUTS_DOCSTRING = r"""
"""


class PegasusXEncoder(PegasusXPreTrainedModel):
    """
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    [`PegasusXEncoderLayer`].

    Args:
        config: PegasusXConfig
        embed_tokens (nn.Embedding): output embedding
    """
    def __init__(self, config: PegasusXConfig, embed_tokens: Optional[nn.Embedding] = None):
        # 调用父类构造函数初始化模型
        super().__init__(config)

        # 从配置中获取模型的dropout率和encoder层的layerdrop率
        self.dropout = config.dropout
        self.layerdrop = config.encoder_layerdrop

        # 获取词嵌入的维度,并设置最大源序列位置和词嵌入的缩放因子
        embed_dim = config.d_model
        self.max_source_positions = config.max_position_embeddings
        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

        # 如果提供了外部的嵌入词向量,则使用它;否则初始化一个词嵌入层
        if embed_tokens is not None:
            self.embed_tokens = embed_tokens
        else:
            self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim)

        # 初始化全局嵌入层和位置嵌入层
        self.embed_global = nn.Embedding(config.num_global_tokens, embed_dim)
        self.embed_positions = PegasusXSinusoidalPositionalEmbedding(embed_dim)

        # 初始化编码器层的模块列表,根据配置可能会交错局部块
        self.layers = nn.ModuleList(
            [
                PegasusXEncoderLayer(
                    stagger_blocks_this_layer=i % 2 == 1 and config.stagger_local_blocks, config=config
                )
                for i in range(config.encoder_layers)
            ]
        )

        # 初始化层归一化模块
        self.layer_norm = nn.LayerNorm(config.d_model)

        # 初始化梯度检查点标志
        self.gradient_checkpointing = False

        # 调用后续初始化函数,包括权重初始化和最终处理
        self.post_init()

    def resize_position_embeddings(self, new_num_position_embeddings: int):
        """
        Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
        config.max_position_embeddings`.

        Arguments:
            new_num_position_embeddings (`int`):
                The number of new position embeddings. If position embeddings are learned, increasing the size will add
                newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
                position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
                add correct vectors at the end following the position encoding algorithm, whereas reducing the size
                will remove vectors from the end.
        """
        # 记录日志,设置新的最大位置嵌入数量
        logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
        self.config.max_position_embeddings = new_num_position_embeddings

        # 根据新的配置重新初始化位置嵌入层
        self.embed_positions = PegasusXSinusoidalPositionalEmbedding(self.config.d_model)
        self.embed_positions.to(self.device)

    def get_position_embeddings(self) -> nn.Embedding:
        """
        Returns the position embeddings matrix
        """
        # 返回当前位置嵌入层的引用
        return self.embed_positions

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
class PegasusXDecoder(PegasusXPreTrainedModel):
    """
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PegasusDecoderLayer`]

    Args:
        config: PegasusXConfig
        embed_tokens (nn.Embedding): output embedding
    """

    def __init__(self, config: PegasusXConfig, embed_tokens: Optional[nn.Embedding] = None):
        super().__init__(config)
        self.dropout = config.dropout  # 设置 dropout 概率
        self.layerdrop = config.decoder_layerdrop  # 设置层级 dropout 概率
        self.max_target_positions = config.max_position_embeddings  # 最大目标位置
        self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0  # 嵌入尺度,如果配置中指定要缩放则开平方根

        if embed_tokens is not None:
            self.embed_tokens = embed_tokens  # 如果提供了嵌入词汇表,则使用它
        else:
            self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)  # 否则创建一个新的嵌入层

        self.embed_positions = PegasusXSinusoidalPositionalEmbedding(config.d_model)  # 初始化位置编码
        self.layers = nn.ModuleList([PegasusXDecoderLayer(config) for _ in range(config.decoder_layers)])  # 创建多个解码层
        self.layer_norm = nn.LayerNorm(config.d_model)  # 层归一化

        self.gradient_checkpointing = False  # 梯度检查点默认关闭
        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens  # 返回嵌入词汇表

    def set_input_embeddings(self, value):
        self.embed_tokens = value  # 设置新的嵌入词汇表

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        inputs_embeds=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,



@add_start_docstrings(
    "The bare PEGASUS-X Model outputting raw hidden-states without any specific head on top.",
    PEGASUS_X_START_DOCSTRING,
)
class PegasusXModel(PegasusXPreTrainedModel):
    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

    def __init__(self, config: PegasusXConfig):
        super().__init__(config)

        vocab_size = config.vocab_size
        self.shared = nn.Embedding(vocab_size, config.d_model)  # 创建共享的嵌入层

        self.encoder = PegasusXEncoder(config, self.shared)  # 初始化编码器
        self.decoder = PegasusXDecoder(config, self.shared)  # 初始化解码器

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        return self.shared  # 返回共享的嵌入层

    def set_input_embeddings(self, value):
        self.shared = value
        self.encoder.embed_tokens = self.shared  # 设置编码器的嵌入词汇表
        self.decoder.embed_tokens = self.shared  # 设置解码器的嵌入词汇表

    def get_encoder(self):
        return self.encoder  # 返回编码器实例

    def get_decoder(self):
        return self.decoder  # 返回解码器实例
    # 调整模型的位置嵌入矩阵大小,如果 `new_num_position_embeddings` 不等于 `config.max_position_embeddings`。
    def resize_position_embeddings(self, new_num_position_embeddings: int):
        # 将模型配置中的 `max_position_embeddings` 设置为新的位置嵌入数量
        self.config.max_position_embeddings = new_num_position_embeddings
        # 调整编码器的位置嵌入矩阵大小
        self.encoder.resize_position_embeddings(new_num_position_embeddings)
        # 调整解码器的位置嵌入矩阵大小
        self.decoder.resize_position_embeddings(new_num_position_embeddings)

    # 返回编码器和解码器的位置嵌入矩阵
    def get_position_embeddings(self) -> Tuple[nn.Embedding]:
        return (self.encoder.get_position_embeddings(), self.decoder.get_position_embeddings())

    # 前向传播函数,通过添加 `add_start_docstrings_to_model_forward` 和 `replace_return_docstrings` 来注释函数用途和返回类型
    @add_start_docstrings_to_model_forward(PEGASUS_X_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.Tensor] = None,
        decoder_attention_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
        past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        decoder_inputs_embeds: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
# 使用装饰器为 PegasusXForConditionalGeneration 类添加文档字符串,说明其用途和示例
@add_start_docstrings("The PEGASUS-X for conditional generation (e.g. summarization).", PEGASUS_X_START_DOCSTRING)
# 声明 PegasusXForConditionalGeneration 类,继承自 PegasusXPreTrainedModel
class PegasusXForConditionalGeneration(PegasusXPreTrainedModel):
    # 定义模型的参数前缀
    base_model_prefix = "model"
    # 定义共享权重的关键字列表
    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]

    # 初始化方法,接收 PegasusXConfig 类型的配置参数
    def __init__(self, config: PegasusXConfig):
        super().__init__(config)
        # 使用给定配置初始化 PegasusXModel 实例
        self.model = PegasusXModel(config)
        # 定义线性层 lm_head,用于生成输出的逻辑
        self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)

        # 调用后续初始化方法
        self.post_init()

    # 返回 encoder 的方法
    def get_encoder(self):
        return self.model.get_encoder()

    # 返回 decoder 的方法
    def get_decoder(self):
        return self.model.get_decoder()

    # 返回 lm_head,用于输出的嵌入层
    def get_output_embeddings(self):
        return self.lm_head

    # 设置新的输出嵌入层
    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    # 调整位置嵌入的方法,根据新的位置嵌入数量调整模型的配置和编码器、解码器的位置嵌入
    def resize_position_embeddings(self, new_num_position_embeddings: int):
        """
        Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
        config.max_position_embeddings`.

        Arguments:
            new_num_position_embeddings (`int`):
                The number of new position embeddings. If position embeddings are learned, increasing the size will add
                newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
                position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
                add correct vectors at the end following the position encoding algorithm, whereas reducing the size
                will remove vectors from the end.
        """
        # 更新模型配置的最大位置嵌入数量
        self.config.max_position_embeddings = new_num_position_embeddings
        # 调整编码器和解码器的位置嵌入
        self.model.encoder.resize_position_embeddings(new_num_position_embeddings)
        self.model.decoder.resize_position_embeddings(new_num_position_embeddings)

    # 返回编码器和解码器的位置嵌入矩阵
    def get_position_embeddings(self) -> Tuple[nn.Embedding]:
        """
        Returns the position embeddings matrix
        """
        return (self.model.encoder.get_position_embeddings(), self.model.decoder.get_position_embeddings())

    # 使用装饰器为 model_forward 方法添加文档字符串,详细说明输入和输出的文档
    @add_start_docstrings_to_model_forward(PEGASUS_X_INPUTS_DOCSTRING)
    # 替换输出的文档字符串为 Seq2SeqLMOutput 类型和配置类相关的描述
    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
    # 添加 PEGASUS_X_GENERATION_EXAMPLE 的结束文档字符串
    @add_end_docstrings(PEGASUS_X_GENERATION_EXAMPLE)
    # 定义模型的前向传播函数,接受多个输入参数,所有参数都是可选的
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,  # 输入序列的 token IDs,类型为可选的 PyTorch 张量
        attention_mask: Optional[torch.Tensor] = None,  # 输入序列的注意力掩码,类型为可选的 PyTorch 张量
        decoder_input_ids: Optional[torch.Tensor] = None,  # 解码器输入序列的 token IDs,类型为可选的 PyTorch 张量
        decoder_attention_mask: Optional[torch.Tensor] = None,  # 解码器输入序列的注意力掩码,类型为可选的 PyTorch 张量
        encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,  # 编码器的输出,类型为可选的 PyTorch 浮点张量元组
        past_key_values: Optional[Tuple[torch.FloatTensor]] = None,  # 用于存储解码器过去键值的元组,类型为可选的 PyTorch 浮点张量
        inputs_embeds: Optional[torch.Tensor] = None,  # 输入序列的嵌入表示,类型为可选的 PyTorch 张量
        decoder_inputs_embeds: Optional[torch.Tensor] = None,  # 解码器输入序列的嵌入表示,类型为可选的 PyTorch 张量
        labels: Optional[torch.Tensor] = None,  # 模型的标签,类型为可选的 PyTorch 张量
        use_cache: Optional[bool] = None,  # 是否使用缓存,类型为可选的布尔值
        output_attentions: Optional[bool] = None,  # 是否输出注意力权重,类型为可选的布尔值
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态,类型为可选的布尔值
        return_dict: Optional[bool] = None,  # 是否以字典形式返回结果,类型为可选的布尔值
    ) -> Union[Tuple, Seq2SeqLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:
            Depending on `return_dict`:
            - if `return_dict` is `False`: returns a tuple comprising `lm_logits` and additional outputs from the model.
            - if `return_dict` is `True`: returns a `Seq2SeqLMOutput` object containing loss, logits, and other model outputs.

        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if labels is not None:
            # If `labels` are provided, adjust `use_cache` and prepare `decoder_input_ids` if necessary.
            if use_cache:
                logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
            use_cache = False
            if decoder_input_ids is None and decoder_inputs_embeds is None:
                # Shift labels to the right to create `decoder_input_ids` for training.
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
                )

        # Pass inputs to the model for computation.
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # Compute logits from the model's output.
        lm_logits = self.lm_head(outputs[0])

        masked_lm_loss = None
        if labels is not None:
            # Compute masked language modeling loss if `labels` are provided.
            loss_fct = CrossEntropyLoss()
            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))

        if not return_dict:
            # If `return_dict` is `False`, format output as a tuple.
            output = (lm_logits,) + outputs[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        # If `return_dict` is `True`, format output using `Seq2SeqLMOutput`.
        return Seq2SeqLMOutput(
            loss=masked_lm_loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )

    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        past_key_values=None,
        attention_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):
        # 如果使用了过去的键值对,则根据过去的长度调整 decoder_input_ids
        if past_key_values is not None:
            # 获取过去键值对的第一个元素的长度
            past_length = past_key_values[0][0].shape[2]

            # 一些生成方法可能已经只传递了最后一个输入 ID
            if decoder_input_ids.shape[1] > past_length:
                # 如果 decoder_input_ids 的长度大于过去的长度,只保留后面的部分
                remove_prefix_length = past_length
            else:
                # 默认行为:只保留最后一个 ID
                remove_prefix_length = decoder_input_ids.shape[1] - 1

            # 调整 decoder_input_ids,去除前缀部分
            decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]

        # 返回一个包含各种信息的字典
        return {
            "input_ids": None,  # encoder_outputs 已定义,不需要 input_ids
            "encoder_outputs": encoder_outputs,
            "past_key_values": past_key_values,
            "decoder_input_ids": decoder_input_ids,
            "attention_mask": attention_mask,
            "use_cache": use_cache,  # 更改此项以避免缓存(可能是为了调试目的)
        }

    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        # 调用 shift_tokens_right 函数,根据标签生成 decoder_input_ids
        return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            # 缓存的交叉注意力状态无需重新排序 -> 它们始终相同
            # 重新排序每一层的过去状态,以便按照 beam_idx 的顺序重新排列
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
                + layer_past[2:],
            )
        return reordered_past
# 从 transformers.models.bart.modeling_bart.BartDecoderWrapper 复制并修改为 PegasusXDecoderWrapper
# 这个类是一个辅助类,用于在使用因果语言模型与 EncoderDecoderModel 框架结合时,正确加载预训练检查点。

class PegasusXDecoderWrapper(PegasusXPreTrainedModel):
    """
    这个包装类是一个辅助类,用于在因果语言模型与 EncoderDecoderModel 框架结合时正确加载预训练检查点。
    """

    def __init__(self, config):
        # 调用父类构造函数初始化对象
        super().__init__(config)
        # 创建 PegasusXDecoder 对象作为该类的一个属性
        self.decoder = PegasusXDecoder(config)

    def forward(self, *args, **kwargs):
        # 将前向传播调用委托给 self.decoder 对象
        return self.decoder(*args, **kwargs)

.\models\pegasus_x\__init__.py

# 版权声明和版权许可信息,标识本代码版权归 HuggingFace 团队所有,受 Apache License, Version 2.0 许可
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入类型检查模块中的 TYPE_CHECKING 类型
from typing import TYPE_CHECKING

# 从 utils 中导入 OptionalDependencyNotAvailable、_LazyModule 和 is_torch_available 函数
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available

# 定义模块导入结构字典 _import_structure
_import_structure = {
    "configuration_pegasus_x": ["PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusXConfig"],
}

# 检查是否 Torch 可用,若不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若 Torch 可用,则扩展 _import_structure 字典,导入 modeling_pegasus_x 模块中的类和变量
    _import_structure["modeling_pegasus_x"] = [
        "PEGASUS_X_PRETRAINED_MODEL_ARCHIVE_LIST",
        "PegasusXForConditionalGeneration",
        "PegasusXModel",
        "PegasusXPreTrainedModel",
    ]

# 如果是类型检查模式
if TYPE_CHECKING:
    # 从 configuration_pegasus_x 模块中导入 PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP 和 PegasusXConfig 类
    from .configuration_pegasus_x import PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusXConfig

    # 检查 Torch 是否可用,若不可用则抛出 OptionalDependencyNotAvailable 异常
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 若 Torch 可用,则从 modeling_pegasus_x 模块中导入相关类和变量
        from .modeling_pegasus_x import (
            PEGASUS_X_PRETRAINED_MODEL_ARCHIVE_LIST,
            PegasusXForConditionalGeneration,
            PegasusXModel,
            PegasusXPreTrainedModel,
        )

# 如果不是类型检查模式
else:
    # 导入 sys 模块
    import sys

    # 将当前模块替换为 LazyModule 对象,延迟加载相关模块
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\perceiver\configuration_perceiver.py

# coding=utf-8
# 声明文件编码格式为 UTF-8

# 导入必要的模块和类
from collections import OrderedDict  # 导入 OrderedDict 类,用于有序字典
from typing import Any, Mapping, Optional, Union  # 导入类型提示相关的类和方法

# 导入配置相关的类和函数
from ...configuration_utils import PretrainedConfig  # 导入预训练配置类
from ...feature_extraction_utils import FeatureExtractionMixin  # 导入特征提取混合类
from ...onnx import OnnxConfig  # 导入 ONNX 配置类
from ...onnx.utils import compute_effective_axis_dimension  # 导入计算有效轴维度的方法
from ...tokenization_utils_base import PreTrainedTokenizerBase  # 导入预训练分词器基类
from ...utils import TensorType, logging  # 导入 TensorType 和 logging 工具

# 获取全局日志记录器
logger = logging.get_logger(__name__)

# 预训练配置文件映射表,指定不同模型的配置文件链接
PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "deepmind/language-perceiver": "https://huggingface.co/deepmind/language-perceiver/resolve/main/config.json",
    # 可查看所有 Perceiver 模型列表:https://huggingface.co/models?filter=perceiver
}


class PerceiverConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`PerceiverModel`]. It is used to instantiate an
    Perceiver model according to the specified arguments, defining the model architecture. Instantiating a
    configuration with the defaults will yield a similar configuration to that of the Perceiver
    [deepmind/language-perceiver](https://huggingface.co/deepmind/language-perceiver) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Example:

    ```
    >>> from transformers import PerceiverModel, PerceiverConfig

    >>> # Initializing a Perceiver deepmind/language-perceiver style configuration
    >>> configuration = PerceiverConfig()

    >>> # Initializing a model from the deepmind/language-perceiver style configuration
    >>> model = PerceiverModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""
    
    model_type = "perceiver"
    # 定义一个初始化方法,用于设置模型的各种参数和属性
    def __init__(
        self,
        num_latents=256,  # Latent space dimensionality
        d_latents=1280,  # Dimensionality of latent vectors
        d_model=768,  # Dimensionality of the model
        num_blocks=1,  # Number of transformer blocks
        num_self_attends_per_block=26,  # Number of self-attention layers per block
        num_self_attention_heads=8,  # Number of self-attention heads
        num_cross_attention_heads=8,  # Number of cross-attention heads
        qk_channels=None,  # Query and key projection dimensionality
        v_channels=None,  # Value projection dimensionality
        cross_attention_shape_for_attention="kv",  # Shape for cross-attention computation
        self_attention_widening_factor=1,  # Self-attention widening factor
        cross_attention_widening_factor=1,  # Cross-attention widening factor
        hidden_act="gelu",  # Activation function for hidden layers
        attention_probs_dropout_prob=0.1,  # Dropout probability for attention weights
        initializer_range=0.02,  # Range for weight initialization
        layer_norm_eps=1e-12,  # Epsilon for layer normalization
        use_query_residual=True,  # Flag indicating whether to use query residual connections
        vocab_size=262,  # Size of vocabulary for masked language modeling
        max_position_embeddings=2048,  # Maximum number of positional embeddings
        image_size=56,  # Size of input images for image classification
        train_size=[368, 496],  # Size of training images
        num_frames=16,  # Number of frames in video input
        audio_samples_per_frame=1920,  # Audio samples per video frame
        samples_per_patch=16,  # Number of samples per image patch
        output_shape=[1, 16, 224, 224],  # Output shape of the model
        output_num_channels=512,  # Number of output channels
        _label_trainable_num_channels=1024,  # Number of channels for trainable labels
        **kwargs,  # Additional keyword arguments
    ):
        # 调用父类的初始化方法,传入额外的关键字参数
        super().__init__(**kwargs)
    
        # 初始化模型的各种参数和属性
        self.num_latents = num_latents
        self.d_latents = d_latents
        self.d_model = d_model
        self.num_blocks = num_blocks
        self.num_self_attends_per_block = num_self_attends_per_block
        self.num_self_attention_heads = num_self_attention_heads
        self.num_cross_attention_heads = num_cross_attention_heads
        self.qk_channels = qk_channels
        self.v_channels = v_channels
        self.cross_attention_shape_for_attention = cross_attention_shape_for_attention
        self.self_attention_widening_factor = self_attention_widening_factor
        self.cross_attention_widening_factor = cross_attention_widening_factor
        self.hidden_act = hidden_act
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps
        self.use_query_residual = use_query_residual
        # 以下是针对不同任务的特定属性
    
        # Masked Language Modeling任务的属性
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
    
        # Image Classification任务的属性
        self.image_size = image_size
    
        # Flow任务的属性
        self.train_size = train_size
    
        # Multimodal Autoencoding任务的属性
        self.num_frames = num_frames
        self.audio_samples_per_frame = audio_samples_per_frame
        self.samples_per_patch = samples_per_patch
    
        # 输出的形状和通道数属性
        self.output_shape = output_shape
        self.output_num_channels = output_num_channels
    
        # 可训练标签的通道数属性
        self._label_trainable_num_channels = _label_trainable_num_channels
class PerceiverOnnxConfig(OnnxConfig):
    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        if self.task == "multiple-choice":
            # 如果任务为多选题,则定义动态轴的维度
            dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
        else:
            # 否则定义动态轴的维度
            dynamic_axis = {0: "batch", 1: "sequence"}
        # 返回有序字典,包含输入名称和对应的动态轴
        return OrderedDict(
            [
                ("inputs", dynamic_axis),
                ("attention_mask", dynamic_axis),
            ]
        )

    @property
    def atol_for_validation(self) -> float:
        # 返回用于验证的绝对容差值
        return 1e-4

    def generate_dummy_inputs(
        self,
        preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
        batch_size: int = -1,
        seq_length: int = -1,
        num_choices: int = -1,
        is_pair: bool = False,
        framework: Optional[TensorType] = None,
        num_channels: int = 3,
        image_width: int = 40,
        image_height: int = 40,
    ) -> Mapping[str, Any]:
        # 从`transformers.onnx.config.OnnxConfig`中复制并稍作修改和简化

        if isinstance(preprocessor, PreTrainedTokenizerBase):
            # 如果预处理器是预训练的分词器,则根据需要设置动态轴的维度
            batch_size = compute_effective_axis_dimension(
                batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
            )
            # 获取要添加的特殊标记的数量
            token_to_add = preprocessor.num_special_tokens_to_add(is_pair)
            # 根据需要设置动态轴的维度
            seq_length = compute_effective_axis_dimension(
                seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
            )
            # 根据计算的批次大小和序列长度生成虚拟输入
            dummy_input = [" ".join(["a"]) * seq_length] * batch_size
            # 使用预处理器生成输入字典,并将输入名称标准化为`input_ids`
            inputs = dict(preprocessor(dummy_input, return_tensors=framework))
            inputs["inputs"] = inputs.pop("input_ids")
            return inputs
        elif isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "pixel_values":
            # 如果预处理器是特征提取混合类,并且模型输入名称为`pixel_values`
            batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch)
            # 根据指定的批次大小和图像尺寸生成虚拟图像数据
            dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width)
            # 使用预处理器生成输入字典,并将输入名称标准化为`pixel_values`
            inputs = dict(preprocessor(images=dummy_input, return_tensors=framework))
            inputs["inputs"] = inputs.pop("pixel_values")
            return inputs
        else:
            # 如果无法为模型生成虚拟输入,则抛出值错误异常
            raise ValueError(
                "Unable to generate dummy inputs for the model. Please provide a tokenizer or a preprocessor."
            )

.\models\perceiver\convert_perceiver_haiku_to_pytorch.py

# 设置编码格式为 UTF-8
# 版权声明,声明代码版权及使用许可
# 根据 Apache 许可证版本 2.0 使用本文件,详见指定链接
# 除非适用法律要求或书面同意,本软件是基于"原样"提供的,无任何明示或暗示的保证或条件
# 请参阅许可证,了解详细的法律条款
"""将 Haiku 实现的 Perceiver 检查点转换为 PyTorch 模型。"""


import argparse  # 导入用于解析命令行参数的模块
import json  # 导入处理 JSON 数据的模块
import pickle  # 导入序列化和反序列化 Python 对象的模块
from pathlib import Path  # 导入处理路径的模块

import haiku as hk  # 导入 Haiku 深度学习库
import numpy as np  # 导入处理数组和矩阵的数学库
import requests  # 导入处理 HTTP 请求的库
import torch  # 导入 PyTorch 深度学习库
from huggingface_hub import hf_hub_download  # 导入从 Hugging Face Hub 下载模型的函数
from PIL import Image  # 导入处理图像的 Python 库

from transformers import (  # 导入 Transformers 库中的多个模型和工具类
    PerceiverConfig,
    PerceiverForImageClassificationConvProcessing,
    PerceiverForImageClassificationFourier,
    PerceiverForImageClassificationLearned,
    PerceiverForMaskedLM,
    PerceiverForMultimodalAutoencoding,
    PerceiverForOpticalFlow,
    PerceiverImageProcessor,
    PerceiverTokenizer,
)
from transformers.utils import logging  # 导入 Transformers 中的日志模块


logging.set_verbosity_info()  # 设置日志记录详细程度为信息级别
logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器


def prepare_img():
    # 我们将使用一张狗的图像来验证我们的结果
    url = "https://storage.googleapis.com/perceiver_io/dalmation.jpg"
    im = Image.open(requests.get(url, stream=True).raw)  # 从 URL 加载图像并打开
    return im


def rename_keys(state_dict, architecture):
@torch.no_grad()  # 使用装饰器声明不需要梯度的上下文管理器
def convert_perceiver_checkpoint(pickle_file, pytorch_dump_folder_path, architecture="MLM"):
    """
    将模型的权重复制/粘贴/调整为我们的 Perceiver 结构。
    """

    # 将参数作为 FlatMapping 数据结构加载
    with open(pickle_file, "rb") as f:
        checkpoint = pickle.loads(f.read())

    state = None
    if isinstance(checkpoint, dict) and architecture in [
        "image_classification",
        "image_classification_fourier",
        "image_classification_conv",
    ]:
        # 图像分类 Conv 检查点还包含批归一化状态 (running_mean 和 running_var)
        params = checkpoint["params"]
        state = checkpoint["state"]
    else:
        params = checkpoint

    # 转换为初始状态字典
    state_dict = {}
    for scope_name, parameters in hk.data_structures.to_mutable_dict(params).items():
        for param_name, param in parameters.items():
            state_dict[scope_name + "/" + param_name] = param

    if state is not None:
        # 添加状态变量
        for scope_name, parameters in hk.data_structures.to_mutable_dict(state).items():
            for param_name, param in parameters.items():
                state_dict[scope_name + "/" + param_name] = param

    # 重命名键名
    rename_keys(state_dict, architecture=architecture)

    # 加载 HuggingFace 模型
    config = PerceiverConfig()
    # 初始化 subsampling 变量为 None
    subsampling = None
    # 设置 repo_id 变量为 "huggingface/label-files"
    repo_id = "huggingface/label-files"
    # 根据不同的架构设置模型配置和实例化不同的 Perceiver 模型
    if architecture == "MLM":
        # 针对 MLM 架构设置特定的配置参数
        config.qk_channels = 8 * 32
        config.v_channels = 1280
        # 实例化一个 PerceiverForMaskedLM 模型
        model = PerceiverForMaskedLM(config)
    elif "image_classification" in architecture:
        # 针对图像分类相关架构设置特定的配置参数
        config.num_latents = 512
        config.d_latents = 1024
        config.d_model = 512
        config.num_blocks = 8
        config.num_self_attends_per_block = 6
        config.num_cross_attention_heads = 1
        config.num_self_attention_heads = 8
        # 重置 config 中的 qk_channels 和 v_channels 为 None
        config.qk_channels = None
        config.v_channels = None
        # 设置 num_labels 为 1000,并加载对应的类别标签映射文件
        config.num_labels = 1000
        filename = "imagenet-1k-id2label.json"
        # 从指定的 repo_id 中下载并读取 id2label 映射
        id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
        # 将 id2label 字典的键值转换为整数类型
        id2label = {int(k): v for k, v in id2label.items()}
        # 设置模型配置的 id2label 和 label2id 属性
        config.id2label = id2label
        config.label2id = {v: k for k, v in id2label.items()}
        if architecture == "image_classification":
            # 针对 image_classification 架构设置图像尺寸为 224,并实例化 PerceiverForImageClassificationLearned 模型
            config.image_size = 224
            model = PerceiverForImageClassificationLearned(config)
        elif architecture == "image_classification_fourier":
            # 针对 image_classification_fourier 架构设置特定的 d_model,并实例化 PerceiverForImageClassificationFourier 模型
            config.d_model = 261
            model = PerceiverForImageClassificationFourier(config)
        elif architecture == "image_classification_conv":
            # 针对 image_classification_conv 架构设置特定的 d_model,并实例化 PerceiverForImageClassificationConvProcessing 模型
            config.d_model = 322
            model = PerceiverForImageClassificationConvProcessing(config)
        else:
            # 如果架构不在预期的架构列表中,抛出异常
            raise ValueError(f"Architecture {architecture} not supported")
    elif architecture == "optical_flow":
        # 针对 optical_flow 架构设置特定的配置参数,并实例化 PerceiverForOpticalFlow 模型
        config.num_latents = 2048
        config.d_latents = 512
        config.d_model = 322
        config.num_blocks = 1
        config.num_self_attends_per_block = 24
        config.num_self_attention_heads = 16
        config.num_cross_attention_heads = 1
        model = PerceiverForOpticalFlow(config)
    # 如果架构是多模态自编码
    elif architecture == "multimodal_autoencoding":
        # 设置编码器的输入大小为图像的像素数
        config.num_latents = 28 * 28 * 1
        # 设置潜在空间向量的维度
        config.d_latents = 512
        # 设置模型的维度
        config.d_model = 704
        # 设置模型的块数
        config.num_blocks = 1
        # 每个块的自注意力层数
        config.num_self_attends_per_block = 8
        # 自注意力头数
        config.num_self_attention_heads = 8
        # 交叉注意力头数
        config.num_cross_attention_heads = 1
        # 标签数
        config.num_labels = 700
        
        # 定义虚拟输入和子采样(因为每次前向传播只处理图像和音频数据的一部分)
        images = torch.randn((1, 16, 3, 224, 224))
        audio = torch.randn((1, 30720, 1))
        nchunks = 128
        # 图像块大小
        image_chunk_size = np.prod((16, 224, 224)) // nchunks
        # 音频块大小
        audio_chunk_size = audio.shape[1] // config.samples_per_patch // nchunks
        
        # 处理第一个块
        chunk_idx = 0
        # 设置子采样字典,包含图像和音频的索引
        subsampling = {
            "image": torch.arange(image_chunk_size * chunk_idx, image_chunk_size * (chunk_idx + 1)),
            "audio": torch.arange(audio_chunk_size * chunk_idx, audio_chunk_size * (chunk_idx + 1)),
            "label": None,
        }
        
        # 创建多模态自编码器模型
        model = PerceiverForMultimodalAutoencoding(config)
        
        # 设置标签
        filename = "kinetics700-id2label.json"
        # 从数据集库中下载标签文件
        id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
        # 将标签字典中的键转换为整数类型
        id2label = {int(k): v for k, v in id2label.items()}
        config.id2label = id2label
        # 设置标签到ID的映射
        config.label2id = {v: k for k, v in id2label.items()}
    else:
        # 抛出异常,指出不支持的架构类型
        raise ValueError(f"Architecture {architecture} not supported")
    
    # 将模型设置为评估模式
    model.eval()
    
    # 加载模型权重
    model.load_state_dict(state_dict)
    
    # 准备虚拟输入
    input_mask = None
    if architecture == "MLM":
        # 从预训练的分词器创建分词器对象
        tokenizer = PerceiverTokenizer.from_pretrained("/Users/NielsRogge/Documents/Perceiver/Tokenizer files")
        # 文本输入,包含一部分单词缺失的不完整句子
        text = "This is an incomplete sentence where some words are missing."
        # 对文本进行编码,填充到最大长度,并返回PyTorch张量
        encoding = tokenizer(text, padding="max_length", return_tensors="pt")
        # 掩码掉 " missing." 部分的词。模型更好地表现需要掩码的部分以空格开头。
        encoding.input_ids[0, 51:60] = tokenizer.mask_token_id
        inputs = encoding.input_ids
        input_mask = encoding.attention_mask
    elif architecture in ["image_classification", "image_classification_fourier", "image_classification_conv"]:
        # 创建图像处理器对象
        image_processor = PerceiverImageProcessor()
        # 准备图像数据
        image = prepare_img()
        # 对图像进行编码,返回PyTorch张量
        encoding = image_processor(image, return_tensors="pt")
        inputs = encoding.pixel_values
    elif architecture == "optical_flow":
        # 生成随机张量作为输入
        inputs = torch.randn(1, 2, 27, 368, 496)
    elif architecture == "multimodal_autoencoding":
        # 使用虚拟数据设置输入为图像、音频和标签
        images = torch.randn((1, 16, 3, 224, 224))
        audio = torch.randn((1, 30720, 1))
        inputs = {"image": images, "audio": audio, "label": torch.zeros((images.shape[0], 700))}
    
    # 执行前向传播
    if architecture == "multimodal_autoencoding":
        # 使用模型进行前向传播,传入输入数据、注意力掩码和子采样输出点
        outputs = model(inputs=inputs, attention_mask=input_mask, subsampled_output_points=subsampling)
    else:
        # 使用模型进行推理,获取模型输出
        outputs = model(inputs=inputs, attention_mask=input_mask)
    # 获取模型输出中的 logits
    logits = outputs.logits

    # 验证 logits
    if not isinstance(logits, dict):
        # 如果 logits 不是字典,打印其形状
        print("Shape of logits:", logits.shape)
    else:
        # 如果 logits 是字典,逐个打印每个模态的 logits 形状
        for k, v in logits.items():
            print(f"Shape of logits of modality {k}", v.shape)

    if architecture == "MLM":
        # 对于 Masked Language Model (MLM) 架构
        expected_slice = torch.tensor(
            [[-11.8336, -11.6850, -11.8483], [-12.8149, -12.5863, -12.7904], [-12.8440, -12.6410, -12.8646]]
        )
        # 断言切片部分的 logits 与预期的张量接近
        assert torch.allclose(logits[0, :3, :3], expected_slice)
        # 获取被掩码的标记的预测值,并转换为列表
        masked_tokens_predictions = logits[0, 51:60].argmax(dim=-1).tolist()
        # 预期的列表
        expected_list = [38, 115, 111, 121, 121, 111, 116, 109, 52]
        # 断言掩码标记的预测值与预期列表相等
        assert masked_tokens_predictions == expected_list
        # 打印贪婪预测结果
        print("Greedy predictions:")
        print(masked_tokens_predictions)
        print()
        # 打印预测的字符串
        print("Predicted string:")
        print(tokenizer.decode(masked_tokens_predictions))

    elif architecture in ["image_classification", "image_classification_fourier", "image_classification_conv"]:
        # 对于图像分类等架构,打印预测的类别
        print("Predicted class:", model.config.id2label[logits.argmax(-1).item()])

    # 最后,保存文件
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
    # 打印保存模型的路径
    print(f"Saving model to {pytorch_dump_folder_path}")
    # 将模型保存到指定路径
    model.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
    # 如果当前脚本被直接执行(而不是被导入到其他模块中),则执行以下代码块

    parser = argparse.ArgumentParser()
    # 创建一个参数解析器对象

    # Required parameters
    parser.add_argument(
        "--pickle_file",
        type=str,
        default=None,
        required=True,
        help="Path to local pickle file of a Perceiver checkpoint you'd like to convert.",
    )
    # 添加一个必需的参数:指向本地 Perceiver 检查点 pickle 文件的路径

    parser.add_argument(
        "--pytorch_dump_folder_path",
        default=None,
        type=str,
        required=True,
        help="Path to the output PyTorch model directory, provided as a string.",
    )
    # 添加一个必需的参数:指向输出 PyTorch 模型目录的路径,作为一个字符串提供

    parser.add_argument(
        "--architecture",
        default="MLM",
        type=str,
        help="""
        Architecture, provided as a string. One of 'MLM', 'image_classification', image_classification_fourier',
        image_classification_fourier', 'optical_flow' or 'multimodal_autoencoding'.
        """,
    )
    # 添加一个可选参数:模型的架构类型,作为字符串提供。可选项包括 'MLM', 'image_classification',
    # 'image_classification_fourier', 'optical_flow' 或 'multimodal_autoencoding'

    args = parser.parse_args()
    # 解析命令行参数,并将其存储在 args 变量中

    convert_perceiver_checkpoint(args.pickle_file, args.pytorch_dump_folder_path, args.architecture)
    # 调用函数 convert_perceiver_checkpoint,传递命令行参数中的 pickle_file、pytorch_dump_folder_path 和 architecture

.\models\perceiver\feature_extraction_perceiver.py

# 设置文件编码为 UTF-8
# 版权声明和许可信息,告知代码使用者遵循 Apache 许可证版本 2.0 使用,禁止未经许可的复制和修改
# 获取 Apache 许可证版本 2.0 的详细信息的链接
# 根据适用法律或书面同意,按“现状”分发软件,不提供任何形式的担保或条件
# 引入警告模块,用于将来可能删除的类的警告
# 引入日志模块,用于记录和输出日志信息
# 从 image_processing_perceiver 模块中导入 PerceiverImageProcessor 类
import warnings
from ...utils import logging
from .image_processing_perceiver import PerceiverImageProcessor

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# PerceiverFeatureExtractor 类继承自 PerceiverImageProcessor 类
class PerceiverFeatureExtractor(PerceiverImageProcessor):
    # 初始化方法,接受任意位置参数和关键字参数
    def __init__(self, *args, **kwargs) -> None:
        # 发出未来警告,提醒用户 PerceiverFeatureExtractor 类将在 Transformers 版本 5 中删除,建议使用 PerceiverImageProcessor 替代
        warnings.warn(
            "The class PerceiverFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
            " Please use PerceiverImageProcessor instead.",
            FutureWarning,
        )
        # 调用父类 PerceiverImageProcessor 的初始化方法,传递所有位置参数和关键字参数
        super().__init__(*args, **kwargs)

.\models\perceiver\image_processing_perceiver.py

# 导入必要的模块和函数
from typing import Dict, List, Optional, Union  # 导入类型提示相关的模块

import numpy as np  # 导入NumPy库,用于数值计算

# 导入与图像处理相关的工具函数和类
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict  
# 从自定义模块中导入基础图像处理器、批量特征和获取大小字典函数

from ...image_transforms import center_crop, resize, to_channel_dimension_format  
# 从自定义模块中导入中心裁剪、调整大小和转换通道维度格式的函数

from ...image_utils import (
    IMAGENET_DEFAULT_MEAN,  # 导入ImageNet图像默认均值
    IMAGENET_DEFAULT_STD,   # 导入ImageNet图像默认标准差
    ChannelDimension,       # 导入通道维度枚举
    ImageInput,             # 导入图像输入类
    PILImageResampling,     # 导入PIL图像重采样枚举
    get_image_size,         # 导入获取图像尺寸的函数
    infer_channel_dimension_format,  # 导入推断通道维度格式的函数
    is_scaled_image,        # 导入判断是否为缩放图像的函数
    make_list_of_images,    # 导入生成图像列表的函数
    to_numpy_array,         # 导入转换为NumPy数组的函数
    valid_images,           # 导入验证图像函数
    validate_kwargs,        # 导入验证关键字参数的函数
    validate_preprocess_arguments,  # 导入验证预处理参数的函数
)

from ...utils import TensorType, is_vision_available, logging  # 导入张量类型、判断视觉库是否可用和日志记录相关的模块和函数

if is_vision_available():
    import PIL  # 如果视觉库可用,则导入PIL库

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器对象
    # 定义函数参数和默认值,用于控制图像预处理的各个步骤和参数
    Args:
        do_center_crop (`bool`, `optional`, defaults to `True`):
            是否进行中心裁剪图像。如果输入尺寸小于 `crop_size` 的任何边,图像将被填充为零,然后进行中心裁剪。
            可以被 `preprocess` 方法中的 `do_center_crop` 参数覆盖。
        crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 256, "width": 256}`):
            应用中心裁剪时的期望输出尺寸。可以被 `preprocess` 方法中的 `crop_size` 参数覆盖。
        do_resize (`bool`, *optional*, defaults to `True`):
            是否调整图像大小为 `(size["height"], size["width"])`。
            可以被 `preprocess` 方法中的 `do_resize` 参数覆盖。
        size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
            调整大小后的图像尺寸。可以被 `preprocess` 方法中的 `size` 参数覆盖。
        resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
            定义在调整图像大小时使用的重采样滤波器。
            可以被 `preprocess` 方法中的 `resample` 参数覆盖。
        do_rescale (`bool`, *optional*, defaults to `True`):
            是否按指定的比例因子 `rescale_factor` 进行重新缩放图像。
            可以被 `preprocess` 方法中的 `do_rescale` 参数覆盖。
        rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
            如果重新缩放图像,则定义要使用的比例因子。
            可以被 `preprocess` 方法中的 `rescale_factor` 参数覆盖。
        do_normalize:
            是否对图像进行归一化。
            可以被 `preprocess` 方法中的 `do_normalize` 参数覆盖。
        image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
            如果归一化图像,则使用的平均值。这是一个浮点数或与图像通道数相同长度的浮点数列表。
            可以被 `preprocess` 方法中的 `image_mean` 参数覆盖。
        image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
            如果归一化图像,则使用的标准差。这是一个浮点数或与图像通道数相同长度的浮点数列表。
            可以被 `preprocess` 方法中的 `image_std` 参数覆盖。
    """
    
    # 定义模型输入的名称列表
    model_input_names = ["pixel_values"]
    # 初始化函数,设置图像预处理参数和默认值
    def __init__(
        self,
        do_center_crop: bool = True,  # 是否进行中心裁剪,默认为True
        crop_size: Dict[str, int] = None,  # 裁剪尺寸字典,可以为空
        do_resize: bool = True,  # 是否进行调整大小,默认为True
        size: Dict[str, int] = None,  # 调整大小的目标尺寸字典,可以为空
        resample: PILImageResampling = PILImageResampling.BICUBIC,  # 调整大小的插值方法,默认为双三次插值
        do_rescale: bool = True,  # 是否进行重新缩放,默认为True
        rescale_factor: Union[int, float] = 1 / 255,  # 重新缩放的因子,默认为1/255
        do_normalize: bool = True,  # 是否进行归一化,默认为True
        image_mean: Optional[Union[float, List[float]]] = None,  # 图像均值,可以为空
        image_std: Optional[Union[float, List[float]]] = None,  # 图像标准差,可以为空
        **kwargs,  # 其他未指定参数
    ) -> None:
        # 调用父类的初始化函数
        super().__init__(**kwargs)
        # 如果未指定裁剪尺寸,则使用默认的裁剪尺寸
        crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
        # 根据指定的裁剪尺寸参数名称获取尺寸字典
        crop_size = get_size_dict(crop_size, param_name="crop_size")
        # 如果未指定调整大小的目标尺寸,则使用默认的目标尺寸
        size = size if size is not None else {"height": 224, "width": 224}
        # 获取调整大小的尺寸字典
        size = get_size_dict(size)

        # 初始化对象的属性
        self.do_center_crop = do_center_crop  # 是否进行中心裁剪
        self.crop_size = crop_size  # 裁剪尺寸字典
        self.do_resize = do_resize  # 是否进行调整大小
        self.size = size  # 调整大小的目标尺寸字典
        self.resample = resample  # 调整大小的插值方法
        self.do_rescale = do_rescale  # 是否进行重新缩放
        self.rescale_factor = rescale_factor  # 重新缩放的因子
        self.do_normalize = do_normalize  # 是否进行归一化
        self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN  # 图像均值
        self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD  # 图像标准差
        # 有效的处理器关键字列表
        self._valid_processor_keys = [
            "images", "do_center_crop", "crop_size", "do_resize", "size", 
            "resample", "do_rescale", "rescale_factor", "do_normalize", 
            "image_mean", "image_std", "return_tensors", "data_format", 
            "input_data_format"
        ]
    ) -> np.ndarray:
        """
        Center crop an image to `(size["height"] / crop_size["height"] * min_dim, size["width"] / crop_size["width"] *
        min_dim)`. Where `min_dim = min(size["height"], size["width"])`.

        If the input size is smaller than `crop_size` along any edge, the image will be padded with zeros and then
        center cropped.

        Args:
            image (`np.ndarray`):
                Image to center crop.
            crop_size (`Dict[str, int]`):
                Desired output size after applying the center crop.
            size (`Dict[str, int]`, *optional*):
                Size of the image after resizing. If not provided, the self.size attribute will be used.
            data_format (`str` or `ChannelDimension`, *optional*):
                The channel dimension format of the image. If not provided, it will be the same as the input image.
            input_data_format (`str` or `ChannelDimension`, *optional*):
                The channel dimension format of the input image. If not provided, it will be inferred.
        """
        # 如果没有提供特定的尺寸参数,则使用默认的 self.size 属性
        size = self.size if size is None else size
        # 根据给定的 size 获取一个规范化的尺寸字典
        size = get_size_dict(size)
        # 根据给定的 crop_size 获取一个规范化的尺寸字典
        crop_size = get_size_dict(crop_size, param_name="crop_size")

        # 获取输入图片的高度和宽度,并根据输入数据格式确定通道维度
        height, width = get_image_size(image, channel_dim=input_data_format)
        # 计算输入图片中较小的维度作为 min_dim
        min_dim = min(height, width)
        # 计算裁剪后的高度和宽度,确保按比例缩放
        cropped_height = (size["height"] / crop_size["height"]) * min_dim
        cropped_width = (size["width"] / crop_size["width"]) * min_dim
        # 调用 center_crop 函数进行中心裁剪,并返回裁剪后的图片
        return center_crop(
            image,
            size=(cropped_height, cropped_width),
            data_format=data_format,
            input_data_format=input_data_format,
            **kwargs,
        )

    # 从 transformers.models.vit.image_processing_vit.ViTImageProcessor.resize 复制,修改了 resample 参数的默认值为 PILImageResampling.BICUBIC
    def resize(
        self,
        image: np.ndarray,
        size: Dict[str, int],
        resample: PILImageResampling = PILImageResampling.BICUBIC,
        data_format: Optional[Union[str, ChannelDimension]] = None,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
        **kwargs,
    ) -> np.ndarray:
        """
        Resize an image to `(size["height"], size["width"])`.

        Args:
            image (`np.ndarray`):
                Image to resize.
            size (`Dict[str, int]`):
                Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
            resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
                `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
            data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format for the output image. If unset, the channel dimension format of the input
                image is used. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
                - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
            input_data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format for the input image. If unset, the channel dimension format is inferred
                from the input image. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
                - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.

        Returns:
            `np.ndarray`: The resized image.
        """
        # Ensure `size` dictionary contains required keys
        size = get_size_dict(size)
        if "height" not in size or "width" not in size:
            raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
        
        # Prepare output size tuple
        output_size = (size["height"], size["width"])
        
        # Call the `resize` function to resize the image
        return resize(
            image,
            size=output_size,
            resample=resample,
            data_format=data_format,
            input_data_format=input_data_format,
            **kwargs,
        )

    def preprocess(
        self,
        images: ImageInput,
        do_center_crop: Optional[bool] = None,
        crop_size: Optional[Dict[str, int]] = None,
        do_resize: Optional[bool] = None,
        size: Optional[Dict[str, int]] = None,
        resample: PILImageResampling = None,
        do_rescale: Optional[bool] = None,
        rescale_factor: Optional[float] = None,
        do_normalize: Optional[bool] = None,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        data_format: ChannelDimension = ChannelDimension.FIRST,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
        **kwargs,

.\models\perceiver\modeling_perceiver.py

# 设置文件编码为 UTF-8

# 版权声明,声明代码版权归 Deepmind 和 HuggingFace Inc. 团队所有,保留所有权利

# 根据 Apache 许可证版本 2.0 进行许可
# 在遵守许可证的情况下,您可以使用此文件,详细信息请参见许可证
# 您可以在以下网址获取许可证副本:http://www.apache.org/licenses/LICENSE-2.0

# 除非法律另有规定或书面同意,否则不得以任何方式使用此软件
# 此软件按"原样"提供,不提供任何明示或暗示的保证或条件
# 请参阅许可证以获取特定于语言的权限和限制
""" PyTorch Perceiver 模型。"""

# 导入必要的库和模块
import abc  # 抽象基类模块
import math  # 数学函数模块
from dataclasses import dataclass  # 用于定义数据类的装饰器
from functools import reduce  # 函数工具模块中的reduce函数
from operator import __add__  # 运算符模块中的add函数
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union  # 类型提示

import numpy as np  # 导入 NumPy 库,用于数值计算
import torch  # 导入 PyTorch 深度学习库
import torch.utils.checkpoint  # PyTorch 的 checkpoint 模块,用于内存优化
from torch import nn  # PyTorch 的神经网络模块
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss  # PyTorch 的损失函数

from ...activations import ACT2FN  # 导入激活函数映射
from ...modeling_outputs import BaseModelOutputWithCrossAttentions  # 模型输出类,包含交叉注意力
from ...modeling_utils import PreTrainedModel  # 预训练模型基类
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, meshgrid, prune_linear_layer  # PyTorch 工具函数
from ...utils import (  # 通用工具函数和类
    ModelOutput,  # 模型输出基类
    add_start_docstrings,  # 为函数添加文档字符串的装饰器
    add_start_docstrings_to_model_forward,  # 为模型前向方法添加文档字符串的装饰器
    logging,  # 日志记录模块
    replace_return_docstrings,  # 替换返回文档字符串的工具函数
)
from .configuration_perceiver import PerceiverConfig  # 导入 Perceiver 模型的配置类

# 类型别名定义
ModalitySizeType = Mapping[str, int]  # 模态大小类型别名,映射字符串到整数
PreprocessorOutputType = Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]  # 预处理器输出类型别名
PreprocessorType = Callable[..., PreprocessorOutputType]  # 预处理器类型别名
PostprocessorType = Callable[..., Any]  # 后处理器类型别名

# 获取日志记录器
logger = logging.get_logger(__name__)

# 用于文档的模型检查点和配置常量
_CHECKPOINT_FOR_DOC = "deepmind/language-perceiver"  # 模型检查点用于文档
_CONFIG_FOR_DOC = "PerceiverConfig"  # Perceiver 模型配置用于文档

# 预训练模型的存档列表
PERCEIVER_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "deepmind/language-perceiver",  # Deepmind 的语言 Perceiver 模型
    # 更多 Perceiver 模型存档可以在此处查看 https://huggingface.co/models?filter=perceiver
]

@dataclass
class PerceiverModelOutput(ModelOutput):
    """
    Perceiver 模型输出的基类,包含可能的隐藏状态、注意力和交叉注意力。
    
    这个类使用 dataclass 装饰器来定义数据类,它是一个轻量级的数据结构,用于表示简单的值对象。
    """

    # 该类用于描述 Perceiver 模型的输出,包含了可能的隐藏状态、注意力和交叉注意力等信息
    """
    Args:
        logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
            分类(或回归,如果config.num_labels==1)分数,即SoftMax之前的分数。
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
            模型最后一层输出的隐藏状态序列。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
            `torch.FloatTensor`元组(一个用于嵌入输出 + 一个用于每层输出),形状为`(batch_size, sequence_length, hidden_size)`。
            模型每层的隐藏状态,以及初始嵌入输出。
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
            `torch.FloatTensor`元组(每层一个),形状为`(batch_size, num_heads, sequence_length, sequence_length)`。
            自注意力头中注意力权重的 softmax 后的结果,用于计算加权平均值。
        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
            used to compute the weighted average in the cross-attention heads.
            `torch.FloatTensor`元组(每层一个),形状为`(batch_size, num_heads, sequence_length, sequence_length)`。
            解码器的交叉注意力层中注意力权重的 softmax 后的结果,用于计算加权平均值。
    """

    logits: torch.FloatTensor = None  # 分类(或回归)分数的张量,形状为`(batch_size, num_labels)`
    last_hidden_state: torch.FloatTensor = None  # 模型最后一层输出的隐藏状态张量,形状为`(batch_size, sequence_length, hidden_size)`
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None  # 模型每层的隐藏状态的元组张量,形状为`(batch_size, sequence_length, hidden_size)`
    attentions: Optional[Tuple[torch.FloatTensor]] = None  # 自注意力头中的注意力权重的元组张量,形状为`(batch_size, num_heads, sequence_length, sequence_length)`
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None  # 交叉注意力头中的注意力权重的元组张量,形状为`(batch_size, num_heads, sequence_length, sequence_length)`
    # 定义 PerceiverDecoderOutput 类,继承自 ModelOutput 类
    @dataclass
    class PerceiverDecoderOutput(ModelOutput):
        """
        Base class for Perceiver decoder outputs, with potential cross-attentions.

        Args:
            logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
                Output of the basic decoder.
            cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
                Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
                sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
                used to compute the weighted average in the cross-attention heads.
        """

        # 定义 logits 属性,类型为 torch.FloatTensor,形状为 (batch_size, num_labels),存储基本解码器的输出
        logits: torch.FloatTensor = None
        # 定义 cross_attentions 属性,类型为可选的元组,如果传入参数 output_attentions=True 或者 config.output_attentions=True 则会返回,存储解码器的跨注意力层的注意力权重
        cross_attentions: Optional[Tuple[torch.FloatTensor]] = None


    # 定义 PerceiverMaskedLMOutput 类,继承自 ModelOutput 类
    @dataclass
    class PerceiverMaskedLMOutput(ModelOutput):
        """
        Base class for Perceiver's masked language model outputs.

        Args:
            loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
                Masked language modeling (MLM) loss.
            logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
                Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
            hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
                Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
                shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
                plus the initial embedding outputs.
            attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
                Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, num_latents,
                num_latents)`. Attentions weights after the attention softmax, used to compute the weighted average in the
                self-attention heads.
            cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
                Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
                sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
                used to compute the weighted average in the cross-attention heads.
        """

        # 定义 loss 属性,类型为可选的 torch.FloatTensor,形状为 (1,),当提供 labels 参数时返回,存储掩码语言模型(MLM)损失
        loss: Optional[torch.FloatTensor] = None
        # 定义 logits 属性,类型为 torch.FloatTensor,形状为 (batch_size, sequence_length, config.vocab_size),存储语言建模头的预测分数(SoftMax之前的每个词汇标记的分数)
        logits: torch.FloatTensor = None
        # 定义 hidden_states 属性,类型为可选的元组,如果传入参数 output_hidden_states=True 或者 config.output_hidden_states=True 则会返回,存储模型在每一层输出之后的隐藏状态 plus 初始嵌入输出
        hidden_states: Optional[Tuple[torch.FloatTensor]] = None
        # 定义 attentions 属性,类型为可选的元组,如果传入参数 output_attentions=True 或者 config.output_attentions=True 则会返回,存储注意力softmax后的注意权重,用于计算自注意力头中的加权平均值
        attentions: Optional[Tuple[torch.FloatTensor]] = None
        # 定义 cross_attentions 属性,类型为可选的元组,如果传入参数 output_attentions=True 或者 config.output_attentions=True 则会返回,存储解码器的跨注意力层的注意力权重
        cross_attentions: Optional[Tuple[torch.FloatTensor]] = None


    # 定义 PerceiverClassifierOutput 类,继承自 ModelOutput 类
    @dataclass
    class PerceiverClassifierOutput(ModelOutput):
    """
    Perceiver 模型的输出基类,适用于序列/图像分类模型、光流和多模态自编码。
    
    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            分类(或回归,如果 `config.num_labels==1`)的损失值。
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            分类(或回归,如果 `config.num_labels==1`)的分数(SoftMax 之前)。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            包含 `torch.FloatTensor` 元组的隐藏状态(如果传递了 `output_hidden_states=True` 或 `config.output_hidden_states=True`)。
            形状为 `(batch_size, sequence_length, hidden_size)`,模型在每一层输出后的隐藏状态以及初始嵌入输出。
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            包含 `torch.FloatTensor` 元组的注意力权重(如果传递了 `output_attentions=True` 或 `config.output_attentions=True`)。
            形状为 `(batch_size, num_heads, sequence_length, sequence_length)`,经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            包含 `torch.FloatTensor` 元组的交叉注意力权重(如果传递了 `output_attentions=True` 或 `config.output_attentions=True`)。
            形状为 `(batch_size, num_heads, sequence_length, sequence_length)`,解码器的交叉注意力层的注意力权重,经过注意力 softmax 后用于计算加权平均值。
    """
    
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
    """实现Perceiver模型的自注意力机制模块。可以用于编码器和解码器中。"""

    def __init__(
        self,
        config,
        is_cross_attention=False,
        qk_channels=None,
        v_channels=None,
        num_heads=1,
        q_dim=None,
        kv_dim=None,
    ):
        super().__init__()
        self.num_heads = num_heads
        # Q和K必须具有相同数量的通道。
        # 默认保持Q的输入形状。
        if qk_channels is None:
            qk_channels = q_dim
        # V的通道数确定了QKV-attention输出的形状。
        # 默认使用与键-查询操作中使用的通道数相同。
        if v_channels is None:
            v_channels = qk_channels
        if qk_channels % num_heads != 0:
            raise ValueError(f"qk_channels ({qk_channels})必须能被num_heads ({num_heads})整除。")
        if v_channels % num_heads != 0:
            raise ValueError(f"v_channels ({v_channels})必须能被num_heads ({num_heads})整除。")

        self.qk_channels = qk_channels
        self.v_channels = v_channels
        self.qk_channels_per_head = self.qk_channels // num_heads
        self.v_channels_per_head = self.v_channels // num_heads

        # 层归一化
        self.layernorm1 = nn.LayerNorm(q_dim)
        self.layernorm2 = nn.LayerNorm(kv_dim) if is_cross_attention else nn.Identity()

        # 投影矩阵
        self.query = nn.Linear(q_dim, qk_channels)
        self.key = nn.Linear(kv_dim, qk_channels)
        self.value = nn.Linear(kv_dim, v_channels)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x, channels_per_head):
        """将张量重塑为注意力分数计算所需的形状。"""
        new_x_shape = x.size()[:-1] + (self.num_heads, channels_per_head)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs: Optional[torch.FloatTensor] = None,
        inputs_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
        hidden_states = self.layernorm1(hidden_states)
        # 对隐藏状态进行 Layer Normalization

        inputs = self.layernorm2(inputs)
        # 对输入进行 Layer Normalization

        # Project queries, keys and values to a common feature dimension. If this is instantiated as a cross-attention module,
        # the keys and values come from the inputs; the attention mask needs to be such that the inputs's non-relevant tokens are not attended to.
        is_cross_attention = inputs is not None
        # 判断是否为跨注意力模块

        queries = self.query(hidden_states)
        # 从隐藏状态计算查询

        if is_cross_attention:
            keys = self.key(inputs)
            # 如果是跨注意力模块,从输入计算键
            values = self.value(inputs)
            # 如果是跨注意力模块,从输入计算值
            attention_mask = inputs_mask
            # 如果是跨注意力模块,使用输入的注意力掩码
        else:
            keys = self.key(hidden_states)
            # 如果不是跨注意力模块,从隐藏状态计算键
            values = self.value(hidden_states)
            # 如果不是跨注意力模块,从隐藏状态计算值

        # Reshape channels for multi-head attention.
        # We reshape from (batch_size, time, channels) to (batch_size, num_heads, time, channels per head)
        queries = self.transpose_for_scores(queries, self.qk_channels_per_head)
        # 调整查询张量以进行多头注意力计算
        keys = self.transpose_for_scores(keys, self.qk_channels_per_head)
        # 调整键张量以进行多头注意力计算
        values = self.transpose_for_scores(values, self.v_channels_per_head)
        # 调整值张量以进行多头注意力计算

        # Take the dot product between the queries and keys to get the raw attention scores.
        attention_scores = torch.matmul(queries, keys.transpose(-1, -2))
        # 计算查询和键的点积以获得原始注意力分数

        batch_size, num_heads, seq_len, q_head_dim = queries.shape
        _, _, _, v_head_dim = values.shape
        hiddens = self.num_heads * v_head_dim
        # 计算中间变量

        attention_scores = attention_scores / math.sqrt(q_head_dim)
        # 缩放注意力分数

        if attention_mask is not None:
            # Apply the attention mask (precomputed for all layers in PerceiverModel forward() function)
            attention_scores = attention_scores + attention_mask
            # 应用注意力掩码

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        # 将注意力分数归一化为概率

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)
        # 使用 dropout 随机丢弃整个 token 的注意力概率,这种做法源自于原始的 Transformer 论文

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask
            # 如果有头部掩码,应用头部掩码

        context_layer = torch.matmul(attention_probs, values)
        # 计算上下文张量

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        # 调整上下文张量的维度顺序

        new_context_layer_shape = context_layer.size()[:-2] + (hiddens,)
        context_layer = context_layer.view(*new_context_layer_shape)
        # 调整上下文张量的形状

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
        # 准备输出

        return outputs
        # 返回计算结果
class PerceiverSelfOutput(nn.Module):
    def __init__(self, config, input_channels, output_channels):
        super().__init__()
        # 初始化一个全连接层,输入通道数为input_channels,输出通道数为output_channels
        self.dense = nn.Linear(input_channels, output_channels)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 将输入的隐藏状态通过全连接层进行线性变换
        hidden_states = self.dense(hidden_states)
        return hidden_states


class PerceiverAttention(nn.Module):
    """Attention module, including a dense block."""

    def __init__(
        self,
        config,
        is_cross_attention=False,
        qk_channels=None,
        v_channels=None,
        num_heads=1,
        q_dim=None,
        kv_dim=None,
        use_query_residual=True,
    ):
        super().__init__()
        
        # 根据是否是交叉注意力机制和参数配置设置查询键值通道数和值通道数
        if is_cross_attention and qk_channels is None:
            if config.cross_attention_shape_for_attention == "q":
                qk_channels = q_dim
            elif config.cross_attention_shape_for_attention == "kv":
                qk_channels = kv_dim
            else:
                raise ValueError(
                    f"Unknown value {config.cross_attention_shape_for_attention} for "
                    "cross_attention_shape_for_attention."
                )
        else:
            if qk_channels is None:
                qk_channels = q_dim
            if v_channels is None:
                v_channels = qk_channels
        
        # 初始化自注意力层
        self.self = PerceiverSelfAttention(
            config,
            is_cross_attention=is_cross_attention,
            qk_channels=qk_channels,
            v_channels=v_channels,
            num_heads=num_heads,
            q_dim=q_dim,
            kv_dim=kv_dim,
        )
        
        # 设置输出层,根据是否是交叉注意力机制确定输出通道数
        output_channels = None
        if is_cross_attention:
            output_channels = q_dim
        else:
            if output_channels is None:
                output_channels = v_channels
        self.output = PerceiverSelfOutput(config, input_channels=self.self.v_channels, output_channels=output_channels)
        
        self.use_query_residual = use_query_residual
        self.pruned_heads = set()

    def prune_heads(self, heads):
        # 如果没有需要剪枝的头部,直接返回
        if len(heads) == 0:
            return
        
        # 寻找可剪枝的头部及其索引
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )

        # 对线性层进行剪枝
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # 更新超参数并记录已剪枝的头部
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs: Optional[torch.FloatTensor] = None,
        inputs_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # 使用 self.self 方法处理输入的隐藏状态,返回处理后的输出
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            inputs,
            inputs_mask,
            output_attentions,
        )

        # 将 self_outputs[0] 通过 self.output 进行输出投影
        attention_output = self.output(self_outputs[0])

        # 如果指定使用查询残差连接
        if self.use_query_residual:
            # 将 attention_output 添加到原始隐藏状态 hidden_states 上
            attention_output = attention_output + hidden_states

        # 组装最终输出,包括 attention_output 和可能的其他输出
        outputs = (attention_output,) + self_outputs[1:]  # 如果需要输出注意力权重,也加入到 outputs 中
        return outputs
class PerceiverMLP(nn.Module):
    """A Transformer-style dense module to follow attention."""

    def __init__(self, config, input_size, widening_factor):
        super().__init__()
        # 第一层全连接层,将输入特征大小映射到扩展因子倍数的输入特征大小
        self.dense1 = nn.Linear(input_size, widening_factor * input_size)
        # 根据配置选择激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act
        # 第二层全连接层,将扩展后的特征映射回原始输入特征大小
        self.dense2 = nn.Linear(widening_factor * input_size, input_size)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 前向传播:全连接层1
        hidden_states = self.dense1(hidden_states)
        # 前向传播:激活函数
        hidden_states = self.intermediate_act_fn(hidden_states)
        # 前向传播:全连接层2
        hidden_states = self.dense2(hidden_states)
        return hidden_states


class PerceiverLayer(nn.Module):
    def __init__(
        self,
        config,
        is_cross_attention=False,
        qk_channels=None,
        v_channels=None,
        num_heads=1,
        q_dim=None,
        kv_dim=None,
        widening_factor=4,
        use_query_residual=True,
    ):
        super().__init__()
        # 分块大小
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        # 序列长度维度
        self.seq_len_dim = 1
        # 注意力机制
        self.attention = PerceiverAttention(
            config,
            is_cross_attention=is_cross_attention,
            qk_channels=qk_channels,
            v_channels=v_channels,
            num_heads=num_heads,
            q_dim=q_dim,
            kv_dim=kv_dim,
            use_query_residual=use_query_residual,
        )
        # Layer normalization
        self.layernorm = nn.LayerNorm(q_dim)
        # MLP层
        self.mlp = PerceiverMLP(config, input_size=q_dim, widening_factor=widening_factor)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs: Optional[torch.FloatTensor] = None,
        inputs_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # 调用注意力机制的前向传播
        attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            inputs,
            inputs_mask,
            output_attentions,
        )
        attention_output = attention_outputs[0]

        outputs = attention_outputs[1:]  # 如果输出注意力权重,则添加

        # 对注意力输出应用分块处理,返回分块后的输出
        layer_output = apply_chunking_to_forward(
            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
        )

        # 残差连接
        layer_output = layer_output + attention_output

        outputs = (layer_output,) + outputs

        return outputs

    def feed_forward_chunk(self, attention_output):
        # Layer normalization
        layer_output = self.layernorm(attention_output)
        # MLP层
        layer_output = self.mlp(layer_output)
        return layer_output


class PerceiverEncoder(nn.Module):
    """The Perceiver Encoder: a scalable, fully attentional encoder."""
    def __init__(self, config, kv_dim=None):
        super().__init__()
        self.config = config

        # Check that we can use multihead-attention with these shapes.
        # 检查是否可以使用这些形状进行多头注意力
        if config.d_latents % config.num_self_attention_heads != 0:
            raise ValueError(
                f"num_z_channels ({config.d_latents}) must be divisible by"
                f" num_self_attend_heads ({config.num_self_attention_heads})."
            )
        if config.d_latents % config.num_cross_attention_heads != 0:
            raise ValueError(
                f"num_z_channels ({config.d_latents}) must be divisible by"
                f" num_cross_attend_heads ({config.num_cross_attention_heads})."
            )

        # Construct the cross attention layer.
        # 构建跨注意力层
        self.cross_attention = PerceiverLayer(
            config,
            is_cross_attention=True,
            qk_channels=config.qk_channels,
            v_channels=config.v_channels,
            num_heads=config.num_cross_attention_heads,
            q_dim=config.d_latents,
            kv_dim=kv_dim,
            widening_factor=config.cross_attention_widening_factor,
            use_query_residual=config.use_query_residual,
        )

        # Construct a single block of self-attention layers.
        # 构建一个自注意力层块
        # 通过多次应用这个块,可以得到更深的网络结构
        self_attention_layers = []
        for _ in range(config.num_self_attends_per_block):
            layer = PerceiverLayer(
                config,
                is_cross_attention=False,
                qk_channels=config.qk_channels,
                v_channels=config.v_channels,
                num_heads=config.num_self_attention_heads,
                q_dim=config.d_latents,
                kv_dim=config.d_latents,
                widening_factor=config.self_attention_widening_factor,
            )
            self_attention_layers.append(layer)

        self.self_attends = nn.ModuleList(self_attention_layers)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs: Optional[torch.FloatTensor] = None,
        inputs_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = True,
        ) -> Union[Tuple, BaseModelOutputWithCrossAttentions]:
        # 如果不需要输出隐藏状态,设置为空元组;否则初始化为 None
        all_hidden_states = () if output_hidden_states else None
        # 如果不需要输出注意力权重,设置为空元组;否则初始化为 None
        all_self_attentions = () if output_attentions else None
        # 如果不需要输出交叉注意力权重,设置为空元组;否则初始化为 None
        all_cross_attentions = () if output_attentions else None

        # 对 latent(hidden_states)和 inputs 之间进行交叉注意力计算:
        layer_outputs = self.cross_attention(
            hidden_states,
            attention_mask=attention_mask,
            head_mask=None,
            inputs=inputs,
            inputs_mask=inputs_mask,
            output_attentions=output_attentions,
        )
        # 更新 hidden_states 为交叉注意力计算的输出的第一个元素
        hidden_states = layer_outputs[0]

        # 如果需要输出注意力权重,将本次计算的注意力权重添加到 all_cross_attentions 中
        if output_attentions:
            all_cross_attentions = all_cross_attentions + (layer_outputs[1],)

        # 多次应用自注意力层块:
        for _ in range(self.config.num_blocks):
            for i, layer_module in enumerate(self.self_attends):
                # 如果需要输出隐藏状态,将当前 hidden_states 添加到 all_hidden_states 中
                if output_hidden_states:
                    all_hidden_states = all_hidden_states + (hidden_states,)

                # 获取当前层的头部掩码
                layer_head_mask = head_mask[i] if head_mask is not None else None

                # 执行当前自注意力层的前向传播
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask=attention_mask,
                    head_mask=layer_head_mask,
                    output_attentions=output_attentions,
                )

                # 更新 hidden_states 为当前自注意力层的输出的第一个元素
                hidden_states = layer_outputs[0]
                # 如果需要输出注意力权重,将本次计算的注意力权重添加到 all_self_attentions 中
                if output_attentions:
                    all_self_attentions = all_self_attentions + (layer_outputs[1],)

            # 如果需要输出隐藏状态,将当前 hidden_states 添加到 all_hidden_states 中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

        # 如果不返回字典形式的结果,将结果以元组形式返回,过滤掉值为 None 的项
        if not return_dict:
            return tuple(
                v
                for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
                if v is not None
            )
        # 返回一个 BaseModelOutputWithCrossAttentions 对象,包含最后隐藏状态、所有隐藏状态、自注意力和交叉注意力
        return BaseModelOutputWithCrossAttentions(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )
class PerceiverPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 设置默认的配置类为PerceiverConfig
    config_class = PerceiverConfig
    # 指定基础模型前缀为"perceiver"
    base_model_prefix = "perceiver"
    # 指定主要输入名称为"inputs"
    main_input_name = "inputs"

    def _init_weights(self, module):
        """Initialize the weights"""
        # 如果module是Linear或者Conv2d类型
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            # 使用正态分布初始化权重数据,均值为0,标准差为config中指定的initializer_range
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 如果有偏置项,则将偏置数据初始化为零
            if module.bias is not None:
                module.bias.data.zero_()
        # 如果module具有"latents"属性
        elif hasattr(module, "latents"):
            # 使用正态分布初始化latents属性数据,均值为0,标准差为config中指定的initializer_range
            module.latents.data.normal_(mean=0.0, std=self.config.initializer_range)
        # 如果module具有"position_embeddings"属性并且是PerceiverTrainablePositionEncoding类型
        elif hasattr(module, "position_embeddings") and isinstance(module, PerceiverTrainablePositionEncoding):
            # 使用正态分布初始化position_embeddings数据,均值为0,标准差为config中指定的initializer_range
            module.position_embeddings.data.normal_(mean=0.0, std=self.config.initializer_range)
        # 如果module是ParameterDict类型
        elif isinstance(module, nn.ParameterDict):
            # 对于每个modality,使用正态分布初始化数据,均值为0,标准差为config中指定的initializer_range
            for modality in module.keys():
                module[modality].data.normal_(mean=0.0, std=self.config.initializer_range)
        # 如果module是Embedding类型
        elif isinstance(module, nn.Embedding):
            # 使用正态分布初始化权重数据,均值为0,标准差为config中指定的initializer_range
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 如果指定了padding_idx,则将该索引位置的权重初始化为零
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        # 如果module是LayerNorm类型
        elif isinstance(module, nn.LayerNorm):
            # 将偏置数据初始化为零
            module.bias.data.zero_()
            # 将权重数据初始化为1
            module.weight.data.fill_(1.0)


PERCEIVER_START_DOCSTRING = r"""
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`PerceiverConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

PERCEIVER_MODEL_START_DOCSTRING = r"""
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.
    # 参数说明:
    # config ([`PerceiverConfig`]): 模型配置类,包含模型的所有参数。
    #                           通过配置文件初始化不会加载模型的权重,仅加载配置。
    #                           查看 [`~PreTrainedModel.from_pretrained`] 方法以加载模型权重。
    # decoder (*DecoderType*, *optional*):
    #         可选的解码器,用于解码编码器的潜在表示。示例包括
    #         *transformers.models.perceiver.modeling_perceiver.PerceiverBasicDecoder*,
    #         *transformers.models.perceiver.modeling_perceiver.PerceiverClassificationDecoder*,
    #         *transformers.models.perceiver.modeling_perceiver.PerceiverMultimodalDecoder*。
    # input_preprocessor (*PreprocessorType*, *optional*):
    #         可选的输入预处理器。示例包括
    #         *transformers.models.perceiver.modeling_perceiver.PerceiverImagePreprocessor*,
    #         *transformers.models.perceiver.modeling_perceiver.PerceiverAudioPreprocessor*,
    #         *transformers.models.perceiver.modeling_perceiver.PerceiverTextPreprocessor*,
    #         *transformers.models.perceiver.modeling_perceiver.PerceiverMultimodalPreprocessor*。
    # output_postprocessor (*PostprocessorType*, *optional*):
    #         可选的输出后处理器。示例包括
    #         *transformers.models.perceiver.modeling_perceiver.PerceiverImagePostprocessor*,
    #         *transformers.models.perceiver.modeling_perceiver.PerceiverAudioPostprocessor*,
    #         *transformers.models.perceiver.modeling_perceiver.PerceiverClassificationPostprocessor*,
    #         *transformers.models.perceiver.modeling_perceiver.PerceiverProjectionPostprocessor*,
    #         *transformers.models.perceiver.modeling_perceiver.PerceiverMultimodalPostprocessor*。

    # 注意:您可以定义自己的解码器、预处理器和/或后处理器以适应您的使用案例。
"""
    Args:
        inputs (`torch.FloatTensor`):
            Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
        attention_mask (`torch.FloatTensor` of shape `{0}`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""

@add_start_docstrings(
    """The Perceiver: a scalable, fully attentional architecture.""",
    PERCEIVER_MODEL_START_DOCSTRING,
)
class PerceiverModel(PerceiverPreTrainedModel):
    """
    The PerceiverModel class implements a perceiver architecture for various input modalities.

    Args:
        config (PretrainedConfig):
            The model configuration class instance.
        decoder (Optional):
            Optional decoder for the model.
        input_preprocessor (PreprocessorType, Optional):
            Optional input preprocessor for handling input data.
        output_postprocessor (PostprocessorType, Optional):
            Optional output postprocessor for handling model outputs.
    """

    def __init__(
        self,
        config,
        decoder=None,
        input_preprocessor: PreprocessorType = None,
        output_postprocessor: PostprocessorType = None,
    ):
        """
        Initialize the PerceiverModel with given configuration and optional components.

        Args:
            config (PretrainedConfig):
                The model configuration class instance.
            decoder (Optional):
                Optional decoder for the model.
            input_preprocessor (PreprocessorType, Optional):
                Optional input preprocessor for handling input data.
            output_postprocessor (PostprocessorType, Optional):
                Optional output postprocessor for handling model outputs.
        """
        super().__init__(config)
        self.config = config

        self.input_preprocessor = input_preprocessor
        self.output_postprocessor = output_postprocessor
        self.embeddings = PerceiverEmbeddings(config)
        self.encoder = PerceiverEncoder(
            config, kv_dim=input_preprocessor.num_channels if input_preprocessor is not None else config.d_model
        )
        self.decoder = decoder

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        """
        Returns the latent embeddings used as inputs to the perceiver model.

        Returns:
            torch.Tensor: The latent embeddings tensor.
        """
        return self.embeddings.latents

    def set_input_embeddings(self, value):
        """
        Sets the input embeddings of the perceiver model.

        Args:
            value (torch.Tensor): The new input embeddings tensor.
        """
        self.embeddings.latents = value

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model.

        Args:
            heads_to_prune (dict):
                Dictionary of {layer_num: list of heads to prune in this layer}. See base class PreTrainedModel.
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
    @replace_return_docstrings(output_type=PerceiverModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(self, **inputs):
        """
        Perform a forward pass through the PerceiverModel.

        Args:
            **inputs (keyword arguments):
                The input data. Can contain various inputs depending on the model configuration.

        Returns:
            PerceiverModelOutput or tuple:
                The model outputs. Can contain attentions, hidden states, and additional model-specific outputs.
        """
        return super().forward(**inputs)
    # 定义神经网络模型的前向传播函数
    def forward(
        self,
        # 输入数据张量,通常是浮点型张量
        inputs: torch.FloatTensor,
        # 注意力掩码张量,可选,用于控制注意力机制的作用范围
        attention_mask: Optional[torch.FloatTensor] = None,
        # 子采样输出点的字典,可选,包含不同子样本输出的张量
        subsampled_output_points: Optional[Dict[str, torch.Tensor]] = None,
        # 头部掩码张量,可选,用于掩盖特定头部的注意力权重
        head_mask: Optional[torch.FloatTensor] = None,
        # 是否输出注意力权重信息,可选
        output_attentions: Optional[bool] = None,
        # 是否输出隐藏状态信息,可选
        output_hidden_states: Optional[bool] = None,
        # 是否返回字典形式的结果,可选
        return_dict: Optional[bool] = None,
@add_start_docstrings("""Example use of Perceiver for masked language modeling.""", PERCEIVER_START_DOCSTRING)
class PerceiverForMaskedLM(PerceiverPreTrainedModel):
    def __init__(self, config: PerceiverConfig):
        super().__init__(config)

        # 实例化文本预处理器
        text_preprocessor = PerceiverTextPreprocessor(config)

        # 定义用于解码器的可训练位置编码参数
        trainable_position_encoding_kwargs_decoder = {
            "num_channels": text_preprocessor.num_channels,
            "index_dims": config.max_position_embeddings,
        }

        # 创建 PerceiverModel 实例,配置输入预处理器和解码器
        self.perceiver = PerceiverModel(
            config,
            input_preprocessor=text_preprocessor,
            decoder=PerceiverBasicDecoder(
                config,
                output_num_channels=config.d_latents,
                output_index_dims=config.max_position_embeddings,  # 需要预先定义输入的序列长度
                num_channels=text_preprocessor.num_channels,
                qk_channels=8 * 32,
                v_channels=text_preprocessor.num_channels,
                num_heads=8,
                use_query_residual=False,
                final_project=False,
                trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
            ),
        )

        # 实例化 PerceiverEmbeddingDecoder
        self.embedding_decoder = PerceiverEmbeddingDecoder(config)

        # 初始化权重并进行最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=PerceiverMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        inputs: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        labels: Optional[torch.Tensor] = None,
        return_dict: Optional[bool] = None,
        input_ids: Optional[torch.Tensor] = None,
        # 具体的模型前向传播方法,参见 PERCEIVER_INPUTS_DOCSTRING 的格式说明
        # 输出类型为 PerceiverMaskedLMOutput,配置类为 _CONFIG_FOR_DOC



@add_start_docstrings("""Example use of Perceiver for text classification.""", PERCEIVER_START_DOCSTRING)
class PerceiverForSequenceClassification(PerceiverPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        # 定义用于解码器的可训练位置编码参数
        trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1}

        # 设置分类数量
        self.num_labels = config.num_labels

        # 创建 PerceiverModel 实例,配置输入预处理器和分类解码器
        self.perceiver = PerceiverModel(
            config,
            input_preprocessor=PerceiverTextPreprocessor(config),
            decoder=PerceiverClassificationDecoder(
                config,
                num_channels=config.d_latents,
                trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
                use_query_residual=True,
            ),
        )

        # 初始化权重并进行最终处理
        self.post_init()
    # 将模型的输入格式的文档字符串添加到前向传播方法上,描述其参数是批量大小和序列长度
    @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    # 替换返回值的文档字符串,指定输出类型为PerceiverClassifierOutput,并使用_CONFIG_FOR_DOC作为配置类
    @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC)
    # 定义模型的前向传播方法
    def forward(
        self,
        inputs: Optional[torch.Tensor] = None,  # 模型的输入张量,默认为None
        attention_mask: Optional[torch.Tensor] = None,  # 注意力遮罩张量,默认为None
        head_mask: Optional[torch.Tensor] = None,  # 头部遮罩张量,默认为None
        output_attentions: Optional[bool] = None,  # 是否输出注意力权重,默认为None
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态,默认为None
        labels: Optional[torch.Tensor] = None,  # 标签张量,默认为None
        return_dict: Optional[bool] = None,  # 是否以字典形式返回,默认为None
        input_ids: Optional[torch.Tensor] = None,  # 输入ID张量,默认为None
@add_start_docstrings(
    """
Example use of Perceiver for image classification, for tasks such as ImageNet.

This model uses fixed 2D Fourier position embeddings. As shown in the paper, this model can achieve a top-1 accuracy of
79.0 on ImageNet, and 84.5 when pre-trained on a large-scale dataset (i.e. JFT).
""",
    PERCEIVER_START_DOCSTRING,
)
class PerceiverForImageClassificationFixed(PerceiverPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        # Define kwargs for trainable position encoding in preprocessor and decoder
        trainable_position_encoding_kwargs_preprocessor = {"num_channels": 256, "index_dims": config.image_size ** 2}
        trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1}

        # Initialize number of labels from config
        self.num_labels = config.num_labels

        # Initialize Perceiver model with fixed 2D Fourier position embeddings
        self.perceiver = PerceiverModel(
            config,
            input_preprocessor=PerceiverImagePreprocessor(
                config,
                prep_type="conv1x1",
                spatial_downsample=1,
                out_channels=256,
                position_encoding_type="fourier",
                concat_or_add_pos="add",
                project_pos_dim=256,
                trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_preprocessor,
            ),
            decoder=PerceiverClassificationDecoder(
                config,
                num_channels=config.d_latents,
                trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
                use_query_residual=True,
            ),
        )

        # Initialize weights and apply final processing
        self.post_init()

    @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        inputs: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        labels: Optional[torch.Tensor] = None,
        return_dict: Optional[bool] = None,
        pixel_values: Optional[torch.Tensor] = None,
        ):
        """
        Perform forward pass of the PerceiverForImageClassificationFixed model.

        Args:
            inputs (torch.Tensor, optional): Input tensor of shape (batch_size, sequence_length).
            attention_mask (torch.Tensor, optional): Mask tensor indicating which elements should be attended to.
            head_mask (torch.Tensor, optional): Mask tensor for attention heads.
            output_attentions (bool, optional): Whether to output attentions.
            output_hidden_states (bool, optional): Whether to output hidden states.
            labels (torch.Tensor, optional): Labels tensor for classification.
            return_dict (bool, optional): Whether to return outputs as a dictionary.
            pixel_values (torch.Tensor, optional): Pixel values tensor for image input.

        Returns:
            PerceiverClassifierOutput or torch.Tensor: Output of the model, depending on return_dict.

        """
        # Forward pass through the Perceiver model
        return self.perceiver(
            inputs=inputs,
            attention_mask=attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            labels=labels,
            return_dict=return_dict,
            pixel_values=pixel_values,
        )
"""
[`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`]
(with `prep_type="pixels"`) to preprocess the input images, and
[`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of
[`PerceiverModel`] into classification logits.
""",
PERCEIVER_START_DOCSTRING,
)
class PerceiverForImageClassificationFourier(PerceiverPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        # 设置傅里叶位置编码的预处理器参数
        fourier_position_encoding_kwargs_preprocessor = {
            "concat_pos": True,  # 是否将位置编码与输入数据连接
            "max_resolution": (224, 224),  # 输入图像的最大分辨率
            "num_bands": 64,  # 傅里叶变换中的频带数量
            "sine_only": False,  # 是否只使用正弦函数作为位置编码的基础
        }
        # 可训练位置编码解码器的参数
        trainable_position_encoding_kwargs_decoder = {
            "num_channels": config.d_latents,  # 潜在表示的通道数
            "index_dims": 1,  # 位置索引的维度
        }

        self.num_labels = config.num_labels
        # 创建Perceiver模型,指定输入预处理器和分类解码器
        self.perceiver = PerceiverModel(
            config,
            input_preprocessor=PerceiverImagePreprocessor(
                config,
                prep_type="pixels",  # 使用像素级别的预处理方式
                spatial_downsample=1,  # 空间下采样因子
                fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
            ),
            decoder=PerceiverClassificationDecoder(
                config,
                num_channels=config.d_latents,
                trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
                use_query_residual=True,  # 使用查询残差连接
            ),
        )

        # 初始化权重并应用最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        inputs: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        labels: Optional[torch.Tensor] = None,
        return_dict: Optional[bool] = None,
        pixel_values: Optional[torch.Tensor] = None,



@add_start_docstrings(
    """
    Example use of Perceiver for image classification, for tasks such as ImageNet.

    This model uses a 2D conv+maxpool preprocessing network. As shown in the paper, this model can achieve a top-1 accuracy
    of 82.1 on ImageNet.

    [`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`]
    (with `prep_type="conv"`) to preprocess the input images, and
    [`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of
    [`PerceiverModel`] into classification logits.
    """,
    PERCEIVER_START_DOCSTRING,
)
class PerceiverForImageClassificationConvProcessing(PerceiverPreTrainedModel):



"""
注释:
- `PerceiverForImageClassificationLearned` 类使用 `PerceiverImagePreprocessor` 来预处理输入图像(使用 `prep_type="pixels"`),并使用 `PerceiverClassificationDecoder` 来将 `PerceiverModel` 的潜在表示解码为分类 logits。
- `PerceiverForImageClassificationConvProcessing` 类示例用于图像分类任务(例如 ImageNet),使用 2D 卷积+最大池化预处理网络,可以在 ImageNet 上达到82.1%的top-1准确率。
"""
    # 初始化函数,接受一个配置对象作为参数
    def __init__(self, config):
        # 调用父类的初始化方法,传入配置对象
        super().__init__(config)

        # 定义用于预处理的傅里叶位置编码的参数字典
        fourier_position_encoding_kwargs_preprocessor = {
            "concat_pos": True,  # 是否在输入中连接位置编码
            "max_resolution": (56, 56),  # 最大分辨率
            "num_bands": 64,  # 傅里叶变换中使用的波段数
            "sine_only": False,  # 是否只使用正弦函数
        }
        
        # 定义用于解码器的可训练位置编码的参数字典
        trainable_position_encoding_kwargs_decoder = {"num_channels": config.d_latents, "index_dims": 1}

        # 设置实例变量:标签的数量
        self.num_labels = config.num_labels
        
        # 初始化感知器模型,配置输入预处理器和解码器
        self.perceiver = PerceiverModel(
            config,
            input_preprocessor=PerceiverImagePreprocessor(
                config,
                prep_type="conv",  # 预处理类型为卷积
                spatial_downsample=1,  # 空间下采样因子
                position_encoding_type="fourier",  # 位置编码类型为傅里叶
                fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,  # 傅里叶位置编码参数
            ),
            decoder=PerceiverClassificationDecoder(
                config,
                num_channels=config.d_latents,  # 解码器的通道数
                trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,  # 可训练位置编码参数
                use_query_residual=True,  # 是否使用查询残差
            ),
        )

        # 调用初始化后的处理函数
        self.post_init()

    # 重写的前向传播函数,接受多种输入参数并返回模型输出
    @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        inputs: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        labels: Optional[torch.Tensor] = None,
        return_dict: Optional[bool] = None,
        pixel_values: Optional[torch.Tensor] = None,
@add_start_docstrings(
    """
    Example use of Perceiver for optical flow, for tasks such as Sintel and KITTI. [`PerceiverForOpticalFlow`] uses
    [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`] (with *prep_type="patches"*) to preprocess the
    input images, and [`~models.perceiver.modeling_perceiver.PerceiverOpticalFlowDecoder`] to decode the latent
    representation of [`PerceiverModel`].

    As input, one concatenates 2 subsequent frames along the channel dimension and extract a 3 x 3 patch around each pixel
    (leading to 3 x 3 x 3 x 2 = 54 values for each pixel). Fixed Fourier position encodings are used to encode the position
    of each pixel in the patch. Next, one applies the Perceiver encoder. To decode, one queries the latent representation
    using the same encoding used for the input.
    """,
    PERCEIVER_START_DOCSTRING,
)
class PerceiverForOpticalFlow(PerceiverPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        fourier_position_encoding_kwargs_preprocessor = {
            "num_bands": 64,
            "max_resolution": config.train_size,
            "sine_only": False,
            "concat_pos": True,
        }
        fourier_position_encoding_kwargs_decoder = {
            "concat_pos": True,
            "max_resolution": config.train_size,
            "num_bands": 64,
            "sine_only": False,
        }

        # Initialize the image preprocessor for the Perceiver model
        image_preprocessor = PerceiverImagePreprocessor(
            config,
            prep_type="patches",
            spatial_downsample=1,
            conv_after_patching=True,
            conv_after_patching_in_channels=54,
            temporal_downsample=2,
            position_encoding_type="fourier",
            # Set Fourier position encoding parameters for preprocessor
            fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
        )

        # Initialize the Perceiver model with image preprocessor and optical flow decoder
        self.perceiver = PerceiverModel(
            config,
            input_preprocessor=image_preprocessor,
            decoder=PerceiverOpticalFlowDecoder(
                config,
                num_channels=image_preprocessor.num_channels,
                output_image_shape=config.train_size,
                rescale_factor=100.0,
                # Set decoder parameters including position encoding
                use_query_residual=False,
                output_num_channels=2,
                # Specify using Fourier position encoding for decoder
                position_encoding_type="fourier",
                fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_decoder,
            ),
        )

        # Initialize weights and perform post-initialization steps
        self.post_init()

    @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        inputs: Optional[torch.Tensor] = None,  # 输入张量,用于模型的前向传播
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码张量,用于控制模型关注的位置
        head_mask: Optional[torch.Tensor] = None,  # 头部掩码张量,用于屏蔽特定的注意力头部
        output_attentions: Optional[bool] = None,  # 是否输出注意力权重
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态
        labels: Optional[torch.Tensor] = None,  # 目标标签张量,用于光流损失的计算
        return_dict: Optional[bool] = None,  # 是否返回字典形式的输出结果
    ) -> Union[Tuple, PerceiverClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the optical flow loss. Indices should be in `[0, ..., config.num_labels - 1]`.

        Returns:
            根据参数设置返回不同形式的输出结果。

        Examples:
            代码示例,展示了如何使用Perceiver模型处理光流问题。

        ```
        >>> from transformers import PerceiverForOpticalFlow
        >>> import torch

        >>> model = PerceiverForOpticalFlow.from_pretrained("deepmind/optical-flow-perceiver")

        >>> # in the Perceiver IO paper, the authors extract a 3 x 3 patch around each pixel,
        >>> # leading to 3 x 3 x 3 = 27 values for each pixel (as each pixel also has 3 color channels)
        >>> # patches have shape (batch_size, num_frames, num_channels, height, width)
        >>> # the authors train on resolutions of 368 x 496
        >>> patches = torch.randn(1, 2, 27, 368, 496)
        >>> outputs = model(inputs=patches)
        >>> logits = outputs.logits
        >>> list(logits.shape)
        [1, 368, 496, 2]
        ```
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict  # 根据配置确定是否使用字典形式的返回结果

        outputs = self.perceiver(
            inputs=inputs,
            attention_mask=attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )  # 调用Perceiver模型进行前向传播,获取模型输出

        logits = outputs.logits if return_dict else outputs[0]  # 根据是否返回字典形式选择相应的logits输出方式

        loss = None  # 初始化损失为None
        if labels is not None:
            raise NotImplementedError("Optical flow training is not yet supported")  # 如果标签不为空,抛出未实现错误,暂不支持光流训练

        if not return_dict:
            output = (logits,) + outputs[2:]  # 如果不返回字典形式,组合输出为(logits, hidden_states, attentions, cross_attentions)
            return ((loss,) + output) if loss is not None else output  # 返回输出结果,包括损失信息

        return PerceiverClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )  # 返回以PerceiverClassifierOutput形式封装的输出结果
# 导入所需模块和函数
@add_start_docstrings(
    """
    Perceiver 用于多模态(视频)自编码的示例用法,例如 Kinetics-700 数据集。

    [`PerceiverForMultimodalAutoencoding`] 使用 [`~models.perceiver.modeling_perceiver.PerceiverMultimodalPreprocessor`] 来
    预处理三种模态:图像、音频和类标签。这个预处理器使用模态特定的预处理器来单独处理每种模态,然后将它们连接起来。使用可训练的位置编码来
    将每种模态填充到相同数量的通道,以便在时间维度上进行串联。接下来,应用 Perceiver 编码器。

    [`~models.perceiver.modeling_perceiver.PerceiverMultimodalDecoder`] 用于解码 [`PerceiverModel`] 的潜在表示。
    这个解码器使用每种模态特定的解码器来构建查询。解码器的查询基于预处理后的输入。然而,在单个前向传递中自动编码整个视频在计算上是不可行的,
    因此只使用部分解码器查询与潜在表示进行交叉注意力。这由每种模态的子采样索引决定,可以作为额外输入提供给 [`PerceiverForMultimodalAutoencoding`] 的前向传递。

    [`~models.perceiver.modeling_perceiver.PerceiverMultimodalDecoder`] 还将不同模态的解码器查询填充到相同数量的通道,以便在时间维度上进行串联。接下来,使用 [`PerceiverModel`] 的潜在表示进行交叉注意力。

    最后,[`~models.perceiver.modeling_perceiver.PerceiverMultiModalPostprocessor`] 用于将这个张量转换成实际的视频。
    它首先将输出分割成不同的模态,然后为每种模态应用相应的后处理器。

    请注意,在评估过程中通过掩盖分类标签(即简单地为"label"模态提供零张量)时,这个自编码模型变成了 Kinetics 700 视频分类器。
    """,
    PERCEIVER_START_DOCSTRING,
)
# 使用 PerceiverPreTrainedModel 作为基类定义 PerceiverForMultimodalAutoencoding 类
class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel):
    @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC)
    # 重写前向传递函数 forward
    def forward(
        self,
        inputs: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        subsampled_output_points: Optional[Dict[str, torch.Tensor]] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        labels: Optional[torch.Tensor] = None,
        return_dict: Optional[bool] = None,
    ):
        # 下面是位置编码
    # 设置一个参数变量用来传递傅立叶位置编码的参数,默认为None
# 构建位置编码器的函数,根据指定的参数生成不同类型的位置编码

def build_position_encoding(
    out_channels,  # 输出通道数,表示位置编码的通道数目
    project_pos_dim=None,  # 如果指定,将位置编码投影到这个维度
    position_encoding_type="trainable",  # 位置编码的类型,默认为可训练的位置编码
    trainable_position_encoding_kwargs=None,  # 可训练位置编码的额外参数
    fourier_position_encoding_kwargs=None  # 傅立叶位置编码的额外参数
):
    """
    Builds the position encoding.

    Args:
    - out_channels: refers to the number of channels of the position encodings.
    - project_pos_dim: if specified, will project the position encodings to this dimension.
    - position_encoding_type: specifies the type of position encoding to use.
    - trainable_position_encoding_kwargs: additional kwargs for trainable position encoding.
    - fourier_position_encoding_kwargs: additional kwargs for Fourier position encoding.

    Returns:
    - output_pos_enc: the constructed position encoding object.
    - positions_projection: optional projection layer for position encoding.
    """

    if position_encoding_type == "trainable":
        if not trainable_position_encoding_kwargs:
            raise ValueError("Make sure to pass trainable_position_encoding_kwargs")
        output_pos_enc = PerceiverTrainablePositionEncoding(**trainable_position_encoding_kwargs)
    elif position_encoding_type == "fourier":
        # We don't use the index_dims argument, as this is only known during the forward pass
        if not fourier_position_encoding_kwargs:
            raise ValueError("Make sure to pass fourier_position_encoding_kwargs")
        output_pos_enc = PerceiverFourierPositionEncoding(**fourier_position_encoding_kwargs)
    else:
        raise ValueError(f"Unknown position encoding type: {position_encoding_type}.")

    # Optionally, project the position encoding to a target dimension:
    positions_projection = nn.Linear(out_channels, project_pos_dim) if project_pos_dim > 0 else nn.Identity()

    return output_pos_enc, positions_projection


# Below: Perceiver decoders


class PerceiverAbstractDecoder(nn.Module, metaclass=abc.ABCMeta):
    """Perceiver abstract decoder."""

    @abc.abstractmethod
    def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
        raise NotImplementedError

    @property
    @abc.abstractmethod
    def num_query_channels(self):
        raise NotImplementedError

    @abc.abstractmethod
    def forward(self, query, z, query_mask=None):
        raise NotImplementedError


class PerceiverProjectionDecoder(PerceiverAbstractDecoder):
    """
    Baseline projection decoder (no cross-attention).

    Args:
        config ([`PerceiverConfig`]):
            Model configuration.
    """

    def __init__(self, config):
        super().__init__()
        self.classifier = nn.Linear(config.d_latents, config.num_labels)

    def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
        return None

    def forward(
        self, query: torch.Tensor, z: torch.FloatTensor, query_mask: Optional[torch.FloatTensor] = None
    ) -> torch.FloatTensor:
        # (batch_size, num_latents, d_latents) -> (batch_size, d_latents)
        z = torch.mean(z, dim=1)
        # (batch_size, d_latents) -> (batch_size, config.num_labels)
        logits = self.classifier(z)
        return logits


class PerceiverBasicDecoder(PerceiverAbstractDecoder):
    """
    Cross-attention-based decoder. This class can be used to decode the final hidden states of the latents using a
    cross-attention operation, in which the latents produce keys and values.

    The shape of the output of this class depends on how one defines the output queries (also called decoder queries).
    """
    Args:
        config ([*PerceiverConfig*]):
            Model configuration.
        output_num_channels (`int`, *optional*):
            The number of channels in the output. Will only be used in case *final_project* is set to `True`.
        position_encoding_type (`str`, *optional*, defaults to "trainable"):
            The type of position encoding to use. Can be either "trainable", "fourier", or "none".
        output_index_dims (`int`, *optional*):
            The number of dimensions of the output queries. Ignored if 'position_encoding_type' == 'none'.
        num_channels (`int`, *optional*, defaults to 128):
            The number of channels of the decoder queries. Ignored if 'position_encoding_type' == 'none'.
        subsampled_index_dims (`int`, *optional*):
            The number of dimensions of the subsampled indices. Ignored if 'position_encoding_type' == 'none'.
        qk_channels (`int`, *optional*):
            The number of channels of the queries and keys in the cross-attention layer.
        v_channels (`int`, *optional*):
            The number of channels of the values in the cross-attention layer.
        num_heads (`int`, *optional*, defaults to 1):
            The number of attention heads in the cross-attention layer.
        widening_factor (`int`, *optional*, defaults to 1):
            The widening factor of the cross-attention layer.
        use_query_residual (`bool`, *optional*, defaults to `False`):
            Whether to use a residual connection between the query and the output of the cross-attention layer.
        concat_preprocessed_input (`bool`, *optional*, defaults to `False`):
            Whether to concatenate the preprocessed input to the query.
        final_project (`bool`, *optional*, defaults to `True`):
            Whether to project the output of the cross-attention layer to a target dimension.
        position_encoding_only (`bool`, *optional*, defaults to `False`):
            Whether to only use this class to define output queries.
    ) -> None:
        super().__init__()
        
        self.output_num_channels = output_num_channels
        # 如果为 `none`,则解码器不会构建任何位置编码。
        # 当查询解码器时,您应该自行构建位置编码。
        self.output_position_encodings = None
        self.position_encoding_type = position_encoding_type
        self.position_encoding_kwargs = position_encoding_kwargs
        if position_encoding_type != "none":
            self.output_position_encodings, self.positions_projection = build_position_encoding(
                position_encoding_type=position_encoding_type, **position_encoding_kwargs
            )
        
        self.output_index_dims = output_index_dims
        self.num_channels = num_channels
        if subsampled_index_dims is None:
            subsampled_index_dims = output_index_dims
        self.subsampled_index_dims = subsampled_index_dims
        self.concat_preprocessed_input = concat_preprocessed_input
        self.final_project = final_project
        self.position_encoding_only = position_encoding_only
        
        # 对于多模态自编码,我们不需要解码器的交叉注意力和最终层
        # 因此,将 position_encoding_only 设置为 True
        if not self.position_encoding_only:
            self.decoding_cross_attention = PerceiverLayer(
                config,
                is_cross_attention=True,
                qk_channels=qk_channels,
                v_channels=v_channels,
                num_heads=num_heads,
                q_dim=num_channels,
                kv_dim=config.d_latents,
                widening_factor=widening_factor,
                use_query_residual=use_query_residual,
            )
            self.final_layer = nn.Linear(num_channels, output_num_channels) if final_project else nn.Identity()

    @property
    def num_query_channels(self) -> int:
        if self.position_encoding_type == "none":  # 查询来自其他地方
            raise ValueError(
                "You cannot calculate number of decoder query channels when position_encoding_type is set to none"
            )
        if self.position_encoding_only:
            if "project_pos_dim" in self.position_encoding_kwargs:
                return self.position_encoding_kwargs["project_pos_dim"]
            return self.output_position_encodings.output_size()
        if self.final_project:
            return self.output_num_channels
        return self.num_channels
    # 定义一个方法用于解码查询,接受多个输入参数
    def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
        # 如果位置编码类型为"none",则抛出数值错误,不允许构建解码查询
        if self.position_encoding_type == "none":
            raise ValueError("You cannot construct decoder queries when position_encoding_type is set to none")
        
        # 如果给定了子采样点(subsampled_points)
        if subsampled_points is not None:
            # subsampled_points 是输入在扁平化后的索引,使用unravel_index获取非扁平化后的数组索引
            indices = [torch.from_numpy(x) for x in np.unravel_index(subsampled_points.cpu(), self.output_index_dims)]
            # 将索引堆叠成 [n, d] 的坐标张量
            pos = torch.stack(indices, dim=1)
            batch_size = inputs.shape[0]
            # 将这些坐标映射到 [-1, 1] 的范围
            pos = -1 + 2 * pos / torch.tensor(self.output_index_dims)[None, :]
            # 广播位置张量,使其与输入数据形状相匹配
            pos = torch.broadcast_to(pos[None], [batch_size, pos.shape[0], pos.shape[1]])
            
            # 构建位置编码
            if self.position_encoding_type == "trainable":
                pos_emb = self.output_position_encodings(batch_size)
            elif self.position_encoding_type == "fourier":
                pos_emb = self.output_position_encodings(
                    self.output_index_dims, batch_size=batch_size, device=inputs.device, dtype=inputs.dtype, pos=pos
                )

            # 可选地将位置编码投影到目标维度
            pos_emb = self.positions_projection(pos_emb)
            pos_emb = torch.reshape(pos_emb, [pos_emb.shape[0], -1, pos_emb.shape[-1]])
        else:
            # 如果没有提供子采样点,获取输入的批次大小和索引维度
            batch_size = inputs.shape[0]
            index_dims = inputs.shape[2:]

            # 构建位置编码
            if self.position_encoding_type == "trainable":
                pos_emb = self.output_position_encodings(batch_size)
            elif self.position_encoding_type == "fourier":
                pos_emb = self.output_position_encodings(
                    index_dims, batch_size, device=inputs.device, dtype=inputs.dtype
                )

            # 可选地将位置编码投影到目标维度
            pos_emb = self.positions_projection(pos_emb)

        # 如果设置了 concat_preprocessed_input 标志,则将预处理的输入与位置编码连接起来
        if self.concat_preprocessed_input:
            if inputs_without_pos is None:
                raise ValueError("Value is required for inputs_without_pos if concat_preprocessed_input is True")
            pos_emb = torch.cat([inputs_without_pos, pos_emb], dim=-1)

        # 返回位置编码张量作为方法的输出
        return pos_emb
    ) -> PerceiverDecoderOutput:
        # 定义函数签名,指定返回类型为 PerceiverDecoderOutput

        # 执行交叉注意力解码。
        # key, value: B x N x K; query: B x M x K
        # Attention maps -> B x N x M
        # Output -> B x M x K
        # 如果不需要输出注意力权重,则将 cross_attentions 设置为 None
        cross_attentions = () if output_attentions else None

        # 调用解码器的交叉注意力层
        layer_outputs = self.decoding_cross_attention(
            query,
            attention_mask=query_mask,
            head_mask=None,
            inputs=z,
            inputs_mask=None,
            output_attentions=output_attentions,
        )
        # 获取解码器层输出的第一个元素,即解码器的输出
        output = layer_outputs[0]

        # 如果需要输出注意力权重,将当前层的注意力权重添加到 cross_attentions 中
        if output_attentions:
            cross_attentions = cross_attentions + (layer_outputs[1],)

        # 将解码器的输出传入最终的输出层,得到最终的 logits
        logits = self.final_layer(output)

        # 返回 PerceiverDecoderOutput 对象,包含 logits 和可能的 cross_attentions
        return PerceiverDecoderOutput(logits=logits, cross_attentions=cross_attentions)
class PerceiverClassificationDecoder(PerceiverAbstractDecoder):
    """
    Cross-attention based classification decoder. Light-weight wrapper of [`PerceiverBasicDecoder`] for logit output.
    Will turn the output of the Perceiver encoder which is of shape (batch_size, num_latents, d_latents) to a tensor of
    shape (batch_size, num_labels). The queries are of shape (batch_size, 1, num_labels).

    Args:
        config ([`PerceiverConfig`]):
            Model configuration.
    """

    def __init__(self, config, **decoder_kwargs):
        super().__init__()

        self.num_labels = config.num_labels  # 设置分类标签的数量
        self.decoder = PerceiverBasicDecoder(
            config,
            output_num_channels=self.num_labels,  # 输出通道数设置为分类标签的数量
            output_index_dims=1,  # 预测单一logit数组
            **decoder_kwargs,
        )

    @property
    def num_query_channels(self) -> int:
        return self.decoder.num_query_channels  # 返回解码器的查询通道数量

    def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
        return self.decoder.decoder_query(
            inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_points  # 返回解码器的查询结果
        )

    def forward(
        self,
        query: torch.Tensor,
        z: torch.FloatTensor,
        query_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> PerceiverDecoderOutput:
        decoder_outputs = self.decoder(query, z, output_attentions=output_attentions)

        # B x 1 x num_classes -> B x num_classes
        logits = decoder_outputs.logits[:, 0, :]  # 从解码器输出中提取logits

        return PerceiverDecoderOutput(logits=logits, cross_attentions=decoder_outputs.cross_attentions)  # 返回解码器的输出结果


class PerceiverOpticalFlowDecoder(PerceiverAbstractDecoder):
    """Cross-attention based optical flow decoder."""

    def __init__(self, config, output_image_shape, output_num_channels=2, rescale_factor=100.0, **decoder_kwargs):
        super().__init__()

        self.output_image_shape = output_image_shape  # 设置输出图像的形状
        self.output_num_channels = output_num_channels  # 设置输出图像的通道数
        self.rescale_factor = rescale_factor  # 设置光流的重新缩放因子
        self.decoder = PerceiverBasicDecoder(config, output_num_channels=output_num_channels, **decoder_kwargs)

    @property
    def num_query_channels(self) -> int:
        return self.decoder.num_query_channels  # 返回解码器的查询通道数量

    def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
        if subsampled_points is not None:
            raise ValueError("FlowDecoder doesn't support subsampling yet.")  # 如果有子采样点,则引发错误
        return inputs  # 返回输入数据,用于光流解码器的查询

    def forward(
        self,
        query: torch.Tensor,
        z: torch.FloatTensor,
        query_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> PerceiverDecoderOutput:
        # 此处应有更多代码,但已截断
        pass  # 占位符,实际应该返回光流解码器的输出结果
    ) -> PerceiverDecoderOutput:
        # 调用解码器生成输出,传入查询向量 query 和编码器输出 z,选择是否返回注意力权重
        decoder_outputs = self.decoder(query, z, output_attentions=output_attentions)
        # 从解码器输出中提取预测的 logits
        preds = decoder_outputs.logits
        # 对预测结果进行缩放,使用预定义的缩放因子 self.rescale_factor
        preds /= self.rescale_factor
        # 调整预测结果的形状为 [batch_size, output_height, output_width, num_classes]
        preds = preds.reshape([preds.shape[0]] + list(self.output_image_shape) + [preds.shape[-1]])
        # 返回经过解码器处理后的输出,包括 logits 和可能的交叉注意力权重
        return PerceiverDecoderOutput(logits=preds, cross_attentions=decoder_outputs.cross_attentions)
class PerceiverBasicVideoAutoencodingDecoder(PerceiverAbstractDecoder):
    """
    Cross-attention based video-autoencoding decoder. Light-weight wrapper of [*PerceiverBasicDecoder*] with video
    reshaping logic.

    Args:
        config ([*PerceiverConfig*]):
            Model configuration.
        output_shape (`List[int]`):
            Shape of the output as (batch_size, num_frames, height, width), excluding the channel dimension.
        position_encoding_type (`str`):
            The type of position encoding to use. Can be either "trainable", "fourier", or "none".
    """

    def __init__(
        self, config: PerceiverConfig, output_shape: List[int], position_encoding_type: str, **decoder_kwargs
    ) -> None:
        super().__init__()
        # Validate the shape of output_shape to ensure it's rank 4 (batch_size, num_frames, height, width)
        if len(output_shape) != 4:  # B, T, H, W
            raise ValueError(f"Expected rank 4 output_shape, got {output_shape}.")
        # Initialize the decoder components:
        self.output_shape = output_shape
        self.output_num_channels = decoder_kwargs["output_num_channels"]

        # Create an instance of PerceiverBasicDecoder tailored for video decoding:
        self.decoder = PerceiverBasicDecoder(
            config,
            output_index_dims=self.output_shape[1:4],  # T*H*W
            position_encoding_type=position_encoding_type,
            **decoder_kwargs,
        )

    @property
    def num_query_channels(self) -> int:
        # Return the number of query channels from the decoder:
        return self.decoder.num_query_channels

    def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
        # Delegate the decoder_query method to the underlying PerceiverBasicDecoder instance:
        return self.decoder.decoder_query(
            inputs,
            modality_sizes=modality_sizes,
            inputs_without_pos=inputs_without_pos,
            subsampled_points=subsampled_points,
        )

    def forward(
        self, query: torch.Tensor, z: torch.FloatTensor, query_mask: Optional[torch.FloatTensor] = None
    ) -> PerceiverDecoderOutput:
        # Forward pass through the decoder:
        decoder_outputs = self.decoder(query, z)
        logits = decoder_outputs.logits

        # Reshape logits to match the specified output shape:
        logits = torch.reshape(logits, self.output_shape + [logits.shape[-1]])
        return PerceiverDecoderOutput(logits=logits, cross_attentions=decoder_outputs.cross_attentions)


def restructure(modality_sizes: ModalitySizeType, inputs: torch.Tensor) -> Mapping[str, torch.Tensor]:
    """
    Partitions a [B, N, C] tensor into tensors for each modality.

    Args:
        modality_sizes
            dict specifying the size of the modality
        inputs:
            input tensor

    Returns:
        dict mapping name of modality to its associated tensor.
    """
    outputs = {}
    index = 0
    # Apply a predictable ordering to the modalities by iterating over sorted keys
    for modality in sorted(modality_sizes.keys()):
        size = modality_sizes[modality]
        # Slice the input tensor to extract the portion corresponding to the current modality
        inp = inputs[:, index : index + size]
        index += size
        outputs[modality] = inp
    return outputs


class PerceiverMultimodalDecoder(PerceiverAbstractDecoder):
    """
    Placeholder class for a multimodal decoder based on the Perceiver architecture.
    """
    """
    Multimodal decoding by composing uni-modal decoders. The *modalities* argument of the constructor is a dictionary
    mapping modality name to the decoder of that modality. That decoder will be used to construct queries for that
    modality. Modality-specific queries are padded with trainable modality-specific parameters, after which they are
    concatenated along the time dimension.

    Next, there is a shared cross attention operation across all modalities.

    Args:
        config ([*PerceiverConfig*]):
            Model configuration.
        modalities (`Dict[str, PerceiverAbstractDecoder]`):
            Dictionary mapping modality name to the decoder of that modality.
        num_outputs (`int`):
            The number of outputs of the decoder.
        output_num_channels (`int`):
            The number of channels in the output.
        min_padding_size (`int`, *optional*, defaults to 2):
            The minimum padding size for all modalities. The final output will have num_channels equal to the maximum
            channels across all modalities plus min_padding_size.
        subsampled_index_dims (`Dict[str, PerceiverAbstractDecoder]`, *optional*):
            Dictionary mapping modality name to the subsampled index dimensions to use for the decoder query of that
            modality.
    """

    def __init__(
        self,
        config: PerceiverConfig,
        modalities: Dict[str, PerceiverAbstractDecoder],
        num_outputs: int,
        output_num_channels: int,
        min_padding_size: Optional[int] = 2,
        subsampled_index_dims: Optional[Dict[str, PerceiverAbstractDecoder]] = None,
        **decoder_kwargs,
    ) -> None:
        """
        Constructor method for the MultimodalPerceiverDecoder class.

        Args:
            config (PerceiverConfig): Model configuration.
            modalities (Dict[str, PerceiverAbstractDecoder]): Dictionary mapping modality name to the decoder.
            num_outputs (int): The number of outputs of the decoder.
            output_num_channels (int): The number of channels in the output.
            min_padding_size (int, optional): The minimum padding size for all modalities.
            subsampled_index_dims (Dict[str, PerceiverAbstractDecoder], optional): Dictionary mapping modality name to
                subsampled index dimensions for the decoder query.
            **decoder_kwargs: Additional keyword arguments for the decoder.
        """
        super().__init__()
        # Initialize the modalities as a ModuleDict
        self.modalities = nn.ModuleDict(modalities)
        # Store the subsampled index dimensions
        self.subsampled_index_dims = subsampled_index_dims
        # Store the minimum padding size
        self.min_padding_size = min_padding_size
        # Store the number of output channels
        self.output_num_channels = output_num_channels
        # Store the number of outputs
        self.num_outputs = num_outputs
        # Initialize the decoder with given configuration and arguments
        self.decoder = PerceiverBasicDecoder(
            config,
            output_index_dims=(num_outputs,),
            output_num_channels=output_num_channels,
            position_encoding_type="none",
            num_channels=self.num_query_channels,
            **decoder_kwargs,
        )
        # Initialize padding parameters for each modality
        self.padding = nn.ParameterDict(
            {
                modality: nn.Parameter(torch.randn(1, self.num_query_channels - decoder.num_query_channels))
                for modality, decoder in modalities.items()
            }
        )

    @property
    def num_query_channels(self) -> int:
        """
        Calculate the number of query channels based on the modalities.

        Returns:
            int: Number of query channels.
        """
        # Determine the maximum number of query channels among modalities
        max_channel_size = max(decoder.num_query_channels for _, decoder in self.modalities.items())
        # Ensure common channel size includes minimum padding size
        common_channel_size = max_channel_size + self.min_padding_size
        return common_channel_size
    def decoder_query(self, inputs, modality_sizes, inputs_without_pos=None, subsampled_points=None):
        # 将扁平化的输入数据按照不同的感知模态进行分割重组
        inputs = restructure(modality_sizes, inputs)

        # 获取各个感知模态的解码器查询
        subsampled_points = subsampled_points or {}

        # 存储每个模态的解码器查询结果
        decoder_queries = {}
        for modality, decoder in self.modalities.items():
            # 如果存在输入数据不包含位置信息,则获取当前模态的无位置信息的输入数据
            input_without_pos = None
            if inputs_without_pos is not None:
                input_without_pos = inputs_without_pos.get(modality, None)
            # 调用解码器的查询函数,获取当前模态的查询结果
            query = decoder.decoder_query(
                inputs=inputs[modality],
                modality_sizes=None,  # 此处未使用 modality_sizes 参数,可能为函数签名未更新的遗留
                inputs_without_pos=input_without_pos,
                subsampled_points=subsampled_points.get(modality, None),
            )
            decoder_queries[modality] = query

        # 使用可训练的位置编码填充所有查询结果,以保证它们具有相同的通道数

        def embed(modality, x):
            # 将输入张量 x 重塑为 [batch_size, 总特征数, 通道数] 的形状
            x = torch.reshape(x, [x.shape[0], np.prod(x.shape[1:-1]), x.shape[-1]])
            # 获取当前模态的填充位置编码
            pos = self.padding[modality]
            # 将位置编码广播到与 x 相同的形状
            pos = torch.broadcast_to(pos, [x.shape[0], x.shape[1], self.num_query_channels - x.shape[2]])
            # 在通道维度上连接 x 和位置编码
            return torch.cat([x, pos], dim=2)

        # 对模态按照可预测的顺序进行排序,并连接它们的查询结果
        return torch.cat(
            [embed(modality, decoder_queries[modality]) for modality in sorted(self.modalities.keys())], dim=1
        )

    def forward(
        self,
        query: torch.Tensor,
        z: torch.FloatTensor,
        query_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> torch.Tensor:
        # B x 1 x num_classes -> B x num_classes
        # 调用解码器模块进行前向传播,生成解码器的输出结果
        decoder_outputs = self.decoder(query, z, output_attentions=output_attentions)

        return decoder_outputs
# Below: IO pre- and post-processor classes for Perceiver.

# 定义一个函数,实现空间到深度的转换,用于重新排列空间数据块到深度
def space_to_depth(frames: torch.Tensor, temporal_block_size: int = 1, spatial_block_size: int = 1) -> torch.Tensor:
    """
    Space to depth transform. Rearranges blocks of spatial data, into depth.

    This function assumes the channels to be first, but will place the channels last after transformation.

    Based on https://discuss.pytorch.org/t/is-there-any-layer-like-tensorflows-space-to-depth-function/3487/15.
    """
    # 检查输入张量的维度是否为4
    if len(frames.shape) == 4:
        batch_size, num_channels, height, width = frames.shape
        # 将空间数据块按照指定的空间块大小进行分割
        frames = frames.view(
            batch_size,
            num_channels,
            height // spatial_block_size,
            spatial_block_size,
            width // spatial_block_size,
            spatial_block_size,
        )
        # 将分割后的块移动到最后一个维度:(batch_size, H//bs, W//bs, bs, bs, C)
        frames = frames.permute(0, 2, 4, 3, 5, 1).contiguous()
        # 沿着通道维度连接块:(batch_size, H//bs, W//bs, bs*bs*C)
        frames = frames.view(
            batch_size,
            height // spatial_block_size,
            width // spatial_block_size,
            (spatial_block_size**2) * num_channels,
        )
        return frames
    # 检查输入张量的维度是否为5
    elif len(frames.shape) == 5:
        batch_size, time, num_channels, height, width = frames.shape
        # 将时间维度和空间维度按照指定的块大小进行分割
        frames = frames.view(
            batch_size,
            time // temporal_block_size,
            temporal_block_size,
            num_channels,
            height // spatial_block_size,
            spatial_block_size,
            width // spatial_block_size,
            spatial_block_size,
        )
        # 将分割后的块移动到最后一个维度:(batch_size, T//ts, H//bs, W//bs, ts, bs, bs, C)
        frames = frames.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous()
        # 沿着通道维度连接块:(batch_size, T//ts, H//bs, W//bs, ts*bs*bs*C)
        frames = frames.view(
            batch_size,
            time // temporal_block_size,
            height // spatial_block_size,
            width // spatial_block_size,
            temporal_block_size * (spatial_block_size**2) * num_channels,
        )
        return frames
    else:
        # 抛出异常,如果输入张量的维度既不是4也不是5
        raise ValueError(
            "Frames should be of rank 4 (batch, channels, height, width)"
            " or rank 5 (batch, time, channels, height, width)"
        )


# 定义一个继承自 nn.Conv2d 的类,支持 padding="same"
class Conv2dSamePadding(nn.Conv2d):
    """
    Conv2d layer with padding="same" support. Source:
    https://gist.github.com/sumanmichael/4de9dee93f972d47c80c4ade8e149ea6
    """
    # 初始化方法,继承父类 Conv2dSamePadding
    def __init__(self, *args, **kwargs):
        # 调用父类的初始化方法
        super(Conv2dSamePadding, self).__init__(*args, **kwargs)
        # 创建 ZeroPad2d 层,用于实现“same” padding
        self.zero_pad_2d = nn.ZeroPad2d(
            # 计算每个维度的 padding 数量,使得卷积操作后大小不变
            reduce(__add__, [(k // 2 + (k - 2 * (k // 2)) - 1, k // 2) for k in self.kernel_size[::-1]])
        )

    # 前向传播方法
    def forward(self, input):
        # 对输入进行 zero padding,保证卷积输出大小与输入相同
        padded_input = self.zero_pad_2d(input)
        # 执行卷积操作,使用权重 self.weight 和偏置 self.bias
        return self._conv_forward(padded_input, self.weight, self.bias)
class Conv2DDownsample(nn.Module):
    """Downsamples 4x by applying a 2D convolution and doing max pooling."""

    def __init__(
        self,
        num_layers: int = 1,
        in_channels: int = 3,
        out_channels: int = 64,
        use_batchnorm: bool = True,
    ):
        """
        Constructs a Conv2DDownsample model.

        Args:
          in_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
          out_channels (`int`, *optional*, defaults to 64):
            The number of conv output channels.
          use_batchnorm (`bool`, *optional*, defaults to `True`):
            Whether to use batchnorm.
        """
        super().__init__()

        # Define a 2D convolution layer with same padding
        self.conv = Conv2dSamePadding(
            in_channels=in_channels, out_channels=out_channels, kernel_size=7, stride=2, bias=False
        )
        
        # Batch normalization layer if `use_batchnorm` is True, otherwise an identity layer
        self.batchnorm = nn.BatchNorm2d(num_features=out_channels) if use_batchnorm else nn.Identity()
        
        # ReLU activation function
        self.relu = nn.ReLU()
        
        # Max pooling layer with kernel size 3 and stride 2
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        # Forward pass through the layers
        out = self.conv(inputs)  # Apply convolution
        out = self.batchnorm(out)  # Apply batch normalization or identity
        out = self.relu(out)  # Apply ReLU activation
        out = self.max_pool(out)  # Apply max pooling
        return out


def generate_fourier_features(pos, num_bands, max_resolution=(224, 224), concat_pos=True, sine_only=False):
    """
    Generate a Fourier frequency position encoding with linear spacing.

    Args:
      pos (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`):
        The Tensor containing the position of n points in d dimensional space.
      num_bands (`int`):
        The number of frequency bands (K) to use.
      max_resolution (`Tuple[int]`, *optional*, defaults to (224, 224)):
        The maximum resolution (i.e. the number of pixels per dim). A tuple representing resolution for each dimension.
      concat_pos (`bool`, *optional*, defaults to `True`):
        Whether to concatenate the input position encoding to the Fourier features.
      sine_only (`bool`, *optional*, defaults to `False`):
        Whether to use a single phase (sin) or two (sin/cos) for each frequency band.

    Returns:
      `torch.FloatTensor` of shape `(batch_size, sequence_length, n_channels)`: The Fourier position embeddings. If
      `concat_pos` is `True` and `sine_only` is `False`, output dimensions are ordered as: [dim_1, dim_2, ..., dim_d,
      sin(pi*f_1*dim_1), ..., sin(pi*f_K*dim_1), ..., sin(pi*f_1*dim_d), ..., sin(pi*f_K*dim_d), cos(pi*f_1*dim_1),
      ..., cos(pi*f_K*dim_1), ..., cos(pi*f_1*dim_d), ..., cos(pi*f_K*dim_d)], where dim_i is pos[:, i] and f_k is the
      kth frequency band.
    """

    batch_size = pos.shape[0]

    min_freq = 1.0
    # Nyquist frequency at the target resolution:
    freq_bands = torch.stack(
        [torch.linspace(start=min_freq, end=res / 2, steps=num_bands) for res in max_resolution], dim=0
    )

    # Get frequency bands for each spatial dimension.
    # (This part of the function calculates frequency bands based on the given maximum resolution and number of bands)
    # Output is size [n, d * num_bands]
    per_pos_features = pos[0, :, :][:, :, None] * freq_bands[None, :, :]
    # Reshape per_pos_features into a flattened shape
    per_pos_features = torch.reshape(per_pos_features, [-1, np.prod(per_pos_features.shape[1:])])

    if sine_only:
        # Output is size [n, d * num_bands]
        # Apply sine transformation to per_pos_features
        per_pos_features = torch.sin(np.pi * (per_pos_features))
    else:
        # Output is size [n, 2 * d * num_bands]
        # Apply both sine and cosine transformations to per_pos_features
        per_pos_features = torch.cat(
            [torch.sin(np.pi * per_pos_features), torch.cos(np.pi * per_pos_features)], dim=-1
        )
    
    # Concatenate the raw input positions.
    if concat_pos:
        # Adds d bands to the encoding.
        # Concatenate pos and per_pos_features along the last dimension
        per_pos_features = torch.cat([pos, per_pos_features.expand(batch_size, -1, -1)], dim=-1)
    
    # Return the final per_pos_features tensor
    return per_pos_features
# 生成一个线性位置索引数组,用于 N 维输入数组。

def build_linear_positions(index_dims, output_range=(-1.0, 1.0)):
    """
    Generate an array of position indices for an N-D input array.

    Args:
      index_dims (`List[int]`):
        The shape of the index dimensions of the input array.
      output_range (`Tuple[float]`, *optional*, defaults to `(-1.0, 1.0)`):
        The min and max values taken by each input index dimension.

    Returns:
      `torch.FloatTensor` of shape `(index_dims[0], index_dims[1], .., index_dims[-1], N)`.
    """

    def _linspace(n_xels_per_dim):
        # 使用 torch.linspace 生成指定范围和步长的一维张量
        return torch.linspace(start=output_range[0], end=output_range[1], steps=n_xels_per_dim, dtype=torch.float32)

    # 生成每个维度的线性分布的张量数组
    dim_ranges = [_linspace(n_xels_per_dim) for n_xels_per_dim in index_dims]
    # 使用 meshgrid 函数创建多维网格,表示每个位置的坐标
    array_index_grid = meshgrid(*dim_ranges, indexing="ij")

    return torch.stack(array_index_grid, dim=-1)


class PerceiverAbstractPositionEncoding(nn.Module, metaclass=abc.ABCMeta):
    """Perceiver abstract position encoding."""

    @property
    @abc.abstractmethod
    def num_dimensions(self) -> int:
        raise NotImplementedError

    @abc.abstractmethod
    def output_size(self, *args, **kwargs) -> int:
        raise NotImplementedError

    @abc.abstractmethod
    def forward(self, batch_size, pos):
        raise NotImplementedError


class PerceiverTrainablePositionEncoding(PerceiverAbstractPositionEncoding):
    """Trainable position encoding."""

    def __init__(self, index_dims, num_channels=128):
        super().__init__()
        self._num_channels = num_channels
        self._index_dims = index_dims
        index_dim = np.prod(index_dims)
        # 创建一个形状为 (index_dim, num_channels) 的可训练的位置嵌入参数
        self.position_embeddings = nn.Parameter(torch.randn(index_dim, num_channels))

    @property
    def num_dimensions(self) -> int:
        if isinstance(self._index_dims, int):
            return 1
        return len(self._index_dims)

    def output_size(self, *args, **kwargs) -> int:
        # 返回位置编码器的输出大小,即 num_channels
        return self._num_channels

    def forward(self, batch_size: int) -> torch.Tensor:
        position_embeddings = self.position_embeddings

        if batch_size is not None:
            # 如果指定了批量大小,扩展位置嵌入参数的第一维度为 batch_size
            position_embeddings = position_embeddings.expand(batch_size, -1, -1)
        return position_embeddings


def _check_or_build_spatial_positions(pos, index_dims, batch_size):
    """
    Checks or builds spatial position features (x, y, ...).

    Args:
      pos (`torch.FloatTensor`):
        None, or an array of position features. If None, position features are built. Otherwise, their size is checked.
      index_dims (`List[int]`):
        An iterable giving the spatial/index size of the data to be featurized.
      batch_size (`int`):
        The batch size of the data to be featurized.

    Returns:
        `torch.FloatTensor` of shape `(batch_size, prod(index_dims))` an array of position features.
    """
    # 如果 pos 参数为 None,则根据 index_dims 构建线性位置信息
    if pos is None:
        pos = build_linear_positions(index_dims)
        # 相当于 `torch.broadcast_to(pos[None], (batch_size,) + pos.shape)`
        # 但是 `torch.broadcast_to` 不能转换为 ONNX 格式
        pos = pos[None].expand((batch_size,) + pos.shape)
        pos = torch.reshape(pos, [batch_size, np.prod(index_dims), -1])
    else:
        # 警告:你可能不希望你的空间特征与 pos 坐标系的空间布局不同。
        # 如果你认为可以,请随意覆盖这一段代码!
        
        # 检查 pos 的最后一个维度是否与 index_dims 的长度相同
        if pos.shape[-1] != len(index_dims):
            raise ValueError("Spatial features have the wrong number of dimensions.")
    # 返回 pos 变量,其中包含位置信息
    return pos
class PerceiverFourierPositionEncoding(PerceiverAbstractPositionEncoding):
    """Fourier (Sinusoidal) position encoding."""

    def __init__(self, num_bands, max_resolution, concat_pos=True, sine_only=False):
        super().__init__()
        self.num_bands = num_bands  # 设置频带数量
        self.max_resolution = max_resolution  # 设置最大分辨率
        self.concat_pos = concat_pos  # 是否连接位置编码
        self.sine_only = sine_only  # 是否只使用正弦编码

    @property
    def num_dimensions(self) -> int:
        return len(self.max_resolution)  # 返回最大分辨率的维度数

    def output_size(self):
        """Returns size of positional encodings last dimension."""
        num_dims = len(self.max_resolution)  # 获取最大分辨率的维度数
        encoding_size = self.num_bands * num_dims  # 计算编码的大小
        if not self.sine_only:
            encoding_size *= 2  # 如果不仅使用正弦编码,则大小加倍
        if self.concat_pos:
            encoding_size += self.num_dimensions  # 如果连接位置编码,则增加维度数

        return encoding_size  # 返回编码的最后一个维度大小

    def forward(
        self,
        index_dims: List[int],
        batch_size: int,
        device: torch.device,
        dtype: torch.dtype,
        pos: torch.FloatTensor = None,
    ) -> torch.FloatTensor:
        pos = _check_or_build_spatial_positions(pos, index_dims, batch_size)  # 检查或构建空间位置
        fourier_pos_enc = generate_fourier_features(
            pos,
            num_bands=self.num_bands,
            max_resolution=self.max_resolution,
            concat_pos=self.concat_pos,
            sine_only=self.sine_only,
        ).to(device=device, dtype=dtype)  # 生成傅里叶特征编码,并将其转移到指定设备和数据类型
        return fourier_pos_enc


class AbstractPreprocessor(nn.Module):
    @property
    def num_channels(self) -> int:
        """Returns size of preprocessor output."""
        raise NotImplementedError()


class PerceiverTextPreprocessor(AbstractPreprocessor):
    """
    Text preprocessing for Perceiver Encoder. Can be used to embed `inputs` and add positional encodings.

    The dimensionality of the embeddings is determined by the `d_model` attribute of the configuration.

    Args:
        config ([`PerceiverConfig`]):
            Model configuration.
    """

    def __init__(self, config: PerceiverConfig) -> None:
        super().__init__()
        self.config = config  # 设置模型配置
        self.embeddings = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.d_model)  # 创建词嵌入层
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)  # 创建位置编码层

    @property
    def num_channels(self) -> int:
        return self.config.d_model  # 返回模型配置中的 d_model 大小

    def forward(self, inputs: torch.LongTensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
        embeddings_without_pos = self.embeddings(inputs)  # 获取不包含位置编码的词嵌入

        seq_length = inputs.shape[1]  # 获取序列长度
        position_ids = torch.arange(0, seq_length, device=inputs.device)  # 在指定设备上创建位置索引
        embeddings = embeddings_without_pos + self.position_embeddings(position_ids)  # 添加位置编码到词嵌入

        return embeddings, None, embeddings_without_pos


class PerceiverEmbeddingDecoder(nn.Module):
    """
    Module to decode embeddings (for masked language modeling).
    """
    Args:
        config ([`PerceiverConfig`]):
            Model configuration.
    """
    # 定义 Perceiver 模型类,继承自 nn.Module
    def __init__(self, config: PerceiverConfig) -> None:
        super().__init__()
        # 保存模型配置
        self.config = config
        # 获取词汇表大小
        self.vocab_size = config.vocab_size
        # 初始化偏置项,维度为词汇表大小,作为可学习参数
        self.bias = nn.Parameter(torch.zeros(self.vocab_size))

    def forward(self, hidden_states: torch.Tensor, embedding_layer: torch.Tensor) -> torch.Tensor:
        # 获取输入的张量维度信息
        batch_size, seq_len, d_model = hidden_states.shape
        # 将隐藏状态张量展平(flatten)为二维张量,进行矩阵乘法
        output = torch.matmul(hidden_states.reshape([-1, d_model]), embedding_layer.weight.transpose(0, 1))
        # 添加偏置项到输出张量
        output = output + self.bias
        # 将输出张量重新形状为原始的三维张量形状
        return output.reshape([batch_size, seq_len, self.vocab_size])
class PerceiverMultimodalPostprocessor(nn.Module):
    """
    Multimodal postprocessing for Perceiver. Can be used to combine modality-specific postprocessors into a single
    postprocessor.

    Args:
          modalities (`Mapping[str, PostprocessorType]`):
            Dictionary mapping modality name to postprocessor class for that modality.
          input_is_dict (`bool`, *optional*, defaults to `False`):
            If True, input is assumed to be dictionary structured, and outputs keep the same dictionary shape. If
            False, input is a tensor which is sliced up during postprocessing by *modality_sizes*.
    """

    def __init__(self, modalities: Mapping[str, PostprocessorType], input_is_dict: bool = False):
        super().__init__()
        # 初始化时将各个模态的后处理器组成一个模块字典
        self.modalities = nn.ModuleDict(modalities)
        # 标记输入是否为字典形式
        self.input_is_dict = input_is_dict

    def forward(
        self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None
    ) -> Mapping[str, torch.Tensor]:
        if not self.input_is_dict:
            # 如果输入不是字典形式,根据模态大小重新组织输入数据
            if modality_sizes is None:
                raise ValueError("Modality sizes should be specified if input is not a dictionary.")
            inputs = restructure(modality_sizes=modality_sizes, inputs=inputs)

        # 对每个模态使用对应的后处理器进行处理,并输出结果字典
        outputs = {
            modality: postprocessor(inputs[modality], pos=pos, modality_sizes=None)
            for modality, postprocessor in self.modalities.items()
        }
        return outputs


class PerceiverClassificationPostprocessor(nn.Module):
    """
    Classification postprocessing for Perceiver. Can be used to convert the decoder output to classification logits.

    Args:
        config ([*PerceiverConfig*]):
            Model configuration.
        in_channels (`int`):
            Number of channels in the input.
    """

    def __init__(self, config: PerceiverConfig, in_channels: int) -> None:
        super().__init__()
        # 使用线性层将输入通道数映射为分类标签数
        self.classifier = nn.Linear(in_channels, config.num_labels)

    def forward(self, inputs, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor:
        # 使用分类器线性层计算分类 logits
        logits = self.classifier(inputs)
        return logits[:, 0, :]


class PerceiverAudioPostprocessor(nn.Module):
    """
    Audio postprocessing for Perceiver. Can be used to convert the decoder output to audio features.

    Args:
        config ([*PerceiverConfig*]):
            Model configuration.
        in_channels (`int`):
            Number of channels in the input.
        postproc_type (`str`, *optional*, defaults to `"patches"`):
            Postprocessor type to use. Currently, only "patches" is supported.
    """
    # 使用给定的配置和输入通道数初始化模型
    def __init__(self, config: PerceiverConfig, in_channels: int, postproc_type: str = "patches") -> None:
        # 调用父类的初始化方法
        super().__init__()

        # 检查后处理类型是否在支持的范围内,目前支持 'patches' 类型
        if postproc_type not in ("patches",):  # to be supported: 'conv', 'patches', 'pixels'
            # 如果不在支持的类型中,则抛出数值错误异常
            raise ValueError("Invalid postproc_type!")

        # 架构参数:
        # 创建一个线性分类器,输入通道数为 in_channels,输出通道数为 config.samples_per_patch
        self.classifier = nn.Linear(in_channels, config.samples_per_patch)

    # 前向传播函数,接收输入张量 inputs,可选的位置张量 pos 和模态大小 modality_sizes
    def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor:
        # 使用分类器进行前向计算,得到 logits
        logits = self.classifier(inputs)
        # 对 logits 进行形状变换,将其变为 [batch_size, -1] 的形状
        return torch.reshape(logits, [inputs.shape[0], -1])
# 定义了一个名为 PerceiverProjectionPostprocessor 的神经网络模块,用于处理 Perceiver 模型的投影后处理,
# 可以将解码器输出的通道投影到较低维度。

class PerceiverProjectionPostprocessor(nn.Module):
    """
    Projection postprocessing for Perceiver. Can be used to project the channels of the decoder output to a lower
    dimension.

    Args:
        in_channels (`int`):
            Number of channels in the input.
        out_channels (`int`):
            Number of channels in the output.
    """

    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        # 使用线性层进行投影,将输入通道数投影到输出通道数
        self.classifier = nn.Linear(in_channels, out_channels)

    def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor:
        # 将输入数据通过线性层进行投影
        logits = self.classifier(inputs)
        # 返回投影后的结果
        return logits


class PerceiverImagePreprocessor(AbstractPreprocessor):
    """
    Image preprocessing for Perceiver Encoder.

    Note: the *out_channels* argument refers to the output channels of a convolutional layer, if *prep_type* is set to
    "conv1x1" or "conv". If one adds absolute position embeddings, one must make sure the *num_channels* of the
    position encoding kwargs are set equal to the *out_channels*.

    Args:
        config ([*PerceiverConfig*]):
            Model configuration.
        prep_type (`str`, *optional*, defaults to `"conv"`):
            Preprocessing type. Can be "conv1x1", "conv", "patches", "pixels".
        spatial_downsample (`int`, *optional*, defaults to 4):
            Spatial downsampling factor.
        temporal_downsample (`int`, *optional*, defaults to 1):
            Temporal downsampling factor (only relevant in case a time dimension is present).
        position_encoding_type (`str`, *optional*, defaults to `"fourier"`):
            Position encoding type. Can be "fourier" or "trainable".
        in_channels (`int`, *optional*, defaults to 3):
            Number of channels in the input.
        out_channels (`int`, *optional*, defaults to 64):
            Number of channels in the output.
        conv_after_patching (`bool`, *optional*, defaults to `False`):
            Whether to apply a convolutional layer after patching.
        conv_after_patching_in_channels (`int`, *optional*, defaults to 54):
            Number of channels in the input of the convolutional layer after patching.
        conv2d_use_batchnorm (`bool`, *optional*, defaults to `True`):
            Whether to use batch normalization in the convolutional layer.
        concat_or_add_pos (`str`, *optional*, defaults to `"concat"`):
            How to concatenate the position encoding to the input. Can be "concat" or "add".
        project_pos_dim (`int`, *optional*, defaults to -1):
            Dimension of the position encoding to project to. If -1, no projection is applied.
        **position_encoding_kwargs (`Dict`, *optional*):
            Keyword arguments for the position encoding.
    """
    # 初始化函数,用于创建一个新的对象实例
    def __init__(
        self,
        config,
        prep_type="conv",  # 预处理类型,默认为卷积
        spatial_downsample: int = 4,  # 空间下采样因子,默认为4
        temporal_downsample: int = 1,  # 时间下采样因子,默认为1
        position_encoding_type: str = "fourier",  # 位置编码类型,默认为傅里叶
        in_channels: int = 3,  # 输入通道数,默认为3
        out_channels: int = 64,  # 输出通道数,默认为64
        conv_after_patching: bool = False,  # 是否在打补丁后进行卷积,默认为False
        conv_after_patching_in_channels: int = 54,  # 仅在conv_after_patching为True时 relevant 的输入通道数
        conv2d_use_batchnorm: bool = True,  # 是否在卷积层后使用批量归一化,默认为True
        concat_or_add_pos: str = "concat",  # 位置编码添加方式,默认为拼接
        project_pos_dim: int = -1,  # 位置维度投影,默认为-1
        **position_encoding_kwargs,  # 其他位置编码的关键字参数
        ):
        # 调用父类的构造函数
        super().__init__()
        # 将配置参数保存到实例变量中
        self.config = config

        # 检查预处理类型是否合法
        if prep_type not in ("conv", "patches", "pixels", "conv1x1"):
            raise ValueError(f"Prep_type {prep_type} is invalid")

        # 检查拼接或添加位置的选项是否合法
        if concat_or_add_pos not in ["concat", "add"]:
            raise ValueError(f"Invalid value {concat_or_add_pos} for concat_or_add_pos.")

        # 初始化实例变量
        self.in_channels = in_channels
        self.prep_type = prep_type
        self.spatial_downsample = spatial_downsample
        self.temporal_downsample = temporal_downsample
        self.position_encoding_type = position_encoding_type
        self.concat_or_add_pos = concat_or_add_pos
        self.conv_after_patching = conv_after_patching
        self.out_channels = out_channels

        # 如果预处理类型为 "conv"
        if self.prep_type == "conv":
            # 使用对数函数计算需要的卷积层数,要求空间下采样为4的幂次方
            convnet_num_layers = math.log(spatial_downsample, 4)
            convnet_num_layers_is_int = convnet_num_layers == np.round(convnet_num_layers)
            # 检查空间和时间下采样是否符合要求
            if not convnet_num_layers_is_int or temporal_downsample != 1:
                raise ValueError(
                    "Only powers of 4 expected for spatial and 1 expected for temporal downsampling with conv."
                )
            # 创建卷积下采样网络
            self.convnet = Conv2DDownsample(
                in_channels=in_channels,
                num_layers=int(convnet_num_layers),
                out_channels=out_channels,
                use_batchnorm=conv2d_use_batchnorm,
            )

        # 如果预处理类型为 "conv1x1"
        elif self.prep_type == "conv1x1":
            # 对于 conv1x1,只允许空间下采样,不允许时间下采样
            if temporal_downsample != 1:
                raise ValueError("Conv1x1 does not downsample in time.")
            # 创建 1x1 卷积层
            self.convnet_1x1 = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=(1, 1),
                stride=(spatial_downsample, spatial_downsample),  # 空间下采样步幅设置
            )

        # 构建位置编码
        self.project_pos_dim = project_pos_dim
        self.position_embeddings, self.positions_projection = build_position_encoding(
            position_encoding_type=position_encoding_type,
            out_channels=out_channels,
            project_pos_dim=project_pos_dim,
            **position_encoding_kwargs,
        )

        # 可选的卷积层,用于在提取补丁之后进行处理
        self.conv_after_patches = (
            nn.Linear(conv_after_patching_in_channels, self.out_channels) if conv_after_patching else nn.Identity()
        )
    def num_channels(self) -> int:
        # 假设输入数据的分辨率在图像预处理的上下文中是2或3,
        # 取决于我们是处理图像还是视频。为了方便起见,
        # 我们定义一个 is_temporal 变量,用于表示数据是否具有时间维度。
        is_temporal = self.position_embeddings.num_dimensions > 2

        # 位置嵌入
        if self.project_pos_dim > 0:
            pos_dim = self.project_pos_dim
        else:
            pos_dim = self.position_embeddings.output_size()

        # 如果使用“add”模式连接位置编码,则返回位置维度
        if self.concat_or_add_pos == "add":
            return pos_dim

        # 输入维度
        if self.conv_after_patching or self.prep_type in ("conv1x1", "conv"):
            inp_dim = self.out_channels
        elif self.prep_type == "pixels":
            inp_dim = self.in_channels
            if not is_temporal:
                inp_dim = math.ceil(inp_dim / self.spatial_downsample)
        elif self.prep_type == "patches":
            if self.conv_after_patching:
                inp_dim = self.out_channels
            else:
                inp_dim = self.in_channels * self.spatial_downsample**2
                if is_temporal:
                    inp_dim *= self.temporal_downsample

        # 返回输入维度加上位置维度的结果
        return inp_dim + pos_dim

    def _build_network_inputs(self, inputs: torch.Tensor, network_input_is_1d: bool = True):
        """
        构建最终输入,包括位置编码。

        该方法假设输入始终将通道作为最后一个维度。

        """
        batch_size = inputs.shape[0]
        index_dims = inputs.shape[1:-1]
        indices = np.prod(index_dims)

        # 如果输入维度大于3且网络输入是1维,则将输入特征展平为1维索引维度。
        if len(inputs.shape) > 3 and network_input_is_1d:
            inputs = torch.reshape(inputs, [batch_size, indices, -1])

        # 构建位置编码
        if self.position_encoding_type == "trainable":
            pos_enc = self.position_embeddings(batch_size)
        elif self.position_encoding_type == "fourier":
            pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype)

        # 可选择将位置编码投影到目标维度。
        pos_enc = self.positions_projection(pos_enc)

        if not network_input_is_1d:
            # 如果网络接受非1维输入,则重新整形位置编码以匹配输入特征形状。
            sh = inputs.shape
            pos_enc = torch.reshape(pos_enc, list(sh)[:-1] + [-1])

        # 根据连接或加法模式将位置编码与输入合并或相加,并返回结果。
        if self.concat_or_add_pos == "concat":
            inputs_with_pos = torch.cat([inputs, pos_enc], dim=-1)
        elif self.concat_or_add_pos == "add":
            inputs_with_pos = inputs + pos_enc

        return inputs_with_pos, inputs
    def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
        # 根据 self.prep_type 的不同进行不同的数据预处理
        if self.prep_type == "conv":
            # 如果预处理类型为 "conv",则使用卷积神经网络进行图像特征提取
            # 空间下采样因子为4
            inputs = self.convnet(inputs)

        elif self.prep_type == "conv1x1":
            # 如果预处理类型为 "conv1x1",则将输入映射到 self.out_channels 维度
            inputs = self.convnet_1x1(inputs)

        elif self.prep_type == "pixels":
            # 如果预处理类型为 "pixels",根据输入的维度进行最简单的下采样处理
            if inputs.ndim == 4:
                inputs = inputs[:: self.spatial_downsample, :: self.spatial_downsample]
            elif inputs.ndim == 5:
                inputs = inputs[
                    :, :: self.temporal_downsample, :, :: self.spatial_downsample, :: self.spatial_downsample
                ]
            else:
                raise ValueError("Unsupported data format for pixels.")

        elif self.prep_type == "patches":
            # 如果预处理类型为 "patches",进行 Space2depth 特征化处理
            # 视频数据格式为 B x T x C x H x W
            inputs = space_to_depth(
                inputs, temporal_block_size=self.temporal_downsample, spatial_block_size=self.spatial_downsample
            )

            # 如果数据维度为5且第二个维度为1,则为光流数据,进行压缩处理
            if inputs.ndim == 5 and inputs.shape[1] == 1:
                inputs = inputs.squeeze(dim=1)

            # 可选择应用卷积层
            inputs = self.conv_after_patches(inputs)

        if self.prep_type != "patches":
            # 将通道移动到最后一个维度,因为下面的 _build_network_inputs 方法需要这种格式
            if inputs.ndim == 4:
                inputs = inputs.permute(0, 2, 3, 1)
            elif inputs.ndim == 5:
                inputs = inputs.permute(0, 1, 3, 4, 2)
            else:
                raise ValueError("Unsupported data format for conv1x1.")

        # 调用 _build_network_inputs 方法构建网络输入
        inputs, inputs_without_pos = self._build_network_inputs(inputs, network_input_is_1d)
        modality_sizes = None  # 每种模态的大小,仅在多模态情况下需要

        return inputs, modality_sizes, inputs_without_pos
# 定义一个用于Perceiver Encoder的One-hot预处理器,用于将一个虚拟的索引维度添加到输入中。
class PerceiverOneHotPreprocessor(AbstractPreprocessor):
    """
    One-hot preprocessor for Perceiver Encoder. Can be used to add a dummy index dimension to the input.

    Args:
        config ([`PerceiverConfig`]):
            Model configuration.
    """

    def __init__(self, config: PerceiverConfig) -> None:
        super().__init__()
        self.config: PerceiverConfig = config

    @property
    def num_channels(self) -> int:
        # 返回配置中定义的标签数,作为通道数
        return self.config.num_labels

    def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
        # 添加一个虚拟的索引维度到输入张量中
        inputs = inputs[:, None, :]

        # 由于没有位置编码,因此第一个(输入)和第三个(没有位置编码的输入)输出是相同的
        return inputs, None, inputs


class PerceiverAudioPreprocessor(AbstractPreprocessor):
    """
    Audio preprocessing for Perceiver Encoder.

    Args:
        config ([*PerceiverConfig*]):
            Model configuration.
        prep_type (`str`, *optional*, defaults to `"patches"`):
            Preprocessor type to use. Only "patches" is supported.
        samples_per_patch (`int`, *optional*, defaults to 96):
            Number of samples per patch.
        position_encoding_type (`str`, *optional*, defaults to `"fourier"`):
            Type of position encoding to use. Can be "trainable" or "fourier".
        concat_or_add_pos (`str`, *optional*, defaults to `"concat"`):
            How to concatenate the position encoding to the input. Can be "concat" or "add".
        out_channels (`int`, *optional*, defaults to 64):
            Number of channels in the output.
        project_pos_dim (`int`, *optional*, defaults to -1):
            Dimension of the position encoding to project to. If -1, no projection is applied.
        **position_encoding_kwargs (`Dict`, *optional*):
            Keyword arguments for the position encoding.
    """

    def __init__(
        self,
        config,
        prep_type: str = "patches",
        samples_per_patch: int = 96,
        position_encoding_type: str = "fourier",
        concat_or_add_pos: str = "concat",
        out_channels=64,
        project_pos_dim=-1,
        **position_encoding_kwargs,
    ):
    ):
        super().__init__()
        self.config = config

        # 检查预处理类型是否合法,只能是 "patches"
        if prep_type not in ("patches",):
            raise ValueError(f"Prep_type {prep_type} is invalid, can only be 'patches'.")

        # 检查连接或添加位置编码的方式是否合法,只能是 "concat" 或 "add"
        if concat_or_add_pos not in ["concat", "add"]:
            raise ValueError(f"Concat_or_pos {concat_or_add_pos} is invalid, can only be 'concat' or 'add'.")

        # 设置样本每个补丁的数量
        self.samples_per_patch = samples_per_patch
        # 设置位置编码类型
        self.position_encoding_type = position_encoding_type
        # 设置连接或添加位置编码的方式
        self.concat_or_add_pos = concat_or_add_pos
        # 设置位置编码的投影维度
        self.project_pos_dim = project_pos_dim

        # 构建位置编码和位置投影
        self.position_embeddings, self.positions_projection = build_position_encoding(
            position_encoding_type=position_encoding_type,
            out_channels=out_channels,
            project_pos_dim=project_pos_dim,
            **position_encoding_kwargs,
        )

    @property
    def num_channels(self) -> int:
        # 位置编码维度
        if self.project_pos_dim > 0:
            pos_dim = self.project_pos_dim
        else:
            pos_dim = self.position_embeddings.output_size()
        # 根据连接或添加位置编码的方式确定通道数
        if self.concat_or_add_pos == "add":
            return pos_dim
        return self.samples_per_patch + pos_dim

    def _build_network_inputs(self, inputs):
        """Construct the final input, including position encoding."""
        batch_size = inputs.shape[0]
        index_dims = inputs.shape[1:-1]

        # 构建位置编码
        if self.position_encoding_type == "trainable":
            pos_enc = self.position_embeddings(batch_size)
        elif self.position_encoding_type == "fourier":
            pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype)

        # 可选择性地将位置编码投影到目标维度
        pos_enc = self.positions_projection(pos_enc)

        # 根据连接或添加位置编码的方式,合并输入数据和位置编码
        if self.concat_or_add_pos == "concat":
            inputs_with_pos = torch.cat([inputs, pos_enc], dim=-1)
        elif self.concat_or_add_pos == "add":
            inputs_with_pos = inputs + pos_enc

        return inputs_with_pos, inputs  # 返回带位置编码和原始输入的数据

    def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
        inputs = torch.reshape(inputs, [inputs.shape[0], -1, self.samples_per_patch])

        # 构建网络的输入,包括位置编码
        inputs, inputs_without_pos = self._build_network_inputs(inputs)
        modality_sizes = None  # 用于多模态的每个模态的大小

        return inputs, modality_sizes, inputs_without_pos
    """
    Multimodal preprocessing for Perceiver Encoder.

    Inputs for each modality are preprocessed, then padded with trainable position embeddings to have the same number
    of channels.

    Args:
        modalities (`Mapping[str, PreprocessorType]`):
            Dict mapping modality name to preprocessor.
        mask_probs (`Dict[str, float]`):
            Dict mapping modality name to masking probability of that modality.
        min_padding_size (`int`, *optional*, defaults to 2):
            The minimum padding size for all modalities. The final output will have num_channels equal to the maximum
            channels across all modalities plus min_padding_size.
    """

    def __init__(
        self,
        modalities: Mapping[str, PreprocessorType],
        mask_probs: Optional[Mapping[str, float]] = None,
        min_padding_size: int = 2,
    ):
        super().__init__()
        # 使用 nn.ModuleDict 封装各个模态的预处理器
        self.modalities = nn.ModuleDict(modalities)
        # 设置最小填充大小
        self.min_padding_size = min_padding_size
        # 如果提供了遮罩概率字典,则使用该字典;否则为空字典
        self.mask_probs = mask_probs if mask_probs is not None else {}
        # 初始化填充参数,为每个模态创建一个可训练的位置填充向量
        self.padding = nn.ParameterDict(
            {
                modality: nn.Parameter(torch.randn(1, self.num_channels - preprocessor.num_channels))
                for modality, preprocessor in modalities.items()
            }
        )
        # 初始化遮罩参数,为每个模态创建一个可训练的遮罩向量
        self.mask = nn.ParameterDict(
            {modality: nn.Parameter(torch.randn(1, self.num_channels)) for modality, _ in self.mask_probs.items()}
        )

    @property
    def num_channels(self) -> int:
        # 计算所有模态中最大通道数,并加上最小填充大小,得到公共通道数
        max_channel_size = max(processor.num_channels for _, processor in self.modalities.items())
        common_channel_size = max_channel_size + self.min_padding_size
        return common_channel_size

    def forward(
        self, inputs: Mapping[str, torch.Tensor], pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True
    ):
        # 实现前向传播的方法,处理输入数据和位置信息
    ) -> PreprocessorOutputType:
        # 初始化空字典用于存储填充后的输出
        padded = {}
        # 初始化空字典用于存储每个模态的输出大小
        modality_sizes = {}
        # 初始化空字典用于存储没有位置编码的输入
        inputs_without_pos = {}

        # 遍历每个模态和其对应的预处理器
        for modality, preprocessor in self.modalities.items():
            # 使用对应的预处理器处理每个模态的输入
            # 获取预处理后的输出、位置编码和没有位置编码的输入
            output, _, inputs_without_pos[modality] = preprocessor(
                inputs[modality], pos=pos, network_input_is_1d=network_input_is_1d
            )

            # 对输出进行填充到相同的 common_channel_size
            batch_size, num_samples, num_channels = output.shape
            # 扩展位置编码以匹配输出的形状
            pos_enc = self.padding[modality].expand(batch_size, -1, -1)

            # 使用广播方式创建填充张量,使其与输出的通道数匹配
            padding = torch.broadcast_to(
                pos_enc,
                [batch_size, num_samples, self.num_channels - num_channels],
            )
            # 在通道维度上连接输出和填充部分
            output_padded = torch.cat([output, padding], dim=2)

            # 如果需要,进行掩码操作
            if modality in self.mask_probs:
                # 获取模态对应的掩码标记并扩展以匹配输出形状
                mask_token = self.mask[modality].expand(batch_size, -1, -1)
                mask_prob = self.mask_probs[modality]
                # 使用伯努利分布生成掩码
                mask = torch.bernoulli(torch.full([batch_size, num_samples], mask_prob))
                mask = torch.unsqueeze(mask, dim=2).to(mask_token.device)
                # 应用掩码到填充后的输出
                output_padded = (1 - mask) * output_padded + mask * mask_token

            # 将填充后的输出存储到对应的模态键下
            padded[modality] = output_padded
            # 记录每个模态填充后的输出大小
            modality_sizes[modality] = output_padded.shape[1]

        # 将填充后的输出按照模态键排序形成列表
        padded_ls = [padded[k] for k in sorted(padded.keys())]

        # 最终将所有模态的填充输出沿时间维度连接起来
        final_inputs = torch.cat(padded_ls, dim=1)

        # 返回最终的填充后的输入、每个模态的输出大小和没有位置编码的输入
        return final_inputs, modality_sizes, inputs_without_pos
posted @ 2024-06-29 15:48  绝不原创的飞龙  阅读(26)  评论(0编辑  收藏  举报