Transformers-源码解析-八十六-

Transformers 源码解析(八十六)

.\models\owlvit\processing_owlvit.py

# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image/Text processor class for OWL-ViT
"""

import warnings
from typing import List

import numpy as np

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...utils import is_flax_available, is_tf_available, is_torch_available

class OwlViTProcessor(ProcessorMixin):
    r"""
    Constructs an OWL-ViT processor which wraps [`OwlViTImageProcessor`] and [`CLIPTokenizer`]/[`CLIPTokenizerFast`]
    into a single processor that interits both the image processor and tokenizer functionalities. See the
    [`~OwlViTProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more information.

    Args:
        image_processor ([`OwlViTImageProcessor`], *optional*):
            The image processor is a required input.
        tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`], *optional*):
            The tokenizer is a required input.
    """

    attributes = ["image_processor", "tokenizer"]
    # 类属性,包含了要初始化的属性名称列表
    image_processor_class = "OwlViTImageProcessor"
    # 类属性,指定图像处理器的类名
    tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast")
    # 类属性,指定了两种可能的标记化器类名

    def __init__(self, image_processor=None, tokenizer=None, **kwargs):
        feature_extractor = None
        # 初始化特征提取器为 None
        if "feature_extractor" in kwargs:
            # 如果在参数中有 'feature_extractor',发出警告
            warnings.warn(
                "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
                " instead.",
                FutureWarning,
            )
            feature_extractor = kwargs.pop("feature_extractor")
            # 弹出 'feature_extractor' 参数并将其赋给特征提取器

        image_processor = image_processor if image_processor is not None else feature_extractor
        # 如果没有指定图像处理器,则使用特征提取器(如果有的话)
        if image_processor is None:
            raise ValueError("You need to specify an `image_processor`.")
            # 如果图像处理器为空,则抛出值错误异常
        if tokenizer is None:
            raise ValueError("You need to specify a `tokenizer`.")
            # 如果标记化器为空,则抛出值错误异常

        super().__init__(image_processor, tokenizer)
        # 调用父类的初始化方法,传入图像处理器和标记化器作为参数

    def post_process(self, *args, **kwargs):
        """
        This method forwards all its arguments to [`OwlViTImageProcessor.post_process`]. Please refer to the docstring
        of this method for more information.
        """
        return self.image_processor.post_process(*args, **kwargs)
        # 调用图像处理器的后处理方法,并将所有参数转发给它
    def post_process_object_detection(self, *args, **kwargs):
        """
        将所有参数转发到 `OwlViTImageProcessor.post_process_object_detection` 方法中。
        请参阅该方法的文档字符串获取更多信息。
        """
        return self.image_processor.post_process_object_detection(*args, **kwargs)

    def post_process_image_guided_detection(self, *args, **kwargs):
        """
        将所有参数转发到 `OwlViTImageProcessor.post_process_one_shot_object_detection` 方法中。
        请参阅该方法的文档字符串获取更多信息。
        """
        return self.image_processor.post_process_image_guided_detection(*args, **kwargs)

    def batch_decode(self, *args, **kwargs):
        """
        将所有参数转发到 CLIPTokenizerFast 的 `~PreTrainedTokenizer.batch_decode` 方法中。
        请参阅该方法的文档字符串获取更多信息。
        """
        return self.tokenizer.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        """
        将所有参数转发到 CLIPTokenizerFast 的 `~PreTrainedTokenizer.decode` 方法中。
        请参阅该方法的文档字符串获取更多信息。
        """
        return self.tokenizer.decode(*args, **kwargs)

    @property
    def feature_extractor_class(self):
        """
        警告:`feature_extractor_class` 已弃用,并将在 v5 版本中移除。请使用 `image_processor_class` 替代。
        返回 `image_processor_class`。
        """
        warnings.warn(
            "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.",
            FutureWarning,
        )
        return self.image_processor_class

    @property
    def feature_extractor(self):
        """
        警告:`feature_extractor` 已弃用,并将在 v5 版本中移除。请使用 `image_processor` 替代。
        返回 `image_processor`。
        """
        warnings.warn(
            "`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.",
            FutureWarning,
        )
        return self.image_processor

.\models\owlvit\__init__.py

# 版权声明和许可证信息,指明代码版权归 HuggingFace Team 所有,使用 Apache License 2.0 许可
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入类型检查工具
from typing import TYPE_CHECKING

# 导入依赖项检查函数和懒加载模块
from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_flax_available,
    is_tf_available,
    is_tokenizers_available,
    is_torch_available,
    is_vision_available,
)

# 定义模块的导入结构
_import_structure = {
    "configuration_owlvit": [
        "OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP",
        "OwlViTConfig",
        "OwlViTOnnxConfig",
        "OwlViTTextConfig",
        "OwlViTVisionConfig",
    ],
    "processing_owlvit": ["OwlViTProcessor"],
}

# 检查是否存在视觉处理库,若不存在则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_vision_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若存在视觉处理库,则添加视觉特征提取和图像处理到导入结构
    _import_structure["feature_extraction_owlvit"] = ["OwlViTFeatureExtractor"]
    _import_structure["image_processing_owlvit"] = ["OwlViTImageProcessor"]

# 检查是否存在 Torch 库,若不存在则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若存在 Torch 库,则添加模型相关的导入到导入结构
    _import_structure["modeling_owlvit"] = [
        "OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
        "OwlViTModel",
        "OwlViTPreTrainedModel",
        "OwlViTTextModel",
        "OwlViTVisionModel",
        "OwlViTForObjectDetection",
    ]

# 如果是类型检查阶段,则导入配置和处理模块
if TYPE_CHECKING:
    from .configuration_owlvit import (
        OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
        OwlViTConfig,
        OwlViTOnnxConfig,
        OwlViTTextConfig,
        OwlViTVisionConfig,
    )
    from .processing_owlvit import OwlViTProcessor

    # 在类型检查阶段,若存在视觉处理库,则导入视觉特征提取和图像处理模块
    try:
        if not is_vision_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .feature_extraction_owlvit import OwlViTFeatureExtractor
        from .image_processing_owlvit import OwlViTImageProcessor

    # 在类型检查阶段,若存在 Torch 库,则导入模型相关的模块
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_owlvit import (
            OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
            OwlViTForObjectDetection,
            OwlViTModel,
            OwlViTPreTrainedModel,
            OwlViTTextModel,
            OwlViTVisionModel,
        )

# 如果不是类型检查阶段,则将当前模块替换为懒加载模块以支持动态导入
else:
    import sys

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\patchtsmixer\configuration_patchtsmixer.py

# 设置编码格式为 UTF-8
# 版权声明和许可协议,指定代码使用许可
# 导入所需的模块和函数
from typing import List, Optional, Union

# 导入预训练配置类 PretrainedConfig
from ...configuration_utils import PretrainedConfig
# 导入日志记录工具
from ...utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 预训练模型名称到配置文件的映射字典
PATCHTSMIXER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "ibm/patchtsmixer-etth1-pretrain": "https://huggingface.co/ibm/patchtsmixer-etth1-pretrain/resolve/main/config.json",
}

# PatchTSMixerConfig 类,用于存储 PatchTSMixer 模型的配置信息
class PatchTSMixerConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`PatchTSMixerModel`]. It is used to instantiate a
    PatchTSMixer model according to the specified arguments, defining the model architecture. Instantiating a
    configuration with the defaults will yield a similar configuration to that of the PatchTSMixer
    [ibm/patchtsmixer-etth1-pretrain](https://huggingface.co/ibm/patchtsmixer-etth1-pretrain) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Example:

    ```
    >>> from transformers import PatchTSMixerConfig, PatchTSMixerModel

    >>> # Initializing a default PatchTSMixer configuration
    >>> configuration = PatchTSMixerConfig()

    >>> # Randomly initializing a model (with random weights) from the configuration
    >>> model = PatchTSMixerModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    """

    # 模型类型为 "patchtsmixer"
    model_type = "patchtsmixer"

    # 属性映射字典,将 PatchTSMixerConfig 的属性名映射到其他标准属性名
    attribute_map = {
        "hidden_size": "d_model",               # 隐藏层大小映射到 d_model
        "num_hidden_layers": "num_layers",      # 隐藏层数量映射到 num_layers
    }

# 注释结束
    def __init__(
        self,
        # Time series specific configuration
        context_length: int = 32,  # 定义时间序列的上下文长度,默认为32
        patch_length: int = 8,  # 定义用于处理的补丁长度,默认为8
        num_input_channels: int = 1,  # 输入的通道数,默认为1
        patch_stride: int = 8,  # 补丁的步长,默认为8
        num_parallel_samples: int = 100,  # 并行采样的数量,默认为100
        # General model configuration
        d_model: int = 8,  # 模型的隐藏单元数,默认为8
        expansion_factor: int = 2,  # 扩展因子,默认为2
        num_layers: int = 3,  # 模型层数,默认为3
        dropout: float = 0.2,  # Dropout 的比率,默认为0.2
        mode: str = "common_channel",  # 模型的工作模式,默认为"common_channel"
        gated_attn: bool = True,  # 是否使用门控注意力,默认为True
        norm_mlp: str = "LayerNorm",  # MLP 归一化类型,默认为"LayerNorm"
        self_attn: bool = False,  # 是否使用自注意力,默认为False
        self_attn_heads: int = 1,  # 自注意力头的数量,默认为1
        use_positional_encoding: bool = False,  # 是否使用位置编码,默认为False
        positional_encoding_type: str = "sincos",  # 位置编码的类型,默认为"sincos"
        scaling: Optional[Union[str, bool]] = "std",  # 缩放的方式,默认为"std"
        loss: str = "mse",  # 损失函数类型,默认为"mse"
        init_std: float = 0.02,  # 初始化标准差,默认为0.02
        post_init: bool = False,  # 是否在初始化后执行后处理,默认为False
        norm_eps: float = 1e-5,  # 归一化的小常数,默认为1e-5
        # Pretrain model configuration
        mask_type: str = "random",  # 掩码类型,默认为"random"
        random_mask_ratio: float = 0.5,  # 随机掩码的比率,默认为0.5
        num_forecast_mask_patches: Optional[Union[List[int], int]] = [2],  # 预测掩码的补丁数量,默认为[2]
        mask_value: int = 0,  # 掩码值,默认为0
        masked_loss: bool = True,  # 是否使用掩码损失,默认为True
        channel_consistent_masking: bool = True,  # 是否通道一致的掩码,默认为True
        unmasked_channel_indices: Optional[List[int]] = None,  # 未掩码的通道索引,默认为None
        # General head configuration
        head_dropout: float = 0.2,  # 头部的Dropout比率,默认为0.2
        distribution_output: str = "student_t",  # 分布输出类型,默认为"student_t"
        # Prediction head configuration
        prediction_length: int = 16,  # 预测长度,默认为16
        prediction_channel_indices: list = None,  # 预测的通道索引,默认为None
        # Classification/Regression configuration
        num_targets: int = 3,  # 目标数量,默认为3
        output_range: list = None,  # 输出范围,默认为None
        head_aggregation: str = "max_pool",  # 头部聚合方法,默认为"max_pool"
        **kwargs,
        ):
        self.num_input_channels = num_input_channels
        # 输入通道数,用于模型输入数据的通道数目
        self.context_length = context_length
        # 上下文长度,表示模型处理输入数据时的上下文窗口大小
        self.patch_length = patch_length
        # 补丁长度,指定模型用于处理输入数据的每个补丁的长度
        self.patch_stride = patch_stride
        # 补丁步长,指定模型在输入数据上滑动补丁时的步长
        self.d_model = d_model
        # 模型维度,表示模型中注意力机制的向量维度
        self.expansion_factor = expansion_factor
        # 扩展因子,用于指定模型在进行特征映射时的扩展因子大小
        self.num_layers = num_layers
        # 层数,表示模型中堆叠的自注意力层或前馈网络层的数量
        self.dropout = dropout
        # 丢弃率,指定模型在训练时用于防止过拟合的丢弃率
        self.mode = mode
        # 模式,指定模型的操作模式,如训练模式或推理模式
        self.gated_attn = gated_attn
        # 门控注意力,指定模型是否使用门控机制增强注意力机制
        self.norm_mlp = norm_mlp
        # 归一化MLP,指定模型是否使用归一化操作来规范MLP层
        self.scaling = scaling
        # 缩放,指定模型中注意力机制的缩放因子
        self.head_dropout = head_dropout
        # 头部丢弃率,指定模型中多头注意力机制的丢弃率
        self.num_patches = (max(context_length, patch_length) - patch_length) // patch_stride + 1
        # 补丁数量,根据上下文长度、补丁长度和步长计算模型需要处理的补丁数量
        self.mask_type = mask_type
        # 掩码类型,指定模型中使用的掩码类型,如随机掩码或预测掩码
        self.random_mask_ratio = random_mask_ratio
        # 随机掩码比率,指定模型中随机掩码的比率
        self.num_forecast_mask_patches = num_forecast_mask_patches
        # 预测掩码补丁数量,指定模型中用于预测的掩码补丁的数量
        self.mask_value = mask_value
        # 掩码值,指定模型中用于掩码的特定数值
        self.channel_consistent_masking = channel_consistent_masking
        # 通道一致掩码,指定模型中是否进行通道一致的掩码处理
        self.masked_loss = masked_loss
        # 掩码损失,指定模型中是否使用掩码损失函数
        self.patch_last = True
        # 补丁最后,指定模型是否将补丁处理放在最后执行
        self.use_positional_encoding = use_positional_encoding
        # 使用位置编码,指定模型是否使用位置编码来增强输入数据的位置信息
        self.positional_encoding_type = positional_encoding_type
        # 位置编码类型,指定模型中使用的位置编码的类型
        self.prediction_length = prediction_length
        # 预测长度,指定模型中输出的预测长度
        self.prediction_channel_indices = prediction_channel_indices
        # 预测通道索引,指定模型中用于预测的通道索引
        self.num_targets = num_targets
        # 目标数量,指定模型中预测的目标数量
        self.output_range = output_range
        # 输出范围,指定模型中预测输出的范围
        self.head_aggregation = head_aggregation
        # 头部聚合,指定模型中多头注意力机制的聚合方式
        self.self_attn = self_attn
        # 自注意力,指定模型是否使用自注意力机制
        self.self_attn_heads = self_attn_heads
        # 自注意力头数,指定模型中自注意力机制的头数
        self.init_std = init_std
        # 初始化标准差,指定模型中参数初始化的标准差
        self.post_init = post_init
        # 后初始化,指定模型在初始化后执行的操作
        self.distribution_output = distribution_output
        # 分布输出,指定模型输出的分布类型
        self.loss = loss
        # 损失函数,指定模型中使用的损失函数
        self.num_parallel_samples = num_parallel_samples
        # 并行样本数,指定模型中每次推理时的并行样本数量
        self.unmasked_channel_indices = unmasked_channel_indices
        # 未掩码通道索引,指定模型中不需要进行掩码处理的通道索引
        self.norm_eps = norm_eps
        # 归一化epsilon,指定模型中归一化操作的epsilon值
        super().__init__(**kwargs)
        # 调用父类初始化方法,传入额外的关键字参数

.\models\patchtsmixer\modeling_patchtsmixer.py

# 引入 math 模块,用于数学计算函数
import math
# 引入 dataclass 模块,用于定义数据类
from dataclasses import dataclass
# 引入 Optional、Tuple、Union 类型,用于类型提示
from typing import Optional, Tuple, Union

# 引入 PyTorch 模块
import torch
# 引入 PyTorch 的神经网络模块
import torch.nn as nn

# 引入 Transformers 库中的预训练模型基类 PreTrainedModel
from transformers.modeling_utils import PreTrainedModel
# 引入 Transformers 库中的模型输出类 ModelOutput
from transformers.utils import ModelOutput

# 引入日志记录工具、文档字符串添加函数、返回值替换函数等实用工具函数
from ...utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
# 引入 PatchTSMixer 的配置类
from .configuration_patchtsmixer import PatchTSMixerConfig

# 获取日志记录器
logger = logging.get_logger(__name__)

# PatchTSMixer 预训练模型存档列表
PATCHTSMIXER_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "ibm/patchtsmixer-etth1-pretrain",
    # 更多 PatchTSMixer 模型可在 https://huggingface.co/models?filter=patchtsmixer 查看
]

# PatchTSMixer 模型文档的起始字符串
PATCHTSMIXER_START_DOCSTRING = r"""

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`PatchTSMixerConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
        mask_input (`bool`, *optional*, defaults to `False`):
            If True, Masking will be enabled. False otherwise.
"""

# PatchTSMixer 模型输入参数文档字符串
PATCHTSMIXER_INPUTS_DOCSTRING = r"""
    # 定义函数的参数和类型注解,输入参数为过去时间序列的值
    # 对于预训练任务,这表示要预测掩码部分的输入时间序列;对于预测任务,这表示历史/过去的时间序列值;
    # 对于分类或回归任务,这表示时间序列的上下文值。
    # 对于单变量时间序列,num_input_channels 维度应为 1;对于多变量时间序列,它大于 1。
    Args:
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to predict
            the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
            for classification or regression tasks, it denotes the appropriate context values of the time series.

        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers.

        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class PatchTSMixerGatedAttention(nn.Module):
    """
    Module that applies gated attention to input data.

    Args:
        in_size (`int`): The input size.
        out_size (`int`): The output size.
    """

    def __init__(self, in_size: int, out_size: int):
        super().__init__()
        # Linear layer for computing attention weights
        self.attn_layer = nn.Linear(in_size, out_size)
        # Softmax activation to normalize attention weights across input dimensions
        self.attn_softmax = nn.Softmax(dim=-1)

    def forward(self, inputs):
        # Calculate attention weights using linear layer and apply softmax
        attn_weight = self.attn_softmax(self.attn_layer(inputs))
        # Apply gated attention mechanism to input data
        inputs = inputs * attn_weight
        return inputs


# Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTBatchNorm with PatchTST->PatchTSMixer
class PatchTSMixerBatchNorm(nn.Module):
    """
    Compute batch normalization over the sequence length (time) dimension.
    """

    def __init__(self, config: PatchTSMixerConfig):
        super().__init__()
        # Batch normalization across the d_model dimension
        self.batchnorm = nn.BatchNorm1d(config.d_model, eps=config.norm_eps)

    def forward(self, inputs: torch.Tensor):
        """
        Parameters:
            inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`):
                input for Batch norm calculation
        Returns:
            `torch.Tensor` of shape `(batch_size, sequence_length, d_model)`
        """
        # Transpose input tensor to match expected shape for batch normalization
        output = inputs.transpose(1, 2)  # output: (batch_size, d_model, sequence_length)
        # Apply batch normalization along the d_model dimension
        output = self.batchnorm(output)
        # Transpose output back to original shape
        return output.transpose(1, 2)


class PatchTSMixerPositionalEncoding(nn.Module):
    """
    Class for positional encoding
    """

    def __init__(self, config: PatchTSMixerConfig):
        super().__init__()
        # positional encoding initialization based on config settings
        if config.use_positional_encoding:
            self.position_enc = self._init_pe(config)
        else:
            # Initialize positional encoding as a parameter tensor filled with zeros
            self.position_enc = nn.Parameter(torch.zeros(config.num_patches, config.d_model))

    @staticmethod
    def _init_pe(config: PatchTSMixerConfig) -> nn.Parameter:
        # Positional encoding initialization based on configuration
        # 根据配置初始化位置编码

        # If positional encoding type is 'random', initialize with random values
        if config.positional_encoding_type == "random":
            position_enc = nn.Parameter(torch.randn(config.num_patches, config.d_model), requires_grad=True)
            # 使用随机值初始化位置编码张量

        # If positional encoding type is 'sincos', initialize with sine and cosine positional encodings
        elif config.positional_encoding_type == "sincos":
            position_enc = torch.zeros(config.num_patches, config.d_model)
            position = torch.arange(0, config.num_patches).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, config.d_model, 2) * -(math.log(10000.0) / config.d_model))
            position_enc[:, 0::2] = torch.sin(position * div_term)
            position_enc[:, 1::2] = torch.cos(position * div_term)
            position_enc = position_enc - position_enc.mean()
            position_enc = position_enc / (position_enc.std() * 10)
            position_enc = nn.Parameter(position_enc, requires_grad=False)
            # 使用sin和cos函数生成位置编码张量,并进行标准化处理

        else:
            # Raise an error if an unsupported positional encoding type is provided
            raise ValueError(
                f"{config.positional_encoding_type} is not a valid positional encoder. Available types are 'random' and 'sincos'."
            )
            # 如果提供了不支持的位置编码类型,则引发错误

        return position_enc
        # 返回位置编码张量作为模型参数

    def forward(self, patch_input: torch.Tensor):
        # Calculate the hidden state by adding positional encoding to patch input
        # 计算隐藏状态,将位置编码添加到补丁输入中
        hidden_state = patch_input + self.position_enc
        return hidden_state
        # 返回隐藏状态张量作为前向传播的输出
class PatchTSMixerNormLayer(nn.Module):
    """Normalization block

    Args:
        config (`PatchTSMixerConfig`, *required*):
            Configuration.
    """

    def __init__(self, config: PatchTSMixerConfig):
        super().__init__()

        self.norm_mlp = config.norm_mlp

        # 根据配置选择合适的归一化层
        if "batch" in config.norm_mlp.lower():
            # 如果配置中包含"batch",使用批量归一化层
            self.norm = PatchTSMixerBatchNorm(config)
        else:
            # 否则使用 Layer Normalization,设置 epsilon 为 config.norm_eps
            self.norm = nn.LayerNorm(config.d_model, eps=config.norm_eps)

    def forward(self, inputs: torch.Tensor):
        """
        Args:
            inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
                Input to the normalization layer.
        Returns:
            `torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`
        """
        if "batch" in self.norm_mlp.lower():
            # 重塑数据形状为 [batch_size*num_channels, num_patches, d_model]
            inputs_reshaped = torch.reshape(
                inputs,
                (
                    inputs.shape[0] * inputs.shape[1],
                    inputs.shape[2],
                    inputs.shape[3],
                ),
            )

            # 对重塑后的数据进行归一化处理
            inputs_reshaped = self.norm(inputs_reshaped)

            # 恢复数据到原始形状
            inputs = torch.reshape(inputs_reshaped, inputs.shape)

        else:
            # 使用选择的归一化层处理输入
            inputs = self.norm(inputs)

        return inputs


class PatchTSMixerMLP(nn.Module):
    def __init__(self, in_features, out_features, config):
        super().__init__()
        num_hidden = in_features * config.expansion_factor
        self.fc1 = nn.Linear(in_features, num_hidden)
        self.dropout1 = nn.Dropout(config.dropout)
        self.fc2 = nn.Linear(num_hidden, out_features)
        self.dropout2 = nn.Dropout(config.dropout)

    def forward(self, inputs: torch.Tensor):
        """
        Args:
            inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
                Input to the MLP layer.
        Returns:
            `torch.Tensor` of the same shape as `inputs`
        """
        # 第一层全连接 + GELU 激活 + dropout
        inputs = self.dropout1(nn.functional.gelu(self.fc1(inputs)))
        # 第二层全连接 + dropout
        inputs = self.fc2(inputs)
        inputs = self.dropout2(inputs)
        return inputs


class PatchTSMixerChannelFeatureMixerBlock(nn.Module):
    """This module mixes the features in the channel dimension.

    Args:
        config (`PatchTSMixerConfig`, *required*):
            Configuration.
    """
    # 初始化函数,接受一个配置对象 `PatchTSMixerConfig` 作为参数
    def __init__(self, config: PatchTSMixerConfig):
        # 调用父类的初始化方法
        super().__init__()

        # 创建一个 PatchTSMixerNormLayer 层,并将其赋值给 self.norm
        self.norm = PatchTSMixerNormLayer(config)
        
        # 将配置对象中的 gated_attn 属性赋值给 self.gated_attn
        self.gated_attn = config.gated_attn
        
        # 创建一个 PatchTSMixerMLP 实例,并将其赋值给 self.mlp
        self.mlp = PatchTSMixerMLP(
            in_features=config.num_input_channels,
            out_features=config.num_input_channels,
            config=config,
        )

        # 如果配置中的 gated_attn 为 True,则创建一个 PatchTSMixerGatedAttention 实例,并将其赋值给 self.gating_block
        if config.gated_attn:
            self.gating_block = PatchTSMixerGatedAttention(
                in_size=config.num_input_channels, out_size=config.num_input_channels
            )

    # 前向传播函数
    def forward(self, inputs: torch.Tensor):
        """
        Args:
            inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
                input to the MLP layer
        Returns:
            `torch.Tensor` of the same shape as `inputs`
        """
        # 将输入的张量保存为 residual,用于后续的残差连接
        residual = inputs
        
        # 对输入进行归一化处理
        inputs = self.norm(inputs)

        # 将张量维度重新排列为 (batch_size, d_model, num_patches, num_channels)
        inputs = inputs.permute(0, 3, 2, 1)

        # 如果存在 gated_attn,对输入应用 gating_block 进行注意力机制操作
        if self.gated_attn:
            inputs = self.gating_block(inputs)

        # 通过 MLP 层处理输入张量
        inputs = self.mlp(inputs)

        # 将张量维度重新排列为原始顺序 (batch_size, num_channels, num_patches, d_model)
        inputs = inputs.permute(0, 3, 2, 1)

        # 将处理后的张量与残差张量相加,得到最终输出
        out = inputs + residual
        return out
# 从transformers.models.bart.modeling_bart.BartAttention复制到PatchTSMixerAttention并将Bart改为PatchTSMixer
class PatchTSMixerAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        bias: bool = True,
        is_causal: bool = False,
        config: Optional[PatchTSMixerConfig] = None,
    ):
        super().__init__()
        self.embed_dim = embed_dim  # 初始化注意力机制的输入维度
        self.num_heads = num_heads  # 注意力头的数量
        self.dropout = dropout  # dropout概率
        self.head_dim = embed_dim // num_heads  # 每个注意力头的维度
        self.config = config  # PatchTSMixer的配置对象

        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {num_heads})."
            )
        self.scaling = self.head_dim**-0.5  # 缩放因子,用于缩放点积注意力的输出
        self.is_decoder = is_decoder  # 是否用作解码器
        self.is_causal = is_causal  # 是否是因果注意力

        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)  # 线性变换,用于生成key
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)  # 线性变换,用于生成value
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)  # 线性变换,用于生成query
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)  # 最终的输出线性变换

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
        # 重塑张量形状以支持多头注意力计算,然后转置维度以便正确计算注意力

    def forward(
        self,
        hidden_states: torch.Tensor,
        key_value_states: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ):
        # 前向传播函数,实现注意力机制的计算



class PatchMixerBlock(nn.Module):
    """This module mixes the patch dimension.

    Args:
        config (`PatchTSMixerConfig`, *required*):
            Configuration.
    """

    def __init__(self, config: PatchTSMixerConfig):
        super().__init__()

        self.norm = PatchTSMixerNormLayer(config)  # 规范化层,用于标准化输入数据

        self.self_attn = config.self_attn  # 是否使用自注意力机制
        self.gated_attn = config.gated_attn  # 是否使用门控注意力机制

        self.mlp = PatchTSMixerMLP(
            in_features=config.num_patches,
            out_features=config.num_patches,
            config=config,
        )
        # 多层感知机,用于处理补丁维度

        if config.gated_attn:
            self.gating_block = PatchTSMixerGatedAttention(in_size=config.num_patches, out_size=config.num_patches)
            # 如果使用门控注意力,初始化门控注意力模块

        if config.self_attn:
            self.self_attn_layer = PatchTSMixerAttention(
                embed_dim=config.d_model,
                num_heads=config.self_attn_heads,
                dropout=config.dropout,
            )
            self.norm_attn = PatchTSMixerNormLayer(config)
            # 如果使用自注意力,初始化自注意力层和相应的规范化层
    def forward(self, hidden_state):
        """
        Args:
            hidden_state (`torch.Tensor`): Input tensor.

        Returns:
            `torch.Tensor`: Transformed tensor.
        """
        # 保存输入张量作为残差连接的基准
        residual = hidden_state

        # 应用层归一化到输入张量
        hidden_state = self.norm(hidden_state)

        # 如果使用自注意力机制
        if self.self_attn:
            # 获取张量的形状信息
            batch_size, n_vars, num_patches, d_model = hidden_state.shape
            # 重塑张量以便进行自注意力操作
            hidden_state_reshaped = hidden_state.reshape(batch_size * n_vars, num_patches, d_model)

            # 应用自注意力层,关闭注意力输出选项
            x_attn, _, _ = self.self_attn_layer(hidden_state_reshaped, output_attentions=False)
            # 将输出张量重塑回原始形状
            x_attn = x_attn.reshape(batch_size, n_vars, num_patches, d_model)

        # 将张量转置,使得 num_patches 成为最后一个维度
        hidden_state = hidden_state.transpose(2, 3)
        # 应用多层感知机(MLP)转换
        hidden_state = self.mlp(hidden_state)

        # 如果使用门控注意力机制
        if self.gated_attn:
            # 应用门控块
            hidden_state = self.gating_block(hidden_state)

        # 再次将张量转置回原始形状
        hidden_state = hidden_state.transpose(2, 3)

        # 如果使用自注意力机制,应用层归一化到注意力输出和输入张量的残差连接
        if self.self_attn:
            hidden_state = self.norm_attn(hidden_state + x_attn)

        # 将残差连接的结果添加到变换后的张量上,作为最终输出
        out = hidden_state + residual
        return out
class FeatureMixerBlock(nn.Module):
    """This module mixes the hidden feature dimension.

    Args:
        config (`PatchTSMixerConfig`, *required`):
            Configuration.

    """

    def __init__(self, config: PatchTSMixerConfig):
        super().__init__()

        # 初始化层:使用给定的配置初始化规范化层
        self.norm = PatchTSMixerNormLayer(config)

        # 获取配置中的门控注意力标志
        self.gated_attn = config.gated_attn

        # 初始化层:使用给定的配置初始化多层感知机(MLP)层
        self.mlp = PatchTSMixerMLP(
            in_features=config.d_model,
            out_features=config.d_model,
            config=config,
        )

        # 如果配置中包含门控注意力,则初始化门控注意力块
        if config.gated_attn:
            self.gating_block = PatchTSMixerGatedAttention(in_size=config.d_model, out_size=config.d_model)

    def forward(self, hidden: torch.Tensor):
        """
        Args:
            hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`):
                Input tensor to the layer.

        Returns:
            `torch.Tensor`: Transformed tensor.
        """
        # 保存输入张量作为残差连接的一部分
        residual = hidden

        # 对输入张量进行规范化处理
        hidden = self.norm(hidden)

        # 通过多层感知机处理规范化后的张量
        hidden = self.mlp(hidden)

        # 如果启用了门控注意力,则使用门控注意力块处理张量
        if self.gated_attn:
            hidden = self.gating_block(hidden)

        # 将处理后的张量与残差连接起来作为输出
        out = hidden + residual
        return out


class PatchTSMixerLayer(nn.Module):
    """
    The `PatchTSMixer` layer that does all three kinds of mixing.

    Args:
        config (`PatchTSMixerConfig`, *required`):
            Configuration.

    """

    def __init__(self, config: PatchTSMixerConfig):
        super().__init__()

        # 初始化层:使用给定的配置初始化PatchMixerBlock
        self.patch_mixer = PatchMixerBlock(config=config)

        # 初始化层:使用给定的配置初始化FeatureMixerBlock
        self.feature_mixer = FeatureMixerBlock(config=config)

        # 获取配置中的模式信息
        self.mode = config.mode

        # 如果模式是"mix_channel",则初始化通道特征混合块
        if config.mode == "mix_channel":
            self.channel_feature_mixer = PatchTSMixerChannelFeatureMixerBlock(config=config)

    def forward(self, hidden: torch.Tensor):
        """
        Args:
            hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`):
                Input tensor to the layer.

        Returns:
            `torch.Tensor`: Transformed tensor.
        """
        # 如果模式是"mix_channel",则使用通道特征混合块处理输入张量
        if self.mode == "mix_channel":
            hidden = self.channel_feature_mixer(hidden)

        # 使用PatchMixerBlock处理输入张量
        hidden = self.patch_mixer(hidden)

        # 使用FeatureMixerBlock处理输入张量
        hidden = self.feature_mixer(hidden)  # hidden: (batch_size x num_patches x d_model)
        return hidden


class PatchTSMixerBlock(nn.Module):
    """The main computing framework of the `PatchTSMixer` model.

    Args:
        config (`PatchTSMixerConfig`, *required`):
            Configuration.
    """

    def __init__(self, config: PatchTSMixerConfig):
        super().__init__()

        # 获取层数
        num_layers = config.num_layers

        # 使用循环初始化PatchTSMixerLayer模块列表
        self.mixers = nn.ModuleList([PatchTSMixerLayer(config=config) for _ in range(num_layers)])
    # 定义一个方法 `forward`,用于前向传播计算
    def forward(self, hidden_state, output_hidden_states: bool = False):
        """
        Args:
            hidden_state (`torch.Tensor`): 输入张量。
            output_hidden_states (`bool`, *optional*, 默认为 False):
                是否输出所有隐藏状态。

        Returns:
            `torch.Tensor`: 嵌入结果。 `list`: 如果 `output_hidden_states` 设置为 `True`,则返回所有隐藏状态的列表。
        """
        # 初始化一个空列表,用于存储所有的隐藏状态
        all_hidden_states = []

        # 初始嵌入为输入的隐藏状态张量
        embedding = hidden_state

        # 遍历所有的混合模块
        for mod in self.mixers:
            # 将当前嵌入张量通过混合模块进行处理
            embedding = mod(embedding)
            # 如果设置要输出隐藏状态,则将当前处理后的嵌入张量加入列表中
            if output_hidden_states:
                all_hidden_states.append(embedding)

        # 如果设置要输出隐藏状态,则返回最终的嵌入张量和所有隐藏状态列表
        if output_hidden_states:
            return embedding, all_hidden_states
        # 否则,只返回最终的嵌入张量和空值
        else:
            return embedding, None
class PatchTSMixerForPredictionHead(nn.Module):
    """Prediction Head for Forecasting

    Args:
        config (`PatchTSMixerConfig`, *required*): Configuration.
    """

    def __init__(self, config: PatchTSMixerConfig, distribution_output=None):
        super().__init__()

        self.prediction_channel_indices = config.prediction_channel_indices  # 获取预测通道的索引列表

        if self.prediction_channel_indices is not None:
            self.prediction_channel_indices.sort()  # 如果索引列表不为空,则排序索引

        self.dropout_layer = nn.Dropout(config.head_dropout)  # 创建一个Dropout层,用于随机失活
        if distribution_output is None:
            self.base_forecast_block = nn.Linear((config.num_patches * config.d_model), config.prediction_length)
        else:
            self.base_forecast_block = distribution_output.get_parameter_projection(
                config.num_patches * config.d_model
            )  # 根据分布输出类型选择线性层或其他投影层

        self.flatten = nn.Flatten(start_dim=-2)  # 创建一个展平层,用于展平输入张量

    def forward(self, hidden_features):
        """
        Args:
            hidden_features (`torch.Tensor` of shape `(batch_size, num_patch, d_model)` in `flatten` mode
                or `(batch_size, n_vars, num_patch, d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
                features.

        Returns:
            `torch.Tensor` of shape `(batch_size, prediction_length, nvars)`.

        """

        hidden_features = self.flatten(hidden_features)  # 将输入张量展平成 `(batch_size x n_vars x num_patch * d_model)`
        hidden_features = self.dropout_layer(hidden_features)  # 对展平后的张量进行Dropout操作
        forecast = self.base_forecast_block(hidden_features)  # 使用预测头线性层进行预测
        if isinstance(forecast, tuple):
            forecast = tuple(z.transpose(-1, -2) for z in forecast)  # 如果预测结果是元组,则对每个元素进行维度转置
        else:
            forecast = forecast.transpose(-1, -2)  # 否则,对预测张量进行维度转置为 `(batch_size x prediction_length x n_vars)`

        if self.prediction_channel_indices is not None:
            if isinstance(forecast, tuple):
                forecast = tuple(z[..., self.prediction_channel_indices] for z in forecast)  # 如果有预测通道索引,仅保留指定通道的预测结果
            else:
                forecast = forecast[..., self.prediction_channel_indices]  # 对预测结果张量仅保留指定通道的预测结果

        return forecast


class PatchTSMixerLinearHead(nn.Module):
    """Linear head for Classification and Regression.

    Args:
        config (`PatchTSMixerConfig`, *required*): Configuration.
    """
    def __init__(self, config: PatchTSMixerConfig, distribution_output=None):
        super().__init__()  # 调用父类的初始化方法

        self.head_aggregation = config.head_aggregation  # 设置头部聚合方式
        self.output_range = config.output_range  # 设置输出范围

        if config.head_aggregation is None:
            mul_factor = config.num_patches  # 如果头部聚合方式为None,则设置乘数因子为patch数量
        else:
            mul_factor = 1  # 否则设置乘数因子为1
        self.distribution_output = distribution_output  # 设置分布输出
        if distribution_output is None:
            self.projection = nn.Linear(
                config.d_model * config.num_input_channels * mul_factor,
                config.num_targets,
            )  # 如果分布输出为None,则设置线性投影层
        else:
            self.projection = distribution_output.get_parameter_projection(
                config.d_model * config.num_input_channels * mul_factor
            )  # 否则根据分布输出获取参数投影

        if config.head_aggregation is None:
            self.flatten = nn.Flatten(start_dim=-3)  # 如果头部聚合方式为None,则设置展平层
        else:
            self.flatten = nn.Flatten(start_dim=-2)  # 否则设置展平层

        self.dropout = nn.Dropout(config.head_dropout)  # 设置dropout层

    def forward(self, hidden_features):
        """
        Args:
            hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode
                or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
                features.

        Returns:
            `torch.Tensor` of shape `(batch_size x num_targets)`.
        """

        # 调整hidden_features的维度顺序,将最后两个维度进行交换
        hidden_features = hidden_features.transpose(-1, -2)

        if self.head_aggregation == "use_last":
            # 如果头部聚合方式为"use_last",选择最后一个位置的特征
            hidden_features = hidden_features[..., -1]
        elif self.head_aggregation == "max_pool":
            # 如果头部聚合方式为"max_pool",对最后一个维度进行最大池化操作
            hidden_features = hidden_features.max(dim=-1).values
        elif self.head_aggregation == "avg_pool":
            # 如果头部聚合方式为"avg_pool",对最后一个维度进行平均池化操作
            hidden_features = hidden_features.mean(dim=-1)

        if self.flatten:
            hidden_features = self.flatten(hidden_features)  # 如果需要展平,则进行展平操作
        hidden_features = self.dropout(hidden_features)  # 对特征进行dropout处理
        hidden_features = self.projection(hidden_features)  # 使用投影层进行特征投影

        if (self.distribution_output is None) and (self.output_range is not None):
            # 如果分布输出为None且输出范围不为None,则对输出进行sigmoid归一化处理
            hidden_features = (
                torch.sigmoid(hidden_features) * (self.output_range[1] - self.output_range[0]) + self.output_range[0]
            )

        return hidden_features  # 返回处理后的特征
class PatchTSMixerPreTrainedModel(PreTrainedModel):
    # Weight initialization
    config_class = PatchTSMixerConfig  # 设置配置类为 PatchTSMixerConfig
    base_model_prefix = "model"  # 基础模型前缀设为 "model"
    main_input_name = "past_values"  # 主输入名称设为 "past_values"
    supports_gradient_checkpointing = False  # 不支持梯度检查点

    def _init_weights(self, module):
        """Initialize weights"""
        if isinstance(module, PatchTSMixerPositionalEncoding):
            # initialize positional encoding
            if self.config.positional_encoding_type == "random":
                nn.init.normal_(module.position_enc, mean=0.0, std=0.1)
        elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
            module.bias.data.zero_()  # 将偏置项初始化为零
            module.weight.data.fill_(1.0)  # 将权重初始化为全1
        elif isinstance(module, PatchTSMixerBatchNorm):
            module.batchnorm.bias.data.zero_()  # 将批归一化层的偏置项初始化为零
            module.batchnorm.weight.data.fill_(1.0)  # 将批归一化层的权重初始化为全1
        elif isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.init_std)  # 使用正态分布初始化权重
            if module.bias is not None:
                module.bias.data.zero_()  # 如果存在偏置项,则初始化为零

class PatchTSMixerPretrainHead(nn.Module):
    """Pretraining head.

    Args:
        config (`PatchTSMixerConfig`, *required*):
            Configuration.
    """

    def __init__(self, config: PatchTSMixerConfig):
        super().__init__()

        self.dropout_layer = nn.Dropout(config.head_dropout)  # 使用给定的 dropout 概率创建 dropout 层
        self.base_pt_block = nn.Linear(config.d_model, config.patch_length)  # 创建线性层

    def forward(self, hidden_features):
        """
        Args:
            hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode
                or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
                features.

        Returns:
            `torch.Tensor` of shape `(batch_size x n_vars x num_patch x patch_length)`.
        """

        hidden_features = self.dropout_layer(hidden_features)  # 应用 dropout 层到输入特征上
        forecast = self.base_pt_block(hidden_features)  # 使用线性层进行预测
        return forecast


# Copied from transformers.models.patchtst.modeling_patchtst.random_masking
def random_masking(
    inputs: torch.Tensor,
    mask_ratio: float,
    unmasked_channel_indices: list = None,
    channel_consistent_masking: bool = False,
    mask_value: int = 0,
):
    """random_masking: Mask the input considering the control variables.
    
    Args:
        inputs (torch.Tensor): Input tensor to be masked.
        mask_ratio (float): Ratio of elements to be masked.
        unmasked_channel_indices (list, optional): List of unmasked channel indices.
        channel_consistent_masking (bool, optional): Whether to mask consistently across channels.
        mask_value (int, optional): Value to fill in for masked elements.

    Returns:
        torch.Tensor: Masked input tensor.
    """
    # 检查掩码比例是否在有效范围内
    if mask_ratio < 0 or mask_ratio >= 1:
        raise ValueError(f"Mask ratio {mask_ratio} has to be between 0 and 1.")

    # 获取输入张量的形状信息
    batch_size, num_channels, sequence_length, num_features = inputs.shape
    device = inputs.device

    # 计算不被掩码的数据长度
    len_keep = int(sequence_length * (1 - mask_ratio))

    # 根据channel_consistent_masking的设置生成随机噪声张量
    if channel_consistent_masking:
        noise = torch.rand(batch_size, 1, sequence_length, device=device)  # noise in [0, 1], bs x 1 x L
        noise = noise.repeat(1, num_channels, 1)  # bs x num_channels x time
    else:
        noise = torch.rand(batch_size, num_channels, sequence_length, device=device)  # noise in [0, 1], bs x num_channels x L

    # 创建掩码张量,初始化为全1
    mask = torch.ones(batch_size, num_channels, sequence_length, device=device)

    # 将前len_keep个位置置为0,即进行掩码操作
    mask[:, :, :len_keep] = 0

    # 对噪声进行排序,以便后续恢复掩码位置
    ids_shuffle = torch.argsort(noise, dim=-1)  # ascend: small is keep, large is remove
    ids_restore = torch.argsort(ids_shuffle, dim=-1)  # ids_restore: [bs x num_channels x L]

    # 根据排序后的索引恢复掩码的顺序
    mask = torch.gather(mask, dim=-1, index=ids_restore)

    # 将掩码张量的形状调整为与输入张量相同,且每个掩码值重复num_features次
    mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features)  # mask: [bs x num_channels x num_patches x patch_length]

    # 如果有指定不被掩码的通道,将这些通道的掩码值置为0
    if unmasked_channel_indices is not None:
        mask[:, unmasked_channel_indices, :, :] = 0

    # 使用掩码值进行输入张量的掩码操作,将掩码后的结果作为inputs_mask返回
    inputs_mask = inputs.masked_fill(mask.bool(), mask_value)

    # 返回掩码后的输入张量和掩码张量的第一个维度
    return inputs_mask, mask[..., 0]
# Copied from transformers.models.patchtst.modeling_patchtst.forecast_masking
def forecast_masking(
    inputs: torch.Tensor,
    num_forecast_mask_patches: Union[list, int],
    unmasked_channel_indices: list = None,
    mask_value: int = 0,
):
    """Forecast masking that masks the last K patches where K is from the num_forecast_mask_patches.
    If num_forecast_mask_patches is a list, samples in the batch will be randomly masked by numbers defined in the list.

    Parameters:
        inputs (`torch.Tensor`):
            Input of shape `(bs, num_channels, num_patch, patch_length)`
        num_forecast_mask_patches (`list`):
            Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
        unmasked_channel_indices (`list`, *optional*):
            Indices of channels that are not masked.
        mask_value (`int`, *optional*, defaults to 0):
            Values in the masked patches will be filled by `mask_value`.

    Returns:
        `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs,
        num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)`
    """

    # If num_forecast_mask_patches is an integer, convert it to a list for consistency
    if isinstance(num_forecast_mask_patches, int):
        num_forecast_mask_patches = [num_forecast_mask_patches]

    # Initialize forecast_mask_ratios with a list of 1s for each num_forecast_mask_patches
    forecast_mask_ratios = [1 for _ in num_forecast_mask_patches]

    # Extract dimensions from inputs tensor
    batch_size, num_channels, sequence_length, num_features = inputs.shape

    # Initialize mask tensor with zeros
    mask = torch.zeros(batch_size, num_channels, sequence_length, device=inputs.device)

    # Initialize an empty list to store temporary computations
    t_list = []
    total_length = 0
    total_ratio = sum(forecast_mask_ratios)

    # Iterate over num_forecast_mask_patches and forecast_mask_ratios to compute temporary lengths
    for patch_length, ratio in zip(num_forecast_mask_patches, forecast_mask_ratios):
        # Validate patch_length to ensure it is within valid range
        if patch_length <= 0 or patch_length >= sequence_length:
            raise ValueError(
                f"num_forecast_mask_patches {patch_length} should be greater than 0 and less than total patches."
            )
        # Compute temporary length based on batch size and ratio
        temp_len = int(batch_size * ratio / total_ratio)
        t_list.append([patch_length, ratio, temp_len])
        total_length += temp_len

    # Sort t_list based on the third element (temp_len)
    t_list = sorted(t_list, key=lambda x: x[2])

    # Adjust the last element in t_list to match batch size
    if total_length < batch_size:
        t_list[0][2] = t_list[0][2] + (batch_size - total_length)
    elif total_length > batch_size:
        t_list[-1][2] = t_list[-1][2] + (total_length - batch_size)

    # Initialize batch indices
    batch1 = 0

    # Iterate over t_list to populate mask tensor
    for patch_len, _, temp_len in t_list:
        batch2 = batch1 + temp_len
        mask[batch1:batch2, :, -patch_len:] = 1
        batch1 = batch2

    # Randomly permute the batch indices of mask tensor
    perm = torch.randperm(mask.shape[0])
    mask = mask[perm]

    # Expand mask tensor dimensions to match inputs tensor
    mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features)

    # If unmasked_channel_indices is provided, zero out corresponding channels in mask tensor
    if unmasked_channel_indices is not None:
        mask[:, unmasked_channel_indices, :, :] = 0

    # Apply masking to inputs tensor using mask tensor and return masked inputs and mask tensor
    inputs_mask = inputs.masked_fill(mask.bool(), mask_value)
    return inputs_mask, mask[..., 0]
class PatchTSMixerPatchify(nn.Module):
    """
    A class to patchify the time series sequence into different patches

    Returns:
        `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
    """

    def __init__(self, config: PatchTSMixerConfig):
        super().__init__()

        self.sequence_length = config.context_length  # 设置实例变量sequence_length为config中的context_length
        self.patch_length = config.patch_length  # 设置实例变量patch_length为config中的patch_length
        self.patch_stride = config.patch_stride  # 设置实例变量patch_stride为config中的patch_stride

        if self.sequence_length <= self.patch_length:
            raise ValueError(
                f"Sequence length ({self.sequence_length}) has to be greater than the patch length ({self.patch_length})"
            )

        # 计算patch的数量
        self.num_patches = (max(self.sequence_length, self.patch_length) - self.patch_length) // self.patch_stride + 1
        new_sequence_length = self.patch_length + self.patch_stride * (self.num_patches - 1)
        self.sequence_start = self.sequence_length - new_sequence_length  # 计算起始序列位置

    def forward(self, past_values: torch.Tensor):
        """
        Parameters:
            past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*):
                Input for patchification

        Returns:
            `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
        """
        sequence_length = past_values.shape[-2]  # 获取输入序列的长度
        if sequence_length != self.sequence_length:
            raise ValueError(
                f"Input sequence length ({sequence_length}) doesn't match model configuration ({self.sequence_length})."
            )
        # output: [bs x new_sequence_length x num_channels]
        output = past_values[:, self.sequence_start :, :]  # 截取序列的起始位置后的值
        # output: [bs x num_patches x num_input_channels x patch_length]
        output = output.unfold(dimension=-2, size=self.patch_length, step=self.patch_stride)  # 使用unfold方法进行切片操作
        # output: [bs x num_input_channels x num_patches x patch_length]
        output = output.transpose(-2, -3).contiguous()  # 转置操作,调整维度顺序
        return output


# Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTMasking with PatchTST->PatchTSMixer
class PatchTSMixerMasking(nn.Module):
    """
    Class to perform random or forecast masking.

    Parameters:
        config (`PatchTSMixerConfig`): model config
    Returns:
        x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
            Masked patched input
        mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
            Bool tensor indicating True on masked points
    """
    # 初始化方法,接受一个配置对象作为参数
    def __init__(self, config: PatchTSMixerConfig):
        # 调用父类的初始化方法
        super().__init__()
        # 从配置对象中获取随机遮盖比例
        self.random_mask_ratio = config.random_mask_ratio
        # 从配置对象中获取通道一致性遮盖的标志
        self.channel_consistent_masking = config.channel_consistent_masking
        # 从配置对象中获取遮盖类型
        self.mask_type = config.mask_type
        # 从配置对象中获取预测遮盖的数量
        self.num_forecast_mask_patches = config.num_forecast_mask_patches
        # 从配置对象中获取未遮盖通道的索引列表
        self.unmasked_channel_indices = config.unmasked_channel_indices
        # 从配置对象中获取遮盖数值
        self.mask_value = config.mask_value
        # 如果存在未遮盖通道的索引列表,则对其进行排序
        if self.unmasked_channel_indices is not None:
            self.unmasked_channel_indices = sorted(self.unmasked_channel_indices)

    def forward(self, patch_input: torch.Tensor):
        """
        Parameters:
            patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
                Patch input

        Return:
            masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
                Masked patched input
            mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
                Bool tensor indicating True on masked points

        """
        # 根据遮盖类型选择不同的遮盖方法
        if self.mask_type == "random":
            # 调用随机遮盖函数,生成遮盖后的输入和遮盖掩码
            masked_input, mask = random_masking(
                inputs=patch_input,
                mask_ratio=self.random_mask_ratio,
                unmasked_channel_indices=self.unmasked_channel_indices,
                channel_consistent_masking=self.channel_consistent_masking,
                mask_value=self.mask_value,
            )
        elif self.mask_type == "forecast":
            # 调用预测遮盖函数,生成遮盖后的输入和遮盖掩码
            masked_input, mask = forecast_masking(
                inputs=patch_input,
                num_forecast_mask_patches=self.num_forecast_mask_patches,
                unmasked_channel_indices=self.unmasked_channel_indices,
                mask_value=self.mask_value,
            )
        else:
            # 若遮盖类型无效,则抛出数值错误异常
            raise ValueError(f"Invalid mask type {self.mask_type}.")

        # 将遮盖掩码转换为布尔型张量
        mask = mask.bool()
        # 返回遮盖后的输入和遮盖掩码
        return masked_input, mask
# 从 transformers.models.patchtst.modeling_patchtst.PatchTSTStdScaler 复制的代码,将 PatchTST 替换为 PatchTSMixer
class PatchTSMixerStdScaler(nn.Module):
    """
    标准化特征,通过计算均值并沿第一个维度进行缩放,然后通过减去均值并除以标准差进行归一化。
    """

    def __init__(self, config: PatchTSMixerConfig):
        super().__init__()
        # 如果 config 中有 scaling_dim 属性,则使用其值作为 dim;否则默认为 1
        self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
        # 如果 config 中有 keepdim 属性,则使用其值作为 keepdim;否则默认为 True
        self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
        # 如果 config 中有 minimum_scale 属性,则使用其值作为 minimum_scale;否则默认为 1e-5
        self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-5

    def forward(
        self, data: torch.Tensor, observed_indicator: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        前向传播方法
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                输入数据用于批次归一化计算
            observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                计算观察指标上的缩放
        Returns:
            返回元组,包含三个 `torch.Tensor`:
                (`(batch_size, sequence_length, num_input_channels)`, 
                 `(batch_size, 1, num_input_channels)`,
                 `(batch_size, 1, num_input_channels)`)
        """
        # 计算分母,即观察指标的和
        denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim)
        # 将分母限制为最小值为 1.0
        denominator = denominator.clamp_min(1.0)
        # 计算均值,即数据乘以观察指标后的和除以分母
        loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator

        # 计算方差,即数据减去均值后乘以观察指标的平方和除以分母
        variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator
        # 计算标准差,即方差加上最小缩放值后开平方
        scale = torch.sqrt(variance + self.minimum_scale)
        # 返回归一化后的数据,均值 loc,标准差 scale
        return (data - loc) / scale, loc, scale


# 从 transformers.models.patchtst.modeling_patchtst.PatchTSTMeanScaler 复制的代码,将 PatchTST 替换为 PatchTSMixer
class PatchTSMixerMeanScaler(nn.Module):
    """
    计算缩放因子作为第一个维度上的加权平均绝对值,并相应地缩放数据。
    """

    def __init__(self, config: PatchTSMixerConfig):
        super().__init__()
        # 如果 config 中有 scaling_dim 属性,则使用其值作为 dim;否则默认为 1
        self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
        # 如果 config 中有 keepdim 属性,则使用其值作为 keepdim;否则默认为 True
        self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
        # 如果 config 中有 minimum_scale 属性,则使用其值作为 minimum_scale;否则默认为 1e-10
        self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10
        # 如果 config 中有 default_scale 属性,则使用其值作为 default_scale;否则默认为 None
        self.default_scale = config.default_scale if hasattr(config, "default_scale") else None

    def forward(
        self, data: torch.Tensor, observed_indicator: torch.Tensor
        # 这里 forward 方法未完成
        """
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                输入用于批量归一化计算的数据
            observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                表示观察指标,用于计算缩放比例
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`, `(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, num_input_channels)`)
        """
        # 计算加权求和,得到每个通道的绝对值和
        ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True)
        # 计算每个通道上的观察样本数量
        num_observed = observed_indicator.sum(self.dim, keepdim=True)

        # 计算每个通道的缩放比例
        scale = ts_sum / torch.clamp(num_observed, min=1)

        # 如果提供了 `default_scale`,则使用它;否则使用批量的缩放比例
        if self.default_scale is None:
            # 计算整个批次的加权绝对值和
            batch_sum = ts_sum.sum(dim=0)
            # 计算整个批次的观察样本数量,并至少为1
            batch_observations = torch.clamp(num_observed.sum(0), min=1)
            # 计算默认的缩放比例
            default_scale = torch.squeeze(batch_sum / batch_observations)
        else:
            # 使用给定的 `default_scale` 来初始化缩放比例
            default_scale = self.default_scale * torch.ones_like(scale)

        # 在没有观察到样本的位置应用默认的缩放比例
        scale = torch.where(num_observed > 0, scale, default_scale)

        # 确保缩放比例至少为 `self.minimum_scale`
        scale = torch.clamp(scale, min=self.minimum_scale)
        # 对数据应用缩放
        scaled_data = data / scale

        # 如果不保持维度,则将缩放比例的维度压缩
        if not self.keepdim:
            scale = scale.squeeze(dim=self.dim)

        return scaled_data, torch.zeros_like(scale), scale
# Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTNOPScaler with PatchTST->PatchTSMixer
class PatchTSMixerNOPScaler(nn.Module):
    """
    Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data.
    """

    def __init__(self, config: PatchTSMixerConfig):
        super().__init__()
        # 设置缩放维度为配置中的 scaling_dim,如果不存在则默认为 1
        self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
        # 设置是否保持维度的配置参数,默认为 True
        self.keepdim = config.keepdim if hasattr(config, "keepdim") else True

    def forward(
        self, data: torch.Tensor, observed_indicator: torch.Tensor = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                input for Batch norm calculation
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, num_input_channels)`)
        """
        # 计算沿着指定维度 dim 的数据均值,得到缩放因子 scale
        scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
        # 初始化位置信息为零向量,用于模型输出的位置参数
        loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
        # 返回原始数据、位置信息 loc 和缩放因子 scale
        return data, loc, scale


@dataclass
class PatchTSMixerEncoderOutput(ModelOutput):
    """
    Base class for `PatchTSMixerEncoderOutput`, with potential hidden states.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`):
            Hidden-state at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*):
            Hidden-states of the model at the output of each layer.
    """

    last_hidden_state: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None


class PatchTSMixerEncoder(PatchTSMixerPreTrainedModel):
    """
    Encoder for PatchTSMixer which inputs patched time-series and outputs patched embeddings.

    Args:
        config (`PatchTSMixerConfig`, *required*):
            Configuration.
    """

    def __init__(self, config: PatchTSMixerConfig):
        super().__init__(config)

        self.use_return_dict = config.use_return_dict

        # 线性层,用于将 patch 后的时间序列映射到 d_model 维度
        self.patcher = nn.Linear(config.patch_length, config.d_model)
        # 如果使用位置编码,则初始化位置编码器
        if config.use_positional_encoding:
            self.positional_encoder = PatchTSMixerPositionalEncoding(config=config)
        else:
            self.positional_encoder = None
        # MLP-Mixer 编码器块
        self.mlp_mixer_encoder = PatchTSMixerBlock(config=config)

        # 如果设置了 post_init 标志,则调用后初始化方法
        if config.post_init:
            self.post_init()

    @replace_return_docstrings(output_type=PatchTSMixerEncoderOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        past_values: torch.Tensor,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, PatchTSMixerEncoderOutput]:
        r"""
        Args:
            past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
                Context values of the time series. For a pretraining task, this denotes the input time series to
                predict the masked portion. For a forecasting task, this denotes the history/past time series values.
                Similarly, for classification or regression tasks, it denotes the appropriate context values of the
                time series.

                For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series,
                it is greater than 1.

            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers.

            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.

        Returns:
            `torch.FloatTensor` of shape `(batch_size, n_vars, num_patches, d_model)`
        """

        # Determine the final return format based on `return_dict` or `self.use_return_dict`
        return_dict = return_dict if return_dict is not None else self.use_return_dict

        # Flatten the input `past_values` into patches [bs x num_patch x d_model]
        # For multivariate time series, patches will be [bs x n_vars x num_patch x d_model]
        patches = self.patcher(past_values)

        # Add positional encoding to the patches if a positional encoder is provided
        if self.positional_encoder is not None:
            patches = self.positional_encoder(patches)

        # Apply the MLP-Mixer encoder to obtain the last hidden state and potentially all hidden states
        last_hidden_state, hidden_states = self.mlp_mixer_encoder(patches, output_hidden_states=output_hidden_states)

        # If `return_dict` is False, return the outputs as a tuple
        if not return_dict:
            return tuple(
                v
                for v in [
                    last_hidden_state,
                    hidden_states,
                ]
            )

        # If `return_dict` is True, return the outputs wrapped in PatchTSMixerEncoderOutput
        return PatchTSMixerEncoderOutput(last_hidden_state=last_hidden_state, hidden_states=hidden_states)
@dataclass
class PatchTSMixerModelOutput(ModelOutput):
    """
    Base class for model's outputs, with potential hidden states.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`):
            Hidden-state at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*):
            Hidden-states of the model at the output of each layer.
        patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
            Patched input data to the model.
        mask: (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`,*optional*):
            Bool Tensor indicating True in masked patches and False otherwise.
        loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`,*optional*):
            Gives the mean of the context window per channel. Used for revin denorm outside the model, if revin
            enabled.
        scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`,*optional*):
            Gives the std dev of the context window per channel. Used for revin denorm outside the model, if revin
            enabled.
    """

    last_hidden_state: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    patch_input: torch.FloatTensor = None
    mask: Optional[torch.FloatTensor] = None
    loc: Optional[torch.FloatTensor] = None
    scale: Optional[torch.FloatTensor] = None


@add_start_docstrings(
    "The PatchTSMixer Model for time-series forecasting.",
    PATCHTSMIXER_START_DOCSTRING,
)
class PatchTSMixerModel(PatchTSMixerPreTrainedModel):
    def __init__(self, config: PatchTSMixerConfig, mask_input: bool = False):
        super().__init__(config)

        # 设置是否返回字典格式的输出
        self.use_return_dict = config.use_return_dict
        # 初始化编码器
        self.encoder = PatchTSMixerEncoder(config)
        # 初始化 patching 模块
        self.patching = PatchTSMixerPatchify(config)

        # 如果需要对输入进行掩码处理,则初始化 masking 模块;否则置为 None
        if mask_input is True:
            self.masking = PatchTSMixerMasking(config)
        else:
            self.masking = None

        # 根据配置选择标准化器(均值、标准差或无操作)
        if config.scaling == "mean":
            self.scaler = PatchTSMixerMeanScaler(config)
        elif config.scaling == "std" or config.scaling is True:
            self.scaler = PatchTSMixerStdScaler(config)
        else:
            self.scaler = PatchTSMixerNOPScaler(config)

        # 如果配置要求在初始化后进行进一步处理,则调用 post_init 方法
        if config.post_init:
            self.post_init()

    @add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=PatchTSMixerModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        past_values: torch.Tensor,
        observed_mask: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = None,
    ) -> PatchTSMixerModelOutput:
        r"""
        observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
            Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
            in `[0, 1]`:
                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).

        Returns:
            `PatchTSMixerModelOutput`: An object containing encoder outputs and other processed inputs.

        """
        # Determine if the return_dict should be used or not
        return_dict = return_dict if return_dict is not None else self.use_return_dict

        # Initialize mask to None
        mask = None
        # If observed_mask is not provided, initialize it as a tensor of ones with the same shape as past_values
        if observed_mask is None:
            observed_mask = torch.ones_like(past_values)
        
        # Scale the observed values using a scaler function, and get location and scale parameters
        scaled_past_values, loc, scale = self.scaler(past_values, observed_mask)

        # Patch the scaled past values using a patching function
        patched_x = self.patching(scaled_past_values)  # [batch_size x num_input_channels x num_patch x patch_length]

        # Prepare encoder input; apply masking if masking function is defined
        enc_input = patched_x
        if self.masking is not None:
            enc_input, mask = self.masking(patched_x)
            # enc_input: [batch_size x num_input_channels x num_patch x patch_length]
            # mask: [batch_size x num_input_channels x num_patch]

        # Pass the encoder input to the encoder module
        encoder_output = self.encoder(
            enc_input,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # Ensure encoder_output is of type PatchTSMixerEncoderOutput if it is a tuple
        if isinstance(encoder_output, tuple):
            encoder_output = PatchTSMixerEncoderOutput(*encoder_output)

        # If return_dict is False, return a tuple of selected encoder outputs and inputs
        if not return_dict:
            return tuple(
                v
                for v in [
                    encoder_output.last_hidden_state,
                    encoder_output.hidden_states,
                    patched_x,
                    mask,
                    loc,
                    scale,
                ]
            )

        # If return_dict is True, return a PatchTSMixerModelOutput object with specified attributes
        return PatchTSMixerModelOutput(
            last_hidden_state=encoder_output.last_hidden_state,
            hidden_states=encoder_output.hidden_states,
            patch_input=patched_x,
            mask=mask,
            loc=loc,
            scale=scale,
        )
@dataclass
class PatchTSMixerForPreTrainingOutput(ModelOutput):
    """
    Output type of [`PatchTSMixerForPreTrainingOutput`].

    Args:
        prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, patch_length)`):
            Prediction output from the pretrain head.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*):
            Hidden-states of the model at the output of each layer.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
            Backbone embeddings before passing through the head.
        loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
            Total loss
    """

    loss: Optional[torch.FloatTensor] = None
    prediction_outputs: torch.FloatTensor = None
    last_hidden_state: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None


class PatchTSMixerForPretraining(PatchTSMixerPreTrainedModel):
    r"""
    `PatchTSMixer` for mask pretraining.

    Args:
        config (`PatchTSMixerConfig`, *required*):
            Configuration.

    Returns:
        `None`.
    """

    def __init__(self, config: PatchTSMixerConfig):
        # 调用父类的构造函数初始化对象
        super().__init__(config)
        # 使用给定配置创建 PatchTSMixerModel 对象,设置 mask_input 为 True
        self.model = PatchTSMixerModel(config, mask_input=True)
        # 创建 PatchTSMixerPretrainHead 对象,使用给定配置
        self.head = PatchTSMixerPretrainHead(config=config)
        # 从配置中获取 masked_loss,并将其赋值给对象属性
        self.masked_loss = config.masked_loss
        # 从配置中获取 use_return_dict,并将其赋值给对象属性
        self.use_return_dict = config.use_return_dict

        # 如果配置中指定了 post_init 为 True,则调用对象的 post_init 方法
        if config.post_init:
            self.post_init()

    @add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=PatchTSMixerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        past_values: torch.Tensor,
        observed_mask: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = False,
        return_loss: bool = True,
        return_dict: Optional[bool] = None,
    ) -> PatchTSMixerForPreTrainingOutput:
        r"""
        observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
            Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
            in `[0, 1]`:
                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
        return_loss (`bool`,  *optional*):
            Whether to return the loss in the `forward` call.

        Returns:
            PatchTSMixerForPreTrainingOutput: An instance of the output class containing various outputs based on the model's forward pass.
        """
        # Determine whether to use the provided return_dict or the default one from the class
        return_dict = return_dict if return_dict is not None else self.use_return_dict

        # Define the type of loss function based on whether masked loss is enabled
        if self.masked_loss is True:
            loss = torch.nn.MSELoss(reduction="none")
        else:
            loss = torch.nn.MSELoss(reduction="mean")

        # Perform forward pass through the model with specified arguments
        model_output = self.model(
            past_values,
            observed_mask=observed_mask,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )  # x.last_hidden_state: [batch_size x nvars x num_patch x d_model]

        # Ensure model_output is of type PatchTSMixerModelOutput
        if isinstance(model_output, tuple):
            model_output = PatchTSMixerModelOutput(*model_output)

        # Generate predictions using the head module
        x_hat = self.head(model_output.last_hidden_state)  # tensor [batch_size x nvars x num_patch x patch_length]

        # Compute loss if return_loss flag is set to True
        if return_loss is True:
            loss_val = loss(x_hat, model_output.patch_input)
        else:
            loss_val = None

        # Calculate masked loss if enabled and loss_val is not None
        if self.masked_loss is True and loss_val is not None:
            loss_val = (loss_val.mean(dim=-1) * model_output.mask).sum() / (model_output.mask.sum() + 1e-10)

        # Return outputs based on whether return_dict is False
        if not return_dict:
            return tuple(
                v
                for v in [
                    loss_val,
                    x_hat,
                    model_output.last_hidden_state,
                    model_output.hidden_states,
                ]
            )

        # Return outputs wrapped in PatchTSMixerForPreTrainingOutput object if return_dict is True
        return PatchTSMixerForPreTrainingOutput(
            loss=loss_val,
            prediction_outputs=x_hat,  # tensor [batch_size x nvars x num_patch x patch_length]
            last_hidden_state=model_output.last_hidden_state,  # x: [batch_size x nvars x num_patch x d_model]
            hidden_states=model_output.hidden_states,
        )
@dataclass
class PatchTSMixerForPredictionOutput(ModelOutput):
    """
    Output type of [`PatchTSMixerForPredictionOutput`].

    Args:
        prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_input_channels)`):
            Prediction output from the forecast head.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
            Backbone embeddings before passing through the head.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*):
            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
            Total loss.
        loc (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
            Input mean
        scale (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
            Input std dev

    """

    loss: Optional[torch.FloatTensor] = None  # 可选的损失值
    prediction_outputs: torch.FloatTensor = None  # 预测输出
    last_hidden_state: torch.FloatTensor = None  # 经过头部之前的背景嵌入
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None  # 模型各层输出的隐藏状态
    loc: torch.FloatTensor = None  # 输入均值
    scale: torch.FloatTensor = None  # 输入标准差


@dataclass
class SamplePatchTSMixerPredictionOutput(ModelOutput):
    """
    Base class for time series model's predictions outputs that contains the sampled values from the chosen
    distribution.

    Args:
        sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, number_channels)`):
            Sampled values from the chosen distribution.
    """

    sequences: torch.FloatTensor = None  # 从选择的分布中抽样得到的序列值


@dataclass
class SamplePatchTSMixerRegressionOutput(ModelOutput):
    """
    Base class for time series model's predictions outputs that contains the sampled values from the chosen
    distribution.

    Args:
        sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, num_targets)`
                Sampled values from the chosen distribution.
    """

    sequences: torch.FloatTensor = None  # 从选择的分布中抽样得到的序列值


# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.nll
def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor:
    """
    Computes the negative log likelihood loss from input distribution with respect to target.
    """
    return -input.log_prob(target)  # 计算输入分布相对于目标的负对数似然损失


# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.weighted_average
def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor:
    """
    Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
    meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.
    """
    # 计算给定张量在给定维度上的加权平均值,遮蔽与零权重相关的值
    return torch.sum(input_tensor * weights, dim=dim, keepdim=True) / torch.sum(weights, dim=dim, keepdim=True)
    # 如果提供了权重,则计算加权平均值
    if weights is not None:
        # 计算加权后的张量,其中权重不为零的位置乘以输入张量对应位置的值,否则置为零
        weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor))
        # 计算权重的总和,对给定维度进行求和,如果没有指定维度,则对整个张量进行求和
        sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0)
        # 返回加权平均值,对给定维度进行求和并除以总权重,如果没有指定维度,则对整个张量进行操作
        return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights
    else:
        # 如果未提供权重,则计算输入张量沿指定维度的平均值
        return input_tensor.mean(dim=dim)
class PatchTSMixerForPrediction(PatchTSMixerPreTrainedModel):
    r"""
    `PatchTSMixer` for forecasting application.

    Args:
        config (`PatchTSMixerConfig`, *required*):
            Configuration.

    Returns:
        `None`.
    """

    def __init__(self, config: PatchTSMixerConfig):
        # 调用父类的初始化方法,传入配置对象
        super().__init__(config)
        # 从配置中获取损失函数类型
        self.loss = config.loss
        # 从配置中获取是否返回字典类型结果的选项
        self.use_return_dict = config.use_return_dict
        # 从配置中获取预测通道索引列表
        self.prediction_channel_indices = config.prediction_channel_indices
        # 从配置中获取并行采样数量
        self.num_parallel_samples = config.num_parallel_samples

        # 根据配置中的损失函数类型选择分布输出类型
        if config.loss == "mse":
            self.distribution_output = None
        else:
            dim = config.prediction_length
            distribution_output_map = {
                "student_t": StudentTOutput,
                "normal": NormalOutput,
                "negative_binomial": NegativeBinomialOutput,
            }
            # 根据配置中的分布输出类型选择相应的输出类
            output_class = distribution_output_map.get(config.distribution_output, None)
            if output_class is not None:
                self.distribution_output = output_class(dim=dim)
            else:
                raise ValueError(f"Unknown distribution output {config.distribution_output}")

        # 创建 PatchTSMixerModel 模型对象
        self.model = PatchTSMixerModel(config)
        # 创建 PatchTSMixerForPredictionHead 头部对象
        self.head = PatchTSMixerForPredictionHead(
            config=config,
            distribution_output=self.distribution_output,
        )

        # 如果配置指定了后初始化操作,则执行后初始化
        if config.post_init:
            self.post_init()

    @add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=PatchTSMixerForPredictionOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        past_values: torch.Tensor,
        observed_mask: Optional[torch.Tensor] = None,
        future_values: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = False,
        return_loss: bool = True,
        return_dict: Optional[bool] = None,
    ):
        # 前向传播函数,接受多个输入参数并返回预测结果
        ...

    def generate(
        self,
        past_values: torch.Tensor,
        observed_mask: Optional[torch.Tensor] = None,
        ...
    ) -> SamplePatchTSMixerPredictionOutput:
        """
        Generate sequences of sample predictions from a model with a probability distribution head.

        Args:
            past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Past values of the time series that serves as context in order to predict the future.

            observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
                Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
                in `[0, 1]`:

                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).

        Return:
            [`SamplePatchTSMixerPredictionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
            number of samples, prediction_length, num_input_channels)`.
        """
        # 获取并行采样数量
        num_parallel_samples = self.num_parallel_samples

        # 获取模型输出
        outputs = self(
            past_values=past_values,
            future_values=None,
            observed_mask=observed_mask,
            output_hidden_states=False,
        )

        # 获取分布
        distribution = self.distribution_output.distribution(
            outputs.prediction_outputs, loc=outputs.loc, scale=outputs.scale
        )

        # 获取样本:列表,每个元素为 [batch_size x prediction_length x num_channels]
        samples = [distribution.sample() for _ in range(num_parallel_samples)]

        # 堆叠张量
        samples = torch.stack(samples, dim=1)  # [batch_size x num_samples x prediction_length x num_channels]
        return SamplePatchTSMixerPredictionOutput(sequences=samples)
@dataclass
class PatchTSMixerForTimeSeriesClassificationOutput(ModelOutput):
    """
    Output type of [`PatchTSMixerForTimeSeriesClassificationOutput`].

    Args:
        prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
            Prediction output from the classification head.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
            Backbone embeddings before passing through the head.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*):
            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
            Total loss.
    """

    loss: Optional[torch.FloatTensor] = None
    prediction_outputs: torch.FloatTensor = None
    last_hidden_state: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None


class PatchTSMixerForTimeSeriesClassification(PatchTSMixerPreTrainedModel):
    r"""
    `PatchTSMixer` for classification application.

    Args:
        config (`PatchTSMixerConfig`, *required*):
            Configuration.

    Returns:
        `None`.
    """

    def __init__(self, config: PatchTSMixerConfig):
        super().__init__(config)

        # Initialize the main model backbone
        self.model = PatchTSMixerModel(config)

        # Initialize the classification head
        self.head = PatchTSMixerLinearHead(
            config=config,
        )

        # Determine if statistical scaling should be applied
        self.use_return_dict = config.use_return_dict
        if config.scaling in ["std", "mean", True]:
            self.inject_scale = InjectScalerStatistics4D(d_model=config.d_model, num_patches=config.num_patches)
        else:
            self.inject_scale = None

        # Apply post-initialization steps if specified in the configuration
        if config.post_init:
            self.post_init()

    @add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING)
    @replace_return_docstrings(
        output_type=PatchTSMixerForTimeSeriesClassificationOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        past_values: torch.Tensor,
        target_values: torch.Tensor = None,
        output_hidden_states: Optional[bool] = False,
        return_loss: bool = True,
        return_dict: Optional[bool] = None,
    ):
        """
        Perform forward pass of the PatchTSMixerForTimeSeriesClassification model.

        Args:
            past_values (`torch.Tensor`):
                Tensor of past input values.
            target_values (`torch.Tensor`, *optional*):
                Tensor of target values for training.
            output_hidden_states (`bool`, *optional*):
                Whether to output hidden states.
            return_loss (`bool`, *optional*):
                Whether to return the loss.
            return_dict (`bool`, *optional*):
                Whether to return a dictionary of outputs.

        Returns:
            Depending on the `return_dict` setting, returns either a dictionary of outputs or directly the outputs.
        """
        # Forward pass through the main model backbone
        # and the classification head
        pass  # Actual computation details are omitted for brevity
        r"""
        target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
            `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target
            values of the time series, that serve as labels for the model. The `target_values` is what the
            Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
            required for a pretraining task.

            For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
            to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
            pass the target data with all channels, as channel Filtering for both prediction and target will be
            manually applied before the loss computation.

            For a classification task, it has a shape of `(batch_size,)`.

            For a regression task, it has a shape of `(batch_size, num_targets)`.
        return_loss (`bool`, *optional*):
            Whether to return the loss in the `forward` call.

        Returns:

        """

        # 定义交叉熵损失函数
        loss = torch.nn.CrossEntropyLoss()

        # 确定是否使用预定义的返回字典,如果未定义则使用类属性中的默认设置
        return_dict = return_dict if return_dict is not None else self.use_return_dict

        # 将输入数据传递给模型进行前向推理
        model_output = self.model(
            past_values,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )  # x: [batch_size x nvars x num_patch x d_model]

        # 如果模型输出是元组,则转换为PatchTSMixerModelOutput对象
        if isinstance(model_output, tuple):
            model_output = PatchTSMixerModelOutput(*model_output)

        # 如果定义了inject_scale方法,则将其应用于模型输出的最后隐藏状态
        if self.inject_scale is not None:
            model_output.last_hidden_state = self.inject_scale(
                model_output.last_hidden_state,
                loc=model_output.loc,
                scale=model_output.scale,
            )  # x: [batch_size x nvars x num_patch x d_model]

        # 通过模型头部获取预测结果
        y_hat = self.head(model_output.last_hidden_state)  # tensor [batch_size x n_labels]

        # 如果提供了目标值并且需要计算损失,则计算交叉熵损失
        if target_values is not None and return_loss is True:
            loss_val = loss(y_hat, target_values)
        else:
            loss_val = None

        # 如果不要求返回字典形式的结果,则返回一个元组
        if not return_dict:
            return tuple(
                v
                for v in [
                    loss_val,
                    y_hat,
                    model_output.last_hidden_state,
                    model_output.hidden_states,
                ]
            )

        # 否则,返回PatchTSMixerForTimeSeriesClassificationOutput对象
        return PatchTSMixerForTimeSeriesClassificationOutput(
            loss=loss_val,
            prediction_outputs=y_hat,  # tensor [batch_size x n_labels]
            last_hidden_state=model_output.last_hidden_state,  # x: [batch_size x nvars x num_patch x d_model]
            hidden_states=model_output.hidden_states,
        )
@dataclass
class PatchTSMixerForRegressionOutput(ModelOutput):
    """
    Output type of [`PatchTSMixerForRegressionOutput`].

    Args:
        regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
            Prediction output from the regression head.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
            Backbone embeddings before passing through the head.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*):
            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
            Total loss.
    """

    loss: Optional[torch.FloatTensor] = None  # 可选的总损失,如果提供了 `y`,则返回
    regression_outputs: torch.FloatTensor = None  # 回归头部的预测输出,形状为 `(batch_size, num_targets)`
    last_hidden_state: torch.FloatTensor = None  # 通过头部之前的主干嵌入,形状为 `(batch_size, num_input_channels, num_patches, d_model)`
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None  # 模型各层输出的隐藏状态,以及可选的初始嵌入输出的元组

class InjectScalerStatistics4D(nn.Module):
    def __init__(self, d_model: int, num_patches: int, expansion: int = 2):
        super().__init__()

        self.inverse_trans_expansion = nn.Linear(d_model + 2, expansion * d_model)
        # 反向转换扩展线性层,输入维度为 `d_model + 2`,输出维度为 `expansion * d_model`
        self.inverse_trans_compression = nn.Linear(expansion * d_model, d_model)
        # 反向转换压缩线性层,输入维度为 `expansion * d_model`,输出维度为 `d_model`
        self.map_scale_expansion = nn.Linear(2, 2 * expansion)
        # 映射尺度扩展线性层,输入维度为 `2`,输出维度为 `2 * expansion`
        self.map_scale_compression = nn.Linear(2 * expansion, 2)
        # 映射尺度压缩线性层,输入维度为 `2 * expansion`,输出维度为 `2`
        self.num_patches = num_patches
        # 存储传入的补丁数
    def forward(self, inputs: torch.Tensor, loc: torch.Tensor, scale: torch.Tensor):
        """
        Args:
            inputs (`torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)`)
            loc (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`)
            scale (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`)
        Returns:
            `torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)`
        """

        mean = loc.transpose(-1, -2)  # 将 loc 的最后两个维度交换位置,变为 `[batch_size x n_channels x 1]`
        mean = mean.unsqueeze(-2)  # 在倒数第二个位置添加一个维度,变为 `[batch_size x n_channels x 1 x 1]`
        mean = mean.repeat(1, 1, self.num_patches, 1)  # 沿着指定维度重复张量,扩展为 `[batch_size x n_channels x num_patch x 1]`

        stdev = scale.transpose(-1, -2)  # 将 scale 的最后两个维度交换位置,变为 `[batch_size x n_channels x 1]`
        stdev = stdev.unsqueeze(-2)  # 在倒数第二个位置添加一个维度,变为 `[batch_size x n_channels x 1 x 1]`
        stdev = stdev.repeat(1, 1, self.num_patches, 1)  # 沿着指定维度重复张量,扩展为 `[batch_size x n_channels x num_patch x 1]`

        concat_stats = torch.cat([mean, stdev], dim=-1)  # 沿着最后一个维度连接张量,得到 `[batch_size x n_channels x num_patch x 2]`

        concat_stats = self.map_scale_expansion(concat_stats)  # 使用模型的 `map_scale_expansion` 方法处理张量,输出 `[batch_size x n_channels x num_patch x (2*expansion)]`
        concat_stats = self.map_scale_compression(concat_stats)  # 使用模型的 `map_scale_compression` 方法处理张量,输出 `[batch_size x n_channels x num_patch x 2]`

        inputs = torch.cat([inputs, concat_stats], dim=-1)  # 沿着最后一个维度连接张量,得到 `[batch_size x channels x num_patch x d_model+2]`
        inputs = self.inverse_trans_expansion(inputs)  # 使用模型的 `inverse_trans_expansion` 方法处理张量,输出 `[batch_size x channels x num_patch x (expansion*d_model)]`
        inputs = self.inverse_trans_compression(inputs)  # 使用模型的 `inverse_trans_compression` 方法处理张量,输出 `[batch_size x channels x num_patch x d_model]`

        return inputs
class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
    r"""
    `PatchTSMixer` for regression application.

    Args:
        config (`PatchTSMixerConfig`, *required*):
            Configuration.

    Returns:
        `None`.
    """

    def __init__(self, config: PatchTSMixerConfig):
        super().__init__(config)

        # 初始化 PatchTSMixerModel 模型
        self.model = PatchTSMixerModel(config)

        # 设置损失函数和输出分布
        self.loss = config.loss
        self.distribution_output = config.distribution_output

        # 是否返回字典形式的输出
        self.use_return_dict = config.use_return_dict
        # 并行采样的数量
        self.num_parallel_samples = config.num_parallel_samples

        # 根据损失函数选择相应的输出分布类别
        if config.loss == "mse":
            self.distribution_output = None
        else:
            distribution_output_map = {
                "student_t": StudentTOutput,
                "normal": NormalOutput,
                "negative_binomial": NegativeBinomialOutput,
            }
            output_class = distribution_output_map.get(config.distribution_output)
            if output_class is not None:
                self.distribution_output = output_class(dim=config.num_targets)
            else:
                raise ValueError(f"Unknown distribution output {config.distribution_output}")

        # 根据 scaling 参数选择是否注入尺度统计信息
        if config.scaling in ["std", "mean", True]:
            self.inject_scale = InjectScalerStatistics4D(d_model=config.d_model, num_patches=config.num_patches)
        else:
            self.inject_scale = None

        # 初始化线性头部
        self.head = PatchTSMixerLinearHead(
            config=config,
            distribution_output=self.distribution_output,
        )

        # 如果需要,在初始化后执行后处理操作
        if config.post_init:
            self.post_init()

    @add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=PatchTSMixerForRegressionOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        past_values: torch.Tensor,
        target_values: torch.Tensor = None,
        output_hidden_states: Optional[bool] = False,
        return_loss: bool = True,
        return_dict: Optional[bool] = None,
    ):
        """
        实现模型的前向传播。

        Args:
            past_values (torch.Tensor):
                过去的值,作为模型的输入。
            target_values (torch.Tensor, optional):
                目标值,用于计算损失。默认为 None。
            output_hidden_states (bool, optional):
                是否输出隐藏状态。默认为 False。
            return_loss (bool):
                是否返回损失值。默认为 True。
            return_dict (bool, optional):
                是否返回字典形式的输出。默认为 None。

        Returns:
            根据 return_dict 参数决定的输出形式。
        """
        # 实现模型的具体逻辑,这里可以包含调用各个模块的过程
        ...

    def generate(
        self,
        past_values: torch.Tensor,
        ...
    ):
        # 生成方法的具体实现
        ...
    ) -> SamplePatchTSMixerRegressionOutput:
        """
        Generate sequences of sample predictions from a model with a probability distribution head.

        Args:
            past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Past values of the time series that serves as context in order to predict the target values.

        Return:
            [`SamplePatchTSMixerRegressionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
            number of samples, num_targets)`.
        """
        # 获取并行采样的数量
        num_parallel_samples = self.num_parallel_samples

        # 获得模型输出
        outputs = self(
            past_values=past_values,
            target_values=None,
            output_hidden_states=False,
        )

        # 获取输出分布
        distribution = self.distribution_output.distribution(outputs.regression_outputs)

        # 生成样本
        samples = [
            distribution.sample() for _ in range(num_parallel_samples)
        ]  # samples: list of [batch_size x num_targets]
        # 堆叠张量
        # [batch_size x num_samples x num_targets]
        samples = torch.stack(samples, dim=1).view(-1, num_parallel_samples, self.config.num_targets)
        return SamplePatchTSMixerRegressionOutput(sequences=samples)

.\models\patchtsmixer\__init__.py

代码:


#
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
## 注意:这个代码块是用来组织和提供开源代码许可的元数据和项目结构的,以及决定不同组件的可用性。
#
## 实现了一个标准的包装和加载元数据的逻辑结构,用于表明项目是如何组织的,并提供了用于组件切换的功能,即对于需要依赖torch的组件取决于torch是否可用。

# 引入必要的依赖结构
from typing import TYPE_CHECKING

# 创建一个字典用于存储预训练的配置文件的下载链接和相关类名。
_import_structure = {
    "configuration_patchtsmixer": [
        "PATCHTSMIXER_PRETRAINED_CONFIG_ARCHIVE_MAP",  # 字典包含预训练配置文件的远程链接列表
        "PatchTSMixerConfig",  # 代表模型配置类名
    ]
}

try:
    # 尝试检查torch是否可用,
    if not is_torch_available():  # 如果torch可用性检查结果为False
        raise OptionalDependencyNotAvailable()  # 抛出自定义异常
except OptionalDependencyNotAvailable:
    pass  # 如果torch不可用则忽略后续代码并直接通过is_torch_available()
else:
    # 如果torch可用,从模型层引入不同模型类相关的预训练链接和类名,
    _import_structure["modeling_patchtsmixer"] = [
        "PATCHTSMIXER_PRETRAINED_MODEL_ARCHIVE_LIST",  # 模型预训练集合的远程链接列表
        "PatchTSMixerPreTrainedModel",  # 代表预训练模型类名
        "PatchTSMixerModel",  # 代表主模型类名
        "PatchTSMixerForPretraining",  # 代表用于预训练的模型类名
        "PatchTSMixerForPrediction",  # 代表预测任务的模型类名
        "PatchTSMixerForTimeSeriesClassification",  # 代表时序分类任务的模型类名
        "PatchTSMixerForRegression",  # 代表回归任务的模型类名
    ]

## 默认情况下提供IDE检察功能:
if TYPE_CHECKING:
    # 为装修或需要导入类型提示的模块,引入相关类的类型信息。
    from .configuration_patchtsmixer import (
        PATCHTSMIXER_PRETRAINED_CONFIG_ARCHIVE_MAP,  # 预训练模型配置信息类型
        PatchTSMixerConfig,  # 配置类的具体类型引用
    )

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()  # 修改内部状态检查处,但对于IDE注释策略,提供了类型检查支持
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果在指定条件下(即已确认没有torch依赖),导入模型类的部分及方法类型定义。
        from .modeling_patchtsmixer import (
            PATCHTSMIXER_PRETRAINED_MODEL_ARCHIVE_LIST,  # 预训练模型集合类型
            PatchTSMixerForPrediction,  # 预定义预测类类型
            PatchTSMixerForPretraining,  # 预定义预训练类类型
            PatchTSMixerForRegression,  # 预定义回归任务类类型
            PatchTSMixerForTimeSeriesClassification,  # 预定义时间序列分类任务类类型
            PatchTSMixerModel,  # 主模型类类型
            PatchTSMixerPreTrainedModel,  # 预训练模型基类类型
        )

# 在实际使用中,如果代码块遵循模块和类/函数的组织结构,会接受相对导入(如上面的例子),正确导入依赖并延迟直接执行过程。
else:
    # 空导入模块,通过创建一个惰性加载模块,将容器和绝对/相对导入组织当作元数据和初始检查逻辑,并封盖现有模块状态,以节省资源。
    import sys

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\patchtst\configuration_patchtst.py

# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PatchTST model configuration"""

from typing import List, Optional, Union

# 导入所需的类和函数
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

# 获取日志记录器对象
logger = logging.get_logger(__name__)

# 预训练模型的配置映射表,指定了模型名称及其对应的配置文件链接
PATCHTST_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "ibm/patchtst-base": "https://huggingface.co/ibm/patchtst-base/resolve/main/config.json",
    # 查看所有 PatchTST 模型的链接地址:https://huggingface.co/ibm/models?filter=patchtst
}


class PatchTSTConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of an [`PatchTSTModel`]. It is used to instantiate an
    PatchTST model according to the specified arguments, defining the model architecture.
    [ibm/patchtst](https://huggingface.co/ibm/patchtst) architecture.

    Configuration objects inherit from [`PretrainedConfig`] can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    ```
    >>> from transformers import PatchTSTConfig, PatchTSTModel

    >>> # Initializing an PatchTST configuration with 12 time steps for prediction
    >>> configuration = PatchTSTConfig(prediction_length=12)

    >>> # Randomly initializing a model (with random weights) from the configuration
    >>> model = PatchTSTModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    """

    # 指定模型类型
    model_type = "patchtst"
    # 属性映射字典,将模型配置属性名映射到实际使用的名称
    attribute_map = {
        "hidden_size": "d_model",
        "num_attention_heads": "num_attention_heads",
        "num_hidden_layers": "num_hidden_layers",
    }
    # 初始化函数,用于初始化时间序列特定配置和Transformer模型参数
    def __init__(
        self,
        # 输入时间序列的通道数,默认为1
        num_input_channels: int = 1,
        # 上下文长度,默认为32,表示模型每次处理的时间步数
        context_length: int = 32,
        # 分布输出类型,默认为"student_t",指定模型输出的概率分布类型
        distribution_output: str = "student_t",
        # 损失函数类型,默认为"mse",表示模型训练过程中使用的损失函数类型
        loss: str = "mse",
        # PatchTST模型参数
        patch_length: int = 1,
        patch_stride: int = 1,
        # Transformer模型架构配置
        num_hidden_layers: int = 3,
        d_model: int = 128,
        num_attention_heads: int = 4,
        share_embedding: bool = True,
        channel_attention: bool = False,
        ffn_dim: int = 512,
        norm_type: str = "batchnorm",
        norm_eps: float = 1e-05,
        attention_dropout: float = 0.0,
        dropout: float = 0.0,
        positional_dropout: float = 0.0,
        path_dropout: float = 0.0,
        ff_dropout: float = 0.0,
        bias: bool = True,
        activation_function: str = "gelu",
        pre_norm: bool = True,
        positional_encoding_type: str = "sincos",
        use_cls_token: bool = False,
        init_std: float = 0.02,
        share_projection: bool = True,
        scaling: Optional[Union[str, bool]] = "std",
        # 掩码预训练相关参数
        do_mask_input: Optional[bool] = None,
        mask_type: str = "random",
        random_mask_ratio: float = 0.5,
        num_forecast_mask_patches: Optional[Union[List[int], int]] = [2],
        channel_consistent_masking: Optional[bool] = False,
        unmasked_channel_indices: Optional[List[int]] = None,
        mask_value: int = 0,
        # 头部相关参数
        pooling_type: str = "mean",
        head_dropout: float = 0.0,
        prediction_length: int = 24,
        num_targets: int = 1,
        output_range: Optional[List] = None,
        # 分布头部相关参数
        num_parallel_samples: int = 100,
        **kwargs,
    ):
        # time series specific configuration
        # 设置上下文长度
        self.context_length = context_length
        # 设置输入通道数量
        self.num_input_channels = num_input_channels  # n_vars
        # 损失函数
        self.loss = loss
        # 输出分布类型
        self.distribution_output = distribution_output
        # 并行采样数量
        self.num_parallel_samples = num_parallel_samples

        # Transformer 架构配置
        # 模型维度
        self.d_model = d_model
        # 注意力头数
        self.num_attention_heads = num_attention_heads
        # 前馈神经网络维度
        self.ffn_dim = ffn_dim
        # 隐藏层数量
        self.num_hidden_layers = num_hidden_layers
        # 全连接层的 dropout
        self.dropout = dropout
        # 注意力机制的 dropout
        self.attention_dropout = attention_dropout
        # 是否共享嵌入层
        self.share_embedding = share_embedding
        # 通道注意力
        self.channel_attention = channel_attention
        # 规范化类型
        self.norm_type = norm_type
        # 规范化的 epsilon 值
        self.norm_eps = norm_eps
        # 位置编码的 dropout
        self.positional_dropout = positional_dropout
        # 路径的 dropout
        self.path_dropout = path_dropout
        # 前馈网络的 dropout
        self.ff_dropout = ff_dropout
        # 是否添加偏置
        self.bias = bias
        # 激活函数类型
        self.activation_function = activation_function
        # 是否在规范化前应用激活函数
        self.pre_norm = pre_norm
        # 位置编码类型
        self.positional_encoding_type = positional_encoding_type
        # 是否使用 CLS token
        self.use_cls_token = use_cls_token
        # 初始化标准差
        self.init_std = init_std
        # 缩放倍率
        self.scaling = scaling

        # PatchTST 参数
        # 补丁长度
        self.patch_length = patch_length
        # 补丁步长
        self.patch_stride = patch_stride

        # Mask 预训练
        # 是否进行输入遮罩
        self.do_mask_input = do_mask_input
        # 遮罩类型
        self.mask_type = mask_type
        # 随机遮罩比例
        self.random_mask_ratio = random_mask_ratio  # for random masking
        # 预测遮罩的数量
        self.num_forecast_mask_patches = num_forecast_mask_patches  # for forecast masking
        # 通道一致性遮罩
        self.channel_consistent_masking = channel_consistent_masking
        # 未遮罩通道的索引
        self.unmasked_channel_indices = unmasked_channel_indices
        # 遮罩值
        self.mask_value = mask_value

        # 通用头参数
        # 汇聚类型
        self.pooling_type = pooling_type
        # 头部的 dropout
        self.head_dropout = head_dropout

        # 用于预测头
        # 是否共享投影
        self.share_projection = share_projection
        # 预测长度
        self.prediction_length = prediction_length

        # 用于预测和回归头
        # 并行采样数量
        self.num_parallel_samples = num_parallel_samples

        # 回归
        # 目标数量
        self.num_targets = num_targets
        # 输出范围
        self.output_range = output_range

        super().__init__(**kwargs)

.\models\patchtst\modeling_patchtst.py

# coding=utf-8
# Copyright 2023 IBM & Hugging Face. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch PatchTST model."""

import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
from torch import nn

from ...activations import ACT2CLS
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
from ...utils import ModelOutput, add_start_docstrings, logging
from .configuration_patchtst import PatchTSTConfig


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "PatchTSTConfig"

PATCHTST_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "ibm/patchtst-etth1-pretrain",
    # See all PatchTST models at https://huggingface.co/models?filter=patchtst
]


# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PatchTST
class PatchTSTAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        bias: bool = True,
        is_causal: bool = False,
        config: Optional[PatchTSTConfig] = None,
    ):
        super().__init__()
        # 初始化注意力机制模块
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads  # 计算每个头的维度
        self.config = config

        # 检查embed_dim必须能被num_heads整除
        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {num_heads})."
            )
        # 缩放因子
        self.scaling = self.head_dim**-0.5
        self.is_decoder = is_decoder
        self.is_causal = is_causal

        # 线性变换层,用于计算查询、键、值、输出
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        # 将张量重塑成适合多头注意力计算的形状
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
    # 定义 Transformer 模型的前向传播函数
    def forward(
        self,
        # 输入隐藏状态张量,通常是上一层的输出
        hidden_states: torch.Tensor,
        # 键-值状态张量,用于自注意力机制,可选参数
        key_value_states: Optional[torch.Tensor] = None,
        # 上一步的键-值状态元组,用于复用键-值状态,可选参数
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        # 注意力掩码张量,控制哪些位置被注意力机制处理,可选参数
        attention_mask: Optional[torch.Tensor] = None,
        # 层级头掩码张量,用于控制层级注意力,可选参数
        layer_head_mask: Optional[torch.Tensor] = None,
        # 是否输出注意力权重,布尔值,默认为 False
        output_attentions: bool = False,
class PatchTSTBatchNorm(nn.Module):
    """
    Compute batch normalization over the sequence length (time) dimension.
    """

    def __init__(self, config: PatchTSTConfig):
        super().__init__()
        # 初始化 BatchNorm1d,设置输入维度为 config.d_model,epsilon 参数为 config.norm_eps
        self.batchnorm = nn.BatchNorm1d(config.d_model, eps=config.norm_eps)

    def forward(self, inputs: torch.Tensor):
        """
        Parameters:
            inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`):
                input for Batch norm calculation
        Returns:
            `torch.Tensor` of shape `(batch_size, sequence_length, d_model)`
        """
        # 将输入张量转置,调整维度顺序为 (batch_size, d_model, sequence_length)
        output = inputs.transpose(1, 2)  # output: (batch_size, d_model, sequence_length)
        # 应用 BatchNorm1d 进行批量归一化
        output = self.batchnorm(output)
        # 再次转置,将维度顺序恢复为 (batch_size, sequence_length, d_model),并返回结果
        return output.transpose(1, 2)


def random_masking(
    inputs: torch.Tensor,
    mask_ratio: float,
    unmasked_channel_indices: list = None,
    channel_consistent_masking: bool = False,
    mask_value: int = 0,
):
    """random_masking: Mask the input considering the control variables.

    Args:
        inputs (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, num_features)`):
            The input tensor to mask.
        mask_ratio (`float`):
            Masking ratio applied to mask the input data during random pretraining. It is the number between 0 and 1.
        unmasked_channel_indices (list, *optional*):
            Indices of channels that will not be masked.
        channel_consistent_masking (bool, *optional*, defaults to `False`):
            When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary
            across channels.
        mask_value (int, *optional*, defaults to 0):
            Define the value of masked patches for pretraining.

    Returns:
        `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as input Tensor and mask tensor of shape [bs x c x
        n]
    """
    # 检查 mask_ratio 是否在有效范围内
    if mask_ratio < 0 or mask_ratio >= 1:
        raise ValueError(f"Mask ratio {mask_ratio} has to be between 0 and 1.")

    batch_size, num_channels, sequence_length, num_features = inputs.shape
    device = inputs.device

    # 计算不被遮盖的长度
    len_keep = int(sequence_length * (1 - mask_ratio))

    if channel_consistent_masking:
        # 生成随机噪声,形状为 bs x 1 x L
        noise = torch.rand(batch_size, 1, sequence_length, device=device)
        # 将噪声在通道维度上复制,形状变为 bs x num_channels x L
        noise = noise.repeat(1, num_channels, 1)
    else:
        # 生成随机噪声,形状为 bs x num_channels x L
        noise = torch.rand(batch_size, num_channels, sequence_length, device=device)

    # 创建遮罩张量,形状为 bs x num_channels x L,并初始化为全 1
    mask = torch.ones(batch_size, num_channels, sequence_length, device=device)
    # 将部分位置置为 0,以实现遮盖操作
    mask[:, :, :len_keep] = 0

    # 对噪声进行排序,得到排序后的索引,用于确定要保留的位置
    ids_shuffle = torch.argsort(noise, dim=-1)
    # 创建恢复索引,将排序后的索引恢复到原始顺序
    ids_restore = torch.argsort(ids_shuffle, dim=-1)
    # 使用给定的索引ids_restore从mask张量中按列收集数据,形成新的mask张量
    mask = torch.gather(mask, dim=-1, index=ids_restore)
    # 在最后一个维度上增加一个维度,并将其复制多次,扩展为指定形状
    mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features)  # mask: [bs x num_channels x num_patches x patch_length]
    # 如果unmasked_channel_indices不为None,则将指定通道的mask值置为0
    if unmasked_channel_indices is not None:
        mask[:, unmasked_channel_indices, :, :] = 0

    # 使用bool类型的mask张量,在inputs中将对应位置的值填充为指定的mask_value
    inputs_mask = inputs.masked_fill(mask.bool(), mask_value)
    # 返回处理后的inputs_mask和mask张量的第一个通道的数据
    return inputs_mask, mask[..., 0]
# 定义一个预测掩码函数,用于在输入的时间序列数据中掩盖预测期末的部分补丁。如果 num_forecast_mask_patches 是一个列表,批次中的样本将随机掩盖列表中定义的补丁数。
def forecast_masking(
    inputs: torch.Tensor,
    num_forecast_mask_patches: Union[list, int],
    unmasked_channel_indices: list = None,
    mask_value: int = 0,
):
    """Forecast masking that masks the last K patches where K is from the num_forecast_mask_patches.
    If num_forecast_mask_patches is a list, samples in the batch will be randomly masked by numbers defined in the list.

    Parameters:
        inputs (`torch.Tensor`):
            Input of shape `(bs, num_channels, num_patch, patch_length)`
        num_forecast_mask_patches (`list`):
            Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
        unmasked_channel_indices (`list`, *optional*):
            Indices of channels that are not masked.
        mask_value (`int`, *optional*, defaults to 0):
            Values in the masked patches will be filled by `mask_value`.

    Returns:
        `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs,
        num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)`
    """

    # 如果 num_forecast_mask_patches 是整数,则转换为列表形式方便处理
    if isinstance(num_forecast_mask_patches, int):
        num_forecast_mask_patches = [num_forecast_mask_patches]
    
    # 初始化每个预测掩码比例为 1
    forecast_mask_ratios = [1 for _ in num_forecast_mask_patches]

    # 获取输入的形状信息
    batch_size, num_channels, sequence_length, num_features = inputs.shape

    # 创建一个全零的掩码张量,形状与输入数据相同
    mask = torch.zeros(batch_size, num_channels, sequence_length, device=inputs.device)

    # 初始化用于存储各个补丁长度、比例和临时长度的列表
    t_list = []
    total_length = 0
    total_ratio = sum(forecast_mask_ratios)

    # 遍历每个预测掩码长度和比例,并根据比例分配临时长度
    for patch_length, ratio in zip(num_forecast_mask_patches, forecast_mask_ratios):
        # 检查补丁长度是否合理
        if patch_length <= 0 or patch_length >= sequence_length:
            raise ValueError(
                f"num_forecast_mask_patches {patch_length} should be greater than 0 and less than total patches."
            )
        temp_len = int(batch_size * ratio / total_ratio)
        t_list.append([patch_length, ratio, temp_len])
        total_length += temp_len

    # 按临时长度排序 t_list
    t_list = sorted(t_list, key=lambda x: x[2])

    # 如果总临时长度小于批次大小,调整第一个补丁的临时长度
    if total_length < batch_size:
        t_list[0][2] = t_list[0][2] + (batch_size - total_length)
    # 如果总临时长度大于批次大小,调整最后一个补丁的临时长度
    elif total_length > batch_size:
        t_list[-1][2] = t_list[-1][2] + (total_length - batch_size)

    # 初始化变量用于迭代赋值掩码
    batch1 = 0
    for patch_len, _, temp_len in t_list:
        batch2 = batch1 + temp_len
        # 在掩码的最后 patch_len 长度处进行赋值为 1,表示需要掩盖的部分
        mask[batch1:batch2, :, -patch_len:] = 1
        batch1 = batch2

    # 随机打乱掩码的顺序
    perm = torch.randperm(mask.shape[0])
    mask = mask[perm]

    # 将掩码扩展维度以匹配输入数据的形状
    mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features)  # mask: [bs x num_channels x num_patch x patch_len]

    # 如果提供了未掩盖的通道索引,将这些通道的掩码值设为 0
    if unmasked_channel_indices is not None:
        mask[:, unmasked_channel_indices, :, :] = 0

    # 根据掩码值将输入数据进行掩码处理
    inputs_mask = inputs.masked_fill(mask.bool(), mask_value)

    # 返回掩码后的输入数据和掩码张量的第一个通道
    return inputs_mask, mask[..., 0]
    # 初始化方法,接受一个 PatchTSTConfig 类型的配置对象作为参数
    def __init__(self, config: PatchTSTConfig):
        # 调用父类的初始化方法
        super().__init__()

        # 设置对象的序列长度、补丁长度和补丁步幅
        self.sequence_length = config.context_length
        self.patch_length = config.patch_length
        self.patch_stride = config.patch_stride

        # 如果序列长度小于等于补丁长度,则抛出数值错误异常
        if self.sequence_length <= self.patch_length:
            raise ValueError(
                f"Sequence length ({self.sequence_length}) has to be greater than the patch length ({self.patch_length})"
            )

        # 计算补丁的数量
        self.num_patches = (max(self.sequence_length, self.patch_length) - self.patch_length) // self.patch_stride + 1
        # 计算新的序列长度
        new_sequence_length = self.patch_length + self.patch_stride * (self.num_patches - 1)
        # 计算序列的起始位置
        self.sequence_start = self.sequence_length - new_sequence_length

    def forward(self, past_values: torch.Tensor):
        """
        Parameters:
            past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*):
                Input for patchification

        Returns:
            `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
        """
        # 检查输入的序列长度是否与模型配置的序列长度相匹配
        sequence_length = past_values.shape[-2]
        if sequence_length != self.sequence_length:
            raise ValueError(
                f"Input sequence length ({sequence_length}) doesn't match model configuration ({self.sequence_length})."
            )

        # 输出形状: [batch_size x new_sequence_length x num_channels]
        output = past_values[:, self.sequence_start :, :]
        # 按照补丁步幅展开序列
        # 输出形状: [batch_size x num_patches x num_input_channels x patch_length]
        output = output.unfold(dimension=-2, size=self.patch_length, step=self.patch_stride)
        # 转置输出,调整维度顺序
        # 输出形状: [batch_size x num_input_channels x num_patches x patch_length]
        output = output.transpose(-2, -3).contiguous()
        return output
class PatchTSTMasking(nn.Module):
    """
    Class to perform random or forecast masking.

    Parameters:
        config (`PatchTSTConfig`): model config
    Returns:
        x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
            Masked patched input
        mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
            Bool tensor indicating True on masked points
    """

    def __init__(self, config: PatchTSTConfig):
        super().__init__()
        self.random_mask_ratio = config.random_mask_ratio  # 设置随机遮蔽比例
        self.channel_consistent_masking = config.channel_consistent_masking  # 是否进行通道一致的遮蔽
        self.mask_type = config.mask_type  # 遮蔽类型,随机或预测
        self.num_forecast_mask_patches = config.num_forecast_mask_patches  # 预测遮蔽时的遮蔽补丁数量
        self.unmasked_channel_indices = config.unmasked_channel_indices  # 未遮蔽的通道索引列表
        self.mask_value = config.mask_value  # 遮蔽值的设置
        if self.unmasked_channel_indices is not None:
            self.unmasked_channel_indices = sorted(self.unmasked_channel_indices)  # 如果有未遮蔽的通道索引,进行排序

    def forward(self, patch_input: torch.Tensor):
        """
        Parameters:
            patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
                Patch input

        Return:
            masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
                Masked patched input
            mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
                Bool tensor indicating True on masked points

        """
        if self.mask_type == "random":
            # 执行随机遮蔽
            masked_input, mask = random_masking(
                inputs=patch_input,
                mask_ratio=self.random_mask_ratio,
                unmasked_channel_indices=self.unmasked_channel_indices,
                channel_consistent_masking=self.channel_consistent_masking,
                mask_value=self.mask_value,
            )
        elif self.mask_type == "forecast":
            # 执行预测遮蔽
            masked_input, mask = forecast_masking(
                inputs=patch_input,
                num_forecast_mask_patches=self.num_forecast_mask_patches,
                unmasked_channel_indices=self.unmasked_channel_indices,
                mask_value=self.mask_value,
            )
        else:
            # 抛出无效的遮蔽类型错误
            raise ValueError(f"Invalid mask type {self.mask_type}.")

        # 将遮蔽张量转换为布尔类型
        mask = mask.bool()
        return masked_input, mask


class PatchTSTEncoderLayer(nn.Module):
    """
    PatchTST encoder layer
    """
    def __init__(self, config: PatchTSTConfig):
        super().__init__()

        self.channel_attention = config.channel_attention
        # Multi-Head attention
        self.self_attn = PatchTSTAttention(
            embed_dim=config.d_model,
            num_heads=config.num_attention_heads,
            dropout=config.attention_dropout,
        )

        # Add & Norm of the sublayer 1
        self.dropout_path1 = nn.Dropout(config.path_dropout) if config.path_dropout > 0 else nn.Identity()
        # 根据配置选择不同的规范化层(批标准化或层标准化)
        if config.norm_type == "batchnorm":
            self.norm_sublayer1 = PatchTSTBatchNorm(config)
        elif config.norm_type == "layernorm":
            self.norm_sublayer1 = nn.LayerNorm(config.d_model, eps=config.norm_eps)
        else:
            raise ValueError(f"{config.norm_type} is not a supported norm layer type.")

        # Add & Norm of the sublayer 2, conditionally based on self.channel_attention
        if self.channel_attention:
            self.dropout_path2 = nn.Dropout(config.path_dropout) if config.path_dropout > 0 else nn.Identity()
            # 根据配置选择不同的规范化层(批标准化或层标准化)
            if config.norm_type == "batchnorm":
                self.norm_sublayer2 = PatchTSTBatchNorm(config)
            elif config.norm_type == "layernorm":
                self.norm_sublayer2 = nn.LayerNorm(config.d_model, eps=config.norm_eps)
            else:
                raise ValueError(f"{config.norm_type} is not a supported norm layer type.")

        # Position-wise Feed-Forward
        self.ff = nn.Sequential(
            nn.Linear(config.d_model, config.ffn_dim, bias=config.bias),
            ACT2CLS[config.activation_function](),  # 使用配置中的激活函数类别激活线性层输出
            nn.Dropout(config.ff_dropout) if config.ff_dropout > 0 else nn.Identity(),
            nn.Linear(config.ffn_dim, config.d_model, bias=config.bias),
        )

        # Add & Norm of sublayer 3
        self.dropout_path3 = nn.Dropout(config.path_dropout) if config.path_dropout > 0 else nn.Identity()
        # 根据配置选择不同的规范化层(批标准化或层标准化)
        if config.norm_type == "batchnorm":
            self.norm_sublayer3 = PatchTSTBatchNorm(config)
        elif config.norm_type == "layernorm":
            self.norm_sublayer3 = nn.LayerNorm(config.d_model, eps=config.norm_eps)
        else:
            raise ValueError(f"{config.norm_type} is not a supported norm layer type.")

        self.pre_norm = config.pre_norm
class PatchTSTPreTrainedModel(PreTrainedModel):
    # 设置配置类
    config_class = PatchTSTConfig
    # 基础模型前缀
    base_model_prefix = "model"
    # 主输入名称
    main_input_name = "past_values"
    # 不支持梯度检查点
    supports_gradient_checkpointing = False

    def _init_weights(self, module):
        """
        初始化权重
        """
        if isinstance(module, PatchTSTPositionalEncoding):
            # 初始化 cls_token
            if self.config.use_cls_token:
                nn.init.normal_(module.cls_token, std=0.02)
            # 初始化位置编码
            if self.config.positional_encoding_type == "random":
                nn.init.normal_(module.position_enc, mean=0.0, std=0.1)
        elif isinstance(module, nn.LayerNorm):
            # 将偏置项初始化为零
            module.bias.data.zero_()
            # 将权重初始化为1.0
            module.weight.data.fill_(1.0)
        elif isinstance(module, PatchTSTBatchNorm):
            # 将批归一化层的偏置项初始化为零
            module.batchnorm.bias.data.zero_()
            # 将批归一化层的权重初始化为1.0
            module.batchnorm.weight.data.fill_(1.0)
        elif isinstance(module, (nn.Linear, nn.Conv1d)):
            # 将权重初始化为正态分布随机值
            module.weight.data.normal_(mean=0.0, std=self.config.init_std)
            # 如果存在偏置项,则初始化为零
            if module.bias is not None:
                module.bias.data.zero_()

    def _set_gradient_checkpointing(self, module, value=False):
        # 如果是 PatchTSTEncoder 类型的模块,设置梯度检查点
        if isinstance(module, (PatchTSTEncoder)):
            module.gradient_checkpointing = value


class PatchTSTEmbedding(nn.Module):
    def __init__(self, config: PatchTSTConfig):
        super().__init__()
        self.num_input_channels = config.num_input_channels
        self.share_embedding = config.share_embedding
        # 输入编码:将特征向量投影到 d 维向量空间
        if self.share_embedding:
            # 如果共享嵌入层,则使用线性映射
            self.input_embedding = nn.Linear(config.patch_length, config.d_model)
        else:
            # 如果不共享嵌入层,则创建多个线性映射
            self.input_embedding = nn.ModuleList()
            for _ in range(config.num_input_channels):
                self.input_embedding.append(nn.Linear(config.patch_length, config.d_model))
    def forward(self, patch_input: torch.Tensor):
        """
        Parameters:
            patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
                Patch input for embedding
        return:
            `torch.Tensor` of shape `(batch_size, num_channels, num_patches, d_model)`
        """
        # Input encoding

        # 获取输入张量的通道数
        num_input_channels = patch_input.shape[1]

        # 检查输入通道数是否与配置中的要求一致
        if num_input_channels != self.num_input_channels:
            raise ValueError(
                f"The defined number of input channels ({self.num_input_channels}) in the config "
                f"has to be the same as the number of channels in the batch input ({num_input_channels})"
            )

        # 如果指定共享嵌入层,则使用单个嵌入层对所有通道进行嵌入
        if self.share_embedding:
            embeddings = self.input_embedding(patch_input)  # x: [bs x num_channels  x num_patches x d_model]
        else:
            # 否则,对每个通道分别使用对应的嵌入层进行嵌入
            embeddings = [self.input_embedding[i](patch_input[:, i, :, :]) for i in range(num_input_channels)]
            embeddings = torch.stack(embeddings, dim=1)

        # 返回嵌入后的张量
        return embeddings
class PatchTSTPositionalEncoding(nn.Module):
    """
    Class for positional encoding
    """

    def __init__(self, config: PatchTSTConfig, num_patches: int):
        super().__init__()
        self.use_cls_token = config.use_cls_token  # 是否使用类别令牌标志位
        self.num_input_channels = config.num_input_channels  # 输入通道数
        if config.use_cls_token:
            # cls_token: [1 x num_input_channels x 1 x d_model]
            self.cls_token = nn.Parameter(torch.zeros(1, 1, 1, config.d_model))  # 类别令牌参数初始化
            num_patches += 1  # 如果使用类别令牌,增加补丁数量

        # postional encoding: [num_patches x d_model]
        self.position_enc = self._init_pe(config, num_patches)  # 初始化位置编码

        # Positional dropout
        self.positional_dropout = (
            nn.Dropout(config.positional_dropout) if config.positional_dropout > 0 else nn.Identity()
        )  # 位置dropout,如果设置了dropout则使用,否则使用恒等映射

    @staticmethod
    def _init_pe(config: PatchTSTConfig, num_patches: int) -> nn.Parameter:
        # Positional encoding
        if config.positional_encoding_type == "random":
            position_enc = nn.Parameter(torch.randn(num_patches, config.d_model), requires_grad=True)
        elif config.positional_encoding_type == "sincos":
            position_enc = torch.zeros(num_patches, config.d_model)
            position = torch.arange(0, num_patches).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, config.d_model, 2) * -(math.log(10000.0) / config.d_model))
            position_enc[:, 0::2] = torch.sin(position * div_term)
            position_enc[:, 1::2] = torch.cos(position * div_term)
            position_enc = position_enc - position_enc.mean()
            position_enc = position_enc / (position_enc.std() * 10)
            position_enc = nn.Parameter(position_enc, requires_grad=False)
        else:
            raise ValueError(
                f"{config.positional_encoding_type} is not a valid positional encoder. Available types are 'random' and 'sincos'."
            )
        return position_enc  # 返回位置编码张量作为参数

    def forward(self, patch_input: torch.Tensor):
        if self.use_cls_token:
            # patch_input: [bs x num_channels x num_patches x d_model]
            patch_input = self.positional_dropout(patch_input + self.position_enc[1:, :])
            # append cls token where cls_token: [1 x num_channels x 1 x d_model]
            cls_token = self.cls_token + self.position_enc[:1, :]
            # get the same copy of cls_token for all the samples in batch: [bs x num_channels x 1 x d_model]
            cls_tokens = cls_token.expand(patch_input.shape[0], self.num_input_channels, -1, -1)
            # hidden_state: [bs x num_channels x (num_patches+1) x d_model]
            hidden_state = torch.cat((cls_tokens, patch_input), dim=2)
        else:
            # hidden_state: [bs x num_channels x num_patches x d_model]
            hidden_state = self.positional_dropout(patch_input + self.position_enc)
        return hidden_state


class PatchTSTEncoder(PatchTSTPreTrainedModel):
    """
    PatchTST Encoder
    """
    def __init__(self, config: PatchTSTConfig, num_patches: int):
        super().__init__(config)
        self.gradient_checkpointing = False

        # Input embedding: projection of feature vectors onto a d-dim vector space
        self.embedder = PatchTSTEmbedding(config)
        # Positional encoding
        self.positional_encoder = PatchTSTPositionalEncoding(config, num_patches)
        # Encoder
        self.layers = nn.ModuleList([PatchTSTEncoderLayer(config) for i in range(config.num_hidden_layers)])

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        patch_input: torch.Tensor,
        output_hidden_states: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
    ) -> BaseModelOutput:
        """
        Parameters:
            patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
                Past values of the time series
            output_hidden_states (bool, optional): Indicates if hidden states should be outputted.
            output_attentions (bool, optional): Indicates if attentions should be outputted.

        return:
            `BaseModelOutput`
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )

        # Input embedding
        patch_input = self.embedder(patch_input)
        # Positional encoding
        hidden_state = self.positional_encoder(patch_input)

        encoder_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
        for encoder_layer in self.layers:
            if output_hidden_states:
                # Collect hidden states if requested
                encoder_states = encoder_states + (hidden_state,)

            # Process each encoder layer
            layer_outputs = encoder_layer(hidden_state=hidden_state, output_attentions=output_attentions)
            # Update hidden state to the output of the current layer
            hidden_state = layer_outputs[0]
            if output_attentions:
                # Collect attention matrices if requested
                all_attentions = all_attentions + (layer_outputs[1],)

        # Return model output including final hidden states and attentions
        return BaseModelOutput(last_hidden_state=hidden_state, hidden_states=encoder_states, attentions=all_attentions)
# 定义文档字符串,说明了这个模型继承自 `PreTrainedModel`,可以使用该类中的通用方法,如下载、保存、调整输入嵌入等。
# 这个模型也是一个 PyTorch 的 `torch.nn.Module` 子类,可以像普通的 PyTorch 模块一样使用,相关的使用和行为请参考 PyTorch 文档。

PATCHTST_START_DOCSTRING = r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`PatchTSTConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

@dataclass
class PatchTSTModelOutput(ModelOutput):
    """
    Base class for model's outputs, with potential hidden states.

    Parameters:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of
            the model at the output of each layer plus the optional initial embedding outputs.
        mask: (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`, *optional*)
            Bool masked tensor indicating which patches are masked
        loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*)
            Mean of the input data (batch_size, sequence_length, num_channels) over the sequence_length
        scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*)
            Std of the input data (batch_size, sequence_length, num_channels) over the sequence_length
        patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
            Patched input to the Transformer
    """

    last_hidden_state: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    mask: torch.FloatTensor = None
    loc: torch.FloatTensor = None
    scale: torch.FloatTensor = None
    patch_input: torch.FloatTensor = None


@dataclass
class PatchTSTForPretrainingOutput(ModelOutput):
    """
    Output type of [`PatchTSTForPretraining`].
    
    This class defines the output structure specifically for the `PatchTSTForPretraining` model, but does not contain any additional fields.
    It inherits from `ModelOutput`, which is a base class providing basic fields like `last_hidden_state`, `hidden_states`, etc.
    """
    Parameters:
        loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
            MSE loss.
            MSE(均方误差)损失值,仅在提供了`labels`时返回,类型为`torch.FloatTensor`,形状为`(1,)`。

        prediction_outputs (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction outputs of the time series modeling heads.
            时间序列建模头部的预测输出,类型为`torch.FloatTensor`,形状为`(batch_size, sequence_length, config.vocab_size)`。

        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
            模型每一层输出的隐藏状态,以及初始嵌入输出的元组,类型为`tuple(torch.FloatTensor)`,形状为`(batch_size, sequence_length, hidden_size)`。
        
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
            注意力权重,在经过注意力 softmax 后得到,用于计算自注意力头部的加权平均,类型为`tuple(torch.FloatTensor)`,形状为`(batch_size, num_heads, sequence_length, sequence_length)`。
# 使用 dataclass 装饰器定义 PatchTSTForRegressionOutput 类,表示回归模型的输出结果
@dataclass
class PatchTSTForRegressionOutput(ModelOutput):
    """
    Output type of [`PatchTSTForRegression`].

    Parameters:
        loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
            MSE loss.
            均方误差损失,仅在提供 `labels` 参数时返回,类型为 `torch.FloatTensor`,形状为 `(1,)`。
        regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
            Regression outputs of the time series modeling heads.
            时间序列建模头部的回归输出,类型为 `torch.FloatTensor`,形状为 `(batch_size, num_targets)`。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
            模型在每层输出的隐藏状态,包括初始嵌入输出,类型为 `tuple(torch.FloatTensor)`,仅在传递 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回。
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
            注意力权重经过注意力 softmax 后的值,用于计算自注意力头中的加权平均,类型为 `tuple(torch.FloatTensor)`,仅在传递 `output_attentions=True` 或 `config.output_attentions=True` 时返回。
    """

    # 可选的属性:MSE 损失,类型为 torch.FloatTensor,形状为 `(1,)`
    loss: Optional[torch.FloatTensor] = None
    # 回归模型的输出结果,类型为 torch.FloatTensor,形状为 `(batch_size, num_targets)`
    regression_outputs: torch.FloatTensor = None
    # 可选的属性:模型各层的隐藏状态,类型为 tuple(torch.FloatTensor),形状为 `(batch_size, sequence_length, hidden_size)`
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # 可选的属性:注意力权重,类型为 tuple(torch.FloatTensor),形状为 `(batch_size, num_heads, sequence_length, sequence_length)`
    attentions: Optional[Tuple[torch.FloatTensor]] = None


# 使用 dataclass 装饰器定义 PatchTSTForPredictionOutput 类,表示预测模型的输出结果
@dataclass
class PatchTSTForPredictionOutput(ModelOutput):
    """
    Output type of [`PatchTSTForPrediction`].
    """
    # 定义函数参数及其可选的类型和描述
    
    Parameters:
        loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
            MSE loss.
            MSE 损失(均方误差损失),当提供 `labels` 时返回,类型为 `torch.FloatTensor`,形状为 `(1,)`。
        prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, -1)`):
            Prediction outputs of the time series modeling heads.
            时间序列建模头的预测输出,类型为 `torch.FloatTensor`,形状为 `(batch_size, prediction_length, -1)`。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.
    
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
            模型在每层输出的隐藏状态,以及初始嵌入输出的元组。返回条件包括传递 `output_hidden_states=True` 或 `config.output_hidden_states=True`。
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.
    
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
            注意力权重,经过注意力 softmax 后的结果,用于计算自注意力头部的加权平均值。返回条件包括传递 `output_attentions=True` 或 `config.output_attentions=True`。
        loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*)
            Mean of the input data (batch_size, sequence_length, num_channels) over the sequence_length
            输入数据的均值(在序列长度上)。类型为 `torch.FloatTensor`,形状为 `(batch_size, 1, num_channels)`。
        scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*)
            Std of the input data (batch_size, sequence_length, num_channels) over the sequence_length
            输入数据的标准差(在序列长度上)。类型为 `torch.FloatTensor`,形状为 `(batch_size, 1, num_channels)`。
# 定义一个数据类,用于存储 PatchTST 模型用于分类的输出结果,继承自 ModelOutput。
@dataclass
class PatchTSTForClassificationOutput(ModelOutput):
    """
    Output type of [`PatchTSTForClassification`].

    Parameters:
        loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
            Total loss as the sum of the masked language modeling loss and the next sequence prediction
            (classification) loss.
        prediction_logits (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
            Prediction scores of the PatchTST modeling head (scores before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    loss: Optional[torch.FloatTensor] = None  # 总损失,如果提供了 `labels` 参数,则返回
    prediction_logits: torch.FloatTensor = None  # PatchTST 模型头部的预测分数(SoftMax 前的分数)
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None  # 模型每一层的隐藏状态和初始嵌入输出的元组
    attentions: Optional[Tuple[torch.FloatTensor]] = None  # 自注意力头中注意力权重的元组


# 定义一个数据类,用于存储样本化的 PatchTST 模型输出结果,继承自 ModelOutput。
@dataclass
class SamplePatchTSTOutput(ModelOutput):
    """
    Base class for time series model's predictions outputs that contains the sampled values from the chosen
    distribution.

    Parameters:
        sequences `(batch_size, num_samples, prediction_length, num_targets)`):
                Sampled values from the chosen distribution.
    """

    sequences: torch.FloatTensor = None  # 从选择的分布中抽样得到的值


# 从时间序列变换模型中引用的函数,计算给定分布对于目标的负对数似然损失。
def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor:
    """
    Computes the negative log likelihood loss from input distribution with respect to target.
    """
    return -input.log_prob(target)


# 从时间序列变换模型中引用的函数,计算给定张量在给定维度上的加权平均值,避免权重为零的部分置零而非 NaN。
def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor:
    """
    Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
    meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.
    """
    # 如果给定了权重张量 `weights`
    if weights is not None:
        # 计算加权后的张量,其中权重不为零的位置进行乘法,其余位置置零
        weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor))
        # 计算权重的总和,限制最小值为1.0,按指定的维度 `dim` 进行求和
        sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0)
        # 返回加权张量沿指定维度 `dim` 的平均值
        return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights
    else:
        # 如果没有提供权重张量,则计算输入张量沿指定维度 `dim` 的平均值
        return input_tensor.mean(dim=dim)
# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesStdScaler with TimeSeriesTransformer->PatchTST,TimeSeries->PatchTST
class PatchTSTStdScaler(nn.Module):
    """
    Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by
    subtracting from the mean and dividing by the standard deviation.
    """

    def __init__(self, config: PatchTSTConfig):
        super().__init__()
        # 设置标准化的维度,默认为1
        self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
        # 是否保持维度,True 表示保持,默认为 True
        self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
        # 最小标度,默认为1e-5
        self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-5

    def forward(
        self, data: torch.Tensor, observed_indicator: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                input for Batch norm calculation
            observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Calculating the scale on the observed indicator.
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, num_input_channels)`)
        """
        # 计算标度的分母,根据 observed_indicator 在指定维度上的和
        denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim)
        # 将分母至少设为1.0,避免除以零
        denominator = denominator.clamp_min(1.0)
        # 计算均值 loc,根据 observed_indicator 对 data 加权平均
        loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator

        # 计算方差 variance,根据 observed_indicator 对 data 进行标准差计算
        variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator
        # 计算标度 scale,将方差开根号并加上最小标度
        scale = torch.sqrt(variance + self.minimum_scale)
        # 返回标准化后的数据,均值 loc 和标度 scale
        return (data - loc) / scale, loc, scale


# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesMeanScaler with TimeSeriesTransformer->PatchTST,TimeSeries->PatchTST
class PatchTSTMeanScaler(nn.Module):
    """
    Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data
    accordingly.
    """

    def __init__(self, config: PatchTSTConfig):
        super().__init__()
        # 设置标准化的维度,默认为1
        self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
        # 是否保持维度,True 表示保持,默认为 True
        self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
        # 最小标度,默认为1e-10
        self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10
        # 默认标度,若为 None 则无默认标度
        self.default_scale = config.default_scale if hasattr(config, "default_scale") else None

    def forward(
        self, data: torch.Tensor, observed_indicator: torch.Tensor
    ) -> torch.Tensor:
        """
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                input for Batch norm calculation
            observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Calculating the scale on the observed indicator.
        Returns:
            `torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`
                scaled data based on the computed scaling factor.
        """
        # 计算标度的分母,根据 observed_indicator 在指定维度上的和
        denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim)
        # 将分母至少设为1.0,避免除以零
        denominator = denominator.clamp_min(1.0)
        # 计算均值 loc,根据 observed_indicator 对 data 加权平均
        loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator

        # 计算绝对值的加权平均数,作为标度 scale
        scale = torch.mean(torch.abs(data - loc), dim=self.dim, keepdim=self.keepdim)
        # 若存在默认标度,则应用默认标度
        if self.default_scale is not None:
            scale = torch.max(scale, self.default_scale)

        # 根据计算得到的标度对 data 进行缩放
        return data / scale
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                input for Batch norm calculation
            observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Calculating the scale on the observed indicator.
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, num_input_channels)`)
        """
        # Calculate the sum of absolute values of `data` multiplied by `observed_indicator`
        # along the specified dimension `self.dim`, maintaining the dimensionality.
        ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True)
        
        # Count the number of observed elements (True values) in `observed_indicator`
        # along the specified dimension `self.dim`, maintaining the dimensionality.
        num_observed = observed_indicator.sum(self.dim, keepdim=True)

        # Compute the scale as the ratio of `ts_sum` to `num_observed`, clamping
        # `num_observed` to a minimum value of 1 to avoid division by zero.
        scale = ts_sum / torch.clamp(num_observed, min=1)

        # If `default_scale` is not provided, calculate it based on the sum of `ts_sum`
        # across the batch and the sum of `num_observed` across the batch, clamped to
        # ensure no division by zero.
        if self.default_scale is None:
            batch_sum = ts_sum.sum(dim=0)
            batch_observations = torch.clamp(num_observed.sum(0), min=1)
            default_scale = torch.squeeze(batch_sum / batch_observations)
        else:
            # Use the provided `default_scale` multiplied element-wise by a tensor of ones
            # with the same shape as `scale`.
            default_scale = self.default_scale * torch.ones_like(scale)

        # Apply `default_scale` where `num_observed` is greater than zero, otherwise use `scale`.
        scale = torch.where(num_observed > 0, scale, default_scale)

        # Ensure that `scale` is not less than `self.minimum_scale`.
        scale = torch.clamp(scale, min=self.minimum_scale)

        # Scale `data` by dividing each element by the corresponding element in `scale`.
        scaled_data = data / scale

        # If `self.keepdim` is False, squeeze `scale` along the specified dimension `self.dim`.
        if not self.keepdim:
            scale = scale.squeeze(dim=self.dim)

        # Return the scaled data, a tensor of zeros with the same shape as `scale`,
        # and the computed `scale`.
        return scaled_data, torch.zeros_like(scale), scale
# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesNOPScaler with TimeSeriesTransformer->PatchTST,TimeSeries->PatchTST
# 定义一个模块 PatchTSTNOPScaler,用于数据缩放,不进行实际缩放,仅保持输入数据原样输出
class PatchTSTNOPScaler(nn.Module):
    """
    Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data.
    """
    
    def __init__(self, config: PatchTSTConfig):
        super().__init__()
        # 初始化时设置缩放维度,默认为第一个维度(通常是 batch_size)
        self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
        # 初始化时设置是否保持维度,默认为 True
        self.keepdim = config.keepdim if hasattr(config, "keepdim") else True

    def forward(
        self, data: torch.Tensor, observed_indicator: torch.Tensor = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                input for Batch norm calculation
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, num_input_channels)`)
        """
        # 计算数据的均值,生成与输入数据相同形状的缩放因子
        scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
        # 生成与输入数据相同形状的零向量,作为均值
        loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
        # 返回原始输入数据、均值和缩放因子
        return data, loc, scale


# 定义一个模块 PatchTSTScaler,根据配置选择不同的缩放方式
class PatchTSTScaler(nn.Module):
    def __init__(self, config: PatchTSTConfig):
        super().__init__()
        # 根据配置选择不同的缩放方式
        if config.scaling == "mean" or config.scaling is True:
            self.scaler = PatchTSTMeanScaler(config)
        elif config.scaling == "std":
            self.scaler = PatchTSTStdScaler(config)
        else:
            self.scaler = PatchTSTNOPScaler(config)

    def forward(
        self, data: torch.Tensor, observed_indicator: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Input for scaler calculation
            observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Calculating the scale on the observed indicator.
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, um_input_channels)`)
        """
        # 调用所选的缩放器模块进行缩放操作
        data, loc, scale = self.scaler(data, observed_indicator)
        # 返回缩放后的数据、均值和缩放因子
        return data, loc, scale


# 添加文档字符串描述 PatchTSTModel 模型输出原始隐藏状态,不包含特定头部
@add_start_docstrings(
    "The bare PatchTST Model outputting raw hidden-states without any specific head.",
    PATCHTST_START_DOCSTRING,
)
class PatchTSTModel(PatchTSTPreTrainedModel):
    # 使用给定的配置对象初始化类,调用父类的初始化方法
    def __init__(self, config: PatchTSTConfig):
        super().__init__(config)

        # 使用配置对象初始化 PatchTSTScaler 实例
        self.scaler = PatchTSTScaler(config)
        # 使用配置对象初始化 PatchTSTPatchify 实例
        self.patchifier = PatchTSTPatchify(config)
        # 从 PatchTSTPatchify 获取 num_patches 信息
        num_patches = self.patchifier.num_patches

        # 根据配置决定是否对输入进行屏蔽处理
        if self.do_mask_input:
            self.masking = PatchTSTMasking(config)
        else:
            # 如果不需要屏蔽输入,则使用恒等映射
            self.masking = nn.Identity()
        
        # 使用配置对象和 num_patches 初始化 PatchTSTEncoder 实例
        self.encoder = PatchTSTEncoder(config, num_patches=num_patches)

        # 初始化权重并进行最终处理
        self.post_init()

    # 定义前向传播方法,接受一些输入张量和可选参数,并返回预测结果
    def forward(
        self,
        past_values: torch.Tensor,
        past_observed_mask: Optional[torch.Tensor] = None,
        future_values: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        return_dict: Optional[bool] = None,
class PatchTSTMaskPretrainHead(nn.Module):
    """
    Pretraining head for mask modelling
    """

    def __init__(self, config: PatchTSTConfig):
        super().__init__()
        self.dropout = nn.Dropout(config.dropout)  # 定义一个 dropout 层,根据配置决定丢弃概率
        self.linear = nn.Linear(config.d_model, config.patch_length)  # 定义一个全连接层,将输入维度映射到 patch_length
        self.use_cls_token = config.use_cls_token  # 是否使用类别标记(CLS token)

    def forward(self, embedding: torch.Tensor) -> torch.Tensor:
        """
        Parameters:
            embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
                    `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
                Embedding from the model
        Returns:
            `torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
                            `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True

        """
        embedding = self.linear(self.dropout(embedding))  # 使用线性层处理嵌入向量,形状变为 [bs x num_channels x num_patches x patch_length]
        if self.use_cls_token:
            embedding = embedding[:, :, 1:, :]  # 如果设置使用类别标记,去除第一个类别标记的部分
        return embedding


@add_start_docstrings(
    "The PatchTST for pretrain model.",
    PATCHTST_START_DOCSTRING,
)
class PatchTSTForPretraining(PatchTSTPreTrainedModel):
    def __init__(self, config: PatchTSTConfig):
        super().__init__(config)

        config.do_mask_input = True  # 设置配置参数以掩蔽输入
        self.model = PatchTSTModel(config=config)  # 实例化 PatchTSTModel,并传入配置
        self.head = PatchTSTMaskPretrainHead(config)  # 实例化预训练头部模型 PatchTSTMaskPretrainHead

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        past_values: torch.Tensor,
        past_observed_mask: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Parameters:
            past_values (`torch.Tensor`): Tensor containing past values.
            past_observed_mask (`Optional[torch.Tensor]`, optional): Mask tensor for observed values.
            output_hidden_states (`Optional[bool]`, optional): Whether to output hidden states.
            output_attentions (`Optional[bool]`, optional): Whether to output attention weights.
            return_dict (`Optional[bool]`, optional): Whether to return a dictionary.

        Returns:
            Dictionary containing output tensors depending on the model configuration.
        """
        # 省略部分代码...

class PatchTSTClassificationHead(nn.Module):
    def __init__(self, config: PatchTSTConfig):
        super().__init__()
        self.use_cls_token = config.use_cls_token  # 是否使用类别标记(CLS token)
        self.pooling_type = config.pooling_type  # 池化类型
        self.flatten = nn.Flatten(start_dim=1)  # 展开操作,从第一个维度开始展开
        self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()  # 如果设置了 dropout,则使用;否则使用恒等映射
        self.linear = nn.Linear(config.num_input_channels * config.d_model, config.num_targets)  # 全连接层,输入为 num_input_channels * d_model,输出为 num_targets
    def forward(self, embedding: torch.Tensor):
        """
        Parameters:
            embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
                     `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
                Embedding from the model
        Returns:
            `torch.Tensor` of shape `(bs, num_targets)`

        """
        if self.use_cls_token:
            # 如果设置了使用CLS token,则使用第一个输出token作为池化的embedding: bs x num_channels x d_model
            pooled_embedding = embedding[:, :, 0, :]
        elif self.pooling_type == "mean":
            # 如果使用均值池化,则对embedding在第2维(num_patches)上取均值: pooled_embedding: [bs x num_channels x d_model]
            pooled_embedding = embedding.mean(dim=2)
        elif self.pooling_type == "max":
            # 如果使用最大池化,则对embedding在第2维(num_patches)上取最大值: pooled_embedding: [bs x num_channels x d_model]
            pooled_embedding = embedding.max(dim=2).values
        else:
            # 如果指定的池化类型未实现,则抛出异常
            raise ValueError(f"pooling operator {self.pooling_type} is not implemented yet")
        
        # 将池化后的embedding展平,pooled_embedding: bs x num_channels * d_model
        pooled_embedding = self.flatten(pooled_embedding)
        
        # 经过线性层和dropout后得到最终输出,output: bs x n_classes
        output = self.linear(self.dropout(pooled_embedding))
        
        return output
@add_start_docstrings(
    "The PatchTST for classification model.",
    PATCHTST_START_DOCSTRING,
)
class PatchTSTForClassification(PatchTSTPreTrainedModel):
    def __init__(self, config: PatchTSTConfig):
        super().__init__(config)

        # Turn off masking if specified in the configuration
        if config.do_mask_input:
            logger.warning("Setting `do_mask_input` parameter to False.")
            config.do_mask_input = False

        # Initialize PatchTSTModel and PatchTSTClassificationHead
        self.model = PatchTSTModel(config)
        self.head = PatchTSTClassificationHead(config)

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        past_values: torch.Tensor,
        target_values: torch.Tensor = None,
        past_observed_mask: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        """
        Forward pass of the PatchTSTForClassification model.

        Parameters:
        - past_values: Tensor of past input values.
        - target_values: Optional tensor of target values.
        - past_observed_mask: Optional boolean mask for observed values.
        - output_hidden_states: Optional boolean to output hidden states.
        - output_attentions: Optional boolean to output attentions.
        - return_dict: Optional boolean to return a dictionary.

        Returns:
        - Depending on configurations, returns classification predictions.
        """
        # Forward pass implementation details are defined elsewhere.
        pass


@add_start_docstrings(
    "The PatchTST for regression Model.",
    PATCHTST_START_DOCSTRING,
)
class PatchTSTPredictionHead(nn.Module):
    def __init__(self, config: PatchTSTConfig, num_patches, distribution_output=None):
        super().__init__()

        self.share_projection = config.share_projection
        self.num_input_channels = config.num_input_channels
        self.use_cls_token = config.use_cls_token
        self.pooling_type = config.pooling_type

        # Determine head dimension based on configuration
        if self.pooling_type or self.use_cls_token:
            head_dim = config.d_model
        else:
            head_dim = config.d_model * num_patches

        if not self.share_projection:
            # If each channel has its own head, initialize projections, dropouts, and flattens
            self.projections = nn.ModuleList()
            self.dropouts = nn.ModuleList()
            self.flattens = nn.ModuleList()
            for i in range(self.num_input_channels):
                self.flattens.append(nn.Flatten(start_dim=2))
                if distribution_output is None:
                    # Use linear head projection
                    self.projections.append(nn.Linear(head_dim, config.prediction_length))
                else:
                    # Use distribution head projection
                    self.projections.append(distribution_output.get_parameter_projection(head_dim))
                self.dropouts.append(nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity())
        else:
            # All channels share the same head, initialize flatten, projection, and dropout
            self.flatten = nn.Flatten(start_dim=2)
            if distribution_output is None:
                # Use linear head projection
                self.projection = nn.Linear(head_dim, config.prediction_length)
            else:
                # Use distribution head projection
                self.projection = distribution_output.get_parameter_projection(head_dim)
            self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()

        # Additional initialization steps can be included here
    def forward(self, embedding: torch.Tensor):
        """
        Parameters:
            embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
                     `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
                Embedding from the model
        Returns:
            `torch.Tensor` of shape `(bs, forecast_len, num_channels)`

        """
        if self.use_cls_token:
            # 如果使用了 cls_token,则从 embedding 中选择第一个 patch 的 embedding
            # pooled_embedding: [bs x num_channels x d_model]
            pooled_embedding = embedding[:, :, 0, :]
        else:
            if self.pooling_type == "mean":
                # 如果使用平均池化,则对 embedding 在第二个维度(patch 维度)进行平均池化
                # pooled_embedding: [bs x num_channels x d_model]
                pooled_embedding = embedding.mean(dim=2)
            elif self.pooling_type == "max":
                # 如果使用最大池化,则对 embedding 在第二个维度进行最大池化操作,取最大值
                # pooled_embedding: [bs x num_channels x d_model]
                pooled_embedding = embedding.max(dim=2).values
            else:
                # 如果没有指定池化方式,则直接使用 embedding
                # pooled_embedding: [bs x num_channels x num_patches x d_model]
                pooled_embedding = embedding

        if not self.share_projection:
            output = []
            for i in range(self.num_input_channels):
                # 对 pooled_embedding 进行展平操作,以便进行后续的线性变换
                # pooled_embedding: [bs x (d_model * num_patches)] or [bs x d_model)]
                pooled_embedding = self.flattens[i](pooled_embedding[:, i, :])
                pooled_embedding = self.dropouts[i](pooled_embedding)
                # 经过线性变换得到输出,可能返回一个或两个 tensor,视具体实现而定
                # pooled_embedding: [bs x forecast_len]
                #  or tuple ([bs x forecast_len], [bs x forecast_len]) if using distribution head
                pooled_embedding = self.projections[i](pooled_embedding)
                output.append(pooled_embedding)
            # 将每个通道的输出堆叠起来
            # output: [bs x num_channels x forecast_len]
            output = torch.stack(output, dim=1)
        else:
            # 如果共享投影层,则对 pooled_embedding 进行统一的展平操作
            # pooled_embedding: [bs x num_channels x (d_model * num_patches)] or [bs x num_channels x d_model)]
            pooled_embedding = self.flatten(pooled_embedding)
            pooled_embedding = self.dropout(pooled_embedding)
            # 经过线性变换得到输出,可能返回一个或两个 tensor,视具体实现而定
            # output: [bs x num_channels x forecast_len] or
            # tuple ([bs x num_channels x forecast_len], [bs x num_channels x forecast_len]) if using distribution head
            output = self.projection(pooled_embedding)

        if isinstance(output, tuple):
            # 如果输出是一个 tuple,则交换第二个和第三个维度
            # output: ([bs x forecast_len x num_channels], [bs x forecast_len x num_channels])
            output = tuple(z.transpose(2, 1) for z in output)
        else:
            # 否则,交换第二个和第三个维度
            output = output.transpose(2, 1)  # [bs x forecast_len x num_channels]
        return output
@add_start_docstrings(
    "The PatchTST for prediction model.",
    PATCHTST_START_DOCSTRING,
)
class PatchTSTForPrediction(PatchTSTPreTrainedModel):
    def __init__(self, config: PatchTSTConfig):
        super().__init__(config)

        # Turn off masking if specified in the configuration
        if config.do_mask_input:
            logger.warning("Setting `do_mask_input` parameter to False.")
            config.do_mask_input = False

        # Instantiate the PatchTSTModel with the provided configuration
        self.model = PatchTSTModel(config)

        # Determine the type of distribution output based on the configuration
        if config.loss == "mse":
            self.distribution_output = None
        else:
            if config.distribution_output == "student_t":
                self.distribution_output = StudentTOutput(dim=config.prediction_length)
            elif config.distribution_output == "normal":
                self.distribution_output = NormalOutput(dim=config.prediction_length)
            elif config.distribution_output == "negative_binomial":
                self.distribution_output = NegativeBinomialOutput(dim=config.prediction_length)
            else:
                raise ValueError(f"Unknown distribution output {config.distribution_output}")

        # Initialize PatchTSTPredictionHead with necessary configurations and distribution output
        self.head = PatchTSTPredictionHead(
            config, self.model.patchifier.num_patches, distribution_output=self.distribution_output
        )

        # Initialize weights and apply final processing for the model
        self.post_init()

    def forward(
        self,
        past_values: torch.Tensor,
        past_observed_mask: Optional[torch.Tensor] = None,
        future_values: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # Forward pass method for the model, computes predictions based on input
        ...

    def generate(
        self,
        past_values: torch.Tensor,
        past_observed_mask: Optional[torch.Tensor] = None,
        ...
    ) -> SamplePatchTSTOutput:
        """
        Generate sequences of sample predictions from a model with a probability distribution head.

        Parameters:
            past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Past values of the time series that serves as context in order to predict the future.
            past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
                Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
                in `[0, 1]`:

                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).

        Return:
            [`SamplePatchTSTOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of
            samples, prediction_length, 1)` or `(batch_size, number of samples, prediction_length, num_input_channels)`
            for multivariate predictions.
        """
        # 获取并行采样的数量
        num_parallel_samples = self.config.num_parallel_samples

        # 获取模型的输出
        outputs = self(
            past_values=past_values,
            future_values=None,
            past_observed_mask=past_observed_mask,
            output_hidden_states=False,
        )

        if self.distribution_output:
            # 获取分布对象
            distribution = self.distribution_output.distribution(
                outputs.prediction_outputs, loc=outputs.loc, scale=outputs.scale
            )
            # 生成样本:列表形式的 [bs x forecast_len x num_channels]
            samples = [distribution.sample() for _ in range(num_parallel_samples)]
            # 将样本堆叠起来:[bs x num_samples x forecast_len x num_channels]
            samples = torch.stack(samples, dim=1)
        else:
            # 如果没有指定分布输出,直接使用预测输出,并在样本维度上增加一个维度
            samples = outputs.prediction_outputs.unsqueeze(1)

        # 返回包含样本预测序列的 SamplePatchTSTOutput 对象
        return SamplePatchTSTOutput(sequences=samples)
class PatchTSTRegressionHead(nn.Module):
    """
    Regression head
    """

    def __init__(self, config: PatchTSTConfig, distribution_output=None):
        super().__init__()
        # 设置输出范围
        self.y_range = config.output_range
        # 是否使用类别标记
        self.use_cls_token = config.use_cls_token
        # 池化类型
        self.pooling_type = config.pooling_type
        # 分布输出
        self.distribution_output = distribution_output

        # 计算头部维度
        head_dim = config.num_input_channels * config.d_model

        # 展平层,将输入展平
        self.flatten = nn.Flatten(start_dim=1)
        # dropout层,如果配置了dropout,则应用dropout;否则使用恒等映射
        self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()

        # 如果未提供分布输出,使用线性层进行投影
        if distribution_output is None:
            self.projection = nn.Linear(head_dim, config.num_targets)
        else:
            # 否则,使用分布输出对象提供的投影
            self.projection = distribution_output.get_parameter_projection(head_dim)

    def forward(self, embedding: torch.Tensor):
        """
        Parameters:
            embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
                    `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
                Embedding from the model
        Returns:
            `torch.Tensor` of shape `(bs, output_dim)`

        """
        if self.use_cls_token:
            # 如果使用类别标记,选择第一个输出标记,池化后的嵌入:[bs x num_channels x d_model]
            pooled_embedding = embedding[:, :, 0, :]
        elif self.pooling_type == "mean":
            # 使用均值池化,池化后的嵌入:[bs x num_channels x d_model]
            pooled_embedding = embedding.mean(dim=2)
        elif self.pooling_type == "max":
            # 使用最大池化,池化后的嵌入:[bs x num_channels x d_model]
            pooled_embedding = embedding.max(dim=2).values
        else:
            # 抛出错误,指定的池化类型尚未实现
            raise ValueError(f"pooling operator {self.pooling_type} is not implemented yet")
        
        # 展平输入
        # pooled_embedding: bs x (num_channels * d_model)
        pooled_embedding = self.dropout(self.flatten(pooled_embedding))
        
        # 投影操作
        # output: bs x output_dim 或者是一个这样形状的元组,用于分布头部
        output = self.projection(pooled_embedding)
        
        # 如果需要,应用sigmoid函数来限制输出范围
        if (self.distribution_output is None) & (self.y_range is not None):  # 线性头部
            output = torch.sigmoid(output) * (self.y_range[1] - self.y_range[0]) + self.y_range[0]
        
        return output


@add_start_docstrings(
    "The PatchTST for regression model.",
    PATCHTST_START_DOCSTRING,
)
class PatchTSTForRegression(PatchTSTPreTrainedModel):
    """
    PatchTST for regression model.
    
    Inherits from PatchTSTPreTrainedModel.
    """
    def __init__(self, config: PatchTSTConfig):
        # 调用父类的初始化方法,传入配置对象
        super().__init__(config)

        # 关闭输入数据的掩码处理
        if config.do_mask_input:
            # 如果需要掩码输入,发出警告并设置参数为 False
            logger.warning("Setting `do_mask_input` parameter to False.")
            config.do_mask_input = False

        # 使用配置对象初始化 PatchTSTModel 模型
        self.model = PatchTSTModel(config)

        # 根据损失函数类型确定输出分布
        if config.loss == "mse":
            # 如果损失函数是均方误差,则不需要特定的分布输出
            self.distribution_output = None
        else:
            # 根据配置中的分布输出类型选择对应的输出对象
            if config.distribution_output == "student_t":
                self.distribution_output = StudentTOutput(dim=config.num_targets)
            elif config.distribution_output == "normal":
                self.distribution_output = NormalOutput(dim=config.num_targets)
            elif config.distribution_output == "negative_binomial":
                self.distribution_output = NegativeBinomialOutput(dim=config.num_targets)
            else:
                # 如果配置中指定了未知的分布输出类型,抛出数值错误
                raise ValueError(f"Unknown distribution output {config.distribution_output}")

        # 使用 PatchTSTRegressionHead 初始化模型的头部
        self.head = PatchTSTRegressionHead(config, self.distribution_output)

        # 初始化权重并应用最终处理
        self.post_init()

    def forward(
        self,
        past_values: torch.Tensor,
        target_values: torch.Tensor = None,
        past_observed_mask: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 正向传播函数,根据输入参数进行模型推断和预测

    def generate(
        self,
        past_values: torch.Tensor,
        past_observed_mask: Optional[torch.Tensor] = None,
        ) -> SamplePatchTSTOutput:
        """
        从具有概率分布输出头的模型生成样本预测序列。

        Parameters:
            past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                时间序列的过去值,用作上下文以预测未来。
            past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
                布尔掩码,指示哪些 `past_values` 是观察到的,哪些是缺失的。掩码的取值范围为 `[0, 1]`:

                - 1 表示 **观察到** 的值,
                - 0 表示 **缺失** 的值(即被 NaN 替换为零)。

        Return:
            [`SamplePatchTSTOutput`],输出的 `sequences` 张量形状为 `(batch_size, number of samples, num_targets)`。
        """
        # 获取样本数
        num_parallel_samples = self.config.num_parallel_samples

        # 获取模型输出
        outputs = self(
            past_values=past_values,
            target_values=None,
            past_observed_mask=past_observed_mask,
            output_hidden_states=False,
        )

        # 获取分布
        distribution = self.distribution_output.distribution(outputs.regression_outputs)
        # 获取样本: 列表 `[bs x num_targets]`
        samples = [distribution.sample() for _ in range(num_parallel_samples)]
        # samples: `[bs x num_samples x num_targets]`
        samples = torch.stack(samples, dim=1).view(-1, num_parallel_samples, self.config.num_targets)
        return SamplePatchTSTOutput(sequences=samples)

.\models\patchtst\__init__.py

# 版权声明和版权信息,声明此代码的版权归 HuggingFace 团队所有
#
# 根据 Apache 许可证 2.0 版本许可,除非符合许可证规定,否则不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则按“原样”分发软件,
# 不附带任何明示或暗示的保证或条件。请参阅许可证了解具体的法律条款和限制。
from typing import TYPE_CHECKING

# 从模块中导入必要的异常和类
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available

# 定义模块的导入结构
_import_structure = {
    "configuration_patchtst": [
        "PATCHTST_PRETRAINED_CONFIG_ARCHIVE_MAP",
        "PatchTSTConfig",
    ],
}

# 检查是否导入了 torch 库
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果 torch 可用,将模型相关的内容添加到导入结构中
    _import_structure["modeling_patchtst"] = [
        "PATCHTST_PRETRAINED_MODEL_ARCHIVE_LIST",
        "PatchTSTModel",
        "PatchTSTPreTrainedModel",
        "PatchTSTForPrediction",
        "PatchTSTForPretraining",
        "PatchTSTForRegression",
        "PatchTSTForClassification",
    ]

# 如果是类型检查阶段,导入相关类型检查需要的内容
if TYPE_CHECKING:
    from .configuration_patchtst import PATCHTST_PRETRAINED_CONFIG_ARCHIVE_MAP, PatchTSTConfig

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_patchtst import (
            PATCHTST_PRETRAINED_MODEL_ARCHIVE_LIST,
            PatchTSTForClassification,
            PatchTSTForPrediction,
            PatchTSTForPretraining,
            PatchTSTForRegression,
            PatchTSTModel,
            PatchTSTPreTrainedModel,
        )

# 如果不是类型检查阶段,则将当前模块注册为 LazyModule
else:
    import sys

    # 使用 LazyModule 将当前模块注册到 sys.modules 中
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\pegasus\configuration_pegasus.py

# 设置文件编码为 UTF-8
# 版权声明,指出版权所有者和版权年份
# 根据 Apache 许可证 2.0 版本使用该文件
# 在符合许可证条件下,您可以使用该文件;如果不符合,则不允许使用
# 您可以在以下网址获取许可证的副本:http://www.apache.org/licenses/LICENSE-2.0
# 除非适用法律要求或书面同意,否则根据“现状”分发此软件
# 无论是明示的还是隐含的,都没有任何形式的保证或条件
# 有关更多信息,请参阅许可证文档
""" PEGASUS model configuration"""

# 从相对路径导入预训练配置类 PretrainedConfig
from ...configuration_utils import PretrainedConfig
# 导入日志记录工具
from ...utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# PEGASUS 预训练模型的配置文件映射,指定模型名称及其对应的配置 JSON 文件 URL
PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "google/pegasus-large": "https://huggingface.co/google/pegasus-large/resolve/main/config.json",
    # 查看所有 PEGASUS 模型,请访问 https://huggingface.co/models?filter=pegasus
}

# PegasusConfig 类,用于存储 PEGASUS 模型的配置信息,继承自 PretrainedConfig
class PegasusConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`PegasusModel`]. It is used to instantiate an
    PEGASUS model according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the PEGASUS
    [google/pegasus-large](https://huggingface.co/google/pegasus-large) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Example:

    ```
    >>> from transformers import PegasusConfig, PegasusModel

    >>> # Initializing a PEGASUS google/pegasus-large style configuration
    >>> configuration = PegasusConfig()

    >>> # Initializing a model (with random weights) from the google/pegasus-large style configuration
    >>> model = PegasusModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    # 指定模型类型为 PEGASUS
    model_type = "pegasus"
    # 在推理阶段要忽略的键列表
    keys_to_ignore_at_inference = ["past_key_values"]
    # 属性映射,将一些配置项名映射到模型参数中的实际名称
    attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
    #python
    # 初始化函数,用于创建一个新的Transformer模型实例
    def __init__(
        self,
        vocab_size=50265,                            # 词汇表大小,默认为50265
        max_position_embeddings=1024,                # 最大位置编码数,默认为1024
        encoder_layers=12,                           # 编码器层数,默认为12层
        encoder_ffn_dim=4096,                        # 编码器中FeedForward层的维度,默认为4096
        encoder_attention_heads=16,                  # 编码器中注意力头的数量,默认为16个
        decoder_layers=12,                           # 解码器层数,默认为12层
        decoder_ffn_dim=4096,                        # 解码器中FeedForward层的维度,默认为4096
        decoder_attention_heads=16,                  # 解码器中注意力头的数量,默认为16个
        encoder_layerdrop=0.0,                       # 编码器中层Dropout的比例,默认为0.0(无Dropout)
        decoder_layerdrop=0.0,                       # 解码器中层Dropout的比例,默认为0.0(无Dropout)
        use_cache=True,                              # 是否使用缓存,默认为True
        is_encoder_decoder=True,                     # 是否是编码-解码结构,默认为True
        activation_function="gelu",                  # 激活函数类型,默认为"GELU"
        d_model=1024,                                # 模型的维度,默认为1024
        dropout=0.1,                                 # 全局Dropout的比例,默认为0.1
        attention_dropout=0.0,                       # 注意力模块中Dropout的比例,默认为0.0(无Dropout)
        activation_dropout=0.0,                      # 激活函数Dropout的比例,默认为0.0(无Dropout)
        init_std=0.02,                               # 参数初始化的标准差,默认为0.02
        decoder_start_token_id=0,                    # 解码器起始token的ID,默认为0
        scale_embedding=False,                       # 是否对嵌入进行缩放,默认为False
        pad_token_id=0,                              # 填充token的ID,默认为0
        eos_token_id=1,                              # 结束token的ID,默认为1
        forced_eos_token_id=1,                       # 强制结束token的ID,默认为1
        **kwargs,                                    # 其他关键字参数,传递给父类初始化函数
    ):
        self.vocab_size = vocab_size                  # 设置词汇表大小
        self.max_position_embeddings = max_position_embeddings  # 设置最大位置编码数
        self.d_model = d_model                        # 设置模型维度
        self.encoder_ffn_dim = encoder_ffn_dim        # 设置编码器中FeedForward层的维度
        self.encoder_layers = encoder_layers          # 设置编码器层数
        self.encoder_attention_heads = encoder_attention_heads  # 设置编码器中注意力头的数量
        self.decoder_ffn_dim = decoder_ffn_dim        # 设置解码器中FeedForward层的维度
        self.decoder_layers = decoder_layers          # 设置解码器层数
        self.decoder_attention_heads = decoder_attention_heads  # 设置解码器中注意力头的数量
        self.dropout = dropout                        # 设置全局Dropout的比例
        self.attention_dropout = attention_dropout    # 设置注意力模块中Dropout的比例
        self.activation_dropout = activation_dropout  # 设置激活函数Dropout的比例
        self.activation_function = activation_function  # 设置激活函数类型
        self.init_std = init_std                      # 设置参数初始化的标准差
        self.encoder_layerdrop = encoder_layerdrop    # 设置编码器中层Dropout的比例
        self.decoder_layerdrop = decoder_layerdrop    # 设置解码器中层Dropout的比例
        self.use_cache = use_cache                    # 设置是否使用缓存
        self.num_hidden_layers = encoder_layers       # 设置隐藏层的数量为编码器层数
        self.scale_embedding = scale_embedding        # 设置是否对嵌入进行缩放,如果是则缩放因子为sqrt(d_model)
        super().__init__(                              # 调用父类初始化函数,传递参数给父类
            pad_token_id=pad_token_id,                # 设置填充token的ID
            eos_token_id=eos_token_id,                # 设置结束token的ID
            is_encoder_decoder=is_encoder_decoder,    # 设置是否是编码-解码结构
            decoder_start_token_id=decoder_start_token_id,  # 设置解码器起始token的ID
            forced_eos_token_id=forced_eos_token_id,  # 设置强制结束token的ID
            **kwargs,                                # 传递其他关键字参数给父类初始化函数
        )

    @property
    def num_attention_heads(self) -> int:
        return self.encoder_attention_heads         # 返回编码器中注意力头的数量

    @property
    def hidden_size(self) -> int:
        return self.d_model                         # 返回模型的维度

.\models\pegasus\convert_pegasus_tf_to_pytorch.py

# 定义全局变量,用于将 TensorFlow 模型的状态字典键转换为 PyTorch 模型的对应键
PATTERNS = [
    # 将左侧字符串替换为右侧字符串,以获取与 BART 模型状态字典相同的关键键值对(与 Pegasus 模型相同)
    ["memory_attention", "encoder_attn"],
    ["attention", "attn"],
    ["/", "."],
    [".LayerNorm.gamma", "_layer_norm.weight"],
    [".LayerNorm.beta", "_layer_norm.bias"],
    ["r.layer_", "r.layers."],
    ["output_proj", "out_proj"],
    ["ffn.dense_1.", "fc2."],
    ["ffn.dense.", "fc1."],
    ["ffn_layer_norm", "final_layer_norm"],
    ["kernel", "weight"],
    ["encoder_layer_norm.", "encoder.layer_norm."],
    ["decoder_layer_norm.", "decoder.layer_norm."],
    ["embeddings.weights", "shared.weight"],
]

# 函数:根据指定模型的状态字典键转换规则,重命名给定键名
def rename_state_dict_key(k):
    for pegasus_name, hf_name in PATTERNS:
        k = k.replace(pegasus_name, hf_name)
    return k


# 函数:将 TensorFlow 权重转换为 Pegasus 模型的 PyTorch 对象
def convert_pegasus(tf_weights: dict, cfg_updates: dict) -> PegasusForConditionalGeneration:
    # 复制默认配置并更新为指定配置
    cfg_kwargs = DEFAULTS.copy()
    cfg_kwargs.update(cfg_updates)
    # 创建 Pegasus 配置对象
    cfg = PegasusConfig(**cfg_kwargs)
    # 创建 PegasusForConditionalGeneration 模型对象
    torch_model = PegasusForConditionalGeneration(cfg)
    # 获取 PyTorch 模型的状态字典
    sd = torch_model.model.state_dict()
    # 存储键名映射关系的空字典
    mapping = {}
    # 遍历 TensorFlow 权重字典中的每个项
    for k, v in tf_weights.items():
        # 根据转换规则重命名键名
        new_k = rename_state_dict_key(k)
        # 如果重命名后的键名不存在于 PyTorch 模型的状态字典中,抛出错误
        if new_k not in sd:
            raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")

        # 如果键名中包含 "dense" 或 "proj",则转置权重矩阵
        if "dense" in k or "proj" in new_k:
            v = v.T
        # 将 TensorFlow 权重转换为 PyTorch 张量,并存储到映射字典中
        mapping[new_k] = torch.tensor(v, dtype=sd[new_k].dtype)
        # 断言 TensorFlow 权重形状与 PyTorch 状态字典中对应项的形状相同
        assert v.shape == sd[new_k].shape, f"{new_k}, {k}, {v.shape}, {sd[new_k].shape}"

    # 确保嵌入层的 padding_idx 被正确设置
    mapping["shared.weight"][cfg.pad_token_id] = torch.zeros_like(mapping["shared.weight"][cfg.pad_token_id + 1])
    # 将映射后的共享权重应用于编码器和解码器的嵌入层
    mapping["encoder.embed_tokens.weight"] = mapping["shared.weight"]
    mapping["decoder.embed_tokens.weight"] = mapping["shared.weight"]
    # 创建空偏置项的字典,并添加到映射字典中
    empty_biases = {k: torch.zeros_like(v) for k, v in sd.items() if k.endswith("bias") and k not in mapping}
    mapping.update(**empty_biases)
    # 使用给定的映射加载模型的状态字典,允许部分匹配
    missing, extra = torch_model.model.load_state_dict(mapping, strict=False)
    
    # 找出在缺失的键中,不属于特定的例外列表的键
    unexpected_missing = [
        k for k in missing if k not in ["encoder.embed_positions.weight", "decoder.embed_positions.weight"]
    ]
    
    # 如果存在不期望的缺失键,引发断言错误,显示未匹配的torch键
    assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
    
    # 如果存在额外的键,引发断言错误,显示未匹配的tf键
    assert extra == [], f"no matches found for the following tf keys {extra}"
    
    # 返回更新后的torch模型
    return torch_model
# 导入必要的库和模块
import argparse
from pathlib import Path
from typing import Dict
from tqdm import tqdm
import tensorflow as tf
from transformers import PegasusTokenizer
from utils import task_specific_params  # 假设这是从外部导入的任务特定参数
from convert_pegasus import convert_pegasus  # 假设这是从外部导入的模型转换函数

# 从 TensorFlow 检查点中获取权重并以字典形式返回
def get_tf_weights_as_numpy(path="./ckpt/aeslc/model.ckpt-32000") -> Dict:
    # 列出 TensorFlow 检查点中的变量名和形状
    init_vars = tf.train.list_variables(path)
    # 初始化一个空字典用于存储 TensorFlow 的权重
    tf_weights = {}
    # 忽略特定名称的变量
    ignore_name = ["Adafactor", "global_step"]
    # 遍历变量名和形状列表
    for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"):
        # 如果变量名中包含需要忽略的关键字,则跳过此变量
        skip_key = any(pat in name for pat in ignore_name)
        if skip_key:
            continue
        # 加载 TensorFlow 检查点中的变量值
        array = tf.train.load_variable(path, name)
        # 将变量名和对应的值存入字典中
        tf_weights[name] = array
    # 返回包含 TensorFlow 权重的字典
    return tf_weights


# 将 Pegasus 模型从 TensorFlow 转换为 PyTorch 并保存
def convert_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str):
    # 首先保存分词器 tokenizer
    dataset = Path(ckpt_path).parent.name
    # 获取特定任务的最大模型长度
    desired_max_model_length = task_specific_params[f"summarization_{dataset}"]["max_position_embeddings"]
    # 根据预训练模型和指定的最大长度创建 tokenizer
    tok = PegasusTokenizer.from_pretrained("sshleifer/pegasus", model_max_length=desired_max_model_length)
    # 确认 tokenizer 的最大长度符合预期
    assert tok.model_max_length == desired_max_model_length
    # 将 tokenizer 保存到指定的目录
    tok.save_pretrained(save_dir)

    # 转换模型
    tf_weights = get_tf_weights_as_numpy(ckpt_path)
    # 获取特定任务的配置更新
    cfg_updates = task_specific_params[f"summarization_{dataset}"]
    # 如果数据集为 "large",则添加任务特定参数到配置更新中
    if dataset == "large":
        cfg_updates["task_specific_params"] = task_specific_params
    # 将 TensorFlow 模型转换为 PyTorch 模型
    torch_model = convert_pegasus(tf_weights, cfg_updates)
    # 将 PyTorch 模型保存到指定的目录
    torch_model.save_pretrained(save_dir)
    # 获取 PyTorch 模型的状态字典
    sd = torch_model.state_dict()
    # 从状态字典中删除特定的位置嵌入权重
    sd.pop("model.decoder.embed_positions.weight")
    sd.pop("model.encoder.embed_positions.weight")
    # 将处理后的状态字典保存为二进制文件
    torch.save(sd, Path(save_dir) / "pytorch_model.bin")


if __name__ == "__main__":
    # 解析命令行参数
    parser = argparse.ArgumentParser()
    # 必需的参数:TensorFlow 检查点路径
    parser.add_argument("tf_ckpt_path", type=str, help="passed to tf.train.list_variables")
    # 可选的参数:保存 PyTorch 模型的路径
    parser.add_argument("save_dir", default=None, type=str, help="Path to the output PyTorch model.")
    args = parser.parse_args()
    
    # 如果未指定保存路径,则根据 TensorFlow 检查点路径确定默认路径
    if args.save_dir is None:
        dataset = Path(args.tf_ckpt_path).parent.name
        args.save_dir = os.path.join("pegasus", dataset)
    
    # 调用函数:将 Pegasus 模型从 TensorFlow 转换为 PyTorch 并保存
    convert_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir)

.\models\pegasus\modeling_flax_pegasus.py

# 导入必要的库和模块
import math  # 导入数学函数库
import random  # 导入随机数生成模块
from functools import partial  # 导入函数工具模块中的 partial 函数
from typing import Callable, Optional, Tuple  # 导入类型提示相关的模块

import flax.linen as nn  # 导入 Flax 的 Linen 模块,用于定义神经网络层
import jax  # 导入 JAX,用于自动求导和并行计算
import jax.numpy as jnp  # 导入 JAX 对应的 NumPy 函数库
import numpy as np  # 导入 NumPy 函数库
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze  # 导入冻结字典相关的功能
from flax.linen import combine_masks, make_causal_mask  # 导入组合掩码和创建因果掩码的函数
from flax.linen.attention import dot_product_attention_weights  # 导入点积注意力权重计算函数
from flax.traverse_util import flatten_dict, unflatten_dict  # 导入扁平化和反扁平化字典的工具函数
from jax import lax  # 导入 JAX 的低级 API,用于控制流程和并行计算
from jax.random import PRNGKey  # 导入 JAX 随机数生成器 PRNGKey

from ...modeling_flax_outputs import (  # 导入输出相关的 Flax 模块
    FlaxBaseModelOutput,
    FlaxBaseModelOutputWithPastAndCrossAttentions,
    FlaxCausalLMOutputWithCrossAttentions,
    FlaxSeq2SeqLMOutput,
    FlaxSeq2SeqModelOutput,
)
from ...modeling_flax_utils import (  # 导入 Flax 模型工具函数
    ACT2FN,
    FlaxPreTrainedModel,
    add_start_docstrings_to_model_forward,
    append_call_sample_docstring,
    append_replace_return_docstrings,
    overwrite_call_docstring,
)
from ...utils import add_start_docstrings, logging, replace_return_docstrings  # 导入工具函数和日志记录相关模块
from .configuration_pegasus import PegasusConfig  # 导入 Pegasus 模型的配置类

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

_CHECKPOINT_FOR_DOC = "google/pegasus-large"  # Pegasus 模型的预训练检查点名称
_CONFIG_FOR_DOC = "PegasusConfig"  # Pegasus 模型的配置名称

PEGASUS_START_DOCSTRING = r"""
    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a Flax Linen
    [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
    regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
"""
    # 定义函数参数说明
    Parameters:
        config ([`PegasusConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
            The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
            `jax.numpy.bfloat16` (on TPUs).

            This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
            specified all the computation will be performed with the given `dtype`.

            **Note that this only specifies the dtype of the computation and does not influence the dtype of model
            parameters.**

            If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
            [`~FlaxPreTrainedModel.to_bf16`].
"""

PEGASUS_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are decoder input IDs?](../glossary#decoder-input-ids)
        decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.

            If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
            paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
        position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.
        decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
            range `[0, config.max_position_embeddings - 1]`.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""


PEGASUS_ENCODE_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens to be encoded. These indices are obtained using a tokenizer, typically
            from a list of input strings. Each index corresponds to a token in the vocabulary.
        
        attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. This mask tensor has shape `(batch_size, 
            sequence_length)`, where each value is either 1 (token is not masked) or 0 (token is masked, typically 
            because it's padding).
        
        position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence token in the position embeddings matrix. These indices range 
            from 0 to `config.max_position_embeddings - 1`.
        
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. If `True`, the returned output 
            will include attention tensors from all layers of the model.
        
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. If `True`, the returned output will include 
            hidden states from all layers of the model.
        
        return_dict (`bool`, *optional*):
            Whether or not to return a `utils.ModelOutput` object instead of a plain tuple. If `True`, the output 
            will be encapsulated in a structured object that includes additional metadata.
"""
    Args:
        input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
            # 输入序列标记的索引数组,形状为(batch_size, sequence_length)。默认情况下会忽略填充部分。
            # 可以使用AutoTokenizer获取这些索引。参见PreTrainedTokenizer.encode和PreTrainedTokenizer.__call__获取详细信息。
            # 输入IDs是什么?详见../glossary#input-ids

        attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            # 避免在填充的标记索引上执行注意力操作的掩码。掩码值选择在[0, 1]范围内:
            # - 1表示**未屏蔽**的标记,
            # - 0表示**屏蔽**的标记。
            # 注意掩码是什么?详见../glossary#attention-mask

        position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            # 输入序列每个标记在位置嵌入中的位置索引数组,形状为(batch_size, sequence_length)。
            # 索引值选在范围[0, config.max_position_embeddings - 1]内。

        output_attentions (`bool`, *optional*):
            # 是否返回所有注意力层的注意力张量。更多详细信息参见返回的张量中的'attentions'部分。

        output_hidden_states (`bool`, *optional*):
            # 是否返回所有层的隐藏状态。更多详细信息参见返回的张量中的'hidden_states'部分。

        return_dict (`bool`, *optional*):
            # 是否返回一个utils.ModelOutput而不是普通的元组。
# PEGASUS_DECODE_INPUTS_DOCSTRING 是一个原始字符串(raw string),用于文档化 Pegasus 解码函数的输入参数及其含义。
PEGASUS_DECODE_INPUTS_DOCSTRING = r"""
    Args:
        decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`):
            解码器输入序列标记在词汇表中的索引。

            索引可以使用 [`AutoTokenizer`] 获得。有关详细信息,请参阅 [`PreTrainedTokenizer.encode`] 和
            [`PreTrainedTokenizer.__call__`]。

            [什么是解码器输入ID?](../glossary#decoder-input-ids)
        encoder_outputs (`tuple(tuple(jnp.ndarray)`):
            元组包括 (`last_hidden_state`, *可选*: `hidden_states`, *可选*: `attentions`)
            `last_hidden_state` 的形状为 `(batch_size, sequence_length, hidden_size)`,*可选*) 是编码器最后一层的隐藏状态序列。用于解码器的交叉注意力。
        encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *可选*):
            遮罩,避免在填充的标记索引上执行注意力。遮罩值选择在 `[0, 1]`:

            - 对于 **未遮罩** 的标记为 1,
            - 对于 **遮罩** 的标记为 0.

            [什么是注意力遮罩?](../glossary#attention-mask)
        decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *可选*):
            默认行为: 生成一个张量,忽略 `decoder_input_ids` 中的填充标记。默认情况下也将使用因果遮罩。

            如果要更改填充行为,应根据需求进行修改。有关默认策略的更多信息,请参见 [论文中的图表 1](https://arxiv.org/abs/1910.13461)。
        decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *可选*):
            每个解码器输入序列标记在位置嵌入中的位置索引。选取范围为 `[0, config.max_position_embeddings - 1]`。
        past_key_values (`Dict[str, np.ndarray]`, *可选*, 由 `init_cache` 返回或传递先前的 `past_key_values`):
            预计算的隐藏状态的字典(在注意力块中的键和值)。用于快速自回归解码。预计算的键和值隐藏状态的形状为 *[batch_size, max_length]*。
        output_attentions (`bool`, *可选*):
            是否返回所有注意力层的注意力张量。有关返回张量的更多细节,请参见返回的张量下的 `attentions`。
        output_hidden_states (`bool`, *可选*):
            是否返回所有层的隐藏状态。有关返回张量的更多细节,请参见返回的张量下的 `hidden_states`。
        return_dict (`bool`, *可选*):
            是否返回 [`~utils.ModelOutput`] 而不是普通元组。
"""
# 将输入的 token ID 右移一个位置
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
    """
    Shift input ids one token to the right.
    """
    # 创建与 input_ids 相同形状的全零数组
    shifted_input_ids = jnp.zeros_like(input_ids)
    # 将 input_ids 的每一行,从第二列开始到末尾的数据,复制到 shifted_input_ids 的对应位置
    shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
    # 将每一行的第一列设为 decoder_start_token_id
    shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)

    # 将值为 -100 的位置(特殊标记),替换为 pad_token_id
    shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
    return shifted_input_ids


# 从 transformers.models.marian.modeling_flax_marian.create_sinusoidal_positions 复制而来
# 创建一个正弦位置编码矩阵
def create_sinusoidal_positions(n_pos, dim):
    # 根据位置和维度生成一个正弦位置编码矩阵
    position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
    sentinel = dim // 2 + dim % 2
    out = np.zeros_like(position_enc)
    # 将位置编码矩阵中偶数索引列的值设置为正弦值
    out[:, 0:sentinel] = np.sin(position_enc[:, 0::2])
    # 将位置编码矩阵中奇数索引列的值设置为余弦值
    out[:, sentinel:] = np.cos(position_enc[:, 1::2])

    return jnp.array(out)


# 从 transformers.models.bart.modeling_flax_bart.FlaxBartAttention 复制而来,将 Bart 替换为 Pegasus
# 定义 Pegasus 注意力机制的模块
class FlaxPegasusAttention(nn.Module):
    config: PegasusConfig
    embed_dim: int
    num_heads: int
    dropout: float = 0.0
    causal: bool = False
    bias: bool = True
    dtype: jnp.dtype = jnp.float32  # 计算时的数据类型

    def setup(self) -> None:
        # 计算每个头部的维度
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {self.num_heads})."
            )

        # 定义用于计算的全连接层,初始化方式为正态分布
        dense = partial(
            nn.Dense,
            self.embed_dim,
            use_bias=self.bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )

        # 分别为查询、键、值和输出定义全连接层
        self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
        self.out_proj = dense()

        # 定义 dropout 层
        self.dropout_layer = nn.Dropout(rate=self.dropout)

        # 如果 causal 为 True,则创建一个因果遮罩
        if self.causal:
            self.causal_mask = make_causal_mask(
                jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
            )

    # 将隐藏状态按头部数和头部维度进行分割
    def _split_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))

    # 将隐藏状态的头部合并回原始形状
    def _merge_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))

    @nn.compact
    def _concatenate_to_cache(self, key, value, query, attention_mask):
        """
        This function takes projected key, value states from a single input token and concatenates the states to cached
        states from previous steps. This function is slightly adapted from the official Flax repository:
        https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
        """
        # 检测是否初始化,通过检查缓存数据是否存在来判断
        is_initialized = self.has_variable("cache", "cached_key")
        # 获取或初始化缓存的键(key)和值(value)
        cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
        cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
        # 获取或初始化缓存索引,用于追踪当前缓存的位置
        cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))

        if is_initialized:
            # 获取当前缓存的形状信息,包括批次维度、最大长度、头数和每个头的深度
            *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
            # 使用新的一维空间切片更新键(key)和值(value)缓存
            cur_index = cache_index.value
            indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
            key = lax.dynamic_update_slice(cached_key.value, key, indices)
            value = lax.dynamic_update_slice(cached_value.value, value, indices)
            cached_key.value = key
            cached_value.value = value
            # 更新缓存索引,增加已更新的缓存向量数量
            num_updated_cache_vectors = query.shape[1]
            cache_index.value = cache_index.value + num_updated_cache_vectors
            # 为缓存的解码器自注意力生成因果遮罩:当前查询位置只能注意到已生成和缓存的键位置,而不是剩余的零元素。
            pad_mask = jnp.broadcast_to(
                jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
                tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
            )
            attention_mask = combine_masks(pad_mask, attention_mask)
        
        # 返回更新后的键(key)、值(value)和注意力遮罩
        return key, value, attention_mask
# 从transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayer复制而来,将MBart改为Pegasus
class FlaxPegasusEncoderLayer(nn.Module):
    # Pegasus模型的配置
    config: PegasusConfig
    # 计算中使用的数据类型,默认为32位浮点数
    dtype: jnp.dtype = jnp.float32

    # 设置方法,初始化编码器层的组件
    def setup(self) -> None:
        # 嵌入维度等于模型配置中的d_model参数
        self.embed_dim = self.config.d_model
        # Pegasus自注意力机制
        self.self_attn = FlaxPegasusAttention(
            config=self.config,
            embed_dim=self.embed_dim,
            num_heads=self.config.encoder_attention_heads,
            dropout=self.config.attention_dropout,
            dtype=self.dtype,
        )
        # 层归一化层,用于自注意力输出
        self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
        # 用于自注意力输出的Dropout层
        self.dropout_layer = nn.Dropout(rate=self.config.dropout)
        # 激活函数,根据配置中的激活函数选择对应的激活函数
        self.activation_fn = ACT2FN[self.config.activation_function]
        # 激活函数后的Dropout层
        self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
        # 第一个全连接层,用于前馈神经网络
        self.fc1 = nn.Dense(
            self.config.encoder_ffn_dim,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )
        # 第二个全连接层,用于前馈神经网络的输出
        self.fc2 = nn.Dense(
            self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
        )
        # 最终的层归一化层,用于前馈神经网络的输出
        self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)

    # 调用方法,执行编码器层的前向计算
    def __call__(
        self,
        hidden_states: jnp.ndarray,
        attention_mask: jnp.ndarray,
        output_attentions: bool = True,
        deterministic: bool = True,
    ) -> Tuple[jnp.ndarray]:
        # 保存残差连接的输入
        residual = hidden_states
        # 对输入进行自注意力输出的层归一化处理
        hidden_states = self.self_attn_layer_norm(hidden_states)
        # 执行自注意力机制计算,并返回计算后的隐藏状态及注意力权重
        hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask)
        # 应用Dropout层,以减少过拟合风险
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
        # 添加残差连接
        hidden_states = residual + hidden_states

        # 保存残差连接的输入
        residual = hidden_states
        # 对输入进行最终输出的层归一化处理
        hidden_states = self.final_layer_norm(hidden_states)
        # 使用激活函数处理第一个全连接层的输出
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        # 应用激活函数后的Dropout层
        hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
        # 执行第二个全连接层的计算
        hidden_states = self.fc2(hidden_states)
        # 应用Dropout层,以减少过拟合风险
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
        # 添加残差连接
        hidden_states = residual + hidden_states

        # 构建输出元组,包含最终的隐藏状态
        outputs = (hidden_states,)

        # 如果需要输出注意力权重,则添加到输出元组中
        if output_attentions:
            outputs += (attn_weights,)

        # 返回最终的输出元组
        return outputs


# 从transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayerCollection复制而来,将Bart改为Pegasus
class FlaxPegasusEncoderLayerCollection(nn.Module):
    # Pegasus模型的配置
    config: PegasusConfig
    # 计算中使用的数据类型,默认为32位浮点数
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    # 设置方法,初始化编码器层的集合
    def setup(self):
        # 创建编码器层的列表,每一层为FlaxPegasusEncoderLayer的实例
        self.layers = [
            FlaxPegasusEncoderLayer(self.config, name=str(i), dtype=self.dtype)
            for i in range(self.config.encoder_layers)
        ]
        # 编码器层的Dropout概率,由配置文件中的encoder_layerdrop定义
        self.layerdrop = self.config.encoder_layerdrop
    def __call__(
        self,
        hidden_states,
        attention_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 如果输出注意力权重,则初始化一个空元组用于存储所有注意力权重
        all_attentions = () if output_attentions else None
        # 如果输出隐藏状态,则初始化一个空元组用于存储所有隐藏状态
        all_hidden_states = () if output_hidden_states else None

        # 遍历每个编码器层
        for encoder_layer in self.layers:
            # 如果需要输出隐藏状态,则将当前隐藏状态添加到 all_hidden_states 中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)
            # 使用 LayerDrop 方法来控制是否跳过当前层
            dropout_probability = random.uniform(0, 1)
            if not deterministic and (dropout_probability < self.layerdrop):  # 如果随机数小于 layerdrop,跳过当前层
                layer_outputs = (None, None)
            else:
                # 调用当前编码器层的前向传播方法
                layer_outputs = encoder_layer(
                    hidden_states,
                    attention_mask,
                    output_attentions,
                    deterministic,
                )
            # 更新隐藏状态为当前层的输出的第一个元素
            hidden_states = layer_outputs[0]
            # 如果需要输出注意力权重,则将当前层的注意力权重添加到 all_attentions 中
            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        # 如果需要输出隐藏状态,则将最终隐藏状态添加到 all_hidden_states 中
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        # 构建模型输出对象,根据 return_dict 决定返回格式
        outputs = (hidden_states, all_hidden_states, all_attentions)

        # 如果 return_dict 为 False,则以元组形式返回 outputs 中非空的部分
        if not return_dict:
            return tuple(v for v in outputs if v is not None)

        # 如果 return_dict 为 True,则构建 FlaxBaseModelOutput 对象返回
        return FlaxBaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )
# 从transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayer复制到Pegasus,定义了FlaxPegasusDecoderLayer类
class FlaxPegasusDecoderLayer(nn.Module):
    # 类变量:使用PegasusConfig配置
    config: PegasusConfig
    # 类变量:数据类型为jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 初始化方法,无返回值
    def setup(self) -> None:
        # 设置self.embed_dim为配置中的d_model值,即模型的维度大小
        self.embed_dim = self.config.d_model
        # 初始化self.self_attn为FlaxPegasusAttention对象,用于自注意力机制
        self.self_attn = FlaxPegasusAttention(
            config=self.config,
            embed_dim=self.embed_dim,
            num_heads=self.config.decoder_attention_heads,
            dropout=self.config.attention_dropout,
            causal=True,
            dtype=self.dtype,
        )
        # 初始化self.dropout_layer为Dropout层,用于随机失活
        self.dropout_layer = nn.Dropout(rate=self.config.dropout)
        # 设置激活函数为配置中指定的激活函数,并初始化激活函数的随机失活层
        self.activation_fn = ACT2FN[self.config.activation_function]
        self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)

        # 初始化self.self_attn_layer_norm为LayerNorm层,用于自注意力机制的归一化
        self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
        # 初始化self.encoder_attn为FlaxPegasusAttention对象,用于编码器注意力机制
        self.encoder_attn = FlaxPegasusAttention(
            config=self.config,
            embed_dim=self.embed_dim,
            num_heads=self.config.decoder_attention_heads,
            dropout=self.config.attention_dropout,
            dtype=self.dtype,
        )
        # 初始化self.encoder_attn_layer_norm为LayerNorm层,用于编码器注意力机制的归一化
        self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
        # 初始化self.fc1为全连接层,用于第一个前馈神经网络
        self.fc1 = nn.Dense(
            self.config.decoder_ffn_dim,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )
        # 初始化self.fc2为全连接层,用于第二个前馈神经网络
        self.fc2 = nn.Dense(
            self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
        )
        # 初始化self.final_layer_norm为LayerNorm层,用于最终的归一化
        self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)

    # 调用方法,定义了层的前向传播逻辑
    def __call__(
        self,
        hidden_states: jnp.ndarray,  # 输入的隐藏状态张量
        attention_mask: jnp.ndarray,  # 注意力掩码张量
        encoder_hidden_states: Optional[jnp.ndarray] = None,  # 编码器的隐藏状态张量(可选)
        encoder_attention_mask: Optional[jnp.ndarray] = None,  # 编码器的注意力掩码张量(可选)
        init_cache: bool = False,  # 是否初始化缓存(默认为False)
        output_attentions: bool = True,  # 是否输出注意力权重(默认为True)
        deterministic: bool = True,  # 是否使用确定性计算(默认为True)

        # 方法开始
        # 返回self.self_attn的前向传播结果,对输入的hidden_states进行自注意力计算
        # 返回值包括输出张量以及注意力权重(如果output_attentions为True)
        return self.self_attn(
            hidden_states,
            attention_mask,
            init_cache=init_cache,
            output_attentions=output_attentions,
            deterministic=deterministic,
        )
        ) -> Tuple[jnp.ndarray]:
        # 将输入的 hidden_states 复制给 residual,用于后续的残差连接
        residual = hidden_states
        # 对 hidden_states 进行 Layer Normalization
        hidden_states = self.self_attn_layer_norm(hidden_states)

        # Self Attention
        # 使用 self_attn 层处理 hidden_states,包括注意力计算和可能的缓存初始化
        hidden_states, self_attn_weights = self.self_attn(
            hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
        )
        # 应用 dropout 层
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
        # 添加残差连接
        hidden_states = residual + hidden_states

        # Cross-Attention Block
        cross_attn_weights = None
        # 如果提供了 encoder_hidden_states,则执行以下操作
        if encoder_hidden_states is not None:
            # 将输入的 hidden_states 复制给 residual,用于后续的残差连接
            residual = hidden_states

            # 对 hidden_states 进行 Layer Normalization
            hidden_states = self.encoder_attn_layer_norm(hidden_states)
            # 使用 encoder_attn 层处理 hidden_states 和 encoder_hidden_states
            # 包括注意力计算和可能的 attention_mask 应用
            hidden_states, cross_attn_weights = self.encoder_attn(
                hidden_states=hidden_states,
                key_value_states=encoder_hidden_states,
                attention_mask=encoder_attention_mask,
            )
            # 应用 dropout 层
            hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
            # 添加残差连接
            hidden_states = residual + hidden_states

        # Fully Connected
        # 将输入的 hidden_states 复制给 residual,用于后续的残差连接
        residual = hidden_states
        # 对 hidden_states 进行 Layer Normalization
        hidden_states = self.final_layer_norm(hidden_states)
        # 应用激活函数 activation_fn
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        # 应用 activation_dropout_layer
        hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
        # 应用全连接层 fc2
        hidden_states = self.fc2(hidden_states)
        # 应用 dropout 层
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
        # 添加残差连接
        hidden_states = residual + hidden_states

        # 准备输出,初始化为包含 hidden_states 的元组 outputs
        outputs = (hidden_states,)

        # 如果需要输出 attention weights,则将 self_attn_weights 和 cross_attn_weights 添加到 outputs 中
        if output_attentions:
            outputs += (self_attn_weights, cross_attn_weights)

        # 返回 outputs
        return outputs
# 从transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayerCollection复制的代码,将Bart更改为Pegasus
class FlaxPegasusDecoderLayerCollection(nn.Module):
    # Pegasus模型的配置
    config: PegasusConfig
    # 计算的数据类型
    dtype: jnp.dtype = jnp.float32  # 计算的数据类型为浮点数(32位)

    def setup(self):
        # 创建Pegasus解码器层的集合
        self.layers = [
            FlaxPegasusDecoderLayer(self.config, name=str(i), dtype=self.dtype)
            for i in range(self.config.decoder_layers)
        ]
        # 解码器层的LayerDrop概率
        self.layerdrop = self.config.decoder_layerdrop

    def __call__(
        self,
        hidden_states,
        attention_mask,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 存储所有隐藏状态(如果需要返回)
        all_hidden_states = () if output_hidden_states else None
        # 存储所有自注意力权重(如果需要返回)
        all_self_attns = () if output_attentions else None
        # 存储所有跨注意力权重(如果需要返回),仅在同时输出注意力并且存在编码器隐藏状态时才存储
        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None

        # 对每个解码器层进行迭代
        for decoder_layer in self.layers:
            if output_hidden_states:
                # 如果需要输出隐藏状态,则将当前隐藏状态添加到存储中
                all_hidden_states += (hidden_states,)
                # 添加LayerDrop(参见https://arxiv.org/abs/1909.11556进行描述)

            # 随机采样一个Dropout概率
            dropout_probability = random.uniform(0, 1)
            # 如果是非确定性计算并且随机概率小于LayerDrop阈值,则将输出设为None
            if not deterministic and (dropout_probability < self.layerdrop):
                layer_outputs = (None, None, None)
            else:
                # 否则,调用当前解码器层进行前向传播计算
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_attention_mask,
                    init_cache=init_cache,
                    output_attentions=output_attentions,
                    deterministic=deterministic,
                )

            # 更新隐藏状态为当前解码器层的输出的第一个元素
            hidden_states = layer_outputs[0]
            # 如果需要输出注意力权重,则将当前解码器层的自注意力权重添加到存储中
            if output_attentions:
                all_self_attns += (layer_outputs[1],)
                # 如果存在编码器隐藏状态,则将当前解码器层的跨注意力权重添加到存储中
                if encoder_hidden_states is not None:
                    all_cross_attentions += (layer_outputs[2],)

        # 如果需要输出最终隐藏状态,则将最终隐藏状态添加到存储中
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        # 将所有需要输出的结果存储在outputs列表中
        outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]

        # 如果不需要以字典形式返回结果,则返回输出列表中的非None元素
        if not return_dict:
            return tuple(v for v in outputs if v is not None)

        # 以包含过去和跨注意力权重的形式返回FlaxBaseModelOutputWithPastAndCrossAttentions对象
        return FlaxBaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
            cross_attentions=all_cross_attentions,
        )


class FlaxPegasusEncoder(nn.Module):
    # Pegasus模型的配置
    config: PegasusConfig
    # 嵌入令牌的层
    embed_tokens: nn.Embed
    # 计算的数据类型
    dtype: jnp.dtype = jnp.float32  # 计算的数据类型为浮点数(32位)
    # 在类初始化方法中设置dropout层,根据配置中的dropout率创建实例
    def setup(self):
        self.dropout_layer = nn.Dropout(rate=self.config.dropout)

        # 从配置中获取词嵌入的维度
        embed_dim = self.config.d_model
        # 从配置中获取填充标记的索引
        self.padding_idx = self.config.pad_token_id
        # 从配置中获取源序列的最大位置数
        self.max_source_positions = self.config.max_position_embeddings
        # 如果配置中设置了缩放词嵌入,则计算缩放因子为词嵌入维度的平方根,否则为1.0
        self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0

        # 创建正弦位置编码,并赋值给self.embed_positions
        self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)
        # 创建FlaxPegasusEncoderLayerCollection实例,用于后续编码器层的处理
        self.layers = FlaxPegasusEncoderLayerCollection(self.config, self.dtype)
        # 创建LayerNorm实例,用于对隐藏状态进行归一化处理
        self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)

    # 在实例调用时处理输入数据,执行编码器操作
    def __call__(
        self,
        input_ids,
        attention_mask,
        position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        # 获取输入数据的形状信息
        input_shape = input_ids.shape
        # 将input_ids重新reshape为二维张量
        input_ids = input_ids.reshape(-1, input_shape[-1])

        # 使用嵌入词表将input_ids转换为嵌入向量,并乘以缩放因子
        inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

        # 嵌入位置信息
        embed_pos = jnp.take(self.embed_positions, position_ids, axis=0)
        # 显式地将位置信息embed_pos转换为与inputs_embeds相同的数据类型
        embed_pos = embed_pos.astype(inputs_embeds.dtype)

        # 将输入嵌入向量与位置嵌入向量相加形成隐藏状态
        hidden_states = inputs_embeds + embed_pos
        # 对隐藏状态应用dropout操作
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
        # 将隐藏状态传递给编码器层进行处理
        outputs = self.layers(
            hidden_states,
            attention_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 获取编码器层输出中的最后一个隐藏状态
        last_hidden_state = outputs[0]
        # 对最后一个隐藏状态应用LayerNorm归一化
        last_hidden_state = self.layer_norm(last_hidden_state)

        # 如果需要输出所有隐藏状态,则更新outputs中的hidden_states
        hidden_states = None
        if output_hidden_states:
            hidden_states = outputs[1]
            hidden_states = hidden_states[:-1] + (last_hidden_state,)

        # 如果不返回字典,则根据需要重新组织输出的元组
        if not return_dict:
            outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
            return tuple(v for v in outputs if v is not None)

        # 返回FlaxBaseModelOutput对象,包括最后一个隐藏状态、所有隐藏状态和注意力权重(如果有)
        return FlaxBaseModelOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=hidden_states,
            attentions=outputs.attentions,
        )
class FlaxPegasusDecoder(nn.Module):
    config: PegasusConfig  # Pegasus model configuration
    embed_tokens: nn.Embed  # Embedding tokens for input sequence
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):
        self.dropout_layer = nn.Dropout(rate=self.config.dropout)  # Dropout layer initialization

        embed_dim = self.config.d_model  # Dimension of the embedding
        self.padding_idx = self.config.pad_token_id  # Padding token index from configuration
        self.max_target_positions = self.config.max_position_embeddings  # Maximum target positions
        self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0  # Embedding scale factor

        self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)
        # Create sinusoidal positional embeddings

        self.layers = FlaxPegasusDecoderLayerCollection(self.config, self.dtype)
        # Layers of the Pegasus decoder
        self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
        # Layer normalization initialization

    def __call__(
        self,
        input_ids,
        attention_mask,
        position_ids,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        input_shape = input_ids.shape
        input_ids = input_ids.reshape(-1, input_shape[-1])
        # Reshape input_ids to flatten the sequence dimensions

        inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
        # Embed input tokens and scale embeddings

        # embed positions
        positions = jnp.take(self.embed_positions, position_ids, axis=0)
        # Retrieve positional embeddings based on position_ids
        positions = positions.astype(inputs_embeds.dtype)
        # Explicitly cast positions to match inputs_embeds dtype

        hidden_states = inputs_embeds + positions
        # Combine token embeddings with positional embeddings
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
        # Apply dropout to hidden states

        outputs = self.layers(
            hidden_states,
            attention_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # Pass hidden states through decoder layers

        last_hidden_state = outputs[0]
        # Retrieve the last hidden state from the outputs
        last_hidden_state = self.layer_norm(last_hidden_state)
        # Apply layer normalization to the last hidden state

        hidden_states = None
        if output_hidden_states:
            hidden_states = outputs[1]
            hidden_states = hidden_states[:-1] + (last_hidden_state,)
            # Concatenate previous hidden states with the current one

        if not return_dict:
            outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
            return tuple(v for v in outputs if v is not None)
            # Return outputs as a tuple without the return_dict format

        return FlaxBaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=last_hidden_state,
            hidden_states=hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )
        # Return outputs as FlaxBaseModelOutputWithPastAndCrossAttentions object
# 从 transformers.models.bart.modeling_flax_bart.FlaxBartModule 复制并修改为 Pegasus
class FlaxPegasusModule(nn.Module):
    config: PegasusConfig  # Pegasus 模型的配置对象
    dtype: jnp.dtype = jnp.float32  # 计算时使用的数据类型

    def setup(self):
        # 创建共享的嵌入层,用于编码器和解码器
        self.shared = nn.Embed(
            self.config.vocab_size,
            self.config.d_model,
            embedding_init=jax.nn.initializers.normal(self.config.init_std),  # 使用正态分布初始化嵌入层
            dtype=self.dtype,
        )

        # 初始化编码器和解码器
        self.encoder = FlaxPegasusEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
        self.decoder = FlaxPegasusDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared)

    def _get_encoder_module(self):
        return self.encoder  # 返回编码器模块

    def _get_decoder_module(self):
        return self.decoder  # 返回解码器模块

    def __call__(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask,
        position_ids,
        decoder_position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        # 调用编码器得到编码器的输出
        encoder_outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )

        # 调用解码器得到解码器的输出,传入编码器的隐藏状态和注意力掩码
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            position_ids=decoder_position_ids,
            encoder_hidden_states=encoder_outputs[0],
            encoder_attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )

        if not return_dict:
            return decoder_outputs + encoder_outputs  # 如果不返回字典,则返回所有输出

        # 返回序列到序列模型的输出
        return FlaxSeq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )


class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel):
    config_class = PegasusConfig  # Pegasus 预训练模型的配置类
    base_model_prefix: str = "model"  # 基础模型的前缀名称
    module_class: nn.Module = None

    def __init__(
        self,
        config: PegasusConfig,
        input_shape: Tuple[int] = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        # 使用配置和数据类型初始化模块对象
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        # 调用父类的构造函数,传入配置、模块对象、输入形状、种子、数据类型和初始化标志
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # 初始化输入张量
        input_ids = jnp.zeros(input_shape, dtype="i4")
        # 创建与 input_ids 形状相同的全 1 张量作为注意力掩码
        attention_mask = jnp.ones_like(input_ids)
        # 将 decoder_input_ids 初始化为与 input_ids 相同的张量
        decoder_input_ids = input_ids
        # 创建与 input_ids 形状相同的全 1 张量作为解码器的注意力掩码
        decoder_attention_mask = jnp.ones_like(input_ids)

        # 获取 batch_size 和 sequence_length
        batch_size, sequence_length = input_ids.shape
        # 创建位置编码,将序列长度广播到每个样本的每个位置
        position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
        # 解码器的位置编码与编码器相同
        decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

        # 分割随机数生成器 rng 为 params_rng 和 dropout_rng
        params_rng, dropout_rng = jax.random.split(rng)
        # 创建随机数字典 rngs,包含 params_rng 和 dropout_rng
        rngs = {"params": params_rng, "dropout": dropout_rng}

        # 使用模块对象的初始化方法初始化模型参数
        random_params = self.module.init(
            rngs,
            input_ids,
            attention_mask,
            decoder_input_ids,
            decoder_attention_mask,
            position_ids,
            decoder_position_ids,
        )["params"]

        # 如果提供了初始参数 params,则使用它替换随机初始化的部分参数
        if params is not None:
            # 展平并解冻参数字典
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            # 将缺失的参数键从随机参数复制到提供的参数中
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()  # 清空缺失键集合
            # 冻结并返回更新后的参数字典
            return freeze(unflatten_dict(params))
        else:
            # 如果未提供初始参数,则直接返回随机初始化的参数字典
            return random_params
    def init_cache(self, batch_size, max_length, encoder_outputs):
        r"""
        Args:
            batch_size (`int`):
                fast auto-regressive decoding 使用的 batch_size。定义了初始化缓存时的批处理大小。
            max_length (`int`):
                自动回归解码的最大可能长度。定义了初始化缓存时的序列长度。
            encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
                `encoder_outputs` 包含 (`last_hidden_state`, *可选*: `hidden_states`, *可选*: `attentions`)。
                `last_hidden_state` 的形状为 `(batch_size, sequence_length, hidden_size)`,*可选*: 编码器最后一层输出的隐藏状态。
                在解码器的交叉注意力中使用。

        初始化缓存函数,用于预先设置解码器的缓存状态。

        """
        # 初始化解码器的输入标识,全部为1的数组,形状为 (batch_size, max_length)
        decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
        # 初始化解码器的注意力掩码,与 decoder_input_ids 形状相同的全1数组
        decoder_attention_mask = jnp.ones_like(decoder_input_ids)
        # 初始化解码器的位置标识,将一个广播数组设置为与 decoder_input_ids 形状相同的位置标识
        decoder_position_ids = jnp.broadcast_to(
            jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
        )

        # 定义内部函数 _decoder_forward,用于调用解码器模块并返回结果
        def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
            decoder_module = module._get_decoder_module()
            return decoder_module(
                decoder_input_ids,
                decoder_attention_mask,
                decoder_position_ids,
                **kwargs,
            )

        # 使用模型的初始化方法初始化变量,并设置解码器相关参数
        init_variables = self.module.init(
            jax.random.PRNGKey(0),
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            decoder_position_ids=decoder_position_ids,
            encoder_hidden_states=encoder_outputs[0],
            init_cache=True,
            method=_decoder_forward,  # 仅需调用解码器以初始化缓存
        )
        # 返回解除冻结后的变量中的缓存部分
        return unfreeze(init_variables["cache"])
    ):
        r"""
        Returns:

        Example:

        ```
        >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration

        >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large")
        >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")

        >>> text = "My friends are cool but they eat too many carbs."
        >>> inputs = tokenizer(text, max_length=1024, return_tensors="np")
        >>> encoder_outputs = model.encode(**inputs)
        ```"""
        # 根据传入的参数设置输出注意力机制
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 根据传入的参数设置输出隐藏状态
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 根据传入的参数设置返回字典类型
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # 如果 attention_mask 为 None,则创建一个全为1的掩码与 input_ids 形状相同
        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)
        # 如果 position_ids 为 None,则根据 input_ids 形状创建对应的位置ID张量
        if position_ids is None:
            batch_size, sequence_length = input_ids.shape
            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

        # 如果 dropout_rng 不为 None,则将其作为 "dropout" 的随机数生成器
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        # 定义一个内部函数 _encoder_forward,用于调用编码器模块的前向方法
        def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
            encode_module = module._get_encoder_module()
            return encode_module(input_ids, attention_mask, position_ids, **kwargs)

        # 调用 self.module.apply 方法执行模型的前向传播
        return self.module.apply(
            {"params": params or self.params},  # 使用传入的参数或者默认参数来执行前向传播
            input_ids=jnp.array(input_ids, dtype="i4"),  # 将 input_ids 转换为 JAX 数组
            attention_mask=jnp.array(attention_mask, dtype="i4"),  # 将 attention_mask 转换为 JAX 数组
            position_ids=jnp.array(position_ids, dtype="i4"),  # 将 position_ids 转换为 JAX 数组
            output_attentions=output_attentions,  # 控制是否输出注意力机制
            output_hidden_states=output_hidden_states,  # 控制是否输出隐藏状态
            return_dict=return_dict,  # 控制是否以字典形式返回结果
            deterministic=not train,  # 是否处于训练模式
            rngs=rngs,  # 随机数生成器的字典
            method=_encoder_forward,  # 指定执行的方法
        )

    @add_start_docstrings(PEGASUS_DECODE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=PegasusConfig)
    def decode(
        self,
        decoder_input_ids,
        encoder_outputs,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        past_key_values: dict = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
    # 定义一个 __call__ 方法,使对象可以像函数一样被调用,接收以下参数:
    #   - input_ids: 输入的编码序列,类型为 jnp.ndarray
    #   - attention_mask: 可选参数,注意力掩码,默认为 None
    #   - decoder_input_ids: 可选参数,解码器的输入编码序列,默认为 None
    #   - decoder_attention_mask: 可选参数,解码器的注意力掩码,默认为 None
    #   - position_ids: 可选参数,位置编码序列,默认为 None
    #   - decoder_position_ids: 可选参数,解码器的位置编码序列,默认为 None
    #   - output_attentions: 可选参数,是否输出注意力权重,默认为 None
    #   - output_hidden_states: 可选参数,是否输出隐藏状态,默认为 None
    #   - return_dict: 可选参数,是否返回字典格式的结果,默认为 None
    #   - train: 是否处于训练模式,默认为 False
    #   - params: 可选参数,模型的参数,默认为 None
    #   - dropout_rng: 可选参数,随机数生成器用于 dropout,默认为 None

    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    # 如果 output_attentions 不为 None,则使用该值;否则使用 self.config.output_attentions 的值

    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    # 如果 output_hidden_states 不为 None,则使用该值;否则使用 self.config.output_hidden_states 的值

    return_dict = return_dict if return_dict is not None else self.config.return_dict
    # 如果 return_dict 不为 None,则使用该值;否则使用 self.config.return_dict 的值

    # 准备编码器的输入
    if attention_mask is None:
        attention_mask = jnp.ones_like(input_ids)
    # 如果 attention_mask 为 None,则创建一个与 input_ids 形状相同的全为 1 的注意力掩码

    if position_ids is None:
        batch_size, sequence_length = input_ids.shape
        position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
    # 如果 position_ids 为 None,则根据 input_ids 的形状创建位置编码序列

    # 准备解码器的输入
    if decoder_input_ids is None:
        decoder_input_ids = shift_tokens_right(
            input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
        )
    # 如果 decoder_input_ids 为 None,则将 input_ids 向右移动一个位置,并使用配置中的特殊标记进行填充

    if decoder_attention_mask is None:
        decoder_attention_mask = jnp.ones_like(decoder_input_ids)
    # 如果 decoder_attention_mask 为 None,则创建一个与 decoder_input_ids 形状相同的全为 1 的注意力掩码

    if decoder_position_ids is None:
        batch_size, sequence_length = decoder_input_ids.shape
        decoder_position_ids = jnp.broadcast_to(
            jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
        )
    # 如果 decoder_position_ids 为 None,则根据 decoder_input_ids 的形状创建位置编码序列

    # 处理可能需要的任何随机数生成器
    rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
    # 如果 dropout_rng 不为 None,则创建一个包含 dropout_rng 的随机数生成器字典;否则创建一个空字典

    return self.module.apply(
        {"params": params or self.params},
        input_ids=jnp.array(input_ids, dtype="i4"),
        attention_mask=jnp.array(attention_mask, dtype="i4"),
        position_ids=jnp.array(position_ids, dtype="i4"),
        decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
        decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
        decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        deterministic=not train,
        rngs=rngs,
    )
    # 应用 self.module 中的函数:
    #   - 使用 params 或 self.params 中的参数
    #   - 将输入转换为 jnp.ndarray 格式并传递给相应参数
    #   - 设置是否输出注意力权重和隐藏状态
    #   - 设置是否以字典格式返回结果
    #   - 设置是否处于确定性计算模式
    #   - 传递随机数生成器字典 rngs
# 使用装饰器为 FlaxPegasusModel 类添加文档字符串,描述其作为 Pegasus 模型的基本转换器,输出原始隐藏状态而无顶部特定头部。
# PEGASUS_START_DOCSTRING 中包含 Pegasus 模型的起始文档字符串。
@add_start_docstrings(
    "The bare Pegasus Model transformer outputting raw hidden-states without any specific head on top.",
    PEGASUS_START_DOCSTRING,
)
# 定义 FlaxPegasusModel 类,继承自 FlaxPegasusPreTrainedModel,具有 PegasusConfig 类型的配置参数。
class FlaxPegasusModel(FlaxPegasusPreTrainedModel):
    config: PegasusConfig
    # 计算中使用的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32
    # 模型类别设置为 FlaxPegasusModule
    module_class = FlaxPegasusModule

# 调用函数,为 FlaxPegasusModel 类附加示例调用文档字符串,使用 _CHECKPOINT_FOR_DOC 和 FlaxSeq2SeqModelOutput。
append_call_sample_docstring(FlaxPegasusModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)


# 从 transformers.models.bart.modeling_flax_bart.FlaxBartForConditionalGenerationModule 复制代码,修改为 Pegasus 模型
class FlaxPegasusForConditionalGenerationModule(nn.Module):
    config: PegasusConfig
    dtype: jnp.dtype = jnp.float32
    # 偏置初始化器,使用 jax.nn.initializers.zeros 初始化
    bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros

    # 模块设置方法
    def setup(self):
        # 创建 FlaxPegasusModule 模型对象,使用配置和数据类型作为参数
        self.model = FlaxPegasusModule(config=self.config, dtype=self.dtype)
        # 创建 lm_head 层,使用 nn.Dense 定义,输出维度为 self.model.shared.num_embeddings
        self.lm_head = nn.Dense(
            self.model.shared.num_embeddings,
            use_bias=False,
            dtype=self.dtype,
            # 使用正态分布初始化 kernel 参数,标准差为 self.config.init_std
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )
        # 定义 final_logits_bias 参数,形状为 (1, self.model.shared.num_embeddings),使用 bias_init 初始化
        self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))

    # 获取编码器模块的方法
    def _get_encoder_module(self):
        return self.model.encoder

    # 获取解码器模块的方法
    def _get_decoder_module(self):
        return self.model.decoder

    # 调用方法,定义模型的前向传播逻辑
    def __call__(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask,
        position_ids,
        decoder_position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
        # 声明输入参数的类型和默认值
        **kwargs
    ):
    ):
        # 使用模型生成输出结果
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            position_ids=position_ids,
            decoder_position_ids=decoder_position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )

        # 获取模型输出的隐藏状态
        hidden_states = outputs[0]

        # 如果配置要求共享词嵌入,则使用共享的嵌入层进行计算
        if self.config.tie_word_embeddings:
            shared_embedding = self.model.variables["params"]["shared"]["embedding"]
            # 应用共享的嵌入层权重到隐藏状态,得到语言模型的logits
            lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
        else:
            # 否则直接使用语言模型头部计算logits
            lm_logits = self.lm_head(hidden_states)

        # 添加最终logits的偏置
        lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype))

        # 如果不要求返回字典形式的输出,则返回完整的输出元组
        if not return_dict:
            output = (lm_logits,) + outputs[1:]
            return output

        # 否则,返回FlaxSeq2SeqLMOutput类型的对象,包含完整的输出信息
        return FlaxSeq2SeqLMOutput(
            logits=lm_logits,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )
@add_start_docstrings(
    "The PEGASUS Model with a language modeling head. Can be used for summarization.", PEGASUS_START_DOCSTRING
)
class FlaxPegasusForConditionalGeneration(FlaxPegasusPreTrainedModel):
    module_class = FlaxPegasusForConditionalGenerationModule
    dtype: jnp.dtype = jnp.float32

    @add_start_docstrings(PEGASUS_DECODE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=PegasusConfig)
    def decode(
        self,
        decoder_input_ids,
        encoder_outputs,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        past_key_values: dict = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        deterministic: bool = True,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        """
        Decode function for PEGASUS model, generating outputs based on decoder inputs and encoder outputs.

        Args:
            decoder_input_ids: Input IDs for the decoder.
            encoder_outputs: Outputs from the encoder.
            encoder_attention_mask: Optional attention mask for encoder outputs.
            decoder_attention_mask: Optional attention mask for decoder inputs.
            decoder_position_ids: Optional position IDs for the decoder inputs.
            past_key_values: Cached key values from previous decoding steps.
            output_attentions: Whether to output attention weights.
            output_hidden_states: Whether to output hidden states.
            return_dict: Whether to return outputs as a dictionary.
            deterministic: Whether to use deterministic behavior.
            params: Optional parameters for decoding.
            dropout_rng: Random number generator for dropout.

        Returns:
            Model outputs with cross attentions, conforming to PEGASUS configuration.
        """
        # initializing the cache
        batch_size, seq_length = decoder_input_ids.shape

        past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
        # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
        # But since the decoder uses a causal mask, those positions are masked anyways.
        # Thus we can create a single static attention_mask here, which is more efficient for compilation
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
        if decoder_attention_mask is not None:
            position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
        else:
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))

        return {
            "past_key_values": past_key_values,
            "encoder_outputs": encoder_outputs,
            "encoder_attention_mask": attention_mask,
            "decoder_attention_mask": extended_attention_mask,
            "decoder_position_ids": position_ids,
        }

    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        max_length,
        attention_mask: Optional[jax.Array] = None,
        decoder_attention_mask: Optional[jax.Array] = None,
        encoder_outputs=None,
        **kwargs,
    ):
        """
        Prepare inputs for generation based on decoder inputs and optional masks.

        Args:
            decoder_input_ids: Input IDs for the decoder.
            max_length: Maximum length of generated outputs.
            attention_mask: Optional attention mask for encoder outputs.
            decoder_attention_mask: Optional attention mask for decoder inputs.
            encoder_outputs: Optional outputs from the encoder.
            **kwargs: Additional keyword arguments.

        Returns:
            Dictionary of inputs formatted for generation.
        """
        # initializing the cache
        batch_size, seq_length = decoder_input_ids.shape

        past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
        # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
        # But since the decoder uses a causal mask, those positions are masked anyways.
        # Thus we can create a single static attention_mask here, which is more efficient for compilation
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
        if decoder_attention_mask is not None:
            position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
        else:
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))

        return {
            "past_key_values": past_key_values,
            "encoder_outputs": encoder_outputs,
            "encoder_attention_mask": attention_mask,
            "decoder_attention_mask": extended_attention_mask,
            "decoder_position_ids": position_ids,
        }

    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        """
        Update inputs for generation based on model outputs and current model arguments.

        Args:
            model_outputs: Outputs from the model.
            model_kwargs: Current model keyword arguments.

        Returns:
            Updated model keyword arguments for generation.
        """
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
        return model_kwargs
    >>> model = FlaxPegasusForConditionalGeneration.from_pretrained('google/pegasus-large')
    >>> tokenizer = AutoTokenizer.from_pretrained('google/pegasus-large')
    
    >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
    >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='np')
    
    >>> # 生成摘要
    >>> summary_ids = model.generate(inputs['input_ids']).sequences
    >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
    
    
    Mask filling example:
    
    
    >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration
    
    >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
    >>> TXT = "My friends are <mask> but they eat too many carbs."
    
    >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large")
    >>> input_ids = tokenizer([TXT], return_tensors="np")["input_ids"]
    >>> logits = model(input_ids).logits
    
    >>> # 获取掩码位置的索引
    >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
    >>> # 对预测的概率进行 softmax 处理
    >>> probs = jax.nn.softmax(logits[0, masked_index], axis=0)
    >>> # 获取最高概率的值和预测的索引
    >>> values, predictions = jax.lax.top_k(probs)
    
    >>> # 解码预测的词语并以列表形式返回
    >>> tokenizer.decode(predictions).split()
"""
为FlaxPegasusForConditionalGeneration类的文档字符串添加内容
使用 PEGASUS_INPUTS_DOCSTRING 和 FLAX_PEGASUS_CONDITIONAL_GENERATION_DOCSTRING 进行覆盖
"""
overwrite_call_docstring(
    FlaxPegasusForConditionalGeneration, PEGASUS_INPUTS_DOCSTRING + FLAX_PEGASUS_CONDITIONAL_GENERATION_DOCSTRING
)

"""
为FlaxPegasusForConditionalGeneration类的返回文档字符串追加内容
使用 FlaxSeq2SeqLMOutput 作为输出类型,使用 _CONFIG_FOR_DOC 作为配置类
"""
append_replace_return_docstrings(
    FlaxPegasusForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
)
posted @ 2024-06-29 15:48  绝不原创的飞龙  阅读(27)  评论(0编辑  收藏  举报