Transformers-源码解析-九十三-

Transformers 源码解析(九十三)

.\models\reformer\tokenization_reformer_fast.py

# 设置文件编码为 UTF-8
# 版权声明及许可协议信息
# 引入操作系统模块和复制文件函数
# 引入类型提示模块中的 Optional 和 Tuple 类型
import os
from shutil import copyfile
from typing import Optional, Tuple

# 从 tokenization_utils_fast 中引入 PreTrainedTokenizerFast 类
# 从 utils 中引入 is_sentencepiece_available 和 logging 函数
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import is_sentencepiece_available, logging

# 如果 sentencepiece 可用,从 tokenization_reformer 中引入 ReformerTokenizer 类,否则为 None
if is_sentencepiece_available():
    from .tokenization_reformer import ReformerTokenizer
else:
    ReformerTokenizer = None

# 获取日志记录器对象
logger = logging.get_logger(__name__)

# 定义特殊的单词分隔符
SPIECE_UNDERLINE = "▁"

# 定义词汇文件名映射
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"}

# 定义预训练模型的词汇文件映射
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "google/reformer-crime-and-punishment": (
            "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/spiece.model"
        )
    },
    "tokenizer_file": {
        "google/reformer-crime-and-punishment": (
            "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/tokenizer.json"
        )
    },
}

# 定义预训练模型的位置嵌入尺寸
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "google/reformer-crime-and-punishment": 524288,
}

# 定义 ReformerTokenizerFast 类,继承自 PreTrainedTokenizerFast
class ReformerTokenizerFast(PreTrainedTokenizerFast):
    """
    构建一个“快速”Reformer分词器(由HuggingFace的tokenizers库支持)。基于Unigram模型。

    这个分词器继承自 PreTrainedTokenizerFast,包含大多数主要方法。用户应该参考这个超类来获取更多关于这些方法的信息。
    """
    Args:
        vocab_file (`str`):
            [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
            contains the vocabulary necessary to instantiate a tokenizer.
        eos_token (`str`, *optional*, defaults to `"</s>"`):
            The end of sequence token.

            <Tip>

            When building a sequence using special tokens, this is not the token that is used for the end of sequence.
            The token used is the `sep_token`.

            </Tip>

        unk_token (`str`, *optional*, defaults to `"<unk>"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        pad_token (`str`, *optional*, defaults to `"<pad>"`):
            The token used for padding, for example when batching sequences of different lengths.
        additional_special_tokens (`List[str]`, *optional*):
            Additional special tokens used by the tokenizer.
    """

    # 获取预定义的文件名常量列表
    vocab_files_names = VOCAB_FILES_NAMES
    # 获取预训练模型使用的词汇文件映射
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    # 获取预训练位置嵌入的最大模型输入尺寸
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    # 定义模型输入名称列表
    model_input_names = ["input_ids", "attention_mask"]
    # 慢速标记器类定义为 ReformerTokenizer
    slow_tokenizer_class = ReformerTokenizer

    def __init__(
        self,
        vocab_file=None,
        tokenizer_file=None,
        eos_token="</s>",
        unk_token="<unk>",
        additional_special_tokens=[],
        **kwargs,
    ):
        # 调用父类的初始化方法,传递参数以设置词汇文件、标记器文件、特殊标记等
        super().__init__(
            vocab_file,
            tokenizer_file=tokenizer_file,
            eos_token=eos_token,
            unk_token=unk_token,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )

        # 将参数中的词汇文件路径保存到对象属性中
        self.vocab_file = vocab_file

    @property
    def can_save_slow_tokenizer(self) -> bool:
        # 检查当前对象是否具备保存慢速标记器所需的信息,主要是检查词汇文件是否存在
        return os.path.isfile(self.vocab_file) if self.vocab_file else False

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # 如果无法保存慢速标记器,则引发 ValueError 异常
        if not self.can_save_slow_tokenizer:
            raise ValueError(
                "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
                "tokenizer."
            )

        # 如果保存路径不是一个目录,则记录错误并返回
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return

        # 指定输出词汇文件的路径
        out_vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )

        # 如果当前词汇文件路径与输出路径不一致,则复制词汇文件到输出路径
        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
            copyfile(self.vocab_file, out_vocab_file)

        # 返回保存的词汇文件路径的元组
        return (out_vocab_file,)

.\models\reformer\__init__.py

# 导入必要的类型检查模块
from typing import TYPE_CHECKING

# 导入所需的工具函数和异常类
from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_sentencepiece_available,
    is_tokenizers_available,
    is_torch_available,
)

# 定义模块的导入结构字典,包含相关配置和类名
_import_structure = {"configuration_reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"]}

# 检查是否存在 sentencepiece 库,若不可用则抛出异常
try:
    if not is_sentencepiece_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若可用,则加入 tokenization_reformer 模块到导入结构中
    _import_structure["tokenization_reformer"] = ["ReformerTokenizer"]

# 检查是否存在 tokenizers 库,若不可用则抛出异常
try:
    if not is_tokenizers_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若可用,则加入 tokenization_reformer_fast 模块到导入结构中
    _import_structure["tokenization_reformer_fast"] = ["ReformerTokenizerFast"]

# 检查是否存在 torch 库,若不可用则抛出异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若可用,则加入 modeling_reformer 模块到导入结构中,包含多个类和常量
    _import_structure["modeling_reformer"] = [
        "REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
        "ReformerAttention",
        "ReformerForMaskedLM",
        "ReformerForQuestionAnswering",
        "ReformerForSequenceClassification",
        "ReformerLayer",
        "ReformerModel",
        "ReformerModelWithLMHead",
        "ReformerPreTrainedModel",
    ]

# 如果在类型检查模式下
if TYPE_CHECKING:
    # 导入 configuration_reformer 模块中的指定内容
    from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig

    try:
        # 检查是否存在 sentencepiece 库,若不可用则抛出异常
        if not is_sentencepiece_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 若可用,则导入 tokenization_reformer 模块中的 ReformerTokenizer
        from .tokenization_reformer import ReformerTokenizer

    try:
        # 检查是否存在 tokenizers 库,若不可用则抛出异常
        if not is_tokenizers_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 若可用,则导入 tokenization_reformer_fast 模块中的 ReformerTokenizerFast
        from .tokenization_reformer_fast import ReformerTokenizerFast

    try:
        # 检查是否存在 torch 库,若不可用则抛出异常
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入模块中的一系列符号和类
        from .modeling_reformer import (
            REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,  # 导入预训练模型的存档列表
            ReformerAttention,                      # 导入Reformer模型中的Attention类
            ReformerForMaskedLM,                    # 导入用于Masked Language Modeling的Reformer模型类
            ReformerForQuestionAnswering,           # 导入用于问答任务的Reformer模型类
            ReformerForSequenceClassification,      # 导入用于序列分类任务的Reformer模型类
            ReformerLayer,                          # 导入Reformer模型的一个层类
            ReformerModel,                          # 导入Reformer模型类
            ReformerModelWithLMHead,                # 导入带有LM头的Reformer模型类
            ReformerPreTrainedModel,                # 导入预训练的Reformer模型类
        )
else:
    # 导入 sys 模块,用于动态修改当前模块的属性
    import sys

    # 使用 sys.modules[__name__] 将当前模块替换为 LazyModule 的实例
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\regnet\configuration_regnet.py

# 导入所需模块和类
from ...configuration_utils import PretrainedConfig
from ...utils import logging

# 获取全局日志记录器
logger = logging.get_logger(__name__)

# 预训练配置文件的映射字典,将模型名称映射到配置文件的URL
REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "facebook/regnet-y-040": "https://huggingface.co/facebook/regnet-y-040/blob/main/config.json",
}

# RegNetConfig 类,继承自 PretrainedConfig 类,用于存储 RegNet 模型的配置信息
class RegNetConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`RegNetModel`]. It is used to instantiate a RegNet
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the RegNet
    [facebook/regnet-y-040](https://huggingface.co/facebook/regnet-y-040) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        num_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        embedding_size (`int`, *optional*, defaults to 64):
            Dimensionality (hidden size) for the embedding layer.
        hidden_sizes (`List[int]`, *optional*, defaults to `[256, 512, 1024, 2048]`):
            Dimensionality (hidden size) at each stage.
        depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 3]`):
            Depth (number of layers) for each stage.
        layer_type (`str`, *optional*, defaults to `"y"`):
            The layer to use, it can be either `"x" or `"y"`. An `x` layer is a ResNet's BottleNeck layer with
            `reduction` fixed to `1`. While a `y` layer is a `x` but with squeeze and excitation. Please refer to the
            paper for a detailed explanation of how these layers were constructed.
        hidden_act (`str`, *optional*, defaults to `"relu"`):
            The non-linear activation function in each block. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"`
            are supported.
        downsample_in_first_stage (`bool`, *optional*, defaults to `False`):
            If `True`, the first stage will downsample the inputs using a `stride` of 2.

    Example:
    ```
    >>> from transformers import RegNetConfig, RegNetModel

    >>> # Initializing a RegNet regnet-y-40 style configuration

    ```
    """
    configuration = RegNetConfig()
    # 使用 RegNetConfig 类创建一个配置对象

    model = RegNetModel(configuration)
    # 使用 RegNetModel 类基于给定的配置对象创建一个模型对象

    configuration = model.config
    # 获取模型对象的配置信息
    """

    model_type = "regnet"
    # 定义模型类型为 "regnet"

    layer_types = ["x", "y"]
    # 支持的层类型列表,包括 'x' 和 'y'

    def __init__(
        self,
        num_channels=3,
        embedding_size=32,
        hidden_sizes=[128, 192, 512, 1088],
        depths=[2, 6, 12, 2],
        groups_width=64,
        layer_type="y",
        hidden_act="relu",
        **kwargs,
    ):
        # 调用父类构造函数初始化对象
        super().__init__(**kwargs)

        # 检查给定的 layer_type 是否在支持的层类型列表中,如果不在则抛出错误
        if layer_type not in self.layer_types:
            raise ValueError(f"layer_type={layer_type} is not one of {','.join(self.layer_types)}")

        # 设置对象的各个属性值
        self.num_channels = num_channels
        self.embedding_size = embedding_size
        self.hidden_sizes = hidden_sizes
        self.depths = depths
        self.groups_width = groups_width
        self.layer_type = layer_type
        self.hidden_act = hidden_act

        # 始终在第一阶段进行下采样
        self.downsample_in_first_stage = True
    ```

.\models\regnet\convert_regnet_seer_10b_to_pytorch.py

# coding=utf-8
# 版权所有 2022 年 HuggingFace Inc. 团队
#
# 根据 Apache 许可证版本 2.0(“许可证”)许可;
# 除非符合许可证要求,否则不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件按“原样”分发,
# 没有任何形式的明示或暗示保证或条件。
# 有关许可下的详细信息,请参阅许可证。
"""转换 RegNet 10B 检查点为 vissl 格式。"""
# 您需要安装 classy vision 的特定版本
# pip install git+https://github.com/FrancescoSaverioZuppichini/ClassyVision.git@convert_weights

import argparse
import json
import os
import re
from collections import OrderedDict
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from pprint import pprint
from typing import Dict, List, Tuple

import torch
import torch.nn as nn
from classy_vision.models.regnet import RegNet, RegNetParams  # 导入 RegNet 相关模块
from huggingface_hub import cached_download, hf_hub_url  # 导入缓存下载和 HF Hub URL 相关模块
from torch import Tensor
from vissl.models.model_helpers import get_trunk_forward_outputs  # 导入 vissl 模型助手函数

from transformers import AutoImageProcessor, RegNetConfig, RegNetForImageClassification, RegNetModel
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging


logging.set_verbosity_info()  # 设置日志记录级别为 info
logger = logging.get_logger()  # 获取日志记录器


@dataclass
class Tracker:
    """
    追踪器类,用于跟踪模块的前向传播过程,并记录子模块和参数信息。
    """
    module: nn.Module  # 要追踪的模块
    traced: List[nn.Module] = field(default_factory=list)  # 记录已追踪的模块列表
    handles: list = field(default_factory=list)  # 模块注册的钩子句柄列表
    name2module: Dict[str, nn.Module] = field(default_factory=OrderedDict)  # 模块名称到模块对象的字典

    def _forward_hook(self, m, inputs: Tensor, outputs: Tensor, name: str):
        """
        前向传播钩子函数,用于处理模块的前向传播输出。
        """
        has_not_submodules = len(list(m.modules())) == 1 or isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d)
        if has_not_submodules:
            self.traced.append(m)
            self.name2module[name] = m

    def __call__(self, x: Tensor):
        """
        执行追踪器对象,注册前向传播钩子,并进行模块的前向传播。
        """
        for name, m in self.module.named_modules():
            self.handles.append(m.register_forward_hook(partial(self._forward_hook, name=name)))
        self.module(x)
        [x.remove() for x in self.handles]  # 移除注册的所有前向传播钩子
        return self

    @property
    def parametrized(self):
        """
        属性方法,返回具有可学习参数的模块字典。
        """
        return {k: v for k, v in self.name2module.items() if len(list(v.state_dict().keys())) > 0}


class FakeRegNetVisslWrapper(nn.Module):
    """
    模拟 vissl 操作而无需传递配置文件的 RegNet 包装器。
    """
    pass
    # 初始化函数,用于创建一个特征提取器对象
    def __init__(self, model: nn.Module):
        # 调用父类的初始化方法
        super().__init__()

        # 定义特征块列表,用于存储特征块的名称和对应的模块
        feature_blocks: List[Tuple[str, nn.Module]] = []

        # 添加模型的起始卷积层作为特征块 "conv1"
        feature_blocks.append(("conv1", model.stem))

        # 遍历模型的主干输出的每个子模块
        for k, v in model.trunk_output.named_children():
            # 断言子模块的名称以 "block" 开头,以确保符合预期
            assert k.startswith("block"), f"Unexpected layer name {k}"

            # 计算当前特征块的索引
            block_index = len(feature_blocks) + 1

            # 添加当前子模块作为特征块 "resN",其中 N 是索引
            feature_blocks.append((f"res{block_index}", v))

        # 使用特征块列表创建 nn.ModuleDict 对象,用于管理特征块
        self._feature_blocks = nn.ModuleDict(feature_blocks)

    # 前向传播函数,接受输入张量 x,并返回特征提取器的输出
    def forward(self, x: Tensor):
        # 调用 get_trunk_forward_outputs 函数获取主干网络的前向传播输出
        return get_trunk_forward_outputs(
            x,
            out_feat_keys=None,  # 不指定输出特征键值,表示返回所有特征块的输出
            feature_blocks=self._feature_blocks,  # 使用初始化时创建的特征块字典
        )
class FakeRegNetParams(RegNetParams):
    """
    Used to instantiate a RegNet model from Classy Vision with the same depth as the 10B one but with super small
    parameters, so we can trace it in memory.
    """

    def get_expanded_params(self):
        # 返回一个列表,每个元素是一个元组,描述了不同配置的参数
        return [(8, 2, 2, 8, 1.0), (8, 2, 7, 8, 1.0), (8, 2, 17, 8, 1.0), (8, 2, 1, 8, 1.0)]


def get_from_to_our_keys(model_name: str) -> Dict[str, str]:
    """
    Returns a dictionary that maps from original model's key -> our implementation's keys
    """

    # 创建我们的模型(使用小的权重)
    our_config = RegNetConfig(depths=[2, 7, 17, 1], hidden_sizes=[8, 8, 8, 8], groups_width=8)
    if "in1k" in model_name:
        our_model = RegNetForImageClassification(our_config)
    else:
        our_model = RegNetModel(our_config)

    # 创建原始模型(使用小的权重)
    from_model = FakeRegNetVisslWrapper(
        RegNet(FakeRegNetParams(depth=27, group_width=1010, w_0=1744, w_a=620.83, w_m=2.52))
    )

    with torch.no_grad():
        from_model = from_model.eval()
        our_model = our_model.eval()

        x = torch.randn((1, 3, 32, 32))
        # 对两个模型进行追踪
        dest_tracker = Tracker(our_model)
        dest_traced = dest_tracker(x).parametrized

        pprint(dest_tracker.name2module)
        src_tracker = Tracker(from_model)
        src_traced = src_tracker(x).parametrized

    # 将模块字典转换为参数字典
    def to_params_dict(dict_with_modules):
        params_dict = OrderedDict()
        for name, module in dict_with_modules.items():
            for param_name, param in module.state_dict().items():
                params_dict[f"{name}.{param_name}"] = param
        return params_dict

    from_to_ours_keys = {}

    src_state_dict = to_params_dict(src_traced)
    dst_state_dict = to_params_dict(dest_traced)

    # 将原始模型和我们模型的键映射关系存储到字典中
    for (src_key, src_param), (dest_key, dest_param) in zip(src_state_dict.items(), dst_state_dict.items()):
        from_to_ours_keys[src_key] = dest_key
        logger.info(f"{src_key} -> {dest_key}")

    # 如果模型名中包含 "in1k",则表明它可能有一个分类头(经过微调)
    if "in1k" in model_name:
        from_to_ours_keys["0.clf.0.weight"] = "classifier.1.weight"
        from_to_ours_keys["0.clf.0.bias"] = "classifier.1.bias"

    return from_to_ours_keys


def convert_weights_and_push(save_directory: Path, model_name: str = None, push_to_hub: bool = True):
    filename = "imagenet-1k-id2label.json"
    num_labels = 1000

    repo_id = "huggingface/label-files"
    num_labels = num_labels
    id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename, repo_type="dataset")), "r"))
    id2label = {int(k): v for k, v in id2label.items()}

    id2label = id2label
    label2id = {v: k for k, v in id2label.items()}

    # 使用部分函数创建 ImageNetPreTrainedConfig 对象
    ImageNetPreTrainedConfig = partial(RegNetConfig, num_labels=num_labels, id2label=id2label, label2id=label2id)
    # 定义一个字典,映射模型名称到预训练配置对象
    names_to_config = {
        "regnet-y-10b-seer": ImageNetPreTrainedConfig(
            depths=[2, 7, 17, 1], hidden_sizes=[2020, 4040, 11110, 28280], groups_width=1010
        ),
        # 在 ImageNet 上微调
        "regnet-y-10b-seer-in1k": ImageNetPreTrainedConfig(
            depths=[2, 7, 17, 1], hidden_sizes=[2020, 4040, 11110, 28280], groups_width=1010
        ),
    }

    # 添加 SEER 模型权重逻辑
    def load_using_classy_vision(checkpoint_url: str) -> Tuple[Dict, Dict]:
        # 从给定 URL 加载模型状态字典,保存在内存中,并映射到 CPU
        files = torch.hub.load_state_dict_from_url(checkpoint_url, model_dir=str(save_directory), map_location="cpu")
        # 检查是否有头部信息,如果有,则添加到模型状态字典中
        model_state_dict = files["classy_state_dict"]["base_model"]["model"]
        return model_state_dict["trunk"], model_state_dict["heads"]

    # 定义一个字典,将模型名称映射到从 URL 加载模型状态字典的部分函数
    names_to_from_model = {
        "regnet-y-10b-seer": partial(
            load_using_classy_vision,
            "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet10B/model_iteration124500_conso.torch",
        ),
        "regnet-y-10b-seer-in1k": partial(
            load_using_classy_vision,
            "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_10b_finetuned_in1k_model_phase28_conso.torch",
        ),
    }

    # 获取从原始模型到我们模型键的映射
    from_to_ours_keys = get_from_to_our_keys(model_name)

    # 检查是否已经存在模型的状态字典文件
    if not (save_directory / f"{model_name}.pth").exists():
        logger.info("Loading original state_dict.")
        # 加载模型的原始状态字典 trunk 和 head
        from_state_dict_trunk, from_state_dict_head = names_to_from_model[model_name]()
        from_state_dict = from_state_dict_trunk
        if "in1k" in model_name:
            # 如果模型名称中包含 "in1k",则将头部信息添加到模型状态字典中
            from_state_dict = {**from_state_dict_trunk, **from_state_dict_head}
        logger.info("Done!")

        # 创建一个空字典来存储转换后的状态字典
        converted_state_dict = {}

        # 初始化未使用的键列表
        not_used_keys = list(from_state_dict.keys())
        # 定义一个正则表达式来匹配要移除的模型键中的特定字符串
        regex = r"\.block.-part."
        # 迭代处理原始模型状态字典的每个键
        for key in from_state_dict.keys():
            # 从模型键中移除特定字符串以获取源键
            src_key = re.sub(regex, "", key)
            # 使用映射表将源键转换为我们模型的目标键
            dest_key = from_to_ours_keys[src_key]
            # 将参数与目标键存储到转换后的状态字典中
            converted_state_dict[dest_key] = from_state_dict[key]
            # 从未使用的键列表中移除当前键
            not_used_keys.remove(key)
        # 检查是否所有的键都已经更新
        assert len(not_used_keys) == 0, f"Some keys where not used {','.join(not_used_keys)}"

        logger.info(f"The following keys were not used: {','.join(not_used_keys)}")

        # 将转换后的状态字典保存到磁盘
        torch.save(converted_state_dict, save_directory / f"{model_name}.pth")

        # 释放转换后的状态字典的内存
        del converted_state_dict
    else:
        logger.info("The state_dict was already stored on disk.")
    # 如果需要将模型推送到 Hub
    if push_to_hub:
        # 记录环境变量中的 HF_TOKEN
        logger.info(f"Token is {os.environ['HF_TOKEN']}")
        # 输出信息:加载我们的模型
        logger.info("Loading our model.")
        # 根据模型名称获取配置
        our_config = names_to_config[model_name]
        # 默认使用 RegNetModel 作为模型函数
        our_model_func = RegNetModel
        # 如果模型名称中包含 "in1k",则使用 RegNetForImageClassification
        if "in1k" in model_name:
            our_model_func = RegNetForImageClassification
        # 创建我们的模型实例
        our_model = our_model_func(our_config)
        # 将我们的模型放置到 meta 设备上(移除所有权重)
        our_model.to(torch.device("meta"))
        # 输出信息:在我们的模型中加载 state_dict
        logger.info("Loading state_dict in our model.")
        # 获取我们模型当前的 state_dict 的键集合
        state_dict_keys = our_model.state_dict().keys()
        # 以低内存方式加载预训练模型
        PreTrainedModel._load_pretrained_model_low_mem(
            our_model, state_dict_keys, [save_directory / f"{model_name}.pth"]
        )
        # 输出信息:最终进行推送操作
        logger.info("Finally, pushing!")
        # 将模型推送到 Hub
        our_model.push_to_hub(
            repo_path_or_name=save_directory / model_name,
            commit_message="Add model",
            output_dir=save_directory / model_name,
        )
        # 设定图像处理器的尺寸
        size = 384
        # 输出信息:我们可以使用 convnext 模型
        logger.info("we can use the convnext one")
        # 从预训练模型 facebook/convnext-base-224-22k-1k 创建图像处理器实例
        image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224-22k-1k", size=size)
        # 将图像处理器推送到 Hub
        image_processor.push_to_hub(
            repo_path_or_name=save_directory / model_name,
            commit_message="Add image processor",
            output_dir=save_directory / model_name,
        )
if __name__ == "__main__":
    # 如果当前脚本作为主程序运行,则执行以下代码块

    parser = argparse.ArgumentParser()
    # 创建参数解析器对象

    # Required parameters
    parser.add_argument(
        "--model_name",
        default=None,
        type=str,
        help=(
            "The name of the model you wish to convert, it must be one of the supported regnet* architecture,"
            " currently: regnetx-*, regnety-*. If `None`, all of them will the converted."
        ),
    )
    # 添加名为 `--model_name` 的参数,用于指定要转换的模型名称,必须是支持的 regnet* 架构之一

    parser.add_argument(
        "--pytorch_dump_folder_path",
        default=None,
        type=Path,
        required=True,
        help="Path to the output PyTorch model directory.",
    )
    # 添加名为 `--pytorch_dump_folder_path` 的参数,指定输出的 PyTorch 模型目录的路径,此参数为必选

    parser.add_argument(
        "--push_to_hub",
        default=True,
        type=bool,
        required=False,
        help="If True, push model and image processor to the hub.",
    )
    # 添加名为 `--push_to_hub` 的参数,如果设置为 True,则推送模型和图像处理器到指定的 Hub

    args = parser.parse_args()
    # 解析命令行参数并存储到 `args` 对象中

    pytorch_dump_folder_path: Path = args.pytorch_dump_folder_path
    # 从参数对象 `args` 中获取 PyTorch 模型目录路径,并赋值给变量 `pytorch_dump_folder_path`

    pytorch_dump_folder_path.mkdir(exist_ok=True, parents=True)
    # 创建 PyTorch 模型目录,如果目录已存在则忽略,同时创建必要的父目录

    convert_weights_and_push(pytorch_dump_folder_path, args.model_name, args.push_to_hub)
    # 调用函数 `convert_weights_and_push`,将 PyTorch 模型目录路径、模型名称和推送标志作为参数传递给该函数

.\models\regnet\convert_regnet_to_pytorch.py

# coding=utf-8
# 声明编码格式为 UTF-8

# Copyright 2022 The HuggingFace Inc. team.
# 版权声明为 2022 年 HuggingFace Inc. 团队所有

# Licensed under the Apache License, Version 2.0 (the "License");
# 使用 Apache 许可证 2.0 版本进行许可

# you may not use this file except in compliance with the License.
# 您除非遵守许可证,否则不得使用此文件

# You may obtain a copy of the License at
# 您可以在以下网址获取许可证副本

#     http://www.apache.org/licenses/LICENSE-2.0
#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 除非法律要求或书面同意,否则根据许可证分发的软件是基于"原样"分发的,
# 没有任何形式的明示或暗示担保或条件

# See the License for the specific language governing permissions and
# limitations under the License.
# 请查看许可证以了解具体的语言授权和限制

"""Convert RegNet checkpoints from timm and vissl."""

import argparse
import json
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Callable, Dict, List, Tuple

import timm
import torch
import torch.nn as nn
from classy_vision.models.regnet import RegNet, RegNetParams, RegNetY32gf, RegNetY64gf, RegNetY128gf
from huggingface_hub import cached_download, hf_hub_url
from torch import Tensor
from vissl.models.model_helpers import get_trunk_forward_outputs

from transformers import AutoImageProcessor, RegNetConfig, RegNetForImageClassification, RegNetModel
from transformers.utils import logging

# 设置日志输出为 info 级别
logging.set_verbosity_info()
# 获取日志记录器
logger = logging.get_logger()


@dataclass
class Tracker:
    # 跟踪器类,用于追踪模块的前向传播行为
    module: nn.Module
    traced: List[nn.Module] = field(default_factory=list)
    handles: list = field(default_factory=list)

    def _forward_hook(self, m, inputs: Tensor, outputs: Tensor):
        # 前向钩子函数,记录非子模块的模块到追踪列表
        has_not_submodules = len(list(m.modules())) == 1 or isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d)
        if has_not_submodules:
            self.traced.append(m)

    def __call__(self, x: Tensor):
        # 对模块进行前向传播,记录前向钩子处理
        for m in self.module.modules():
            self.handles.append(m.register_forward_hook(self._forward_hook))
        self.module(x)
        [x.remove() for x in self.handles]  # 移除注册的钩子
        return self

    @property
    def parametrized(self):
        # 检查追踪的模块列表中是否有可学习的参数
        return list(filter(lambda x: len(list(x.state_dict().keys())) > 0, self.traced))


@dataclass
class ModuleTransfer:
    # 模块传输类,用于在不同模型之间传输权重参数
    src: nn.Module
    dest: nn.Module
    verbose: int = 1
    src_skip: List = field(default_factory=list)
    dest_skip: List = field(default_factory=list)
    raise_if_mismatch: bool = True
    # 定义一个方法,使得对象可以像函数一样被调用,传入参数 x,其类型为 Tensor
    def __call__(self, x: Tensor):
        """
        Transfer the weights of `self.src` to `self.dest` by performing a forward pass using `x` as input. Under the
        hood we tracked all the operations in both modules.
        """
        # 使用 Tracker 对象跟踪 self.dest 模块的前向传播结果,并获取参数化的结果
        dest_traced = Tracker(self.dest)(x).parametrized
        # 使用 Tracker 对象跟踪 self.src 模块的前向传播结果,并获取参数化的结果
        src_traced = Tracker(self.src)(x).parametrized

        # 过滤掉 self.src_skip 中指定类型的操作,得到过滤后的 src_traced 列表
        src_traced = list(filter(lambda x: type(x) not in self.src_skip, src_traced))
        # 过滤掉 self.dest_skip 中指定类型的操作,得到过滤后的 dest_traced 列表
        dest_traced = list(filter(lambda x: type(x) not in self.dest_skip, dest_traced))

        # 如果 dest_traced 和 src_traced 的长度不同,并且设置了 raise_if_mismatch 标志,则抛出异常
        if len(dest_traced) != len(src_traced) and self.raise_if_mismatch:
            raise Exception(
                f"Numbers of operations are different. Source module has {len(src_traced)} operations while"
                f" destination module has {len(dest_traced)}."
            )

        # 逐一将 src_m 的状态字典加载到 dest_m 中
        for dest_m, src_m in zip(dest_traced, src_traced):
            dest_m.load_state_dict(src_m.state_dict())
            # 如果设置了 verbose 标志为 1,则打印详细信息表示权重转移情况
            if self.verbose == 1:
                print(f"Transfered from={src_m} to={dest_m}")
class FakeRegNetVisslWrapper(nn.Module):
    """
    Fake wrapper for RegNet that mimics what vissl does without the need to pass a config file.
    """

    def __init__(self, model: nn.Module):
        super().__init__()

        feature_blocks: List[Tuple[str, nn.Module]] = []
        # - get the stem
        feature_blocks.append(("conv1", model.stem))  # 将模型的 stem 添加到特征块列表中,命名为 'conv1'
        # - get all the feature blocks
        for k, v in model.trunk_output.named_children():
            assert k.startswith("block"), f"Unexpected layer name {k}"
            block_index = len(feature_blocks) + 1
            feature_blocks.append((f"res{block_index}", v))  # 将模型 trunk_output 中的每个块添加到特征块列表中

        self._feature_blocks = nn.ModuleDict(feature_blocks)  # 将特征块列表转换为 nn.ModuleDict,存储在对象属性 _feature_blocks 中

    def forward(self, x: Tensor):
        return get_trunk_forward_outputs(
            x,
            out_feat_keys=None,
            feature_blocks=self._feature_blocks,
        )


class NameToFromModelFuncMap(dict):
    """
    A Dictionary with some additional logic to return a function that creates the correct original model.
    """

    def convert_name_to_timm(self, x: str) -> str:
        x_split = x.split("-")
        return x_split[0] + x_split[1] + "_" + "".join(x_split[2:])

    def __getitem__(self, x: str) -> Callable[[], Tuple[nn.Module, Dict]]:
        # default to timm!
        if x not in self:
            x = self.convert_name_to_timm(x)  # 将 x 转换为 timm 模型的名称格式
            val = partial(lambda: (timm.create_model(x, pretrained=True).eval(), None))  # 创建一个 lambda 函数,返回预训练的 timm 模型和空字典
        else:
            val = super().__getitem__(x)  # 调用父类 dict 的 __getitem__ 方法获取对应项

        return val


class NameToOurModelFuncMap(dict):
    """
    A Dictionary with some additional logic to return the correct hugging face RegNet class reference.
    """

    def __getitem__(self, x: str) -> Callable[[], nn.Module]:
        if "seer" in x and "in1k" not in x:
            val = RegNetModel  # 如果 x 包含 "seer" 且不包含 "in1k",返回 RegNetModel 类引用
        else:
            val = RegNetForImageClassification  # 否则返回 RegNetForImageClassification 类引用
        return val


def manually_copy_vissl_head(from_state_dict, to_state_dict, keys: List[Tuple[str, str]]):
    for from_key, to_key in keys:
        to_state_dict[to_key] = from_state_dict[from_key].clone()  # 复制 from_state_dict 中的权重到 to_state_dict 中,并使用 clone() 方法克隆张量
        print(f"Copied key={from_key} to={to_key}")  # 打印复制的键名和目标键名
    return to_state_dict


def convert_weight_and_push(
    name: str,
    from_model_func: Callable[[], nn.Module],
    our_model_func: Callable[[], nn.Module],
    config: RegNetConfig,
    save_directory: Path,
    push_to_hub: bool = True,
):
    print(f"Converting {name}...")  # 打印转换的模型名称
    with torch.no_grad():
        from_model, from_state_dict = from_model_func()  # 调用 from_model_func 获取源模型和其状态字典
        our_model = our_model_func(config).eval()  # 调用 our_model_func 创建我们的模型,并转换为评估模式
        module_transfer = ModuleTransfer(src=from_model, dest=our_model, raise_if_mismatch=False)  # 使用 ModuleTransfer 将源模型的权重转移到我们的模型中
        x = torch.randn((1, 3, 224, 224))  # 创建一个随机张量作为输入
        module_transfer(x)  # 执行模型权重转换
    # 如果有给定的 from_state_dict,则需要手动复制特定的头部参数
    if from_state_dict is not None:
        keys = []
        # 对于 seer - in1k finetuned 模型,需要手动复制头部参数
        if "seer" in name and "in1k" in name:
            keys = [("0.clf.0.weight", "classifier.1.weight"), ("0.clf.0.bias", "classifier.1.bias")]
        # 手动复制头部参数,并获取复制后的状态字典
        to_state_dict = manually_copy_vissl_head(from_state_dict, our_model.state_dict(), keys)
        # 使用复制后的状态字典加载我们的模型
        our_model.load_state_dict(to_state_dict)

    # 获取我们模型的输出,同时要求输出隐藏状态
    our_outputs = our_model(x, output_hidden_states=True)
    # 根据模型类型选择输出 logits 或者最后的隐藏状态
    our_output = (
        our_outputs.logits if isinstance(our_model, RegNetForImageClassification) else our_outputs.last_hidden_state
    )

    # 获取原始模型的输出
    from_output = from_model(x)
    # 如果原始模型的输出是一个列表,则选择最后一个元素作为输出
    from_output = from_output[-1] if isinstance(from_output, list) else from_output

    # 对于 vissl seer 模型,因为不使用任何配置文件,实际上没有头部,因此直接使用最后的隐藏状态
    if "seer" in name and "in1k" in name:
        our_output = our_outputs.hidden_states[-1]

    # 断言两个模型的输出是否近似相等,否则抛出异常
    assert torch.allclose(from_output, our_output), "The model logits don't match the original one."

    # 如果需要推送到 Hub
    if push_to_hub:
        # 将我们的模型推送到 Hub
        our_model.push_to_hub(
            repo_path_or_name=save_directory / name,
            commit_message="Add model",
            use_temp_dir=True,
        )

        # 根据模型名称选择图像处理器的大小
        size = 224 if "seer" not in name else 384
        # 使用预训练的 convnext-base-224-22k-1k 模型创建图像处理器
        image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224-22k-1k", size=size)
        # 将图像处理器推送到 Hub
        image_processor.push_to_hub(
            repo_path_or_name=save_directory / name,
            commit_message="Add image processor",
            use_temp_dir=True,
        )

        # 打印推送成功的消息
        print(f"Pushed {name}")
def convert_weights_and_push(save_directory: Path, model_name: str = None, push_to_hub: bool = True):
    # 定义文件名和标签数量
    filename = "imagenet-1k-id2label.json"
    num_labels = 1000
    expected_shape = (1, num_labels)

    # Hub repo ID
    repo_id = "huggingface/label-files"
    num_labels = num_labels  # 更新 num_labels 变量

    # 从 Hub 下载并加载 id 到 label 的映射关系
    id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename, repo_type="dataset")), "r"))
    id2label = {int(k): v for k, v in id2label.items()}  # 转换键为整数类型

    id2label = id2label  # 重复赋值,可能是误操作
    label2id = {v: k for k, v in id2label.items()}  # 创建 label 到 id 的映射关系

    # 创建一个配置对象,使用部分函数创建 RegNet 配置
    ImageNetPreTrainedConfig = partial(RegNetConfig, num_labels=num_labels, id2label=id2label, label2id=label2id)

    # 初始化两个映射对象
    names_to_ours_model_map = NameToOurModelFuncMap()
    names_to_from_model_map = NameToFromModelFuncMap()

    # 添加 SEER weights 逻辑

    # 定义一个函数,通过 Classy Vision 加载模型和状态字典
    def load_using_classy_vision(checkpoint_url: str, model_func: Callable[[], nn.Module]) -> Tuple[nn.Module, Dict]:
        # 从 URL 加载模型状态字典到指定目录
        files = torch.hub.load_state_dict_from_url(checkpoint_url, model_dir=str(save_directory), map_location="cpu")
        model = model_func()
        # 检查是否有头部,如果有则添加到模型
        model_state_dict = files["classy_state_dict"]["base_model"]["model"]
        state_dict = model_state_dict["trunk"]
        model.load_state_dict(state_dict)
        return model.eval(), model_state_dict["heads"]

    # 预训练模型映射

    # 添加 regnet-y-320-seer 的映射
    names_to_from_model_map["regnet-y-320-seer"] = partial(
        load_using_classy_vision,
        "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet32d/seer_regnet32gf_model_iteration244000.torch",
        lambda: FakeRegNetVisslWrapper(RegNetY32gf()),
    )

    # 添加 regnet-y-640-seer 的映射
    names_to_from_model_map["regnet-y-640-seer"] = partial(
        load_using_classy_vision,
        "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet64/seer_regnet64gf_model_final_checkpoint_phase0.torch",
        lambda: FakeRegNetVisslWrapper(RegNetY64gf()),
    )

    # 添加 regnet-y-1280-seer 的映射
    names_to_from_model_map["regnet-y-1280-seer"] = partial(
        load_using_classy_vision,
        "https://dl.fbaipublicfiles.com/vissl/model_zoo/swav_ig1b_regnet128Gf_cnstant_bs32_node16_sinkhorn10_proto16k_syncBN64_warmup8k/model_final_checkpoint_phase0.torch",
        lambda: FakeRegNetVisslWrapper(RegNetY128gf()),
    )

    # 添加 regnet-y-10b-seer 的映射
    names_to_from_model_map["regnet-y-10b-seer"] = partial(
        load_using_classy_vision,
        "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet10B/model_iteration124500_conso.torch",
        lambda: FakeRegNetVisslWrapper(
            RegNet(RegNetParams(depth=27, group_width=1010, w_0=1744, w_a=620.83, w_m=2.52))
        ),
    )

    # IN1K finetuned 映射

    # 添加 regnet-y-320-seer-in1k 的映射
    names_to_from_model_map["regnet-y-320-seer-in1k"] = partial(
        load_using_classy_vision,
        "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet32_finetuned_in1k_model_final_checkpoint_phase78.torch",
        lambda: FakeRegNetVisslWrapper(RegNetY32gf()),
    )
    # 将模型名称映射到加载模型函数的部分函数调用,使用 Classy Vision 加载模型
    names_to_from_model_map["regnet-y-640-seer-in1k"] = partial(
        load_using_classy_vision,
        "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet64_finetuned_in1k_model_final_checkpoint_phase78.torch",
        lambda: FakeRegNetVisslWrapper(RegNetY64gf()),
    )

    # 将模型名称映射到加载模型函数的部分函数调用,使用 Classy Vision 加载模型
    names_to_from_model_map["regnet-y-1280-seer-in1k"] = partial(
        load_using_classy_vision,
        "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet128_finetuned_in1k_model_final_checkpoint_phase78.torch",
        lambda: FakeRegNetVisslWrapper(RegNetY128gf()),
    )

    # 将模型名称映射到加载模型函数的部分函数调用,使用 Classy Vision 加载模型
    names_to_from_model_map["regnet-y-10b-seer-in1k"] = partial(
        load_using_classy_vision,
        "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_10b_finetuned_in1k_model_phase28_conso.torch",
        lambda: FakeRegNetVisslWrapper(
            RegNet(RegNetParams(depth=27, group_width=1010, w_0=1744, w_a=620.83, w_m=2.52))
        ),
    )

    # 如果指定了模型名称,则转换权重并推送到指定的模型名称下
    if model_name:
        convert_weight_and_push(
            model_name,
            names_to_from_model_map[model_name],  # 使用模型名称从映射中获取加载模型函数
            names_to_ours_model_map[model_name],  # 使用模型名称从映射中获取我们的模型名称
            names_to_config[model_name],  # 使用模型名称从映射中获取配置
            save_directory,  # 保存目录路径
            push_to_hub,  # 是否推送到 Hub
        )
    else:
        # 否则,对于每个模型名称和其对应的配置,转换权重并推送到对应的模型名称下
        for model_name, config in names_to_config.items():
            convert_weight_and_push(
                model_name,
                names_to_from_model_map[model_name],  # 使用模型名称从映射中获取加载模型函数
                names_to_ours_model_map[model_name],  # 使用模型名称从映射中获取我们的模型名称
                config,  # 使用当前配置
                save_directory,  # 保存目录路径
                push_to_hub,  # 是否推送到 Hub
            )
    
    # 返回配置和期望的形状
    return config, expected_shape
if __name__ == "__main__":
    # 如果当前脚本作为主程序执行,则执行以下操作

    parser = argparse.ArgumentParser()
    # 创建参数解析器对象

    # Required parameters
    parser.add_argument(
        "--model_name",
        default=None,
        type=str,
        help=(
            "The name of the model you wish to convert, it must be one of the supported regnet* architecture,"
            " currently: regnetx-*, regnety-*. If `None`, all of them will the converted."
        ),
    )
    # 添加一个必需的参数选项,用于指定要转换的模型名称

    parser.add_argument(
        "--pytorch_dump_folder_path",
        default=None,
        type=Path,
        required=True,
        help="Path to the output PyTorch model directory.",
    )
    # 添加一个必需的参数选项,用于指定输出的 PyTorch 模型目录路径

    parser.add_argument(
        "--push_to_hub",
        default=True,
        type=bool,
        required=False,
        help="If True, push model and image processor to the hub.",
    )
    # 添加一个可选的参数选项,指定是否将模型和图像处理器推送到 hub

    args = parser.parse_args()
    # 解析命令行参数并将其存储在 args 变量中

    pytorch_dump_folder_path: Path = args.pytorch_dump_folder_path
    # 从参数中获取 PyTorch 模型目录路径

    pytorch_dump_folder_path.mkdir(exist_ok=True, parents=True)
    # 创建 PyTorch 模型目录,如果目录已存在则忽略,同时创建必要的父目录

    convert_weights_and_push(pytorch_dump_folder_path, args.model_name, args.push_to_hub)
    # 调用函数,将权重转换并推送到指定的 PyTorch 模型目录

.\models\regnet\modeling_flax_regnet.py

# 导入所需的模块和库
from functools import partial  # 导入partial函数,用于创建带预设参数的可调用对象
from typing import Optional, Tuple  # 引入类型提示,用于函数参数和返回类型的声明

import flax.linen as nn  # 导入Flax的linen模块,用于定义Flax模型
import jax  # 导入JAX,用于自动微分和并行计算
import jax.numpy as jnp  # 导入JAX的NumPy接口,用于数值计算
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze  # 导入FrozenDict和相关函数,用于不可变字典的管理
from flax.traverse_util import flatten_dict, unflatten_dict  # 导入flatten_dict和unflatten_dict函数,用于字典的扁平化和恢复

from transformers import RegNetConfig  # 导入RegNetConfig类,用于配置RegNet模型
from transformers.modeling_flax_outputs import (  # 导入Flax模型输出相关类
    FlaxBaseModelOutputWithNoAttention,
    FlaxBaseModelOutputWithPooling,
    FlaxBaseModelOutputWithPoolingAndNoAttention,
    FlaxImageClassifierOutputWithNoAttention,
)
from transformers.modeling_flax_utils import (  # 导入Flax模型工具函数
    ACT2FN,
    FlaxPreTrainedModel,
    append_replace_return_docstrings,
    overwrite_call_docstring,
)
from transformers.utils import (  # 导入transformers工具函数
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
)

# 定义模型文档字符串
REGNET_START_DOCSTRING = r"""

    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading, saving and converting weights from PyTorch models)

    This model is also a
    [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
    a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
    behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)

"""
    # 定义函数参数:
    #   config ([`RegNetConfig`]): 模型配置类,包含模型的所有参数。
    #       仅使用配置文件初始化,不加载与模型相关的权重,只加载配置信息。
    #       查看 [`~FlaxPreTrainedModel.from_pretrained`] 方法以加载模型权重。
    #   dtype (`jax.numpy.dtype`, *可选*, 默认为 `jax.numpy.float32`):
    #       计算的数据类型。可以是 `jax.numpy.float32`, `jax.numpy.float16` (在GPU上),以及 `jax.numpy.bfloat16` (在TPU上)。
    #       
    #       这可以用于在GPU或TPU上启用混合精度训练或半精度推断。如果指定,则所有计算将使用给定的 `dtype` 执行。
    #       
    #       **请注意,这仅指定计算的数据类型,并不影响模型参数的数据类型。**
    #       
    #       如果要更改模型参数的数据类型,请参阅 [`~FlaxPreTrainedModel.to_fp16`] 和 [`~FlaxPreTrainedModel.to_bf16`]。
"""

REGNET_INPUTS_DOCSTRING = r"""
    Args:
        pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`RegNetImageProcessor.__call__`] for details.

        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""


# Copied from transformers.models.resnet.modeling_flax_resnet.Identity
class Identity(nn.Module):
    """Identity function."""
    
    @nn.compact
    def __call__(self, x, **kwargs):
        return x
    RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
    downsample the input using `stride=2`.
    """
    
    # 定义一个类,用于执行RegNet的shortcut操作,将残差特征投影到正确的大小。如果需要,还可以使用 `stride=2` 对输入进行降采样。
    class RegNetShortcut(nn.Module):
        
        # 初始化函数,设置输出通道数、步幅为2、数据类型为 jnp.float32
        def __init__(self, out_channels: int, stride: int = 2, dtype: jnp.dtype = jnp.float32):
            self.out_channels = out_channels
            self.stride = stride
            self.dtype = dtype
            
        # 在设置阶段定义操作:使用 1x1 的卷积层进行投影,无偏置,采用截断正态分布进行初始化,数据类型为 self.dtype
        def setup(self):
            self.convolution = nn.Conv(
                self.out_channels,
                kernel_size=(1, 1),
                strides=self.stride,
                use_bias=False,
                kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"),
                dtype=self.dtype,
            )
            # 使用批量归一化层,设置动量为 0.9,epsilon 为 1e-05,数据类型为 self.dtype
            self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype)
        
        # 调用实例时执行的操作:对输入 x 进行卷积投影,然后对投影结果进行归一化
        def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
            hidden_state = self.convolution(x)
            hidden_state = self.normalization(hidden_state, use_running_average=deterministic)
            return hidden_state
class FlaxRegNetSELayerCollection(nn.Module):
    in_channels: int
    reduced_channels: int
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # 定义第一个卷积层,用于 SE 层集合
        self.conv_1 = nn.Conv(
            self.reduced_channels,
            kernel_size=(1, 1),
            kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"),
            dtype=self.dtype,
            name="0",
        )  # 0 is the name used in corresponding pytorch implementation
        # 定义第二个卷积层,用于 SE 层集合
        self.conv_2 = nn.Conv(
            self.in_channels,
            kernel_size=(1, 1),
            kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"),
            dtype=self.dtype,
            name="2",
        )  # 2 is the name used in corresponding pytorch implementation

    def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray:
        # 对输入的隐藏状态应用第一个卷积层
        hidden_state = self.conv_1(hidden_state)
        # 应用 ReLU 激活函数
        hidden_state = nn.relu(hidden_state)
        # 对处理后的隐藏状态应用第二个卷积层
        hidden_state = self.conv_2(hidden_state)
        # 计算注意力,使用 sigmoid 激活函数
        attention = nn.sigmoid(hidden_state)

        return attention


class FlaxRegNetSELayer(nn.Module):
    """
    Squeeze and Excitation layer (SE) proposed in [Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507).
    """

    in_channels: int
    reduced_channels: int
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # 定义平均池化操作的部分函数
        self.pooler = partial(nn.avg_pool, padding=((0, 0), (0, 0)))
        # 初始化 SE 层集合作为注意力机制
        self.attention = FlaxRegNetSELayerCollection(self.in_channels, self.reduced_channels, dtype=self.dtype)

    def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray:
        # 对输入的隐藏状态进行平均池化
        pooled = self.pooler(
            hidden_state,
            window_shape=(hidden_state.shape[1], hidden_state.shape[2]),
            strides=(hidden_state.shape[1], hidden_state.shape[2]),
        )
        # 应用注意力机制得到注意力张量
        attention = self.attention(pooled)
        # 将原始隐藏状态与注意力张量相乘,执行 SE 操作
        hidden_state = hidden_state * attention
        return hidden_state


class FlaxRegNetXLayerCollection(nn.Module):
    config: RegNetConfig
    out_channels: int
    stride: int = 1
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # 计算组数,确保每个组的宽度不小于 1
        groups = max(1, self.out_channels // self.config.groups_width)

        # 定义层集合,包括三个卷积层
        self.layer = [
            FlaxRegNetConvLayer(
                self.out_channels,
                kernel_size=1,
                activation=self.config.hidden_act,
                dtype=self.dtype,
                name="0",
            ),
            FlaxRegNetConvLayer(
                self.out_channels,
                stride=self.stride,
                groups=groups,
                activation=self.config.hidden_act,
                dtype=self.dtype,
                name="1",
            ),
            FlaxRegNetConvLayer(
                self.out_channels,
                kernel_size=1,
                activation=None,
                dtype=self.dtype,
                name="2",
            ),
        ]
    # 定义一个特殊方法 __call__,使得该类的实例对象可以像函数一样被调用
    def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
        # 遍历该类实例对象中的每一个层(layer)
        for layer in self.layer:
            # 对隐藏状态(hidden_state)依次应用每一层(layer)的操作
            hidden_state = layer(hidden_state, deterministic=deterministic)
        # 返回经过所有层操作后的最终隐藏状态
        return hidden_state
# 定义一个 FlaxRegNetXLayer 类,表示 RegNet 的 X 层模块
class FlaxRegNetXLayer(nn.Module):
    """
    RegNet 的层,由三个 3x3 卷积组成,与 ResNet 的瓶颈层相同,但 reduction = 1。
    """

    # 类属性:RegNet 的配置
    config: RegNetConfig
    # 输入通道数
    in_channels: int
    # 输出通道数
    out_channels: int
    # 步幅,默认为 1
    stride: int = 1
    # 数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 设置方法,用于初始化层的各个组件
    def setup(self):
        # 判断是否需要应用 shortcut
        should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1
        # 如果需要应用 shortcut,则初始化为 FlaxRegNetShortCut 对象;否则初始化为 Identity 对象
        self.shortcut = (
            FlaxRegNetShortCut(
                self.out_channels,
                stride=self.stride,
                dtype=self.dtype,
            )
            if should_apply_shortcut
            else Identity()
        )
        # 初始化层对象为 FlaxRegNetXLayerCollection 实例
        self.layer = FlaxRegNetXLayerCollection(
            self.config,
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            stride=self.stride,
            dtype=self.dtype,
        )
        # 激活函数为 ACT2FN 中根据配置选择的隐藏激活函数
        self.activation_func = ACT2FN[self.config.hidden_act]

    # 调用方法,用于前向传播
    def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
        # 将输入 hidden_state 赋值给 residual 作为残差连接的起始值
        residual = hidden_state
        # 经过 X 层的主体卷积操作
        hidden_state = self.layer(hidden_state)
        # 对残差应用 shortcut 操作
        residual = self.shortcut(residual, deterministic=deterministic)
        # 将主体卷积的输出与 shortcut 的结果相加,实现残差连接
        hidden_state += residual
        # 应用激活函数到最终输出
        hidden_state = self.activation_func(hidden_state)
        return hidden_state


# 定义一个 FlaxRegNetYLayerCollection 类,表示 RegNet 的 Y 层的卷积集合
class FlaxRegNetYLayerCollection(nn.Module):
    config: RegNetConfig
    in_channels: int
    out_channels: int
    stride: int = 1
    dtype: jnp.dtype = jnp.float32

    # 设置方法,用于初始化层的各个组件
    def setup(self):
        # 计算组数,用于分组卷积
        groups = max(1, self.out_channels // self.config.groups_width)

        # 初始化层对象为包含四个子层的列表
        self.layer = [
            # 第一个卷积层,1x1 卷积
            FlaxRegNetConvLayer(
                self.out_channels,
                kernel_size=1,
                activation=self.config.hidden_act,
                dtype=self.dtype,
                name="0",
            ),
            # 第二个卷积层,3x3 卷积,带分组卷积
            FlaxRegNetConvLayer(
                self.out_channels,
                stride=self.stride,
                groups=groups,
                activation=self.config.hidden_act,
                dtype=self.dtype,
                name="1",
            ),
            # Squeeze and Excitation 模块
            FlaxRegNetSELayer(
                self.out_channels,
                reduced_channels=int(round(self.in_channels / 4)),
                dtype=self.dtype,
                name="2",
            ),
            # 第四个卷积层,1x1 卷积,不带激活函数
            FlaxRegNetConvLayer(
                self.out_channels,
                kernel_size=1,
                activation=None,
                dtype=self.dtype,
                name="3",
            ),
        ]

    # 调用方法,用于前向传播
    def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray:
        # 依次对每个子层进行前向传播
        for layer in self.layer:
            hidden_state = layer(hidden_state)
        return hidden_state


# 定义一个 FlaxRegNetYLayer 类,表示 RegNet 的 Y 层模块,是 X 层与 Squeeze and Excitation 的组合
class FlaxRegNetYLayer(nn.Module):
    """
    RegNet 的 Y 层:包含一个 X 层和 Squeeze and Excitation 模块。
    """

    config: RegNetConfig
    in_channels: int
    out_channels: int
    stride: int = 1
    dtype: jnp.dtype = jnp.float32
    # 定义设置方法,用于初始化网络层
    def setup(self):
        # 检查是否需要应用快捷连接,条件为输入通道数不等于输出通道数或步长不为1
        should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1

        # 根据条件选择不同的快捷连接方式,若需要则使用 FlaxRegNetShortCut,否则使用 Identity()
        self.shortcut = (
            FlaxRegNetShortCut(
                self.out_channels,
                stride=self.stride,
                dtype=self.dtype,
            )
            if should_apply_shortcut
            else Identity()
        )

        # 创建网络层集合对象 FlaxRegNetYLayerCollection
        self.layer = FlaxRegNetYLayerCollection(
            self.config,
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            stride=self.stride,
            dtype=self.dtype,
        )

        # 选择激活函数,根据配置选择对应的激活函数
        self.activation_func = ACT2FN[self.config.hidden_act]

    # 定义调用方法,用于执行前向传播计算
    def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
        # 将输入的隐藏状态作为残差保存
        residual = hidden_state
        # 通过网络层集合对象计算新的隐藏状态
        hidden_state = self.layer(hidden_state)
        # 根据快捷连接计算残差项
        residual = self.shortcut(residual, deterministic=deterministic)
        # 将残差项加到新的隐藏状态上
        hidden_state += residual
        # 应用选择的激活函数到更新后的隐藏状态上
        hidden_state = self.activation_func(hidden_state)
        # 返回更新后的隐藏状态作为输出
        return hidden_state
class FlaxRegNetStageLayersCollection(nn.Module):
    """
    A RegNet stage composed by stacked layers.
    """

    config: RegNetConfig  # 存储RegNet配置的对象
    in_channels: int  # 输入通道数
    out_channels: int  # 输出通道数
    stride: int = 2  # 步幅,默认为2
    depth: int = 2  # 层的深度,默认为2
    dtype: jnp.dtype = jnp.float32  # 数据类型,默认为32位浮点数

    def setup(self):
        layer = FlaxRegNetXLayer if self.config.layer_type == "x" else FlaxRegNetYLayer

        layers = [
            # downsampling is done in the first layer with stride of 2
            layer(
                self.config,
                self.in_channels,
                self.out_channels,
                stride=self.stride,
                dtype=self.dtype,
                name="0",
            )
        ]

        for i in range(self.depth - 1):
            layers.append(
                layer(
                    self.config,
                    self.out_channels,
                    self.out_channels,
                    dtype=self.dtype,
                    name=str(i + 1),
                )
            )

        self.layers = layers  # 存储所有层的列表

    def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
        hidden_state = x
        for layer in self.layers:
            hidden_state = layer(hidden_state, deterministic=deterministic)
        return hidden_state


# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetStage with ResNet->RegNet
class FlaxRegNetStage(nn.Module):
    """
    A RegNet stage composed by stacked layers.
    """

    config: RegNetConfig  # 存储RegNet配置的对象
    in_channels: int  # 输入通道数
    out_channels: int  # 输出通道数
    stride: int = 2  # 步幅,默认为2
    depth: int = 2  # 层的深度,默认为2
    dtype: jnp.dtype = jnp.float32  # 数据类型,默认为32位浮点数

    def setup(self):
        self.layers = FlaxRegNetStageLayersCollection(
            self.config,
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            stride=self.stride,
            depth=self.depth,
            dtype=self.dtype,
        )

    def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
        return self.layers(x, deterministic=deterministic)


# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetStageCollection with ResNet->RegNet
class FlaxRegNetStageCollection(nn.Module):
    config: RegNetConfig  # 存储RegNet配置的对象
    dtype: jnp.dtype = jnp.float32  # 数据类型,默认为32位浮点数
    # 定义初始化方法,用于设置网络结构
    def setup(self):
        # 计算每个阶段输入输出通道数的元组列表
        in_out_channels = zip(self.config.hidden_sizes, self.config.hidden_sizes[1:])
        # 创建阶段列表,并添加第一个阶段的配置
        stages = [
            FlaxRegNetStage(
                self.config,
                self.config.embedding_size,
                self.config.hidden_sizes[0],
                stride=2 if self.config.downsample_in_first_stage else 1,
                depth=self.config.depths[0],
                dtype=self.dtype,
                name="0",
            )
        ]

        # 遍历计算后续阶段的配置并添加到阶段列表中
        for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, self.config.depths[1:])):
            stages.append(
                FlaxRegNetStage(self.config, in_channels, out_channels, depth=depth, dtype=self.dtype, name=str(i + 1))
            )

        # 将创建好的阶段列表赋值给对象的stages属性
        self.stages = stages

    # 定义调用方法,实现模型的前向传播
    def __call__(
        self,
        hidden_state: jnp.ndarray,
        output_hidden_states: bool = False,
        deterministic: bool = True,
    ) -> FlaxBaseModelOutputWithNoAttention:
        # 如果需要输出隐藏状态,则初始化一个空元组用于存储隐藏状态
        hidden_states = () if output_hidden_states else None

        # 遍历每个阶段模块进行前向传播
        for stage_module in self.stages:
            # 如果需要输出隐藏状态,则将当前隐藏状态转置后添加到隐藏状态元组中
            if output_hidden_states:
                hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),)

            # 调用当前阶段模块进行前向传播,更新隐藏状态
            hidden_state = stage_module(hidden_state, deterministic=deterministic)

        # 返回最终的隐藏状态和可能的隐藏状态元组
        return hidden_state, hidden_states
# 从 transformers.models.resnet.modeling_flax_resnet.FlaxResNetEncoder 复制而来,将 ResNet 修改为 RegNet
class FlaxRegNetEncoder(nn.Module):
    # 使用 RegNetConfig 类型的 config 属性
    config: RegNetConfig
    # 使用 jnp.float32 类型的 dtype 属性
    dtype: jnp.dtype = jnp.float32

    # 模块初始化方法
    def setup(self):
        # 使用 RegNetConfig 和 dtype 创建 FlaxRegNetStageCollection 实例
        self.stages = FlaxRegNetStageCollection(self.config, dtype=self.dtype)

    # 调用方法,接受 hidden_state, output_hidden_states, return_dict, deterministic 参数,并返回 FlaxBaseModelOutputWithNoAttention 类型的值
    def __call__(
        self,
        hidden_state: jnp.ndarray,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ) -> FlaxBaseModelOutputWithNoAttention:
        # 调用 self.stages 处理 hidden_state,并根据参数设置返回 hidden_state 和 hidden_states
        hidden_state, hidden_states = self.stages(
            hidden_state, output_hidden_states=output_hidden_states, deterministic=deterministic
        )

        # 如果 output_hidden_states 为真,则添加转置后的 hidden_state 到 hidden_states 元组中
        if output_hidden_states:
            hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),)

        # 如果 return_dict 为假,则返回非空的 hidden_state 和 hidden_states 元组
        if not return_dict:
            return tuple(v for v in [hidden_state, hidden_states] if v is not None)

        # 返回 FlaxBaseModelOutputWithNoAttention 实例,包含 last_hidden_state 和 hidden_states
        return FlaxBaseModelOutputWithNoAttention(
            last_hidden_state=hidden_state,
            hidden_states=hidden_states,
        )


# 从 transformers.models.resnet.modeling_flax_resnet.FlaxResNetPreTrainedModel 复制而来,将 ResNet 修改为 RegNet,resnet 修改为 regnet,RESNET 修改为 REGNET
class FlaxRegNetPreTrainedModel(FlaxPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 使用 RegNetConfig 类型的 config_class 属性
    config_class = RegNetConfig
    # 使用 "regnet" 字符串的 base_model_prefix 属性
    base_model_prefix = "regnet"
    # 使用 "pixel_values" 字符串的 main_input_name 属性
    main_input_name = "pixel_values"
    # 使用 NoneType 的 module_class 属性
    module_class: nn.Module = None

    # 初始化方法,接受 config, input_shape, seed, dtype, _do_init 等参数
    def __init__(
        self,
        config: RegNetConfig,
        input_shape=(1, 224, 224, 3),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        # 使用 config, dtype, **kwargs 创建 module 实例
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        # 如果 input_shape 为 None,则使用 (1, config.image_size, config.image_size, config.num_channels)
        if input_shape is None:
            input_shape = (1, config.image_size, config.image_size, config.num_channels)
        # 调用父类的初始化方法,传递 config, module, input_shape, seed, dtype, _do_init 等参数
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

    # 初始化权重方法,接受 rng, input_shape, params 等参数,并返回 FrozenDict 类型的值
    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # 创建 dtype 为 self.dtype 的 jnp.zeros 像素值数组 pixel_values
        pixel_values = jnp.zeros(input_shape, dtype=self.dtype)

        # 创建包含 rng 的 rngs 字典
        rngs = {"params": rng}

        # 使用 self.module.init 初始化随机参数 random_params,并设置 return_dict 为 False
        random_params = self.module.init(rngs, pixel_values, return_dict=False)

        # 如果 params 非空,则扁平化并合并 random_params 和 params 中缺失的键
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            # 否则返回 random_params
            return random_params

    # 将 REGNET_INPUTS_DOCSTRING 添加到模型前向传播方法的装饰器
    @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING)
    def __call__(
        self,
        pixel_values,
        params: dict = None,
        train: bool = False,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # 将输入的像素值进行维度转换,调整通道顺序为 (0, 2, 3, 1)
        pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))

        # 处理可能需要的伪随机数生成器
        rngs = {}

        # 调用模块的应用方法,传递参数和数据
        return self.module.apply(
            {
                "params": params["params"] if params is not None else self.params["params"],
                "batch_stats": params["batch_stats"] if params is not None else self.params["batch_stats"],
            },
            jnp.array(pixel_values, dtype=jnp.float32),  # 转换后的像素值数组
            not train,  # 训练模式取反,传递给模块的参数
            output_hidden_states,  # 是否返回隐藏状态的标志
            return_dict,  # 是否返回字典形式的输出
            rngs=rngs,  # 伪随机数生成器
            mutable=["batch_stats"] if train else False,  # 当训练为真时,返回包含 batch_stats 的元组
        )
# 从transformers.models.resnet.modeling_flax_resnet.FlaxResNetModule复制到此处,修改ResNet为RegNet
class FlaxRegNetModule(nn.Module):
    config: RegNetConfig  # 模型配置为RegNetConfig类型
    dtype: jnp.dtype = jnp.float32  # 计算中使用的数据类型为jnp.float32

    def setup(self):
        self.embedder = FlaxRegNetEmbeddings(self.config, dtype=self.dtype)
        self.encoder = FlaxRegNetEncoder(self.config, dtype=self.dtype)

        # 在ResNet中使用的自适应平均池化
        self.pooler = partial(
            nn.avg_pool,
            padding=((0, 0), (0, 0)),
        )

    def __call__(
        self,
        pixel_values,
        deterministic: bool = True,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ) -> FlaxBaseModelOutputWithPoolingAndNoAttention:
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        embedding_output = self.embedder(pixel_values, deterministic=deterministic)

        encoder_outputs = self.encoder(
            embedding_output,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )

        last_hidden_state = encoder_outputs[0]

        # 对最后隐藏状态进行自适应平均池化
        pooled_output = self.pooler(
            last_hidden_state,
            window_shape=(last_hidden_state.shape[1], last_hidden_state.shape[2]),
            strides=(last_hidden_state.shape[1], last_hidden_state.shape[2]),
        ).transpose(0, 3, 1, 2)

        last_hidden_state = last_hidden_state.transpose(0, 3, 1, 2)

        if not return_dict:
            # 如果不返回字典形式,则返回元组形式的输出
            return (last_hidden_state, pooled_output) + encoder_outputs[1:]

        # 返回带池化和无注意力的基础模型输出的字典形式
        return FlaxBaseModelOutputWithPoolingAndNoAttention(
            last_hidden_state=last_hidden_state,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
        )


@add_start_docstrings(
    "The bare RegNet model outputting raw features without any specific head on top.",
    REGNET_START_DOCSTRING,
)
class FlaxRegNetModel(FlaxRegNetPreTrainedModel):
    module_class = FlaxRegNetModule


# FLAX_VISION_MODEL_DOCSTRING字符串文档
FLAX_VISION_MODEL_DOCSTRING = """
    Returns:

    Examples:

    ```
    >>> from transformers import AutoImageProcessor, FlaxRegNetModel
    >>> from PIL import Image
    >>> import requests

    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)

    >>> image_processor = AutoImageProcessor.from_pretrained("facebook/regnet-y-040")
    >>> model = FlaxRegNetModel.from_pretrained("facebook/regnet-y-040")

    >>> inputs = image_processor(images=image, return_tensors="np")
    >>> outputs = model(**inputs)
    >>> last_hidden_states = outputs.last_hidden_state
    ```
"""

# 覆盖FlaxRegNetModel的调用文档字符串
overwrite_call_docstring(FlaxRegNetModel, FLAX_VISION_MODEL_DOCSTRING)
# 使用函数`append_replace_return_docstrings`设置FlaxRegNetModel的文档字符串,指定输出类型为FlaxBaseModelOutputWithPooling,
# 并使用RegNetConfig进行配置。

# 从`transformers.models.resnet.modeling_flax_resnet.FlaxResNetClassifierCollection`复制代码到`FlaxRegNetClassifierCollection`,
# 将ResNet更改为RegNet。这个类用于创建RegNet模型的分类器集合。

class FlaxRegNetClassifierCollection(nn.Module):
    # 使用RegNetConfig配置模型
    config: RegNetConfig
    # 默认数据类型为jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 模块设置方法
    def setup(self):
        # 创建具有config.num_labels输出的全连接层作为分类器,数据类型为dtype,名称为"1"
        self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype, name="1")

    # 调用方法,将输入x经过分类器处理后返回结果
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        return self.classifier(x)


# 从`transformers.models.resnet.modeling_flax_resnet.FlaxResNetForImageClassificationModule`复制代码到`FlaxRegNetForImageClassificationModule`,
# 将ResNet更改为RegNet,同时修改resnet->regnet, RESNET->REGNET。这个类用于创建RegNet用于图像分类的模块。

class FlaxRegNetForImageClassificationModule(nn.Module):
    # 使用RegNetConfig配置模型
    config: RegNetConfig
    # 默认数据类型为jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 模块设置方法
    def setup(self):
        # 创建RegNet模块,使用给定的config和dtype
        self.regnet = FlaxRegNetModule(config=self.config, dtype=self.dtype)

        # 根据配置决定是否创建分类器集合或保持身份映射
        if self.config.num_labels > 0:
            self.classifier = FlaxRegNetClassifierCollection(self.config, dtype=self.dtype)
        else:
            self.classifier = Identity()

    # 调用方法,根据输入参数进行前向传播,返回分类结果
    def __call__(
        self,
        pixel_values=None,
        deterministic: bool = True,
        output_hidden_states=None,
        return_dict=None,
    ):
        # 根据参数设置返回字典的使用与否
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用RegNet模块进行前向传播
        outputs = self.regnet(
            pixel_values,
            deterministic=deterministic,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 如果使用返回字典,则获取池化后的输出;否则直接从输出中获取池化后的特征
        pooled_output = outputs.pooler_output if return_dict else outputs[1]

        # 将池化后的特征输入到分类器中,获取最终的logits
        logits = self.classifier(pooled_output[:, :, 0, 0])

        # 如果不使用返回字典,则将logits与额外的隐藏状态一起返回
        if not return_dict:
            output = (logits,) + outputs[2:]
            return output

        # 使用FlaxImageClassifierOutputWithNoAttention将logits和隐藏状态输出
        return FlaxImageClassifierOutputWithNoAttention(logits=logits, hidden_states=outputs.hidden_states)


@add_start_docstrings(
    """
    用于在RegNet模型顶部添加图像分类头的模型,例如在ImageNet上使用线性层对池化特征进行分类。
    """,
    REGNET_START_DOCSTRING,
)
# 使用`add_start_docstrings`添加模型文档字符串,描述其用途和示例。
class FlaxRegNetForImageClassification(FlaxRegNetPreTrainedModel):
    # 模块类别设置为FlaxRegNetForImageClassificationModule
    module_class = FlaxRegNetForImageClassificationModule


FLAX_VISION_CLASSIF_DOCSTRING = """
    Returns:

    Example:

    ```
    >>> from transformers import AutoImageProcessor, FlaxRegNetForImageClassification
    >>> from PIL import Image
    >>> import jax
    >>> import requests

    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)

    >>> image_processor = AutoImageProcessor.from_pretrained("facebook/regnet-y-040")
    >>> model = FlaxRegNetForImageClassification.from_pretrained("facebook/regnet-y-040")

    >>> inputs = image_processor(images=image, return_tensors="np")
    >>> outputs = model(**inputs)
    >>> logits = outputs.logits

    >>> # model predicts one of the 1000 ImageNet classes


# 注释:
# 此部分为文档字符串`FLAX_VISION_CLASSIF_DOCSTRING`,提供了该模型的返回值说明和使用示例。
    # 使用 JAX 提供的 numpy 模块计算 logits 中每个样本预测的类别索引
    predicted_class_idx = jax.numpy.argmax(logits, axis=-1)
    # 打印预测的类别,根据模型配置中的 id2label 映射将索引转换为标签名称并输出
    print("Predicted class:", model.config.id2label[predicted_class_idx.item()])
"""
覆盖调用函数的文档字符串为指定的文档字符串。
"""
overwrite_call_docstring(FlaxRegNetForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING)

"""
向指定类追加或替换返回值文档字符串。
"""
append_replace_return_docstrings(
    FlaxRegNetForImageClassification,
    output_type=FlaxImageClassifierOutputWithNoAttention,
    config_class=RegNetConfig,
)

.\models\regnet\modeling_regnet.py

# coding=utf-8
# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch RegNet model."""

from typing import Optional

import torch
import torch.utils.checkpoint
from torch import Tensor, nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import (
    BaseModelOutputWithNoAttention,
    BaseModelOutputWithPoolingAndNoAttention,
    ImageClassifierOutputWithNoAttention,
)
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_regnet import RegNetConfig


logger = logging.get_logger(__name__)

# General docstring
_CONFIG_FOR_DOC = "RegNetConfig"

# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/regnet-y-040"
_EXPECTED_OUTPUT_SHAPE = [1, 1088, 7, 7]

# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "facebook/regnet-y-040"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"

REGNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "facebook/regnet-y-040",
    # See all regnet models at https://huggingface.co/models?filter=regnet
]


class RegNetConvLayer(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        groups: int = 1,
        activation: Optional[str] = "relu",
    ):
        super().__init__()
        # 定义卷积层,设置卷积核大小、步长、填充方式、分组数和是否使用偏置
        self.convolution = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=kernel_size // 2,
            groups=groups,
            bias=False,
        )
        # 定义批归一化层
        self.normalization = nn.BatchNorm2d(out_channels)
        # 根据激活函数名称选择激活函数,或者使用恒等映射
        self.activation = ACT2FN[activation] if activation is not None else nn.Identity()

    def forward(self, hidden_state):
        # 执行卷积操作
        hidden_state = self.convolution(hidden_state)
        # 执行批归一化操作
        hidden_state = self.normalization(hidden_state)
        # 执行激活函数操作
        hidden_state = self.activation(hidden_state)
        return hidden_state


class RegNetEmbeddings(nn.Module):
    """
    RegNet Embedddings (stem) composed of a single aggressive convolution.
    """
    # 初始化函数,接受一个配置对象作为参数
    def __init__(self, config: RegNetConfig):
        # 调用父类的初始化方法
        super().__init__()
        # 创建一个 RegNetConvLayer 实例作为 embedder 属性,配置如下参数:
        # - config.num_channels: 输入通道数
        # - config.embedding_size: 嵌入向量的大小
        # - kernel_size=3: 卷积核大小为 3x3
        # - stride=2: 步长为 2
        # - activation=config.hidden_act: 激活函数由配置对象中的 hidden_act 决定
        self.embedder = RegNetConvLayer(
            config.num_channels, config.embedding_size, kernel_size=3, stride=2, activation=config.hidden_act
        )
        # 将配置对象中的 num_channels 属性赋值给实例的 num_channels 属性
        self.num_channels = config.num_channels

    # 前向传播函数,接受像素值作为输入
    def forward(self, pixel_values):
        # 获取像素值的通道数
        num_channels = pixel_values.shape[1]
        # 如果像素值的通道数与实例属性中的 num_channels 不匹配,抛出 ValueError
        if num_channels != self.num_channels:
            raise ValueError(
                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
            )
        # 将像素值传递给 embedder 进行处理,得到隐藏状态 hidden_state
        hidden_state = self.embedder(pixel_values)
        # 返回隐藏状态 hidden_state
        return hidden_state
# 从transformers.models.resnet.modeling_resnet.ResNetShortCut复制并修改为RegNetShortCut
class RegNetShortCut(nn.Module):
    """
    RegNet的shortcut,用于将残差特征投影到正确的大小。如果需要,还用于使用`stride=2`对输入进行下采样。
    """

    def __init__(self, in_channels: int, out_channels: int, stride: int = 2):
        super().__init__()
        # 使用1x1的卷积层进行投影,并设置步长和无偏置
        self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
        # 添加批归一化层
        self.normalization = nn.BatchNorm2d(out_channels)

    def forward(self, input: Tensor) -> Tensor:
        # 对输入进行1x1卷积操作
        hidden_state = self.convolution(input)
        # 对卷积结果进行批归一化
        hidden_state = self.normalization(hidden_state)
        return hidden_state


class RegNetSELayer(nn.Module):
    """
    压缩与激发层(SE),在[Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507)中提出。
    """

    def __init__(self, in_channels: int, reduced_channels: int):
        super().__init__()
        # 自适应平均池化层,将输入大小池化为(1, 1)
        self.pooler = nn.AdaptiveAvgPool2d((1, 1))
        # SE结构,包括两个1x1卷积层,ReLU激活函数和Sigmoid激活函数
        self.attention = nn.Sequential(
            nn.Conv2d(in_channels, reduced_channels, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(reduced_channels, in_channels, kernel_size=1),
            nn.Sigmoid(),
        )

    def forward(self, hidden_state):
        # 输入为b c h w,将其池化为b c 1 1
        pooled = self.pooler(hidden_state)
        # 使用SE结构计算注意力权重
        attention = self.attention(pooled)
        # 使用注意力权重加权输入特征
        hidden_state = hidden_state * attention
        return hidden_state


class RegNetXLayer(nn.Module):
    """
    RegNet的层,由三个3x3的卷积组成,与ResNet的瓶颈层相同,但reduction=1。
    """

    def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1):
        super().__init__()
        # 确定是否应用shortcut,以及设置groups参数
        should_apply_shortcut = in_channels != out_channels or stride != 1
        groups = max(1, out_channels // config.groups_width)
        # 设置shortcut连接,如果需要则使用RegNetShortCut,否则使用身份映射
        self.shortcut = (
            RegNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity()
        )
        # 设计层的顺序:第一层1x1卷积,第二层3x3卷积(可能使用分组卷积),第三层1x1卷积
        self.layer = nn.Sequential(
            RegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act),
            RegNetConvLayer(out_channels, out_channels, stride=stride, groups=groups, activation=config.hidden_act),
            RegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None),
        )
        # 设定激活函数
        self.activation = ACT2FN[config.hidden_act]

    def forward(self, hidden_state):
        # 保留输入的残差连接
        residual = hidden_state
        # 执行层内的卷积操作
        hidden_state = self.layer(hidden_state)
        # 应用shortcut连接
        residual = self.shortcut(residual)
        # 将残差添加到层的输出中
        hidden_state += residual
        # 应用激活函数
        hidden_state = self.activation(hidden_state)
        return hidden_state


class RegNetYLayer(nn.Module):
    """
    RegNet的Y层:一个带有Squeeze和Excitation的X层。
    """
    # 初始化函数,用于初始化一个网络模块
    def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1):
        # 调用父类的初始化函数
        super().__init__()
        # 根据输入输出通道数和步长判断是否需要应用快捷连接
        should_apply_shortcut = in_channels != out_channels or stride != 1
        # 计算分组卷积的分组数,确保至少有一个分组
        groups = max(1, out_channels // config.groups_width)
        # 如果需要应用快捷连接,则创建RegNetShortCut对象;否则创建nn.Identity对象
        self.shortcut = (
            RegNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity()
        )
        # 创建一个包含多个子模块的序列模块
        self.layer = nn.Sequential(
            # 第一个卷积层:输入通道数到输出通道数的卷积,卷积核大小为1,激活函数为config中指定的隐藏层激活函数
            RegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act),
            # 第二个卷积层:输出通道数到输出通道数的卷积,卷积核大小为3(由步长决定),分组卷积数为groups,激活函数为config中指定的隐藏层激活函数
            RegNetConvLayer(out_channels, out_channels, stride=stride, groups=groups, activation=config.hidden_act),
            # Squeeze-and-Excitation(SE)模块:对输出通道数进行SE操作,减少通道数为输入通道数的四分之一
            RegNetSELayer(out_channels, reduced_channels=int(round(in_channels / 4))),
            # 第三个卷积层:输出通道数到输出通道数的卷积,卷积核大小为1,无激活函数
            RegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None),
        )
        # 激活函数,从配置中选择合适的激活函数
        self.activation = ACT2FN[config.hidden_act]

    # 前向传播函数,用于定义数据从输入到输出的流程
    def forward(self, hidden_state):
        # 保存输入作为残差
        residual = hidden_state
        # 将输入通过序列模块进行前向传播
        hidden_state = self.layer(hidden_state)
        # 使用快捷连接模块对残差进行转换
        residual = self.shortcut(residual)
        # 将前向传播结果与转换后的残差相加
        hidden_state += residual
        # 对相加后的结果应用激活函数
        hidden_state = self.activation(hidden_state)
        # 返回处理后的输出结果
        return hidden_state
class RegNetStage(nn.Module):
    """
    A RegNet stage composed by stacked layers.
    """

    def __init__(
        self,
        config: RegNetConfig,
        in_channels: int,
        out_channels: int,
        stride: int = 2,
        depth: int = 2,
    ):
        super().__init__()

        # 根据配置选择不同类型的层
        layer = RegNetXLayer if config.layer_type == "x" else RegNetYLayer

        # 使用 nn.Sequential 定义层的序列
        self.layers = nn.Sequential(
            # 第一层进行下采样,步幅为2
            layer(
                config,
                in_channels,
                out_channels,
                stride=stride,
            ),
            *[layer(config, out_channels, out_channels) for _ in range(depth - 1)],
        )

    def forward(self, hidden_state):
        # 前向传播函数,依次通过每一层
        hidden_state = self.layers(hidden_state)
        return hidden_state


class RegNetEncoder(nn.Module):
    def __init__(self, config: RegNetConfig):
        super().__init__()
        self.stages = nn.ModuleList([])

        # 根据配置决定是否在第一阶段的第一层进行输入下采样
        self.stages.append(
            RegNetStage(
                config,
                config.embedding_size,
                config.hidden_sizes[0],
                stride=2 if config.downsample_in_first_stage else 1,
                depth=config.depths[0],
            )
        )

        # 逐阶段定义 RegNetStage,并连接起来
        in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:])
        for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]):
            self.stages.append(RegNetStage(config, in_channels, out_channels, depth=depth))

    def forward(
        self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True
    ) -> BaseModelOutputWithNoAttention:
        hidden_states = () if output_hidden_states else None

        # 逐阶段通过 RegNetStage 进行前向传播
        for stage_module in self.stages:
            if output_hidden_states:
                hidden_states = hidden_states + (hidden_state,)

            hidden_state = stage_module(hidden_state)

        if output_hidden_states:
            hidden_states = hidden_states + (hidden_state,)

        # 根据 return_dict 返回不同的输出格式
        if not return_dict:
            return tuple(v for v in [hidden_state, hidden_states] if v is not None)

        return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)


class RegNetPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = RegNetConfig
    base_model_prefix = "regnet"
    main_input_name = "pixel_values"

    # 从 transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel._init_weights 复制而来的初始化权重函数
    # 定义一个方法 `_init_weights`,用于初始化神经网络模块的权重
    def _init_weights(self, module):
        # 如果传入的模块是 nn.Conv2d 类型,则使用 Kaiming 正态分布初始化权重
        if isinstance(module, nn.Conv2d):
            nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
        # 如果传入的模块是 nn.BatchNorm2d 或 nn.GroupNorm 类型,则初始化权重为常数 1,偏置为常数 0
        elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
            nn.init.constant_(module.weight, 1)
            nn.init.constant_(module.bias, 0)
REGNET_START_DOCSTRING = r"""
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matters related to general usage and
    behavior.

    Parameters:
        config ([`RegNetConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

REGNET_INPUTS_DOCSTRING = r"""
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`ConvNextImageProcessor.__call__`] for details.

        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""

@add_start_docstrings(
    "The bare RegNet model outputting raw features without any specific head on top.",
    REGNET_START_DOCSTRING,
)
# Copied from transformers.models.resnet.modeling_resnet.ResNetModel with RESNET->REGNET,ResNet->RegNet
class RegNetModel(RegNetPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.embedder = RegNetEmbeddings(config)
        self.encoder = RegNetEncoder(config)
        self.pooler = nn.AdaptiveAvgPool2d((1, 1))
        # Initialize weights and apply final processing
        self.post_init()

    @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=BaseModelOutputWithPoolingAndNoAttention,
        config_class=_CONFIG_FOR_DOC,
        modality="vision",
        expected_output=_EXPECTED_OUTPUT_SHAPE,
    )
    def forward(
        self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
    ):
        """
        Perform the forward pass of the RegNet model.

        Args:
            pixel_values (torch.FloatTensor): Pixel values of shape `(batch_size, num_channels, height, width)`.
                These values are obtained using an `AutoImageProcessor`.

            output_hidden_states (bool, optional): Whether or not to return hidden states of all layers.
                Refer to `hidden_states` in the returned tensors for details.

            return_dict (bool, optional): Whether to return a `ModelOutput` instead of a tuple.

        Returns:
            Depending on `return_dict`, either a `ModelOutput` or a tuple of outputs from the model.
        """
        # Forward pass logic goes here
        pass
    ) -> BaseModelOutputWithPoolingAndNoAttention:
        # 函数声明,指定返回类型为BaseModelOutputWithPoolingAndNoAttention

        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 如果输出隐藏状态参数不为空,则使用该参数;否则使用self.config.output_hidden_states

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # 如果返回字典参数不为空,则使用该参数;否则使用self.config.use_return_dict

        embedding_output = self.embedder(pixel_values)
        # 将像素值传入嵌入器(embedder),获取嵌入输出

        encoder_outputs = self.encoder(
            embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict
        )
        # 使用编码器(encoder)处理嵌入输出,可以选择输出隐藏状态和是否返回字典

        last_hidden_state = encoder_outputs[0]
        # 获取编码器输出的最后隐藏状态

        pooled_output = self.pooler(last_hidden_state)
        # 使用池化器(pooler)对最后隐藏状态进行池化操作,得到池化输出

        if not return_dict:
            # 如果不返回字典
            return (last_hidden_state, pooled_output) + encoder_outputs[1:]
            # 返回最后隐藏状态、池化输出,以及编码器输出的其余部分

        return BaseModelOutputWithPoolingAndNoAttention(
            last_hidden_state=last_hidden_state,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
        )
        # 如果返回字典,则使用BaseModelOutputWithPoolingAndNoAttention类创建并返回一个对象,包括最后隐藏状态、池化输出和所有隐藏状态
@add_start_docstrings(
    """
    RegNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
    ImageNet.
    """,
    REGNET_START_DOCSTRING,
)
# 定义 RegNetForImageClassification 类,继承自 RegNetPreTrainedModel 类
class RegNetForImageClassification(RegNetPreTrainedModel):
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)
        # 设置分类标签数量
        self.num_labels = config.num_labels
        # 初始化 RegNetModel,并赋值给 self.regnet
        self.regnet = RegNetModel(config)
        # 定义分类器,使用 nn.Sequential 定义层序列
        self.classifier = nn.Sequential(
            nn.Flatten(),  # 将输入展平
            # 如果配置中有标签数量大于零,则添加全连接层;否则使用恒等映射
            nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity(),
        )
        # 执行初始化权重和最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_IMAGE_CLASS_CHECKPOINT,
        output_type=ImageClassifierOutputWithNoAttention,
        config_class=_CONFIG_FOR_DOC,
        expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
    )
    # 重写 forward 方法,接受像素值、标签等参数,返回模型输出
    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        # 输入参数详细文档字符串已添加

        # 在此处输入参数详细文档字符串已添加
        ):
        # 正文函数方法
    ) -> ImageClassifierOutputWithNoAttention:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        # 如果 return_dict 不为 None,则使用给定的 return_dict;否则使用 self.config.use_return_dict
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用 regnet 方法进行图像处理,返回输出结果
        outputs = self.regnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)

        # 如果 return_dict 为 True,则使用 outputs.pooler_output 作为 pooled_output;否则使用 outputs 的第二个元素
        pooled_output = outputs.pooler_output if return_dict else outputs[1]

        # 使用 classifier 模型对 pooled_output 进行分类得到 logits
        logits = self.classifier(pooled_output)

        # 初始化 loss 为 None
        loss = None

        # 如果 labels 不为 None,则计算损失函数
        if labels is not None:
            # 如果 self.config.problem_type 为 None,则根据条件设置 problem_type
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            # 根据 problem_type 计算相应的损失函数
            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    # 对于单标签回归任务,计算 logits.squeeze() 和 labels.squeeze() 的均方误差损失
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    # 对于多标签回归任务,计算 logits 和 labels 的均方误差损失
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                # 对于单标签分类任务,使用交叉熵损失函数
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                # 对于多标签分类任务,使用带 logits 的二元交叉熵损失函数
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        # 如果 return_dict 为 False,则返回 logits 和 outputs 的其他部分
        if not return_dict:
            output = (logits,) + outputs[2:]
            return (loss,) + output if loss is not None else output

        # 返回 ImageClassifierOutputWithNoAttention 对象,包括 loss、logits 和 hidden_states
        return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)

.\models\regnet\modeling_tf_regnet.py

# 设置文件编码为 UTF-8
# 版权声明和版权信息,表明该文件的版权归 Meta Platforms, Inc. 和 The HuggingFace Inc. 团队所有
#
# 根据 Apache 许可证 2.0 版本(“许可证”)授权;
# 除非符合许可证的要求,否则不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则按“原样”分发的软件
# 无任何明示或暗示的担保或条件。
# 请参阅许可证了解特定语言下的权限和限制。
""" TensorFlow RegNet 模型."""

from typing import Optional, Tuple, Union

import tensorflow as tf

# 从相应模块导入必要的功能和类
from ...activations_tf import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_tf_outputs import (
    TFBaseModelOutputWithNoAttention,
    TFBaseModelOutputWithPoolingAndNoAttention,
    TFSequenceClassifierOutput,
)
from ...modeling_tf_utils import (
    TFPreTrainedModel,
    TFSequenceClassificationLoss,
    keras,
    keras_serializable,
    unpack_inputs,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_regnet import RegNetConfig

# 获取日志记录器实例
logger = logging.get_logger(__name__)

# 用于文档的通用配置
_CONFIG_FOR_DOC = "RegNetConfig"

# 用于文档的基本检查点
_CHECKPOINT_FOR_DOC = "facebook/regnet-y-040"
# 预期输出的形状
_EXPECTED_OUTPUT_SHAPE = [1, 1088, 7, 7]

# 图像分类相关的检查点
_IMAGE_CLASS_CHECKPOINT = "facebook/regnet-y-040"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"

# TFRegNet 模型的预训练模型存档列表
TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "facebook/regnet-y-040",
    # 查看所有 RegNet 模型:https://huggingface.co/models?filter=regnet
]

# 定义 TFRegNetConvLayer 类,继承自 keras.layers.Layer
class TFRegNetConvLayer(keras.layers.Layer):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        groups: int = 1,
        activation: Optional[str] = "relu",
        **kwargs,
    ):
        super().__init__(**kwargs)
        # 对输入进行零填充,以确保输出大小与输入大小相同
        self.padding = keras.layers.ZeroPadding2D(padding=kernel_size // 2)
        # 定义卷积层,设置卷积核大小、步长、填充方式和组数
        self.convolution = keras.layers.Conv2D(
            filters=out_channels,
            kernel_size=kernel_size,
            strides=stride,
            padding="VALID",
            groups=groups,
            use_bias=False,
            name="convolution",
        )
        # 批量归一化层,用于规范化卷积层的输出
        self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization")
        # 激活函数,根据给定的激活函数名称选择激活函数或者返回标识函数
        self.activation = ACT2FN[activation] if activation is not None else tf.identity
        self.in_channels = in_channels
        self.out_channels = out_channels
    # 定义一个方法,用于调用神经网络层的操作
    def call(self, hidden_state):
        # 对输入的隐藏状态进行填充,并进行卷积操作
        hidden_state = self.convolution(self.padding(hidden_state))
        # 对卷积后的结果进行规范化(例如批量归一化)
        hidden_state = self.normalization(hidden_state)
        # 对规范化后的结果应用激活函数(如ReLU)
        hidden_state = self.activation(hidden_state)
        # 返回处理后的隐藏状态
        return hidden_state

    # 定义一个方法,用于构建神经网络层
    def build(self, input_shape=None):
        # 如果网络层已经构建,则直接返回
        if self.built:
            return
        # 标记网络层为已构建状态
        self.built = True
        # 如果存在卷积操作,并且未被构建,则构建卷积操作
        if getattr(self, "convolution", None) is not None:
            with tf.name_scope(self.convolution.name):
                # 使用输入通道数构建卷积层
                self.convolution.build([None, None, None, self.in_channels])
        # 如果存在规范化操作,并且未被构建,则构建规范化操作
        if getattr(self, "normalization", None) is not None:
            with tf.name_scope(self.normalization.name):
                # 使用输出通道数构建规范化层
                self.normalization.build([None, None, None, self.out_channels])
class TFRegNetEmbeddings(keras.layers.Layer):
    """
    RegNet Embeddings (stem) composed of a single aggressive convolution.
    """

    def __init__(self, config: RegNetConfig, **kwargs):
        super().__init__(**kwargs)
        self.num_channels = config.num_channels  # 从配置中获取通道数
        self.embedder = TFRegNetConvLayer(
            in_channels=config.num_channels,  # 输入通道数
            out_channels=config.embedding_size,  # 输出通道数(嵌入维度)
            kernel_size=3,  # 卷积核大小
            stride=2,  # 步长
            activation=config.hidden_act,  # 激活函数
            name="embedder",  # 层的名称
        )

    def call(self, pixel_values):
        num_channels = shape_list(pixel_values)[1]  # 获取像素值的通道数
        if tf.executing_eagerly() and num_channels != self.num_channels:
            raise ValueError(
                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
            )

        # 当在 CPU 上运行时,`keras.layers.Conv2D` 不支持 `NCHW` 格式。
        # 因此将输入格式从 `NCHW` 转换为 `NHWC`。
        # shape = (batch_size, in_height, in_width, in_channels=num_channels)
        pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))  # 转置像素值的维度顺序
        hidden_state = self.embedder(pixel_values)  # 嵌入器处理像素值
        return hidden_state

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "embedder", None) is not None:
            with tf.name_scope(self.embedder.name):
                self.embedder.build(None)  # 构建嵌入器层


class TFRegNetShortCut(keras.layers.Layer):
    """
    RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
    downsample the input using `stride=2`.
    """

    def __init__(self, in_channels: int, out_channels: int, stride: int = 2, **kwargs):
        super().__init__(**kwargs)
        self.convolution = keras.layers.Conv2D(
            filters=out_channels, kernel_size=1, strides=stride, use_bias=False, name="convolution"
        )  # 1x1 卷积层,用于投影和下采样
        self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization")  # 批量归一化层
        self.in_channels = in_channels  # 输入通道数
        self.out_channels = out_channels  # 输出通道数

    def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
        return self.normalization(self.convolution(inputs), training=training)  # 应用卷积和归一化操作

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "convolution", None) is not None:
            with tf.name_scope(self.convolution.name):
                self.convolution.build([None, None, None, self.in_channels])  # 构建卷积层
        if getattr(self, "normalization", None) is not None:
            with tf.name_scope(self.normalization.name):
                self.normalization.build([None, None, None, self.out_channels])  # 构建归一化层


class TFRegNetSELayer(keras.layers.Layer):
    """
    Placeholder for the SE (Squeeze-and-Excitation) Layer in RegNet, to be implemented.
    This layer is intended for enhancing channel-wise relationships adaptively.
    """
    Squeeze and Excitation layer (SE) proposed in [Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507).
    """

    # 定义 Squeeze-and-Excitation(SE)层的类
    def __init__(self, in_channels: int, reduced_channels: int, **kwargs):
        super().__init__(**kwargs)
        # 创建全局平均池化层,用于计算特征图的平均值
        self.pooler = keras.layers.GlobalAveragePooling2D(keepdims=True, name="pooler")
        # 创建注意力机制的两个卷积层,用于生成注意力权重
        self.attention = [
            keras.layers.Conv2D(filters=reduced_channels, kernel_size=1, activation="relu", name="attention.0"),
            keras.layers.Conv2D(filters=in_channels, kernel_size=1, activation="sigmoid", name="attention.2"),
        ]
        # 记录输入通道数和降维后的通道数
        self.in_channels = in_channels
        self.reduced_channels = reduced_channels

    # 定义 SE 层的前向传播函数
    def call(self, hidden_state):
        # 对输入的特征图进行全局平均池化,生成池化后的结果
        pooled = self.pooler(hidden_state)
        # 对池化后的结果分别通过两个注意力卷积层,生成注意力权重
        for layer_module in self.attention:
            pooled = layer_module(pooled)
        # 将原始特征图与注意力权重相乘,增强特征表示
        hidden_state = hidden_state * pooled
        return hidden_state

    # 构建 SE 层,确保每个组件都被正确地构建和连接
    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 构建全局平均池化层
        if getattr(self, "pooler", None) is not None:
            with tf.name_scope(self.pooler.name):
                self.pooler.build((None, None, None, None))
        # 构建注意力卷积层
        if getattr(self, "attention", None) is not None:
            with tf.name_scope(self.attention[0].name):
                self.attention[0].build([None, None, None, self.in_channels])
            with tf.name_scope(self.attention[1].name):
                self.attention[1].build([None, None, None, self.reduced_channels])
# 定义 TFRegNetXLayer 类,表示 RegNet 模型中的一个层,类似于 ResNet 的瓶颈层,但具有不同的特性。
class TFRegNetXLayer(keras.layers.Layer):
    """
    RegNet's layer composed by three `3x3` convolutions, same as a ResNet bottleneck layer with reduction = 1.
    """

    # 初始化方法,设置层的参数和结构
    def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1, **kwargs):
        super().__init__(**kwargs)
        # 检查是否需要应用快捷连接,根据输入和输出通道数以及步长来判断
        should_apply_shortcut = in_channels != out_channels or stride != 1
        # 如果需要应用快捷连接,则创建 TFRegNetShortCut 实例作为 shortcut 属性;否则创建线性激活函数作为 shortcut 属性
        self.shortcut = (
            TFRegNetShortCut(in_channels, out_channels, stride=stride, name="shortcut")
            if should_apply_shortcut
            else keras.layers.Activation("linear", name="shortcut")
        )
        # 定义三个卷积层的列表,每一层都是 TFRegNetConvLayer 类的实例,用于构建层内部的特征提取流程
        self.layers = [
            TFRegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act, name="layer.0"),
            TFRegNetConvLayer(
                out_channels, out_channels, stride=stride, groups=max(1, out_channels // config.groups_width),
                activation=config.hidden_act, name="layer.1"
            ),
            TFRegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None, name="layer.2"),
        ]
        # 激活函数根据配置文件中的隐藏激活函数来选择
        self.activation = ACT2FN[config.hidden_act]

    # 定义层的前向传播逻辑
    def call(self, hidden_state):
        # 保存输入的残差连接
        residual = hidden_state
        # 遍历每一层卷积,依次对 hidden_state 进行特征提取
        for layer_module in self.layers:
            hidden_state = layer_module(hidden_state)
        # 将残差连接通过快捷连接层进行处理
        residual = self.shortcut(residual)
        # 将特征提取后的 hidden_state 与处理后的残差相加
        hidden_state += residual
        # 使用预定义的激活函数对输出进行激活
        hidden_state = self.activation(hidden_state)
        return hidden_state

    # 构建方法,用于在第一次调用前构建层的变量
    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果定义了快捷连接,则构建快捷连接层
        if getattr(self, "shortcut", None) is not None:
            with tf.name_scope(self.shortcut.name):
                self.shortcut.build(None)
        # 构建每一个卷积层
        if getattr(self, "layers", None) is not None:
            for layer in self.layers:
                with tf.name_scope(layer.name):
                    layer.build(None)


class TFRegNetYLayer(keras.layers.Layer):
    """
    RegNet's Y layer: an X layer with Squeeze and Excitation.
    """
    # 初始化函数,用于初始化模型对象
    def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)
        # 确定是否应用快捷连接(shortcut),条件是输入通道数不等于输出通道数或步长不为1
        should_apply_shortcut = in_channels != out_channels or stride != 1
        # 计算组数,确保至少有一个组
        groups = max(1, out_channels // config.groups_width)
        # 如果应用快捷连接,则创建一个 TFRegNetShortCut 对象作为快捷连接,否则创建线性激活函数作为快捷连接
        self.shortcut = (
            TFRegNetShortCut(in_channels, out_channels, stride=stride, name="shortcut")
            if should_apply_shortcut
            else keras.layers.Activation("linear", name="shortcut")
        )
        # 定义模型的层列表,包括几个 TFRegNetConvLayer 层和一个 TFRegNetSELayer 层
        self.layers = [
            TFRegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act, name="layer.0"),
            TFRegNetConvLayer(
                out_channels, out_channels, stride=stride, groups=groups, activation=config.hidden_act, name="layer.1"
            ),
            TFRegNetSELayer(out_channels, reduced_channels=int(round(in_channels / 4)), name="layer.2"),
            TFRegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None, name="layer.3"),
        ]
        # 激活函数使用根据配置选择的激活函数
        self.activation = ACT2FN[config.hidden_act]

    # 调用函数,用于模型的前向传播
    def call(self, hidden_state):
        # 将输入状态作为残差
        residual = hidden_state
        # 遍历模型的每一层,并对输入状态进行处理
        for layer_module in self.layers:
            hidden_state = layer_module(hidden_state)
        # 将残差通过快捷连接处理
        residual = self.shortcut(residual)
        # 将处理后的状态与残差相加
        hidden_state += residual
        # 应用激活函数到最终的隐藏状态
        hidden_state = self.activation(hidden_state)
        # 返回最终的隐藏状态
        return hidden_state

    # 构建函数,用于构建模型的层次结构
    def build(self, input_shape=None):
        # 如果模型已经构建过,则直接返回
        if self.built:
            return
        # 标记模型已经构建
        self.built = True
        # 如果存在快捷连接,则构建快捷连接
        if getattr(self, "shortcut", None) is not None:
            with tf.name_scope(self.shortcut.name):
                self.shortcut.build(None)
        # 遍历每一层,并构建每一层
        if getattr(self, "layers", None) is not None:
            for layer in self.layers:
                with tf.name_scope(layer.name):
                    layer.build(None)
class TFRegNetStage(keras.layers.Layer):
    """
    A RegNet stage composed by stacked layers.
    """

    def __init__(
        self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 2, depth: int = 2, **kwargs
    ):
        super().__init__(**kwargs)

        # 根据配置选择使用 TFRegNetXLayer 或 TFRegNetYLayer 作为层
        layer = TFRegNetXLayer if config.layer_type == "x" else TFRegNetYLayer

        # 创建层列表,第一层可能使用 stride=2 进行下采样
        self.layers = [
            layer(config, in_channels, out_channels, stride=stride, name="layers.0"),
            *[layer(config, out_channels, out_channels, name=f"layers.{i+1}") for i in range(depth - 1)],
        ]

    def call(self, hidden_state):
        # 逐层调用各层的 call 方法
        for layer_module in self.layers:
            hidden_state = layer_module(hidden_state)
        return hidden_state

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "layers", None) is not None:
            for layer in self.layers:
                with tf.name_scope(layer.name):
                    layer.build(None)


class TFRegNetEncoder(keras.layers.Layer):
    def __init__(self, config: RegNetConfig, **kwargs):
        super().__init__(**kwargs)
        self.stages = []

        # 根据配置中的 downsample_in_first_stage 决定第一阶段是否进行输入的下采样
        self.stages.append(
            TFRegNetStage(
                config,
                config.embedding_size,
                config.hidden_sizes[0],
                stride=2 if config.downsample_in_first_stage else 1,
                depth=config.depths[0],
                name="stages.0",
            )
        )

        # 构建多个阶段,每个阶段包含多个 TFRegNetStage
        in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:])
        for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, config.depths[1:])):
            self.stages.append(TFRegNetStage(config, in_channels, out_channels, depth=depth, name=f"stages.{i+1}"))

    def call(
        self, hidden_state: tf.Tensor, output_hidden_states: bool = False, return_dict: bool = True
    ) -> TFBaseModelOutputWithNoAttention:
        hidden_states = () if output_hidden_states else None

        # 逐阶段调用 TFRegNetStage 的 call 方法,收集隐藏状态
        for stage_module in self.stages:
            if output_hidden_states:
                hidden_states = hidden_states + (hidden_state,)
            hidden_state = stage_module(hidden_state)

        if output_hidden_states:
            hidden_states = hidden_states + (hidden_state,)

        # 根据 return_dict 决定返回的结果类型
        if not return_dict:
            return tuple(v for v in [hidden_state, hidden_states] if v is not None)
        return TFBaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        for stage in self.stages:
            with tf.name_scope(stage.name):
                stage.build(None)
class TFRegNetMainLayer(keras.layers.Layer):
    # 使用 RegNetConfig 类来配置模型参数
    config_class = RegNetConfig

    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.config = config
        # 创建 TFRegNetEmbeddings 实例作为嵌入层
        self.embedder = TFRegNetEmbeddings(config, name="embedder")
        # 创建 TFRegNetEncoder 实例作为编码器
        self.encoder = TFRegNetEncoder(config, name="encoder")
        # 创建全局平均池化层,用于池化特征
        self.pooler = keras.layers.GlobalAveragePooling2D(keepdims=True, name="pooler")

    @unpack_inputs
    def call(
        self,
        pixel_values: tf.Tensor,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: bool = False,
    ) -> TFBaseModelOutputWithPoolingAndNoAttention:
        # 根据需要设置是否输出隐藏状态和是否返回字典形式结果
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 通过嵌入层处理输入数据
        embedding_output = self.embedder(pixel_values, training=training)

        # 使用编码器处理嵌入输出
        encoder_outputs = self.encoder(
            embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training
        )

        # 获取最后一个隐藏状态
        last_hidden_state = encoder_outputs[0]
        # 对最终池化的输出进行全局维度转换
        pooled_output = self.pooler(last_hidden_state)

        # 将池化的输出格式转换为 NCHW 格式,确保模块的一致性
        pooled_output = tf.transpose(pooled_output, perm=(0, 3, 1, 2))
        last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2))

        # 如果需要输出隐藏状态,则将所有隐藏状态也转换为 NCHW 格式
        if output_hidden_states:
            hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]])

        # 如果不返回字典形式结果,则返回元组形式的输出
        if not return_dict:
            return (last_hidden_state, pooled_output) + encoder_outputs[1:]

        # 如果返回字典形式结果,则构造 TFBaseModelOutputWithPoolingAndNoAttention 对象
        return TFBaseModelOutputWithPoolingAndNoAttention(
            last_hidden_state=last_hidden_state,
            pooler_output=pooled_output,
            hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
        )

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果嵌入层已定义,则构建嵌入层
        if getattr(self, "embedder", None) is not None:
            with tf.name_scope(self.embedder.name):
                self.embedder.build(None)
        # 如果编码器已定义,则构建编码器
        if getattr(self, "encoder", None) is not None:
            with tf.name_scope(self.encoder.name):
                self.encoder.build(None)
        # 如果池化层已定义,则构建池化层
        if getattr(self, "pooler", None) is not None:
            with tf.name_scope(self.pooler.name):
                self.pooler.build((None, None, None, None))


class TFRegNetPreTrainedModel(TFPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 使用 RegNetConfig 类来配置模型参数
    config_class = RegNetConfig
    # 指定基础模型的前缀名称为 "regnet"
    base_model_prefix = "regnet"
    # 模型的主要输入名称为 "pixel_values"
    main_input_name = "pixel_values"

    @property
    # 定义一个方法input_signature,用于返回输入数据的签名信息,通常在 TensorFlow 的模型定义中使用
    def input_signature(self):
        # 返回一个字典,描述了输入张量的规格和数据类型
        return {"pixel_values": tf.TensorSpec(shape=(None, self.config.num_channels, 224, 224), dtype=tf.float32)}
# 定义用于文档字符串的模型描述和参数说明,使用原始的三重引号格式化字符串
REGNET_START_DOCSTRING = r"""
    This model is a Tensorflow
    [keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) sub-class. Use it as a
    regular Tensorflow Module and refer to the Tensorflow documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`RegNetConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""

# 定义用于输入参数文档字符串的格式化字符串
REGNET_INPUTS_DOCSTRING = r"""
    Args:
        pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`ConveNextImageProcessor.__call__`] for details.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""

# 使用装饰器为类添加起始文档字符串和额外的模型前向传播方法文档
@add_start_docstrings(
    "The bare RegNet model outputting raw features without any specific head on top.",
    REGNET_START_DOCSTRING,
)
class TFRegNetModel(TFRegNetPreTrainedModel):
    def __init__(self, config: RegNetConfig, *inputs, **kwargs):
        # 调用父类的初始化方法,传递模型配置和额外的输入参数
        super().__init__(config, *inputs, **kwargs)
        # 创建主要的RegNet层,使用给定的配置和命名为"regnet"
        self.regnet = TFRegNetMainLayer(config, name="regnet")

    # 使用装饰器为call方法添加起始文档字符串、输入参数和代码示例文档
    @unpack_inputs
    @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TFBaseModelOutputWithPoolingAndNoAttention,
        config_class=_CONFIG_FOR_DOC,
        modality="vision",
        expected_output=_EXPECTED_OUTPUT_SHAPE,
    )
    def call(
        self,
        pixel_values: tf.Tensor,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: bool = False,
    ) -> Union[TFBaseModelOutputWithPoolingAndNoAttention, Tuple[tf.Tensor]]:
        # 如果没有明确指定输出隐藏状态,使用模型配置中的设定
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 如果没有明确指定返回字典形式的输出,使用模型配置中的设定
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用RegNet主层进行前向传播,传递像素值、输出隐藏状态选项、返回字典选项和训练模式
        outputs = self.regnet(
            pixel_values=pixel_values,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )
        # 如果不返回字典形式的输出,以元组形式返回
        if not return_dict:
            return (outputs[0],) + outputs[1:]

        # 返回TFBaseModelOutputWithPoolingAndNoAttention类型的输出,包括最终隐藏状态和池化输出
        return TFBaseModelOutputWithPoolingAndNoAttention(
            last_hidden_state=outputs.last_hidden_state,
            pooler_output=outputs.pooler_output,
            hidden_states=outputs.hidden_states,
        )
    # 如果模型已经构建完成,则直接返回,不进行重复构建
    if self.built:
        return
    
    # 将模型标记为已构建状态
    self.built = True
    
    # 检查是否存在名为 "regnet" 的属性,如果存在则执行以下操作
    if getattr(self, "regnet", None) is not None:
        # 使用 TensorFlow 的命名空间为 regnet 构建模型
        with tf.name_scope(self.regnet.name):
            # 调用 regnet 对象的 build 方法,传入 None 作为输入形状
            self.regnet.build(None)
@add_start_docstrings(
    """
    RegNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
    ImageNet.
    """,
    REGNET_START_DOCSTRING,
)
class TFRegNetForImageClassification(TFRegNetPreTrainedModel, TFSequenceClassificationLoss):
    def __init__(self, config: RegNetConfig, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels
        self.regnet = TFRegNetMainLayer(config, name="regnet")
        # classification head
        self.classifier = [
            keras.layers.Flatten(),  # 将输入展平以供后续全连接层使用
            keras.layers.Dense(config.num_labels, name="classifier.1") if config.num_labels > 0 else tf.identity,  # 分类器的全连接层
        ]

    @unpack_inputs
    @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING)  # 添加模型前向传播的文档字符串
    @add_code_sample_docstrings(
        checkpoint=_IMAGE_CLASS_CHECKPOINT,
        output_type=TFSequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
        expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
    )
    def call(
        self,
        pixel_values: Optional[tf.Tensor] = None,
        labels: Optional[tf.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: bool = False,
    ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
        r"""
        labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states  # 设置是否输出隐藏状态
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict  # 设置是否使用返回字典

        outputs = self.regnet(
            pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training  # 调用 RegNet 主层进行前向传播
        )

        pooled_output = outputs.pooler_output if return_dict else outputs[1]  # 获取汇聚输出或指定位置的输出

        flattened_output = self.classifier[0](pooled_output)  # 使用展平层处理汇聚输出
        logits = self.classifier[1](flattened_output)  # 使用全连接层计算 logits

        loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)  # 计算损失,若无标签则损失为 None

        if not return_dict:
            output = (logits,) + outputs[2:]  # 组合输出,包括 logits 和可能的其他输出
            return ((loss,) + output) if loss is not None else output  # 返回损失与输出,或者仅输出

        return TFSequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)  # 返回包装的输出对象
    # 定义神经网络层的构建方法,如果已经构建过则直接返回
    def build(self, input_shape=None):
        if self.built:
            return
        # 将标志位设置为已构建
        self.built = True
        # 如果存在名为"regnet"的属性,并且不为None,则构建regnet部分
        if getattr(self, "regnet", None) is not None:
            # 在命名空间内构建regnet
            with tf.name_scope(self.regnet.name):
                self.regnet.build(None)
        # 如果存在名为"classifier"的属性,并且不为None,则构建classifier部分
        if getattr(self, "classifier", None) is not None:
            # 在命名空间内构建classifier[1]
            with tf.name_scope(self.classifier[1].name):
                # 构建classifier[1],输入形状为[None, None, None, self.config.hidden_sizes[-1]]
                self.classifier[1].build([None, None, None, self.config.hidden_sizes[-1]])

.\models\regnet\__init__.py

# 版权声明和许可证信息
#
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入类型检查相关模块
from typing import TYPE_CHECKING

# 导入必要的依赖和模块
# 引入了一些特定的异常和工具函数
from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_flax_available,
    is_tf_available,
    is_torch_available,
)

# 定义了一个导入结构的字典,包含了模块和其对应的导入内容
_import_structure = {"configuration_regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"]}

# 检查是否可用 Torch,如果不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,添加 Torch 模型相关的导入内容到导入结构字典中
    _import_structure["modeling_regnet"] = [
        "REGNET_PRETRAINED_MODEL_ARCHIVE_LIST",
        "RegNetForImageClassification",
        "RegNetModel",
        "RegNetPreTrainedModel",
    ]

# 类似地检查 TensorFlow 的可用性,并添加相应的导入内容到导入结构字典中
try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_tf_regnet"] = [
        "TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST",
        "TFRegNetForImageClassification",
        "TFRegNetModel",
        "TFRegNetPreTrainedModel",
    ]

# 类似地检查 Flax 的可用性,并添加相应的导入内容到导入结构字典中
try:
    if not is_flax_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_flax_regnet"] = [
        "FlaxRegNetForImageClassification",
        "FlaxRegNetModel",
        "FlaxRegNetPreTrainedModel",
    ]

# 如果是类型检查阶段,则导入更多的内容以支持类型检查
if TYPE_CHECKING:
    from .configuration_regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_regnet import (
            REGNET_PRETRAINED_MODEL_ARCHIVE_LIST,
            RegNetForImageClassification,
            RegNetModel,
            RegNetPreTrainedModel,
        )

    try:
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_tf_regnet import (
            TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST,
            TFRegNetForImageClassification,
            TFRegNetModel,
            TFRegNetPreTrainedModel,
        )

    # Flax 在类型检查中的导入暂时略过,因为前面已经处理过了
    try:
        if not is_flax_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果前面的条件不满足,则执行以下代码块
        # 从当前目录下的 `modeling_flax_regnet` 模块中导入以下三个类
        from .modeling_flax_regnet import (
            FlaxRegNetForImageClassification,
            FlaxRegNetModel,
            FlaxRegNetPreTrainedModel,
        )
else:
    # 如果条件不满足,导入 sys 模块
    import sys
    # 将当前模块替换为一个懒加载模块,传入当前模块名、文件路径和导入结构
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)

.\models\rembert\configuration_rembert.py

# coding=utf-8
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" RemBERT model configuration"""

from collections import OrderedDict  # 导入有序字典
from typing import Mapping  # 导入类型提示 Mapping

from ...configuration_utils import PretrainedConfig  # 导入预训练配置类
from ...onnx import OnnxConfig  # 导入ONNX配置
from ...utils import logging  # 导入日志工具


logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

# RemBERT预训练配置文件映射字典,指定不同模型的配置文件下载链接
REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "google/rembert": "https://huggingface.co/google/rembert/resolve/main/config.json",
    # 查看所有RemBERT模型:https://huggingface.co/models?filter=rembert
}


class RemBertConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`RemBertModel`]. It is used to instantiate an
    RemBERT model according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the RemBERT
    [google/rembert](https://huggingface.co/google/rembert) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Example:

    ```
    >>> from transformers import RemBertModel, RemBertConfig

    >>> # Initializing a RemBERT rembert style configuration
    >>> configuration = RemBertConfig()

    >>> # Initializing a model from the rembert style configuration
    >>> model = RemBertModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    model_type = "rembert"  # 模型类型为rembert

    def __init__(
        self,
        vocab_size=250300,  # 词汇表大小,默认为250300
        hidden_size=1152,  # 隐藏层大小,默认为1152
        num_hidden_layers=32,  # 隐藏层层数,默认为32
        num_attention_heads=18,  # 注意力头数,默认为18
        input_embedding_size=256,  # 输入嵌入大小,默认为256
        output_embedding_size=1664,  # 输出嵌入大小,默认为1664
        intermediate_size=4608,  # 中间层大小,默认为4608
        hidden_act="gelu",  # 隐藏层激活函数,默认为GELU
        hidden_dropout_prob=0.0,  # 隐藏层Dropout概率,默认为0.0
        attention_probs_dropout_prob=0.0,  # 注意力Dropout概率,默认为0.0
        classifier_dropout_prob=0.1,  # 分类器Dropout概率,默认为0.1
        max_position_embeddings=512,  # 最大位置嵌入数,默认为512
        type_vocab_size=2,  # 类型词汇表大小,默认为2
        initializer_range=0.02,  # 初始化范围,默认为0.02
        layer_norm_eps=1e-12,  # 层归一化的epsilon,默认为1e-12
        use_cache=True,  # 是否使用缓存,默认为True
        pad_token_id=0,  # 填充token的ID,默认为0
        bos_token_id=312,  # 开始token的ID,默认为312
        eos_token_id=313,  # 结束token的ID,默认为313
        **kwargs,  # 其他关键字参数
        # 调用父类的初始化方法,设置模型的特殊 token ID 和其他关键参数
        super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

        # 设置模型的词汇表大小
        self.vocab_size = vocab_size
        # 设置输入词嵌入的维度大小
        self.input_embedding_size = input_embedding_size
        # 设置输出词嵌入的维度大小
        self.output_embedding_size = output_embedding_size
        # 设置最大位置嵌入的数量
        self.max_position_embeddings = max_position_embeddings
        # 设置隐藏层的大小
        self.hidden_size = hidden_size
        # 设置隐藏层数量
        self.num_hidden_layers = num_hidden_layers
        # 设置注意力头的数量
        self.num_attention_heads = num_attention_heads
        # 设置中间层的大小
        self.intermediate_size = intermediate_size
        # 设置隐藏层的激活函数
        self.hidden_act = hidden_act
        # 设置隐藏层的丢弃率
        self.hidden_dropout_prob = hidden_dropout_prob
        # 设置注意力概率的丢弃率
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        # 设置分类器的丢弃率
        self.classifier_dropout_prob = classifier_dropout_prob
        # 设置初始化范围
        self.initializer_range = initializer_range
        # 设置类型词汇表的大小
        self.type_vocab_size = type_vocab_size
        # 设置层归一化的 epsilon 值
        self.layer_norm_eps = layer_norm_eps
        # 设置是否使用缓存
        self.use_cache = use_cache
        # 设置是否将词嵌入进行绑定
        self.tie_word_embeddings = False
# 定义一个自定义的配置类 RemBertOnnxConfig,继承自 OnnxConfig 类
class RemBertOnnxConfig(OnnxConfig):

    # 定义一个属性 inputs,返回一个映射,其键为字符串,值为映射(键为整数,值为字符串)
    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        # 如果任务类型为 "multiple-choice"
        if self.task == "multiple-choice":
            # 设置动态轴 dynamic_axis 为 {0: "batch", 1: "choice", 2: "sequence"}
            dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
        else:
            # 否则设置动态轴 dynamic_axis 为 {0: "batch", 1: "sequence"}
            dynamic_axis = {0: "batch", 1: "sequence"}
        
        # 返回一个有序字典,包含三个键值对,分别是 ("input_ids", dynamic_axis),("attention_mask", dynamic_axis),("token_type_ids", dynamic_axis)
        return OrderedDict(
            [
                ("input_ids", dynamic_axis),
                ("attention_mask", dynamic_axis),
                ("token_type_ids", dynamic_axis),
            ]
        )

    # 定义一个属性 atol_for_validation,返回一个浮点数,表示验证时的绝对容差
    @property
    def atol_for_validation(self) -> float:
        # 返回绝对容差的数值,设定为 1e-4
        return 1e-4

.\models\rembert\convert_rembert_tf_checkpoint_to_pytorch.py

# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert RemBERT checkpoint."""


import argparse  # 导入命令行参数解析模块

import torch  # 导入 PyTorch 深度学习库

from transformers import RemBertConfig, RemBertModel, load_tf_weights_in_rembert  # 导入转换所需的类和函数
from transformers.utils import logging  # 导入日志模块


logging.set_verbosity_info()  # 设置日志输出级别为信息


def convert_rembert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
    # 初始化 PyTorch 模型
    config = RemBertConfig.from_json_file(bert_config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))  # 打印模型配置信息
    model = RemBertModel(config)

    # 从 TensorFlow checkpoint 加载权重
    load_tf_weights_in_rembert(model, config, tf_checkpoint_path)

    # 保存 PyTorch 模型
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()  # 创建参数解析器

    # 必需参数
    parser.add_argument(
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
    )
    parser.add_argument(
        "--rembert_config_file",
        default=None,
        type=str,
        required=True,
        help=(
            "The config json file corresponding to the pre-trained RemBERT model. \n"
            "This specifies the model architecture."
        ),
    )
    parser.add_argument(
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
    )
    args = parser.parse_args()  # 解析命令行参数
    convert_rembert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.rembert_config_file, args.pytorch_dump_path)

.\models\rembert\modeling_rembert.py

# 设置文件编码为 UTF-8
# 版权声明和许可证信息
# 此代码使用 Apache License, Version 2.0 许可证,详细信息可查阅 http://www.apache.org/licenses/LICENSE-2.0
# 除非适用法律要求或书面同意,本软件按"原样"分发,不附带任何形式的担保或条件
# 请查阅许可证了解更多信息
""" PyTorch RemBERT 模型。"""

# 导入必要的库和模块
import math
import os
from typing import Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

# 导入自定义模块和类
from ...activations import ACT2FN
from ...modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    BaseModelOutputWithPoolingAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from .configuration_rembert import RemBertConfig

# 获取日志记录器对象
logger = logging.get_logger(__name__)

# 文档中的配置和检查点
_CONFIG_FOR_DOC = "RemBertConfig"
_CHECKPOINT_FOR_DOC = "google/rembert"

# RemBERT 预训练模型的存档列表
REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "google/rembert",
    # 查看所有 RemBERT 模型:https://huggingface.co/models?filter=rembert
]


def load_tf_weights_in_rembert(model, config, tf_checkpoint_path):
    """从 TensorFlow checkpoints 中加载权重到 PyTorch 模型中。"""
    try:
        import re  # 导入正则表达式模块
        import numpy as np  # 导入 NumPy 模块
        import tensorflow as tf  # 导入 TensorFlow 模块
    except ImportError:
        logger.error(
            "在 PyTorch 中加载 TensorFlow 模型需要安装 TensorFlow。请访问 "
            "https://www.tensorflow.org/install/ 获取安装指南。"
        )
        raise  # 抛出 ImportError 异常
    tf_path = os.path.abspath(tf_checkpoint_path)  # 获取 TensorFlow checkpoints 的绝对路径
    logger.info(f"从 {tf_path} 转换 TensorFlow checkpoints")  # 记录日志信息
    # 从 TF 模型中加载权重
    init_vars = tf.train.list_variables(tf_path)  # 获取 TensorFlow checkpoints 的变量列表
    names = []  # 初始化空列表存储变量名
    arrays = []  # 初始化空列表存储权重数组
    for name, shape in init_vars:
        # 检查点占用12Gb,通过不加载无用变量来节省内存
        # 输出嵌入和cls在分类时会被重置
        if any(deny in name for deny in ("adam_v", "adam_m", "output_embedding", "cls")):
            # 如果变量名包含"adam_v", "adam_m", "output_embedding", "cls"中的任意一个,跳过加载
            # logger.info("Skipping loading of %s", name)
            continue
        logger.info(f"Loading TF weight {name} with shape {shape}")
        # 使用TensorFlow的函数加载变量
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
        arrays.append(array)

    for name, array in zip(names, arrays):
        # 将名称中的前缀"bert/"替换为"rembert/"
        name = name.replace("bert/", "rembert/")
        # "pooler"是一个线性层
        # 如果名称包含"pooler/dense",则替换为"pooler"
        # name = name.replace("pooler/dense", "pooler")

        # 将名称按"/"分割
        name = name.split("/")
        # "adam_v"和"adam_m"是AdamWeightDecayOptimizer中用于计算m和v的变量,预训练模型不需要这些变量
        if any(
            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
            for n in name
        ):
            logger.info(f"Skipping {'/'.join(name)}")
            continue
        pointer = model
        for m_name in name:
            # 如果变量名符合形如"A-Za-z+_\d+"的正则表达式,分割出作用域名称和数字
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
                scope_names = re.split(r"_(\d+)", m_name)
            else:
                scope_names = [m_name]
            # 根据作用域名称选择指针位置
            if scope_names[0] == "kernel" or scope_names[0] == "gamma":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
                pointer = getattr(pointer, "bias")
            elif scope_names[0] == "output_weights":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "squad":
                pointer = getattr(pointer, "classifier")
            else:
                try:
                    pointer = getattr(pointer, scope_names[0])
                except AttributeError:
                    logger.info("Skipping {}".format("/".join(name)))
                    continue
            # 如果作用域名称长度大于等于2,选择指定数字位置的指针
            if len(scope_names) >= 2:
                num = int(scope_names[1])
                pointer = pointer[num]
        # 如果变量名以"_embeddings"结尾,选择权重指针
        if m_name[-11:] == "_embeddings":
            pointer = getattr(pointer, "weight")
        elif m_name == "kernel":
            # 如果变量名是"kernel",转置数组
            array = np.transpose(array)
        try:
            # 检查指针和数组的形状是否匹配,如果不匹配抛出异常
            if pointer.shape != array.shape:
                raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
        logger.info(f"Initialize PyTorch weight {name}")
        # 将NumPy数组转换为PyTorch张量,初始化权重指针
        pointer.data = torch.from_numpy(array)
    return model
# 定义一个名为 RemBertEmbeddings 的神经网络模块,用于构建来自单词、位置和标记类型嵌入的嵌入向量。
class RemBertEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 创建单词嵌入层,vocab_size 表示词汇表大小,input_embedding_size 表示嵌入向量的维度,
        # padding_idx 指定了填充标记的索引,以便在计算时将其置零。
        self.word_embeddings = nn.Embedding(
            config.vocab_size, config.input_embedding_size, padding_idx=config.pad_token_id
        )
        # 创建位置嵌入层,max_position_embeddings 表示最大的位置编码数量,input_embedding_size 表示嵌入向量的维度。
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.input_embedding_size)
        # 创建标记类型嵌入层,type_vocab_size 表示标记类型的数量,input_embedding_size 表示嵌入向量的维度。
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.input_embedding_size)

        # 使用 LayerNorm 进行归一化处理,保持与 TensorFlow 模型变量名的一致性,
        # 并能够加载任意 TensorFlow 检查点文件。
        self.LayerNorm = nn.LayerNorm(config.input_embedding_size, eps=config.layer_norm_eps)
        # Dropout 层,用于在训练过程中随机将一部分输入单元置零,以防止过拟合。
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) 在序列化时在内存中是连续的,并在导出时被导出。
        self.register_buffer(
            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
        )

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        past_key_values_length: int = 0,
    ) -> torch.Tensor:
        # 如果给定 input_ids,则获取其形状;否则,获取 inputs_embeds 的形状。
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        # 获取序列的长度(即时间步数)。
        seq_length = input_shape[1]

        # 如果未提供 position_ids,则从预定义的 position_ids 中选择一段对应于序列长度的位置编码。
        if position_ids is None:
            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]

        # 如果未提供 token_type_ids,则创建一个与输入形状相同的全零张量。
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

        # 如果未提供 inputs_embeds,则使用 input_ids 从 word_embeddings 中获取嵌入向量。
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        
        # 获取 token_type_ids 对应的标记类型嵌入向量。
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        # 将单词嵌入向量和标记类型嵌入向量相加。
        embeddings = inputs_embeds + token_type_embeddings
        # 获取 position_ids 对应的位置嵌入向量。
        position_embeddings = self.position_embeddings(position_ids)
        # 将位置嵌入向量加到之前的结果中。
        embeddings += position_embeddings
        # 应用 LayerNorm 进行归一化处理。
        embeddings = self.LayerNorm(embeddings)
        # 应用 Dropout 进行随机丢弃部分输入。
        embeddings = self.dropout(embeddings)
        # 返回最终的嵌入向量。
        return embeddings


# 从 transformers.models.bert.modeling_bert.BertPooler 复制并修改为 RemBertPooler 类。
class RemBertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 全连接层,输入和输出大小均为 hidden_size,用于池化模型隐藏状态。
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 激活函数 tanh,用于非线性变换。
        self.activation = nn.Tanh()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 简单地通过获取第一个标记对应的隐藏状态来“池化”模型。
        first_token_tensor = hidden_states[:, 0]
        # 经过全连接层处理。
        pooled_output = self.dense(first_token_tensor)
        # 应用激活函数。
        pooled_output = self.activation(pooled_output)
        # 返回池化后的输出。
        return pooled_output
# 定义一个名为 RemBertSelfAttention 的类,继承自 nn.Module
class RemBertSelfAttention(nn.Module):
    # 初始化方法,接受一个 config 对象作为参数
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__()
        # 如果隐藏层大小不是注意力头数的整数倍,且 config 对象没有 embedding_size 属性,则抛出 ValueError 异常
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        # 将注意力头数和每个注意力头的大小设置为类的属性
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # 定义用于生成查询、键和值的线性层,并作为类的属性
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        # 定义用于在注意力计算过程中进行 dropout 的层,并作为类的属性
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

        # 判断是否为解码器,设置为类的属性
        self.is_decoder = config.is_decoder

    # 定义一个方法用于调整输入张量的形状,以适应多头注意力的计算
    def transpose_for_scores(self, x):
        # 计算新的形状,将最后一维分解为注意力头数和每个头的大小,并对输入张量进行相应的变形
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        # 将维度进行置换,以便后续计算
        return x.permute(0, 2, 1, 3)

    # 定义前向传播方法,接收隐藏状态、注意力掩码等作为输入,并进行注意力计算
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Tuple[Tuple[torch.FloatTensor]] = None,
        output_attentions: bool = False,
    ):
        # 方法体中的实现会在下面补充



# 定义一个名为 RemBertSelfOutput 的类,继承自 nn.Module
class RemBertSelfOutput(nn.Module):
    # 初始化方法,接受一个 config 对象作为参数
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__()
        # 定义用于变换隐藏状态维度的线性层、LayerNorm 层和 dropout 层,并作为类的属性
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    # 定义前向传播方法,接收隐藏状态和输入张量作为输入,并返回处理后的隐藏状态
    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 使用线性层处理隐藏状态
        hidden_states = self.dense(hidden_states)
        # 使用 dropout 层对处理后的隐藏状态进行随机失活
        hidden_states = self.dropout(hidden_states)
        # 对处理后的隐藏状态进行 LayerNorm
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        # 返回处理后的隐藏状态
        return hidden_states



# 定义一个名为 RemBertAttention 的类,继承自 nn.Module
class RemBertAttention(nn.Module):
    # 初始化方法,接受一个 config 对象作为参数
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__()
        # 定义自注意力和自注意力输出层,并作为类的属性
        self.self = RemBertSelfAttention(config)
        self.output = RemBertSelfOutput(config)
        # 初始化一个空集合用于存储被修剪的注意力头
        self.pruned_heads = set()

    # 方法体中的实现会在下面补充
    # 剪枝注意力头部
    def prune_heads(self, heads):
        # 如果头部列表为空,则直接返回
        if len(heads) == 0:
            return
        
        # 调用函数查找可剪枝的注意力头部和对应的索引
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )

        # 剪枝线性层:对自注意力模块中的query、key、value以及输出dense层进行剪枝
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # 更新超参数并记录已剪枝的头部
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    # 从transformers.models.bert.modeling_bert.BertAttention.forward中复制而来
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # 调用self模块的forward方法,传递参数并接收返回的self_outputs
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
        )
        
        # 使用self_outputs[0]和hidden_states调用output模块,得到attention_output
        attention_output = self.output(self_outputs[0], hidden_states)
        
        # 如果需要输出attentions,则在outputs中加入attention信息
        outputs = (attention_output,) + self_outputs[1:]  # 如果有attentions,将其加入outputs
        return outputs
# 基于修改后的RemBert的中间层实现,继承自nn.Module类
class RemBertIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 创建一个全连接层,将输入特征大小为config.hidden_size转换为config.intermediate_size
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        # 根据配置选择激活函数,可能是预定义的激活函数映射表ACT2FN中的函数,或者是直接指定的函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    # 前向传播函数,输入hidden_states是一个torch.Tensor,输出也是一个torch.Tensor
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 将输入的hidden_states经过全连接层self.dense进行线性变换
        hidden_states = self.dense(hidden_states)
        # 将线性变换后的结果经过激活函数self.intermediate_act_fn进行非线性变换
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


# 基于修改后的RemBert的输出层实现,继承自nn.Module类
class RemBertOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 创建一个全连接层,将输入特征大小为config.intermediate_size转换为config.hidden_size
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        # 创建LayerNorm层,归一化config.hidden_size维度的张量,eps是归一化的参数
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 创建Dropout层,以config.hidden_dropout_prob的概率随机将输入元素置零
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    # 前向传播函数,输入hidden_states和input_tensor都是torch.Tensor,输出也是一个torch.Tensor
    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 将输入的hidden_states经过全连接层self.dense进行线性变换
        hidden_states = self.dense(hidden_states)
        # 将线性变换后的结果经过Dropout层self.dropout进行随机置零处理
        hidden_states = self.dropout(hidden_states)
        # 将Dropout后的结果与input_tensor相加,并经过LayerNorm层self.LayerNorm进行归一化处理
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


# 基于修改后的RemBert的层实现,继承自nn.Module类
class RemBertLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 设定feed forward过程中的chunk大小
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        # 序列长度维度设为1
        self.seq_len_dim = 1
        # 创建RemBertAttention对象
        self.attention = RemBertAttention(config)
        # 是否作为解码器使用
        self.is_decoder = config.is_decoder
        # 是否添加交叉注意力
        self.add_cross_attention = config.add_cross_attention
        # 如果添加交叉注意力但不作为解码器使用,抛出异常
        if self.add_cross_attention:
            if not self.is_decoder:
                raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
            # 创建另一个RemBertAttention对象用于交叉注意力
            self.crossattention = RemBertAttention(config)
        # 创建RemBertIntermediate对象
        self.intermediate = RemBertIntermediate(config)
        # 创建RemBertOutput对象
        self.output = RemBertOutput(config)

    # 基于transformers库中BertLayer的前向传播函数
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # 声明函数的输入和输出类型注解,这里返回一个 torch.Tensor 的元组

        # 如果过去的键/值对不为 None,则提取自注意力的缓存键/值对(单向),位置在索引 1 和 2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
        
        # 执行自注意力计算,传入隐藏状态、注意力掩码、头部掩码等参数
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            past_key_value=self_attn_past_key_value,
        )
        
        # 获取自注意力计算的输出
        attention_output = self_attention_outputs[0]

        # 如果是解码器,最后一个输出是自注意力缓存的元组
        if self.is_decoder:
            outputs = self_attention_outputs[1:-1]
            present_key_value = self_attention_outputs[-1]
        else:
            # 否则,将自注意力计算的输出加入到输出中
            outputs = self_attention_outputs[1:]  # 如果输出注意力权重,则添加自注意力
        

        cross_attn_present_key_value = None
        
        # 如果是解码器且有编码器的隐藏状态
        if self.is_decoder and encoder_hidden_states is not None:
            # 如果未定义交叉注意力层,则抛出错误
            if not hasattr(self, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
                    " by setting `config.add_cross_attention=True`"
                )

            # 提取过去键/值对的交叉注意力缓存,位置在倒数第二和最后的两个位置
            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
            
            # 执行交叉注意力计算,传入自注意力输出、注意力掩码、头部掩码、编码器隐藏状态等参数
            cross_attention_outputs = self.crossattention(
                attention_output,
                attention_mask,
                head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                cross_attn_past_key_value,
                output_attentions,
            )
            
            # 获取交叉注意力计算的输出
            attention_output = cross_attention_outputs[0]
            outputs = outputs + cross_attention_outputs[1:-1]  # 如果输出注意力权重,则添加交叉注意力

            # 将交叉注意力缓存添加到当前键/值对的末尾位置
            cross_attn_present_key_value = cross_attention_outputs[-1]
            present_key_value = present_key_value + cross_attn_present_key_value

        # 将注意力输出应用分块处理,传入分块处理函数、分块大小、序列长度维度和注意力输出
        layer_output = apply_chunking_to_forward(
            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
        )
        
        # 将分块处理后的输出添加到结果元组中
        outputs = (layer_output,) + outputs

        # 如果是解码器,将注意力的键/值对作为最后的输出添加到结果中
        if self.is_decoder:
            outputs = outputs + (present_key_value,)

        # 返回所有的输出
        return outputs

    # 从 transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk 复制过来的函数
    def feed_forward_chunk(self, attention_output):
        # 执行前馈网络的一部分,传入注意力输出
        intermediate_output = self.intermediate(attention_output)
        
        # 应用输出层,并返回最终的层输出
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output
class RemBertEncoder(nn.Module):
    # RemBert 编码器模块,继承自 nn.Module
    def __init__(self, config):
        super().__init__()
        self.config = config

        # 输入嵌入层的线性映射,将输入的嵌入大小映射到隐藏大小
        self.embedding_hidden_mapping_in = nn.Linear(config.input_embedding_size, config.hidden_size)
        
        # 创建一个由多个 RemBert 层组成的层列表,层数由配置中的 num_hidden_layers 决定
        self.layer = nn.ModuleList([RemBertLayer(config) for _ in range(config.num_hidden_layers)])
        
        # 是否启用梯度检查点,默认为 False
        self.gradient_checkpointing = False

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 前向传播方法,接受多个输入参数并返回一个 tensor
        # 具体操作由每个 RemBert 层来处理
        pass  # 这里应该有实际的前向传播代码,这里只是为了演示注释结构


# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->RemBert
class RemBertPredictionHeadTransform(nn.Module):
    # RemBert 预测头变换模块,继承自 nn.Module
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        
        # 根据配置中的激活函数类型选择对应的激活函数
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        
        # LayerNorm 层,对隐藏状态进行归一化处理
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 前向传播方法,对输入的隐藏状态进行线性变换、激活函数变换和归一化处理
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


class RemBertLMPredictionHead(nn.Module):
    # RemBert 语言模型预测头模块,继承自 nn.Module
    def __init__(self, config):
        super().__init__()
        
        # 全连接层,将隐藏状态映射到输出嵌入大小
        self.dense = nn.Linear(config.hidden_size, config.output_embedding_size)
        
        # 解码器层,将输出嵌入映射到词汇表大小
        self.decoder = nn.Linear(config.output_embedding_size, config.vocab_size)
        
        # 根据配置中的激活函数类型选择对应的激活函数
        self.activation = ACT2FN[config.hidden_act]
        
        # LayerNorm 层,对输出嵌入进行归一化处理
        self.LayerNorm = nn.LayerNorm(config.output_embedding_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 前向传播方法,对输入的隐藏状态进行线性变换、激活函数变换、归一化和解码处理
        hidden_states = self.dense(hidden_states)
        hidden_states = self.activation(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        hidden_states = self.decoder(hidden_states)
        return hidden_states


# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->RemBert
class RemBertOnlyMLMHead(nn.Module):
    # 仅包含 RemBert 语言模型头模块,继承自 nn.Module
    def __init__(self, config):
        super().__init__()
        
        # RemBert 语言模型预测头模块
        self.predictions = RemBertLMPredictionHead(config)

    def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
        # 前向传播方法,接受序列输出并返回预测分数
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


class RemBertPreTrainedModel(PreTrainedModel):
    """
    RemBert 预训练模型基类,继承自 PreTrainedModel
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 定义配置类为 RemBertConfig
    config_class = RemBertConfig
    # 加载 TensorFlow 权重函数为 load_tf_weights_in_rembert
    load_tf_weights = load_tf_weights_in_rembert
    # 基础模型前缀为 "rembert"
    base_model_prefix = "rembert"
    # 支持梯度检查点
    supports_gradient_checkpointing = True

    def _init_weights(self, module):
        """Initialize the weights"""
        # 如果模块是线性层
        if isinstance(module, nn.Linear):
            # 使用正态分布初始化权重,标准差为配置中的 initializer_range
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 如果存在偏置项,则将其初始化为零
            if module.bias is not None:
                module.bias.data.zero_()
        # 如果模块是嵌入层
        elif isinstance(module, nn.Embedding):
            # 使用正态分布初始化权重,标准差为配置中的 initializer_range
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 如果指定了填充索引,则将填充索引对应的权重初始化为零
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        # 如果模块是 LayerNorm 层
        elif isinstance(module, nn.LayerNorm):
            # 将偏置项初始化为零
            module.bias.data.zero_()
            # 将权重初始化为全1
            module.weight.data.fill_(1.0)
# REMBERT_START_DOCSTRING 是一个原始文档字符串,描述了一个 PyTorch 模型类 RemBert 的基本信息和用法建议
REMBERT_START_DOCSTRING = r"""
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`RemBertConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

# REMBERT_INPUTS_DOCSTRING 是一个空白的文档字符串,用于描述模型的输入参数和示例,但当前为空
REMBERT_INPUTS_DOCSTRING = r"""
"""
        Args:
            input_ids (`torch.LongTensor` of shape `({0})`):
                # 输入序列标记在词汇表中的索引

                # 可以使用 `AutoTokenizer` 获取这些索引。参见 `PreTrainedTokenizer.encode` 和 `PreTrainedTokenizer.__call__` 获取详情。

                # [什么是输入 ID?](../glossary#input-ids)
            attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
                # 遮罩,避免在填充标记索引上进行注意力计算。遮罩的取值范围为 `[0, 1]`:

                # - 1 表示**未遮罩**的标记,
                # - 0 表示**已遮罩**的标记。

                # [什么是注意力遮罩?](../glossary#attention-mask)
            token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
                # 段标记索引,用于指示输入的第一部分和第二部分。索引取值范围为 `[0, 1]`:

                # - 0 对应**句子 A** 的标记,
                # - 1 对应**句子 B** 的标记。

                # [什么是标记类型 ID?](../glossary#token-type-ids)
            position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
                # 每个输入序列标记在位置嵌入中的位置索引。索引取值范围为 `[0, config.max_position_embeddings - 1]`。

                # [什么是位置 ID?](../glossary#position-ids)
            head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
                # 遮罩,用于屏蔽自注意力模块中选定的注意力头部。遮罩的取值范围为 `[0, 1]`:

                # - 1 表示**未遮罩**的头部,
                # - 0 表示**已遮罩**的头部。

            inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
                # 可选项,可以直接传递嵌入表示而不是 `input_ids`。这对于控制如何将 `input_ids` 索引转换为相关向量比模型内部的嵌入查找矩阵更有用。

            output_attentions (`bool`, *optional*):
                # 是否返回所有注意力层的注意力张量。更多细节请参见返回的张量中的 `attentions`。

            output_hidden_states (`bool`, *optional*):
                # 是否返回所有层的隐藏状态。更多细节请参见返回的张量中的 `hidden_states`。

            return_dict (`bool`, *optional*):
                # 是否返回 [`~utils.ModelOutput`] 而不是普通元组。
    """
    @add_start_docstrings(
        "The bare RemBERT Model transformer outputting raw hidden-states without any specific head on top.",
        REMBERT_START_DOCSTRING,
    )
    """
    class RemBertModel(RemBertPreTrainedModel):
        """
        The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
        cross-attention is added between the self-attention layers, following the architecture described in [Attention is
        all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
        Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.

        To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
        to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
        `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
        """

        def __init__(self, config, add_pooling_layer=True):
            super().__init__(config)
            self.config = config

            # Initialize embeddings based on configuration
            self.embeddings = RemBertEmbeddings(config)
            # Initialize encoder based on configuration
            self.encoder = RemBertEncoder(config)

            # Optionally initialize a pooling layer based on configuration
            self.pooler = RemBertPooler(config) if add_pooling_layer else None

            # Initialize weights and apply final processing
            self.post_init()

        def get_input_embeddings(self):
            # Return the word embeddings from the embeddings layer
            return self.embeddings.word_embeddings

        def set_input_embeddings(self, value):
            # Set new word embeddings for the embeddings layer
            self.embeddings.word_embeddings = value

        def _prune_heads(self, heads_to_prune):
            """
            Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
            class PreTrainedModel
            """
            for layer, heads in heads_to_prune.items():
                # Prune specified heads in the attention layers of the encoder
                self.encoder.layer[layer].attention.prune_heads(heads)

        @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
        @add_code_sample_docstrings(
            checkpoint="google/rembert",
            output_type=BaseModelOutputWithPastAndCrossAttentions,
            config_class=_CONFIG_FOR_DOC,
        )
        def forward(
            self,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.LongTensor] = None,
            token_type_ids: Optional[torch.LongTensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            head_mask: Optional[torch.FloatTensor] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            encoder_hidden_states: Optional[torch.FloatTensor] = None,
            encoder_attention_mask: Optional[torch.FloatTensor] = None,
            past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            ):
            """
            Forward pass for the RemBERT model.

            Args:
                input_ids: Indices of input sequence tokens in the vocabulary.
                attention_mask: Mask to avoid performing attention on padding token indices.
                token_type_ids: Segment token indices to indicate first and second portions of the inputs.
                position_ids: Indices of positions of each input sequence tokens in the position embeddings.
                head_mask: Mask to nullify selected heads of the attention modules.
                inputs_embeds: Overrides the model's base input word embeddings if provided.
                encoder_hidden_states: Hidden states of the encoder to feed into the cross-attention layer.
                encoder_attention_mask: Mask to avoid performing attention on encoder hidden states.
                past_key_values: Cached key-value pairs for fast autoregressive decoding.
                use_cache: Whether or not to use the past key-value caches.
                output_attentions: Whether or not to return attentions weights.
                output_hidden_states: Whether or not to return hidden states.
                return_dict: Whether or not to return a dictionary as output.

            Returns:
                BaseModelOutputWithPastAndCrossAttentions: Model output.

            Notes:
                Args above are based on REMBERT_INPUTS_DOCSTRING for batch size and sequence length.
            """
            # Actual implementation of the forward pass will follow here, specific to RemBERT's architecture and functionality
            pass
# 用装饰器添加文档字符串,描述这是一个在 `language modeling` 模型基础上加上头部的 RemBERT 模型
@add_start_docstrings("""RemBERT Model with a `language modeling` head on top.""", REMBERT_START_DOCSTRING)
# 定义一个 RemBertForMaskedLM 类,继承自 RemBertPreTrainedModel 类
class RemBertForMaskedLM(RemBertPreTrainedModel):
    # 定义一个类变量,指定共享权重的键值
    _tied_weights_keys = ["cls.predictions.decoder.weight"]

    # 初始化方法,接受一个 config 对象作为参数
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)

        # 如果 config 设置为 decoder 模式,则发出警告信息
        if config.is_decoder:
            logger.warning(
                "If you want to use `RemBertForMaskedLM` make sure `config.is_decoder=False` for "
                "bi-directional self-attention."
            )

        # 创建一个 RemBertModel 对象,不添加池化层
        self.rembert = RemBertModel(config, add_pooling_layer=False)
        # 创建一个 RemBertOnlyMLMHead 对象
        self.cls = RemBertOnlyMLMHead(config)

        # 初始化权重并进行最终处理
        self.post_init()

    # 获取输出 embeddings 的方法,返回预测解码器的权重
    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    # 设置输出 embeddings 的方法,更新预测解码器的权重
    def set_output_embeddings(self, new_embeddings):
        self.cls.predictions.decoder = new_embeddings

    # 重写 forward 方法,接受多个输入参数
    @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint="google/rembert",
        output_type=MaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.LongTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, MaskedLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
        """
        # 如果 `return_dict` 参数为 None,则根据配置决定是否使用返回字典
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 使用输入参数调用 `rembert` 方法,获取输出结果
        outputs = self.rembert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从 `outputs` 中获取序列输出
        sequence_output = outputs[0]
        
        # 使用分类头部 `cls` 对序列输出进行预测
        prediction_scores = self.cls(sequence_output)

        masked_lm_loss = None
        if labels is not None:
            # 定义交叉熵损失函数,用于计算masked language modeling loss
            loss_fct = CrossEntropyLoss()  # -100 index = padding token
            # 计算预测分数和标签之间的损失
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

        if not return_dict:
            # 如果 `return_dict` 为 False,返回结果元组
            output = (prediction_scores,) + outputs[2:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        # 如果 `return_dict` 为 True,返回 MaskedLMOutput 对象
        return MaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
        # 获取输入 `input_ids` 的形状
        input_shape = input_ids.shape
        # 获取有效的批量大小
        effective_batch_size = input_shape[0]

        # 确保 PAD token 已定义用于生成
        assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
        # 将注意力掩码与新生成的零张量连接,以扩展序列长度
        attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
        # 创建一个填充了 PAD token 的虚拟令牌
        dummy_token = torch.full(
            (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
        )
        # 在输入 `input_ids` 的末尾连接虚拟令牌
        input_ids = torch.cat([input_ids, dummy_token], dim=1)

        # 返回包含输入 `input_ids` 和注意力掩码的字典
        return {"input_ids": input_ids, "attention_mask": attention_mask}
# 使用装饰器添加文档字符串,描述了这个类的作用是在 CLM fine-tuning 上使用 RemBERT 模型,并带有语言建模头部
@add_start_docstrings(
    """RemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", REMBERT_START_DOCSTRING
)
# 定义 RemBertForCausalLM 类,继承自 RemBertPreTrainedModel 类
class RemBertForCausalLM(RemBertPreTrainedModel):
    # 类属性,指定权重共享的键名
    _tied_weights_keys = ["cls.predictions.decoder.weight"]

    # 初始化方法,接受一个 config 对象作为参数
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)

        # 如果配置中不是解码器,发出警告
        if not config.is_decoder:
            logger.warning("If you want to use `RemBertForCausalLM` as a standalone, add `is_decoder=True.`")

        # 初始化 RemBERT 模型,不添加池化层
        self.rembert = RemBertModel(config, add_pooling_layer=False)
        # 初始化仅包含 MLM 头部的类
        self.cls = RemBertOnlyMLMHead(config)

        # 调用后处理初始化方法
        self.post_init()

    # 返回输出嵌入层的方法
    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    # 设置输出嵌入层的方法,接受新的嵌入层作为参数
    def set_output_embeddings(self, new_embeddings):
        self.cls.predictions.decoder = new_embeddings

    # 前向传播方法,接受多个输入参数,具体参数的作用通过装饰器和替换返回值的方式进行文档化
    @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.LongTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 准备生成输入的方法,接受输入 ID,过去键值,注意力掩码等参数
        def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
            # 获取输入张量的形状
            input_shape = input_ids.shape

            # 如果注意力掩码为空,创建一个与输入形状相同的全为 1 的张量
            if attention_mask is None:
                attention_mask = input_ids.new_ones(input_shape)

            # 如果传入了过去键值,裁剪输入 ID
            if past_key_values is not None:
                # 获取过去键值的长度
                past_length = past_key_values[0][0].shape[2]

                # 如果输入 ID 的长度大于过去键值的长度,移除前缀长度为过去键值的长度
                if input_ids.shape[1] > past_length:
                    remove_prefix_length = past_length
                else:
                    # 否则,默认只保留最后一个 ID
                    remove_prefix_length = input_ids.shape[1] - 1

                # 更新输入 ID
                input_ids = input_ids[:, remove_prefix_length:]

            # 返回包含输入 ID、注意力掩码和过去键值的字典
            return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
    # 重新排序缓存中的过去键值,以匹配新的束搜索索引顺序
    def _reorder_cache(self, past_key_values, beam_idx):
        # 初始化一个空元组来存储重新排序后的过去状态
        reordered_past = ()
        # 遍历每一层的过去状态
        for layer_past in past_key_values:
            # 对每个层的过去状态的前两项进行重新排序,根据给定的束搜索索引
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
                # 将每层的第三项及其后的项保持不变地添加到重新排序后的元组中
                + layer_past[2:],
            )
        # 返回重新排序后的过去状态元组
        return reordered_past
# 使用装饰器添加文档字符串,描述了这个类是基于RemBERT模型的序列分类/回归模型,适用于GLUE任务等应用。
@add_start_docstrings(
    """
    RemBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    """,
    REMBERT_START_DOCSTRING,
)
class RemBertForSequenceClassification(RemBertPreTrainedModel):
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)
        # 设置分类器的类别数
        self.num_labels = config.num_labels
        # 初始化RemBERT模型
        self.rembert = RemBertModel(config)
        # Dropout层,用于防止过拟合
        self.dropout = nn.Dropout(config.classifier_dropout_prob)
        # 分类器,线性层,将RemBERT模型的输出映射到类别数量上
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        # 初始化权重并进行最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint="google/rembert",
        output_type=SequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: torch.FloatTensor = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, SequenceClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        # 根据是否提供 return_dict 参数来确定是否返回字典类型的输出
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用 RemBERT 模型进行前向传播
        outputs = self.rembert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取经过池化的输出
        pooled_output = outputs[1]

        # 对经过池化的输出进行 dropout 处理
        pooled_output = self.dropout(pooled_output)
        # 将 dropout 后的结果输入分类器进行分类
        logits = self.classifier(pooled_output)

        # 初始化损失值
        loss = None
        # 如果提供了标签,则计算损失值
        if labels is not None:
            # 根据配置和标签的数据类型确定问题类型
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            # 根据问题类型选择合适的损失函数
            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        # 如果不需要返回字典类型的输出,则返回元组
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 返回 SequenceClassifierOutput 对象,包括损失、logits、隐藏状态和注意力权重
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
# 使用多项选择任务头部的 RemBERT 模型(在汇总输出上方添加了一个线性层和 softmax),例如适用于 RocStories/SWAG 任务
@add_start_docstrings(
    """
    RemBERT Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
    softmax) e.g. for RocStories/SWAG tasks.
    """,
    REMBERT_START_DOCSTRING,
)
class RemBertForMultipleChoice(RemBertPreTrainedModel):
    def __init__(self, config):
        # 调用父类构造函数,初始化 RemBERT 模型
        super().__init__(config)

        # 初始化 RemBERT 模型
        self.rembert = RemBertModel(config)
        # 添加一个 dropout 层
        self.dropout = nn.Dropout(config.classifier_dropout_prob)
        # 添加一个线性层,用于分类
        self.classifier = nn.Linear(config.hidden_size, 1)

        # 初始化权重并应用最终处理
        self.post_init()

    # 为前向传播方法添加文档字符串注释
    @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
    # 添加代码示例的文档字符串注释
    @add_code_sample_docstrings(
        checkpoint="google/rembert",
        output_type=MultipleChoiceModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    # 前向传播方法
    def forward(
        self,
        input_ids: torch.FloatTensor = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        ) -> Union[Tuple, MultipleChoiceModelOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
            `input_ids` above)
        """
        # 根据 `return_dict` 参数确定是否返回字典类型的结果
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # 获取输入的选项数量,即每个样本的选择数目
        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

        # 将输入数据展平成二维张量,以便适应模型输入要求
        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
        inputs_embeds = (
            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
            if inputs_embeds is not None
            else None
        )

        # 调用模型的前向传播方法,获取输出结果
        outputs = self.rembert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 提取汇聚后的输出表示
        pooled_output = outputs[1]

        # 对汇聚后的输出进行dropout处理
        pooled_output = self.dropout(pooled_output)
        # 使用分类器对处理后的输出进行分类预测
        logits = self.classifier(pooled_output)
        # 将预测的 logits 重塑成(batch_size, num_choices)形状
        reshaped_logits = logits.view(-1, num_choices)

        # 初始化损失值为None
        loss = None
        # 如果有提供标签,则计算交叉熵损失
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)

        # 如果不要求返回字典类型的结果,则将输出整理成元组形式返回
        if not return_dict:
            output = (reshaped_logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 如果要求返回字典类型的结果,则创建MultipleChoiceModelOutput对象返回
        return MultipleChoiceModelOutput(
            loss=loss,
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
RemBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
"""
# 继承自RemBertPreTrainedModel的RemBertForTokenClassification类,用于在RemBERT模型上添加一个用于标记分类的头部
class RemBertForTokenClassification(RemBertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        # 初始化RemBERT模型,不添加池化层
        self.rembert = RemBertModel(config, add_pooling_layer=False)
        # Dropout层,用于防止过拟合
        self.dropout = nn.Dropout(config.classifier_dropout_prob)
        # 分类器线性层,将隐藏状态输出映射到标签数量的空间
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        # 初始化权重并应用最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint="google/rembert",
        output_type=TokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    # 前向传播函数,接受多个输入参数,并返回模型的输出或损失
    def forward(
        self,
        input_ids: torch.FloatTensor = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, TokenClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
        Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 使用RemBERT模型进行前向传播
        outputs = self.rembert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取序列输出
        sequence_output = outputs[0]

        # 应用Dropout层
        sequence_output = self.dropout(sequence_output)
        # 使用分类器线性层计算logits
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            # 计算交叉熵损失
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            # 如果不使用return_dict,则返回元组形式的输出
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 如果使用return_dict,则返回TokenClassifierOutput对象
        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
    RemBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
    REMBERT_START_DOCSTRING,



# 定义 RemBERT 模型,用于抽取式问答任务(如 SQuAD),其在隐藏状态输出之上带有一个用于计算“span start logits”和“span end logits”的线性分类头部。
# REMBERT_START_DOCSTRING 是一个预定义的文档字符串常量,可能包含 RemBERT 模型的详细描述或指导。
# 定义一个继承自 RemBertPreTrainedModel 的问题回答模型 RemBertForQuestionAnswering 类
class RemBertForQuestionAnswering(RemBertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        # 初始化时设置类别数目
        self.num_labels = config.num_labels

        # 使用 RemBertModel 创建一个 RemBert 对象,不添加池化层
        self.rembert = RemBertModel(config, add_pooling_layer=False)
        
        # 使用 nn.Linear 初始化一个线性层,用于生成问题回答的输出
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        # 初始化权重并进行最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint="google/rembert",
        output_type=QuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    # 定义模型的前向传播方法,接受多种输入参数并返回预测结果
    def forward(
        self,
        input_ids: torch.FloatTensor = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        start_positions: Optional[torch.LongTensor] = None,
        end_positions: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, QuestionAnsweringModelOutput]:
        r"""
        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        """
    
        # Determine if we should return a dictionary based on the provided argument or default configuration
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    
        # Pass inputs through the RoBERTa model
        outputs = self.rembert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
    
        # Extract the sequence output from RoBERTa model outputs
        sequence_output = outputs[0]
    
        # Compute logits for start and end positions from the sequence output
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
    
        total_loss = None
        if start_positions is not None and end_positions is not None:
            # If the start_positions and end_positions tensors have more than one dimension, squeeze them
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            
            # Clamp the start_positions and end_positions to valid ranges within the sequence length
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)
    
            # Define the loss function and compute start and end loss
            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            
            # Compute total loss as the average of start and end loss
            total_loss = (start_loss + end_loss) / 2
    
        # If return_dict is False, return outputs as tuple
        if not return_dict:
            output = (start_logits, end_logits) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output
    
        # If return_dict is True, return structured output using QuestionAnsweringModelOutput class
        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
posted @ 2024-06-29 16:57  绝不原创的飞龙  阅读(6)  评论(0编辑  收藏  举报