Transformers-源码解析-一百一十九-

Transformers 源码解析(一百一十九)

.\models\vit_hybrid\convert_vit_hybrid_timm_to_pytorch.py

# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ViT hybrid checkpoints from the timm library."""

import argparse  # 导入用于解析命令行参数的模块
import json  # 导入处理 JSON 格式数据的模块
from pathlib import Path  # 导入处理路径的模块

import requests  # 导入发送 HTTP 请求的模块
import timm  # 导入用于训练和评估神经网络模型的模块
import torch  # 导入 PyTorch 深度学习框架
from huggingface_hub import hf_hub_download  # 导入与 Hugging Face Hub 集成的下载功能
from PIL import Image  # 导入处理图像的模块
from timm.data import resolve_data_config  # 导入用于配置数据加载的函数
from timm.data.transforms_factory import create_transform  # 导入创建数据转换的工厂函数

from transformers import (
    BitConfig,  # 导入 Bit 模型的配置类
    ViTHybridConfig,  # 导入 ViT Hybrid 模型的配置类
    ViTHybridForImageClassification,  # 导入用于图像分类的 ViT Hybrid 模型类
    ViTHybridImageProcessor,  # 导入用于处理图像的 ViT Hybrid 图像处理器类
    ViTHybridModel,  # 导入 ViT Hybrid 模型类
)
from transformers.image_utils import PILImageResampling  # 导入图像重采样功能
from transformers.utils import logging  # 导入日志记录工具

logging.set_verbosity_info()  # 设置日志记录的详细级别为信息
logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

# here we list all keys to be renamed (original name on the left, our name on the right)
def create_rename_keys(config, base_model=False):
    rename_keys = []  # 初始化空列表用于存储重命名的键值对

    # fmt: off
    # stem:
    rename_keys.append(("cls_token", "vit.embeddings.cls_token"))  # 添加 cls_token 的重命名映射
    rename_keys.append(("pos_embed", "vit.embeddings.position_embeddings"))  # 添加 pos_embed 的重命名映射

    rename_keys.append(("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"))  # 添加 patch_embed.proj.weight 的重命名映射
    rename_keys.append(("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"))  # 添加 patch_embed.proj.bias 的重命名映射

    # backbone
    rename_keys.append(("patch_embed.backbone.stem.conv.weight", "vit.embeddings.patch_embeddings.backbone.bit.embedder.convolution.weight"))  # 添加 patch_embed.backbone.stem.conv.weight 的重命名映射
    rename_keys.append(("patch_embed.backbone.stem.norm.weight", "vit.embeddings.patch_embeddings.backbone.bit.embedder.norm.weight"))  # 添加 patch_embed.backbone.stem.norm.weight 的重命名映射
    rename_keys.append(("patch_embed.backbone.stem.norm.bias", "vit.embeddings.patch_embeddings.backbone.bit.embedder.norm.bias"))  # 添加 patch_embed.backbone.stem.norm.bias 的重命名映射
    # fmt: on
    # 遍历配置中每个阶段的深度
    for stage_idx in range(len(config.backbone_config.depths)):
        # 遍历当前阶段的每个层级
        for layer_idx in range(config.backbone_config.depths[stage_idx]):
            # 添加重命名键值对,将原始模型中的权重和偏置名称映射到新的Transformer模型的对应位置
            rename_keys.append((
                f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.conv1.weight",
                f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.conv1.weight"))
            rename_keys.append((
                f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm1.weight",
                f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm1.weight"))
            rename_keys.append((
                f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm1.bias",
                f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm1.bias"))
            rename_keys.append((
                f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.conv2.weight",
                f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.conv2.weight"))
            rename_keys.append((
                f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm2.weight",
                f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm2.weight"))
            rename_keys.append((
                f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm2.bias",
                f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm2.bias"))
            rename_keys.append((
                f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.conv3.weight",
                f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.conv3.weight"))
            rename_keys.append((
                f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm3.weight",
                f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm3.weight"))
            rename_keys.append((
                f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm3.bias",
                f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm3.bias"))

        # 添加重命名键值对,将原始模型中的第一个块的下采样卷积和规范化层的名称映射到Transformer模型的对应位置
        rename_keys.append((
            f"patch_embed.backbone.stages.{stage_idx}.blocks.0.downsample.conv.weight",
            f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.0.downsample.conv.weight"))
        rename_keys.append((
            f"patch_embed.backbone.stages.{stage_idx}.blocks.0.downsample.norm.weight",
            f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.0.downsample.norm.weight"))
        rename_keys.append((
            f"patch_embed.backbone.stages.{stage_idx}.blocks.0.downsample.norm.bias",
            f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.0.downsample.norm.bias"))

    # transformer encoder
    for i in range(config.num_hidden_layers):
        # 遍历编码器层次:输出投影、2个前馈神经网络和2个层归一化模块
        rename_keys.append((f"blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight"))
        rename_keys.append((f"blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias"))
        rename_keys.append((f"blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight"))
        rename_keys.append((f"blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias"))
        rename_keys.append((f"blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight"))
        rename_keys.append((f"blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias"))
        rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight"))
        rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias"))
        rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight"))
        rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias"))

    if base_model:
        # 如果是基础模型,进行下面的重命名操作:层归一化 + 池化器
        rename_keys.extend(
            [
                ("norm.weight", "layernorm.weight"),
                ("norm.bias", "layernorm.bias"),
                ("pre_logits.fc.weight", "pooler.dense.weight"),
                ("pre_logits.fc.bias", "pooler.dense.bias"),
            ]
        )

        # 如果仅仅是基础模型,需要从所有以 "vit" 开头的键名中移除 "vit"
        rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys]
    else:
        # 如果不是基础模型,进行下面的重命名操作:层归一化 + 分类头
        rename_keys.extend(
            [
                ("norm.weight", "vit.layernorm.weight"),
                ("norm.bias", "vit.layernorm.bias"),
                ("head.weight", "classifier.weight"),
                ("head.bias", "classifier.bias"),
            ]
        )
    # 格式化结束
    return rename_keys
# 将每个编码器层的权重矩阵拆分为查询(query)、键(key)和值(value)
def read_in_q_k_v(state_dict, config, base_model=False):
    # 遍历每个编码器层
    for i in range(config.num_hidden_layers):
        if base_model:
            prefix = ""
        else:
            prefix = "vit."
        
        # 读取输入投影层(在timm中是一个单独的矩阵加偏置)的权重和偏置
        in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
        in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
        
        # 将查询(query)、键(key)和值(value)依次添加到state_dict中
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
            : config.hidden_size, :
        ]
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
            config.hidden_size : config.hidden_size * 2, :
        ]
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
            config.hidden_size : config.hidden_size * 2
        ]
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
            -config.hidden_size :, :
        ]
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]


# 移除state_dict中的分类头部权重和偏置
def remove_classification_head_(state_dict):
    ignore_keys = ["head.weight", "head.bias"]
    for k in ignore_keys:
        state_dict.pop(k, None)


# 将字典中的旧键(old)替换为新键(new)
def rename_key(dct, old, new):
    val = dct.pop(old)
    dct[new] = val


# 准备一个可爱猫咪的图片,用于验证结果
def prepare_img():
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    im = Image.open(requests.get(url, stream=True).raw)
    return im


@torch.no_grad()
def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path, push_to_hub=False):
    """
    将模型权重复制/粘贴/调整为我们的ViT结构。
    """

    # 定义默认的ViT混合配置
    backbone_config = BitConfig(
        global_padding="same",
        layer_type="bottleneck",
        depths=(3, 4, 9),
        out_features=["stage3"],
        embedding_dynamic_padding=True,
    )
    config = ViTHybridConfig(backbone_config=backbone_config, image_size=384, num_labels=1000)
    base_model = False

    # 从timm中加载原始模型
    timm_model = timm.create_model(vit_name, pretrained=True)
    timm_model.eval()

    # 加载原始模型的state_dict,移除并重命名一些键
    state_dict = timm_model.state_dict()
    if base_model:
        remove_classification_head_(state_dict)
    rename_keys = create_rename_keys(config, base_model)  # 此处缺少create_rename_keys函数的定义,但在原代码中未提及
    for src, dest in rename_keys:
        rename_key(state_dict, src, dest)
    read_in_q_k_v(state_dict, config, base_model)

    repo_id = "huggingface/label-files"
    filename = "imagenet-1k-id2label.json"
    # 使用 HuggingFace Hub 下载指定资源,并加载为 JSON 格式
    id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
    # 将 id2label 中的键转换为整数类型,保留其原始值作为对应的值
    id2label = {int(k): v for k, v in id2label.items()}
    # 将 id2label 设置为配置对象中的 id 到标签的映射
    config.id2label = id2label
    # 将 id2label 反转,创建标签到 id 的映射,并设置为配置对象中的 label2id
    config.label2id = {v: k for k, v in id2label.items()}

    # 加载 HuggingFace 模型
    if vit_name[-5:] == "in21k":
        # 如果 vit_name 的后缀是 "in21k",则创建 ViTHybridModel 对象并设为评估模式
        model = ViTHybridModel(config).eval()
    else:
        # 否则创建 ViTHybridForImageClassification 对象并设为评估模式
        model = ViTHybridForImageClassification(config).eval()
    # 加载模型的状态字典
    model.load_state_dict(state_dict)

    # 创建图像处理器
    transform = create_transform(**resolve_data_config({}, model=timm_model))
    timm_transforms = transform.transforms

    # 定义 PIL 图像重采样方式的映射关系
    pillow_resamplings = {
        "bilinear": PILImageResampling.BILINEAR,
        "bicubic": PILImageResampling.BICUBIC,
        "nearest": PILImageResampling.NEAREST,
    }

    # 创建 ViTHybridImageProcessor 实例,配置各种图像处理选项
    processor = ViTHybridImageProcessor(
        do_resize=True,
        size={"shortest_edge": timm_transforms[0].size},
        resample=pillow_resamplings[timm_transforms[0].interpolation.value],
        do_center_crop=True,
        crop_size={"height": timm_transforms[1].size[0], "width": timm_transforms[1].size[1]},
        do_normalize=True,
        image_mean=timm_transforms[-1].mean.tolist(),
        image_std=timm_transforms[-1].std.tolist(),
    )

    # 准备图像数据
    image = prepare_img()
    # 对图像进行变换并扩展维度,以适应模型输入要求
    timm_pixel_values = transform(image).unsqueeze(0)
    # 使用图像处理器处理图像,并获取处理后的像素值
    pixel_values = processor(image, return_tensors="pt").pixel_values

    # 验证像素值是否一致
    assert torch.allclose(timm_pixel_values, pixel_values)

    # 使用无梯度计算上下文,对处理后的像素值进行模型推断
    with torch.no_grad():
        outputs = model(pixel_values)
        logits = outputs.logits

    # 打印预测类别的 logit 最大值对应的预测类别
    print("Predicted class:", logits.argmax(-1).item())
    
    # 如果指定了 base_model,则使用 timm_model 进行特征前向传播并验证输出的形状和值是否一致
    if base_model:
        timm_pooled_output = timm_model.forward_features(pixel_values)
        assert timm_pooled_output.shape == outputs.pooler_output.shape
        assert torch.allclose(timm_pooled_output, outputs.pooler_output, atol=1e-3)
    else:
        # 否则直接使用 timm_model 进行推断并验证输出的形状和值是否一致
        timm_logits = timm_model(pixel_values)
        assert timm_logits.shape == outputs.logits.shape
        assert torch.allclose(timm_logits, outputs.logits, atol=1e-3)
    # 打印验证通过信息
    print("Looks ok!")

    # 如果指定了 pytorch_dump_folder_path,则保存模型和处理器到指定路径
    if pytorch_dump_folder_path is not None:
        # 确保路径存在,如果不存在则创建
        Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
        # 打印保存模型信息
        print(f"Saving model {vit_name} to {pytorch_dump_folder_path}")
        # 将模型保存到指定路径
        model.save_pretrained(pytorch_dump_folder_path)
        # 打印保存处理器信息
        print(f"Saving processor to {pytorch_dump_folder_path}")
        # 将处理器保存到指定路径
        processor.save_pretrained(pytorch_dump_folder_path)

    # 如果指定了 push_to_hub,则推送模型和处理器到 HuggingFace Hub
    if push_to_hub:
        # 打印推送模型和处理器到 Hub 的信息
        print(f"Pushing model and processor to the hub {vit_name}")
        # 将模型推送到 Hub 上的指定路径
        model.push_to_hub(f"ybelkada/{vit_name}")
        # 将处理器推送到 Hub 上的指定路径
        processor.push_to_hub(f"ybelkada/{vit_name}")
if __name__ == "__main__":
    # 如果当前脚本作为主程序运行,则执行以下代码块

    parser = argparse.ArgumentParser()
    # 创建参数解析器对象

    # 必填参数
    parser.add_argument(
        "--vit_name",
        default="vit_base_r50_s16_384",
        type=str,
        help="Name of the hybrid ViT timm model you'd like to convert.",
    )
    # 添加参数 `vit_name`,默认为 "vit_base_r50_s16_384",类型为字符串,用于指定要转换的混合 ViT 模型的名称

    parser.add_argument(
        "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
    )
    # 添加参数 `pytorch_dump_folder_path`,默认为 None,类型为字符串,用于指定输出的 PyTorch 模型保存路径

    parser.add_argument(
        "--push_to_hub", action="store_true", help="Whether to upload the model to the HuggingFace hub."
    )
    # 添加参数 `push_to_hub`,如果提供该参数则将模型上传至 HuggingFace hub

    args = parser.parse_args()
    # 解析命令行参数并存储到 `args` 变量中

    convert_vit_checkpoint(args.vit_name, args.pytorch_dump_folder_path, args.push_to_hub)
    # 调用 `convert_vit_checkpoint` 函数,传入解析得到的参数 `vit_name`、`pytorch_dump_folder_path` 和 `push_to_hub`

.\models\vit_hybrid\image_processing_vit_hybrid.py

    r"""
    Constructs a ViT Hybrid image processor.
    """

    def __init__(self, image_size: Union[int, List[int]], mean: Optional[List[float]] = OPENAI_CLIP_MEAN,
                 std: Optional[List[float]] = OPENAI_CLIP_STD, resampling: PILImageResampling = PILImageResampling.BILINEAR,
                 channel_dim: ChannelDimension = ChannelDimension.LAST, dtype: TensorType = np.float32):
        """
        Initialize the ViT Hybrid image processor.

        Parameters:
        - image_size (Union[int, List[int]]): Desired size of the output image.
        - mean (Optional[List[float]]): Mean values for normalization, defaults to OPENAI_CLIP_MEAN.
        - std (Optional[List[float]]): Standard deviation values for normalization, defaults to OPENAI_CLIP_STD.
        - resampling (PILImageResampling): Resampling method for image resizing, defaults to PILImageResampling.BILINEAR.
        - channel_dim (ChannelDimension): Channel dimension format, defaults to ChannelDimension.LAST.
        - dtype (TensorType): Data type of the processed images, defaults to np.float32.
        """
        super().__init__()
        self.image_size = image_size
        self.mean = mean
        self.std = std
        self.resampling = resampling
        self.channel_dim = channel_dim
        self.dtype = dtype

    def __call__(self, images: Union[ImageInput, List[ImageInput]], return_tensors: bool = True,
                 **kwargs) -> Union[BatchFeature, List[BatchFeature]]:
        """
        Process a single image or a batch of images.

        Parameters:
        - images (Union[ImageInput, List[ImageInput]]): Input image(s) to be processed.
        - return_tensors (bool): Whether to return tensors (True) or numpy arrays (False), defaults to True.
        - **kwargs: Additional keyword arguments for preprocessing.

        Returns:
        - Union[BatchFeature, List[BatchFeature]]: Processed image(s) as tensors or numpy arrays.
        """
        # Validate and preprocess input arguments
        images = make_list_of_images(images)
        validate_kwargs(kwargs)
        validate_preprocess_arguments(self.mean, self.std, self.image_size, self.channel_dim)

        # Resize images to the desired size
        resized_images = [resize(image, self.image_size, self.resampling) for image in images]

        # Convert images to RGB format if needed
        rgb_images = [convert_to_rgb(image) for image in resized_images]

        # Ensure images have the correct channel dimension format
        formatted_images = [to_channel_dimension_format(image, self.channel_dim) for image in rgb_images]

        # Normalize images
        normalized_images = [self._normalize_image(image) for image in formatted_images]

        # Convert images to numpy arrays or tensors based on return_tensors flag
        if return_tensors:
            return np.stack(normalized_images).astype(self.dtype)
        else:
            return normalized_images

    def _normalize_image(self, image: np.ndarray) -> np.ndarray:
        """
        Normalize the image data using mean and standard deviation.

        Parameters:
        - image (np.ndarray): Input image data.

        Returns:
        - np.ndarray: Normalized image data.
        """
        mean = np.array(self.mean)
        std = np.array(self.std)
        return (image.astype(np.float32) - mean) / std
    Args:
        do_resize (`bool`, *optional*, defaults to `True`):
            Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
            `do_resize` in the `preprocess` method.
        size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
            Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
            the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
            method.
        resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
            Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
        do_center_crop (`bool`, *optional*, defaults to `True`):
            Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
            `preprocess` method.
        crop_size (`Dict[str, int]` *optional*, defaults to 224):
            Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
            method.
        do_rescale (`bool`, *optional*, defaults to `True`):
            Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
            the `preprocess` method.
        rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
            Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
            method.
        do_normalize:
            Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
        image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
            Mean to use if normalizing the image. This is a float or list of floats the length of the number of
            channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
        image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
            Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
            number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
            Can be overridden by the `image_std` parameter in the `preprocess` method.
        do_convert_rgb (`bool`, *optional*, defaults to `True`):
            Whether to convert the image to RGB.
    """

    # 定义模型输入的名称列表,包含一个元素 "pixel_values"
    model_input_names = ["pixel_values"]
    # 初始化方法,用于设置图像处理器的各种参数
    def __init__(
        self,
        do_resize: bool = True,  # 是否进行图像大小调整的标志
        size: Dict[str, int] = None,  # 图像调整后的尺寸字典,默认最短边为224
        resample: PILImageResampling = PILImageResampling.BICUBIC,  # 图像调整时的重采样方法,默认为双三次插值
        do_center_crop: bool = True,  # 是否进行中心裁剪的标志
        crop_size: Dict[str, int] = None,  # 中心裁剪后的尺寸字典,默认为224x224
        do_rescale: bool = True,  # 是否进行图像数值缩放的标志
        rescale_factor: Union[int, float] = 1 / 255,  # 图像数值缩放的因子,默认为1/255
        do_normalize: bool = True,  # 是否进行图像标准化的标志
        image_mean: Optional[Union[float, List[float]]] = None,  # 图像标准化的均值,默认为OpenAI CLIP模型的均值
        image_std: Optional[Union[float, List[float]]] = None,  # 图像标准化的标准差,默认为OpenAI CLIP模型的标准差
        do_convert_rgb: bool = True,  # 是否进行RGB格式转换的标志
        **kwargs,  # 其他可选参数
    ) -> None:
        # 调用父类初始化方法,传递额外参数
        super().__init__(**kwargs)
        # 如果size为None,则设置默认尺寸字典,最短边为224
        size = size if size is not None else {"shortest_edge": 224}
        # 根据参数获取尺寸字典,不默认为正方形
        size = get_size_dict(size, default_to_square=False)
        # 如果crop_size为None,则设置默认的裁剪尺寸字典,高度和宽度均为224
        crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
        # 根据参数获取裁剪尺寸字典,默认为正方形
        crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")

        # 将参数赋值给实例变量
        self.do_resize = do_resize
        self.size = size
        self.resample = resample
        self.do_center_crop = do_center_crop
        self.crop_size = crop_size
        self.do_rescale = do_rescale
        self.rescale_factor = rescale_factor
        self.do_normalize = do_normalize
        self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
        self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
        self.do_convert_rgb = do_convert_rgb
        # 设置有效的处理器关键字列表
        self._valid_processor_keys = [
            "images",
            "do_resize",
            "size",
            "resample",
            "do_center_crop",
            "crop_size",
            "do_rescale",
            "rescale_factor",
            "do_normalize",
            "image_mean",
            "image_std",
            "do_convert_rgb",
            "return_tensors",
            "data_format",
            "input_data_format",
        ]

    # 从transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize复制而来的方法
    def resize(
        self,
        image: np.ndarray,  # 待处理的图像数据,NumPy数组格式
        size: Dict[str, int],  # 目标尺寸字典,包含高度和宽度信息
        resample: PILImageResampling = PILImageResampling.BICUBIC,  # 重采样方法,默认为双三次插值
        data_format: Optional[Union[str, ChannelDimension]] = None,  # 数据格式,可选参数
        input_data_format: Optional[Union[str, ChannelDimension]] = None,  # 输入数据格式,可选参数
        **kwargs,  # 其他可选参数
    ) -> np.ndarray:
        """
        Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
        resized to keep the input aspect ratio.

        Args:
            image (`np.ndarray`):
                Image to resize.
            size (`Dict[str, int]`):
                Size of the output image.
            resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
                Resampling filter to use when resizing the image.
            data_format (`str` or `ChannelDimension`, *optional*):
                The channel dimension format of the image. If not provided, it will be the same as the input image.
            input_data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format of the input image. If not provided, it will be inferred.
        """
        # 默认情况下将图片调整为正方形
        default_to_square = True
        if "shortest_edge" in size:
            # 如果指定了最短边的长度,则按照最短边调整图片大小
            size = size["shortest_edge"]
            default_to_square = False
        elif "height" in size and "width" in size:
            # 如果指定了高度和宽度,则按照这两个尺寸调整图片大小
            size = (size["height"], size["width"])
        else:
            # 如果大小参数中既没有指定最短边,也没有指定高度和宽度,则抛出数值错误
            raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")

        # 获得调整后的图片尺寸
        output_size = get_resize_output_image_size(
            image,
            size=size,
            default_to_square=default_to_square,
            input_data_format=input_data_format,
        )
        # 调整图片大小并返回调整后的图片数据
        return resize(
            image,
            size=output_size,
            resample=resample,
            data_format=data_format,
            input_data_format=input_data_format,
            **kwargs,
        )

.\models\vit_hybrid\modeling_vit_hybrid.py

# coding=utf-8
# 版权 2022 Google AI、Ross Wightman、The HuggingFace Inc. team。保留所有权利。
#
# 根据 Apache 许可证 2.0 版本(“许可证”)获得许可;
# 除非符合许可证要求,否则不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则按“原样”分发软件,
# 没有任何明示或暗示的保证或条件。
# 有关特定语言的权限,请参阅许可证。

""" PyTorch ViT Hybrid model. """

import collections.abc
import math
from typing import Dict, List, Optional, Set, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

# 导入模型输出和工具类
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging

# 加载背景工具模块
from ...utils.backbone_utils import load_backbone
from .configuration_vit_hybrid import ViTHybridConfig

# 获取记录器对象
logger = logging.get_logger(__name__)

# 用于文档的通用字符串
_CONFIG_FOR_DOC = "ViTHybridConfig"

# 用于文档的基本字符串
_CHECKPOINT_FOR_DOC = "google/vit-hybrid-base-bit-384"
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]

# 用于图像分类的文档字符串
_IMAGE_CLASS_CHECKPOINT = "google/vit-hybrid-base-bit-384"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"

# ViT Hybrid 预训练模型的存档列表
VIT_HYBRID_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "google/vit-hybrid-base-bit-384",
    # 查看所有 ViT Hybrid 模型的列表:https://huggingface.co/models?filter=vit-hybrid
]

class ViTHybridEmbeddings(nn.Module):
    """
    构建CLS标记、位置和补丁嵌入。可选择添加掩码标记。
    """

    # 从 transformers.models.vit.modeling_vit.ViTEmbeddings.__init__ 复制而来,将 ViT 改为 ViTHybrid
    def __init__(self, config: ViTHybridConfig, use_mask_token: bool = False) -> None:
        super().__init__()

        # 定义CLS标记作为可训练参数
        self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
        
        # 如果需要,定义掩码标记作为可训练参数;否则为None
        self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
        
        # 初始化补丁嵌入层
        self.patch_embeddings = ViTHybridPatchEmbeddings(config)
        
        # 计算补丁数目(用于位置编码)
        num_patches = self.patch_embeddings.num_patches
        
        # 定义位置编码作为可训练参数
        self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
        
        # 定义Dropout层
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        
        # 保存配置对象
        self.config = config
    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
        """
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
        resolution images.

        Source:
        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
        """

        # 计算当前嵌入向量中的补丁数量和预训练位置编码中的位置数量
        num_patches = embeddings.shape[1] - 1
        num_positions = self.position_embeddings.shape[1] - 1
        
        # 如果补丁数量和位置数量相等,并且图像的高度和宽度也相等,则直接返回位置编码
        if num_patches == num_positions and height == width:
            return self.position_embeddings
        
        # 从位置编码中提取类别位置编码和补丁位置编码
        class_pos_embed = self.position_embeddings[:, 0]
        patch_pos_embed = self.position_embeddings[:, 1:]
        
        # 获取嵌入向量的维度信息
        dim = embeddings.shape[-1]
        
        # 根据配置中的补丁大小调整图像的高度和宽度
        height = height // self.config.patch_size
        width = width // self.config.patch_size
        
        # 为了避免插值时的浮点误差,向高度和宽度添加一个小的数值
        height, width = height + 0.1, width + 0.1
        
        # 将补丁位置编码重塑为合适的形状,以便进行插值
        patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
        
        # 使用双三次插值对补丁位置编码进行插值
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)),
            mode="bicubic",
            align_corners=False,
        )
        
        # 检查插值后的高度和宽度是否与预期的一致,否则抛出值错误
        if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
            raise ValueError(f"Invalid height or width: {height}, {width}")
        
        # 调整补丁位置编码的形状,并将类别位置编码与之合并
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
    ) -> torch.Tensor:
        # 获取输入张量的形状信息,分别为批量大小、通道数、高度和宽度
        batch_size, num_channels, height, width = pixel_values.shape
        # 使用 patch_embeddings 方法对输入的像素值进行嵌入处理,包括是否插值位置编码的选择
        embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)

        # 如果存在 bool_masked_pos,则执行以下操作
        if bool_masked_pos is not None:
            # 获取嵌入后张量的序列长度
            seq_length = embeddings.shape[1]
            # 将 mask_token 在批量维度和序列长度维度上进行扩展
            mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
            # 创建用于掩盖被标记视觉标记的 mask 张量,并将其类型转换为与 mask_tokens 相同的类型
            mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
            # 将 embeddings 中被 mask 标记的部分替换为 mask_tokens,保持其它部分不变
            embeddings = embeddings * (1.0 - mask) + mask_tokens * mask

        # 将 [CLS] 标记添加到嵌入的补丁标记中
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)

        # 将位置编码添加到每个标记中
        if interpolate_pos_encoding:
            # 如果选择插值位置编码,则对 embeddings 应用 interpolate_pos_encoding 方法
            embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
        else:
            # 否则,直接将预定义的位置编码添加到 embeddings 中
            embeddings = embeddings + self.position_embeddings

        # 对嵌入的张量应用 dropout 操作
        embeddings = self.dropout(embeddings)

        # 返回处理后的嵌入张量
        return embeddings
# 定义一个继承自 nn.Module 的类 ViTHybridPatchEmbeddings,用于将形状为 `(batch_size, num_channels, height, width)` 的像素值转换成形状为 `(batch_size, seq_length, hidden_size)` 的初始隐藏状态(补丁嵌入),以便供 Transformer 使用。

    """
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    """

    def __init__(self, config, feature_size=None):
        super().__init__()
        
        # 从配置中获取图像大小和补丁大小
        image_size, patch_size = config.image_size, config.patch_size
        num_channels, hidden_size = config.num_channels, config.hidden_size
        
        # 将图像大小和补丁大小转为可迭代对象,确保其为元组
        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
        
        # 载入指定配置的主干模型
        self.backbone = load_backbone(config)
        
        # 检查主干模型是否为 "bit" 类型,否则引发 ValueError 异常
        if self.backbone.config.model_type != "bit":
            raise ValueError(f"Backbone model type {self.backbone.model_type} is not supported.")
        
        # 获取主干模型的最终特征维度
        feature_dim = self.backbone.channels[-1]
        
        # 如果未提供特征大小,则从配置中获取主干模型的特征映射形状
        if feature_size is None:
            feature_map = config.backbone_featmap_shape
            
            # 提取特征大小,并设置特征维度为特征映射的第二维度值
            feature_size = feature_map[-2:]
            feature_dim = feature_map[1]
        else:
            # 将特征大小转为元组形式,确保其为可迭代对象
            feature_size = (
                feature_size if isinstance(feature_size, collections.abc.Iterable) else (feature_size, feature_size)
            )
            # 获取主干模型的最终特征维度
            feature_dim = self.backbone.channels[-1]
        
        # 计算网格大小,即特征大小除以补丁大小得到的元组
        self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
        # 计算补丁数量,即网格大小的乘积
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        # 设置图像大小和补丁大小
        self.image_size = image_size
        self.patch_size = patch_size
        # 设置通道数
        self.num_channels = num_channels
        
        # 定义投影层,使用二维卷积将特征维度投影到隐藏大小,卷积核大小为补丁大小,步长为补丁大小
        self.projection = nn.Conv2d(feature_dim, hidden_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
        # 提取输入张量的形状信息
        _, num_channels, height, width = pixel_values.shape
        
        # 如果通道数不匹配预设的通道数,引发 ValueError 异常
        if num_channels != self.num_channels:
            raise ValueError(
                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
            )
        
        # 如果不需要插值位置编码
        if not interpolate_pos_encoding:
            # 如果输入图像的高度或宽度与模型预设的图像大小不匹配,引发 ValueError 异常
            if height != self.image_size[0] or width != self.image_size[1]:
                raise ValueError(
                    f"Input image size ({height}*{width}) doesn't match model"
                    f" ({self.image_size[0]}*{self.image_size[1]})."
                )
        
        # 使用主干模型处理输入像素值,提取最终特征映射的最后一个
        features = self.backbone(pixel_values).feature_maps[-1]
        
        # 使用投影层对特征映射进行卷积投影,然后展平为二维张量并转置维度
        embeddings = self.projection(features).flatten(2).transpose(1, 2)
        
        # 返回转换后的补丁嵌入张量
        return embeddings


# 从 transformers.models.vit.modeling_vit.ViTSelfAttention 复制并修改为 ViTHybridSelfAttention
class ViTHybridSelfAttention(nn.Module):
    def __init__(self, config: ViTHybridConfig) -> None:
        super().__init__()
        # 检查隐藏层大小是否是注意力头数的整数倍,并且是否定义了嵌入大小
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
                f"heads {config.num_attention_heads}."
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # 创建用于查询、键和值的线性层,每个线性层输出的大小为 all_head_size
        self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
        self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
        self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)

        # 定义用于 dropout 的层
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        # 调整输入张量 x 的形状,以便适应多头注意力的计算
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        # 通过查询线性层处理隐藏状态,生成混合查询层
        mixed_query_layer = self.query(hidden_states)

        # 使用 transpose_for_scores 方法对键和值线性层的输出进行适应性调整
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)

        # 计算注意力分数,即查询与键的点积
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        # 对注意力分数进行缩放
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        # 对注意力分数进行 softmax 操作,将其转换为概率分布
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # 使用 dropout 对注意力概率进行随机失活处理
        attention_probs = self.dropout(attention_probs)

        # 如果指定了 head_mask,则应用 head_mask 到注意力概率上
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        # 计算上下文层,即注意力概率与值层的乘积
        context_layer = torch.matmul(attention_probs, value_layer)

        # 调整上下文层的形状以匹配输出要求
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)

        # 根据输出要求构建输出元组
        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        return outputs
# 从 transformers.models.vit.modeling_vit.ViTSelfOutput 复制而来,修改为 ViTHybridSelfOutput
class ViTHybridSelfOutput(nn.Module):
    """
    在 ViTHybridLayer 中定义残差连接,而不是像其他模型一样在此处定义,这是因为每个块前都应用了 layernorm。
    """

    def __init__(self, config: ViTHybridConfig) -> None:
        super().__init__()
        # 定义一个线性层,将输入的隐藏状态映射到相同大小的输出空间
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 定义一个 dropout 层,以防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 通过线性层处理隐藏状态
        hidden_states = self.dense(hidden_states)
        # 对处理后的输出应用 dropout
        hidden_states = self.dropout(hidden_states)

        return hidden_states


# 从 transformers.models.vit.modeling_vit.ViTAttention 复制而来,修改为 ViTHybridAttention
class ViTHybridAttention(nn.Module):
    def __init__(self, config: ViTHybridConfig) -> None:
        super().__init__()
        # 初始化注意力机制模块,这里使用 ViTHybridSelfAttention
        self.attention = ViTHybridSelfAttention(config)
        # 初始化输出模块,这里使用 ViTHybridSelfOutput
        self.output = ViTHybridSelfOutput(config)
        # 存储需要剪枝的注意力头信息
        self.pruned_heads = set()

    def prune_heads(self, heads: Set[int]) -> None:
        if len(heads) == 0:
            return
        # 寻找可以剪枝的注意力头及其索引
        heads, index = find_pruneable_heads_and_indices(
            heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
        )

        # 对线性层进行剪枝
        self.attention.query = prune_linear_layer(self.attention.query, index)
        self.attention.key = prune_linear_layer(self.attention.key, index)
        self.attention.value = prune_linear_layer(self.attention.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # 更新超参数并存储已剪枝的注意力头
        self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
        self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        # 使用注意力机制模块处理隐藏状态
        self_outputs = self.attention(hidden_states, head_mask, output_attentions)

        # 将注意力机制模块的输出传入输出模块,生成最终的注意力输出
        attention_output = self.output(self_outputs[0], hidden_states)

        # 如果需要输出注意力权重,则将它们添加到输出元组中
        outputs = (attention_output,) + self_outputs[1:]
        return outputs


# 从 transformers.models.vit.modeling_vit.ViTIntermediate 复制而来,修改为 ViTHybridIntermediate
class ViTHybridIntermediate(nn.Module):
    # 初始化函数,用于创建一个新的网络层对象
    def __init__(self, config: ViTHybridConfig) -> None:
        # 调用父类的初始化方法
        super().__init__()
        # 创建一个线性层,将输入大小设置为 config.hidden_size,输出大小为 config.intermediate_size
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        
        # 检查 config.hidden_act 是否为字符串类型,若是,则从 ACT2FN 字典中获取对应的激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            # 否则直接使用 config.hidden_act 作为激活函数
            self.intermediate_act_fn = config.hidden_act

    # 前向传播函数,用于定义层的计算流程
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 将输入的 hidden_states 通过 self.dense 线性层进行线性变换
        hidden_states = self.dense(hidden_states)
        # 将线性变换后的结果通过 self.intermediate_act_fn 激活函数进行非线性变换
        hidden_states = self.intermediate_act_fn(hidden_states)

        # 返回经过处理后的 hidden_states 结果作为输出
        return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->ViTHybrid
# 定义了一个名为 ViTHybridOutput 的类,继承自 nn.Module
class ViTHybridOutput(nn.Module):
    def __init__(self, config: ViTHybridConfig) -> None:
        super().__init__()
        # 创建一个全连接层,将输入特征维度为 config.intermediate_size 转换为 config.hidden_size
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        # 定义一个 dropout 层,以 config.hidden_dropout_prob 的概率丢弃输入
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    # 前向传播方法,接收 hidden_states 和 input_tensor 两个张量作为输入,返回处理后的张量
    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 通过全连接层处理 hidden_states
        hidden_states = self.dense(hidden_states)
        # 对处理后的 hidden_states 进行 dropout
        hidden_states = self.dropout(hidden_states)

        # 将处理后的 hidden_states 与 input_tensor 相加,实现残差连接
        hidden_states = hidden_states + input_tensor

        return hidden_states


# 定义了一个名为 ViTHybridLayer 的类,继承自 nn.Module
class ViTHybridLayer(nn.Module):
    """This corresponds to the Block class in the timm implementation."""

    def __init__(self, config: ViTHybridConfig) -> None:
        super().__init__()
        # 设置用于分块前馈的 chunk 大小和序列长度维度
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        # 初始化 self-attention、中间层和输出层
        self.attention = ViTHybridAttention(config)
        self.intermediate = ViTHybridIntermediate(config)
        self.output = ViTHybridOutput(config)
        # 设置两个 layernorm 层,分别应用在 self-attention 前后
        self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    # 前向传播方法,接收 hidden_states、head_mask 和 output_attentions 三个参数,返回处理后的张量或元组
    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        # 对 hidden_states 应用 layernorm_before,并传入 self-attention 进行处理
        self_attention_outputs = self.attention(
            self.layernorm_before(hidden_states),  # in ViTHybrid, layernorm is applied before self-attention
            head_mask,
            output_attentions=output_attentions,
        )
        # 获取 self-attention 的输出张量
        attention_output = self_attention_outputs[0]
        # 如果需要输出注意力权重,则包含在 outputs 中
        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights

        # 第一个残差连接,将 attention_output 与原始 hidden_states 相加
        hidden_states = attention_output + hidden_states.to(attention_output.device)

        # 在 ViTHybrid 中,也在 self-attention 后应用 layernorm
        layer_output = self.layernorm_after(hidden_states)
        # 经过中间层处理
        layer_output = self.intermediate(layer_output)

        # 第二个残差连接,输出最终的层输出
        layer_output = self.output(layer_output, hidden_states)

        # 将 layer_output 添加到 outputs 中,并返回
        outputs = (layer_output,) + outputs

        return outputs


# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTHybrid
# 定义了一个名为 ViTHybridEncoder 的类,继承自 nn.Module
class ViTHybridEncoder(nn.Module):
    def __init__(self, config: ViTHybridConfig) -> None:
        super().__init__()
        # 存储配置信息
        self.config = config
        # 创建一系列 ViTHybridLayer 层,数量为 config.num_hidden_layers
        self.layer = nn.ModuleList([ViTHybridLayer(config) for _ in range(config.num_hidden_layers)])
        # 设置梯度检查点为 False
        self.gradient_checkpointing = False
    # 定义前向传播函数,接受隐藏状态、头部掩码、是否输出注意力权重、是否输出隐藏状态、是否返回字典等参数
    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ) -> Union[tuple, BaseModelOutput]:
        # 如果需要输出隐藏状态,则初始化一个空元组用于存储所有的隐藏状态
        all_hidden_states = () if output_hidden_states else None
        # 如果需要输出注意力权重,则初始化一个空元组用于存储所有的注意力权重
        all_self_attentions = () if output_attentions else None

        # 遍历每个层次的模块
        for i, layer_module in enumerate(self.layer):
            # 如果需要输出隐藏状态,则将当前隐藏状态加入到所有隐藏状态的元组中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 获取当前层的头部掩码,如果头部掩码不为None,则获取当前层的掩码;否则为None
            layer_head_mask = head_mask[i] if head_mask is not None else None

            # 如果开启了梯度检查点并且在训练阶段,则使用梯度检查点函数对当前层进行调用
            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    hidden_states,
                    layer_head_mask,
                    output_attentions,
                )
            else:
                # 否则直接调用当前层模块,计算输出结果
                layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)

            # 更新隐藏状态为当前层的输出的第一个元素
            hidden_states = layer_outputs[0]

            # 如果需要输出注意力权重,则将当前层的注意力权重加入到所有注意力权重的元组中
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        # 如果需要输出隐藏状态,则将最终的隐藏状态加入到所有隐藏状态的元组中
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 如果不需要返回字典形式的输出,则返回非None的元组
        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
        # 否则返回BaseModelOutput对象,包含最终的隐藏状态、所有隐藏状态和所有注意力权重
        return BaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values of the input image. This tensor represents the image in the form of batches,
            channels, height, and width.

        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask indicating which heads of the self-attention mechanism to mask out. It can be provided as
            a 1D tensor for a single layer model or a 2D tensor for multi-layer models. Values are in the
            range [0, 1]:

            - 1 indicates that the head is **not masked**,
            - 0 indicates that the head is **masked**.

        output_attentions (`bool`, *optional*):
            Whether or not to include the attention tensors from all attention layers in the output. Refer to
            the returned tensors for details on the `attentions` field.

        output_hidden_states (`bool`, *optional*):
            Whether or not to include the hidden states from all layers in the output. Refer to the returned
            tensors for details on the `hidden_states` field.

        return_dict (`bool`, *optional*):
            Whether to return a [`~utils.ModelOutput`] instead of a tuple. If True, the output will be wrapped
            in a standardized model output format for ease of use and consistency.
"""
@add_start_docstrings(
    "The bare ViT Hybrid Model transformer outputting raw hidden-states without any specific head on top.",
    VIT_START_DOCSTRING,
)
# 从transformers.models.vit.modeling_vit.ViTModel复制而来,将ViT替换为ViTHybrid
class ViTHybridModel(ViTHybridPreTrainedModel):
    def __init__(self, config: ViTHybridConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
        super().__init__(config)
        self.config = config

        # 初始化ViTHybridEmbeddings对象,使用是否mask token作为参数
        self.embeddings = ViTHybridEmbeddings(config, use_mask_token=use_mask_token)
        # 初始化ViTHybridEncoder对象
        self.encoder = ViTHybridEncoder(config)

        # 初始化LayerNorm层,使用配置中的hidden_size和layer_norm_eps参数
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 如果add_pooling_layer为True,则初始化ViTHybridPooler对象,否则设为None
        self.pooler = ViTHybridPooler(config) if add_pooling_layer else None

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self) -> ViTHybridPatchEmbeddings:
        # 返回embeddings中的patch_embeddings对象
        return self.embeddings.patch_embeddings

    def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        # 遍历heads_to_prune字典,对每个层和对应的需要prune的头部列表进行操作
        for layer, heads in heads_to_prune.items():
            # 调用encoder中对应层的attention对象的prune_heads方法,进行头部修剪
            self.encoder.layer[layer].attention.prune_heads(heads)

    @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=BaseModelOutputWithPooling,
        config_class=_CONFIG_FOR_DOC,
        modality="vision",
        expected_output=_EXPECTED_OUTPUT_SHAPE,
    )
    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        bool_masked_pos: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPooling]:
        r"""
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        """
        # 如果 output_attentions 为 None,则使用模型配置中的默认值
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 如果 output_hidden_states 为 None,则使用模型配置中的默认值
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 如果 return_dict 为 None,则使用模型配置中的默认值
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 如果 pixel_values 为空,则抛出数值错误异常
        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        # 准备头部掩码(head_mask)如果需要
        # head_mask 中的 1.0 表示保留该头部
        # attention_probs 的形状为 bsz x n_heads x N x N
        # 输入的 head_mask 形状为 [num_heads] 或 [num_hidden_layers x num_heads]
        # head_mask 被转换为形状 [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        # TODO: 可能有更干净的方法来转换输入(从 ImageProcessor 的角度来看)
        
        # 检查 pixel_values 的数据类型是否符合预期,若不符合,则转换为预期的数据类型
        expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
        if pixel_values.dtype != expected_dtype:
            pixel_values = pixel_values.to(expected_dtype)

        # 将像素值传入嵌入层,得到嵌入输出
        embedding_output = self.embeddings(
            pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
        )

        # 将嵌入输出传入编码器层
        encoder_outputs = self.encoder(
            embedding_output,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        # 获取编码器的序列输出
        sequence_output = encoder_outputs[0]
        # 应用层归一化到序列输出
        sequence_output = self.layernorm(sequence_output)
        # 如果存在池化器,则对序列输出进行池化
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

        # 如果 return_dict 为 False,则返回头部输出和编码器其他输出
        if not return_dict:
            head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
            return head_outputs + encoder_outputs[1:]

        # 如果 return_dict 为 True,则返回包含池化输出在内的 BaseModelOutputWithPooling 对象
        return BaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
# 从transformers.models.vit.modeling_vit.ViTPooler复制而来,将ViT替换为ViTHybrid
class ViTHybridPooler(nn.Module):
    def __init__(self, config: ViTHybridConfig):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)  # 创建一个线性层,输入输出维度相同
        self.activation = nn.Tanh()  # 使用双曲正切函数作为激活函数

    def forward(self, hidden_states):
        # 通过仅获取第一个令牌对应的隐藏状态来“汇集”模型
        first_token_tensor = hidden_states[:, 0]  # 获取第一个令牌对应的隐藏状态张量
        pooled_output = self.dense(first_token_tensor)  # 将其应用于线性层
        pooled_output = self.activation(pooled_output)  # 应用双曲正切激活函数
        return pooled_output


@add_start_docstrings(
    """
    ViT Hybrid Model transformer with an image classification head on top (a linear layer on top of the final hidden
    state of the [CLS] token) e.g. for ImageNet.
    """,
    VIT_START_DOCSTRING,
)
# 从transformers.models.vit.modeling_vit.ViTForImageClassification复制而来,将ViT替换为ViTHybrid
class ViTHybridForImageClassification(ViTHybridPreTrainedModel):
    def __init__(self, config: ViTHybridConfig) -> None:
        super().__init__(config)

        self.num_labels = config.num_labels  # 设置分类标签数目
        self.vit = ViTHybridModel(config, add_pooling_layer=False)  # 创建一个ViTHybridModel模型实例,不添加汇集层

        # 分类器头部
        self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
        # 如果有分类标签,则创建一个线性层作为分类器头部;否则使用恒等映射函数Identity()

        # 初始化权重并应用最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_IMAGE_CLASS_CHECKPOINT,
        output_type=ImageClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
        expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
    )
    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[tuple, ImageClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        # 确保 return_dict 变量有值,如果没有提供则使用 self.config.use_return_dict 的值
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 使用 ViT 模型进行前向传播
        outputs = self.vit(
            pixel_values,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            interpolate_pos_encoding=interpolate_pos_encoding,
            return_dict=return_dict,
        )

        # 获取模型输出中的序列输出
        sequence_output = outputs[0]

        # 对序列输出的第一个位置进行分类预测
        logits = self.classifier(sequence_output[:, 0, :])

        # 初始化损失为 None
        loss = None
        if labels is not None:
            # 将 labels 移动到与 logits 相同的设备上,以支持模型并行计算
            labels = labels.to(logits.device)
            # 根据配置确定问题类型,如果未指定则根据 num_labels 来判断
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            # 根据问题类型计算损失函数
            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        # 如果不需要返回字典形式的结果,则按元组形式返回输出
        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        # 返回 ImageClassifierOutput 对象,包含损失、logits、隐藏状态和注意力权重
        return ImageClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

.\models\vit_hybrid\__init__.py

# 版权声明和许可证信息,指明代码版权归HuggingFace团队所有,使用Apache License, Version 2.0许可证
#
# 导入必要的类型检查工具
from typing import TYPE_CHECKING

# 导入自定义的异常和模块延迟加载工具,用于处理可能缺失的依赖
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available

# 定义模块的导入结构字典,包含配置和模型信息
_import_structure = {"configuration_vit_hybrid": ["VIT_HYBRID_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTHybridConfig"]}

# 尝试导入torch,若不可用则抛出OptionalDependencyNotAvailable异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若torch可用,则添加模型相关的导入信息到_import_structure字典
    _import_structure["modeling_vit_hybrid"] = [
        "VIT_HYBRID_PRETRAINED_MODEL_ARCHIVE_LIST",
        "ViTHybridForImageClassification",
        "ViTHybridModel",
        "ViTHybridPreTrainedModel",
    ]

# 尝试导入vision,若不可用则抛出OptionalDependencyNotAvailable异常
try:
    if not is_vision_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若vision可用,则添加图像处理相关的导入信息到_import_structure字典
    _import_structure["image_processing_vit_hybrid"] = ["ViTHybridImageProcessor"]

# 如果正在进行类型检查,导入具体的配置和模型类
if TYPE_CHECKING:
    from .configuration_vit_hybrid import VIT_HYBRID_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTHybridConfig

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_vit_hybrid import (
            VIT_HYBRID_PRETRAINED_MODEL_ARCHIVE_LIST,
            ViTHybridForImageClassification,
            ViTHybridModel,
            ViTHybridPreTrainedModel,
        )

    try:
        if not is_vision_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .image_processing_vit_hybrid import ViTHybridImageProcessor

# 如果不是类型检查环境,则进行模块的延迟加载设置
else:
    import sys

    # 将当前模块替换为延迟加载模块
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\vit_mae\configuration_vit_mae.py

# coding=utf-8
# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" ViT MAE model configuration"""

# 导入必要的模块和函数
from ...configuration_utils import PretrainedConfig
from ...utils import logging

# 获取日志记录器
logger = logging.get_logger(__name__)

# 预训练配置文件的映射表,指定每个预训练模型对应的配置文件的 URL
VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "facebook/vit-mae-base": "https://huggingface.co/facebook/vit-mae-base/resolve/main/config.json",
    # 可以查看所有 ViT MAE 模型的列表:https://huggingface.co/models?filter=vit-mae
}

# ViTMAEConfig 类,继承自 PretrainedConfig 类
class ViTMAEConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`ViTMAEModel`]. It is used to instantiate an ViT
    MAE model according to the specified arguments, defining the model architecture. Instantiating a configuration with
    the defaults will yield a similar configuration to that of the ViT
    [facebook/vit-mae-base](https://huggingface.co/facebook/vit-mae-base) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    """
    # 隐藏层的维度,包括编码器层和池化层
    Args:
        hidden_size (`int`, *optional*, defaults to 768):
            Dimensionality of the encoder layers and the pooler layer.
        
        # Transformer 编码器中隐藏层的数量
        num_hidden_layers (`int`, *optional*, defaults to 12):
            Number of hidden layers in the Transformer encoder.
        
        # Transformer 编码器中每个注意力层的注意力头数
        num_attention_heads (`int`, *optional*, defaults to 12):
            Number of attention heads for each attention layer in the Transformer encoder.
        
        # Transformer 编码器中"中间"(即前馈)层的维度
        intermediate_size (`int`, *optional*, defaults to 3072):
            Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
        
        # 编码器和池化器中的非线性激活函数
        hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
            `"relu"`, `"selu"` and `"gelu_new"` are supported.
        
        # 嵌入层、编码器和池化器中所有全连接层的 dropout 概率
        hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
            The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
        
        # 注意力概率的 dropout 比率
        attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.
        
        # 初始化所有权重矩阵的截断正态分布的标准差
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        
        # 层归一化层使用的 epsilon 值
        layer_norm_eps (`float`, *optional*, defaults to 1e-12):
            The epsilon used by the layer normalization layers.
        
        # 每个图像的大小(分辨率)
        image_size (`int`, *optional*, defaults to 224):
            The size (resolution) of each image.
        
        # 每个图像块(patch)的大小(分辨率)
        patch_size (`int`, *optional*, defaults to 16):
            The size (resolution) of each patch.
        
        # 输入通道的数量
        num_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        
        # 是否为查询、键和值添加偏置
        qkv_bias (`bool`, *optional*, defaults to `True`):
            Whether to add a bias to the queries, keys and values.
        
        # 解码器中每个注意力层的注意力头数
        decoder_num_attention_heads (`int`, *optional*, defaults to 16):
            Number of attention heads for each attention layer in the decoder.
        
        # 解码器的维度
        decoder_hidden_size (`int`, *optional*, defaults to 512):
            Dimensionality of the decoder.
        
        # 解码器中隐藏层的数量
        decoder_num_hidden_layers (`int`, *optional*, defaults to 8):
            Number of hidden layers in the decoder.
        
        # 解码器中"中间"(即前馈)层的维度
        decoder_intermediate_size (`int`, *optional*, defaults to 2048):
            Dimensionality of the "intermediate" (i.e., feed-forward) layer in the decoder.
        
        # 输入序列中掩码标记的比例
        mask_ratio (`float`, *optional*, defaults to 0.75):
            The ratio of the number of masked tokens in the input sequence.
        
        # 是否使用归一化像素进行训练
        norm_pix_loss (`bool`, *optional*, defaults to `False`):
            Whether or not to train with normalized pixels (see Table 3 in the paper). Using normalized pixels improved
            representation quality in the experiments of the authors.
    >>> configuration = ViTMAEConfig()
    
    >>> # 初始化一个模型(带有随机权重),使用 vit-mae-base 风格的配置
    >>> model = ViTMAEModel(configuration)
    
    >>> # 访问模型的配置信息
    >>> configuration = model.config

.\models\vit_mae\convert_vit_mae_to_pytorch.py

# 导入必要的模块和库
import argparse  # 用于解析命令行参数
import requests  # 用于发送 HTTP 请求
import torch  # PyTorch 深度学习库
from PIL import Image  # Python Imaging Library,用于图像处理

# 从transformers模块导入相关的类和函数
from transformers import ViTMAEConfig, ViTMAEForPreTraining, ViTMAEImageProcessor


# 定义函数,根据特定规则重命名输入的名称
def rename_key(name):
    if "cls_token" in name:
        name = name.replace("cls_token", "vit.embeddings.cls_token")
    if "mask_token" in name:
        name = name.replace("mask_token", "decoder.mask_token")
    if "decoder_pos_embed" in name:
        name = name.replace("decoder_pos_embed", "decoder.decoder_pos_embed")
    if "pos_embed" in name and "decoder" not in name:
        name = name.replace("pos_embed", "vit.embeddings.position_embeddings")
    if "patch_embed.proj" in name:
        name = name.replace("patch_embed.proj", "vit.embeddings.patch_embeddings.projection")
    if "patch_embed.norm" in name:
        name = name.replace("patch_embed.norm", "vit.embeddings.norm")
    if "decoder_blocks" in name:
        name = name.replace("decoder_blocks", "decoder.decoder_layers")
    if "blocks" in name:
        name = name.replace("blocks", "vit.encoder.layer")
    if "attn.proj" in name:
        name = name.replace("attn.proj", "attention.output.dense")
    if "attn" in name:
        name = name.replace("attn", "attention.self")
    if "norm1" in name:
        name = name.replace("norm1", "layernorm_before")
    if "norm2" in name:
        name = name.replace("norm2", "layernorm_after")
    if "mlp.fc1" in name:
        name = name.replace("mlp.fc1", "intermediate.dense")
    if "mlp.fc2" in name:
        name = name.replace("mlp.fc2", "output.dense")
    if "decoder_embed" in name:
        name = name.replace("decoder_embed", "decoder.decoder_embed")
    if "decoder_norm" in name:
        name = name.replace("decoder_norm", "decoder.decoder_norm")
    if "decoder_pred" in name:
        name = name.replace("decoder_pred", "decoder.decoder_pred")
    if "norm.weight" in name and "decoder" not in name:
        name = name.replace("norm.weight", "vit.layernorm.weight")
    if "norm.bias" in name and "decoder" not in name:
        name = name.replace("norm.bias", "vit.layernorm.bias")

    return name


# 定义函数,将给定的模型状态字典转换为新的配置
def convert_state_dict(orig_state_dict, config):
    # 遍历原始状态字典的复制版本的键
    for key in orig_state_dict.copy().keys():
        # 弹出当前键对应的值
        val = orig_state_dict.pop(key)

        # 检查键名中是否包含 "qkv"
        if "qkv" in key:
            # 根据 "." 分割键名
            key_split = key.split(".")
            # 获取层编号
            layer_num = int(key_split[1])

            # 检查键名是否包含 "decoder_blocks"
            if "decoder_blocks" in key:
                # 根据配置文件获取隐藏大小
                dim = config.decoder_hidden_size
                prefix = "decoder.decoder_layers."

                # 根据键名中是否包含 "weight" 还是 "bias",更新对应的查询、键、值的权重或偏置
                if "weight" in key:
                    orig_state_dict[f"{prefix}{layer_num}.attention.attention.query.weight"] = val[:dim, :]
                    orig_state_dict[f"{prefix}{layer_num}.attention.attention.key.weight"] = val[dim : dim * 2, :]
                    orig_state_dict[f"{prefix}{layer_num}.attention.attention.value.weight"] = val[-dim:, :]
                elif "bias" in key:
                    orig_state_dict[f"{prefix}{layer_num}.attention.attention.query.bias"] = val[:dim]
                    orig_state_dict[f"{prefix}{layer_num}.attention.attention.key.bias"] = val[dim : dim * 2]
                    orig_state_dict[f"{prefix}{layer_num}.attention.attention.value.bias"] = val[-dim:]
            else:
                # 根据配置文件获取隐藏大小
                dim = config.hidden_size
                prefix = "vit.encoder.layer."

                # 根据键名中是否包含 "weight" 还是 "bias",更新对应的查询、键、值的权重或偏置
                if "weight" in key:
                    orig_state_dict[f"{prefix}{layer_num}.attention.attention.query.weight"] = val[:dim, :]
                    orig_state_dict[f"{prefix}{layer_num}.attention.attention.key.weight"] = val[dim : dim * 2, :]
                    orig_state_dict[f"{prefix}{layer_num}.attention.attention.value.weight"] = val[-dim:, :]
                elif "bias" in key:
                    orig_state_dict[f"{prefix}{layer_num}.attention.attention.query.bias"] = val[:dim]
                    orig_state_dict[f"{prefix}{layer_num}.attention.attention.key.bias"] = val[dim : dim * 2]
                    orig_state_dict[f"{prefix}{layer_num}.attention.attention.value.bias"] = val[-dim:]

        else:
            # 对不包含 "qkv" 的键名进行重命名处理后更新到原始状态字典中
            orig_state_dict[rename_key(key)] = val

    # 返回更新后的原始状态字典
    return orig_state_dict
def convert_vit_mae_checkpoint(checkpoint_url, pytorch_dump_folder_path):
    # 创建一个 ViTMAEConfig 对象
    config = ViTMAEConfig()
    
    # 根据 checkpoint_url 的内容设置不同的配置项
    if "large" in checkpoint_url:
        config.hidden_size = 1024
        config.intermediate_size = 4096
        config.num_hidden_layers = 24
        config.num_attention_heads = 16
    elif "huge" in checkpoint_url:
        config.patch_size = 14
        config.hidden_size = 1280
        config.intermediate_size = 5120
        config.num_hidden_layers = 32
        config.num_attention_heads = 16

    # 使用配置初始化 ViTMAEForPreTraining 模型
    model = ViTMAEForPreTraining(config)

    # 从指定 URL 加载 PyTorch state_dict,并获取其中的 "model" 部分
    state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["model"]

    # 创建 ViTMAEImageProcessor 对象,指定图片大小
    image_processor = ViTMAEImageProcessor(size=config.image_size)

    # 将加载的 state_dict 转换为新的 state_dict 格式
    new_state_dict = convert_state_dict(state_dict, config)

    # 加载模型的新 state_dict
    model.load_state_dict(new_state_dict)
    model.eval()

    # 指定要处理的图片 URL
    url = "https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg"

    # 使用 requests 获取图片并打开为 Image 对象,再次初始化 image_processor
    image = Image.open(requests.get(url, stream=True).raw)
    image_processor = ViTMAEImageProcessor(size=config.image_size)
    
    # 使用 image_processor 处理图片,返回 PyTorch 张量格式的输入
    inputs = image_processor(images=image, return_tensors="pt")

    # 执行模型的前向传播
    torch.manual_seed(2)
    outputs = model(**inputs)
    logits = outputs.logits

    # 根据 checkpoint_url 设置期望的输出结果的一部分(slice)
    if "large" in checkpoint_url:
        expected_slice = torch.tensor(
            [[-0.7309, -0.7128, -1.0169], [-1.0161, -0.9058, -1.1878], [-1.0478, -0.9411, -1.1911]]
        )
    elif "huge" in checkpoint_url:
        expected_slice = torch.tensor(
            [[-1.1599, -0.9199, -1.2221], [-1.1952, -0.9269, -1.2307], [-1.2143, -0.9337, -1.2262]]
        )
    else:
        expected_slice = torch.tensor(
            [[-0.9192, -0.8481, -1.1259], [-1.1349, -1.0034, -1.2599], [-1.1757, -1.0429, -1.2726]]
        )

    # 验证模型输出是否与期望的 slice 相似
    assert torch.allclose(logits[0, :3, :3], expected_slice, atol=1e-4)

    # 打印消息,保存模型到指定路径
    print(f"Saving model to {pytorch_dump_folder_path}")
    model.save_pretrained(pytorch_dump_folder_path)

    # 打印消息,保存 image_processor 到指定路径
    print(f"Saving image processor to {pytorch_dump_folder_path}")


if __name__ == "__main__":
    # 解析命令行参数
    parser = argparse.ArgumentParser()
    # 必需参数
    parser.add_argument(
        "--checkpoint_url",
        default="https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_base.pth",
        type=str,
        help="URL of the checkpoint you'd like to convert.",
    )
    parser.add_argument(
        "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
    )

    args = parser.parse_args()
    # 调用函数进行模型转换
    convert_vit_mae_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)

.\models\vit_mae\modeling_tf_vit_mae.py

# coding=utf-8
# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 ViT MAE (masked autoencoder) model."""

# Importing necessary modules and libraries
from __future__ import annotations  # Allow forward references in type annotations

import collections.abc  # Import for abstract base classes
import math  # Import for mathematical functions
from copy import deepcopy  # Import for deep copying objects
from dataclasses import dataclass  # Import for creating structured data classes
from typing import Optional, Tuple, Union  # Import for type hints

import numpy as np  # Import for numerical operations with arrays
import tensorflow as tf  # Import TensorFlow library

# Importing specific functions and classes from custom modules
from ...activations_tf import get_tf_activation  # Import activation function retriever
from ...file_utils import (
    ModelOutput,  # Import base class for model outputs
    add_start_docstrings,  # Import function for adding docstrings to functions
    add_start_docstrings_to_model_forward,  # Import function for adding docstrings to model forward pass
    replace_return_docstrings,  # Import function for replacing return docstrings
)
from ...modeling_tf_outputs import TFBaseModelOutput  # Import base model output class for TensorFlow
from ...modeling_tf_utils import (
    TFModelInputType,  # Import type hint for model input in TensorFlow
    TFPreTrainedModel,  # Import base class for pre-trained models in TensorFlow
    get_initializer,  # Import function for getting weight initializers
    keras,  # Import Keras submodule from TensorFlow
    keras_serializable,  # Import decorator for serializing Keras layers
    unpack_inputs,  # Import function for unpacking model inputs
)
from ...tf_utils import shape_list, stable_softmax  # Import utility functions for TensorFlow
from ...utils import logging  # Import logging utilities
from .configuration_vit_mae import ViTMAEConfig  # Import configuration class for ViT MAE model


logger = logging.get_logger(__name__)  # Get logger instance for current module

_CONFIG_FOR_DOC = "ViTMAEConfig"  # Documentation string for configuration class
_CHECKPOINT_FOR_DOC = "facebook/vit-mae-base"  # Documentation string for model checkpoint

@dataclass
class TFViTMAEModelOutput(ModelOutput):
    """
    Class for TFViTMAEModel's outputs, with potential hidden states and attentions.
    """
    # 定义函数的参数及其类型注解
    Args:
        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
            模型最后一层的隐藏状态序列的张量。
        mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
            指示哪些补丁被掩码(1)和哪些未被掩码(0)的张量。
        ids_restore (`tf.Tensor` of shape `(batch_size, sequence_length)`):
            包含(打乱后的)掩码补丁的原始索引的张量。
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            元组的 `tf.Tensor` (一个用于嵌入输出 + 每层输出的一个)的形状为 `(batch_size, sequence_length, hidden_size)`。
            模型在每一层输出的隐藏状态以及初始嵌入输出。
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            元组的 `tf.Tensor` (每层一个)的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
            注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
    """
    
    # 初始化函数的参数为默认值为 None
    last_hidden_state: tf.Tensor = None
    mask: tf.Tensor = None
    ids_restore: tf.Tensor = None
    hidden_states: Tuple[tf.Tensor] | None = None
    attentions: Tuple[tf.Tensor] | None = None
@dataclass
class TFViTMAEDecoderOutput(ModelOutput):
    """
    TFViTMAEDecoderOutput 类用于存储 TFViTMAEDecoder 的输出结果,可能包含隐藏状态和注意力权重。

    Args:
        logits (`tf.Tensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
            像素重建的逻辑回归结果。
        hidden_states (`tuple(tf.Tensor)`, *optional*, 当 `output_hidden_states=True` 时返回或 `config.output_hidden_states=True` 时返回):
            包含 `tf.Tensor` 元组(一个用于嵌入的输出 + 每层的一个输出),形状为 `(batch_size, sequence_length, hidden_size)`。
            模型每层的隐藏状态以及初始嵌入的输出。
        attentions (`tuple(tf.Tensor)`, *optional*, 当 `output_attentions=True` 时返回或 `config.output_attentions=True` 时返回):
            包含 `tf.Tensor` 元组(每层一个),形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
            经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
    """

    logits: tf.Tensor = None
    hidden_states: Tuple[tf.Tensor] | None = None
    attentions: Tuple[tf.Tensor] | None = None


@dataclass
class TFViTMAEForPreTrainingOutput(ModelOutput):
    """
    TFViTMAEForPreTrainingOutput 类用于存储 TFViTMAEForPreTraining 的输出结果,可能包含隐藏状态和注意力权重。

    Args:
        loss (`tf.Tensor` of shape `(1,)`):
            像素重建损失。
        logits (`tf.Tensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
            像素重建的逻辑回归结果。
        mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
            指示哪些补丁被掩盖(1)和哪些没有(0)的张量。
        ids_restore (`tf.Tensor` of shape `(batch_size, sequence_length)`):
            包含(打乱的)掩盖补丁的原始索引的张量。
        hidden_states (`tuple(tf.Tensor)`, *optional*, 当 `output_hidden_states=True` 时返回或 `config.output_hidden_states=True` 时返回):
            包含 `tf.Tensor` 元组(一个用于嵌入的输出 + 每层的一个输出),形状为 `(batch_size, sequence_length, hidden_size)`。
            模型每层的隐藏状态以及初始嵌入的输出。
        attentions (`tuple(tf.Tensor)`, *optional*, 当 `output_attentions=True` 时返回或 `config.output_attentions=True` 时返回):
            包含 `tf.Tensor` 元组(每层一个),形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
            经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
    """

    loss: tf.Tensor | None = None
    logits: tf.Tensor = None
    mask: tf.Tensor = None
    ids_restore: tf.Tensor = None
    hidden_states: Tuple[tf.Tensor] | None = None
    # attentions 是一个变量,类型是 Tuple[tf.Tensor] 或者 None
    attentions: Tuple[tf.Tensor] | None = None
# 创建二维 sin/cos 位置嵌入的函数
def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
    """
    Create 2D sin/cos positional embeddings.

    Args:
        embed_dim (`int`):
            Embedding dimension.
        grid_size (`int`):
            The grid height and width.
        add_cls_token (`bool`, *optional*, defaults to `False`):
            Whether or not to add a classification (CLS) token.

    Returns:
        (`tf.Tensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the position
        embeddings (with or without classification token)
    """
    # 创建高度和宽度的网格
    grid_h = tf.range(grid_size, dtype=tf.float32)
    grid_w = tf.range(grid_size, dtype=tf.float32)
    grid = tf.meshgrid(grid_w, grid_h)  # 这里宽度先行
    grid = tf.stack(grid, axis=0)

    grid = tf.reshape(grid, [2, 1, grid_size, grid_size])
    # 从网格获取二维 sin/cos 位置嵌入
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if add_cls_token:
        # 如果需要添加 CLS token,则在位置嵌入前面加一个全零向量
        pos_embed = tf.concat([tf.zeros((1, embed_dim)), pos_embed], axis=0)
    return pos_embed


# 从网格获取二维 sin/cos 位置嵌入
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    if embed_dim % 2 != 0:
        raise ValueError("embed_dim must be even")

    # 使用一半维度来编码 grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = tf.concat([emb_h, emb_w], axis=1)  # (H*W, D)
    return emb


# 从网格获取一维 sin/cos 位置嵌入
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
    """
    if embed_dim % 2 != 0:
        raise ValueError("embed_dim must be even")

    omega = tf.range(embed_dim // 2, dtype="float32")
    omega /= embed_dim / 2.0
    omega = 1.0 / 10000**omega  # (D/2,)

    pos = tf.reshape(pos, [-1])  # (M,)
    out = tf.einsum("m,d->md", pos, omega)  # (M, D/2), outer product

    # 一半位置获取正弦模式,另一半获取余弦模式,然后串联起来
    emb_sin = tf.sin(out)  # (M, D/2)
    emb_cos = tf.cos(out)  # (M, D/2)

    emb = tf.concat([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


class TFViTMAEEmbeddings(keras.layers.Layer):
    """
    构建 CLS token、位置和补丁嵌入。
    """

    def __init__(self, config: ViTMAEConfig, **kwargs):
        super().__init__(**kwargs)

        self.patch_embeddings = TFViTMAEPatchEmbeddings(config, name="patch_embeddings")
        self.num_patches = self.patch_embeddings.num_patches

        self.config = config
    # 在神经网络层的建立过程中,创建一个用于分类特殊令牌的权重矩阵,形状为 (1, 1, 隐藏层大小)
    self.cls_token = self.add_weight(
        shape=(1, 1, self.config.hidden_size),
        initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
        trainable=True,
        name="cls_token",
    )
    
    # 创建位置嵌入矩阵,形状为 (1, num_patches + 1, 隐藏层大小),使用零值初始化
    self.position_embeddings = self.add_weight(
        shape=(1, self.num_patches + 1, self.config.hidden_size),
        initializer="zeros",
        trainable=False,  # 固定的正弦-余弦位置嵌入
        name="position_embeddings",
    )
    
    # 调用函数 `get_2d_sincos_pos_embed` 生成二维正弦-余弦位置嵌入
    pos_embed = get_2d_sincos_pos_embed(
        self.position_embeddings.shape[-1],
        int(self.patch_embeddings.num_patches**0.5),
        add_cls_token=True,
    )[None, ...]
    
    # 将生成的位置嵌入赋值给 self.position_embeddings
    self.position_embeddings.assign(pos_embed)

    # 如果模型已经建立完成,则直接返回
    if self.built:
        return
    
    # 标记模型已经建立
    self.built = True
    
    # 如果 self.patch_embeddings 属性存在,则调用它的 build 方法
    if getattr(self, "patch_embeddings", None) is not None:
        with tf.name_scope(self.patch_embeddings.name):
            self.patch_embeddings.build(None)

def random_masking(self, sequence: tf.Tensor, noise: tf.Tensor | None = None):
    """
    执行每个样本的随机遮盖,通过每个样本的乱序实现。每个样本的乱序由参数 argsort 的随机噪声完成。

    Args:
        sequence (`tf.Tensor` of shape `(batch_size, sequence_length, dim)`)
        noise (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*),主要用于测试目的,
            控制随机性并保持可重现性
    """
    # 获取 sequence 的形状信息:batch_size, sequence_length, dim
    batch_size, seq_length, dim = shape_list(sequence)
    
    # 计算保留的长度,以保证不被遮盖的部分占比为 self.config.mask_ratio
    len_keep = int(seq_length * (1 - self.config.mask_ratio))

    # 如果没有提供噪声数据,则生成一个均匀分布在 [0, 1) 区间的随机噪声
    if noise is None:
        noise = tf.random.uniform(shape=(batch_size, seq_length), minval=0.0, maxval=1.0)  # 噪声范围在 [0, 1)

    # 对每个样本的噪声进行排序
    ids_shuffle = tf.argsort(noise, axis=1)  # 升序排序:小的表示保留,大的表示移除
    ids_restore = tf.argsort(ids_shuffle, axis=1)

    # 保留前 len_keep 部分的序号
    ids_keep = ids_shuffle[:, :len_keep]
    sequence_unmasked = tf.gather(
        sequence,
        axis=1,
        batch_dims=1,
        indices=ids_keep,
    )

    # 生成二进制遮罩:0 表示保留,1 表示移除
    # 这个方法是必需的,因为 TF 的 EagerTensors 不支持直接的赋值操作
    mask_keep = tf.zeros((batch_size, len_keep))
    mask_remove = tf.ones((batch_size, seq_length - len_keep))
    mask = tf.concat([mask_keep, mask_remove], axis=-1)

    # 根据 ids_restore 恢复原始顺序,得到最终的二进制遮罩
    mask = tf.gather(mask, axis=1, batch_dims=1, indices=ids_restore)

    return sequence_unmasked, mask, ids_restore
    def call(self, pixel_values: tf.Tensor, noise: tf.Tensor = None) -> tf.Tensor:
        # 使用 patch_embeddings 方法将像素值转换为嵌入向量
        embeddings = self.patch_embeddings(pixel_values)

        # 添加位置嵌入,不包括 cls 标记
        embeddings = embeddings + self.position_embeddings[:, 1:, :]

        # 执行随机遮蔽:将 embeddings 进行部分遮蔽,生成 mask,并记录遮蔽前的位置 ids_restore
        embeddings, mask, ids_restore = self.random_masking(embeddings, noise)

        # 添加 cls 标记
        # 从 self.cls_token 和 self.position_embeddings 中获取 cls 标记,并复制到每个样本序列的开头
        cls_token = self.cls_token + self.position_embeddings[:, :1, :]
        cls_tokens = tf.tile(cls_token, (shape_list(embeddings)[0], 1, 1))
        # 将 cls 标记与 embeddings 拼接起来
        embeddings = tf.concat([cls_tokens, embeddings], axis=1)

        # 返回处理后的 embeddings、mask 和 ids_restore
        return embeddings, mask, ids_restore
class TFViTMAEPatchEmbeddings(keras.layers.Layer):
    """
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    """

    def __init__(self, config: ViTMAEConfig, **kwargs):
        super().__init__(**kwargs)
        # 从配置中获取图像大小和patch大小
        image_size, patch_size = config.image_size, config.patch_size
        num_channels, hidden_size = config.num_channels, config.hidden_size
        # 如果图像大小和patch大小不是可迭代对象,转换为元组形式
        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
        # 计算图像中的patch数量
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.num_channels = num_channels
        self.config = config

        # 定义卷积层,用于将输入的像素值转换为patch embeddings
        self.projection = keras.layers.Conv2D(
            filters=hidden_size,
            kernel_size=patch_size,
            strides=patch_size,
            padding="valid",
            data_format="channels_last",
            kernel_initializer="glorot_uniform",  # 使用glorot_uniform初始化卷积核
            bias_initializer="zeros",  # 使用零初始化偏置
            name="projection",
        )

    def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
        # 获取输入张量的形状信息
        batch_size, num_channels, height, width = shape_list(pixel_values)
        
        # 在动态执行模式下,检查通道数是否与配置中设置的一致
        if tf.executing_eagerly():
            if num_channels != self.num_channels:
                raise ValueError(
                    "Make sure that the channel dimension of the pixel values match with the one set in the"
                    " configuration."
                )
            # 检查输入图像的尺寸是否与配置中设置的一致
            if height != self.image_size[0] or width != self.image_size[1]:
                raise ValueError(
                    f"Input image size ({height}*{width}) doesn't match model"
                    f" ({self.image_size[0]}*{self.image_size[1]})."
                )

        # 在CPU上运行时,keras.layers.Conv2D不支持NCHW格式,需要将输入格式从NCHW转换为NHWC
        # shape = (batch_size, in_height, in_width, in_channels=num_channels)
        pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))

        # 将输入像素值投影到隐藏空间中
        projection = self.projection(pixel_values)

        # 将2D空间维度变换为单一的时间维度
        # shape = (batch_size, num_patches, out_channels=embed_dim)
        num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
        x = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))

        return x
    # 定义 build 方法,用于构建神经网络层的结构
    def build(self, input_shape=None):
        # 如果已经构建过,直接返回,避免重复构建
        if self.built:
            return
        # 标记当前层已经构建
        self.built = True
        # 如果存在投影层,则构建投影层
        if getattr(self, "projection", None) is not None:
            # 在 TensorFlow 中,使用 name_scope 可以定义操作的命名空间
            with tf.name_scope(self.projection.name):
                # 构建投影层,输入的形状是 [None, None, None, self.num_channels]
                self.projection.build([None, None, None, self.num_channels])
# 从transformers.models.vit.modeling_tf_vit.TFViTSelfAttention复制到TFViTMAESelfAttention,并修改为ViT->ViTMAE
class TFViTMAESelfAttention(keras.layers.Layer):
    def __init__(self, config: ViTMAEConfig, **kwargs):
        super().__init__(**kwargs)

        # 检查隐藏大小是否能被注意力头数整除
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number "
                f"of attention heads ({config.num_attention_heads})"
            )

        # 初始化注意力头数和每个头的大小
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.sqrt_att_head_size = math.sqrt(self.attention_head_size)

        # 创建用于查询、键、值的全连接层,并初始化权重
        self.query = keras.layers.Dense(
            units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
        )
        self.key = keras.layers.Dense(
            units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
        )
        self.value = keras.layers.Dense(
            units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
        )
        # 添加 dropout 层,用于注意力概率的随机失活
        self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
        self.config = config

    def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
        # 将形状从 [batch_size, seq_length, all_head_size] 重塑为 [batch_size, seq_length, num_attention_heads, attention_head_size]
        tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))

        # 将张量从 [batch_size, seq_length, num_attention_heads, attention_head_size] 转置为 [batch_size, num_attention_heads, seq_length, attention_head_size]
        return tf.transpose(tensor, perm=[0, 2, 1, 3])

    def call(
        self,
        hidden_states: tf.Tensor,
        head_mask: tf.Tensor,
        output_attentions: bool,
        training: bool = False,
    ) -> Tuple[tf.Tensor]:
        # 获取隐藏状态张量的批量大小
        batch_size = shape_list(hidden_states)[0]
        # 对隐藏状态进行查询操作,生成混合查询层
        mixed_query_layer = self.query(inputs=hidden_states)
        # 对隐藏状态进行键操作,生成混合键层
        mixed_key_layer = self.key(inputs=hidden_states)
        # 对隐藏状态进行值操作,生成混合值层
        mixed_value_layer = self.value(inputs=hidden_states)
        # 将混合查询层转置以便计算注意力分数
        query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
        # 将混合键层转置以便计算注意力分数
        key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
        # 将混合值层转置以便计算注意力分数
        value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)

        # 计算查询与键之间的点积,得到原始注意力分数
        # 形状为(batch size, num_heads, seq_len_q, seq_len_k)
        attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
        # 计算注意力分数的缩放系数 dk
        dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
        attention_scores = tf.divide(attention_scores, dk)

        # 将注意力分数归一化为概率
        attention_probs = stable_softmax(logits=attention_scores, axis=-1)

        # 对注意力概率进行 dropout 处理
        attention_probs = self.dropout(inputs=attention_probs, training=training)

        # 如果存在头部掩码,则应用头部掩码
        if head_mask is not None:
            attention_probs = tf.multiply(attention_probs, head_mask)

        # 计算注意力输出值
        attention_output = tf.matmul(attention_probs, value_layer)
        # 调整注意力输出值的维度顺序
        attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])

        # 将注意力输出值重新形状为(batch_size, seq_len_q, all_head_size)
        attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
        # 构建输出元组,根据需要包含注意力概率
        outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)

        return outputs

    def build(self, input_shape=None):
        # 如果已经构建过,则直接返回
        if self.built:
            return
        self.built = True
        # 如果存在查询层,构建查询层
        if getattr(self, "query", None) is not None:
            with tf.name_scope(self.query.name):
                self.query.build([None, None, self.config.hidden_size])
        # 如果存在键层,构建键层
        if getattr(self, "key", None) is not None:
            with tf.name_scope(self.key.name):
                self.key.build([None, None, self.config.hidden_size])
        # 如果存在值层,构建值层
        if getattr(self, "value", None) is not None:
            with tf.name_scope(self.value.name):
                self.value.build([None, None, self.config.hidden_size])
# Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfOutput with ViT->ViTMAE
class TFViTMAESelfOutput(keras.layers.Layer):
    """
    The residual connection is defined in TFViTMAELayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    """

    def __init__(self, config: ViTMAEConfig, **kwargs):
        super().__init__(**kwargs)

        # 定义一个全连接层,用于将输入的隐藏状态转换为指定大小的输出
        self.dense = keras.layers.Dense(
            units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        # 定义一个 dropout 层,用于在训练时随机置零部分神经元,防止过拟合
        self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
        self.config = config

    def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
        # 使用全连接层对隐藏状态进行转换
        hidden_states = self.dense(inputs=hidden_states)
        # 在训练时对转换后的输出应用 dropout
        hidden_states = self.dropout(inputs=hidden_states, training=training)

        return hidden_states

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果已定义全连接层,构建全连接层
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                self.dense.build([None, None, self.config.hidden_size])


# Copied from transformers.models.vit.modeling_tf_vit.TFViTAttention with ViT->ViTMAE
class TFViTMAEAttention(keras.layers.Layer):
    def __init__(self, config: ViTMAEConfig, **kwargs):
        super().__init__(**kwargs)

        # 定义注意力层对象,用于处理自注意力机制
        self.self_attention = TFViTMAESelfAttention(config, name="attention")
        # 定义输出层对象,负责接收注意力层输出并进行处理
        self.dense_output = TFViTMAESelfOutput(config, name="output")

    def prune_heads(self, heads):
        raise NotImplementedError

    def call(
        self,
        input_tensor: tf.Tensor,
        head_mask: tf.Tensor,
        output_attentions: bool,
        training: bool = False,
    ) -> Tuple[tf.Tensor]:
        # 调用自注意力层,处理输入张量,返回处理结果和可能的注意力分布(如果输出的话)
        self_outputs = self.self_attention(
            hidden_states=input_tensor, head_mask=head_mask, output_attentions=output_attentions, training=training
        )
        # 调用输出层,接收自注意力层的输出和输入张量,并返回处理后的结果
        attention_output = self.dense_output(
            hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
        )
        # 将输出整合为一个元组,包括处理后的注意力输出和可能的注意力分布
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them

        return outputs

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果已定义自注意力层,构建自注意力层
        if getattr(self, "self_attention", None) is not None:
            with tf.name_scope(self.self_attention.name):
                self.self_attention.build(None)
        # 如果已定义输出层,构建输出层
        if getattr(self, "dense_output", None) is not None:
            with tf.name_scope(self.dense_output.name):
                self.dense_output.build(None)


# Copied from transformers.models.vit.modeling_tf_vit.TFViTIntermediate with ViT->ViTMAE
class TFViTMAEIntermediate(keras.layers.Layer):
    # 初始化函数,用于创建一个新的 ViTMAE 层实例
    def __init__(self, config: ViTMAEConfig, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)

        # 创建一个密集连接层,用于处理输入特征
        self.dense = keras.layers.Dense(
            units=config.intermediate_size,  # 设置层的输出维度
            kernel_initializer=get_initializer(config.initializer_range),  # 使用指定的初始化器初始化权重矩阵
            name="dense"  # 设置层的名称
        )

        # 根据配置选择激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = get_tf_activation(config.hidden_act)  # 获取指定名称的 TensorFlow 激活函数
        else:
            self.intermediate_act_fn = config.hidden_act  # 直接使用给定的激活函数
        self.config = config  # 保存配置对象

    # 调用函数,用于定义层的正向传播逻辑
    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        hidden_states = self.dense(inputs=hidden_states)  # 将输入数据传递给密集连接层
        hidden_states = self.intermediate_act_fn(hidden_states)  # 应用中间激活函数

        return hidden_states  # 返回处理后的特征表示

    # 构建函数,用于构建层的参数
    def build(self, input_shape=None):
        if self.built:  # 如果已经构建过,直接返回
            return
        self.built = True  # 标记为已构建
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):  # 使用名称空间管理密集连接层
                self.dense.build([None, None, self.config.hidden_size])  # 构建密集连接层的参数
# Copied from transformers.models.vit.modeling_tf_vit.TFViTOutput with ViT->ViTMAE
class TFViTMAEOutput(keras.layers.Layer):
    def __init__(self, config: ViTMAEConfig, **kwargs):
        super().__init__(**kwargs)

        # 定义一个全连接层,输出维度为 config.hidden_size,使用指定的初始化方法
        self.dense = keras.layers.Dense(
            units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        # 定义一个 Dropout 层,使用给定的 dropout 率
        self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
        self.config = config

    def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
        # 对输入的 hidden_states 进行全连接操作
        hidden_states = self.dense(inputs=hidden_states)
        # 在训练时对全连接结果进行 dropout 处理
        hidden_states = self.dropout(inputs=hidden_states, training=training)
        # 将 dropout 后的结果与 input_tensor 相加
        hidden_states = hidden_states + input_tensor

        return hidden_states

    def build(self, input_shape=None):
        # 如果已经构建过,直接返回
        if self.built:
            return
        self.built = True
        # 如果存在 self.dense 层,则根据输入形状构建全连接层
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                self.dense.build([None, None, self.config.intermediate_size])


# Copied from transformers.models.vit.modeling_tf_vit.TFViTLayer with ViT->ViTMAE
class TFViTMAELayer(keras.layers.Layer):
    """This corresponds to the Block class in the timm implementation."""

    def __init__(self, config: ViTMAEConfig, **kwargs):
        super().__init__(**kwargs)

        # 初始化 TFViTMAEAttention 层
        self.attention = TFViTMAEAttention(config, name="attention")
        # 初始化 TFViTMAEIntermediate 层
        self.intermediate = TFViTMAEIntermediate(config, name="intermediate")
        # 初始化 TFViTMAEOutput 层
        self.vit_output = TFViTMAEOutput(config, name="output")

        # 初始化 layernorm 层,在每个 block 的开始和结束进行归一化
        self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before")
        self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after")
        self.config = config

    def call(
        self,
        hidden_states: tf.Tensor,
        head_mask: tf.Tensor,
        output_attentions: bool,
        training: bool = False,
        **kwargs
    ) -> tf.Tensor:
        # 调用 self.attention 层处理输入 hidden_states
        attention_output = self.attention(
            hidden_states, head_mask, output_attentions=output_attentions, training=training
        )
        # 将 attention 输出与 hidden_states 相加,并进行 layernorm 处理
        hidden_states = self.layernorm_before(attention_output + hidden_states)
        # 调用 self.intermediate 层处理 layernorm 后的结果
        intermediate_output = self.intermediate(hidden_states)
        # 将 intermediate 输出与 hidden_states 相加,并进行 layernorm 处理
        hidden_states = self.layernorm_after(intermediate_output + hidden_states)
        # 调用 self.vit_output 层处理最终的输出
        output = self.vit_output(hidden_states, attention_output, training=training)

        return output
    ) -> Tuple[tf.Tensor]:
        # 调用 self.attention 进行注意力计算,ViTMAE 中在 self-attention 前应用 layernorm
        attention_outputs = self.attention(
            input_tensor=self.layernorm_before(inputs=hidden_states),  # 在 self-attention 前应用 layernorm
            head_mask=head_mask,
            output_attentions=output_attentions,
            training=training,
        )
        attention_output = attention_outputs[0]

        # 第一个残差连接
        hidden_states = attention_output + hidden_states

        # ViTMAE 中在 self-attention 后同样应用 layernorm
        layer_output = self.layernorm_after(inputs=hidden_states)

        # 使用 intermediate 层处理输出
        intermediate_output = self.intermediate(hidden_states=layer_output)

        # 第二个残差连接在此处完成
        layer_output = self.vit_output(
            hidden_states=intermediate_output, input_tensor=hidden_states, training=training
        )
        outputs = (layer_output,) + attention_outputs[1:]  # 如果有需要,添加注意力信息到输出中

        return outputs

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果已经构建过,直接返回

        # 构建 attention 层
        if getattr(self, "attention", None) is not None:
            with tf.name_scope(self.attention.name):
                self.attention.build(None)

        # 构建 intermediate 层
        if getattr(self, "intermediate", None) is not None:
            with tf.name_scope(self.intermediate.name):
                self.intermediate.build(None)

        # 构建 vit_output 层
        if getattr(self, "vit_output", None) is not None:
            with tf.name_scope(self.vit_output.name):
                self.vit_output.build(None)

        # 构建 layernorm_before 层
        if getattr(self, "layernorm_before", None) is not None:
            with tf.name_scope(self.layernorm_before.name):
                self.layernorm_before.build([None, None, self.config.hidden_size])

        # 构建 layernorm_after 层
        if getattr(self, "layernorm_after", None) is not None:
            with tf.name_scope(self.layernorm_after.name):
                self.layernorm_after.build([None, None, self.config.hidden_size])
# 从 transformers.models.vit.modeling_tf_vit.TFViTEncoder 复制代码,将 ViT 更改为 ViTMAE
class TFViTMAEEncoder(keras.layers.Layer):
    def __init__(self, config: ViTMAEConfig, **kwargs):
        super().__init__(**kwargs)

        # 初始化编码器的多层子模块 TFViTMAELayer,并命名为"layer_._{i}"
        self.layer = [TFViTMAELayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]

    def call(
        self,
        hidden_states: tf.Tensor,
        head_mask: tf.Tensor,
        output_attentions: bool,
        output_hidden_states: bool,
        return_dict: bool,
        training: bool = False,
    ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
        # 初始化存储所有隐藏状态的元组,如果不需要输出隐藏状态则为 None
        all_hidden_states = () if output_hidden_states else None
        # 初始化存储所有注意力权重的元组,如果不需要输出注意力权重则为 None
        all_attentions = () if output_attentions else None

        # 遍历每一层编码器
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                # 如果需要输出隐藏状态,则将当前隐藏状态添加到 all_hidden_states 中
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 调用当前层的编码器模块,计算输出
            layer_outputs = layer_module(
                hidden_states=hidden_states,
                head_mask=head_mask[i],
                output_attentions=output_attentions,
                training=training,
            )
            # 更新当前隐藏状态为当前层的输出
            hidden_states = layer_outputs[0]

            if output_attentions:
                # 如果需要输出注意力权重,则将当前层的注意力权重添加到 all_attentions 中
                all_attentions = all_attentions + (layer_outputs[1],)

        # 添加最后一层的隐藏状态到 all_hidden_states
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 如果不需要返回字典格式的输出,则返回所有非空的元组值
        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)

        # 返回 TFBaseModelOutput 类的对象,包含最后的隐藏状态、所有隐藏状态和所有注意力权重
        return TFBaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )

    def build(self, input_shape=None):
        # 如果已经构建过则直接返回
        if self.built:
            return
        self.built = True
        # 构建每一层的编码器模块
        if getattr(self, "layer", None) is not None:
            for layer in self.layer:
                with tf.name_scope(layer.name):
                    layer.build(None)


@keras_serializable
class TFViTMAEMainLayer(keras.layers.Layer):
    config_class = ViTMAEConfig

    def __init__(self, config: ViTMAEConfig, **kwargs):
        super().__init__(**kwargs)

        # 初始化 ViTMAE 主层的配置
        self.config = config

        # 初始化 ViTMAE 主层的嵌入层 TFViTMAEEmbeddings,并命名为"embeddings"
        self.embeddings = TFViTMAEEmbeddings(config, name="embeddings")
        # 初始化 ViTMAE 主层的编码器 TFViTMAEEncoder,并命名为"encoder"
        self.encoder = TFViTMAEEncoder(config, name="encoder")
        # 初始化 ViTMAE 主层的层归一化层 LayerNormalization,使用指定的 epsilon 值,并命名为"layernorm"
        self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")

    def get_input_embeddings(self) -> keras.layers.Layer:
        # 返回嵌入层 TFViTMAEEmbeddings 的补丁嵌入
        return self.embeddings.patch_embeddings

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        # 抛出未实现的错误,子类需要实现具体的头部修剪逻辑
        raise NotImplementedError

    @unpack_inputs
    # 定义一个方法 `call`,用于执行模型推断或训练
    def call(
        self,
        pixel_values: TFModelInputType | None = None,  # 输入像素值,可以为空
        noise: tf.Tensor = None,  # 噪声张量,默认为空
        head_mask: np.ndarray | tf.Tensor | None = None,  # 头部掩码,可以是 NumPy 数组、张量或空
        output_attentions: Optional[bool] = None,  # 是否输出注意力权重,可选布尔值
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态,可选布尔值
        return_dict: Optional[bool] = None,  # 是否以字典形式返回结果,可选布尔值
        training: bool = False,  # 是否处于训练模式,默认为 False
    ) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
        # 调用嵌入层的方法获取嵌入输出、掩码和恢复的 IDs
        embedding_output, mask, ids_restore = self.embeddings(
            pixel_values=pixel_values, training=training, noise=noise
        )

        # 如果需要,准备头部掩码
        # 在头部掩码中,1.0 表示保留该头部
        # attention_probs 的形状为 bsz x n_heads x N x N
        # 输入的 head_mask 形状为 [num_heads] 或 [num_hidden_layers x num_heads]
        # 并且 head_mask 被转换为形状 [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        if head_mask is not None:
            # 如果存在头部掩码,但当前未实现如何处理
            raise NotImplementedError
        else:
            # 如果头部掩码为空,则创建一个空列表,长度为隐藏层数
            head_mask = [None] * self.config.num_hidden_layers

        # 使用编码器处理嵌入输出
        encoder_outputs = self.encoder(
            embedding_output,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        # 获取编码器输出的序列输出
        sequence_output = encoder_outputs[0]

        # 应用层归一化到序列输出
        sequence_output = self.layernorm(inputs=sequence_output)

        # 如果不要求以字典形式返回结果,则返回元组
        if not return_dict:
            return (sequence_output, mask, ids_restore) + encoder_outputs[1:]

        # 以 TFViTMAEModelOutput 对象形式返回结果,包括最后的隐藏状态、掩码、恢复的 IDs、隐藏状态和注意力权重
        return TFViTMAEModelOutput(
            last_hidden_state=sequence_output,
            mask=mask,
            ids_restore=ids_restore,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )

    # 定义构建方法 build,用于在需要时构建模型
    def build(self, input_shape=None):
        # 如果已经构建过模型,则直接返回
        if self.built:
            return

        # 设置标记为已构建
        self.built = True

        # 如果存在嵌入层,构建嵌入层
        if getattr(self, "embeddings", None) is not None:
            with tf.name_scope(self.embeddings.name):
                self.embeddings.build(None)

        # 如果存在编码器,构建编码器
        if getattr(self, "encoder", None) is not None:
            with tf.name_scope(self.encoder.name):
                self.encoder.build(None)

        # 如果存在层归一化,构建层归一化,设置形状为 [None, None, self.config.hidden_size]
        if getattr(self, "layernorm", None) is not None:
            with tf.name_scope(self.layernorm.name):
                self.layernorm.build([None, None, self.config.hidden_size])
"""
    Documenting the expected input formats for ViT-MAE models when using TensorFlow. This docstring serves as a guide
    for users on how to provide inputs to the model.

    TensorFlow models in `transformers` support two input formats:
    - Passing all inputs as keyword arguments.
    - Passing all inputs in a list, tuple, or dictionary as the first positional argument.

    This flexibility ensures compatibility with TensorFlow's Keras API and other functional usage scenarios.

    Args:
        pixel_values (Tensor): Input pixel values representing the image.
        attention_mask (Tensor, optional): Mask to avoid performing attention on padding tokens.
        token_type_ids (Tensor, optional): Segment token indices to distinguish different parts of the input.

    Usage Examples:
        - Using keyword arguments: `model(pixel_values=inputs)`
        - Using a list or tuple for positional argument:
          `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
        - Using a dictionary with input names: `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`

    Note:
        For custom layers or models using Keras Functional API, ensure inputs match the documented formats.

    Reference:
        [Transformers documentation](https://huggingface.co/transformers/model_doc/vit.html)
"""
    Args:
        pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
            for details.
        head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
            config will be used instead.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
            used instead.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. This argument can be used
            in eager mode, in graph mode the value will always be set to True.
        training (`bool`, *optional*, defaults to `False``):
            Whether or not to use the model in training mode (some modules like dropout modules have different
            behaviors between training and evaluation).


注释:
"""
Transformer 模型的 ViTMAE 版本的 TensorFlow 实现,输出原始隐藏状态而不带特定的输出头部。

Args:
    config (ViTMAEConfig): ViTMAE 模型的配置对象。
    *inputs: 可变长度的输入参数。
    **kwargs: 关键字参数。

Attributes:
    vit (TFViTMAEMainLayer): ViTMAE 主层对象。

Methods:
    get_input_embeddings(): 获取输入嵌入层的方法。
    call(): 模型的前向传播方法,接受多种参数并返回模型输出。
    build(input_shape=None): 构建模型的方法,用于初始化网络层。

Examples:
    ```
    >>> from transformers import AutoImageProcessor, TFViTMAEModel
    >>> from PIL import Image
    >>> import requests

    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)

    >>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base")
    >>> model = TFViTMAEModel.from_pretrained("facebook/vit-mae-base")

    >>> inputs = image_processor(images=image, return_tensors="tf")
    >>> outputs = model(**inputs)
    >>> last_hidden_states = outputs.last_hidden_state
    ```
"""
class TFViTMAEModel(TFViTMAEPreTrainedModel):
    def __init__(self, config: ViTMAEConfig, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        # 初始化 ViTMAE 主层对象
        self.vit = TFViTMAEMainLayer(config, name="vit")

    def get_input_embeddings(self):
        # 调用 ViTMAE 主层对象的输入嵌入层方法
        return self.vit.get_input_embeddings()

    @unpack_inputs
    @add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=TFViTMAEModelOutput, config_class=_CONFIG_FOR_DOC)
    def call(
        self,
        pixel_values: TFModelInputType | None = None,
        noise: tf.Tensor = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: bool = False,
    ) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
        r"""
        模型的前向传播方法,接受多种参数并返回模型输出。

        Args:
            pixel_values (TFModelInputType | None): 输入的像素值,可以为 None。
            noise (tf.Tensor): 噪声张量,默认为 None。
            head_mask (np.ndarray | tf.Tensor | None): 头部掩码,可以为 None。
            output_attentions (Optional[bool]): 是否输出注意力权重。
            output_hidden_states (Optional[bool]): 是否输出隐藏状态。
            return_dict (Optional[bool]): 是否返回字典形式的输出。
            training (bool): 是否处于训练模式。

        Returns:
            Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]: 模型的输出结果。

        Examples:
            ```
            >>> from transformers import AutoImageProcessor, TFViTMAEModel
            >>> from PIL import Image
            >>> import requests

            >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
            >>> image = Image.open(requests.get(url, stream=True).raw)

            >>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base")
            >>> model = TFViTMAEModel.from_pretrained("facebook/vit-mae-base")

            >>> inputs = image_processor(images=image, return_tensors="tf")
            >>> outputs = model(**inputs)
            >>> last_hidden_states = outputs.last_hidden_state
            ```
        """
        # 调用 ViTMAE 主层对象的前向传播方法
        outputs = self.vit(
            pixel_values=pixel_values,
            noise=noise,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        return outputs

    def build(self, input_shape=None):
        # 如果已经构建过,直接返回
        if self.built:
            return
        # 设置已构建标志为 True
        self.built = True
        # 如果存在 ViTMAE 主层对象,则在命名作用域内构建它
        if getattr(self, "vit", None) is not None:
            with tf.name_scope(self.vit.name):
                self.vit.build(None)
    # 初始化函数,用于创建对象实例时的初始化操作
    def __init__(self, config, num_patches, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)
        
        # 创建一个全连接层作为解码器的嵌入层,用于将输入映射到解码器的隐藏大小
        self.decoder_embed = keras.layers.Dense(config.decoder_hidden_size, name="decoder_embed")

        # 深拷贝配置对象,用于配置解码器层的参数
        decoder_config = deepcopy(config)
        decoder_config.hidden_size = config.decoder_hidden_size
        decoder_config.num_hidden_layers = config.decoder_num_hidden_layers
        decoder_config.num_attention_heads = config.decoder_num_attention_heads
        decoder_config.intermediate_size = config.decoder_intermediate_size
        
        # 创建多层解码器,每层使用相同的配置
        self.decoder_layers = [
            TFViTMAELayer(decoder_config, name=f"decoder_layers.{j}") for j in range(config.decoder_num_hidden_layers)
        ]

        # 创建层归一化层,用于归一化解码器的输出
        self.decoder_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="decoder_norm")
        
        # 创建解码器预测层,将解码器输出映射回原始图像块的大小和通道数
        self.decoder_pred = keras.layers.Dense(
            config.patch_size**2 * config.num_channels,
            kernel_initializer=get_initializer(config.initializer_range),
            name="decoder_pred",
        )  # encoder to decoder

        # 保存配置对象和图像块数量
        self.config = config
        self.num_patches = num_patches

    # 构建模型,用于在图层创建完成后的初始化和构建操作
    def build(self, input_shape=None):
        # 创建一个权重,用作掩码令牌,形状为 (1, 1, 解码器隐藏大小)
        self.mask_token = self.add_weight(
            shape=(1, 1, self.config.decoder_hidden_size),
            initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
            trainable=True,
            name="mask_token",
        )
        
        # 创建解码器位置嵌入权重,形状为 (1, 图像块数量+1, 解码器隐藏大小),初始化为零
        self.decoder_pos_embed = self.add_weight(
            shape=(1, self.num_patches + 1, self.config.decoder_hidden_size),
            initializer="zeros",
            trainable=False,
            name="decoder_pos_embed",
        )
        
        # 使用函数生成二维正弦余弦位置嵌入,并将结果赋值给解码器位置嵌入权重
        decoder_pos_embed = get_2d_sincos_pos_embed(
            self.decoder_pos_embed.shape[-1],
            int(self.num_patches**0.5),
            add_cls_token=True,
        )[None, ...]
        self.decoder_pos_embed.assign(decoder_pos_embed)

        # 如果已经构建完成则直接返回
        if self.built:
            return
        self.built = True
        
        # 如果存在解码器嵌入层,则构建该层
        if getattr(self, "decoder_embed", None) is not None:
            with tf.name_scope(self.decoder_embed.name):
                self.decoder_embed.build([None, None, self.config.hidden_size])
        
        # 如果存在解码器归一化层,则构建该层
        if getattr(self, "decoder_norm", None) is not None:
            with tf.name_scope(self.decoder_norm.name):
                self.decoder_norm.build([None, None, self.config.decoder_hidden_size])
        
        # 如果存在解码器预测层,则构建该层
        if getattr(self, "decoder_pred", None) is not None:
            with tf.name_scope(self.decoder_pred.name):
                self.decoder_pred.build([None, None, self.config.decoder_hidden_size])
        
        # 如果存在解码器层,则分别构建每一层
        if getattr(self, "decoder_layers", None) is not None:
            for layer in self.decoder_layers:
                with tf.name_scope(layer.name):
                    layer.build(None)

    # 调用模型,实现模型的前向计算
    def call(
        self,
        hidden_states,
        ids_restore,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True,
        # 其他未列出的参数,用于控制前向计算的行为
        # 嵌入标记tokens到隐藏状态
        x = self.decoder_embed(hidden_states)

        # 将mask tokens附加到序列
        mask_tokens = tf.tile(
            self.mask_token,
            (shape_list(x)[0], shape_list(ids_restore)[1] + 1 - shape_list(x)[1], 1),
        )
        x_ = tf.concat([x[:, 1:, :], mask_tokens], axis=1)  # 没有cls token
        x_ = tf.gather(x_, axis=1, batch_dims=1, indices=ids_restore)  # 取消洗牌
        x = tf.concat([x[:, :1, :], x_], axis=1)  # 添加cls token

        # 添加位置嵌入
        hidden_states = x + self.decoder_pos_embed

        # 应用Transformer层(块)
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        for i, layer_module in enumerate(self.decoder_layers):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_outputs = layer_module(
                hidden_states,
                head_mask=None,
                output_attentions=output_attentions,
            )

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 对隐藏状态进行归一化
        hidden_states = self.decoder_norm(hidden_states)

        # 预测器投影
        logits = self.decoder_pred(hidden_states)

        # 移除cls token
        logits = logits[:, 1:, :]

        # 根据return_dict决定返回的内容
        if not return_dict:
            return tuple(v for v in [logits, all_hidden_states, all_self_attentions] if v is not None)
        return TFViTMAEDecoderOutput(logits=logits, hidden_states=all_hidden_states, attentions=all_self_attentions)
@add_start_docstrings(
    "The ViTMAE Model transformer with the decoder on top for self-supervised pre-training.",
    VIT_MAE_START_DOCSTRING,
)
class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config

        # 初始化 ViT 主层
        self.vit = TFViTMAEMainLayer(config, name="vit")
        
        # 初始化解码器,传入配置和从 ViT 主层获取的补丁数
        self.decoder = TFViTMAEDecoder(
            config,
            num_patches=self.vit.embeddings.num_patches,
            name="decoder",
        )

    def get_input_embeddings(self):
        # 返回 ViT 主层的输入嵌入
        return self.vit.get_input_embeddings()

    def _prune_heads(self, heads_to_prune):
        # 抛出未实现错误,用于剪枝操作
        raise NotImplementedError

    def patchify(self, pixel_values):
        """
        Args:
            pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)` or `(batch_size, num_channels, height, width)`):
                Pixel values.

        Returns:
            `tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
                Patchified pixel values.
        """
        patch_size, num_channels = self.config.patch_size, self.config.num_channels
        
        # 确保通道在最后一个维度
        if shape_list(pixel_values)[1] == num_channels:
            pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))

        # 断言检查
        tf.debugging.assert_equal(
            shape_list(pixel_values)[1],
            shape_list(pixel_values)[2],
            message="Make sure the pixel values have a squared size",
        )
        tf.debugging.assert_equal(
            shape_list(pixel_values)[1] % patch_size,
            0,
            message="Make sure the pixel values have a size that is divisible by the patch size",
        )
        tf.debugging.assert_equal(
            shape_list(pixel_values)[3],
            num_channels,
            message=(
                "Make sure the number of channels of the pixel values is equal to the one set in the configuration"
            ),
        )

        # 补丁化处理
        batch_size = shape_list(pixel_values)[0]
        num_patches_one_direction = shape_list(pixel_values)[2] // patch_size
        patchified_pixel_values = tf.reshape(
            pixel_values,
            (batch_size, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size, num_channels),
        )
        patchified_pixel_values = tf.einsum("nhpwqc->nhwpqc", patchified_pixel_values)
        patchified_pixel_values = tf.reshape(
            patchified_pixel_values,
            (batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels),
        )
        return patchified_pixel_values
    def unpatchify(self, patchified_pixel_values):
        """
        Args:
            patchified_pixel_values (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
                Patchified pixel values.

        Returns:
            `tf.Tensor` of shape `(batch_size, height, width, num_channels)`:
                Pixel values.
        """
        # 从patchified_pixel_values中获取patch大小和通道数
        patch_size, num_channels = self.config.patch_size, self.config.num_channels
        # 计算每个方向上的patch数量,应该是整数
        num_patches_one_direction = int(shape_list(patchified_pixel_values)[1] ** 0.5)
        # 进行健全性检查,确保patch数量是可以平方的
        tf.debugging.assert_equal(
            num_patches_one_direction * num_patches_one_direction,
            shape_list(patchified_pixel_values)[1],
            message="Make sure that the number of patches can be squared",
        )

        # 解除patchification
        batch_size = shape_list(patchified_pixel_values)[0]
        patchified_pixel_values = tf.reshape(
            patchified_pixel_values,
            (batch_size, num_patches_one_direction, num_patches_one_direction, patch_size, patch_size, num_channels),
        )
        patchified_pixel_values = tf.einsum("nhwpqc->nhpwqc", patchified_pixel_values)
        # 重新组织成完整的像素值形状
        pixel_values = tf.reshape(
            patchified_pixel_values,
            (batch_size, num_patches_one_direction * patch_size, num_patches_one_direction * patch_size, num_channels),
        )
        return pixel_values

    def forward_loss(self, pixel_values, pred, mask):
        """
        Args:
            pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)`):
                Pixel values.
            pred (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
                Predicted pixel values.
            mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
                Tensor indicating which patches are masked (1) and which are not (0).

        Returns:
            `tf.Tensor`: Pixel reconstruction loss.
        """
        # 将像素值进行patchify处理
        target = self.patchify(pixel_values)
        # 如果设置了像素损失的归一化,则进行归一化处理
        if self.config.norm_pix_loss:
            mean = tf.reduce_mean(target, axis=-1, keepdims=True)
            var = tf.math.reduce_variance(target, axis=-1, keepdims=True)
            target = (target - mean) / (var + 1.0e-6) ** 0.5

        # 计算损失,即预测值与目标值之间的平方差
        loss = (pred - target) ** 2
        loss = tf.reduce_mean(loss, axis=-1)  # [batch_size, num_patches], mean loss per patch

        # 计算仅在掩码处损失的平均损失
        loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)  # mean loss on removed patches
        loss = tf.reshape(loss, (1,))
        return loss

    @unpack_inputs
    @add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=TFViTMAEForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
    def call(
        self,
        pixel_values: TFModelInputType | None = None,
        noise: tf.Tensor = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: bool = False,
    ) -> Union[TFViTMAEForPreTrainingOutput, Tuple[tf.Tensor]]:
        r"""
        Returns:

        Examples:

        ```
        >>> from transformers import AutoImageProcessor, TFViTMAEForPreTraining
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base")
        >>> model = TFViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")

        >>> inputs = image_processor(images=image, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> loss = outputs.loss
        >>> mask = outputs.mask
        >>> ids_restore = outputs.ids_restore
        ```"""

        # 根据传入的参数设置是否返回字典格式的输出
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用 Vision Transformer 模型进行前向传播
        outputs = self.vit(
            pixel_values=pixel_values,
            noise=noise,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        # 从输出中获取最后隐藏状态、恢复的图像标识和掩码
        latent = outputs.last_hidden_state
        ids_restore = outputs.ids_restore
        mask = outputs.mask

        # 使用解码器生成的 logits 计算前向损失
        decoder_outputs = self.decoder(latent, ids_restore)  # [batch_size, num_patches, patch_size**2*3]
        logits = decoder_outputs.logits
        loss = self.forward_loss(pixel_values, logits, mask)

        # 根据是否返回字典格式决定返回的输出形式
        if not return_dict:
            output = (logits, mask, ids_restore) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 返回 TFViTMAEForPreTrainingOutput 对象,包含损失、logits、掩码、恢复的图像标识、隐藏状态和注意力矩阵
        return TFViTMAEForPreTrainingOutput(
            loss=loss,
            logits=logits,
            mask=mask,
            ids_restore=ids_restore,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def build(self, input_shape=None):
        if self.built:
            return

        # 标记模型已经构建
        self.built = True

        # 如果已定义 Vision Transformer 模型,构建其结构
        if getattr(self, "vit", None) is not None:
            with tf.name_scope(self.vit.name):
                self.vit.build(None)

        # 如果已定义解码器模型,构建其结构
        if getattr(self, "decoder", None) is not None:
            with tf.name_scope(self.decoder.name):
                self.decoder.build(None)

.\models\vit_mae\modeling_vit_mae.py

# coding=utf-8
# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch ViT MAE (masked autoencoder) model."""


import collections.abc
import math
from copy import deepcopy
from dataclasses import dataclass
from typing import Optional, Set, Tuple, Union

import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn

from ...activations import ACT2FN  # 导入激活函数映射表
from ...modeling_outputs import BaseModelOutput  # 导入基础模型输出类
from ...modeling_utils import PreTrainedModel  # 导入预训练模型基类
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer  # 导入相关的PyTorch工具函数
from ...utils import (  # 导入常用工具函数和类
    ModelOutput,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from .configuration_vit_mae import ViTMAEConfig  # 导入ViT MAE模型的配置类


logger = logging.get_logger(__name__)  # 获取日志记录器

_CONFIG_FOR_DOC = "ViTMAEConfig"  # 用于文档的配置类名称
_CHECKPOINT_FOR_DOC = "facebook/vit-mae-base"  # 用于文档的预训练模型名称

VIT_MAE_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "facebook/vit-mae-base",
    # See all ViTMAE models at https://huggingface.co/models?filter=vit_mae
]
    # `last_hidden_state`参数:模型最后一层的隐藏状态输出,形状为`(batch_size, sequence_length, hidden_size)`
    last_hidden_state: torch.FloatTensor = None
    
    # `mask`参数:指示哪些补丁被屏蔽(1)和哪些没有(0)的张量,形状为`(batch_size, sequence_length)`
    mask: torch.LongTensor = None
    
    # `ids_restore`参数:包含(打乱后的)屏蔽补丁的原始索引的张量,形状为`(batch_size, sequence_length)`
    ids_restore: torch.LongTensor = None
    
    # `hidden_states`参数(可选):元组的`torch.FloatTensor`(如果`output_hidden_states=True`或`config.output_hidden_states=True`时返回),
    # 包含模型每一层的隐藏状态输出加上初始嵌入输出,形状为`(batch_size, sequence_length, hidden_size)`
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    
    # `attentions`参数(可选):元组的`torch.FloatTensor`(如果`output_attentions=True`或`config.output_attentions=True`时返回),
    # 包含每一层的注意力权重,形状为`(batch_size, num_heads, sequence_length, sequence_length)`。
    # 这些是经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
    attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class ViTMAEDecoderOutput(ModelOutput):
    """
    ViTMAEDecoder的输出类,包含潜在的隐藏状态和注意力。

    Args:
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
            像素重建的logits。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, 当 `output_hidden_states=True` 时返回或当 `config.output_hidden_states=True` 时返回):
            `torch.FloatTensor` 的元组(一个用于嵌入的输出 + 每层的输出),形状为 `(batch_size, sequence_length, hidden_size)`。
            模型每层输出的隐藏状态以及初始嵌入输出。
        attentions (`tuple(torch.FloatTensor)`, *optional*, 当 `output_attentions=True` 时返回或当 `config.output_attentions=True` 时返回):
            `torch.FloatTensor` 的元组(每层一个)形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
            经过注意力softmax后的注意力权重,用于计算自注意力头中的加权平均值。
    """

    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


@dataclass
class ViTMAEForPreTrainingOutput(ModelOutput):
    """
    ViTMAEForPreTraining的输出类,包含潜在的隐藏状态和注意力。

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`):
            像素重建损失。
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
            像素重建的logits。
        mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            指示哪些补丁被屏蔽(1)哪些没有(0)的张量。
        ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            包含(打乱的)屏蔽补丁的原始索引的张量。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, 当 `output_hidden_states=True` 时返回或当 `config.output_hidden_states=True` 时返回):
            `torch.FloatTensor` 的元组(一个用于嵌入的输出 + 每层的输出),形状为 `(batch_size, sequence_length, hidden_size)`。
            模型每层输出的隐藏状态以及初始嵌入输出。
        attentions (`tuple(torch.FloatTensor)`, *optional*, 当 `output_attentions=True` 时返回或当 `config.output_attentions=True` 时返回):
            `torch.FloatTensor` 的元组(每层一个)形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
            经过注意力softmax后的注意力权重,用于计算自注意力头中的加权平均值。
    """

    loss: Optional[torch.FloatTensor] = None
    # 定义变量 logits,类型为 torch.FloatTensor,初始值为 None
    logits: torch.FloatTensor = None
    # 定义变量 mask,类型为 torch.LongTensor,初始值为 None
    mask: torch.LongTensor = None
    # 定义变量 ids_restore,类型为 torch.LongTensor,初始值为 None
    ids_restore: torch.LongTensor = None
    # 定义变量 hidden_states,类型为 Optional[Tuple[torch.FloatTensor]],初始值为 None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # 定义变量 attentions,类型为 Optional[Tuple[torch.FloatTensor]],初始值为 None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
    """
    Create 2D sin/cos positional embeddings.

    Args:
        embed_dim (`int`):
            Embedding dimension.
        grid_size (`int`):
            The grid height and width.
        add_cls_token (`bool`, *optional*, defaults to `False`):
            Whether or not to add a classification (CLS) token.

    Returns:
        (`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the
        position embeddings (with or without classification token)
    """
    # Generate a grid of height and width using numpy arrays
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    # Reshape the grid to prepare for positional embeddings calculation
    grid = grid.reshape([2, 1, grid_size, grid_size])
    # Compute positional embeddings using grid coordinates
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    
    # Optionally add a classification token to the positional embeddings
    if add_cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    """
    Generate 2D sin/cos positional embeddings from grid coordinates.

    Args:
        embed_dim (`int`):
            Embedding dimension.
        grid (`numpy.ndarray`):
            Grid coordinates of shape (2, 1, grid_size, grid_size).

    Returns:
        (`numpy.ndarray`): Positional embeddings of shape (grid_size*grid_size, embed_dim)
    """
    # Ensure embedding dimension is even
    if embed_dim % 2 != 0:
        raise ValueError("embed_dim must be even")

    # Generate sin/cos positional embeddings separately for height and width
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    # Concatenate embeddings for height and width to form 2D embeddings
    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
    return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    Generate 1D sin/cos positional embeddings from positions.

    Args:
        embed_dim (`int`):
            Embedding dimension.
        pos (`numpy.ndarray`):
            Positions to be encoded, shape (M,).

    Returns:
        (`numpy.ndarray`): Positional embeddings of shape (M, embed_dim)
    """
    # Ensure embedding dimension is even
    if embed_dim % 2 != 0:
        raise ValueError("embed_dim must be even")

    # Generate frequencies for sin/cos functions
    omega = np.arange(embed_dim // 2, dtype=float)
    omega /= embed_dim / 2.0
    omega = 1.0 / 10000**omega  # (D/2,)

    # Reshape positions for matrix multiplication
    pos = pos.reshape(-1)  # (M,)
    out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product

    # Compute sin and cos embeddings
    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    # Concatenate sin and cos embeddings
    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


class ViTMAEEmbeddings(nn.Module):
    """
    Construct the CLS token, position and patch embeddings.

    """

    def __init__(self, config):
        """
        Initializes ViTMAEEmbeddings module.

        Args:
            config (`object`):
                Configuration object containing model parameters.
        """
        super().__init__()

        # Define CLS token as a learnable parameter
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
        # Initialize patch embeddings using ViTMAEPatchEmbeddings module
        self.patch_embeddings = ViTMAEPatchEmbeddings(config)
        # Obtain number of patches from patch_embeddings
        self.num_patches = self.patch_embeddings.num_patches
        # Fixed sin-cos positional embeddings for patches and CLS token
        self.position_embeddings = nn.Parameter(
            torch.zeros(1, self.num_patches + 1, config.hidden_size), requires_grad=False
        )
        self.config = config
        self.initialize_weights()

    def initialize_weights(self):
        """
        Initialize weights of the module.
        """
        # Implementation details for weight initialization can be added here
        pass
    def initialize_weights(self):
        # 使用 sin-cos 嵌入初始化(并冻结)位置嵌入
        pos_embed = get_2d_sincos_pos_embed(
            self.position_embeddings.shape[-1], int(self.patch_embeddings.num_patches**0.5), add_cls_token=True
        )
        self.position_embeddings.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        # 使用类似 nn.Linear 的方式初始化 patch_embeddings(而不是 nn.Conv2d)
        w = self.patch_embeddings.projection.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        # timm 的 trunc_normal_(std=.02) 实际上是 normal_(std=0.02),因为截断过大(2.)
        torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range)

    def random_masking(self, sequence, noise=None):
        """
        执行每个样本的随机掩码操作,通过每个样本的排序随机噪声来进行。排序随机噪声通过 argsort 实现。

        Args:
            sequence (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`)
                输入序列张量
            noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *可选*) 
                主要用于测试目的,控制随机性以及保持可重现性
        """
        batch_size, seq_length, dim = sequence.shape
        len_keep = int(seq_length * (1 - self.config.mask_ratio))

        if noise is None:
            noise = torch.rand(batch_size, seq_length, device=sequence.device)  # 噪声范围在 [0, 1]

        # 对每个样本的噪声进行排序
        ids_shuffle = torch.argsort(noise, dim=1)  # 升序:小的保留,大的移除
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # 保留第一个子集
        ids_keep = ids_shuffle[:, :len_keep]
        sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))

        # 生成二进制掩码:0 表示保留,1 表示移除
        mask = torch.ones([batch_size, seq_length], device=sequence.device)
        mask[:, :len_keep] = 0
        # 解除排序以获得二进制掩码
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return sequence_unmasked, mask, ids_restore

    def forward(self, pixel_values, noise=None):
        batch_size, num_channels, height, width = pixel_values.shape
        embeddings = self.patch_embeddings(pixel_values)

        # 添加位置嵌入,不包括 cls token
        embeddings = embeddings + self.position_embeddings[:, 1:, :]

        # 掩码操作:长度变为 length * config.mask_ratio
        embeddings, mask, ids_restore = self.random_masking(embeddings, noise)

        # 添加 cls token
        cls_token = self.cls_token + self.position_embeddings[:, :1, :]
        cls_tokens = cls_token.expand(embeddings.shape[0], -1, -1)
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)

        return embeddings, mask, ids_restore
# 定义一个名为 ViTMAEPatchEmbeddings 的类,继承自 nn.Module,用于将形状为 `(batch_size, num_channels, height, width)` 的像素值转换成形状为 `(batch_size, seq_length, hidden_size)` 的初始隐藏状态(patch embeddings),以供 Transformer 使用。
class ViTMAEPatchEmbeddings(nn.Module):
    """
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    """

    def __init__(self, config):
        super().__init__()
        # 从配置中获取图像大小和patch大小
        image_size, patch_size = config.image_size, config.patch_size
        # 从配置中获取通道数和隐藏大小
        num_channels, hidden_size = config.num_channels, config.hidden_size
        # 如果图像大小和patch大小不是可迭代对象,则转换为元组
        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
        # 计算图像中的patch数量
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_patches = num_patches

        # 使用 nn.Conv2d 定义投影层,将输入通道数转换为隐藏大小,使用patch大小的卷积核和步幅
        self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, pixel_values):
        # 获取输入张量的维度信息
        batch_size, num_channels, height, width = pixel_values.shape
        # 如果输入通道数与配置中的不匹配,则抛出 ValueError 异常
        if num_channels != self.num_channels:
            raise ValueError(
                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
            )
        # 如果输入图像尺寸与配置中的不匹配,则抛出 ValueError 异常
        if height != self.image_size[0] or width != self.image_size[1]:
            raise ValueError(
                f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
            )
        # 将输入张量通过投影层进行处理,然后展平成二维张量,并交换维度顺序
        x = self.projection(pixel_values).flatten(2).transpose(1, 2)
        return x


# 从 transformers.models.vit.modeling_vit.ViTSelfAttention 模型复制并重命名为 ViTMAESelfAttention
class ViTMAESelfAttention(nn.Module):
    def __init__(self, config: ViTMAEConfig) -> None:
        super().__init__()
        # 检查隐藏大小是否可以被注意力头数整除,如果不是,则抛出 ValueError 异常
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
                f"heads {config.num_attention_heads}."
            )

        # 初始化注意力头数和每个注意力头的大小
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # 定义查询、键和值的线性变换层
        self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
        self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
        self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)

        # 定义 dropout 层,用于注意力概率的随机丢弃
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
    # 将输入张量 x 进行维度重塑,以便进行注意力计算
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    # 模型的前向传播方法
    def forward(
        self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        # 计算混合查询向量
        mixed_query_layer = self.query(hidden_states)

        # 对键值对进行维度重塑,为了计算注意力得分
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)

        # 计算注意力得分,即查询向量与键向量的点积
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        # 缩放注意力得分
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        # 将注意力得分归一化为概率分布
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # 对注意力概率进行 dropout 处理
        attention_probs = self.dropout(attention_probs)

        # 如果指定了头部掩码,应用头部掩码
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        # 计算上下文向量,即注意力概率与值向量的加权和
        context_layer = torch.matmul(attention_probs, value_layer)

        # 将上下文向量维度重塑为 [batch_size, seq_length, all_head_size]
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)

        # 准备输出结果,包括上下文层和(可选的)注意力概率
        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMAE
class ViTMAESelfOutput(nn.Module):
    """
    The residual connection is defined in ViTMAELayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    """

    def __init__(self, config: ViTMAEConfig) -> None:
        super().__init__()
        # 定义一个全连接层,将输入特征空间转换为隐藏状态大小
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 定义一个dropout层,以减少过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 输入经过全连接层
        hidden_states = self.dense(hidden_states)
        # 经过dropout层
        hidden_states = self.dropout(hidden_states)

        return hidden_states


# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMAE
class ViTMAEAttention(nn.Module):
    def __init__(self, config: ViTMAEConfig) -> None:
        super().__init__()
        # 创建一个ViTMAESelfAttention对象
        self.attention = ViTMAESelfAttention(config)
        # 创建一个ViTMAESelfOutput对象
        self.output = ViTMAESelfOutput(config)
        # 初始化一个空集合,用于存储需要被修剪的注意力头
        self.pruned_heads = set()

    def prune_heads(self, heads: Set[int]) -> None:
        if len(heads) == 0:
            return
        # 寻找需要被修剪的头的索引
        heads, index = find_pruneable_heads_and_indices(
            heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
        )

        # 修剪线性层
        self.attention.query = prune_linear_layer(self.attention.query, index)
        self.attention.key = prune_linear_layer(self.attention.key, index)
        self.attention.value = prune_linear_layer(self.attention.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # 更新超参数并存储修剪后的注意力头
        self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
        self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        # 调用self.attention进行自注意力计算
        self_outputs = self.attention(hidden_states, head_mask, output_attentions)

        # 将self输出传递给self.output进行进一步处理
        attention_output = self.output(self_outputs[0], hidden_states)

        outputs = (attention_output,) + self_outputs[1:]  # 如果需要输出注意力权重,则添加到outputs中
        return outputs


# Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->ViTMAE
class ViTMAEIntermediate(nn.Module):
    def __init__(self, config: ViTMAEConfig) -> None:
        super().__init__()
        # 创建一个线性层,将隐藏状态大小转换为中间层大小
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        # 如果隐藏激活函数是字符串,则使用相应的激活函数映射;否则使用给定的激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act
    # 定义一个前向传播方法,接收隐藏状态作为输入张量,并返回处理后的张量
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 使用全连接层对隐藏状态进行线性变换
        hidden_states = self.dense(hidden_states)
        # 应用激活函数到线性变换后的隐藏状态张量
        hidden_states = self.intermediate_act_fn(hidden_states)

        # 返回处理后的隐藏状态张量作为输出
        return hidden_states
# 从 transformers.models.vit.modeling_vit.ViTOutput 复制而来,被重命名为 ViTMAEOutput
class ViTMAEOutput(nn.Module):
    def __init__(self, config: ViTMAEConfig) -> None:
        super().__init__()
        # 创建一个全连接层,将输入维度转换为 config.hidden_size
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        # 定义一个 dropout 层,使用 config.hidden_dropout_prob 的概率进行随机失活
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 将输入的 hidden_states 通过全连接层 dense 进行线性变换
        hidden_states = self.dense(hidden_states)
        # 对线性变换后的结果进行 dropout 操作
        hidden_states = self.dropout(hidden_states)

        # 将 dropout 后的结果与输入的 input_tensor 相加,实现残差连接
        hidden_states = hidden_states + input_tensor

        return hidden_states


# 从 transformers.models.vit.modeling_vit.ViTLayer 复制而来,被重命名为 ViTMAELayer
class ViTMAELayer(nn.Module):
    """对应 timm 实现中的 Block 类。"""

    def __init__(self, config: ViTMAEConfig) -> None:
        super().__init__()
        # 设置 chunk_size_feed_forward 和 seq_len_dim 参数
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        # 初始化 self-attention、中间层和输出层
        self.attention = ViTMAEAttention(config)
        self.intermediate = ViTMAEIntermediate(config)
        self.output = ViTMAEOutput(config)
        # 在 self-attention 前应用 layernorm
        self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 在 self-attention 后再次应用 layernorm
        self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        # 在 ViTMAE 中,在 self-attention 前应用 layernorm
        self_attention_outputs = self.attention(
            self.layernorm_before(hidden_states),
            head_mask,
            output_attentions=output_attentions,
        )
        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[1:]  # 如果输出注意力权重,则添加自注意力权重

        # 第一个残差连接
        hidden_states = attention_output + hidden_states

        # 在 ViTMAE 中,self-attention 后再次应用 layernorm
        layer_output = self.layernorm_after(hidden_states)
        layer_output = self.intermediate(layer_output)

        # 第二个残差连接在此处完成
        layer_output = self.output(layer_output, hidden_states)

        outputs = (layer_output,) + outputs

        return outputs


# 从 transformers.models.vit.modeling_vit.ViTEncoder 复制而来,被重命名为 ViTMAEEncoder
class ViTMAEEncoder(nn.Module):
    def __init__(self, config: ViTMAEConfig) -> None:
        super().__init__()
        self.config = config
        # 使用 ViTMAELayer 构建层的列表,重复 config.num_hidden_layers 次
        self.layer = nn.ModuleList([ViTMAELayer(config) for _ in range(config.num_hidden_layers)])
        # 设置梯度检查点为 False
        self.gradient_checkpointing = False

    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    # 定义函数的返回类型为一个元组或BaseModelOutput类型
    ) -> Union[tuple, BaseModelOutput]:
        # 如果不输出隐藏状态,则初始化为空元组;否则为None
        all_hidden_states = () if output_hidden_states else None
        # 如果不输出注意力权重,则初始化为空元组;否则为None
        all_self_attentions = () if output_attentions else None
    
        # 遍历模型的每一层
        for i, layer_module in enumerate(self.layer):
            # 如果需要输出隐藏状态
            if output_hidden_states:
                # 将当前层的隐藏状态添加到all_hidden_states元组中
                all_hidden_states = all_hidden_states + (hidden_states,)
    
            # 如果给定了head_mask,则使用当前层的head_mask;否则为None
            layer_head_mask = head_mask[i] if head_mask is not None else None
    
            # 如果开启了梯度检查点并且处于训练模式
            if self.gradient_checkpointing and self.training:
                # 使用梯度检查点函数来调用当前层的forward方法
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    hidden_states,
                    layer_head_mask,
                    output_attentions,
                )
            else:
                # 否则直接调用当前层的forward方法
                layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
    
            # 获取当前层的输出隐藏状态
            hidden_states = layer_outputs[0]
    
            # 如果需要输出注意力权重
            if output_attentions:
                # 将当前层的注意力权重添加到all_self_attentions元组中
                all_self_attentions = all_self_attentions + (layer_outputs[1],)
    
        # 如果需要输出隐藏状态
        if output_hidden_states:
            # 将最终的隐藏状态添加到all_hidden_states元组中
            all_hidden_states = all_hidden_states + (hidden_states,)
    
        # 如果不需要返回字典格式的结果
        if not return_dict:
            # 返回非None的元组,包括隐藏状态、所有隐藏状态和所有注意力权重
            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
        
        # 否则,返回BaseModelOutput类型的对象,包括最终的隐藏状态、所有隐藏状态和所有注意力权重
        return BaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )
# 使用 `add_start_docstrings` 装饰器为 `ViTMAEModel` 类添加文档字符串,描述其为一个输出原始隐藏状态的 ViTMAE 模型变压器,没有特定输出头部。
# 包含关于模型使用和行为的一般信息,建议用户参考 PyTorch 文档进行详细了解。

@add_start_docstrings(
    "The bare ViTMAE Model transformer outputting raw hidden-states without any specific head on top.",
    VIT_MAE_START_DOCSTRING,
)
class ViTMAEModel(ViTMAEPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 设置配置类,用于模型配置参数的初始化
    config_class = ViTMAEConfig
    # 基础模型前缀,用于标识模型
    base_model_prefix = "vit"
    # 主要输入名称,指定模型的主要输入是像素值
    main_input_name = "pixel_values"
    # 支持梯度检查点
    supports_gradient_checkpointing = True

    def _init_weights(self, module):
        """Initialize the weights"""
        # 初始化模型权重
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            # 对于线性层和卷积层,使用正态分布初始化权重,均值为 0,标准差为配置文件中的初始化范围
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 如果存在偏置,则将偏置初始化为零
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            # 对于 LayerNorm 层,初始化偏置为零,权重为 1
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
    # 初始化函数,接受一个配置参数 config
    def __init__(self, config):
        # 调用父类的初始化函数,传入配置参数 config
        super().__init__(config)
        # 将配置参数 config 存储在对象的属性中
        self.config = config

        # 初始化 ViTMAE 模型的嵌入层和编码器
        self.embeddings = ViTMAEEmbeddings(config)
        self.encoder = ViTMAEEncoder(config)

        # 初始化 LayerNorm 层,用于规范化隐藏层的输出
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        # 调用模型的后初始化方法,初始化权重并应用最终处理
        self.post_init()

    # 获取输入嵌入层的方法
    def get_input_embeddings(self):
        return self.embeddings.patch_embeddings

    # 剪枝模型中的注意力头部方法
    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        # 遍历需要剪枝的层及其对应的注意力头部列表
        for layer, heads in heads_to_prune.items():
            # 调用编码器中特定层的注意力机制对象的剪枝方法
            self.encoder.layer[layer].attention.prune_heads(heads)

    # 前向传播方法,实现模型的推理过程
    @add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=ViTMAEModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        noise: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 如果 output_attentions 不为 None,则使用其值;否则使用模型配置中的 output_attentions 值

        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 如果 output_hidden_states 不为 None,则使用其值;否则使用模型配置中的 output_hidden_states 值

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # 如果 return_dict 不为 None,则使用其值;否则使用模型配置中的 use_return_dict 值

        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")
        # 如果 pixel_values 为 None,则抛出 ValueError 异常提示需要指定 pixel_values

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
        # 根据需要准备头部掩码
        # head_mask 是一个形状为 [num_hidden_layers x batch x num_heads x seq_length x seq_length] 的掩码数组

        embedding_output, mask, ids_restore = self.embeddings(pixel_values, noise=noise)
        # 将 pixel_values 通过 embeddings 方法转换为嵌入输出 embedding_output,同时生成 mask 和 ids_restore

        encoder_outputs = self.encoder(
            embedding_output,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 使用编码器处理嵌入输出,可以选择传入头部掩码、注意力输出、隐藏状态输出和是否使用返回字典

        sequence_output = encoder_outputs[0]
        # 从编码器输出中获取序列输出

        sequence_output = self.layernorm(sequence_output)
        # 应用 layernorm 对序列输出进行归一化处理

        if not return_dict:
            return (sequence_output, mask, ids_restore) + encoder_outputs[1:]
        # 如果不使用返回字典,则返回序列输出、mask、ids_restore,以及编码器输出的其余部分

        return ViTMAEModelOutput(
            last_hidden_state=sequence_output,
            mask=mask,
            ids_restore=ids_restore,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
        # 使用 ViTMAEModelOutput 封装输出,包括最终隐藏状态、mask、ids_restore、隐藏状态数组和注意力数组
    # ViTMAE 解码器模型类,继承自 nn.Module
    class ViTMAEDecoder(nn.Module):
        def __init__(self, config, num_patches):
            super().__init__()
            # 初始化解码器的嵌入层,将隐藏大小转换为解码器隐藏大小
            self.decoder_embed = nn.Linear(config.hidden_size, config.decoder_hidden_size, bias=True)
            # 定义掩码令牌作为可学习参数
            self.mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_hidden_size))
            # 初始化解码器位置嵌入,使用固定的正弦-余弦嵌入
            self.decoder_pos_embed = nn.Parameter(
                torch.zeros(1, num_patches + 1, config.decoder_hidden_size), requires_grad=False
            )  # fixed sin-cos embedding

            # 复制配置并调整以用于解码器
            decoder_config = deepcopy(config)
            decoder_config.hidden_size = config.decoder_hidden_size
            decoder_config.num_hidden_layers = config.decoder_num_hidden_layers
            decoder_config.num_attention_heads = config.decoder_num_attention_heads
            decoder_config.intermediate_size = config.decoder_intermediate_size
            # 创建解码器层列表
            self.decoder_layers = nn.ModuleList(
                [ViTMAELayer(decoder_config) for _ in range(config.decoder_num_hidden_layers)]
            )

            # 初始化解码器层归一化层
            self.decoder_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
            # 定义解码器的预测线性层,将解码器隐藏大小映射为图像块的像素数和通道数
            self.decoder_pred = nn.Linear(
                config.decoder_hidden_size, config.patch_size**2 * config.num_channels, bias=True
            )  # encoder to decoder
            # 是否使用梯度检查点,默认为 False
            self.gradient_checkpointing = False
            # 存储模型配置
            self.config = config
            # 初始化权重
            self.initialize_weights(num_patches)

        def initialize_weights(self, num_patches):
            # 使用正弦-余弦嵌入初始化(并冻结)位置嵌入
            decoder_pos_embed = get_2d_sincos_pos_embed(
                self.decoder_pos_embed.shape[-1], int(num_patches**0.5), add_cls_token=True
            )
            self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

            # 使用 timm 的截断正态分布初始化掩码令牌
            # timm's trunc_normal_(std=.02) 实际上相当于 normal_(std=0.02),因为截断值太大(2.)
            torch.nn.init.normal_(self.mask_token, std=self.config.initializer_range)

        def forward(
            self,
            hidden_states,
            ids_restore,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True,
        # embed tokens
        # 使用解码器的嵌入层将隐藏状态转换为嵌入表示
        x = self.decoder_embed(hidden_states)

        # append mask tokens to sequence
        # 将掩码标记追加到序列中
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
        # 根据恢复的标识索引重新排列张量
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token

        # add pos embed
        # 添加位置嵌入
        hidden_states = x + self.decoder_pos_embed

        # apply Transformer layers (blocks)
        # 应用 Transformer 层(块)
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        for i, layer_module in enumerate(self.decoder_layers):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            if self.gradient_checkpointing and self.training:
                # 使用梯度检查点函数处理层调用(用于内存效率)
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    hidden_states,
                    None,
                    output_attentions,
                )
            else:
                layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions)

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # normalize output using layer norm
        # 使用层归一化对隐藏状态进行标准化
        hidden_states = self.decoder_norm(hidden_states)

        # predictor projection
        # 预测器投影
        logits = self.decoder_pred(hidden_states)

        # remove cls token
        # 移除 cls 标记
        logits = logits[:, 1:, :]

        if not return_dict:
            # 如果不返回字典形式的输出,按顺序返回结果元组
            return tuple(v for v in [logits, all_hidden_states, all_self_attentions] if v is not None)
        # 返回 ViTMAEDecoderOutput 对象,包含 logits、hidden_states 和 attentions
        return ViTMAEDecoderOutput(
            logits=logits,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )
@add_start_docstrings(
    """The ViTMAE Model transformer with the decoder on top for self-supervised pre-training.

    <Tip>

    Note that we provide a script to pre-train this model on custom data in our [examples
    directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).

    </Tip>

    """,
    VIT_MAE_START_DOCSTRING,
)
class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config

        # Initialize the ViTMAE model with the provided configuration
        self.vit = ViTMAEModel(config)

        # Initialize the ViTMAE decoder using the config and number of patches from the embeddings
        self.decoder = ViTMAEDecoder(config, num_patches=self.vit.embeddings.num_patches)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        # Return the patch embeddings from the ViTMAE model's embeddings
        return self.vit.embeddings.patch_embeddings

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model.

        heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        See base class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            # Prune heads in the attention mechanism of the encoder layer
            self.encoder.layer[layer].attention.prune_heads(heads)

    def patchify(self, pixel_values):
        """
        Args:
            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
                Pixel values.

        Returns:
            `torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
                Patchified pixel values.
        """
        patch_size, num_channels = self.config.patch_size, self.config.num_channels
        
        # Perform sanity checks on pixel values
        if (pixel_values.shape[2] != pixel_values.shape[3]) or (pixel_values.shape[2] % patch_size != 0):
            raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size")
        if pixel_values.shape[1] != num_channels:
            raise ValueError(
                "Make sure the number of channels of the pixel values is equal to the one set in the configuration"
            )

        # Patchify the pixel values
        batch_size = pixel_values.shape[0]
        num_patches_one_direction = pixel_values.shape[2] // patch_size
        patchified_pixel_values = pixel_values.reshape(
            batch_size, num_channels, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size
        )
        patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values)
        patchified_pixel_values = patchified_pixel_values.reshape(
            batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels
        )
        return patchified_pixel_values
    def unpatchify(self, patchified_pixel_values):
        """
        Args:
            patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
                Patchified pixel values.

        Returns:
            `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
                Pixel values.
        """
        # 从配置中获取补丁大小和通道数
        patch_size, num_channels = self.config.patch_size, self.config.num_channels
        
        # 计算单个方向上的补丁数量
        num_patches_one_direction = int(patchified_pixel_values.shape[1] ** 0.5)
        
        # 检查补丁数量是否可以完全平方
        if num_patches_one_direction**2 != patchified_pixel_values.shape[1]:
            raise ValueError("Make sure that the number of patches can be squared")
        
        # 对补丁化的像素值进行重塑,以进行反补丁化
        batch_size = patchified_pixel_values.shape[0]
        patchified_pixel_values = patchified_pixel_values.reshape(
            batch_size,
            num_patches_one_direction,
            num_patches_one_direction,
            patch_size,
            patch_size,
            num_channels,
        )
        
        # 使用 `einsum` 函数重新排列张量维度,完成反补丁化
        patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values)
        
        # 最终重塑像素值张量,以恢复原始图像形状
        pixel_values = patchified_pixel_values.reshape(
            batch_size,
            num_channels,
            num_patches_one_direction * patch_size,
            num_patches_one_direction * patch_size,
        )
        
        return pixel_values

    def forward_loss(self, pixel_values, pred, mask):
        """
        Args:
            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
                Pixel values.
            pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
                Predicted pixel values.
            mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
                Tensor indicating which patches are masked (1) and which are not (0).

        Returns:
            `torch.FloatTensor`: Pixel reconstruction loss.
        """
        # 对目标像素值进行补丁化
        target = self.patchify(pixel_values)
        
        # 如果配置中指定了像素归一化损失,则进行像素归一化
        if self.config.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.0e-6) ** 0.5
        
        # 计算像素重建损失
        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch
        
        # 根据掩码计算被移除补丁的平均损失
        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
        
        return loss

    @add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=ViTMAEForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        noise: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # 设置返回字典的选项,如果未指定则使用配置中的默认值

        outputs = self.vit(
            pixel_values,
            noise=noise,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 使用 Vision Transformer 处理器进行前向传播,生成模型的输出

        latent = outputs.last_hidden_state
        # 获取模型输出的最后隐藏状态作为潜变量

        ids_restore = outputs.ids_restore
        # 获取模型输出中的 ids_restore 属性,用于恢复图像补丁

        mask = outputs.mask
        # 获取模型输出中的 mask 属性,用于掩码处理

        decoder_outputs = self.decoder(latent, ids_restore)
        # 使用解码器对潜变量和 ids_restore 进行解码

        logits = decoder_outputs.logits  # shape (batch_size, num_patches, patch_size*patch_size*num_channels)
        # 从解码器的输出中获取 logits,表示重建图像的预测值

        loss = self.forward_loss(pixel_values, logits, mask)
        # 计算损失,用于评估重建图像的准确性

        if not return_dict:
            output = (logits, mask, ids_restore) + outputs[2:]
            # 如果不使用返回字典,则返回一个包含 logits、mask 和 ids_restore 的元组,以及可能的额外输出
            return ((loss,) + output) if loss is not None else output
            # 如果存在损失,则将损失加入返回结果中

        return ViTMAEForPreTrainingOutput(
            loss=loss,
            logits=logits,
            mask=mask,
            ids_restore=ids_restore,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
        # 如果使用返回字典,则将损失、logits、mask、ids_restore、隐藏状态和注意力作为 ViTMAEForPreTrainingOutput 的实例返回

.\models\vit_mae\__init__.py

# 引入依赖类型检查
from typing import TYPE_CHECKING

# 引入内部工具函数和异常类
from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_flax_available,
    is_tf_available,
    is_torch_available,
)

# 定义模块导入结构
_import_structure = {"configuration_vit_mae": ["VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTMAEConfig"]}

# 检查是否支持 Torch,若不支持则抛出异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若支持 Torch,则添加相关模型定义到导入结构中
    _import_structure["modeling_vit_mae"] = [
        "VIT_MAE_PRETRAINED_MODEL_ARCHIVE_LIST",
        "ViTMAEForPreTraining",
        "ViTMAELayer",
        "ViTMAEModel",
        "ViTMAEPreTrainedModel",
    ]

# 检查是否支持 TensorFlow,若不支持则抛出异常
try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若支持 TensorFlow,则添加相关模型定义到导入结构中
    _import_structure["modeling_tf_vit_mae"] = [
        "TFViTMAEForPreTraining",
        "TFViTMAEModel",
        "TFViTMAEPreTrainedModel",
    ]

# 如果处于类型检查模式
if TYPE_CHECKING:
    # 从特定模块导入配置和模型定义
    from .configuration_vit_mae import VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMAEConfig

    try:
        # 再次检查是否支持 Torch,若不支持则抛出异常
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 若支持 Torch,则从模型定义中导入相关类
        from .modeling_vit_mae import (
            VIT_MAE_PRETRAINED_MODEL_ARCHIVE_LIST,
            ViTMAEForPreTraining,
            ViTMAELayer,
            ViTMAEModel,
            ViTMAEPreTrainedModel,
        )

    try:
        # 再次检查是否支持 TensorFlow,若不支持则抛出异常
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 若支持 TensorFlow,则从模型定义中导入相关类
        from .modeling_tf_vit_mae import TFViTMAEForPreTraining, TFViTMAEModel, TFViTMAEPreTrainedModel

# 如果不处于类型检查模式
else:
    # 动态创建一个懒加载模块,用于按需导入所需的模型和配置
    import sys
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\vit_msn\configuration_vit_msn.py

# coding=utf-8
# 指定文件编码格式为UTF-8

# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
# 版权声明,版权归Facebook AI和HuggingFace Inc.团队所有

# Licensed under the Apache License, Version 2.0 (the "License");
# 根据Apache License 2.0许可协议授权使用该文件

# you may not use this file except in compliance with the License.
# 除非遵守许可协议,否则不能使用该文件

# You may obtain a copy of the License at
# 您可以在此处获取许可协议的副本

#     http://www.apache.org/licenses/LICENSE-2.0
#     http://www.apache.org/licenses/LICENSE-2.0 的链接

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 没有任何形式的明示或暗示的担保或条件,包括但不限于

# See the License for the specific language governing permissions and
# limitations under the License.
# 许可协议详细说明了授权的特定语言和限制条件

""" ViT MSN model configuration"""
# ViT MSN模型的配置信息

# Import necessary libraries
# 导入必要的库
from ...configuration_utils import PretrainedConfig
from ...utils import logging

# Get logger instance
# 获取日志记录器实例
logger = logging.get_logger(__name__)

# Dictionary mapping model names to their respective config.json file URLs
# 字典,将模型名称映射到其相应的config.json文件的URL
VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "sayakpaul/vit-msn-base": "https://huggingface.co/sayakpaul/vit-msn-base/resolve/main/config.json",
    # See all ViT MSN models at https://huggingface.co/models?filter=vit_msn
}

# Configuration class for ViT MSN model inheriting PretrainedConfig
# ViT MSN模型的配置类,继承自PretrainedConfig
class ViTMSNConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`ViTMSNModel`]. It is used to instantiate an ViT
    MSN model according to the specified arguments, defining the model architecture. Instantiating a configuration with
    the defaults will yield a similar configuration to that of the ViT
    [facebook/vit_msn_base](https://huggingface.co/facebook/vit_msn_base) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    """
    # 这是用于存储ViTMSNModel配置的配置类。根据指定的参数实例化ViT MSN模型,定义模型架构。
    # 使用默认参数实例化配置将产生与ViT facebook/vit_msn_base架构类似的配置。

    # Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
    # 配置对象继承自[`PretrainedConfig`],可用于控制模型输出。阅读[`PretrainedConfig`]的文档获取更多信息。
    # 设定模型类型为 "vit_msn"
    model_type = "vit_msn"

    # 定义初始化方法,接受多个可选参数
    def __init__(
        self,
        hidden_size=768,  # 编码器层和池化层的维度大小,默认为768
        num_hidden_layers=12,  # Transformer 编码器中隐藏层的数量,默认为12
        num_attention_heads=12,  # Transformer 编码器中每个注意力层的注意头数量,默认为12
        intermediate_size=3072,  # Transformer 编码器中"中间"(即前馈)层的维度,默认为3072
        hidden_act="gelu",  # 编码器和池化器中的非线性激活函数,默认为"gelu"
        hidden_dropout_prob=0.0,  # 嵌入层、编码器和池化器中所有全连接层的dropout概率,默认为0.0
        attention_probs_dropout_prob=0.0,  # 注意力概率的dropout比率,默认为0.0
        initializer_range=0.02,  # 用于初始化所有权重矩阵的截断正态分布的标准差,默认为0.02
        layer_norm_eps=1e-06,  # 层归一化层使用的 epsilon,默认为1e-06
        image_size=224,  # 每个图像的大小(分辨率),默认为224
        patch_size=16,  # 每个补丁的大小(分辨率),默认为16
        num_channels=3,  # 输入通道的数量,默认为3
        qkv_bias=True,  # 是否向查询、键和值中添加偏置,默认为True
        **kwargs,  # 其他可选参数
    ):
        ):
        # 调用父类的初始化方法,传递所有关键字参数
        super().__init__(**kwargs)

        # 设置隐藏层大小
        self.hidden_size = hidden_size
        # 设置隐藏层数量
        self.num_hidden_layers = num_hidden_layers
        # 设置注意力头的数量
        self.num_attention_heads = num_attention_heads
        # 设置中间层大小
        self.intermediate_size = intermediate_size
        # 设置隐藏层激活函数
        self.hidden_act = hidden_act
        # 设置隐藏层的dropout概率
        self.hidden_dropout_prob = hidden_dropout_prob
        # 设置注意力概率的dropout概率
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        # 设置初始化范围
        self.initializer_range = initializer_range
        # 设置层归一化的epsilon值
        self.layer_norm_eps = layer_norm_eps
        # 设置图像大小
        self.image_size = image_size
        # 设置patch(补丁)的大小
        self.patch_size = patch_size
        # 设置通道数
        self.num_channels = num_channels
        # 设置qkv偏置
        self.qkv_bias = qkv_bias

.\models\vit_msn\convert_msn_to_pytorch.py

# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ViT MSN checkpoints from the original repository: https://github.com/facebookresearch/msn"""

import argparse  # 导入解析命令行参数的库
import json  # 导入处理 JSON 格式数据的库

import requests  # 导入进行 HTTP 请求的库
import torch  # 导入 PyTorch 深度学习库
from huggingface_hub import hf_hub_download  # 导入从 Hugging Face Hub 下载模型的功能
from PIL import Image  # 导入 Python Imaging Library,用于处理图像

from transformers import ViTImageProcessor, ViTMSNConfig, ViTMSNModel  # 导入用于处理 ViT 模型的相关类
from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD  # 导入图像处理相关的常量


torch.set_grad_enabled(False)  # 禁用梯度计算

# here we list all keys to be renamed (original name on the left, our name on the right)
def create_rename_keys(config, base_model=False):
    rename_keys = []  # 初始化空列表,用于存储重命名的键值对
    for i in range(config.num_hidden_layers):
        # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
        rename_keys.append((f"module.blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight"))
        rename_keys.append((f"module.blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias"))
        rename_keys.append(
            (f"module.blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight")
        )
        rename_keys.append((f"module.blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias"))
        rename_keys.append((f"module.blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight"))
        rename_keys.append((f"module.blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias"))
        rename_keys.append((f"module.blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight"))
        rename_keys.append((f"module.blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias"))
        rename_keys.append((f"module.blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight"))
        rename_keys.append((f"module.blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias"))

    # projection layer + position embeddings
    rename_keys.extend(
        [
            ("module.cls_token", "vit.embeddings.cls_token"),
            ("module.patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"),
            ("module.patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"),
            ("module.pos_embed", "vit.embeddings.position_embeddings"),
        ]
    )
    # 如果存在基础模型,则执行以下操作
    if base_model:
        # 将以下键值对添加到 rename_keys 列表中,用于重命名
        rename_keys.extend(
            [
                ("module.norm.weight", "layernorm.weight"),  # 将 "module.norm.weight" 重命名为 "layernorm.weight"
                ("module.norm.bias", "layernorm.bias"),      # 将 "module.norm.bias" 重命名为 "layernorm.bias"
            ]
        )

        # 如果只有基础模型,需要从所有以 "vit" 开头的键中删除 "vit"
        rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys]
    else:
        # 如果没有基础模型,则执行以下操作
        # 将以下键值对添加到 rename_keys 列表中,用于重命名
        rename_keys.extend(
            [
                ("norm.weight", "vit.layernorm.weight"),   # 将 "norm.weight" 重命名为 "vit.layernorm.weight"
                ("norm.bias", "vit.layernorm.bias"),       # 将 "norm.bias" 重命名为 "vit.layernorm.bias"
                ("head.weight", "classifier.weight"),      # 将 "head.weight" 重命名为 "classifier.weight"
                ("head.bias", "classifier.bias"),          # 将 "head.bias" 重命名为 "classifier.bias"
            ]
        )

    # 返回重命名后的键值对列表
    return rename_keys
# 将每个编码器层的权重矩阵分割为查询(query)、键(key)和值(value)
def read_in_q_k_v(state_dict, config, base_model=False):
    # 遍历每个编码器层
    for i in range(config.num_hidden_layers):
        if base_model:
            prefix = ""
        else:
            prefix = "vit."
        
        # 读取输入投影层的权重和偏置(在 timm 中,这是一个单独的矩阵加上偏置)
        in_proj_weight = state_dict.pop(f"module.blocks.{i}.attn.qkv.weight")
        in_proj_bias = state_dict.pop(f"module.blocks.{i}.attn.qkv.bias")
        
        # 将查询(query)、键(key)和值(value)依次添加到状态字典中
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
            : config.hidden_size, :
        ]
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
            config.hidden_size : config.hidden_size * 2, :
        ]
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
            config.hidden_size : config.hidden_size * 2
        ]
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
            -config.hidden_size :, :
        ]
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]


# 从状态字典中移除分类头部的权重和偏置
def remove_classification_head_(state_dict):
    ignore_keys = ["head.weight", "head.bias"]
    for k in ignore_keys:
        state_dict.pop(k, None)


# 从状态字典中移除投影头部的相关键
def remove_projection_head(state_dict):
    # 投影头部在自监督预训练中使用,但在下游任务中不需要
    ignore_keys = [
        "module.fc.fc1.weight",
        "module.fc.fc1.bias",
        "module.fc.bn1.weight",
        "module.fc.bn1.bias",
        "module.fc.bn1.running_mean",
        "module.fc.bn1.running_var",
        "module.fc.bn1.num_batches_tracked",
        "module.fc.fc2.weight",
        "module.fc.fc2.bias",
        "module.fc.bn2.weight",
        "module.fc.bn2.bias",
        "module.fc.bn2.running_mean",
        "module.fc.bn2.running_var",
        "module.fc.bn2.num_batches_tracked",
        "module.fc.fc3.weight",
        "module.fc.fc3.bias",
    ]
    for k in ignore_keys:
        state_dict.pop(k, None)


# 将字典中的键从旧名称重命名为新名称
def rename_key(dct, old, new):
    val = dct.pop(old)
    dct[new] = val


# 将 ViT-MSN 模型的检查点转换为 PyTorch 模型
def convert_vit_msn_checkpoint(checkpoint_url, pytorch_dump_folder_path):
    config = ViTMSNConfig()
    config.num_labels = 1000

    repo_id = "datasets/huggingface/label-files"
    filename = "imagenet-1k-id2label.json"
    
    # 从 HF Hub 下载 imagenet-1k-id2label.json 文件,并加载为 id 到 label 的映射字典
    id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
    id2label = {int(k): v for k, v in id2label.items()}
    config.id2label = id2label
    config.label2id = {v: k for k, v in id2label.items()}
    # 根据 checkpoint_url 的内容设置不同的配置参数
    if "s16" in checkpoint_url:
        # 如果包含 "s16",设置较小的隐藏层大小、中间层大小和注意力头数
        config.hidden_size = 384
        config.intermediate_size = 1536
        config.num_attention_heads = 6
    elif "l16" in checkpoint_url:
        # 如果包含 "l16",设置较大的隐藏层大小、中间层大小、层数、注意力头数和隐藏层的 dropout 概率
        config.hidden_size = 1024
        config.intermediate_size = 4096
        config.num_hidden_layers = 24
        config.num_attention_heads = 16
        config.hidden_dropout_prob = 0.1
    elif "b4" in checkpoint_url:
        # 如果包含 "b4",设置较小的图像块大小
        config.patch_size = 4
    elif "l7" in checkpoint_url:
        # 如果包含 "l7",设置较大的图像块大小、较大的隐藏层大小、中间层大小、层数、注意力头数和隐藏层的 dropout 概率
        config.patch_size = 7
        config.hidden_size = 1024
        config.intermediate_size = 4096
        config.num_hidden_layers = 24
        config.num_attention_heads = 16
        config.hidden_dropout_prob = 0.1

    # 使用配置参数初始化 ViTMSNModel 模型
    model = ViTMSNModel(config)

    # 从指定的 URL 加载预训练模型的状态字典
    state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["target_encoder"]

    # 创建图像处理器对象,设置图像大小为 config.image_size
    image_processor = ViTImageProcessor(size=config.image_size)

    # 移除模型状态字典中的投影头部分
    remove_projection_head(state_dict)
    
    # 根据配置创建新的键名映射列表
    rename_keys = create_rename_keys(config, base_model=True)

    # 对状态字典中的键名进行重命名
    for src, dest in rename_keys:
        rename_key(state_dict, src, dest)
    
    # 读取状态字典中的查询、键、值信息,针对基础模型
    read_in_q_k_v(state_dict, config, base_model=True)

    # 加载模型的状态字典
    model.load_state_dict(state_dict)
    # 设置模型为评估模式
    model.eval()

    # 设置图像 URL
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"

    # 使用 requests 库获取图像数据流,并用 PIL 库打开图像
    image = Image.open(requests.get(url, stream=True).raw)
    
    # 创建图像处理器对象,设置图像大小、图像均值和标准差
    image_processor = ViTImageProcessor(
        size=config.image_size, image_mean=IMAGENET_DEFAULT_MEAN, image_std=IMAGENET_DEFAULT_STD
    )
    
    # 对输入图像进行处理,返回 PyTorch 张量
    inputs = image_processor(images=image, return_tensors="pt")

    # 执行前向传播
    torch.manual_seed(2)
    outputs = model(**inputs)
    # 获取最后一层隐藏状态
    last_hidden_state = outputs.last_hidden_state

    # 验证预测的对数值是否接近预期值
    if "s16" in checkpoint_url:
        expected_slice = torch.tensor([[-1.0915, -1.4876, -1.1809]])
    elif "b16" in checkpoint_url:
        expected_slice = torch.tensor([[14.2889, -18.9045, 11.7281]])
    elif "l16" in checkpoint_url:
        expected_slice = torch.tensor([[41.5028, -22.8681, 45.6475]])
    elif "b4" in checkpoint_url:
        expected_slice = torch.tensor([[-4.3868, 5.2932, -0.4137]])
    else:
        expected_slice = torch.tensor([[-0.1792, -0.6465, 2.4263]])

    # 使用 assert 验证张量的所有元素是否在指定的误差范围内接近预期的值
    assert torch.allclose(last_hidden_state[:, 0, :3], expected_slice, atol=1e-4)

    # 打印模型保存的路径
    print(f"Saving model to {pytorch_dump_folder_path}")
    # 将模型保存到指定路径
    model.save_pretrained(pytorch_dump_folder_path)

    # 打印图像处理器保存的路径
    print(f"Saving image processor to {pytorch_dump_folder_path}")
    # 将图像处理器保存到指定路径
    image_processor.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
    # 如果当前脚本作为主程序运行,则执行以下代码
    parser = argparse.ArgumentParser()
    # 创建命令行参数解析器对象

    # 必选参数
    parser.add_argument(
        "--checkpoint_url",
        default="https://dl.fbaipublicfiles.com/msn/vits16_800ep.pth.tar",
        type=str,
        help="URL of the checkpoint you'd like to convert."
    )
    # 添加命令行参数,指定模型检查点的下载链接,默认为 Facebook 提供的一个预训练模型的链接

    parser.add_argument(
        "--pytorch_dump_folder_path", 
        default=None, 
        type=str, 
        help="Path to the output PyTorch model directory."
    )
    # 添加命令行参数,指定输出的 PyTorch 模型保存目录的路径,默认为 None,即没有指定路径

    # 解析命令行参数,并将其存储在 args 变量中
    args = parser.parse_args()

    # 调用 convert_vit_msn_checkpoint 函数,传入命令行参数中指定的模型下载链接和保存路径
    convert_vit_msn_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)

.\models\vit_msn\modeling_vit_msn.py

# coding=utf-8
# 版权所有 2022 年 Facebook AI 和 HuggingFace Inc. 团队。保留所有权利。
#
# 根据 Apache 许可证 2.0 版本("许可证")许可;
# 除非符合许可证要求,否则不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件基于"原样"提供,
# 没有任何形式的明示或暗示的保证或条件。
# 有关特定语言的权限,请参阅许可证。
""" PyTorch ViT MSN(masked siamese network)模型。"""

# 导入必要的库
import collections.abc
import math
from typing import Dict, List, Optional, Set, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

# 从本地库中导入相关函数和类
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_vit_msn import ViTMSNConfig

# 获取日志记录器
logger = logging.get_logger(__name__)

# 文档用配置和检查点
_CONFIG_FOR_DOC = "ViTMSNConfig"
_CHECKPOINT_FOR_DOC = "facebook/vit-msn-small"
VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "facebook/vit-msn-small",
    # 查看所有 ViTMSN 模型 https://huggingface.co/models?filter=vit_msn
]

class ViTMSNEmbeddings(nn.Module):
    """
    构建 CLS 令牌、位置和补丁嵌入。可选地,也包括掩码令牌。
    """

    def __init__(self, config: ViTMSNConfig, use_mask_token: bool = False) -> None:
        super().__init__()

        # 初始化 CLS 令牌参数
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
        
        # 如果使用掩码令牌,则初始化掩码令牌参数
        self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
        
        # 初始化补丁嵌入层
        self.patch_embeddings = ViTMSNPatchEmbeddings(config)
        num_patches = self.patch_embeddings.num_patches
        
        # 初始化位置嵌入参数
        self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
        
        # 初始化 dropout 层
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        
        # 保存配置信息
        self.config = config
    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
        """
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
        resolution images.

        Source:
        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
        """

        # 计算当前嵌入张量中的图块数量和预训练位置编码的位置数量
        num_patches = embeddings.shape[1] - 1
        num_positions = self.position_embeddings.shape[1] - 1

        # 如果图块数量与位置数量相等,并且高度与宽度相同,则直接返回预训练的位置编码
        if num_patches == num_positions and height == width:
            return self.position_embeddings
        
        # 从预训练的位置编码中提取类别位置编码和图块位置编码
        class_pos_embed = self.position_embeddings[:, 0]
        patch_pos_embed = self.position_embeddings[:, 1:]

        # 获取张量的维度信息
        dim = embeddings.shape[-1]

        # 计算图块窗口的高度和宽度
        patch_window_height = height // self.config.patch_size
        patch_window_width = width // self.config.patch_size

        # 为了避免插值时的浮点数误差,向高度和宽度添加一个小数值
        patch_window_height, patch_window_width = patch_window_height + 0.1, patch_window_width + 0.1

        # 将图块位置编码重塑为合适的形状,并进行维度置换
        patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

        # 使用双三次插值对图块位置编码进行插值
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            scale_factor=(
                patch_window_height / math.sqrt(num_positions),
                patch_window_width / math.sqrt(num_positions),
            ),
            mode="bicubic",
            align_corners=False,
        )

        # 再次进行维度置换和重塑,以便与类别位置编码拼接
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)

        # 返回拼接后的位置编码张量
        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
    ) -> torch.Tensor:
        # 获取输入张量的形状信息
        batch_size, num_channels, height, width = pixel_values.shape
        # 使用 patch_embeddings 方法将像素值转换为嵌入向量
        embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)

        if bool_masked_pos is not None:
            # 获取嵌入向量的序列长度
            seq_length = embeddings.shape[1]
            # 扩展 mask_token 到与 embeddings 相同的形状,用于替换被遮盖的视觉 tokens
            mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
            # 创建一个掩码,将布尔类型的遮盖位置转换为与 mask_tokens 相同类型的张量
            mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
            # 使用 mask 对 embeddings 进行覆盖处理,替换遮盖位置的 tokens
            embeddings = embeddings * (1.0 - mask) + mask_tokens * mask

        # 将 [CLS] token 添加到嵌入的 patch tokens 中
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        # 在第一维度上连接 cls_tokens 和 embeddings
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)

        # 添加位置编码到每个 token
        if interpolate_pos_encoding:
            # 使用 interpolate_pos_encoding 方法对 embeddings 进行插值处理
            embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
        else:
            # 直接添加预先计算好的位置编码到 embeddings
            embeddings = embeddings + self.position_embeddings

        # 对 embeddings 应用 dropout 操作
        embeddings = self.dropout(embeddings)

        # 返回最终的嵌入向量张量
        return embeddings
# 从transformers.models.vit.modeling_vit.ViTPatchEmbeddings复制而来,修改为ViTMSN的实现
class ViTMSNPatchEmbeddings(nn.Module):
    """
    这个类将形状为`(batch_size, num_channels, height, width)`的`pixel_values`转换为形状为`(batch_size, seq_length, hidden_size)`的初始隐藏状态(patch embeddings),
    以供Transformer使用。
    """

    def __init__(self, config):
        super().__init__()
        image_size, patch_size = config.image_size, config.patch_size
        num_channels, hidden_size = config.num_channels, config.hidden_size

        # 将image_size和patch_size转换为元组(tuple),如果它们不是可迭代对象
        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
        
        # 计算patch的数量
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
        
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_patches = num_patches

        # 使用Conv2d进行投影,将输入的num_channels维度转换为hidden_size维度
        self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
        batch_size, num_channels, height, width = pixel_values.shape
        
        # 检查输入的像素值是否与配置中的num_channels匹配
        if num_channels != self.num_channels:
            raise ValueError(
                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
                f" Expected {self.num_channels} but got {num_channels}."
            )
        
        # 如果不插值位置编码,检查输入图像的尺寸是否与配置中的image_size匹配
        if not interpolate_pos_encoding:
            if height != self.image_size[0] or width != self.image_size[1]:
                raise ValueError(
                    f"Input image size ({height}*{width}) doesn't match model"
                    f" ({self.image_size[0]}*{self.image_size[1]})."
                )
        
        # 对输入的像素值进行投影,并将结果展平和转置,以生成patch embeddings
        embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
        
        return embeddings


# 从transformers.models.vit.modeling_vit.ViTSelfAttention复制而来,修改为ViTMSN的实现
class ViTMSNSelfAttention(nn.Module):
    def __init__(self, config: ViTMSNConfig) -> None:
        super().__init__()
        # 检查隐藏层大小是否可以被注意力头数整除,并且配置中没有嵌入大小的属性
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            # 如果不符合条件,抛出数值错误异常
            raise ValueError(
                f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
                f"heads {config.num_attention_heads}."
            )

        # 初始化注意力头数和每个头的大小
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # 初始化查询、键、值的线性层
        self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
        self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
        self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)

        # 初始化注意力概率的Dropout
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        # 调整张量形状以便计算注意力得分
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        # 计算混合查询层
        mixed_query_layer = self.query(hidden_states)

        # 计算键和值的转置以便计算注意力得分
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)

        # 计算"查询"和"键"之间的点积,得到原始注意力分数
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        # 将注意力分数除以sqrt(注意力头的大小)进行缩放
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        # 对注意力分数进行归一化,得到注意力概率
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # 应用Dropout到注意力概率上,实际上是以一定概率将整个token置零以进行注意
        attention_probs = self.dropout(attention_probs)

        # 如果有头部遮罩,应用头部遮罩到注意力概率上
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        # 计算上下文层,将注意力概率乘以值层
        context_layer = torch.matmul(attention_probs, value_layer)

        # 调整上下文层的形状以适应输出
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)

        # 根据需要返回上下文层和注意力概率,或者仅返回上下文层
        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        return outputs
# 从 transformers.models.vit.modeling_vit.ViTSelfOutput 复制并修改为 ViT->ViTMSN
class ViTMSNSelfOutput(nn.Module):
    """
    The residual connection is defined in ViTMSNLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    """
    
    def __init__(self, config: ViTMSNConfig) -> None:
        super().__init__()
        # 定义一个全连接层,输入和输出维度都为 config.hidden_size
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 定义一个 Dropout 层,使用 config.hidden_dropout_prob 的概率进行随机失活
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 将输入的 hidden_states 应用全连接层 self.dense
        hidden_states = self.dense(hidden_states)
        # 对全连接层的输出应用 Dropout
        hidden_states = self.dropout(hidden_states)

        return hidden_states


# 从 transformers.models.vit.modeling_vit.ViTAttention 复制并修改为 ViT->ViTMSN
class ViTMSNAttention(nn.Module):
    def __init__(self, config: ViTMSNConfig) -> None:
        super().__init__()
        # 初始化 ViTMSNSelfAttention 和 ViTMSNSelfOutput 层
        self.attention = ViTMSNSelfAttention(config)
        self.output = ViTMSNSelfOutput(config)
        self.pruned_heads = set()

    def prune_heads(self, heads: Set[int]) -> None:
        if len(heads) == 0:
            return
        # 找到可裁剪的注意力头部并索引
        heads, index = find_pruneable_heads_and_indices(
            heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
        )

        # 裁剪线性层
        self.attention.query = prune_linear_layer(self.attention.query, index)
        self.attention.key = prune_linear_layer(self.attention.key, index)
        self.attention.value = prune_linear_layer(self.attention.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # 更新超参数并存储已裁剪的头部
        self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
        self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        # 将输入的 hidden_states 通过注意力层 self.attention 进行处理
        self_outputs = self.attention(hidden_states, head_mask, output_attentions)

        # 将注意力层的输出通过 self.output 层进行处理
        attention_output = self.output(self_outputs[0], hidden_states)

        outputs = (attention_output,) + self_outputs[1:]  # 如果有需要,添加注意力信息到输出中
        return outputs


# 从 transformers.models.vit.modeling_vit.ViTIntermediate 复制并修改为 ViT->ViTMSN
class ViTMSNIntermediate(nn.Module):
    def __init__(self, config: ViTMSNConfig) -> None:
        super().__init__()
        # 定义一个全连接层,输入维度为 config.hidden_size,输出维度为 config.intermediate_size
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        # 根据配置选择激活函数,存储在 self.intermediate_act_fn 中
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act
    # 定义一个前向传播方法,接受隐藏状态作为输入张量,并返回处理后的张量
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 使用全连接层对隐藏状态进行线性变换
        hidden_states = self.dense(hidden_states)
        # 对线性变换后的结果应用激活函数,例如ReLU等
        hidden_states = self.intermediate_act_fn(hidden_states)

        # 返回处理后的隐藏状态张量作为输出
        return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->ViTMSN
class ViTMSNOutput(nn.Module):
    def __init__(self, config: ViTMSNConfig) -> None:
        super().__init__()
        # 定义一个全连接层,将输入特征维度转换为配置中指定的隐藏层大小
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        # 定义一个dropout层,用于随机置零输入张量的部分元素,以防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 将输入的隐藏状态通过全连接层进行线性变换
        hidden_states = self.dense(hidden_states)
        # 对变换后的隐藏状态进行dropout处理
        hidden_states = self.dropout(hidden_states)

        # 将dropout后的隐藏状态与输入张量相加,实现残差连接
        hidden_states = hidden_states + input_tensor

        return hidden_states


# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMSN
class ViTMSNLayer(nn.Module):
    """This corresponds to the Block class in the timm implementation."""

    def __init__(self, config: ViTMSNConfig) -> None:
        super().__init__()
        # 定义块大小用于分块前馈网络的处理
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        # 序列长度维度,默认为1,用于处理自注意力
        self.seq_len_dim = 1
        # 定义注意力层,使用ViTMSNAttention类处理自注意力机制
        self.attention = ViTMSNAttention(config)
        # 定义中间层,使用ViTMSNIntermediate类处理中间层操作
        self.intermediate = ViTMSNIntermediate(config)
        # 定义输出层,使用ViTMSNOutput类处理输出层操作
        self.output = ViTMSNOutput(config)
        # 定义前层归一化层,使用LayerNorm对隐藏状态进行归一化
        self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 定义后层归一化层,同样使用LayerNorm对隐藏状态进行归一化
        self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        # 在ViTMSN中,先对隐藏状态进行前层归一化
        self_attention_outputs = self.attention(
            self.layernorm_before(hidden_states),
            head_mask,
            output_attentions=output_attentions,
        )
        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[1:]  # 如果输出注意力权重,将其加入到输出元组中

        # 第一个残差连接
        hidden_states = attention_output + hidden_states

        # 在ViTMSN中,也会在自注意力后进行后层归一化
        layer_output = self.layernorm_after(hidden_states)
        layer_output = self.intermediate(layer_output)

        # 第二个残差连接
        layer_output = self.output(layer_output, hidden_states)

        outputs = (layer_output,) + outputs

        return outputs


# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMSN
class ViTMSNEncoder(nn.Module):
    def __init__(self, config: ViTMSNConfig) -> None:
        super().__init__()
        self.config = config
        # 使用ViTMSNLayer构建编码器的多层堆叠
        self.layer = nn.ModuleList([ViTMSNLayer(config) for _ in range(config.num_hidden_layers)])
        # 是否使用梯度检查点,默认为False
        self.gradient_checkpointing = False

    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        ) -> Union[tuple, BaseModelOutput]:
        # 如果不输出隐藏状态,则初始化一个空元组
        all_hidden_states = () if output_hidden_states else None
        # 如果不输出注意力权重,则初始化一个空元组
        all_self_attentions = () if output_attentions else None

        # 遍历每个 Transformer 层
        for i, layer_module in enumerate(self.layer):
            # 如果需要输出隐藏状态,则将当前层的隐藏状态加入到 all_hidden_states 中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 获取当前层的头部掩码
            layer_head_mask = head_mask[i] if head_mask is not None else None

            # 如果启用了梯度检查点并且处于训练状态
            if self.gradient_checkpointing and self.training:
                # 使用梯度检查点函数进行前向传播计算
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    hidden_states,
                    layer_head_mask,
                    output_attentions,
                )
            else:
                # 普通情况下直接调用当前层的前向传播方法
                layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)

            # 更新隐藏状态为当前层的输出的第一个元素(通常是最终隐藏状态)
            hidden_states = layer_outputs[0]

            # 如果需要输出注意力权重,则将当前层的注意力权重加入到 all_self_attentions 中
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        # 如果需要输出隐藏状态,则将最终的隐藏状态加入到 all_hidden_states 中
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 如果不以字典形式返回结果,则返回所有非空元素的元组
        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
        # 以 BaseModelOutput 对象形式返回结果,包含最终隐藏状态、所有隐藏状态、所有注意力权重
        return BaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )
class ViTMSNPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 使用 ViTMSNConfig 作为模型配置类
    config_class = ViTMSNConfig
    # 模型基础名称前缀
    base_model_prefix = "vit"
    # 主要输入名称为 pixel_values
    main_input_name = "pixel_values"
    # 支持梯度检查点
    supports_gradient_checkpointing = True

    # todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211
    # when creating pre-training scripts.
    def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            # 与 TF 版本略有不同,TF 使用截断正态分布进行初始化
            # 参考 https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            # 初始化 LayerNorm 的权重
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


VIT_MSN_START_DOCSTRING = r"""
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`ViTMSNConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

VIT_MSN_INPUTS_DOCSTRING = r"""
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
            for details.

        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        interpolate_pos_encoding (`bool`, *optional*):
            Whether to interpolate the pre-trained position encodings.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
# 为 ViTMSNModel 类添加文档字符串,描述该模型输出原始隐藏状态而不带特定的输出头部
@add_start_docstrings(
    "The bare ViTMSN Model outputting raw hidden-states without any specific head on top.",
    VIT_MSN_START_DOCSTRING,
)
class ViTMSNModel(ViTMSNPreTrainedModel):
    def __init__(self, config: ViTMSNConfig, use_mask_token: bool = False):
        # 调用父类的初始化方法,传入配置对象
        super().__init__(config)
        # 保存配置对象
        self.config = config

        # 初始化嵌入层对象
        self.embeddings = ViTMSNEmbeddings(config, use_mask_token=use_mask_token)
        # 初始化编码器对象
        self.encoder = ViTMSNEncoder(config)

        # 初始化 LayerNorm 层,用于归一化隐藏状态向量
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        # 调用后处理方法,用于权重初始化和最终处理
        self.post_init()

    def get_input_embeddings(self) -> ViTMSNPatchEmbeddings:
        # 返回嵌入层的 patch_embeddings 属性,用于获取输入嵌入
        return self.embeddings.patch_embeddings

    def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        # 遍历需要剪枝的层和头部的字典
        for layer, heads in heads_to_prune.items():
            # 在编码器的指定层中,调用注意力头部的剪枝方法
            self.encoder.layer[layer].attention.prune_heads(heads)

    @add_start_docstrings_to_model_forward(VIT_MSN_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        bool_masked_pos: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 设置是否输出注意力权重,默认为模型配置中的输出设置
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 设置是否输出隐藏状态,默认为模型配置中的输出设置
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # 设置是否返回字典格式的输出,默认为模型配置中的设置使用返回字典

        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")
        # 如果输入的像素值为 None,则抛出值错误异常

        # Prepare head mask if needed
        # 准备需要的头部掩码
        # 1.0 in head_mask indicate we keep the head
        # head_mask 中的 1.0 表示我们保留该头部
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
        # 根据输入的头部掩码参数获取头部掩码,确保其形状符合模型的隐藏层数量和序列长度

        embedding_output = self.embeddings(
            pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
        )
        # 将像素值输入到嵌入层,根据 bool_masked_pos 和 interpolate_pos_encoding 参数进行相应的处理

        encoder_outputs = self.encoder(
            embedding_output,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 将嵌入输出传入编码器,获取编码器的输出结果

        sequence_output = encoder_outputs[0]
        # 从编码器输出中获取序列输出
        sequence_output = self.layernorm(sequence_output)
        # 序列输出经过 LayerNorm 处理

        if not return_dict:
            head_outputs = (sequence_output,)
            return head_outputs + encoder_outputs[1:]
        # 如果不要求返回字典格式,则返回头部输出和编码器其他输出

        return BaseModelOutput(
            last_hidden_state=sequence_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
        # 返回模型的基础输出,包括最后的隐藏状态、隐藏状态列表和注意力权重列表
# 注意:我们尚未为分类头部准备权重。此类用于希望对基础模型(ViTMSNModel)进行微调的用户。
@add_start_docstrings(
    """
    在顶部具有图像分类头的 ViTMSN 模型,例如用于 ImageNet。
    """,
    VIT_MSN_START_DOCSTRING,
)
class ViTMSNForImageClassification(ViTMSNPreTrainedModel):
    def __init__(self, config: ViTMSNConfig) -> None:
        super().__init__(config)

        self.num_labels = config.num_labels
        self.vit = ViTMSNModel(config)

        # 分类器头部
        self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()

        # 初始化权重并应用最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(VIT_MSN_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        return_dict: Optional[bool] = None,
posted @ 2024-07-01 10:57  绝不原创的飞龙  阅读(20)  评论(0编辑  收藏  举报