Transformers-源码解析-一百一十七-

Transformers 源码解析(一百一十七)

.\models\vision_text_dual_encoder\__init__.py

# 导入必要的模块和函数,包括一些自定义的异常和LazyModule
from typing import TYPE_CHECKING
from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_flax_available,
    is_tf_available,
    is_torch_available,
)

# 定义模块的导入结构,用于延迟加载模块
_import_structure = {
    "configuration_vision_text_dual_encoder": ["VisionTextDualEncoderConfig"],
    "processing_vision_text_dual_encoder": ["VisionTextDualEncoderProcessor"],
}

# 检查是否可用torch,如果不可用则抛出OptionalDependencyNotAvailable异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,则将VisionTextDualEncoderModel添加到_import_structure中
    _import_structure["modeling_vision_text_dual_encoder"] = ["VisionTextDualEncoderModel"]

# 检查是否可用flax,如果不可用则抛出OptionalDependencyNotAvailable异常
try:
    if not is_flax_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,则将FlaxVisionTextDualEncoderModel添加到_import_structure中
    _import_structure["modeling_flax_vision_text_dual_encoder"] = ["FlaxVisionTextDualEncoderModel"]

# 检查是否可用tensorflow,如果不可用则抛出OptionalDependencyNotAvailable异常
try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,则将TFVisionTextDualEncoderModel添加到_import_structure中
    _import_structure["modeling_tf_vision_text_dual_encoder"] = ["TFVisionTextDualEncoderModel"]

# 如果是类型检查模式,则从相应的模块导入具体的类
if TYPE_CHECKING:
    from .configuration_vision_text_dual_encoder import VisionTextDualEncoderConfig
    from .processing_vision_text_dual_encoder import VisionTextDualEncoderProcessor

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果可用,则从modeling_vision_text_dual_encoder模块导入VisionTextDualEncoderModel类
        from .modeling_vision_text_dual_encoder import VisionTextDualEncoderModel

    try:
        if not is_flax_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果可用,则从modeling_flax_vision_text_dual_encoder模块导入FlaxVisionTextDualEncoderModel类
        from .modeling_flax_vision_text_dual_encoder import FlaxVisionTextDualEncoderModel

    try:
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果可用,则从modeling_tf_vision_text_dual_encoder模块导入TFVisionTextDualEncoderModel类
        from .modeling_tf_vision_text_dual_encoder import TFVisionTextDualEncoderModel

# 如果不是类型检查模式,则将LazyModule应用到当前模块,以支持延迟导入
else:
    import sys

    # 将当前模块替换为LazyModule,用于延迟导入
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)

.\models\visual_bert\configuration_visual_bert.py

# 设置文件编码为 UTF-8

# 版权声明和许可信息,声明代码版权归 HuggingFace Inc. 团队所有,并遵循 Apache License, Version 2.0
# 根据许可证,除非符合许可协议,否则不得使用此文件

# 引入 VisualBERT 模型的配置类 PretrainedConfig 和日志记录工具 logging
from ...configuration_utils import PretrainedConfig
from ...utils import logging

# 获取当前模块的日志记录器对象
logger = logging.get_logger(__name__)

# 定义 VisualBERT 预训练模型及其对应的配置文件下载地址映射表
VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "uclanlp/visualbert-vqa": "https://huggingface.co/uclanlp/visualbert-vqa/resolve/main/config.json",
    "uclanlp/visualbert-vqa-pre": "https://huggingface.co/uclanlp/visualbert-vqa-pre/resolve/main/config.json",
    "uclanlp/visualbert-vqa-coco-pre": (
        "https://huggingface.co/uclanlp/visualbert-vqa-coco-pre/resolve/main/config.json"
    ),
    "uclanlp/visualbert-vcr": "https://huggingface.co/uclanlp/visualbert-vcr/resolve/main/config.json",
    "uclanlp/visualbert-vcr-pre": "https://huggingface.co/uclanlp/visualbert-vcr-pre/resolve/main/config.json",
    "uclanlp/visualbert-vcr-coco-pre": (
        "https://huggingface.co/uclanlp/visualbert-vcr-coco-pre/resolve/main/config.json"
    ),
    "uclanlp/visualbert-nlvr2": "https://huggingface.co/uclanlp/visualbert-nlvr2/resolve/main/config.json",
    "uclanlp/visualbert-nlvr2-pre": "https://huggingface.co/uclanlp/visualbert-nlvr2-pre/resolve/main/config.json",
    "uclanlp/visualbert-nlvr2-coco-pre": (
        "https://huggingface.co/uclanlp/visualbert-nlvr2-coco-pre/resolve/main/config.json"
    ),
    # 所有 VisualBERT 模型的列表可在 https://huggingface.co/models?filter=visual_bert 查看
}

# VisualBertConfig 类,继承自 PretrainedConfig 类
class VisualBertConfig(PretrainedConfig):
    r"""
    这是 VisualBERT 模型的配置类,用于存储 [`VisualBertModel`] 的配置信息。根据指定的参数实例化一个 VisualBERT 模型,
    定义模型架构。使用默认配置实例化一个配置对象将产生类似于 VisualBERT [uclanlp/visualbert-vqa-coco-pre]
    (https://huggingface.co/uclanlp/visualbert-vqa-coco-pre) 架构的配置。

    配置对象继承自 [`PretrainedConfig`],可用于控制模型的输出。阅读 [`PretrainedConfig`] 的文档以获取更多信息。

    Example:

    ```
    >>> from transformers import VisualBertConfig, VisualBertModel

    >>> # 初始化一个 VisualBERT visualbert-vqa-coco-pre 风格的配置
    >>> configuration = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
    ```

    """
    # 初始化一个 VisualBertModel 模型,使用给定的配置参数
    model = VisualBertModel(configuration)

    # 访问模型的配置信息
    configuration = model.config

.\models\visual_bert\convert_visual_bert_original_pytorch_checkpoint_to_pytorch.py

# coding=utf-8
# 定义代码文件的字符编码为UTF-8

# 版权声明
# 2021年由HuggingFace Inc.团队版权所有。
#
# 根据Apache许可证2.0版(“许可证”)授权;
# 您只能在符合许可证的情况下使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,软件
# 依据“现状”分发,没有任何形式的明示或暗示的担保或条件。
# 有关许可证的详细信息,请参阅
# 许可证。

"""Convert VisualBert checkpoint."""
# 头部文档字符串,指明本文件用途为转换VisualBert检查点。

import argparse
from collections import OrderedDict
from pathlib import Path

import torch

from transformers import (
    VisualBertConfig,
    VisualBertForMultipleChoice,
    VisualBertForPreTraining,
    VisualBertForQuestionAnswering,
    VisualBertForVisualReasoning,
)
# 导入所需的库和模块

from transformers.utils import logging

# 设置日志输出级别为info
logging.set_verbosity_info()
# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义需要重命名的键前缀列表
rename_keys_prefix = [
    ("bert.bert", "visual_bert"),
    ("bert.cls", "cls"),
    ("bert.classifier", "cls"),
    ("token_type_embeddings_visual", "visual_token_type_embeddings"),
    ("position_embeddings_visual", "visual_position_embeddings"),
    ("projection", "visual_projection"),
]

# 可接受的检查点文件名列表
ACCEPTABLE_CHECKPOINTS = [
    "nlvr2_coco_pre_trained.th",
    "nlvr2_fine_tuned.th",
    "nlvr2_pre_trained.th",
    "vcr_coco_pre_train.th",
    "vcr_fine_tune.th",
    "vcr_pre_train.th",
    "vqa_coco_pre_trained.th",
    "vqa_fine_tuned.th",
    "vqa_pre_trained.th",
]

# 加载模型状态字典的函数
def load_state_dict(checkpoint_path):
    # 使用CPU加载检查点文件的状态字典
    sd = torch.load(checkpoint_path, map_location="cpu")
    return sd

# 根据给定的状态字典和配置信息,生成适配VisualBert模型的新字典
def get_new_dict(d, config, rename_keys_prefix=rename_keys_prefix):
    new_d = OrderedDict()
    # 创建新字典中的'visual_bert.embeddings.position_ids'键,对应的值为一个torch张量,表示位置ID
    new_d["visual_bert.embeddings.position_ids"] = torch.arange(config.max_position_embeddings).expand((1, -1))
    # 遍历原始字典的键
    for key in d:
        if "detector" in key:
            # 如果键中包含'detector',则跳过处理
            continue
        new_key = key
        # 使用预定义的前缀对键名进行重命名
        for name_pair in rename_keys_prefix:
            new_key = new_key.replace(name_pair[0], name_pair[1])
        # 将重命名后的键值对添加到新字典中
        new_d[new_key] = d[key]
        # 特殊处理,如果键为'bert.cls.predictions.decoder.weight',则额外添加'decoder.bias'到新字典
        if key == "bert.cls.predictions.decoder.weight":
            new_d["cls.predictions.decoder.bias"] = new_d["cls.predictions.bias"]
    return new_d

# 无梯度运行的函数装饰器,用于转换VisualBert检查点
@torch.no_grad()
def convert_visual_bert_checkpoint(checkpoint_path, pytorch_dump_folder_path):
    """
    Copy/paste/tweak model's weights to our VisualBERT structure.
    """
    # 断言检查,确保提供的检查点文件名在可接受的检查点文件列表中
    assert (
        checkpoint_path.split("/")[-1] in ACCEPTABLE_CHECKPOINTS
    ), f"The checkpoint provided must be in {ACCEPTABLE_CHECKPOINTS}."

    # 获取配置信息
    # 如果检查点路径中包含字符串 "pre",则模型类型为预训练
    if "pre" in checkpoint_path:
        model_type = "pretraining"
        # 根据检查点路径中的特定字符串确定配置参数
        if "vcr" in checkpoint_path:
            config_params = {"visual_embedding_dim": 512}
        elif "vqa_advanced" in checkpoint_path:
            config_params = {"visual_embedding_dim": 2048}
        elif "vqa" in checkpoint_path:
            config_params = {"visual_embedding_dim": 2048}
        elif "nlvr" in checkpoint_path:
            config_params = {"visual_embedding_dim": 1024}
        else:
            # 如果未找到适合的实现,抛出 NotImplementedError 异常
            raise NotImplementedError(f"No implementation found for `{checkpoint_path}`.")
    else:
        # 如果检查点路径不包含 "pre",则根据其他字符串确定模型类型和配置参数
        if "vcr" in checkpoint_path:
            config_params = {"visual_embedding_dim": 512}
            model_type = "multichoice"
        elif "vqa_advanced" in checkpoint_path:
            config_params = {"visual_embedding_dim": 2048}
            model_type = "vqa_advanced"
        elif "vqa" in checkpoint_path:
            config_params = {"visual_embedding_dim": 2048, "num_labels": 3129}
            model_type = "vqa"
        elif "nlvr" in checkpoint_path:
            config_params = {
                "visual_embedding_dim": 1024,
                "num_labels": 2,
            }
            model_type = "nlvr"

    # 根据配置参数创建 VisualBertConfig 对象
    config = VisualBertConfig(**config_params)

    # 加载模型的状态字典
    state_dict = load_state_dict(checkpoint_path)

    # 根据状态字典和配置创建新的状态字典
    new_state_dict = get_new_dict(state_dict, config)

    # 根据模型类型选择相应的模型类实例化
    if model_type == "pretraining":
        model = VisualBertForPreTraining(config)
    elif model_type == "vqa":
        model = VisualBertForQuestionAnswering(config)
    elif model_type == "nlvr":
        model = VisualBertForVisualReasoning(config)
    elif model_type == "multichoice":
        model = VisualBertForMultipleChoice(config)

    # 将新的状态字典加载到模型中
    model.load_state_dict(new_state_dict)

    # 确保 PyTorch dump 文件夹路径存在,如果不存在则创建
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)

    # 将模型保存为 PyTorch 预训练格式
    model.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
    # 如果这个脚本是直接运行的(而不是被导入的),则执行以下代码
    parser = argparse.ArgumentParser()
    # 创建一个参数解析器对象

    # 必选参数
    parser.add_argument("orig_checkpoint_path", type=str, help="A path to .th on local filesystem.")
    # 添加一个必选参数:原始检查点文件的路径,必须是一个本地文件系统上的 .th 文件

    parser.add_argument("pytorch_dump_folder_path", type=str, help="Path to the output PyTorch model.")
    # 添加一个必选参数:输出 PyTorch 模型的路径

    args = parser.parse_args()
    # 解析命令行参数,并将其存储在 args 对象中

    convert_visual_bert_checkpoint(args.orig_checkpoint_path, args.pytorch_dump_folder_path)
    # 调用函数 convert_visual_bert_checkpoint,传入两个参数:原始检查点路径和输出模型路径

.\models\visual_bert\modeling_visual_bert.py

# coding=utf-8
# Copyright 2021 The UCLA NLP Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch VisualBERT model."""

# 导入需要的库和模块
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax

# 导入自定义模块和函数
from ...activations import ACT2FN
from ...modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPooling,
    MultipleChoiceModelOutput,
    SequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
    ModelOutput,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from .configuration_visual_bert import VisualBertConfig

# 获取日志记录器
logger = logging.get_logger(__name__)

# 文档字符串中使用的配置和检查点名称
_CONFIG_FOR_DOC = "VisualBertConfig"
_CHECKPOINT_FOR_DOC = "uclanlp/visualbert-vqa-coco-pre"

# 预训练模型存档列表
VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "uclanlp/visualbert-vqa",
    "uclanlp/visualbert-vqa-pre",
    "uclanlp/visualbert-vqa-coco-pre",
    "uclanlp/visualbert-vcr",
    "uclanlp/visualbert-vcr-pre",
    "uclanlp/visualbert-vcr-coco-pre",
    "uclanlp/visualbert-nlvr2",
    "uclanlp/visualbert-nlvr2-pre",
    "uclanlp/visualbert-nlvr2-coco-pre",
    # See all VisualBERT models at https://huggingface.co/models?filter=visual_bert
]

# VisualBertEmbeddings 类定义,构建来自单词、位置、标记类型嵌入和视觉嵌入的嵌入
class VisualBertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings and visual embeddings."""
    # 初始化方法,接受一个配置对象作为参数
    def __init__(self, config):
        # 调用父类(nn.Module)的初始化方法
        super().__init__()

        # 创建词嵌入层,用于将词索引映射到隐藏表示向量,参数包括词汇表大小、隐藏大小,以及填充标记的索引
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)

        # 创建位置嵌入层,用于将位置索引映射到隐藏表示向量,参数包括最大位置数和隐藏大小
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)

        # 创建类型嵌入层,用于将类型索引(如token类型)映射到隐藏表示向量,参数包括类型词汇表大小和隐藏大小
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # 创建 LayerNorm 层,用于标准化隐藏表示向量,参数包括隐藏大小和epsilon值
        # 注:变量名不使用蛇形命名以匹配 TensorFlow 模型变量名,方便加载 TensorFlow 的检查点文件
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        # 创建 Dropout 层,用于在训练过程中随机失活一部分神经元,防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # 创建并注册位置索引张量,用于将位置信息导入模型
        self.register_buffer(
            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
        )

        # 对于视觉特征处理
        # 创建视觉token类型嵌入层,用于映射视觉token类型索引到隐藏表示向量
        self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # 创建视觉位置嵌入层,用于映射视觉位置索引到隐藏表示向量
        self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)

        # 如果需要特殊的视觉初始化
        if config.special_visual_initialize:
            # 使用语言模型的token类型嵌入初始化视觉token类型嵌入
            self.visual_token_type_embeddings.weight.data = nn.Parameter(
                self.token_type_embeddings.weight.data.clone(), requires_grad=True
            )
            # 使用语言模型的位置嵌入初始化视觉位置嵌入
            self.visual_position_embeddings.weight.data = nn.Parameter(
                self.position_embeddings.weight.data.clone(), requires_grad=True
            )

        # 创建视觉投影层,用于将视觉嵌入维度投影到隐藏表示向量维度
        self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)
class VisualBertSelfAttention(nn.Module):
    # 定义 VisualBertSelfAttention 类,继承自 nn.Module
    def __init__(self, config):
        # 初始化函数,接收一个配置对象 config
        super().__init__()
        # 要求隐藏大小 config.hidden_size 必须能被注意力头数 config.num_attention_heads 整除,同时不应具有 embedding_size 属性
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        # 设置注意力头数和每个头的大小
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # 定义查询、键、值的线性变换层
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        # 定义 dropout 层
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        # 对输入张量 x 进行维度重排,以便进行注意力分数计算
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        output_attentions=False,
    ):
        # 生成混合的查询层
        mixed_query_layer = self.query(hidden_states)

        # 通过 transpose_for_scores 函数,生成键和值的张量
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))

        # 通过 transpose_for_scores 函数,生成查询的张量
        query_layer = self.transpose_for_scores(mixed_query_layer)

        # 计算注意力分数,使用点积操作在 "查询" 和 "键" 之间
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        # 将注意力分数除以 sqrt(注意力头大小),以进行缩放
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        # 如果存在注意力遮罩,则将其应用于注意力分数
        if attention_mask is not None:
            attention_scores = attention_scores + attention_mask

        # 将注意力分数归一化为概率
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # 应用 dropout 操作到注意力概率上
        attention_probs = self.dropout(attention_probs)

        # 如果存在头遮罩,则将其应用于注意力概率
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        # 计算上下文层,将注意力概率与值层进行加权求和
        context_layer = torch.matmul(attention_probs, value_layer)

        # 重新排列上下文层的维度,使其返回到初始的 hidden_states 的维度
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # 返回计算结果,如果需要输出注意力权重,则一并返回
        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        return outputs
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->VisualBert
class VisualBertSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 创建一个线性层,将输入特征的大小映射为输出特征的大小
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # LayerNorm 层,对隐藏状态进行归一化
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # Dropout 层,以一定的概率进行随机失活,防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 使用线性层进行特征变换
        hidden_states = self.dense(hidden_states)
        # 使用 Dropout 层进行随机失活
        hidden_states = self.dropout(hidden_states)
        # 使用 LayerNorm 进行归一化,并将输入张量与变换后的隐藏状态相加
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class VisualBertAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 创建自注意力层对象
        self.self = VisualBertSelfAttention(config)
        # 创建输出层对象
        self.output = VisualBertSelfOutput(config)
        # 用于记录被剪枝的注意力头集合
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        # 查找可剪枝的注意力头,并返回头部索引
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )

        # 剪枝线性层
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # 更新超参数并记录剪枝的注意力头
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        output_attentions=False,
    ):
        # 使用自注意力层处理隐藏状态
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions,
        )
        # 将自注意力层的输出传递给输出层处理
        attention_output = self.output(self_outputs[0], hidden_states)
        # 如果需要输出注意力矩阵,则将其添加到输出元组中
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs


# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->VisualBert
class VisualBertIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 创建一个线性层,将输入特征的大小映射为中间特征的大小
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        # 根据配置选择激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 使用线性层进行特征变换
        hidden_states = self.dense(hidden_states)
        # 使用中间激活函数进行非线性变换
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->VisualBert
# 定义 VisualBertOutput 类,继承自 nn.Module
class VisualBertOutput(nn.Module):
    # 初始化函数,接收 config 参数
    def __init__(self, config):
        super().__init__()
        # 创建一个全连接层,输入大小为 config.intermediate_size,输出大小为 config.hidden_size
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        # 创建 LayerNorm 层,对隐藏状态进行归一化,eps 参数为 config.layer_norm_eps
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 创建 Dropout 层,用于随机屏蔽神经元,防止过拟合,dropout 概率为 config.hidden_dropout_prob
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    # 前向传播函数,接收 hidden_states 和 input_tensor 两个张量,返回处理后的张量
    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 全连接层处理隐藏状态
        hidden_states = self.dense(hidden_states)
        # 应用 Dropout 层
        hidden_states = self.dropout(hidden_states)
        # 对处理后的隐藏状态应用 LayerNorm 层,并与输入张量相加
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


# 定义 VisualBertLayer 类,继承自 nn.Module
class VisualBertLayer(nn.Module):
    # 初始化函数,接收 config 参数
    def __init__(self, config):
        super().__init__()
        # 设定 feed forward 过程中的 chunk 大小
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        # 序列长度维度
        self.seq_len_dim = 1
        # 创建 VisualBertAttention 对象
        self.attention = VisualBertAttention(config)
        # 创建 VisualBertIntermediate 对象
        self.intermediate = VisualBertIntermediate(config)
        # 创建 VisualBertOutput 对象
        self.output = VisualBertOutput(config)

    # 前向传播函数,接收 hidden_states、attention_mask、head_mask 等参数,返回处理后的输出
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        output_attentions=False,
    ):
        # 调用 self.attention 的前向传播函数,处理注意力机制
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
        )
        # 获取注意力输出
        attention_output = self_attention_outputs[0]

        # 如果需要输出注意力权重,则将其添加到 outputs 中
        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights

        # 将注意力输出应用于前向传播的分块处理函数,处理后的结果作为层的输出
        layer_output = apply_chunking_to_forward(
            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
        )
        # 将层输出添加到 outputs 中
        outputs = (layer_output,) + outputs

        return outputs

    # feed forward 的分块处理函数,接收 attention_output 作为输入,返回处理后的层输出
    def feed_forward_chunk(self, attention_output):
        # 调用 self.intermediate 处理注意力输出
        intermediate_output = self.intermediate(attention_output)
        # 调用 self.output 处理 intermediate_output 和 attention_output,返回层输出
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output


# 定义 VisualBertEncoder 类,继承自 nn.Module
class VisualBertEncoder(nn.Module):
    # 初始化函数,接收 config 参数
    def __init__(self, config):
        super().__init__()
        # 将 config 存储在 self.config 中
        self.config = config
        # 创建 nn.ModuleList 存储多个 VisualBertLayer 层,层数为 config.num_hidden_layers
        self.layer = nn.ModuleList([VisualBertLayer(config) for _ in range(config.num_hidden_layers)])
        # 默认禁用梯度检查点
        self.gradient_checkpointing = False

    # 前向传播函数,接收 hidden_states、attention_mask、head_mask 等参数,返回处理后的输出
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True,
        ):
        # 遍历每个层进行前向传播,处理 hidden_states,返回输出
        for i, layer_module in enumerate(self.layer):
            layer_outputs = layer_module(
                hidden_states,
                attention_mask,
                head_mask[i],
                output_attentions=output_attentions,
            )
            hidden_states = layer_outputs[0]

        return hidden_states
        ):
        # 如果不输出隐藏状态,则将 all_hidden_states 初始化为空元组;否则设为 None
        all_hidden_states = () if output_hidden_states else None
        # 如果不输出注意力权重,则将 all_self_attentions 初始化为空元组;否则设为 None
        all_self_attentions = () if output_attentions else None

        # 遍历模型的每一层
        for i, layer_module in enumerate(self.layer):
            # 如果需要输出隐藏状态,则将当前隐藏状态加入 all_hidden_states 中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 获取当前层的头部掩码(如果有的话)
            layer_head_mask = head_mask[i] if head_mask is not None else None

            # 如果启用了梯度检查点且在训练模式下,使用梯度检查点函数进行前向传播
            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    output_attentions,
                )
            else:
                # 否则,直接调用当前层进行前向传播
                layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)

            # 更新隐藏状态为当前层的输出的第一个元素
            hidden_states = layer_outputs[0]
            # 如果需要输出注意力权重,则将当前层的注意力权重加入 all_self_attentions 中
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        # 如果需要输出隐藏状态,则最后将最终隐藏状态加入 all_hidden_states 中
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 如果不返回字典形式的结果,则将需要返回的值打包为元组返回
        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    all_hidden_states,
                    all_self_attentions,
                ]
                if v is not None
            )
        # 否则,以 BaseModelOutput 类型返回结果
        return BaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
        )
# Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->VisualBert
class VisualBertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化一个全连接层,输入和输出维度均为 config.hidden_size
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 定义激活函数为双曲正切函数
        self.activation = nn.Tanh()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 取隐藏状态的第一个 token 对应的隐藏状态作为汇总输出
        first_token_tensor = hidden_states[:, 0]
        # 将第一个 token 的隐藏状态输入全连接层得到汇总输出
        pooled_output = self.dense(first_token_tensor)
        # 对汇总输出应用激活函数
        pooled_output = self.activation(pooled_output)
        return pooled_output


# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->VisualBert
class VisualBertPredictionHeadTransform(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化一个全连接层,输入和输出维度均为 config.hidden_size
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 根据配置选择激活函数
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        # 应用 LayerNorm 对隐藏状态进行归一化
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 输入隐藏状态到全连接层
        hidden_states = self.dense(hidden_states)
        # 应用激活函数
        hidden_states = self.transform_act_fn(hidden_states)
        # 对结果应用 LayerNorm
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->VisualBert
class VisualBertLMPredictionHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化预测头部变换层
        self.transform = VisualBertPredictionHeadTransform(config)

        # 输出权重与输入嵌入相同,但每个 token 都有一个仅输出的偏置
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # 为每个 token 添加输出偏置
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))

        # 需要在 `resize_token_embeddings` 时链接这两个变量,以便正确调整偏置
        self.decoder.bias = self.bias

    def forward(self, hidden_states):
        # 输入隐藏状态到变换层
        hidden_states = self.transform(hidden_states)
        # 将变换后的隐藏状态输入到解码器中
        hidden_states = self.decoder(hidden_states)
        return hidden_states


# Copied from transformers.models.bert.modeling_bert.BertPreTrainingHeads with Bert->VisualBert
class VisualBertPreTrainingHeads(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化预测头部
        self.predictions = VisualBertLMPredictionHead(config)
        # 序列关系分类任务的线性层
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, sequence_output, pooled_output):
        # 对序列输出进行预测
        prediction_scores = self.predictions(sequence_output)
        # 对序列关系进行分类
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score


class VisualBertPreTrainedModel(PreTrainedModel):
    """
    VisualBert 的预训练模型基类
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = VisualBertConfig
    base_model_prefix = "visual_bert"
    supports_gradient_checkpointing = True

    def _init_weights(self, module):
        """Initialize the weights"""
        # 如果模块是线性层或者嵌入层,使用正态分布初始化权重
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # 与 TensorFlow 版本稍有不同,PyTorch 使用 normal_ 方法初始化
            # 参考 https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)

        # 如果模块是 LayerNorm 层,初始化偏置为零,权重为1
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

        # 如果模块是线性层且有偏置项,初始化偏置为零
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()
    """
    Output type of `VisualBertForPreTraining`.

    Args:
        loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
            Total loss as the sum of the masked language modeling loss and the sentence-image prediction
            (classification) loss.
        prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
            Prediction scores of the sentence-image prediction (classification) head (scores of True/False continuation
            before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """
    # 创建一个字符串,描述 VisualBert 模型的基本信息,没有任何特定的输出头部
    "The bare VisualBert Model transformer outputting raw hidden-states without any specific head on top.",
    # 使用 VISUAL_BERT_START_DOCSTRING 常量来继续完整的文档字符串
    VISUAL_BERT_START_DOCSTRING,
)

class VisualBertModel(VisualBertPreTrainedModel):
    """
    The model can behave as an encoder (with only self-attention) following the architecture described in [Attention is
    all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
    """

    def __init__(self, config, add_pooling_layer=True):
        super().__init__(config)
        self.config = config

        # Initialize VisualBert model components
        self.embeddings = VisualBertEmbeddings(config)  # VisualBert embeddings module
        self.encoder = VisualBertEncoder(config)        # VisualBert encoder module

        # Optionally add a pooling layer based on add_pooling_layer flag
        self.pooler = VisualBertPooler(config) if add_pooling_layer else None

        # Determine if to bypass the transformer and add an additional layer
        self.bypass_transformer = config.bypass_transformer
        if self.bypass_transformer:
            self.additional_layer = VisualBertLayer(config)  # Additional layer for bypassing transformer

        # Initialize weights and perform final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        # Iterate over layers and prune specified attention heads
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        visual_embeds: Optional[torch.FloatTensor] = None,
        visual_attention_mask: Optional[torch.LongTensor] = None,
        visual_token_type_ids: Optional[torch.LongTensor] = None,
        image_text_alignment: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        """
        Forward pass of the VisualBert model.

        Args:
            input_ids (torch.LongTensor, optional): Input token IDs. Defaults to None.
            attention_mask (torch.LongTensor, optional): Mask to avoid performing attention on padding token indices.
                Defaults to None.
            token_type_ids (torch.LongTensor, optional): Segment token indices to differentiate image and text tokens.
                Defaults to None.
            position_ids (torch.LongTensor, optional): Indices of each input token in its position. Defaults to None.
            head_mask (torch.LongTensor, optional): Mask to nullify selected heads of the self-attention modules.
                Defaults to None.
            inputs_embeds (torch.FloatTensor, optional): Embedded input tokens. Defaults to None.
            visual_embeds (torch.FloatTensor, optional): Embedded visual features. Defaults to None.
            visual_attention_mask (torch.LongTensor, optional): Mask for visual features to avoid attending on padding.
                Defaults to None.
            visual_token_type_ids (torch.LongTensor, optional): Segment token indices for visual inputs. Defaults to None.
            image_text_alignment (torch.LongTensor, optional): Alignment between image and text tokens. Defaults to None.
            output_attentions (bool, optional): Whether to output attentions weights. Defaults to None.
            output_hidden_states (bool, optional): Whether to output hidden states. Defaults to None.
            return_dict (bool, optional): Whether to return a dictionary. Defaults to None.

        Returns:
            BaseModelOutputWithPooling: Output of the VisualBert model with optional pooling.
        """
        # Implementation of forward pass for VisualBert model
        pass

@add_start_docstrings(
    """
    VisualBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
    `sentence-image prediction (classification)` head.
    """,
    VISUAL_BERT_START_DOCSTRING,
)
class VisualBertForPreTraining(VisualBertPreTrainedModel):
    _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
    # 初始化方法,接受一个配置对象作为参数
    def __init__(self, config):
        # 调用父类的初始化方法,传递配置对象
        super().__init__(config)

        # 创建 VisualBertModel 实例,传递配置对象作为参数
        self.visual_bert = VisualBertModel(config)
        # 创建 VisualBertPreTrainingHeads 实例,传递配置对象作为参数
        self.cls = VisualBertPreTrainingHeads(config)

        # 调用 post_init 方法,用于初始化权重并进行最终处理
        self.post_init()

    # 返回预测的输出嵌入层解码器
    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    # 设置输出嵌入层解码器的新值
    def set_output_embeddings(self, new_embeddings):
        self.cls.predictions.decoder = new_embeddings

    # 前向传播方法,接受多个输入参数,包括输入的序列、注意力掩码、类型 ID 等
    @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=VisualBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        visual_embeds: Optional[torch.FloatTensor] = None,
        visual_attention_mask: Optional[torch.LongTensor] = None,
        visual_token_type_ids: Optional[torch.LongTensor] = None,
        image_text_alignment: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: Optional[torch.LongTensor] = None,
        sentence_image_labels: Optional[torch.LongTensor] = None,
@add_start_docstrings(
    """
    VisualBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
    a softmax) e.g. for VCR tasks.
    """,
    VISUAL_BERT_START_DOCSTRING,
)



class VisualBertForMultipleChoice(VisualBertPreTrainedModel):
    # 初始化方法,接受一个配置对象作为参数
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)

        # 创建 VisualBertModel 模型对象
        self.visual_bert = VisualBertModel(config)
        # Dropout 层,使用配置中的隐藏层 dropout 概率
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # 全连接层,将隐藏层输出映射到 1 维,用于多选题任务
        self.cls = nn.Linear(config.hidden_size, 1)

        # 初始化权重并应用最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(
        VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
    )
    @replace_return_docstrings(output_type=MultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC)
    # 前向传播方法,接收多个输入参数和一些可选参数
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        visual_embeds: Optional[torch.FloatTensor] = None,
        visual_attention_mask: Optional[torch.LongTensor] = None,
        visual_token_type_ids: Optional[torch.LongTensor] = None,
        image_text_alignment: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: Optional[torch.LongTensor] = None,



@add_start_docstrings(
    """
    VisualBert Model with a classification/regression head on top (a dropout and a linear layer on top of the pooled
    output) for VQA.
    """,
    VISUAL_BERT_START_DOCSTRING,
)



class VisualBertForQuestionAnswering(VisualBertPreTrainedModel):
    # 初始化方法,接受一个配置对象作为参数
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)
        # 标签数量等于配置中的标签数
        self.num_labels = config.num_labels

        # 创建 VisualBertModel 模型对象
        self.visual_bert = VisualBertModel(config)
        # Dropout 层,使用配置中的隐藏层 dropout 概率
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # 全连接层,将隐藏层输出映射到标签数维度,用于问答任务
        self.cls = nn.Linear(config.hidden_size, config.num_labels)

        # 初始化权重并应用最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
        # 定义前向传播方法,接受多种输入参数,都是可选的Tensor类型
        self,
        input_ids: Optional[torch.LongTensor] = None,
        # 输入的token IDs,用于模型输入序列的表示
        attention_mask: Optional[torch.LongTensor] = None,
        # 注意力掩码,指示哪些元素在计算注意力时被忽略
        token_type_ids: Optional[torch.LongTensor] = None,
        # 用于区分不同句子或段落的类型IDs
        position_ids: Optional[torch.LongTensor] = None,
        # 位置IDs,指示每个token在序列中的位置
        head_mask: Optional[torch.LongTensor] = None,
        # 头部掩码,用于指定哪些注意力头部被屏蔽
        inputs_embeds: Optional[torch.FloatTensor] = None,
        # 输入的嵌入表示,替代input_ids的嵌入
        visual_embeds: Optional[torch.FloatTensor] = None,
        # 可视化输入的嵌入表示,例如图像的嵌入
        visual_attention_mask: Optional[torch.LongTensor] = None,
        # 可视化输入的注意力掩码
        visual_token_type_ids: Optional[torch.LongTensor] = None,
        # 可视化输入的类型IDs
        image_text_alignment: Optional[torch.LongTensor] = None,
        # 图像与文本对齐信息
        output_attentions: Optional[bool] = None,
        # 是否输出注意力权重
        output_hidden_states: Optional[bool] = None,
        # 是否输出隐藏状态
        return_dict: Optional[bool] = None,
        # 是否返回字典形式的输出
        labels: Optional[torch.LongTensor] = None,
        # 标签,用于模型的监督学习训练
@add_start_docstrings(
    """
    VisualBert Model with a sequence classification head on top (a dropout and a linear layer on top of the pooled
    output) for Visual Reasoning e.g. for NLVR task.
    """,
    VISUAL_BERT_START_DOCSTRING,
)
class VisualBertForVisualReasoning(VisualBertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        # Initialize the VisualBertModel with the provided configuration
        self.visual_bert = VisualBertModel(config)
        
        # Dropout layer with dropout probability as specified in the configuration
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        
        # Linear layer for classification, mapping hidden_size to num_labels
        self.cls = nn.Linear(config.hidden_size, config.num_labels)  # 2

        # Initialize weights and apply final processing
        self.post_init()

    @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        visual_embeds: Optional[torch.FloatTensor] = None,
        visual_attention_mask: Optional[torch.LongTensor] = None,
        visual_token_type_ids: Optional[torch.LongTensor] = None,
        image_text_alignment: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: Optional[torch.LongTensor] = None,
):
    class VisualBertRegionToPhraseAttention(nn.Module):
        def __init__(self, config):
            super().__init__()
            if config.hidden_size % config.num_attention_heads != 0:
                raise ValueError(
                    f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                    f"heads ({config.num_attention_heads})"
                )
            # Number of attention heads is set to 1 for this module
            self.num_attention_heads = 1  # config.num_attention_heads
            self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
            self.all_head_size = self.num_attention_heads * self.attention_head_size

            # Linear transformations for query, key, and value vectors
            self.query = nn.Linear(config.hidden_size, self.all_head_size)
            self.key = nn.Linear(config.hidden_size, self.all_head_size)
            self.value = nn.Linear(config.hidden_size, self.all_head_size)

            # Dropout layer for attention probabilities
            self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

        def transpose_for_scores(self, x):
            # Reshape and permute dimensions for multi-head attention
            new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
            x = x.view(*new_x_shape)
            return x.permute(0, 2, 1, 3)
    # 将注意力掩码转换为与查询张量相同的数据类型
    attention_mask = attention_mask.to(query.dtype)
    # 在张量的维度上扩展注意力掩码,以便与后续张量计算兼容
    attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
    # 使用注意力掩码的逆来填充注意力分数张量,以便在softmax操作中不被考虑
    attention_mask = (1.0 - attention_mask) * torch.finfo(query.dtype).min

    # 使用查询网络层处理查询张量,生成混合查询层
    mixed_query_layer = self.query(query)
    # 使用键网络层处理键张量,生成混合键层
    mixed_key_layer = self.key(key)

    # 将混合查询层转置以便进行注意力得分计算
    query_layer = self.transpose_for_scores(mixed_query_layer)
    # 将混合键层转置以便进行注意力得分计算
    key_layer = self.transpose_for_scores(mixed_key_layer)

    # 计算注意力得分矩阵,使用查询层与键层的乘积
    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

    # 对注意力得分进行缩放,以减少梯度消失问题
    attention_scores = attention_scores / math.sqrt(self.attention_head_size)

    # 将注意力掩码添加到注意力得分中,屏蔽无效位置的注意力
    attention_scores = attention_scores + attention_mask

    # 去除不必要的维度,使得注意力得分张量的形状与预期一致
    attention_scores = attention_scores.squeeze(1)
    # 返回最终的注意力得分张量作为前向传播的输出
    return attention_scores
# 使用装饰器为该类添加文档字符串,描述其作为 VisualBert 模型的特性和用途,特别是在 Flickr30 Entities 任务中的应用。
@add_start_docstrings(
    """
    VisualBert Model with a Masked Language Modeling head and an attention layer on top for Region-to-Phrase Alignment
    e.g. for Flickr30 Entities task.
    """,
    VISUAL_BERT_START_DOCSTRING,  # 引用预定义的 VisualBert 文档字符串模板的一部分
)
# 定义 VisualBertForRegionToPhraseAlignment 类,继承自 VisualBertPreTrainedModel
class VisualBertForRegionToPhraseAlignment(VisualBertPreTrainedModel):
    # 定义 _tied_weights_keys 列表,用于记录需要绑定权重的键名
    _tied_weights_keys = ["cls.predictions.decoder.bias"]

    # 初始化方法,接受一个 config 参数
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)

        # 创建 VisualBertModel 实例,传入 config 参数
        self.visual_bert = VisualBertModel(config)
        # 创建一个 dropout 层,使用配置中的隐藏层 dropout 概率
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # 创建 VisualBertPreTrainingHeads 实例,用于预测任务
        self.cls = VisualBertPreTrainingHeads(config)
        # 创建 VisualBertRegionToPhraseAttention 实例,处理区域到短语的注意力对齐
        self.attention = VisualBertRegionToPhraseAttention(config)

        # 调用类的后初始化方法,可能用于权重初始化和最终处理
        self.post_init()

    # 使用装饰器为 forward 方法添加文档字符串,描述其输入参数和返回类型
    @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
    # 定义 forward 方法,处理模型的前向传播
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        visual_embeds: Optional[torch.FloatTensor] = None,
        visual_attention_mask: Optional[torch.LongTensor] = None,
        visual_token_type_ids: Optional[torch.LongTensor] = None,
        image_text_alignment: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        region_to_phrase_position: Optional[torch.LongTensor] = None,
        labels: Optional[torch.LongTensor] = None,

.\models\visual_bert\__init__.py

# 版权声明和许可证信息
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入必要的类型检查模块
from typing import TYPE_CHECKING

# 导入自定义的异常和模块延迟加载函数
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available

# 定义导入结构字典,包含导入的模块和变量列表
_import_structure = {"configuration_visual_bert": ["VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "VisualBertConfig"]}

# 尝试检查是否可用 Torch 库,若不可用则引发自定义异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果 Torch 可用,则添加模型相关的导入结构
    _import_structure["modeling_visual_bert"] = [
        "VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
        "VisualBertForMultipleChoice",
        "VisualBertForPreTraining",
        "VisualBertForQuestionAnswering",
        "VisualBertForRegionToPhraseAlignment",
        "VisualBertForVisualReasoning",
        "VisualBertLayer",
        "VisualBertModel",
        "VisualBertPreTrainedModel",
    ]

# 如果是类型检查模式
if TYPE_CHECKING:
    # 导入配置相关的类和变量
    from .configuration_visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig

    # 尝试检查是否可用 Torch 库,若不可用则忽略
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入模型相关的类和变量
        from .modeling_visual_bert import (
            VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
            VisualBertForMultipleChoice,
            VisualBertForPreTraining,
            VisualBertForQuestionAnswering,
            VisualBertForRegionToPhraseAlignment,
            VisualBertForVisualReasoning,
            VisualBertLayer,
            VisualBertModel,
            VisualBertPreTrainedModel,
        )

# 如果不是类型检查模式,则配置 LazyModule 并添加到当前模块中
else:
    import sys

    # 使用 LazyModule 来延迟加载模块的定义
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\vit\configuration_vit.py

# coding=utf-8
# Copyright 2021 Google AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
ViT model configuration
"""

from collections import OrderedDict  # 导入有序字典模块
from typing import Mapping  # 导入类型提示模块 Mapping

from packaging import version  # 导入版本控制模块

from ...configuration_utils import PretrainedConfig  # 导入预训练模型配置工具
from ...onnx import OnnxConfig  # 导入ONNX配置
from ...utils import logging  # 导入日志工具

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

VIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "google/vit-base-patch16-224": "https://huggingface.co/vit-base-patch16-224/resolve/main/config.json",
    # See all ViT models at https://huggingface.co/models?filter=vit
}


class ViTConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`ViTModel`]. It is used to instantiate an ViT
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the ViT
    [google/vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    """
        # 定义 ViT 模型的配置类
        Args:
            hidden_size (`int`, *optional*, defaults to 768):
                编码器层和池化层的维度。
            num_hidden_layers (`int`, *optional*, defaults to 12):
                Transformer 编码器中的隐藏层数量。
            num_attention_heads (`int`, *optional*, defaults to 12):
                Transformer 编码器中每个注意力层的注意力头数量。
            intermediate_size (`int`, *optional*, defaults to 3072):
                Transformer 编码器中“中间”(即前馈)层的维度。
            hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
                编码器和池化器中的非线性激活函数(函数或字符串)。如果是字符串,支持 "gelu"、"relu"、"selu" 和 "gelu_new"。
            hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
                嵌入、编码器和池化层中所有全连接层的 dropout 概率。
            attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
                注意力概率的 dropout 比例。
            initializer_range (`float`, *optional*, defaults to 0.02):
                用于初始化所有权重矩阵的截断正态初始化器的标准差。
            layer_norm_eps (`float`, *optional*, defaults to 1e-12):
                层规范化层使用的 epsilon。
            image_size (`int`, *optional*, defaults to 224):
                每个图像的大小(分辨率)。
            patch_size (`int`, *optional*, defaults to 16):
                每个图块的大小(分辨率)。
            num_channels (`int`, *optional*, defaults to 3):
                输入通道的数量。
            qkv_bias (`bool`, *optional*, defaults to `True`):
                是否为查询、键和值添加偏置。
            encoder_stride (`int`, *optional*, defaults to 16):
                用于遮蔽图像建模中解码器头部中空间分辨率的增加因子。

        Example:

        ```
        >>> from transformers import ViTConfig, ViTModel

        >>> # 初始化一个 ViT vit-base-patch16-224 风格的配置
        >>> configuration = ViTConfig()

        >>> # 根据 vit-base-patch16-224 风格的配置初始化一个模型(使用随机权重)
        >>> model = ViTModel(configuration)

        >>> # 访问模型配置
        >>> configuration = model.config
        ```"""
        # 调用父类的初始化方法,传递所有的关键字参数
        super().__init__(**kwargs)

        # 设置隐藏层的大小
        self.hidden_size = hidden_size
        # 设置隐藏层的数量
        self.num_hidden_layers = num_hidden_layers
        # 设置注意力头的数量
        self.num_attention_heads = num_attention_heads
        # 设置中间层的大小
        self.intermediate_size = intermediate_size
        # 设置隐藏层的激活函数类型
        self.hidden_act = hidden_act
        # 设置隐藏层的dropout概率
        self.hidden_dropout_prob = hidden_dropout_prob
        # 设置注意力概率的dropout概率
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        # 设置初始化范围
        self.initializer_range = initializer_range
        # 设置层归一化的epsilon值
        self.layer_norm_eps = layer_norm_eps
        # 设置图像的大小
        self.image_size = image_size
        # 设置图像块的大小
        self.patch_size = patch_size
        # 设置通道的数量
        self.num_channels = num_channels
        # 设置查询-键-值的偏置
        self.qkv_bias = qkv_bias
        # 设置编码器的步长
        self.encoder_stride = encoder_stride
class ViTOnnxConfig(OnnxConfig):
    # 定义一个继承自OnnxConfig的类ViTOnnxConfig

    # 设置torch_onnx_minimum_version属性为最低要求的版本号为1.11
    torch_onnx_minimum_version = version.parse("1.11")

    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        # 定义一个属性inputs,返回一个OrderedDict,其中包含输入数据的描述
        return OrderedDict(
            [
                ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
                # 描述像素值输入的结构,使用字典表示各维度的含义
            ]
        )

    @property
    def atol_for_validation(self) -> float:
        # 定义一个属性atol_for_validation,返回用于验证的绝对误差阈值为1e-4
        return 1e-4

.\models\vit\convert_dino_to_pytorch.py

# coding=utf-8
# 定义编码格式为 UTF-8,确保脚本中可以使用中文注释和字符串

# 版权声明和许可证信息,指明代码的使用条款和条件
# Copyright 2021 The HuggingFace Inc. team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Convert ViT checkpoints trained with the DINO method."""

# 导入必要的模块和库
import argparse  # 用于解析命令行参数
import json  # 用于处理 JSON 格式数据
from pathlib import Path  # 用于处理文件路径的类

import requests  # 用于进行网络请求
import torch  # PyTorch 深度学习框架
from huggingface_hub import hf_hub_download  # 从 Hugging Face Hub 下载模型和文件
from PIL import Image  # Python Imaging Library,用于处理图像

# 导入需要转换的 ViT 相关类和函数
from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor, ViTModel
from transformers.utils import logging  # 引入日志记录功能

# 设置日志输出等级为信息级别
logging.set_verbosity_info()
# 获取当前模块的日志记录器对象
logger = logging.get_logger(__name__)

# 定义函数:生成需要重命名的键值对列表
# 根据 ViT 模型配置,将原始的键名映射为新的键名
def create_rename_keys(config, base_model=False):
    rename_keys = []
    # 遍历 ViT 模型的所有隐藏层
    for i in range(config.num_hidden_layers):
        # 对每一层进行重命名映射
        rename_keys.append((f"blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight"))
        rename_keys.append((f"blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias"))
        rename_keys.append((f"blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight"))
        rename_keys.append((f"blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias"))
        rename_keys.append((f"blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight"))
        rename_keys.append((f"blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias"))
        rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight"))
        rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias"))
        rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight"))
        rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias"))

    # 还需重命名投影层和位置嵌入
    rename_keys.extend(
        [
            ("cls_token", "vit.embeddings.cls_token"),
            ("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"),
            ("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"),
            ("pos_embed", "vit.embeddings.position_embeddings"),
        ]
    )
    # 如果存在基础模型(base_model为真),执行以下操作:
    if base_model:
        # 将以下键值对追加到rename_keys列表中,用于后续重命名
        rename_keys.extend(
            [
                ("norm.weight", "layernorm.weight"),  # 将"norm.weight"重命名为"layernorm.weight"
                ("norm.bias", "layernorm.bias"),      # 将"norm.bias"重命名为"layernorm.bias"
            ]
        )

        # 对于以"vit"开头的所有键名,去除开头的"vit"(如果仅有基础模型时使用)
        rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys]
    else:
        # 如果不存在基础模型,执行以下操作:
        # 将以下键值对追加到rename_keys列表中,用于后续重命名
        rename_keys.extend(
            [
                ("norm.weight", "vit.layernorm.weight"),   # 将"norm.weight"重命名为"vit.layernorm.weight"
                ("norm.bias", "vit.layernorm.bias"),       # 将"norm.bias"重命名为"vit.layernorm.bias"
                ("head.weight", "classifier.weight"),      # 将"head.weight"重命名为"classifier.weight"
                ("head.bias", "classifier.bias"),          # 将"head.bias"重命名为"classifier.bias"
            ]
        )

    # 返回处理后的rename_keys列表,其中包含了根据条件不同而进行的键重命名操作
    return rename_keys
# 将每个编码器层的矩阵分割为查询(queries)、键(keys)和值(values)
def read_in_q_k_v(state_dict, config, base_model=False):
    # 遍历每一个编码器层
    for i in range(config.num_hidden_layers):
        # 如果是基础模型,则前缀为空字符串;否则前缀为"vit."
        if base_model:
            prefix = ""
        else:
            prefix = "vit."
        
        # 从状态字典中弹出输入投影层的权重和偏置(在timm中,这是一个单独的矩阵加上偏置)
        in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
        in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
        
        # 将查询(query)、键(key)、和值(value)依次添加到状态字典中
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
            : config.hidden_size, :
        ]
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
            config.hidden_size : config.hidden_size * 2, :
        ]
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
            config.hidden_size : config.hidden_size * 2
        ]
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
            -config.hidden_size :, :
        ]
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]


# 从状态字典中移除分类头部(classification head)
def remove_classification_head_(state_dict):
    ignore_keys = ["head.weight", "head.bias"]
    for k in ignore_keys:
        state_dict.pop(k, None)


# 将字典中的旧键(old)重命名为新键(new)
def rename_key(dct, old, new):
    val = dct.pop(old)
    dct[new] = val


# 在一张可爱猫咪的图片上准备我们的结果验证
def prepare_img():
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    im = Image.open(requests.get(url, stream=True).raw)
    return im


# 使用torch.no_grad()装饰器,将函数设置为无需梯度的上下文
@torch.no_grad()
def convert_vit_checkpoint(model_name, pytorch_dump_folder_path, base_model=True):
    """
    将模型的权重复制/粘贴/调整到我们的ViT结构中。
    """

    # 定义默认的ViT配置
    config = ViTConfig()
    
    # 如果模型名称的最后一个字符是"8",则设置patch_size为8
    if model_name[-1] == "8":
        config.patch_size = 8
    
    # 如果不是基础模型,则设置num_labels为1000,并加载对应的标签文件
    if not base_model:
        config.num_labels = 1000
        repo_id = "huggingface/label-files"
        filename = "imagenet-1k-id2label.json"
        id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
        id2label = {int(k): v for k, v in id2label.items()}
        config.id2label = id2label
        config.label2id = {v: k for k, v in id2label.items()}
    
    # 如果模型名称在指定的列表中,则设置ViT的隐藏层大小、中间层大小等
    if model_name in ["dino_vits8", "dino_vits16"]:
        config.hidden_size = 384
        config.intermediate_size = 1536
        config.num_hidden_layers = 12
        config.num_attention_heads = 6
    
    # 从torch hub加载原始模型
    original_model = torch.hub.load("facebookresearch/dino:main", model_name)
    # 将原始模型设置为评估模式
    original_model.eval()

    # 加载原始模型的状态字典,并移除/重命名一些键
    state_dict = original_model.state_dict()
    if base_model:
        # 如果指定了基础模型,移除分类头部分的参数
        remove_classification_head_(state_dict)
    
    # 根据配置文件创建需要重命名的键列表
    rename_keys = create_rename_keys(config, base_model=base_model)
    
    # 遍历重命名键列表,逐一重命名状态字典中的键
    for src, dest in rename_keys:
        rename_key(state_dict, src, dest)
    
    # 根据状态字典和配置信息读入查询、键、值的数据
    read_in_q_k_v(state_dict, config, base_model)

    # 加载 HuggingFace 模型
    if base_model:
        # 如果指定了基础模型,创建 ViT 模型对象(不添加池化层)并设置为评估模式
        model = ViTModel(config, add_pooling_layer=False).eval()
    else:
        # 否则创建用于图像分类的 ViT 模型对象并设置为评估模式
        model = ViTForImageClassification(config).eval()
    
    # 加载模型的状态字典
    model.load_state_dict(state_dict)

    # 使用 ViTImageProcessor 准备图像并编码
    image_processor = ViTImageProcessor()
    encoding = image_processor(images=prepare_img(), return_tensors="pt")
    pixel_values = encoding["pixel_values"]
    
    # 将图像数据输入模型并获取输出
    outputs = model(pixel_values)

    if base_model:
        # 如果指定了基础模型,还需要使用原始模型对图像进行预测
        final_hidden_state_cls_token = original_model(pixel_values)
        # 断言原始模型的分类标记的最终隐藏状态与当前模型输出的第一个位置的隐藏状态在给定的误差范围内相等
        assert torch.allclose(final_hidden_state_cls_token, outputs.last_hidden_state[:, 0, :], atol=1e-1)
    else:
        # 否则,直接使用原始模型获取分类 logits
        logits = original_model(pixel_values)
        # 断言原始模型输出的 logits 形状与当前模型输出的 logits 形状相等
        assert logits.shape == outputs.logits.shape
        # 断言两个 logits 张量在给定的误差范围内相等
        assert torch.allclose(logits, outputs.logits, atol=1e-3)

    # 创建存储 PyTorch 模型的文件夹(如果不存在)
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
    # 打印保存模型的消息
    print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
    # 将模型保存到指定路径
    model.save_pretrained(pytorch_dump_folder_path)
    # 打印保存图像处理器的消息
    print(f"Saving image processor to {pytorch_dump_folder_path}")
    # 将图像处理器保存到指定路径
    image_processor.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
    # 如果作为主程序运行,则执行以下代码块

    parser = argparse.ArgumentParser()
    # 创建参数解析器对象

    # 必填参数
    parser.add_argument(
        "--model_name",
        default="dino_vitb16",
        type=str,
        help="Name of the model trained with DINO you'd like to convert.",
    )
    # 模型名称,指定使用的 DINO 训练的模型名称,默认为 dino_vitb16

    parser.add_argument(
        "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
    )
    # PyTorch 模型输出目录的路径,用于存储转换后的模型

    parser.add_argument(
        "--base_model",
        action="store_true",
        help="Whether to only convert the base model (no projection head weights).",
    )
    # 是否仅转换基础模型(不包括投影头权重)

    parser.set_defaults(base_model=True)
    # 设置默认参数,base_model 默认为 True

    args = parser.parse_args()
    # 解析命令行参数并存储到 args 变量中

    convert_vit_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.base_model)
    # 调用 convert_vit_checkpoint 函数,传递解析后的参数 model_name、pytorch_dump_folder_path 和 base_model

.\models\vit\convert_vit_timm_to_pytorch.py

# 定义用于重命名权重键的函数,根据给定的配置和是否基于基础模型来生成重命名规则列表
def create_rename_keys(config, base_model=False):
    rename_keys = []
    # 遍历所有编码器层
    for i in range(config.num_hidden_layers):
        # 添加权重重命名规则:输入层的归一化权重
        rename_keys.append((f"blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight"))
        # 添加权重重命名规则:输入层的归一化偏置
        rename_keys.append((f"blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias"))
        # 添加权重重命名规则:注意力机制输出的投影权重
        rename_keys.append((f"blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight"))
        # 添加权重重命名规则:注意力机制输出的投影偏置
        rename_keys.append((f"blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias"))
        # 添加权重重命名规则:输出层的归一化权重
        rename_keys.append((f"blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight"))
        # 添加权重重命名规则:输出层的归一化偏置
        rename_keys.append((f"blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias"))
        # 添加权重重命名规则:中间层的全连接层1权重
        rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight"))
        # 添加权重重命名规则:中间层的全连接层1偏置
        rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias"))
        # 添加权重重命名规则:中间层的全连接层2权重
        rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight"))
        # 添加权重重命名规则:中间层的全连接层2偏置
        rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias"))

    # 添加权重重命名规则:CLS token
    rename_keys.append(("cls_token", "vit.embeddings.cls_token"))
    # 添加权重重命名规则:补丁嵌入的投影权重
    rename_keys.append(("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"))
    # 添加权重重命名规则:补丁嵌入的投影偏置
    rename_keys.append(("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"))
    # 添加权重重命名规则:位置嵌入
    rename_keys.append(("pos_embed", "vit.embeddings.position_embeddings"))

    return rename_keys
    # 如果存在基础模型(base_model不为None)
    if base_model:
        # 将以下键值对添加到rename_keys列表中,用于重命名模型参数:
        # 将"norm.weight"重命名为"layernorm.weight"
        # 将"norm.bias"重命名为"layernorm.bias"
        rename_keys.extend(
            [
                ("norm.weight", "layernorm.weight"),
                ("norm.bias", "layernorm.bias"),
            ]
        )

        # 如果仅有基础模型,需要移除所有以"vit"开头的键中的前缀"vit"
        # 对rename_keys中的每对键值对进行检查和可能的修改
        rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys]
    else:
        # 如果不仅有基础模型,而是包含分类头(layernorm + classification head)
        # 将以下键值对添加到rename_keys列表中,用于重命名模型参数:
        # 将"norm.weight"重命名为"vit.layernorm.weight"
        # 将"norm.bias"重命名为"vit.layernorm.bias"
        # 将"head.weight"重命名为"classifier.weight"
        # 将"head.bias"重命名为"classifier.bias"
        rename_keys.extend(
            [
                ("norm.weight", "vit.layernorm.weight"),
                ("norm.bias", "vit.layernorm.bias"),
                ("head.weight", "classifier.weight"),
                ("head.bias", "classifier.bias"),
            ]
        )

    # 返回最终的重命名后的键值对列表
    return rename_keys
# we split up the matrix of each encoder layer into queries, keys and values
def read_in_q_k_v(state_dict, config, base_model=False):
    # 遍历每个编码器层,分离出查询(query)、键(keys)和值(values)的权重和偏置
    for i in range(config.num_hidden_layers):
        if base_model:
            prefix = ""
        else:
            prefix = "vit."
        # 读取输入投影层的权重和偏置(在timm中,这是一个包含权重和偏置的单一矩阵)
        in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
        in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
        # 将查询(query)、键(keys)和值(values)按顺序添加到状态字典中
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
            : config.hidden_size, :
        ]
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
            config.hidden_size : config.hidden_size * 2, :
        ]
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
            config.hidden_size : config.hidden_size * 2
        ]
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
            -config.hidden_size :, :
        ]
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]


def remove_classification_head_(state_dict):
    # 移除状态字典中的分类头部权重和偏置项
    ignore_keys = ["head.weight", "head.bias"]
    for k in ignore_keys:
        state_dict.pop(k, None)


def rename_key(dct, old, new):
    # 重命名字典中的键
    val = dct.pop(old)
    dct[new] = val


# We will verify our results on an image of cute cats
def prepare_img():
    # 准备一张可爱猫咪的图像来验证结果
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    im = Image.open(requests.get(url, stream=True).raw)
    return im


@torch.no_grad()
def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path):
    """
    Copy/paste/tweak model's weights to our ViT structure.
    """

    # define default ViT configuration
    config = ViTConfig()
    base_model = False

    # load original model from timm
    timm_model = timm.create_model(vit_name, pretrained=True)
    timm_model.eval()

    # detect unsupported ViT models in transformers
    # fc_norm is present
    # 检测transformers中不支持的ViT模型
    if not isinstance(getattr(timm_model, "fc_norm", None), torch.nn.Identity):
        raise ValueError(f"{vit_name} is not supported in transformers because of the presence of fc_norm.")

    # use of global average pooling in combination (or without) class token
    # 检测在使用全局平均池化时(或没有类令牌时),transformers中不支持的ViT模型
    if getattr(timm_model, "global_pool", None) == "avg":
        raise ValueError(f"{vit_name} is not supported in transformers because of use of global average pooling.")

    # CLIP style vit with norm_pre layer present
    # 检测是否存在norm_pre层,以确定是否是类似CLIP风格的ViT模型
    # 检查是否为 CLIP 风格的 ViT,且其 norm_pre 层不是 torch.nn.Identity
    if "clip" in vit_name and not isinstance(getattr(timm_model, "norm_pre", None), torch.nn.Identity):
        raise ValueError(
            f"{vit_name} is not supported in transformers because it's a CLIP style ViT with norm_pre layer."
        )

    # 检查是否为 SigLIP 风格的 ViT,且具有 attn_pool 层
    if "siglip" in vit_name and getattr(timm_model, "global_pool", None) == "map":
        raise ValueError(
            f"{vit_name} is not supported in transformers because it's a SigLIP style ViT with attn_pool."
        )

    # 检查 ViT 模型的 blocks[0] 中是否使用了 layer scale
    if not isinstance(getattr(timm_model.blocks[0], "ls1", None), torch.nn.Identity) or not isinstance(
        getattr(timm_model.blocks[0], "ls2", None), torch.nn.Identity
    ):
        raise ValueError(f"{vit_name} is not supported in transformers because it uses a layer scale in its blocks.")

    # 检查是否为混合 ResNet-ViT 模型,即 patch_embed 不是 timm.layers.PatchEmbed 类型
    if not isinstance(timm_model.patch_embed, timm.layers.PatchEmbed):
        raise ValueError(f"{vit_name} is not supported in transformers because it is a hybrid ResNet-ViT.")

    # 从 patch embedding 子模块中获取 patch 大小和图像大小
    config.patch_size = timm_model.patch_embed.patch_size[0]
    config.image_size = timm_model.patch_embed.img_size[0]

    # 从 timm 模型中获取特定于架构的参数
    config.hidden_size = timm_model.embed_dim
    config.intermediate_size = timm_model.blocks[0].mlp.fc1.out_features
    config.num_hidden_layers = len(timm_model.blocks)
    config.num_attention_heads = timm_model.blocks[0].attn.num_heads

    # 检查模型是否有分类头
    if timm_model.num_classes != 0:
        # 设置分类标签数量
        config.num_labels = timm_model.num_classes
        # 推断出 timm 模型的 ImageNet 子集
        imagenet_subset = infer_imagenet_subset(timm_model)
        dataset_info = ImageNetInfo(imagenet_subset)
        # 设置 id 到 label 名称的映射和 label 名称到 id 的映射
        config.id2label = {i: dataset_info.index_to_label_name(i) for i in range(dataset_info.num_classes())}
        config.label2id = {v: k for k, v in config.id2label.items()}
    else:
        # 若没有分类头,则模型将被转换为仅提取特征的模式
        print(f"{vit_name} is going to be converted as a feature extractor only.")
        base_model = True

    # 加载原始模型的 state_dict
    state_dict = timm_model.state_dict()

    # 如果是基础模型,移除和重命名 state_dict 中的一些键
    if base_model:
        remove_classification_head_(state_dict)
    rename_keys = create_rename_keys(config, base_model)
    for src, dest in rename_keys:
        rename_key(state_dict, src, dest)
    read_in_q_k_v(state_dict, config, base_model)

    # 加载 HuggingFace 模型
    if base_model:
        model = ViTModel(config, add_pooling_layer=False).eval()
    else:
        model = ViTForImageClassification(config).eval()
    model.load_state_dict(state_dict)

    # 在图像处理器 ViTImageProcessor/DeiTImageProcessor 上检查图像的输出
    if "deit" in vit_name:
        image_processor = DeiTImageProcessor(size=config.image_size)
    # 如果存在基础模型,则使用 ViTImageProcessor 处理图像数据
    else:
        image_processor = ViTImageProcessor(size=config.image_size)
    
    # 对准备好的图像数据进行编码,返回 PyTorch 张量表示
    encoding = image_processor(images=prepare_img(), return_tensors="pt")
    
    # 提取像素数值
    pixel_values = encoding["pixel_values"]
    
    # 使用模型进行推理,得到输出结果
    outputs = model(pixel_values)

    # 如果存在基础模型:
    if base_model:
        # 使用 timm_model 提取特征,并进行形状断言
        timm_pooled_output = timm_model.forward_features(pixel_values)
        assert timm_pooled_output.shape == outputs.last_hidden_state.shape
        assert torch.allclose(timm_pooled_output, outputs.last_hidden_state, atol=1e-1)
    else:
        # 使用 timm_model 进行推理,得到 logits,并进行形状断言
        timm_logits = timm_model(pixel_values)
        assert timm_logits.shape == outputs.logits.shape
        assert torch.allclose(timm_logits, outputs.logits, atol=1e-3)

    # 确保指定路径下的文件夹存在,用于保存 PyTorch 模型
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
    
    # 打印保存模型的信息
    print(f"Saving model {vit_name} to {pytorch_dump_folder_path}")
    
    # 将模型保存到指定路径
    model.save_pretrained(pytorch_dump_folder_path)
    
    # 打印保存图像处理器的信息
    print(f"Saving image processor to {pytorch_dump_folder_path}")
    
    # 将图像处理器保存到指定路径
    image_processor.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
    # 如果作为主程序执行以下代码块

    parser = argparse.ArgumentParser()
    # 创建参数解析器对象

    # 必选参数
    parser.add_argument(
        "--vit_name",
        default="vit_base_patch16_224",
        type=str,
        help="Name of the ViT timm model you'd like to convert.",
    )
    # 添加一个名为--vit_name的参数,默认值为"vit_base_patch16_224",类型为字符串,用于指定要转换的 ViT 模型的名称

    parser.add_argument(
        "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
    )
    # 添加一个名为--pytorch_dump_folder_path的参数,值默认为None,类型为字符串,用于指定输出 PyTorch 模型的目录路径

    args = parser.parse_args()
    # 解析命令行参数并存储到args变量中

    convert_vit_checkpoint(args.vit_name, args.pytorch_dump_folder_path)
    # 调用函数convert_vit_checkpoint,传递解析得到的--vit_name和--pytorch_dump_folder_path参数


这段代码是一个命令行程序的入口点,使用argparse模块解析命令行参数,然后调用函数`convert_vit_checkpoint`进行处理。

.\models\vit\feature_extraction_vit.py

# coding=utf-8
# 定义文件编码格式为 UTF-8

# 版权声明和许可证信息,指定代码的版权和使用条款
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Feature extractor class for ViT."""
# 导入警告模块,用于向用户显示警告信息
import warnings

# 导入 logging 模块,用于记录日志
from ...utils import logging
# 从本地模块中导入 ViTImageProcessor 类
from .image_processing_vit import ViTImageProcessor

# 获取当前模块的日志记录器对象
logger = logging.get_logger(__name__)

# 定义 ViTFeatureExtractor 类,继承自 ViTImageProcessor 类
class ViTFeatureExtractor(ViTImageProcessor):
    def __init__(self, *args, **kwargs) -> None:
        # 发出关于类被弃用的警告信息,建议使用 ViTImageProcessor 类代替
        warnings.warn(
            "The class ViTFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
            " use ViTImageProcessor instead.",
            FutureWarning,
        )
        # 调用父类的构造函数,传入所有参数和关键字参数
        super().__init__(*args, **kwargs)

.\models\vit\image_processing_vit.py

# 导入必要的库和模块,包括类型提示、NumPy等
from typing import Dict, List, Optional, Union

import numpy as np

# 导入自定义的图像处理工具函数和类
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
# 导入图像变换函数和常量
from ...image_transforms import resize, to_channel_dimension_format
# 导入图像处理相关的工具函数和常量,如均值、标准差、通道格式等
from ...image_utils import (
    IMAGENET_STANDARD_MEAN,
    IMAGENET_STANDARD_STD,
    ChannelDimension,
    ImageInput,
    PILImageResampling,
    infer_channel_dimension_format,
    is_scaled_image,
    make_list_of_images,
    to_numpy_array,
    valid_images,
    validate_kwargs,
    validate_preprocess_arguments,
)
# 导入通用的工具函数,如日志记录
from ...utils import TensorType, logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义一个 ViT 图像处理器类,继承自 BaseImageProcessor 类
class ViTImageProcessor(BaseImageProcessor):
    r"""
    Constructs a ViT image processor.
    """
    Args:
        do_resize (`bool`, *optional*, defaults to `True`):
            Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
            size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
        size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
            Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
            method.
        resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
            Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
            `preprocess` method.
        do_rescale (`bool`, *optional*, defaults to `True`):
            Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
            parameter in the `preprocess` method.
        rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
            Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
            `preprocess` method.
        do_normalize (`bool`, *optional*, defaults to `True`):
            Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
            method.
        image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
            Mean to use if normalizing the image. This is a float or list of floats the length of the number of
            channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
        image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
            Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
            number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
    """
    # 设置模型输入名称为"pixel_values"
    model_input_names = ["pixel_values"]

    # 初始化函数,设置各个参数的默认值,可以通过`preprocess`方法中的对应参数进行覆盖
    def __init__(
        self,
        do_resize: bool = True,
        size: Optional[Dict[str, int]] = None,
        resample: PILImageResampling = PILImageResampling.BILINEAR,
        do_rescale: bool = True,
        rescale_factor: Union[int, float] = 1 / 255,
        do_normalize: bool = True,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        **kwargs,
    ) -> None:
        # 调用父类初始化方法,并传递所有关键字参数
        super().__init__(**kwargs)
        # 如果 size 参数不为 None,则使用指定的尺寸;否则使用默认尺寸 {"height": 224, "width": 224}
        size = size if size is not None else {"height": 224, "width": 224}
        # 根据 size 获取一个标准化的尺寸字典
        size = get_size_dict(size)
        # 初始化是否执行 resize 操作的标志
        self.do_resize = do_resize
        # 初始化是否执行 rescale 操作的标志
        self.do_rescale = do_rescale
        # 初始化是否执行 normalize 操作的标志
        self.do_normalize = do_normalize
        # 将 size 存储到对象属性中
        self.size = size
        # 设定图像 resize 时使用的 resample 方法
        self.resample = resample
        # 设定图像 rescale 时的缩放因子
        self.rescale_factor = rescale_factor
        # 如果 image_mean 不为 None,则使用给定的 image_mean;否则使用 IMAGENET_STANDARD_MEAN
        self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
        # 如果 image_std 不为 None,则使用给定的 image_std;否则使用 IMAGENET_STANDARD_STD
        self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
        # 初始化一个有效的处理器键列表,用于检查处理器的有效性
        self._valid_processor_keys = [
            "images",
            "do_resize",
            "size",
            "resample",
            "do_rescale",
            "rescale_factor",
            "do_normalize",
            "image_mean",
            "image_std",
            "return_tensors",
            "data_format",
            "input_data_format",
        ]
    ) -> np.ndarray:
        """
        Resize an image to `(size["height"], size["width"])`.

        Args:
            image (`np.ndarray`):
                Image to resize.
            size (`Dict[str, int]`):
                Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
            resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
                `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
            data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format for the output image. If unset, the channel dimension format of the input
                image is used. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
                - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
            input_data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format for the input image. If unset, the channel dimension format is inferred
                from the input image. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
                - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.

        Returns:
            `np.ndarray`: The resized image.
        """
        size = get_size_dict(size)  # 获取调整大小后的字典格式大小
        if "height" not in size or "width" not in size:
            raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
        output_size = (size["height"], size["width"])  # 获取调整后的输出大小
        return resize(
            image,
            size=output_size,
            resample=resample,
            data_format=data_format,
            input_data_format=input_data_format,
            **kwargs,
        )

    def preprocess(
        self,
        images: ImageInput,
        do_resize: Optional[bool] = None,
        size: Dict[str, int] = None,
        resample: PILImageResampling = None,
        do_rescale: Optional[bool] = None,
        rescale_factor: Optional[float] = None,
        do_normalize: Optional[bool] = None,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
        **kwargs,
        ):
        """
        Preprocesses a batch of images including resizing, rescaling, normalization, and tensor conversion.

        Args:
            images (`ImageInput`):
                Batch of input images.
            do_resize (`bool`, *optional*):
                Whether to resize the images. If `True`, resizing will be performed according to `size`.
            size (`Dict[str, int]`, *optional*):
                Dictionary specifying the target size for resizing each image in the batch.
            resample (`PILImageResampling`, *optional*):
                Resampling method to use for resizing, default is `PILImageResampling.BILINEAR`.
            do_rescale (`bool`, *optional*):
                Whether to rescale the images. If `True`, images will be scaled by `rescale_factor`.
            rescale_factor (`float`, *optional*):
                Scaling factor for rescaling images.
            do_normalize (`bool`, *optional*):
                Whether to normalize the images.
            image_mean (`float` or `List[float]`, *optional*):
                Mean values for normalization.
            image_std (`float` or `List[float]`, *optional*):
                Standard deviation values for normalization.
            return_tensors (`str` or `TensorType`, *optional*):
                If specified, converts the output to the desired tensor type (e.g., `"torch"` or `"tensorflow"`).
            data_format (`str` or `ChannelDimension`, *optional*):
                The channel dimension format for the output images.
            input_data_format (`str` or `ChannelDimension`, *optional*):
                The channel dimension format for the input images.

        Returns:
            `np.ndarray` or `TensorType`: Preprocessed batch of images.
        """

.\models\vit\modeling_flax_vit.py

#python
# 导入必要的模块和类
from typing import Optional, Tuple

import flax.linen as nn  # 导入 Flax 的 linen 模块
import jax  # 导入 JAX 库
import jax.numpy as jnp  # 导入 JAX 的 NumPy 接口
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze  # 导入 Flax 的 FrozenDict 相关函数
from flax.linen.attention import dot_product_attention_weights  # 导入 dot_product_attention_weights 函数
from flax.traverse_util import flatten_dict, unflatten_dict  # 导入 flatten_dict 和 unflatten_dict 函数

from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxSequenceClassifierOutput  # 导入输出相关类
from ...modeling_flax_utils import (  # 导入 FlaxPreTrainedModel 和其他实用函数
    ACT2FN,
    FlaxPreTrainedModel,
    append_replace_return_docstrings,
    overwrite_call_docstring,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward  # 导入添加文档字符串的函数
from .configuration_vit import ViTConfig  # 导入 ViTConfig 配置类

VIT_START_DOCSTRING = r"""

    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading, saving and converting weights from PyTorch models)

    This model is also a
    [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
    a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
    behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)

"""
    # Parameters: 定义函数的参数列表及其描述
    # config ([`ViTConfig`]): 使用ViTConfig类配置模型的参数
    #     通过配置文件初始化不会加载模型的权重,仅加载配置。请查看[`~FlaxPreTrainedModel.from_pretrained`]方法以加载模型权重。
    # dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
    #     计算时所使用的数据类型。可以是`jax.numpy.float32`、`jax.numpy.float16`(在GPU上)和`jax.numpy.bfloat16`(在TPU上)之一。
    #     
    #     可用于在GPU或TPU上启用混合精度训练或半精度推断。如果指定了dtype,则所有计算将使用给定的`dtype`执行。
    #     
    #     **请注意,此仅指定计算的数据类型,不影响模型参数的数据类型。**
    #     
    #     如果您希望更改模型参数的数据类型,请参阅[`~FlaxPreTrainedModel.to_fp16`]和[`~FlaxPreTrainedModel.to_bf16`]。
"""
Define a docstring for the module `FlaxViTPatchEmbeddings`.

This docstring provides detailed information about the inputs expected by the module, including:
- `pixel_values`: A numpy array containing pixel values of shape `(batch_size, num_channels, height, width)`.
  This array can be obtained using an `AutoImageProcessor`.
- `output_attentions`: An optional boolean indicating whether to return attention tensors from all layers.
- `output_hidden_states`: An optional boolean indicating whether to return hidden states from all layers.
- `return_dict`: An optional boolean indicating whether to return a `ModelOutput` instead of a tuple.
"""
class FlaxViTPatchEmbeddings(nn.Module):
    config: ViTConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):
        # Calculate number of patches in the image
        image_size = self.config.image_size
        patch_size = self.config.patch_size
        num_patches = (image_size // patch_size) * (image_size // patch_size)
        self.num_patches = num_patches
        self.num_channels = self.config.num_channels
        # Initialize a convolutional layer for projecting patches
        self.projection = nn.Conv(
            self.config.hidden_size,
            kernel_size=(patch_size, patch_size),
            strides=(patch_size, patch_size),
            padding="VALID",
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.variance_scaling(
                self.config.initializer_range**2, "fan_in", "truncated_normal"
            ),
        )

    def __call__(self, pixel_values):
        # Check if the number of channels in pixel values matches the configuration
        num_channels = pixel_values.shape[-1]
        if num_channels != self.num_channels:
            raise ValueError(
                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
            )
        # Project pixel values into patch embeddings
        embeddings = self.projection(pixel_values)
        batch_size, _, _, channels = embeddings.shape
        # Reshape embeddings to match the expected output format
        return jnp.reshape(embeddings, (batch_size, -1, channels))


class FlaxViTEmbeddings(nn.Module):
    """Construct the CLS token, position and patch embeddings."""

    config: ViTConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    # 在模型设置阶段初始化特殊的“CLS”标记,它是一个参数,根据给定的初始化器创建
    # 使用方差缩放初始化器(以配置的初始化范围的平方为参数),初始化“CLS”标记
    self.cls_token = self.param(
        "cls_token",
        jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"),
        (1, 1, self.config.hidden_size),
    )

    # 创建图像块的嵌入表示对象,使用ViT中的补丁嵌入(Patch Embeddings)
    self.patch_embeddings = FlaxViTPatchEmbeddings(self.config, dtype=self.dtype)

    # 计算图像块的数量,并初始化位置嵌入,将其视为模型参数
    num_patches = self.patch_embeddings.num_patches
    self.position_embeddings = self.param(
        "position_embeddings",
        jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"),
        (1, num_patches + 1, self.config.hidden_size),
    )

    # 初始化一个dropout层,根据配置中的隐藏层dropout概率
    self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

def __call__(self, pixel_values, deterministic=True):
    # 获取输入张量的批次大小
    batch_size = pixel_values.shape[0]

    # 对输入的像素值计算补丁嵌入(Patch Embeddings)
    embeddings = self.patch_embeddings(pixel_values)

    # 创建形状与批次大小匹配的“CLS”标记,并将其广播到嵌入的第二个维度
    cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size))

    # 将“CLS”标记与图像块嵌入连接起来,形成完整的嵌入表示
    embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1)

    # 将位置嵌入添加到嵌入向量中,以增强位置信息
    embeddings = embeddings + self.position_embeddings

    # 应用dropout,用于模型训练时的随机失活,以防止过拟合
    embeddings = self.dropout(embeddings, deterministic=deterministic)

    # 返回最终的嵌入表示作为模型的输出
    return embeddings
# 定义一个名为 FlaxViTSelfAttention 的自定义 PyTorch 模块
class FlaxViTSelfAttention(nn.Module):
    # ViT 模型的配置信息
    config: ViTConfig
    # 计算时所使用的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    # 模块的初始化方法
    def setup(self):
        # 检查 hidden_size 是否能被 num_attention_heads 整除
        if self.config.hidden_size % self.config.num_attention_heads != 0:
            raise ValueError(
                "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`:"
                " {self.config.num_attention_heads}"
            )

        # 初始化查询(query)的全连接层
        self.query = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.variance_scaling(
                self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
            ),
            use_bias=self.config.qkv_bias,
        )
        # 初始化键(key)的全连接层
        self.key = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.variance_scaling(
                self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
            ),
            use_bias=self.config.qkv_bias,
        )
        # 初始化值(value)的全连接层
        self.value = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.variance_scaling(
                self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
            ),
            use_bias=self.config.qkv_bias,
        )

    # 模块的调用方法,实现自注意力机制
    def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False):
        # 计算每个注意力头的维度
        head_dim = self.config.hidden_size // self.config.num_attention_heads

        # 使用查询向量进行全连接操作并重塑形状
        query_states = self.query(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )
        # 使用值向量进行全连接操作并重塑形状
        value_states = self.value(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )
        # 使用键向量进行全连接操作并重塑形状
        key_states = self.key(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )

        # 初始化一个用于 dropout 的随机数生成器
        dropout_rng = None
        if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
            dropout_rng = self.make_rng("dropout")

        # 计算自注意力权重
        attn_weights = dot_product_attention_weights(
            query_states,
            key_states,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attention_probs_dropout_prob,
            broadcast_dropout=True,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        # 使用注意力权重和值向量计算最终的自注意力输出
        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
        # 重塑输出的形状
        attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))

        # 如果需要输出注意力权重,将其包含在输出中
        outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
        return outputs


# 定义一个名为 FlaxViTSelfOutput 的自定义 PyTorch 模块
class FlaxViTSelfOutput(nn.Module):
    # ViT 模型的配置信息
    config: ViTConfig
    # 计算时所使用的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    # 定义类的初始化方法,用于设置模型参数
    def setup(self):
        # 初始化一个全连接层,设置输出大小为self.config.hidden_size
        # 使用方差缩放初始化方法,基于截断的正态分布,参数为self.config.initializer_range的平方
        self.dense = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.variance_scaling(
                self.config.initializer_range**2, "fan_in", "truncated_normal"
            ),
            dtype=self.dtype,
        )
        # 初始化一个Dropout层,设置丢弃率为self.config.hidden_dropout_prob
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

    # 定义类的调用方法,用于模型推理过程
    def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
        # 将输入的隐藏状态通过全连接层进行线性变换
        hidden_states = self.dense(hidden_states)
        # 对线性变换后的隐藏状态应用Dropout层,用于随机丢弃部分神经元的输出
        # 根据deterministic参数决定是否以确定性方式进行操作
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        # 返回经过全连接层和Dropout层处理后的隐藏状态作为最终的模型输出
        return hidden_states
class FlaxViTAttention(nn.Module):
    config: ViTConfig  # 类属性,指定 ViT 的配置
    dtype: jnp.dtype = jnp.float32  # 类属性,默认使用 jnp.float32 数据类型

    def setup(self):
        self.attention = FlaxViTSelfAttention(self.config, dtype=self.dtype)
        self.output = FlaxViTSelfOutput(self.config, dtype=self.dtype)
        # 初始化 self.attention 和 self.output,使用指定的配置和数据类型

    def __call__(self, hidden_states, deterministic=True, output_attentions: bool = False):
        attn_outputs = self.attention(hidden_states, deterministic=deterministic, output_attentions=output_attentions)
        # 调用 self.attention 对象处理 hidden_states,获取注意力输出
        attn_output = attn_outputs[0]
        hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
        # 使用 self.output 处理注意力输出和 hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (attn_outputs[1],)
            # 如果需要输出注意力信息,则将注意力权重信息添加到输出中

        return outputs
        # 返回处理后的 hidden_states 和可能的注意力信息


class FlaxViTIntermediate(nn.Module):
    config: ViTConfig  # 类属性,指定 ViT 的配置
    dtype: jnp.dtype = jnp.float32  # 类属性,默认使用 jnp.float32 数据类型

    def setup(self):
        self.dense = nn.Dense(
            self.config.intermediate_size,
            kernel_init=jax.nn.initializers.variance_scaling(
                self.config.initializer_range**2, "fan_in", "truncated_normal"
            ),
            dtype=self.dtype,
        )
        # 初始化 self.dense,使用指定的中间层大小和初始化方法

        self.activation = ACT2FN[self.config.hidden_act]
        # 根据配置选择激活函数

    def __call__(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        # 使用 self.dense 处理 hidden_states
        hidden_states = self.activation(hidden_states)
        # 使用选择的激活函数处理 hidden_states
        return hidden_states
        # 返回处理后的 hidden_states


class FlaxViTOutput(nn.Module):
    config: ViTConfig  # 类属性,指定 ViT 的配置
    dtype: jnp.dtype = jnp.float32  # 类属性,默认使用 jnp.float32 数据类型

    def setup(self):
        self.dense = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.variance_scaling(
                self.config.initializer_range**2, "fan_in", "truncated_normal"
            ),
            dtype=self.dtype,
        )
        # 初始化 self.dense,使用指定的输出层大小和初始化方法

        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
        # 初始化 self.dropout,使用指定的 dropout 比率

    def __call__(self, hidden_states, attention_output, deterministic: bool = True):
        hidden_states = self.dense(hidden_states)
        # 使用 self.dense 处理 hidden_states
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        # 使用 self.dropout 处理 hidden_states
        hidden_states = hidden_states + attention_output
        # 将处理后的 hidden_states 与 attention_output 相加作为最终输出
        return hidden_states
        # 返回处理后的 hidden_states


class FlaxViTLayer(nn.Module):
    config: ViTConfig  # 类属性,指定 ViT 的配置
    dtype: jnp.dtype = jnp.float32  # 类属性,默认使用 jnp.float32 数据类型

    def setup(self):
        self.attention = FlaxViTAttention(self.config, dtype=self.dtype)
        # 初始化 self.attention,使用指定的配置和数据类型
        self.intermediate = FlaxViTIntermediate(self.config, dtype=self.dtype)
        # 初始化 self.intermediate,使用指定的配置和数据类型
        self.output = FlaxViTOutput(self.config, dtype=self.dtype)
        # 初始化 self.output,使用指定的配置和数据类型
        self.layernorm_before = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        # 初始化 self.layernorm_before,使用指定的层归一化参数和数据类型
        self.layernorm_after = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        # 初始化 self.layernorm_after,使用指定的层归一化参数和数据类型
    # 定义一个调用方法,用于处理隐藏状态,接收是否确定性处理的参数和是否输出注意力的参数
    def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False):
        # 使用自注意力机制处理前的层归一化,这是 ViT 中的操作顺序
        attention_outputs = self.attention(
            self.layernorm_before(hidden_states),  # 在 ViT 中,自注意力前会进行层归一化
            deterministic=deterministic,
            output_attentions=output_attentions,
        )

        attention_output = attention_outputs[0]

        # 第一个残差连接
        attention_output = attention_output + hidden_states

        # 在 ViT 中,自注意力后同样会进行层归一化
        layer_output = self.layernorm_after(attention_output)

        # 经过中间层的处理
        hidden_states = self.intermediate(layer_output)

        # 输出层的最终处理,同时传入注意力输出和隐藏状态的处理结果
        hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)

        outputs = (hidden_states,)

        # 如果需要输出注意力的话,将注意力输出加入到返回的元组中
        if output_attentions:
            outputs += (attention_outputs[1],)
        
        # 返回处理后的输出结果元组
        return outputs
# 定义一个 FlaxViTLayerCollection 类,继承自 nn.Module
class FlaxViTLayerCollection(nn.Module):
    # 定义类变量 config,类型为 ViTConfig
    config: ViTConfig
    # 定义 dtype 变量,默认为 jnp.float32,用于计算的数据类型

    # 初始化函数,设置网络层集合
    def setup(self):
        # 创建 self.layers 列表,包含 self.config.num_hidden_layers 个 FlaxViTLayer 实例
        self.layers = [
            FlaxViTLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
        ]

    # 实现 __call__ 方法,用于模型的前向传播
    def __call__(
        self,
        hidden_states,  # 输入的隐藏状态张量
        deterministic: bool = True,  # 是否使用确定性计算,默认为 True
        output_attentions: bool = False,  # 是否输出注意力权重,默认为 False
        output_hidden_states: bool = False,  # 是否输出所有隐藏状态,默认为 False
        return_dict: bool = True,  # 是否返回字典格式的输出,默认为 True
    ):
        # 如果 output_attentions 为 True,则初始化 all_attentions 为空元组,否则为 None
        all_attentions = () if output_attentions else None
        # 如果 output_hidden_states 为 True,则初始化 all_hidden_states 为空元组,否则为 None
        all_hidden_states = () if output_hidden_states else None

        # 遍历 self.layers 中的每一层
        for i, layer in enumerate(self.layers):
            # 如果输出所有隐藏状态
            if output_hidden_states:
                # 将当前隐藏状态加入 all_hidden_states 中
                all_hidden_states += (hidden_states,)

            # 调用当前层的 __call__ 方法,进行前向传播
            layer_outputs = layer(hidden_states, deterministic=deterministic, output_attentions=output_attentions)

            # 更新 hidden_states 为当前层的输出的第一个元素(即隐藏状态)
            hidden_states = layer_outputs[0]

            # 如果输出注意力权重
            if output_attentions:
                # 将当前层的注意力权重加入 all_attentions 中
                all_attentions += (layer_outputs[1],)

        # 如果输出所有隐藏状态
        if output_hidden_states:
            # 将最终的隐藏状态加入 all_hidden_states 中
            all_hidden_states += (hidden_states,)

        # 将最终的隐藏状态作为元组 outputs 的第一个元素
        outputs = (hidden_states,)

        # 如果不返回字典格式的输出
        if not return_dict:
            # 返回 outputs 中非 None 的元素作为元组
            return tuple(v for v in outputs if v is not None)

        # 返回 FlaxBaseModelOutput 类的实例,包含最终的隐藏状态、所有隐藏状态和注意力权重
        return FlaxBaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )


# 定义一个 FlaxViTEncoder 类,继承自 nn.Module
class FlaxViTEncoder(nn.Module):
    # 定义类变量 config,类型为 ViTConfig
    config: ViTConfig
    # 定义 dtype 变量,默认为 jnp.float32,用于计算的数据类型

    # 初始化函数,创建 FlaxViTLayerCollection 类的实例作为 self.layer
    def setup(self):
        self.layer = FlaxViTLayerCollection(self.config, dtype=self.dtype)

    # 实现 __call__ 方法,用于模型的前向传播
    def __call__(
        self,
        hidden_states,  # 输入的隐藏状态张量
        deterministic: bool = True,  # 是否使用确定性计算,默认为 True
        output_attentions: bool = False,  # 是否输出注意力权重,默认为 False
        output_hidden_states: bool = False,  # 是否输出所有隐藏状态,默认为 False
        return_dict: bool = True,  # 是否返回字典格式的输出,默认为 True
    ):
        # 调用 self.layer 的 __call__ 方法进行前向传播
        return self.layer(
            hidden_states,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )


# 定义一个 FlaxViTPooler 类,继承自 nn.Module
class FlaxViTPooler(nn.Module):
    # 定义类变量 config,类型为 ViTConfig
    config: ViTConfig
    # 定义 dtype 变量,默认为 jnp.float32,用于计算的数据类型

    # 初始化函数,设置池化层为全连接层
    def setup(self):
        self.dense = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.variance_scaling(
                self.config.initializer_range**2, "fan_in", "truncated_normal"
            ),
            dtype=self.dtype,
        )

    # 实现 __call__ 方法,用于模型的前向传播
    def __call__(self, hidden_states):
        # 取出每个样本的第一个位置的隐藏状态作为池化结果
        cls_hidden_state = hidden_states[:, 0]
        # 将 cls_hidden_state 输入到全连接层中
        cls_hidden_state = self.dense(cls_hidden_state)
        # 对全连接层的输出进行 tanh 激活函数处理
        return nn.tanh(cls_hidden_state)


# 定义一个 FlaxViTPreTrainedModel 类,继承自 FlaxPreTrainedModel
class FlaxViTPreTrainedModel(FlaxPreTrainedModel):
    """
    一个处理权重初始化和简单接口以下载和加载预训练模型的抽象类。
    """

    # 类变量,指定配置类为 ViTConfig
    config_class = ViTConfig
    # 模型的基础名称前缀为 "vit"
    base_model_prefix = "vit"
    # 主输入名称为 "pixel_values"
    main_input_name = "pixel_values"
    # 模型类变量,默认为 None
    module_class: nn.Module = None

    # 初始化函数,接受配置对象、输入形状、种子值、数据类型和其他关键字参数
    def __init__(
        self,
        config: ViTConfig,
        input_shape=None,
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        # 使用给定的配置和数据类型实例化模块对象
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        
        # 如果未提供输入形状,则设置为默认的图像大小和通道数
        if input_shape is None:
            input_shape = (1, config.image_size, config.image_size, config.num_channels)
        
        # 调用父类的初始化方法,传递配置、模块、输入形状、种子、数据类型和是否初始化参数
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

    # 初始化模型权重的函数,接受随机数生成器、输入形状和参数字典
    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # 初始化输入张量,全零张量
        pixel_values = jnp.zeros(input_shape, dtype=self.dtype)

        # 切分随机数生成器为参数随机数和 dropout 随机数
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        # 使用模块对象的初始化方法初始化随机参数
        random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"]

        # 如果传入了已有参数,则将缺失的参数从随机参数中添加进去
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params

    # 调用函数,接受像素值、参数字典、dropout 随机数生成器、训练标志、是否输出注意力、隐藏状态和是否返回字典
    @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    def __call__(
        self,
        pixel_values,
        params: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 如果未指定输出注意力,默认使用配置中的设置
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 如果未指定输出隐藏状态,默认使用配置中的设置
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 如果未指定是否返回字典,默认使用配置中的设置
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # 转置像素值的维度,将通道维度移到最后一个位置
        pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
        
        # 如果需要处理 dropout 的随机数生成器,则添加到随机数字典中
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        # 应用模块对象的前向传播函数,传入参数字典或者自身的参数、像素值、训练标志、输出注意力、隐藏状态、返回字典和随机数字典
        return self.module.apply(
            {"params": params or self.params},
            jnp.array(pixel_values, dtype=jnp.float32),
            not train,
            output_attentions,
            output_hidden_states,
            return_dict,
            rngs=rngs,
        )
class FlaxViTModule(nn.Module):
    config: ViTConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    add_pooling_layer: bool = True

    def setup(self):
        # 初始化 ViT 模型的嵌入层
        self.embeddings = FlaxViTEmbeddings(self.config, dtype=self.dtype)
        # 初始化 ViT 模型的编码器层
        self.encoder = FlaxViTEncoder(self.config, dtype=self.dtype)
        # 初始化 LayerNorm 层,用于规范化隐藏状态
        self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        # 如果设置了添加池化层选项,初始化 ViT 模型的池化层
        self.pooler = FlaxViTPooler(self.config, dtype=self.dtype) if self.add_pooling_layer else None

    def __call__(
        self,
        pixel_values,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 将像素值转换为隐藏状态向量
        hidden_states = self.embeddings(pixel_values, deterministic=deterministic)

        # 使用编码器处理隐藏状态向量,获取输出
        outputs = self.encoder(
            hidden_states,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 获取编码器的输出的隐藏状态向量
        hidden_states = outputs[0]
        # 对隐藏状态向量进行 LayerNorm 规范化
        hidden_states = self.layernorm(hidden_states)
        # 如果设置了池化层且存在,对隐藏状态向量进行池化操作
        pooled = self.pooler(hidden_states) if self.add_pooling_layer else None

        # 如果不返回字典,根据池化层的存在性选择性返回结果
        if not return_dict:
            if pooled is None:
                return (hidden_states,) + outputs[1:]
            return (hidden_states, pooled) + outputs[1:]

        # 返回包含池化输出的 FlaxBaseModelOutputWithPooling 对象
        return FlaxBaseModelOutputWithPooling(
            last_hidden_state=hidden_states,
            pooler_output=pooled,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


@add_start_docstrings(
    "The bare ViT Model transformer outputting raw hidden-states without any specific head on top.",
    VIT_START_DOCSTRING,
)
class FlaxViTModel(FlaxViTPreTrainedModel):
    module_class = FlaxViTModule


FLAX_VISION_MODEL_DOCSTRING = """
    Returns:

    Examples:

    ```
    >>> from transformers import AutoImageProcessor, FlaxViTModel
    >>> from PIL import Image
    >>> import requests

    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)

    >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
    >>> model = FlaxViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

    >>> inputs = image_processor(images=image, return_tensors="np")
    >>> outputs = model(**inputs)
    >>> last_hidden_states = outputs.last_hidden_state
    ```
"""

overwrite_call_docstring(FlaxViTModel, FLAX_VISION_MODEL_DOCSTRING)
append_replace_return_docstrings(FlaxViTModel, output_type=FlaxBaseModelOutputWithPooling, config_class=ViTConfig)


class FlaxViTForImageClassificationModule(nn.Module):
    config: ViTConfig
    dtype: jnp.dtype = jnp.float32
    # 在对象初始化时设置模型结构,使用指定的配置和数据类型,不添加池化层
    def setup(self):
        self.vit = FlaxViTModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
        
        # 初始化分类器,设置输出类别数和数据类型,并使用截断正态分布初始化权重
        self.classifier = nn.Dense(
            self.config.num_labels,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.variance_scaling(
                self.config.initializer_range**2, "fan_in", "truncated_normal"
            ),
        )

    # 实现对象的调用功能,接受像素值、确定性标志、是否输出注意力和隐藏状态、返回字典等参数
    def __call__(
        self,
        pixel_values=None,
        deterministic: bool = True,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        # 如果没有显式指定返回字典的用法,则使用配置中的默认设置
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用预训练的 ViT 模型进行前向传播
        outputs = self.vit(
            pixel_values,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 提取模型输出的隐藏状态,并通过分类器获取对应的 logits
        hidden_states = outputs[0]
        logits = self.classifier(hidden_states[:, 0, :])

        # 如果不需要返回字典形式的输出,则组装为元组
        if not return_dict:
            output = (logits,) + outputs[2:]  # 包括 logits 和额外的隐藏状态
            return output

        # 返回格式化后的分类器输出对象,包括 logits、隐藏状态和注意力权重
        return FlaxSequenceClassifierOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
# 在FlaxViTForImageClassification类之前添加文档字符串,描述其作为ViT模型转换器的用途,顶部有一个基于图像分类的线性层的简要说明,例如用于ImageNet。
@add_start_docstrings(
    """
    ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
    the [CLS] token) e.g. for ImageNet.
    """,
    VIT_START_DOCSTRING,  # 添加ViT模型的起始文档字符串
)

# 指定FlaxViTForImageClassification的模块类为FlaxViTForImageClassificationModule
class FlaxViTForImageClassification(FlaxViTPreTrainedModel):
    module_class = FlaxViTForImageClassificationModule


# FLAX_VISION_CLASSIF_DOCSTRING 是一个多行字符串,用于描述FlaxViTForImageClassification类的返回值和示例
FLAX_VISION_CLASSIF_DOCSTRING = """
    Returns:

    Example:

    ```
    >>> from transformers import AutoImageProcessor, FlaxViTForImageClassification
    >>> from PIL import Image
    >>> import jax
    >>> import requests

    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)

    >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
    >>> model = FlaxViTForImageClassification.from_pretrained("google/vit-base-patch16-224")

    >>> inputs = image_processor(images=image, return_tensors="np")
    >>> outputs = model(**inputs)
    >>> logits = outputs.logits

    >>> # model predicts one of the 1000 ImageNet classes
    >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1)
    >>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()])
    ```
"""

# 使用overwrite_call_docstring函数,将FLAX_VISION_CLASSIF_DOCSTRING的内容覆盖到FlaxViTForImageClassification类的文档字符串中
overwrite_call_docstring(FlaxViTForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING)

# 使用append_replace_return_docstrings函数,将输出类型设置为FlaxSequenceClassifierOutput,并指定配置类为ViTConfig
append_replace_return_docstrings(
    FlaxViTForImageClassification, output_type=FlaxSequenceClassifierOutput, config_class=ViTConfig
)

.\models\vit\modeling_tf_vit.py

# coding=utf-8
# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 ViT model."""

from __future__ import annotations

import collections.abc
import math
from typing import Optional, Tuple, Union

import numpy as np
import tensorflow as tf

from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput
from ...modeling_tf_utils import (
    TFModelInputType,
    TFPreTrainedModel,
    TFSequenceClassificationLoss,
    get_initializer,
    keras,
    keras_serializable,
    unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_vit import ViTConfig

# Logger setup for this module
logger = logging.get_logger(__name__)

# General docstring
_CONFIG_FOR_DOC = "ViTConfig"

# Base docstring
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224-in21k"
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]

# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"

class TFViTEmbeddings(keras.layers.Layer):
    """
    Construct the CLS token, position and patch embeddings.
    """

    def __init__(self, config: ViTConfig, **kwargs):
        super().__init__(**kwargs)
        
        # Initialize patch embeddings using TFViTPatchEmbeddings layer
        self.patch_embeddings = TFViTPatchEmbeddings(config, name="patch_embeddings")
        # Dropout layer with rate from configuration
        self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
        self.config = config

    def build(self, input_shape=None):
        # Get number of patches from patch embeddings layer
        num_patches = self.patch_embeddings.num_patches
        
        # Initialize CLS token embedding
        self.cls_token = self.add_weight(
            shape=(1, 1, self.config.hidden_size),
            initializer=get_initializer(self.config.initializer_range),
            trainable=True,
            name="cls_token",
        )
        
        # Initialize position embeddings based on number of patches
        self.position_embeddings = self.add_weight(
            shape=(1, num_patches + 1, self.config.hidden_size),
            initializer=get_initializer(self.config.initializer_range),
            trainable=True,
            name="position_embeddings",
        )

        # Check if layer is already built to avoid re-building
        if self.built:
            return
        
        self.built = True
        
        # Build patch embeddings layer if it exists
        if getattr(self, "patch_embeddings", None) is not None:
            with tf.name_scope(self.patch_embeddings.name):
                self.patch_embeddings.build(None)
    def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor:
        """
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
        resolution images.

        Source:
        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
        """

        # 获取嵌入张量的形状信息:batch_size为批大小,seq_len为序列长度,dim为嵌入维度
        batch_size, seq_len, dim = shape_list(embeddings)
        # 计算图像分块的数量(即序列长度减去1)
        num_patches = seq_len - 1

        # 获取预训练位置编码张量的形状信息:num_positions为位置编码的数量
        _, num_positions, _ = shape_list(self.position_embeddings)
        num_positions -= 1

        # 如果图像分块数量等于位置编码数量且图像高度等于宽度,则直接返回位置编码张量
        if num_patches == num_positions and height == width:
            return self.position_embeddings
        
        # 从位置编码张量中分离出类别位置编码
        class_pos_embed = self.position_embeddings[:, :1]
        # 从位置编码张量中分离出图像分块位置编码
        patch_pos_embed = self.position_embeddings[:, 1:]
        # 计算新的图像分块数量
        h0 = height // self.config.patch_size
        w0 = width // self.config.patch_size
        # 使用双三次插值法对图像分块位置编码进行调整
        patch_pos_embed = tf.image.resize(
            images=tf.reshape(
                patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
            ),
            size=(h0, w0),
            method="bicubic",
        )

        # 检查调整后的图像分块位置编码张量形状是否与预期一致
        shape = shape_list(patch_pos_embed)
        assert h0 == shape[-3] and w0 == shape[-2]
        # 重新整形调整后的图像分块位置编码张量
        patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
        # 将类别位置编码和调整后的图像分块位置编码拼接在一起作为最终的位置编码张量
        return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)

    def call(
        self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False
    ) -> tf.Tensor:
        # 获取输入像素张量的形状信息:batch_size为批大小,num_channels为通道数,height为高度,width为宽度
        batch_size, num_channels, height, width = shape_list(pixel_values)
        # 将像素值转换为嵌入向量,并根据需要进行位置编码的插值
        embeddings = self.patch_embeddings(
            pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, training=training
        )

        # 将[CLS]令牌添加到嵌入的补丁令牌中
        cls_tokens = tf.repeat(self.cls_token, repeats=batch_size, axis=0)
        embeddings = tf.concat((cls_tokens, embeddings), axis=1)

        # 如果需要插值位置编码,则将插值后的位置编码添加到每个令牌中
        if interpolate_pos_encoding:
            embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
        else:
            # 否则,将原始位置编码添加到每个令牌中
            embeddings = embeddings + self.position_embeddings

        # 在训练时对嵌入向量应用丢弃操作
        embeddings = self.dropout(embeddings, training=training)

        # 返回最终的嵌入向量
        return embeddings
# 基于 timm 实现,可以在此处找到:
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
class TFViTPatchEmbeddings(keras.layers.Layer):
    """
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    """

    def __init__(self, config: ViTConfig, **kwargs):
        super().__init__(**kwargs)
        
        # 从配置中获取图像大小和补丁大小
        image_size, patch_size = config.image_size, config.patch_size
        # 从配置中获取通道数和隐藏大小
        num_channels, hidden_size = config.num_channels, config.hidden_size

        # 如果图像大小和补丁大小不是可迭代对象,则转换为元组
        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
        
        # 计算图像中的补丁数量
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
        
        # 设置实例变量
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.num_channels = num_channels
        self.config = config

        # 创建卷积层,用于将图像补丁投影到隐藏大小的向量空间
        self.projection = keras.layers.Conv2D(
            filters=hidden_size,                              # 输出通道数为隐藏大小
            kernel_size=patch_size,                           # 卷积核大小设为补丁大小
            strides=patch_size,                               # 步幅设为补丁大小
            padding="valid",                                  # 使用有效填充
            data_format="channels_last",                       # 输入格式为通道在后
            use_bias=True,                                    # 使用偏置项
            kernel_initializer=get_initializer(self.config.initializer_range),  # 卷积核初始化器
            bias_initializer="zeros",                         # 偏置项初始化器设为零
            name="projection",                                # 层的名称为“projection”
        )

    def call(
        self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False
        # 输入函数,将像素值转换为图像补丁嵌入向量
        # pixel_values: 输入的像素值张量
        # interpolate_pos_encoding: 是否插值位置编码
        # training: 是否在训练模式下
    ) -> tf.Tensor:
        # 获取输入张量的形状信息:batch_size, num_channels, height, width
        batch_size, num_channels, height, width = shape_list(pixel_values)
        
        # 如果在即时执行模式下,并且像素值的通道数不等于配置中设置的通道数,则引发值错误
        if tf.executing_eagerly() and num_channels != self.num_channels:
            raise ValueError(
                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
            )
        
        # 如果不需要插值位置编码
        if not interpolate_pos_encoding:
            if tf.executing_eagerly():
                # 如果高度或宽度与模型期望的图像尺寸不匹配,则引发值错误
                if height != self.image_size[0] or width != self.image_size[1]:
                    raise ValueError(
                        f"Input image size ({height}*{width}) doesn't match model"
                        f" ({self.image_size[0]}*{self.image_size[1]})."
                    )

        # 当在 CPU 上运行时,`keras.layers.Conv2D` 不支持 `NCHW` 格式,因此将输入格式从 `NCHW` 转换为 `NHWC`
        # 形状变为:(batch_size, in_height, in_width, in_channels=num_channels)
        pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))

        # 对输入进行投影操作
        projection = self.projection(pixel_values)

        # 将二维空间维度转换为单一的时间维度
        # 形状变为:(batch_size, num_patches, out_channels=embed_dim)
        num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
        embeddings = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))

        return embeddings

    def build(self, input_shape=None):
        # 如果已经构建过网络,则直接返回
        if self.built:
            return
        
        # 标记网络已构建
        self.built = True
        
        # 如果已存在投影层,则构建投影层
        if getattr(self, "projection", None) is not None:
            with tf.name_scope(self.projection.name):
                # 构建投影层,输入形状为 [None, None, None, self.num_channels]
                self.projection.build([None, None, None, self.num_channels])
class TFViTSelfAttention(keras.layers.Layer):
    # 定义一个名为TFViTSelfAttention的自定义Layer类
    def __init__(self, config: ViTConfig, **kwargs):
        # 初始化函数,接受一个ViTConfig类型的config对象和其他可选参数
        super().__init__(**kwargs)
        # 调用父类的初始化函数

        if config.hidden_size % config.num_attention_heads != 0:
            # 如果隐藏层大小不能被注意力头的数量整除
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number "
                f"of attention heads ({config.num_attention_heads})"
            )

        self.num_attention_heads = config.num_attention_heads
        # 设置注意力头的数量
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        # 计算每个注意力头的大小
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        # 计算所有注意力头的总大小
        self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
        # 计算注意力头大小的平方根

        self.query = keras.layers.Dense(
            units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
        )
        # 创建查询矩阵
        self.key = keras.layers.Dense(
            units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
        )
        # 创建键矩阵
        self.value = keras.layers.Dense(
            units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
        )
        # 创建值矩阵
        self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
        # 创建dropout层
        self.config = config
        # 保存config对象

    def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
        # 定义一个函数,将输入的tensor进行维度变换,返回变换后的tensor
        tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
        # 将tensor从[batch_size, seq_length, all_head_size]形状变换为[batch_size, seq_length, num_attention_heads, attention_head_size]

        return tf.transpose(tensor, perm=[0, 2, 1, 3])
        # 将tensor从[batch_size, seq_length, num_attention_heads, attention_head_size]形状变换为[batch_size, num_attention_heads, seq_length, attention_head_size]

    def call(
        self,
        hidden_states: tf.Tensor,
        head_mask: tf.Tensor,
        output_attentions: bool,
        training: bool = False,
        ```
    ) -> Tuple[tf.Tensor]:
        # 获取隐藏状态张量的批量大小
        batch_size = shape_list(hidden_states)[0]
        # 通过 self.query 对象计算混合的查询层
        mixed_query_layer = self.query(inputs=hidden_states)
        # 通过 self.key 对象计算混合的键层
        mixed_key_layer = self.key(inputs=hidden_states)
        # 通过 self.value 对象计算混合的值层
        mixed_value_layer = self.value(inputs=hidden_states)
        # 将混合的查询层转置以便进行注意力分数计算
        query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
        # 将混合的键层转置以便进行注意力分数计算
        key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
        # 将混合的值层转置以便进行注意力分数计算
        value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)

        # 计算查询与键的点积,得到原始注意力分数
        # 形状为 (batch size, num_heads, seq_len_q, seq_len_k)
        attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
        # 计算缩放系数 dk,并将注意力分数进行缩放
        dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
        attention_scores = tf.divide(attention_scores, dk)

        # 将注意力分数归一化为概率
        attention_probs = stable_softmax(logits=attention_scores, axis=-1)

        # 使用 dropout 随机屏蔽注意力概率中的部分内容,用于模型训练中的稳定性
        attention_probs = self.dropout(inputs=attention_probs, training=training)

        # 如果有头部掩码 head_mask,则将其应用到注意力概率中
        if head_mask is not None:
            attention_probs = tf.multiply(attention_probs, head_mask)

        # 计算注意力输出值
        attention_output = tf.matmul(attention_probs, value_layer)
        # 调整输出张量的维度顺序
        attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])

        # 重新整形注意力输出张量的形状为 (batch_size, seq_len_q, all_head_size)
        attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
        # 输出包含注意力输出张量和注意力概率的元组,如果输出注意力分布则包含注意力概率
        outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)

        return outputs

    def build(self, input_shape=None):
        # 如果已经构建过网络层,则直接返回
        if self.built:
            return
        # 设置网络层为已构建状态
        self.built = True
        # 如果 self.query 存在,则构建查询层
        if getattr(self, "query", None) is not None:
            with tf.name_scope(self.query.name):
                self.query.build([None, None, self.config.hidden_size])
        # 如果 self.key 存在,则构建键层
        if getattr(self, "key", None) is not None:
            with tf.name_scope(self.key.name):
                self.key.build([None, None, self.config.hidden_size])
        # 如果 self.value 存在,则构建值层
        if getattr(self, "value", None) is not None:
            with tf.name_scope(self.value.name):
                self.value.build([None, None, self.config.hidden_size])
class TFViTSelfOutput(keras.layers.Layer):
    """
    The residual connection is defined in TFViTLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    """

    def __init__(self, config: ViTConfig, **kwargs):
        super().__init__(**kwargs)

        # 定义一个全连接层,用于变换隐藏状态的维度
        self.dense = keras.layers.Dense(
            units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        # 定义一个dropout层,用于随机失活隐藏状态
        self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
        # 保存配置信息
        self.config = config

    def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
        # 对隐藏状态进行全连接层变换
        hidden_states = self.dense(inputs=hidden_states)
        # 在训练时对全连接层输出进行dropout处理
        hidden_states = self.dropout(inputs=hidden_states, training=training)

        return hidden_states

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                # 构建全连接层,指定输入形状和隐藏单元数
                self.dense.build([None, None, self.config.hidden_size])


class TFViTAttention(keras.layers.Layer):
    def __init__(self, config: ViTConfig, **kwargs):
        super().__init__(**kwargs)

        # 定义自注意力层,用于计算注意力分数
        self.self_attention = TFViTSelfAttention(config, name="attention")
        # 定义输出层,处理自注意力层的输出
        self.dense_output = TFViTSelfOutput(config, name="output")

    def prune_heads(self, heads):
        raise NotImplementedError

    def call(
        self,
        input_tensor: tf.Tensor,
        head_mask: tf.Tensor,
        output_attentions: bool,
        training: bool = False,
    ) -> Tuple[tf.Tensor]:
        # 进行自注意力计算,得到自注意力层的输出
        self_outputs = self.self_attention(
            hidden_states=input_tensor, head_mask=head_mask, output_attentions=output_attentions, training=training
        )
        # 经过输出层处理自注意力层的输出
        attention_output = self.dense_output(
            hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
        )
        # 构建最终的输出元组,包括注意力输出和可能的其他返回值
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them

        return outputs

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "self_attention", None) is not None:
            with tf.name_scope(self.self_attention.name):
                # 构建自注意力层
                self.self_attention.build(None)
        if getattr(self, "dense_output", None) is not None:
            with tf.name_scope(self.dense_output.name):
                # 构建输出层
                self.dense_output.build(None)


class TFViTIntermediate(keras.layers.Layer):
    # 此处需要继续完成 TFViTIntermediate 类的注释
    # 初始化方法,用于创建一个新的ViTLayer对象
    def __init__(self, config: ViTConfig, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)

        # 创建一个全连接层对象,设置单元数为config.intermediate_size,
        # 内核初始化器为config.initializer_range指定的初始化器,层名为"dense"
        self.dense = keras.layers.Dense(
            units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )

        # 如果config.hidden_act是字符串类型,则通过get_tf_activation函数获取对应的激活函数
        # 否则直接使用config.hidden_act作为中间激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = get_tf_activation(config.hidden_act)
        else:
            self.intermediate_act_fn = config.hidden_act
        
        # 将配置信息保存在self.config中
        self.config = config

    # 调用方法,用于执行实际的前向传播操作
    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        # 将输入hidden_states传入全连接层self.dense,得到输出hidden_states
        hidden_states = self.dense(inputs=hidden_states)
        # 将全连接层的输出hidden_states通过中间激活函数self.intermediate_act_fn进行激活处理
        hidden_states = self.intermediate_act_fn(hidden_states)

        # 返回处理后的hidden_states作为本层的输出
        return hidden_states

    # 构建方法,用于构建层的参数和状态
    def build(self, input_shape=None):
        # 如果已经构建过,直接返回
        if self.built:
            return
        
        # 设置标志位built为True,表示已经构建过
        self.built = True
        
        # 如果self.dense层存在,则使用tf.name_scope设置作用域为self.dense.name,
        # 并构建全连接层self.dense,输入形状为[None, None, self.config.hidden_size]
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                self.dense.build([None, None, self.config.hidden_size])
class TFViTOutput(keras.layers.Layer):
    def __init__(self, config: ViTConfig, **kwargs):
        super().__init__(**kwargs)

        # 创建一个全连接层,用于将输入转换到隐藏大小
        self.dense = keras.layers.Dense(
            units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        # 创建一个 dropout 层,用于在训练时随机失活部分神经元
        self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
        self.config = config

    def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
        # 使用全连接层对隐藏状态进行线性变换
        hidden_states = self.dense(inputs=hidden_states)
        # 在训练时对转换后的隐藏状态应用 dropout
        hidden_states = self.dropout(inputs=hidden_states, training=training)
        # 将 dropout 后的隐藏状态与输入张量相加,实现残差连接
        hidden_states = hidden_states + input_tensor

        return hidden_states

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "dense", None) is not None:
            # 使用 dense 层的名称作为命名空间,构建其权重
            with tf.name_scope(self.dense.name):
                self.dense.build([None, None, self.config.intermediate_size])


class TFViTLayer(keras.layers.Layer):
    """This corresponds to the Block class in the timm implementation."""

    def __init__(self, config: ViTConfig, **kwargs):
        super().__init__(**kwargs)

        # 创建注意力机制层
        self.attention = TFViTAttention(config, name="attention")
        # 创建中间层
        self.intermediate = TFViTIntermediate(config, name="intermediate")
        # 创建 ViT 输出层
        self.vit_output = TFViTOutput(config, name="output")

        # 创建前层归一化层,用于 ViT 中的前置处理
        self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before")
        # 创建后层归一化层,用于 ViT 中的后置处理
        self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after")
        self.config = config

    def call(
        self,
        hidden_states: tf.Tensor,
        head_mask: tf.Tensor,
        output_attentions: bool,
        training: bool = False,
    ) -> Tuple[tf.Tensor]:
        # 对隐藏状态应用前层归一化
        attention_outputs = self.attention(
            input_tensor=self.layernorm_before(inputs=hidden_states),
            head_mask=head_mask,
            output_attentions=output_attentions,
            training=training,
        )
        attention_output = attention_outputs[0]

        # 第一个残差连接
        hidden_states = attention_output + hidden_states

        # 对输出应用后层归一化
        layer_output = self.layernorm_after(inputs=hidden_states)

        intermediate_output = self.intermediate(hidden_states=layer_output)

        # 第二个残差连接
        layer_output = self.vit_output(
            hidden_states=intermediate_output, input_tensor=hidden_states, training=training
        )
        outputs = (layer_output,) + attention_outputs[1:]  # 如果输出注意力信息,则添加到输出中

        return outputs
    # 在构建模型时调用的方法,用于设置模型的各个组件
    def build(self, input_shape=None):
        # 如果模型已经构建过,则直接返回,避免重复构建
        if self.built:
            return
        # 设置标志位,表示模型已经构建
        self.built = True
        
        # 如果存在注意力组件,则构建注意力组件
        if getattr(self, "attention", None) is not None:
            # 在命名空间中构建注意力组件
            with tf.name_scope(self.attention.name):
                self.attention.build(None)
        
        # 如果存在中间层组件,则构建中间层组件
        if getattr(self, "intermediate", None) is not None:
            # 在命名空间中构建中间层组件
            with tf.name_scope(self.intermediate.name):
                self.intermediate.build(None)
        
        # 如果存在ViT输出组件,则构建ViT输出组件
        if getattr(self, "vit_output", None) is not None:
            # 在命名空间中构建ViT输出组件
            with tf.name_scope(self.vit_output.name):
                self.vit_output.build(None)
        
        # 如果存在层归一化前组件,则构建层归一化前组件
        if getattr(self, "layernorm_before", None) is not None:
            # 在命名空间中构建层归一化前组件,输入形状为 [None, None, self.config.hidden_size]
            with tf.name_scope(self.layernorm_before.name):
                self.layernorm_before.build([None, None, self.config.hidden_size])
        
        # 如果存在层归一化后组件,则构建层归一化后组件
        if getattr(self, "layernorm_after", None) is not None:
            # 在命名空间中构建层归一化后组件,输入形状为 [None, None, self.config.hidden_size]
            with tf.name_scope(self.layernorm_after.name):
                self.layernorm_after.build([None, None, self.config.hidden_size])
class TFViTEncoder(keras.layers.Layer):
    def __init__(self, config: ViTConfig, **kwargs):
        super().__init__(**kwargs)

        # 初始化 ViT 编码器的各个层
        self.layer = [TFViTLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]

    def call(
        self,
        hidden_states: tf.Tensor,
        head_mask: tf.Tensor,
        output_attentions: bool,
        output_hidden_states: bool,
        return_dict: bool,
        training: bool = False,
    ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
        # 初始化输出变量
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        # 遍历每个编码器层
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                # 如果需要输出隐藏状态,则记录当前隐藏状态
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 调用当前编码器层的计算
            layer_outputs = layer_module(
                hidden_states=hidden_states,
                head_mask=head_mask[i],
                output_attentions=output_attentions,
                training=training,
            )
            hidden_states = layer_outputs[0]

            if output_attentions:
                # 如果需要输出注意力权重,则记录当前层的注意力权重
                all_attentions = all_attentions + (layer_outputs[1],)

        # 添加最后一个编码器层的隐藏状态输出
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 根据 return_dict 参数决定返回结果的形式
        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)

        # 返回 TFBaseModelOutput 对象,包含最后隐藏状态、所有隐藏状态和所有注意力权重
        return TFBaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )

    def build(self, input_shape=None):
        # 如果已经构建过,直接返回
        if self.built:
            return
        self.built = True
        # 构建每个编码器层
        if getattr(self, "layer", None) is not None:
            for layer in self.layer:
                with tf.name_scope(layer.name):
                    layer.build(None)


@keras_serializable
class TFViTMainLayer(keras.layers.Layer):
    config_class = ViTConfig

    def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, **kwargs):
        super().__init__(**kwargs)

        # 初始化 ViT 主层的配置
        self.config = config

        # 初始化 ViT 主层的各个子层
        self.embeddings = TFViTEmbeddings(config, name="embeddings")
        self.encoder = TFViTEncoder(config, name="encoder")
        self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
        self.pooler = TFViTPooler(config, name="pooler") if add_pooling_layer else None

    def get_input_embeddings(self) -> keras.layers.Layer:
        # 返回输入嵌入层的 patch embeddings
        return self.embeddings.patch_embeddings

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        # 剪枝模型的注意力头,具体实现未给出
        raise NotImplementedError

    @unpack_inputs
    def call(
        self,
        pixel_values: TFModelInputType | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: bool = False,
    ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
        # 检查是否提供了 pixel_values 参数,如果没有则抛出数值错误
        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        # 通过嵌入层处理输入的像素值,包括插值位置编码和训练模式
        embedding_output = self.embeddings(
            pixel_values=pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
            training=training,
        )

        # 如果需要,准备头部遮罩
        # 在 head_mask 中为 1.0 表示保留对应的注意力头部
        # attention_probs 的形状为 bsz x n_heads x N x N
        # 输入的 head_mask 形状为 [num_heads] 或 [num_hidden_layers x num_heads]
        # 将 head_mask 转换为形状 [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        if head_mask is not None:
            raise NotImplementedError
        else:
            # 如果 head_mask 为 None,则创建一个与隐藏层数量相同的空列表
            head_mask = [None] * self.config.num_hidden_layers

        # 使用编码器处理嵌入的输出,传入头部遮罩和其他可选参数
        encoder_outputs = self.encoder(
            hidden_states=embedding_output,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        # 获取编码器的序列输出
        sequence_output = encoder_outputs[0]
        # 对序列输出进行 LayerNormalization 处理
        sequence_output = self.layernorm(inputs=sequence_output)
        # 如果定义了池化器,则对序列输出进行池化操作
        pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None

        # 如果 return_dict 为 False,则返回一个包含序列输出和池化输出的元组,以及其他编码器输出
        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        # 如果 return_dict 为 True,则构建 TFBaseModelOutputWithPooling 对象并返回
        return TFBaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )

    def build(self, input_shape=None):
        # 如果模型已经构建,则直接返回
        if self.built:
            return
        self.built = True
        # 如果存在 embeddings 属性,则构建 embeddings
        if getattr(self, "embeddings", None) is not None:
            with tf.name_scope(self.embeddings.name):
                self.embeddings.build(None)
        # 如果存在 encoder 属性,则构建 encoder
        if getattr(self, "encoder", None) is not None:
            with tf.name_scope(self.encoder.name):
                self.encoder.build(None)
        # 如果存在 layernorm 属性,则根据配置的隐藏大小构建 layernorm
        if getattr(self, "layernorm", None) is not None:
            with tf.name_scope(self.layernorm.name):
                self.layernorm.build([None, None, self.config.hidden_size])
        # 如果存在 pooler 属性,则构建 pooler
        if getattr(self, "pooler", None) is not None:
            with tf.name_scope(self.pooler.name):
                self.pooler.build(None)
    """
    This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
    as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
    behavior.

    <Tip>

    TensorFlow models and layers in `transformers` accept two formats as input:

    - having all inputs as keyword arguments (like PyTorch models), or
    - having all inputs as a list, tuple or dict in the first positional argument.

    The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
    and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
    pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
    format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
    the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
    positional argument:

    - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)`
    - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
    `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
    - a dictionary with one or several input Tensors associated to the input names given in the docstring:
    `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`

    Note that when creating models and layers with
    [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
    about any of this, as you can just pass inputs like you would to any other Python function!

    </Tip>

    Args:
        config ([`ViTConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
    """
    Args:
        pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
            for details.
        head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
            config will be used instead.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
            used instead.
        interpolate_pos_encoding (`bool`, *optional*):
            Whether to interpolate the pre-trained position encodings.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
            eager mode, in graph mode the value will always be set to True.
        training (`bool`, *optional*, defaults to `False``):
            Whether or not to use the model in training mode (some modules like dropout modules have different
            behaviors between training and evaluation).
@add_start_docstrings(
    """
    ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
    the [CLS] token) e.g. for ImageNet.
    """,
    VIT_START_DOCSTRING,
)
class TFViTModel(TFViTPreTrainedModel):
    def __init__(self, config: ViTConfig, *inputs, add_pooling_layer=True, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        # Initialize the main ViT layer using TFViTMainLayer with optional pooling
        self.vit = TFViTMainLayer(config, add_pooling_layer=add_pooling_layer, name="vit")

    @unpack_inputs
    @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TFBaseModelOutputWithPooling,
        config_class=_CONFIG_FOR_DOC,
        modality="vision",
        expected_output=_EXPECTED_OUTPUT_SHAPE,
    )
    def call(
        self,
        pixel_values: TFModelInputType | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: bool = False,
    ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
        # Pass inputs to the ViT model and return outputs
        outputs = self.vit(
            pixel_values=pixel_values,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            interpolate_pos_encoding=interpolate_pos_encoding,
            return_dict=return_dict,
            training=training,
        )

        return outputs

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "vit", None) is not None:
            with tf.name_scope(self.vit.name):
                self.vit.build(None)


class TFViTPooler(keras.layers.Layer):
    def __init__(self, config: ViTConfig, **kwargs):
        super().__init__(**kwargs)

        # Initialize a dense layer for pooling with specified units, activation, and initializer
        self.dense = keras.layers.Dense(
            units=config.hidden_size,
            kernel_initializer=get_initializer(config.initializer_range),
            activation="tanh",
            name="dense",
        )
        self.config = config

    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        # Pooling operation by taking the hidden state of the first token
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(inputs=first_token_tensor)

        return pooled_output

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                self.dense.build([None, None, self.config.hidden_size])
    """
    <Tip>

        Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by
        setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
        position embeddings to the higher resolution.

    </Tip>
    """,
    VIT_START_DOCSTRING,
# 定义一个图像分类模型,继承自TFViTPreTrainedModel和TFSequenceClassificationLoss类
class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassificationLoss):
    def __init__(self, config: ViTConfig, *inputs, **kwargs):
        # 调用父类的初始化方法
        super().__init__(config, *inputs, **kwargs)

        # 设置模型的标签数量
        self.num_labels = config.num_labels
        # 创建一个ViT主层对象,不包含池化层
        self.vit = TFViTMainLayer(config, add_pooling_layer=False, name="vit")

        # 分类器头部
        self.classifier = keras.layers.Dense(
            units=config.num_labels,
            kernel_initializer=get_initializer(config.initializer_range),
            name="classifier",
        )
        # 保存配置对象
        self.config = config

    # 调用模型的前向传播方法,处理输入参数并返回输出结果
    @unpack_inputs
    @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_IMAGE_CLASS_CHECKPOINT,
        output_type=TFSequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
        expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
    )
    def call(
        self,
        pixel_values: TFModelInputType | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: np.ndarray | tf.Tensor | None = None,
        training: Optional[bool] = False,
    ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
        r"""
        labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """

        # 调用ViT主层的前向传播,获取输出结果
        outputs = self.vit(
            pixel_values=pixel_values,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            interpolate_pos_encoding=interpolate_pos_encoding,
            return_dict=return_dict,
            training=training,
        )
        # 从ViT输出中提取序列输出
        sequence_output = outputs[0]
        # 将序列输出传递给分类器获取logits
        logits = self.classifier(inputs=sequence_output[:, 0, :])
        # 如果存在标签,计算损失
        loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)

        # 如果不返回字典格式的输出,按照元组格式构建返回结果
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 返回TFSequenceClassifierOutput对象,包含损失、logits、隐藏状态和注意力权重
        return TFSequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
    # 定义神经网络层的构建方法,input_shape 参数表示输入形状,默认为 None
    def build(self, input_shape=None):
        # 如果已经构建过,则直接返回,避免重复构建
        if self.built:
            return
        # 标记该层已经构建
        self.built = True
        
        # 如果存在名为 "vit" 的属性,执行以下操作
        if getattr(self, "vit", None) is not None:
            # 在 TensorFlow 中创建名为 self.vit.name 的命名空间
            with tf.name_scope(self.vit.name):
                # 调用 self.vit 对象的 build 方法,参数为 None,即不指定输入形状
                self.vit.build(None)
        
        # 如果存在名为 "classifier" 的属性,执行以下操作
        if getattr(self, "classifier", None) is not None:
            # 在 TensorFlow 中创建名为 self.classifier.name 的命名空间
            with tf.name_scope(self.classifier.name):
                # 调用 self.classifier 对象的 build 方法,参数为 [None, None, self.config.hidden_size]
                # 表示指定输入形状为 [None, None, self.config.hidden_size]
                self.classifier.build([None, None, self.config.hidden_size])

.\models\vit\modeling_vit.py

# coding=utf-8
# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch ViT model."""

import collections.abc
import math
from typing import Dict, List, Optional, Set, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPooling,
    ImageClassifierOutput,
    MaskedImageModelingOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from .configuration_vit import ViTConfig

# Get logger for this module
logger = logging.get_logger(__name__)

# General docstring
_CONFIG_FOR_DOC = "ViTConfig"

# Base docstring
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224-in21k"
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]

# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"

# List of pretrained ViT model archives
VIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "google/vit-base-patch16-224",
    # See all ViT models at https://huggingface.co/models?filter=vit
]


class ViTEmbeddings(nn.Module):
    """
    Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
    """

    def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
        super().__init__()

        # Initialize CLS token as a learnable parameter
        self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))

        # Initialize mask token if `use_mask_token` is True
        self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None

        # Initialize patch embeddings
        self.patch_embeddings = ViTPatchEmbeddings(config)
        num_patches = self.patch_embeddings.num_patches
        
        # Initialize position embeddings for patches and CLS token
        self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
        
        # Dropout layer with dropout probability specified in config
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        
        # Store configuration
        self.config = config
    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
        """
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
        resolution images.

        Source:
        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
        """

        # 获取嵌入张量中的补丁数量
        num_patches = embeddings.shape[1] - 1
        # 获取预训练位置编码张量中的位置数量
        num_positions = self.position_embeddings.shape[1] - 1
        # 如果补丁数量和位置数量相等,并且输入高度和宽度相等,则直接返回位置编码张量
        if num_patches == num_positions and height == width:
            return self.position_embeddings
        # 从位置编码张量中获取类别位置编码
        class_pos_embed = self.position_embeddings[:, 0]
        # 从位置编码张量中获取补丁位置编码
        patch_pos_embed = self.position_embeddings[:, 1:]
        # 获取嵌入张量的最后一个维度(表示特征维度)
        dim = embeddings.shape[-1]
        # 计算调整后的高度和宽度
        h0 = height // self.config.patch_size
        w0 = width // self.config.patch_size
        # 添加一个小数以避免插值时的浮点数误差
        h0, w0 = h0 + 0.1, w0 + 0.1
        # 重塑补丁位置编码张量的形状
        patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
        # 将补丁位置编码张量的维度顺序重新排列
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
        # 使用双三次插值对补丁位置编码张量进行插值
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
            mode="bicubic",
            align_corners=False,
        )
        # 断言调整后的高度和宽度与插值后的张量形状一致
        assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
        # 将补丁位置编码张量的维度顺序重新排列
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        # 在类别位置编码张量和调整后的补丁位置编码张量之间进行拼接,并返回结果张量
        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
    ) -> torch.Tensor:
        # 获取输入张量的维度信息:批大小、通道数、高度、宽度
        batch_size, num_channels, height, width = pixel_values.shape
        # 使用自定义函数将像素值转换为补丁嵌入
        embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)

        # 如果存在布尔掩码,处理被掩码的位置
        if bool_masked_pos is not None:
            # 获取嵌入的序列长度
            seq_length = embeddings.shape[1]
            # 扩展掩码令牌以匹配批大小和序列长度
            mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
            # 创建掩码,将被掩码的可视令牌替换为掩码令牌
            mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
            embeddings = embeddings * (1.0 - mask) + mask_tokens * mask

        # 将[CLS]令牌添加到嵌入的补丁令牌中
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)

        # 添加位置编码到每个令牌
        if interpolate_pos_encoding:
            embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
        else:
            embeddings = embeddings + self.position_embeddings

        # 对嵌入进行丢弃操作,以防止过拟合
        embeddings = self.dropout(embeddings)

        # 返回处理后的嵌入张量
        return embeddings
class ViTPatchEmbeddings(nn.Module):
    """
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    """

    def __init__(self, config):
        super().__init__()
        # 从配置中获取图像大小和patch大小
        image_size, patch_size = config.image_size, config.patch_size
        # 从配置中获取通道数和隐藏层大小
        num_channels, hidden_size = config.num_channels, config.hidden_size

        # 如果图像大小不是可迭代对象,转换为元组
        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
        # 如果patch大小不是可迭代对象,转换为元组
        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
        # 计算patch的数量
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_patches = num_patches

        # 使用卷积层进行投影,将输入的通道数转换为隐藏层大小
        self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
        # 获取输入张量的形状信息
        batch_size, num_channels, height, width = pixel_values.shape
        # 检查输入通道数是否与配置中的通道数匹配
        if num_channels != self.num_channels:
            raise ValueError(
                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
                f" Expected {self.num_channels} but got {num_channels}."
            )
        # 如果不插值位置编码,检查输入图像大小是否与配置中的图像大小匹配
        if not interpolate_pos_encoding:
            if height != self.image_size[0] or width != self.image_size[1]:
                raise ValueError(
                    f"Input image size ({height}*{width}) doesn't match model"
                    f" ({self.image_size[0]}*{self.image_size[1]})."
                )
        # 对输入的像素值应用投影,并将结果展平并转置以生成嵌入向量
        embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
        return embeddings


class ViTSelfAttention(nn.Module):
    def __init__(self, config: ViTConfig) -> None:
        super().__init__()
        # 检查隐藏层大小是否能被注意力头数整除,同时检查是否有嵌入大小属性
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
                f"heads {config.num_attention_heads}."
            )

        # 初始化注意力头数和每个注意力头的大小
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # 初始化查询、键、值线性层,带有可选的偏置项
        self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
        self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
        self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)

        # 初始化dropout层,用于注意力概率的dropout
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
    # 对输入张量 x 进行形状转换,将最后两维重新组合成指定的形状,
    # 前面维度保持不变,以便后续计算注意力
    new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
    x = x.view(new_x_shape)
    return x.permute(0, 2, 1, 3)

# 定义前向传播函数,接受隐藏状态、头部掩码和是否输出注意力概率等参数
def forward(
    self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
    # 计算混合的查询向量
    mixed_query_layer = self.query(hidden_states)

    # 对键(key)和值(value)进行形状转换,以便计算注意力分数
    key_layer = self.transpose_for_scores(self.key(hidden_states))
    value_layer = self.transpose_for_scores(self.value(hidden_states))
    query_layer = self.transpose_for_scores(mixed_query_layer)

    # 计算注意力分数,即查询向量与键向量的点积
    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

    # 将注意力分数除以 sqrt(注意力头的大小),以减少梯度消失问题
    attention_scores = attention_scores / math.sqrt(self.attention_head_size)

    # 对注意力分数进行 softmax 操作,得到注意力概率
    attention_probs = nn.functional.softmax(attention_scores, dim=-1)

    # 使用 dropout 对注意力概率进行随机置零,以防止过拟合
    attention_probs = self.dropout(attention_probs)

    # 如果存在头部掩码,则将注意力概率与掩码相乘
    if head_mask is not None:
        attention_probs = attention_probs * head_mask

    # 计算加权后的值向量,即注意力概率与值向量的矩阵乘积
    context_layer = torch.matmul(attention_probs, value_layer)

    # 对结果进行形状转换,将头部维度重新合并为最后两个维度
    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
    new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
    context_layer = context_layer.view(new_context_layer_shape)

    # 根据是否输出注意力概率,选择输出的结果
    outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

    return outputs
class ViTSelfOutput(nn.Module):
    """
    The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    """

    def __init__(self, config: ViTConfig) -> None:
        super().__init__()
        # 定义一个全连接层,输入和输出维度都是 config.hidden_size
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 定义一个 Dropout 层,用于随机置零输入张量中的一些元素,防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 通过全连接层进行线性变换
        hidden_states = self.dense(hidden_states)
        # 对线性变换后的输出使用 Dropout 进行随机置零
        hidden_states = self.dropout(hidden_states)

        return hidden_states


class ViTAttention(nn.Module):
    def __init__(self, config: ViTConfig) -> None:
        super().__init__()
        # 创建一个自注意力模块,用于计算注意力分布
        self.attention = ViTSelfAttention(config)
        # 创建一个自定义的输出模块,用于处理自注意力模块的输出
        self.output = ViTSelfOutput(config)
        # 初始化一个空集合,用于存储需要剪枝的注意力头信息
        self.pruned_heads = set()

    def prune_heads(self, heads: Set[int]) -> None:
        if len(heads) == 0:
            return
        # 调用剪枝函数,找到需要剪枝的注意力头并返回
        heads, index = find_pruneable_heads_and_indices(
            heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
        )

        # 对注意力模块的查询、键、值以及输出层的全连接层进行剪枝
        self.attention.query = prune_linear_layer(self.attention.query, index)
        self.attention.key = prune_linear_layer(self.attention.key, index)
        self.attention.value = prune_linear_layer(self.attention.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # 更新超参数并记录剪枝的注意力头信息
        self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
        self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        # 调用自注意力模块进行前向传播
        self_outputs = self.attention(hidden_states, head_mask, output_attentions)

        # 将自注意力模块的输出传递给输出模块进行处理
        attention_output = self.output(self_outputs[0], hidden_states)

        # 如果需要输出注意力权重,则将它们添加到输出中
        outputs = (attention_output,) + self_outputs[1:]  # 如果输出了注意力权重,将它们添加到输出元组中
        return outputs


class ViTIntermediate(nn.Module):
    def __init__(self, config: ViTConfig) -> None:
        super().__init__()
        # 定义一个全连接层,将输入维度转换为 config.intermediate_size
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        # 根据配置选择激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 通过全连接层进行线性变换
        hidden_states = self.dense(hidden_states)
        # 应用配置中选择的激活函数
        hidden_states = self.intermediate_act_fn(hidden_states)

        return hidden_states


class ViTOutput(nn.Module):
    # 在 ViTOutput 类中的代码将在下一个问题中继续进行。
    # 初始化函数,用于初始化类的实例
    def __init__(self, config: ViTConfig) -> None:
        # 调用父类的初始化方法
        super().__init__()
        # 创建一个全连接层,输入大小为配置中的中间大小,输出大小为配置中的隐藏大小
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        # 创建一个dropout层,用于在训练过程中随机置零输入张量的部分元素,防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    # 前向传播函数,定义了数据从输入到输出的流程
    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 将隐藏状态通过全连接层
        hidden_states = self.dense(hidden_states)
        # 对全连接层的输出进行dropout操作
        hidden_states = self.dropout(hidden_states)

        # 将dropout后的隐藏状态与输入张量相加
        hidden_states = hidden_states + input_tensor

        # 返回前向传播的结果张量
        return hidden_states
class ViTLayer(nn.Module):
    """This corresponds to the Block class in the timm implementation."""

    def __init__(self, config: ViTConfig) -> None:
        super().__init__()
        self.chunk_size_feed_forward = config.chunk_size_feed_forward  # 设置块大小以进行前向传播分块处理
        self.seq_len_dim = 1  # 序列长度维度设定为1,通常用于处理输入序列的长度
        self.attention = ViTAttention(config)  # 初始化注意力机制模块
        self.intermediate = ViTIntermediate(config)  # 初始化中间层模块
        self.output = ViTOutput(config)  # 初始化输出层模块
        self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # 初始化前向传播前的层归一化
        self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # 初始化前向传播后的层归一化

    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        self_attention_outputs = self.attention(
            self.layernorm_before(hidden_states),  # 在ViT中,先应用层归一化再进行自注意力计算
            head_mask,
            output_attentions=output_attentions,
        )
        attention_output = self_attention_outputs[0]  # 获取自注意力的输出
        outputs = self_attention_outputs[1:]  # 如果输出注意力权重,将其添加到输出中

        # 第一个残差连接
        hidden_states = attention_output + hidden_states

        # 在ViT中,自注意力后也应用层归一化
        layer_output = self.layernorm_after(hidden_states)
        layer_output = self.intermediate(layer_output)  # 中间层处理

        # 第二个残差连接在这里完成
        layer_output = self.output(layer_output, hidden_states)  # 输出层处理

        outputs = (layer_output,) + outputs  # 将处理后的输出添加到结果中

        return outputs


class ViTEncoder(nn.Module):
    def __init__(self, config: ViTConfig) -> None:
        super().__init__()
        self.config = config
        self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])  # 创建多层ViTLayer组成的层列表
        self.gradient_checkpointing = False  # 梯度检查点设为False,通常用于优化内存消耗

    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        ) -> Union[tuple, BaseModelOutput]:
        # 如果不需要输出隐藏状态,则初始化为空元组;否则设为None,以便后续添加隐藏状态
        all_hidden_states = () if output_hidden_states else None
        # 如果不需要输出注意力权重,则初始化为空元组;否则设为None,以便后续添加注意力权重
        all_self_attentions = () if output_attentions else None

        # 遍历每个 Transformer 层
        for i, layer_module in enumerate(self.layer):
            # 如果需要输出隐藏状态,则将当前层的隐藏状态添加到 all_hidden_states 中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 获取当前层的头部掩码
            layer_head_mask = head_mask[i] if head_mask is not None else None

            # 如果启用了梯度检查点且在训练阶段
            if self.gradient_checkpointing and self.training:
                # 使用梯度检查点函数进行前向传播
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    hidden_states,
                    layer_head_mask,
                    output_attentions,
                )
            else:
                # 正常情况下,调用当前层的前向传播函数
                layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)

            # 更新隐藏状态为当前层的输出的第一个元素
            hidden_states = layer_outputs[0]

            # 如果需要输出注意力权重,则将当前层的注意力权重添加到 all_self_attentions 中
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        # 如果需要输出隐藏状态,则将最后一个隐藏状态添加到 all_hidden_states 中
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 如果不返回字典形式的输出
        if not return_dict:
            # 返回非空元素的元组,包括 hidden_states, all_hidden_states, all_self_attentions
            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
        # 返回 BaseModelOutput 对象,包含最后的隐藏状态、所有隐藏状态和所有注意力权重
        return BaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )
    """
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`ViTConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
    """
    # 接受输入参数:
    # pixel_values: 表示像素值的张量,形状为 `(batch_size, num_channels, height, width)`
    #               可以使用 `AutoImageProcessor` 获取。详见 `ViTImageProcessor.__call__` 的说明。
    # head_mask: 可选参数,形状可以是 `(num_heads,)` 或者 `(num_layers, num_heads)` 的张量。
    #            用于掩盖自注意力模块中选定的头部。掩盖值在 `[0, 1]` 范围内:
    #            - 1 表示该头部**未被掩盖**,
    #            - 0 表示该头部**被掩盖**。
    # output_attentions: 可选参数,布尔值,是否返回所有注意力层的注意力张量。
    #                    返回的张量中的 `attentions` 字段包含更多细节。
    # output_hidden_states: 可选参数,布尔值,是否返回所有层的隐藏状态。
    #                       返回的张量中的 `hidden_states` 字段包含更多细节。
    # interpolate_pos_encoding: 可选参数,布尔值,是否插值预训练的位置编码。
    # return_dict: 可选参数,布尔值,是否返回一个 `~utils.ModelOutput` 对象而不是普通元组。
"""
@add_start_docstrings(
    "The bare ViT Model transformer outputting raw hidden-states without any specific head on top.",
    VIT_START_DOCSTRING,
)
"""



class ViTModel(ViTPreTrainedModel):
    """
    ViT Model class inheriting from ViTPreTrainedModel.
    """

    def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
        """
        Initializes a ViTModel instance.

        Args:
            config (ViTConfig): Configuration class instance defining model architecture.
            add_pooling_layer (bool): Whether to add a pooling layer on top of the encoder.
            use_mask_token (bool): Whether to use a mask token for the model.

        """
        super().__init__(config)
        self.config = config

        # Initialize embeddings and encoder
        self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
        self.encoder = ViTEncoder(config)

        # Layer normalization
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        
        # Optional pooling layer
        self.pooler = ViTPooler(config) if add_pooling_layer else None

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self) -> ViTPatchEmbeddings:
        """
        Returns the patch embeddings used as input to the model.
        """
        return self.embeddings.patch_embeddings

    def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
        """
        Prunes heads of the model.

        Args:
            heads_to_prune (Dict[int, List[int]]): Dictionary of layers and heads to prune.

        See base class PreTrainedModel for more details.
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    """
    @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=BaseModelOutputWithPooling,
        config_class=_CONFIG_FOR_DOC,
        modality="vision",
        expected_output=_EXPECTED_OUTPUT_SHAPE,
    )
    """



    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        bool_masked_pos: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPooling]:
        r"""
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        """
        # 根据输入或者配置决定是否输出注意力矩阵
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 根据输入或者配置决定是否输出隐藏层状态
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 根据输入或者配置决定是否使用返回字典
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 如果未提供像素值,抛出数值错误异常
        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        # 准备头部掩码(如果需要)
        # head_mask 中的 1.0 表示保留对应的注意力头
        # attention_probs 的形状为 bsz x n_heads x N x N
        # 输入的 head_mask 的形状为 [num_heads] 或者 [num_hidden_layers x num_heads]
        # head_mask 被转换为形状 [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        # TODO: 可能有更干净的方式将输入转换(从 `ImageProcessor` 方面考虑?)
        # 如果像素值的数据类型与期望的数据类型不匹配,则转换像素值的数据类型
        expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
        if pixel_values.dtype != expected_dtype:
            pixel_values = pixel_values.to(expected_dtype)

        # 获取嵌入输出,根据输入的布尔掩码位置和插值位置编码进行插值
        embedding_output = self.embeddings(
            pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
        )

        # 编码器处理,传递嵌入输出,根据需要传递头部掩码、是否输出注意力、是否输出隐藏层状态、是否使用返回字典
        encoder_outputs = self.encoder(
            embedding_output,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 获取编码器的序列输出
        sequence_output = encoder_outputs[0]
        # 应用层归一化到序列输出
        sequence_output = self.layernorm(sequence_output)
        # 如果存在池化器,对序列输出进行池化
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

        # 如果不使用返回字典,返回头部输出和编码器输出的其余部分
        if not return_dict:
            head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
            return head_outputs + encoder_outputs[1:]

        # 使用自定义的返回类返回模型的输出,包括最终的隐藏状态、池化输出、隐藏层状态和注意力矩阵
        return BaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
class ViTPooler(nn.Module):
    def __init__(self, config: ViTConfig):
        super().__init__()
        # 定义一个全连接层,输入和输出的大小都是隐藏层大小
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 定义一个激活函数,使用双曲正切函数
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # 通过取第一个标记对应的隐藏状态来“池化”模型
        first_token_tensor = hidden_states[:, 0]
        # 将第一个标记的隐藏状态输入全连接层
        pooled_output = self.dense(first_token_tensor)
        # 应用激活函数到全连接层输出
        pooled_output = self.activation(pooled_output)
        return pooled_output


@add_start_docstrings(
    """ViT Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://arxiv.org/abs/2111.09886).

    <Tip>

    Note that we provide a script to pre-train this model on custom data in our [examples
    directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).

    </Tip>
    """,
    VIT_START_DOCSTRING,
)
class ViTForMaskedImageModeling(ViTPreTrainedModel):
    def __init__(self, config: ViTConfig) -> None:
        super().__init__(config)

        # 初始化ViT模型,设置不添加池化层和使用掩码标记
        self.vit = ViTModel(config, add_pooling_layer=False, use_mask_token=True)

        # 定义解码器
        self.decoder = nn.Sequential(
            # 定义一个2D卷积层,输入通道数为隐藏层大小,输出通道数为config中的计算结果
            nn.Conv2d(
                in_channels=config.hidden_size,
                out_channels=config.encoder_stride**2 * config.num_channels,
                kernel_size=1,
            ),
            # 像素重排操作
            nn.PixelShuffle(config.encoder_stride),
        )

        # 初始化权重并进行最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=MaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        bool_masked_pos: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    # 初始化函数,用于初始化一个视觉Transformer模型
    def __init__(self, config: ViTConfig) -> None:
        # 调用父类的初始化方法
        super().__init__(config)

        # 设置模型的标签数量
        self.num_labels = config.num_labels
        # 创建一个不带池化层的ViT模型实例
        self.vit = ViTModel(config, add_pooling_layer=False)

        # 分类器头部
        # 如果标签数量大于0,则创建一个线性层作为分类器;否则创建一个恒等映射
        self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()

        # 初始化权重并应用最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_IMAGE_CLASS_CHECKPOINT,
        output_type=ImageClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
        expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
    )
    # 前向传播函数,接收像素值、头部掩码、标签等参数,并返回模型的输出
    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        ) -> Union[tuple, ImageClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        # 初始化返回字典,如果未提供则使用配置中的默认设置
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 将输入的像素值和其他参数传递给 Vision Transformer 模型进行处理
        outputs = self.vit(
            pixel_values,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            interpolate_pos_encoding=interpolate_pos_encoding,
            return_dict=return_dict,
        )

        # 从模型输出中提取序列输出(Sequence Output)
        sequence_output = outputs[0]

        # 将序列输出的首个位置的特征向量输入分类器,得到分类器的 logits
        logits = self.classifier(sequence_output[:, 0, :])

        # 初始化损失为 None
        loss = None
        if labels is not None:
            # 将标签移动到正确的设备以实现模型的并行计算
            labels = labels.to(logits.device)
            # 根据问题类型动态确定问题类型
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            # 根据问题类型选择相应的损失函数
            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        # 如果 return_dict 为 False,则返回一个包含 logits 和额外输出的元组
        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        # 如果 return_dict 为 True,则返回一个 ImageClassifierOutput 对象
        return ImageClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
posted @ 2024-07-01 10:57  绝不原创的飞龙  阅读(11)  评论(0编辑  收藏  举报