Transformers-源码解析-八十九-

Transformers 源码解析(八十九)

.\models\pix2struct\processing_pix2struct.py

# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Pix2Struct.
"""

from typing import List, Optional, Union

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType


class Pix2StructProcessor(ProcessorMixin):
    r"""
    Constructs a PIX2STRUCT processor which wraps a BERT tokenizer and PIX2STRUCT image processor into a single
    processor.

    [`Pix2StructProcessor`] offers all the functionalities of [`Pix2StructImageProcessor`] and [`T5TokenizerFast`]. See
    the docstring of [`~Pix2StructProcessor.__call__`] and [`~Pix2StructProcessor.decode`] for more information.

    Args:
        image_processor (`Pix2StructImageProcessor`):
            An instance of [`Pix2StructImageProcessor`]. The image processor is a required input.
        tokenizer (Union[`T5TokenizerFast`, `T5Tokenizer`]):
            An instance of ['T5TokenizerFast`] or ['T5Tokenizer`]. The tokenizer is a required input.
    """

    attributes = ["image_processor", "tokenizer"]
    image_processor_class = "Pix2StructImageProcessor"
    tokenizer_class = ("T5Tokenizer", "T5TokenizerFast")

    def __init__(self, image_processor, tokenizer):
        # Disable token type IDs as they are not used in this processor
        tokenizer.return_token_type_ids = False
        # Initialize the processor with the provided image processor and tokenizer
        super().__init__(image_processor, tokenizer)

    def __call__(
        self,
        images=None,
        text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
        add_special_tokens: bool = True,
        padding: Union[bool, str, PaddingStrategy] = False,
        truncation: Union[bool, str, TruncationStrategy] = None,
        max_length: Optional[int] = None,
        max_patches: Optional[int] = 2048,
        stride: int = 0,
        pad_to_multiple_of: Optional[int] = None,
        return_attention_mask: Optional[bool] = None,
        return_overflowing_tokens: bool = False,
        return_special_tokens_mask: bool = False,
        return_offsets_mapping: bool = False,
        return_token_type_ids: bool = False,
        return_length: bool = False,
        verbose: bool = True,
        return_tensors: Optional[Union[str, TensorType]] = None,
        **kwargs,
    ):
        """
        Process input images and text into a format suitable for PIX2STRUCT tasks.

        Args:
            images (optional): Input images to process.
            text (Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]): Input text data.
            add_special_tokens (bool): Whether to add special tokens (like [CLS], [SEP]) or not.
            padding (Union[bool, str, PaddingStrategy]): Padding strategy for text sequences.
            truncation (Union[bool, str, TruncationStrategy]): Truncation strategy for text sequences.
            max_length (Optional[int]): Maximum sequence length to enforce.
            max_patches (Optional[int]): Maximum number of patches to consider.
            stride (int): Stride length for patch extraction.
            pad_to_multiple_of (Optional[int]): Pad the sequence length to a multiple of this value.
            return_attention_mask (Optional[bool]): Whether to return attention masks.
            return_overflowing_tokens (bool): Whether to return overflowing tokens.
            return_special_tokens_mask (bool): Whether to return special tokens mask.
            return_offsets_mapping (bool): Whether to return offsets mapping.
            return_token_type_ids (bool): Whether to return token type IDs (not used in this processor).
            return_length (bool): Whether to return sequence length.
            verbose (bool): Whether to print verbose information.
            return_tensors (Optional[Union[str, TensorType]]): Desired tensor type for returned tensors.

        Returns:
            BatchEncoding: Processed inputs in a batch encoding format.

        Notes:
            This method processes both images and text to prepare them for PIX2STRUCT tasks.
            It incorporates functionality from both `Pix2StructImageProcessor` and `T5TokenizerFast`.
        """
        # Implementation of input processing logic goes here
        pass
    def batch_decode(self, *args, **kwargs):
        """
        This method forwards all its arguments to Pix2StructTokenizerFast's [`~PreTrainedTokenizer.batch_decode`].
        Please refer to the docstring of this method for more information.
        """
        # 调用内部的 `batch_decode` 方法,将所有参数传递给 Pix2StructTokenizerFast 的 `batch_decode` 方法
        return self.tokenizer.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        """
        This method forwards all its arguments to Pix2StructTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please
        refer to the docstring of this method for more information.
        """
        # 调用内部的 `decode` 方法,将所有参数传递给 Pix2StructTokenizerFast 的 `decode` 方法
        return self.tokenizer.decode(*args, **kwargs)

    @property
    def model_input_names(self):
        """
        This property returns a list of unique model input names by combining tokenizer's and image_processor's input names.
        """
        # 获取 tokenizer 和 image_processor 的模型输入名称列表
        tokenizer_input_names = self.tokenizer.model_input_names
        image_processor_input_names = self.image_processor.model_input_names
        # 使用集合去除重复项,然后转换为列表并返回
        return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))

.\models\pix2struct\__init__.py

# 版权声明及许可证声明,声明代码版权及使用许可
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 引入类型检查标志 TYPE_CHECKING
from typing import TYPE_CHECKING

# 从 utils 模块中引入相关工具和检查 Torch 和 Vision 是否可用的函数
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available

# 定义模块导入结构字典 _import_structure
_import_structure = {
    "configuration_pix2struct": [
        "PIX2STRUCT_PRETRAINED_CONFIG_ARCHIVE_MAP",  # 预训练配置文件映射
        "Pix2StructConfig",  # Pix2Struct 模型配置
        "Pix2StructTextConfig",  # 文本 Pix2Struct 模型配置
        "Pix2StructVisionConfig",  # 视觉 Pix2Struct 模型配置
    ],
    "processing_pix2struct": ["Pix2StructProcessor"],  # Pix2Struct 数据处理器
}

# 检查 Vision 是否可用,若不可用则引发 OptionalDependencyNotAvailable 异常
try:
    if not is_vision_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若 Vision 可用,则将 Pix2Struct 图像处理器添加到 _import_structure 中
    _import_structure["image_processing_pix2struct"] = ["Pix2StructImageProcessor"]

# 检查 Torch 是否可用,若不可用则引发 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若 Torch 可用,则将 Pix2Struct 模型相关内容添加到 _import_structure 中
    _import_structure["modeling_pix2struct"] = [
        "PIX2STRUCT_PRETRAINED_MODEL_ARCHIVE_LIST",  # 预训练模型存档列表
        "Pix2StructPreTrainedModel",  # Pix2Struct 预训练模型基类
        "Pix2StructForConditionalGeneration",  # 条件生成 Pix2Struct 模型
        "Pix2StructVisionModel",  # 视觉 Pix2Struct 模型
        "Pix2StructTextModel",  # 文本 Pix2Struct 模型
    ]

# 如果是类型检查阶段,导入特定模块和类
if TYPE_CHECKING:
    from .configuration_pix2struct import (
        PIX2STRUCT_PRETRAINED_CONFIG_ARCHIVE_MAP,  # 预训练配置文件映射
        Pix2StructConfig,  # Pix2Struct 模型配置
        Pix2StructTextConfig,  # 文本 Pix2Struct 模型配置
        Pix2StructVisionConfig,  # 视觉 Pix2Struct 模型配置
    )
    from .processing_pix2struct import Pix2StructProcessor  # Pix2Struct 数据处理器

    # 检查 Vision 是否可用,若可用则导入 Pix2Struct 图像处理器
    try:
        if not is_vision_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .image_processing_pix2struct import Pix2StructImageProcessor  # Pix2Struct 图像处理器

    # 检查 Torch 是否可用,若可用则导入 Pix2Struct 模型相关内容
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_pix2struct import (
            PIX2STRUCT_PRETRAINED_MODEL_ARCHIVE_LIST,  # 预训练模型存档列表
            Pix2StructForConditionalGeneration,  # 条件生成 Pix2Struct 模型
            Pix2StructPreTrainedModel,  # Pix2Struct 预训练模型基类
            Pix2StructTextModel,  # 文本 Pix2Struct 模型
            Pix2StructVisionModel,  # 视觉 Pix2Struct 模型
        )

# 如果不是类型检查阶段,则将当前模块替换为延迟加载模块 _LazyModule
else:
    import sys

    # 使用 _LazyModule 代替当前模块,传入模块名、文件名、导入结构和模块规范
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\plbart\configuration_plbart.py

# coding=utf-8
# Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PLBART model configuration"""
# 导入需要的模块和类
from collections import OrderedDict
from typing import Mapping

# 导入配置工具和ONNX配置
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast
from ...utils import logging

# 获取日志记录器
logger = logging.get_logger(__name__)

# 定义预训练模型配置文件映射
PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "uclanlp/plbart-base": "https://huggingface.co/uclanlp/plbart-base/resolve/main/config.json",
    # 查看所有PLBART模型的列表 https://huggingface.co/models?filter=plbart
}

# 定义PLBartConfig类,继承自PretrainedConfig类
class PLBartConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`PLBartModel`]. It is used to instantiate an
    PLBART model according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the PLBART
    [uclanlp/plbart-base](https://huggingface.co/uclanlp/plbart-base) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Example:

    ```
    >>> from transformers import PLBartConfig, PLBartModel

    >>> # Initializing a PLBART uclanlp/plbart-base style configuration
    >>> configuration = PLBartConfig()

    >>> # Initializing a model (with random weights) from the uclanlp/plbart-base style configuration
    >>> model = PLBartModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""
    
    # 模型类型为plbart
    model_type = "plbart"
    # 推理过程中需要忽略的键列表
    keys_to_ignore_at_inference = ["past_key_values"]
    # 属性映射字典,用于将配置中的属性名映射到模型架构中使用的名称
    attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
    # 初始化函数,用于初始化Transformer模型的各种参数和配置
    def __init__(
        self,
        vocab_size=50005,  # 词汇表大小,默认为50005
        max_position_embeddings=1024,  # 最大位置编码长度,默认为1024
        encoder_layers=6,  # 编码器层数,默认为6层
        encoder_ffn_dim=3072,  # 编码器中FFN(Feed Forward Network)层的维度,默认为3072
        encoder_attention_heads=12,  # 编码器中注意力头的数量,默认为12
        decoder_layers=6,  # 解码器层数,默认为6层
        decoder_ffn_dim=3072,  # 解码器中FFN层的维度,默认为3072
        decoder_attention_heads=12,  # 解码器中注意力头的数量,默认为12
        encoder_layerdrop=0.0,  # 编码器层的丢弃概率,默认为0.0
        decoder_layerdrop=0.0,  # 解码器层的丢弃概率,默认为0.0
        use_cache=True,  # 是否使用缓存,默认为True
        is_encoder_decoder=True,  # 是否是编码器-解码器结构,默认为True
        activation_function="gelu",  # 激活函数类型,默认为GELU
        d_model=768,  # 模型维度,默认为768
        dropout=0.1,  # 全局Dropout概率,默认为0.1
        attention_dropout=0.1,  # 注意力机制的Dropout概率,默认为0.1
        activation_dropout=0.0,  # 激活函数Dropout概率,默认为0.0
        init_std=0.02,  # 初始化标准差,默认为0.02
        classifier_dropout=0.0,  # 分类器层的Dropout概率,默认为0.0
        scale_embedding=True,  # 是否缩放嵌入,默认为True;如果为True,则缩放因子为sqrt(d_model)
        pad_token_id=1,  # 填充标记的ID,默认为1
        bos_token_id=0,  # 起始标记的ID,默认为0
        eos_token_id=2,  # 结束标记的ID,默认为2
        forced_eos_token_id=2,  # 强制结束标记的ID,默认为2
        **kwargs,  # 其他未明确列出的参数,用于接收和处理其他未命名参数
    ):
        self.vocab_size = vocab_size  # 设置词汇表大小
        self.max_position_embeddings = max_position_embeddings  # 设置最大位置编码长度
        self.d_model = d_model  # 设置模型维度
        self.encoder_ffn_dim = encoder_ffn_dim  # 设置编码器中FFN层的维度
        self.encoder_layers = encoder_layers  # 设置编码器层数
        self.encoder_attention_heads = encoder_attention_heads  # 设置编码器中注意力头的数量
        self.decoder_ffn_dim = decoder_ffn_dim  # 设置解码器中FFN层的维度
        self.decoder_layers = decoder_layers  # 设置解码器层数
        self.decoder_attention_heads = decoder_attention_heads  # 设置解码器中注意力头的数量
        self.dropout = dropout  # 设置全局Dropout概率
        self.attention_dropout = attention_dropout  # 设置注意力机制的Dropout概率
        self.activation_dropout = activation_dropout  # 设置激活函数Dropout概率
        self.activation_function = activation_function  # 设置激活函数类型
        self.init_std = init_std  # 设置初始化标准差
        self.encoder_layerdrop = encoder_layerdrop  # 设置编码器层的丢弃概率
        self.decoder_layerdrop = decoder_layerdrop  # 设置解码器层的丢弃概率
        self.classifier_dropout = classifier_dropout  # 设置分类器层的Dropout概率
        self.use_cache = use_cache  # 设置是否使用缓存
        self.num_hidden_layers = encoder_layers  # 设置隐藏层的数量为编码器的层数
        self.scale_embedding = scale_embedding  # 设置是否缩放嵌入

        super().__init__(  # 调用父类的初始化函数
            pad_token_id=pad_token_id,  # 设置填充标记的ID
            bos_token_id=bos_token_id,  # 设置起始标记的ID
            eos_token_id=eos_token_id,  # 设置结束标记的ID
            is_encoder_decoder=is_encoder_decoder,  # 设置是否是编码器-解码器结构
            forced_eos_token_id=forced_eos_token_id,  # 设置强制结束标记的ID
            **kwargs,  # 传递其他未明确列出的参数
        )
# 定义一个继承自 OnnxConfigWithPast 的配置类,用于配置 PLBart 模型的 ONNX 格式导出参数
class PLBartOnnxConfig(OnnxConfigWithPast):
    
    # 定义一个属性方法 inputs,返回一个有序字典,描述了模型的输入格式
    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        return OrderedDict(
            [
                ("input_ids", {0: "batch", 1: "sequence"}),  # 输入的 input_ids 的格式描述
                ("attention_mask", {0: "batch", 1: "sequence"}),  # 输入的 attention_mask 的格式描述
            ]
        )

    # 定义一个属性方法 outputs,返回一个有序字典,描述了模型的输出格式
    @property
    def outputs(self) -> Mapping[str, Mapping[int, str]]:
        if self.use_past:
            return OrderedDict(
                [
                    ("last_hidden_state", {0: "batch", 1: "sequence"}),  # 输出的 last_hidden_state 的格式描述
                    ("past_keys", {0: "batch", 2: "sequence"}),  # 输出的 past_keys 的格式描述
                    ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),  # 输出的 encoder_last_hidden_state 的格式描述
                ]
            )
        else:
            return OrderedDict(
                [
                    ("last_hidden_state", {0: "batch", 1: "sequence"}),  # 输出的 last_hidden_state 的格式描述
                    ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),  # 输出的 encoder_last_hidden_state 的格式描述
                ]
            )

.\models\plbart\convert_plbart_original_checkpoint_to_torch.py

# 导入必要的库和模块
import argparse  # 导入命令行参数解析模块

import torch  # 导入PyTorch库
from torch import nn  # 导入PyTorch的神经网络模块

from transformers import PLBartConfig, PLBartForConditionalGeneration, PLBartForSequenceClassification  # 导入transformers库中的PLBart配置和模型类


def remove_ignore_keys_(state_dict):
    # 定义需要从state_dict中移除的键列表
    ignore_keys = [
        "encoder.version",
        "decoder.version",
        "model.encoder.version",
        "model.decoder.version",
        "_float_tensor",
        "decoder.output_projection.weight",
    ]
    # 遍历并移除state_dict中的特定键
    for k in ignore_keys:
        state_dict.pop(k, None)


def make_linear_from_emb(emb):
    # 获取嵌入层的词汇量和嵌入维度大小
    vocab_size, emb_size = emb.weight.shape
    # 创建一个线性层,其输入大小为词汇量,输出大小为嵌入维度,且没有偏置项
    lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
    # 将线性层的权重初始化为嵌入层的权重
    lin_layer.weight.data = emb.weight.data
    return lin_layer


def convert_fairseq_plbart_checkpoint_from_disk(
    checkpoint_path, hf_config_path="uclanlp/plbart-base", finetuned=False, classification=False
):
    # 从磁盘加载模型的state_dict,使用CPU进行映射
    state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
    # 移除state_dict中的忽略键
    remove_ignore_keys_(state_dict)
    # 获取嵌入层的词汇量大小
    vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]

    # 根据指定的hf_config_path加载PLBart模型的配置
    plbart_config = PLBartConfig.from_pretrained(hf_config_path, vocab_size=vocab_size)

    # 将state_dict中的"decoder.embed_tokens.weight"复制给"shared.weight"
    state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]

    if not classification:
        # 如果不是分类任务,创建条件生成的PLBart模型
        model = PLBartForConditionalGeneration(plbart_config)
        # 加载模型的state_dict
        model.model.load_state_dict(state_dict)
        if finetuned:
            # 如果进行了微调,将lm_head替换为基于嵌入层的线性层
            model.lm_head = make_linear_from_emb(model.model.shared)
    else:
        # 如果是分类任务,初始化分类头部字典
        classification_head = {}
        # 将state_dict中的分类头部相关项移动到classification_head字典中
        for key, value in state_dict.copy().items():
            if key.startswith("classification_heads.sentence_classification_head"):
                classification_head[key.replace("classification_heads.sentence_classification_head.", "")] = value
                state_dict.pop(key)
        # 创建序列分类的PLBart模型
        model = PLBartForSequenceClassification(plbart_config)
        # 加载模型的state_dict和分类头部的state_dict
        model.model.load_state_dict(state_dict)
        model.classification_head.load_state_dict(classification_head)

    return model


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # 必需参数:fairseq_path表示本地文件系统上的模型.pt文件
    parser.add_argument("fairseq_path", type=str, help="model.pt on local filesystem.")
    # 可选参数:pytorch_dump_folder_path表示输出PyTorch模型的路径
    parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
    parser.add_argument(
        "--hf_config",
        default="uclanlp/plbart-base",
        type=str,
        help="Which huggingface architecture to use: plbart-base",
    )
    # 添加命令行参数 --hf_config,指定 Huggingface 模型配置的名称,默认为 uclanlp/plbart-base

    parser.add_argument("--finetuned", action="store_true", help="whether the model is a fine-tuned checkpoint")
    # 添加命令行参数 --finetuned,指示模型是否是经过微调的检查点

    parser.add_argument(
        "--classification", action="store_true", help="whether the model is a classification checkpoint"
    )
    # 添加命令行参数 --classification,指示模型是否是一个分类检查点

    args = parser.parse_args()
    # 解析命令行参数并存储到 args 变量中

    model = convert_fairseq_plbart_checkpoint_from_disk(
        args.fairseq_path,
        hf_config_path=args.hf_config,
        finetuned=args.finetuned,
        classification=args.classification,
    )
    # 调用函数 convert_fairseq_plbart_checkpoint_from_disk,从磁盘加载 Fairseq 的 PLBART 检查点,
    # 使用给定的参数来转换为 Huggingface 模型

    model.save_pretrained(args.pytorch_dump_folder_path)
    # 将转换后的 Huggingface 模型保存到指定的 PyTorch dump 文件夹路径中

.\models\plbart\modeling_plbart.py

# coding=utf-8
# Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch PLBART model."""
import copy
import math
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...modeling_attn_mask_utils import (
    _prepare_4d_attention_mask,
    _prepare_4d_attention_mask_for_sdpa,
    _prepare_4d_causal_attention_mask,
    _prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
    Seq2SeqSequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
    add_code_sample_docstrings,
    add_end_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from .configuration_plbart import PLBartConfig


logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "uclanlp/plbart-base"
_CONFIG_FOR_DOC = "PLBartConfig"

PLBART_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "uclanlp/plbart-base",
    "uclanlp/plbart-cs-java",
    "uclanlp/plbart-multi_task-all",
    # See all PLBART models at https://huggingface.co/models?filter=plbart
]


# Copied from transformers.models.mbart.modeling_mbart.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
    """
    Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
    have a single `decoder_start_token_id` in contrast to other Bart-like models.
    """
    # 复制输入的 input_ids 张量
    prev_output_tokens = input_ids.clone()

    # 如果 pad_token_id 为 None,则抛出值错误
    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    
    # 将标签中可能存在的 -100 值替换为 pad_token_id
    prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)

    # 计算每个样本中非 pad_token_id 的最后一个 token 的位置索引
    index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
    
    # 获取 decoder_start_tokens
    decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
    
    # 将 input_ids 向右移动一个 token,并用 decoder_start_tokens 包装最后一个非 pad token
    prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
    prev_output_tokens[:, 0] = decoder_start_tokens

    return prev_output_tokens
# 从transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding复制过来,改为使用PLBart
class PLBartLearnedPositionalEmbedding(nn.Embedding):
    """
    这个模块学习位置编码,最大长度固定。
    """

    def __init__(self, num_embeddings: int, embedding_dim: int):
        # 对于PLBart,如果指定了padding_idx,则将embedding id偏移2,并相应调整num_embeddings。其他模型没有这个hack。
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

    def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
        """`input_ids'的形状预期为[bsz x seqlen]。"""

        bsz, seq_len = input_ids.shape[:2]
        # 根据设备类型和过去键值对的长度,创建位置张量
        positions = torch.arange(
            past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
        ).expand(bsz, -1)

        return super().forward(positions + self.offset)


# 从transformers.models.bart.modeling_bart.BartAttention复制过来,改为使用PLBart
class PLBartAttention(nn.Module):
    """来自'Attention Is All You Need'论文的多头注意力模块"""

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        bias: bool = True,
        is_causal: bool = False,
        config: Optional[PLBartConfig] = None,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        self.config = config

        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim必须能被num_heads整除 (得到 `embed_dim`: {self.embed_dim}"
                f" 和 `num_heads`: {num_heads})."
            )
        self.scaling = self.head_dim**-0.5
        self.is_decoder = is_decoder
        self.is_causal = is_causal

        # 初始化线性层,用于查询(q_proj)、键(k_proj)、值(v_proj)和输出(out_proj)的投影
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        # 重新塑造张量形状以便多头注意力计算
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        key_value_states: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        ):
        # 省略了具体的前向传播过程
        pass


# 从transformers.models.bart.modeling_bart.BartEncoderLayer复制过来,改为使用PLBart,BART->PLBART
class PLBartEncoderLayer(nn.Module):
    # 初始化函数,用于创建一个新的编码器层对象
    def __init__(self, config: PLBartConfig):
        # 调用父类的初始化方法
        super().__init__()
        # 设置嵌入维度为配置文件中的模型维度
        self.embed_dim = config.d_model

        # 初始化自注意力机制,根据配置选择不同的实现类
        self.self_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation](
            embed_dim=self.embed_dim,
            num_heads=config.encoder_attention_heads,
            dropout=config.attention_dropout,
            config=config,
        )

        # 初始化自注意力层的 LayerNorm
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)

        # 设置丢弃率为配置文件中定义的丢弃率
        self.dropout = config.dropout

        # 根据配置选择激活函数
        self.activation_fn = ACT2FN[config.activation_function]

        # 设置激活函数的丢弃率
        self.activation_dropout = config.activation_dropout

        # 第一个全连接层,将嵌入维度映射到编码器中的FFN维度
        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)

        # 第二个全连接层,将FFN维度映射回嵌入维度
        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)

        # 最终的 LayerNorm 层,用于标准化最终输出
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        """
        # 拷贝输入的隐藏状态作为残差连接的基础
        residual = hidden_states
        # 使用自注意力机制处理隐藏状态,获取处理后的隐藏状态、注意力权重及额外输出
        hidden_states, attn_weights, _ = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            layer_head_mask=layer_head_mask,
            output_attentions=output_attentions,
        )
        # 对处理后的隐藏状态应用丢弃机制
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        # 残差连接:将原始隐藏状态与处理后的隐藏状态相加
        hidden_states = residual + hidden_states
        # 对残差连接后的隐藏状态进行层归一化
        hidden_states = self.self_attn_layer_norm(hidden_states)

        # 再次使用残差连接的方法处理隐藏状态
        residual = hidden_states
        # 对处理后的隐藏状态应用激活函数和线性变换 fc1
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        # 对处理后的隐藏状态应用激活函数的丢弃机制
        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
        # 应用第二个线性变换 fc2
        hidden_states = self.fc2(hidden_states)
        # 对处理后的隐藏状态应用丢弃机制
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        # 残差连接:将原始隐藏状态与处理后的隐藏状态相加
        hidden_states = residual + hidden_states
        # 对残差连接后的隐藏状态进行层归一化
        hidden_states = self.final_layer_norm(hidden_states)

        # 如果隐藏状态的数据类型为 float16 并且存在无穷大或 NaN 值,则进行值的截断处理
        if hidden_states.dtype == torch.float16 and (
            torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
        ):
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

        # 构建输出元组
        outputs = (hidden_states,)

        # 如果需要输出注意力权重,则将注意力权重添加到输出元组中
        if output_attentions:
            outputs += (attn_weights,)

        # 返回最终输出元组
        return outputs
# TODO: Implement attention with SDPA for PLBart.
# 定义了一个字典,用于存储不同实现方式的注意力机制类,"eager"对应的实现类为PLBartAttention
PLBART_ATTENTION_CLASSES = {"eager": PLBartAttention}


# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->PLBart, BART->PLBART
# 定义了PLBart解码器的一个层,继承自nn.Module
class PLBartDecoderLayer(nn.Module):
    def __init__(self, config: PLBartConfig):
        super().__init__()
        self.embed_dim = config.d_model

        # 使用配置中的注意力实现类创建自注意力层
        self.self_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation](
            embed_dim=self.embed_dim,
            num_heads=config.decoder_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
            is_causal=True,
            config=config,
        )
        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.activation_function]
        self.activation_dropout = config.activation_dropout

        # 对自注意力输出进行层归一化
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)

        # 使用配置中的注意力实现类创建编码器注意力层
        self.encoder_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation](
            self.embed_dim,
            config.decoder_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
            config=config,
        )
        # 对编码器注意力输出进行层归一化
        self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)

        # 第一个全连接层和第二个全连接层,用于多头注意力的前馈神经网络
        self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
        self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)

        # 对最终输出进行层归一化
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    # 前向传播函数,定义了层的计算逻辑
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = True,
    ):
        # 省略了具体的前向传播逻辑,这里应包含对输入数据的处理和层之间的连接


# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->PLBart
# 用于句子级分类任务的头部
class PLBartClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(
        self,
        input_dim: int,
        inner_dim: int,
        num_classes: int,
        pooler_dropout: float,
    ):
        super().__init__()
        # 全连接层,将输入维度转换为内部维度
        self.dense = nn.Linear(input_dim, inner_dim)
        self.dropout = nn.Dropout(p=pooler_dropout)
        # 输出投影层,将内部维度转换为类别数量的维度
        self.out_proj = nn.Linear(inner_dim, num_classes)

    # 前向传播函数,定义了层的计算逻辑
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.dense(hidden_states)
        hidden_states = torch.tanh(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.out_proj(hidden_states)
        return hidden_states


# 定义了PLBart预训练模型的基类,继承自PreTrainedModel
class PLBartPreTrainedModel(PreTrainedModel):
    config_class = PLBartConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    # 不需要分割的模块列表,这些模块在初始化权重时不会进行处理
    _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"]
    
    # 初始化神经网络模块的权重
    def _init_weights(self, module):
        # 从配置中获取初始化的标准差
        std = self.config.init_std
        # 如果当前模块是线性层
        if isinstance(module, nn.Linear):
            # 使用正态分布初始化权重,均值为0,标准差为配置中的std
            module.weight.data.normal_(mean=0.0, std=std)
            # 如果模块有偏置项,则将偏置项初始化为0
            if module.bias is not None:
                module.bias.data.zero_()
        # 如果当前模块是嵌入层
        elif isinstance(module, nn.Embedding):
            # 使用正态分布初始化权重,均值为0,标准差为配置中的std
            module.weight.data.normal_(mean=0.0, std=std)
            # 如果嵌入层有填充索引,则将填充索引位置的权重初始化为0
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
PLBART_START_DOCSTRING = r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`PLBartConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

PLBART_GENERATION_EXAMPLE = r"""
    Mask-filling example:

    ```
    >>> from transformers import AutoTokenizer, PLBartForConditionalGeneration

    >>> model = PLBartForConditionalGeneration.from_pretrained("uclanlp/plbart-base")
    >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base")

    >>> # en_XX is the language symbol id <LID> for English
    >>> TXT = "<s> Is 0 the <mask> Fibonacci number ? </s> en_XX"
    >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="pt").input_ids

    >>> logits = model(input_ids).logits
    >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
    >>> probs = logits[0, masked_index].softmax(dim=0)
    >>> values, predictions = probs.topk(5)

    >>> tokenizer.decode(predictions).split()
    ['first', 'same', 'highest', 'result', 'number']
    ```
"""

PLBART_INPUTS_DOCSTRING = r"""
"""


# Copied from transformers.models.bart.modeling_bart.BartEncoder with Bart->PLBart
class PLBartEncoder(PLBartPreTrainedModel):
    """
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    [`PLBartEncoderLayer`].

    Args:
        config: PLBartConfig
        embed_tokens (nn.Embedding): output embedding
    """
    # 初始化函数,用于初始化模型的各个组件
    def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = None):
        # 调用父类的初始化方法,传入配置参数
        super().__init__(config)

        # 从配置中获取 dropout 的设置
        self.dropout = config.dropout
        # 从配置中获取 encoder_layerdrop 的设置
        self.layerdrop = config.encoder_layerdrop

        # 从配置中获取嵌入维度
        embed_dim = config.d_model
        # 从配置中获取填充标记的索引
        self.padding_idx = config.pad_token_id
        # 从配置中获取最大源序列长度
        self.max_source_positions = config.max_position_embeddings
        # 根据配置决定是否对嵌入进行缩放
        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

        # 创建词嵌入层,vocab_size 是词汇表大小,embed_dim 是词嵌入维度,padding_idx 是填充标记的索引
        self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)

        # 如果传入了预训练的 embed_tokens,则使用传入的权重
        if embed_tokens is not None:
            self.embed_tokens.weight = embed_tokens.weight

        # 创建学习位置编码的对象,max_position_embeddings 是最大位置编码的长度,embed_dim 是嵌入维度
        self.embed_positions = PLBartLearnedPositionalEmbedding(
            config.max_position_embeddings,
            embed_dim,
        )

        # 创建多层编码器,每层使用相同的配置参数
        self.layers = nn.ModuleList([PLBartEncoderLayer(config) for _ in range(config.encoder_layers)])

        # 根据配置判断是否使用特定的注意力实现方法
        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
        self._use_sdpa = config._attn_implementation == "sdpa"

        # 创建嵌入层的 LayerNorm 层,用于归一化嵌入层输出
        self.layernorm_embedding = nn.LayerNorm(embed_dim)

        # 默认关闭梯度检查点
        self.gradient_checkpointing = False

        # 初始化权重并应用最终处理
        self.post_init()

    # 返回嵌入层的方法
    def get_input_embeddings(self):
        return self.embed_tokens

    # 设置嵌入层的方法,接受一个新的嵌入层作为参数并赋值给当前嵌入层
    def set_input_embeddings(self, value):
        self.embed_tokens = value

    # 前向传播函数,接受多个输入参数并返回模型输出
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
# 从transformers.models.bart.modeling_bart.BartDecoder复制的代码,修改为PLBartDecoder类
class PLBartDecoder(PLBartPreTrainedModel):
    """
    Transformer解码器,由config.decoder_layers层组成。每一层是一个[`PLBartDecoderLayer`]

    Args:
        config: PLBartConfig
        embed_tokens (nn.Embedding): 输出的嵌入层
    """

    def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = None):
        super().__init__(config)
        self.dropout = config.dropout  # 获取config中的dropout值
        self.layerdrop = config.decoder_layerdrop  # 获取config中的decoder_layerdrop值
        self.padding_idx = config.pad_token_id  # 获取config中的pad_token_id值
        self.max_target_positions = config.max_position_embeddings  # 获取config中的max_position_embeddings值
        self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0  # 根据config中的scale_embedding决定是否使用嵌入缩放

        self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)  # 创建词嵌入层

        if embed_tokens is not None:
            self.embed_tokens.weight = embed_tokens.weight  # 如果提供了embed_tokens,则使用提供的权重初始化嵌入层

        self.embed_positions = PLBartLearnedPositionalEmbedding(
            config.max_position_embeddings,
            config.d_model,
        )  # 创建位置编码层

        self.layers = nn.ModuleList([PLBartDecoderLayer(config) for _ in range(config.decoder_layers)])  # 创建多层解码器层

        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"  # 检查是否使用FlashAttention 2.0实现
        self._use_sdpa = config._attn_implementation == "sdpa"  # 检查是否使用Scaled Dot-Product Attention (SDPA)

        self.layernorm_embedding = nn.LayerNorm(config.d_model)  # 创建层归一化层

        self.gradient_checkpointing = False  # 是否启用梯度检查点

        # 初始化权重并进行最终处理
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens  # 返回输入的嵌入层

    def set_input_embeddings(self, value):
        self.embed_tokens = value  # 设置输入的嵌入层为给定的值

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    def __init__(self, config: PLBartConfig):
        # 调用父类的构造函数,传入配置对象
        super().__init__(config)

        # 从配置对象中获取填充标记索引和词汇表大小
        padding_idx, vocab_size = config.pad_token_id, config.vocab_size
        # 创建一个共享的词嵌入层,将词汇表映射到模型维度,并使用填充标记索引进行填充
        self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)

        # 创建编码器和解码器,共享词嵌入层
        self.encoder = PLBartEncoder(config, self.shared)
        self.decoder = PLBartDecoder(config, self.shared)

        # 初始化模型权重
        self.init_weights()

    def get_input_embeddings(self):
        # 返回共享的输入词嵌入层
        return self.shared

    def set_input_embeddings(self, value):
        # 设置新的输入词嵌入层,并更新编码器和解码器的词嵌入层
        self.shared = value
        self.encoder.embed_tokens = self.shared
        self.decoder.embed_tokens = self.shared

    def _tie_weights(self):
        # 如果配置要求共享词嵌入层,则绑定编码器和解码器的词嵌入层权重
        if self.config.tie_word_embeddings:
            self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
            self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

    def get_encoder(self):
        # 返回编码器
        return self.encoder

    def get_decoder(self):
        # 返回解码器
        return self.decoder

    @add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=Seq2SeqModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.LongTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
# 使用自定义的文档字符串修饰符添加模型的开始文档字符串,描述了 PLBART 模型带有语言建模头部,适用于代码到文本、文本到代码和代码到代码的任务。
# 这里引用了 PLBART_START_DOCSTRING 中定义的常量。
@add_start_docstrings(
    "The PLBART Model with a language modeling head. Can be used for code-to-text, text-to-code and code-to-code.",
    PLBART_START_DOCSTRING,
)
# 定义 PLBartForConditionalGeneration 类,继承自 PLBartPreTrainedModel
class PLBartForConditionalGeneration(PLBartPreTrainedModel):
    # 指定基础模型的前缀为 "model"
    base_model_prefix = "model"
    # 在加载过程中忽略的键名列表
    _keys_to_ignore_on_load_missing = ["final_logits_bias"]
    # 被绑定权重的键名列表
    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]

    # 初始化方法,接受一个 PLBartConfig 类型的参数 config
    def __init__(self, config: PLBartConfig):
        # 调用父类的初始化方法
        super().__init__(config)
        # 创建 PLBartModel 实例并赋值给 self.model
        self.model = PLBartModel(config)
        # 注册一个缓冲区 final_logits_bias,用全零张量填充,形状为 (1, self.model.shared.num_embeddings)
        self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
        # 创建一个线性层 lm_head,输入维度为 config.d_model,输出维度为 self.model.shared.num_embeddings,没有偏置
        self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)

        # 调用初始化权重的方法
        self.init_weights()

    # 获取编码器部分的方法
    def get_encoder(self):
        return self.model.get_encoder()

    # 获取解码器部分的方法
    def get_decoder(self):
        return self.model.get_decoder()

    # 调整 token embeddings 大小的方法,返回新的嵌入层
    def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
        # 调用父类的 resize_token_embeddings 方法
        new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
        # 调用 _resize_final_logits_bias 方法,调整 final_logits_bias 的大小以匹配新的 token embeddings
        self._resize_final_logits_bias(new_embeddings.weight.shape[0])
        # 返回新的嵌入层
        return new_embeddings

    # 调整 final_logits_bias 大小的私有方法,没有返回值
    def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
        # 获取旧的 token 数量
        old_num_tokens = self.final_logits_bias.shape[-1]
        # 如果新的 token 数量小于等于旧的 token 数量,则截取 final_logits_bias
        if new_num_tokens <= old_num_tokens:
            new_bias = self.final_logits_bias[:, :new_num_tokens]
        # 否则,创建额外的零偏置,扩展 final_logits_bias
        else:
            extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
            new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
        # 注册一个缓冲区 final_logits_bias,更新为新的偏置
        self.register_buffer("final_logits_bias", new_bias)

    # 获取输出嵌入层 lm_head 的方法
    def get_output_embeddings(self):
        return self.lm_head

    # 设置输出嵌入层 lm_head 的方法
    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    # 使用自定义的文档字符串修饰符添加到模型的前向方法,描述了 PLBART_INPUTS_DOCSTRING 定义的输入文档字符串
    # 以及返回类型为 Seq2SeqLMOutput,使用 _CONFIG_FOR_DOC 指定的配置类,并附加 PLBART_GENERATION_EXAMPLE 的结尾文档字符串
    @add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
    @add_end_docstrings(PLBART_GENERATION_EXAMPLE)
    # 定义一个方法用于模型的前向传播
    def forward(
        self,
        # 输入序列的标识符,类型为可选的长整型张量
        input_ids: Optional[torch.LongTensor] = None,
        # 注意力遮罩,类型为可选的长整型张量
        attention_mask: Optional[torch.LongTensor] = None,
        # 解码器的输入序列标识符,类型为可选的长整型张量
        decoder_input_ids: Optional[torch.LongTensor] = None,
        # 解码器的注意力遮罩,类型为可选的张量
        decoder_attention_mask: Optional[torch.Tensor] = None,
        # 头部遮罩,类型为可选的张量
        head_mask: Optional[torch.Tensor] = None,
        # 解码器的头部遮罩,类型为可选的长整型张量
        decoder_head_mask: Optional[torch.LongTensor] = None,
        # 交叉注意力头部遮罩,类型为可选的张量
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        # 编码器输出的列表,类型为可选的浮点数张量列表
        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
        # 过去的键值对,类型为可选的浮点数张量列表
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        # 输入的嵌入向量,类型为可选的浮点数张量
        inputs_embeds: Optional[torch.FloatTensor] = None,
        # 解码器输入的嵌入向量,类型为可选的浮点数张量
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        # 标签,类型为可选的张量
        labels: Optional[torch.Tensor] = None,
        # 是否使用缓存,类型为可选的布尔值
        use_cache: Optional[bool] = None,
        # 是否输出注意力权重,类型为可选的布尔值
        output_attentions: Optional[bool] = None,
        # 是否输出隐藏状态,类型为可选的布尔值
        output_hidden_states: Optional[bool] = None,
        # 是否返回字典格式的结果,类型为可选的布尔值
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:
            Depending on `return_dict`:
            - if `False` (default), returns a tuple with `lm_logits` followed by various model outputs.
            - if `True`, returns a `Seq2SeqLMOutput` object containing loss, logits, and other outputs.

        """
        # Determine whether to use `return_dict` from self.config or override with provided value
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Adjust `decoder_input_ids` if not provided, using shifted `labels` for autoregressive decoding
        if labels is not None:
            if decoder_input_ids is None and decoder_inputs_embeds is None:
                decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)

        # Forward pass through the model with specified inputs and optional masks and caches
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # Compute logits for the language model head and apply a bias if provided
        lm_logits = self.lm_head(outputs[0])
        lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)

        masked_lm_loss = None
        # Compute masked language modeling loss if `labels` are provided
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))

        # Prepare output depending on whether `return_dict` is `False` or `True`
        if not return_dict:
            # Return a tuple with `lm_logits` followed by other model outputs
            output = (lm_logits,) + outputs[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
        else:
            # Return a `Seq2SeqLMOutput` object with loss, logits, and various model outputs
            return Seq2SeqLMOutput(
                loss=masked_lm_loss,
                logits=lm_logits,
                past_key_values=outputs.past_key_values,
                decoder_hidden_states=outputs.decoder_hidden_states,
                decoder_attentions=outputs.decoder_attentions,
                cross_attentions=outputs.cross_attentions,
                encoder_last_hidden_state=outputs.encoder_last_hidden_state,
                encoder_hidden_states=outputs.encoder_hidden_states,
                encoder_attentions=outputs.encoder_attentions,
            )
    def prepare_inputs_for_generation(
        self,
        decoder_input_ids: torch.LongTensor,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
        **kwargs,  # TODO: Check if this is needed. It is unused?
    ) -> Dict[str, Any]:
        # 如果使用了过去的键值(past_key_values),则根据其长度裁剪decoder_input_ids
        if past_key_values is not None:
            past_length = past_key_values[0][0].shape[2]

            # 一些生成方法已经只传递最后一个输入 ID
            if decoder_input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                # 默认旧行为:只保留最后一个 ID
                remove_prefix_length = decoder_input_ids.shape[1] - 1

            decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]

        # 返回一个字典,包含用于生成的输入参数
        return {
            "input_ids": None,  # encoder_outputs 已经定义。input_ids 不需要
            "encoder_outputs": encoder_outputs,
            "past_key_values": past_key_values,
            "decoder_input_ids": decoder_input_ids,
            "attention_mask": attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,  # 更改此项以避免缓存(可能用于调试目的)
        }

    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        # 将标签右移一个位置,用于解码器输入
        return shift_tokens_right(labels, self.config.pad_token_id)

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            # 缓存的交叉注意力状态无需重新排序 -> 它们始终相同
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
                + layer_past[2:],
            )
        return reordered_past
# 给 PLBartForSequenceClassification 类添加文档字符串,描述其作为 PLBart 模型的序列分类器及其顶部的线性层用途
@add_start_docstrings(
    """
    PLBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for code
    classification.
    """,
    PLBART_START_DOCSTRING,
)
class PLBartForSequenceClassification(PLBartPreTrainedModel):
    # 定义权重绑定的键列表,用于共享编码和解码器的嵌入层权重
    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

    def __init__(self, config: PLBartConfig, **kwargs):
        # 调用父类的初始化方法
        super().__init__(config, **kwargs)
        # 创建 PLBartModel 实例作为主模型
        self.model = PLBartModel(config)
        # 创建 PLBartClassificationHead 实例作为分类器的头部
        self.classification_head = PLBartClassificationHead(
            config.d_model,
            config.d_model,
            config.num_labels,
            config.classifier_dropout,
        )

        # 初始化权重并进行最终处理
        self.post_init()

    # 给 forward 方法添加文档字符串,描述其输入和输出,引用输入文档和代码示例文档
    @add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=Seq2SeqSequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    # 从 transformers.models.bart.modeling_bart.BartForSequenceClassification.forward 复制而来
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 实现 PLBartForSequenceClassification 的前向传播逻辑,详细参数请参见 transformers 文档
        pass  # Placeholder for actual implementation

# 从 transformers.models.bart.modeling_bart.BartDecoderWrapper 复制而来,修改 Bart->PLBart
class PLBartDecoderWrapper(PLBartPreTrainedModel):
    """
    This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
    used in combination with the [`EncoderDecoderModel`] framework.
    """

    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)
        # 创建 PLBartDecoder 实例作为解码器
        self.decoder = PLBartDecoder(config)

    def forward(self, *args, **kwargs):
        # 调用 PLBartDecoder 的 forward 方法进行前向传播
        return self.decoder(*args, **kwargs)


# 从 transformers.models.bart.modeling_bart.BartForCausalLM 复制而来,修改 Bart->PLBart,facebook/bart-base->uclanlp/plbart-base
class PLBartForCausalLM(PLBartPreTrainedModel):
    # 定义权重绑定的键列表,用于共享语言模型头部的权重
    _tied_weights_keys = ["lm_head.weight"]
    # 初始化方法,接受一个配置对象作为参数
    def __init__(self, config):
        # 深拷贝配置对象,以免修改原始配置
        config = copy.deepcopy(config)
        # 设置标志位表明当前实例是解码器
        config.is_decoder = True
        # 设置标志位表明当前实例不是编码器解码器结构
        config.is_encoder_decoder = False
        # 调用父类初始化方法,传入深拷贝后的配置对象
        super().__init__(config)
        # 使用配置对象初始化 PLBartDecoderWrapper 模型
        self.model = PLBartDecoderWrapper(config)

        # 初始化语言模型头部,使用线性层将隐藏状态映射到词汇表大小
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # 初始化权重并应用最终处理
        self.post_init()

    # 获取输入嵌入层对象
    def get_input_embeddings(self):
        return self.model.decoder.embed_tokens

    # 设置输入嵌入层对象
    def set_input_embeddings(self, value):
        self.model.decoder.embed_tokens = value

    # 获取输出嵌入层对象
    def get_output_embeddings(self):
        return self.lm_head

    # 设置输出嵌入层对象
    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    # 设置解码器对象
    def set_decoder(self, decoder):
        self.model.decoder = decoder

    # 获取解码器对象
    def get_decoder(self):
        return self.model.decoder

    # 前向传播方法,接收多种输入参数,使用装饰器替换返回值注释
    @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 准备用于生成的输入数据,处理输入的各种条件参数
        def prepare_inputs_for_generation(
            self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
        ):
            # 如果没有给定注意力掩码,则创建一个全为1的注意力掩码
            if attention_mask is None:
                attention_mask = input_ids.new_ones(input_ids.shape)

            # 如果有过去的键值对,计算过去长度,并截取相应的输入ID
            if past_key_values:
                past_length = past_key_values[0][0].shape[2]

                # 一些生成方法已经只传递了最后一个输入ID
                if input_ids.shape[1] > past_length:
                    remove_prefix_length = past_length
                else:
                    # 默认行为:保留最后一个输入ID
                    remove_prefix_length = input_ids.shape[1] - 1

                input_ids = input_ids[:, remove_prefix_length:]

            # 返回准备好的生成输入数据字典
            return {
                "input_ids": input_ids,  # encoder_outputs is defined. input_ids not needed
                "attention_mask": attention_mask,
                "past_key_values": past_key_values,
                "use_cache": use_cache,
            }
    # 定义函数 _reorder_cache,用于重新排序缓存中的过去键值
    def _reorder_cache(past_key_values, beam_idx):
        # 初始化重新排序后的过去键值元组
        reordered_past = ()
        # 遍历每个层级的过去键值
        for layer_past in past_key_values:
            # 对每个层级的过去状态进行索引选择,根据 beam_idx 重新排序,并转移到与 past_state 相同的设备上
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        # 返回重新排序后的过去键值元组
        return reordered_past

.\models\plbart\tokenization_plbart.py

# 设置文件编码为 UTF-8
# Copyright 2022, UCLA NLP, The Facebook AI Research Team Authors and The HuggingFace Inc. team.
#
# 根据 Apache License, Version 2.0 许可,除非符合许可要求,否则禁止使用此文件
# 可在以下链接获取许可的副本:http://www.apache.org/licenses/LICENSE-2.0
#
# 如果法律要求或书面同意,本软件按"原样"分发,不提供任何明示或暗示的担保或条件
# 详细信息请查看许可证
import os  # 导入操作系统相关功能模块
from shutil import copyfile  # 导入 shutil 库中的文件复制函数 copyfile
from typing import Any, Dict, List, Optional, Tuple  # 导入类型提示相关模块

import sentencepiece as spm  # 导入 sentencepiece 库,用于处理文本分词

from ...tokenization_utils import AddedToken, BatchEncoding, PreTrainedTokenizer  # 导入自定义的 tokenization_utils 模块中的类和函数
from ...utils import logging  # 导入自定义的 logging 模块中的日志功能

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

SPIECE_UNDERLINE = "▁"  # 定义一个特殊符号常量 SPIECE_UNDERLINE

VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
# 定义一个字典常量 VOCAB_FILES_NAMES,用于存储词汇文件和分词器文件的名称

PRETRAINED_VOCAB_FILES_MAP = {
    # 定义一个预训练词汇文件映射的字典常量 PRETRAINED_VOCAB_FILES_MAP
    # 定义一个包含多个模型的字典,每个模型关联一个 URL,用于获取其对应的词汇表文件
    "vocab_file": {
        "uclanlp/plbart-base": "https://huggingface.co/uclanlp/plbart-base/resolve/main/sentencepiece.bpe.model",
        "uclanlp/plbart-c-cpp-defect-detection": (
            "https://huggingface.co/uclanlp/plbart-c-cpp-defect-detection/resolve/main/sentencepiece.bpe.model"
        ),
        "uclanlp/plbart-cs-java": "https://huggingface.co/uclanlp/plbart-cs-java/resolve/main/sentencepiece.bpe.model",
        "uclanlp/plbart-en_XX-java": (
            "https://huggingface.co/uclanlp/plbart-en_XX-java/resolve/main/sentencepiece.bpe.model"
        ),
        "uclanlp/plbart-go-en_XX": (
            "https://huggingface.co/uclanlp/plbart-go-en_XX/resolve/main/sentencepiece.bpe.model"
        ),
        "uclanlp/plbart-java-clone-detection": (
            "https://huggingface.co/uclanlp/plbart-java-clone-detection/resolve/main/sentencepiece.bpe.model"
        ),
        "uclanlp/plbart-java-cs": "https://huggingface.co/uclanlp/plbart-java-cs/resolve/main/sentencepiece.bpe.model",
        "uclanlp/plbart-java-en_XX": (
            "https://huggingface.co/uclanlp/plbart-java-en_XX/resolve/main/sentencepiece.bpe.model"
        ),
        "uclanlp/plbart-javascript-en_XX": (
            "https://huggingface.co/uclanlp/plbart-javascript-en_XX/resolve/main/sentencepiece.bpe.model"
        ),
        "uclanlp/plbart-php-en_XX": (
            "https://huggingface.co/uclanlp/plbart-php-en_XX/resolve/main/sentencepiece.bpe.model"
        ),
        "uclanlp/plbart-python-en_XX": (
            "https://huggingface.co/uclanlp/plbart-python-en_XX/resolve/main/sentencepiece.bpe.model"
        ),
        "uclanlp/plbart-refine-java-medium": (
            "https://huggingface.co/uclanlp/plbart-refine-java-medium/resolve/main/sentencepiece.bpe.model"
        ),
        "uclanlp/plbart-refine-java-small": (
            "https://huggingface.co/uclanlp/plbart-refine-java-small/resolve/main/sentencepiece.bpe.model"
        ),
        "uclanlp/plbart-ruby-en_XX": (
            "https://huggingface.co/uclanlp/plbart-ruby-en_XX/resolve/main/sentencepiece.bpe.model"
        ),
    }
# 定义预训练位置嵌入的大小
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "uclanlp/plbart-base": 1024,
    "uclanlp/plbart-c-cpp-defect-detection": 1024,
    "uclanlp/plbart-cs-java": 1024,
    "uclanlp/plbart-en_XX-java": 1024,
    "uclanlp/plbart-go-en_XX": 1024,
    "uclanlp/plbart-java-clone-detection": 1024,
    "uclanlp/plbart-java-cs": 1024,
    "uclanlp/plbart-java-en_XX": 1024,
    "uclanlp/plbart-javascript-en_XX": 1024,
    "uclanlp/plbart-php-en_XX": 1024,
    "uclanlp/plbart-python-en_XX": 1024,
    "uclanlp/plbart-refine-java-medium": 1024,
    "uclanlp/plbart-refine-java-small": 1024,
    "uclanlp/plbart-ruby-en_XX": 1024,
}
# 定义 Fairseq 语言代码
FAIRSEQ_LANGUAGE_CODES = {
    "base": ["__java__", "__python__", "__en_XX__"],
    "multi": ["__java__", "__python__", "__en_XX__", "__javascript__", "__php__", "__ruby__", "__go__"],
}
# 定义 Fairseq 语言代码的映射
FAIRSEQ_LANGUAGE_CODES_MAP = {
    "java": "__java__",
    "python": "__python__",
    "en_XX": "__en_XX__",
    "javascript": "__javascript__",
    "php": "__php__",
    "ruby": "__ruby__",
    "go": "__go__",
}
# 定义 PLBartTokenizer 类
class PLBartTokenizer(PreTrainedTokenizer):
    """
    Construct an PLBART tokenizer.

    Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
    [SentencePiece](https://github.com/google/sentencepiece).

    The tokenization method is `<tokens> <eos> <language code>` for source language documents, and `<language code>
    <tokens> <eos>` for target language documents.
    Args:
        vocab_file (`str`):
            Path to the vocabulary file. This specifies the location of the vocabulary file to be used by the tokenizer.
        src_lang (`str`, *optional*):
            A string representing the source language. If provided, specifies the source language for the tokenizer.
        tgt_lang (`str`, *optional*):
            A string representing the target language. If provided, specifies the target language for the tokenizer.
        bos_token (`str`, *optional*, defaults to `"<s>"`):
            The start of sequence token. Defines the token used to mark the beginning of a sequence.
        eos_token (`str`, *optional*, defaults to `"</s>"`):
            The end of sequence token. Defines the token used to mark the end of a sequence.
        sep_token (`str`, *optional*, defaults to `"</s>"`):
            The separator token. Used in scenarios like sequence classification or question answering to separate sequences.
        cls_token (`str`, *optional*, defaults to `"<s>"`):
            The classification token. This token is used as the first token for all tasks.
        unk_token (`str`, *optional*, defaults to `"<unk>"`):
            The unknown token. If a token is not found in the vocabulary, it is replaced with this token.
        pad_token (`str`, *optional*, defaults to `"<pad>"`):
            The padding token. Used to pad sequences to the same length during batching.
        mask_token(`str`, *optional*, defaults to `"<mask>"`):
            The mask token. Used in masking tasks during training. Not used in multi-tokenizer scenarios.
        language_codes (`str`, *optional*, defaults to `"base"`):
            Specifies what language codes to use. Can be `"base"` or `"multi"`.
        sp_model_kwargs (`dict`, *optional*):
            Additional arguments passed to the `SentencePieceProcessor.__init__()` method. These parameters can configure
            subword regularization and other SentencePiece settings like `enable_sampling`, `nbest_size`, and `alpha`.
            See the [Python wrapper for SentencePiece](https://github.com/google/sentencepiece/tree/master/python) for details.
    Examples:

    ```
    >>> from transformers import PLBartTokenizer

    >>> tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-python-en_XX", src_lang="python", tgt_lang="en_XX")
    # 定义示例的 Python 代码短语和其对应的英文翻译,用于模型的输入
    example_python_phrase = "def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])"
    expected_translation_english = "Returns the maximum value of a b c."
    # 使用预训练模型的 tokenizer 处理示例的 Python 代码和其对应的英文翻译,返回 PyTorch 张量
    inputs = tokenizer(example_python_phrase, text_target=expected_translation_english, return_tensors="pt")
    
    vocab_files_names = VOCAB_FILES_NAMES  # 加载预训练模型的词汇文件名列表
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES  # 加载预训练模型的最大输入尺寸列表
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP  # 加载预训练模型的词汇文件映射表
    model_input_names = ["input_ids", "attention_mask"]  # 定义模型输入的名称列表
    
    prefix_tokens: List[int] = []  # 初始化前缀 tokens 列表
    suffix_tokens: List[int] = []  # 初始化后缀 tokens 列表
    
    def __init__(
        self,
        vocab_file,
        bos_token="<s>",
        eos_token="</s>",
        sep_token="</s>",
        cls_token="<s>",
        unk_token="<unk>",
        pad_token="<pad>",
        mask_token="<mask>",
        language_codes="base",
        tokenizer_file=None,
        src_lang=None,
        tgt_lang=None,
        sp_model_kwargs: Optional[Dict[str, Any]] = None,
        additional_special_tokens=None,
        **kwargs,
    ):
        # 初始化函数,设置各种参数和属性
    
    def __getstate__(self):
        # 序列化对象状态时调用,返回对象的字典形式状态
        state = self.__dict__.copy()
        state["sp_model"] = None
        state["sp_model_proto"] = self.sp_model.serialized_model_proto()
        return state
    
    def __setstate__(self, d):
        # 反序列化对象状态时调用,恢复对象的状态
        self.__dict__ = d
    
        # 为了向后兼容性
        if not hasattr(self, "sp_model_kwargs"):
            self.sp_model_kwargs = {}
    
        # 加载 SentencePiece 模型并设置状态
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
    
    @property
    def vocab_size(self):
        # 计算词汇表的大小,考虑语言编码和偏移量
        if self.language_codes == "base":
            return (
                len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1
            )  # 加 1 用于 mask token
        else:
            return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset
    
    @property
    def src_lang(self) -> str:
        # 获取源语言代码
        return self._src_lang
    
    @src_lang.setter
    def src_lang(self, new_src_lang: str) -> None:
        # 设置源语言代码,并更新特殊 token
        new_src_lang = self._convert_lang_code_special_format(new_src_lang)
        self._src_lang = new_src_lang
        self.set_src_lang_special_tokens(self._src_lang)
    
    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ):
        # 获取特殊 token 的掩码,用于处理输入 token 的特殊性
    ) -> List[int]:
        """
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` method.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """

        # If the token list already has special tokens, delegate to superclass method
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        # Create lists of 1s corresponding to prefix and suffix tokens
        prefix_ones = [1] * len(self.prefix_tokens)
        suffix_ones = [1] * len(self.suffix_tokens)

        # If token_ids_1 is None, return tokens with prefix, sequence tokens (0s), and suffix
        if token_ids_1 is None:
            return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones

        # If token_ids_1 is provided, return tokens with prefix, token_ids_0, token_ids_1, and suffix
        return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
        adding special tokens. An PLBART sequence has the following format, where `X` represents the sequence:

        - `input_ids` (for encoder) `X [eos, src_lang_code]`
        - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]`

        BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
        separator.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
        """

        # If token_ids_1 is None, concatenate prefix, token_ids_0, and suffix tokens
        if token_ids_1 is None:
            return self.prefix_tokens + token_ids_0 + self.suffix_tokens

        # Otherwise, concatenate prefix, token_ids_0, token_ids_1, and suffix tokens
        return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Create token type IDs from a sequence or a pair of sequences for sequence classification tasks. This is used
        to distinguish between the two sequences in a model that supports sequence pairs.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs representing the first sequence.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs representing the second sequence in a pair.

        Returns:
            `List[int]`: List of token type IDs (0 or 1) indicating the sequence type for each token.
        """
    ) -> List[int]:
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. PLBart does not
        make use of token type ids, therefore a list of zeros is returned.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of zeros.
        """

        # Separator token ID used in sequence pairs
        sep = [self.sep_token_id]
        # CLS token ID used in sequence pairs
        cls = [self.cls_token_id]

        # If only one sequence is provided, return the mask for that sequence
        if token_ids_1 is None:
            return len(cls + token_ids_0 + sep) * [0]
        # If two sequences are provided, return the mask for both sequences concatenated
        return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]

    def _build_translation_inputs(
        self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
    ):
        """Used by translation pipeline, to prepare inputs for the generate function"""

        # Ensure source and target languages are provided
        if src_lang is None or tgt_lang is None:
            raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")

        # Convert source and target language codes to special format
        self.src_lang = self._convert_lang_code_special_format(src_lang)
        self.tgt_lang = self._convert_lang_code_special_format(tgt_lang)

        # Generate model inputs with special tokens and specified return type
        inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)

        # Convert target language to its corresponding token ID
        tgt_lang_id = self.convert_tokens_to_ids(self.tgt_lang)

        # Add forced beginning-of-sequence token ID to inputs
        inputs["forced_bos_token_id"] = tgt_lang_id

        return inputs

    def get_vocab(self):
        # Create a vocabulary dictionary mapping token strings to their IDs
        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
        # Include any additional tokens introduced during model training
        vocab.update(self.added_tokens_encoder)
        return vocab

    def _tokenize(self, text: str) -> List[str]:
        # Tokenize input text using SentencePiece model and return as list of strings
        return self.sp_model.encode(text, out_type=str)

    def _convert_token_to_id(self, token):
        """Converts a token (str) into an ID using the vocabulary."""
        
        # Check if the token exists in the fairseq mapping
        if token in self.fairseq_tokens_to_ids:
            return self.fairseq_tokens_to_ids[token]
        
        # Obtain token ID from SentencePiece model
        spm_id = self.sp_model.PieceToId(token)

        # Return unknown token ID if SentencePiece returns 0 (unknown token)
        return spm_id + self.fairseq_offset if spm_id else self.unk_token_id

    def _convert_id_to_token(self, index):
        """Converts an index (integer) into a token (str) using the vocabulary."""
        
        # Check if the index exists in the fairseq mapping
        if index in self.fairseq_ids_to_tokens:
            return self.fairseq_ids_to_tokens[index]
        
        # Convert index to token using SentencePiece model
        return self.sp_model.IdToPiece(index - self.fairseq_offset)

    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (strings for sub-words) into a single string."""
        
        # Concatenate tokens into a single string, replacing special sub-word marker with space
        out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
        return out_string
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # 检查保存目录是否存在,如果不存在则记录错误并返回
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return

        # 构建输出词汇表文件的路径
        out_vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )

        # 如果当前词汇表文件路径与输出路径不同且当前词汇表文件存在,则复制当前词汇表文件到输出路径
        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
            copyfile(self.vocab_file, out_vocab_file)
        # 如果当前词汇表文件不存在,则将序列化后的特殊模型内容写入输出路径
        elif not os.path.isfile(self.vocab_file):
            with open(out_vocab_file, "wb") as fi:
                content_spiece_model = self.sp_model.serialized_model_proto()
                fi.write(content_spiece_model)

        # 返回输出文件路径的元组
        return (out_vocab_file,)

    def prepare_seq2seq_batch(
        self,
        src_texts: List[str],
        src_lang: str = "en_XX",
        tgt_texts: Optional[List[str]] = None,
        tgt_lang: str = "python",
        **kwargs,
    ) -> BatchEncoding:
        # 将源语言代码转换为特殊格式
        self.src_lang = self._convert_lang_code_special_format(src_lang)
        # 将目标语言代码转换为特殊格式
        self.tgt_lang = self._convert_lang_code_special_format(tgt_lang)
        # 调用父类方法,准备序列到序列的批处理数据
        return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)

    def _switch_to_input_mode(self):
        # 切换到输入模式,设置源语言特殊标记
        return self.set_src_lang_special_tokens(self.src_lang)

    def _switch_to_target_mode(self):
        # 切换到目标模式,设置目标语言特殊标记
        return self.set_tgt_lang_special_tokens(self.tgt_lang)

    def set_src_lang_special_tokens(self, src_lang) -> None:
        """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
        # 将源语言代码转换为特殊格式
        src_lang = self._convert_lang_code_special_format(src_lang)
        # 根据转换后的源语言代码获取其对应的语言代码 ID
        self.cur_lang_code = self.lang_code_to_id[src_lang] if src_lang is not None else None
        # 清空前缀标记
        self.prefix_tokens = []
        # 如果当前语言代码不为 None,则后缀标记为[eos, 当前语言代码];否则后缀标记为[eos]
        if self.cur_lang_code is not None:
            self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
        else:
            self.suffix_tokens = [self.eos_token_id]

    def set_tgt_lang_special_tokens(self, lang: str) -> None:
        """Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code]."""
        # 将目标语言代码转换为特殊格式
        lang = self._convert_lang_code_special_format(lang)
        # 根据转换后的目标语言代码获取其对应的语言代码 ID
        self.cur_lang_code = self.lang_code_to_id[lang] if lang is not None else None
        # 清空前缀标记
        self.prefix_tokens = []
        # 如果当前语言代码不为 None,则后缀标记为[eos, 当前语言代码];否则后缀标记为[eos]
        if self.cur_lang_code is not None:
            self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
        else:
            self.suffix_tokens = [self.eos_token_id]

    def _convert_lang_code_special_format(self, lang: str) -> str:
        """Convert Language Codes to format tokenizer uses if required"""
        # 如果输入的语言代码在映射表中,则转换为对应的格式,否则保持不变
        lang = FAIRSEQ_LANGUAGE_CODES_MAP[lang] if lang in FAIRSEQ_LANGUAGE_CODES_MAP.keys() else lang
        return lang

.\models\plbart\__init__.py

# 导入类型检查模块
from typing import TYPE_CHECKING

# 导入所需的实用工具和依赖项
from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_sentencepiece_available,
    is_tokenizers_available,
    is_torch_available,
)

# 定义模块的导入结构,包括配置和模型相关内容
_import_structure = {"configuration_plbart": ["PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "PLBartConfig"]}

# 检查是否存在 SentencePiece 库,若不可用则引发异常
try:
    if not is_sentencepiece_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若可用,将 PLBartTokenizer 添加到导入结构中
    _import_structure["tokenization_plbart"] = ["PLBartTokenizer"]

# 检查是否存在 Torch 库,若不可用则引发异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若可用,将 PLBart 相关模型添加到导入结构中
    _import_structure["modeling_plbart"] = [
        "PLBART_PRETRAINED_MODEL_ARCHIVE_LIST",
        "PLBartForCausalLM",
        "PLBartForConditionalGeneration",
        "PLBartForSequenceClassification",
        "PLBartModel",
        "PLBartPreTrainedModel",
    ]

# 如果是类型检查模式
if TYPE_CHECKING:
    # 导入配置和模型相关内容
    from .configuration_plbart import PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP, PLBartConfig

    # 检查是否存在 SentencePiece 库,若不可用则引发异常
    try:
        if not is_sentencepiece_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 若可用,导入 PLBartTokenizer
        from .tokenization_plbart import PLBartTokenizer

    # 检查是否存在 Torch 库,若不可用则引发异常
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 若可用,导入 PLBart 相关模型
        from .modeling_plbart import (
            PLBART_PRETRAINED_MODEL_ARCHIVE_LIST,
            PLBartForCausalLM,
            PLBartForConditionalGeneration,
            PLBartForSequenceClassification,
            PLBartModel,
            PLBartPreTrainedModel,
        )

# 如果不是类型检查模式
else:
    # 使用懒加载模块来延迟加载依赖模块
    import sys
    # 将当前模块映射到 LazyModule,用以按需导入模块
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)

.\models\poolformer\configuration_poolformer.py

# coding=utf-8
# 声明文件编码格式为 UTF-8

# Copyright 2022 Sea AI Labs and The HuggingFace Inc. team. All rights reserved.
# 版权声明,保留所有权利

# Licensed under the Apache License, Version 2.0 (the "License");
# 授权许可声明,使用 Apache License, Version 2.0

# you may not use this file except in compliance with the License.
# 您除非遵守许可证,否则不得使用此文件。

# You may obtain a copy of the License at
# 您可以在以下网址获取许可证副本

#     http://www.apache.org/licenses/LICENSE-2.0
#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and
# limitations under the License.
# 请查阅许可证以获取详细的权限和限制信息。

""" PoolFormer model configuration"""
# 模型配置的说明文档

from collections import OrderedDict
# 导入 OrderedDict 类,用于创建有序字典

from typing import Mapping
# 导入 Mapping 类型提示,用于类型注解

from packaging import version
# 导入 version 模块,用于处理版本号

from ...configuration_utils import PretrainedConfig
# 导入预训练配置类

from ...onnx import OnnxConfig
# 导入 ONNX 配置类

from ...utils import logging
# 导入日志工具模块

logger = logging.get_logger(__name__)
# 获取当前模块的日志记录器

POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "sail/poolformer_s12": "https://huggingface.co/sail/poolformer_s12/resolve/main/config.json",
    # 定义预训练模型名称和对应的配置文件 URL
    # 可在 https://huggingface.co/models?filter=poolformer 查看所有 PoolFormer 模型
}


class PoolFormerConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of [`PoolFormerModel`]. It is used to instantiate a
    PoolFormer model according to the specified arguments, defining the model architecture. Instantiating a
    configuration with the defaults will yield a similar configuration to that of the PoolFormer
    [sail/poolformer_s12](https://huggingface.co/sail/poolformer_s12) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    """
    # PoolFormer 模型的配置类,用于存储 PoolFormerModel 的配置信息
    # 可根据指定参数实例化 PoolFormer 模型,定义模型架构
    # 使用默认参数实例化配置对象将得到与 PoolFormer sail/poolformer_s12 架构相似的配置
    # 配置对象继承自 PretrainedConfig,可用于控制模型输出。详细信息请阅读 PretrainedConfig 的文档。
    # 定义模型类型为 "poolformer"
    model_type = "poolformer"
    
    # 定义 PoolFormerModel 类,用于创建 PoolFormer 模型的配置和实例化
    def __init__(
        self,
        num_channels=3,  # 输入图像的通道数,默认为 3
        patch_size=16,   # 输入补丁的大小,默认为 16
        stride=16,       # 输入补丁的步长,默认为 16
        pool_size=3,     # 池化窗口的大小,默认为 3
        mlp_ratio=4.0,   # MLP 输出通道数与输入通道数的比率,默认为 4.0
        depths=[2, 2, 6, 2],           # 每个编码器块的深度,默认为 `[2, 2, 6, 2]`
        hidden_sizes=[64, 128, 320, 512],  # 每个编码器块的隐藏层大小,默认为 `[64, 128, 320, 512]`
        patch_sizes=[7, 3, 3, 3],      # 每个编码器块的输入补丁大小,默认为 `[7, 3, 3, 3]`
        strides=[4, 2, 2, 2],          # 每个编码器块的输入补丁步长,默认为 `[4, 2, 2, 2]`
        padding=[2, 1, 1, 1],          # 每个编码器块的输入补丁填充,默认为 `[2, 1, 1, 1]`
        num_encoder_blocks=4,          # 编码器块的数量,默认为 4
        drop_path_rate=0.0,            # 用于丢弃层的丢弃率,默认为 0.0
        hidden_act="gelu",             # 隐藏层的激活函数,默认为 "gelu"
        use_layer_scale=True,          # 是否使用层尺度,默认为 True
        layer_scale_init_value=1e-5,   # 层尺度的初始值,默认为 1e-5
        initializer_range=0.02,        # 权重的初始化范围,默认为 0.02
        **kwargs,
    ):
        ):
        # 初始化函数,设置各个参数并调用父类的初始化方法
        self.num_channels = num_channels
        self.patch_size = patch_size
        self.stride = stride
        self.padding = padding
        self.pool_size = pool_size
        self.hidden_sizes = hidden_sizes
        self.mlp_ratio = mlp_ratio
        self.depths = depths
        self.patch_sizes = patch_sizes
        self.strides = strides
        self.num_encoder_blocks = num_encoder_blocks
        self.drop_path_rate = drop_path_rate
        self.hidden_act = hidden_act
        self.use_layer_scale = use_layer_scale
        self.layer_scale_init_value = layer_scale_init_value
        self.initializer_range = initializer_range
        # 调用父类的初始化方法
        super().__init__(**kwargs)
class PoolFormerOnnxConfig(OnnxConfig):
    # 定义 PoolFormerOnnxConfig 类,继承自 OnnxConfig 类
    
    torch_onnx_minimum_version = version.parse("1.11")
    # 设置 torch_onnx_minimum_version 属性为 1.11 的版本对象

    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        # 定义 inputs 属性作为 property 方法,返回一个有序字典
        return OrderedDict(
            [
                ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
            ]
        )
        # 返回一个字典,键为 "pixel_values",值为另一个字典,映射关系为索引到字符串描述

    @property
    def atol_for_validation(self) -> float:
        # 定义 atol_for_validation 属性作为 property 方法,返回一个浮点数
        return 2e-3
        # 返回浮点数 0.002,用于验证的绝对容差限制

.\models\poolformer\convert_poolformer_original_to_pytorch.py

# 设置日志输出级别为INFO,确保日志在运行时能够显示相关信息
logging.set_verbosity_info()

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义一个函数,用于替换模型权重字典中的键名,通过减去指定偏移量来实现
def replace_key_with_offset(key, offset, original_name, new_name):
    """
    Replaces the key by subtracting the offset from the original layer number
    
    Args:
        key (str): 需要替换的键名
        offset (int): 偏移量,用于计算新的块号
        original_name (str): 原始层名称,用于定位需要替换的部分
        new_name (str): 新的层名称,用于替换原始层名称
        
    Returns:
        str: 替换后的新键名
    """
    # 根据原始层名称获取需要替换的块号
    to_find = original_name.split(".")[0]
    key_list = key.split(".")
    orig_block_num = int(key_list[key_list.index(to_find) - 2])
    layer_num = int(key_list[key_list.index(to_find) - 1])
    
    # 计算新的块号
    new_block_num = orig_block_num - offset

    # 构建新的键名并进行替换
    key = key.replace(f"{orig_block_num}.{layer_num}.{original_name}", 
                      f"block.{new_block_num}.{layer_num}.{new_name}")
    return key


def rename_keys(state_dict):
    # 使用有序字典保存新的状态字典
    new_state_dict = OrderedDict()
    # 初始化嵌入层的计数和补丁嵌入偏移量
    total_embed_found, patch_emb_offset = 0, 0
    # 遍历给定状态字典中的键值对
    for key, value in state_dict.items():
        # 如果键以"network"开头,替换为"poolformer.encoder"
        if key.startswith("network"):
            key = key.replace("network", "poolformer.encoder")
        
        # 如果键包含"proj",处理第一个嵌入和内部嵌入层的偏置项
        if "proj" in key:
            # 如果键以"bias"结尾且不包含"patch_embed",增加嵌入偏置的偏移量
            if key.endswith("bias") and "patch_embed" not in key:
                patch_emb_offset += 1
            
            # 替换"proj"之前的部分为"patch_embeddings.{total_embed_found}.",
            # 并将"proj"替换为"projection"
            to_replace = key[: key.find("proj")]
            key = key.replace(to_replace, f"patch_embeddings.{total_embed_found}.")
            key = key.replace("proj", "projection")
            
            # 如果键以"bias"结尾,增加已找到的嵌入总数
            if key.endswith("bias"):
                total_embed_found += 1
        
        # 如果键包含"patch_embeddings",在键前面添加"poolformer.encoder."
        if "patch_embeddings" in key:
            key = "poolformer.encoder." + key
        
        # 如果键包含"mlp.fc1",调用函数替换键名,处理偏置项偏移
        if "mlp.fc1" in key:
            key = replace_key_with_offset(key, patch_emb_offset, "mlp.fc1", "output.conv1")
        
        # 如果键包含"mlp.fc2",调用函数替换键名,处理偏置项偏移
        if "mlp.fc2" in key:
            key = replace_key_with_offset(key, patch_emb_offset, "mlp.fc2", "output.conv2")
        
        # 如果键包含"norm1",调用函数替换键名,处理偏置项偏移
        if "norm1" in key:
            key = replace_key_with_offset(key, patch_emb_offset, "norm1", "before_norm")
        
        # 如果键包含"norm2",调用函数替换键名,处理偏置项偏移
        if "norm2" in key:
            key = replace_key_with_offset(key, patch_emb_offset, "norm2", "after_norm")
        
        # 如果键为"layer_scale_1",调用函数替换键名,处理偏置项偏移
        if "layer_scale_1" in key:
            key = replace_key_with_offset(key, patch_emb_offset, "layer_scale_1", "layer_scale_1")
        
        # 如果键为"layer_scale_2",调用函数替换键名,处理偏置项偏移
        if "layer_scale_2" in key:
            key = replace_key_with_offset(key, patch_emb_offset, "layer_scale_2", "layer_scale_2")
        
        # 如果键包含"head",将"head"替换为"classifier"
        if "head" in key:
            key = key.replace("head", "classifier")
        
        # 将处理后的新键值对存入新的状态字典中
        new_state_dict[key] = value
    
    # 返回处理后的新状态字典
    return new_state_dict
# We will verify our results on a COCO image
def prepare_img():
    # 定义 COCO 图像的 URL
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    # 使用 requests 库获取图像的原始字节流,并用 PIL 库打开图像
    image = Image.open(requests.get(url, stream=True).raw)

    return image


@torch.no_grad()
def convert_poolformer_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path):
    """
    Copy/paste/tweak model's weights to our PoolFormer structure.
    """

    # load default PoolFormer configuration
    config = PoolFormerConfig()

    # set attributes based on model_name
    repo_id = "huggingface/label-files"
    # 从模型名字中提取尺寸信息
    size = model_name[-3:]
    config.num_labels = 1000
    filename = "imagenet-1k-id2label.json"
    expected_shape = (1, 1000)

    # set config attributes
    # 从 HuggingFace Hub 下载并加载 id 到 label 的映射
    id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
    id2label = {int(k): v for k, v in id2label.items()}
    config.id2label = id2label
    config.label2id = {v: k for k, v in id2label.items()}
    
    # 根据模型尺寸设置不同的配置参数
    if size == "s12":
        config.depths = [2, 2, 6, 2]
        config.hidden_sizes = [64, 128, 320, 512]
        config.mlp_ratio = 4.0
        crop_pct = 0.9
    elif size == "s24":
        config.depths = [4, 4, 12, 4]
        config.hidden_sizes = [64, 128, 320, 512]
        config.mlp_ratio = 4.0
        crop_pct = 0.9
    elif size == "s36":
        config.depths = [6, 6, 18, 6]
        config.hidden_sizes = [64, 128, 320, 512]
        config.mlp_ratio = 4.0
        config.layer_scale_init_value = 1e-6
        crop_pct = 0.9
    elif size == "m36":
        config.depths = [6, 6, 18, 6]
        config.hidden_sizes = [96, 192, 384, 768]
        config.mlp_ratio = 4.0
        config.layer_scale_init_value = 1e-6
        crop_pct = 0.95
    elif size == "m48":
        config.depths = [8, 8, 24, 8]
        config.hidden_sizes = [96, 192, 384, 768]
        config.mlp_ratio = 4.0
        config.layer_scale_init_value = 1e-6
        crop_pct = 0.95
    else:
        # 如果尺寸不在支持范围内,抛出异常
        raise ValueError(f"Size {size} not supported")

    # 加载 PoolFormerImageProcessor,用于处理图像
    image_processor = PoolFormerImageProcessor(crop_pct=crop_pct)

    # 准备图像数据
    image = prepare_img()
    # 使用图像处理器处理图像并获取像素值张量
    pixel_values = image_processor(images=image, return_tensors="pt").pixel_values

    # 打印日志,显示模型转换开始
    logger.info(f"Converting model {model_name}...")

    # 加载原始的模型状态字典
    state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))

    # 重命名模型状态字典的键名
    state_dict = rename_keys(state_dict)

    # 创建 HuggingFace 模型并加载状态字典
    model = PoolFormerForImageClassification(config)
    model.load_state_dict(state_dict)
    model.eval()

    # 再次定义图像处理器
    image_processor = PoolFormerImageProcessor(crop_pct=crop_pct)
    # 使用 prepare_img 函数准备图像并获取像素值张量
    pixel_values = image_processor(images=prepare_img(), return_tensors="pt").pixel_values

    # 模型前向传播
    outputs = model(pixel_values)
    logits = outputs.logits

    # 定义不同模型的预期 logit 切片
    # 如果尺寸为 "s12",设置预期切片为指定的张量
    if size == "s12":
        expected_slice = torch.tensor([-0.3045, -0.6758, -0.4869])
    # 如果尺寸为 "s24",设置预期切片为指定的张量
    elif size == "s24":
        expected_slice = torch.tensor([0.4402, -0.1374, -0.8045])
    # 如果尺寸为 "s36",设置预期切片为指定的张量
    elif size == "s36":
        expected_slice = torch.tensor([-0.6080, -0.5133, -0.5898])
    # 如果尺寸为 "m36",设置预期切片为指定的张量
    elif size == "m36":
        expected_slice = torch.tensor([0.3952, 0.2263, -1.2668])
    # 如果尺寸为 "m48",设置预期切片为指定的张量
    elif size == "m48":
        expected_slice = torch.tensor([0.1167, -0.0656, -0.3423])
    else:
        # 抛出异常,显示不支持的尺寸
        raise ValueError(f"Size {size} not supported")

    # 验证 logits 的形状是否符合预期形状
    assert logits.shape == expected_shape
    # 验证 logits 的前三个元素是否接近于预期切片,允许的绝对误差为 1e-2
    assert torch.allclose(logits[0, :3], expected_slice, atol=1e-2)

    # 最后,保存 PyTorch 模型和图像处理器
    logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...")
    # 创建保存路径(如果不存在)
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
    # 将模型保存到指定路径
    model.save_pretrained(pytorch_dump_folder_path)
    # 打印保存图像处理器的消息
    print(f"Saving image processor to {pytorch_dump_folder_path}")
    # 将图像处理器保存到指定路径
    image_processor.save_pretrained(pytorch_dump_folder_path)
# 如果当前脚本被直接执行而非被导入其他模块,则执行以下代码块
if __name__ == "__main__":
    # 创建一个参数解析器对象
    parser = argparse.ArgumentParser()

    # 添加一个命令行参数,用于指定模型的名称,默认为"poolformer_s12"
    parser.add_argument(
        "--model_name",
        default="poolformer_s12",
        type=str,
        help="Name of the model you'd like to convert.",
    )

    # 添加一个命令行参数,用于指定原始 PyTorch checkpoint 的路径(.pth 文件)
    parser.add_argument(
        "--checkpoint_path", 
        default=None, 
        type=str, 
        help="Path to the original PyTorch checkpoint (.pth file)."
    )

    # 添加一个命令行参数,用于指定输出 PyTorch 模型的文件夹路径
    parser.add_argument(
        "--pytorch_dump_folder_path", 
        default=None, 
        type=str, 
        help="Path to the folder to output PyTorch model."
    )

    # 解析命令行参数,并将它们存储在 args 对象中
    args = parser.parse_args()

    # 调用 convert_poolformer_checkpoint 函数,传入命令行参数中指定的模型名称、原始 checkpoint 路径和输出文件夹路径
    convert_poolformer_checkpoint(args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path)

.\models\poolformer\feature_extraction_poolformer.py

# 设置文件编码为 UTF-8
# 版权声明,版权归 HuggingFace Inc. 团队所有,保留所有权利
#
# 根据 Apache 许可证版本 2.0 许可,除非符合许可,否则不得使用此文件
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件是基于“按原样提供”的基础分发的,无任何明示或暗示的担保或条件
# 请查阅许可证了解详细信息
"""PoolFormer 的特征提取器类。"""

# 导入警告模块
import warnings

# 导入日志工具
from ...utils import logging
# 从 image_processing_poolformer 模块中导入 PoolFormerImageProcessor 类
from .image_processing_poolformer import PoolFormerImageProcessor

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# PoolFormerFeatureExtractor 类继承自 PoolFormerImageProcessor 类
class PoolFormerFeatureExtractor(PoolFormerImageProcessor):
    # 初始化方法,接受任意参数和关键字参数
    def __init__(self, *args, **kwargs) -> None:
        # 发出未来警告,表明 PoolFormerFeatureExtractor 类将在 Transformers 版本 5 中被删除
        warnings.warn(
            "The class PoolFormerFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
            " Please use PoolFormerImageProcessor instead.",
            FutureWarning,
        )
        # 调用父类的初始化方法
        super().__init__(*args, **kwargs)

.\models\poolformer\image_processing_poolformer.py

# 设置编码格式为UTF-8
# 版权声明和许可证信息
# 版权归The HuggingFace Inc.团队所有,保留所有权利。
# 根据Apache License 2.0许可证使用本文件,除非符合许可证中的条款,否则不得使用此文件。
# 您可以在以下网址获取许可证的副本:http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则按"原样"分发本软件,
# 没有任何明示或暗示的担保或条件。详细信息请参阅许可证。
"""PoolFormer的图像处理类。"""

from typing import Dict, List, Optional, Union

import numpy as np

# 导入图像处理相关的工具和库
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
    get_resize_output_image_size,
    resize,
    to_channel_dimension_format,
)
from ...image_utils import (
    IMAGENET_DEFAULT_MEAN,
    IMAGENET_DEFAULT_STD,
    ChannelDimension,
    ImageInput,
    PILImageResampling,
    infer_channel_dimension_format,
    is_scaled_image,
    make_list_of_images,
    to_numpy_array,
    valid_images,
    validate_kwargs,
    validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging

# 如果视觉库可用,则导入PIL库
if is_vision_available():
    import PIL

# 获取日志记录器
logger = logging.get_logger(__name__)


# PoolFormer图像处理器类,继承自BaseImageProcessor
class PoolFormerImageProcessor(BaseImageProcessor):
    r"""
    构造一个PoolFormer图像处理器。

    """

    # 模型输入的名称列表
    model_input_names = ["pixel_values"]

    # 初始化方法,设置图像处理器的各种参数
    def __init__(
        self,
        do_resize: bool = True,
        size: Dict[str, int] = None,
        crop_pct: int = 0.9,
        resample: PILImageResampling = PILImageResampling.BICUBIC,
        do_center_crop: bool = True,
        crop_size: Dict[str, int] = None,
        rescale_factor: Union[int, float] = 1 / 255,
        do_rescale: bool = True,
        do_normalize: bool = True,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        # 如果传入的尺寸参数为None,则设定默认值为{"shortest_edge": 224}
        size = size if size is not None else {"shortest_edge": 224}
        # 调用函数get_size_dict,获取调整尺寸的字典,允许非正方形
        size = get_size_dict(size, default_to_square=False)
        # 如果传入的裁剪尺寸参数为None,则设定默认值为{"height": 224, "width": 224}
        crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
        # 调用函数get_size_dict,获取裁剪尺寸的字典
        crop_size = get_size_dict(crop_size, param_name="crop_size")

        # 设置是否进行调整尺寸的标志
        self.do_resize = do_resize
        # 设置调整尺寸的参数字典
        self.size = size
        # 设置裁剪比例
        self.crop_pct = crop_pct
        # 设置重采样方法
        self.resample = resample
        # 设置是否进行中心裁剪的标志
        self.do_center_crop = do_center_crop
        # 设置裁剪尺寸的参数字典
        self.crop_size = crop_size
        # 设置是否进行重新缩放的标志
        self.do_rescale = do_rescale
        # 设置重新缩放因子
        self.rescale_factor = rescale_factor
        # 设置是否进行归一化的标志
        self.do_normalize = do_normalize
        # 设置图像均值,如果未指定则使用默认值IMAGENET_DEFAULT_MEAN
        self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
        # 设置图像标准差,如果未指定则使用默认值IMAGENET_DEFAULT_STD
        self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
        # 设置有效的处理器关键字列表
        self._valid_processor_keys = [
            "images",
            "do_resize",
            "size",
            "crop_pct",
            "resample",
            "do_center_crop",
            "crop_size",
            "do_rescale",
            "rescale_factor",
            "do_normalize",
            "image_mean",
            "image_std",
            "return_tensors",
            "data_format",
            "input_data_format",
        ]

.\models\poolformer\modeling_poolformer.py

# 设置文件编码为 UTF-8
# 版权声明和所有权信息
#
# 根据 Apache 许可证 2.0 版本(“许可证”)授权使用此文件;
# 除非符合许可证,否则不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则按“原样”分发软件
# 没有任何明示或暗示的保证或条件。
# 请参阅许可证了解特定语言下的权限和限制。
""" PyTorch PoolFormer model."""

# 导入必要的库
import collections.abc
from typing import Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

# 导入活化函数映射
from ...activations import ACT2FN
# 导入模型输出类
from ...modeling_outputs import BaseModelOutputWithNoAttention, ImageClassifierOutputWithNoAttention
# 导入预训练模型类
from ...modeling_utils import PreTrainedModel
# 导入工具函数和日志记录
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
# 导入 PoolFormer 配置类
from .configuration_poolformer import PoolFormerConfig

# 获取日志记录器
logger = logging.get_logger(__name__)

# 用于文档的配置名称
_CONFIG_FOR_DOC = "PoolFormerConfig"

# 用于文档的检查点名称
_CHECKPOINT_FOR_DOC = "sail/poolformer_s12"
# 预期的输出形状
_EXPECTED_OUTPUT_SHAPE = [1, 512, 7, 7]

# 图像分类检查点名称
_IMAGE_CLASS_CHECKPOINT = "sail/poolformer_s12"
# 预期的图像分类输出
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"

# 预训练模型存档列表
POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "sail/poolformer_s12",
    # 查看所有 PoolFormer 模型 https://huggingface.co/models?filter=poolformer
]


# 从 transformers.models.beit.modeling_beit.drop_path 复制的函数
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
    """
    按样本丢弃路径(随机深度)(在残差块的主路径中应用时)。

    Ross Wightman 的评论:这与我为 EfficientNet 等网络创建的 DropConnect 实现相同,
    但原始名称具有误导性,因为“Drop Connect”是另一篇论文中不同形式的丢弃...
    参见讨论:https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
    我选择改变层和参数名称为“drop path”,而不是将 DropConnect 作为层名称并使用“生存率”作为参数。
    """
    if drop_prob == 0.0 or not training:
        return input
    keep_prob = 1 - drop_prob
    shape = (input.shape[0],) + (1,) * (input.ndim - 1)  # 适用于不同维度张量,而不仅仅是 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
    random_tensor.floor_()  # 二值化
    output = input.div(keep_prob) * random_tensor
    return output


# 从 transformers.models.beit.modeling_beit.BeitDropPath 复制的类,并将 Beit 改为 PoolFormer
class PoolFormerDropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""

    # 初始化方法,用于设置实例的初始状态
    def __init__(self, drop_prob: Optional[float] = None) -> None:
        super().__init__()  # 调用父类的初始化方法
        self.drop_prob = drop_prob  # 初始化实例变量 drop_prob,用于存储丢弃概率

    # 前向传播方法,接收隐藏状态作为输入,返回经过丢弃路径处理后的隐藏状态
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return drop_path(hidden_states, self.drop_prob, self.training)

    # 返回该层的额外信息的字符串表示,这里返回丢弃概率的字符串形式
    def extra_repr(self) -> str:
        return "p={}".format(self.drop_prob)
class PoolFormerEmbeddings(nn.Module):
    """
    Construct Patch Embeddings.
    """

    def __init__(self, hidden_size, num_channels, patch_size, stride, padding, norm_layer=None):
        super().__init__()
        # 将 patch_size、stride 和 padding 转换为可迭代对象,如果它们不是的话
        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
        stride = stride if isinstance(stride, collections.abc.Iterable) else (stride, stride)
        padding = padding if isinstance(padding, collections.abc.Iterable) else (padding, padding)

        # 使用卷积层进行投影,将输入的图像通道数转换为隐藏大小的特征图
        self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=stride, padding=padding)
        # 根据提供的规范化层或者使用身份映射作为默认规范化方法
        self.norm = norm_layer(hidden_size) if norm_layer else nn.Identity()

    def forward(self, pixel_values):
        # 对输入的像素值进行投影处理,得到嵌入表示
        embeddings = self.projection(pixel_values)
        # 对投影后的特征图进行规范化处理
        embeddings = self.norm(embeddings)
        return embeddings


class PoolFormerGroupNorm(nn.GroupNorm):
    """
    Group Normalization with 1 group. Input: tensor in shape [B, C, H, W]
    """

    def __init__(self, num_channels, **kwargs):
        super().__init__(1, num_channels, **kwargs)


class PoolFormerPooling(nn.Module):
    def __init__(self, pool_size):
        super().__init__()
        # 使用平均池化层进行特征图的平均池化操作
        self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False)

    def forward(self, hidden_states):
        # 对输入的隐藏状态进行池化操作,并返回池化结果减去原始隐藏状态的值
        return self.pool(hidden_states) - hidden_states


class PoolFormerOutput(nn.Module):
    def __init__(self, config, dropout_prob, hidden_size, intermediate_size):
        super().__init__()
        # 使用卷积层将隐藏大小的特征图转换为中间大小的特征图
        self.conv1 = nn.Conv2d(hidden_size, intermediate_size, 1)
        # 使用卷积层将中间大小的特征图转换为隐藏大小的特征图
        self.conv2 = nn.Conv2d(intermediate_size, hidden_size, 1)
        # 使用 PoolFormerDropPath 类来执行丢弃路径(drop path)操作,其中 dropout_prob 是丢弃概率
        self.drop = PoolFormerDropPath(dropout_prob)
        # 根据配置选择相应的激活函数,存储到 self.act_fn 中
        if isinstance(config.hidden_act, str):
            self.act_fn = ACT2FN[config.hidden_act]
        else:
            self.act_fn = config.hidden_act

    def forward(self, hidden_states):
        # 使用第一个卷积层处理隐藏状态
        hidden_states = self.conv1(hidden_states)
        # 应用选择的激活函数
        hidden_states = self.act_fn(hidden_states)
        # 应用丢弃路径操作
        hidden_states = self.drop(hidden_states)
        # 使用第二个卷积层处理更新后的隐藏状态
        hidden_states = self.conv2(hidden_states)
        # 再次应用丢弃路径操作
        hidden_states = self.drop(hidden_states)

        return hidden_states


class PoolFormerLayer(nn.Module):
    """This corresponds to the 'PoolFormerBlock' class in the original implementation."""
    # 初始化函数,用于初始化 PoolFormer 类的实例
    def __init__(self, config, num_channels, pool_size, hidden_size, intermediate_size, drop_path):
        super().__init__()
        # 初始化池化层对象
        self.pooling = PoolFormerPooling(pool_size)
        # 初始化输出层对象
        self.output = PoolFormerOutput(config, drop_path, hidden_size, intermediate_size)
        # 初始化归一化层对象(前)
        self.before_norm = PoolFormerGroupNorm(num_channels)
        # 初始化归一化层对象(后)
        self.after_norm = PoolFormerGroupNorm(num_channels)

        # 根据 drop_path 的值初始化 DropPath 层对象或者使用恒等映射(Identity)
        self.drop_path = PoolFormerDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        # 是否使用层尺度缩放
        self.use_layer_scale = config.use_layer_scale
        if config.use_layer_scale:
            # 初始化第一层尺度参数
            self.layer_scale_1 = nn.Parameter(
                config.layer_scale_init_value * torch.ones((num_channels)), requires_grad=True
            )
            # 初始化第二层尺度参数
            self.layer_scale_2 = nn.Parameter(
                config.layer_scale_init_value * torch.ones((num_channels)), requires_grad=True
            )

    # 前向传播函数,处理输入的 hidden_states,并返回处理后的 outputs
    def forward(self, hidden_states):
        # 如果使用层尺度缩放
        if self.use_layer_scale:
            # 执行池化操作,再进行归一化和尺度缩放
            pooling_output = self.pooling(self.before_norm(hidden_states))
            scaled_op = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * pooling_output
            # 第一个残差连接
            hidden_states = hidden_states + self.drop_path(scaled_op)
            outputs = ()

            # 执行输出层操作,再进行归一化和尺度缩放
            layer_output = self.output(self.after_norm(hidden_states))
            scaled_op = self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * layer_output
            # 第二个残差连接
            output = hidden_states + self.drop_path(scaled_op)

            outputs = (output,) + outputs
            return outputs

        else:
            # 如果不使用层尺度缩放,执行池化、归一化、DropPath,再进行残差连接
            pooling_output = self.drop_path(self.pooling(self.before_norm(hidden_states)))
            # 第一个残差连接
            hidden_states = pooling_output + hidden_states
            outputs = ()

            # 在 PoolFormerOutput 块内部执行第二个残差连接
            layer_output = self.drop_path(self.output(self.after_norm(hidden_states)))
            output = hidden_states + layer_output

            outputs = (output,) + outputs
            return outputs
class PoolFormerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # stochastic depth decay rule
        # 生成随机深度衰减规则,根据config.drop_path_rate生成一个线性间隔的衰减率列表
        dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]

        # patch embeddings
        embeddings = []
        for i in range(config.num_encoder_blocks):
            embeddings.append(
                PoolFormerEmbeddings(
                    patch_size=config.patch_sizes[i],
                    stride=config.strides[i],
                    padding=config.padding[i],
                    num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1],
                    hidden_size=config.hidden_sizes[i],
                )
            )
        self.patch_embeddings = nn.ModuleList(embeddings)

        # Transformer blocks
        blocks = []
        cur = 0
        for i in range(config.num_encoder_blocks):
            # each block consists of layers
            layers = []
            if i != 0:
                cur += config.depths[i - 1]
            for j in range(config.depths[i]):
                layers.append(
                    PoolFormerLayer(
                        config,
                        num_channels=config.hidden_sizes[i],
                        pool_size=config.pool_size,
                        hidden_size=config.hidden_sizes[i],
                        intermediate_size=int(config.hidden_sizes[i] * config.mlp_ratio),
                        drop_path=dpr[cur + j],
                    )
                )
            blocks.append(nn.ModuleList(layers))

        self.block = nn.ModuleList(blocks)

    def forward(self, pixel_values, output_hidden_states=False, return_dict=True):
        all_hidden_states = () if output_hidden_states else None

        hidden_states = pixel_values
        for idx, layers in enumerate(zip(self.patch_embeddings, self.block)):
            embedding_layer, block_layer = layers
            # Get patch embeddings from hidden_states
            # 从隐藏状态中获取补丁嵌入
            hidden_states = embedding_layer(hidden_states)
            # Send the embeddings through the blocks
            # 将嵌入通过Transformer块
            for _, blk in enumerate(block_layer):
                layer_outputs = blk(hidden_states)
                hidden_states = layer_outputs[0]

            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)

        # 返回带有或不带注意力的基础模型输出
        return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states)


class PoolFormerPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = PoolFormerConfig
    base_model_prefix = "poolformer"
    main_input_name = "pixel_values"
    def _init_weights(self, module):
        """Initialize the weights"""
        # 检查模块类型是否为线性层或二维卷积层
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            # 对权重进行正态分布初始化,均值为0,标准差为配置中的初始化范围
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 如果存在偏置项,则将其初始化为零
            if module.bias is not None:
                module.bias.data.zero_()
        # 如果模块类型为 LayerNorm 层
        elif isinstance(module, nn.LayerNorm):
            # 将偏置项初始化为零
            module.bias.data.zero_()
            # 将权重初始化为全1
            module.weight.data.fill_(1.0)
# POOLFORMER_START_DOCSTRING 常量,包含 PoolFormerModel 的文档字符串,描述模型作为 PyTorch Module 的用法和配置参数的说明
POOLFORMER_START_DOCSTRING = r"""
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`PoolFormerConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

# POOLFORMER_INPUTS_DOCSTRING 常量,包含 PoolFormerModel 的输入文档字符串,描述输入参数 pixel_values 的格式和用途
POOLFORMER_INPUTS_DOCSTRING = r"""
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`PoolFormerImageProcessor.__call__`] for details.
"""

# 使用装饰器 @add_start_docstrings,为 PoolFormerModel 类添加文档字符串,描述模型输出原始隐藏状态的特性和配置参数
@add_start_docstrings(
    "The bare PoolFormer Model transformer outputting raw hidden-states without any specific head on top.",
    POOLFORMER_START_DOCSTRING,
)
class PoolFormerModel(PoolFormerPreTrainedModel):
    def __init__(self, config):
        # 调用父类的构造函数并初始化配置
        super().__init__(config)
        self.config = config

        # 初始化编码器部分
        self.encoder = PoolFormerEncoder(config)

        # 初始化权重并应用最终处理
        self.post_init()

    # 返回输入嵌入的方法
    def get_input_embeddings(self):
        return self.embeddings.patch_embeddings

    # 使用装饰器 @add_start_docstrings_to_model_forward 和 @add_code_sample_docstrings,为 forward 方法添加文档字符串
    @add_start_docstrings_to_model_forward(POOLFORMER_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=BaseModelOutputWithNoAttention,
        config_class=_CONFIG_FOR_DOC,
        modality="vision",
        expected_output=_EXPECTED_OUTPUT_SHAPE,
    )
    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithNoAttention]:
        # 如果 output_hidden_states 和 return_dict 为 None,则使用配置中的默认值
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 如果 pixel_values 为 None,则抛出 ValueError
        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        # 将输入传递给编码器,获取编码器的输出
        encoder_outputs = self.encoder(
            pixel_values,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = encoder_outputs[0]

        # 如果 return_dict 为 False,则返回一个元组
        if not return_dict:
            return (sequence_output, None) + encoder_outputs[1:]

        # 如果 return_dict 为 True,则返回 BaseModelOutputWithNoAttention 类的对象
        return BaseModelOutputWithNoAttention(
            last_hidden_state=sequence_output,
            hidden_states=encoder_outputs.hidden_states,
        )


# PoolFormerFinalPooler 类的定义,继承自 nn.Module
class PoolFormerFinalPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
    # 定义一个类方法 `forward`,用于前向传播
    def forward(self, hidden_states):
        # 将输入的隐藏状态 `hidden_states` 输入全连接层 `self.dense` 中进行处理
        output = self.dense(hidden_states)
        # 返回处理后的输出结果 `output`
        return output
# 使用自定义的文档字符串描述 PoolFormerForImageClassification 类,说明它是在 PoolFormerPreTrainedModel 基础上添加了图像分类头的变换器模型
@add_start_docstrings(
    """
    PoolFormer Model transformer with an image classification head on top
    """,
    POOLFORMER_START_DOCSTRING,
)
class PoolFormerForImageClassification(PoolFormerPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.poolformer = PoolFormerModel(config)

        # Final norm
        # 使用 PoolFormerGroupNorm 类对模型最后一层的隐藏表示进行归一化处理
        self.norm = PoolFormerGroupNorm(config.hidden_sizes[-1])
        
        # Classifier head
        # 根据配置决定使用线性分类器或者恒等映射来定义分类头
        self.classifier = (
            nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
        )

        # Initialize weights and apply final processing
        # 调用 post_init 方法来初始化权重并进行最终的处理
        self.post_init()

    # 使用自定义的文档字符串描述 forward 方法的输入和输出,包括输入文档、代码示例和预期输出
    @add_start_docstrings_to_model_forward(POOLFORMER_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_IMAGE_CLASS_CHECKPOINT,
        output_type=ImageClassifierOutputWithNoAttention,
        config_class=_CONFIG_FOR_DOC,
        expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
    )
    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        # 其他未列出的参数将由父类处理
        ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        # 如果 return_dict 不为 None,则使用 return_dict;否则使用 self.config.use_return_dict
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 使用 poolformer 进行图像特征提取,可以选择是否返回隐藏状态,根据 return_dict 的设置
        outputs = self.poolformer(
            pixel_values,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 提取模型输出的序列特征
        sequence_output = outputs[0]

        # 对序列特征进行归一化,并计算均值,然后通过分类器得到 logits
        logits = self.classifier(self.norm(sequence_output).mean([-2, -1]))

        # 初始化损失值为 None
        loss = None
        # 如果 labels 不为 None,则计算损失函数
        if labels is not None:
            # 根据问题类型动态确定 self.config.problem_type
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            # 根据不同的问题类型选择不同的损失函数
            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    # 对于单标签回归任务,计算均方误差损失
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    # 对于多标签回归任务,计算均方误差损失
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                # 对于单标签分类任务,使用交叉熵损失函数
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                # 对于多标签分类任务,使用带 logits 的二元交叉熵损失函数
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        # 如果不需要返回字典格式的输出,则返回元组形式的输出
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 如果需要返回字典格式的输出,则创建 ImageClassifierOutputWithNoAttention 对象并返回
        return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)

.\models\poolformer\__init__.py

# 版权声明和许可信息,指明代码版权及使用许可
# 版权所有 2022 年 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证 2.0 版本(“许可证”)许可;
# 除非符合许可证的条款,否则不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则按“原样”分发软件,
# 没有任何明示或暗示的保证或条件。
# 有关具体语言的权限,请参阅许可证。

# 导入类型检查
from typing import TYPE_CHECKING

# 从 utils 模块导入相关依赖和工具函数
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available

# 定义模块导入结构的字典
_import_structure = {
    "configuration_poolformer": [
        "POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
        "PoolFormerConfig",
        "PoolFormerOnnxConfig",
    ]
}

# 检查视觉处理库是否可用,若不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_vision_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,添加 feature_extraction_poolformer 模块的导入项
    _import_structure["feature_extraction_poolformer"] = ["PoolFormerFeatureExtractor"]
    # 添加 image_processing_poolformer 模块的导入项
    _import_structure["image_processing_poolformer"] = ["PoolFormerImageProcessor"]

# 检查 PyTorch 是否可用,若不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,添加 modeling_poolformer 模块的导入项
    _import_structure["modeling_poolformer"] = [
        "POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
        "PoolFormerForImageClassification",
        "PoolFormerModel",
        "PoolFormerPreTrainedModel",
    ]

# 如果当前环境支持类型检查(如 Mypy),执行以下导入
if TYPE_CHECKING:
    # 从 configuration_poolformer 模块导入相关类和常量
    from .configuration_poolformer import (
        POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
        PoolFormerConfig,
        PoolFormerOnnxConfig,
    )

    # 检查视觉处理库是否可用,若可用则导入 feature_extraction_poolformer 和 image_processing_poolformer 模块
    try:
        if not is_vision_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .feature_extraction_poolformer import PoolFormerFeatureExtractor
        from .image_processing_poolformer import PoolFormerImageProcessor

    # 检查 PyTorch 是否可用,若可用则导入 modeling_poolformer 模块
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_poolformer import (
            POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
            PoolFormerForImageClassification,
            PoolFormerModel,
            PoolFormerPreTrainedModel,
        )

# 如果当前环境不支持类型检查,则使用懒加载模块代理导入结构
else:
    import sys

    # 将当前模块替换为懒加载模块的实例,该实例根据需要延迟加载具体模块
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)

.\models\pop2piano\configuration_pop2piano.py

# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Pop2Piano model configuration"""


from ...configuration_utils import PretrainedConfig  # 导入预训练配置的工具类
from ...utils import logging  # 导入日志工具类


logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

# 预训练配置模型的映射字典,将模型名称映射到预训练配置文件的 URL
POP2PIANO_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "sweetcocoa/pop2piano": "https://huggingface.co/sweetcocoa/pop2piano/blob/main/config.json"
}


class Pop2PianoConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`Pop2PianoForConditionalGeneration`]. It is used
    to instantiate a Pop2PianoForConditionalGeneration model according to the specified arguments, defining the model
    architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the
    Pop2Piano [sweetcocoa/pop2piano](https://huggingface.co/sweetcocoa/pop2piano) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    # 定义 `Pop2PianoForConditionalGeneration` 模型的词汇表大小,默认为 2400
    # `inputs_ids` 调用时传入的不同令牌数量,用于 `Pop2PianoForConditionalGeneration`
    vocab_size = 2400
    
    # 定义作曲家的数量,默认为 21
    composer_vocab_size = 21
    
    # 定义编码器层和池化层的大小,默认为 512
    d_model = 512
    
    # 定义每个注意力头中键、查询、值投影的大小,默认为 64
    # 投影层的 `inner_dim` 将被定义为 `num_heads * d_kv`
    d_kv = 64
    
    # 定义每个 `Pop2PianoBlock` 中中间前馈层的大小,默认为 2048
    d_ff = 2048
    
    # 定义Transformer编码器中隐藏层的数量,默认为 6
    num_layers = 6
    
    # 定义Transformer解码器中隐藏层的数量,默认与 `num_layers` 相同
    # 若未设置,将与 `num_layers` 使用相同的值
    num_decoder_layers = None
    
    # 定义Transformer编码器中每个注意力层的注意力头数量,默认为 8
    num_heads = 8
    
    # 定义每个注意力层使用的桶数量,默认为 32
    relative_attention_num_buckets = 32
    
    # 定义用于桶分离的较长序列的最大距离,默认为 128
    relative_attention_max_distance = 128
    
    # 定义所有dropout层的比率,默认为 0.1
    dropout_rate = 0.1
    
    # 定义层归一化层使用的 epsilon 值,默认为 1e-6
    layer_norm_epsilon = 1e-6
    
    # 初始化所有权重矩阵的因子,默认为 1.0
    # 用于初始化测试内部使用,通常应保持为 1.0
    initializer_factor = 1.0
    
    # 定义要使用的前馈层类型,默认为 `"gated-gelu"`
    # 应为 `"relu"` 或 `"gated-gelu"` 之一
    feed_forward_proj = "gated-gelu"
    
    # 模型是否应返回最后的键/值注意力,默认为 `True`
    # 并非所有模型都使用此选项
    use_cache = True
    
    # 定义在 `Pop2PianoDenseActDense` 和 `Pop2PianoDenseGatedActDense` 中使用的激活函数类型,默认为 `"relu"`
    dense_act_fn = "relu"
    
    # 模型类型设置为 `"pop2piano"`
    model_type = "pop2piano"
    
    # 在推断时忽略的键列表,默认包含 `"past_key_values"`
    keys_to_ignore_at_inference = ["past_key_values"]
    # 初始化函数,用于初始化一个自定义的Transformer模型配置
    def __init__(
        self,
        vocab_size=2400,  # 词汇表大小,默认为2400
        composer_vocab_size=21,  # 作曲家词汇表大小,默认为21
        d_model=512,  # Transformer模型的隐藏层维度,默认为512
        d_kv=64,  # 注意力机制中key和value的维度,默认为64
        d_ff=2048,  # Feed Forward网络中间层的维度,默认为2048
        num_layers=6,  # Transformer模型中的层数,默认为6
        num_decoder_layers=None,  # 解码器层数,如果为None则与num_layers相同
        num_heads=8,  # 多头注意力机制中的头数,默认为8
        relative_attention_num_buckets=32,  # 相对位置编码中的桶数,默认为32
        relative_attention_max_distance=128,  # 相对位置编码的最大距离,默认为128
        dropout_rate=0.1,  # Dropout的比率,默认为0.1
        layer_norm_epsilon=1e-6,  # Layer Normalization中的epsilon,默认为1e-6
        initializer_factor=1.0,  # 初始化因子,默认为1.0
        feed_forward_proj="gated-gelu",  # 前向传播的激活函数,默认为"gated-gelu"
        is_encoder_decoder=True,  # 是否是编码器-解码器模型,默认为True
        use_cache=True,  # 是否使用缓存,默认为True
        pad_token_id=0,  # 填充token的ID,默认为0
        eos_token_id=1,  # 结束token的ID,默认为1
        dense_act_fn="relu",  # Dense层的激活函数,默认为"relu"
        **kwargs,  # 其他参数
    ):
        self.vocab_size = vocab_size  # 初始化词汇表大小
        self.composer_vocab_size = composer_vocab_size  # 初始化作曲家词汇表大小
        self.d_model = d_model  # 初始化隐藏层维度
        self.d_kv = d_kv  # 初始化key和value的维度
        self.d_ff = d_ff  # 初始化Feed Forward网络中间层的维度
        self.num_layers = num_layers  # 初始化Transformer模型中的层数
        self.num_decoder_layers = num_decoder_layers if num_decoder_layers is not None else self.num_layers  # 初始化解码器层数
        self.num_heads = num_heads  # 初始化多头注意力机制中的头数
        self.relative_attention_num_buckets = relative_attention_num_buckets  # 初始化相对位置编码中的桶数
        self.relative_attention_max_distance = relative_attention_max_distance  # 初始化相对位置编码的最大距离
        self.dropout_rate = dropout_rate  # 初始化Dropout的比率
        self.layer_norm_epsilon = layer_norm_epsilon  # 初始化Layer Normalization中的epsilon
        self.initializer_factor = initializer_factor  # 初始化初始化因子
        self.feed_forward_proj = feed_forward_proj  # 初始化前向传播的激活函数
        self.use_cache = use_cache  # 初始化是否使用缓存
        self.dense_act_fn = dense_act_fn  # 初始化Dense层的激活函数
        self.is_gated_act = self.feed_forward_proj.split("-")[0] == "gated"  # 检查是否是gated激活函数
        self.hidden_size = self.d_model  # 初始化隐藏层大小为模型维度
        self.num_attention_heads = num_heads  # 初始化注意力头数
        self.num_hidden_layers = num_layers  # 初始化隐藏层的数量

        # 调用父类的初始化方法,设置pad_token_id、eos_token_id、is_encoder_decoder等参数
        super().__init__(
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            is_encoder_decoder=is_encoder_decoder,
            **kwargs,
        )

.\models\pop2piano\convert_pop2piano_weights_to_hf.py

# 加载版权声明和许可证信息
# 版权所有 2023 年 HuggingFace Inc. 团队。保留所有权利。
# 根据 Apache 许可证 2.0 版本进行许可;
# 除非符合许可证的要求,否则不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,软件将按“原样”分发,
# 无任何明示或暗示的担保或条件。
# 有关特定语言的权限,请参阅许可证。

""" 用于从官方库加载 Pop2Piano 模型权重并展示 tokenizer 词汇构建方法的文件 """

import json  # 导入 JSON 模块
import torch  # 导入 PyTorch

from transformers import Pop2PianoConfig, Pop2PianoForConditionalGeneration  # 导入 Pop2Piano 相关类


########################## 模型权重 ##########################

# 这些权重是从官方 pop2piano 仓库下载的
# https://huggingface.co/sweetcocoa/pop2piano/blob/main/model-1999-val_0.67311615.ckpt
official_weights = torch.load("./model-1999-val_0.67311615.ckpt")
state_dict = {}  # 初始化状态字典


# 加载配置并初始化模型
cfg = Pop2PianoConfig.from_pretrained("sweetcocoa/pop2piano")
model = Pop2PianoForConditionalGeneration(cfg)


# 加载相对注意力偏置
state_dict["encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = official_weights["state_dict"][
    "transformer.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
]
state_dict["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = official_weights["state_dict"][
    "transformer.decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
]

# 加载编码器和解码器的嵌入标记和最终层归一化
state_dict["encoder.embed_tokens.weight"] = official_weights["state_dict"]["transformer.encoder.embed_tokens.weight"]
state_dict["decoder.embed_tokens.weight"] = official_weights["state_dict"]["transformer.decoder.embed_tokens.weight"]

state_dict["encoder.final_layer_norm.weight"] = official_weights["state_dict"][
    "transformer.encoder.final_layer_norm.weight"
]
state_dict["decoder.final_layer_norm.weight"] = official_weights["state_dict"][
    "transformer.decoder.final_layer_norm.weight"
]

# 加载 lm_head、mel_conditioner.emb 和 shared
state_dict["lm_head.weight"] = official_weights["state_dict"]["transformer.lm_head.weight"]
state_dict["mel_conditioner.embedding.weight"] = official_weights["state_dict"]["mel_conditioner.embedding.weight"]
state_dict["shared.weight"] = official_weights["state_dict"]["transformer.shared.weight"]

# 加载每个编码器块
for i in range(cfg.num_layers):
    # 第 i 层
    state_dict[f"encoder.block.{i}.layer.0.SelfAttention.q.weight"] = official_weights["state_dict"][
        f"transformer.encoder.block.{i}.layer.0.SelfAttention.q.weight"
    ]
    # 设置编码器(encoder)的每个块(block)中的 SelfAttention 模块的权重参数
    state_dict[f"encoder.block.{i}.layer.0.SelfAttention.k.weight"] = official_weights["state_dict"][
        f"transformer.encoder.block.{i}.layer.0.SelfAttention.k.weight"
    ]
    state_dict[f"encoder.block.{i}.layer.0.SelfAttention.v.weight"] = official_weights["state_dict"][
        f"transformer.encoder.block.{i}.layer.0.SelfAttention.v.weight"
    ]
    state_dict[f"encoder.block.{i}.layer.0.SelfAttention.o.weight"] = official_weights["state_dict"][
        f"transformer.encoder.block.{i}.layer.0.SelfAttention.o.weight"
    ]
    state_dict[f"encoder.block.{i}.layer.0.layer_norm.weight"] = official_weights["state_dict"][
        f"transformer.encoder.block.{i}.layer.0.layer_norm.weight"
    ]

    # 设置编码器(encoder)的每个块(block)中的第二层(layer 1)的 DenseReluDense 模块的权重参数
    state_dict[f"encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight"] = official_weights["state_dict"][
        f"transformer.encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight"
    ]
    state_dict[f"encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight"] = official_weights["state_dict"][
        f"transformer.encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight"
    ]
    state_dict[f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight"] = official_weights["state_dict"][
        f"transformer.encoder.block.{i}.layer.1.DenseReluDense.wo.weight"
    ]
    state_dict[f"encoder.block.{i}.layer.1.layer_norm.weight"] = official_weights["state_dict"][
        f"transformer.encoder.block.{i}.layer.1.layer_norm.weight"
    ]
# 加载每个解码器块的权重

# 循环遍历6个解码器块
for i in range(6):
    # 第 0 层
    state_dict[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"] = official_weights["state_dict"][
        f"transformer.decoder.block.{i}.layer.0.SelfAttention.q.weight"
    ]
    state_dict[f"decoder.block.{i}.layer.0.SelfAttention.k.weight"] = official_weights["state_dict"][
        f"transformer.decoder.block.{i}.layer.0.SelfAttention.k.weight"
    ]
    state_dict[f"decoder.block.{i}.layer.0.SelfAttention.v.weight"] = official_weights["state_dict"][
        f"transformer.decoder.block.{i}.layer.0.SelfAttention.v.weight"
    ]
    state_dict[f"decoder.block.{i}.layer.0.SelfAttention.o.weight"] = official_weights["state_dict"][
        f"transformer.decoder.block.{i}.layer.0.SelfAttention.o.weight"
    ]
    state_dict[f"decoder.block.{i}.layer.0.layer_norm.weight"] = official_weights["state_dict"][
        f"transformer.decoder.block.{i}.layer.0.layer_norm.weight"
    ]

    # 第 1 层
    state_dict[f"decoder.block.{i}.layer.1.EncDecAttention.q.weight"] = official_weights["state_dict"][
        f"transformer.decoder.block.{i}.layer.1.EncDecAttention.q.weight"
    ]
    state_dict[f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"] = official_weights["state_dict"][
        f"transformer.decoder.block.{i}.layer.1.EncDecAttention.k.weight"
    ]
    state_dict[f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"] = official_weights["state_dict"][
        f"transformer.decoder.block.{i}.layer.1.EncDecAttention.v.weight"
    ]
    state_dict[f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"] = official_weights["state_dict"][
        f"transformer.decoder.block.{i}.layer.1.EncDecAttention.o.weight"
    ]
    state_dict[f"decoder.block.{i}.layer.1.layer_norm.weight"] = official_weights["state_dict"][
        f"transformer.decoder.block.{i}.layer.1.layer_norm.weight"
    ]

    # 第 2 层
    state_dict[f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight"] = official_weights["state_dict"][
        f"transformer.decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight"
    ]
    state_dict[f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight"] = official_weights["state_dict"][
        f"transformer.decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight"
    ]
    state_dict[f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"] = official_weights["state_dict"][
        f"transformer.decoder.block.{i}.layer.2.DenseReluDense.wo.weight"
    ]
    state_dict[f"decoder.block.{i}.layer.2.layer_norm.weight"] = official_weights["state_dict"][
        f"transformer.decoder.block.{i}.layer.2.layer_norm.weight"
    ]

# 使用加载的状态字典更新模型的权重
model.load_state_dict(state_dict, strict=True)

# 将模型的状态字典保存到文件
torch.save(state_dict, "./pytorch_model.bin")

########################## TOKENIZER ##########################

# tokenize 和 detokenize 方法来自官方实现

# 链接: https://github.com/sweetcocoa/pop2piano/blob/fac11e8dcfc73487513f4588e8d0c22a22f2fdc5/midi_tokenizer.py#L34
# 定义一个函数用于生成特定类型的令牌编号
def tokenize(idx, token_type, n_special=4, n_note=128, n_velocity=2):
    # 如果令牌类型是 TOKEN_TIME,返回对应的编号
    if token_type == "TOKEN_TIME":
        return n_special + n_note + n_velocity + idx
    # 如果令牌类型是 TOKEN_VELOCITY,返回对应的编号
    elif token_type == "TOKEN_VELOCITY":
        return n_special + n_note + idx
    # 如果令牌类型是 TOKEN_NOTE,返回对应的编号
    elif token_type == "TOKEN_NOTE":
        return n_special + idx
    # 如果令牌类型是 TOKEN_SPECIAL,返回对应的编号
    elif token_type == "TOKEN_SPECIAL":
        return idx
    # 如果令牌类型不在已知类型中,返回 -1
    else:
        return -1


# link : https://github.com/sweetcocoa/pop2piano/blob/fac11e8dcfc73487513f4588e8d0c22a22f2fdc5/midi_tokenizer.py#L48
# 定义一个函数用于将令牌编号反向解析为令牌类型和具体编号
def detokenize(idx, n_special=4, n_note=128, n_velocity=2, time_idx_offset=0):
    # 根据令牌编号判断其属于哪种类型的令牌,并返回对应的令牌类型和具体编号
    if idx >= n_special + n_note + n_velocity:
        return "TOKEN_TIME", (idx - (n_special + n_note + n_velocity)) + time_idx_offset
    elif idx >= n_special + n_note:
        return "TOKEN_VELOCITY", idx - (n_special + n_note)
    elif idx >= n_special:
        return "TOKEN_NOTE", idx - n_special
    else:
        return "TOKEN_SPECIAL", idx


# 创建一个空字典用于存储解析后的令牌编号和对应的字符串表示
decoder = {}
# 遍历令牌的总数,更新 decoder 字典,将每个令牌编号映射为其解析后的字符串表示
for i in range(cfg.vocab_size):
    decoder.update({i: f"{detokenize(i)[1]}_{detokenize(i)[0]}"})

# 创建一个 encoder 字典,将 decoder 中的键值对反转,用于编码时快速查找令牌编号
encoder = {v: k for k, v in decoder.items()}

# 将 encoder 字典保存为 JSON 文件,用于后续使用
with open("./vocab.json", "w") as file:
    file.write(json.dumps(encoder))

.\models\pop2piano\feature_extraction_pop2piano.py

    r"""
    Constructs a Pop2Piano feature extractor.

    This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
    most of the main methods. Users should refer to this superclass for more information regarding those methods.

    This class extracts rhythm and preprocesses the audio before it is passed to the model. First the audio is passed
    to `RhythmExtractor2013` algorithm which extracts the beat_times, beat positions and estimates their confidence as
    well as tempo in bpm, then beat_times is interpolated and to get beatsteps. Later we calculate
    extrapolated_beatsteps from it to be used in tokenizer. On the other hand audio is resampled to self.sampling_rate
    and preprocessed and then log mel spectogram is computed from that to be used in our transformer model.
    """
    
    # 引入警告模块,用于可能的警告信息输出
    import warnings
    # 引入类型提示模块,用于类型检查和提示
    from typing import List, Optional, Union

    # 引入 numpy 库,并给其起一个别名 np
    import numpy
    import numpy as np

    # 引入音频处理相关的函数和工具
    from ...audio_utils import mel_filter_bank, spectrogram
    # 引入特征提取序列工具函数
    from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
    # 引入批处理特征工具
    from ...feature_extraction_utils import BatchFeature
    # 引入常用工具函数
    from ...utils import (
        TensorType,
        is_essentia_available,
        is_librosa_available,
        is_scipy_available,
        logging,
        requires_backends,
    )

    # 如果 Essentia 库可用,则导入相关模块
    if is_essentia_available():
        import essentia
        import essentia.standard

    # 如果 Librosa 库可用,则导入 Librosa 模块
    if is_librosa_available():
        import librosa

    # 如果 Scipy 库可用,则导入 Scipy 模块
    if is_scipy_available():
        import scipy

    # 获取日志记录器对象
    logger = logging.get_logger(__name__)


class Pop2PianoFeatureExtractor(SequenceFeatureExtractor):
    r"""
    Constructs a Pop2Piano feature extractor.

    This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
    most of the main methods. Users should refer to this superclass for more information regarding those methods.

    This class extracts rhythm and preprocesses the audio before it is passed to the model. First the audio is passed
    to `RhythmExtractor2013` algorithm which extracts the beat_times, beat positions and estimates their confidence as
    well as tempo in bpm, then beat_times is interpolated and to get beatsteps. Later we calculate
    extrapolated_beatsteps from it to be used in tokenizer. On the other hand audio is resampled to self.sampling_rate
    and preprocessed and then log mel spectogram is computed from that to be used in our transformer model.
    """
    model_input_names = ["input_features", "beatsteps", "extrapolated_beatstep"]



    # 定义模型输入的名称列表,包括输入特征、节拍步长和外推的节拍步长
    model_input_names = ["input_features", "beatsteps", "extrapolated_beatstep"]



    def __init__(
        self,
        sampling_rate: int = 22050,
        padding_value: int = 0,
        window_size: int = 4096,
        hop_length: int = 1024,
        min_frequency: float = 10.0,
        feature_size: int = 512,
        num_bars: int = 2,
        **kwargs,
    ):
        # 调用父类初始化方法,设置特征大小、采样率和填充值等参数
        super().__init__(
            feature_size=feature_size,
            sampling_rate=sampling_rate,
            padding_value=padding_value,
            **kwargs,
        )
        # 设置对象的属性,包括采样率、填充值、窗口大小、跳跃长度、最小频率、特征大小和节拍条数
        self.sampling_rate = sampling_rate
        self.padding_value = padding_value
        self.window_size = window_size
        self.hop_length = hop_length
        self.min_frequency = min_frequency
        self.feature_size = feature_size
        self.num_bars = num_bars
        # 计算梅尔滤波器组,用于后续的梅尔频谱计算
        self.mel_filters = mel_filter_bank(
            num_frequency_bins=(self.window_size // 2) + 1,
            num_mel_filters=self.feature_size,
            min_frequency=self.min_frequency,
            max_frequency=float(self.sampling_rate // 2),
            sampling_rate=self.sampling_rate,
            norm=None,
            mel_scale="htk",
        )




    def mel_spectrogram(self, sequence: np.ndarray):
        """
        Generates MelSpectrogram.

        Args:
            sequence (`numpy.ndarray`):
                The sequence of which the mel-spectrogram will be computed.
        """
        # 初始化空的梅尔频谱列表
        mel_specs = []
        # 对输入的每个序列进行处理
        for seq in sequence:
            # 应用汉宁窗口函数,用于信号的加窗处理
            window = np.hanning(self.window_size + 1)[:-1]
            # 计算当前序列的梅尔频谱,并加入到梅尔频谱列表中
            mel_specs.append(
                spectrogram(
                    waveform=seq,
                    window=window,
                    frame_length=self.window_size,
                    hop_length=self.hop_length,
                    power=2.0,
                    mel_filters=self.mel_filters,
                )
            )
        # 将梅尔频谱列表转换为 numpy 数组并返回
        mel_specs = np.array(mel_specs)

        return mel_specs
    def extract_rhythm(self, audio: np.ndarray):
        """
        This algorithm(`RhythmExtractor2013`) extracts the beat positions and estimates their confidence as well as
        tempo in bpm for an audio signal. For more information please visit
        https://essentia.upf.edu/reference/std_RhythmExtractor2013.html .

        Args:
            audio(`numpy.ndarray`):
                raw audio waveform which is passed to the Rhythm Extractor.
        """
        # 检查必需的后端库是否存在
        requires_backends(self, ["essentia"])
        # 创建 RhythmExtractor2013 对象,使用多特征方法
        essentia_tracker = essentia.standard.RhythmExtractor2013(method="multifeature")
        # 调用 RhythmExtractor2013 对象处理音频,返回节奏信息
        bpm, beat_times, confidence, estimates, essentia_beat_intervals = essentia_tracker(audio)

        # 返回节拍频率、节拍时间、置信度、估计值和节拍间隔
        return bpm, beat_times, confidence, estimates, essentia_beat_intervals

    def interpolate_beat_times(
        self, beat_times: numpy.ndarray, steps_per_beat: numpy.ndarray, n_extend: numpy.ndarray
    ):
        """
        This method takes beat_times and then interpolates that using `scipy.interpolate.interp1d` and the output is
        then used to convert raw audio to log-mel-spectrogram.

        Args:
            beat_times (`numpy.ndarray`):
                beat_times is passed into `scipy.interpolate.interp1d` for processing.
            steps_per_beat (`int`):
                used as an parameter to control the interpolation.
            n_extend (`int`):
                used as an parameter to control the interpolation.
        """

        # 检查必需的后端库是否存在
        requires_backends(self, ["scipy"])
        # 创建用于插值的 interp1d 函数对象
        beat_times_function = scipy.interpolate.interp1d(
            np.arange(beat_times.size),
            beat_times,
            bounds_error=False,
            fill_value="extrapolate",
        )

        # 使用插值函数对节拍时间进行插值扩展
        ext_beats = beat_times_function(
            np.linspace(0, beat_times.size + n_extend - 1, beat_times.size * steps_per_beat + n_extend)
        )

        # 返回插值后的节拍时间
        return ext_beats
    def preprocess_mel(self, audio: np.ndarray, beatstep: np.ndarray):
        """
        Preprocessing for log-mel-spectrogram

        Args:
            audio (`numpy.ndarray` of shape `(audio_length, )` ):
                Raw audio waveform to be processed.
            beatstep (`numpy.ndarray`):
                Interpolated values of the raw audio. If beatstep[0] is greater than 0.0, then it will be shifted by
                the value at beatstep[0].
        """

        # 检查输入的音频是否为单声道,并且不为 None
        if audio is not None and len(audio.shape) != 1:
            raise ValueError(
                f"Expected `audio` to be a single channel audio input of shape `(n, )` but found shape {audio.shape}."
            )
        
        # 如果 beatstep 的第一个值大于 0.0,则将整个 beatstep 数组向左平移至第一个元素为 0.0
        if beatstep[0] > 0.0:
            beatstep = beatstep - beatstep[0]

        # 计算预处理后的数据点数
        num_steps = self.num_bars * 4
        num_target_steps = len(beatstep)
        
        # 对节拍时间进行插值,以扩展为与处理后的步数匹配的时间点
        extrapolated_beatstep = self.interpolate_beat_times(
            beat_times=beatstep, steps_per_beat=1, n_extend=(self.num_bars + 1) * 4 + 1
        )

        # 初始化样本索引列表和最大特征长度
        sample_indices = []
        max_feature_length = 0
        
        # 划分样本段并计算每个段的特征长度
        for i in range(0, num_target_steps, num_steps):
            start_idx = i
            end_idx = min(i + num_steps, num_target_steps)
            start_sample = int(extrapolated_beatstep[start_idx] * self.sampling_rate)
            end_sample = int(extrapolated_beatstep[end_idx] * self.sampling_rate)
            sample_indices.append((start_sample, end_sample))
            max_feature_length = max(max_feature_length, end_sample - start_sample)
        
        # 初始化填充后的特征批处理列表
        padded_batch = []
        
        # 对每个样本段进行填充处理
        for start_sample, end_sample in sample_indices:
            feature = audio[start_sample:end_sample]
            padded_feature = np.pad(
                feature,
                ((0, max_feature_length - feature.shape[0]),),  # 在最后一维进行填充
                "constant",
                constant_values=0,  # 使用常数值 0 进行填充
            )
            padded_batch.append(padded_feature)

        # 将填充后的特征批处理列表转换为 numpy 数组
        padded_batch = np.asarray(padded_batch)
        
        # 返回填充后的特征批处理和插值后的节拍时间
        return padded_batch, extrapolated_beatstep
    # 定义一个内部方法用于填充特征数据
    def _pad(self, features: np.ndarray, add_zero_line=True):
        # 计算每个特征数据的形状并存储在列表中
        features_shapes = [each_feature.shape for each_feature in features]
        # 初始化存放注意力掩码和填充后特征数据的列表
        attention_masks, padded_features = [], []

        # 遍历每个特征数据及其索引
        for i, each_feature in enumerate(features):
            # 如果特征数据是三维的,则进行"input_features"的填充
            if len(each_feature.shape) == 3:
                # 计算需要填充的值,即特征数据第二维的差值
                features_pad_value = max([*zip(*features_shapes)][1]) - features_shapes[i][1]
                # 创建全为1的注意力掩码
                attention_mask = np.ones(features_shapes[i][:2], dtype=np.int64)
                # 设置特征数据的填充方式和注意力掩码的填充方式
                feature_padding = ((0, 0), (0, features_pad_value), (0, 0))
                attention_mask_padding = (feature_padding[0], feature_padding[1])

            # 如果特征数据是其他维度的,则进行"beatsteps"和"extrapolated_beatstep"的填充
            else:
                # 将特征数据reshape为二维
                each_feature = each_feature.reshape(1, -1)
                # 计算需要填充的值,即特征数据第一维的差值
                features_pad_value = max([*zip(*features_shapes)][0]) - features_shapes[i][0]
                # 创建全为1的注意力掩码并reshape为二维
                attention_mask = np.ones(features_shapes[i], dtype=np.int64).reshape(1, -1)
                # 设置特征数据的填充方式和注意力掩码的填充方式
                feature_padding = attention_mask_padding = ((0, 0), (0, features_pad_value))

            # 对每个特征数据进行填充,使用常数值self.padding_value
            each_padded_feature = np.pad(each_feature, feature_padding, "constant", constant_values=self.padding_value)
            # 对注意力掩码进行填充,使用常数值self.padding_value
            attention_mask = np.pad(
                attention_mask, attention_mask_padding, "constant", constant_values=self.padding_value
            )

            # 如果需要添加零行(add_zero_line为True)
            if add_zero_line:
                # 计算零数组的长度,即特征数据第二维的最大值
                zero_array_len = max([*zip(*features_shapes)][1])

                # 将零数组行连接到每个填充后的特征数据末尾
                each_padded_feature = np.concatenate(
                    [each_padded_feature, np.zeros([1, zero_array_len, self.feature_size])], axis=0
                )
                # 将零数组行连接到每个填充后的注意力掩码末尾
                attention_mask = np.concatenate(
                    [attention_mask, np.zeros([1, zero_array_len], dtype=attention_mask.dtype)], axis=0
                )

            # 将填充后的特征数据和注意力掩码添加到对应的列表中
            padded_features.append(each_padded_feature)
            attention_masks.append(attention_mask)

        # 将所有填充后的特征数据连接成一个numpy数组,并转换为float32类型
        padded_features = np.concatenate(padded_features, axis=0).astype(np.float32)
        # 将所有填充后的注意力掩码连接成一个numpy数组,并转换为int64类型
        attention_masks = np.concatenate(attention_masks, axis=0).astype(np.int64)

        # 返回填充后的特征数据和注意力掩码
        return padded_features, attention_masks
    ):
        """
        Pads the inputs to the same length and returns attention_mask.

        Args:
            inputs (`BatchFeature`):
                Processed audio features.
            is_batched (`bool`):
                Whether inputs are batched or not.
            return_attention_mask (`bool`):
                Whether to return attention mask or not.
            return_tensors (`str` or [`~utils.TensorType`], *optional*):
                If set, will return tensors instead of a list of Python integers. Acceptable values are:
                - `'pt'`: Return PyTorch `torch.Tensor` objects.
                - `'np'`: Return Numpy `np.ndarray` objects.
                If nothing is specified, it will return a list of `np.ndarray` arrays.
        Return:
            `BatchFeature` with attention_mask, attention_mask_beatsteps, and attention_mask_extrapolated_beatstep added
            to it:
            - **attention_mask** numpy.ndarray of shape `(batch_size, max_input_features_seq_length)` --
                Example:
                    1, 1, 1, 0, 0 (audio 1, padded to a max length of 5 with 2 zeros indicating padding)

                    0, 0, 0, 0, 0 (zero padding to separate audio 1 and 2)

                    1, 1, 1, 1, 1 (audio 2)

                    0, 0, 0, 0, 0 (zero padding to separate audio 2 and 3)

                    1, 1, 1, 1, 1 (audio 3)
            - **attention_mask_beatsteps** numpy.ndarray of shape `(batch_size, max_beatsteps_seq_length)`
            - **attention_mask_extrapolated_beatstep** numpy.ndarray of shape `(batch_size,
              max_extrapolated_beatstep_seq_length)`
        """

        processed_features_dict = {}
        # Iterate through each feature and pad its values
        for feature_name, feature_value in inputs.items():
            # If the feature is 'input_features', pad it with an additional zero line
            if feature_name == "input_features":
                padded_feature_values, attention_mask = self._pad(feature_value, add_zero_line=True)
                processed_features_dict[feature_name] = padded_feature_values
                # Optionally add attention_mask to processed_features_dict
                if return_attention_mask:
                    processed_features_dict["attention_mask"] = attention_mask
            else:
                # For other features, pad without adding an extra zero line
                padded_feature_values, attention_mask = self._pad(feature_value, add_zero_line=False)
                processed_features_dict[feature_name] = padded_feature_values
                # Optionally add feature-specific attention_mask to processed_features_dict
                if return_attention_mask:
                    processed_features_dict[f"attention_mask_{feature_name}"] = attention_mask

        # If processing a single example and not returning attention_mask, remove the last zero array line
        if not is_batched and not return_attention_mask:
            processed_features_dict["input_features"] = processed_features_dict["input_features"][:-1, ...]

        # Create BatchFeature object with processed features and optionally convert to specified tensor type
        outputs = BatchFeature(processed_features_dict, tensor_type=return_tensors)

        return outputs
    # 定义一个特殊方法 `__call__`,使得对象可以像函数一样被调用
    def __call__(
        # 输入参数 audio 可以是 numpy 数组、浮点数列表、numpy 数组列表或浮点数列表的列表
        self,
        # 输入参数 sampling_rate 可以是整数或整数列表,表示采样率
        audio: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
        # 每个拍子的步数,默认为 2
        sampling_rate: Union[int, List[int]],
        # 每个拍子的步数,默认为 2
        steps_per_beat: int = 2,
        # 是否重新采样音频,默认为 True
        resample: Optional[bool] = True,
        # 是否返回注意力掩码,默认为 False
        return_attention_mask: Optional[bool] = False,
        # 是否返回张量,默认为 None(即不返回张量)
        return_tensors: Optional[Union[str, TensorType]] = None,
        # 其他关键字参数,以字典形式接收
        **kwargs,

.\models\pop2piano\modeling_pop2piano.py

# coding=utf-8
# Copyright 2023 The Pop2Piano Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Pop2Piano model."""

import copy
import math
from typing import Optional, Tuple, Union

import torch
from torch import nn
from torch.nn import CrossEntropyLoss

from transformers.generation import GenerationConfig

from ...activations import ACT2FN
from ...modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    Seq2SeqLMOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_torch_fx_proxy,
    logging,
    replace_return_docstrings,
)
from .configuration_pop2piano import Pop2PianoConfig

logger = logging.get_logger(__name__)

_load_pop2piano_layer_norm = True

try:
    from apex.normalization import FusedRMSNorm

    _load_pop2piano_layer_norm = False

    logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of Pop2PianoLayerNorm")
except ImportError:
    # using the normal Pop2PianoLayerNorm
    pass
except Exception:
    logger.warning("Discovered apex but it failed to load, falling back to Pop2PianoLayerNorm")
    pass


_CONFIG_FOR_DOC = "Pop2PianoConfig"
_CHECKPOINT_FOR_DOC = "sweetcocoa/pop2piano"

POP2PIANO_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "sweetcocoa/pop2piano",
    # See all Pop2Piano models at https://huggingface.co/models?filter=pop2piano
]


POP2PIANO_INPUTS_DOCSTRING = r"""
"""


# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->Pop2Piano
class Pop2PianoLayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        Construct a layernorm module in the Pop2Piano style. No bias and no subtraction of mean.
        """
        super().__init__()
        # Initialize the weight parameter with ones (no bias)
        self.weight = nn.Parameter(torch.ones(hidden_size))
        # Set the epsilon value for numerical stability in variance calculation
        self.variance_epsilon = eps
    # 定义一个前向传播方法,接收隐藏状态作为输入
    def forward(self, hidden_states):
        # Pop2Piano 使用一种只进行缩放而不进行偏移的层归一化,也称为均方根层归一化
        # 参考论文 https://arxiv.org/abs/1910.07467 ,因此方差是在没有均值和偏差的情况下计算的。
        # 另外,我们希望确保对半精度输入的累积是在 fp32 中完成的。

        # 计算隐藏状态的方差,将隐藏状态转换为 torch.float32 类型,然后平方并在最后一个维度上取平均值
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        # 使用归一化的方差对隐藏状态进行归一化
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

        # 如果权重的数据类型是 torch.float16 或 torch.bfloat16,则将隐藏状态转换为相应的半精度类型
        if self.weight.dtype in [torch.float16, torch.bfloat16]:
            hidden_states = hidden_states.to(self.weight.dtype)

        # 返回加权的隐藏状态
        return self.weight * hidden_states
# 如果 `_load_pop2piano_layer_norm` 为假,将 `Pop2PianoLayerNorm` 设置为 `FusedRMSNorm` 类。
if not _load_pop2piano_layer_norm:
    Pop2PianoLayerNorm = FusedRMSNorm  # noqa

# 将 `Pop2PianoLayerNorm` 添加到 `ALL_LAYERNORM_LAYERS` 列表中
ALL_LAYERNORM_LAYERS.append(Pop2PianoLayerNorm)


# 从 `transformers.models.t5.modeling_t5.T5DenseActDense` 复制,并修改为 `Pop2PianoDenseActDense`,同时将 `T5` 修改为 `Pop2Piano`,`t5` 修改为 `pop2piano`
class Pop2PianoDenseActDense(nn.Module):
    def __init__(self, config: Pop2PianoConfig):
        super().__init__()
        # 初始化线性层 `wi`,输入维度为 `config.d_model`,输出维度为 `config.d_ff`,无偏置
        self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
        # 初始化线性层 `wo`,输入维度为 `config.d_ff`,输出维度为 `config.d_model`,无偏置
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
        # 初始化丢弃层,使用 `config.dropout_rate` 的丢弃率
        self.dropout = nn.Dropout(config.dropout_rate)
        # 选择激活函数,根据配置选择 `ACT2FN` 中对应的函数
        self.act = ACT2FN[config.dense_act_fn]

    # 前向传播函数,接收 `hidden_states` 作为输入
    def forward(self, hidden_states):
        # 输入 `hidden_states` 到 `wi` 线性层,得到输出 `hidden_states`
        hidden_states = self.wi(hidden_states)
        # 对 `hidden_states` 应用激活函数 `act`
        hidden_states = self.act(hidden_states)
        # 对 `hidden_states` 应用丢弃层
        hidden_states = self.dropout(hidden_states)
        # 如果 `self.wo.weight` 是 `torch.Tensor` 类型,并且 `hidden_states` 的数据类型与 `self.wo.weight` 的数据类型不同,且 `self.wo.weight` 的数据类型不是 `torch.int8`
        if (
            isinstance(self.wo.weight, torch.Tensor)
            and hidden_states.dtype != self.wo.weight.dtype
            and self.wo.weight.dtype != torch.int8
        ):
            # 将 `hidden_states` 转换为 `self.wo.weight` 的数据类型
            hidden_states = hidden_states.to(self.wo.weight.dtype)
        # 输入 `hidden_states` 到 `wo` 线性层,得到最终输出 `hidden_states`
        hidden_states = self.wo(hidden_states)
        return hidden_states


# 从 `transformers.models.t5.modeling_t5.T5DenseGatedActDense` 复制,并修改为 `Pop2PianoDenseGatedActDense`,同时将 `T5` 修改为 `Pop2Piano`
class Pop2PianoDenseGatedActDense(nn.Module):
    def __init__(self, config: Pop2PianoConfig):
        super().__init__()
        # 初始化两个线性层 `wi_0` 和 `wi_1`,输入维度为 `config.d_model`,输出维度为 `config.d_ff`,无偏置
        self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
        # 初始化线性层 `wo`,输入维度为 `config.d_ff`,输出维度为 `config.d_model`,无偏置
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
        # 初始化丢弃层,使用 `config.dropout_rate` 的丢弃率
        self.dropout = nn.Dropout(config.dropout_rate)
        # 选择激活函数,根据配置选择 `ACT2FN` 中对应的函数
        self.act = ACT2FN[config.dense_act_fn]

    # 前向传播函数,接收 `hidden_states` 作为输入
    def forward(self, hidden_states):
        # 将 `hidden_states` 输入到 `wi_0` 线性层,应用激活函数后得到 `hidden_gelu`
        hidden_gelu = self.act(self.wi_0(hidden_states))
        # 将 `hidden_states` 输入到 `wi_1` 线性层,得到 `hidden_linear`
        hidden_linear = self.wi_1(hidden_states)
        # 将 `hidden_gelu` 与 `hidden_linear` 相乘得到 `hidden_states`
        hidden_states = hidden_gelu * hidden_linear
        # 对 `hidden_states` 应用丢弃层
        hidden_states = self.dropout(hidden_states)

        # 若要使得 8 位量化适用于 google/flan-t5-xxl,保持 `self.wo` 为 `float32`
        # 参见 https://github.com/huggingface/transformers/issues/20287
        # 确保权重不是 `int8` 类型,以防用户强制设置 `_keep_in_fp32_modules` 为 `None`
        if (
            isinstance(self.wo.weight, torch.Tensor)
            and hidden_states.dtype != self.wo.weight.dtype
            and self.wo.weight.dtype != torch.int8
        ):
            # 将 `hidden_states` 转换为 `self.wo.weight` 的数据类型
            hidden_states = hidden_states.to(self.wo.weight.dtype)

        # 输入 `hidden_states` 到 `wo` 线性层,得到最终输出 `hidden_states`
        hidden_states = self.wo(hidden_states)
        return hidden_states


# 从 `transformers.models.t5.modeling_t5.T5LayerFF` 复制,并修改为 `Pop2PianoLayerFF`,同时将 `T5` 修改为 `Pop2Piano`
class Pop2PianoLayerFF(nn.Module):
    # 初始化方法,接受一个配置对象作为参数
    def __init__(self, config: Pop2PianoConfig):
        # 调用父类的初始化方法
        super().__init__()
        # 根据配置中的是否启用门控激活函数的标志,选择不同的神经网络层结构
        if config.is_gated_act:
            self.DenseReluDense = Pop2PianoDenseGatedActDense(config)
        else:
            self.DenseReluDense = Pop2PianoDenseActDense(config)

        # 初始化层归一化对象,设置归一化的维度和 epsilon 值
        self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
        # 初始化 dropout 层,设置 dropout 的丢弃率
        self.dropout = nn.Dropout(config.dropout_rate)

    # 前向传播方法,接受隐藏状态作为输入,返回处理后的隐藏状态
    def forward(self, hidden_states):
        # 对输入的隐藏状态进行层归一化处理
        forwarded_states = self.layer_norm(hidden_states)
        # 将归一化后的状态输入到 DenseReluDense 网络中进行处理
        forwarded_states = self.DenseReluDense(forwarded_states)
        # 将原始隐藏状态与 dropout 处理后的输出相加,得到最终的隐藏状态
        hidden_states = hidden_states + self.dropout(forwarded_states)
        # 返回处理后的隐藏状态
        return hidden_states
# 从transformers.models.t5.modeling_t5.T5Attention中复制而来,用于Pop2Piano模型的注意力机制实现
class Pop2PianoAttention(nn.Module):
    def __init__(self, config: Pop2PianoConfig, has_relative_attention_bias=False):
        super().__init__()
        self.is_decoder = config.is_decoder  # 标记是否为解码器
        self.has_relative_attention_bias = has_relative_attention_bias  # 是否包含相对注意力偏置
        self.relative_attention_num_buckets = config.relative_attention_num_buckets  # 相对注意力偏置的桶数
        self.relative_attention_max_distance = config.relative_attention_max_distance  # 相对注意力的最大距离
        self.d_model = config.d_model  # 模型的维度
        self.key_value_proj_dim = config.d_kv  # 键值投影的维度
        self.n_heads = config.num_heads  # 注意力头的数量
        self.dropout = config.dropout_rate  # Dropout率
        self.inner_dim = self.n_heads * self.key_value_proj_dim  # 内部维度,注意力头乘以投影维度

        # 使用线性层定义查询(q), 键(k), 值(v)和输出(o)
        self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)

        if self.has_relative_attention_bias:
            # 如果需要相对注意力偏置,使用Embedding层来存储偏置信息
            self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
        self.pruned_heads = set()  # 初始化被修剪的注意力头集合为空
        self.gradient_checkpointing = False  # 梯度检查点标志设置为False

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        # 查找可修剪的注意力头和它们的索引
        heads, index = find_pruneable_heads_and_indices(
            heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
        )
        # 修剪线性层
        self.q = prune_linear_layer(self.q, index)
        self.k = prune_linear_layer(self.k, index)
        self.v = prune_linear_layer(self.v, index)
        self.o = prune_linear_layer(self.o, index, dim=1)
        # 更新超参数
        self.n_heads = self.n_heads - len(heads)
        self.inner_dim = self.key_value_proj_dim * self.n_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    @staticmethod
    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
        """
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

        Translate relative position to a bucket number for relative attention. The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
        This should allow for more graceful generalization to longer sequences than the model has been trained on

        Args:
            relative_position: an int32 Tensor - 相对位置的整数张量
            bidirectional: a boolean - 是否是双向注意力
            num_buckets: an integer - 桶的数量
            max_distance: an integer - 最大距离限制

        Returns:
            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
            返回一个形状与relative_position相同的张量,包含在区间[0, num_buckets)内的int32值
        """
        relative_buckets = 0  # 初始化相对位置桶号为0

        if bidirectional:
            num_buckets //= 2  # 如果是双向注意力,桶的数量减半
            relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
            # 如果相对位置大于0,则加上一半的桶数作为桶偏移量
            relative_position = torch.abs(relative_position)  # 取相对位置的绝对值
        else:
            relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
            # 如果是单向注意力,将相对位置限制为非正数

        # 现在相对位置范围为[0, inf)

        # 小部分的桶用于准确的位置增量
        max_exact = num_buckets // 2
        is_small = relative_position < max_exact

        # 另一半桶用于位置对数级别增大,直到max_distance
        relative_position_if_large = max_exact + (
            torch.log(relative_position.float() / max_exact)
            / math.log(max_distance / max_exact)
            * (num_buckets - max_exact)
        ).to(torch.long)
        relative_position_if_large = torch.min(
            relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
        )

        relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
        # 根据is_small条件选择桶号,累加到相对位置桶号中

        return relative_buckets  # 返回相对位置桶号的张量
    def compute_bias(self, query_length, key_length, device=None):
        """Compute binned relative position bias"""
        # 如果设备未指定,使用 self.relative_attention_bias 的设备
        if device is None:
            device = self.relative_attention_bias.weight.device
        # 创建一个形状为 (query_length, 1) 的张量,表示查询序列的位置
        context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
        # 创建一个形状为 (1, key_length) 的张量,表示记忆序列的位置
        memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
        # 计算相对位置偏差,形状为 (query_length, key_length)
        relative_position = memory_position - context_position
        # 将相对位置映射到桶中,返回形状为 (query_length, key_length)
        relative_position_bucket = self._relative_position_bucket(
            relative_position,
            bidirectional=(not self.is_decoder),
            num_buckets=self.relative_attention_num_buckets,
            max_distance=self.relative_attention_max_distance,
        )
        # 使用 self.relative_attention_bias 对相对位置桶进行加权,形状变为 (query_length, key_length, num_heads)
        values = self.relative_attention_bias(relative_position_bucket)
        # 将结果进行维度变换,形状变为 (1, num_heads, query_length, key_length)
        values = values.permute([2, 0, 1]).unsqueeze(0)
        # 返回最终的相对位置偏差张量
        return values

    def forward(
        self,
        hidden_states,
        mask=None,
        key_value_states=None,
        position_bias=None,
        past_key_value=None,
        layer_head_mask=None,
        query_length=None,
        use_cache=False,
        output_attentions=False,
# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Pop2Piano,t5->pop2piano
class Pop2PianoLayerSelfAttention(nn.Module):
    def __init__(self, config, has_relative_attention_bias=False):
        super().__init__()
        # 初始化自注意力层,使用 Pop2PianoAttention 模块
        self.SelfAttention = Pop2PianoAttention(config, has_relative_attention_bias=has_relative_attention_bias)
        # 初始化层归一化模块,使用 Pop2PianoLayerNorm 进行归一化
        self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
        # 初始化 dropout 模块,丢弃率为 config.dropout_rate
        self.dropout = nn.Dropout(config.dropout_rate)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
        layer_head_mask=None,
        past_key_value=None,
        use_cache=False,
        output_attentions=False,
    ):
        # 对输入的 hidden_states 进行层归一化
        normed_hidden_states = self.layer_norm(hidden_states)
        # 使用 SelfAttention 进行自注意力计算
        attention_output = self.SelfAttention(
            normed_hidden_states,
            mask=attention_mask,
            position_bias=position_bias,
            layer_head_mask=layer_head_mask,
            past_key_value=past_key_value,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        # 将原始 hidden_states 和 dropout 后的 attention_output 相加作为最终的输出
        hidden_states = hidden_states + self.dropout(attention_output[0])
        # 构建输出元组,包含更新后的 hidden_states 和可能的 attention 输出
        outputs = (hidden_states,) + attention_output[1:]  # add attentions if we output them
        return outputs


# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Pop2Piano,t5->pop2piano
class Pop2PianoLayerCrossAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化编码解码注意力层,使用 Pop2PianoAttention 模块
        self.EncDecAttention = Pop2PianoAttention(config, has_relative_attention_bias=False)
        # 初始化层归一化模块,使用 Pop2PianoLayerNorm 进行归一化
        self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
        # 初始化 dropout 模块,丢弃率为 config.dropout_rate
        self.dropout = nn.Dropout(config.dropout_rate)

    def forward(
        self,
        hidden_states,
        key_value_states,
        attention_mask=None,
        position_bias=None,
        layer_head_mask=None,
        past_key_value=None,
        use_cache=False,
        query_length=None,
        output_attentions=False,
    ):
        # 对输入的 hidden_states 进行层归一化
        normed_hidden_states = self.layer_norm(hidden_states)
        # 使用 EncDecAttention 进行编码解码注意力计算
        attention_output = self.EncDecAttention(
            normed_hidden_states,
            mask=attention_mask,
            key_value_states=key_value_states,
            position_bias=position_bias,
            layer_head_mask=layer_head_mask,
            past_key_value=past_key_value,
            use_cache=use_cache,
            query_length=query_length,
            output_attentions=output_attentions,
        )
        # 将原始 hidden_states 和 dropout 后的 attention_output 相加作为最终的输出
        layer_output = hidden_states + self.dropout(attention_output[0])
        # 构建输出元组,包含更新后的 layer_output 和可能的 attention 输出
        outputs = (layer_output,) + attention_output[1:]  # add attentions if we output them
        return outputs


# Copied from transformers.models.t5.modeling_t5.T5Block with T5->Pop2Piano,t5->pop2piano
class Pop2PianoBlock(nn.Module):
    # 初始化方法,接受配置参数和是否包含相对注意力偏置的标志
    def __init__(self, config, has_relative_attention_bias=False):
        # 调用父类的初始化方法
        super().__init__()
        # 根据配置设置当前模块是否为解码器
        self.is_decoder = config.is_decoder
        # 创建一个空的模块列表,用于存储不同层的模块
        self.layer = nn.ModuleList()
        # 向模块列表中添加一个自注意力层,使用Pop2PianoLayerSelfAttention类初始化
        self.layer.append(Pop2PianoLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
        # 如果当前模块是解码器,向模块列表中添加一个交叉注意力层
        if self.is_decoder:
            self.layer.append(Pop2PianoLayerCrossAttention(config))

        # 向模块列表中添加一个Feed Forward层
        self.layer.append(Pop2PianoLayerFF(config))

    # 前向传播方法,接收多个参数来执行前向计算
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        encoder_decoder_position_bias=None,
        layer_head_mask=None,
        cross_attn_layer_head_mask=None,
        past_key_value=None,
        use_cache=False,
        output_attentions=False,
        return_dict=True,
# 定义一个继承自PreTrainedModel的抽象类,用于处理权重初始化和预训练模型的下载与加载接口
class Pop2PianoPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 指定配置类为Pop2PianoConfig
    config_class = Pop2PianoConfig
    # 基础模型的前缀,用于命名
    base_model_prefix = "transformer"
    # 不支持模型并行化
    is_parallelizable = False
    # 支持梯度检查点
    supports_gradient_checkpointing = True
    # 不需要拆分的模块列表
    _no_split_modules = ["Pop2PianoBlock"]
    # 需要保持在fp32精度的模块列表
    _keep_in_fp32_modules = ["wo"]

    # 将输入的ids向右移动一位的方法
    def _shift_right(self, input_ids):
        # 获取解码器起始标记id
        decoder_start_token_id = self.config.decoder_start_token_id
        # 获取填充标记id
        pad_token_id = self.config.pad_token_id

        # 如果解码器起始标记id未定义,则抛出数值错误
        if decoder_start_token_id is None:
            raise ValueError(
                "self.model.config.decoder_start_token_id has to be defined. In Pop2Piano it is usually set to the pad_token_id."
            )

        # 将输入向右移动一位
        if is_torch_fx_proxy(input_ids):
            # 对于torch.fx代理,不支持原生的项目分配
            shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
            shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
        else:
            shifted_input_ids = input_ids.new_zeros(input_ids.shape)
            shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
            shifted_input_ids[..., 0] = decoder_start_token_id

        # 如果填充标记id未定义,则抛出数值错误
        if pad_token_id is None:
            raise ValueError("self.model.config.pad_token_id has to be defined.")
        # 将标签中可能存在的-100值替换为填充标记id
        shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

        return shifted_input_ids


# 定义一个继承自Pop2PianoPreTrainedModel的类Pop2PianoStack
class Pop2PianoStack(Pop2PianoPreTrainedModel):
    # 从transformers.models.t5.modeling_t5.T5Stack.__init__中复制而来,修改为Pop2PianoStack
    def __init__(self, config, embed_tokens=None):
        super().__init__(config)

        # 嵌入标记,可以是None
        self.embed_tokens = embed_tokens
        # 是否是解码器
        self.is_decoder = config.is_decoder

        # 使用列表推导式创建模块列表block,每个Pop2PianoBlock都有一个相对注意偏置
        self.block = nn.ModuleList(
            [Pop2PianoBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
        )
        # 最终层的LayerNorm
        self.final_layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
        # Dropout层
        self.dropout = nn.Dropout(config.dropout_rate)

        # 初始化权重并应用最终处理
        self.post_init()
        # 模型并行化,默认为False
        self.model_parallel = False
        # 设备映射,默认为None
        self.device_map = None
        # 梯度检查点,默认为False
        self.gradient_checkpointing = False

    # 从transformers.models.t5.modeling_t5.T5Stack.get_input_embeddings中复制而来
    def get_input_embeddings(self):
        return self.embed_tokens

    # 从transformers.models.t5.modeling_t5.T5Stack.set_input_embeddings中复制而来
    def set_input_embeddings(self, new_embeddings):
        self.embed_tokens = new_embeddings
    # 定义模型的前向传播方法,接收多个输入参数
    def forward(
        self,
        input_ids=None,  # 输入的 token IDs
        attention_mask=None,  # 自注意力机制的掩码,指示哪些 token 应该被忽略
        encoder_hidden_states=None,  # 编码器的隐藏状态(通常用于 Transformer 架构)
        encoder_attention_mask=None,  # 编码器的注意力掩码(如果有的话)
        inputs_embeds=None,  # 输入的嵌入表示(如不直接传入 token IDs 而是其它形式的输入)
        head_mask=None,  # 头部掩码,用于遮蔽特定的注意力头
        cross_attn_head_mask=None,  # 跨注意力头的掩码
        past_key_values=None,  # 过去的键值对,用于支持自回归生成
        use_cache=None,  # 是否使用缓存加速
        output_attentions=None,  # 是否输出注意力权重
        output_hidden_states=None,  # 是否输出隐藏状态
        return_dict=None,  # 是否返回一个字典作为输出
class Pop2PianoConcatEmbeddingToMel(nn.Module):
    """Embedding Matrix for `composer` tokens."""

    def __init__(self, config):
        super().__init__()
        # 使用 nn.Embedding 创建一个嵌入矩阵,用于存储 `composer` tokens 的嵌入向量
        self.embedding = nn.Embedding(num_embeddings=config.composer_vocab_size, embedding_dim=config.d_model)

    def forward(self, feature, index_value, embedding_offset):
        # 根据给定的偏移量调整索引值
        index_shifted = index_value - embedding_offset
        # 通过嵌入层获取对应的 `composer` tokens 的嵌入向量,并添加一个维度
        composer_embedding = self.embedding(index_shifted).unsqueeze(1)
        # 将 composer_embedding 和输入特征 feature 在维度 1 进行连接
        inputs_embeds = torch.cat([composer_embedding, feature], dim=1)
        return inputs_embeds


Pop2Piano_START_DOCSTRING = r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`Pop2PianoConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""


@add_start_docstrings("""Pop2Piano Model with a `language modeling` head on top.""", Pop2Piano_START_DOCSTRING)
class Pop2PianoForConditionalGeneration(Pop2PianoPreTrainedModel):
    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]

    def __init__(self, config: Pop2PianoConfig):
        super().__init__(config)
        self.config = config
        self.model_dim = config.d_model

        # 创建一个共享的嵌入层,用于模型的输入和输出
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        # 创建一个 Pop2PianoConcatEmbeddingToMel 类的实例,用于处理 composer tokens 的嵌入
        self.mel_conditioner = Pop2PianoConcatEmbeddingToMel(config)

        # 初始化编码器和解码器的配置
        encoder_config = copy.deepcopy(config)
        encoder_config.is_decoder = False
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False

        # 创建编码器堆栈
        self.encoder = Pop2PianoStack(encoder_config, self.shared)

        # 初始化解码器的配置
        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
        decoder_config.is_encoder_decoder = False
        decoder_config.num_layers = config.num_decoder_layers

        # 创建解码器堆栈
        self.decoder = Pop2PianoStack(decoder_config, self.shared)

        # 创建语言模型头部,用于输出预测的下一个 token
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

        # 初始化权重并进行最终处理
        self.post_init()

    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        # 更新共享嵌入层的嵌入向量
        self.shared = new_embeddings
        # 更新编码器和解码器的输入嵌入层
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)
    # 设置新的输出嵌入层,用于语言模型的生成
    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    # 获取当前的输出嵌入层
    def get_output_embeddings(self):
        return self.lm_head

    # 获取编码器(encoder)模型
    def get_encoder(self):
        return self.encoder

    # 获取解码器(decoder)模型
    def get_decoder(self):
        return self.decoder

    # 获取 Mel conditioner 输出,用于在生成模型中控制 MIDI token 的类型
    def get_mel_conditioner_outputs(
        self,
        input_features: torch.FloatTensor,
        composer: str,
        generation_config: GenerationConfig,
        attention_mask: torch.FloatTensor = None,
    ):
        """
        This method is used to concatenate mel conditioner tokens at the front of the input_features in order to
        control the type of MIDI token generated by the model.

        Args:
            input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                input features extracted from the feature extractor.
            composer (`str`):
                composer token which determines the type of MIDI tokens to be generated.
            generation_config (`~generation.GenerationConfig`):
                The generation is used to get the composer-feature_token pair.
            attention_mask (`torch.FloatTensor`, *optional*):
                For batched generation, input_features are padded to have the same shape across all examples.
                `attention_mask` helps determine which areas were padded and which were not:
                - 1 for tokens that are **not padded**,
                - 0 for tokens that are **padded**.
        """
        # 获取 composer 对应的 feature_token 值
        composer_to_feature_token = generation_config.composer_to_feature_token
        # 如果 composer 不在 composer_to_feature_token 的键中,抛出 ValueError
        if composer not in composer_to_feature_token.keys():
            raise ValueError(
                f"Please choose a composer from {list(composer_to_feature_token.keys())}. Composer received - {composer}"
            )
        # 获取 composer 对应的值,并将其转换为 torch.Tensor
        composer_value = composer_to_feature_token[composer]
        composer_value = torch.tensor(composer_value, device=self.device)
        # 将 composer_value 在 batch 维度上重复,以便与 input_features 对齐
        composer_value = composer_value.repeat(input_features.shape[0])

        # 获取最小的 embedding offset
        embedding_offset = min(composer_to_feature_token.values())

        # 调用 self.mel_conditioner 方法,添加 composer_value 到 input_features 的前部
        input_features = self.mel_conditioner(
            feature=input_features,
            index_value=composer_value,
            embedding_offset=embedding_offset,
        )
        # 如果存在 attention_mask,则根据其值对 input_features 进行调整
        if attention_mask is not None:
            input_features[~attention_mask[:, 0].bool()] = 0.0

            # 由于 self.mel_conditioner 在 inputs_embeds 前添加了一个新数组,需要对 attention_mask 做同样处理以保持形状一致
            attention_mask = torch.cat([attention_mask[:, 0].view(-1, 1), attention_mask], dim=1)
            return input_features, attention_mask

        # 如果 attention_mask 为 None,则返回调整后的 input_features 和 None
        return input_features, None

    # 添加文档字符串到模型的前向方法,用于描述 POP2PIANO_INPUTS_DOCSTRING
    # 替换返回文档字符串,输出类型为 Seq2SeqLMOutput,配置类为 _CONFIG_FOR_DOC
    @torch.no_grad()
    def generate(
        self,
        input_features,
        attention_mask=None,
        composer="composer1",
        generation_config=None,
        **kwargs,
    ):
        # 生成器函数,用于生成模型的输出结果
        # input_features: 输入特征
        # attention_mask: 注意力掩码,控制哪些位置的输入会被模型关注
        # composer: 生成器的名称,默认为"composer1"
        # generation_config: 生成配置,控制生成过程的参数配置
        # **kwargs: 其他可选参数

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):
        # 为生成过程准备模型输入
        # input_ids: 输入的token IDs
        # past_key_values: 过去的键值,用于生成过程中的速度和效率
        # attention_mask: 注意力掩码,控制哪些位置的输入会被模型关注
        # head_mask: 头部掩码,控制哪些注意力头部会被激活
        # decoder_head_mask: 解码器头部掩码,控制解码器的注意力头部
        # cross_attn_head_mask: 跨注意力头部掩码,控制跨注意力模块的头部
        # use_cache: 是否使用缓存,提高生成效率
        # encoder_outputs: 编码器的输出,用于解码过程
        # **kwargs: 其他可选参数

    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        # 根据标签准备解码器输入的token IDs
        # labels: 模型的目标标签
        return self._shift_right(labels)
        # 调用内部函数_shift_right,将标签向右移动,用作解码器的输入
    # 重新排序缓存中的过去键值,根据给定的 beam_idx
    def _reorder_cache(self, past_key_values, beam_idx):
        # 如果过去的键值未包含在输出中
        if past_key_values is None:
            # 提示用户设置 `use_cache=True` 来加快解码速度
            logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
            # 返回原始的过去键值
            return past_key_values

        # 初始化重新排序后的解码器过去状态的元组
        reordered_decoder_past = ()
        
        # 遍历每一层的过去状态
        for layer_past_states in past_key_values:
            # 初始化当前层重新排序后的过去状态的元组
            reordered_layer_past_states = ()
            
            # 遍历当前层的每一个过去状态
            for layer_past_state in layer_past_states:
                # 根据给定的 beam_idx 选择正确的批次索引,以匹配过去状态的设备
                reordered_layer_past_states = reordered_layer_past_states + (
                    layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
                )

            # 检查重新排序后的第一个过去状态的形状与原始的第一个过去状态的形状是否匹配
            if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
                raise ValueError(
                    f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
                )
            
            # 检查重新排序后的过去状态的长度是否与原始过去状态的长度匹配
            if len(reordered_layer_past_states) != len(layer_past_states):
                raise ValueError(
                    f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
                )

            # 将当前层重新排序后的过去状态添加到解码器过去状态元组中
            reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)

        # 返回重新排序后的解码器过去状态
        return reordered_decoder_past

.\models\pop2piano\processing_pop2piano.py

# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Processor class for Pop2Piano."""

import os
from typing import List, Optional, Union

import numpy as np

from ...feature_extraction_utils import BatchFeature
from ...processing_utils import ProcessorMixin
from ...tokenization_utils import BatchEncoding, PaddingStrategy, TruncationStrategy
from ...utils import TensorType


class Pop2PianoProcessor(ProcessorMixin):
    r"""
    Constructs an Pop2Piano processor which wraps a Pop2Piano Feature Extractor and Pop2Piano Tokenizer into a single
    processor.

    [`Pop2PianoProcessor`] offers all the functionalities of [`Pop2PianoFeatureExtractor`] and [`Pop2PianoTokenizer`].
    See the docstring of [`~Pop2PianoProcessor.__call__`] and [`~Pop2PianoProcessor.decode`] for more information.

    Args:
        feature_extractor (`Pop2PianoFeatureExtractor`):
            An instance of [`Pop2PianoFeatureExtractor`]. The feature extractor is a required input.
        tokenizer (`Pop2PianoTokenizer`):
            An instance of ['Pop2PianoTokenizer`]. The tokenizer is a required input.
    """

    attributes = ["feature_extractor", "tokenizer"]
    feature_extractor_class = "Pop2PianoFeatureExtractor"
    tokenizer_class = "Pop2PianoTokenizer"

    def __init__(self, feature_extractor, tokenizer):
        super().__init__(feature_extractor, tokenizer)

    def __call__(
        self,
        audio: Union[np.ndarray, List[float], List[np.ndarray]] = None,
        sampling_rate: Union[int, List[int]] = None,
        steps_per_beat: int = 2,
        resample: Optional[bool] = True,
        notes: Union[List, TensorType] = None,
        padding: Union[bool, str, PaddingStrategy] = False,
        truncation: Union[bool, str, TruncationStrategy] = None,
        max_length: Optional[int] = None,
        pad_to_multiple_of: Optional[int] = None,
        verbose: bool = True,
        **kwargs,
    ):
        """
        Call method to process input audio data into features suitable for Pop2Piano model.

        Args:
            audio (Union[np.ndarray, List[float], List[np.ndarray]], optional):
                Input audio data. Can be a numpy array, list of floats, or list of numpy arrays.
            sampling_rate (Union[int, List[int]], optional):
                Sampling rate of the input audio. Can be an integer or a list of integers.
            steps_per_beat (int, optional):
                Number of steps per beat in the musical sequence.
            resample (bool, optional):
                Whether to resample the input audio to the specified sampling rate.
            notes (Union[List, TensorType], optional):
                Musical notes associated with the audio data. Can be a list or TensorType.
            padding (Union[bool, str, PaddingStrategy], optional):
                Padding strategy to apply to the input data.
            truncation (Union[bool, str, TruncationStrategy], optional):
                Truncation strategy to apply to the input data.
            max_length (int, optional):
                Maximum length of the output sequence.
            pad_to_multiple_of (int, optional):
                Pad the sequence length to be a multiple of this value.
            verbose (bool, optional):
                Whether to print verbose information during processing.

            **kwargs:
                Additional keyword arguments for processing.

        Returns:
            BatchEncoding:
                Processed batch of encoded inputs suitable for Pop2Piano model.
        """
        # Implementation details for processing audio data using the provided feature extractor and tokenizer
        pass  # Placeholder for actual implementation
    ) -> Union[BatchFeature, BatchEncoding]:
        """
        使用 [`Pop2PianoFeatureExtractor.__call__`] 方法准备模型的对数梅尔频谱图(log-mel-spectrograms),
        并使用 [`Pop2PianoTokenizer.__call__`] 方法从音符中准备 token_ids。

        请查阅上述两个方法的文档字符串以获取更多信息。
        """

        # 因为特征提取器需要音频和采样率,而分词器需要 token_ids 和特征提取器的输出,所以必须同时检查两者。
        if (audio is None and sampling_rate is None) and (notes is None):
            raise ValueError(
                "You have to specify at least audios and sampling_rate in order to use feature extractor or "
                "notes to use the tokenizer part."
            )

        if audio is not None and sampling_rate is not None:
            # 调用特征提取器,生成模型的输入
            inputs = self.feature_extractor(
                audio=audio,
                sampling_rate=sampling_rate,
                steps_per_beat=steps_per_beat,
                resample=resample,
                **kwargs,
            )
        if notes is not None:
            # 调用分词器,生成音符的 token_ids
            encoded_token_ids = self.tokenizer(
                notes=notes,
                padding=padding,
                truncation=truncation,
                max_length=max_length,
                pad_to_multiple_of=pad_to_multiple_of,
                verbose=verbose,
                **kwargs,
            )

        if notes is None:
            # 如果没有音符,返回特征提取器生成的输入
            return inputs

        elif audio is None or sampling_rate is None:
            # 如果没有音频或采样率,返回分词器生成的 token_ids
            return encoded_token_ids

        else:
            # 否则,将分词器生成的 token_ids 添加到特征提取器生成的输入中,并返回
            inputs["token_ids"] = encoded_token_ids["token_ids"]
            return inputs

    def batch_decode(
        self,
        token_ids,
        feature_extractor_output: BatchFeature,
        return_midi: bool = True,
    ) -> BatchEncoding:
        """
        使用 [`Pop2PianoTokenizer.batch_decode`] 方法将模型生成的 token_ids 转换为 midi_notes。

        请查阅上述方法的文档字符串以获取更多信息。
        """

        return self.tokenizer.batch_decode(
            token_ids=token_ids, feature_extractor_output=feature_extractor_output, return_midi=return_midi
        )

    @property
    def model_input_names(self):
        """
        返回模型输入的名称列表,包括分词器和特征提取器的输入名称。

        使用 `self.tokenizer.model_input_names` 和 `self.feature_extractor.model_input_names` 获取输入名称列表,
        并将两者合并后去除重复项后返回。
        """
        tokenizer_input_names = self.tokenizer.model_input_names
        feature_extractor_input_names = self.feature_extractor.model_input_names
        return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))

    def save_pretrained(self, save_directory, **kwargs):
        """
        将模型的预训练文件保存到指定目录中。

        如果 `save_directory` 是文件而不是目录,将引发 ValueError。
        如果目录不存在,则创建目录。
        最后,调用父类的 `save_pretrained` 方法保存预训练文件。

        Args:
            save_directory (str): 要保存预训练文件的目录路径。
            **kwargs: 其他参数传递给 `save_pretrained` 方法。

        Returns:
            Any: `save_pretrained` 方法的返回值。

        Raises:
            ValueError: 如果 `save_directory` 是文件路径而不是目录路径。
        """
        if os.path.isfile(save_directory):
            raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
        os.makedirs(save_directory, exist_ok=True)
        return super().save_pretrained(save_directory, **kwargs)

    @classmethod
    # 类方法,用于从预训练模型名称或路径中获取参数,并返回一个参数列表
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        # 调用类方法 _get_arguments_from_pretrained,获取预训练模型的参数
        args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
        # 使用获取到的参数列表创建并返回当前类的实例
        return cls(*args)
posted @ 2024-06-29 15:48  绝不原创的飞龙  阅读(10)  评论(0编辑  收藏  举报