Transformers-源码解析-十四-

Transformers 源码解析(十四)

.\models\bartpho\__init__.py

# 版权声明和许可信息,声明版权归 HuggingFace 团队所有,授权遵循 Apache License 2.0
#
# 导入必要的类型检查模块
from typing import TYPE_CHECKING

# 导入必要的依赖和模块,包括自定义的异常 OptionalDependencyNotAvailable,_LazyModule,以及检查 SentencePiece 是否可用的函数 is_sentencepiece_available
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_sentencepiece_available

# 定义一个空的导入结构字典
_import_structure = {}

# 尝试检查 SentencePiece 是否可用,若不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_sentencepiece_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    # 如果检测到 OptionalDependencyNotAvailable 异常,则忽略并继续执行
    pass
else:
    # 如果 SentencePiece 可用,则添加 BartphoTokenizer 到导入结构字典
    _import_structure["tokenization_bartpho"] = ["BartphoTokenizer"]

# 如果正在进行类型检查(Type Checking)
if TYPE_CHECKING:
    try:
        # 再次检查 SentencePiece 是否可用
        if not is_sentencepiece_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        # 如果不可用,则忽略异常
        pass
    else:
        # 如果可用,从 tokenization_bartpho 模块中导入 BartphoTokenizer 类
        from .tokenization_bartpho import BartphoTokenizer

# 如果不是在类型检查模式下运行
else:
    # 导入 sys 模块
    import sys

    # 将当前模块替换为 LazyModule,延迟加载导入结构字典中的内容
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\beit\configuration_beit.py

# coding=utf-8
# Copyright Microsoft Research and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" BEiT model configuration"""

from collections import OrderedDict  # 导入OrderedDict类,用于创建有序字典
from typing import Mapping  # 导入Mapping类型,用于类型提示

from packaging import version  # 导入version模块,用于版本处理

from ...configuration_utils import PretrainedConfig  # 导入预训练配置类
from ...onnx import OnnxConfig  # 导入Onnx配置类
from ...utils import logging  # 导入日志工具模块
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices  # 导入背骨网络工具和特征对齐索引获取函数

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "microsoft/beit-base-patch16-224-pt22k": (
        "https://huggingface.co/microsoft/beit-base-patch16-224-pt22k/resolve/main/config.json"
    ),
    # See all BEiT models at https://huggingface.co/models?filter=beit
}

class BeitConfig(BackboneConfigMixin, PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`BeitModel`]. It is used to instantiate an BEiT
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the BEiT
    [microsoft/beit-base-patch16-224-pt22k](https://huggingface.co/microsoft/beit-base-patch16-224-pt22k) architecture.

    Example:

    ```
    >>> from transformers import BeitConfig, BeitModel

    >>> # Initializing a BEiT beit-base-patch16-224-pt22k style configuration
    >>> configuration = BeitConfig()

    >>> # Initializing a model (with random weights) from the beit-base-patch16-224-pt22k style configuration
    >>> model = BeitModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    """

    model_type = "beit"  # 设置模型类型为 "beit"
    # 初始化函数,用于创建一个新的模型实例
    def __init__(
        self,
        vocab_size=8192,  # 词汇表大小,默认为8192
        hidden_size=768,  # 隐藏层大小,默认为768
        num_hidden_layers=12,  # 隐藏层的数量,默认为12
        num_attention_heads=12,  # 注意力头的数量,默认为12
        intermediate_size=3072,  # 中间层大小,默认为3072
        hidden_act="gelu",  # 隐藏层激活函数,默认为GELU
        hidden_dropout_prob=0.0,  # 隐藏层的dropout概率,默认为0.0
        attention_probs_dropout_prob=0.0,  # 注意力概率的dropout概率,默认为0.0
        initializer_range=0.02,  # 初始化范围,默认为0.02
        layer_norm_eps=1e-12,  # 层归一化的epsilon值,默认为1e-12
        image_size=224,  # 图像大小,默认为224
        patch_size=16,  # 补丁大小,默认为16
        num_channels=3,  # 图像通道数,默认为3
        use_mask_token=False,  # 是否使用mask token,默认为False
        use_absolute_position_embeddings=False,  # 是否使用绝对位置嵌入,默认为False
        use_relative_position_bias=False,  # 是否使用相对位置偏置,默认为False
        use_shared_relative_position_bias=False,  # 是否共享相对位置偏置,默认为False
        layer_scale_init_value=0.1,  # 层缩放初始化值,默认为0.1
        drop_path_rate=0.1,  # drop path的概率,默认为0.1
        use_mean_pooling=True,  # 是否使用均值池化,默认为True
        pool_scales=[1, 2, 3, 6],  # 池化尺度列表,默认为[1, 2, 3, 6]
        use_auxiliary_head=True,  # 是否使用辅助头,默认为True
        auxiliary_loss_weight=0.4,  # 辅助损失权重,默认为0.4
        auxiliary_channels=256,  # 辅助头的通道数,默认为256
        auxiliary_num_convs=1,  # 辅助头的卷积层数,默认为1
        auxiliary_concat_input=False,  # 辅助头是否将输入进行拼接,默认为False
        semantic_loss_ignore_index=255,  # 语义损失忽略的索引,默认为255
        out_features=None,  # 输出特征,默认为None
        out_indices=None,  # 输出索引,默认为None
        add_fpn=False,  # 是否添加特征金字塔网络,默认为False
        reshape_hidden_states=True,  # 是否重塑隐藏状态,默认为True
        **kwargs,  # 其他关键字参数
        ):
            super().__init__(**kwargs)
    
            self.vocab_size = vocab_size
            self.hidden_size = hidden_size
            self.num_hidden_layers = num_hidden_layers
            self.num_attention_heads = num_attention_heads
            self.intermediate_size = intermediate_size
            self.hidden_act = hidden_act
            self.hidden_dropout_prob = hidden_dropout_prob
            self.attention_probs_dropout_prob = attention_probs_dropout_prob
            self.initializer_range = initializer_range
            self.layer_norm_eps = layer_norm_eps
    
            self.image_size = image_size
            self.patch_size = patch_size
            self.num_channels = num_channels
            self.use_mask_token = use_mask_token
            self.use_absolute_position_embeddings = use_absolute_position_embeddings
            self.use_relative_position_bias = use_relative_position_bias
            self.use_shared_relative_position_bias = use_shared_relative_position_bias
            self.layer_scale_init_value = layer_scale_init_value
            self.drop_path_rate = drop_path_rate
            self.use_mean_pooling = use_mean_pooling
            # decode head attributes (semantic segmentation)
            self.pool_scales = pool_scales
            # auxiliary head attributes (semantic segmentation)
            self.use_auxiliary_head = use_auxiliary_head
            self.auxiliary_loss_weight = auxiliary_loss_weight
            self.auxiliary_channels = auxiliary_channels
            self.auxiliary_num_convs = auxiliary_num_convs
            self.auxiliary_concat_input = auxiliary_concat_input
            self.semantic_loss_ignore_index = semantic_loss_ignore_index
    
            # handle backwards compatibility
            如果传入参数中包含"segmentation_indices",发出警告,建议使用"out_indices"代替
            if "segmentation_indices" in kwargs:
                logger.warning(
                    "The `segmentation_indices` argument is deprecated and will be removed in a future version, use `out_indices` instead.",
                    FutureWarning,
                )
                将"segmentation_indices"参数从kwargs中移除
                out_indices = kwargs.pop("segmentation_indices")
    
            # backbone attributes
            构建阶段名称列表,从"stem"开始,然后是每个隐藏层的阶段名
            self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, self.num_hidden_layers + 1)]
            根据输出特征和输出索引,以及阶段名称,获取对齐的输出特征和输出索引
            self._out_features, self._out_indices = get_aligned_output_features_output_indices(
                out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
            )
            是否添加特征金字塔网络(FPN)
            self.add_fpn = add_fpn
            是否重新整形隐藏状态
            self.reshape_hidden_states = reshape_hidden_states
# 从transformers.models.vit.configuration_vit.ViTOnnxConfig复制而来的类定义,继承自OnnxConfig类
class BeitOnnxConfig(OnnxConfig):
    # 设定torch_onnx_minimum_version属性为1.11版本
    torch_onnx_minimum_version = version.parse("1.11")

    # inputs属性的getter方法,返回一个有序字典,描述了输入数据的索引映射关系
    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        return OrderedDict(
            [
                # 指定输入名称为"pixel_values",并定义其维度索引映射关系
                ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
            ]
        )

    # atol_for_validation属性的getter方法,返回一个浮点数,指定验证时的绝对误差限制
    @property
    def atol_for_validation(self) -> float:
        return 1e-4

.\models\beit\convert_beit_unilm_to_pytorch.py

# 设置编码格式为 UTF-8

# 版权声明和许可证信息
# 版权所有 2021 年的 HuggingFace Inc. 团队。
# 根据 Apache 许可证 2.0 版本进行许可;
# 除非符合许可证,否则不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则分发的软件
# 基于“原样”分发,不提供任何形式的保证或条件。
# 请查阅许可证了解具体的法律条文和限制。

"""从 unilm 代码库转换 BEiT 检查点。"""

import argparse  # 导入命令行参数解析模块
import json  # 导入 JSON 操作模块
from pathlib import Path  # 导入路径操作模块

import requests  # 导入 HTTP 请求模块
import torch  # 导入 PyTorch 深度学习框架
from datasets import load_dataset  # 导入数据集加载模块
from huggingface_hub import hf_hub_download  # 导入 HuggingFace Hub 模型下载工具
from PIL import Image  # 导入图像处理库 PIL

from transformers import (  # 导入 transformers 库中的多个类
    BeitConfig,  # BEiT 模型配置类
    BeitForImageClassification,  # 用于图像分类的 BEiT 模型类
    BeitForMaskedImageModeling,  # 用于图像修复的 BEiT 模型类
    BeitForSemanticSegmentation,  # 用于语义分割的 BEiT 模型类
    BeitImageProcessor,  # BEiT 模型的图像处理器类
)
from transformers.image_utils import PILImageResampling  # 导入图像重采样函数
from transformers.utils import logging  # 导入 transformers 的日志记录模块

logging.set_verbosity_info()  # 设置日志记录级别为信息
logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器对象


# 这里列出所有需要重命名的键(左边是原始名称,右边是我们的名称)
def create_rename_keys(config, has_lm_head=False, is_semantic=False):
    prefix = "backbone." if is_semantic else ""  # 如果是语义模型,则前缀为 "backbone."
    
    rename_keys = []  # 创建一个空列表用于存储重命名键值对
    for i in range(config.num_hidden_layers):
        # 编码器层:输出投影、两个前馈神经网络和两个层归一化
        rename_keys.append((f"{prefix}blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight"))
        rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias"))
        rename_keys.append(
            (f"{prefix}blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight")
        )
        rename_keys.append(
            (f"{prefix}blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias")
        )
        rename_keys.append((f"{prefix}blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight"))
        rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias"))
        rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight"))
        rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias"))
        rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight"))
        rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias"))

    # 投影层 + 位置嵌入
    # 将以下键值对添加到 rename_keys 列表中,用于重命名模型中的特定参数路径
    rename_keys.extend(
        [
            # 将 "{prefix}cls_token" 改为 "beit.embeddings.cls_token"
            (f"{prefix}cls_token", "beit.embeddings.cls_token"),
            # 将 "{prefix}patch_embed.proj.weight" 改为 "beit.embeddings.patch_embeddings.projection.weight"
            (f"{prefix}patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"),
            # 将 "{prefix}patch_embed.proj.bias" 改为 "beit.embeddings.patch_embeddings.projection.bias"
            (f"{prefix}patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"),
        ]
    )
    
    if has_lm_head:
        # 如果模型包含语言模型头部,则添加以下键值对到 rename_keys 列表中
        rename_keys.extend(
            [
                # 将 "mask_token" 改为 "beit.embeddings.mask_token"
                ("mask_token", "beit.embeddings.mask_token"),
                # 将 "rel_pos_bias.relative_position_bias_table" 改为 "beit.encoder.relative_position_bias.relative_position_bias_table"
                ("rel_pos_bias.relative_position_bias_table", "beit.encoder.relative_position_bias.relative_position_bias_table"),
                # 将 "rel_pos_bias.relative_position_index" 改为 "beit.encoder.relative_position_bias.relative_position_index"
                ("rel_pos_bias.relative_position_index", "beit.encoder.relative_position_bias.relative_position_index"),
                # 将 "norm.weight" 改为 "layernorm.weight"
                ("norm.weight", "layernorm.weight"),
                # 将 "norm.bias" 改为 "layernorm.bias"
                ("norm.bias", "layernorm.bias"),
            ]
        )
    elif is_semantic:
        # 如果模型是语义分割模型,则添加以下键值对到 rename_keys 列表中
        rename_keys.extend(
            [
                # 将 "decode_head.conv_seg.weight" 改为 "decode_head.classifier.weight"
                ("decode_head.conv_seg.weight", "decode_head.classifier.weight"),
                # 将 "decode_head.conv_seg.bias" 改为 "decode_head.classifier.bias"
                ("decode_head.conv_seg.bias", "decode_head.classifier.bias"),
                # 将 "auxiliary_head.conv_seg.weight" 改为 "auxiliary_head.classifier.weight"
                ("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"),
                # 将 "auxiliary_head.conv_seg.bias" 改为 "auxiliary_head.classifier.bias"
                ("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"),
            ]
        )
    else:
        # 如果以上条件都不满足,则添加以下键值对到 rename_keys 列表中
        rename_keys.extend(
            [
                # 将 "fc_norm.weight" 改为 "beit.pooler.layernorm.weight"
                ("fc_norm.weight", "beit.pooler.layernorm.weight"),
                # 将 "fc_norm.bias" 改为 "beit.pooler.layernorm.bias"
                ("fc_norm.bias", "beit.pooler.layernorm.bias"),
                # 将 "head.weight" 改为 "classifier.weight"
                ("head.weight", "classifier.weight"),
                # 将 "head.bias" 改为 "classifier.bias"
                ("head.bias", "classifier.bias"),
            ]
        )
    
    return rename_keys
# 将每个编码器层的矩阵拆分为查询(queries)、键(keys)和值(values)
def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False):
    # 遍历每个隐藏层
    for i in range(config.num_hidden_layers):
        # 如果是语义模型,则使用特定的前缀
        prefix = "backbone." if is_semantic else ""

        # 从状态字典中弹出查询、键、值的权重矩阵
        in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight")
        q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias")
        v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias")

        # 将查询矩阵权重和偏置添加到 BEiT 模型的状态字典中
        state_dict[f"beit.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
            : config.hidden_size, :
        ]
        state_dict[f"beit.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
        # 将键矩阵权重添加到 BEiT 模型的状态字典中
        state_dict[f"beit.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
            config.hidden_size : config.hidden_size * 2, :
        ]
        # 将值矩阵权重和偏置添加到 BEiT 模型的状态字典中
        state_dict[f"beit.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
            -config.hidden_size :, :
        ]
        state_dict[f"beit.encoder.layer.{i}.attention.attention.value.bias"] = v_bias

        # 弹出并重命名 gamma_1 和 gamma_2 为 lambda_1 和 lambda_2,以防止在 .from_pretrained 方法中被重命名
        gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1")
        gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2")
        state_dict[f"beit.encoder.layer.{i}.lambda_1"] = gamma_1
        state_dict[f"beit.encoder.layer.{i}.lambda_2"] = gamma_2

        # 如果模型没有语言模型头部,则处理相对位置偏置表和索引
        if not has_lm_head:
            # 每个层级都有自己的相对位置偏置表和索引
            table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table")
            index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index")
            # 将相对位置偏置表和索引添加到 BEiT 模型的状态字典中
            state_dict[
                f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
            ] = table
            state_dict[
                f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
            ] = index


# 重命名状态字典中的键
def rename_key(dct, old, new):
    val = dct.pop(old)
    dct[new] = val


# 我们将在一张可爱猫咪的图片上验证我们的结果
def prepare_img():
    # 图片链接
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    # 使用 requests 获取图片流,并打开为 PIL 图像对象
    im = Image.open(requests.get(url, stream=True).raw)
    return im


# 使用无梯度计算上下文环境,将检查点转换为 BEiT 结构
@torch.no_grad()
def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
    """
    复制/粘贴/调整模型的权重到我们的 BEiT 结构。
    """

    # 定义默认的 BEiT 配置
    config = BeitConfig()
    has_lm_head = False
    is_semantic = False
    repo_id = "huggingface/label-files"
    # 根据 URL 设置配置参数
    if checkpoint_url[-9:-4] == "pt22k":
        # 使用共享的相对位置偏置表和遮蔽标记
        config.use_shared_relative_position_bias = True
        config.use_mask_token = True
        has_lm_head = True
    elif checkpoint_url[-9:-4] == "ft22k":
        # 对ImageNet-22k进行中间微调
        config.use_relative_position_bias = True
        config.num_labels = 21841
        filename = "imagenet-22k-id2label.json"
        # 从指定的HF Hub下载数据集文件,加载ID到标签的映射关系
        id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
        id2label = {int(k): v for k, v in id2label.items()}
        # 该数据集包含21843个标签,但模型只有21841个,因此删除不需要的类别
        del id2label[9205]
        del id2label[15027]
        config.id2label = id2label
        # 构建标签到ID的反向映射
        config.label2id = {v: k for k, v in id2label.items()}
    elif checkpoint_url[-8:-4] == "to1k":
        # 对ImageNet-1k进行微调
        config.use_relative_position_bias = True
        config.num_labels = 1000
        filename = "imagenet-1k-id2label.json"
        # 从指定的HF Hub下载数据集文件,加载ID到标签的映射关系
        id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
        id2label = {int(k): v for k, v in id2label.items()}
        config.id2label = id2label
        # 构建标签到ID的反向映射
        config.label2id = {v: k for k, v in id2label.items()}
        # 根据URL中的尺寸信息设置图像大小
        if "384" in checkpoint_url:
            config.image_size = 384
        if "512" in checkpoint_url:
            config.image_size = 512
    elif "ade20k" in checkpoint_url:
        # 对ADE20K数据集进行微调
        config.use_relative_position_bias = True
        config.num_labels = 150
        filename = "ade20k-id2label.json"
        # 从指定的HF Hub下载数据集文件,加载ID到标签的映射关系
        id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
        id2label = {int(k): v for k, v in id2label.items()}
        config.id2label = id2label
        # 构建标签到ID的反向映射
        config.label2id = {v: k for k, v in id2label.items()}
        # 设置图像大小为640,并标记为语义分割任务
        config.image_size = 640
        is_semantic = True
    else:
        raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k', 'to1k' or 'ade20k'")

    # 架构的尺寸设置
    if "base" in checkpoint_url:
        pass
    elif "large" in checkpoint_url:
        # 设置大型模型的隐藏层大小、中间层大小、隐藏层层数和注意力头数
        config.hidden_size = 1024
        config.intermediate_size = 4096
        config.num_hidden_layers = 24
        config.num_attention_heads = 16
        # 如果是ADE20K数据集,设置特定的图像大小和输出索引
        if "ade20k" in checkpoint_url:
            config.image_size = 640
            config.out_indices = [7, 11, 15, 23]
    else:
        raise ValueError("Should either find 'base' or 'large' in checkpoint URL")

    # 加载原始模型的state_dict,并移除/重命名部分键
    state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)
    # 如果不是ADE20K数据集,只加载"model"部分,否则加载"state_dict"部分
    state_dict = state_dict["model"] if "ade20k" not in checkpoint_url else state_dict["state_dict"]

    # 创建重命名键列表并应用到state_dict
    rename_keys = create_rename_keys(config, has_lm_head=has_lm_head, is_semantic=is_semantic)
    for src, dest in rename_keys:
        rename_key(state_dict, src, dest)
    # 读取QKV(查询、键、值)的信息并应用到state_dict
    read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head, is_semantic=is_semantic)
    # 如果是语义分割模型
    if is_semantic:
        # 对状态字典中的键添加前缀
        for key, val in state_dict.copy().items():
            val = state_dict.pop(key)
            if key.startswith("backbone.fpn"):
                key = key.replace("backbone.fpn", "fpn")
            state_dict[key] = val

    # 加载 HuggingFace 模型
    if checkpoint_url[-9:-4] == "pt22k":
        # 根据 URL 后缀选择合适的模型:MaskedImageModeling
        model = BeitForMaskedImageModeling(config)
    elif "ade20k" in checkpoint_url:
        # 如果 URL 中包含 "ade20k",选择语义分割模型
        model = BeitForSemanticSegmentation(config)
    else:
        # 默认选择图像分类模型
        model = BeitForImageClassification(config)
    model.eval()  # 将模型设置为评估模式
    model.load_state_dict(state_dict)  # 加载状态字典到模型中

    # 根据是否是语义分割选择图像处理器和图像
    if is_semantic:
        # 创建语义分割图像处理器
        image_processor = BeitImageProcessor(size=config.image_size, do_center_crop=False)
        # 加载测试数据集中的图像
        ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
        image = Image.open(ds[0]["file"])
    else:
        # 创建图像处理器,设置图像大小和重采样方式
        image_processor = BeitImageProcessor(
            size=config.image_size, resample=PILImageResampling.BILINEAR, do_center_crop=False
        )
        # 准备图像
        image = prepare_img()

    # 对图像进行编码,返回编码结果
    encoding = image_processor(images=image, return_tensors="pt")
    pixel_values = encoding["pixel_values"]

    # 使用模型进行推理,得到输出
    outputs = model(pixel_values)
    logits = outputs.logits

    # 验证输出 logits 的形状是否符合预期
    expected_shape = torch.Size([1, 1000])  # 默认情况下 logits 形状为 [1, 1000]
    if checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k"):
        expected_shape = torch.Size([1, 196, 8192])  # 特定模型的预期 logits 形状
    elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k"):
        expected_shape = torch.Size([1, 196, 8192])  # 特定模型的预期 logits 形状
    elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22k"):
        expected_shape = torch.Size([1, 21841])  # 特定模型的预期 logits 形状
        expected_logits = torch.tensor([2.2288, 2.4671, 0.7395])  # 预期的 logits 值
        expected_class_idx = 2397  # 预期的类别索引
    elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22k"):
        expected_shape = torch.Size([1, 21841])  # 特定模型的预期 logits 形状
        expected_logits = torch.tensor([1.6881, -0.2787, 0.5901])  # 预期的 logits 值
        expected_class_idx = 2396  # 预期的类别索引
    elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft1k"):
        expected_logits = torch.tensor([0.1241, 0.0798, -0.6569])  # 预期的 logits 值
        expected_class_idx = 285  # 预期的类别索引
    elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22kto1k"):
        expected_logits = torch.tensor([-1.2385, -1.0987, -1.0108])  # 预期的 logits 值
        expected_class_idx = 281  # 预期的类别索引
    elif checkpoint_url[:-4].endswith("beit_base_patch16_384_pt22k_ft22kto1k"):
        expected_logits = torch.tensor([-1.5303, -0.9484, -0.3147])  # 预期的 logits 值
        expected_class_idx = 761  # 预期的类别索引
    elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft1k"):
        expected_logits = torch.tensor([0.4610, -0.0928, 0.2086])  # 预期的 logits 值
        expected_class_idx = 761  # 预期的类别索引
    elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22kto1k"):
        expected_logits = torch.tensor([-0.4804, 0.6257, -0.1837])  # 预期的 logits 值
        expected_class_idx = 761  # 预期的类别索引
    elif checkpoint_url[:-4].endswith("beit_large_patch16_384_pt22k_ft22kto1k"):
        # 设置预期的模型输出日志和类别索引,用于后续验证
        expected_logits = torch.tensor([[-0.5122, 0.5117, -0.2113]])
        expected_class_idx = 761
    elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"):
        # 设置预期的模型输出日志和类别索引,用于后续验证
        expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852])
        expected_class_idx = 761
    elif checkpoint_url[:-4].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"):
        # 设置预期的模型输出形状和日志,用于后续验证
        expected_shape = (1, 150, 160, 160)
        expected_logits = torch.tensor(
            [
                [[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
                [[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
                [[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
            ]
        )
    elif checkpoint_url[:-4].endswith("beit_large_patch16_640_pt22k_ft22ktoade20k"):
        # 设置预期的模型输出形状和日志,用于后续验证
        expected_shape = (1, 150, 160, 160)
        expected_logits = torch.tensor(
            [
                [[-4.3305, -2.3049, -3.0161], [-2.9591, -1.5305, -2.2251], [-3.4198, -1.8004, -2.9062]],
                [[-5.8922, -3.7435, -4.3978], [-4.2063, -2.7872, -3.4755], [-4.2791, -3.1874, -4.1681]],
                [[0.9895, 4.3467, 4.7663], [4.2476, 5.6830, 6.1518], [4.5550, 6.2495, 6.5154]],
            ]
        )
    else:
        # 如果不是支持的模型类型,则引发错误
        raise ValueError("Can't verify logits as model is not supported")

    if logits.shape != expected_shape:
        # 检查模型输出的形状是否符合预期
        raise ValueError(f"Shape of logits not as expected. {logits.shape=}, {expected_shape=}")
    if not has_lm_head:
        if is_semantic:
            # 如果是语义任务,检查模型输出的前几个元素是否与预期的日志值接近
            if not torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-3):
                raise ValueError("First elements of logits not as expected")
        else:
            # 如果不是语义任务,打印预测的类别索引并检查模型输出的前几个元素是否与预期的日志值接近
            print("Predicted class idx:", logits.argmax(-1).item())

            if not torch.allclose(logits[0, :3], expected_logits, atol=1e-3):
                raise ValueError("First elements of logits not as expected")
            if logits.argmax(-1).item() != expected_class_idx:
                raise ValueError("Predicted class index not as expected")

    # 创建保存模型的文件夹(如果不存在)
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
    print(f"Saving model to {pytorch_dump_folder_path}")
    # 保存模型到指定路径
    model.save_pretrained(pytorch_dump_folder_path)
    print(f"Saving image processor to {pytorch_dump_folder_path}")
    # 保存图像处理器到指定路径
    image_processor.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
    # 如果当前脚本被直接执行,则执行以下代码块

    parser = argparse.ArgumentParser()
    # 创建一个参数解析器对象

    parser.add_argument(
        "--checkpoint_url",
        default="https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth",
        type=str,
        help="URL to the original PyTorch checkpoint (.pth file).",
    )
    # 添加名为--checkpoint_url的命令行参数,设置默认值和类型,并提供帮助信息

    parser.add_argument(
        "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
    )
    # 添加名为--pytorch_dump_folder_path的命令行参数,设置默认值和类型,并提供帮助信息

    args = parser.parse_args()
    # 解析命令行参数,并将其存储在args对象中

    convert_beit_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)
    # 调用convert_beit_checkpoint函数,传递命令行参数中的checkpoint_url和pytorch_dump_folder_path作为参数

.\models\beit\feature_extraction_beit.py

# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
# 2021年HuggingFace团队。版权所有。
#
# 根据Apache许可证2.0版(“许可证”)许可;您不得使用此文件,除非符合许可证的规定。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件基于“原样”提供,不提供任何明示或暗示的担保或条件。
# 有关详细信息,请参阅许可证。
"""BEiT的特征提取器类。"""

# 导入警告模块
import warnings

# 导入日志记录工具
from ...utils import logging
# 导入BEiT图像处理器类
from .image_processing_beit import BeitImageProcessor

# 获取日志记录器
logger = logging.get_logger(__name__)

# 定义BEiT特征提取器类,继承自BeitImageProcessor类
class BeitFeatureExtractor(BeitImageProcessor):
    def __init__(self, *args, **kwargs) -> None:
        # 发出警告,提示BeitFeatureExtractor类即将在Transformers版本5中删除,请使用BeitImageProcessor代替
        warnings.warn(
            "The class BeitFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
            " use BeitImageProcessor instead.",
            FutureWarning,
        )
        # 调用父类的初始化方法
        super().__init__(*args, **kwargs)

.\models\beit\image_processing_beit.py

# 设置文件编码为 UTF-8
# 版权声明,指出代码版权归 HuggingFace Inc. 团队所有
#
# 根据 Apache 许可证 2.0 版本,只有在遵循许可证的情况下才能使用此文件
# 可以在以下网址获取许可证的副本:http://www.apache.org/licenses/LICENSE-2.0
#
# 除非法律要求或书面同意,否则按“现状”提供软件,不附带任何明示或暗示的担保或条件。
# 有关许可证详细信息,请参见许可证文本。

"""Beit 的图像处理类。"""

# 导入警告模块
import warnings
# 导入类型提示模块
from typing import Any, Dict, List, Optional, Tuple, Union

# 导入 NumPy 库
import numpy as np

# 导入图像处理工具类和函数
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
# 导入图像变换函数
from ...image_transforms import resize, to_channel_dimension_format
# 导入图像处理工具函数
from ...image_utils import (
    IMAGENET_STANDARD_MEAN,  # 导入常用的图像均值
    IMAGENET_STANDARD_STD,   # 导入常用的图像标准差
    ChannelDimension,        # 导入通道维度类
    ImageInput,              # 导入图像输入类
    PILImageResampling,      # 导入 PIL 图像重采样方法枚举
    infer_channel_dimension_format,  # 推断通道维度格式函数
    is_scaled_image,         # 判断图像是否经过缩放函数
    make_list_of_images,     # 将图像转换为图像列表函数
    to_numpy_array,          # 将图像转换为 NumPy 数组函数
    valid_images,            # 检验有效图像函数
    validate_kwargs,         # 验证关键字参数函数
    validate_preprocess_arguments,  # 验证预处理参数函数
)
# 导入通用工具函数和类型
from ...utils import TensorType, is_torch_available, is_torch_tensor, is_vision_available, logging

# 如果 PyTorch 可用,导入 PyTorch 模块
if is_torch_available():
    import torch

# 获取日志记录器
logger = logging.get_logger(__name__)

# 定义 BeitImageProcessor 类,继承自 BaseImageProcessor 类
class BeitImageProcessor(BaseImageProcessor):
    r"""
    构建 BEiT 图像处理器。

    """

    # 模型输入名称列表,仅包含像素值
    model_input_names = ["pixel_values"]

    def __init__(
        self,
        do_resize: bool = True,                   # 是否进行调整大小的标志
        size: Dict[str, int] = None,              # 图像大小的字典,包含宽和高
        resample: PILImageResampling = PILImageResampling.BICUBIC,  # PIL 图像重采样方法
        do_center_crop: bool = True,              # 是否进行中心裁剪的标志
        crop_size: Dict[str, int] = None,         # 裁剪尺寸的字典,包含宽和高
        rescale_factor: Union[int, float] = 1 / 255,  # 图像缩放因子
        do_rescale: bool = True,                  # 是否进行图像缩放的标志
        do_normalize: bool = True,                # 是否进行图像标准化的标志
        image_mean: Optional[Union[float, List[float]]] = None,  # 图像均值
        image_std: Optional[Union[float, List[float]]] = None,   # 图像标准差
        do_reduce_labels: bool = False,           # 是否减少标签的标志
        **kwargs,                                 # 其他关键字参数
    ):
        # 调用父类的构造函数
        super().__init__(**kwargs)
    ) -> None:
        # 如果 kwargs 中包含 "reduce_labels" 参数,则发出警告,并将其值赋给 do_reduce_labels
        if "reduce_labels" in kwargs:
            warnings.warn(
                "The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use"
                " `do_reduce_labels` instead.",
                FutureWarning,
            )
            do_reduce_labels = kwargs.pop("reduce_labels")
        # 调用父类的初始化方法,传入所有的 kwargs
        super().__init__(**kwargs)
        # 设置 size 变量,如果未指定则使用默认值 {"height": 256, "width": 256}
        size = size if size is not None else {"height": 256, "width": 256}
        # 调用 get_size_dict 函数,确保 size 是一个字典
        size = get_size_dict(size)
        # 设置 crop_size 变量,如果未指定则使用默认值 {"height": 224, "width": 224}
        crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
        # 调用 get_size_dict 函数,确保 crop_size 是一个字典,参数名为 "crop_size"
        crop_size = get_size_dict(crop_size, param_name="crop_size")
        # 设置对象的成员变量
        self.do_resize = do_resize
        self.size = size
        self.resample = resample
        self.do_center_crop = do_center_crop
        self.crop_size = crop_size
        self.do_rescale = do_rescale
        self.rescale_factor = rescale_factor
        self.do_normalize = do_normalize
        # 设置对象的成员变量,如果未指定 image_mean 则使用 IMAGENET_STANDARD_MEAN
        self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
        # 设置对象的成员变量,如果未指定 image_std 则使用 IMAGENET_STANDARD_STD
        self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
        # 设置对象的成员变量 do_reduce_labels
        self.do_reduce_labels = do_reduce_labels
        # 设置对象的成员变量 _valid_processor_keys,包含所有可能的处理器参数键名
        self._valid_processor_keys = [
            "images",
            "segmentation_maps",
            "do_resize",
            "size",
            "resample",
            "do_center_crop",
            "crop_size",
            "do_rescale",
            "rescale_factor",
            "do_normalize",
            "image_mean",
            "image_std",
            "do_reduce_labels",
            "return_tensors",
            "data_format",
            "input_data_format",
        ]

    @classmethod
    def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
        """
        Overrides the `from_dict` method from the base class to make sure `reduce_labels` is updated if image processor
        is created using from_dict and kwargs e.g. `BeitImageProcessor.from_pretrained(checkpoint, reduce_labels=True)`
        """
        # 复制 image_processor_dict,确保原始字典不受影响
        image_processor_dict = image_processor_dict.copy()
        # 如果 kwargs 中包含 "reduce_labels" 参数,则将其值更新到 image_processor_dict 中
        if "reduce_labels" in kwargs:
            image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels")
        # 调用父类的 from_dict 方法,传入更新后的 image_processor_dict 和其他 kwargs
        return super().from_dict(image_processor_dict, **kwargs)

    def resize(
        self,
        image: np.ndarray,
        size: Dict[str, int],
        resample: PILImageResampling = PILImageResampling.BICUBIC,
        data_format: Optional[Union[str, ChannelDimension]] = None,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
        **kwargs,
    def _preprocess(
        self,
        image: ImageInput,
        do_reduce_labels: bool = None,
        do_resize: bool = None,
        size: Dict[str, int] = None,
        resample: PILImageResampling = None,
        do_center_crop: bool = None,
        crop_size: Dict[str, int] = None,
        do_rescale: bool = None,
        rescale_factor: float = None,
        do_normalize: bool = None,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
    ):
        """
        Preprocesses an image based on specified operations.

        Args:
            image (`ImageInput`):
                The input image to be preprocessed.
            do_reduce_labels (`bool`, optional):
                Whether to reduce labels using `reduce_label` method.
            do_resize (`bool`, optional):
                Whether to resize the image.
            size (`Dict[str, int]`, optional):
                Target size (height and width) for resizing.
            resample (`PILImageResampling`, optional):
                Resampling filter for resizing the image.
            do_center_crop (`bool`, optional):
                Whether to perform center cropping.
            crop_size (`Dict[str, int]`, optional):
                Size for center cropping (height and width).
            do_rescale (`bool`, optional):
                Whether to rescale the image.
            rescale_factor (`float`, optional):
                Factor for rescaling the image.
            do_normalize (`bool`, optional):
                Whether to normalize the image.
            image_mean (`float` or `List[float]`, optional):
                Mean values for normalizing the image.
            image_std (`float` or `List[float]`, optional):
                Standard deviation values for normalizing the image.
            input_data_format (`str` or `ChannelDimension`, optional):
                Format of the input image data.

        Returns:
            `np.ndarray`: Preprocessed image based on the specified operations.
        """
        if do_reduce_labels:
            # Reduce label values using the `reduce_label` method
            image = self.reduce_label(image)

        if do_resize:
            # Resize the image using specified size and resampling filter
            image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)

        if do_center_crop:
            # Perform center cropping on the image
            image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)

        if do_rescale:
            # Rescale the image using the specified factor
            image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)

        if do_normalize:
            # Normalize the image using mean and standard deviation
            image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)

        return image
    def _preprocess_image(
        self,
        image: ImageInput,
        do_resize: bool = None,
        size: Dict[str, int] = None,
        resample: PILImageResampling = None,
        do_center_crop: bool = None,
        crop_size: Dict[str, int] = None,
        do_rescale: bool = None,
        rescale_factor: float = None,
        do_normalize: bool = None,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        data_format: Optional[Union[str, ChannelDimension]] = None,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
    ) -> np.ndarray:
        """Preprocesses a single image."""
        # 转换输入图像为 numpy 数组
        image = to_numpy_array(image)
        
        # 如果输入图像已经进行了缩放且设置了 do_rescale=True,则发出警告
        if is_scaled_image(image) and do_rescale:
            logger.warning_once(
                "It looks like you are trying to rescale already rescaled images. If the input"
                " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
            )
        
        # 推断输入数据格式的通道维度
        if input_data_format is None:
            input_data_format = infer_channel_dimension_format(image)
        
        # 调用 _preprocess 方法,对图像进行预处理
        image = self._preprocess(
            image,
            do_reduce_labels=False,
            do_resize=do_resize,
            size=size,
            resample=resample,
            do_center_crop=do_center_crop,
            crop_size=crop_size,
            do_rescale=do_rescale,
            rescale_factor=rescale_factor,
            do_normalize=do_normalize,
            image_mean=image_mean,
            image_std=image_std,
            input_data_format=input_data_format,
        )
        
        # 如果指定了 data_format,将图像转换为指定的通道维度格式
        if data_format is not None:
            image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
        
        # 返回预处理后的图像数组
        return image

    def _preprocess_segmentation_map(
        self,
        segmentation_map: ImageInput,
        do_resize: bool = None,
        size: Dict[str, int] = None,
        resample: PILImageResampling = None,
        do_center_crop: bool = None,
        crop_size: Dict[str, int] = None,
        do_reduce_labels: bool = None,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
    ):
        """
        Preprocesses a single segmentation map.

        """
        # All transformations expect numpy arrays.
        segmentation_map = to_numpy_array(segmentation_map)
        # Add an axis to the segmentation maps for transformations.
        if segmentation_map.ndim == 2:
            segmentation_map = segmentation_map[None, ...]
            added_dimension = True
            input_data_format = ChannelDimension.FIRST
        else:
            added_dimension = False
            # If input_data_format is not specified, infer it based on the segmentation map.
            if input_data_format is None:
                input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
        segmentation_map = self._preprocess(
            image=segmentation_map,
            do_reduce_labels=do_reduce_labels,
            do_resize=do_resize,
            resample=resample,
            size=size,
            do_center_crop=do_center_crop,
            crop_size=crop_size,
            do_normalize=False,
            do_rescale=False,
            input_data_format=ChannelDimension.FIRST,
        )
        # Remove extra axis if added
        if added_dimension:
            segmentation_map = np.squeeze(segmentation_map, axis=0)
        segmentation_map = segmentation_map.astype(np.int64)
        return segmentation_map

    def __call__(self, images, segmentation_maps=None, **kwargs):
        """
        Overrides the `__call__` method of the `Preprocessor` class such that the images and segmentation maps can both
        be passed in as positional arguments.
        """
        return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)

    def preprocess(
        self,
        images: ImageInput,
        segmentation_maps: Optional[ImageInput] = None,
        do_resize: bool = None,
        size: Dict[str, int] = None,
        resample: PILImageResampling = None,
        do_center_crop: bool = None,
        crop_size: Dict[str, int] = None,
        do_rescale: bool = None,
        rescale_factor: float = None,
        do_normalize: bool = None,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        do_reduce_labels: Optional[bool] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        data_format: ChannelDimension = ChannelDimension.FIRST,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
        **kwargs,
        ):
        """
        Handles preprocessing of images and segmentation maps with various options for transformations and adjustments.

        """
    def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
        """
        Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.

        Args:
            outputs ([`BeitForSemanticSegmentation`]):
                Raw outputs of the model.
            target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
                List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
                predictions will not be resized.

        Returns:
            semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
            segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
            specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
        """
        # TODO: add support for other frameworks
        
        # Extract logits from the model outputs
        logits = outputs.logits

        # Resize logits and compute semantic segmentation maps if target_sizes is provided
        if target_sizes is not None:
            # Check if the number of logits matches the number of target sizes
            if len(logits) != len(target_sizes):
                raise ValueError(
                    "Make sure that you pass in as many target sizes as the batch dimension of the logits"
                )

            # Convert target_sizes to numpy array if it's a torch tensor
            if is_torch_tensor(target_sizes):
                target_sizes = target_sizes.numpy()

            # Initialize an empty list for storing semantic segmentation maps
            semantic_segmentation = []

            # Iterate over each element in logits and perform interpolation
            for idx in range(len(logits)):
                # Resize logits using bilinear interpolation
                resized_logits = torch.nn.functional.interpolate(
                    logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
                )
                # Compute the semantic map by taking the argmax along the channel dimension
                semantic_map = resized_logits[0].argmax(dim=0)
                # Append the computed semantic map to the list
                semantic_segmentation.append(semantic_map)
        else:
            # Compute semantic segmentation by taking the argmax over the channel dimension
            semantic_segmentation = logits.argmax(dim=1)
            # Convert the result to a list of tensors
            semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]

        # Return the list of semantic segmentation maps
        return semantic_segmentation

.\models\beit\modeling_beit.py

# 设置文件编码为 UTF-8
# 版权声明 2021 Microsoft Research 和 HuggingFace Inc. 团队,保留所有权利
#
# 根据 Apache 许可证 2.0 版本使用本文件,除非符合许可证的条款,否则不得使用本文件
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则依据"原样"分发本软件,
# 没有任何形式的担保或条件,包括但不限于对适销性或特定用途的适用性的保证。
# 有关详细信息,请参阅许可证。
""" PyTorch BEiT 模型。"""

# 导入必要的库和模块
import collections.abc  # 引入 collections.abc 模块
import math  # 引入 math 模块
from dataclasses import dataclass  # 从 dataclasses 模块导入 dataclass 装饰器
from typing import List, Optional, Tuple, Union  # 导入类型提示

import torch  # 导入 PyTorch 库
import torch.utils.checkpoint  # 导入 PyTorch 的 checkpoint 模块
from torch import Tensor, nn  # 从 PyTorch 导入 Tensor 和 nn(神经网络)模块
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss  # 从 nn 模块导入损失函数

# 导入其他需要的类和函数
from ...activations import ACT2FN  # 从 activations 模块导入 ACT2FN 激活函数
from ...modeling_outputs import (  # 导入模型输出相关的类
    BackboneOutput,
    BaseModelOutput,
    BaseModelOutputWithPooling,
    ImageClassifierOutput,
    MaskedLMOutput,
    SemanticSegmenterOutput,
)
from ...modeling_utils import PreTrainedModel  # 从 modeling_utils 模块导入 PreTrainedModel 类
from ...pytorch_utils import (  # 导入 PyTorch 工具函数
    find_pruneable_heads_and_indices,
    meshgrid,
    prune_linear_layer,
)
from ...utils import (  # 导入工具函数和类
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from ...utils.backbone_utils import BackboneMixin  # 从 backbone_utils 模块导入 BackboneMixin 类
from .configuration_beit import BeitConfig  # 导入 BEiT 模型的配置类

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

# 概述文件的一般用途
_CONFIG_FOR_DOC = "BeitConfig"

# 基础说明文档
_CHECKPOINT_FOR_DOC = "microsoft/beit-base-patch16-224-pt22k"
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]

# 图像分类的说明文档
_IMAGE_CLASS_CHECKPOINT = "microsoft/beit-base-patch16-224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"

BEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "microsoft/beit-base-patch16-224",
    # 查看所有 BEiT 模型的列表 https://huggingface.co/models?filter=beit
]

@dataclass
class BeitModelOutputWithPooling(BaseModelOutputWithPooling):
    """
    [`BeitModel`] 的输出类。
    """
    pass  # 占位符,表示类目前不包含额外的属性或方法,继承自 BaseModelOutputWithPooling 类
    # 接收模型最后一层的隐藏状态,形状为 `(batch_size, sequence_length, hidden_size)` 的张量
    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
    
    # 如果 *config.use_mean_pooling* 设置为 True,则返回补丁标记的最后一层隐藏状态的平均值(不包括 *[CLS]* 标记)。
    # 如果设置为 False,则返回 *[CLS]* 标记的最终隐藏状态。
    pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
    
    # 可选参数,当 `output_hidden_states=True` 时返回,或者当 `config.output_hidden_states=True` 时返回。
    # 是一个元组,包含 `torch.FloatTensor` 类型的张量:
    #   - 一个是嵌入层的输出
    #   - 其余每一层的输出,形状为 `(batch_size, sequence_length, hidden_size)`
    hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
    
    # 可选参数,当 `output_attentions=True` 时返回,或者当 `config.output_attentions=True` 时返回。
    # 是一个元组,包含 `torch.FloatTensor` 类型的张量:
    #   - 每一层的注意力权重,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`
    #   这些权重经过注意力 softmax 后得到,用于计算自注意力头中的加权平均值。
    attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
# 定义一个函数,用于在模型训练时对输入的张量进行路径丢弃(随机深度),通常应用于残差块的主路径。
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    """
    if drop_prob == 0.0 or not training:
        # 如果丢弃概率为0或者当前不处于训练状态,直接返回输入张量
        return input
    keep_prob = 1 - drop_prob
    shape = (input.shape[0],) + (1,) * (input.ndim - 1)  # 适用于不同维度张量,而不仅仅是2D卷积网络
    random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
    random_tensor.floor_()  # 将随机张量二值化
    output = input.div(keep_prob) * random_tensor
    return output


class BeitDropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""

    def __init__(self, drop_prob: Optional[float] = None) -> None:
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 调用上面定义的drop_path函数来处理输入的隐藏状态张量
        return drop_path(hidden_states, self.drop_prob, self.training)

    def extra_repr(self) -> str:
        # 返回当前DropPath模块的额外表示,包括当前的丢弃概率
        return "p={}".format(self.drop_prob)


# 基于timm实现,可以在此找到:
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
class BeitEmbeddings(nn.Module):
    """
    Construct the CLS token, position and patch embeddings. Optionally, also the mask token.

    """

    def __init__(self, config: BeitConfig) -> None:
        super().__init__()

        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
        if config.use_mask_token:
            self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
        else:
            self.mask_token = None
        self.patch_embeddings = BeitPatchEmbeddings(config)
        num_patches = self.patch_embeddings.num_patches
        if config.use_absolute_position_embeddings:
            self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
        else:
            self.position_embeddings = None
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
    # 定义前向传播方法,接受像素值张量和可选的掩码位置张量,返回处理后的张量
    def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor:
        # 使用 patch_embeddings 方法处理像素值张量,得到嵌入向量和嵌入坐标
        embeddings, (patch_height, patch_width) = self.patch_embeddings(
            pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None
        )
        # 获取批次大小、序列长度和嵌入向量的维度
        batch_size, seq_len, _ = embeddings.size()

        # 如果存在掩码位置张量
        if bool_masked_pos is not None:
            # 将掩码位置标记的视觉标记替换为掩码标记
            mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
            w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
            embeddings = embeddings * (1 - w) + mask_tokens * w

        # 将 cls_token 扩展到与批次大小和嵌入向量维度相匹配
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        # 如果存在位置嵌入,则将其加到 cls_token 上
        if self.position_embeddings is not None:
            cls_tokens = cls_tokens + self.position_embeddings[:, :1, :]

        # 在序列的开头连接 cls_token 和 embeddings
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)

        # 对 embeddings 应用 dropout
        embeddings = self.dropout(embeddings)

        # 返回处理后的 embeddings 和 patch 的高度、宽度信息
        return embeddings, (patch_height, patch_width)
# 定义一个用于将像素值转换成初始隐藏状态(即补丁嵌入)的类,以便Transformer模型使用。
class BeitPatchEmbeddings(nn.Module):
    """
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    """

    def __init__(self, config):
        super().__init__()
        # 从配置中获取图像大小和补丁大小
        image_size, patch_size = config.image_size, config.patch_size
        num_channels, hidden_size = config.num_channels, config.hidden_size

        # 确保图像大小和补丁大小是可迭代的对象,如果不是则转换为元组
        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
        
        # 计算补丁数量和补丁形状
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
        patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
        
        # 初始化对象的属性
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_patches = num_patches
        self.patch_shape = patch_shape

        # 使用卷积层进行投影,将输入的通道数转换为隐藏状态的大小
        self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, pixel_values: torch.Tensor, position_embedding: Optional[torch.Tensor] = None) -> torch.Tensor:
        # 获取输入像素值的维度信息
        batch_size, num_channels, height, width = pixel_values.shape
        # 检查输入的通道数是否与配置中设置的通道数一致
        if num_channels != self.num_channels:
            raise ValueError(
                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
            )

        # 使用投影层对输入像素值进行投影,得到嵌入表示
        embeddings = self.projection(pixel_values)
        # 获取投影后的补丁高度和宽度
        patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]

        if position_embedding is not None:
            # 插值位置嵌入到相应的大小
            position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(
                0, 3, 1, 2
            )
            position_embedding = nn.functional.interpolate(
                position_embedding, size=(patch_height, patch_width), mode="bicubic"
            )
            # 将位置嵌入加到投影后的嵌入中
            embeddings = embeddings + position_embedding

        # 将嵌入表示展平,并交换维度顺序以符合Transformer的输入格式
        embeddings = embeddings.flatten(2).transpose(1, 2)

        return embeddings, (patch_height, patch_width)
    def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None:
        super().__init__()
        # 检查隐藏层大小是否能被注意力头数整除,同时没有嵌入大小属性
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
                f"heads {config.num_attention_heads}."
            )

        # 初始化注意力头数和每个头的大小
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # 初始化查询、键、值的线性变换层
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        # 初始化用于随机失活的 Dropout 层
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

        # 如果指定了窗口大小,初始化相对位置偏置层
        if window_size:
            self.relative_position_bias = BeitRelativePositionBias(config, window_size=window_size)
        else:
            self.relative_position_bias = None

    def transpose_for_scores(self, x):
        # 调整张量形状以便进行多头注意力计算
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        relative_position_bias: Optional["BeitRelativePositionBias"] = None,
    ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
        # 通过调用 self.query 方法生成混合查询向量 mixed_query_layer
        mixed_query_layer = self.query(hidden_states)

        # 使用 self.key 方法生成键向量 key_layer,并通过 transpose_for_scores 方法转置以备注意力计算使用
        key_layer = self.transpose_for_scores(self.key(hidden_states))

        # 使用 self.value 方法生成值向量 value_layer,并通过 transpose_for_scores 方法转置以备注意力计算使用
        value_layer = self.transpose_for_scores(self.value(hidden_states))

        # 再次调用 transpose_for_scores 方法转置 mixed_query_layer 以备注意力计算使用
        query_layer = self.transpose_for_scores(mixed_query_layer)

        # 计算注意力分数,采用 query_layer 和 key_layer 的点积
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        # 缩放注意力分数,除以 sqrt(attention_head_size)
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        # 如果存在相对位置偏置,将其加入注意力分数中
        if self.relative_position_bias is not None:
            attention_scores = attention_scores + self.relative_position_bias().unsqueeze(0)

        # 如果给定了 shared relative position bias,也将其加入注意力分数中
        if relative_position_bias is not None:
            attention_scores = attention_scores + relative_position_bias

        # 将注意力分数归一化为概率分布
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # 使用 dropout 方法对注意力概率进行随机失活处理
        attention_probs = self.dropout(attention_probs)

        # 如果给定了 head_mask,将其应用到 attention_probs 上
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        # 计算加权后的值向量,得到上下文向量 context_layer
        context_layer = torch.matmul(attention_probs, value_layer)

        # 对 context_layer 进行维度重排,以符合后续计算要求
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # 根据输出设置返回结果,包括 context_layer 和 attention_probs(如果需要)
        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        return outputs
class BeitSelfOutput(nn.Module):
    """
    The residual connection is defined in BeitLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    """

    def __init__(self, config: BeitConfig) -> None:
        super().__init__()
        # Linear transformation for the output of self-attention
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # Dropout layer for regularization
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma=None) -> torch.Tensor:
        # Linear transformation
        hidden_states = self.dense(hidden_states)
        # Apply dropout
        hidden_states = self.dropout(hidden_states)

        return hidden_states


class BeitAttention(nn.Module):
    def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None:
        super().__init__()
        # Self-attention mechanism
        self.attention = BeitSelfAttention(config, window_size=window_size)
        # Output layer after attention
        self.output = BeitSelfOutput(config)
        # Set of pruned attention heads
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        # Find and prune attention heads based on indices
        heads, index = find_pruneable_heads_and_indices(
            heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
        )

        # Prune linear layers for attention components
        self.attention.query = prune_linear_layer(self.attention.query, index)
        self.attention.key = prune_linear_layer(self.attention.key, index)
        self.attention.value = prune_linear_layer(self.attention.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # Update number of attention heads and related sizes, and store pruned heads
        self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
        self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        relative_position_bias: Optional["BeitRelativePositionBias"] = None,
    ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
        # Perform self-attention
        self_outputs = self.attention(hidden_states, head_mask, output_attentions, relative_position_bias)

        # Output of attention passed through output layer
        attention_output = self.output(self_outputs[0], hidden_states)

        # Collect outputs, including attention matrices if requested
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs


class BeitIntermediate(nn.Module):
    def __init__(self, config: BeitConfig) -> None:
        super().__init__()
        # Intermediate dense layer
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        # Activation function for intermediate layer
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act
    # 定义一个方法 `forward`,用于前向传播计算
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 将输入的隐藏状态 `hidden_states` 经过全连接层 `dense` 处理
        hidden_states = self.dense(hidden_states)
        # 对处理后的隐藏状态应用激活函数 `intermediate_act_fn`
        hidden_states = self.intermediate_act_fn(hidden_states)

        # 返回处理后的隐藏状态作为输出
        return hidden_states
class BeitOutput(nn.Module):
    def __init__(self, config: BeitConfig) -> None:
        super().__init__()
        # 创建一个全连接层,将输入特征的维度缩放为隐藏层大小
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        # 定义一个dropout层,用于随机置零输入张量的一些元素,以减少过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 将输入张量传入全连接层,执行线性变换
        hidden_states = self.dense(hidden_states)
        # 对全连接层的输出执行dropout操作
        hidden_states = self.dropout(hidden_states)

        return hidden_states


class BeitLayer(nn.Module):
    """This corresponds to the Block class in the timm implementation."""

    def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0) -> None:
        super().__init__()
        # 设置用于分块feed forward的块大小
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        # 序列长度维度
        self.seq_len_dim = 1
        # 使用给定配置和窗口大小创建注意力机制
        self.attention = BeitAttention(config, window_size=window_size)
        # 创建中间层对象,将输入特征映射到隐藏层大小
        self.intermediate = BeitIntermediate(config)
        # 创建输出层对象,将中间层的输出映射到最终的隐藏层大小
        self.output = BeitOutput(config)
        # 应用LayerNorm在隐藏层上,以归一化特征
        self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 根据dropout路径率初始化drop path对象,如果路径率大于0
        self.drop_path = BeitDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
        # 再次应用LayerNorm在隐藏层上,以归一化特征
        self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        # 初始化lambda参数,如果初始值大于0,则创建可学习的参数张量
        init_values = config.layer_scale_init_value
        if init_values > 0:
            self.lambda_1 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True)
            self.lambda_2 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True)
        else:
            self.lambda_1, self.lambda_2 = None, None

    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        relative_position_bias: Optional["BeitRelativePositionBias"] = None,
        ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
        # 使用 self.attention 对 hidden_states 应用自注意力机制
        self_attention_outputs = self.attention(
            self.layernorm_before(hidden_states),  # 在 BEiT 模型中,先对 hidden_states 应用 layernorm
            head_mask,
            output_attentions=output_attentions,
            relative_position_bias=relative_position_bias,
        )
        # 获取自注意力输出
        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[1:]  # 如果需要输出注意力权重,则添加自注意力结果

        # 如果定义了 lambda_1,则对 attention_output 应用缩放
        if self.lambda_1 is not None:
            attention_output = self.lambda_1 * attention_output

        # 第一个残差连接
        hidden_states = self.drop_path(attention_output) + hidden_states

        # 在 BEiT 中,还会在自注意力后应用 layernorm
        layer_output = self.layernorm_after(hidden_states)

        # 经过中间层和输出层
        layer_output = self.intermediate(layer_output)
        layer_output = self.output(layer_output)

        # 如果定义了 lambda_2,则对 layer_output 应用缩放
        if self.lambda_2 is not None:
            layer_output = self.lambda_2 * layer_output

        # 第二个残差连接
        layer_output = self.drop_path(layer_output) + hidden_states

        # 整合最终输出
        outputs = (layer_output,) + outputs

        return outputs
class BeitRelativePositionBias(nn.Module):
    def __init__(self, config: BeitConfig, window_size: tuple) -> None:
        super().__init__()
        self.window_size = window_size
        # 计算相对位置偏置表的大小
        self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
        # 创建一个可学习的参数,用于存储相对位置偏置表,大小为 num_relative_distance x num_attention_heads
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros(self.num_relative_distance, config.num_attention_heads)
        )  # 2*Wh-1 * 2*Ww-1, nH
        # 用于描述cls到token、token到cls、cls到cls之间的相对位置关系

        # 获取每个窗口内每个token之间的pair-wise相对位置索引
        coords_h = torch.arange(window_size[0])
        coords_w = torch.arange(window_size[1])
        coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += window_size[0] - 1  # 将坐标向左移动
        relative_coords[:, :, 1] += window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * window_size[1] - 1
        relative_position_index = torch.zeros(
            size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
        )
        relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        relative_position_index[0, 0:] = self.num_relative_distance - 3
        relative_position_index[0:, 0] = self.num_relative_distance - 2
        relative_position_index[0, 0] = self.num_relative_distance - 1

        # 将相对位置索引注册为非参数化缓冲区
        self.register_buffer("relative_position_index", relative_position_index, persistent=False)

    def forward(self) -> torch.Tensor:
        # 根据相对位置索引从相对位置偏置表中获取相对位置偏置
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
        )  # Wh*Ww,Wh*Ww,nH

        # 返回维度变换后的相对位置偏置,维度顺序为 nH, Wh*Ww, Wh*Ww
        return relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
    def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None:
        super().__init__()
        self.config = config
        # 如果配置中使用共享的相对位置偏置,则创建相对位置偏置对象
        if config.use_shared_relative_position_bias:
            self.relative_position_bias = BeitRelativePositionBias(config, window_size=window_size)
        else:
            self.relative_position_bias = None

        # 根据随机深度衰减规则生成每个层的衰减率列表
        dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
        # 创建神经网络层的列表,每层使用不同的衰减率和配置
        self.layer = nn.ModuleList(
            [
                BeitLayer(
                    config,
                    window_size=window_size if config.use_relative_position_bias else None,
                    drop_path_rate=dpr[i],
                )
                for i in range(config.num_hidden_layers)
            ]
        )
        # 梯度检查点功能默认关闭
        self.gradient_checkpointing = False

    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ) -> Union[tuple, BaseModelOutput]:
        # 初始化空元组以保存所有隐藏状态和注意力分数
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None

        # 遍历每个神经网络层进行前向传播
        for i, layer_module in enumerate(self.layer):
            # 如果需要记录隐藏状态,则将当前隐藏状态添加到列表中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 获取当前层的头部掩码
            layer_head_mask = head_mask[i] if head_mask is not None else None

            # 如果启用梯度检查点且在训练模式下,则使用梯度检查点函数调用当前层
            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    hidden_states,
                    layer_head_mask,
                    output_attentions,
                )
            else:
                # 获取相对位置偏置(如果存在)并传递给当前层
                relative_position_bias = (
                    self.relative_position_bias() if self.relative_position_bias is not None else None
                )
                layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias)

            # 更新隐藏状态为当前层的输出
            hidden_states = layer_outputs[0]

            # 如果需要输出注意力分数,则将当前层的注意力分数添加到列表中
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        # 如果需要记录隐藏状态,则将最终隐藏状态添加到列表中
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 根据返回类型决定输出格式
        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )
BEIT_START_DOCSTRING = r"""
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`BeitConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
    # "The bare Beit Model transformer outputting raw hidden-states without any specific head on top."
    # BEIT_START_DOCSTRING,
    )
class BeitModel(BeitPreTrainedModel):
    def __init__(self, config: BeitConfig, add_pooling_layer: bool = True) -> None:
        super().__init__(config)
        self.config = config

        # 初始化嵌入层和编码器
        self.embeddings = BeitEmbeddings(config)
        self.encoder = BeitEncoder(config, window_size=self.embeddings.patch_embeddings.patch_shape)

        # 根据配置选择性地添加层归一化或池化层
        self.layernorm = (
            nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        )
        self.pooler = BeitPooler(config) if add_pooling_layer else None

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        # 返回嵌入层的补丁嵌入
        return self.embeddings.patch_embeddings

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        # 遍历需要修剪的层和头部,并在注意力机制中执行修剪
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=BeitModelOutputWithPooling,
        config_class=_CONFIG_FOR_DOC,
        modality="vision",
        expected_output=_EXPECTED_OUTPUT_SHAPE,
    )
    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        bool_masked_pos: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[tuple, BeitModelOutputWithPooling]:
        r"""
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        """
        # Determine whether to return attentions, hidden states, etc., based on input or default config
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Validate input: pixel_values must be specified
        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        # Prepare head mask if needed
        # 1.0 in head_mask indicates that the head is kept active during attention
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # head_mask is reshaped to [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        # Embedding process: computes embeddings from pixel values and masked positions
        embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values, bool_masked_pos)

        # Encoder block: applies transformer encoding to embedding output
        encoder_outputs = self.encoder(
            embedding_output,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        # Extract sequence output from encoder output and normalize using layer normalization
        sequence_output = encoder_outputs[0]
        sequence_output = self.layernorm(sequence_output)

        # Pooler layer: computes pooled output if pooler is defined
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

        # Return different outputs based on return_dict flag
        if not return_dict:
            # Return tuple of sequence output and pooled output (if available) along with other encoder outputs
            head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
            return head_outputs + encoder_outputs[1:]

        # Return structured output using BeitModelOutputWithPooling
        return BeitModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
# 定义一个名为 `BeitPooler` 的神经网络模块,用于对 BEiT 模型的隐藏状态进行池化操作
class BeitPooler(nn.Module):
    def __init__(self, config: BeitConfig) -> None:
        super().__init__()
        # 如果配置要求使用均值池化,则使用 LayerNorm 对隐藏状态进行归一化
        self.layernorm = (
            nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        if self.layernorm is not None:
            # 如果存在 LayerNorm 对象,则对补丁令牌的最终隐藏状态进行均值池化
            patch_tokens = hidden_states[:, 1:, :]  # 选择除了第一个令牌外的所有令牌的隐藏状态
            pooled_output = self.layernorm(patch_tokens.mean(1))  # 对补丁令牌的隐藏状态进行均值池化并归一化
        else:
            # 否则,通过简单地使用 [CLS] 令牌的最终隐藏状态进行池化
            pooled_output = hidden_states[:, 0]  # 选择 [CLS] 令牌的最终隐藏状态作为池化输出

        return pooled_output


@add_start_docstrings(
    """Beit Model transformer with a 'language' modeling head on top. BEiT does masked image modeling by predicting
    visual tokens of a Vector-Quantize Variational Autoencoder (VQ-VAE), whereas other vision models like ViT and DeiT
    predict RGB pixel values. As a result, this class is incompatible with [`AutoModelForMaskedImageModeling`], so you
    will need to use [`BeitForMaskedImageModeling`] directly if you wish to do masked image modeling with BEiT.""",
    BEIT_START_DOCSTRING,
)
# 定义一个带有语言建模头部的 BEiT 模型变压器。BEiT 通过预测矢量量化变分自动编码器(VQ-VAE)的视觉令牌来进行遮罩图像建模,而像 ViT 和 DeiT 这样的其他视觉模型预测 RGB 像素值。因此,此类与 [`AutoModelForMaskedImageModeling`] 不兼容,如果要使用 BEiT 进行遮罩图像建模,您需要直接使用 [`BeitForMaskedImageModeling`]。
class BeitForMaskedImageModeling(BeitPreTrainedModel):
    def __init__(self, config: BeitConfig) -> None:
        super().__init__(config)

        self.num_labels = config.num_labels
        self.beit = BeitModel(config, add_pooling_layer=False)  # 初始化 BEiT 模型,不添加池化层

        # 分类器头部
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # 使用 LayerNorm 对隐藏状态进行归一化
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)  # 线性层用于语言模型的预测

        # 初始化权重并应用最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
    # 重写 `forward` 方法,用于模型的前向传播
    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,  # 像素值,可选输入
        bool_masked_pos: Optional[torch.BoolTensor] = None,  # 遮罩位置的布尔张量,可选输入
        head_mask: Optional[torch.Tensor] = None,  # 头部遮罩,可选输入
        labels: Optional[torch.Tensor] = None,  # 标签,可选输入
        output_attentions: Optional[bool] = None,  # 是否输出注意力权重,可选输入
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态,可选输入
        return_dict: Optional[bool] = None,  # 是否返回字典格式的输出,可选输入
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # 如果 return_dict 不为 None,则使用其值;否则使用 self.config.use_return_dict 的值

        outputs = self.beit(
            pixel_values,
            bool_masked_pos=bool_masked_pos,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 调用 self.beit 方法,传入像素数值 pixel_values 和其他参数,根据 return_dict 是否为真决定是否返回字典形式的输出

        sequence_output = outputs[0]
        # 从模型输出中取得序列输出

        sequence_output = self.layernorm(sequence_output)
        # 对序列输出进行 layer normalization 处理

        prediction_scores = self.lm_head(sequence_output[:, 1:])
        # 使用 lm_head 对序列输出的部分进行预测评分,通常是用来生成模型的输出结果

        masked_lm_loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()  # 定义交叉熵损失函数,用于计算损失
            masked_lm_loss = loss_fct(prediction_scores[bool_masked_pos], labels)
            # 如果给定了标签 labels,则计算被遮蔽位置 bool_masked_pos 的预测结果与标签之间的损失

        if not return_dict:
            output = (prediction_scores,) + outputs[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
            # 如果不要求返回字典形式的输出,则返回预测分数和其他附加输出

        return MaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
        # 如果需要返回字典形式的输出,则返回一个 MaskedLMOutput 对象,包含损失、预测分数、隐藏状态和注意力权重信息
@add_start_docstrings(
    """
    Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final
    hidden states of the patch tokens) e.g. for ImageNet.
    """,
    BEIT_START_DOCSTRING,
)
class BeitForImageClassification(BeitPreTrainedModel):
    def __init__(self, config: BeitConfig) -> None:
        super().__init__(config)

        self.num_labels = config.num_labels
        # 使用配置初始化 BeitModel,添加池化层以便用于分类任务
        self.beit = BeitModel(config, add_pooling_layer=True)

        # 分类器头部,根据配置的隐藏层大小和标签数量初始化线性层或者恒等映射
        self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()

        # 初始化权重并应用最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_IMAGE_CLASS_CHECKPOINT,
        output_type=ImageClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
        expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
    )
    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 模型前向传播方法,接收像素值、头部掩码、标签等参数,返回模型输出
    ) -> Union[tuple, ImageClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        # 根据需要决定是否返回字典形式的输出
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # 使用 BEiT 模型进行推理
        outputs = self.beit(
            pixel_values,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 根据返回值是否为字典形式,选择 pooled_output
        pooled_output = outputs.pooler_output if return_dict else outputs[1]

        # 使用分类器对 pooled_output 进行分类得到 logits
        logits = self.classifier(pooled_output)

        # 初始化损失为 None
        loss = None
        # 如果有标签输入
        if labels is not None:
            # 如果问题类型未指定,则根据标签类型和数量确定问题类型
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            # 根据问题类型选择损失函数和计算损失
            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        # 如果不使用字典形式返回结果,则组装输出并返回
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 如果使用字典形式返回结果,则创建 ImageClassifierOutput 对象并返回
        return ImageClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
class BeitConvModule(nn.Module):
    """
    A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
    layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).

    Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int]],
        padding: Union[int, Tuple[int, int], str] = 0,
        bias: bool = False,
        dilation: Union[int, Tuple[int, int]] = 1,
    ) -> None:
        super().__init__()
        # 定义卷积层,设置输入输出通道数、核大小、填充、是否有偏置、扩张率等参数
        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            padding=padding,
            bias=bias,
            dilation=dilation,
        )
        # 定义批归一化层,设置输出通道数
        self.bn = nn.BatchNorm2d(out_channels)
        # 定义激活函数层为ReLU
        self.activation = nn.ReLU()

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # 前向传播函数,依次经过卷积、批归一化和ReLU激活
        output = self.conv(input)
        output = self.bn(output)
        output = self.activation(output)

        return output


class BeitPyramidPoolingBlock(nn.Module):
    def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None:
        super().__init__()
        # 创建自适应平均池化层和BeitConvModule卷积模块,并将其作为列表存储在self.layers中
        self.layers = [
            nn.AdaptiveAvgPool2d(pool_scale),
            BeitConvModule(in_channels, channels, kernel_size=1),
        ]
        # 将每个层添加为模块
        for i, layer in enumerate(self.layers):
            self.add_module(str(i), layer)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        hidden_state = input
        # 依次对输入数据应用self.layers中的每个层,并返回最终的隐藏状态
        for layer in self.layers:
            hidden_state = layer(hidden_state)
        return hidden_state


class BeitPyramidPoolingModule(nn.Module):
    """
    Pyramid Pooling Module (PPM) used in PSPNet.

    Args:
        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
            Module.
        in_channels (int): Input channels.
        channels (int): Channels after modules, before conv_seg.
        align_corners (bool): align_corners argument of F.interpolate.

    Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
    """

    def __init__(self, pool_scales: Tuple[int, ...], in_channels: int, channels: int, align_corners: bool) -> None:
        super().__init__()
        self.pool_scales = pool_scales
        self.align_corners = align_corners
        self.in_channels = in_channels
        self.channels = channels
        self.blocks = []
        # 根据给定的pool_scales创建多个BeitPyramidPoolingBlock模块,并添加为子模块
        for i, pool_scale in enumerate(pool_scales):
            block = BeitPyramidPoolingBlock(pool_scale=pool_scale, in_channels=in_channels, channels=channels)
            self.blocks.append(block)
            self.add_module(str(i), block)
    # 定义前向传播方法,接受一个张量 x 作为输入,并返回一个张量列表
    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
        # 初始化一个空列表,用于存储各个 PPM 模块的输出张量
        ppm_outs = []
        # 遍历 self.blocks 中的每个 PPM 模块
        for ppm in self.blocks:
            # 对输入 x 应用当前的 PPM 模块,得到该模块的输出张量 ppm_out
            ppm_out = ppm(x)
            # 使用双线性插值方法将 ppm_out 上采样到与输入 x 相同的大小
            upsampled_ppm_out = nn.functional.interpolate(
                ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
            )
            # 将上采样后的 ppm_out 添加到 ppm_outs 列表中
            ppm_outs.append(upsampled_ppm_out)
        # 返回所有 PPM 模块的输出张量组成的列表 ppm_outs
        return ppm_outs
# 定义一个名为 `BeitUperHead` 的类,继承自 `nn.Module`,用于实现场景理解的统一感知解析。
class BeitUperHead(nn.Module):
    """
    Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
    [UPerNet](https://arxiv.org/abs/1807.10221).

    Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
    """

    # 初始化方法,接收一个 `BeitConfig` 类型的配置参数 `config`
    def __init__(self, config: BeitConfig) -> None:
        super().__init__()

        # 设置池化尺度,例如 (1, 2, 3, 6)
        self.pool_scales = config.pool_scales  # e.g. (1, 2, 3, 6)
        # 设置输入通道数列表,全为 `config.hidden_size`,例如 [768, 768, 768, 768]
        self.in_channels = [config.hidden_size] * 4  # e.g. [768, 768, 768, 768]
        # 设置通道数为 `config.hidden_size`
        self.channels = config.hidden_size
        # 是否对齐角点,默认为 False
        self.align_corners = False
        # 分类器,使用 1x1 卷积将通道数从 `self.channels` 转换为 `config.num_labels`
        self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)

        # PSP Module,使用 `BeitPyramidPoolingModule` 初始化池化模块
        self.psp_modules = BeitPyramidPoolingModule(
            self.pool_scales,
            self.in_channels[-1],  # 最后一个输入通道数
            self.channels,
            align_corners=self.align_corners,
        )
        # 瓶颈模块,使用 `BeitConvModule` 初始化卷积模块
        self.bottleneck = BeitConvModule(
            self.in_channels[-1] + len(self.pool_scales) * self.channels,
            self.channels,
            kernel_size=3,
            padding=1,
        )
        
        # FPN Module,构建特征金字塔网络模块
        self.lateral_convs = nn.ModuleList()
        self.fpn_convs = nn.ModuleList()
        # 遍历除了顶层之外的所有输入通道数
        for in_channels in self.in_channels[:-1]:  # skip the top layer
            # 使用 `BeitConvModule` 初始化侧边卷积模块
            l_conv = BeitConvModule(in_channels, self.channels, kernel_size=1)
            # 使用 `BeitConvModule` 初始化金字塔卷积模块
            fpn_conv = BeitConvModule(self.channels, self.channels, kernel_size=3, padding=1)
            self.lateral_convs.append(l_conv)  # 添加到侧边卷积模块列表
            self.fpn_convs.append(fpn_conv)    # 添加到金字塔卷积模块列表

        # FPN 瓶颈模块,使用 `BeitConvModule` 初始化卷积模块
        self.fpn_bottleneck = BeitConvModule(
            len(self.in_channels) * self.channels,  # 所有输入通道数的总和
            self.channels,
            kernel_size=3,
            padding=1,
        )

    # PSP 前向传播方法,接收输入 `inputs`,返回处理后的输出
    def psp_forward(self, inputs):
        x = inputs[-1]  # 取输入列表的最后一个元素作为输入
        psp_outs = [x]  # 初始化 PSP 输出列表
        psp_outs.extend(self.psp_modules(x))  # 将 PSP 模块的输出扩展到 PSP 输出列表
        psp_outs = torch.cat(psp_outs, dim=1)  # 在通道维度上连接 PSP 输出
        output = self.bottleneck(psp_outs)  # 使用瓶颈模块处理 PSP 输出

        return output  # 返回处理后的输出
    def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
        # 构建侧边连接
        laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]

        # 将PSP模块的输出添加到侧边连接列表中
        laterals.append(self.psp_forward(encoder_hidden_states))

        # 构建自顶向下路径
        used_backbone_levels = len(laterals)
        for i in range(used_backbone_levels - 1, 0, -1):
            prev_shape = laterals[i - 1].shape[2:]
            laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate(
                laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
            )

        # 构建FPN输出
        fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
        
        # 将PSP特征追加到FPN输出列表中
        fpn_outs.append(laterals[-1])

        # 对FPN输出进行上采样
        for i in range(used_backbone_levels - 1, 0, -1):
            fpn_outs[i] = nn.functional.interpolate(
                fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
            )

        # 在通道维度上连接所有FPN输出
        fpn_outs = torch.cat(fpn_outs, dim=1)

        # 经过FPN瓶颈层处理
        output = self.fpn_bottleneck(fpn_outs)

        # 使用分类器处理最终输出
        output = self.classifier(output)

        # 返回最终结果
        return output
# 定义一个用于语义分割的头部模块,基于 Fully Convolution Networks(FCN)的设计。
# 详见论文 [FCNNet](https://arxiv.org/abs/1411.4038>)。
class BeitFCNHead(nn.Module):
    """
    Fully Convolution Networks for Semantic Segmentation. This head is implemented of
    [FCNNet](https://arxiv.org/abs/1411.4038>).

    Args:
        config (BeitConfig): Configuration.
        in_index (int): Index of the encoder hidden state to use as input. Default: 2.
        kernel_size (int): The kernel size for convolutions in the head. Default: 3.
        dilation (int or tuple): The dilation rate for convolutions in the head. Default: 1.

    Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
    """

    def __init__(
        self, config: BeitConfig, in_index: int = 2, kernel_size: int = 3, dilation: Union[int, Tuple[int, int]] = 1
    ) -> None:
        super().__init__()
        # 初始化头部模块的参数
        self.in_channels = config.hidden_size  # 输入通道数等于隐藏状态的大小
        self.channels = config.auxiliary_channels  # 辅助通道数
        self.num_convs = config.auxiliary_num_convs  # 卷积层的数量
        self.concat_input = config.auxiliary_concat_input  # 是否将输入与卷积输出拼接的标志
        self.in_index = in_index  # 输入隐藏状态的索引位置

        # 计算卷积层的填充大小
        conv_padding = (kernel_size // 2) * dilation
        convs = []
        
        # 添加第一个卷积模块
        convs.append(
            BeitConvModule(
                self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
            )
        )
        
        # 添加剩余的卷积模块
        for i in range(self.num_convs - 1):
            convs.append(
                BeitConvModule(
                    self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
                )
            )
        
        # 如果没有卷积层,则使用 nn.Identity 作为卷积层
        if self.num_convs == 0:
            self.convs = nn.Identity()
        else:
            self.convs = nn.Sequential(*convs)
        
        # 如果设置了拼接输入标志,则创建用于拼接的卷积模块
        if self.concat_input:
            self.conv_cat = BeitConvModule(
                self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2
            )
        
        # 分类器,最终输出的通道数为配置中的标签数量
        self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)

    def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
        # 从编码器隐藏状态中取出指定的特征图
        hidden_states = encoder_hidden_states[self.in_index]
        
        # 经过卷积层处理
        output = self.convs(hidden_states)
        
        # 如果设置了拼接输入标志,则将原始输入与卷积输出拼接后再进行处理
        if self.concat_input:
            output = self.conv_cat(torch.cat([hidden_states, output], dim=1))
        
        # 最后经过分类器输出结果
        output = self.classifier(output)
        return output


@add_start_docstrings(
    """
    Beit Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
    """,
    BEIT_START_DOCSTRING,
)
class BeitForSemanticSegmentation(BeitPreTrainedModel):
    """
    Beit Model transformer with a semantic segmentation head for tasks like ADE20k, CityScapes.

    Inherits from BeitPreTrainedModel, which is the base class for all Beit models.
    """
    def __init__(self, config: BeitConfig) -> None:
        # 调用父类的初始化方法,传入配置对象
        super().__init__(config)

        # 从配置对象中获取标签数量
        self.num_labels = config.num_labels
        # 创建一个 BEiT 模型对象,不添加池化层
        self.beit = BeitModel(config, add_pooling_layer=False)

        # FPNs
        # 检查配置中的输出索引是否为四个整数,否则抛出数值错误异常
        if len(self.config.out_indices) != 4:
            raise ValueError(
                "BeitForSemanticSegmentation requires config.out_indices to be a list of 4 integers, "
                "specifying which features to use from the backbone. One can use [3, 5, 7, 11] in case of "
                "a base-sized architecture."
            )
        # 定义语义分割头部网络的几个转置卷积操作序列
        self.fpn1 = nn.Sequential(
            nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
            nn.BatchNorm2d(config.hidden_size),
            nn.GELU(),
            nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
        )
        self.fpn2 = nn.Sequential(
            nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
        )
        self.fpn3 = nn.Identity()  # 直接返回输入的恒等映射
        self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)  # 最大池化操作,核大小为2x2

        # 语义分割头部网络
        self.decode_head = BeitUperHead(config)  # 创建解码头部对象
        self.auxiliary_head = BeitFCNHead(config) if config.use_auxiliary_head else None  # 如果配置中启用辅助头部,则创建辅助头部对象,否则为 None

        # 初始化权重并应用最终处理
        self.post_init()

    def compute_loss(self, logits, auxiliary_logits, labels):
        # 将 logits 上采样到原始图像大小
        upsampled_logits = nn.functional.interpolate(
            logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
        )
        if auxiliary_logits is not None:
            upsampled_auxiliary_logits = nn.functional.interpolate(
                auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
            )
        # 计算加权损失
        loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
        main_loss = loss_fct(upsampled_logits, labels)
        loss = main_loss
        if auxiliary_logits is not None:
            auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
            loss += self.config.auxiliary_loss_weight * auxiliary_loss

        return loss

    @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
# 使用 add_start_docstrings 装饰器添加 BEiT 的背景说明文档,用于与 DETR 和 MaskFormer 等框架集成
@add_start_docstrings(
    """
    BEiT backbone, to be used with frameworks like DETR and MaskFormer.
    """,
    BEIT_START_DOCSTRING,
)
# 定义 BEiT 的骨干网络类,继承自 BeitPreTrainedModel 和 BackboneMixin
class BeitBackbone(BeitPreTrainedModel, BackboneMixin):
    # 初始化函数,接受一个 config 对象作为参数
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)
        # 调用 BackboneMixin 类的初始化方法,初始化骨干网络
        super()._init_backbone(config)

        # 根据配置设置特征的数量为隐藏层大小的列表
        self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
        # 初始化嵌入层
        self.embeddings = BeitEmbeddings(config)
        # 初始化编码器,并传递嵌入层的窗口大小作为参数
        self.encoder = BeitEncoder(config, window_size=self.embeddings.patch_embeddings.patch_shape)

        # 如果配置中指定要添加 FPN
        if config.add_fpn:
            # 检查配置中的输出索引列表是否包含四个整数
            if len(self.config.out_indices) != 4:
                # 如果不是,则抛出数值错误异常
                raise ValueError(
                    "BeitBackbone requires config.out_indices to be a list of 4 integers, "
                    "specifying which features to use from the backbone. One can use [3, 5, 7, 11] in case of "
                    "a base-sized architecture."
                )
            # 获取隐藏层大小
            hidden_size = config.hidden_size
            # 初始化 FPN1,包括两个转置卷积层和批归一化层,使用 GELU 激活函数
            self.fpn1 = nn.Sequential(
                nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2),
                nn.BatchNorm2d(hidden_size, eps=config.batch_norm_eps),
                nn.GELU(),
                nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2),
            )

            # 初始化 FPN2,包括一个转置卷积层
            self.fpn2 = nn.Sequential(nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2))
            # 初始化 FPN3,为恒等映射层
            self.fpn3 = nn.Identity()
            # 初始化 FPN4,为最大池化层
            self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 初始化权重并应用最终处理
        self.post_init()

    # 获取输入嵌入层的方法
    def get_input_embeddings(self):
        return self.embeddings.patch_embeddings

    # 重写 forward 方法,接受像素值张量和可选的输出隐藏状态、注意力和返回字典作为参数
    @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        pixel_values: Tensor,
        output_hidden_states: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        """
        如果 return_dict 不为 None,则使用其值;否则使用 self.config.use_return_dict 的值作为返回结果的字典选择
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        """
        如果 output_hidden_states 不为 None,则使用其值;否则使用 self.config.output_hidden_states 的值作为输出隐藏状态的选择
        """
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        """
        如果 output_attentions 不为 None,则使用其值;否则使用 self.config.output_attentions 的值作为输出注意力的选择
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions

        """
        获取输入像素值的批次大小
        """
        batch_size = pixel_values.shape[0]
        """
        使用 self.embeddings 处理像素值,得到嵌入输出和每个补丁的高度和宽度
        """
        embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values)

        """
        使用 self.encoder 处理嵌入输出,设置输出隐藏状态和注意力的选择,根据 return_dict 决定是否返回字典
        """
        outputs = self.encoder(
            embedding_output, output_hidden_states=True, output_attentions=output_attentions, return_dict=return_dict
        )

        """
        如果 return_dict 为 True,则将隐藏状态存储在 outputs.hidden_states 中;否则在 outputs[1] 中
        """
        hidden_states = outputs.hidden_states if return_dict else outputs[1]

        """
        初始化空的特征映射元组
        """
        feature_maps = ()
        """
        遍历阶段名称和隐藏状态,根据设定的输出特征名称收集对应的隐藏状态
        """
        for stage, hidden_state in zip(self.stage_names, hidden_states):
            if stage in self.out_features:
                """
                如果 self.config.reshape_hidden_states 为 True,则对隐藏状态进行形状调整
                """
                if self.config.reshape_hidden_states:
                    hidden_state = hidden_state[:, 1:, :]  # 移除CLS标记
                    hidden_state = hidden_state.permute(0, 2, 1)  # 调整维度顺序
                    hidden_state = hidden_state.reshape(batch_size, -1, patch_height, patch_width)  # 重塑形状

                """
                将符合条件的隐藏状态添加到特征映射中
                """
                feature_maps += (hidden_state,)

        """
        如果配置中添加了特征金字塔网络(FPN),则对特征映射进行相应的处理
        """
        if self.config.add_fpn:
            feature_maps = [
                self.fpn1(feature_maps[0]),
                self.fpn2(feature_maps[1]),
                self.fpn3(feature_maps[2]),
                self.fpn4(feature_maps[3]),
            ]
            feature_maps = tuple(feature_maps)

        """
        如果不返回字典,则根据输出隐藏状态的设置返回输出元组
        """
        if not return_dict:
            if output_hidden_states:
                output = (feature_maps,) + outputs[1:]
            else:
                output = (feature_maps,) + outputs[2:]
            return output

        """
        返回 BackboneOutput 对象,包含特征映射、隐藏状态和注意力
        """
        return BackboneOutput(
            feature_maps=feature_maps,
            hidden_states=outputs.hidden_states if output_hidden_states else None,
            attentions=outputs.attentions,
        )

.\models\beit\modeling_flax_beit.py

# BEIT_START_DOCSTRING 是一个原始文档字符串的标记,用于后续的文档字符串生成
BEIT_START_DOCSTRING = r"""
    # 这个模型继承自 `FlaxPreTrainedModel`。查看超类的文档,了解库为所有模型实现的通用方法(如下载、保存和从PyTorch模型转换权重)。
    
    # 这个模型还是一个 `flax.linen.Module` 的子类。可以将其作为常规的 Flax linen 模块使用,并参考 Flax 文档了解与一般使用和行为相关的所有内容。
    
    # 最后,这个模型支持 JAX 的一些内置特性,如:
    # - Just-In-Time (JIT) 编译:https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit
    # - 自动微分:https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation
    # - 向量化:https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap
    # - 并行化:https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap
    
    # 参数:
    # config (`BeitConfig`): 包含模型所有参数的配置类。
    # 初始化时使用配置文件不会加载模型的权重,只加载配置。查看 `~FlaxPreTrainedModel.from_pretrained` 方法以加载模型权重。
    
    # dtype (`jax.numpy.dtype`, *optional*, 默认为 `jax.numpy.float32`):
    # 计算时的数据类型。可以是 `jax.numpy.float32`、`jax.numpy.float16`(在GPU上)和 `jax.numpy.bfloat16`(在TPU上)之一。
    # 可用于在GPU或TPU上启用混合精度训练或半精度推理。如果指定了 dtype,所有计算都将使用给定的 `dtype` 进行。
    
    # **注意,这只指定了计算时的数据类型,不影响模型参数的数据类型。**
    # 如果希望更改模型参数的数据类型,请参阅 `~FlaxPreTrainedModel.to_fp16` 和 `~FlaxPreTrainedModel.to_bf16`。
"""

BEIT_INPUTS_DOCSTRING = r"""
    Args:
        pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`AutoImageProcessor.__call__`] for details.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""

def relative_position_index_init(window_size: Tuple[int, int]) -> jnp.ndarray:
    """
    Initialize a matrix of relative position indices for tokens inside a window.

    Args:
        window_size: Tuple specifying the height and width of the window.

    Returns:
        jnp.ndarray: Matrix of relative position indices.

    This function computes the relative positions between tokens in a window based on the specified window size.
    """

    num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3

    coords_h = np.arange(window_size[0])
    coords_w = np.arange(window_size[1])
    coords = np.stack(np.meshgrid(coords_h, coords_w, indexing="ij"))  # 2, Wh, Ww
    coords_flatten = np.reshape(coords, (2, -1))
    relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
    relative_coords = np.transpose(relative_coords, (1, 2, 0))  # Wh*Ww, Wh*Ww, 2
    relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
    relative_coords[:, :, 1] += window_size[1] - 1
    relative_coords[:, :, 0] *= 2 * window_size[1] - 1

    relative_position_index = np.zeros(shape=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
    relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
    relative_position_index[0, 0:] = num_relative_distance - 3
    relative_position_index[0:, 0] = num_relative_distance - 2
    relative_position_index[0, 0] = num_relative_distance - 1
    return jnp.array(relative_position_index)


def ones_with_scale(key, shape, scale, dtype=jnp.float32):
    """
    Create a tensor filled with ones scaled by a specified factor.

    Args:
        key: Random key for JAX randomness.
        shape: Shape of the tensor.
        scale: Scaling factor for the ones tensor.
        dtype: Data type of the tensor.

    Returns:
        jnp.ndarray: Tensor filled with ones scaled by `scale`.
    """
    return jnp.ones(shape, dtype) * scale


class FlaxBeitDropPath(nn.Module):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    """

    rate: float

    @nn.module.compact
    def __call__(self, inputs, deterministic: Optional[bool] = True):
        """
        Apply drop path regularization to inputs.

        Args:
            inputs: Input tensor to which drop path is applied.
            deterministic: Whether to apply deterministic or stochastic drop path.

        Returns:
            jnp.ndarray: Output tensor after applying drop path regularization.
        """
        if self.rate == 0.0:
            return inputs
        keep_prob = 1.0 - self.rate
        if deterministic:
            return inputs
        else:
            shape = (inputs.shape[0],) + (1,) * (inputs.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
            rng = self.make_rng("droppath")
            random_tensor = keep_prob + jax.random.uniform(rng, shape=shape, dtype=inputs.dtype)
            binary_tensor = jnp.floor(random_tensor)
            output = inputs / keep_prob * binary_tensor
            return output
    # 定义一个名为 FlaxBeitPatchEmbeddings 的新模块,继承自 nn.Module
    class FlaxBeitPatchEmbeddings(nn.Module):
        # 引入配置类 BeitConfig
        config: BeitConfig
        # 定义计算时使用的数据类型,默认为 jnp.float32
        dtype: jnp.dtype = jnp.float32  # 计算时使用的数据类型

        # 模块的设置方法
        def setup(self):
            # 从配置中获取通道数和图像大小
            self.num_channels = self.config.num_channels
            image_size = self.config.image_size
            patch_size = self.config.patch_size
            # 计算图像被分成的块数和每个块的形状
            num_patches = (image_size // patch_size) * (image_size // patch_size)
            patch_shape = (image_size // patch_size, image_size // patch_size)
            # 设置模块的属性
            self.num_patches = num_patches
            self.patch_shape = patch_shape
            # 创建一个卷积层投影,用于将输入投影到隐藏尺寸空间
            self.projection = nn.Conv(
                self.config.hidden_size,
                kernel_size=(patch_size, patch_size),
                strides=(patch_size, patch_size),
                padding="VALID",
                dtype=self.dtype,
                kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            )

        # 模块的调用方法,处理输入像素值
        def __call__(self, pixel_values):
            # 检查输入像素值的通道数是否与配置中设置的通道数匹配
            num_channels = pixel_values.shape[-1]
            if num_channels != self.num_channels:
                raise ValueError(
                    "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
                )
            # 使用投影层处理像素值,得到嵌入表示
            embeddings = self.projection(pixel_values)
            batch_size, _, _, channels = embeddings.shape
            # 将嵌入表示重塑为适当的形状,以便后续处理
            return jnp.reshape(embeddings, (batch_size, -1, channels))


    # 定义一个名为 FlaxBeitEmbeddings 的新模块,继承自 nn.Module
    class FlaxBeitEmbeddings(nn.Module):
        """构建CLS令牌、位置和补丁嵌入。"""

        # 引入配置类 BeitConfig
        config: BeitConfig
        # 定义计算时使用的数据类型,默认为 jnp.float32
        dtype: jnp.dtype = jnp.float32  # 计算时使用的数据类型

        # 模块的设置方法
        def setup(self):
            # 定义一个CLS令牌,初始化为全零,形状为 (1, 1, hidden_size)
            self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size))
            # 如果配置要求使用掩码令牌,则定义一个掩码令牌,初始化为全零,形状也为 (1, 1, hidden_size)
            if self.config.use_mask_token:
                self.mask_token = self.param("mask_token", nn.initializers.zeros, (1, 1, self.config.hidden_size))
            # 创建补丁嵌入模块实例,使用给定的配置和数据类型
            self.patch_embeddings = FlaxBeitPatchEmbeddings(self.config, dtype=self.dtype)
            num_patches = self.patch_embeddings.num_patches
            # 如果配置要求使用绝对位置嵌入,则定义一个绝对位置嵌入参数,初始化为全零,形状为 (1, num_patches + 1, hidden_size)
            if self.config.use_absolute_position_embeddings:
                self.position_embeddings = self.param(
                    "position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size)
                )
            # 定义一个Dropout层,用于随机断开输入单元,以防止过拟合
            self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
    # 定义一个类的调用方法,接受像素值和可选的布尔掩码作为输入参数,并返回嵌入表示
    def __call__(self, pixel_values, bool_masked_pos=None, deterministic=True):
        # 使用patch_embeddings方法将像素值转换为嵌入表示
        embeddings = self.patch_embeddings(pixel_values)
        # 获取嵌入表示的维度信息:批量大小、序列长度、嵌入维度
        batch_size, seq_len, _ = embeddings.shape

        # 创建一个形状与嵌入表示相同的CLS token,并将其数据类型转换为embeddings的数据类型
        cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size))
        cls_tokens = cls_tokens.astype(embeddings.dtype)

        # 如果给定了布尔掩码,替换被掩码的视觉令牌为mask_tokens
        if bool_masked_pos is not None:
            # 创建一个形状与嵌入表示相同的mask token,并将其数据类型转换为embeddings的数据类型
            mask_tokens = jnp.broadcast_to(self.mask_token, (batch_size, seq_len, self.config.hidden_size))
            mask_tokens = mask_tokens.astype(embeddings.dtype)
            # 使用布尔掩码来选择性地应用mask_tokens替换embeddings中的视觉令牌
            w = jnp.expand_dims(bool_masked_pos, axis=-1)
            embeddings = embeddings * (1 - w) + mask_tokens * w

        # 将CLS token与嵌入表示连接起来,形成完整的嵌入表示序列
        embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1)

        # 如果配置中使用了绝对位置嵌入,将位置嵌入加到嵌入表示上
        if self.config.use_absolute_position_embeddings:
            embeddings = embeddings + self.position_embeddings.astype(embeddings.dtype)

        # 使用dropout方法对嵌入表示进行随机失活,根据deterministic参数确定是否确定性地进行操作
        embeddings = self.dropout(embeddings, deterministic=deterministic)
        # 返回最终的嵌入表示
        return embeddings
    # FlaxBeitRelativePositionBias 类,用于计算相对位置偏置
    class FlaxBeitRelativePositionBias(nn.Module):
        # BeitConfig 类型的配置信息
        config: BeitConfig
        # 窗口大小的元组,表示注意力窗口的尺寸
        window_size: Tuple[int, int]
        # 计算中使用的数据类型,默认为 jnp.float32
        dtype: jnp.dtype = jnp.float32  # the dtype of the computation

        # 模块初始化方法
        def setup(self):
            # 计算相对距离的数量,形状为 (2*Wh-1)*(2*Ww-1) + 3
            num_relative_distance = (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) + 3
            # 创建参数,相对位置偏置表,形状为 (num_relative_distance, num_attention_heads)
            self.relative_position_bias_table = self.param(
                "relative_position_bias_table",
                nn.initializers.zeros,
                (num_relative_distance, self.config.num_attention_heads),
            )  # 2*Wh-1 * 2*Ww-1, nH
            # 类别到标记 & 标记到类别 & 类别到类别

            # 初始化相对位置索引
            self.relative_position_index = relative_position_index_init(self.window_size)

        # 对象调用方法
        def __call__(self):
            # 将相对位置索引重塑为一维数组
            index = self.relative_position_index.reshape(-1)
            # 定义相对位置偏置的形状
            shape = (self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1)
            # 根据索引从相对位置偏置表中获取相对位置偏置,并重塑为指定形状
            relative_position_bias = self.relative_position_bias_table[index].reshape(shape)  # Wh*Ww,Wh*Ww,nH
            # 返回相对位置偏置,并进行维度转置
            return jnp.transpose(relative_position_bias, (2, 0, 1))


    # FlaxBeitSelfAttention 类,实现自注意力机制
    class FlaxBeitSelfAttention(nn.Module):
        # BeitConfig 类型的配置信息
        config: BeitConfig
        # 窗口大小的元组,表示注意力窗口的尺寸
        window_size: Tuple[int, int]
        # 计算中使用的数据类型,默认为 jnp.float32
        dtype: jnp.dtype = jnp.float32  # the dtype of the computation

        # 模块初始化方法
        def setup(self):
            # 检查隐藏层大小是否是注意力头数的倍数,且不是嵌入大小的属性
            if self.config.hidden_size % self.config.num_attention_heads != 0 and not hasattr(
                self.config, "embedding_size"
            ):
                # 抛出数值错误,提示隐藏大小不是注意力头数的倍数
                raise ValueError(
                    f"The hidden size {self.config.hidden_size,} is not a multiple of the number of attention "
                    f"heads {self.config.num_attention_heads}."
                )

            # 初始化查询、键、值的线性变换层
            self.query = nn.Dense(
                self.config.hidden_size,
                dtype=self.dtype,
                kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            )
            self.key = nn.Dense(
                self.config.hidden_size,
                dtype=self.dtype,
                kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
                use_bias=False,
            )
            self.value = nn.Dense(
                self.config.hidden_size,
                dtype=self.dtype,
                kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            )

            # 如果定义了窗口大小,创建相对位置偏置对象
            self.relative_position_bias = (
                FlaxBeitRelativePositionBias(self.config, window_size=self.window_size, dtype=self.dtype)
                if self.window_size
                else None
            )

        # 对象调用方法,实现自注意力计算
        def __call__(
            self, hidden_states, relative_position_bias=None, deterministic: bool = True, output_attentions: bool = False
    ):
        head_dim = self.config.hidden_size // self.config.num_attention_heads

        # 将查询向量转换成多头格式:(batch_size, seq_length, num_attention_heads, head_dim)
        query_states = self.query(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )
        
        # 将数值向量转换成多头格式:(batch_size, seq_length, num_attention_heads, head_dim)
        value_states = self.value(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )
        
        # 将键向量转换成多头格式:(batch_size, seq_length, num_attention_heads, head_dim)
        key_states = self.key(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )

        dropout_rng = None
        # 如果非确定性计算且设置了注意力概率的丢弃率,则创建一个用于丢弃的随机数生成器
        if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
            dropout_rng = self.make_rng("dropout")

        attention_bias = jnp.array(0.0, dtype=self.dtype)
        # 如果存在相对位置偏置,则添加到注意力偏置中
        if self.relative_position_bias is not None:
            attention_bias = jnp.expand_dims(self.relative_position_bias(), 0)
            attention_bias = attention_bias.astype(query_states.dtype)

        # 如果提供了共享的相对位置偏置,则将其加到注意力偏置中
        if relative_position_bias is not None:
            attention_bias = attention_bias + relative_position_bias.astype(attention_bias.dtype)

        # 计算点积注意力的权重
        attn_weights = dot_product_attention_weights(
            query_states,
            key_states,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attention_probs_dropout_prob,
            broadcast_dropout=True,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        # 使用注意力权重计算注意力输出
        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
        attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))

        # 如果需要输出注意力权重,则返回注意力输出和注意力权重;否则只返回注意力输出
        outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
        return outputs
# 定义一个 FlaxBeitSelfOutput 类,继承自 nn.Module
class FlaxBeitSelfOutput(nn.Module):
    # 配置项,使用 BeitConfig 类型
    config: BeitConfig
    # 计算时的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    # 初始化方法
    def setup(self):
        # 定义一个全连接层,输出大小为 self.config.hidden_size
        # 初始化权重使用正态分布,标准差为 self.config.initializer_range
        self.dense = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        # 定义一个 Dropout 层,丢弃率为 self.config.hidden_dropout_prob
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

    # 调用方法,接收 hidden_states 和 deterministic 参数
    def __call__(self, hidden_states, deterministic: bool = True):
        # 将 hidden_states 输入到全连接层中
        hidden_states = self.dense(hidden_states)
        # 使用 Dropout 处理 hidden_states
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        # 返回处理后的 hidden_states
        return hidden_states


# 定义一个 FlaxBeitAttention 类,继承自 nn.Module
class FlaxBeitAttention(nn.Module):
    # 配置项,使用 BeitConfig 类型
    config: BeitConfig
    # 窗口大小,为元组类型,存储两个整数值
    window_size: Tuple[int, int]
    # 计算时的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 初始化方法
    def setup(self):
        # 定义一个自注意力层,使用 FlaxBeitSelfAttention 类
        self.attention = FlaxBeitSelfAttention(self.config, self.window_size, dtype=self.dtype)
        # 定义一个输出层,使用 FlaxBeitSelfOutput 类
        self.output = FlaxBeitSelfOutput(self.config, dtype=self.dtype)

    # 调用方法,接收 hidden_states、relative_position_bias、deterministic 和 output_attentions 参数
    def __call__(
        self, hidden_states, relative_position_bias=None, deterministic=True, output_attentions: bool = False
    ):
        # 执行自注意力层的调用方法,传入相关参数
        attn_outputs = self.attention(
            hidden_states, relative_position_bias, deterministic=deterministic, output_attentions=output_attentions
        )
        # 获取注意力输出的第一个元素
        attn_output = attn_outputs[0]
        # 将注意力输出传入输出层进行处理
        attn_output = self.output(attn_output, deterministic=deterministic)

        # 初始化 outputs 为包含 attn_output 的元组
        outputs = (attn_output,)

        # 如果 output_attentions 为 True,则将注意力输出的第二个元素加入 outputs 中
        if output_attentions:
            outputs += (attn_outputs[1],)

        # 返回 outputs
        return outputs


# 定义一个 FlaxBeitIntermediate 类,继承自 nn.Module
class FlaxBeitIntermediate(nn.Module):
    # 配置项,使用 BeitConfig 类型
    config: BeitConfig
    # 计算时的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    # 初始化方法
    def setup(self):
        # 定义一个全连接层,输出大小为 self.config.intermediate_size
        # 初始化权重使用正态分布,标准差为 self.config.initializer_range
        self.dense = nn.Dense(
            self.config.intermediate_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        # 激活函数为配置中指定的隐藏层激活函数
        self.activation = ACT2FN[self.config.hidden_act]

    # 调用方法,接收 hidden_states 参数
    def __call__(self, hidden_states):
        # 将 hidden_states 输入到全连接层中
        hidden_states = self.dense(hidden_states)
        # 使用激活函数处理 hidden_states
        hidden_states = self.activation(hidden_states)

        # 返回处理后的 hidden_states
        return hidden_states


# 定义一个 FlaxBeitOutput 类,继承自 nn.Module
class FlaxBeitOutput(nn.Module):
    # 配置项,使用 BeitConfig 类型
    config: BeitConfig
    # 计算时的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    # 初始化方法
    def setup(self):
        # 定义一个全连接层,输出大小为 self.config.hidden_size
        # 初始化权重使用正态分布,标准差为 self.config.initializer_range
        self.dense = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        # 定义一个 Dropout 层,丢弃率为 self.config.hidden_dropout_prob
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

    # 调用方法,接收 hidden_states 和 deterministic 参数
    def __call__(self, hidden_states, deterministic: bool = True):
        # 将 hidden_states 输入到全连接层中
        hidden_states = self.dense(hidden_states)
        # 使用 Dropout 处理 hidden_states
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)

        # 返回处理后的 hidden_states
        return hidden_states


# 定义一个 FlaxBeitLayer 类,继承自 nn.Module
class FlaxBeitLayer(nn.Module):
    # 配置项,使用 BeitConfig 类型
    config: BeitConfig
    # 窗口大小,为元组类型,存储两个整数值
    window_size: Tuple[int, int]
    # DropPath 的概率
    drop_path_rate: float
    # 计算时的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    # 在初始化方法中设置模型的各个组件
    def setup(self):
        # 初始化注意力机制组件
        self.attention = FlaxBeitAttention(self.config, self.window_size, dtype=self.dtype)
        # 初始化中间层组件
        self.intermediate = FlaxBeitIntermediate(self.config, dtype=self.dtype)
        # 初始化输出层组件
        self.output = FlaxBeitOutput(self.config, dtype=self.dtype)
        # 初始化前层归一化组件,使用给定的 epsilon 参数
        self.layernorm_before = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        # 初始化 DropPath 组件,使用给定的丢弃率
        self.drop_path = FlaxBeitDropPath(rate=self.drop_path_rate)
        # 初始化后层归一化组件,使用给定的 epsilon 参数
        self.layernorm_after = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)

        # 初始化 lambda_1 和 lambda_2 参数,如果初始化值大于 0,则创建参数;否则设为 None
        self.init_values = self.config.layer_scale_init_value
        if self.init_values > 0:
            self.lambda_1 = self.param("lambda_1", ones_with_scale, (self.config.hidden_size), self.init_values)
            self.lambda_2 = self.param("lambda_2", ones_with_scale, (self.config.hidden_size), self.init_values)
        else:
            self.lambda_1 = None
            self.lambda_2 = None

    # 实现调用方法,处理输入的隐藏状态,执行模型的前向传播
    def __call__(
        self, hidden_states, relative_position_bias=None, deterministic: bool = True, output_attentions: bool = False
    ):
        # 执行自注意力机制,包括前层归一化
        self_attention_outputs = self.attention(
            self.layernorm_before(hidden_states),  # 在 BEiT 中,自注意力前先进行归一化
            relative_position_bias,
            deterministic=deterministic,
            output_attentions=output_attentions,
        )
        # 获取自注意力的输出
        attention_output = self_attention_outputs[0]

        # 如果 lambda_1 参数存在,则应用于注意力输出
        if self.lambda_1 is not None:
            attention_output = self.lambda_1.astype(attention_output.dtype) * attention_output

        # 第一次残差连接
        hidden_states = self.drop_path(attention_output, deterministic=deterministic) + hidden_states

        # 在 BEiT 中,层归一化也应用于自注意力后
        layer_output = self.layernorm_after(hidden_states)

        # 执行中间层操作
        layer_output = self.intermediate(layer_output)
        # 执行输出层操作,包括确定性与否
        layer_output = self.output(layer_output, deterministic=deterministic)

        # 如果 lambda_2 参数存在,则应用于中间层输出
        if self.lambda_2 is not None:
            layer_output = self.lambda_2.astype(layer_output.dtype) * layer_output

        # 第二次残差连接,将中间层输出与原始隐藏状态相加
        layer_output = self.drop_path(layer_output, deterministic=deterministic) + hidden_states

        # 返回最终输出,包括中间层输出或者中间层输出及注意力权重(根据需求)
        outputs = (layer_output,)

        if output_attentions:
            outputs += (self_attention_outputs[1],)

        return outputs
class FlaxBeitLayerCollection(nn.Module):
    config: BeitConfig  # 类型注解,指定 config 属性的类型为 BeitConfig
    window_size: Tuple[int, int]  # 类型注解,指定 window_size 属性的类型为元组,包含两个整数
    drop_path_rates: List[float]  # 类型注解,指定 drop_path_rates 属性的类型为列表,包含浮点数
    relative_position_bias: Callable[[], jnp.ndarray]  # 类型注解,指定 relative_position_bias 属性的类型为可调用对象,返回 jnp.ndarray
    dtype: jnp.dtype = jnp.float32  # 类型注解,默认值为 jnp.float32,指定 dtype 属性的类型为 jnp.dtype,表示计算时的数据类型

    def setup(self):
        # 初始化 layers 属性为一个列表,每个元素是一个 FlaxBeitLayer 实例
        self.layers = [
            FlaxBeitLayer(
                self.config,
                window_size=self.window_size if self.config.use_relative_position_bias else None,
                drop_path_rate=self.drop_path_rates[i],
                name=str(i),
                dtype=self.dtype,
            )
            for i in range(self.config.num_hidden_layers)
        ]

    def __call__(
        self,
        hidden_states,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 如果 output_attentions 为 True,则初始化空元组 all_attentions,否则设为 None
        all_attentions = () if output_attentions else None
        # 如果 output_hidden_states 为 True,则初始化空元组 all_hidden_states,否则设为 None
        all_hidden_states = () if output_hidden_states else None

        # 遍历 layers 列表中的每个层,并处理 hidden_states
        for i, layer in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)  # 将当前 hidden_states 添加到 all_hidden_states 元组中
            # 根据 self.relative_position_bias 的值初始化 relative_position_bias
            relative_position_bias = self.relative_position_bias() if self.relative_position_bias is not None else None
            # 调用当前层 layer 的处理方法,更新 hidden_states
            layer_outputs = layer(
                hidden_states, relative_position_bias, deterministic=deterministic, output_attentions=output_attentions
            )
            # 更新 hidden_states 为当前层的输出的第一个元素
            hidden_states = layer_outputs[0]

            # 如果 output_attentions 为 True,则将当前层的注意力加入 all_attentions 元组
            if output_attentions:
                all_attentions += (layer_outputs[1],)

        # 如果 output_hidden_states 为 True,则将最终的 hidden_states 加入 all_hidden_states 元组
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        # 初始化输出为包含最终 hidden_states 的元组 outputs
        outputs = (hidden_states,)
        # 如果 return_dict 为 False,则返回 outputs 中不为 None 的元素组成的元组
        if not return_dict:
            return tuple(v for v in outputs if v is not None)

        # 返回 FlaxBaseModelOutput 对象,包含最终的 hidden_states、所有 hidden_states 和所有 attentions
        return FlaxBaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )


class FlaxBeitEncoder(nn.Module):
    config: BeitConfig  # 类型注解,指定 config 属性的类型为 BeitConfig
    window_size: Tuple[int, int]  # 类型注解,指定 window_size 属性的类型为元组,包含两个整数
    dtype: jnp.dtype = jnp.float32  # 类型注解,默认值为 jnp.float32,指定 dtype 属性的类型为 jnp.dtype,表示计算时的数据类型

    def setup(self):
        # 如果 self.config.use_shared_relative_position_bias 为 True,则初始化 relative_position_bias
        if self.config.use_shared_relative_position_bias:
            self.relative_position_bias = FlaxBeitRelativePositionBias(
                config=self.config, window_size=self.window_size, dtype=self.dtype
            )

        # 根据 stochastic depth decay rule 初始化 drop_path_rates 列表
        drop_path_rates = list(np.linspace(0, self.config.drop_path_rate, self.config.num_hidden_layers))
        # 初始化 layer 属性为 FlaxBeitLayerCollection 实例
        self.layer = FlaxBeitLayerCollection(
            self.config,
            window_size=self.window_size,
            drop_path_rates=drop_path_rates,
            relative_position_bias=self.relative_position_bias
            if self.config.use_shared_relative_position_bias
            else None,
            dtype=self.dtype,
        )
    # 定义一个特殊方法 __call__,使对象可以像函数一样被调用
    def __call__(
        self,
        hidden_states,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 调用对象的 layer 方法,传入参数和关键字参数,并返回结果
        return self.layer(
            hidden_states,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
class FlaxBeitPreTrainedModel(FlaxPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 使用BeitConfig作为配置类
    config_class = BeitConfig
    # base_model_prefix指定基础模型的前缀为"beit"
    base_model_prefix = "beit"
    # main_input_name指定主要输入名称为"pixel_values"
    main_input_name = "pixel_values"
    # module_class用于存储模块类,初始时未指定
    module_class: nn.Module = None

    def __init__(
        self,
        config: BeitConfig,
        input_shape=None,
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        # 根据配置类和其他参数初始化模块
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        # 如果未提供输入形状,则使用默认形状
        if input_shape is None:
            input_shape = (1, config.image_size, config.image_size, config.num_channels)
        # 调用父类的初始化方法,传递配置、模块、输入形状、种子、数据类型等参数
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # 初始化像素值张量
        pixel_values = jnp.zeros(input_shape, dtype=self.dtype)

        # 分割随机数生成器,用于不同的参数初始化
        params_rng, dropout_rng = jax.random.split(rng)
        dropout_rng, droppath_rng = jax.random.split(dropout_rng)
        rngs = {"params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng}

        # 使用模块的初始化方法初始化随机参数
        random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"]

        if params is not None:
            # 如果提供了参数,则将随机初始化的参数与提供的参数进行合并
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            # 如果未提供参数,则直接返回随机初始化的参数
            return random_params

    @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    def __call__(
        self,
        pixel_values,
        bool_masked_pos=None,
        params: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        ):
            # 如果 output_attentions 参数为 None,则使用 self.config.output_attentions 的默认值
            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
            # 如果 output_hidden_states 参数为 None,则使用 self.config.output_hidden_states 的默认值
            output_hidden_states = (
                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
            )
            # 如果 return_dict 参数为 None,则使用 self.config.return_dict 的默认值
            return_dict = return_dict if return_dict is not None else self.config.return_dict

            # 将像素值张量进行转置,调整通道顺序为 (batch_size, height, width, channels)
            pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
            # 如果需要处理任何 PRNG(伪随机数发生器),则初始化一个空字典用于存储不同的 PRNG
            rngs = {}
            if dropout_rng is not None:
                # 如果 dropout_rng 不为 None,则使用 JAX 提供的随机数分割函数拆分 PRNG
                dropout_rng, droppath_rng = jax.random.split(dropout_rng)
                rngs["dropout"] = dropout_rng
                rngs["droppath"] = droppath_rng

            # 调用模块的 apply 方法,传递参数、像素值张量、布尔掩码、训练标志、输出注意力、隐藏状态、返回字典和 PRNGs
            return self.module.apply(
                {"params": params or self.params},  # 模型参数
                jnp.array(pixel_values, dtype=jnp.float32),  # 像素值张量,转换为 JAX 的 float32 类型数组
                bool_masked_pos,  # 布尔掩码,指示哪些位置需要屏蔽
                not train,  # 是否为推断模式(非训练模式)
                output_attentions,  # 是否输出注意力权重
                output_hidden_states,  # 是否输出隐藏状态
                return_dict,  # 是否以字典形式返回输出
                rngs=rngs,  # PRNGs 字典,用于模型中的随机数生成
            )
# 定义了一个用于池化操作的 FlaxBeitPooler 类,继承自 nn.Module
class FlaxBeitPooler(nn.Module):
    # 用于配置的 BeitConfig 对象
    config: BeitConfig
    # 计算中使用的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32  # 计算的数据类型

    # 初始化方法
    def setup(self):
        # 如果配置中使用了均值池化
        if self.config.use_mean_pooling:
            # 初始化一个 LayerNorm 层,用于对池化后的输出进行归一化,设定 epsilon 值为配置中的 layer_norm_eps
            self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)

    # 调用方法,实现池化操作
    def __call__(self, hidden_states):
        # 如果配置中使用了均值池化
        if self.config.use_mean_pooling:
            # 提取除第一个 token 以外的所有 token 的隐藏状态,进行均值池化操作
            patch_tokens = hidden_states[:, 1:, :]
            pooled_output = self.layernorm(jnp.mean(patch_tokens, axis=1))  # 对所有 token 的均值进行 LayerNorm 归一化
        else:
            # 否则,直接使用第一个 token 的隐藏状态作为池化输出
            pooled_output = hidden_states[:, 0]

        return pooled_output


# 定义了一个 FlaxBeitModule 类,继承自 nn.Module
class FlaxBeitModule(nn.Module):
    # 用于配置的 BeitConfig 对象
    config: BeitConfig
    # 计算中使用的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32  # 计算的数据类型
    # 是否添加池化层,默认为 True
    add_pooling_layer: bool = True

    # 初始化方法
    def setup(self):
        # 初始化嵌入层对象,使用配置和数据类型作为参数
        self.embeddings = FlaxBeitEmbeddings(self.config, dtype=self.dtype)
        # 初始化编码器对象,使用配置、窗口大小和数据类型作为参数
        self.encoder = FlaxBeitEncoder(
            self.config, window_size=self.embeddings.patch_embeddings.patch_shape, dtype=self.dtype
        )
        # 如果不使用均值池化,初始化一个 LayerNorm 层,用于对编码器输出进行归一化,设定 epsilon 值为配置中的 layer_norm_eps
        if not self.config.use_mean_pooling:
            self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        # 如果需要添加池化层,初始化一个 FlaxBeitPooler 对象
        self.pooler = FlaxBeitPooler(self.config, dtype=self.dtype) if self.add_pooling_layer else None

    # 调用方法,实现模型的前向传播
    def __call__(
        self,
        pixel_values,
        bool_masked_pos=None,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 使用嵌入层对象,根据输入像素值和掩码位置,获取隐藏状态
        hidden_states = self.embeddings(pixel_values, bool_masked_pos, deterministic=deterministic)

        # 使用编码器对象,对隐藏状态进行编码处理,返回输出对象
        outputs = self.encoder(
            hidden_states,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 获取编码器的最终隐藏状态作为处理后的隐藏状态
        hidden_states = outputs[0]

        # 如果不使用均值池化,对隐藏状态进行 LayerNorm 归一化处理
        if not self.config.use_mean_pooling:
            hidden_states = self.layernorm(hidden_states)

        # 如果需要添加池化层,对处理后的隐藏状态进行池化操作
        pooled = self.pooler(hidden_states) if self.add_pooling_layer else None

        # 如果不返回字典形式的输出
        if not return_dict:
            # 如果池化结果为空,则不返回池化结果
            if pooled is None:
                return (hidden_states,) + outputs[1:]
            # 否则返回隐藏状态和池化结果以及其他输出
            return (hidden_states, pooled) + outputs[1:]

        # 返回带有池化输出的 FlaxBeitModelOutputWithPooling 对象
        return FlaxBeitModelOutputWithPooling(
            last_hidden_state=hidden_states,
            pooler_output=pooled,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


# 添加文档字符串注释,描述了 FlaxBeitModel 类的基本信息和用法
@add_start_docstrings(
    "The bare Beit Model transformer outputting raw hidden-states without any specific head on top.",
    BEIT_START_DOCSTRING,
)
# 定义了 FlaxBeitModel 类,继承自 FlaxBeitPreTrainedModel
class FlaxBeitModel(FlaxBeitPreTrainedModel):
    # 指定模块的类为 FlaxBeitModule
    module_class = FlaxBeitModule


# 定义了 FLAX_BEIT_MODEL_DOCSTRING,包含返回值和示例信息
FLAX_BEIT_MODEL_DOCSTRING = """
    Returns:

    Examples:

    ```
    # 导入所需的库和模块
    >>> from transformers import AutoImageProcessor, FlaxBeitModel
    >>> from PIL import Image
    >>> import requests
    
    # 定义要处理的图像的 URL
    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    # 使用 requests 模块获取图像的原始字节流,并使用 PIL 打开图像
    >>> image = Image.open(requests.get(url, stream=True).raw)
    
    # 从预训练模型加载图像处理器(AutoImageProcessor)
    >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k")
    # 从预训练模型加载 BEiT 模型(FlaxBeitModel)
    >>> model = FlaxBeitModel.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k")
    
    # 使用图像处理器处理图像,返回 NumPy 张量作为输入
    >>> inputs = image_processor(images=image, return_tensors="np")
    # 使用 BEiT 模型进行推理,输入处理后的图像数据
    >>> outputs = model(**inputs)
    # 获取模型输出中的最后一个隐藏状态
    >>> last_hidden_states = outputs.last_hidden_state
"""

# 调用函数覆盖文档字符串,将 FlaxBeitModel 的文档字符串替换为 FLAX_BEIT_MODEL_DOCSTRING 的内容
overwrite_call_docstring(FlaxBeitModel, FLAX_BEIT_MODEL_DOCSTRING)

# 附加和替换函数返回值的文档字符串,指定输出类型为 FlaxBeitModelOutputWithPooling,配置类为 BeitConfig
append_replace_return_docstrings(FlaxBeitModel, output_type=FlaxBeitModelOutputWithPooling, config_class=BeitConfig)

# FlaxBeitForMaskedImageModelingModule 类定义
class FlaxBeitForMaskedImageModelingModule(nn.Module):
    # 类的配置参数为 BeitConfig 类型
    config: BeitConfig
    # 计算过程中的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 模块初始化方法
    def setup(self):
        # 创建 FlaxBeitModule 实例,不添加池化层,并使用指定的数据类型
        self.beit = FlaxBeitModule(self.config, add_pooling_layer=False, dtype=self.dtype)

        # 分类器头部初始化
        self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)  # LayerNorm 初始化
        self.lm_head = nn.Dense(
            self.config.vocab_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),  # 使用正态分布初始化权重
            dtype=self.dtype,
        )

    # 对象调用方法
    def __call__(
        self,
        pixel_values=None,
        bool_masked_pos=None,
        deterministic: bool = True,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        # 如果 return_dict 为 None,则使用配置中指定的 return_dict 参数
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用 self.beit 对象进行前向计算
        outputs = self.beit(
            pixel_values,
            bool_masked_pos,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取模型输出中的序列输出
        sequence_output = outputs[0]
        # 应用 Layernorm
        sequence_output = self.layernorm(sequence_output)
        # 对序列输出进行预测得分计算,去除第一个位置的特殊标记
        prediction_scores = self.lm_head(sequence_output[:, 1:])

        # 如果不使用 return_dict,则返回预测得分和额外的输出状态
        if not return_dict:
            output = (prediction_scores,) + outputs[2:]
            return output

        # 使用 FlaxMaskedLMOutput 类封装返回结果,包括预测得分、隐藏状态和注意力权重
        return FlaxMaskedLMOutput(
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


# 添加文档字符串说明到 FlaxBeitForMaskedImageModeling 类
@add_start_docstrings(
    "Beit Model transformer with a 'language' modeling head on top (to predict visual tokens).",
    BEIT_START_DOCSTRING,
)
class FlaxBeitForMaskedImageModeling(FlaxBeitPreTrainedModel):
    module_class = FlaxBeitForMaskedImageModelingModule


# 定义 FLAX_BEIT_MLM_DOCSTRING,提供 Beit 模型的文档字符串信息
FLAX_BEIT_MLM_DOCSTRING = """
    bool_masked_pos (`numpy.ndarray` of shape `(batch_size, num_patches)`):
        Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).

    Returns:
        Beit 模型的输出结果,包含 logits、hidden_states 和 attentions。

    Examples:

    ```
    >>> from transformers import AutoImageProcessor, BeitForMaskedImageModeling
    >>> from PIL import Image
    >>> import requests

    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)

    >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k")
    >>> model = BeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k")

    >>> inputs = image_processor(images=image, return_tensors="np")
    >>> outputs = model(**inputs)
    >>> logits = outputs.logits
    ```
"""
# 使用 overwrite_call_docstring 函数,将 FlaxBeitForMaskedImageModeling 类的文档字符串替换为 FLAX_BEIT_MLM_DOCSTRING 中定义的文档字符串
overwrite_call_docstring(FlaxBeitForMaskedImageModeling, FLAX_BEIT_MLM_DOCSTRING)

# 使用 append_replace_return_docstrings 函数,为 FlaxBeitForMaskedImageModeling 类附加或替换输出类型为 FlaxMaskedLMOutput 和配置类为 BeitConfig 的文档字符串
append_replace_return_docstrings(
    FlaxBeitForMaskedImageModeling, output_type=FlaxMaskedLMOutput, config_class=BeitConfig
)


class FlaxBeitForImageClassificationModule(nn.Module):
    config: BeitConfig
    dtype: jnp.dtype = jnp.float32

    # 设置方法,初始化模块中的 Beit 和分类器组件
    def setup(self):
        self.beit = FlaxBeitModule(config=self.config, dtype=self.dtype, add_pooling_layer=True)
        self.classifier = nn.Dense(
            self.config.num_labels,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )

    # 调用方法,接受多个输入参数,并根据配置返回相应的输出
    def __call__(
        self,
        pixel_values=None,
        bool_masked_pos=None,
        deterministic: bool = True,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        # 根据配置或者默认设置 return_dict
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用 self.beit,传递参数并接收输出
        outputs = self.beit(
            pixel_values,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 提取池化后的输出
        pooled_output = outputs[1]
        # 使用分类器计算 logits
        logits = self.classifier(pooled_output)

        # 如果 return_dict 为 False,则返回 logits 和其他输出
        if not return_dict:
            output = (logits,) + outputs[2:]
            return output

        # 如果 return_dict 为 True,则返回 FlaxSequenceClassifierOutput 类的实例,包含 logits、hidden_states 和 attentions
        return FlaxSequenceClassifierOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


# 附加或替换 FlaxBeitForImageClassification 类的文档字符串,包括 BEIT_START_DOCSTRING 和从 FLAX_BEIT_CLASSIF_DOCSTRING 中定义的描述
@add_start_docstrings(
    """
    Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final
    hidden states of the patch tokens) e.g. for ImageNet.
    """,
    BEIT_START_DOCSTRING,
)
class FlaxBeitForImageClassification(FlaxBeitPreTrainedModel):
    module_class = FlaxBeitForImageClassificationModule


# 将 FLAX_BEIT_CLASSIF_DOCSTRING 中定义的文档字符串替换为 FlaxBeitForImageClassification 类的文档字符串
overwrite_call_docstring(FlaxBeitForImageClassification, FLAX_BEIT_CLASSIF_DOCSTRING)

# 使用 append_replace_return_docstrings 函数,为 FlaxBeitForImageClassification 类附加或替换输出类型为 FlaxSequenceClassifierOutput 的文档字符串
append_replace_return_docstrings(
    # 导入FlaxBeitForImageClassification类,指定输出类型为FlaxSequenceClassifierOutput,使用BeitConfig配置类
    FlaxBeitForImageClassification, output_type=FlaxSequenceClassifierOutput, config_class=BeitConfig
# 创建一个名为文件的空列表
files = []
# 遍历整数 i 从 0 到 9(不包括 10)
for i in range(10):
    # 向文件列表添加字符串形式的 i
    files.append(f"file_{i}.txt")

.\models\beit\__init__.py

# 引入类型检查模块,用于条件类型检查
from typing import TYPE_CHECKING

# 从工具模块中引入必要的依赖,包括自定义的异常和延迟加载模块
from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_flax_available,
    is_torch_available,
    is_vision_available,
)

# 定义一个字典,用于存储导入结构,包含待导入模块的名称和对应的成员列表
_import_structure = {"configuration_beit": ["BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BeitConfig", "BeitOnnxConfig"]}

# 检查视觉处理模块是否可用,若不可用则抛出自定义的依赖不可用异常
try:
    if not is_vision_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若可用,则添加视觉特征提取模块和图像处理模块到导入结构中
    _import_structure["feature_extraction_beit"] = ["BeitFeatureExtractor"]
    _import_structure["image_processing_beit"] = ["BeitImageProcessor"]

# 检查 Torch 是否可用,若不可用则抛出自定义的依赖不可用异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若可用,则添加 BEIT 模型相关模块到导入结构中
    _import_structure["modeling_beit"] = [
        "BEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
        "BeitForImageClassification",
        "BeitForMaskedImageModeling",
        "BeitForSemanticSegmentation",
        "BeitModel",
        "BeitPreTrainedModel",
        "BeitBackbone",
    ]

# 检查 Flax 是否可用,若不可用则抛出自定义的依赖不可用异常
try:
    if not is_flax_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若可用,则添加 Flax BEIT 模型相关模块到导入结构中
    _import_structure["modeling_flax_beit"] = [
        "FlaxBeitForImageClassification",
        "FlaxBeitForMaskedImageModeling",
        "FlaxBeitModel",
        "FlaxBeitPreTrainedModel",
    ]

# 若在类型检查环境下,则添加详细的导入语句以满足类型检查的需求
if TYPE_CHECKING:
    # 导入 BEIT 配置相关的类和常量
    from .configuration_beit import BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BeitConfig, BeitOnnxConfig

    try:
        # 检查视觉处理模块是否可用
        if not is_vision_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入视觉特征提取和图像处理模块
        from .feature_extraction_beit import BeitFeatureExtractor
        from .image_processing_beit import BeitImageProcessor

    try:
        # 检查 Torch 是否可用
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入 BEIT 模型相关的类和常量
        from .modeling_beit import (
            BEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
            BeitBackbone,
            BeitForImageClassification,
            BeitForMaskedImageModeling,
            BeitForSemanticSegmentation,
            BeitModel,
            BeitPreTrainedModel,
        )

    try:
        # 检查 Flax 是否可用
        if not is_flax_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 从当前目录中导入模块
        from .modeling_flax_beit import (
            FlaxBeitForImageClassification,
            FlaxBeitForMaskedImageModeling,
            FlaxBeitModel,
            FlaxBeitPreTrainedModel,
        )
else:
    # 导入 sys 模块,用于动态配置模块信息
    import sys
    # 将当前模块 (__name__) 的内容替换为 _LazyModule 的实例,
    # 以延迟加载模块内容,传入模块名、模块文件名、导入结构和模块规范
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\bert\configuration_bert.py

# coding=utf-8
# 声明版权和许可信息

""" BERT模型配置 """
# 导入必要的模块
from collections import OrderedDict  # 导入OrderedDict类,用于创建有序字典
from typing import Mapping  # 导入Mapping类型提示

from ...configuration_utils import PretrainedConfig  # 导入预训练配置类
from ...onnx import OnnxConfig  # 导入ONNX配置
from ...utils import logging  # 导入日志工具

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# BERT预训练模型配置文件映射
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "google-bert/bert-base-uncased": "https://huggingface.co/google-bert/bert-base-uncased/resolve/main/config.json",
    "google-bert/bert-large-uncased": "https://huggingface.co/google-bert/bert-large-uncased/resolve/main/config.json",
    "google-bert/bert-base-cased": "https://huggingface.co/google-bert/bert-base-cased/resolve/main/config.json",
    "google-bert/bert-large-cased": "https://huggingface.co/google-bert/bert-large-cased/resolve/main/config.json",
    "google-bert/bert-base-multilingual-uncased": "https://huggingface.co/google-bert/bert-base-multilingual-uncased/resolve/main/config.json",
    "google-bert/bert-base-multilingual-cased": "https://huggingface.co/google-bert/bert-base-multilingual-cased/resolve/main/config.json",
    "google-bert/bert-base-chinese": "https://huggingface.co/google-bert/bert-base-chinese/resolve/main/config.json",
    "google-bert/bert-base-german-cased": "https://huggingface.co/google-bert/bert-base-german-cased/resolve/main/config.json",
    "google-bert/bert-large-uncased-whole-word-masking": (
        "https://huggingface.co/google-bert/bert-large-uncased-whole-word-masking/resolve/main/config.json"
    ),
    "google-bert/bert-large-cased-whole-word-masking": (
        "https://huggingface.co/google-bert/bert-large-cased-whole-word-masking/resolve/main/config.json"
    ),
    "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad": (
        "https://huggingface.co/google-bert/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/config.json"
    ),
    "google-bert/bert-large-cased-whole-word-masking-finetuned-squad": (
        "https://huggingface.co/google-bert/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/config.json"
    ),
    "google-bert/bert-base-cased-finetuned-mrpc": "https://huggingface.co/google-bert/bert-base-cased-finetuned-mrpc/resolve/main/config.json",
    # 定义一个字典,将不同的BERT模型名称映射到其对应的配置文件URL
    "google-bert/bert-base-german-dbmdz-cased": "https://huggingface.co/google-bert/bert-base-german-dbmdz-cased/resolve/main/config.json",
    "google-bert/bert-base-german-dbmdz-uncased": "https://huggingface.co/google-bert/bert-base-german-dbmdz-uncased/resolve/main/config.json",
    "cl-tohoku/bert-base-japanese": "https://huggingface.co/cl-tohoku/bert-base-japanese/resolve/main/config.json",
    # 使用整词掩码技术的日语BERT模型的配置文件URL
    "cl-tohoku/bert-base-japanese-whole-word-masking": (
        "https://huggingface.co/cl-tohoku/bert-base-japanese-whole-word-masking/resolve/main/config.json"
    ),
    # 使用字符级整词掩码技术的日语BERT模型的配置文件URL
    "cl-tohoku/bert-base-japanese-char": (
        "https://huggingface.co/cl-tohoku/bert-base-japanese-char/resolve/main/config.json"
    ),
    # 使用字符级整词掩码技术的日语BERT模型的配置文件URL
    "cl-tohoku/bert-base-japanese-char-whole-word-masking": (
        "https://huggingface.co/cl-tohoku/bert-base-japanese-char-whole-word-masking/resolve/main/config.json"
    ),
    # 芬兰语大小写BERT模型的配置文件URL
    "TurkuNLP/bert-base-finnish-cased-v1": (
        "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/config.json"
    ),
    # 芬兰语小写BERT模型的配置文件URL
    "TurkuNLP/bert-base-finnish-uncased-v1": (
        "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/config.json"
    ),
    # 荷兰语大小写BERT模型的配置文件URL
    "wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/config.json",
    # 查看所有BERT模型的链接,可以在这里找到更多信息:https://huggingface.co/models?filter=bert
# 类定义:BertConfig,继承自PretrainedConfig,用于存储BERT模型的配置信息
class BertConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to
    instantiate a BERT model according to the specified arguments, defining the model architecture. Instantiating a
    configuration with the defaults will yield a similar configuration to that of the BERT
    [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    
    Examples:  # 示例代码

    ```
    >>> from transformers import BertConfig, BertModel

    >>> # Initializing a BERT google-bert/bert-base-uncased style configuration
    >>> configuration = BertConfig()

    >>> # Initializing a model (with random weights) from the google-bert/bert-base-uncased style configuration
    >>> model = BertModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    model_type = "bert"  # 模型类型设置为"bert"

    # 初始化函数,设置Bert模型的各种配置参数
    def __init__(
        self,
        vocab_size=30522,
        hidden_size=768,
        num_hidden_layers=12,
        num_attention_heads=12,
        intermediate_size=3072,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=512,
        type_vocab_size=2,
        initializer_range=0.02,
        layer_norm_eps=1e-12,
        pad_token_id=0,
        position_embedding_type="absolute",
        use_cache=True,
        classifier_dropout=None,
        **kwargs,
    ):
        super().__init__(pad_token_id=pad_token_id, **kwargs)  # 调用父类的初始化函数

        # 设置BertConfig的各项配置参数
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.hidden_act = hidden_act
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps
        self.position_embedding_type = position_embedding_type
        self.use_cache = use_cache
        self.classifier_dropout = classifier_dropout


# 类定义:BertOnnxConfig,继承自OnnxConfig
class BertOnnxConfig(OnnxConfig):
    @property
    # 定义一个方法 inputs,返回一个字典结构
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        # 如果任务类型是多选题
        if self.task == "multiple-choice":
            # 定义动态轴的顺序,包括批次、选择和序列
            dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
        else:
            # 否则,定义动态轴的顺序,包括批次和序列
            dynamic_axis = {0: "batch", 1: "sequence"}
        # 返回一个有序字典,包含输入数据的名称和对应的动态轴顺序
        return OrderedDict(
            [
                ("input_ids", dynamic_axis),         # 输入的标识符 ID,使用动态轴顺序
                ("attention_mask", dynamic_axis),    # 注意力掩码,使用动态轴顺序
                ("token_type_ids", dynamic_axis),    # 令牌类型 ID,使用动态轴顺序
            ]
        )

.\models\bert\convert_bert_original_tf2_checkpoint_to_pytorch.py

    # 版权声明和许可信息
    """
    Copyright 2020 The HuggingFace Team. All rights reserved.
    
    Licensed under the Apache License, Version 2.0 (the "License");
    you may not use this file except in compliance with the License.
    You may obtain a copy of the License at
    
        http://www.apache.org/licenses/LICENSE-2.0
    
    Unless required by applicable law or agreed to in writing, software
    distributed under the License is distributed on an "AS IS" BASIS,
    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    See the License for the specific language governing permissions and
    limitations under the License.
    """
    
    # 引入所需库和模块
    import argparse  # 解析命令行参数的库
    import os  # 操作系统相关功能的库
    import re  # 正则表达式的库
    
    import tensorflow as tf  # TensorFlow 深度学习框架
    import torch  # PyTorch 深度学习框架
    
    from transformers import BertConfig, BertModel  # Hugging Face 提供的 Bert 相关类
    from transformers.utils import logging  # Hugging Face 提供的日志功能
    
    logging.set_verbosity_info()  # 设置日志记录级别为信息
    logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器
    

def load_tf2_weights_in_bert(model, tf_checkpoint_path, config):
    # 获取 TensorFlow 检查点文件的绝对路径
    tf_path = os.path.abspath(tf_checkpoint_path)
    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")  # 记录日志,显示转换的 TensorFlow 检查点路径
    
    # 从 TF 模型中加载权重
    init_vars = tf.train.list_variables(tf_path)  # 列出 TensorFlow 模型中的所有变量名和形状
    names = []  # 存储变量名
    arrays = []  # 存储加载的变量数组
    layer_depth = []  # 存储每个变量名的层级深度
    
    # 遍历每个变量名和形状
    for full_name, shape in init_vars:
        # logger.info(f"Loading TF weight {name} with shape {shape}")
        name = full_name.split("/")  # 按斜杠分割变量名,获取各级名称
        
        # 如果是特定的非模型层或优化层,则跳过加载
        if full_name == "_CHECKPOINTABLE_OBJECT_GRAPH" or name[0] in ["global_step", "save_counter"]:
            logger.info(f"Skipping non-model layer {full_name}")
            continue
        if "optimizer" in full_name:
            logger.info(f"Skipping optimization layer {full_name}")
            continue
        if name[0] == "model":
            # 忽略初始的 'model' 层级
            name = name[1:]
        
        # 计算变量名的层级深度
        depth = 0
        for _name in name:
            if _name.startswith("layer_with_weights"):
                depth += 1
            else:
                break
        layer_depth.append(depth)
        
        # 加载变量数据
        array = tf.train.load_variable(tf_path, full_name)
        names.append("/".join(name))  # 将分割后的名称重新连接为字符串形式
        arrays.append(array)  # 将加载的变量数组添加到列表中
    
    logger.info(f"Read a total of {len(arrays):,} layers")  # 记录日志,显示总共加载了多少层变量

    # 进行完整性检查
    # 检查层深度列表中是否存在不同的深度值,如果存在则抛出数值错误异常
    if len(set(layer_depth)) != 1:
        raise ValueError(f"Found layer names with different depths (layer depth {list(set(layer_depth))})")
    
    # 将层深度列表转换为集合去重,然后转换回列表,并获取唯一的深度值
    layer_depth = list(set(layer_depth))[0]
    
    # 检查模型的层深度是否为1,如果不是则抛出数值错误异常,说明模型包含了除了嵌入/编码器层之外的其他层
    if layer_depth != 1:
        raise ValueError(
            "The model contains more than just the embedding/encoder layers. This script does not handle MLM/NSP"
            " heads."
        )

    # 输出日志信息,表明开始转换权重
    logger.info("Converting weights...")
    
    # 返回已转换的模型对象
    return model
# 将 TensorFlow 2.x 的检查点文件转换为 PyTorch 模型的函数
def convert_tf2_checkpoint_to_pytorch(tf_checkpoint_path, config_path, pytorch_dump_path):
    # 打印日志信息,加载基于指定配置文件的模型
    logger.info(f"Loading model based on config from {config_path}...")
    # 从 JSON 文件中加载配置信息
    config = BertConfig.from_json_file(config_path)
    # 根据配置创建 BertModel 实例
    model = BertModel(config)

    # 打印日志信息,加载 TensorFlow 2.x 检查点的权重
    logger.info(f"Loading weights from checkpoint {tf_checkpoint_path}...")
    # 调用函数加载 TensorFlow 2.x 检查点中的权重到 PyTorch 模型中
    load_tf2_weights_in_bert(model, tf_checkpoint_path, config)

    # 打印日志信息,保存 PyTorch 模型
    logger.info(f"Saving PyTorch model to {pytorch_dump_path}...")
    # 使用 PyTorch 的函数保存模型的状态字典到指定路径
    torch.save(model.state_dict(), pytorch_dump_path)


if __name__ == "__main__":
    # 创建参数解析器
    parser = argparse.ArgumentParser()
    # 添加命令行参数,指定 TensorFlow 2.x 检查点路径
    parser.add_argument(
        "--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow 2.x checkpoint path."
    )
    # 添加命令行参数,指定 BERT 模型的配置文件路径
    parser.add_argument(
        "--bert_config_file",
        type=str,
        required=True,
        help="The config json file corresponding to the BERT model. This specifies the model architecture.",
    )
    # 添加命令行参数,指定输出的 PyTorch 模型路径(包括文件名)
    parser.add_argument(
        "--pytorch_dump_path",
        type=str,
        required=True,
        help="Path to the output PyTorch model (must include filename).",
    )
    # 解析命令行参数
    args = parser.parse_args()
    # 调用转换函数,传入解析得到的参数
    convert_tf2_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

.\models\bert\convert_bert_original_tf_checkpoint_to_pytorch.py

# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BERT checkpoint."""


import argparse  # 导入用于处理命令行参数的模块

import torch  # 导入 PyTorch 库

from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert  # 导入转换所需的类和函数
from transformers.utils import logging  # 导入日志记录工具


logging.set_verbosity_info()  # 设置日志记录的详细程度为 info


def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
    # 初始化一个 PyTorch 模型
    config = BertConfig.from_json_file(bert_config_file)
    print(f"Building PyTorch model from configuration: {config}")  # 打印模型配置信息
    model = BertForPreTraining(config)  # 使用配置创建 BertForPreTraining 模型对象

    # 从 TensorFlow checkpoint 中加载权重
    load_tf_weights_in_bert(model, config, tf_checkpoint_path)

    # 保存 PyTorch 模型
    print(f"Save PyTorch model to {pytorch_dump_path}")  # 打印保存路径信息
    torch.save(model.state_dict(), pytorch_dump_path)  # 将模型的状态字典保存到指定路径


if __name__ == "__main__":
    parser = argparse.ArgumentParser()  # 创建参数解析器对象
    # 必选参数
    parser.add_argument(
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
    )  # 添加 TensorFlow checkpoint 路径参数
    parser.add_argument(
        "--bert_config_file",
        default=None,
        type=str,
        required=True,
        help=(
            "The config json file corresponding to the pre-trained BERT model. \n"
            "This specifies the model architecture."
        ),
    )  # 添加 BERT 配置文件路径参数
    parser.add_argument(
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
    )  # 添加输出 PyTorch 模型路径参数
    args = parser.parse_args()  # 解析命令行参数
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)  # 调用转换函数并传入参数

.\models\bert\convert_bert_pytorch_checkpoint_to_original_tf.py

# 设置 Python 文件的编码格式为 UTF-8
# Copyright 2018 The HuggingFace Inc. team.
# 声明脚本使用 Apache 2.0 版本许可协议,详见链接
# 只有符合许可协议的情况下才能使用本文件
# 请访问上述链接获取详细信息
# 除非法律另有规定或书面同意,否则不得使用此文件
# 此文件按原样发布,不附带任何形式的保证或条件
# 详见许可协议以了解更多信息

"""Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint."""
# 导入需要的库和模块
import argparse  # 导入命令行参数解析模块
import os  # 导入操作系统相关功能模块

import numpy as np  # 导入数值计算库 NumPy
import tensorflow as tf  # 导入 TensorFlow 深度学习框架
import torch  # 导入 PyTorch 深度学习框架

from transformers import BertModel  # 从 transformers 库中导入 BertModel 类


def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
    """
    Args:
        model: BertModel Pytorch model instance to be converted
        ckpt_dir: Tensorflow model directory
        model_name: model name

    Currently supported HF models:

        - Y BertModel
        - N BertForMaskedLM
        - N BertForPreTraining
        - N BertForMultipleChoice
        - N BertForNextSentencePrediction
        - N BertForSequenceClassification
        - N BertForQuestionAnswering
    """
    # 定义需要转置的张量名称列表
    tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")

    # 定义变量名称映射规则列表
    var_map = (
        ("layer.", "layer_"),
        ("word_embeddings.weight", "word_embeddings"),
        ("position_embeddings.weight", "position_embeddings"),
        ("token_type_embeddings.weight", "token_type_embeddings"),
        (".", "/"),
        ("LayerNorm/weight", "LayerNorm/gamma"),
        ("LayerNorm/bias", "LayerNorm/beta"),
        ("weight", "kernel"),
    )

    # 如果指定的 TensorFlow 模型目录不存在,则创建该目录
    if not os.path.isdir(ckpt_dir):
        os.makedirs(ckpt_dir)

    # 获取 PyTorch 模型的状态字典
    state_dict = model.state_dict()

    def to_tf_var_name(name: str):
        # 根据变量映射规则将 PyTorch 变量名转换为 TensorFlow 变量名
        for patt, repl in iter(var_map):
            name = name.replace(patt, repl)
        return f"bert/{name}"

    def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
        # 根据张量的数据类型和形状,在 TensorFlow 中创建新的变量
        tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
        tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
        session.run(tf.variables_initializer([tf_var]))
        session.run(tf_var)
        return tf_var

    # 重置 TensorFlow 默认计算图
    tf.reset_default_graph()
    # 使用 TensorFlow 创建一个会话(Session),并将其命名为 session
    with tf.Session() as session:
        # 遍历 state_dict 中的每一个变量名
        for var_name in state_dict:
            # 将变量名 var_name 转换为 TensorFlow 的变量名 tf_name
            tf_name = to_tf_var_name(var_name)
            # 将 PyTorch 张量转换为 NumPy 数组,存储在 torch_tensor 中
            torch_tensor = state_dict[var_name].numpy()
            # 如果 var_name 中包含在 tensors_to_transpose 中的任何字符串,则对 torch_tensor 进行转置操作
            if any(x in var_name for x in tensors_to_transpose):
                torch_tensor = torch_tensor.T
            # 使用 create_tf_var 函数在 TensorFlow 中创建变量 tf_var,使用 session 进行管理
            tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
            # 将 torch_tensor 转换为 tf_var 的数据类型,并赋值给 tf_var
            tf_var.assign(tf.cast(torch_tensor, tf_var.dtype))
            # 在 TensorFlow 中运行 tf_var,将结果存储在 tf_weight 中
            tf_weight = session.run(tf_var)
            # 打印成功创建的 TensorFlow 变量 tf_name 和其与 torch_tensor 是否全部接近的比较结果
            print(f"Successfully created {tf_name}: {np.allclose(tf_weight, torch_tensor)}")

        # 使用 tf.train.Saver() 创建一个 Saver 对象,保存所有可训练变量的状态
        saver = tf.train.Saver(tf.trainable_variables())
        # 将这些变量的状态保存到指定的文件路径下,文件名为 model_name 替换 '-' 为 '_' 后加上 '.ckpt' 后缀
        saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
# 主函数,程序的入口点
def main(raw_args=None):
    # 创建参数解析器对象
    parser = argparse.ArgumentParser()
    # 添加解析器的命令行参数:模型名称,必须提供,用于指定模型的名称,如 google-bert/bert-base-uncased
    parser.add_argument("--model_name", type=str, required=True, help="model name e.g. google-bert/bert-base-uncased")
    # 添加解析器的命令行参数:缓存目录,可选,默认为 None,用于指定包含 PyTorch 模型的目录
    parser.add_argument(
        "--cache_dir", type=str, default=None, required=False, help="Directory containing pytorch model"
    )
    # 添加解析器的命令行参数:PyTorch 模型路径,必须提供,用于指定 PyTorch 模型的路径,如 /path/to/<pytorch-model-name>.bin
    parser.add_argument("--pytorch_model_path", type=str, required=True, help="/path/to/<pytorch-model-name>.bin")
    # 添加解析器的命令行参数:TensorFlow 缓存目录,必须提供,用于指定保存 TensorFlow 模型的目录
    parser.add_argument("--tf_cache_dir", type=str, required=True, help="Directory in which to save tensorflow model")
    # 解析命令行参数,将结果存储在 args 中
    args = parser.parse_args(raw_args)

    # 从预训练模型中加载 BertModel 对象
    model = BertModel.from_pretrained(
        pretrained_model_name_or_path=args.model_name,  # 使用命令行参数指定的模型名称或路径
        state_dict=torch.load(args.pytorch_model_path),  # 加载指定路径下的 PyTorch 模型参数
        cache_dir=args.cache_dir,  # 使用命令行参数指定的缓存目录
    )

    # 将 PyTorch 模型转换为 TensorFlow 格式的检查点文件
    convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_cache_dir, model_name=args.model_name)


if __name__ == "__main__":
    main()

.\models\bert\convert_bert_token_dropping_original_tf2_checkpoint_to_pytorch.py

# 打印加载基于给定配置文件的模型信息
print(f"Loading model based on config from {config_path}...")
# 从指定路径加载BERT配置信息
config = BertConfig.from_json_file(config_path)
# 使用加载的配置信息创建一个BertForMaskedLM模型实例
model = BertForMaskedLM(config)

# 以下是待继续完善的部分,涉及模型的各个层次
    # 遍历每个隐藏层的索引,从0到config.num_hidden_layers-1
    for layer_index in range(0, config.num_hidden_layers):
        # 获取当前层的BertLayer对象
        layer: BertLayer = model.bert.encoder.layer[layer_index]

        # Self-attention部分
        # 获取当前层的self-attention模块
        self_attn: BertSelfAttention = layer.attention.self

        # 设置self-attention中的query权重数据
        self_attn.query.weight.data = get_encoder_attention_layer_array(
            layer_index, "_query_dense/kernel", self_attn.query.weight.data.shape
        )
        # 设置self-attention中的query偏置数据
        self_attn.query.bias.data = get_encoder_attention_layer_array(
            layer_index, "_query_dense/bias", self_attn.query.bias.data.shape
        )
        # 设置self-attention中的key权重数据
        self_attn.key.weight.data = get_encoder_attention_layer_array(
            layer_index, "_key_dense/kernel", self_attn.key.weight.data.shape
        )
        # 设置self-attention中的key偏置数据
        self_attn.key.bias.data = get_encoder_attention_layer_array(
            layer_index, "_key_dense/bias", self_attn.key.bias.data.shape
        )
        # 设置self-attention中的value权重数据
        self_attn.value.weight.data = get_encoder_attention_layer_array(
            layer_index, "_value_dense/kernel", self_attn.value.weight.data.shape
        )
        # 设置self-attention中的value偏置数据
        self_attn.value.bias.data = get_encoder_attention_layer_array(
            layer_index, "_value_dense/bias", self_attn.value.bias.data.shape
        )

        # Self-attention输出部分
        # 获取self-attention输出层对象
        self_output: BertSelfOutput = layer.attention.output

        # 设置self-attention输出层中dense层的权重数据
        self_output.dense.weight.data = get_encoder_attention_layer_array(
            layer_index, "_output_dense/kernel", self_output.dense.weight.data.shape
        )
        # 设置self-attention输出层中dense层的偏置数据
        self_output.dense.bias.data = get_encoder_attention_layer_array(
            layer_index, "_output_dense/bias", self_output.dense.bias.data.shape
        )

        # 设置self-attention输出层中LayerNorm的权重数据
        self_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/gamma")
        # 设置self-attention输出层中LayerNorm的偏置数据
        self_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/beta")

        # Intermediate部分
        # 获取当前层的Intermediate对象
        intermediate: BertIntermediate = layer.intermediate

        # 设置Intermediate层中dense层的权重数据
        intermediate.dense.weight.data = get_encoder_layer_array(layer_index, "_intermediate_dense/kernel")
        # 设置Intermediate层中dense层的偏置数据
        intermediate.dense.bias.data = get_encoder_layer_array(layer_index, "_intermediate_dense/bias")

        # Output部分
        # 获取当前层的Output对象
        bert_output: BertOutput = layer.output

        # 设置Output层中dense层的权重数据
        bert_output.dense.weight.data = get_encoder_layer_array(layer_index, "_output_dense/kernel")
        # 设置Output层中dense层的偏置数据
        bert_output.dense.bias.data = get_encoder_layer_array(layer_index, "_output_dense/bias")

        # 设置Output层中LayerNorm的权重数据
        bert_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_output_layer_norm/gamma")
        # 设置Output层中LayerNorm的偏置数据
        bert_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_output_layer_norm/beta")

    # Embeddings部分
    # 设置BERT模型的位置嵌入权重数据
    model.bert.embeddings.position_embeddings.weight.data = get_encoder_array("_position_embedding_layer/embeddings")
    # 设置BERT模型的token类型嵌入权重数据
    model.bert.embeddings.token_type_embeddings.weight.data = get_encoder_array("_type_embedding_layer/embeddings")
    # 设置BERT模型的嵌入层LayerNorm的权重数据
    model.bert.embeddings.LayerNorm.weight.data = get_encoder_array("_embedding_norm_layer/gamma")
    # 设置BERT模型的嵌入层LayerNorm的偏置数据为从文件中获取的编码器数组
    model.bert.embeddings.LayerNorm.bias.data = get_encoder_array("_embedding_norm_layer/beta")

    # LM头部
    lm_head = model.cls.predictions.transform

    # 设置LM头部中dense层的权重数据为从文件中获取的masked LM数组
    lm_head.dense.weight.data = get_masked_lm_array("dense/kernel")
    # 设置LM头部中dense层的偏置数据为从文件中获取的masked LM数组
    lm_head.dense.bias.data = get_masked_lm_array("dense/bias")

    # 设置LM头部中LayerNorm层的权重数据为从文件中获取的masked LM数组
    lm_head.LayerNorm.weight.data = get_masked_lm_array("layer_norm/gamma")
    # 设置LM头部中LayerNorm层的偏置数据为从文件中获取的masked LM数组
    lm_head.LayerNorm.bias.data = get_masked_lm_array("layer_norm/beta")

    # 设置BERT模型的嵌入层中词嵌入权重数据为从文件中获取的masked LM数组
    model.bert.embeddings.word_embeddings.weight.data = get_masked_lm_array("embedding_table")

    # 设置BERT模型的池化层为一个新的BertPooler对象,根据配置信息
    model.bert.pooler = BertPooler(config=config)
    # 设置BERT模型的池化层dense层的权重数据为从文件中获取的编码器数组
    model.bert.pooler.dense.weight.data: BertPooler = get_encoder_array("_pooler_layer/kernel")
    # 设置BERT模型的池化层dense层的偏置数据为从文件中获取的编码器数组
    model.bert.pooler.dense.bias.data: BertPooler = get_encoder_array("_pooler_layer/bias")

    # 导出最终的模型到指定的PyTorch保存路径
    model.save_pretrained(pytorch_dump_path)

    # 集成测试 - 应该能够无错误加载 ;)
    # 从指定的PyTorch保存路径加载一个新的BertForMaskedLM模型
    new_model = BertForMaskedLM.from_pretrained(pytorch_dump_path)
    # 打印新模型的评估结果
    print(new_model.eval())

    # 打印信息:模型转换成功完成!
    print("Model conversion was done successfully!")
if __name__ == "__main__":
    # 当该模块被直接运行时执行以下代码
    parser = argparse.ArgumentParser()
    # 创建参数解析器对象
    parser.add_argument(
        "--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow Token Dropping checkpoint path."
    )
    # 添加命令行参数:TensorFlow Token Dropping 检查点的路径
    parser.add_argument(
        "--bert_config_file",
        type=str,
        required=True,
        help="The config json file corresponding to the BERT model. This specifies the model architecture.",
    )
    # 添加命令行参数:BERT 模型配置文件的路径,指定了模型的架构
    parser.add_argument(
        "--pytorch_dump_path",
        type=str,
        required=True,
        help="Path to the output PyTorch model.",
    )
    # 添加命令行参数:PyTorch 模型输出路径
    args = parser.parse_args()
    # 解析命令行参数,并将其存储在 args 对象中
    convert_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
    # 调用函数 convert_checkpoint_to_pytorch,传入解析得到的参数

.\models\bert\modeling_bert.py

# coding=utf-8
# 版权声明,包括Google AI Language Team和HuggingFace Inc.的版权声明
# 版权声明,包括NVIDIA CORPORATION的版权声明
#
# 根据Apache许可证2.0版(“许可证”)许可使用本文件;
# 除非符合许可证,否则不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件根据“原样”分发,
# 没有任何形式的担保或条件,包括但不限于明示或暗示的任何担保或条件。
# 有关详细信息,请参阅许可证。
"""PyTorch BERT模型。"""


import math
import os
import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

# 从本地库中导入一些函数和类
from ...activations import ACT2FN
from ...modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    BaseModelOutputWithPoolingAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    NextSentencePredictorOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
    ModelOutput,
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
# 从configuration_bert模块导入BertConfig类
from .configuration_bert import BertConfig

# 获取logger对象
logger = logging.get_logger(__name__)

# 以下是用于文档的检查点和配置信息
_CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased"
_CONFIG_FOR_DOC = "BertConfig"

# Token分类任务的文档字符串信息
_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english"
_TOKEN_CLASS_EXPECTED_OUTPUT = (
    "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] "
)
_TOKEN_CLASS_EXPECTED_LOSS = 0.01

# 问答任务的文档字符串信息
_CHECKPOINT_FOR_QA = "deepset/bert-base-cased-squad2"
_QA_EXPECTED_OUTPUT = "'a nice puppet'"
_QA_EXPECTED_LOSS = 7.41
_QA_TARGET_START_INDEX = 14
_QA_TARGET_END_INDEX = 15

# 序列分类任务的文档字符串信息
_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity"
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
_SEQ_CLASS_EXPECTED_LOSS = 0.01

# BERT预训练模型存档列表
BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "google-bert/bert-base-uncased",
    "google-bert/bert-large-uncased",
    "google-bert/bert-base-cased",
    "google-bert/bert-large-cased",
    "google-bert/bert-base-multilingual-uncased",
    "google-bert/bert-base-multilingual-cased",
    "google-bert/bert-base-chinese",
    "google-bert/bert-base-german-cased",
    "google-bert/bert-large-uncased-whole-word-masking",
]
    # 定义一个包含多个字符串的列表,每个字符串表示一个预训练的BERT模型的名称
    [
        "google-bert/bert-large-cased-whole-word-masking",
        "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad",
        "google-bert/bert-large-cased-whole-word-masking-finetuned-squad",
        "google-bert/bert-base-cased-finetuned-mrpc",
        "google-bert/bert-base-german-dbmdz-cased",
        "google-bert/bert-base-german-dbmdz-uncased",
        "cl-tohoku/bert-base-japanese",
        "cl-tohoku/bert-base-japanese-whole-word-masking",
        "cl-tohoku/bert-base-japanese-char",
        "cl-tohoku/bert-base-japanese-char-whole-word-masking",
        "TurkuNLP/bert-base-finnish-cased-v1",
        "TurkuNLP/bert-base-finnish-uncased-v1",
        "wietsedv/bert-base-dutch-cased",
        # 查看所有BERT模型,请访问 https://huggingface.co/models?filter=bert
    ]
# 加载 TensorFlow 权重到 PyTorch 模型中
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
    """Load tf checkpoints in a pytorch model."""
    try:
        import re  # 导入正则表达式模块
        import numpy as np  # 导入 NumPy 模块
        import tensorflow as tf  # 导入 TensorFlow 模块
    except ImportError:
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
        raise
    
    tf_path = os.path.abspath(tf_checkpoint_path)  # 获取 TensorFlow checkpoint 文件的绝对路径
    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")  # 记录日志,显示转换的 TensorFlow checkpoint 路径
    
    # 从 TF 模型中加载权重
    init_vars = tf.train.list_variables(tf_path)  # 获取 TensorFlow checkpoint 中的所有变量名和形状
    names = []
    arrays = []
    
    # 遍历初始化变量,加载变量的值
    for name, shape in init_vars:
        logger.info(f"Loading TF weight {name} with shape {shape}")  # 记录日志,显示加载的 TF 权重名和形状
        array = tf.train.load_variable(tf_path, name)  # 加载 TensorFlow checkpoint 中指定变量的值
        names.append(name)
        arrays.append(array)

    # 将加载的 TensorFlow 权重映射到 PyTorch 模型中的相应位置
    for name, array in zip(names, arrays):
        name = name.split("/")
        
        # 跳过不需要加载的变量,如优化器中的动量信息和全局步数等
        if any(
            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
            for n in name
        ):
            logger.info(f"Skipping {'/'.join(name)}")  # 记录日志,显示跳过的变量名
            continue
        
        pointer = model
        
        # 根据变量名将 TensorFlow 权重映射到 PyTorch 模型的对应位置
        for m_name in name:
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
                scope_names = re.split(r"_(\d+)", m_name)
            else:
                scope_names = [m_name]
                
            # 根据不同的变量名前缀,映射到 PyTorch 模型的不同部分
            if scope_names[0] == "kernel" or scope_names[0] == "gamma":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
                pointer = getattr(pointer, "bias")
            elif scope_names[0] == "output_weights":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "squad":
                pointer = getattr(pointer, "classifier")
            else:
                try:
                    pointer = getattr(pointer, scope_names[0])
                except AttributeError:
                    logger.info(f"Skipping {'/'.join(name)}")  # 记录日志,显示跳过的变量名
                    continue
                    
            # 如果变量名包含下划线加数字的格式,则按数字索引进一步定位
            if len(scope_names) >= 2:
                num = int(scope_names[1])
                pointer = pointer[num]
        
        # 如果变量名以 "_embeddings" 结尾,则映射到 PyTorch 模型的权重部分
        if m_name[-11:] == "_embeddings":
            pointer = getattr(pointer, "weight")
        elif m_name == "kernel":
            array = np.transpose(array)  # 转置数组,用于卷积核权重
        
        # 检查加载的权重形状与 PyTorch 模型对应部分的形状是否匹配
        try:
            if pointer.shape != array.shape:
                raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
        except ValueError as e:
            e.args += (pointer.shape, array.shape)
            raise
        
        logger.info(f"Initialize PyTorch weight {name}")  # 记录日志,显示初始化的 PyTorch 权重名
        pointer.data = torch.from_numpy(array)  # 将 NumPy 数组转换为 PyTorch 张量赋给指针
        
    return model  # 返回加载了 TensorFlow 权重的 PyTorch 模型
    """Construct the embeddings from word, position and token_type embeddings."""

    # 初始化函数,用于构建包含单词、位置和token类型嵌入的模型
    def __init__(self, config):
        super().__init__()
        
        # 创建一个单词嵌入层,用于将单词索引映射到隐藏状态空间
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        
        # 创建一个位置嵌入层,用于将位置索引映射到隐藏状态空间
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        
        # 创建一个token类型嵌入层,用于将token类型索引映射到隐藏状态空间
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # 使用LayerNorm来标准化隐藏状态空间中的每个特征向量,以增强模型的训练效果
        # 注意,这里LayerNorm没有采用蛇形命名,是为了兼容TensorFlow模型变量名,以便加载任何TensorFlow检查点文件
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        
        # Dropout层,用于在训练过程中随机丢弃部分神经元,以防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        
        # 确定位置嵌入的类型,默认为"absolute",但可以在配置中指定
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
        
        # 注册缓冲区,用于存储位置索引,在序列化时可以导出
        self.register_buffer(
            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
        )
        
        # 注册缓冲区,用于存储token类型索引,在序列化时可以导出
        self.register_buffer(
            "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
        )

    # 前向传播函数,定义了模型的正向运算过程
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        past_key_values_length: int = 0,
        ...
    # 定义函数forward,接受input_ids、inputs_embeds、token_type_ids、position_ids和past_key_values_length作为输入,返回torch.Tensor类型的输出
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        past_key_values_length: int = 0,
    ) -> torch.Tensor:
        # 如果传入了input_ids,则获取其尺寸作为input_shape
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            # 否则获取inputs_embeds的所有维度尺寸除了最后一个维度
            input_shape = inputs_embeds.size()[:-1]

        # 获取序列的长度seq_length,即input_shape的第二个维度
        seq_length = input_shape[1]

        # 如果未提供position_ids,则从self.position_ids中选择对应序列长度的部分作为position_ids
        if position_ids is None:
            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]

        # 设置token_type_ids为构造函数中注册的缓冲区,该缓冲区通常为全零,用于在不传递token_type_ids的情况下追踪模型,解决问题#5664
        if token_type_ids is None:
            if hasattr(self, "token_type_ids"):
                # 从self.token_type_ids中获取缓冲的token_type_ids,并扩展到与输入形状相匹配
                buffered_token_type_ids = self.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
                token_type_ids = buffered_token_type_ids_expanded
            else:
                # 如果self中没有token_type_ids属性,则创建全零的token_type_ids张量,dtype为torch.long,设备为self.position_ids的设备
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

        # 如果inputs_embeds为None,则使用self.word_embeddings将input_ids转换为嵌入向量
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        
        # 根据token_type_ids获取对应的token_type_embeddings
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        # 计算嵌入向量embeddings,将inputs_embeds和token_type_embeddings相加
        embeddings = inputs_embeds + token_type_embeddings
        
        # 如果位置嵌入类型为"absolute",则从self.position_embeddings获取对应的位置嵌入并加到embeddings中
        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings
        
        # 对embeddings进行LayerNorm处理
        embeddings = self.LayerNorm(embeddings)
        
        # 对LayerNorm后的embeddings进行dropout处理
        embeddings = self.dropout(embeddings)
        
        # 返回处理后的embeddings作为输出
        return embeddings
class BertSelfAttention(nn.Module):
    def __init__(self, config, position_embedding_type=None):
        super().__init__()
        # 检查隐藏层大小是否能够被注意力头数整除,或者是否具有嵌入大小属性
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        # 初始化注意力头数和每个头的大小
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # 初始化查询、键、值的线性变换层
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        # 初始化 dropout 层
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        
        # 设置位置嵌入类型,默认为绝对位置编码
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        # 如果位置嵌入类型是相对位置编码之一,则初始化距离嵌入层
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

        # 是否为解码器
        self.is_decoder = config.is_decoder

    # 将输入张量重塑为注意力分数的形状
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    # 前向传播函数定义
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,


class BertSelfOutput(nn.Module):
def init(self, config):
super().init()
# 初始化全连接层、LayerNorm 和 dropout
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

# 前向传播函数定义
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
    # 全连接层
    hidden_states = self.dense(hidden_states)
    # dropout
    hidden_states = self.dropout(hidden_states)
    # LayerNorm
    hidden_states = self.LayerNorm(hidden_states + input_tensor)
    return hidden_states
# 初始化方法,接受配置参数和位置嵌入类型,调用父类的初始化方法
def __init__(self, config, position_embedding_type=None):
    super().__init__()
    # 创建 BertSelfAttention 对象,并保存在 self.self 属性中
    self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
    # 创建 BertSelfOutput 对象,并保存在 self.output 属性中
    self.output = BertSelfOutput(config)
    # 初始化一个空集合,用于存储被剪枝的注意力头索引
    self.pruned_heads = set()

# 剪枝注意力头的方法
def prune_heads(self, heads):
    # 如果传入的头索引集合为空,直接返回
    if len(heads) == 0:
        return
    # 调用函数找到可剪枝的头部索引和对应的位置索引
    heads, index = find_pruneable_heads_and_indices(
        heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
    )

    # 对 BertSelfAttention 中的 query、key、value 线性层进行剪枝操作
    self.self.query = prune_linear_layer(self.self.query, index)
    self.self.key = prune_linear_layer(self.self.key, index)
    self.self.value = prune_linear_layer(self.self.value, index)
    # 对 BertSelfOutput 中的 dense 线性层进行剪枝操作
    self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

    # 更新剩余的注意力头数目和总头部大小
    self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
    self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
    # 将剪枝的头部索引添加到 pruned_heads 集合中
    self.pruned_heads = self.pruned_heads.union(heads)

# 前向传播方法,接收多个输入张量和可选参数,返回一个张量元组
def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.FloatTensor] = None,
    head_mask: Optional[torch.FloatTensor] = None,
    encoder_hidden_states: Optional[torch.FloatTensor] = None,
    encoder_attention_mask: Optional[torch.FloatTensor] = None,
    past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
    output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
    # 调用 BertSelfAttention 的 forward 方法进行自注意力计算
    self_outputs = self.self(
        hidden_states,
        attention_mask,
        head_mask,
        encoder_hidden_states,
        encoder_attention_mask,
        past_key_value,
        output_attentions,
    )
    # 将自注意力输出结果和原始隐藏状态传入 BertSelfOutput 对象进行输出计算
    attention_output = self.output(self_outputs[0], hidden_states)
    # 如果需要输出注意力权重,则将注意力权重添加到输出中
    outputs = (attention_output,) + self_outputs[1:]  # 如果需要,添加注意力权重
    return outputs

BertLayer 类的定义,继承自 nn.Module,表示这是一个神经网络模块

class BertLayer(nn.Module):
# 初始化方法,接受一个 config 参数
def init(self, config):
super().init() # 调用父类 nn.Module 的初始化方法
# 设置用于前向传播中的块大小
self.chunk_size_feed_forward = config.chunk_size_feed_forward
# 序列长度的维度,通常为 1
self.seq_len_dim = 1
# BertAttention 类的实例化,使用 config 参数进行初始化
self.attention = BertAttention(config)
# 是否作为解码器使用的标志
self.is_decoder = config.is_decoder
# 是否添加跨注意力机制的标志
self.add_cross_attention = config.add_cross_attention
# 如果添加了跨注意力机制
if self.add_cross_attention:
# 如果不是解码器,则抛出异常
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
# 使用绝对位置编码类型,实例化 BertAttention 类
self.crossattention = BertAttention(config, position_embedding_type="absolute")
# BertIntermediate 类的实例化,使用 config 参数进行初始化
self.intermediate = BertIntermediate(config)
# BertOutput 类的实例化,使用 config 参数进行初始化
self.output = BertOutput(config)

# 前向传播方法,接收多个输入参数并返回一个 torch.Tensor 对象
def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.FloatTensor] = None,
    head_mask: Optional[torch.FloatTensor] = None,
    encoder_hidden_states: Optional[torch.FloatTensor] = None,
    encoder_attention_mask: Optional[torch.FloatTensor] = None,
    past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
    output_attentions: Optional[bool] = False,
    ):
    # 使用 BertAttention 类处理隐藏状态,根据需要传入不同的参数
    hidden_states = self.attention(
        hidden_states,
        attention_mask,
        head_mask,
        encoder_hidden_states,
        encoder_attention_mask,
        past_key_value,
        output_attentions,
    )
    # 如果添加了跨注意力机制,使用 crossattention 处理隐藏状态
    if self.add_cross_attention:
        hidden_states = self.crossattention(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
        )
    # 使用 BertIntermediate 类处理中间状态,传入隐藏状态
    hidden_states = self.intermediate(hidden_states)
    # 使用 BertOutput 类处理输出,传入中间状态和输入张量
    hidden_states = self.output(hidden_states, input_tensor=hidden_states)
    # 返回处理后的隐藏状态张量
    return hidden_states
) -> Tuple[torch.Tensor]:  
    # 函数签名说明该函数返回一个包含torch.Tensor的元组
    # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
    # 如果存在过去的key/value缓存,只选择其前两个元素,否则为None
    self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
    # 调用self.attention方法进行自注意力计算
    self_attention_outputs = self.attention(
        hidden_states,
        attention_mask,
        head_mask,
        output_attentions=output_attentions,
        past_key_value=self_attn_past_key_value,
    )
    # 获取自注意力计算的输出
    attention_output = self_attention_outputs[0]

    # if decoder, the last output is tuple of self-attn cache
    # 如果是解码器,最后一个输出是自注意力缓存的元组
    if self.is_decoder:
        # 去除第一个和最后一个元素,因为它们是self_attention_outputs中的self-attention结果和cross-attention结果
        outputs = self_attention_outputs[1:-1]
        # 最后一个元素是当前时刻的key/value缓存
        present_key_value = self_attention_outputs[-1]
    else:
        # 如果不是解码器,则输出包括self-attention结果(如果输出注意力权重的话)
        outputs = self_attention_outputs[1:]  # 如果输出注意力权重,添加self-attention
    

    cross_attn_present_key_value = None
    # 如果是解码器并且有编码器的隐藏状态
    if self.is_decoder and encoder_hidden_states is not None:
        # 如果没有crossattention属性,抛出值错误
        if not hasattr(self, "crossattention"):
            raise ValueError(
                f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
                " by setting `config.add_cross_attention=True`"
            )

        # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
        # 如果存在过去的key/value缓存,选择其倒数第二个和最后一个元素,否则为None
        cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
        # 调用self.crossattention方法进行跨注意力计算
        cross_attention_outputs = self.crossattention(
            attention_output,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            cross_attn_past_key_value,
            output_attentions,
        )
        # 获取跨注意力计算的输出
        attention_output = cross_attention_outputs[0]
        # 添加cross-attention结果到outputs中,去除第一个和最后一个元素
        outputs = outputs + cross_attention_outputs[1:-1]

        # add cross-attn cache to positions 3,4 of present_key_value tuple
        # 将cross-attn缓存添加到present_key_value元组的倒数第二个和最后一个位置
        cross_attn_present_key_value = cross_attention_outputs[-1]
        present_key_value = present_key_value + cross_attn_present_key_value

    # 应用分块技术对attention_output应用self.feed_forward_chunk方法
    layer_output = apply_chunking_to_forward(
        self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
    )
    # 将layer_output作为第一个元素添加到outputs元组中
    outputs = (layer_output,) + outputs

    # if decoder, return the attn key/values as the last output
    # 如果是解码器,将attn key/values作为最后一个输出返回
    if self.is_decoder:
        outputs = outputs + (present_key_value,)

    # 返回outputs作为函数的输出结果
    return outputs

def feed_forward_chunk(self, attention_output):
    # 将attention_output作为输入,首先应用self.intermediate方法
    intermediate_output = self.intermediate(attention_output)
    # 然后将中间输出作为输入,应用self.output方法
    layer_output = self.output(intermediate_output, attention_output)
    # 返回最终的层输出结果
    return layer_output

定义一个名为 BertEncoder 的类,继承自 nn.Module 类

class BertEncoder(nn.Module):
# 初始化方法,接收一个 config 参数
def init(self, config):
# 调用父类的初始化方法
super().init()
# 将传入的 config 参数保存在实例变量 self.config 中
self.config = config
# 创建一个 nn.ModuleList 对象 self.layer,其中包含 config.num_hidden_layers 个 BertLayer 对象
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
# 初始化一个标志变量 gradient_checkpointing,并设置为 False
self.gradient_checkpointing = False

# 前向传播方法定义
def forward(
    self,
    hidden_states: torch.Tensor,  # 输入的隐藏状态张量
    attention_mask: Optional[torch.FloatTensor] = None,  # 可选的注意力掩码张量
    head_mask: Optional[torch.FloatTensor] = None,  # 可选的头部掩码张量
    encoder_hidden_states: Optional[torch.FloatTensor] = None,  # 可选的编码器隐藏状态张量
    encoder_attention_mask: Optional[torch.FloatTensor] = None,  # 可选的编码器注意力掩码张量
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,  # 可选的过去的键-值对
    use_cache: Optional[bool] = None,  # 可选的缓存标志
    output_attentions: Optional[bool] = False,  # 可选的输出注意力张量标志,默认为 False
    output_hidden_states: Optional[bool] = False,  # 可选的输出隐藏状态标志,默认为 False
    return_dict: Optional[bool] = True,  # 可选的返回字典标志,默认为 True
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
    # 初始化空的所有隐藏状态列表,如果不需要输出隐藏状态则为 None
    all_hidden_states = () if output_hidden_states else None
    # 初始化空的所有自注意力权重列表,如果不需要输出注意力权重则为 None
    all_self_attentions = () if output_attentions else None
    # 初始化空的所有交叉注意力权重列表,如果不需要输出交叉注意力权重或模型不支持则为 None
    all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

    # 如果启用了梯度检查点且处于训练模式下
    if self.gradient_checkpointing and self.training:
        # 如果设置了 use_cache=True,发出警告并设置 use_cache=False
        if use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
            )
            use_cache = False

    # 如果需要使用缓存,初始化下一个解码器缓存为一个空元组
    next_decoder_cache = () if use_cache else None
    # 遍历每一层解码器层
    for i, layer_module in enumerate(self.layer):
        # 如果需要输出隐藏状态,将当前隐藏状态加入到所有隐藏状态列表中
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 获取当前层的头部掩码,如果未提供则为 None
        layer_head_mask = head_mask[i] if head_mask is not None else None
        # 获取先前的键值对,如果未提供则为 None
        past_key_value = past_key_values[i] if past_key_values is not None else None

        # 如果启用了梯度检查点且处于训练模式下
        if self.gradient_checkpointing and self.training:
            # 使用梯度检查点函数计算当前层的输出
            layer_outputs = self._gradient_checkpointing_func(
                layer_module.__call__,
                hidden_states,
                attention_mask,
                layer_head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                past_key_value,
                output_attentions,
            )
        else:
            # 否则直接调用当前层模块计算输出
            layer_outputs = layer_module(
                hidden_states,
                attention_mask,
                layer_head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                past_key_value,
                output_attentions,
            )

        # 更新隐藏状态为当前层的输出的第一个元素
        hidden_states = layer_outputs[0]
        # 如果使用缓存,将当前层的缓存状态添加到下一个解码器缓存中
        if use_cache:
            next_decoder_cache += (layer_outputs[-1],)
        # 如果需要输出注意力权重,将当前层的自注意力权重加入到所有自注意力列表中
        if output_attentions:
            all_self_attentions = all_self_attentions + (layer_outputs[1],)
            # 如果模型支持并需要输出交叉注意力,将当前层的交叉注意力加入到所有交叉注意力列表中
            if self.config.add_cross_attention:
                all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

    # 如果需要输出隐藏状态,将最终隐藏状态加入到所有隐藏状态列表中
    if output_hidden_states:
        all_hidden_states = all_hidden_states + (hidden_states,)

    # 如果不需要返回字典形式的输出,返回一个元组
    if not return_dict:
        return tuple(
            v
            for v in [
                hidden_states,
                next_decoder_cache,
                all_hidden_states,
                all_self_attentions,
                all_cross_attentions,
            ]
            if v is not None
        )
    # 否则返回 BaseModelOutputWithPastAndCrossAttentions 对象
    return BaseModelOutputWithPastAndCrossAttentions(
        last_hidden_state=hidden_states,
        past_key_values=next_decoder_cache,
        hidden_states=all_hidden_states,
        attentions=all_self_attentions,
        cross_attentions=all_cross_attentions,
    )

class BertPooler(nn.Module):
def init(self, config):
super().init()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) # 创建一个全连接层,输入输出大小相同
self.activation = nn.Tanh() # 创建一个tanh激活函数实例

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    # 我们通过简单地取对应于第一个标记的隐藏状态来“汇聚”模型。
    first_token_tensor = hidden_states[:, 0]  # 选择隐藏状态张量的第一个标记
    pooled_output = self.dense(first_token_tensor)  # 将第一个标记的隐藏状态传入全连接层
    pooled_output = self.activation(pooled_output)  # 使用tanh激活函数进行激活
    return pooled_output

class BertPredictionHeadTransform(nn.Module):
def init(self, config):
super().init()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) # 创建一个全连接层,输入输出大小相同
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act] # 如果配置中的隐藏激活函数是字符串,则使用预定义的激活函数映射
else:
self.transform_act_fn = config.hidden_act # 否则直接使用配置中的激活函数
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # 创建一个LayerNorm层

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    hidden_states = self.dense(hidden_states)  # 将隐藏状态传入全连接层
    hidden_states = self.transform_act_fn(hidden_states)  # 应用预定义或配置中的激活函数
    hidden_states = self.LayerNorm(hidden_states)  # 应用LayerNorm
    return hidden_states

class BertLMPredictionHead(nn.Module):
def init(self, config):
super().init()
self.transform = BertPredictionHeadTransform(config) # 创建一个预测头变换器

    # 输出权重与输入嵌入相同,但每个标记都有一个仅输出的偏置项。
    self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)  # 创建一个线性层,无偏置

    self.bias = nn.Parameter(torch.zeros(config.vocab_size))  # 创建一个偏置参数

    # 需要一个链接来确保偏置随 `resize_token_embeddings` 正确调整大小
    self.decoder.bias = self.bias  # 将偏置参数链接到解码器的偏置

def forward(self, hidden_states):
    hidden_states = self.transform(hidden_states)  # 应用变换器到隐藏状态
    hidden_states = self.decoder(hidden_states)  # 应用解码器到变换后的隐藏状态
    return hidden_states

class BertOnlyMLMHead(nn.Module):
def init(self, config):
super().init()
self.predictions = BertLMPredictionHead(config) # 创建一个MLM头部预测器

def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
    prediction_scores = self.predictions(sequence_output)  # 使用预测器进行序列输出的预测
    return prediction_scores

class BertOnlyNSPHead(nn.Module):
def init(self, config):
super().init()
self.seq_relationship = nn.Linear(config.hidden_size, 2) # 创建一个线性层,用于NSP头部

def forward(self, pooled_output):
    seq_relationship_score = self.seq_relationship(pooled_output)  # 计算汇聚输出的关系分数
    return seq_relationship_score

class BertPreTrainingHeads(nn.Module):
def init(self, config):
super().init()
self.predictions = BertLMPredictionHead(config) # 创建一个预测头部
self.seq_relationship = nn.Linear(config.hidden_size, 2) # 创建一个线性层,用于NSP头部
# 定义一个方法 forward,接收 sequence_outputpooled_output 作为参数
def forward(self, sequence_output, pooled_output):
# 调用 self.predictions 方法,传入 sequence_output 参数,返回预测分数
prediction_scores = self.predictions(sequence_output)
# 调用 self.seq_relationship 方法,传入 pooled_output 参数,返回序列关系分数
seq_relationship_score = self.seq_relationship(pooled_output)
# 返回预测分数和序列关系分数作为结果
return prediction_scores, seq_relationship_score

定义一个名为 BertPreTrainedModel 的类,继承自 PreTrainedModel 类

class BertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""

# 指定配置类为 BertConfig
config_class = BertConfig
# 指定加载 TensorFlow 权重的函数为 load_tf_weights_in_bert
load_tf_weights = load_tf_weights_in_bert
# 指定基础模型的前缀为 "bert"
base_model_prefix = "bert"
# 支持梯度检查点
supports_gradient_checkpointing = True

# 初始化权重的函数,根据模块类型不同进行初始化
def _init_weights(self, module):
    """Initialize the weights"""
    if isinstance(module, nn.Linear):
        # 对线性层的权重使用正态分布初始化,标准差为 self.config.initializer_range
        module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        # 如果存在偏置项,则将其初始化为零
        if module.bias is not None:
            module.bias.data.zero_()
    elif isinstance(module, nn.Embedding):
        # 对嵌入层的权重使用正态分布初始化,标准差为 self.config.initializer_range
        module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        # 如果指定了填充索引,则将对应的权重初始化为零
        if module.padding_idx is not None:
            module.weight.data[module.padding_idx].zero_()
    elif isinstance(module, nn.LayerNorm):
        # 对 LayerNorm 层的偏置项初始化为零
        module.bias.data.zero_()
        # 对 LayerNorm 层的权重初始化为 1.0
        module.weight.data.fill_(1.0)

使用 dataclass 装饰器定义 BertForPreTrainingOutput 类,继承自 ModelOutput 类

@dataclass
class BertForPreTrainingOutput(ModelOutput):
"""
Output type of [BertForPreTraining].

Args:
    loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
        Total loss as the sum of the masked language modeling loss and the next sequence prediction
        (classification) loss.
    prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
    seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
        Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
        before SoftMax).
    hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
        shape `(batch_size, sequence_length, hidden_size)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs.
    attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
        heads.
"""

# 可选项:损失,当提供 labels 时返回,torch.FloatTensor,形状为 (1,)
loss: Optional[torch.FloatTensor] = None
# 预测 logits:语言建模头部的预测分数,形状为 (batch_size, sequence_length, config.vocab_size)
prediction_logits: torch.FloatTensor = None
# 序列关系 logits:下一个序列预测头部的预测分数,形状为 (batch_size, 2)
seq_relationship_logits: torch.FloatTensor = None
# 定义一个可选的变量 `hidden_states`,类型为包含单个 `torch.FloatTensor` 的元组或空值
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
# 定义一个可选的变量 `attentions`,类型为包含单个 `torch.FloatTensor` 的元组或空值
attentions: Optional[Tuple[torch.FloatTensor]] = None

BERT_START_DOCSTRING 是一个原始的文档字符串,用于描述 BERT 模型的基本信息和用法。

这个模型继承自 PreTrainedModel,可以查看其父类文档了解通用方法,如下载、保存、调整输入嵌入大小、剪枝等。

同时,这个模型也是 PyTorch 的 torch.nn.Module 子类,可以像普通的 PyTorch 模块一样使用,相关用法和行为请参考 PyTorch 文档。

BERT_INPUTS_DOCSTRING = r"""
Args:
input_ids (torch.LongTensor of shape ({0})):
# 输入序列的标记索引,对应词汇表中的位置。

        # 可以使用 [`AutoTokenizer`] 获取这些索引。参见 [`PreTrainedTokenizer.encode`] 和
        # [`PreTrainedTokenizer.__call__`] 获取详细信息。

        # [什么是输入 ID?](../glossary#input-ids)
    attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
        # 遮盖掩码,用于在填充的标记索引上避免执行注意力操作。遮盖值在 `[0, 1]` 之间:

        # - 1 表示对应的标记是 **未遮盖的**,
        # - 0 表示对应的标记是 **遮盖的**。

        # [什么是注意力遮盖?](../glossary#attention-mask)
    token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
        # 分段标记索引,指示输入的第一部分和第二部分。索引选择在 `[0, 1]`:

        # - 0 对应 *句子 A* 的标记,
        # - 1 对应 *句子 B* 的标记。

        # [什么是标记类型 ID?](../glossary#token-type-ids)
    position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
        # 输入序列中每个标记的位置索引,在位置嵌入中使用。索引选择范围为 `[0, config.max_position_embeddings - 1]`。

        # [什么是位置 ID?](../glossary#position-ids)
    head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
        # 用于屏蔽自注意力模块中选定头部的掩码。掩码值在 `[0, 1]` 之间:

        # - 1 表示头部 **未被屏蔽**,
        # - 0 表示头部 **被屏蔽**。
    inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
        # 可选参数,可以直接传递嵌入表示,而不是传递 `input_ids`。如果希望更好地控制如何将 `input_ids` 索引转换为关联向量,
        # 这种方法非常有用,而不使用模型的内部嵌入查找矩阵。
    output_attentions (`bool`, *optional*):
        # 是否返回所有注意力层的注意力张量。返回的张量中有关于 `attentions` 的更多详细信息。
    output_hidden_states (`bool`, *optional*):
        # 是否返回所有层的隐藏状态。返回的张量中有关于 `hidden_states` 的更多详细信息。
    return_dict (`bool`, *optional*):
        # 是否返回 [`~utils.ModelOutput`] 而不是简单的元组。

"""
@add_start_docstrings(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
BERT_START_DOCSTRING,
)
class BertModel(BertPreTrainedModel):
"""
BertModel类,继承自BertPreTrainedModel,表示一个Bert模型,可以输出原始的隐藏状态,没有特定的输出头部。

The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.

该模型可以作为编码器(仅具有自注意力)或解码器,如果是解码器,会在自注意力层之间添加交叉注意力层,遵循[Attention is
all you need](https://arxiv.org/abs/1706.03762) 中描述的架构,作者为Ashish Vaswani, Noam Shazeer, Niki Parmar,
Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser 和 Illia Polosukhin。

To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.

要作为解码器使用,模型需要用配置中的`is_decoder`参数初始化为`True`。在Seq2Seq模型中使用时,模型需要用`is_decoder`和
`add_cross_attention`两个参数初始化为`True`;然后预期在前向传递中作为输入的`encoder_hidden_states`。
"""

def __init__(self, config, add_pooling_layer=True):
    super().__init__(config)
    self.config = config

    # 初始化嵌入层
    self.embeddings = BertEmbeddings(config)
    # 初始化编码器层
    self.encoder = BertEncoder(config)

    # 如果需要添加汇聚层,则初始化汇聚层;否则为None
    self.pooler = BertPooler(config) if add_pooling_layer else None

    # 初始化权重并应用最终处理
    self.post_init()

def get_input_embeddings(self):
    # 返回嵌入层中的词嵌入
    return self.embeddings.word_embeddings

def set_input_embeddings(self, value):
    # 设置嵌入层的词嵌入为指定的值
    self.embeddings.word_embeddings = value

def _prune_heads(self, heads_to_prune):
    """
    Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
    class PreTrainedModel
    """
    # 对模型的注意力头进行修剪
    for layer, heads in heads_to_prune.items():
        self.encoder.layer[layer].attention.prune_heads(heads)

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
    checkpoint=_CHECKPOINT_FOR_DOC,
    output_type=BaseModelOutputWithPoolingAndCrossAttentions,
    config_class=_CONFIG_FOR_DOC,
)
def forward(
    self,
    input_ids: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    token_type_ids: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.Tensor] = None,
    head_mask: Optional[torch.Tensor] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    encoder_hidden_states: Optional[torch.Tensor] = None,
    encoder_attention_mask: Optional[torch.Tensor] = None,
    past_key_values: Optional[List[torch.FloatTensor]] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
):
    """
    前向传递函数,接受多个参数用于构建Bert模型的输入和控制输出格式。

    Args:
        input_ids (Optional[torch.Tensor], optional): 输入的token ID张量,默认为None。
        attention_mask (Optional[torch.Tensor], optional): 注意力掩码张量,默认为None。
        token_type_ids (Optional[torch.Tensor], optional): 分段类型ID张量,默认为None。
        position_ids (Optional[torch.Tensor], optional): 位置ID张量,默认为None。
        head_mask (Optional[torch.Tensor], optional): 头部掩码张量,默认为None。
        inputs_embeds (Optional[torch.Tensor], optional): 嵌入输入张量,默认为None。
        encoder_hidden_states (Optional[torch.Tensor], optional): 编码器隐藏状态张量,默认为None。
        encoder_attention_mask (Optional[torch.Tensor], optional): 编码器的注意力掩码张量,默认为None。
        past_key_values (Optional[List[torch.FloatTensor]], optional): 过去的键值对列表,默认为None。
        use_cache (Optional[bool], optional): 是否使用缓存,默认为None。
        output_attentions (Optional[bool], optional): 是否输出注意力,默认为None。
        output_hidden_states (Optional[bool], optional): 是否输出隐藏状态,默认为None。
        return_dict (Optional[bool], optional): 是否返回字典格式的输出,默认为None。

    Returns:
        根据参数不同可能返回不同形式的输出,详见具体参数说明。
    """
    # 实际的前向传递逻辑由具体的Bert模型实现,这里是接口定义和文档说明
    pass
Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
sentence prediction (classification)` head.

描述BERT模型的结构,包括两个预训练阶段添加的顶部头部:一个用于掩码语言建模,另一个用于下一句预测分类。

""",
BERT_START_DOCSTRING,

添加额外的文档字符串注释,并引用了 BERT_START_DOCSTRING 变量。

@add_start_docstrings(
"""Bert Model with a masked language modeling head on top for MLM fine-tuning.""", BERT_START_DOCSTRING
)
class BertForMaskedLM(BertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["position_ids"]

def __init__(self, config):
    super().__init__(config)

    if not config.is_decoder:
        logger.warning("If you want to use `BertForMaskedLM` as a standalone, add `is_decoder=True.`")

    self.bert = BertModel(config, add_pooling_layer=False)
    self.cls = BertOnlyMLMHead(config)

    # Initialize weights and apply final processing
    self.post_init()

def get_output_embeddings(self):
    """
    Retrieve the output embedding layer for predictions.
    """
    return self.cls.predictions.decoder

def set_output_embeddings(self, new_embeddings):
    """
    Set new output embeddings for the prediction layer.
    """
    self.cls.predictions.decoder = new_embeddings

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
    checkpoint=_CHECKPOINT_FOR_DOC,
    output_type=CausalLMOutputWithCrossAttentions,
    config_class=_CONFIG_FOR_DOC,
)
def forward(
    self,
    input_ids: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    token_type_ids: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.Tensor] = None,
    head_mask: Optional[torch.Tensor] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    labels: Optional[torch.Tensor] = None,
    next_sentence_label: Optional[torch.Tensor] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
):
    """
    Forward pass for BertForMaskedLM model.
    """
def forward(
    self,
    input_ids: Optional[torch.Tensor] = None,  # 输入的token IDs序列,可选的Tensor类型
    attention_mask: Optional[torch.Tensor] = None,  # 注意力遮罩,可选的Tensor类型
    token_type_ids: Optional[torch.Tensor] = None,  # token类型IDs,可选的Tensor类型
    position_ids: Optional[torch.Tensor] = None,  # 位置IDs,可选的Tensor类型
    head_mask: Optional[torch.Tensor] = None,  # 头部遮罩,可选的Tensor类型
    inputs_embeds: Optional[torch.Tensor] = None,  # 输入的嵌入向量,可选的Tensor类型
    encoder_hidden_states: Optional[torch.Tensor] = None,  # 编码器隐藏状态,可选的Tensor类型
    encoder_attention_mask: Optional[torch.Tensor] = None,  # 编码器注意力遮罩,可选的Tensor类型
    labels: Optional[torch.Tensor] = None,  # 标签,可选的Tensor类型
    past_key_values: Optional[List[torch.Tensor]] = None,  # 过去的键值,可选的Tensor列表类型
    use_cache: Optional[bool] = None,  # 是否使用缓存,可选的布尔类型
    output_attentions: Optional[bool] = None,  # 是否输出注意力,可选的布尔类型
    output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态,可选的布尔类型
    return_dict: Optional[bool] = None,  # 是否返回字典,可选的布尔类型
):
    # 此方法定义了模型的前向传播逻辑,接收多种输入参数,并根据需要返回不同的输出

def prepare_inputs_for_generation(
    self, input_ids, past_key_values=None, attention_mask=None, use_cache=True, **model_kwargs
):
    input_shape = input_ids.shape  # 获取输入IDs的形状

    # 如果未提供注意力遮罩,则创建全为1的遮罩,保证所有token被处理
    if attention_mask is None:
        attention_mask = input_ids.new_ones(input_shape)

    # 如果传入了过去的键值(用于生成),则截断输入的token IDs
    if past_key_values is not None:
        past_length = past_key_values[0][0].shape[2]

        # 一些生成方法可能只传递最后一个输入ID
        if input_ids.shape[1] > past_length:
            remove_prefix_length = past_length
        else:
            # 默认行为:只保留最后一个ID
            remove_prefix_length = input_ids.shape[1] - 1

        input_ids = input_ids[:, remove_prefix_length:]

    # 返回准备好的输入字典,用于生成(或解码)阶段
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "past_key_values": past_key_values,
        "use_cache": use_cache,
    }

def _reorder_cache(self, past_key_values, beam_idx):
    reordered_past = ()
    for layer_past in past_key_values:
        # 重排过去的键值,以适应beam搜索的索引顺序
        reordered_past += (
            tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
        )
    return reordered_past

为带有顶部 语言建模 头部的 Bert 模型添加文档字符串

@add_start_docstrings("""Bert Model with a language modeling head on top.""", BERT_START_DOCSTRING)
class BertForMaskedLM(BertPreTrainedModel):
# 定义绑定权重的键名列表
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]

# 初始化函数,接受一个配置对象作为参数
def __init__(self, config):
    # 调用父类的初始化函数
    super().__init__(config)

    # 如果配置标记为解码器,则发出警告
    if config.is_decoder:
        logger.warning(
            "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
            "bi-directional self-attention."
        )

    # 创建 BertModel 实例,不包含池化层
    self.bert = BertModel(config, add_pooling_layer=False)
    # 创建 BertOnlyMLMHead 实例
    self.cls = BertOnlyMLMHead(config)

    # 执行额外的初始化操作,如权重初始化和最终处理
    self.post_init()

# 返回输出嵌入的函数
def get_output_embeddings(self):
    return self.cls.predictions.decoder

# 设置输出嵌入的函数,接受新的嵌入作为参数
def set_output_embeddings(self, new_embeddings):
    self.cls.predictions.decoder = new_embeddings

# 前向传播函数,接受多个输入参数,根据文档字符串描述了各参数的含义
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
    checkpoint=_CHECKPOINT_FOR_DOC,
    output_type=MaskedLMOutput,
    config_class=_CONFIG_FOR_DOC,
    expected_output="'paris'",
    expected_loss=0.88,
)
def forward(
    self,
    input_ids: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    token_type_ids: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.Tensor] = None,
    head_mask: Optional[torch.Tensor] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    encoder_hidden_states: Optional[torch.Tensor] = None,
    encoder_attention_mask: Optional[torch.Tensor] = None,
    labels: Optional[torch.Tensor] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
    r"""
    labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
        Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
        config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
        loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
    """

    # 确定是否使用返回字典形式
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    # 将输入传递给BERT模型,获取输出
    outputs = self.bert(
        input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        encoder_hidden_states=encoder_hidden_states,
        encoder_attention_mask=encoder_attention_mask,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )

    # 从BERT模型的输出中获取序列输出
    sequence_output = outputs[0]
    # 通过分类层获取预测得分
    prediction_scores = self.cls(sequence_output)

    masked_lm_loss = None
    # 如果提供了标签,则计算masked language modeling loss
    if labels is not None:
        loss_fct = CrossEntropyLoss()  # 交叉熵损失函数,用于计算损失
        masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

    # 如果不要求返回字典形式的输出
    if not return_dict:
        # 构造输出元组
        output = (prediction_scores,) + outputs[2:]
        return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

    # 返回MaskedLMOutput对象,封装了loss、logits、hidden_states和attentions
    return MaskedLMOutput(
        loss=masked_lm_loss,
        logits=prediction_scores,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
    input_shape = input_ids.shape
    effective_batch_size = input_shape[0]

    # 添加一个虚拟token,用于生成
    if self.config.pad_token_id is None:
        raise ValueError("The PAD token should be defined for generation")

    # 修改attention_mask,在末尾添加一个全零列
    attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
    # 创建全为pad_token_id的虚拟token
    dummy_token = torch.full(
        (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
    )
    # 将虚拟token拼接到input_ids末尾
    input_ids = torch.cat([input_ids, dummy_token], dim=1)

    # 返回包含输入信息的字典
    return {"input_ids": input_ids, "attention_mask": attention_mask}

定义一个带有“下一个句子预测(分类)”头部的 Bert 模型。

使用 BERT_START_DOCSTRING 和 BERT_START_DOCSTRING 描述模型的基本信息。

@add_start_docstrings(
"""Bert Model with a next sentence prediction (classification) head on top.""",
BERT_START_DOCSTRING,
)
class BertForNextSentencePrediction(BertPreTrainedModel):

# 初始化方法,接受一个配置对象 config 作为参数。
def __init__(self, config):
    super().__init__(config)
    
    # 初始化 BertModel,加载预训练的 BERT 模型。
    self.bert = BertModel(config)
    
    # 初始化 BertOnlyNSPHead,用于执行仅包含 NSP(Next Sentence Prediction)的任务。
    self.cls = BertOnlyNSPHead(config)

    # 执行额外的初始化步骤和最终处理。
    self.post_init()

# 前向传播函数,接受多个输入参数,包括 input_ids、attention_mask 等。
# 使用 BERT_INPUTS_DOCSTRING 描述输入的详细信息,格式为 batch_size, sequence_length。
# 使用 NextSentencePredictorOutput 类型描述输出,配置类为 _CONFIG_FOR_DOC。
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def forward(
    self,
    input_ids: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    token_type_ids: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.Tensor] = None,
    head_mask: Optional[torch.Tensor] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    labels: Optional[torch.Tensor] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    **kwargs,
    ) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
    r"""
    labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
        Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
        (see `input_ids` docstring). Indices should be in `[0, 1]`:

        - 0 indicates sequence B is a continuation of sequence A,
        - 1 indicates sequence B is a random sequence.

    Returns:
        Depending on `return_dict`:
        - If `return_dict` is `False`, returns a tuple with `next_sentence_loss` (if computed) and `seq_relationship_scores` and possibly additional hidden states.
        - If `return_dict` is `True`, returns a `NextSentencePredictorOutput` object containing `loss`, `logits`, `hidden_states`, and `attentions`.

    Example:

    ```
    >>> from transformers import AutoTokenizer, BertForNextSentencePrediction
    >>> import torch

    >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
    >>> model = BertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased")

    >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
    >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
    >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")

    >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
    >>> logits = outputs.logits
    >>> assert logits[0, 0] < logits[0, 1]  # next sentence was random
    ```
    """

    if "next_sentence_label" in kwargs:
        # 发出警告,`next_sentence_label` 参数已废弃,建议使用 `labels` 参数代替
        warnings.warn(
            "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
            " `labels` instead.",
            FutureWarning,
        )
        # 将 `next_sentence_label` 的值赋给 `labels` 变量,并从 `kwargs` 中移除该参数
        labels = kwargs.pop("next_sentence_label")

    # 根据 `return_dict` 是否为 `None` 确定是否使用配置中的返回字典设置
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    # 使用 BERT 模型处理输入数据
    outputs = self.bert(
        input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )

    # 从 BERT 输出中获取池化后的输出
    pooled_output = outputs[1]

    # 使用分类层处理池化输出,得到序列关系的分数
    seq_relationship_scores = self.cls(pooled_output)

    # 初始化下一个句子预测的损失为 None
    next_sentence_loss = None
    # 如果提供了 `labels` 参数,则计算下一个句子预测的交叉熵损失
    if labels is not None:
        loss_fct = CrossEntropyLoss()
        next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))

    # 如果 `return_dict` 为 False,则返回一个包含 `next_sentence_loss` 和可能的其他隐藏状态的元组
    if not return_dict:
        output = (seq_relationship_scores,) + outputs[2:]
        return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output

    # 如果 `return_dict` 为 True,则返回一个 `NextSentencePredictorOutput` 对象
    return NextSentencePredictorOutput(
        loss=next_sentence_loss,
        logits=seq_relationship_scores,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

使用装饰器添加开始文档字符串,描述了这是一个在Bert模型基础上增加了顶部序列分类/回归头的转换器类,例如用于GLUE任务。

@add_start_docstrings(
"""
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
output) e.g. for GLUE tasks.
""",
BERT_START_DOCSTRING, # 引用了BERT的起始文档字符串
)
class BertForSequenceClassification(BertPreTrainedModel):
def init(self, config):
super().init(config)
self.num_labels = config.num_labels # 从配置中获取标签数量
self.config = config

    self.bert = BertModel(config)  # 初始化BERT模型
    classifier_dropout = (
        config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
    )
    self.dropout = nn.Dropout(classifier_dropout)  # 初始化dropout层
    self.classifier = nn.Linear(config.hidden_size, config.num_labels)  # 初始化线性分类器层

    # 初始化权重并进行最终处理
    self.post_init()

# 使用装饰器添加模型前向传播的开始文档字符串,描述了输入参数的预期形状
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
# 使用装饰器添加代码示例的文档字符串,展示了模型的检查点、输出类型、配置类、预期输出和损失
@add_code_sample_docstrings(
    checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
    output_type=SequenceClassifierOutput,
    config_class=_CONFIG_FOR_DOC,
    expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
    expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
)
def forward(
    self,
    input_ids: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    token_type_ids: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.Tensor] = None,
    head_mask: Optional[torch.Tensor] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    labels: Optional[torch.Tensor] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
    r"""
    labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
        Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
        config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
        `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
    """
    # 如果 return_dict 不为 None,则使用它;否则使用 self.config.use_return_dict
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    # 将输入传递给 BERT 模型进行前向传播
    outputs = self.bert(
        input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )

    # 获取 BERT 输出中的池化输出(通常是 CLS token 的输出)
    pooled_output = outputs[1]

    # 对池化输出应用 dropout
    pooled_output = self.dropout(pooled_output)
    
    # 将 dropout 后的输出传递给分类器,得到 logits
    logits = self.classifier(pooled_output)

    # 初始化损失值
    loss = None
    
    # 如果 labels 不为 None,则计算损失
    if labels is not None:
        # 根据配置决定问题类型,如果未指定,则根据 num_labels 判断是回归还是分类问题
        if self.config.problem_type is None:
            if self.num_labels == 1:
                self.config.problem_type = "regression"
            elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                self.config.problem_type = "single_label_classification"
            else:
                self.config.problem_type = "multi_label_classification"

        # 根据问题类型选择相应的损失函数
        if self.config.problem_type == "regression":
            loss_fct = MSELoss()
            if self.num_labels == 1:
                loss = loss_fct(logits.squeeze(), labels.squeeze())
            else:
                loss = loss_fct(logits, labels)
        elif self.config.problem_type == "single_label_classification":
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        elif self.config.problem_type == "multi_label_classification":
            loss_fct = BCEWithLogitsLoss()
            loss = loss_fct(logits, labels)
    
    # 如果 return_dict 为 False,则返回 logits 和 BERT 模型的隐藏状态
    if not return_dict:
        output = (logits,) + outputs[2:]
        return ((loss,) + output) if loss is not None else output

    # 如果 return_dict 为 True,则构造 SequenceClassifierOutput 对象返回
    return SequenceClassifierOutput(
        loss=loss,
        logits=logits,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

"""
Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
softmax) e.g. for RocStories/SWAG tasks.
"""

导入必要的库和模块

@add_start_docstrings(
"""
Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
softmax) e.g. for RocStories/SWAG tasks.
""",
BERT_START_DOCSTRING,
)

定义 BertForMultipleChoice 类,继承自 BertPreTrainedModel

class BertForMultipleChoice(BertPreTrainedModel):
# 初始化方法
def init(self, config):
# 调用父类的初始化方法
super().init(config)

    # 加载预训练的 BERT 模型
    self.bert = BertModel(config)
    # 设置分类器的 dropout 概率
    classifier_dropout = (
        config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
    )
    # 定义 dropout 层
    self.dropout = nn.Dropout(classifier_dropout)
    # 定义分类器线性层
    self.classifier = nn.Linear(config.hidden_size, 1)

    # 初始化权重并应用最终处理
    self.post_init()

# 前向传播方法
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
    checkpoint=_CHECKPOINT_FOR_DOC,
    output_type=MultipleChoiceModelOutput,
    config_class=_CONFIG_FOR_DOC,
)
def forward(
    self,
    input_ids: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    token_type_ids: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.Tensor] = None,
    head_mask: Optional[torch.Tensor] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    labels: Optional[torch.Tensor] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    # 执行前向传播,处理输入参数,返回模型输出

    self,
    input_ids: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    token_type_ids: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.Tensor] = None,
    head_mask: Optional[torch.Tensor] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    labels: Optional[torch.Tensor] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,



    # 执行前向传播,处理输入参数,返回模型输出
    ...
) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
    r"""
    labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
        Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
        num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
        `input_ids` above)
    """
    # 如果 return_dict 为 None,则使用模型配置中的默认设置
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    # 获取输入张量 input_ids 的第二维度大小作为选择数
    num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

    # 将 input_ids 重新视图为二维张量,第一维为 -1,第二维与原始最后一维相同
    input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
    # 将 attention_mask 重新视图为二维张量,第一维为 -1,第二维与原始最后一维相同
    attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
    # 将 token_type_ids 重新视图为二维张量,第一维为 -1,第二维与原始最后一维相同
    token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
    # 将 position_ids 重新视图为二维张量,第一维为 -1,第二维与原始最后一维相同
    position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
    # 将 inputs_embeds 重新视图为三维张量,第一维为 -1,第二维和第三维与原始相同
    inputs_embeds = (
        inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
        if inputs_embeds is not None
        else None
    )

    # 将输入张量传递给 BERT 模型进行处理,并获取模型的输出
    outputs = self.bert(
        input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )

    # 获取汇总输出(pooled_output),这通常是 BERT 模型的第二个输出
    pooled_output = outputs[1]

    # 对汇总输出应用 dropout
    pooled_output = self.dropout(pooled_output)
    # 使用分类器(通常是一个线性层)对汇总输出进行分类预测
    logits = self.classifier(pooled_output)
    # 重新调整 logits 的形状,使其匹配 num_choices 的维度
    reshaped_logits = logits.view(-1, num_choices)

    loss = None
    # 如果提供了标签,则计算交叉熵损失
    if labels is not None:
        loss_fct = CrossEntropyLoss()
        loss = loss_fct(reshaped_logits, labels)

    # 如果不使用 return_dict,按照非字典格式返回输出
    if not return_dict:
        output = (reshaped_logits,) + outputs[2:]
        return ((loss,) + output) if loss is not None else output

    # 如果使用 return_dict,按照字典格式返回 MultipleChoiceModelOutput
    return MultipleChoiceModelOutput(
        loss=loss,
        logits=reshaped_logits,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

给 BertForTokenClassification 类添加文档字符串,描述其作用和用途,特别是用于命名实体识别 (NER) 等任务的 token 分类模型

@add_start_docstrings(
"""
Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
BERT_START_DOCSTRING,
)
class BertForTokenClassification(BertPreTrainedModel):
def init(self, config):
super().init(config)
self.num_labels = config.num_labels

    # 初始化 Bert 模型,不添加池化层
    self.bert = BertModel(config, add_pooling_layer=False)
    # 根据配置设置分类器的 dropout
    classifier_dropout = (
        config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
    )
    self.dropout = nn.Dropout(classifier_dropout)
    # 线性层,将隐藏状态输出映射到标签数量的空间
    self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    # 初始化权重并应用最终处理
    self.post_init()

# 给 forward 方法添加文档字符串,描述其输入和输出,使用了 BERT_INPUTS_DOCSTRING 中的说明
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
# 添加代码示例文档字符串,显示了如何从检查点加载模型并进行 token 分类
@add_code_sample_docstrings(
    checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION,
    output_type=TokenClassifierOutput,
    config_class=_CONFIG_FOR_DOC,
    expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT,
    expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
)
def forward(
    self,
    input_ids: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    token_type_ids: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.Tensor] = None,
    head_mask: Optional[torch.Tensor] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    labels: Optional[torch.Tensor] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    ):
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
    r"""
    labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
        Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
    """
    # 如果 return_dict 不为 None,则使用其值;否则使用 self.config.use_return_dict 的设置
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    # 使用 BERT 模型处理输入数据,并返回其输出结果
    outputs = self.bert(
        input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )

    # 获取 BERT 模型的输出序列表示
    sequence_output = outputs[0]

    # 对输出序列进行 dropout 处理
    sequence_output = self.dropout(sequence_output)
    # 将 dropout 后的序列输出结果输入分类器,得到分类器的 logits
    logits = self.classifier(sequence_output)

    # 初始化损失为 None
    loss = None
    # 如果提供了标签(labels),则计算交叉熵损失
    if labels is not None:
        loss_fct = CrossEntropyLoss()
        loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

    # 如果 return_dict 为 False,则返回额外的输出信息
    if not return_dict:
        output = (logits,) + outputs[2:]  # 保留分类器 logits 和额外的输出信息
        return ((loss,) + output) if loss is not None else output

    # 如果 return_dict 为 True,则使用 TokenClassifierOutput 类封装输出结果
    return TokenClassifierOutput(
        loss=loss,
        logits=logits,
        hidden_states=outputs.hidden_states,  # 返回所有隐藏状态
        attentions=outputs.attentions,        # 返回所有注意力权重
    )

定义一个 Bert 模型,用于提取式问答任务(如 SQuAD),在隐藏状态的输出上方添加一个线性层,用于计算“起始位置对数”和“结束位置对数”。

@add_start_docstrings(
"""
Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layers on top of the hidden-states output to compute span start logits and span end logits).
""",
BERT_START_DOCSTRING,
)
class BertForQuestionAnswering(BertPreTrainedModel):
def init(self, config):
super().init(config)
# 设置模型的标签数目
self.num_labels = config.num_labels

    # 初始化 Bert 模型,不包含池化层
    self.bert = BertModel(config, add_pooling_layer=False)
    # 线性层,用于输出问题答案的起始位置和结束位置的对数概率
    self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

    # 初始化权重并应用最终处理
    self.post_init()

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
    checkpoint=_CHECKPOINT_FOR_QA,
    output_type=QuestionAnsweringModelOutput,
    config_class=_CONFIG_FOR_DOC,
    qa_target_start_index=_QA_TARGET_START_INDEX,
    qa_target_end_index=_QA_TARGET_END_INDEX,
    expected_output=_QA_EXPECTED_OUTPUT,
    expected_loss=_QA_EXPECTED_LOSS,
)
# 前向传播方法,接受多种输入参数,计算模型的输出
def forward(
    self,
    input_ids: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    token_type_ids: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.Tensor] = None,
    head_mask: Optional[torch.Tensor] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    start_positions: Optional[torch.Tensor] = None,
    end_positions: Optional[torch.Tensor] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
    r"""
    start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
        Labels for position (index) of the start of the labelled span for computing the token classification loss.
        Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
        are not taken into account for computing the loss.
    end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
        Labels for position (index) of the end of the labelled span for computing the token classification loss.
        Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
        are not taken into account for computing the loss.
    """
    # 初始化是否返回字典形式的输出,默认为模型配置中的设定
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    # 使用 BERT 模型处理输入数据,获取模型的输出
    outputs = self.bert(
        input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )

    # 从 BERT 输出中获取序列输出
    sequence_output = outputs[0]

    # 将序列输出传入问答模型的输出层,得到起始位置和结束位置的 logits
    logits = self.qa_outputs(sequence_output)
    start_logits, end_logits = logits.split(1, dim=-1)
    start_logits = start_logits.squeeze(-1).contiguous()
    end_logits = end_logits.squeeze(-1).contiguous()

    total_loss = None
    if start_positions is not None and end_positions is not None:
        # 如果在多 GPU 上运行,需要添加维度以适应多 GPU 并行计算
        if len(start_positions.size()) > 1:
            start_positions = start_positions.squeeze(-1)
        if len(end_positions.size()) > 1:
            end_positions = end_positions.squeeze(-1)
        # 将超出模型输入长度的位置设置为模型最大输入长度,防止超出范围
        ignored_index = start_logits.size(1)
        start_positions = start_positions.clamp(0, ignored_index)
        end_positions = end_positions.clamp(0, ignored_index)

        # 定义交叉熵损失函数,忽略指定的索引
        loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
        start_loss = loss_fct(start_logits, start_positions)
        end_loss = loss_fct(end_logits, end_positions)
        total_loss = (start_loss + end_loss) / 2

    if not return_dict:
        # 如果不返回字典形式的输出,返回起始位置 logits 和结束位置 logits
        output = (start_logits, end_logits) + outputs[2:]
        return ((total_loss,) + output) if total_loss is not None else output

    # 返回字典形式的输出,包括损失值、起始位置 logits、结束位置 logits、隐藏状态和注意力权重
    return QuestionAnsweringModelOutput(
        loss=total_loss,
        start_logits=start_logits,
        end_logits=end_logits,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )
posted @ 2024-06-30 15:34  绝不原创的飞龙  阅读(6)  评论(0编辑  收藏  举报