Transformers-源码解析-四十九-

Transformers 源码解析(四十九)

.\models\flaubert\__init__.py

# 导入必要的模块和函数
from typing import TYPE_CHECKING
# 引入自定义的异常类和模块加载函数
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available

# 定义模块导入结构
_import_structure = {
    "configuration_flaubert": ["FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FlaubertConfig", "FlaubertOnnxConfig"],
    "tokenization_flaubert": ["FlaubertTokenizer"],
}

# 检查是否有 Torch 库可用,如果不可用则抛出自定义异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果 Torch 可用,则添加 Flaubert 相关的模型类到导入结构中
    _import_structure["modeling_flaubert"] = [
        "FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
        "FlaubertForMultipleChoice",
        "FlaubertForQuestionAnswering",
        "FlaubertForQuestionAnsweringSimple",
        "FlaubertForSequenceClassification",
        "FlaubertForTokenClassification",
        "FlaubertModel",
        "FlaubertWithLMHeadModel",
        "FlaubertPreTrainedModel",
    ]

# 检查是否有 TensorFlow 库可用,如果不可用则抛出自定义异常
try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果 TensorFlow 可用,则添加 TensorFlow 下的 Flaubert 相关模型类到导入结构中
    _import_structure["modeling_tf_flaubert"] = [
        "TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
        "TFFlaubertForMultipleChoice",
        "TFFlaubertForQuestionAnsweringSimple",
        "TFFlaubertForSequenceClassification",
        "TFFlaubertForTokenClassification",
        "TFFlaubertModel",
        "TFFlaubertPreTrainedModel",
        "TFFlaubertWithLMHeadModel",
    ]

# 如果是类型检查模式,则导入特定的配置和模型类
if TYPE_CHECKING:
    from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig, FlaubertOnnxConfig
    from .tokenization_flaubert import FlaubertTokenizer

    # 检查 Torch 是否可用,如果可用则导入 Flaubert 相关模型类
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_flaubert import (
            FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
            FlaubertForMultipleChoice,
            FlaubertForQuestionAnswering,
            FlaubertForQuestionAnsweringSimple,
            FlaubertForSequenceClassification,
            FlaubertForTokenClassification,
            FlaubertModel,
            FlaubertPreTrainedModel,
            FlaubertWithLMHeadModel,
        )

    # 检查 TensorFlow 是否可用,如果可用则导入 TensorFlow 下的 Flaubert 相关模型类
    try:
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入所需模块和类,这些来自于当前包的子模块 modeling_tf_flaubert
        from .modeling_tf_flaubert import (
            TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,  # 导入预训练模型的存档列表常量
            TFFlaubertForMultipleChoice,  # 导入用于多项选择任务的 Flaubert 模型类
            TFFlaubertForQuestionAnsweringSimple,  # 导入用于简单问答任务的 Flaubert 模型类
            TFFlaubertForSequenceClassification,  # 导入用于序列分类任务的 Flaubert 模型类
            TFFlaubertForTokenClassification,  # 导入用于标记分类任务的 Flaubert 模型类
            TFFlaubertModel,  # 导入 Flaubert 模型基类
            TFFlaubertPreTrainedModel,  # 导入预训练 Flaubert 模型基类
            TFFlaubertWithLMHeadModel,  # 导入带有语言模型头的 Flaubert 模型类
        )
else:
    # 导入系统模块 sys
    import sys

    # 将当前模块 (__name__) 的模块对象替换为一个懒加载模块对象 (_LazyModule)
    # _LazyModule 的参数依次为模块名 (__name__), 模块所在文件名 (__file__), 导入结构 (_import_structure)
    # module_spec 参数指定模块规范 (__spec__)
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\flava\configuration_flava.py

# coding=utf-8
# 上面的声明指定了代码文件的编码格式为 UTF-8,确保支持中文等非ASCII字符
# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
# 版权声明,保留所有权利,指出代码的版权归 Meta Platforms 和 The HuggingFace Team 所有
#
# Licensed under the Apache License, Version 2.0 (the "License");
# 根据 Apache License, Version 2.0 许可证授权,可以自由使用本代码,遵循许可证的条件
# you may not use this file except in compliance with the License.
# 除非符合许可证的条件,否则不得使用此文件
# You may obtain a copy of the License at
# 可以在以下网址获取许可证的副本
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 除非适用法律要求或书面同意,否则按“原样”分发的软件没有任何明示或暗示的保证或条件
# See the License for the specific language governing permissions and
# limitations under the License.
# 详细了解许可证,包括特定语言控制权限和限制,请参阅许可证
""" FLAVA model configurations"""
# FLAVA 模型的配置信息

import os
# 导入操作系统相关模块
from typing import Any, Dict, Union
# 导入用于类型提示的模块

from ...configuration_utils import PretrainedConfig
# 从配置工具中导入预训练配置类
from ...utils import logging
# 导入日志记录相关的模块

logger = logging.get_logger(__name__)
# 获取当前模块的日志记录器

FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "facebook/flava-full": "https://huggingface.co/facebook/flava-full/resolve/main/config.json",
}
# FLAVA 预训练模型的配置文件映射表,指定了模型名称及其对应的配置文件地址

class FlavaImageConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`FlavaImageModel`]. It is used to instantiate an
    FLAVA model according to the specified arguments, defining the model architecture.
    # 这是用于存储 [`FlavaImageModel`] 配置的配置类,用于根据指定的参数实例化 FLAVA 模型,定义模型架构。

    Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
    [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.
    # 使用默认值实例化配置将产生类似于 FLAVA [facebook/flava-full](https://huggingface.co/facebook/flava-full) 架构的配置。

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    # 配置对象继承自 [`PretrainedConfig`],可用于控制模型输出。更多信息,请阅读 [`PretrainedConfig`] 的文档。
    # 定义模型类型为 FLAVA 图像模型
    model_type = "flava_image_model"
    # 初始化函数,设置模型配置参数
    def __init__(
        self,
        hidden_size: int = 768,                             # 隐藏层大小,默认为768
        num_hidden_layers: int = 12,                        # 隐藏层数,默认为12
        num_attention_heads: int = 12,                      # 注意力头数,默认为12
        intermediate_size: int = 3072,                      # 中间层大小,默认为3072
        hidden_act: int = "gelu",                           # 隐藏层激活函数,默认为"gelu"
        hidden_dropout_prob: float = 0.0,                   # 隐藏层dropout概率,默认为0.0
        attention_probs_dropout_prob: float = 0.0,          # 注意力概率dropout概率,默认为0.0
        initializer_range: float = 0.02,                    # 初始化范围,默认为0.02
        layer_norm_eps: float = 1e-12,                      # LayerNorm的epsilon,默认为1e-12
        image_size: int = 224,                              # 图像大小,默认为224
        patch_size: int = 16,                               # 图像块大小,默认为16
        num_channels: int = 3,                              # 图像通道数,默认为3
        qkv_bias: bool = True,                              # 是否在QKV中使用偏置,默认为True
        mask_token: bool = True,                            # 是否使用掩码token,默认为True
        vocab_size: int = 8192,                             # 词汇表大小,默认为8192
        **kwargs,                                           # 其他关键字参数
    ):
        super().__init__(**kwargs)                          # 调用父类初始化方法

        self.hidden_size = hidden_size                      # 设置隐藏层大小
        self.num_hidden_layers = num_hidden_layers          # 设置隐藏层数
        self.num_attention_heads = num_attention_heads      # 设置注意力头数
        self.intermediate_size = intermediate_size          # 设置中间层大小
        self.hidden_act = hidden_act                        # 设置隐藏层激活函数
        self.hidden_dropout_prob = hidden_dropout_prob      # 设置隐藏层dropout概率
        self.attention_probs_dropout_prob = attention_probs_dropout_prob  # 设置注意力概率dropout概率
        self.initializer_range = initializer_range          # 设置初始化范围
        self.layer_norm_eps = layer_norm_eps                # 设置LayerNorm的epsilon
        self.image_size = image_size                        # 设置图像大小
        self.patch_size = patch_size                        # 设置图像块大小
        self.num_channels = num_channels                    # 设置图像通道数
        self.qkv_bias = qkv_bias                            # 设置是否在QKV中使用偏置
        self.mask_token = mask_token                        # 设置是否使用掩码token
        self.vocab_size = vocab_size                        # 设置词汇表大小

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
        cls._set_token_in_kwargs(kwargs)                    # 设置kwargs中的token

        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)  # 获取配置字典和更新后的kwargs

        # 如果从FlavaConfig加载,获取图像配置字典
        if config_dict.get("model_type") == "flava":
            config_dict = config_dict["image_config"]

        # 如果配置字典中包含model_type,并且cls有model_type属性,并且不匹配时,发出警告
        if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
            logger.warning(
                f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
                f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
            )

        return cls.from_dict(config_dict, **kwargs)          # 从配置字典和kwargs创建实例
class FlavaTextConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`FlavaTextModel`]. It is used to instantiate an
    FLAVA model according to the specified arguments, defining the model architecture.

    Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
    [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Example:

    ```
    >>> from transformers import FlavaTextConfig, FlavaTextModel

    >>> # Initializing a FlavaTextModel with  style configuration
    >>> configuration = FlavaTextConfig()

    >>> # Initializing a FlavaTextModel model (with random weights) from the style configuration
    >>> model = FlavaTextModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    model_type = "flava_text_model"

    def __init__(
        self,
        vocab_size: int = 30522,
        type_vocab_size: int = 2,
        max_position_embeddings: int = 512,
        position_embedding_type: str = "absolute",
        hidden_size: int = 768,
        num_hidden_layers: int = 12,
        num_attention_heads: int = 12,
        intermediate_size: int = 3072,
        hidden_act: str = "gelu",
        hidden_dropout_prob: float = 0.0,
        attention_probs_dropout_prob: float = 0.0,
        initializer_range: float = 0.02,
        layer_norm_eps: float = 1e-12,
        pad_token_id: int = 0,
        qkv_bias: bool = True,
        **kwargs,
    ):
        # 调用父类构造函数,初始化继承自 PretrainedConfig 的属性
        super().__init__(**kwargs)

        # 初始化配置参数
        self.vocab_size = vocab_size                      # 词汇表大小
        self.type_vocab_size = type_vocab_size            # 类型词汇表大小
        self.max_position_embeddings = max_position_embeddings  # 最大位置嵌入长度
        self.position_embedding_type = position_embedding_type  # 位置嵌入类型
        self.hidden_size = hidden_size                    # 隐藏层大小
        self.num_hidden_layers = num_hidden_layers        # 隐藏层层数
        self.num_attention_heads = num_attention_heads    # 注意力头数
        self.intermediate_size = intermediate_size        # 中间层大小
        self.hidden_act = hidden_act                      # 隐藏层激活函数
        self.hidden_dropout_prob = hidden_dropout_prob    # 隐藏层 dropout 概率
        self.attention_probs_dropout_prob = attention_probs_dropout_prob  # 注意力 dropout 概率
        self.initializer_range = initializer_range        # 初始化范围
        self.layer_norm_eps = layer_norm_eps              # 层归一化 epsilon
        self.qkv_bias = qkv_bias                          # 是否使用 QKV 偏置
        self.pad_token_id = pad_token_id                  # 填充 token 的 ID

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
        # 在kwargs中设置token
        cls._set_token_in_kwargs(kwargs)

        # 获取预训练模型的配置字典和更新后的kwargs
        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

        # 如果配置字典指定的模型类型是"flava",则使用其text_config作为配置字典
        if config_dict.get("model_type") == "flava":
            config_dict = config_dict["text_config"]

        # 如果配置字典中有"model_type"字段,并且类(cls)具有"model_type"属性,
        # 且配置字典中的模型类型不等于类的模型类型,则发出警告
        if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
            logger.warning(
                f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
                f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
            )

        # 使用配置字典和kwargs创建实例
        return cls.from_dict(config_dict, **kwargs)
# 定义一个配置类,用于存储 FlavaMultimodalModel 的配置信息。该类继承自 PretrainedConfig。
class FlavaMultimodalConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`FlavaMultimodalModel`]. It is used to instantiate
    an FLAVA model according to the specified arguments, defining the model architecture.

    Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
    [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        hidden_size (`int`, *optional*, defaults to 768):
            Dimensionality of the encoder layers and the pooler layer.
        num_hidden_layers (`int`, *optional*, defaults to 6):
            Number of hidden layers in the Transformer encoder.
        num_attention_heads (`int`, *optional*, defaults to 12):
            Number of attention heads for each attention layer in the Transformer encoder.
        intermediate_size (`int`, *optional*, defaults to 3072):
            Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
        hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
            `"relu"`, `"selu"` and `"gelu_new"` are supported.
        hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
            The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
        attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        layer_norm_eps (`float`, *optional*, defaults to 1e-12):
            The epsilon used by the layer normalization layers.
        qkv_bias (`bool`, *optional*, defaults to `True`):
            Whether to add a bias to the queries, keys and values.
        use_cls_token (`bool`, *optional*, defaults to `True`):
            Whether to use an extra CLS token for multimodal settings. Usually needed by the FLAVA model.

    Example:

    ```
    >>> from transformers import FlavaMultimodalConfig, FlavaMultimodalModel

    >>> # Initializing a FlavaMultimodalModel with  style configuration
    >>> configuration = FlavaMultimodalConfig()

    >>> # Initializing a FlavaMultimodalModel model (with random weights) from the style configuration
    >>> model = FlavaMultimodalModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""
    
    # 类属性,指定模型类型为 flava_multimodal_model
    model_type = "flava_multimodal_model"
    # 初始化方法,用于初始化模型配置参数
    def __init__(
        self,
        hidden_size: int = 768,                      # 隐藏层大小,默认为768
        num_hidden_layers: int = 6,                  # 隐藏层数,默认为6
        num_attention_heads: int = 12,               # 注意力头的数量,默认为12
        intermediate_size: int = 3072,               # 中间层大小,默认为3072
        hidden_act: int = "gelu",                    # 隐藏层激活函数,默认为gelu
        hidden_dropout_prob: int = 0.0,              # 隐藏层的dropout概率,默认为0.0
        attention_probs_dropout_prob: int = 0.0,     # 注意力概率的dropout概率,默认为0.0
        initializer_range: float = 0.02,             # 初始化范围,默认为0.02
        layer_norm_eps: float = 1e-12,               # 层归一化的epsilon值,默认为1e-12
        qkv_bias: bool = True,                       # 是否在QKV中使用偏置,默认为True
        use_cls_token: bool = True,                  # 是否使用CLS令牌,默认为True
        **kwargs,                                    # 其余关键字参数
    ):
        # 调用父类的初始化方法,传入其余的关键字参数
        super().__init__(**kwargs)

        # 设置模型的各种配置参数
        self.hidden_size = hidden_size                 # 设置隐藏层大小
        self.num_hidden_layers = num_hidden_layers     # 设置隐藏层数
        self.num_attention_heads = num_attention_heads # 设置注意力头的数量
        self.intermediate_size = intermediate_size     # 设置中间层大小
        self.hidden_act = hidden_act                   # 设置隐藏层激活函数
        self.hidden_dropout_prob = hidden_dropout_prob # 设置隐藏层的dropout概率
        self.attention_probs_dropout_prob = attention_probs_dropout_prob  # 设置注意力概率的dropout概率
        self.initializer_range = initializer_range     # 设置初始化范围
        self.layer_norm_eps = layer_norm_eps           # 设置层归一化的epsilon值
        self.qkv_bias = qkv_bias                       # 设置是否在QKV中使用偏置
        self.use_cls_token = use_cls_token             # 设置是否使用CLS令牌

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
        # 在kwargs中设置token
        cls._set_token_in_kwargs(kwargs)

        # 获取模型的配置字典和剩余的kwargs参数
        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

        # 如果配置字典中的模型类型是"flava",则使用多模态配置字典
        if config_dict.get("model_type") == "flava":
            config_dict = config_dict["multimodal_config"]

        # 如果配置字典中包含"model_type"且模型类型与当前类的模型类型不匹配,发出警告
        if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
            logger.warning(
                f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
                f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
            )

        # 使用配置字典和kwargs参数创建一个新的实例
        return cls.from_dict(config_dict, **kwargs)
class FlavaImageCodebookConfig(PretrainedConfig):
    model_type = "flava_image_codebook"

    r"""
    [`FlavaImageCodebookConfig`] is the configuration class to store the configuration of a [`FlavaImageCodebook`]. It
    is used to instantiate an FLAVA model according to the specified arguments, defining the model architecture.
    Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
    [facebook/flava-image-codebook](https://huggingface.co/facebook/flava-image-codebook) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        num_groups (`int`, defaults to 4):
            Number of groups to be created. This parameter as of now doesn't affect the model and is used for some
            internal calculation and estimations.
        input_channels (`int`, defaults to 3):
            Number of channels in the image to be passed.
        num_blocks_per_group (`int`, defaults to 2):
            Number of conv-based blocks per group.
        hidden_size (`int`, defaults to 256):
            Size of hidden dim for the blocks.
        vocab_size (`int`, defaults to 8192):
            Size of the output vocabulary for the codebook.
        freeze (`bool`, defaults to `True`):
            Whether to freeze the weights of the model.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        kwargs (*optional*):
            Dictionary of keyword arguments.

    Example:

    ```
    >>> from transformers import FlavaImageCodebookConfig, FlavaImageCodebook

    >>> # Initializing a FlavaImageCodebook with style configuration
    >>> configuration = FlavaImageCodebookConfig()

    >>> # Initializing a FlavaImageCodebook model (with random weights) from the style configuration
    >>> model = FlavaImageCodebook(configuration)
    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    """

    def __init__(
        self,
        num_groups: int = 4,
        input_channels: int = 3,
        num_blocks_per_group: int = 2,
        hidden_size: int = 256,
        vocab_size: int = 8192,
        freeze: int = True,
        initializer_range: float = 0.02,
        **kwargs,
    ):
        # 调用父类的初始化方法,传入所有额外的关键字参数
        super().__init__(**kwargs)
        # 设置对象的各项配置参数
        self.num_groups = num_groups  # 设置创建的组数
        self.input_channels = input_channels  # 设置传递的图像通道数
        self.num_blocks_per_group = num_blocks_per_group  # 设置每组的卷积块数
        self.hidden_size = hidden_size  # 设置块的隐藏维度大小
        self.vocab_size = vocab_size  # 设置代码本输出词汇表的大小
        self.freeze = freeze  # 设置是否冻结模型权重
        self.initializer_range = initializer_range  # 设置权重矩阵初始化的标准差范围

    @classmethod
    # 根据预训练模型名称或路径及额外参数创建一个预训练配置对象
    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
        # 在 kwargs 中设置 token
        cls._set_token_in_kwargs(kwargs)

        # 获取预训练模型的配置字典和更新后的 kwargs
        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

        # 如果配置字典指定模型类型为 "flava",则从中获取图像代码本配置字典
        if config_dict.get("model_type") == "flava":
            config_dict = config_dict["image_codebook_config"]

        # 如果配置字典中包含 "model_type",并且类有 "model_type" 属性,并且配置的模型类型与类中定义的不同,
        # 则发出警告提示,因为这可能会导致不同配置的模型出现错误。
        if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
            logger.warning(
                f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
                f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
            )

        # 根据配置字典和 kwargs 创建预训练配置对象并返回
        return cls.from_dict(config_dict, **kwargs)
# `FlavaConfig` 是一个配置类,继承自 `PretrainedConfig`
class FlavaConfig(PretrainedConfig):
    r"""
    [`FlavaConfig`] is the configuration class to store the configuration of a [`FlavaModel`]. It is used to
    instantiate FLAVA model according to the specified arguments, defining the text model, image model, image codebook
    and multimodal model configs. Instantiating a configuration with the defaults will yield a similar configuration to
    that of the FLAVA [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    Args:
        text_config (`dict`, *optional*):
            文本配置选项的字典,用于初始化 FlavaTextConfig。
        image_config (`dict`, *optional*):
            图像配置选项的字典,用于初始化 FlavaImageConfig。
        multimodal_config (`dict`, *optional*):
            多模态配置选项的字典,用于初始化 FlavaMultimodalConfig。
        hidden_size (`int`, *optional*, defaults to 768):
            编码器层和汇聚层的维度。
        layer_norm_eps (`float`, *optional*, defaults to 1e-12):
            层归一化层使用的 epsilon 值。
        projection_dim (`int`, *optional*, defaults to 512):
            文本和图像投影层的维度。
        logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
            *logit_scale* 参数的初始值,默认按照 FLAVA/CLIP 实现使用的默认值。
        initializer_range (`float`, *optional*, defaults to 0.02):
            用于初始化所有权重矩阵的截断正态分布的标准差。
        ce_ignore_index (`int`, *optional*, defaults to -100):
            交叉熵忽略的索引。
        mim_weight (`float`, *optional*, defaults to 1.0):
            分配给 MIM(Masked Image Modeling)单模态损失的权重。
        mlm_weight (`float`, *optional*, defaults to 1.0):
            分配给 MLM(Masked Language Modeling)单模态损失的权重。
        global_contrastive_weight (`float`, *optional*, defaults to 1.0):
            分配给全局对比度交叉对齐损失的权重。
        itm_weight (`float`, *optional*, defaults to 1.0):
            分配给图像-文本匹配多模态损失的权重。
        mmm_image_weight (`float`, *optional*, defaults to 1.0):
            分配给 MMM 损失的图像部分的权重。
        mmm_text_weight (`float`, *optional*, defaults to 1.0):
            分配给 MMM 损失的文本部分的权重。
        global_backprop_contrastive (`bool`, *optional*, defaults to `True`):
            是否在对比度损失中通过所有工作者进行全局反向传播。
        skip_unmasked_multimodal_encoder (`bool`, *optional*, defaults to `True`):
            是否跳过运行未屏蔽的多模态编码器,其输出不被 FLAVA 损失使用。
        return_loss (`bool`, *optional*, defaults to `True`):
            是否返回损失值。

        kwargs (*optional*):
            关键字参数的字典。

    Example:

    ```
    >>> from transformers import FlavaConfig, FlavaModel, FlavaForPreTraining

    >>> # 使用风格配置初始化 FlavaConfig
    >>> configuration = FlavaConfig()

    >>> # 使用风格配置初始化 FlavaModel 和 FlavaForPreTraining 模型(带有随机权重)
    ```
    >>> model = FlavaModel(configuration)
    >>> model_pre = FlavaForPreTraining(configuration)


    # 实例化 FlavaModel 和 FlavaForPreTraining 对象,使用给定的 configuration 配置
    >>> configuration = model.config
    >>> configuration_pre = model_pre.config


    model_type = "flava"

    def __init__(
        self,
        image_config: Dict[str, Any] = None,
        text_config: Dict[str, Any] = None,
        multimodal_config: Dict[str, Any] = None,
        image_codebook_config: Dict[str, Any] = None,
        hidden_size: int = 768,
        layer_norm_eps: float = 1e-12,
        projection_dim: int = 768,
        init_codebook: bool = True,
        logit_scale_init_value: float = 2.6592,
        initializer_range: float = 0.02,
        ce_ignore_index: int = -100,
        mim_weight: float = 1.0,
        mlm_weight: float = 1.0,
        global_contrastive_weight: float = 1.0,
        itm_weight: float = 1.0,
        mmm_image_weight: float = 1.0,
        mmm_text_weight: float = 1.0,
        global_backprop_contrastive: bool = True,
        skip_unmasked_multimodal_encoder: bool = True,
        return_loss: bool = True,
        **kwargs,
    ):


        r"""
        Instantiate a [`FlavaConfig`] (or a derived class) from flava text model configuration, flava image model
        configuration, flava multimodal model and flava codebook model configuration.

        Returns:
            [`FlavaConfig`]: An instance of a configuration object
        """

        # 使用给定的配置参数初始化 FlavaConfig 类的实例
        return cls(
            image_config=image_config.to_dict(),
            text_config=text_config.to_dict(),
            multimodal_config=multimodal_config.to_dict(),
            image_codebook_config=image_codebook_config.to_dict(),
            **kwargs,
        )

.\models\flava\convert_dalle_to_flava_codebook.py

# 导入必要的模块和库
import argparse  # 导入用于解析命令行参数的模块
import os  # 导入操作系统相关功能的模块

import torch  # 导入PyTorch深度学习库

from transformers import FlavaImageCodebook, FlavaImageCodebookConfig  # 导入transformers库中的模型和配置


def rreplace(s, old, new, occurrence):
    # 从字符串末尾向前查找并替换指定次数的子字符串
    li = s.rsplit(old, occurrence)
    return new.join(li)


def count_parameters(state_dict):
    # 统计模型参数数量
    # 对于不属于"encoder.embeddings"的参数,计算其总和
    return sum(param.float().sum() if "encoder.embeddings" not in key else 0 for key, param in state_dict.items())


def upgrade_state_dict(state_dict):
    # 更新模型状态字典中的键名,以符合transformers的设计规范
    upgrade = {}

    group_keys = ["group_1", "group_2", "group_3", "group_4"]
    for key, value in state_dict.items():
        for group_key in group_keys:
            if group_key in key:
                key = key.replace(f"{group_key}.", f"{group_key}.group.")

        if "res_path" in key:
            key = key.replace("res_path.", "res_path.path.")

        if key.endswith(".w"):
            key = rreplace(key, ".w", ".weight", 1)
        if key.endswith(".b"):
            key = rreplace(key, ".b", ".bias", 1)

        upgrade[key] = value.float()

    return upgrade


@torch.no_grad()
def convert_dalle_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None, save_checkpoint=True):
    """
    复制/粘贴/调整模型的权重以适应transformers设计。
    """
    from dall_e import Encoder  # 导入dall-e项目中的Encoder模型

    encoder = Encoder()  # 实例化Encoder模型对象
    if os.path.exists(checkpoint_path):
        ckpt = torch.load(checkpoint_path)  # 如果本地存在checkpoint文件,则加载
    else:
        ckpt = torch.hub.load_state_dict_from_url(checkpoint_path)  # 否则,从URL加载模型权重

    if isinstance(ckpt, Encoder):
        ckpt = ckpt.state_dict()  # 如果加载的是Encoder对象,则获取其状态字典
    encoder.load_state_dict(ckpt)  # 加载Encoder模型的权重

    if config_path is not None:
        config = FlavaImageCodebookConfig.from_pretrained(config_path)  # 如果提供了配置文件路径,则从中加载配置
    else:
        config = FlavaImageCodebookConfig()  # 否则使用默认配置

    hf_model = FlavaImageCodebook(config).eval()  # 根据配置实例化FlavaImageCodebook模型,并设置为评估模式
    state_dict = encoder.state_dict()  # 获取Encoder模型的状态字典

    hf_state_dict = upgrade_state_dict(state_dict)  # 将Encoder模型的状态字典转换为适应transformers的格式
    hf_model.load_state_dict(hf_state_dict)  # 加载适应transformers格式的状态字典到FlavaImageCodebook模型
    hf_state_dict = hf_model.state_dict()  # 获取转换后的模型状态字典
    hf_count = count_parameters(hf_state_dict)  # 统计转换后模型的参数数量
    state_dict_count = count_parameters(state_dict)  # 统计原始Encoder模型的参数数量

    assert torch.allclose(hf_count, state_dict_count, atol=1e-3)  # 断言转换后的模型参数数量与原始模型参数数量的接近性

    if save_checkpoint:
        hf_model.save_pretrained(pytorch_dump_folder_path)  # 如果指定保存checkpoint,则保存模型到指定路径
    else:
        return hf_state_dict  # 否则返回转换后的模型状态字典


if __name__ == "__main__":
    parser = argparse.ArgumentParser()  # 创建命令行参数解析器对象
    # 解析命令行参数,获取用户输入的 PyTorch 模型输出路径
    parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
    # 解析命令行参数,获取用户输入的 flava 检查点路径
    parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to flava checkpoint")
    # 解析命令行参数,获取用户输入的模型配置文件路径(通常是一个 JSON 文件)
    parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
    # 解析命令行参数,将所有参数解析并存储在 args 对象中
    args = parser.parse_args()
    
    # 调用函数 convert_dalle_checkpoint,传递解析后的参数:
    # args.checkpoint_path:flava 检查点路径
    # args.pytorch_dump_folder_path:PyTorch 模型输出路径
    # args.config_path:模型配置文件路径
    convert_dalle_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)

.\models\flava\convert_flava_original_pytorch_to_hf.py

# 导入命令行参数解析库
import argparse
# 导入操作系统功能模块
import os

# 导入PyTorch库
import torch

# 导入transformers库中的FlavaConfig和FlavaForPreTraining类
from transformers import FlavaConfig, FlavaForPreTraining
# 导入convert_dalle_to_flava_codebook模块中的convert_dalle_checkpoint函数
from transformers.models.flava.convert_dalle_to_flava_codebook import convert_dalle_checkpoint


# 定义函数:计算模型参数数量
def count_parameters(state_dict):
    # 对模型参数进行求和计算,但跳过名称中包含"encoder.embeddings"的部分
    return sum(param.float().sum() if "encoder.embeddings" not in key else 0 for key, param in state_dict.items())


# 定义函数:升级模型状态字典
def upgrade_state_dict(state_dict, codebook_state_dict):
    # 初始化升级后的状态字典
    upgrade = {}

    # 遍历原始状态字典中的键值对
    for key, value in state_dict.items():
        # 如果键名中包含"text_encoder.embeddings"或"image_encoder.embeddings",则跳过处理
        if "text_encoder.embeddings" in key or "image_encoder.embeddings" in key:
            continue

        # 替换键名中的特定子串,以适配新模型结构
        key = key.replace("heads.cmd.mim_head.cls.predictions", "mmm_image_head")
        key = key.replace("heads.cmd.mlm_head.cls.predictions", "mmm_text_head")
        key = key.replace("heads.cmd.itm_head.cls", "itm_head")
        key = key.replace("heads.cmd.itm_head.pooler", "itm_head.pooler")
        key = key.replace("heads.cmd.clip_head.logit_scale", "flava.logit_scale")
        key = key.replace("heads.fairseq_mlm.cls.predictions", "mlm_head")
        key = key.replace("heads.imagenet.mim_head.cls.predictions", "mim_head")
        key = key.replace("mm_text_projection", "flava.text_to_mm_projection")
        key = key.replace("mm_image_projection", "flava.image_to_mm_projection")
        key = key.replace("image_encoder.module", "flava.image_model")
        key = key.replace("text_encoder.module", "flava.text_model")
        key = key.replace("mm_encoder.module.encoder.cls_token", "flava.multimodal_model.cls_token")
        key = key.replace("mm_encoder.module", "flava.multimodal_model")
        key = key.replace("text_projection", "flava.text_projection")
        key = key.replace("image_projection", "flava.image_projection")

        # 将处理后的键值对应存入升级后的状态字典
        upgrade[key] = value.float()

    # 将代码簿状态字典中的键值对应存入升级后的状态字典,前缀为"image_codebook."
    for key, value in codebook_state_dict.items():
        upgrade[f"image_codebook.{key}"] = value

    return upgrade


# 定义函数:转换FLAVA模型的检查点
@torch.no_grad()
def convert_flava_checkpoint(checkpoint_path, codebook_path, pytorch_dump_folder_path, config_path=None):
    """
    Copy/paste/tweak model's weights to transformers design.
    将模型权重复制/粘贴/调整为transformers设计。
    """
    # 如果提供了配置文件路径,则从预训练配置文件中加载配置,否则使用默认配置
    if config_path is not None:
        config = FlavaConfig.from_pretrained(config_path)
    else:
        config = FlavaConfig()

    # 创建一个FlavaForPreTraining模型,并设置为评估模式
    hf_model = FlavaForPreTraining(config).eval()
    # 调用函数 `convert_dalle_checkpoint`,将 `codebook_path` 转换为 DALL-E 模型的状态字典
    codebook_state_dict = convert_dalle_checkpoint(codebook_path, None, save_checkpoint=False)

    # 检查 `checkpoint_path` 是否存在
    if os.path.exists(checkpoint_path):
        # 如果存在,则从本地加载 PyTorch 模型状态字典
        state_dict = torch.load(checkpoint_path, map_location="cpu")
    else:
        # 如果不存在,则从指定的 URL 加载 PyTorch 模型状态字典
        state_dict = torch.hub.load_state_dict_from_url(checkpoint_path, map_location="cpu")

    # 升级模型的状态字典 `state_dict`,结合 `codebook_state_dict`
    hf_state_dict = upgrade_state_dict(state_dict, codebook_state_dict)

    # 将升级后的状态字典加载到 `hf_model` 中
    hf_model.load_state_dict(hf_state_dict)

    # 获取 `hf_model` 的当前状态字典
    hf_state_dict = hf_model.state_dict()

    # 计算 `hf_model` 中可训练参数的总数
    hf_count = count_parameters(hf_state_dict)

    # 计算总共的模型参数数目,包括 `state_dict` 和 `codebook_state_dict`
    state_dict_count = count_parameters(state_dict) + count_parameters(codebook_state_dict)

    # 使用断言确保 `hf_count` 与 `state_dict_count` 之间的参数数量非常接近,允许误差为 1e-3
    assert torch.allclose(hf_count, state_dict_count, atol=1e-3)

    # 将 `hf_model` 的模型权重保存到指定的 PyTorch 转储文件夹路径中
    hf_model.save_pretrained(pytorch_dump_folder_path)
# 如果当前脚本作为主程序运行,则执行以下代码块
if __name__ == "__main__":
    # 创建参数解析器对象
    parser = argparse.ArgumentParser()
    # 添加命令行参数,用于指定输出的 PyTorch 模型的路径
    parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
    # 添加命令行参数,用于指定 flava checkpoint 的路径
    parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to flava checkpoint")
    # 添加命令行参数,用于指定 flava codebook checkpoint 的路径
    parser.add_argument("--codebook_path", default=None, type=str, help="Path to flava codebook checkpoint")
    # 添加命令行参数,用于指定待转换模型的 hf config.json 文件的路径
    parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
    # 解析命令行参数
    args = parser.parse_args()

    # 调用函数 convert_flava_checkpoint,传入命令行参数中指定的路径信息
    convert_flava_checkpoint(args.checkpoint_path, args.codebook_path, args.pytorch_dump_folder_path, args.config_path)

.\models\flava\feature_extraction_flava.py

# coding=utf-8
# 文件编码声明为 UTF-8

# 版权声明及许可证信息
# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""FLAVA 的特征提取器类。"""

# 引入警告模块
import warnings

# 引入日志模块
from ...utils import logging
# 引入 FLAVA 图像处理模块中的特征提取器类
from .image_processing_flava import FlavaImageProcessor

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# FLAVA 特征提取器类,继承自 FlavaImageProcessor 类
class FlavaFeatureExtractor(FlavaImageProcessor):
    # 初始化方法
    def __init__(self, *args, **kwargs) -> None:
        # 发出警告,表明 FlavaFeatureExtractor 类已弃用,将在 Transformers 版本 5 中移除,建议使用 FlavaImageProcessor 替代
        warnings.warn(
            "The class FlavaFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
            " use FlavaImageProcessor instead.",
            FutureWarning,
        )
        # 调用父类的初始化方法,传递所有位置参数和关键字参数
        super().__init__(*args, **kwargs)

.\models\flava\image_processing_flava.py

# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for Flava."""

import math
import random
from functools import lru_cache
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np

from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import resize, to_channel_dimension_format
from ...image_utils import (
    OPENAI_CLIP_MEAN,
    OPENAI_CLIP_STD,
    ChannelDimension,
    ImageInput,
    PILImageResampling,
    infer_channel_dimension_format,
    is_scaled_image,
    make_list_of_images,
    to_numpy_array,
    valid_images,
    validate_kwargs,
    validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging


if is_vision_available():
    import PIL


logger = logging.get_logger(__name__)


# These values are taken from CLIP
FLAVA_IMAGE_MEAN = OPENAI_CLIP_MEAN
FLAVA_IMAGE_STD = OPENAI_CLIP_STD
FLAVA_CODEBOOK_MEAN = [0.0, 0.0, 0.0]
FLAVA_CODEBOOK_STD = [1.0, 1.0, 1.0]
LOGIT_LAPLACE_EPS: float = 0.1


# Inspired from https://github.com/microsoft/unilm/blob/master/beit/masking_generator.py
class FlavaMaskingGenerator:
    def __init__(
        self,
        input_size: Union[int, Tuple[int, int]] = 14,
        total_mask_patches: int = 75,
        mask_group_max_patches: Optional[int] = None,
        mask_group_min_patches: int = 16,
        mask_group_min_aspect_ratio: Optional[float] = 0.3,
        mask_group_max_aspect_ratio: float = None,
    ):
        # 如果输入大小不是元组,则将其转换为元组
        if not isinstance(input_size, tuple):
            input_size = (input_size,) * 2
        # 初始化输入的高度和宽度
        self.height, self.width = input_size

        # 计算总的掩码片段数
        self.num_patches = self.height * self.width
        self.total_mask_patches = total_mask_patches

        # 设定每个掩码组的最小和最大片段数
        self.mask_group_min_patches = mask_group_min_patches
        self.mask_group_max_patches = total_mask_patches if mask_group_max_patches is None else mask_group_max_patches

        # 根据最小和最大纵横比计算对数纵横比的范围
        mask_group_max_aspect_ratio = mask_group_max_aspect_ratio or 1 / mask_group_min_aspect_ratio
        self.log_aspect_ratio = (math.log(mask_group_min_aspect_ratio), math.log(mask_group_max_aspect_ratio))
    # 返回对象的字符串表示,描述了 MaskingGenerator 实例的参数和范围
    def __repr__(self):
        repr_str = "MaskingGenerator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
            self.height,
            self.width,
            self.mask_group_min_patches,
            self.mask_group_max_patches,
            self.total_mask_patches,
            self.log_aspect_ratio[0],
            self.log_aspect_ratio[1],
        )
        return repr_str

    # 返回生成器的高度和宽度
    def get_shape(self):
        return self.height, self.width

    # 执行掩码生成的核心方法,修改给定的掩码并返回修改的像素数
    def _mask(self, mask, max_mask_patches):
        delta = 0
        for _attempt in range(10):
            # 随机确定目标区域的面积
            target_area = random.uniform(self.mask_group_min_patches, max_mask_patches)
            # 随机生成长宽比
            aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
            # 根据面积和长宽比计算高度和宽度
            height = int(round(math.sqrt(target_area * aspect_ratio)))
            width = int(round(math.sqrt(target_area / aspect_ratio)))
            # 如果生成的掩码区域在合理范围内
            if width < self.width and height < self.height:
                top = random.randint(0, self.height - height)
                left = random.randint(0, self.width - width)

                # 计算当前掩码区域中已经被掩盖的像素数
                num_masked = mask[top : top + height, left : left + width].sum()
                # 如果新生成的掩盖区域与当前掩盖区域有重叠
                if 0 < height * width - num_masked <= max_mask_patches:
                    # 将新区域中未被掩盖的像素进行掩盖
                    for i in range(top, top + height):
                        for j in range(left, left + width):
                            if mask[i, j] == 0:
                                mask[i, j] = 1
                                delta += 1

                # 如果有像素被掩盖,则结束循环
                if delta > 0:
                    break
        # 返回掩盖操作导致的像素数变化
        return delta

    # 生成器的调用方法,生成并返回掩码
    def __call__(self):
        # 创建一个与生成器相同形状的零矩阵作为掩码
        mask = np.zeros(shape=self.get_shape(), dtype=int)
        mask_count = 0
        # 循环生成掩码,直到达到指定的总掩盖像素数
        while mask_count < self.total_mask_patches:
            # 每次最多可以生成的掩码数
            max_mask_patches = self.total_mask_patches - mask_count
            max_mask_patches = min(max_mask_patches, self.mask_group_max_patches)

            # 执行掩码生成,获取本次生成的掩码像素数
            delta = self._mask(mask, max_mask_patches)
            # 如果没有新的像素被掩盖,则结束生成过程
            if delta == 0:
                break
            else:
                mask_count += delta

        # 返回生成的掩码
        return mask
class FlavaImageProcessor(BaseImageProcessor):
    r"""
    构造一个 Flava 图像处理器。
    """

    model_input_names = ["pixel_values"]

    def __init__(
        self,
        # 是否进行调整大小
        do_resize: bool = True,
        # 图像大小
        size: Dict[str, int] = None,
        # 重采样方法
        resample: PILImageResampling = PILImageResampling.BICUBIC,
        # 是否进行中心裁剪
        do_center_crop: bool = True,
        # 裁剪大小
        crop_size: Dict[str, int] = None,
        # 是否进行重新缩放
        do_rescale: bool = True,
        # 缩放系数
        rescale_factor: Union[int, float] = 1 / 255,
        # 是否进行归一化
        do_normalize: bool = True,
        # 图像均值
        image_mean: Optional[Union[float, Iterable[float]]] = None,
        # 图像标准差
        image_std: Optional[Union[float, Iterable[float]]] = None,
        # Mask 相关参数
        return_image_mask: bool = False,
        input_size_patches: int = 14,
        total_mask_patches: int = 75,
        mask_group_min_patches: int = 16,
        mask_group_max_patches: Optional[int] = None,
        mask_group_min_aspect_ratio: float = 0.3,
        mask_group_max_aspect_ratio: Optional[float] = None,
        # Codebook 相关参数
        return_codebook_pixels: bool = False,
        codebook_do_resize: bool = True,
        codebook_size: bool = None,
        codebook_resample: int = PILImageResampling.LANCZOS,
        codebook_do_center_crop: bool = True,
        codebook_crop_size: int = None,
        codebook_do_rescale: bool = True,
        codebook_rescale_factor: Union[int, float] = 1 / 255,
        codebook_do_map_pixels: bool = True,
        codebook_do_normalize: bool = True,
        codebook_image_mean: Optional[Union[float, Iterable[float]]] = None,
        codebook_image_std: Optional[Union[float, Iterable[float]]] = None,
        **kwargs,
    @classmethod
    def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
        """
        重写基类的 `from_dict` 方法,以确保在使用 from_dict 和 kwargs 创建图像处理器时更新参数,
        例如 `FlavaImageProcessor.from_pretrained(checkpoint, codebook_size=600)`
        """
        image_processor_dict = image_processor_dict.copy()
        if "codebook_size" in kwargs:
            image_processor_dict["codebook_size"] = kwargs.pop("codebook_size")
        if "codebook_crop_size" in kwargs:
        image_processor_dict["codebook_crop_size"] = kwargs.pop("codebook_crop_size")
        return super().from_dict(image_processor_dict, **kwargs)

    @lru_cache()
    def masking_generator(
        self,
        input_size_patches,
        total_mask_patches,
        mask_group_min_patches,
        mask_group_max_patches,
        mask_group_min_aspect_ratio,
        mask_group_max_aspect_ratio,
    # 返回一个 FlavaMaskingGenerator 实例,使用给定的参数初始化
    ) -> FlavaMaskingGenerator:
        return FlavaMaskingGenerator(
            input_size=input_size_patches,  # 设置输入大小为 input_size_patches
            total_mask_patches=total_mask_patches,  # 设置总掩蔽片段数为 total_mask_patches
            mask_group_min_patches=mask_group_min_patches,  # 设置掩蔽组最小片段数为 mask_group_min_patches
            mask_group_max_patches=mask_group_max_patches,  # 设置掩蔽组最大片段数为 mask_group_max_patches
            mask_group_min_aspect_ratio=mask_group_min_aspect_ratio,  # 设置掩蔽组最小长宽比为 mask_group_min_aspect_ratio
            mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,  # 设置掩蔽组最大长宽比为 mask_group_max_aspect_ratio
        )

    # 从 transformers.models.vit.image_processing_vit.ViTImageProcessor.resize 中复制的函数
    # 用于调整图像大小,使用 BICUBIC 插值算法,接受一个 np.ndarray 格式的图像数据
    def resize(
        self,
        image: np.ndarray,  # 输入的图像数据,格式为 np.ndarray
        size: Dict[str, int],  # 目标大小,以字典形式提供,包含 'height' 和 'width' 键
        resample: PILImageResampling = PILImageResampling.BICUBIC,  # 插值方法,默认为 BICUBIC
        data_format: Optional[Union[str, ChannelDimension]] = None,  # 输出数据格式,可选参数
        input_data_format: Optional[Union[str, ChannelDimension]] = None,  # 输入数据格式,可选参数
        **kwargs,  # 其它可选参数
    ) -> np.ndarray:
        """
        Resize an image to `(size["height"], size["width"])`.

        Args:
            image (`np.ndarray`):
                Image to resize.
            size (`Dict[str, int]`):
                Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
            resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
                `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
            data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format for the output image. If unset, the channel dimension format of the input
                image is used. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
                - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
            input_data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format for the input image. If unset, the channel dimension format is inferred
                from the input image. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
                - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.

        Returns:
            `np.ndarray`: The resized image.
        """
        size = get_size_dict(size)  # 获取处理后的尺寸字典
        if "height" not in size or "width" not in size:
            raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
        output_size = (size["height"], size["width"])  # 设置输出图像的尺寸
        return resize(
            image,
            size=output_size,  # 调整图像大小至指定尺寸
            resample=resample,  # 使用指定的重采样滤波器,默认为双三次插值
            data_format=data_format,  # 设置输出图像的通道顺序格式
            input_data_format=input_data_format,  # 设置输入图像的通道顺序格式,默认从输入图像推断
            **kwargs,
        )

    def map_pixels(self, image: np.ndarray) -> np.ndarray:
        """
        Maps pixel values of an image using a specific constant.

        Args:
            image (`np.ndarray`):
                Input image.

        Returns:
            `np.ndarray`: Processed image with mapped pixel values.
        """
        return (1 - 2 * LOGIT_LAPLACE_EPS) * image + LOGIT_LAPLACE_EPS

    def _preprocess_image(
        self,
        image: ImageInput,
        do_resize: bool = None,
        size: Dict[str, int] = None,
        resample: PILImageResampling = None,
        do_center_crop: bool = None,
        crop_size: Dict[str, int] = None,
        do_rescale: bool = None,
        rescale_factor: float = None,
        do_normalize: bool = None,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        do_map_pixels: bool = None,
        data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
        input_data_format: Optional[ChannelDimension] = None,
    ) -> np.ndarray:
        """Preprocesses a single image."""

        validate_preprocess_arguments(
            do_rescale=do_rescale,
            rescale_factor=rescale_factor,
            do_normalize=do_normalize,
            image_mean=image_mean,
            image_std=image_std,
            do_center_crop=do_center_crop,
            crop_size=crop_size,
            do_resize=do_resize,
            size=size,
            resample=resample,
        )

        # All transformations expect numpy arrays.
        # 将图像转换为 numpy 数组
        image = to_numpy_array(image)

        if is_scaled_image(image) and do_rescale:
            # 如果图像已经缩放并且设定了重新缩放选项,则发出警告
            logger.warning_once(
                "It looks like you are trying to rescale already rescaled images. If the input"
                " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
            )

        if input_data_format is None:
            # 假设所有图像具有相同的通道维度格式
            input_data_format = infer_channel_dimension_format(image)

        if do_resize:
            # 如果需要调整大小,则调用 resize 方法
            image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)

        if do_center_crop:
            # 如果需要中心裁剪,则调用 center_crop 方法
            image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)

        if do_rescale:
            # 如果需要重新缩放,则调用 rescale 方法
            image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)

        if do_normalize:
            # 如果需要标准化,则调用 normalize 方法
            image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)

        if do_map_pixels:
            # 如果需要像素映射,则调用 map_pixels 方法
            image = self.map_pixels(image)

        if data_format is not None:
            # 如果指定了数据格式,则将图像转换为该格式的通道维度格式
            image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
        # 返回预处理后的图像
        return image
    # 定义一个方法用于预处理输入的图像数据,包含多个参数用于控制不同的预处理步骤和参数设置
    def preprocess(
        self,
        images: ImageInput,
        do_resize: Optional[bool] = None,
        size: Dict[str, int] = None,
        resample: PILImageResampling = None,
        do_center_crop: Optional[bool] = None,
        crop_size: Optional[Dict[str, int]] = None,
        do_rescale: Optional[bool] = None,
        rescale_factor: Optional[float] = None,
        do_normalize: Optional[bool] = None,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        # Mask 相关参数
        return_image_mask: Optional[bool] = None,
        input_size_patches: Optional[int] = None,
        total_mask_patches: Optional[int] = None,
        mask_group_min_patches: Optional[int] = None,
        mask_group_max_patches: Optional[int] = None,
        mask_group_min_aspect_ratio: Optional[float] = None,
        mask_group_max_aspect_ratio: Optional[float] = None,
        # Codebook 相关参数
        return_codebook_pixels: Optional[bool] = None,
        codebook_do_resize: Optional[bool] = None,
        codebook_size: Optional[Dict[str, int]] = None,
        codebook_resample: Optional[int] = None,
        codebook_do_center_crop: Optional[bool] = None,
        codebook_crop_size: Optional[Dict[str, int]] = None,
        codebook_do_rescale: Optional[bool] = None,
        codebook_rescale_factor: Optional[float] = None,
        codebook_do_map_pixels: Optional[bool] = None,
        codebook_do_normalize: Optional[bool] = None,
        codebook_image_mean: Optional[Iterable[float]] = None,
        codebook_image_std: Optional[Iterable[float]] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        data_format: ChannelDimension = ChannelDimension.FIRST,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
        **kwargs,
    ):

.\models\flava\modeling_flava.py

# 设置代码文件的编码格式为 UTF-8
# 版权声明和许可证信息
# 根据 Apache License 2.0 许可证,除非符合许可证要求,否则不得使用此文件
# 可以在以下网址获取完整许可证文本:http://www.apache.org/licenses/LICENSE-2.0
# 本软件基于 "AS IS" 原则发布,不提供任何形式的明示或暗示担保或条件
# 详细信息请参阅许可证文档

""" PyTorch FLAVA model. """

# 导入所需的库和模块
import collections
import math
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn

# 从外部模块导入特定的函数和类
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
    ModelOutput,
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
# 导入 FLAVA 相关配置类
from .configuration_flava import (
    FlavaConfig,
    FlavaImageCodebookConfig,
    FlavaImageConfig,
    FlavaMultimodalConfig,
    FlavaTextConfig,
)

# 获取 logger 对象用于记录日志信息
logger = logging.get_logger(__name__)

# 用于文档的预训练模型检查点名称
_CHECKPOINT_FOR_DOC = "facebook/flava-full"

# 用于图像代码簿文档的预训练模型检查点名称
_CHECKPOINT_FOR_CODEBOOK_DOC = "facebook/flava-image-codebook"
# 图像模型配置类的文档字符串
_CONFIG_CLASS_FOR_IMAGE_MODEL_DOC = "FlavaImageConfig"
# 文本模型配置类的文档字符串
_CONFIG_CLASS_FOR_TEXT_MODEL_DOC = "FlavaTextConfig"
# 多模态模型配置类的文档字符串
_CONFIG_CLASS_FOR_MULTIMODAL_MODEL_DOC = "FlavaMultimodalConfig"
# 预期的图像输出形状
_EXPECTED_IMAGE_OUTPUT_SHAPE = [1, 197, 768]

# FLAVA 预训练模型的模型存档列表
FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "facebook/flava-full",
    # 可以在 https://huggingface.co/models?filter=flava 查看所有 FLAVA 模型
]
# FLAVA 图像代码簿预训练模型的模型存档列表
FLAVA_CODEBOOK_PRETRAINED_MODEL_ARCHIVE_LIST = ["facebook/flava-image-codebook"]
# 对数尺度的最小值
LOGIT_SCALE_CLAMP_MIN = 0
# 对数尺度的最大值
LOGIT_SCALE_CLAMP_MAX = 4.6052

# FLAVA 模型可能的配置类别
FlavaPossibleConfigs = Union[FlavaTextConfig, FlavaImageConfig, FlavaMultimodalConfig]

# 数据类,包含 FLAVA 模型的输出,继承自 ModelOutput 类
@dataclass
class FlavaModelOutput(ModelOutput):
    """
    FlavaModel 的输出,包含来自各个编码器的嵌入和输出。

    注意,返回的 `image_embeddings` 和 `text_embeddings` 类似于变压器返回的汇总输出。
    如果需要用于对比损失或检索的嵌入,请在 `image_embeddings` 和 `text_embeddings` 上使用 FLAVA 模型的
    `image_projection` 和 `text_projection` 层。
    """
    """
    Args:
        image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
            The image embeddings which are basically the pooled output of [`FlavaImageModel`].
        image_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
            The output of the [`FlavaImageModel`].
        text_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` are present):
            The text embeddings which are basically the pooled output of [`FlavaTextModel`].
        text_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present):
            The output of the [`FlavaTextModel`].
        multimodal_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`):
            The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
        multimodal_output (`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`):
            The output of the [`FlavaMultimodalModel`].
    """

    # 可选的图像嵌入向量,形状为 `(batch_size, output_dim)`,当存在 `pixel_values` 时返回
    image_embeddings: Optional[torch.FloatTensor] = None
    # 可选的图像模型输出,当存在 `pixel_values` 时返回
    image_output: Optional[BaseModelOutputWithPooling] = None
    # 可选的文本嵌入向量,形状为 `(batch_size, output_dim)`,当存在 `input_ids` 时返回
    text_embeddings: Optional[torch.FloatTensor] = None
    # 可选的文本模型输出,当存在 `input_ids` 时返回
    text_output: Optional[BaseModelOutputWithPooling] = None
    # 可选的多模态嵌入向量,形状为 `(batch_size, output_dim)`,当同时存在 `input_ids` 和 `pixel_values` 并且 `skip_multimodal_encoder` 不为 `None` 或 `False` 时返回
    multimodal_embeddings: Optional[torch.FloatTensor] = None
    # 可选的多模态模型输出,当同时存在 `input_ids` 和 `pixel_values` 并且 `skip_multimodal_encoder` 不为 `None` 或 `False` 时返回
    multimodal_output: Optional[BaseModelOutputWithPooling] = None

    # 将当前对象转换为元组
    def to_tuple(self) -> Tuple[Any]:
        return tuple(
            # 对于所有键,返回对应的值,除非键是 ["text_output", "image_output", "multimodal_output"] 中的一个,
            # 这些键对应的值需调用其 `to_tuple()` 方法进行转换
            self[k] if k not in ["text_output", "image_output", "multimodal_output"] else getattr(self, k).to_tuple()
            for k in self.keys()
        )
# 定义一个数据类 `FlavaLosses`,用于存储 FLAVA 模型的预训练损失
@dataclass
class FlavaLosses(ModelOutput):
    """Class representing pretraining losses from FLAVA model

    Args:
        mim (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels` and `pixel_values` are present, `input_ids_masked` is absent and `mim_weight` > 0.:
            Masked Image Modeling loss as used in BeIT calculated only for unimodal image data.
        mlm (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels` and `input_ids_masked` are present, `pixel_values` is absent and `mlm_weight` > 0.:
            Masked Language Modeling loss as used in BERT calculated only for unimodal text data.
        itm (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `itm_labels`, `input_ids_masked`, `pixel_values` are present and `itm_weight` > 0.:
            Image Text Matching (ITM) loss calculated for paired image-text data. Note that ITM loss is calculated on
            masked pairs in FLAVA.
        global_contrastive (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `input_ids` and `pixel_values` are present and `global_contrastive_weight` > 0.:
            Contrastive loss for image-text similarity similar to CLIP but calculated globally for paired image-text
            data. This is calculated on unmasked images and texts.
        mmm_image (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_image_weight` > 0.:
            Masked Multimodal Modeling loss's image component calculated on paired image-text data.
        mmm_text (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_text_weight` > 0.:
            Masked Multimodal Modeling loss's text component calculated on paired image-text data.
    """

    # 定义各种损失的属性,使用 torch.FloatTensor 类型,均可选
    mim: Optional[torch.FloatTensor] = None
    mlm: Optional[torch.FloatTensor] = None
    itm: Optional[torch.FloatTensor] = None
    global_contrastive: Optional[torch.FloatTensor] = None
    mmm_image: Optional[torch.FloatTensor] = None
    mmm_text: Optional[torch.FloatTensor] = None

    # 定义一个方法,用于检查所有损失属性是否都为 None
    def all_none(self) -> bool:
        # 初始化一个标志位,表示是否所有属性都为 None
        all_none = True
        # 遍历所有损失属性的值
        for v in self.values():
            # 如果某个属性值不为 None,则将标志位置为 False,并跳出循环
            if v is not None:
                all_none = False
                break
        # 返回标志位,表示所有损失属性是否都为 None
        return all_none


# 定义一个数据类 `FlavaForPreTrainingOutput`,用于存储 FLAVA 模型预训练的输出
@dataclass
class FlavaForPreTrainingOutput(ModelOutput):
    """
    Output from FlavaForPreTraining containing embeddings, and outputs from individual encoders.

    Note that `image_embeddings` and `text_embeddings` returned are similar to pooled output returned from a
    transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and
    `text_projection` layers on `image_embeddings` and `text_embeddings` respectively.

    """

    # 定义模型输出的属性,包括损失、损失信息以及图像嵌入
    loss: Optional[torch.FloatTensor] = None
    loss_info: FlavaLosses = None
    image_embeddings: Optional[torch.FloatTensor] = None
    # 定义多个可选的模型输出变量,初始值均为 None
    image_output: Optional[BaseModelOutputWithPooling] = None
    text_embeddings: Optional[torch.FloatTensor] = None
    text_output: Optional[BaseModelOutputWithPooling] = None
    multimodal_embeddings: Optional[torch.FloatTensor] = None
    multimodal_output: Optional[BaseModelOutputWithPooling] = None
    image_masked_embeddings: Optional[torch.FloatTensor] = None
    image_masked_output: Optional[BaseModelOutputWithPooling] = None
    text_masked_embeddings: Optional[torch.FloatTensor] = None
    text_masked_output: Optional[BaseModelOutputWithPooling] = None
    multimodal_masked_embeddings: Optional[torch.FloatTensor] = None
    multimodal_masked_output: Optional[BaseModelOutputWithPooling] = None
    mim_logits: Optional[torch.FloatTensor] = None
    mlm_logits: Optional[torch.FloatTensor] = None
    itm_logits: Optional[torch.FloatTensor] = None
    contrastive_logits_per_image: Optional[torch.FloatTensor] = None
    contrastive_logits_per_text: Optional[torch.FloatTensor] = None
    mmm_image_logits: Optional[torch.FloatTensor] = None
    mmm_text_logits: Optional[torch.FloatTensor] = None

    # 定义方法将对象转换为元组的函数签名
    def to_tuple(self) -> Tuple[Any]:
        # 指定转换输出的顺序列表
        transformer_outputs = [
            "text_output",
            "image_output",
            "multimodal_output",
            "text_masked_output",
            "image_masked_output",
            "multimodal_masked_output",
        ]
        # 返回一个元组,包含对象中指定的属性的值,若属性在 transformer_outputs 中,则调用相应对象的 to_tuple() 方法
        return tuple(self[k] if k not in transformer_outputs else getattr(self, k).to_tuple() for k in self.keys())
# 基于 timm 实现的代码,可以在以下链接找到:
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/image_transformer.py
class FlavaImageEmbeddings(nn.Module):
    """
    构建 CLS token、位置和patch embeddings。可选择是否包含 mask token。
    """

    def __init__(self, config: FlavaImageConfig, use_mask_token: bool = False) -> None:
        super().__init__()

        # 确定是否使用 mask token,如果 use_mask_token 为 True 或者 config 中指定了 mask_token,则使用
        use_mask_token = use_mask_token or config.mask_token
        
        # 定义 CLS token,是一个可学习的参数,形状为 (1, 1, hidden_size)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
        
        # 如果使用 mask token,则定义一个可学习的参数作为 mask token,形状同上
        self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
        
        # 初始化 patch embeddings,使用 PatchEmbeddings 类生成 patch embeddings
        self.patch_embeddings = PatchEmbeddings(
            image_size=config.image_size,
            patch_size=config.patch_size,
            num_channels=config.num_channels,
            embed_dim=config.hidden_size,
        )
        
        # 计算 patch 的数量(加上一个额外的位置用于 CLS token),用于定义位置 embeddings
        num_patches = self.patch_embeddings.num_patches
        
        # 定义位置 embeddings,是一个可学习的参数,形状为 (1, num_patches + 1, hidden_size)
        self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
        
        # 定义 dropout 层,用于在训练过程中随机丢弃部分神经元,防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        
        # 保存配置信息
        self.config = config
    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
        """
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
        resolution images.

        Source:
        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/image_transformer.py#L174
        """

        # 计算当前嵌入的图像块数目(npatch)
        npatch = embeddings.shape[1] - 1
        # 获取预训练位置编码的数目(num_pos)
        num_pos = self.position_embeddings.shape[1] - 1
        # 如果图像块数目与位置编码数目相等,并且图像是正方形,则直接返回位置编码
        if npatch == num_pos and height == width:
            return self.position_embeddings
        
        # 获取类别位置编码(第一列)
        class_pos_embed = self.position_embeddings[:, 0]
        # 获取图像块位置编码(除去第一列)
        patch_pos_embed = self.position_embeddings[:, 1:]
        # 获取嵌入的维度
        dim = embeddings.shape[-1]
        # 计算图像中的水平和垂直图块数目
        num_h_patches = height // self.config.patch_size
        num_w_patches = width // self.config.patch_size
        # 添加一个小数以避免插值时的浮点数误差
        num_h_patches, num_w_patches = num_h_patches + 0.1, num_w_patches + 0.1
        
        # 对图像块位置编码进行插值操作
        patch_pos_embed = nn.functional.interpolate(
            # 将位置编码重新形状为 4 维张量,并重新排列维度顺序
            patch_pos_embed.reshape(1, int(math.sqrt(num_pos)), int(math.sqrt(num_pos)), dim).permute(0, 3, 1, 2),
            # 设置插值的比例因子,根据图像块数目和位置编码数目的关系
            scale_factor=(num_h_patches / math.sqrt(num_pos), num_w_patches / math.sqrt(num_pos)),
            mode="bicubic",  # 使用双三次插值模式
            align_corners=False,  # 不对齐角落像素
        )
        
        # 检查插值后的图像块位置编码是否与预期的图像块数目相符
        if int(num_h_patches) != patch_pos_embed.shape[-2] or int(num_w_patches) != patch_pos_embed.shape[-1]:
            raise ValueError(
                f"Number of patches for images ({int(num_h_patches), int(num_w_patches)}) don't match the "
                f"shape of position embedding ({patch_pos_embed.shape[-2], patch_pos_embed.shape[-1]})"
            )
        
        # 调整形状并排列维度以匹配模型输出的要求
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        # 将类别位置编码与插值后的图像块位置编码拼接在一起
        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
    # 定义一个方法,接受像素值作为输入,返回一个 Torch 张量
    ) -> torch.Tensor:
        # 获取输入张量的维度信息:批大小、通道数、高度、宽度
        batch_size, num_channels, height, width = pixel_values.shape
        # 将像素值传入 patch_embeddings 方法,生成嵌入表示,并可能插入位置编码
        embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)

        # 再次获取嵌入张量的维度信息:批大小、序列长度、嵌入维度
        batch_size, seq_len, _ = embeddings.size()
        # 如果存在布尔类型的遮罩位置信息
        if bool_masked_pos is not None:
            # 创建一个与嵌入张量形状相同的 mask_tokens 张量,用于替换遮罩的视觉标记
            mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
            # 如果 bool_masked_pos 是三维的,将其展平为二维的
            if bool_masked_pos.dim() == 3:
                bool_masked_pos = bool_masked_pos.view(bool_masked_pos.size(0), -1)
            # 将 bool_masked_pos 转换为与 mask_tokens 相同类型的张量,并将遮罩应用到 embeddings
            mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
            embeddings = embeddings * (1.0 - mask) + mask_tokens * mask

        # 将 [CLS] 标记添加到嵌入的补丁标记中
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)

        # 如果选择插值位置编码,则对每个 token 添加位置编码
        if interpolate_pos_encoding:
            embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
        else:
            # 否则,直接添加预定义的位置编码
            embeddings = embeddings + self.position_embeddings

        # 应用 dropout 操作到嵌入张量
        embeddings = self.dropout(embeddings)

        # 返回嵌入张量作为输出
        return embeddings
# Based on timm implementation, which can be found here:
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/image_transformer.py
class PatchEmbeddings(nn.Module):
    """
    Image to Patch Embedding.
    """

    def __init__(
        self,
        image_size: int = 224,
        patch_size: Union[int, Tuple[int, int]] = 16,
        num_channels: int = 3,
        embed_dim: int = 768,
    ):
        super().__init__()
        # 如果image_size不是可迭代对象,则转换为元组
        if not isinstance(image_size, collections.abc.Iterable):
            image_size = (image_size, image_size)
        # 如果patch_size不是可迭代对象,则转换为元组
        if not isinstance(patch_size, collections.abc.Iterable):
            patch_size = (patch_size, patch_size)
        # 计算图像被划分成的块数
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        # 使用卷积层将输入图像的每个patch映射到embed_dim维度的特征空间
        self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
        batch_size, num_channels, height, width = pixel_values.shape
        # 如果不需要插值位置编码,检查输入图像尺寸是否与预期尺寸匹配
        if not interpolate_pos_encoding:
            if height != self.image_size[0] or width != self.image_size[1]:
                raise ValueError(
                    f"Input image size ({height}*{width}) doesn't match model"
                    f" ({self.image_size[0]}*{self.image_size[1]})."
                )
        # 使用卷积层对输入图像进行特征提取,并展平和转置维度以适应后续处理
        x = self.projection(pixel_values).flatten(2).transpose(1, 2)
        return x


class FlavaTextEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, config):
        super().__init__()
        # 创建词嵌入,词位置嵌入和令牌类型嵌入
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # 保持变量名与TensorFlow模型的一致性,并且能够加载TensorFlow检查点文件
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # 位置嵌入类型,绝对还是相对
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
        # 注册位置ID张量,用于序列化时导出
        self.register_buffer(
            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
        )
        # 注册令牌类型ID张量,初始化为全零张量
        self.register_buffer(
            "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
        )
    # 定义一个方法,用于模型的前向传播,接受输入的张量和位置信息
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
    ):
        # 获取输入张量的形状信息
        input_shape = input_ids.size()
        # 获取序列长度
        seq_length = input_shape[1]

        # 如果未提供位置信息,则使用预设的位置信息
        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]

        # 设置 token_type_ids 为构造函数中注册的缓冲区,通常为全零,用于在不传递 token_type_ids 时
        # 跟踪模型时帮助用户,解决问题 #5664
        if token_type_ids is None:
            if hasattr(self, "token_type_ids"):
                # 使用预设的 token_type_ids,扩展到与输入形状相同的尺寸
                buffered_token_type_ids = self.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
                token_type_ids = buffered_token_type_ids_expanded
            else:
                # 如果未定义 token_type_ids,则创建全零张量,设备与位置信息相同
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

        # 将输入的 token IDs 转换为词嵌入向量
        inputs_embeds = self.word_embeddings(input_ids)
        # 根据 token_type_ids 获取对应的 token type embeddings
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        # 将输入嵌入向量和 token type embeddings 相加得到最终的 embeddings
        embeddings = inputs_embeds + token_type_embeddings

        # 如果位置嵌入类型为 "absolute",则添加绝对位置嵌入向量
        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings

        # 对 embeddings 进行 Layer Normalization 处理
        embeddings = self.LayerNorm(embeddings)
        # 对 embeddings 进行 dropout 处理,用于模型的正则化
        embeddings = self.dropout(embeddings)

        # 返回最终的 embeddings 作为前向传播的输出
        return embeddings
# 定义自注意力机制的模型类,继承自 nn.Module
class FlavaSelfAttention(nn.Module):
    def __init__(self, config: FlavaPossibleConfigs) -> None:
        super().__init__()
        # 检查隐藏层大小是否是注意力头数的整数倍,如果不是且没有嵌入大小的属性,则引发 ValueError
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
                f"heads {config.num_attention_heads}."
            )

        # 初始化注意力头数和每个注意力头的大小
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # 初始化查询、键、值的线性层
        self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
        self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
        self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)

        # 初始化用于 dropout 的层
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    # 将输入张量转换为分数矩阵的形状
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    # 定义前向传播函数,接受隐藏状态、注意力掩码、头掩码和输出注意力标志
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        # Compute mixed query layer using the query projection layer
        mixed_query_layer = self.query(hidden_states)

        # Compute key and value layers by applying projection layers to hidden states
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)

        # Compute attention scores by taking the dot product of query and key tensors
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        # Scale the attention scores by the square root of the head size
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        if attention_mask is not None:
            # Apply the provided attention mask to the attention scores
            attention_scores = attention_scores + attention_mask

        # Compute attention probabilities by applying softmax to the attention scores
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # Apply dropout to attention probabilities
        attention_probs = self.dropout(attention_probs)

        # Apply head mask to attention probabilities if provided
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        # Compute the context layer by taking the weighted sum of value tensors
        context_layer = torch.matmul(attention_probs, value_layer)

        # Transpose and reshape the context layer to match the required output shape
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # Prepare outputs depending on whether to include attention probabilities
        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        # Return the computed outputs
        return outputs
class FlavaSelfOutput(nn.Module):
    """
    The residual connection is defined in FlavaLayer (same as ViTLayer) instead of here (as is the case with other
    models), due to the layernorm applied before each block.
    """

    def __init__(self, config: FlavaPossibleConfigs) -> None:
        super().__init__()
        # 线性层,输入输出维度均为 config.hidden_size
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 以概率 config.hidden_dropout_prob 进行随机失活
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 线性变换
        hidden_states = self.dense(hidden_states)
        # 随机失活
        hidden_states = self.dropout(hidden_states)

        return hidden_states


class FlavaAttention(nn.Module):
    def __init__(self, config: FlavaPossibleConfigs) -> None:
        super().__init__()
        # 自注意力机制模块
        self.attention = FlavaSelfAttention(config)
        # 输出模块,包括线性变换和随机失活
        self.output = FlavaSelfOutput(config)
        # 被剪枝的注意力头部集合
        self.pruned_heads = set()

    def prune_heads(self, heads: Set[int]) -> None:
        if len(heads) == 0:
            return
        # 找到可剪枝的注意力头部并获取索引
        heads, index = find_pruneable_heads_and_indices(
            heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
        )

        # 对线性层进行剪枝
        self.attention.query = prune_linear_layer(self.attention.query, index)
        self.attention.key = prune_linear_layer(self.attention.key, index)
        self.attention.value = prune_linear_layer(self.attention.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # 更新超参数并存储已剪枝的头部
        self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
        self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        # 执行自注意力机制,并返回输出
        self_outputs = self.attention(
            hidden_states, attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions
        )

        # 将自注意力的输出传递给输出模块,得到最终的注意力输出
        attention_output = self.output(self_outputs[0], hidden_states)

        # 如果需要输出注意力权重,则将其添加到输出中
        outputs = (attention_output,) + self_outputs[1:]  # 如果输出注意力权重,则添加到输出中
        return outputs


class FlavaIntermediate(nn.Module):
    def __init__(self, config: FlavaPossibleConfigs) -> None:
        super().__init__()
        # 线性层,输入维度为 config.hidden_size,输出维度为 config.intermediate_size
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        # 如果 config.hidden_act 是字符串,则使用预定义的激活函数;否则使用配置中定义的激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    # 从 transformers.models.vit.modeling_vit.ViTIntermediate.forward 复制过来的
    # 定义一个方法,用于前向传播计算
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 将输入的隐藏状态通过全连接层进行线性变换
        hidden_states = self.dense(hidden_states)
        # 对线性变换后的隐藏状态应用激活函数(通常是ReLU或类似的函数)
        hidden_states = self.intermediate_act_fn(hidden_states)

        # 返回经过线性变换和激活函数处理后的隐藏状态
        return hidden_states
class FlavaOutput(nn.Module):
    def __init__(self, config: FlavaPossibleConfigs) -> None:
        super().__init__()
        # 创建一个线性层,用于从中间大小映射到隐藏大小
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        # 创建一个用于随机失活的层,根据隐藏失活概率
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    # 从transformers.models.vit.modeling_vit.ViTOutput.forward复制过来
    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 通过线性层传播隐藏状态
        hidden_states = self.dense(hidden_states)
        # 应用随机失活到传播后的隐藏状态
        hidden_states = self.dropout(hidden_states)

        # 将传播后的隐藏状态与输入张量相加
        hidden_states = hidden_states + input_tensor

        return hidden_states


class FlavaLayer(nn.Module):
    """这对应于timm实现中的Block类。"""

    def __init__(self, config: FlavaPossibleConfigs) -> None:
        super().__init__()
        # 设置用于前馈传递的块大小
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        # 序列长度维度设为1
        self.seq_len_dim = 1
        # 初始化自注意力层
        self.attention = FlavaAttention(config)
        # 初始化中间层
        self.intermediate = FlavaIntermediate(config)
        # 初始化输出层
        self.output = FlavaOutput(config)

        # TODO: 检查是否可能使用fp32层归一化
        # 在隐藏大小上使用层归一化,设置epsilon为config中的层归一化epsilon
        self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        # 在自注意力之前应用层归一化,在ViT中,在自注意力之前应用层归一化
        self_attention_outputs = self.attention(
            self.layernorm_before(hidden_states),
            attention_mask=attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
        )
        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[1:]  # 如果输出注意力权重,则添加自注意力

        # 第一个残差连接
        hidden_states = attention_output + hidden_states

        # 在ViT中,也在自注意力之后应用层归一化
        layer_output = self.layernorm_after(hidden_states)
        layer_output = self.intermediate(layer_output)

        # 第二个残差连接在这里完成
        layer_output = self.output(layer_output, hidden_states)

        outputs = (layer_output,) + outputs

        return outputs


class FlavaEncoder(nn.Module):
    def __init__(self, config: FlavaConfig) -> None:
        super().__init__()
        self.config = config
        # 创建一个由FlavaLayer组成的层列表,列表长度为config中的隐藏层数量
        self.layer = nn.ModuleList([FlavaLayer(config) for _ in range(config.num_hidden_layers)])
        self.gradient_checkpointing = False
    # 定义一个前向传播方法,用于处理模型的前向推断过程
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ) -> Union[tuple, BaseModelOutput]:
        # 如果需要输出隐藏状态,则初始化一个空元组来存储所有隐藏状态
        all_hidden_states = () if output_hidden_states else None
        # 如果需要输出注意力权重,则初始化一个空元组来存储所有自注意力权重
        all_self_attentions = () if output_attentions else None

        # 遍历每个层次的模块
        for i, layer_module in enumerate(self.layer):
            # 如果需要输出隐藏状态,则将当前隐藏状态加入到所有隐藏状态元组中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 获取当前层的头部掩码
            layer_head_mask = head_mask[i] if head_mask is not None else None

            # 如果启用了梯度检查点且处于训练模式,则通过梯度检查点函数执行当前层的调用
            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    output_attentions,
                )
            else:
                # 否则直接调用当前层的前向传播函数
                layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)

            # 更新隐藏状态为当前层的输出的第一个元素
            hidden_states = layer_outputs[0]

            # 如果需要输出注意力权重,则将当前层的注意力权重加入到所有自注意力权重元组中
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        # 最后一个层完成后,如果需要输出隐藏状态,则将最终的隐藏状态加入到所有隐藏状态元组中
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 如果不需要返回字典形式的输出,则按需返回隐藏状态、所有隐藏状态、所有自注意力权重的元组形式
        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
        # 否则,返回一个BaseModelOutput对象,包含最终的隐藏状态、所有隐藏状态和所有自注意力权重
        return BaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
        )
class FlavaPooler(nn.Module):
    def __init__(self, config: FlavaPossibleConfigs):
        super().__init__()
        # 定义一个全连接层,输入和输出维度都为 config.hidden_size
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 定义激活函数为双曲正切函数
        self.activation = nn.Tanh()

    def forward(self, hidden_states: torch.Tensor):
        # 通过简单地选取第一个 token 对应的隐藏状态来"池化"模型
        first_token_tensor = hidden_states[:, 0]
        # 将选取的隐藏状态输入全连接层
        pooled_output = self.dense(first_token_tensor)
        # 将全连接层的输出应用激活函数
        pooled_output = self.activation(pooled_output)
        # 返回池化后的输出张量
        return pooled_output


FLAVA_START_DOCSTRING = r"""
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`{config}`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

FLAVA_INPUTS_DOCSTRING_COMMON = r"""
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
            [What are attention masks?](../glossary#attention-mask)

        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.

        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""

FLAVA_IMAGE_INPUTS_DOCSTRING_BASE = r"""
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values of the input images. This tensor represents the image data with dimensions:
            - `batch_size`: Number of images in the batch.
            - `num_channels`: Number of color channels (e.g., 3 for RGB images).
            - `height`: Height of each image.
            - `width`: Width of each image.
            Pixel values can be obtained using an `AutoImageProcessor`. Refer to the documentation
            of [`FlavaImageProcessor.__call__`] for more details.

        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`):
            Boolean tensor indicating masked positions within each image. Each element:
            - `1`: Indicates the corresponding image patch is masked.
            - `0`: Indicates the corresponding image patch is not masked.

        interpolate_pos_encoding (`bool`, *optional*):
            Optional flag indicating whether to interpolate pre-trained position encodings. If set to `True`,
            the model will interpolate existing position encodings; if `False` or not provided, no interpolation
            will be performed.
"""

FLAVA_IMAGE_INPUTS_DOCSTRING = FLAVA_IMAGE_INPUTS_DOCSTRING_BASE + FLAVA_INPUTS_DOCSTRING_COMMON
# 将基础的图像输入文档字符串与通用输入文档字符串相结合,形成完整的图像输入文档字符串

FLAVA_TEXT_INPUTS_DOCSTRING_BASE = r"""
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids)

        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:
            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.
            [What are token type IDs?](../glossary#token-type-ids)
"""
# 文本输入基础文档字符串,包含关于输入IDs和token type IDs的详细说明

FLAVA_TEXT_INPUTS_DOCSTRING = FLAVA_TEXT_INPUTS_DOCSTRING_BASE + FLAVA_INPUTS_DOCSTRING_COMMON
# 将基础的文本输入文档字符串与通用输入文档字符串相结合,形成完整的文本输入文档字符串

FLAVA_MULTIMODAL_INPUTS_DOCSTRING = (
    r"""
    Args:
        hidden_states (`torch.FloatTensor` of shape `(batch_size, image_num_patches + text_seq_len, hidden_size)`):
            The concatenated hidden states of unimodal encoders.
"""
    + FLAVA_INPUTS_DOCSTRING_COMMON
)
# 多模态输入文档字符串,描述了隐藏状态的拼接表示以及通用输入信息

FLAVA_MODEL_INPUTS_DOCSTRING_BASE = r"""
    Args:
        skip_multimodal_encoder (*bool*, *optional*):
            Skip any calculations for multimodal encoder. Useful if multimodal encoding is not going to be used.
"""
# 模型输入基础文档字符串,描述了是否跳过多模态编码器的计算的可选参数

FLAVA_MODEL_INPUTS_DOCSTRING = (
    FLAVA_IMAGE_INPUTS_DOCSTRING_BASE
    + FLAVA_TEXT_INPUTS_DOCSTRING_BASE
    + FLAVA_INPUTS_DOCSTRING_COMMON
    + FLAVA_MODEL_INPUTS_DOCSTRING_BASE
)
# 模型输入文档字符串,包含了图像、文本、和通用输入的详细说明以及模型输入基础的描述

FLAVA_PRETRAINING_INPUTS_DOCSTRING = (
    r"""
    Args:
        input_ids_masked (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary. These ones are the masked version of the original task
            to be used with MLM. Indices can be obtained using [`AutoTokenizer`] along with
            [`DataCollatorForMaskedLanguageModeling`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)

"""
    + FLAVA_TEXT_INPUTS_DOCSTRING_BASE
    + FLAVA_IMAGE_INPUTS_DOCSTRING_BASE
)
# 预训练输入文档字符串,描述了用于掩码语言建模的输入IDs以及文本和图像输入的基础描述
    # 定义参数 `image_attention_mask`,用于指定哪些图像注意力应避免,避免在填充标记上执行注意力。
    # 该参数是一个形状为 `{1}` 的 PyTorch 浮点张量,可选参数。
    # 值范围为 `[0, 1]`:
    # - 1 表示 **未屏蔽** 的标记,
    # - 0 表示 **已屏蔽** 的标记。
    # 详细了解注意力遮罩,请参阅 glossary 中的 "attention-mask" 部分。

    # 跳过未屏蔽的多模态编码器,用于 FLAVA 预训练,当前不需要未屏蔽的多模态嵌入或输出。

    # 定义参数 `mlm_labels`,用于计算左到右语言和多模态屏蔽建模损失(下一个词预测)的标签。
    # 该参数是一个形状为 `(batch_size, text_seq_len)` 的 PyTorch 长整型张量,可选参数。
    # 索引应在 `[-100, 0, ..., text_config.vocab_size - 1]` 范围内(参见 `input_ids` 的文档字符串)。
    # 索引设置为 `-100` 的标记被忽略(屏蔽),仅为标签在 `[0, ..., text_config.vocab_size - 1]` 范围内的标记计算损失。

    # 定义参数 `mim_labels`,用于计算图像和多模态屏蔽建模损失的标签。
    # 该参数是一个形状为 `(batch_size, image_num_patches)` 的 PyTorch 长整型张量,可选参数。
    # 索引应在 `[-100, 0, ..., image_config.vocab_size - 1]` 范围内。
    # 索引设置为 `-100` 的标记被忽略(屏蔽),仅为标签在 `[0, ..., image_config.vocab_size - 1]` 范围内的标记计算损失。
    # 如果未传入该参数,则会自动生成,使用模型分配的图像码本。默认使用 [`FlavaImageCodebook`]。详细了解 `FlavaImageCodebook` 以了解如何生成 `mim_labels`。

    # 定义参数 `itm_labels`,用于计算图像-文本匹配损失的标签。
    # 该参数是一个形状为 `(batch_size, 1)` 的 PyTorch 长整型张量,可选参数。
    # 值 `0` 表示不匹配的对,值 `1` 表示匹配的对。
    # 值为 `0` 的对将被跳过计算 MMM 和全局对比损失。

    # 定义参数 `return_loss`,指示是否返回计算的损失。
    # 该参数是一个布尔值,可选参数,默认为 `None`。
"""
    + FLAVA_INPUTS_DOCSTRING_COMMON
)

FLAVA_PRETRAINING_START_DOCSTRING_EXTRA = r"""
    Parameters:
        image_codebook ([`nn.Module`]): If passed, the image codebook will be set to this. Otherwise. it will
            be initialized using the image_codebook_config defined in the config first as the first parameter.
"""


class FlavaPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = FlavaConfig
    base_model_prefix = "flava"
    supports_gradient_checkpointing = True

    def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            # 对于线性层和卷积层,使用正态分布初始化权重
            # 略微不同于 TF 版本,后者使用截断正态分布进行初始化
            # 参考:https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                # 如果存在偏置项,则将其初始化为零
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            # 对于嵌入层,使用正态分布初始化权重
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                # 如果指定了 padding_idx,则将其对应的权重初始化为零
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            # 对于 LayerNorm 层,初始化偏置为零,初始化权重为全1
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


@add_start_docstrings(
    "The bare FLAVA Image Model transformer outputting raw hidden-states without any specific head on top.",
    FLAVA_START_DOCSTRING.format(config="FlavaImageConfig"),
)
class FlavaImageModel(FlavaPreTrainedModel):
    config_class = FlavaImageConfig
    # This override allows us to load FlavaImageModel from FlavaModel/FlavaForPreTraining checkpoints.
    base_model_prefix = "flava.image_model"
    main_input_name = "pixel_values"

    def __init__(self, config: FlavaImageConfig, add_pooling_layer: bool = True):
        super().__init__(config)

        self.config = config

        # 初始化模型的各个部分
        self.embeddings = FlavaImageEmbeddings(config)
        self.encoder = FlavaEncoder(config)

        # 初始化 LayerNorm 和 Pooler(如果需要)
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.pooler = FlavaPooler(config) if add_pooling_layer else None

        # 执行初始化后的操作
        self.post_init()

    def get_input_embeddings(self) -> nn.Module:
        return self.embeddings.patch_embeddings

    def set_input_embeddings(self, value: nn.Module):
        self.embeddings.patch_embeddings = value

    def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            # 裁剪模型中的注意力头
            self.encoder.layer[layer].attention.prune_heads(heads)
    @add_start_docstrings_to_model_forward(FLAVA_IMAGE_INPUTS_DOCSTRING.format("batch_size, image_num_patches"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=BaseModelOutputWithPooling,
        config_class=_CONFIG_CLASS_FOR_IMAGE_MODEL_DOC,
        modality="vision",
        expected_output=_EXPECTED_IMAGE_OUTPUT_SHAPE,
    )
    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        bool_masked_pos: Optional[torch.BoolTensor] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[tuple, BaseModelOutputWithPooling]:
        # 设置输出注意事项,默认为模型配置中的输出设置
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 设置输出隐藏状态,默认为模型配置中的输出设置
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 设置是否返回字典,默认为模型配置中的设置
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if pixel_values is None:
            # 如果未提供像素值,则抛出数值错误异常
            raise ValueError("You have to specify pixel_values")

        # 准备头部掩码(如果需要)
        # head_mask 中的 1.0 表示保留该头部
        # attention_probs 的形状为 bsz x n_heads x N x N
        # 输入的 head_mask 形状为 [num_heads] 或 [num_hidden_layers x num_heads]
        # 并且 head_mask 被转换为形状 [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        # 将像素值传入嵌入层进行编码
        embedding_output = self.embeddings(
            pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
        )

        # 将编码后的数据传入编码器进行处理
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 获取编码器的序列输出
        sequence_output = encoder_outputs[0]
        # 序列输出经过 LayerNormalization 处理
        sequence_output = self.layernorm(sequence_output)
        # 如果有池化层,对序列输出进行池化
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

        if not return_dict:
            # 如果不返回字典,则返回元组格式的输出
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        # 如果需要返回字典格式的输出,则构造 BaseModelOutputWithPooling 对象
        return BaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
# 使用装饰器添加文档字符串,描述这是一个在顶部没有特定头部的原始隐藏状态输出的 FLAVA 文本模型转换器。
# 使用 FLAVA_START_DOCSTRING 格式化字符串,填充 FlavaTextConfig 相关信息。
@add_start_docstrings(
    "The bare FLAVA Text Model transformer outputting raw hidden-states without any specific head on top.",
    FLAVA_START_DOCSTRING.format(config="FlavaTextConfig"),
)
# 定义 FlavaTextModel 类,继承自 FlavaPreTrainedModel 类
class FlavaTextModel(FlavaPreTrainedModel):
    # 指定配置类为 FlavaTextConfig
    config_class = FlavaTextConfig
    # 模型前缀用于加载 FlavaTextModel 的检查点
    base_model_prefix = "flava.text_model"

    def __init__(self, config: FlavaTextConfig, add_pooling_layer: bool = True):
        # 调用父类的构造方法,传入配置对象
        super().__init__(config)
        # 保存配置对象
        self.config = config

        # 初始化嵌入层对象
        self.embeddings = FlavaTextEmbeddings(config)
        # 初始化编码器对象
        self.encoder = FlavaEncoder(config)

        # 初始化层归一化层,使用配置中的隐藏层大小和层归一化的 epsilon 参数
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 如果设置了添加池化层,则初始化池化层对象
        self.pooler = FlavaPooler(config) if add_pooling_layer else None

        # 执行初始化后的处理
        self.post_init()

    # 获取输入嵌入的方法,返回词嵌入层对象
    def get_input_embeddings(self) -> PatchEmbeddings:
        return self.embeddings.word_embeddings

    # 设置输入嵌入的方法,设置词嵌入层对象为指定的值
    def set_input_embeddings(self, value: nn.Module):
        self.embeddings.word_embeddings = value

    # 对模型的注意力头进行修剪的方法
    def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    # 使用装饰器添加模型正向传播的文档字符串,描述输入参数的含义
    @add_start_docstrings_to_model_forward(FLAVA_TEXT_INPUTS_DOCSTRING.format("batch_size, text_seq_length"))
    # 使用示例代码的文档字符串样本,描述模型的检查点、输出类型和配置类
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=BaseModelOutputWithPooling,
        config_class=_CONFIG_CLASS_FOR_TEXT_MODEL_DOC,
    )
    # 模型的正向传播方法定义,接收多个输入参数,返回模型输出
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        ) -> Union[tuple, BaseModelOutputWithPooling]:
        # 如果 output_attentions 参数为 None,则使用配置中的 output_attentions 参数
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 如果 output_hidden_states 参数为 None,则使用配置中的 output_hidden_states 参数
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 如果 return_dict 参数为 None,则使用配置中的 use_return_dict 参数
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 如果 input_ids 为空,则抛出数值错误异常
        if input_ids is None:
            raise ValueError("You have to specify input_ids")

        # 获取 input_ids 的形状
        input_shape = input_ids.size()

        # 如果 attention_mask 为空,则创建全 1 的注意力掩码,形状与 input_ids 相同
        if attention_mask is None:
            attention_mask = torch.ones(input_shape, device=input_ids.device)

        # 准备头部掩码(head mask),如果需要的话
        # head_mask 中的 1.0 表示保留对应的头部
        # attention_probs 的形状为 bsz x n_heads x N x N
        # 输入的 head_mask 形状为 [num_heads] 或 [num_hidden_layers x num_heads]
        # 并且 head_mask 被转换为形状 [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
        
        # 获取扩展的注意力掩码(extended_attention_mask)
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
            attention_mask, input_shape, input_ids.device
        )

        # 通过 embeddings 模块生成嵌入输出
        embedding_output = self.embeddings(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
        )

        # 使用 encoder 模块处理嵌入输出,得到编码器的输出
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取序列输出(sequence_output)
        sequence_output = encoder_outputs[0]
        # 应用 layernorm 层到序列输出上
        sequence_output = self.layernorm(sequence_output)
        # 如果有池化器(pooler),则应用池化器到序列输出上,得到池化输出(pooled_output)
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

        # 如果不需要返回字典形式的结果,则返回一个元组
        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        # 否则,返回一个 BaseModelOutputWithPooling 对象
        return BaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
# 添加文档字符串描述 FLAVA Multimodal Model 类的基本信息和配置格式
@add_start_docstrings(
    "The bare FLAVA Multimodal Model transformer outputting raw hidden-states without any specific head on top.",
    FLAVA_START_DOCSTRING.format(config="FlavaMultimodalConfig"),
)
class FlavaMultimodalModel(FlavaPreTrainedModel):
    # 指定该类使用的配置类为 FlavaMultimodalConfig
    config_class = FlavaMultimodalConfig
    # 定义在加载模型时从 FlavaModel/FlavaForPreTraining 检查点中读取的基础模型前缀
    base_model_prefix = "flava.multimodal_model"
    # 主输入名称为 "hidden_states"
    main_input_name = "hidden_states"

    def __init__(self, config: FlavaMultimodalConfig, add_pooling_layer=True):
        # 调用父类构造函数,初始化模型配置
        super().__init__(config)
        self.config = config
        # 根据配置决定是否使用类别标记 (CLS token)
        self.use_cls_token = self.config.use_cls_token
        if self.use_cls_token:
            # 如果使用类别标记,则初始化一个可学习的张量作为类别标记
            self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))

        # 初始化编码器,使用 FlavaEncoder 类
        self.encoder = FlavaEncoder(config)

        # 初始化层归一化层,使用指定的层归一化尺寸和 epsilon 值
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 根据参数决定是否添加池化层,如果需要则使用 FlavaPooler 类初始化池化层
        self.pooler = FlavaPooler(config) if add_pooling_layer else None

        # 调用后初始化方法,用于子类中进一步初始化操作
        self.post_init()

    def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        # 遍历需要剪枝的层和对应的注意力头信息,通过调用编码器中的注意力层进行剪枝
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    @add_start_docstrings_to_model_forward(
        FLAVA_MULTIMODAL_INPUTS_DOCSTRING.format("batch_size, image_num_patches + text_seq_len")
    )
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=BaseModelOutputWithPooling,
        config_class=_CONFIG_CLASS_FOR_MULTIMODAL_MODEL_DOC,
    )
    # 定义模型的前向传播函数,接收输入的隐藏状态和可选的掩码和掩码头,返回模型输出
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        # 前向传播函数的文档字符串描述了输入参数的含义和模型预期的输出

        # 输入的隐藏状态张量
        self,
        # 可选的注意力掩码张量,用于控制哪些位置需要被忽略
        attention_mask: Optional[torch.Tensor] = None,
        # 可选的头掩码张量,用于控制哪些注意力头需要被忽略
        head_mask: Optional[torch.Tensor] = None,
        # 是否输出注意力权重
        output_attentions: Optional[bool] = None,
        # 是否输出隐藏状态
        output_hidden_states: Optional[bool] = None,
        # 是否以字典格式返回结果
        return_dict: Optional[bool] = None,
        ) -> Union[tuple, BaseModelOutputWithPooling]:
        # 确定是否输出注意力权重
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 确定是否输出隐藏状态
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 确定是否使用返回字典格式
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 获取隐藏状态张量的维度
        batch_size, seq_length, _ = hidden_states.size()

        # 如果使用CLS token,则扩展并拼接隐藏状态
        if self.use_cls_token:
            cls_tokens = self.cls_token.expand(batch_size, -1, -1)
            hidden_states = torch.cat((cls_tokens, hidden_states), dim=1)
            seq_length += 1

        # 如果未提供注意力掩码,则创建全1的注意力掩码张量
        if attention_mask is None:
            attention_mask = torch.ones((batch_size, seq_length), device=hidden_states.device)

        # 准备头部掩码(如果需要)
        # head_mask中的1.0表示保留对应的头部
        # attention_probs的形状为bsz x n_heads x N x N
        # 输入的head_mask形状为[num_heads]或[num_hidden_layers x num_heads]
        # head_mask被转换为形状[num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
        
        # 获取扩展的注意力掩码张量
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
            attention_mask, (batch_size, seq_length), hidden_states.device
        )

        # 将隐藏状态输入编码器
        encoder_outputs = self.encoder(
            hidden_states,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        # 获取编码器的序列输出
        sequence_output = encoder_outputs[0]
        # 应用层归一化
        sequence_output = self.layernorm(sequence_output)
        # 如果存在池化器,则对序列输出进行池化
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

        # 如果不需要返回字典,则返回一个元组
        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        # 如果需要返回字典,则构造BaseModelOutputWithPooling对象返回
        return BaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
# 使用装饰器为 FLAVA 模型类添加文档字符串,描述该模型仅输出原始隐藏状态,没有额外的顶层头部。
# FLAVA_START_DOCSTRING 包含一个格式字符串,用于填充 FlavaConfig 的信息。
@add_start_docstrings(
    "The bare FLAVA Model transformer outputting raw hidden-states without any specific head on top.",
    FLAVA_START_DOCSTRING.format(config="FlavaConfig"),
)
class FlavaModel(FlavaPreTrainedModel):
    # 指定配置类为 FlavaConfig
    config_class = FlavaConfig

    def __init__(self, config: FlavaConfig):
        # 调用父类构造函数,初始化模型
        super().__init__(config)

        # 验证文本配置是否为 FlavaTextConfig 类型,否则引发 ValueError
        if not isinstance(config.text_config, FlavaTextConfig):
            raise ValueError(
                "config.text_config is expected to be of type FlavaTextConfig but is of type"
                f" {type(config.text_config)}."
            )

        # 验证图像配置是否为 FlavaImageConfig 类型,否则引发 ValueError
        if not isinstance(config.image_config, FlavaImageConfig):
            raise ValueError(
                "config.image_config is expected to be of type FlavaImageConfig but is of type"
                f" {type(config.image_config)}."
            )

        # 验证多模态配置是否为 FlavaMultimodalConfig 类型,否则引发 ValueError
        if not isinstance(config.multimodal_config, FlavaMultimodalConfig):
            raise ValueError(
                "config.multimodal_config is expected to be of type FlavaMultimodalConfig but "
                + f"is of type {type(config.multimodal_config)}."
            )

        # 将各配置对象存储为类属性
        text_config = config.text_config
        image_config = config.image_config
        multimodal_config = config.multimodal_config

        # 初始化投影维度、文本隐藏层大小、图像隐藏层大小、多模态隐藏层大小
        self.projection_dim = config.projection_dim
        self.text_hidden_size = text_config.hidden_size
        self.image_hidden_size = image_config.hidden_size
        self.mm_hidden_size = multimodal_config.hidden_size

        # 初始化文本模型、图像模型和多模态模型
        self.text_model = FlavaTextModel(text_config)
        self.image_model = FlavaImageModel(image_config)
        self.multimodal_model = FlavaMultimodalModel(multimodal_config)

        # 初始化图像到投影空间的线性层、文本到投影空间的线性层、logit 缩放参数
        self.image_projection = nn.Linear(self.image_hidden_size, self.projection_dim)
        self.text_projection = nn.Linear(self.text_hidden_size, self.projection_dim)
        self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))

        # 初始化图像到多模态投影空间的线性层、文本到多模态投影空间的线性层
        self.image_to_mm_projection = nn.Linear(self.image_hidden_size, self.mm_hidden_size)
        self.text_to_mm_projection = nn.Linear(self.text_hidden_size, self.mm_hidden_size)

        # 执行初始化权重并进行最终处理
        self.post_init()

    # 使用装饰器为 get_text_features 方法添加文档字符串,描述该方法接受文本输入并返回相关特征
    @add_start_docstrings_to_model_forward(FLAVA_TEXT_INPUTS_DOCSTRING.format("batch_size, text_seq_length"))
    def get_text_features(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> torch.FloatTensor:
    r"""
    Returns:
        text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
        applying the projection layer to the pooled output of [`FlavaTextModel`].
        
    Examples:
    
    ```
    >>> from transformers import AutoProcessor, FlavaModel
    
    >>> model = FlavaModel.from_pretrained("{0}")
    >>> processor = AutoProcessor.from_pretrained("{0}")
    
    >>> inputs = processor(
    ...     text=["a photo of a cat", "a photo of a dog"], max_length=77, padding="max_length", return_tensors="pt"
    ... )
    >>> text_features = model.get_text_features(**inputs)
    ```""".format(_CHECKPOINT_FOR_DOC)
    # 使用预训练模型生成文本特征,通过传入的参数组成输入
    text_outputs = self.text_model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
    
    # 从文本输出中获取池化后的特征向量(通常是最后一个隐藏状态)
    pooled_output = text_outputs[0]  # last_hidden_state
    # 将池化后的特征向量投影到最终的文本特征空间
    text_features = self.text_projection(pooled_output)
    
    # 返回文本特征向量
    return text_features

@add_start_docstrings_to_model_forward(FLAVA_IMAGE_INPUTS_DOCSTRING.format("batch_size, image_num_patches"))
def get_image_features(
    self,
    pixel_values: Optional[torch.Tensor] = None,
    bool_masked_pos: Optional[torch.BoolTensor] = None,
    interpolate_pos_encoding: Optional[bool] = None,
    attention_mask: Optional[torch.Tensor] = None,
    head_mask: Optional[torch.Tensor] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    ) -> torch.FloatTensor:
        r"""
        Returns:
            image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
            applying the projection layer to the pooled output of [`FlavaImageModel`].

        Examples:

        ```
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, FlavaModel

        >>> model = FlavaModel.from_pretrained("{0}")
        >>> processor = AutoProcessor.from_pretrained("{0}")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> image_features = model.get_image_features(**inputs)
        ```""".format(_CHECKPOINT_FOR_DOC)
        # 调用图像模型,传入像素值、是否遮蔽位置、注意力掩码、头掩码等参数进行推理
        image_outputs = self.image_model(
            pixel_values=pixel_values,
            bool_masked_pos=bool_masked_pos,
            attention_mask=attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            interpolate_pos_encoding=interpolate_pos_encoding,
            return_dict=return_dict,
        )

        # 从图像模型输出中取出汇总输出(通常是最后一个隐藏状态)
        pooled_output = image_outputs[0]  # last_hidden_state
        # 将汇总输出应用于图像投影层,生成图像特征向量
        image_features = self.image_projection(pooled_output)

        # 返回图像特征向量作为模型前向传播的结果
        return image_features

    @add_start_docstrings_to_model_forward(
        FLAVA_MODEL_INPUTS_DOCSTRING.format("batch_size, image_num_patches + text_seq_len")
    )
    @replace_return_docstrings(output_type=FlavaModelOutput, config_class=FlavaConfig)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        bool_masked_pos: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        image_attention_mask: Optional[torch.Tensor] = None,
        skip_multimodal_encoder: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: bool = True,
        return_dict: Optional[bool] = None,
class FlavaImageCodebookResPath(nn.Module):
    def __init__(self, in_size: int, out_size: int, **kwargs):
        super().__init__()
        hid_size = out_size // 4

        # 定义一个有序字典,用于存储网络的层次结构
        path = OrderedDict()
        path["relu_1"] = nn.ReLU()  # 第一个 ReLU 激活函数
        path["conv_1"] = nn.Conv2d(in_size, hid_size, kernel_size=3, padding=1)  # 第一个卷积层
        path["relu_2"] = nn.ReLU()  # 第二个 ReLU 激活函数
        path["conv_2"] = nn.Conv2d(hid_size, hid_size, kernel_size=3, padding=1)  # 第二个卷积层
        path["relu_3"] = nn.ReLU()  # 第三个 ReLU 激活函数
        path["conv_3"] = nn.Conv2d(hid_size, hid_size, kernel_size=3, padding=1)  # 第三个卷积层
        path["relu_4"] = nn.ReLU()  # 第四个 ReLU 激活函数
        path["conv_4"] = nn.Conv2d(hid_size, out_size, kernel_size=1, padding=0)  # 第四个卷积层(输出层)

        # 使用有序字典定义的层次结构创建一个顺序容器
        self.path = nn.Sequential(path)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.path(x)


class FlavaImageCodebookBlock(nn.Module):
    def __init__(self, in_size: int, out_size: int, num_layers: int, **kwargs):
        super().__init__()

        # 计算后增益,用于乘以残差路径的输出
        self.post_gain = 1 / (num_layers**2)

        # 如果输入尺寸不等于输出尺寸,使用 1x1 卷积进行维度匹配
        if in_size != out_size:
            self.id_path = nn.Conv2d(in_size, out_size, kernel_size=1, padding=0)
        else:
            self.id_path = nn.Identity()  # 若输入输出尺寸相同,则使用恒等映射

        # 创建 FLAVA 图像编码块的残差路径
        self.res_path = FlavaImageCodebookResPath(in_size, out_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 返回恒等映射加上后增益乘以残差路径的输出
        return self.id_path(x) + self.post_gain * self.res_path(x)


class FlavaImageCodebookLayerGroup(nn.Module):
    def __init__(self, num_blocks: int, num_layers: int, in_size: int, out_size: int, use_pool: bool = True):
        super().__init__()
        blocks = OrderedDict()
        
        # 创建多个 FLAVA 图像编码块的组合
        for i in range(num_blocks):
            if i == 0:
                blocks[f"block_{i+1}"] = FlavaImageCodebookBlock(in_size, out_size, num_layers)
            else:
                blocks[f"block_{i+1}"] = FlavaImageCodebookBlock(out_size, out_size, num_layers)

        # 如果指定使用池化层,则添加最大池化层到块组中
        if use_pool:
            blocks["pool"] = nn.MaxPool2d(kernel_size=2)

        # 创建顺序容器,包含所有创建的块组
        self.group = nn.Sequential(blocks)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.group(x)


# 受 DALL-E 编码器启发,FLAVA 图像码书模型用于生成原始隐藏状态,可用于根据 DALL-E 词汇为图像生成图像标记。用于为 MIM 生成标签。
# 使用 `get_codebook_indices` 函数获取图像的标记。
@add_start_docstrings(
    """
    FLAVA 图像码书模型,受 DALL-E 原始编码器启发而来。输出原始隐藏状态,可用于根据 DALL-E 词汇为图像生成图像标记。用于为 MIM 生成标签。
    使用 `get_codebook_indices` 函数获取图像的标记。
    """,
    FLAVA_START_DOCSTRING.format(config="FlavaImageCodebookConfig"),
)
class FlavaImageCodebook(FlavaPreTrainedModel):
    base_model_prefix = ""
    config_class = FlavaImageCodebookConfig
    main_input_name = "pixel_values"
    supports_gradient_checkpointing = False

    def __init__(
        self,
        config: FlavaImageCodebookConfig,
        **kwargs: Any,
    ):
        super().__init__(config)  # 调用父类构造函数,初始化模型配置

        self.config = config  # 将配置信息存储到对象属性中
        self.num_groups = config.num_groups  # 设置组数
        self.input_channels = config.input_channels  # 设置输入通道数
        self.num_blocks_per_group = config.num_blocks_per_group  # 设置每组中的块数
        self.hidden_size = config.hidden_size  # 设置隐藏层大小
        self.vocab_size = config.vocab_size  # 设置词汇表大小

        num_layers = self.num_groups * self.num_blocks_per_group  # 计算总层数

        output_blocks = OrderedDict()
        output_blocks["relu"] = nn.ReLU()  # 添加ReLU激活函数到输出块
        output_blocks["conv"] = nn.Conv2d(8 * self.hidden_size, self.vocab_size, kernel_size=1, padding=0)  # 添加卷积层到输出块

        blocks = OrderedDict()
        blocks["input"] = nn.Conv2d(self.input_channels, 1 * self.hidden_size, kernel_size=7, padding=3)  # 添加输入卷积层到块
        blocks["group_1"] = FlavaImageCodebookLayerGroup(
            self.num_blocks_per_group, num_layers, 1 * self.hidden_size, 1 * self.hidden_size
        )  # 添加第一个图像码书层组
        blocks["group_2"] = FlavaImageCodebookLayerGroup(
            self.num_blocks_per_group, num_layers, 1 * self.hidden_size, 2 * self.hidden_size
        )  # 添加第二个图像码书层组
        blocks["group_3"] = FlavaImageCodebookLayerGroup(
            self.num_blocks_per_group, num_layers, 2 * self.hidden_size, 4 * self.hidden_size
        )  # 添加第三个图像码书层组
        blocks["group_4"] = FlavaImageCodebookLayerGroup(
            self.num_blocks_per_group, num_layers, 4 * self.hidden_size, 8 * self.hidden_size, use_pool=False
        )  # 添加第四个图像码书层组,并指定不使用池化操作
        blocks["output"] = nn.Sequential(output_blocks)  # 添加输出块到块序列

        self.blocks = nn.Sequential(blocks)  # 构建模型的块序列

        self.post_init()  # 执行后初始化操作

        if self.config.freeze:
            for param in self.parameters():
                param.requires_grad = False  # 如果配置要求冻结模型,则设置所有参数不需要梯度计算

    def get_codebook_indices(self, pixel_values: torch.Tensor) -> torch.Tensor:
        """
        Args:
            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
                Pixel values. Codebook pixel values can be obtained using [`AutoImageProcessor`] by passing
                `return_codebook_pixels=True`. See [`FlavaImageProcessor.__call__`] for details.

        Examples:
        ```
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoImageProcessor, FlavaImageCodebook

        >>> model = FlavaImageCodebook.from_pretrained("{0}")
        >>> image_processor = AutoImageProcessor.from_pretrained("{0}")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = image_processor([image], return_codebook_pixels=True, return_tensors="pt")
        >>> inputs = dict(pixel_values=inputs.codebook_pixel_values)

        >>> outputs = model.get_codebook_indices(**inputs)
        ```
        """.format(_CHECKPOINT_FOR_CODEBOOK_DOC)
        z_logits = self.blocks(pixel_values)  # 将像素值传递给模型的块序列进行处理,得到logits
        return torch.argmax(z_logits, axis=1)  # 返回logits的最大值索引作为码书索引
        # 使用给定的像素值作为输入,通过神经网络模块生成概率分布的 logits
        z_logits = self.blocks(pixel_values)
        # 对 logits 进行 softmax 处理,得到概率分布
        return nn.Softmax(dim=1)(z_logits)
# 定义一个类,用于处理 Flava 模型的预测头部变换
class FlavaPredictionHeadTransform(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 创建一个全连接层,输入和输出维度都是 config.hidden_size
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 根据配置选择激活函数,存储在 transform_act_fn 中
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        # 创建一个 LayerNorm 层,归一化隐藏状态的维度为 config.hidden_size
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    # 前向传播函数,对输入的隐藏状态进行变换操作
    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)  # 全连接层变换
        hidden_states = self.transform_act_fn(hidden_states)  # 应用激活函数
        hidden_states = self.LayerNorm(hidden_states)  # LayerNorm 归一化
        return hidden_states


# 定义一个类,用于 Flava 模型的蒙版预测头部
class FlavaMaskedPredictionHead(nn.Module):
    def __init__(self, config, weight=None):
        super().__init__()
        self.config = config
        # 创建 FlavaPredictionHeadTransform 实例,用于变换隐藏状态
        self.transform = FlavaPredictionHeadTransform(config)
        # 创建一个线性层,将隐藏状态映射到词汇表大小,无偏置
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))  # 创建一个偏置参数
        if weight is not None:
            self.decoder.weight = weight

        # 将 decoder 的偏置参数与 self.bias 关联,以便在 resize_token_embeddings 时正确调整大小
        self.decoder.bias = self.bias

    # 前向传播函数,对输入进行预测头部的操作
    def forward(self, x):
        x = self.transform(x)  # 应用变换操作
        x = self.decoder(x)  # 使用线性层进行预测
        return x


# 定义一个类,用于 Flava 模型的 ITM 头部
class FlavaITMHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # 创建 FlavaPooler 实例,用于池化隐藏状态
        self.pooler = FlavaPooler(config)
        # 创建一个线性层,将池化后的隐藏状态映射到 2 个输出类别(用于 ITM 任务)
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    # 前向传播函数,对输入进行 ITM 头部的操作
    def forward(self, x):
        x = self.pooler(x)  # 应用池化操作
        x = self.seq_relationship(x)  # 使用线性层进行序列关系预测
        return x


# 定义一个类,用于 Flava 模型的全局对比头部
class FlavaGlobalContrastiveHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # 存储全局背景传播对比的配置
        self.global_backprop_contrastive = config.global_backprop_contrastive
    # 定义一个前向传播函数,接收图片嵌入、文本嵌入和logit缩放因子作为输入参数
    def forward(self, image_embeddings, text_embeddings, logit_scale):
        # 计算温度参数,使用logit缩放因子的指数值
        temperature = torch.exp(logit_scale)
        
        # 检查当前环境是否支持并初始化了分布式训练
        if not torch.distributed.is_available() or not torch.distributed.is_initialized():
            # 如果未启用分布式训练,则生成标签张量,从0到图片嵌入的数量
            labels = torch.arange(image_embeddings.size(0), device=image_embeddings.device)
            # 将当前批次的图片嵌入数据存储在列表中
            image_embeddings_all = [image_embeddings]
            # 将当前批次的文本嵌入数据存储在列表中
            text_embeddings_all = [text_embeddings]
        else:
            # 获取当前批次的本地大小
            local_batch_size = image_embeddings.size(0)
            # 获取分布式训练的总节点数
            world_size = torch.distributed.get_world_size()

            if self.global_backprop_contrastive:
                # 如果启用全局反向传播对比,则使用分布式函数收集所有工作节点的图片和文本嵌入
                image_embeddings_all = torch.distributed.nn.functional.all_gather(image_embeddings)
                text_embeddings_all = torch.distributed.nn.functional.all_gather(text_embeddings)
            else:
                # 如果未启用全局反向传播对比,则为每个工作节点创建一个零张量列表
                image_embeddings_all = [torch.zeros_like(text_embeddings) for _ in range(world_size)]
                text_embeddings_all = [torch.zeros_like(image_embeddings) for _ in range(world_size)]
                # 使用分布式函数收集所有工作节点的图片嵌入数据
                torch.distributed.all_gather(image_embeddings_all, image_embeddings)
                # 使用分布式函数收集所有工作节点的文本嵌入数据
                torch.distributed.all_gather(text_embeddings_all, text_embeddings)

            # 为每个本地批次生成对应的标签,考虑当前节点的排名和本地批次大小
            labels = local_batch_size * torch.distributed.get_rank() + torch.arange(
                local_batch_size, device=image_embeddings.device
            )

        # 将收集到的所有图片嵌入数据拼接成一个张量
        image_embeddings_all = torch.cat(image_embeddings_all)
        # 将收集到的所有文本嵌入数据拼接成一个张量
        text_embeddings_all = torch.cat(text_embeddings_all)

        # 计算图片嵌入与所有文本嵌入的点积,并乘以温度参数
        logits_per_image = torch.matmul(image_embeddings, text_embeddings_all.transpose(0, 1)) * temperature
        # 计算文本嵌入与所有图片嵌入的点积,并乘以温度参数
        logits_per_text = torch.matmul(text_embeddings, image_embeddings_all.transpose(0, 1)) * temperature

        # 返回计算得到的图片logits、文本logits以及相应的标签
        return logits_per_image, logits_per_text, labels
# 使用装饰器为 FLAVA 预训练模型添加文档字符串,描述模型输出损失、嵌入、logits 和变换器输出。
@add_start_docstrings(
    """
    The FLAVA model for pretraining which outputs losses, embeddings, logits and transformer outputs.
    """,
    FLAVA_START_DOCSTRING.format(config="FlavaConfig") + FLAVA_PRETRAINING_START_DOCSTRING_EXTRA,
)
class FlavaForPreTraining(FlavaPreTrainedModel):
    # 这些键与 xxx.bias 相关联
    _tied_weights_keys = [
        "mmm_text_head.decoder.bias",
        "mmm_image_head.decoder.bias",
        "mlm_head.decoder.bias",
        "mim_head.decoder.bias",
    ]

    def __init__(self, config: FlavaConfig, image_codebook: Optional[nn.Module] = None):
        # 调用父类构造函数初始化模型
        super().__init__(config)
        # 创建 FLAVA 模型
        self.flava = FlavaModel(config)

        # 设置图像码书,如果未提供且配置指定则初始化图像码书
        self.image_codebook = image_codebook
        if self.image_codebook is None and config.init_codebook:
            self.image_codebook = FlavaImageCodebook(config.image_codebook_config)

        # 根据文本和图像编码器配置创建遮蔽头,以确保有正确的词汇表
        self.mim_head = FlavaMaskedPredictionHead(config.image_config)
        self.mlm_head = FlavaMaskedPredictionHead(config.text_config)
        self.itm_head = FlavaITMHead(config)
        self.mmm_image_head = FlavaMaskedPredictionHead(config.image_config)
        self.mmm_text_head = FlavaMaskedPredictionHead(config.text_config)
        self.global_contrastive_head = FlavaGlobalContrastiveHead(config)

        # 设置图像和文本词汇表大小
        self.image_vocab_size = config.image_config.vocab_size
        self.text_vocab_size = config.text_config.vocab_size
        # 设置 MLM、MIM、全局对比损失权重
        self.mlm_weight = config.mlm_weight
        self.mim_weight = config.mim_weight
        self.global_contrastive_weight = config.global_contrastive_weight
        # 设置交叉熵忽略索引和 ITM 权重
        self.ce_ignore_index = config.ce_ignore_index
        self.itm_weight = config.itm_weight
        # 设置 MMM 图像和文本权重
        self.mmm_image_weight = config.mmm_image_weight
        self.mmm_text_weight = config.mmm_text_weight
        # 设置是否跳过未遮蔽的多模态编码器
        self.skip_unmasked_multimodal_encoder = config.skip_unmasked_multimodal_encoder

        # 执行初始化后操作
        self.post_init()

    def _resize_to_2d(self, x: torch.Tensor):
        # 如果输入张量维度大于 2,则展平为二维张量
        if x.dim() > 2:
            x = x.view(x.size(0), -1)
        return x

    @add_start_docstrings_to_model_forward(
        # 添加模型 forward 方法的输入文档字符串,描述输入参数的形状
        FLAVA_PRETRAINING_INPUTS_DOCSTRING.format("batch_size, text_seq_len", "batch_size, image_num_patches")
    )
    @replace_return_docstrings(output_type=FlavaForPreTrainingOutput, config_class=FlavaConfig)
        # 定义模型的前向传播函数,接收多个输入参数,所有参数都是可选的
        self,
        # 输入的token IDs序列,用于文本输入
        input_ids: Optional[torch.LongTensor] = None,
        # 掩码后的输入token IDs序列,用于MLM任务
        input_ids_masked: Optional[torch.LongTensor] = None,
        # 图像的像素值,用于图像输入
        pixel_values: Optional[torch.FloatTensor] = None,
        # 用于编码图像的码本像素值
        codebook_pixel_values: Optional[torch.FloatTensor] = None,
        # 注意力掩码,用于指示哪些token是padding的
        attention_mask: Optional[torch.Tensor] = None,
        # token类型IDs,用于BERT类型模型
        token_type_ids: Optional[torch.Tensor] = None,
        # 布尔掩码,指示哪些位置是被掩盖的
        bool_masked_pos: Optional[torch.Tensor] = None,
        # 位置IDs,用于指定token的位置信息
        position_ids: Optional[torch.LongTensor] = None,
        # 图像注意力掩码,指示图像中哪些部分需要注意力
        image_attention_mask: Optional[torch.Tensor] = None,
        # 是否跳过未掩盖的多模态编码器
        skip_unmasked_multimodal_encoder: bool = None,
        # MLM任务的标签
        mlm_labels: Optional[torch.Tensor] = None,
        # MIM任务的标签
        mim_labels: Optional[torch.Tensor] = None,
        # ITM任务的标签
        itm_labels: Optional[torch.Tensor] = None,
        # 是否输出注意力权重
        output_attentions: Optional[bool] = None,
        # 是否输出隐藏状态
        output_hidden_states: bool = True,
        # 是否返回字典形式的输出
        return_dict: Optional[bool] = None,
        # 是否返回损失值
        return_loss: Optional[bool] = None,

.\models\flava\processing_flava.py

# 设置文件编码为UTF-8
# 版权声明
# 根据Apache License, Version 2.0 (许可证)的规定,除非符合许可证的条款,否则不得使用此文件
# 可以在以下网址获取许可证的副本
# http://www.apache.org/licenses/LICENSE-2.0
# 除非适用法律要求或书面同意,否则根据许可证分发的软件是基于“按原样”分发,
# 没有任何明示或暗示的保证或条件。请参阅许可证获取特定语言的权限和限制

"""
Image/Text processor class for FLAVA
"""

# 导入模块
import warnings
from typing import List, Optional, Union

# 导入自定义模块
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType

# 定义FlavaProcessor类,并继承ProcessorMixin
class FlavaProcessor(ProcessorMixin):
    r"""
    Constructs a FLAVA processor which wraps a FLAVA image processor and a FLAVA tokenizer into a single processor.

    [`FlavaProcessor`] offers all the functionalities of [`FlavaImageProcessor`] and [`BertTokenizerFast`]. See the
    [`~FlavaProcessor.__call__`] and [`~FlavaProcessor.decode`] for more information.

    Args:
        image_processor ([`FlavaImageProcessor`], *optional*): The image processor is a required input.
        tokenizer ([`BertTokenizerFast`], *optional*): The tokenizer is a required input.
    """

    # 定义类属性
    attributes = ["image_processor", "tokenizer"]
    image_processor_class = "FlavaImageProcessor"
    tokenizer_class = ("BertTokenizer", "BertTokenizerFast")

    # 定义初始化方法
    def __init__(self, image_processor=None, tokenizer=None, **kwargs):
        feature_extractor = None
        # 如果kwargs中包含feature_extractor,则发出警告,此参数将在v5中被弃用
        if "feature_extractor" in kwargs:
            warnings.warn(
                "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
                " instead.",
                FutureWarning,
            )
            feature_extractor = kwargs.pop("feature_extractor")

        # 如果image_processor未传入,则使用feature_extractor
        image_processor = image_processor if image_processor is not None else feature_extractor
        # 如果image_processor仍然为None,则抛出数值错误
        if image_processor is None:
            raise ValueError("You need to specify an `image_processor`.")
        # 如果tokenizer未传入,则抛出数值错误
        if tokenizer is None:
            raise ValueError("You need to specify a `tokenizer`.")

        # 调用父类初始化方法,传入image_processor和tokenizer
        super().__init__(image_processor, tokenizer)
        # 初始化current_processor属性为image_processor
        self.current_processor = self.image_processor
    def __call__(
        self,
        images: Optional[ImageInput] = None,
        text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
        add_special_tokens: bool = True,
        padding: Union[bool, str, PaddingStrategy] = False,
        truncation: Union[bool, str, TruncationStrategy] = False,
        max_length: Optional[int] = None,
        stride: int = 0,
        pad_to_multiple_of: Optional[int] = None,
        return_image_mask: Optional[bool] = None,
        return_codebook_pixels: Optional[bool] = None,
        return_token_type_ids: Optional[bool] = None,
        return_attention_mask: Optional[bool] = None,
        return_overflowing_tokens: bool = False,
        return_special_tokens_mask: bool = False,
        return_offsets_mapping: bool = False,
        return_length: bool = False,
        verbose: bool = True,
        return_tensors: Optional[Union[str, TensorType]] = None,
        **kwargs,
    ):
        """
        This method uses [`FlavaImageProcessor.__call__`] method to prepare image(s) for the model, and
        [`BertTokenizerFast.__call__`] to prepare text for the model.

        Please refer to the docstring of the above two methods for more information.
        """

        # 检查是否同时未指定文本和图像,若是,则抛出数值错误异常
        if text is None and images is None:
            raise ValueError("You have to specify either text or images. Both cannot be none.")

        # 如果存在文本,则使用 tokenizer 处理文本数据
        if text is not None:
            encoding = self.tokenizer(
                text=text,
                add_special_tokens=add_special_tokens,
                padding=padding,
                truncation=truncation,
                max_length=max_length,
                stride=stride,
                pad_to_multiple_of=pad_to_multiple_of,
                return_token_type_ids=return_token_type_ids,
                return_attention_mask=return_attention_mask,
                return_overflowing_tokens=return_overflowing_tokens,
                return_special_tokens_mask=return_special_tokens_mask,
                return_offsets_mapping=return_offsets_mapping,
                return_length=return_length,
                verbose=verbose,
                return_tensors=return_tensors,
                **kwargs,
            )
        
        # 如果存在图像,则使用 image_processor 处理图像数据
        if images is not None:
            image_features = self.image_processor(
                images,
                return_image_mask=return_image_mask,
                return_codebook_pixels=return_codebook_pixels,
                return_tensors=return_tensors,
                **kwargs,
            )

        # 如果同时存在文本和图像,则将图像特征更新到文本编码中并返回结果
        if text is not None and images is not None:
            encoding.update(image_features)
            return encoding
        # 如果仅存在文本,则直接返回文本编码结果
        elif text is not None:
            return encoding
        # 如果仅存在图像,则创建一个 BatchEncoding 对象返回图像特征
        else:
            return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
    # 将所有参数转发给 BertTokenizerFast 的 `~PreTrainedTokenizer.batch_decode` 方法,并返回结果
    def batch_decode(self, *args, **kwargs):
        return self.tokenizer.batch_decode(*args, **kwargs)

    # 将所有参数转发给 BertTokenizerFast 的 `~PreTrainedTokenizer.decode` 方法,并返回结果
    def decode(self, *args, **kwargs):
        return self.tokenizer.decode(*args, **kwargs)

    # 返回模型输入的名称列表,这里使用了去重操作
    @property
    def model_input_names(self):
        tokenizer_input_names = self.tokenizer.model_input_names
        image_processor_input_names = self.image_processor.model_input_names
        return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))

    # 返回特征提取器的类别,已被标记为废弃,建议使用 `image_processor_class` 替代
    @property
    def feature_extractor_class(self):
        warnings.warn(
            "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.",
            FutureWarning,
        )
        return self.image_processor_class

    # 返回特征提取器,已被标记为废弃,建议使用 `image_processor` 替代
    @property
    def feature_extractor(self):
        warnings.warn(
            "`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.",
            FutureWarning,
        )
        return self.image_processor

.\models\flava\__init__.py

# 版权声明及导入必要的类型检查
# Meta Platforms 作者和 The HuggingFace Team 版权声明
# Apache License, Version 2.0 版权许可,可以在指定条件下使用此文件
# 如果未按许可条件使用,可能会出现限制和法律责任
from typing import TYPE_CHECKING

# 导入异常处理相关依赖
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available

# 定义导入结构,包含需要导入的模块和类
_import_structure = {
    "configuration_flava": [
        "FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP",
        "FlavaConfig",
        "FlavaImageCodebookConfig",
        "FlavaImageConfig",
        "FlavaMultimodalConfig",
        "FlavaTextConfig",
    ],
}

# 检查是否存在视觉处理相关的依赖,若不存在则引发异常
try:
    if not is_vision_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 添加视觉特征提取相关的导入结构
    _import_structure["feature_extraction_flava"] = ["FlavaFeatureExtractor"]
    _import_structure["image_processing_flava"] = ["FlavaImageProcessor"]
    _import_structure["processing_flava"] = ["FlavaProcessor"]

# 检查是否存在 Torch 相关的依赖,若不存在则引发异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 添加模型相关的导入结构
    _import_structure["modeling_flava"] = [
        "FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST",
        "FlavaForPreTraining",
        "FlavaImageCodebook",
        "FlavaImageModel",
        "FlavaModel",
        "FlavaMultimodalModel",
        "FlavaPreTrainedModel",
        "FlavaTextModel",
    ]

# 如果是类型检查阶段,则导入具体模块
if TYPE_CHECKING:
    from .configuration_flava import (
        FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP,
        FlavaConfig,
        FlavaImageCodebookConfig,
        FlavaImageConfig,
        FlavaMultimodalConfig,
        FlavaTextConfig,
    )

    try:
        if not is_vision_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .feature_extraction_flava import FlavaFeatureExtractor
        from .image_processing_flava import FlavaImageProcessor
        from .processing_flava import FlavaProcessor

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_flava import (
            FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST,
            FlavaForPreTraining,
            FlavaImageCodebook,
            FlavaImageModel,
            FlavaModel,
            FlavaMultimodalModel,
            FlavaPreTrainedModel,
            FlavaTextModel,
        )

else:
    # 如果不是类型检查阶段,则直接导入 sys 模块
    import sys
    # 将当前模块注册到 sys.modules 中,使用 LazyModule 进行延迟加载
    sys.modules[__name__] = _LazyModule(
        __name__,  # 当前模块的名称
        globals()["__file__"],  # 当前模块的文件路径
        _import_structure,  # 导入结构的定义
        module_spec=__spec__  # 当前模块的规范对象
    )

.\models\fnet\configuration_fnet.py

"""
FNet model configuration

This module defines the configuration for the FNet model, specifying how to instantiate
different variants of the model architecture based on provided arguments.

"""

# Import necessary modules from Hugging Face library
from ...configuration_utils import PretrainedConfig
from ...utils import logging

# Get logger instance for logging messages related to this module
logger = logging.get_logger(__name__)

# Map of pretrained FNet model configurations with their respective URLs
FNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "google/fnet-base": "https://huggingface.co/google/fnet-base/resolve/main/config.json",
    "google/fnet-large": "https://huggingface.co/google/fnet-large/resolve/main/config.json",
    # Additional models can be found at the provided URL
}

class FNetConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`FNetModel`]. It is used to instantiate an FNet
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the FNet
    [google/fnet-base](https://huggingface.co/google/fnet-base) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
"""
    # 模型类型设定为 "fnet"
    model_type = "fnet"
    # 初始化函数,用于创建一个新的对象实例
    def __init__(
        self,
        vocab_size=32000,  # 设置词汇表大小,默认为32000
        hidden_size=768,  # 设置隐藏层大小,默认为768
        num_hidden_layers=12,  # 设置隐藏层数,默认为12
        intermediate_size=3072,  # 设置中间层大小,默认为3072
        hidden_act="gelu_new",  # 设置隐藏层激活函数,默认为"gelu_new"
        hidden_dropout_prob=0.1,  # 设置隐藏层dropout概率,默认为0.1
        max_position_embeddings=512,  # 设置最大位置嵌入大小,默认为512
        type_vocab_size=4,  # 设置类型词汇表大小,默认为4
        initializer_range=0.02,  # 设置初始化范围,默认为0.02
        layer_norm_eps=1e-12,  # 设置层归一化的epsilon值,默认为1e-12
        use_tpu_fourier_optimizations=False,  # 是否使用TPU Fourier优化,默认为False
        tpu_short_seq_length=512,  # TPU短序列长度,默认为512
        pad_token_id=3,  # PAD标记的token id,默认为3
        bos_token_id=1,  # 开始序列标记的token id,默认为1
        eos_token_id=2,  # 结束序列标记的token id,默认为2
        **kwargs,
    ):
        # 调用父类的初始化方法,传入PAD、BOS、EOS标记的token id以及其他关键字参数
        super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
    
        # 初始化对象的各个属性,用传入的参数或者默认值
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.intermediate_size = intermediate_size
        self.hidden_act = hidden_act
        self.hidden_dropout_prob = hidden_dropout_prob
        self.initializer_range = initializer_range
        self.type_vocab_size = type_vocab_size
        self.layer_norm_eps = layer_norm_eps
        self.use_tpu_fourier_optimizations = use_tpu_fourier_optimizations
        self.tpu_short_seq_length = tpu_short_seq_length

.\models\fnet\convert_fnet_original_flax_checkpoint_to_pytorch.py

    # 导入所需的库和模块
import argparse  # 用于解析命令行参数

import torch  # 导入PyTorch库
from flax.training.checkpoints import restore_checkpoint  # 从Flax库中导入恢复检查点的函数

from transformers import FNetConfig, FNetForPreTraining  # 导入FNet模型的配置和预训练类
from transformers.utils import logging  # 导入日志记录工具

# 设置日志输出级别为信息
logging.set_verbosity_info()

def convert_flax_checkpoint_to_pytorch(flax_checkpoint_path, fnet_config_file, save_path):
    # 使用FNetConfig类从JSON文件加载FNet的配置
    config = FNetConfig.from_json_file(fnet_config_file)
    print(f"Building PyTorch model from configuration: {config}")
    # 根据配置初始化FNetForPreTraining模型
    fnet_pretraining_model = FNetForPreTraining(config)

    # 从Flax的检查点中恢复模型参数
    checkpoint_dict = restore_checkpoint(flax_checkpoint_path, None)
    pretrained_model_params = checkpoint_dict["target"]

    # 初始化新的状态字典,用于存储PyTorch模型的参数
    state_dict = fnet_pretraining_model.state_dict()

    # 处理嵌入层的参数转换

    # 位置编码
    position_ids = state_dict["fnet.embeddings.position_ids"]
    new_state_dict = {"fnet.embeddings.position_ids": position_ids}

    # 单词嵌入
    new_state_dict["fnet.embeddings.word_embeddings.weight"] = torch.tensor(
        pretrained_model_params["encoder"]["embedder"]["word"]["embedding"]
    )

    # 位置嵌入
    new_state_dict["fnet.embeddings.position_embeddings.weight"] = torch.tensor(
        pretrained_model_params["encoder"]["embedder"]["position"]["embedding"][0]
    )

    # 类型嵌入
    new_state_dict["fnet.embeddings.token_type_embeddings.weight"] = torch.tensor(
        pretrained_model_params["encoder"]["embedder"]["type"]["embedding"]
    )

    # 投影层的权重和偏置
    new_state_dict["fnet.embeddings.projection.weight"] = torch.tensor(
        pretrained_model_params["encoder"]["embedder"]["hidden_mapping_in"]["kernel"]
    ).T
    new_state_dict["fnet.embeddings.projection.bias"] = torch.tensor(
        pretrained_model_params["encoder"]["embedder"]["hidden_mapping_in"]["bias"]
    )

    # LayerNorm层的权重和偏置
    new_state_dict["fnet.embeddings.LayerNorm.weight"] = torch.tensor(
        pretrained_model_params["encoder"]["embedder"]["layer_norm"]["scale"]
    )
    new_state_dict["fnet.embeddings.LayerNorm.bias"] = torch.tensor(
        pretrained_model_params["encoder"]["embedder"]["layer_norm"]["bias"]
    )

    # 处理编码器层的参数转换
    # 对每个隐藏层进行循环,从预训练模型参数中加载相关的权重和偏置

    # 加载当前层的 Fourier 输出的 LayerNorm 权重
    new_state_dict[f"fnet.encoder.layer.{layer}.fourier.output.LayerNorm.weight"] = torch.tensor(
        pretrained_model_params["encoder"][f"encoder_{layer}"]["mixing_layer_norm"]["scale"]
    )
    # 加载当前层的 Fourier 输出的 LayerNorm 偏置
    new_state_dict[f"fnet.encoder.layer.{layer}.fourier.output.LayerNorm.bias"] = torch.tensor(
        pretrained_model_params["encoder"][f"encoder_{layer}"]["mixing_layer_norm"]["bias"]
    )

    # 加载当前层的 intermediate dense 层权重,并转置
    new_state_dict[f"fnet.encoder.layer.{layer}.intermediate.dense.weight"] = torch.tensor(
        pretrained_model_params["encoder"][f"feed_forward_{layer}"]["intermediate"]["kernel"]
    ).T
    # 加载当前层的 intermediate dense 层偏置
    new_state_dict[f"fnet.encoder.layer.{layer}.intermediate.dense.bias"] = torch.tensor(
        pretrained_model_params["encoder"][f"feed_forward_{layer}"]["intermediate"]["bias"]
    )

    # 加载当前层的 output dense 层权重,并转置
    new_state_dict[f"fnet.encoder.layer.{layer}.output.dense.weight"] = torch.tensor(
        pretrained_model_params["encoder"][f"feed_forward_{layer}"]["output"]["kernel"]
    ).T
    # 加载当前层的 output dense 层偏置
    new_state_dict[f"fnet.encoder.layer.{layer}.output.dense.bias"] = torch.tensor(
        pretrained_model_params["encoder"][f"feed_forward_{layer}"]["output"]["bias"]
    )

    # 加载当前层的 output LayerNorm 权重
    new_state_dict[f"fnet.encoder.layer.{layer}.output.LayerNorm.weight"] = torch.tensor(
        pretrained_model_params["encoder"][f"encoder_{layer}"]["output_layer_norm"]["scale"]
    )
    # 加载当前层的 output LayerNorm 偏置
    new_state_dict[f"fnet.encoder.layer.{layer}.output.LayerNorm.bias"] = torch.tensor(
        pretrained_model_params["encoder"][f"encoder_{layer}"]["output_layer_norm"]["bias"]
    )

    # 加载池化层的 dense 权重,并转置
    new_state_dict["fnet.pooler.dense.weight"] = torch.tensor(pretrained_model_params["encoder"]["pooler"]["kernel"]).T
    # 加载池化层的 dense 偏置
    new_state_dict["fnet.pooler.dense.bias"] = torch.tensor(pretrained_model_params["encoder"]["pooler"]["bias"])

    # 加载预测层的 transform dense 权重,并转置
    new_state_dict["cls.predictions.transform.dense.weight"] = torch.tensor(
        pretrained_model_params["predictions_dense"]["kernel"]
    ).T
    # 加载预测层的 transform dense 偏置
    new_state_dict["cls.predictions.transform.dense.bias"] = torch.tensor(
        pretrained_model_params["predictions_dense"]["bias"]
    )
    # 加载预测层的 transform LayerNorm 权重
    new_state_dict["cls.predictions.transform.LayerNorm.weight"] = torch.tensor(
        pretrained_model_params["predictions_layer_norm"]["scale"]
    )
    # 加载预测层的 transform LayerNorm 偏置
    new_state_dict["cls.predictions.transform.LayerNorm.bias"] = torch.tensor(
        pretrained_model_params["predictions_layer_norm"]["bias"]
    )
    
    # 加载预测层的 decoder 权重
    new_state_dict["cls.predictions.decoder.weight"] = torch.tensor(
        pretrained_model_params["encoder"]["embedder"]["word"]["embedding"]
    )
    # 加载预测层的 decoder 偏置
    new_state_dict["cls.predictions.decoder.bias"] = torch.tensor(
        pretrained_model_params["predictions_output"]["output_bias"]
    )
    # 加载预测层的 bias
    new_state_dict["cls.predictions.bias"] = torch.tensor(pretrained_model_params["predictions_output"]["output_bias"])

    # Seq Relationship Layers
    # 使用预训练模型参数中的输出核和偏置,创建新的张量并赋给新状态字典的键"cls.seq_relationship.weight"
    new_state_dict["cls.seq_relationship.weight"] = torch.tensor(
        pretrained_model_params["classification"]["output_kernel"]
    )
    # 使用预训练模型参数中的输出偏置,创建新的张量并赋给新状态字典的键"cls.seq_relationship.bias"
    new_state_dict["cls.seq_relationship.bias"] = torch.tensor(
        pretrained_model_params["classification"]["output_bias"]
    )

    # 加载新状态字典到预训练模型中
    fnet_pretraining_model.load_state_dict(new_state_dict)

    # 打印信息,指示正在将预训练模型保存到指定路径
    print(f"Saving pretrained model to {save_path}")

    # 将预训练模型保存到指定路径
    fnet_pretraining_model.save_pretrained(save_path)
if __name__ == "__main__":
    # 如果当前脚本作为主程序运行

    parser = argparse.ArgumentParser()
    # 创建命令行参数解析器对象

    # Required parameters
    parser.add_argument(
        "--flax_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
    )
    # 添加必需的命令行参数 --flax_checkpoint_path,指定 TensorFlow 检查点路径

    parser.add_argument(
        "--fnet_config_file",
        default=None,
        type=str,
        required=True,
        help=(
            "The config json file corresponding to the pre-trained FNet model. \n"
            "This specifies the model architecture."
        ),
    )
    # 添加必需的命令行参数 --fnet_config_file,指定预训练 FNet 模型的配置 JSON 文件路径
    # 该文件用于指定模型的架构

    parser.add_argument("--save_path", default=None, type=str, required=True, help="Path to the output model.")
    # 添加必需的命令行参数 --save_path,指定输出模型的路径

    args = parser.parse_args()
    # 解析命令行参数并将其存储在 args 变量中

    convert_flax_checkpoint_to_pytorch(args.flax_checkpoint_path, args.fnet_config_file, args.save_path)
    # 调用 convert_flax_checkpoint_to_pytorch 函数,传递命令行参数中指定的路径信息

.\models\fnet\modeling_fnet.py

# 设置源代码文件的编码格式为UTF-8
# 版权声明,2021年由Google Research和HuggingFace Inc.团队保留所有权利
# 根据Apache许可证2.0版(“许可证”)许可,除非符合许可证的使用,否则不得使用此文件。
# 您可以在以下网址获取许可证副本:http://www.apache.org/licenses/LICENSE-2.0
# 除非适用法律要求或书面同意,否则根据“许可证”分发的软件是基于“原样”提供的,不提供任何形式的明示或暗示担保或条件。
# 有关特定语言的权限,请参阅许可证。

""" PyTorch FNet model."""

# 导入警告模块
import warnings
# 导入dataclass用于数据类
from dataclasses import dataclass
# 导入partial函数用于创建偏函数
from functools import partial
# 导入类型提示相关模块
from typing import Optional, Tuple, Union

# 导入PyTorch相关模块
import torch
# 导入PyTorch中的checkpoint功能
import torch.utils.checkpoint
# 导入PyTorch中的神经网络模块
from torch import nn
# 导入PyTorch中的损失函数
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

# 判断是否存在SciPy库,用于后续可能的特定操作
from ...utils import is_scipy_available

# 如果SciPy库可用,则导入linalg模块
if is_scipy_available():
    from scipy import linalg

# 导入激活函数映射表
from ...activations import ACT2FN
# 导入模型输出相关类
from ...modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPooling,
    MaskedLMOutput,
    ModelOutput,
    MultipleChoiceModelOutput,
    NextSentencePredictorOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
# 导入模型工具函数和预训练模型基类
from ...modeling_utils import PreTrainedModel
# 导入PyTorch工具函数,用于前向计算时的分块应用
from ...pytorch_utils import apply_chunking_to_forward
# 导入通用工具函数,包括日志记录等
from ...utils import (
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
# 导入FNet模型的配置类
from .configuration_fnet import FNetConfig

# 获取日志记录器
logger = logging.get_logger(__name__)

# 用于文档的检查点和配置信息
_CHECKPOINT_FOR_DOC = "google/fnet-base"
_CONFIG_FOR_DOC = "FNetConfig"

# 预训练模型的存档列表
FNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "google/fnet-base",
    "google/fnet-large",
    # 更多FNet模型详见https://huggingface.co/models?filter=fnet
]

# 从https://github.com/google-research/google-research/blob/master/f_net/fourier.py适配而来
def _two_dim_matmul(x, matrix_dim_one, matrix_dim_two):
    """Applies 2D matrix multiplication to 3D input arrays."""
    # 获取序列长度
    seq_length = x.shape[1]
    # 裁剪矩阵以匹配序列长度
    matrix_dim_one = matrix_dim_one[:seq_length, :seq_length]
    # 将输入张量转换为复数类型
    x = x.type(torch.complex64)
    # 执行张量乘法操作
    return torch.einsum("bij,jk,ni->bnk", x, matrix_dim_two, matrix_dim_one)


# 从https://github.com/google-research/google-research/blob/master/f_net/fourier.py适配而来
def two_dim_matmul(x, matrix_dim_one, matrix_dim_two):
    """Applies 2D matrix multiplication to 3D input arrays."""
    # 调用内部函数_two_dim_matmul执行操作
    return _two_dim_matmul(x, matrix_dim_one, matrix_dim_two)


# 从https://github.com/google-research/google-research/blob/master/f_net/fourier.py适配而来
def fftn(x):
    """
    Applies n-dimensional Fast Fourier Transform (FFT) to input array.

    Args:
        x: Input n-dimensional array.

    Returns:
        n-dimensional Fourier transform of input n-dimensional array.
    """
    # 将输入直接返回,实际实现可能在此基础上增加FFT操作
    out = x
    # 对输入张量 x 进行逆序遍历其除了最后一个轴以外的所有轴
    for axis in reversed(range(x.ndim)[1:]):  # We don't need to apply FFT to last axis
        # 对张量 out 在指定的轴上应用 FFT 变换
        out = torch.fft.fft(out, axis=axis)
    # 返回应用完 FFT 后的张量 out
    return out
class FNetEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, config):
        super().__init__()
        # 初始化词嵌入层,用于将词汇索引映射为隐藏状态向量,支持填充索引
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        # 初始化位置嵌入层,用于将位置索引映射为隐藏状态向量
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        # 初始化标记类型嵌入层,用于将标记类型索引映射为隐藏状态向量
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        # 初始化层归一化层,用于归一化隐藏状态向量,保持与 TensorFlow 模型变量名的一致性以便加载 TensorFlow 检查点文件
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # NOTE: This is the project layer and will be needed. The original code allows for different embedding and different model dimensions.
        # 初始化投影层,用于将隐藏状态向量投影到另一个隐藏状态空间
        self.projection = nn.Linear(config.hidden_size, config.hidden_size)
        # 初始化丢弃层,用于在训练过程中随机丢弃部分隐藏状态向量,防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        # 注册缓冲区 position_ids,用于存储位置索引,作为持久化数据不会随模型参数保存
        self.register_buffer(
            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
        )

        # 注册缓冲区 token_type_ids,用于存储标记类型索引,初始化为全零
        self.register_buffer(
            "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
        )
    # 定义模型的前向传播函数,接收输入的标识符、标记类型ID、位置ID和嵌入输入
    def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
        # 如果传入了input_ids,则获取其形状
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            # 否则,获取inputs_embeds的形状,排除最后一个维度
            input_shape = inputs_embeds.size()[:-1]

        # 获取序列的长度
        seq_length = input_shape[1]

        # 如果未提供position_ids,则使用模型中注册的缓冲区,截取到seq_length长度的部分
        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]

        # 设置token_type_ids为模型构造函数中注册的缓冲区,通常是全零,用于在不传递token_type_ids时帮助用户追踪模型,解决问题#5664
        if token_type_ids is None:
            # 如果模型具有token_type_ids属性,则使用它的缓冲区值
            if hasattr(self, "token_type_ids"):
                buffered_token_type_ids = self.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
                token_type_ids = buffered_token_type_ids_expanded
            else:
                # 否则创建一个全零的tensor作为token_type_ids,类型为长整型,位于与self.position_ids设备相同的设备上
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

        # 如果未提供inputs_embeds,则使用word_embeddings对input_ids进行嵌入
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        
        # 根据token_type_ids获取token type embeddings
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        # 计算最终的嵌入向量,包括输入的嵌入和token type embeddings
        embeddings = inputs_embeds + token_type_embeddings

        # 根据position_ids获取位置嵌入
        position_embeddings = self.position_embeddings(position_ids)
        
        # 将位置嵌入加到当前的嵌入向量中
        embeddings += position_embeddings
        
        # 对嵌入向量进行LayerNormalization处理
        embeddings = self.LayerNorm(embeddings)
        
        # 将嵌入向量投影到最终的输出空间
        embeddings = self.projection(embeddings)
        
        # 对投影后的向量进行dropout操作,用于防止过拟合
        embeddings = self.dropout(embeddings)
        
        # 返回最终的嵌入向量作为前向传播的结果
        return embeddings
# 定义 FNetBasicFourierTransform 类,继承自 nn.Module
class FNetBasicFourierTransform(nn.Module):
    # 初始化方法
    def __init__(self, config):
        super().__init__()
        # 调用 _init_fourier_transform 方法进行初始化
        self._init_fourier_transform(config)

    # 初始化傅里叶变换方法
    def _init_fourier_transform(self, config):
        # 如果配置指示不使用 TPU 傅里叶优化
        if not config.use_tpu_fourier_optimizations:
            # 使用 torch.fft.fftn 作为傅里叶变换的部分函数,指定变换维度为 (1, 2)
            self.fourier_transform = partial(torch.fft.fftn, dim=(1, 2))
        # 如果配置中最大位置嵌入小于等于 4096
        elif config.max_position_embeddings <= 4096:
            # 检查是否有 SciPy 库可用
            if is_scipy_available():
                # 注册隐藏大小的 DFT(离散傅里叶变换)矩阵为缓冲区
                self.register_buffer(
                    "dft_mat_hidden", torch.tensor(linalg.dft(config.hidden_size), dtype=torch.complex64)
                )
                # 注册序列长度的 DFT 矩阵为缓冲区
                self.register_buffer(
                    "dft_mat_seq", torch.tensor(linalg.dft(config.tpu_short_seq_length), dtype=torch.complex64)
                )
                # 使用自定义的两个维度矩阵乘法作为傅里叶变换的部分函数
                self.fourier_transform = partial(
                    two_dim_matmul, matrix_dim_one=self.dft_mat_seq, matrix_dim_two=self.dft_mat_hidden
                )
            else:
                # 如果没有找到 SciPy 库,则记录警告并使用 fftn 作为傅里叶变换
                logging.warning(
                    "SciPy is needed for DFT matrix calculation and is not found. Using TPU optimized fast fourier"
                    " transform instead."
                )
                self.fourier_transform = fftn
        else:
            # 如果不满足上述条件,则使用 fftn 作为傅里叶变换
            self.fourier_transform = fftn

    # 前向传播方法
    def forward(self, hidden_states):
        # 输出通过傅里叶变换后的实部
        outputs = self.fourier_transform(hidden_states).real
        return (outputs,)


# 定义 FNetBasicOutput 类,继承自 nn.Module
class FNetBasicOutput(nn.Module):
    # 初始化方法
    def __init__(self, config):
        super().__init__()
        # 初始化 LayerNorm 层
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    # 前向传播方法
    def forward(self, hidden_states, input_tensor):
        # 对输入张量和隐藏状态进行 LayerNorm 处理
        hidden_states = self.LayerNorm(input_tensor + hidden_states)
        return hidden_states


# 定义 FNetFourierTransform 类,继承自 nn.Module
class FNetFourierTransform(nn.Module):
    # 初始化方法
    def __init__(self, config):
        super().__init__()
        # 创建 FNetBasicFourierTransform 实例作为 self 属性
        self.self = FNetBasicFourierTransform(config)
        # 创建 FNetBasicOutput 实例作为 output 属性
        self.output = FNetBasicOutput(config)

    # 前向传播方法
    def forward(self, hidden_states):
        # 调用 self 实例的前向传播方法
        self_outputs = self.self(hidden_states)
        # 将 self 的输出与隐藏状态作为输入,调用 output 实例的前向传播方法
        fourier_output = self.output(self_outputs[0], hidden_states)
        # 返回输出元组
        outputs = (fourier_output,)
        return outputs


# 从 transformers.models.bert.modeling_bert.BertIntermediate 复制并修改为 FNetIntermediate 类
class FNetIntermediate(nn.Module):
    # 初始化方法
    def __init__(self, config):
        super().__init__()
        # 使用线性层将隐藏大小转换为中间大小
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        # 如果隐藏激活函数是字符串,则从 ACT2FN 字典获取对应的激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            # 否则使用配置中的激活函数
            self.intermediate_act_fn = config.hidden_act
    # 定义一个方法 `forward`,接收一个名为 `hidden_states` 的张量参数,并返回一个张量
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 将输入张量通过全连接层 `self.dense` 进行线性变换
        hidden_states = self.dense(hidden_states)
        # 对线性变换后的张量应用激活函数 `self.intermediate_act_fn`
        hidden_states = self.intermediate_act_fn(hidden_states)
        # 返回经过线性变换和激活函数处理后的张量结果
        return hidden_states
# 从 transformers.models.bert.modeling_bert.BertOutput 复制代码,并将 Bert 替换为 FNet
class FNetOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 创建一个全连接层,将中间大小的特征向量映射到隐藏大小
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        # LayerNorm 层,对隐藏状态进行归一化
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # Dropout 层,用于随机丢弃隐藏状态中的一部分特征,以防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 全连接层计算
        hidden_states = self.dense(hidden_states)
        # 应用 Dropout
        hidden_states = self.dropout(hidden_states)
        # LayerNorm 和原始输入的残差连接
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class FNetLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 用于分块前馈的块大小
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        # 序列长度所在的维度
        self.seq_len_dim = 1  # The dimension which has the sequence length
        # Fourier 变换层
        self.fourier = FNetFourierTransform(config)
        # 中间层
        self.intermediate = FNetIntermediate(config)
        # 输出层
        self.output = FNetOutput(config)

    def forward(self, hidden_states):
        # Fourier 变换的输出
        self_fourier_outputs = self.fourier(hidden_states)
        fourier_output = self_fourier_outputs[0]

        # 将前馈应用到每个块
        layer_output = apply_chunking_to_forward(
            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, fourier_output
        )

        outputs = (layer_output,)

        return outputs

    def feed_forward_chunk(self, fourier_output):
        # 中间层的输出
        intermediate_output = self.intermediate(fourier_output)
        # 输出层的输出
        layer_output = self.output(intermediate_output, fourier_output)
        return layer_output


class FNetEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # 多层 FNetLayer 堆叠
        self.layer = nn.ModuleList([FNetLayer(config) for _ in range(config.num_hidden_layers)])
        # 是否启用梯度检查点
        self.gradient_checkpointing = False

    def forward(self, hidden_states, output_hidden_states=False, return_dict=True):
        # 是否输出所有隐藏状态
        all_hidden_states = () if output_hidden_states else None

        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            if self.gradient_checkpointing and self.training:
                # 如果启用梯度检查点,使用梯度检查点函数进行前向传播
                layer_outputs = self._gradient_checkpointing_func(layer_module.__call__, hidden_states)
            else:
                # 否则直接调用层的前向传播
                layer_outputs = layer_module(hidden_states)

            hidden_states = layer_outputs[0]

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)

        # 返回模型输出,包括最后一个隐藏状态和所有隐藏状态(如果输出所有隐藏状态)
        return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
    # 初始化方法,接受一个config对象作为参数
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__()
        # 创建一个全连接层,输入和输出大小为config.hidden_size
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 激活函数选择为双曲正切函数
        self.activation = nn.Tanh()

    # 前向传播方法,接受一个形状为[batch_size, sequence_length, hidden_size]的张量作为输入,
    # 返回一个形状为[batch_size, hidden_size]的张量作为输出
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 从隐藏状态张量中取出每个样本的第一个token的隐藏状态
        first_token_tensor = hidden_states[:, 0]
        # 将第一个token的隐藏状态传入全连接层进行线性变换
        pooled_output = self.dense(first_token_tensor)
        # 对线性变换后的结果应用激活函数
        pooled_output = self.activation(pooled_output)
        # 返回激活后的结果作为最终的输出张量
        return pooled_output
# 从transformers.models.bert.modeling_bert.BertPredictionHeadTransform复制代码,将Bert->FNet
class FNetPredictionHeadTransform(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化一个全连接层,输入和输出大小都是config.hidden_size
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 如果config.hidden_act是字符串,则使用ACT2FN字典中对应的激活函数;否则直接使用config.hidden_act作为激活函数
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        # 初始化LayerNorm层,输入大小为config.hidden_size,设置eps为config.layer_norm_eps
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 通过全连接层进行线性变换
        hidden_states = self.dense(hidden_states)
        # 应用激活函数变换
        hidden_states = self.transform_act_fn(hidden_states)
        # 应用LayerNorm进行归一化
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


# 复制自transformers.models.bert.modeling_bert.BertLMPredictionHead,将Bert->FNet
class FNetLMPredictionHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化预测头部的变换层
        self.transform = FNetPredictionHeadTransform(config)

        # 输出权重与输入嵌入相同,但每个token有一个仅输出的偏置
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size)

        # 初始化偏置参数
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        # 将decoder的偏置设置为初始化的偏置参数
        self.decoder.bias = self.bias

    def forward(self, hidden_states):
        # 输入经过变换层
        hidden_states = self.transform(hidden_states)
        # 经过线性层得到预测分数
        hidden_states = self.decoder(hidden_states)
        return hidden_states

    def _tie_weights(self):
        # 如果权重断开连接(在TPU上或者调整偏置大小时),重新绑定偏置
        self.bias = self.decoder.bias


# 复制自transformers.models.bert.modeling_bert.BertOnlyMLMHead,将Bert->FNet
class FNetOnlyMLMHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化MLM头部的预测
        self.predictions = FNetLMPredictionHead(config)

    def forward(self, sequence_output):
        # 使用预测层进行预测
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


# 复制自transformers.models.bert.modeling_bert.BertOnlyNSPHead,将Bert->FNet
class FNetOnlyNSPHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化NSP头部的序列关系预测层
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, pooled_output):
        # 使用线性层计算序列关系分数
        seq_relationship_score = self.seq_relationship(pooled_output)
        return seq_relationship_score


# 复制自transformers.models.bert.modeling_bert.BertPreTrainingHeads,将Bert->FNet
class FNetPreTrainingHeads(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化预测头部和序列关系头部
        self.predictions = FNetLMPredictionHead(config)
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, sequence_output, pooled_output):
        # 分别计算预测分数和序列关系分数
        prediction_scores = self.predictions(sequence_output)
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score


# 继承自PreTrainedModel的FNet预训练模型
class FNetPreTrainedModel(PreTrainedModel):
    """
    FNet预训练模型基类
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = FNetConfig
    base_model_prefix = "fnet"
    supports_gradient_checkpointing = True

    def _init_weights(self, module):
        """Initialize the weights"""
        # 如果 module 是 nn.Linear 类型
        if isinstance(module, nn.Linear):
            # 使用正态分布初始化权重,均值为 0.0,标准差为配置中的初始化范围
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 注意:原始代码中偏置的初始化和权重相同
            if module.bias is not None:
                # 将偏置数据初始化为零
                module.bias.data.zero_()
        # 如果 module 是 nn.Embedding 类型
        elif isinstance(module, nn.Embedding):
            # 使用正态分布初始化权重,均值为 0.0,标准差为配置中的初始化范围
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                # 如果有 padding_idx,则将对应位置的权重初始化为零
                module.weight.data[module.padding_idx].zero_()
        # 如果 module 是 nn.LayerNorm 类型
        elif isinstance(module, nn.LayerNorm):
            # 将偏置数据初始化为零
            module.bias.data.zero_()
            # 将权重数据初始化为 1.0
            module.weight.data.fill_(1.0)
    """
    FNetForPreTrainingOutput 类定义了预训练模型的输出类型,继承自 ModelOutput。

    Args:
        loss (torch.FloatTensor, 可选): 当提供 `labels` 时返回,表示总损失,包括掩码语言建模损失和下一个序列预测(分类)损失,形状为 `(1,)`。
        prediction_logits (torch.FloatTensor): 语言建模头部的预测分数,即 SoftMax 之前的每个词汇标记的分数,形状为 `(batch_size, sequence_length, config.vocab_size)`。
        seq_relationship_logits (torch.FloatTensor): 下一个序列预测(分类)头部的预测分数,即 SoftMax 之前的 True/False 连续性的分数,形状为 `(batch_size, 2)`。
        hidden_states (tuple(torch.FloatTensor), 可选): 当 `output_hidden_states=True` 被传递或 `config.output_hidden_states=True` 时返回,包含模型每一层的隐藏状态,形状为 `(batch_size, sequence_length, hidden_size)`。

    """

FNET_START_DOCSTRING = r"""
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`FNetConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

FNET_INPUTS_DOCSTRING = r"""
    输入参数说明:
    # 接受输入的索引序列,表示输入序列中的单词在词汇表中的索引
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        # 表示段落标记索引,用于指示输入的第一部分和第二部分。索引值为 0 或 1:
        # - 0 对应于*句子 A* 的标记,
        # - 1 对应于*句子 B* 的标记。
        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        # 表示输入序列中每个单词的位置索引,用于位置嵌入。索引值选择在范围 `[0, config.max_position_embeddings - 1]` 内。
        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)

        # 选项参数,可以选择直接传递嵌入表示而不是传递 `input_ids`。如果您想要比模型内部嵌入查找矩阵更多控制权,则这是有用的。
        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
            model's internal embedding lookup matrix.
        # 是否返回所有层的隐藏状态。有关更多细节,请查看返回张量中的 `hidden_states`。
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        # 是否返回 [`~utils.ModelOutput`] 而不是普通元组。
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
    "The bare FNet Model transformer outputting raw hidden-states without any specific head on top.",
    FNET_START_DOCSTRING,
)
class FNetModel(FNetPreTrainedModel):
    """

    The model can behave as an encoder, following the architecture described in [FNet: Mixing Tokens with Fourier
    Transforms](https://arxiv.org/abs/2105.03824) by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon.

    """

    def __init__(self, config, add_pooling_layer=True):
        super().__init__(config)
        self.config = config

        # 初始化 FNetEmbeddings 对象,用于处理模型的嵌入层
        self.embeddings = FNetEmbeddings(config)
        
        # 初始化 FNetEncoder 对象,用于处理模型的编码器层
        self.encoder = FNetEncoder(config)

        # 如果需要添加池化层,则初始化 FNetPooler 对象,否则为 None
        self.pooler = FNetPooler(config) if add_pooling_layer else None

        # 初始化权重并进行最终处理
        self.post_init()

    def get_input_embeddings(self):
        # 返回嵌入层的单词嵌入
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        # 设置嵌入层的单词嵌入
        self.embeddings.word_embeddings = value

    @add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=BaseModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[tuple, BaseModelOutput]:
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 如果同时指定了 input_ids 和 inputs_embeds,则抛出数值错误
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        # 如果指定了 input_ids
        elif input_ids is not None:
            # 获取 input_ids 的形状
            input_shape = input_ids.size()
            batch_size, seq_length = input_shape
        # 如果指定了 inputs_embeds
        elif inputs_embeds is not None:
            # 获取 inputs_embeds 的形状,排除最后一维
            input_shape = inputs_embeds.size()[:-1]
            batch_size, seq_length = input_shape
        else:
            # 如果既没有指定 input_ids 也没有指定 inputs_embeds,则抛出数值错误
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        # 如果配置中启用了 TPU Fourier 优化,并且序列长度小于等于 4096,并且配置的 tpu_short_seq_length 不等于当前序列长度
        if (
            self.config.use_tpu_fourier_optimizations
            and seq_length <= 4096
            and self.config.tpu_short_seq_length != seq_length
        ):
            # 抛出数值错误,提示需要设置正确的 tpu_short_seq_length
            raise ValueError(
                "The `tpu_short_seq_length` in FNetConfig should be set equal to the sequence length being passed to"
                " the model when using TPU optimizations."
            )

        # 根据是否存在 input_ids 来确定设备
        device = input_ids.device if input_ids is not None else inputs_embeds.device

        # 如果 token_type_ids 未指定
        if token_type_ids is None:
            # 如果 embeddings 拥有 token_type_ids 属性
            if hasattr(self.embeddings, "token_type_ids"):
                # 从 embeddings 中获取 token_type_ids,并截取到序列长度的部分,然后扩展到整个 batch
                buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
                token_type_ids = buffered_token_type_ids_expanded
            else:
                # 如果 embeddings 没有 token_type_ids 属性,则创建一个全零的 tensor
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        # 使用 embeddings 进行前向传播
        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
        )

        # 使用 encoder 进行编码器的前向传播
        encoder_outputs = self.encoder(
            embedding_output,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取编码器输出的序列输出
        sequence_output = encoder_outputs[0]

        # 如果存在 pooler,则使用 pooler 对序列输出进行池化
        pooler_output = self.pooler(sequence_output) if self.pooler is not None else None

        # 如果不要求返回字典形式的输出,则返回元组形式的结果
        if not return_dict:
            return (sequence_output, pooler_output) + encoder_outputs[1:]

        # 否则,返回一个带池化的 BaseModelOutputWithPooling 对象
        return BaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooler_output,
            hidden_states=encoder_outputs.hidden_states,
        )
"""
FNet Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
sentence prediction (classification)` head.
"""
# 导入所需的函数和类,用于添加文档字符串
@add_start_docstrings(
    """
    FNet Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
    sentence prediction (classification)` head.
    """,
    FNET_START_DOCSTRING,
)
# 定义 FNetForPreTraining 类,继承自 FNetPreTrainedModel
class FNetForPreTraining(FNetPreTrainedModel):
    # 定义用于权重共享的关键键列表
    _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]

    # 初始化函数
    def __init__(self, config):
        # 调用父类的初始化函数
        super().__init__(config)

        # 初始化 FNetModel 模型
        self.fnet = FNetModel(config)
        # 初始化 FNetPreTrainingHeads 模型
        self.cls = FNetPreTrainingHeads(config)

        # 调用后处理函数,初始化权重并应用最终处理
        self.post_init()

    # 获取输出嵌入层的函数
    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    # 设置输出嵌入层的函数
    def set_output_embeddings(self, new_embeddings):
        self.cls.predictions.decoder = new_embeddings

    # 前向传播函数
    @add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=FNetForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        next_sentence_label: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        # 函数参数:输入的张量数据、标签、下一个句子标签、是否返回隐藏状态、是否返回字典类型的结果

        return self.fnet(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            labels=labels,
            next_sentence_label=next_sentence_label,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )


注释:

        # 调用 FNetModel 的前向传播方法,将参数传递给 fnet 对象
        return self.fnet(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            labels=labels,
            next_sentence_label=next_sentence_label,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )


这段代码定义了一个 FNetForPreTraining 类,它包含了 FNet 模型的预训练结构,包括一个掩码语言建模头和一个下一个句子预测头。它实现了前向传播方法,调用了内部的 FNetModel 的前向传播函数。
    ) -> Union[Tuple, FNetForPreTrainingOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
            (see `input_ids` docstring) Indices should be in `[0, 1]`:

            - 0 indicates sequence B is a continuation of sequence A,
            - 1 indicates sequence B is a random sequence.
        kwargs (`Dict[str, any]`, optional, defaults to *{}*):
            Used to hide legacy arguments that have been deprecated.

        Returns:
            Returns an instance of `FNetForPreTrainingOutput` containing various outputs from the model.

        Example:

        ```
        >>> from transformers import AutoTokenizer, FNetForPreTraining
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("google/fnet-base")
        >>> model = FNetForPreTraining.from_pretrained("google/fnet-base")
        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> prediction_logits = outputs.prediction_logits
        >>> seq_relationship_logits = outputs.seq_relationship_logits
        ```"""
        # Determine whether to use the return_dict mode based on the input argument or default configuration
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Forward pass through the FNet model with specified inputs and configurations
        outputs = self.fnet(
            input_ids,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # Extract the relevant outputs from the FNet model's output
        sequence_output, pooled_output = outputs[:2]

        # Perform classification on the extracted outputs
        prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)

        # Initialize total_loss to None
        total_loss = None

        # Compute total loss if both labels and next_sentence_label are provided
        if labels is not None and next_sentence_label is not None:
            # Define the CrossEntropyLoss criterion
            loss_fct = CrossEntropyLoss()
            
            # Calculate the masked language modeling (MLM) loss
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
            
            # Calculate the next sentence prediction (NSP) loss
            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
            
            # Aggregate total loss from MLM and NSP losses
            total_loss = masked_lm_loss + next_sentence_loss

        # Return the appropriate outputs based on the return_dict flag
        if not return_dict:
            # If return_dict is False, return a tuple of outputs
            output = (prediction_scores, seq_relationship_score) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output
        else:
            # If return_dict is True, return FNetForPreTrainingOutput object
            return FNetForPreTrainingOutput(
                loss=total_loss,
                prediction_logits=prediction_scores,
                seq_relationship_logits=seq_relationship_score,
                hidden_states=outputs.hidden_states,
            )
# 使用装饰器为类添加文档字符串,描述了其带有语言建模头部的 FNet 模型
@add_start_docstrings("""FNet Model with a `language modeling` head on top.""", FNET_START_DOCSTRING)
class FNetForMaskedLM(FNetPreTrainedModel):
    # 指定了共享权重的键列表
    _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]

    def __init__(self, config):
        super().__init__(config)

        # 初始化 FNet 模型和仅包含 MLM 头部的组件
        self.fnet = FNetModel(config)
        self.cls = FNetOnlyMLMHead(config)

        # 执行后期初始化,包括权重初始化和最终处理
        self.post_init()

    def get_output_embeddings(self):
        # 返回 MLM 头部的预测解码器
        return self.cls.predictions.decoder

    def set_output_embeddings(self, new_embeddings):
        # 设置 MLM 头部的预测解码器的新嵌入
        self.cls.predictions.decoder = new_embeddings

    # 为 forward 方法添加文档字符串,描述了其输入和输出以及一些示例代码
    @add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=MaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, MaskedLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
        """
        # 如果没有指定 return_dict,则使用配置中的设置
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 使用 FNet 模型进行前向传播
        outputs = self.fnet(
            input_ids,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 提取序列输出和预测分数
        sequence_output = outputs[0]
        prediction_scores = self.cls(sequence_output)

        masked_lm_loss = None
        if labels is not None:
            # 计算掩蔽语言建模损失
            loss_fct = CrossEntropyLoss()  # -100 索引表示填充令牌
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

        if not return_dict:
            # 如果不需要返回字典,则输出结果元组
            output = (prediction_scores,) + outputs[2:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        # 如果需要返回字典,则返回 MaskedLMOutput 对象
        return MaskedLMOutput(loss=masked_lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states)


# 使用装饰器为类添加文档字符串,描述了其带有下一句预测分类头部的 FNet 模型
@add_start_docstrings(
    """FNet Model with a `next sentence prediction (classification)` head on top.""",
    FNET_START_DOCSTRING,
)
class FNetForNextSentencePrediction(FNetPreTrainedModel):
    # 初始化函数,接受一个配置参数 config
    def __init__(self, config):
        # 调用父类的初始化方法,传入配置参数 config
        super().__init__(config)

        # 创建 FNetModel 对象,使用给定的配置参数 config
        self.fnet = FNetModel(config)
        # 创建 FNetOnlyNSPHead 对象,使用给定的配置参数 config
        self.cls = FNetOnlyNSPHead(config)

        # 调用本类中的 post_init 方法,用于初始化权重并应用最终处理
        self.post_init()

    # 前向传播函数,添加了一些文档字符串用于描述函数的输入和输出
    @add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs,
        ) -> Union[Tuple, NextSentencePredictorOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
            (see `input_ids` docstring). Indices should be in `[0, 1]`:

            - 0 indicates sequence B is a continuation of sequence A,
            - 1 indicates sequence B is a random sequence.

        Returns:

        Example:

        ```
        >>> from transformers import AutoTokenizer, FNetForNextSentencePrediction
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("google/fnet-base")
        >>> model = FNetForNextSentencePrediction.from_pretrained("google/fnet-base")
        >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
        >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
        >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
        >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
        >>> logits = outputs.logits
        >>> assert logits[0, 0] < logits[0, 1]  # next sentence was random
        ```"""

        # 如果 kwargs 中包含 `next_sentence_label`,则发出警告并使用其值作为 labels
        if "next_sentence_label" in kwargs:
            warnings.warn(
                "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
                " `labels` instead.",
                FutureWarning,
            )
            labels = kwargs.pop("next_sentence_label")

        # 决定是否返回字典格式的输出,如果未指定则使用配置中的默认值
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 使用 FNet 模型进行下一句预测任务的计算
        outputs = self.fnet(
            input_ids,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从 FNet 输出的第二个元素中提取池化后的输出,用于后续分类任务
        pooled_output = outputs[1]

        # 将池化后的输出传递给分类器,得到下一句关系的分数
        seq_relationship_scores = self.cls(pooled_output)

        # 初始化下一句预测的损失为 None
        next_sentence_loss = None
        # 如果 labels 不为 None,则计算损失
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))

        # 根据 return_dict 决定返回的结果格式
        if not return_dict:
            output = (seq_relationship_scores,) + outputs[2:]
            return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output

        # 使用 NextSentencePredictorOutput 类封装并返回结果
        return NextSentencePredictorOutput(
            loss=next_sentence_loss,
            logits=seq_relationship_scores,
            hidden_states=outputs.hidden_states,
        )
@add_start_docstrings(
    """
    FNet Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    """,
    FNET_START_DOCSTRING,
)
class FNetForSequenceClassification(FNetPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        # 初始化时从配置中获取标签数量
        self.num_labels = config.num_labels
        # 初始化一个FNet模型实例
        self.fnet = FNetModel(config)

        # Dropout层,使用配置中的隐藏层dropout概率
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # 分类器,线性层,输入大小为配置中的隐藏层大小,输出大小为标签数量
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        # 初始化权重并进行最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=SequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        """
        Perform a forward pass of the FNetForSequenceClassification model.

        Args:
            input_ids (torch.Tensor, optional): Input token IDs of shape (batch_size, sequence_length).
            token_type_ids (torch.Tensor, optional): Input token type IDs of shape (batch_size, sequence_length).
            position_ids (torch.Tensor, optional): Input token position IDs of shape (batch_size, sequence_length).
            inputs_embeds (torch.Tensor, optional): Embedded representations of inputs.
            labels (torch.Tensor, optional): Labels for classification task.
            output_hidden_states (bool, optional): Whether to output hidden states.
            return_dict (bool, optional): Whether to return a dictionary as output.

        Returns:
            SequenceClassifierOutput: Classification output consisting of logits, hidden states, etc.
        """
        # 实现FNetForSequenceClassification模型的前向传播

        # 省略具体的前向传播逻辑,由于代码中未提供实现细节
    ) -> Union[Tuple, SequenceClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        # 初始化 return_dict 变量,如果 return_dict 不为 None,则使用给定的值,否则使用 self.config.use_return_dict 的值
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用 self.fnet 进行前向传播,获取模型输出
        outputs = self.fnet(
            input_ids,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从模型输出中获取 pooled_output(通常是 CLS token 的输出)
        pooled_output = outputs[1]
        # 对 pooled_output 应用 dropout 操作,用于防止过拟合
        pooled_output = self.dropout(pooled_output)
        # 使用分类器 self.classifier 对 pooled_output 进行分类预测,得到 logits
        logits = self.classifier(pooled_output)

        # 初始化损失为 None
        loss = None
        # 如果 labels 不为 None,则计算损失
        if labels is not None:
            # 如果 self.config.problem_type 未定义,则根据 num_labels 自动定义问题类型
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            # 根据问题类型计算损失
            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    # 对于单标签回归任务,计算均方误差损失
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    # 对于多标签回归任务,计算均方误差损失
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                # 对于单标签分类任务,使用交叉熵损失函数
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                # 对于多标签分类任务,使用带 logits 的二元交叉熵损失函数
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        # 如果 return_dict 为 False,则按顺序返回 logits 和额外的 hidden states
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 如果 return_dict 为 True,则返回 SequenceClassifierOutput 类型的对象,包括损失、logits 和 hidden states
        return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
@add_start_docstrings(
    """
    FNet Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
    softmax) e.g. for RocStories/SWAG tasks.
    """,
    FNET_START_DOCSTRING,
)
class FNetForMultipleChoice(FNetPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        # 初始化 FNet 模型
        self.fnet = FNetModel(config)
        # Dropout 层,用于随机失活以防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # 分类器线性层,将 FNet 输出映射到单一数值(用于二元分类)
        self.classifier = nn.Linear(config.hidden_size, 1)

        # 初始化权重并进行后续处理
        self.post_init()

    @add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=MultipleChoiceModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        """
        FNet 模型的前向传播方法。

        Args:
            input_ids (torch.Tensor, optional): 输入的 token IDs 张量.
            token_type_ids (torch.Tensor, optional): token 类型 IDs 张量.
            position_ids (torch.Tensor, optional): 位置 IDs 张量.
            inputs_embeds (torch.Tensor, optional): 嵌入输入张量.
            labels (torch.Tensor, optional): 标签张量 (用于训练时).
            output_hidden_states (bool, optional): 是否输出隐藏状态.
            return_dict (bool, optional): 是否返回字典格式输出.

        Returns:
            返回一个包含多个选择的模型输出.

        """
        # 实现 FNet 模型的具体前向传播逻辑
        # (具体实现代码在这里,不包含在注释内)
        pass
    # 返回一个字典,如果没有指定则使用配置中的默认设置
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    
    # 获取输入中第二维的大小作为选择项的数量
    num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
    
    # 如果输入的input_ids不为None,则将其视图重新调整为二维张量
    input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
    
    # 如果token_type_ids不为None,则将其视图重新调整为二维张量
    token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
    
    # 如果position_ids不为None,则将其视图重新调整为二维张量
    position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
    
    # 如果inputs_embeds不为None,则将其视图重新调整为三维张量
    inputs_embeds = (
        inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
        if inputs_embeds is not None
        else None
    )
    
    # 使用给定的输入调用模型的前向传播函数fnet,返回输出结果
    outputs = self.fnet(
        input_ids,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        inputs_embeds=inputs_embeds,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
    
    # 从模型输出中获取池化后的输出,一般在位置1
    pooled_output = outputs[1]
    
    # 对池化后的输出应用dropout操作
    pooled_output = self.dropout(pooled_output)
    
    # 将dropout后的输出通过分类器得到logits(对数概率)
    logits = self.classifier(pooled_output)
    
    # 将logits重新调整为二维张量,以匹配选择项的数量
    reshaped_logits = logits.view(-1, num_choices)
    
    # 初始化损失值为None
    loss = None
    
    # 如果提供了标签labels,则计算交叉熵损失
    if labels is not None:
        loss_fct = CrossEntropyLoss()
        loss = loss_fct(reshaped_logits, labels)
    
    # 如果不需要返回字典形式的输出,则按指定格式返回结果
    if not return_dict:
        output = (reshaped_logits,) + outputs[2:]  # 将预测值和可能的其他输出组合成元组
        return ((loss,) + output) if loss is not None else output  # 如果有损失值则将其包含在返回结果中
    
    # 如果需要返回字典形式的输出,则构建MultipleChoiceModelOutput对象返回
    return MultipleChoiceModelOutput(loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states)
"""
FNet Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
"""
# 导入必要的库函数
@add_start_docstrings(
    """
    FNet Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
    Named-Entity-Recognition (NER) tasks.
    """,
    FNET_START_DOCSTRING,
)
# 定义 FNetForTokenClassification 类,继承自 FNetPreTrainedModel
class FNetForTokenClassification(FNetPreTrainedModel):
    
    # 初始化方法,接收一个配置对象 config
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)
        # 设置类别数量为配置中的 num_labels
        self.num_labels = config.num_labels

        # 初始化 FNetModel 对象,并保存在 self.fnet 中
        self.fnet = FNetModel(config)

        # 使用配置中的隐藏层 dropout 概率创建 Dropout 层
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # 创建线性分类器层,将隐藏层输出映射到 num_labels 维度
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        # 初始化权重并进行最终处理
        self.post_init()

    # 定义 forward 方法,处理输入并返回结果
    @add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    # forward 方法的详细文档字符串
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, TokenClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        # 确定是否使用返回字典
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用 self.fnet 对象的 forward 方法,处理输入数据
        outputs = self.fnet(
            input_ids,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从输出中获取序列输出
        sequence_output = outputs[0]

        # 应用 dropout 层
        sequence_output = self.dropout(sequence_output)
        # 将序列输出传入分类器层,得到 logits
        logits = self.classifier(sequence_output)

        # 初始化损失为 None
        loss = None
        # 如果 labels 不为 None,则计算分类损失
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # 只保留损失的有效部分
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        # 如果不使用返回字典,则返回元组形式的输出
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 如果使用返回字典,则创建 TokenClassifierOutput 对象并返回
        return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
    # 初始化函数,接受一个配置对象作为参数
    def __init__(self, config):
        # 调用父类的初始化方法,传入配置对象
        super().__init__(config)

        # 从配置对象中获取标签数目并存储在实例变量中
        self.num_labels = config.num_labels

        # 创建一个FNetModel的实例并存储在实例变量self.fnet中
        self.fnet = FNetModel(config)

        # 创建一个线性层用于输出,输入尺寸为配置对象中的隐藏层大小,输出尺寸为标签数目
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        # 调用自定义的初始化方法,用于初始化权重并进行最终处理
        # 在此方法中可能包含权重初始化和其他必要的处理步骤
        self.post_init()

    # 前向传播函数,用于模型的前向计算
    @add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=QuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        start_positions: Optional[torch.Tensor] = None,
        end_positions: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        ) -> Union[Tuple, QuestionAnsweringModelOutput]:
        r"""
        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        """
        # Determine if we should use the return_dict from the config or from the function argument
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Forward pass through the transformer network with given inputs
        outputs = self.fnet(
            input_ids,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # Extract the final sequence output from the transformer network
        sequence_output = outputs[0]

        # Generate logits from the sequence output using the question answering head
        logits = self.qa_outputs(sequence_output)
        
        # Split logits into start and end logits
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()

        total_loss = None
        # Calculate total loss if start_positions and end_positions are provided
        if start_positions is not None and end_positions is not None:
            # If the batch size is greater than 1, squeeze extra dimensions
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # Clamp positions to be within the valid range of sequence length
            ignored_index = start_logits.size(1)
            start_positions = start_positions.clamp(0, ignored_index)
            end_positions = end_positions.clamp(0, ignored_index)

            # Define loss function and compute start and end position losses
            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

        # If return_dict is False, prepare outputs as tuple
        if not return_dict:
            output = (start_logits, end_logits) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        # If return_dict is True, return QuestionAnsweringModelOutput
        return QuestionAnsweringModelOutput(
            loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states
        )
posted @ 2024-06-30 15:36  绝不原创的飞龙  阅读(11)  评论(0编辑  收藏  举报