Transformers-源码解析-一百一十六-

Transformers 源码解析(一百一十六)

.\models\vilt\modeling_vilt.py

# 设置文件编码为 UTF-8
# 版权声明
# 版权所有 2022 年 NAVER AI Labs 和 HuggingFace Inc. 团队。保留所有权利。
#
# 根据 Apache 许可证 2.0 版本(“许可证”)许可;
# 除非符合许可证的规定,否则不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则软件
# 根据“原样”分发,不附带任何明示或暗示的担保或条件。
# 有关详细信息,请参阅许可证。
""" PyTorch ViLT 模型。"""

import collections.abc  # 导入抽象基类集合
import math  # 导入数学库
from dataclasses import dataclass  # 导入数据类装饰器
from typing import List, Optional, Tuple, Union  # 导入类型提示

import torch  # 导入 PyTorch
import torch.utils.checkpoint  # 导入 PyTorch 检查点工具
from torch import nn  # 导入 PyTorch 神经网络模块
from torch.nn import CrossEntropyLoss  # 导入交叉熵损失函数

from ...activations import ACT2FN  # 导入激活函数映射
from ...modeling_outputs import (  # 导入模型输出类
    BaseModelOutput,
    BaseModelOutputWithPooling,
    MaskedLMOutput,
    ModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel  # 导入预训练模型工具
from ...pytorch_utils import (  # 导入 PyTorch 工具函数
    find_pruneable_heads_and_indices,
    meshgrid,
    prune_linear_layer,
)
from ...utils import (  # 导入工具函数
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from .configuration_vilt import ViltConfig  # 导入 ViLT 配置

logger = logging.get_logger(__name__)  # 获取日志记录器

_CONFIG_FOR_DOC = "ViltConfig"  # 用于文档的配置类名
_CHECKPOINT_FOR_DOC = "dandelin/vilt-b32-mlm"  # 用于文档的检查点名称

VILT_PRETRAINED_MODEL_ARCHIVE_LIST = [  # ViLT 预训练模型存档列表
    "dandelin/vilt-b32-mlm",
    # 查看所有 ViLT 模型 https://huggingface.co/models?filter=vilt
]


@dataclass
class ViltForImagesAndTextClassificationOutput(ModelOutput):
    """
    [`ViltForImagesAndTextClassification`] 的输出类。
    """
    # 定义函数参数和返回类型的文档字符串,描述了该函数可以接受的参数和可能的返回值类型
    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            分类(如果config.num_labels==1,则为回归)损失值。
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            分类(如果config.num_labels==1,则为回归)得分(SoftMax之前)。
        hidden_states (`List[tuple(torch.FloatTensor)]`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            `torch.FloatTensor`的列表(每个图像-文本对一个,每个元组包含嵌入的输出+每层输出)的元组,形状为`(batch_size, sequence_length, hidden_size)`。
            模型在每一层输出的隐藏状态加上初始嵌入的输出。
        attentions (`List[tuple(torch.FloatTensor)]`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            `torch.FloatTensor`的列表(每个图像-文本对一个,每个元组包含注意力权重的输出)的元组,形状为`(batch_size, num_heads, sequence_length, sequence_length)`。
            注意力softmax之后的注意力权重,用于计算自注意力头中的加权平均值。
    """

    # 定义可能为None的损失值变量,类型为`torch.FloatTensor`
    loss: Optional[torch.FloatTensor] = None
    # 定义必须存在的logits变量,类型为`torch.FloatTensor`
    logits: torch.FloatTensor = None
    # 定义可能为None的隐藏状态列表变量,每个元素为`torch.FloatTensor`的元组列表
    hidden_states: Optional[List[Tuple[torch.FloatTensor]]] = None
    # 定义可能为None的注意力权重列表变量,每个元素为`torch.FloatTensor`的元组列表
    attentions: Optional[List[Tuple[torch.FloatTensor]]] = None
class ViltEmbeddings(nn.Module):
    """
    Construct the text and patch embeddings.

    Text embeddings are equivalent to BERT embeddings.

    Patch embeddings are equivalent to ViT embeddings.
    """

    def __init__(self, config):
        super().__init__()

        # text embeddings
        self.text_embeddings = TextEmbeddings(config)
        # patch embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
        self.patch_embeddings = ViltPatchEmbeddings(config)
        num_patches = self.patch_embeddings.num_patches
        self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
        # modality type (text/patch) embeddings
        self.token_type_embeddings = nn.Embedding(config.modality_type_vocab_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.config = config

    def forward(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        pixel_values,
        pixel_mask,
        inputs_embeds,
        image_embeds,
        image_token_type_idx=1,
    ):
        # PART 1: text embeddings
        text_embeds = self.text_embeddings(
            input_ids=input_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
        )

        # PART 2: patch embeddings (with interpolated position encodings)
        if image_embeds is None:
            # Generate visual embeddings and masks from pixel values
            image_embeds, image_masks, patch_index = self.visual_embed(
                pixel_values, pixel_mask, max_image_length=self.config.max_image_length
            )
        else:
            # Flatten pixel masks
            image_masks = pixel_mask.flatten(1)

        # PART 3: add modality type embeddings
        # 0 indicates text, 1 indicates image, 2 is optionally used when a second image is provided (NLVR2)
        if image_token_type_idx is None:
            image_token_type_idx = 1
        # Add token type embeddings to text embeddings
        text_embeds = text_embeds + self.token_type_embeddings(
            torch.zeros_like(attention_mask, dtype=torch.long, device=text_embeds.device)
        )
        # Add token type embeddings to image embeddings
        image_embeds = image_embeds + self.token_type_embeddings(
            torch.full_like(image_masks, image_token_type_idx, dtype=torch.long, device=text_embeds.device)
        )

        # PART 4: concatenate text and image embeddings
        embeddings = torch.cat([text_embeds, image_embeds], dim=1)
        # Concatenate attention masks and image masks
        masks = torch.cat([attention_mask, image_masks], dim=1)

        return embeddings, masks
    # 初始化函数,接受一个配置参数 config
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__()
        # 创建一个词嵌入层,用于将词汇索引映射为隐藏状态向量
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        # 创建一个位置嵌入层,用于将位置索引映射为隐藏状态向量
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        # 创建一个标记类型嵌入层,用于将标记类型索引映射为隐藏状态向量
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # 创建一个 LayerNorm 层,用于标准化隐藏状态向量
        # 注意:这里 LayerNorm 的命名方式与 TensorFlow 的模型变量保持一致,以便能够加载任何 TensorFlow 的检查点文件
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 创建一个 Dropout 层,用于在训练过程中随机丢弃隐藏状态向量的部分内容,以防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        
        # 设置位置嵌入类型,默认为绝对位置编码
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
        # 注册一个缓冲区,用于存储位置索引的张量,这个张量在序列化时会被导出
        self.register_buffer(
            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
        )
        # 注册一个缓冲区,用于存储标记类型索引的张量,初始值为全零
        self.register_buffer(
            "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
        )

    # 前向传播函数,接受多个输入参数,根据输入计算模型的输出
    def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
        # 如果输入的 input_ids 不为 None,则获取其形状
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            # 否则,获取 inputs_embeds 的形状(排除最后一维)
            input_shape = inputs_embeds.size()[:-1]

        # 获取序列长度,即输入数据的第二个维度大小
        seq_length = input_shape[1]

        # 如果 position_ids 为 None,则使用预先注册的位置索引张量 self.position_ids
        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]

        # 如果 token_type_ids 为 None,则使用预先注册的标记类型索引张量 self.token_type_ids
        if token_type_ids is None:
            if hasattr(self, "token_type_ids"):
                # 获取并扩展预先注册的 token_type_ids 到与输入形状相匹配的张量
                buffered_token_type_ids = self.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
                token_type_ids = buffered_token_type_ids_expanded
            else:
                # 如果未注册 token_type_ids,则创建一个全零张量,与输入形状相匹配
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

        # 如果 inputs_embeds 为 None,则通过 word_embeddings 层将 input_ids 映射为词嵌入向量
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        
        # 根据 token_type_ids 获取标记类型嵌入向量
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        # 将词嵌入向量和标记类型嵌入向量相加,得到最终的嵌入向量
        embeddings = inputs_embeds + token_type_embeddings
        
        # 如果位置编码方式为绝对位置编码,则添加位置嵌入向量到最终的嵌入向量中
        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings
        
        # 对最终的嵌入向量进行 LayerNorm 标准化处理
        embeddings = self.LayerNorm(embeddings)
        # 对标准化后的向量应用 Dropout,以防止过拟合
        embeddings = self.dropout(embeddings)
        # 返回最终的嵌入向量作为模型的输出
        return embeddings
    """
    Image to Patch Embedding.
    """

    # 初始化函数,设置类的初始状态
    def __init__(self, config):
        super().__init__()
        # 从配置中获取图像大小和patch大小
        image_size, patch_size = config.image_size, config.patch_size
        # 从配置中获取通道数和隐藏层大小
        num_channels, hidden_size = config.num_channels, config.hidden_size

        # 确保image_size和patch_size是可迭代对象,若不是则转为元组
        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
        # 计算patch的数量
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])

        # 设置类的成员变量
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_patches = num_patches

        # 使用卷积层进行投影,将图像转换为patch embeddings
        self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)

    # 前向传播函数,定义了数据从输入到输出的流程
    def forward(self, pixel_values):
        # 获取输入张量的尺寸信息
        batch_size, num_channels, height, width = pixel_values.shape
        # 如果输入通道数与配置中的通道数不匹配,则抛出数值错误异常
        if num_channels != self.num_channels:
            raise ValueError(
                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
            )
        # 确定目标数据类型为投影层权重的数据类型
        target_dtype = self.projection.weight.dtype
        # 对输入张量进行投影操作,并转换为目标数据类型
        x = self.projection(pixel_values.to(dtype=target_dtype))
        # 返回投影后的张量作为输出
        return x


class ViltSelfAttention(nn.Module):
    # 初始化函数,设置自注意力模块的初始状态
    def __init__(self, config):
        super().__init__()
        # 如果隐藏层大小不能被注意力头数整除,并且配置中没有嵌入大小属性,则抛出数值错误异常
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
                f"heads {config.num_attention_heads}."
            )

        # 设置注意力头数和每个头的大小
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # 定义查询、键、值的线性映射层,并考虑是否使用偏置
        self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
        self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
        self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)

        # 定义dropout层,用于注意力概率的dropout
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    # 将输入张量重塑为注意力分数计算所需的形状
    def transpose_for_scores(self, x):
        # 计算新的张量形状,以便符合多头注意力的要求
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        # 重新排列张量的维度,使得多头注意力能够并行计算
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)
    # 此方法用于前向传播,计算注意力机制中的上下文表示
    def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
        # 通过self.query对隐藏状态进行查询,生成混合的查询层
        mixed_query_layer = self.query(hidden_states)

        # 通过self.key对隐藏状态进行键的变换,并进行得分计算
        key_layer = self.transpose_for_scores(self.key(hidden_states))

        # 通过self.value对隐藏状态进行值的变换,并进行得分计算
        value_layer = self.transpose_for_scores(self.value(hidden_states))

        # 对混合的查询层再进行得分计算
        query_layer = self.transpose_for_scores(mixed_query_layer)

        # 通过点积计算"查询"和"键"之间的原始注意力分数
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        # 根据注意力头的大小对注意力分数进行缩放
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        # 如果提供了注意力掩码,则应用它
        if attention_mask is not None:
            # 注意力掩码是预先计算好的,适用于BertModel的forward()函数中的所有层
            attention_scores = attention_scores + attention_mask

        # 将注意力分数归一化为概率分布
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # 使用dropout函数对注意力概率进行随机失活处理
        attention_probs = self.dropout(attention_probs)

        # 如果需要,对注意力概率进行头部掩码处理
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        # 计算加权和得到上下文表示层
        context_layer = torch.matmul(attention_probs, value_layer)

        # 对上下文表示层进行维度变换和重塑
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # 准备输出,根据需要包含注意力分布
        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        # 返回计算结果
        return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Vilt
class ViltSelfOutput(nn.Module):
    """
    The residual connection is defined in ViltLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    """

    def __init__(self, config: ViltConfig) -> None:
        super().__init__()
        # 定义一个全连接层,输入和输出的维度都是 config.hidden_size
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 定义一个 dropout 层,根据给定的隐藏状态概率随机将输入置零
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 将输入的隐藏状态通过全连接层映射到同一维度
        hidden_states = self.dense(hidden_states)
        # 对映射后的隐藏状态进行 dropout 操作
        hidden_states = self.dropout(hidden_states)

        return hidden_states


class ViltAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化自注意力层和自输出层,都使用给定的配置参数
        self.attention = ViltSelfAttention(config)
        self.output = ViltSelfOutput(config)
        # 初始化一个空集合,用于存储被修剪的注意力头
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        # 根据给定的头部列表找到可修剪的头部和对应的索引
        heads, index = find_pruneable_heads_and_indices(
            heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
        )

        # 修剪线性层
        self.attention.query = prune_linear_layer(self.attention.query, index)
        self.attention.key = prune_linear_layer(self.attention.key, index)
        self.attention.value = prune_linear_layer(self.attention.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # 更新超参数并存储修剪的头部
        self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
        self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
        # 通过自注意力层处理隐藏状态和相关的掩码信息
        self_outputs = self.attention(hidden_states, attention_mask, head_mask, output_attentions)

        # 使用自输出层将注意力层的输出与原始隐藏状态相加
        attention_output = self.output(self_outputs[0], hidden_states)

        # 如果输出注意力信息,则将其添加到输出元组中
        outputs = (attention_output,) + self_outputs[1:]
        return outputs


# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Vilt
class ViltIntermediate(nn.Module):
    def __init__(self, config: ViltConfig) -> None:
        super().__init__()
        # 定义一个全连接层,输入维度为 config.hidden_size,输出维度为 config.intermediate_size
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        # 如果隐藏激活函数是字符串,则根据字符串映射到相应的激活函数,否则直接使用给定的激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 将输入的隐藏状态通过全连接层映射到 intermediate_size 的维度
        hidden_states = self.dense(hidden_states)
        # 将映射后的隐藏状态通过中间激活函数处理
        hidden_states = self.intermediate_act_fn(hidden_states)

        return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->Vilt
class ViltOutput(nn.Module):
    def __init__(self, config: ViltConfig) -> None:
        super().__init__()
        # 定义一个全连接层,将中间大小的特征转换为隐藏大小
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        # 定义一个 dropout 层,用于随机断开神经元连接,防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 将输入的隐藏状态通过全连接层映射到隐藏大小的空间
        hidden_states = self.dense(hidden_states)
        # 对映射后的结果进行 dropout 处理
        hidden_states = self.dropout(hidden_states)

        # 将处理后的隐藏状态与输入张量相加作为最终输出
        hidden_states = hidden_states + input_tensor

        return hidden_states


class ViltLayer(nn.Module):
    """This corresponds to the Block class in the timm implementation."""

    def __init__(self, config):
        super().__init__()
        # 设置用于分块前馈的块大小
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        # 序列长度的维度索引
        self.seq_len_dim = 1
        # 初始化自注意力层、中间层和输出层
        self.attention = ViltAttention(config)
        self.intermediate = ViltIntermediate(config)
        self.output = ViltOutput(config)
        # ViLT 中的 layernorm 在自注意力之前应用
        self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # ViLT 中的 layernorm 也在自注意力之后应用
        self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
        # 对输入的隐藏状态应用 layernorm,并传入自注意力层进行处理
        self_attention_outputs = self.attention(
            self.layernorm_before(hidden_states),
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
        )
        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[1:]  # 如果输出注意力权重,添加自注意力的输出

        # 第一个残差连接:将自注意力的输出与原始隐藏状态相加
        hidden_states = attention_output + hidden_states.to(attention_output.device)

        # 在 ViLT 中,layernorm 也在自注意力之后应用
        layer_output = self.layernorm_after(hidden_states)
        # 经过中间层的处理
        layer_output = self.intermediate(layer_output)

        # 第二个残差连接:将中间层的输出与原始隐藏状态传入输出层
        layer_output = self.output(layer_output, hidden_states)

        # 将最终层的输出添加到输出集合中
        outputs = (layer_output,) + outputs

        return outputs


class ViltEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # 创建多层 ViltLayer 构成的层列表
        self.layer = nn.ModuleList([ViltLayer(config) for _ in range(config.num_hidden_layers)])
        # 默认关闭梯度检查点
        self.gradient_checkpointing = False

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True,
        ):
        ):
            # 如果不需要输出隐藏状态,则初始化为空元组;否则设为 None
            all_hidden_states = () if output_hidden_states else None
            # 如果不需要输出注意力权重,则初始化为空元组;否则设为 None
            all_self_attentions = () if output_attentions else None
        
            # 遍历 Transformer 模型的每一层
            for i, layer_module in enumerate(self.layer):
                # 如果需要输出隐藏状态,则累加当前隐藏状态到 all_hidden_states
                if output_hidden_states:
                    all_hidden_states = all_hidden_states + (hidden_states,)
                
                # 获取当前层的头部遮罩,如果未提供则为 None
                layer_head_mask = head_mask[i] if head_mask is not None else None
                
                # 如果启用渐变检查点并且处于训练模式下
                if self.gradient_checkpointing and self.training:
                    # 通过渐变检查点功能调用当前层模块,获取层的输出
                    layer_outputs = self._gradient_checkpointing_func(
                        layer_module.__call__,
                        hidden_states,
                        attention_mask,
                        layer_head_mask,
                        output_attentions,
                    )
                else:
                    # 否则直接调用当前层模块,获取层的输出
                    layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
    
                # 更新隐藏状态为当前层的输出的第一个元素
                hidden_states = layer_outputs[0]
    
                # 如果需要输出注意力权重,则累加当前层的注意力权重到 all_self_attentions
                if output_attentions:
                    all_self_attentions = all_self_attentions + (layer_outputs[1],)
    
            # 如果需要输出隐藏状态,则最后将当前隐藏状态加入 all_hidden_states
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)
    
            # 如果不返回字典形式的输出,则按顺序返回隐藏状态、所有隐藏状态和所有注意力权重的非空元组
            if not return_dict:
                return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
            # 否则以 BaseModelOutput 类的形式返回结果,包含最终隐藏状态、所有隐藏状态和所有注意力权重
            return BaseModelOutput(
                last_hidden_state=hidden_states,
                hidden_states=all_hidden_states,
                attentions=all_self_attentions,
            )
class ViltPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 设置模型的配置类
    config_class = ViltConfig
    # 模型基础名称前缀
    base_model_prefix = "vilt"
    # 支持梯度检查点
    supports_gradient_checkpointing = True
    # 不需要分割的模块列表
    _no_split_modules = ["ViltEmbeddings", "ViltSelfAttention"]

    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            # 如果是线性层或卷积层,使用正态分布初始化权重
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                # 如果有偏置,则将偏置初始化为零
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            # 如果是嵌入层,使用正态分布初始化权重
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                # 如果有填充索引,则将对应位置的权重初始化为零
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            # 如果是LayerNorm层,将偏置初始化为零,权重初始化为1.0
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


# ViLT模型的起始文档字符串
VILT_START_DOCSTRING = r"""
    This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ subclass. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`ViltConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

# ViLT模型输入文档字符串(空白)
VILT_INPUTS_DOCSTRING = r"""
"""

# ViLT图像和文本分类输入文档字符串(空白)
VILT_IMAGES_AND_TEXT_CLASSIFICATION_INPUTS_DOCSTRING = r"""
"""

# 添加起始文档字符串注释到ViltModel类
@add_start_docstrings(
    "The bare ViLT Model transformer outputting raw hidden-states without any specific head on top.",
    VILT_START_DOCSTRING,
)
class ViltModel(ViltPreTrainedModel):
    def __init__(self, config, add_pooling_layer=True):
        super().__init__(config)
        self.config = config

        # 初始化嵌入层和编码器
        self.embeddings = ViltEmbeddings(config)
        self.encoder = ViltEncoder(config)

        # LayerNorm层,用于归一化隐藏层输出
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        
        # 如果需要添加汇聚层,则初始化汇聚器
        self.pooler = ViltPooler(config) if add_pooling_layer else None

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        return self.embeddings.text_embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.text_embeddings.word_embeddings = value
    # 修剪模型的注意力头
    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        # 遍历需要修剪的层和对应需要修剪的注意力头
        for layer, heads in heads_to_prune.items():
            # 对编码器中的特定层的注意力头进行修剪
            self.encoder.layer[layer].attention.prune_heads(heads)

    # 将输入参数添加到模型的文档字符串
    @add_start_docstrings_to_model_forward(VILT_INPUTS_DOCSTRING)
    # 替换返回值的文档字符串
    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
    # 定义模型的前向传播方法
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        pixel_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        image_embeds: Optional[torch.FloatTensor] = None,
        image_token_type_idx: Optional[int] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
class ViltPredictionHeadTransform(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 使用全连接层进行线性变换,输入和输出维度都是 config.hidden_size
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        
        # 根据配置选择激活函数,如果配置中指定的是字符串形式的激活函数,则使用对应的函数,否则直接使用配置中的函数
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        
        # 应用 Layer Normalization 进行归一化处理,参数包括隐藏状态的维度和层归一化的 epsilon 值
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states):
        # 将隐藏状态通过全连接层进行线性变换
        hidden_states = self.dense(hidden_states)
        
        # 应用选择的激活函数进行非线性变换
        hidden_states = self.transform_act_fn(hidden_states)
        
        # 对变换后的隐藏状态应用 Layer Normalization 进行归一化
        hidden_states = self.LayerNorm(hidden_states)
        
        return hidden_states
    # 初始化函数,用于初始化模型对象
    def __init__(self, config, weight=None):
        # 调用父类的初始化方法
        super().__init__()
        # 将配置参数保存到对象的属性中
        self.config = config
        # 创建一个 ViltPredictionHeadTransform 的实例,并保存到对象的属性中
        self.transform = ViltPredictionHeadTransform(config)
        # 创建一个线性层,用于模型的解码器,指定输入大小为 config.hidden_size,输出大小为 config.vocab_size,且没有偏置项
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        # 创建一个可学习的偏置项,大小为 config.vocab_size
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        # 如果给定了预训练权重 weight,则将其赋值给解码器的权重
        if weight is not None:
            self.decoder.weight = weight

        # 为了确保偏置项能够正确地在调整 token embeddings 时被重新调整大小,需要在这里建立两者之间的链接
        self.decoder.bias = self.bias

    # 前向传播函数,接收输入 x 并返回模型的输出 x
    def forward(self, x):
        # 对输入 x 应用预测头变换
        x = self.transform(x)
        # 使用解码器对变换后的 x 进行解码
        x = self.decoder(x)
        # 返回解码后的输出 x
        return x
@add_start_docstrings(
    """
    Vilt Model transformer with a classifier head on top (a linear layer on top of the final hidden state of the [CLS]
    token) for visual question answering, e.g. for VQAv2.
    """,
    VILT_START_DOCSTRING,
)
class ViltForQuestionAnswering(ViltPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.num_labels = config.num_labels
        self.vilt = ViltModel(config)

        # Classifier head
        self.classifier = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size * 2),  # Linear layer to expand hidden size
            nn.LayerNorm(config.hidden_size * 2),  # Layer normalization
            nn.GELU(),  # GELU activation function
            nn.Linear(config.hidden_size * 2, config.num_labels),  # Final linear layer for classification
        )

        # Initialize weights and apply final processing
        self.post_init()

    @add_start_docstrings_to_model_forward(VILT_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        pixel_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        image_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        ) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]:
        r"""
        labels (`torch.FloatTensor` of shape `(batch_size, num_labels)`, *optional*):
            Labels for computing the visual question answering loss. This tensor must be either a one-hot encoding of
            all answers that are applicable for a given example in the batch, or a soft encoding indicating which
            answers are applicable, where 1.0 is the highest score.

        Returns:
            Depending on `return_dict`, returns either a `SequenceClassifierOutput` or a tuple containing logits and optionally other outputs.

        Examples:

        ```
        >>> from transformers import ViltProcessor, ViltForQuestionAnswering
        >>> import requests
        >>> from PIL import Image

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> text = "How many cats are there?"

        >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
        >>> model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

        >>> # prepare inputs
        >>> encoding = processor(image, text, return_tensors="pt")

        >>> # forward pass
        >>> outputs = model(**encoding)
        >>> logits = outputs.logits
        >>> idx = logits.argmax(-1).item()
        >>> print("Predicted answer:", model.config.id2label[idx])
        Predicted answer: 2
        ```"""

        # Determine whether to use the return_dict provided or the class attribute for return settings
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Perform forward pass through the VILT model
        outputs = self.vilt(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            pixel_values=pixel_values,
            pixel_mask=pixel_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            image_embeds=image_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # Extract the pooler_output from the outputs based on return_dict setting
        pooler_output = outputs.pooler_output if return_dict else outputs[1]

        # Pass the pooler_output through the classifier layer to obtain logits
        logits = self.classifier(pooler_output)

        # Initialize loss variable
        loss = None

        # Calculate loss if labels are provided
        if labels is not None:
            # Move labels tensor to the same device as logits for compatibility
            labels = labels.to(logits.device)
            # Compute binary cross entropy loss scaled by number of labels
            loss = nn.functional.binary_cross_entropy_with_logits(logits, labels) * labels.shape[1]
            # Reference to paper or implementation where this loss scaling is discussed

        # Prepare output based on return_dict flag
        if not return_dict:
            # If return_dict is False, prepare tuple output
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # If return_dict is True, prepare SequenceClassifierOutput object
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
# 使用装饰器为类添加文档字符串,描述了该类的作用和功能,以及适用的应用场景(图片到文本或文本到图片的检索)
@add_start_docstrings(
    """
    Vilt Model transformer with a classifier head on top (a linear layer on top of the final hidden state of the [CLS]
    token) for image-to-text or text-to-image retrieval, e.g. MSCOCO and F30K.
    """,
    VILT_START_DOCSTRING,
)
# 定义 ViltForImageAndTextRetrieval 类,继承自 ViltPreTrainedModel
class ViltForImageAndTextRetrieval(ViltPreTrainedModel):
    
    # 初始化方法,接受一个 config 对象作为参数
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)

        # 创建 ViltModel 的实例,并保存到 self.vilt 属性中
        self.vilt = ViltModel(config)

        # 分类器头部,使用线性层将最终隐藏状态([CLS] token)映射到单一输出维度
        self.rank_output = nn.Linear(config.hidden_size, 1)

        # 初始化权重并进行最终处理
        self.post_init()

    # 使用装饰器为 forward 方法添加文档字符串,描述了该方法的输入参数及其作用
    @add_start_docstrings_to_model_forward(VILT_INPUTS_DOCSTRING)
    # 使用装饰器替换返回值的文档字符串,指定输出类型为 SequenceClassifierOutput,配置类为 _CONFIG_FOR_DOC
    @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
    # forward 方法,处理模型的前向传播逻辑
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        pixel_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        image_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels are currently not supported.

        Returns:
            Depending on `return_dict` flag:
                - If `return_dict` is False, returns a tuple containing `logits` and additional outputs.
                - If `return_dict` is True, returns a `SequenceClassifierOutput` object containing `loss`, `logits`, `hidden_states`, and `attentions`.

        Examples:

        ```
        >>> from transformers import ViltProcessor, ViltForImageAndTextRetrieval
        >>> import requests
        >>> from PIL import Image

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> texts = ["An image of two cats chilling on a couch", "A football player scoring a goal"]

        >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-coco")
        >>> model = ViltForImageAndTextRetrieval.from_pretrained("dandelin/vilt-b32-finetuned-coco")

        >>> # forward pass
        >>> scores = dict()
        >>> for text in texts:
        ...     # prepare inputs
        ...     encoding = processor(image, text, return_tensors="pt")
        ...     outputs = model(**encoding)
        ...     scores[text] = outputs.logits[0, :].item()
        ```
        """
        # Determine whether to use the return_dict flag or the model's default configuration
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Perform the forward pass through the VILT model
        outputs = self.vilt(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            pixel_values=pixel_values,
            pixel_mask=pixel_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            image_embeds=image_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # Select the pooler output based on whether return_dict is True or False
        pooler_output = outputs.pooler_output if return_dict else outputs[1]

        # Generate logits using the rank_output method
        logits = self.rank_output(pooler_output)

        # Initialize loss as None
        loss = None

        # Handle labels if provided (currently raises NotImplementedError)
        if labels is not None:
            # Move labels to the device where logits are located
            labels = labels.to(logits.device)
            raise NotImplementedError("Training is not yet supported.")

        # Return the output based on whether return_dict is True or False
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # Return a SequenceClassifierOutput object containing loss, logits, hidden_states, and attentions
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
@add_start_docstrings(
    """
    Vilt Model transformer with a classifier head on top for natural language visual reasoning, e.g. NLVR2.
    """,
    VILT_IMAGES_AND_TEXT_CLASSIFICATION_INPUTS_DOCSTRING,
)
class ViltForImagesAndTextClassification(ViltPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.num_labels = config.num_labels
        self.vilt = ViltModel(config)

        # Classifier head
        num_images = config.num_images
        # 定义分类器,包括线性层、LayerNorm和GELU激活函数
        self.classifier = nn.Sequential(
            nn.Linear(config.hidden_size * num_images, config.hidden_size * num_images),
            nn.LayerNorm(config.hidden_size * num_images),
            nn.GELU(),
            nn.Linear(config.hidden_size * num_images, config.num_labels),
        )

        # Initialize weights and apply final processing
        # 初始化权重并进行最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(VILT_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=ViltForImagesAndTextClassificationOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        pixel_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        image_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,



@add_start_docstrings(
    """
    ViLT Model with a token classification head on top (a linear layer on top of the final hidden-states of the text
    tokens) e.g. for Named-Entity-Recognition (NER) tasks.
    """,
    VILT_START_DOCSTRING,
)
class ViltForTokenClassification(ViltPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.num_labels = config.num_labels
        # 初始化 ViLT 模型,不添加池化层
        self.vilt = ViltModel(config, add_pooling_layer=False)

        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # 分类器是一个线性层,输出维度为 config.num_labels
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        # Initialize weights and apply final processing
        # 初始化权重并进行最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(VILT_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        pixel_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        image_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[TokenClassifierOutput, Tuple[torch.FloatTensor]]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.

        Returns:
            Either a `TokenClassifierOutput` containing loss, logits, hidden states, and attentions,
            or a tuple with logits and optional hidden states and attentions.
        """

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Pass inputs to the VILT model for processing
        outputs = self.vilt(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            pixel_values=pixel_values,
            pixel_mask=pixel_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            image_embeds=image_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        # Determine the size of the text input
        text_input_size = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

        # Apply dropout to the sequence output
        sequence_output = self.dropout(sequence_output)
        
        # Classify tokens using the classifier layer
        logits = self.classifier(sequence_output[:, :text_input_size])

        loss = None
        if labels is not None:
            # Calculate the cross-entropy loss
            loss_fct = CrossEntropyLoss()
            # Move labels to the same device as logits
            labels = labels.to(logits.device)
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            # If return_dict is False, return a tuple of logits and optionally hidden states and attentions
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # If return_dict is True, return a TokenClassifierOutput object
        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

.\models\vilt\processing_vilt.py

# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for ViLT.
"""

import warnings
from typing import List, Optional, Union

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType


class ViltProcessor(ProcessorMixin):
    r"""
    Constructs a ViLT processor which wraps a BERT tokenizer and ViLT image processor into a single processor.

    [`ViltProcessor`] offers all the functionalities of [`ViltImageProcessor`] and [`BertTokenizerFast`]. See the
    docstring of [`~ViltProcessor.__call__`] and [`~ViltProcessor.decode`] for more information.

    Args:
        image_processor (`ViltImageProcessor`, *optional*):
            An instance of [`ViltImageProcessor`]. The image processor is a required input.
        tokenizer (`BertTokenizerFast`, *optional*):
            An instance of ['BertTokenizerFast`]. The tokenizer is a required input.
    """

    attributes = ["image_processor", "tokenizer"]
    image_processor_class = "ViltImageProcessor"
    tokenizer_class = ("BertTokenizer", "BertTokenizerFast")

    def __init__(self, image_processor=None, tokenizer=None, **kwargs):
        # Check if 'feature_extractor' is provided in kwargs; deprecated warning
        feature_extractor = None
        if "feature_extractor" in kwargs:
            warnings.warn(
                "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
                " instead.",
                FutureWarning,
            )
            feature_extractor = kwargs.pop("feature_extractor")

        # Use 'feature_extractor' if 'image_processor' is not provided; raise error if neither are provided
        image_processor = image_processor if image_processor is not None else feature_extractor
        if image_processor is None:
            raise ValueError("You need to specify an `image_processor`.")
        if tokenizer is None:
            raise ValueError("You need to specify a `tokenizer`.")

        # Initialize the processor mixin with image processor and tokenizer
        super().__init__(image_processor, tokenizer)
        # Set current_processor to image_processor
        self.current_processor = self.image_processor
    def __call__(
        self,
        images,
        text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
        add_special_tokens: bool = True,
        padding: Union[bool, str, PaddingStrategy] = False,
        truncation: Union[bool, str, TruncationStrategy] = None,
        max_length: Optional[int] = None,
        stride: int = 0,
        pad_to_multiple_of: Optional[int] = None,
        return_token_type_ids: Optional[bool] = None,
        return_attention_mask: Optional[bool] = None,
        return_overflowing_tokens: bool = False,
        return_special_tokens_mask: bool = False,
        return_offsets_mapping: bool = False,
        return_length: bool = False,
        verbose: bool = True,
        return_tensors: Optional[Union[str, TensorType]] = None,
        **kwargs,
    ) -> BatchEncoding:
        """
        This method uses [`ViltImageProcessor.__call__`] method to prepare image(s) for the model, and
        [`BertTokenizerFast.__call__`] to prepare text for the model.

        Please refer to the docstring of the above two methods for more information.
        """
        # 使用 tokenizer 处理文本,准备输入编码
        encoding = self.tokenizer(
            text=text,
            add_special_tokens=add_special_tokens,
            padding=padding,
            truncation=truncation,
            max_length=max_length,
            stride=stride,
            pad_to_multiple_of=pad_to_multiple_of,
            return_token_type_ids=return_token_type_ids,
            return_attention_mask=return_attention_mask,
            return_overflowing_tokens=return_overflowing_tokens,
            return_special_tokens_mask=return_special_tokens_mask,
            return_offsets_mapping=return_offsets_mapping,
            return_length=return_length,
            verbose=verbose,
            return_tensors=return_tensors,
            **kwargs,
        )
        # 使用 image_processor 处理图像,准备输入编码
        encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)
        # 将图像编码信息更新到文本编码信息中
        encoding.update(encoding_image_processor)

        # 返回合并后的编码信息
        return encoding

    def batch_decode(self, *args, **kwargs):
        """
        This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
        refer to the docstring of this method for more information.
        """
        # 调用 tokenizer 的 batch_decode 方法进行批量解码
        return self.tokenizer.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        """
        This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
        the docstring of this method for more information.
        """
        # 调用 tokenizer 的 decode 方法进行解码
        return self.tokenizer.decode(*args, **kwargs)

    @property
    def model_input_names(self):
        # 获取 tokenizer 和 image_processor 的模型输入名称列表,并去重
        tokenizer_input_names = self.tokenizer.model_input_names
        image_processor_input_names = self.image_processor.model_input_names
        return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
    # 发出警告,提示用户 `feature_extractor_class` 方法即将在 v5 版本中移除,建议使用 `image_processor_class` 方法代替
    def feature_extractor_class(self):
        warnings.warn(
            "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.",
            FutureWarning,
        )
        # 返回当前对象的 `image_processor_class` 属性
        return self.image_processor_class

    # 属性装饰器,发出警告,提示用户 `feature_extractor` 属性即将在 v5 版本中移除,建议使用 `image_processor` 属性代替
    @property
    def feature_extractor(self):
        warnings.warn(
            "`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.",
            FutureWarning,
        )
        # 返回当前对象的 `image_processor` 属性
        return self.image_processor

.\models\vilt\__init__.py

# 导入必要的模块和函数
from typing import TYPE_CHECKING

# 导入自定义的异常类,用于处理可选依赖不可用的情况
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available

# 定义模块的导入结构,包括各个子模块的名称和导入项
_import_structure = {"configuration_vilt": ["VILT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViltConfig"]}

# 检查视觉功能是否可用,若不可用则抛出自定义异常
try:
    if not is_vision_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,则添加视觉特征提取、图像处理和处理模块到导入结构中
    _import_structure["feature_extraction_vilt"] = ["ViltFeatureExtractor"]
    _import_structure["image_processing_vilt"] = ["ViltImageProcessor"]
    _import_structure["processing_vilt"] = ["ViltProcessor"]

# 检查是否可用的 PyTorch 包,若不可用则抛出自定义异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,则添加模型相关的各类模块到导入结构中
    _import_structure["modeling_vilt"] = [
        "VILT_PRETRAINED_MODEL_ARCHIVE_LIST",
        "ViltForImageAndTextRetrieval",
        "ViltForImagesAndTextClassification",
        "ViltForTokenClassification",
        "ViltForMaskedLM",
        "ViltForQuestionAnswering",
        "ViltLayer",
        "ViltModel",
        "ViltPreTrainedModel",
    ]

# 如果是类型检查阶段,则导入具体的配置和模型类
if TYPE_CHECKING:
    from .configuration_vilt import VILT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViltConfig

    # 检查视觉功能是否可用,若可用则导入相应的特征提取、图像处理和处理模块
    try:
        if not is_vision_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .feature_extraction_vilt import ViltFeatureExtractor
        from .image_processing_vilt import ViltImageProcessor
        from .processing_vilt import ViltProcessor

    # 检查 PyTorch 包是否可用,若可用则导入模型相关的各类模块
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_vilt import (
            VILT_PRETRAINED_MODEL_ARCHIVE_LIST,
            ViltForImageAndTextRetrieval,
            ViltForImagesAndTextClassification,
            ViltForMaskedLM,
            ViltForQuestionAnswering,
            ViltForTokenClassification,
            ViltLayer,
            ViltModel,
            ViltPreTrainedModel,
        )

# 如果不是类型检查阶段,则设置当前模块为懒加载模块
else:
    import sys

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)

.\models\vipllava\configuration_vipllava.py

# 定义模块的版权信息和许可协议
# coding=utf-8
# Copyright 2023 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" VipLlava model configuration"""

# 引入警告模块
import warnings

# 从 transformers 包中引入预训练配置类 PretrainedConfig
from ...configuration_utils import PretrainedConfig
# 从 transformers.utils 中引入日志记录功能
from ...utils import logging
# 从 transformers.modeling_auto 中引入配置映射
from ..auto import CONFIG_MAPPING

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义预训练模型配置文件的映射
VIPLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "ybelkada/vip-llava-7b-hf": "https://huggingface.co/llava-hf/vip-llava-7b-hf/resolve/main/config.json",
}

# 定义 VipLlavaConfig 类,继承自 PretrainedConfig 类
class VipLlavaConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`VipLlavaForConditionalGeneration`]. It is used to instantiate an
    VipLlava model according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the VipLlava-9B.

    e.g. [ybelkada/vip-llava-7b-hf](https://huggingface.co/ybelkada/vip-llava-7b-hf)

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        vision_config (`VipLlavaVisionConfig`,  *optional*):
            Custom vision config or dict
        text_config (`Union[AutoConfig, dict]`, *optional*):
            The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
        ignore_index (`int`, *optional*, defaults to -100):
            The ignore index for the loss function.
        image_token_index (`int`, *optional*, defaults to 32000):
            The image token index to encode the image prompt.
        projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
            The activation function used by the multimodal projector.
        projector_layernorm_eps (`float`, *optional*, defaults to 1e-05):
            The layer norm epsilon of the projector layernorm
        vision_feature_layers (`List[int]`, *optional*, defaults to `[-2, -5, -8, -11, 6]`):
            The list of layers to select the vision features from.

    Example:

    ```
    >>> from transformers import VipLlavaForConditionalGeneration, VipLlavaConfig, CLIPVisionConfig, LlamaConfig

    >>> # Initializing a CLIP-vision config
    >>> vision_config = CLIPVisionConfig()

    >>> # Initializing a Llama config
    >>> text_config = LlamaConfig()

    ```

    """
    # 初始化一个 VipLlava vipllava-7b 风格的配置
    >>> configuration = VipLlavaConfig(vision_config, text_config)

    # 使用 vipllava-7b 风格的配置初始化一个模型
    >>> model = VipLlavaForConditionalGeneration(configuration)

    # 获取模型的配置信息
    >>> configuration = model.config

.\models\vipllava\convert_vipllava_weights_to_hf.py

# 导入 argparse 库,用于处理命令行参数
import argparse

# 导入 torch 库
import torch
# 从 huggingface_hub 库导入 hf_hub_download 函数
from huggingface_hub import hf_hub_download

# 从 transformers 库导入以下类和函数
from transformers import (
    AddedToken,
    AutoConfig,
    AutoTokenizer,
    CLIPImageProcessor,
    LlavaProcessor,
    VipLlavaConfig,
    VipLlavaForConditionalGeneration,
)

# 定义一个字典映射,用于修改模型权重的键名
KEYS_TO_MODIFY_MAPPING = {
    "model.vision_tower.": "",
    "model.mm_projector": "multi_modal_projector",
    "model": "model.model",
    "vision_model.model": "vision_model",
    "lm_head": "language_model.lm_head",
    "model.model": "language_model.model",
    "multi_modal_projector.0": "multi_modal_projector.linear_1",
    "multi_modal_projector.2": "multi_modal_projector.linear_2",
    "final_linear.0": "linear_1",
    "final_linear.2": "linear_2",
    "multi_modal_projector.clip_layernorm": "multi_modal_projector.projector_layernorm",
}

# 定义函数,将旧版权重字典转换为适合 HF 的新版权重字典
def convert_state_dict_to_hf(state_dict):
    new_state_dict = {}
    for key, value in state_dict.items():
        # 如果键以 ".inv_freq" 结尾,则跳过不处理
        if key.endswith(".inv_freq"):
            continue
        # 遍历键名映射字典,替换对应的键名
        for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
            if key_to_modify in key:
                key = key.replace(key_to_modify, new_key)
        # 更新新版权重字典
        new_state_dict[key] = value
    return new_state_dict


# 定义函数,将 vipllava_llama 模型转换为适合 HF 的模型
def convert_vipllava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id):
    # 设置默认张量数据类型为 float16
    torch.set_default_dtype(torch.float16)
    
    # 从预训练模型 ID 加载文本配置
    text_config = AutoConfig.from_pretrained(text_model_id)

    # 从预训练模型 ID 加载分词器,并添加特殊标记
    tokenizer = AutoTokenizer.from_pretrained(text_model_id)
    tokenizer.add_tokens(AddedToken("<image>", special=True, normalized=False), special_tokens=True)
    tokenizer.add_special_tokens({"pad_token": "<pad>"})

    # 从预训练模型 ID 加载图像处理器
    image_processor = CLIPImageProcessor.from_pretrained(vision_model_id)

    # 创建 LlavaProcessor 对象,用于处理文本和图像
    processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)

    # 使用 VipLlavaConfig 配置对象,配置 ViPLlava 模型
    config = VipLlavaConfig(text_config=text_config)
    config.pad_token_id = 32001

    # 在 meta 设备上创建 VipLlavaForConditionalGeneration 模型
    with torch.device("meta"):
        model = VipLlavaForConditionalGeneration(config)

    # 为了提高性能,将输入填充至 64 的形状
    pad_shape = 64

    # 下载并加载旧版权重字典路径
    state_dict_path = hf_hub_download(old_state_dict_id, "model_state_dict_7b.bin")
    state_dict = torch.load(state_dict_path, map_location="cpu")  # 在 CPU 上加载权重字典
    state_dict = convert_state_dict_to_hf(state_dict)  # 转换权重字典为适合 HF 的格式
    # 使用给定的状态字典加载模型的状态,严格匹配模型参数,同时允许分配参数
    model.load_state_dict(state_dict, strict=True, assign=True)

    # 获取模型语言模型的词嵌入权重,用于计算均值和协方差
    pre_expansion_embeddings = model.language_model.model.embed_tokens.weight.data
    # 计算词嵌入的均值向量
    mu = torch.mean(pre_expansion_embeddings, dim=0).float()
    # 获取词嵌入矩阵的行数,用于计算协方差
    n = pre_expansion_embeddings.size()[0]
    # 计算词嵌入矩阵的协方差矩阵
    sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n
    # 创建一个多元正态分布对象,以 mu 为均值,sigma 为协方差矩阵
    dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma)

    # 调整模型的词嵌入层,增加一个特殊的图像标记,以适应扩展后的词汇表大小和填充形状
    model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape)
    # 为词嵌入矩阵的扩展部分生成样本,并替换模型中的词嵌入权重
    model.language_model.model.embed_tokens.weight.data[32000:] = torch.stack(
        tuple((dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[32000:].shape[0]))),
        dim=0,
    )
    # 为语言模型头部的权重矩阵的扩展部分生成样本,并替换模型中的权重
    model.language_model.lm_head.weight.data[32000:] = torch.stack(
        tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[32000:].shape[0]))),
        dim=0,
    )

    # 将模型推送到指定的 Hub 输出路径
    model.push_to_hub(output_hub_path)
    # 将处理器对象推送到指定的 Hub 输出路径
    processor.push_to_hub(output_hub_path)
# 主程序入口函数
def main():
    # 创建命令行参数解析器对象
    parser = argparse.ArgumentParser()
    
    # 添加命令行参数:text_model_id,用于指定文本模型的 Hub 地址
    parser.add_argument(
        "--text_model_id",
        help="Hub location of the text model",
    )
    
    # 添加命令行参数:vision_model_id,用于指定视觉模型的 Hub 地址
    parser.add_argument(
        "--vision_model_id",
        help="Hub location of the vision model",
    )
    
    # 添加命令行参数:output_hub_path,用于指定转换后模型在 Hub 上的位置
    parser.add_argument(
        "--output_hub_path",
        help="Location on the hub of the converted model",
    )
    
    # 添加命令行参数:old_state_dict_id,用于指定原始模型状态字典的 Hub 地址
    # 需要注意文件名应为 `model_state_dict.bin`
    parser.add_argument(
        "--old_state_dict_id",
        help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`",
    )
    
    # 解析命令行参数
    args = parser.parse_args()
    
    # 调用函数 convert_vipllava_llama_to_hf,将参数传递给该函数
    convert_vipllava_llama_to_hf(
        args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id
    )


# 如果当前脚本被直接执行,则调用主函数 main()
if __name__ == "__main__":
    main()

.\models\vipllava\modeling_vipllava.py

# coding=utf-8
# Copyright 2023 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch VipLlava model."""

# 导入必要的库和模块
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn

# 导入 Hugging Face Transformers 中的预训练模型基类和其他必要组件
from ... import PreTrainedModel
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...modeling_outputs import ModelOutput
from ...utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from ..auto import AutoModel, AutoModelForCausalLM

# 导入 VipLlava 模型的配置类
from .configuration_vipllava import VipLlavaConfig

# 获取 logger 对象用于日志记录
logger = logging.get_logger(__name__)

# 文档中显示的配置对象名称
_CONFIG_FOR_DOC = "VipLlavaConfig"

# 预训练模型的存档列表,包括一个示例
VIPLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "llava-hf/vip-llava-7b-hf",
    # See all VipLlava models at https://huggingface.co/models?filter=vipllava
]

# 定义一个数据类,用于表示 VipLlava 模型的自回归语言模型输出及过去状态
@dataclass
# 从 Idefics 模型中复制的类,作为 VipLlava 模型的输出基类
class VipLlavaCausalLMOutputWithPast(ModelOutput):
    """
    Base class for VipLlava causal language model (or autoregressive) outputs.
    """
    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`)

            Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
            `past_key_values` input) to speed up sequential decoding.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
            Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
            sequence_length, hidden_size)`.

            image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
    """

    # Optional loss value for language modeling
    loss: Optional[torch.FloatTensor] = None
    # Predicted logits for each token in the batch
    logits: torch.FloatTensor = None
    # Cached key and value states for speeding up sequential decoding
    past_key_values: Optional[List[torch.FloatTensor]] = None
    # Hidden states of the model at each layer's output and optional initial embeddings
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # Attention weights after softmax, used for self-attention computation
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    # Hidden states produced by the vision encoder for image embeddings
    image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@add_start_docstrings(
    "The bare VipLlava Model outputting raw hidden-states without any specific head on top.",
    VIPLLAVA_START_DOCSTRING,
)
# 为 VipLlavaPreTrainedModel 类添加文档字符串,描述其作为 VipLlava 模型的基础预训练模型的输出为原始隐藏状态,没有特定的输出头部。

# 从 PreTrainedModel 类继承,定义 VipLlavaPreTrainedModel 类
class VipLlavaPreTrainedModel(PreTrainedModel):
    # 指定配置类为 VipLlavaConfig
    config_class = VipLlavaConfig
    # 模型的基础名称前缀
    base_model_prefix = "model"
    # 支持梯度检查点
    supports_gradient_checkpointing = True
    # 不分割的模块列表
    _no_split_modules = ["VipLlavaVisionAttention"]
    # 跳过键设备放置
    _skip_keys_device_placement = "past_key_values"
    # 支持 Flash Attention 2
    _supports_flash_attn_2 = True
    # 初始化模型权重的方法,用于对传入的模块进行权重初始化
    def _init_weights(self, module):
        # 注意: 这个迁移版本的 VipLlava 不适用于从头训练,只能用于推理和微调。
        # 因此,适当的初始化权重代码已经被移除。原始代码库位于 https://github.com/haotian-liu/LLaVA/tree/main/vipllava,可以用于训练目的。

        # 根据配置获取初始化标准差
        std = (
            self.config.initializer_range
            if hasattr(self.config, "initializer_range")
            else self.config.text_config.initializer_range
        )

        # 如果模块具有类嵌入(class_embedding)属性,则对其进行标准正态分布初始化
        if hasattr(module, "class_embedding"):
            module.class_embedding.data.normal_(mean=0.0, std=std)

        # 如果模块是线性层(nn.Linear)或二维卷积层(nn.Conv2d),则对权重进行标准正态分布初始化,
        # 如果有偏置,则将偏置初始化为零
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        
        # 如果模块是嵌入层(nn.Embedding),则对权重进行标准正态分布初始化,
        # 如果定义了填充索引(padding_idx),则将该索引处的权重初始化为零
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

    @property
    def _supports_sdpa(self):
        """
        检索语言模型的属性,检查模型是否支持 SDPA(Self-Attention with Dual Paths)。
        """
        return self.language_model._supports_sdpa
# 定义模型文档字符串,用于描述 VIPLLAVA 模型的输入
VIPLLAVA_INPUTS_DOCSTRING = r"""
"""


@add_start_docstrings(
    """The VIPLLAVA model which consists of a vision backbone and a language model.""",
    VIPLLAVA_START_DOCSTRING,
)
# 从 transformers.models.llava.modeling_llava.LlavaForConditionalGeneration 复制而来,将 LLAVA 改为 VIPLLAVA,Llava 改为 VipLlava
class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
    def __init__(self, config: VipLlavaConfig):
        super().__init__(config)
        # 初始化视觉塔模型,使用从配置中获取的视觉配置
        self.vision_tower = AutoModel.from_config(config.vision_config)

        # 初始化多模态投影器
        self.multi_modal_projector = VipLlavaMultiModalProjector(config)
        # 获取文本配置中的词汇表大小作为模型的词汇表大小
        self.vocab_size = config.text_config.vocab_size
        # 初始化语言模型,使用从配置中获取的文本配置和注意力实现方式
        self.language_model = AutoModelForCausalLM.from_config(
            config.text_config, attn_implementation=config._attn_implementation
        )
        # 如果配置中定义了 pad_token_id,则使用配置中的值;否则使用 -1
        self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
        # 执行初始化后的处理
        self.post_init()

    def get_input_embeddings(self):
        # 获取语言模型的输入嵌入层
        return self.language_model.get_input_embeddings()

    def set_input_embeddings(self, value):
        # 设置语言模型的输入嵌入层
        self.language_model.set_input_embeddings(value)

    def get_output_embeddings(self):
        # 获取语言模型的输出嵌入层
        return self.language_model.get_output_embeddings()

    def set_output_embeddings(self, new_embeddings):
        # 设置语言模型的输出嵌入层
        self.language_model.set_output_embeddings(new_embeddings)

    def set_decoder(self, decoder):
        # 设置语言模型的解码器
        self.language_model.set_decoder(decoder)

    def get_decoder(self):
        # 获取语言模型的解码器
        return self.language_model.get_decoder()

    def tie_weights(self):
        # 绑定语言模型的权重
        return self.language_model.tie_weights()

    def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
        # 调整语言模型的 token 嵌入层大小,并更新模型配置中的词汇表大小
        model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
        self.config.text_config.vocab_size = model_embeds.num_embeddings
        self.vocab_size = model_embeds.num_embeddings
        return model_embeds

    @add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    # 忽略复制
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        pixel_values: torch.FloatTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        vision_feature_layers: Optional[List[int]] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        pass  # 在正式实现前,暂时占位,不执行任何操作

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
    ):
        pass  # 在正式实现前,暂时占位,不执行任何操作
    ):
        # 如果传入的过去键值不为 None,则处理缓存相关逻辑
        if past_key_values is not None:
            # 如果过去键值是 Cache 类型,则获取序列长度和已见标记数
            if isinstance(past_key_values, Cache):
                cache_length = past_key_values.get_seq_length()
                past_length = past_key_values.seen_tokens
            else:
                # 否则,假设过去键值的第一个元素的第一个维度是 token 的形状的长度
                cache_length = past_length = past_key_values[0][0].shape[2]

            # 保留未处理的 token:
            # 1 - 如果 attention_mask 的长度超过 input_ids 的长度,则处理仅作为缓存传递的情况
            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
            # 2 - 如果 past_length 小于 input_ids 的长度,则 input_ids 包含所有输入 token,可以基于 past_length 丢弃 input_ids
            elif past_length < input_ids.shape[1]:
                input_ids = input_ids[:, past_length:]
            # 3 - 否则(past_length >= input_ids.shape[1]),假设 input_ids 只有未处理的 token
            elif self.config.image_token_index in input_ids:
                input_ids = input_ids[:, input_ids.shape[1] - 1 :]
            
            # 如果缓存已见 token 数超过其容量限制,那么缓存有一个大小限制。丢弃较早的 attention 值,因为它们对应的值不是输入的一部分。
            if cache_length < past_length and attention_mask is not None:
                attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]

        position_ids = kwargs.get("position_ids", None)
        # 如果 attention_mask 不为 None 且 position_ids 为 None,则在批量生成时动态创建 position_ids
        if attention_mask is not None and position_ids is None:
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -input_ids.shape[1] :]

        # 如果传入 inputs_embeds,则仅在第一代步骤中使用它们
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        # 更新 model_inputs 字典
        model_inputs.update(
            {
                "position_ids": position_ids,
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
                "pixel_values": pixel_values,
            }
        )
        return model_inputs

    # 重排序缓存的内部方法委托给语言模型的 _reorder_cache 方法
    def _reorder_cache(self, *args, **kwargs):
        return self.language_model._reorder_cache(*args, **kwargs)

.\models\vipllava\__init__.py

# 版权声明和许可信息,指明代码版权和许可协议
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 引入类型检查模块
from typing import TYPE_CHECKING

# 导入可选依赖未可用的异常
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available

# 模块导入结构定义
_import_structure = {"configuration_vipllava": ["VIPLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP", "VipLlavaConfig"]}

# 检查是否可用 torch
try:
    if not is_torch_available():
        # 若不可用,则抛出可选依赖不可用异常
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    # 处理可选依赖不可用的情况
    pass
else:
    # 若可用 torch,则添加以下模块到导入结构中
    _import_structure["modeling_vipllava"] = [
        "VIPLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST",
        "VipLlavaForConditionalGeneration",
        "VipLlavaPreTrainedModel",
    ]

# 如果进行类型检查
if TYPE_CHECKING:
    # 从 configuration_vipllava 模块导入特定符号
    from .configuration_vipllava import VIPLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP, VipLlavaConfig

    # 再次检查是否可用 torch
    try:
        if not is_torch_available():
            # 若不可用,则抛出可选依赖不可用异常
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        # 处理可选依赖不可用的情况
        pass
    else:
        # 若可用 torch,则从 modeling_vipllava 模块导入特定符号
        from .modeling_vipllava import (
            VIPLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST,
            VipLlavaForConditionalGeneration,
            VipLlavaPreTrainedModel,
        )

# 如果不进行类型检查
else:
    # 导入 sys 模块
    import sys

    # 动态设置当前模块为 LazyModule,延迟加载模块
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)

.\models\vision_encoder_decoder\configuration_vision_encoder_decoder.py

# 引入需要的模块和类
from typing import TYPE_CHECKING, Any, Mapping, Optional, OrderedDict
# 引入版本控制的模块
from packaging import version
# 引入日志记录工具
from ...utils import logging
# 引入自动配置模块
from ..auto.configuration_auto import AutoConfig

# 如果是类型检查阶段,则导入必要的类型
if TYPE_CHECKING:
    from ... import PreTrainedTokenizerBase, TensorType

# 获取全局日志记录器
logger = logging.get_logger(__name__)


class VisionEncoderDecoderConfig(PretrainedConfig):
    r"""
    [`VisionEncoderDecoderConfig`] 是配置类,用于存储 [`VisionEncoderDecoderModel`] 的配置信息。
    用于根据指定的参数实例化一个 Vision-Encoder-Text-Decoder 模型,定义编码器和解码器的配置。

    配置对象继承自 [`PretrainedConfig`],用于控制模型的输出。查阅 [`PretrainedConfig`] 的文档获取更多信息。

    Args:
        kwargs (*optional*):
            关键字参数的字典。特别是:

                - **encoder** ([`PretrainedConfig`], *optional*) -- 定义编码器配置的配置对象实例。
                - **decoder** ([`PretrainedConfig`], *optional*) -- 定义解码器配置的配置对象实例。

    Examples:

    ```
    >>> from transformers import BertConfig, ViTConfig, VisionEncoderDecoderConfig, VisionEncoderDecoderModel

    >>> # 初始化 ViT 和 BERT 风格的配置
    >>> config_encoder = ViTConfig()
    >>> config_decoder = BertConfig()

    >>> config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)

    >>> # 初始化一个 ViTBert 模型(具有随机权重),从 ViT 和 google-bert/bert-base-uncased 风格的配置开始
    >>> model = VisionEncoderDecoderModel(config=config)

    >>> # 访问模型配置
    >>> config_encoder = model.config.encoder
    >>> config_decoder = model.config.decoder
    >>> # 将解码器配置设置为 causal lm
    >>> config_decoder.is_decoder = True
    >>> config_decoder.add_cross_attention = True

    >>> # 保存模型,包括其配置
    >>> model.save_pretrained("my-model")

    >>> # 从预训练文件夹加载模型和配置

    ```
    """
    pass  # VisionEncoderDecoderConfig 类定义结束
    # 使用预训练模型名称加载视觉编码-解码器配置
    encoder_decoder_config = VisionEncoderDecoderConfig.from_pretrained("my-model")
    # 使用预训练模型名称加载视觉编码-解码器模型,传入相应的配置
    model = VisionEncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config)



    # 定义模型类型为视觉编码-解码器
    model_type = "vision-encoder-decoder"
    # 标记该类为组合类
    is_composition = True

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # 检查是否传入了编码器和解码器的配置,否则抛出异常
        if "encoder" not in kwargs or "decoder" not in kwargs:
            raise ValueError(
                f"A configuraton of type {self.model_type} cannot be instantiated because "
                f"not both `encoder` and `decoder` sub-configurations are passed, but only {kwargs}"
            )

        # 弹出并获取编码器配置和模型类型
        encoder_config = kwargs.pop("encoder")
        encoder_model_type = encoder_config.pop("model_type")
        # 弹出并获取解码器配置和模型类型
        decoder_config = kwargs.pop("decoder")
        decoder_model_type = decoder_config.pop("model_type")

        # 根据编码器配置创建自动配置对象
        self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config)
        # 根据解码器配置创建自动配置对象
        self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config)
        # 标记该模型为编码-解码器结构
        self.is_encoder_decoder = True

    @classmethod
    def from_encoder_decoder_configs(
        cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
    ) -> PretrainedConfig:
        """
        从预训练的编码器模型配置和解码器模型配置实例化一个 `VisionEncoderDecoderConfig`(或其派生类)。

        返回:
            [`VisionEncoderDecoderConfig`]: 配置对象的一个实例
        """
        # 记录日志信息,设置解码器配置为True和添加交叉注意力机制为True
        logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
        decoder_config.is_decoder = True
        decoder_config.add_cross_attention = True

        # 返回使用编码器和解码器配置实例化的类的实例
        return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
class VisionEncoderDecoderEncoderOnnxConfig(OnnxConfig):
    # 定义 Torch ONNX 的最低版本要求为 1.11
    torch_onnx_minimum_version = version.parse("1.11")

    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        # 返回输入规范化的顺序字典,定义了各个输入的维度信息
        return OrderedDict(
            [
                ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
            ]
        )

    @property
    def atol_for_validation(self) -> float:
        # 返回用于验证的绝对误差容限
        return 1e-4

    @property
    def outputs(self) -> Mapping[str, Mapping[int, str]]:
        # 返回输出规范化的顺序字典,定义了各个输出的维度信息
        return OrderedDict({"last_hidden_state": {0: "batch", 1: "encoder_sequence"}})


class VisionEncoderDecoderDecoderOnnxConfig(OnnxConfig):
    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        # 返回输入规范化的顺序字典,定义了各个公共输入的维度信息
        common_inputs = OrderedDict()
        common_inputs["input_ids"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
        common_inputs["attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
        common_inputs["encoder_hidden_states"] = {0: "batch", 1: "encoder_sequence"}

        return common_inputs

    def generate_dummy_inputs(
        self,
        tokenizer: "PreTrainedTokenizerBase",
        batch_size: int = -1,
        seq_length: int = -1,
        is_pair: bool = False,
        framework: Optional["TensorType"] = None,
    ) -> Mapping[str, Any]:
        import torch

        common_inputs = OrderedDict()

        # 调用父类方法生成虚拟输入
        dummy_input = super().generate_dummy_inputs(
            tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
        )

        # 提取 input_ids 的 batch 和 encoder_sequence 的值
        batch, encoder_sequence = dummy_input["input_ids"].shape
        # 创建 encoder_hidden_states 的形状 (batch, encoder_sequence, encoder_hidden_size) 的零张量
        encoder_hidden_states_shape = (batch, encoder_sequence, self._config.encoder_hidden_size)
        common_inputs["input_ids"] = dummy_input.pop("input_ids")
        common_inputs["attention_mask"] = dummy_input.pop("attention_mask")
        common_inputs["encoder_hidden_states"] = torch.zeros(encoder_hidden_states_shape)

        return common_inputs


class VisionEncoderDecoderOnnxConfig(OnnxConfig):
    @property
    def inputs(self) -> None:
        # 空实现,表示没有特定的输入定义
        pass

    def get_encoder_config(self, encoder_config: PretrainedConfig) -> OnnxConfig:
        r"""
        返回用于 `VisionEncoderDecoder` 模型的 ONNX 编码器配置。

        Args:
            encoder_config (`PretrainedConfig`):
                导出到 ONNX 时使用的编码器模型配置。

        Returns:
            [`VisionEncoderDecoderEncoderOnnxConfig`]: ONNX 配置对象的实例
        """
        return VisionEncoderDecoderEncoderOnnxConfig(encoder_config)

    def get_decoder_config(
        self, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, feature: str = "default"
        # 返回用于 `VisionEncoderDecoder` 模型的 ONNX 解码器配置
    ) -> OnnxConfig:
        r"""
        Returns ONNX decoder config for `VisionEncoderDecoder` model.

        Args:
            encoder_config (`PretrainedConfig`):
                The encoder model's configuration to use when exporting to ONNX.
            decoder_config (`PretrainedConfig`):
                The decoder model's configuration to use when exporting to ONNX
            feature (`str`, *optional*):
                The type of feature to export the model with.

        Returns:
            [`VisionEncoderDecoderDecoderOnnxConfig`]: An instance of the ONNX configuration object.
        """
        # 设置解码器配置的隐藏状态大小为编码器配置的隐藏状态大小
        decoder_config.encoder_hidden_size = encoder_config.hidden_size
        # 返回一个包含解码器配置和特征的 ONNX 配置对象实例
        return VisionEncoderDecoderDecoderOnnxConfig(decoder_config, feature)

.\models\vision_encoder_decoder\modeling_flax_vision_encoder_decoder.py

# 设定文件编码为 UTF-8
# 版权声明和许可证明,使用 Apache License, Version 2.0
#
# 导入必要的库和模块
import os
from typing import Optional, Tuple, Union

import flax.linen as nn  # 导入 Flax 的神经网络模块
import jax  # 导入 JAX,用于自动微分和加速计算
import jax.numpy as jnp  # 导入 JAX 的 NumPy 接口
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze  # 导入 Flax 的冻结字典相关功能
from flax.traverse_util import flatten_dict, unflatten_dict  # 导入 Flax 的字典扁平化和还原功能
from jax import lax  # 导入 JAX 的低级别 API
from jax.random import PRNGKey  # 导入 JAX 的随机数生成器 PRNGKey

# 导入模型输出相关的类
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
# 导入 FlaxPreTrainedModel 类,用于定义 Flax 模型的基类
from ...modeling_flax_utils import FlaxPreTrainedModel
# 导入文档字符串相关的工具函数和日志记录
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
# 导入自动配置类 AutoConfig
from ..auto.configuration_auto import AutoConfig
# 导入自动化模型类,包括通用模型和用于因果语言建模的模型
from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM
# 导入视觉编码解码配置类 VisionEncoderDecoderConfig
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig

# 获取日志记录器
logger = logging.get_logger(__name__)

# 定义一个文档字符串常量,描述视觉-编码器-文本-解码器架构的类
_CONFIG_FOR_DOC = "VisionEncoderDecoderConfig"

VISION_ENCODER_DECODER_START_DOCSTRING = r"""
    This class can be used to initialize an image-to-text-sequence model with any pretrained vision autoencoding model
    as the encoder and any pretrained text autoregressive model as the decoder. The encoder is loaded via
    [`~AutoModel.from_pretrained`] function and the decoder is loaded via [`~AutoModelForCausalLM.from_pretrained`]
    function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream
    generative task, like image captioning.

    The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
    tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
    Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
    Zhou, Wei Li, Peter J. Liu.

    Additionally, in [TrOCR: Transformer-based Optical Character Recognition with Pre-trained
    Models](https://arxiv.org/abs/2109.10282) it is shown how leveraging large pretrained vision models for optical
    character recognition (OCR) yields a significant performance improvement.

    After such a Vision-Encoder-Text-Decoder model has been trained/fine-tuned, it can be saved/loaded just like any
    other models (see the examples for more information).
"""
    # 这个模型继承自 `FlaxPreTrainedModel`。查看超类文档以了解库实现的通用方法,如下载或保存模型、调整输入嵌入的大小、修剪头部等。

    # 这个模型还是一个 Flax Linen [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) 子类。
    # 将其用作常规的 Flax Module,并参考 Flax 文档以获取与一般使用和行为相关的所有信息。

    # Parameters:
    #     config ([`VisionEncoderDecoderConfig`]): 模型配置类,包含模型的所有参数。
    #         使用配置文件初始化不会加载与模型关联的权重,只加载配置。查看 [`~FlaxPreTrainedModel.from_pretrained`] 方法以加载模型权重。
    #     dtype (`jax.numpy.dtype`, *optional*, 默认为 `jax.numpy.float32`):
    #         计算的数据类型。可以是 `jax.numpy.float32`、`jax.numpy.float16`(在 GPU 上)和 `jax.numpy.bfloat16`(在 TPU 上)之一。
    #
    #         这可用于在 GPU 或 TPU 上启用混合精度训练或半精度推断。如果指定,则所有计算将使用给定的 `dtype` 执行。
    #
    #         **注意,这只指定计算的数据类型,不影响模型参数的数据类型。**
    #
    #         如果您希望更改模型参数的数据类型,请参阅 [`~FlaxPreTrainedModel.to_fp16`] 和 [`~FlaxPreTrainedModel.to_bf16`]。
"""

VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
    Args:
        pixel_values (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using the vision model's image processor. For example, using
            [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] for details.
        decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary.

            Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are decoder input IDs?](../glossary#decoder-input-ids)
        decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
        decoder_position_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
            range `[0, config.decoder.max_position_embeddings - 1]`.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple.
"""

VISION_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r"""
    Args:
        pixel_values (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using the vision model's image processor. For example, using
            [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] for details.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple.
"""

VISION_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
    Args:
        decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`):
            Indices of decoder input sequence tokens in the vocabulary. These tokens are generated by the model during
            decoding based on the provided `decoder_start_token_id`.
        decoder_start_token_id (`int`):
            The id of the token to start decoding with. This is usually the beginning-of-sequence token.
        encoder_outputs (`Union[FlaxBaseModelOutput, Tuple[jnp.ndarray]]`):
            Tuple comprising various elements depending on the configuration and inputs: logits as a jnp.ndarray of
            shape `(batch_size, sequence_length, vocab_size)`, hidden_states as a tuple of length `num_layers` with
            each element being a jnp.ndarray of shape `(batch_size, sequence_length, hidden_size)`, attentions as a
            tuple of length `num_layers` with each element being a jnp.ndarray of shape `(batch_size, num_heads,
            sequence_length, sequence_length)`, and others.
        attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            1's in positions corresponding to input tokens to ignore and 0's in positions corresponding to input tokens
            to attend to. It's used to mask pad tokens in input sentences. It's also used to indicate the position of
            input tokens.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple.
"""
    # 定义函数参数和其类型,以下是函数的解释说明
    Args:
        decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
            解码器输入序列标记的索引,大小为`(batch_size, target_sequence_length)`,可选。
            可以使用[`PreTrainedTokenizer`]获取这些索引。详见[`PreTrainedTokenizer.encode`]和[`PreTrainedTokenizer.__call__`]。
            [什么是解码器输入 ID?](../glossary#decoder-input-ids)
            如果使用了 `past_key_values`,则可选地只需输入最后的 `decoder_input_ids`(参见 `past_key_values`)。
            对于序列到序列的训练,应提供 `decoder_input_ids`。如果没有提供 `decoder_input_ids`,模型将通过将 `input_ids` 向右移动创建此张量,用于去噪预训练。
    encoder_outputs (`tuple(tuple(jnp.ndarray)`):
        元组由 (`last_hidden_state`, *可选*: `hidden_states`, *可选*: `attentions`) 组成。
        `last_hidden_state` 大小为 `(batch_size, sequence_length, hidden_size)`,*可选*,是编码器最后一层的隐藏状态输出序列。用于解码器的交叉注意力。
    decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
        默认行为:生成一个张量,忽略 `decoder_input_ids` 中的填充标记。因果蒙版也将默认使用。
    decoder_position_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
        每个解码器输入序列标记在位置嵌入中的位置索引。选择范围为 `[0, config.decoder.max_position_embeddings - 1]`。
    past_key_values (`Dict[str, jnp.ndarray]`, *optional*, 由 `init_cache` 返回或在传递先前的 `past_key_values` 时返回):
        预计算的隐藏状态字典(键和值在注意力块中)。可用于快速自回归解码。预计算的键和值隐藏状态的形状为 *[batch_size, max_length]*。
    output_attentions (`bool`, *optional*):
        是否返回所有注意力层的注意力张量。有关更多详细信息,请参见返回的张量下的 `attentions`。
    output_hidden_states (`bool`, *optional*):
        是否返回所有层的隐藏状态。有关更多详细信息,请参见返回的张量下的 `hidden_states`。
    return_dict (`bool`, *optional*):
        如果设置为 `True`,模型将返回 [`~utils.FlaxCausalLMOutputWithCrossAttentions`] 而不是普通元组。
"""
# 定义一个 Flax 模型类,用于视觉编码器解码器
class FlaxVisionEncoderDecoderModule(nn.Module):
    # 模型配置信息,包括编码器和解码器配置
    config: VisionEncoderDecoderConfig
    # 默认数据类型为 JAX 的 float32 类型
    dtype: jnp.dtype = jnp.float32

    # 模型设置函数,用于初始化模型
    def setup(self):
        # 获取编码器和解码器的配置信息
        encoder_config = self.config.encoder
        decoder_config = self.config.decoder

        # 从 `modeling_hybrid_clip.py` 中复制代码,并进行修改
        # 导入模型映射表,用于根据配置选择相应的编码器和解码器模块
        from ...models.auto.modeling_flax_auto import FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_MAPPING

        # 根据编码器配置选择对应的编码器模块类
        encoder_module = FLAX_MODEL_MAPPING[encoder_config.__class__].module_class
        # 根据解码器配置选择对应的解码器模块类
        decoder_module = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING[decoder_config.__class__].module_class

        # 使用选定的编码器模块和解码器模块初始化实例
        self.encoder = encoder_module(encoder_config, dtype=self.dtype)
        self.decoder = decoder_module(decoder_config, dtype=self.dtype)

        # 如果编码器的隐藏状态大小与解码器不同,并且解码器的交叉注意力隐藏大小为 None
        # 则需要进行编码器到解码器投影
        if (
            self.encoder.config.hidden_size != self.decoder.config.hidden_size
            and self.decoder.config.cross_attention_hidden_size is None
        ):
            # 定义一个全连接层,用于编码器到解码器的投影
            self.enc_to_dec_proj = nn.Dense(
                self.decoder.config.hidden_size,
                kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
                dtype=self.dtype,
            )
        else:
            # 否则不需要投影,设为 None
            self.enc_to_dec_proj = None

    # 获取编码器模块的方法
    def _get_encoder_module(self):
        return self.encoder

    # 获取投影层模块的方法
    def _get_projection_module(self):
        return self.enc_to_dec_proj

    # 获取解码器模块的方法
    def _get_decoder_module(self):
        return self.decoder

    # 模型调用函数,用于执行模型推理或训练
    def __call__(
        self,
        pixel_values,
        decoder_input_ids,
        decoder_attention_mask,
        decoder_position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
        ):
            # 调用编码器(Encoder)模型,传入像素值、是否输出注意力权重、隐藏状态等参数,返回编码器的输出
            encoder_outputs = self.encoder(
                pixel_values=pixel_values,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                deterministic=deterministic,
            )

            # 获取编码器的隐藏状态
            encoder_hidden_states = encoder_outputs[0]

            # 如果存在编码器到解码器的投影层,则将编码器隐藏状态投影到解码器空间
            if self.enc_to_dec_proj is not None:
                encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)

            # 显式设置编码器的注意力掩码,全为1的矩阵
            batch_size, sequence_length = encoder_hidden_states.shape[:2]
            encoder_attention_mask = jnp.ones((batch_size, sequence_length))

            # 调用解码器(Decoder)模型,传入解码器输入的token IDs、注意力掩码、位置 IDs、编码器隐藏状态及注意力掩码等参数,返回解码器的输出
            decoder_outputs = self.decoder(
                input_ids=decoder_input_ids,
                attention_mask=decoder_attention_mask,
                position_ids=decoder_position_ids,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                deterministic=deterministic,
            )

            # 如果不返回字典形式的输出,则将解码器输出和编码器输出拼接起来返回
            if not return_dict:
                return decoder_outputs + encoder_outputs

            # 返回经过FlaxSeq2SeqLMOutput包装后的解码器输出和编码器输出
            return FlaxSeq2SeqLMOutput(
                logits=decoder_outputs.logits,
                decoder_hidden_states=decoder_outputs.hidden_states,
                decoder_attentions=decoder_outputs.attentions,
                cross_attentions=decoder_outputs.cross_attentions,
                encoder_last_hidden_state=encoder_outputs.last_hidden_state,
                encoder_hidden_states=encoder_outputs.hidden_states,
                encoder_attentions=encoder_outputs.attentions,
            )
# 使用装饰器将以下类添加文档字符串,该文档字符串与VISION_ENCODER_DECODER_START_DOCSTRING相符
@add_start_docstrings(VISION_ENCODER_DECODER_START_DOCSTRING)
# 定义FlaxVisionEncoderDecoderModel类,继承自FlaxPreTrainedModel
class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
    # 类的文档字符串,描述了FlaxVisionEncoderDecoderModel的通用模型类特性
    r"""
    [`FlaxVisionEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture
    with the module (flax.nn.Module) of one of the base vision model classes of the library as encoder module and
    another one as decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method
    for the encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder.
    """

    # 类属性,指定配置类为VisionEncoderDecoderConfig
    config_class = VisionEncoderDecoderConfig
    # 类属性,指定基础模型的前缀为"vision_encoder_decoder"
    base_model_prefix = "vision_encoder_decoder"
    # 类属性,主输入的名称为"pixel_values"
    main_input_name = "pixel_values"
    # 类属性,模块类为FlaxVisionEncoderDecoderModule
    module_class = FlaxVisionEncoderDecoderModule

    # 初始化方法
    def __init__(
        self,
        config: VisionEncoderDecoderConfig,
        input_shape: Optional[Tuple] = None,
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        # 如果_do_init为False,则引发值错误,要求初始化为True
        if not _do_init:
            raise ValueError(
                "`FlaxVisionEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
            )

        # 如果未提供input_shape,则根据config.encoder的num_channels和image_size设置默认输入形状
        if input_shape is None:
            num_channels = getattr(config.encoder, "num_channels", 3)
            input_shape = (
                (1, config.encoder.image_size, config.encoder.image_size, num_channels),
                (1, 1),
            )

        # 如果decoder的cross_attention_hidden_size不为None,则验证其与encoder的hidden_size是否相等
        if config.decoder.cross_attention_hidden_size is not None:
            if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
                raise ValueError(
                    "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
                    f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
                    f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
                    " `config.encoder.hidden_size`."
                )

        # 使用给定的配置和其他参数实例化模块类,得到module对象
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        # 调用父类(FlaxPreTrainedModel)的初始化方法,传递配置、模块对象、输入形状、种子、数据类型和是否初始化标志
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
    # 初始化权重方法,使用给定的随机数生成器 rng,输入形状 input_shape 和可选的参数字典 params,返回冻结的参数字典
    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # 解包输入形状为编码器和解码器的输入形状
        encoder_input_shape, decoder_input_shape = input_shape

        # 初始化输入张量
        pixel_values = jnp.zeros(encoder_input_shape, dtype=self.dtype)  # 初始化像素值张量
        decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")  # 初始化解码器输入的 ID 张量
        decoder_attention_mask = jnp.ones_like(decoder_input_ids)  # 初始化解码器注意力掩码张量

        # 检查批处理大小是否一致
        batch_size, _, _, _ = pixel_values.shape
        decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
        if not decoder_batch_size == batch_size:
            raise ValueError(
                f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder "
                f"and {decoder_batch_size} for decoder."
            )

        # 创建解码器位置 ID 张量,广播到与解码器批处理大小和序列长度相匹配
        decoder_position_ids = jnp.broadcast_to(
            jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)
        )

        # 拆分随机数生成器 rng,以用于参数和 dropout
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        # 使用模块的初始化方法初始化随机参数
        random_params = self.module.init(
            rngs,
            pixel_values,
            decoder_input_ids,
            decoder_attention_mask,
            decoder_position_ids,
        )["params"]

        # 如果提供了参数字典 params,则用随机初始化的参数覆盖缺失的参数键
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            # 冻结并返回参数字典
            return freeze(unflatten_dict(params))
        else:
            # 否则,直接返回随机初始化的参数字典
            return random_params
    @add_start_docstrings(VISION_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)
    def encode(
        self,
        pixel_values: jnp.ndarray,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,


        r"""
        Args:
            pixel_values (`jnp.ndarray`):
                Pixel values of the input images. A tensor of shape `(batch_size, channels, height, width)`.
            output_attentions (`Optional[bool]`, optional):
                Whether to return attentions weights. Defaults to `None`.
            output_hidden_states (`Optional[bool]`, optional):
                Whether to return hidden states. Defaults to `None`.
            return_dict (`Optional[bool]`, optional):
                Whether to return a dictionary instead of a tuple of outputs. Defaults to `None`.
            train (`bool`, optional):
                Whether in training mode. Defaults to `False`.
            params (`dict`, optional):
                Optional parameters for the encoding process. Defaults to `None`.
            dropout_rng (`PRNGKey`, optional):
                Random number generator key for dropout. Defaults to `None`.
        """
        ):
        r"""
        Returns:

        Example:

        ```
        >>> from transformers import AutoImageProcessor, FlaxVisionEncoderDecoderModel
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

        >>> # initialize a vit-gpt2 from pretrained ViT and GPT2 models. Note that the cross-attention layers will be randomly initialized
        >>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
        ...     "google/vit-base-patch16-224-in21k", "openai-community/gpt2"
        ... )

        >>> pixel_values = image_processor(images=image, return_tensors="np").pixel_values
        >>> encoder_outputs = model.encode(pixel_values)
        ```
        """
        # Determine whether to output attentions, hidden states, and return as a dictionary based on inputs or default model config
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # Transpose pixel_values from channel last format to channel first format as expected by FlaxViTModel
        pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))

        # Handle random number generator states for dropout if specified
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        # Define a function to perform the forward pass through the encoder module
        def _encoder_forward(module, pixel_values, **kwargs):
            encode_module = module._get_encoder_module()
            return encode_module(pixel_values, **kwargs)

        # Apply the model's forward pass method with specified parameters and options
        outputs = self.module.apply(
            {"params": params or self.params},
            pixel_values=jnp.array(pixel_values, dtype=self.dtype),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs,
            method=_encoder_forward,
        )

        # If return_dict is True, wrap outputs in FlaxBaseModelOutput format
        if return_dict:
            outputs = FlaxBaseModelOutput(
                last_hidden_state=outputs.last_hidden_state,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )

        # Return the processed outputs
        return outputs

    @add_start_docstrings(VISION_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
    def decode(
        self,
        decoder_input_ids,
        encoder_outputs,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        past_key_values: dict = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        # 在此方法中解码模型的输出
        # 参数说明:
        # - decoder_input_ids: 解码器输入的标识符
        # - encoder_outputs: 编码器的输出
        # - decoder_attention_mask: 解码器的注意力掩码,可选
        # - decoder_position_ids: 解码器的位置标识符,可选
        # - past_key_values: 过去的键值对,用于缓存解码器状态,可选
        # - output_attentions: 是否输出注意力权重,可选
        # - output_hidden_states: 是否输出隐藏状态,可选
        # - return_dict: 是否返回字典形式的输出,可选
        # - train: 是否为训练模式
        # - params: 额外的参数字典,可选
        # - dropout_rng: 随机数生成器密钥,用于dropout操作,可选
        pass  # 实际的解码逻辑需要根据具体模型来实现

    @add_start_docstrings_to_model_forward(VISION_ENCODER_DECODER_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
    def __call__(
        self,
        pixel_values: jnp.ndarray,
        decoder_input_ids: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        # 调用模型,实现前向传播
        # 参数说明同上述的decode方法
        pass  # 实际的前向传播逻辑需要根据具体模型来实现

    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        max_length,
        decoder_attention_mask: Optional[jax.Array] = None,
        encoder_outputs=None,
        **kwargs,
    ):
        # 准备用于生成的输入数据
        # 参数说明:
        # - decoder_input_ids: 解码器的输入标识符
        # - max_length: 生成的最大长度
        # - decoder_attention_mask: 解码器的注意力掩码,可选
        # - encoder_outputs: 编码器的输出,可选
        # - **kwargs: 其他关键字参数
        # 初始化缓存
        batch_size, seq_length = decoder_input_ids.shape
        past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)

        # 创建扩展的注意力掩码,用于生成
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
        if decoder_attention_mask is not None:
            decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
        else:
            decoder_position_ids = jnp.broadcast_to(
                jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
            )

        return {
            "past_key_values": past_key_values,
            "encoder_outputs": encoder_outputs,
            "decoder_attention_mask": extended_attention_mask,
            "decoder_position_ids": decoder_position_ids,
        }

    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        # 更新生成过程中的输入
        # 参数说明:
        # - model_outputs: 模型的输出,包含过去的键值对
        # - model_kwargs: 模型调用的关键字参数
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
        return model_kwargs

    @classmethod
    # 定义一个类方法,用于从预训练的编码器和解码器模型中加载模型
    def from_encoder_decoder_pretrained(
        cls,
        # 可选参数:编码器预训练模型的名称或路径,可以是字符串或操作系统路径
        encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
        # 可选参数:解码器预训练模型的名称或路径,可以是字符串或操作系统路径
        decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
        # *model_args: 可变位置参数列表,用于接收额外的模型参数
        *model_args,
        # **kwargs: 可变关键字参数,用于接收额外的关键字参数
        **kwargs,

.\models\vision_encoder_decoder\modeling_tf_vision_encoder_decoder.py

# 设置编码为 UTF-8,确保脚本能够正确处理各种字符
# 版权声明,指出代码的版权归属及使用许可
# 导入必要的模块和类型声明
# 引入警告模块,用于显示编码警告信息
# 导入 NumPy 库,用于处理数组和矩阵数据
# 导入 TensorFlow 库,用于构建和训练深度学习模型

# 导入配置工具相关模块和类
from ...configuration_utils import PretrainedConfig
# 导入 TF 模型输出相关模块和类
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
# 导入 TF 实用工具相关模块和类
from ...modeling_tf_utils import TFCausalLanguageModelingLoss, TFPreTrainedModel, get_initializer, keras, unpack_inputs
# 导入 TensorFlow 实用工具,包括形状处理函数
from ...tf_utils import shape_list
# 导入通用工具模块,包括模型输出、文档字符串处理、日志记录等功能
from ...utils import (
    ModelOutput,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
# 导入自动配置相关类
from ..auto.configuration_auto import AutoConfig
# 导入 TensorFlow 自动化模型相关类
from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
# 导入视觉编码器-解码器配置类
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig

# 获取日志记录器对象
logger = logging.get_logger(__name__)

# 文档字符串中用到的配置名称
_CONFIG_FOR_DOC = "VisionEncoderDecoderConfig"

# 弃用警告信息,提醒版本更新带来的变更
DEPRECATION_WARNING = (
    "Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the"
    " encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if"
    " fine-tuning a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the"
    " labels, no need to pass them yourself anymore."
)

# 视觉编码器-解码器类的起始文档字符串,详细说明其功能和用法
VISION_ENCODER_DECODER_START_DOCSTRING = r"""
    This class can be used to initialize an image-to-text-sequence model with any pretrained vision autoencoding model
    as the encoder and any pretrained text autoregressive model as the decoder. The encoder is loaded via
    [`~TFAutoModel.from_pretrained`] function and the decoder is loaded via [`~TFAutoModelForCausalLM.from_pretrained`]
    function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream
    generative task, like image captioning.

    The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
    tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
    Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
    Zhou, Wei Li, Peter J. Liu.

    Additionally, in [TrOCR: Transformer-based Optical Character Recognition with Pre-trained
    # 在论文[Large Pretrained Vision Models](https://arxiv.org/abs/2109.10282)中展示了如何利用大型预训练视觉模型进行光学字符识别(OCR),从而显著提高性能。
    #
    # 训练/微调了这样的视觉-编码器-文本-解码器模型后,可以像处理其他模型一样保存/加载它(参见示例以获取更多信息)。
    #
    # 这个模型继承自[`TFPreTrainedModel`]。请查阅超类文档,了解库为所有模型实现的通用方法(例如下载或保存、调整输入嵌入、剪枝头等)。
    #
    # 这个模型也是一个[keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model)子类。可以将其作为常规的 TF 2.0 Keras 模型使用,并参考 TF 2.0 的文档了解所有与一般使用和行为相关的事项。
    #
    # 参数:
    #     config ([`VisionEncoderDecoderConfig`]): 包含模型所有参数的配置类。
    #         使用配置文件初始化模型不会加载与模型关联的权重,只加载配置。查看[`~TFPreTrainedModel.from_pretrained`]方法以加载模型权重。
"""

VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
"""


# Copied from transformers.models.encoder_decoder.modeling_tf_encoder_decoder.shift_tokens_right
# 将输入的 token 向右移动一位,用于生成 decoder 的输入序列
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
    # 检查 pad_token_id 是否为 None,如果是则抛出数值错误异常
    if pad_token_id is None:
        raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
    pad_token_id = tf.cast(pad_token_id, input_ids.dtype)

    # 检查 decoder_start_token_id 是否为 None,如果是则抛出数值错误异常
    if decoder_start_token_id is None:
        raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
    decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)

    # 创建一个形状为 (batch_size, 1) 的张量,用 decoder_start_token_id 填充
    start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
    # 将 start_tokens 和 input_ids 的前 n-1 列拼接起来,形成向右移动后的输入序列
    shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
    # 将 labels 中可能存在的 -100 值替换为 pad_token_id
    shifted_input_ids = tf.where(
        shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
    )

    # 确保 shifted_input_ids 中的值大于等于 0,并添加调试信息
    assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))

    # 确保断言操作被调用,通过将结果包装在 identity 操作中
    with tf.control_dependencies([assert_gte0]):
        shifted_input_ids = tf.identity(shifted_input_ids)

    return shifted_input_ids


@add_start_docstrings(VISION_ENCODER_DECODER_START_DOCSTRING)
# TFVisionEncoderDecoderModel 是一个通用的模型类,用于将库中的一个基本视觉模型类作为编码器,另一个基本模型类作为解码器
class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
    r"""
    [`TFVisionEncoderDecoderModel`] 是一个通用模型类,当使用 [`~TFAutoModel.from_pretrained`] 类方法为编码器创建一个基本视觉模型类,
    并使用 [`~TFAutoModelForCausalLM.from_pretrained`] 类方法为解码器创建另一个基本模型类时,它将被实例化为一个转换器架构。
    """

    config_class = VisionEncoderDecoderConfig  # 配置类为 VisionEncoderDecoderConfig
    base_model_prefix = "vision_encoder_decoder"  # 基础模型前缀为 "vision_encoder_decoder"
    load_weight_prefix = "tf_vision_encoder_decoder_model"  # 加载权重前缀为 "tf_vision_encoder_decoder_model"
    main_input_name = "pixel_values"  # 主输入名称为 "pixel_values"

    # 初始化函数,接受配置、编码器和解码器作为参数
    def __init__(
        self,
        config: Optional[PretrainedConfig] = None,
        encoder: Optional[TFPreTrainedModel] = None,
        decoder: Optional[TFPreTrainedModel] = None,
        ):
            # 检查配置是否为 None,并且编码器和解码器必须同时提供,否则抛出数值错误异常
            if config is None and (encoder is None or decoder is None):
                raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
            # 如果配置为 None,则从提供的编码器和解码器配置创建 VisionEncoderDecoderConfig 对象
            if config is None:
                config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
            else:
                # 如果提供的配置不是 self.config_class 类型,则抛出数值错误异常
                if not isinstance(config, self.config_class):
                    raise ValueError(f"config: {config} has to be of type {self.config_class}")

            # 如果解码器配置中的交叉注意力隐藏大小不为 None
            if config.decoder.cross_attention_hidden_size is not None:
                # 检查解码器的交叉注意力隐藏大小是否等于编码器的隐藏大小,否则抛出数值错误异常
                if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
                    raise ValueError(
                        "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
                        f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
                        f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
                        " `config.encoder.hidden_size`."
                    )

            # 使用给定的配置初始化父类
            super().__init__(config)

            # 如果编码器为 None,则从配置创建 TFAutoModel 对象,并命名为 "encoder"
            if encoder is None:
                encoder = TFAutoModel.from_config(config.encoder, name="encoder")

            # 如果解码器为 None,则从配置创建 TFAutoModelForCausalLM 对象,并命名为 "decoder"
            if decoder is None:
                decoder = TFAutoModelForCausalLM.from_config(config.decoder, name="decoder")

            # 将编码器和解码器设置为类的属性
            self.encoder = encoder
            self.decoder = decoder

            # 如果编码器的配置与类的配置不同,发出警告信息
            if self.encoder.config.to_dict() != self.config.encoder.to_dict():
                logger.warning(
                    f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
                    f" {self.config.encoder}"
                )
            # 如果解码器的配置与类的配置不同,发出警告信息
            if self.decoder.config.to_dict() != self.config.decoder.to_dict():
                logger.warning(
                    f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
                    f" {self.config.decoder}"
                )

            # 确保各模型的配置与共享配置保持同步
            self.encoder.config = self.config.encoder
            self.decoder.config = self.config.decoder

            # 如果编码器输出具有嵌入层,则抛出数值错误异常
            if (
                self.encoder.get_output_embeddings() is not None:
                raise ValueError(
                    f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
                )
    def input_signature(self):
        # 获取视觉编码器的配置
        vision_config = self.config.encoder
        # 检查是否存在额外的视觉配置,如果有则使用它
        if hasattr(vision_config, "vision_config"):
            vision_config = vision_config.vision_config
        # 检查视觉配置中是否定义了图像尺寸,如果没有则使用输入尺寸作为默认值
        if hasattr(vision_config, "image_size"):
            image_size = vision_config.image_size
        else:
            image_size = vision_config.input_size
        # 返回输入签名字典,包括像素值和解码器输入 ID 的 TensorSpec
        return {
            "pixel_values": tf.TensorSpec(
                shape=(
                    None,
                    vision_config.num_channels,
                    image_size,
                    image_size,
                ),
                dtype=tf.float32,
            ),
            "decoder_input_ids": tf.TensorSpec(shape=(None, None), dtype=tf.int32, name="decoder_input_ids"),
        }

    def get_encoder(self):
        # 返回当前对象的编码器
        return self.encoder

    def get_decoder(self):
        # 返回当前对象的解码器
        return self.decoder

    def get_input_embeddings(self):
        # 返回编码器的输入嵌入
        return self.encoder.get_input_embeddings()

    def get_output_embeddings(self):
        # 返回解码器的输出嵌入
        return self.decoder.get_output_embeddings()

    def set_output_embeddings(self, new_embeddings):
        # 设置解码器的输出嵌入
        return self.decoder.set_output_embeddings(new_embeddings)

    def tf_to_pt_weight_rename(self, tf_weight):
        # 根据不同的情况,重命名 TensorFlow 到 PyTorch 的权重名称
        # 这是为了解决 TensorFlow 和 PyTorch 模型结构不完全对齐的问题
        encoder_model_type = self.config.encoder.model_type
        if "encoder" in tf_weight and "decoder" not in tf_weight:
            return (re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight),)
        else:
            return (tf_weight,)

    @classmethod
    def from_encoder_decoder_pretrained(
        cls,
        encoder_pretrained_model_name_or_path: str = None,
        decoder_pretrained_model_name_or_path: str = None,
        *model_args,
        **kwargs,
    ):
        # 从预训练的编码器和解码器模型构建一个新的对象
        # 这是一个类方法,用于初始化对象
        pass  # Placeholder for method implementation

    @unpack_inputs
    @add_start_docstrings_to_model_forward(
        VISION_ENCODER_DECODER_INPUTS_DOCSTRING.format("batch_size, sequence_length")
    )
    @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
    def forward(self, **kwargs):
        # 在模型前向传播时执行一些预处理和文档化操作
        pass  # Placeholder for method implementation
    # 定义一个方法 `call`,接受多个参数:
    # - pixel_values: 像素值,可以是 numpy 数组、Tensor 或 None
    # - decoder_input_ids: 解码器输入的 ID,可以是 numpy 数组、Tensor 或 None
    # - decoder_attention_mask: 解码器注意力掩码,可以是 numpy 数组、Tensor 或 None
    # - encoder_outputs: 编码器输出,可以是元组或 TFBaseModelOutput 类型的可选项
    # - past_key_values: 缓存的键值对,是一个元组,包含 numpy 数组或 Tensor 的元组的可选项
    # - decoder_inputs_embeds: 解码器输入的嵌入,可以是 numpy 数组、Tensor 或 None
    # - labels: 标签,可以是 numpy 数组、Tensor 或 None
    # - use_cache: 是否使用缓存,布尔类型的可选项
    # - output_attentions: 是否输出注意力权重,布尔类型的可选项
    # - output_hidden_states: 是否输出隐藏状态,布尔类型的可选项
    # - return_dict: 是否返回字典形式的结果,布尔类型的可选项
    # - training: 是否处于训练模式,布尔类型,默认为 False
    # - **kwargs: 其他关键字参数

    def serving_output(self, output):
        # 如果配置指定使用缓存,则提取输出中的 past_key_values 的第二个元素作为 pkv,否则设为 None
        pkv = tf.tuple(output.past_key_values)[1] if self.config.decoder.use_cache else None
        # 如果配置要求输出解码器隐藏状态,则转换输出中的 decoder_hidden_states 为 Tensor,否则设为 None
        dec_hs = (
            tf.convert_to_tensor(output.decoder_hidden_states) if self.config.decoder.output_hidden_states else None
        )
        # 如果配置要求输出解码器注意力权重,则转换输出中的 decoder_attentions 为 Tensor,否则设为 None
        dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.decoder.output_attentions else None
        # 如果配置要求输出编码器隐藏状态,则转换输出中的 encoder_hidden_states 为 Tensor,否则设为 None
        enc_hs = (
            tf.convert_to_tensor(output.encoder_hidden_states) if self.config.encoder.output_hidden_states else None
        )
        # 如果配置要求输出编码器注意力权重,则转换输出中的 encoder_attentions 为 Tensor,否则设为 None
        enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.encoder.output_attentions else None
        # 如果配置要求输出交叉注意力权重,并且输出中有 cross_attentions,则转换输出中的 cross_attentions 为 Tensor,否则设为 None
        cross_attns = (
            tf.convert_to_tensor(output.cross_attentions)
            if self.config.decoder.output_attentions and output.cross_attentions is not None
            else None
        )

        # 返回 TFSeq2SeqLMOutput 类的实例,包括输出的逻辑 logits、缓存的 past_key_values、解码器隐藏状态、解码器注意力权重、
        # 编码器最后的隐藏状态、编码器隐藏状态、编码器注意力权重和交叉注意力权重
        return TFSeq2SeqLMOutput(
            logits=output.logits,
            past_key_values=pkv,
            decoder_hidden_states=dec_hs,
            decoder_attentions=dec_attns,
            encoder_last_hidden_state=output.encoder_last_hidden_state,
            encoder_hidden_states=enc_hs,
            encoder_attentions=enc_attns,
            cross_attentions=cross_attns,
        )

    # 定义一个方法 `prepare_inputs_for_generation`,用于为生成准备输入
    # - input_ids: 输入的 ID
    # - past_key_values: 缓存的键值对,可选项
    # - attention_mask: 注意力掩码,可选项
    # - use_cache: 是否使用缓存,可选项
    # - encoder_outputs: 编码器输出,可选项
    # - **kwargs: 其他关键字参数
        ):
        # 准备解码器的输入,使用当前的输入 ID 和过去的键值对
        decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
        # 获取解码器的注意力掩码,如果存在的话
        decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
        # 获取过去的键值对
        past_key_values = decoder_inputs.get("past_key_values")
        # 构建输入字典,包括像素值(传递以确保 Keras.layer.__call__ 正常工作)、注意力掩码、解码器的注意力掩码、解码器的输入 ID
        input_dict = {
            "pixel_values": None,  # 需要传递以确保 Keras.layer.__call__ 正常工作
            "attention_mask": attention_mask,
            "decoder_attention_mask": decoder_attention_mask,
            "decoder_input_ids": decoder_inputs["input_ids"],
            # TODO (joao): 在生成重构完成后,应该不再需要 `TFBaseModelOutput` 包装器
            "encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]),
            "past_key_values": past_key_values,
            "use_cache": use_cache,
        }
        # 返回构建好的输入字典
        return input_dict

    def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
        # 根据标签准备解码器的输入 ID,右移标签以适应解码器的输入要求
        return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

    def resize_token_embeddings(self, *args, **kwargs):
        # 抛出未实现错误,因为不支持通过 TFVisionEncoderDecoderModel 直接调整嵌入层大小
        raise NotImplementedError(
            "Resizing the embedding layers via the TFVisionEncoderDecoderModel directly is not supported. "
            "Please use the respective methods of the wrapped objects (model.decoder.resize_token_embeddings(...))"
        )

    def build(self, input_shape=None):
        # 如果已经构建过,则直接返回
        if self.built:
            return
        self.built = True
        # 如果存在 enc_to_dec_proj 属性,则构建它的计算图
        if getattr(self, "enc_to_dec_proj", None) is not None:
            with tf.name_scope(self.enc_to_dec_proj.name):
                self.enc_to_dec_proj.build([None, None, self.encoder.config.hidden_size])
        # 如果存在 encoder 属性,则构建它的计算图
        if getattr(self, "encoder", None) is not None:
            with tf.name_scope(self.encoder.name):
                self.encoder.build(None)
        # 如果存在 decoder 属性,则构建它的计算图
        if getattr(self, "decoder", None) is not None:
            with tf.name_scope(self.decoder.name):
                self.decoder.build(None)

.\models\vision_encoder_decoder\modeling_vision_encoder_decoder.py

# 设置文件的编码格式为 UTF-8

# 版权声明,指明版权归 HuggingFace Inc. 团队所有
# 根据 Apache 许可证 2.0 版本,除非符合许可证要求,否则不得使用此文件
# 可在以下链接获取许可证副本:http://www.apache.org/licenses/LICENSE-2.0
# 除非适用法律要求或书面同意,否则软件按"原样"分发,无任何担保或条件
# 请参阅许可证了解具体权限和限制

""" 用于支持 Vision-Encoder-Text-Decoder 结构的类"""

# 引入模块
import gc  # Python 垃圾回收模块
import os  # 操作系统模块
import tempfile  # 临时文件和目录模块
from typing import Optional, Tuple, Union  # 引入类型提示

import torch  # 引入 PyTorch 模块
from torch import nn  # 引入 PyTorch 中的神经网络模块
from torch.nn import CrossEntropyLoss  # 引入交叉熵损失函数

# 引入 Transformers 库中的一些模块和函数
from ...configuration_utils import PretrainedConfig  # 引入预训练配置相关函数
from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput  # 引入基础模型输出和 Seq2SeqLM 输出
from ...modeling_utils import PreTrainedModel  # 引入预训练模型基类
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings  # 引入辅助函数和日志记录函数
from ..auto.configuration_auto import AutoConfig  # 引入自动配置函数
from ..auto.modeling_auto import AutoModel, AutoModelForCausalLM  # 引入自动模型加载函数
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig  # 引入视觉编码解码器配置类


# 从 Transformers 库中的 encoder_decoder 模块中复制的函数
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    """
    将输入的 token 向右移动一个位置。
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)  # 创建一个与输入形状相同的全零张量
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()  # 将输入向右移动一个位置
    if decoder_start_token_id is None:
        raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
    shifted_input_ids[:, 0] = decoder_start_token_id  # 设置起始 token

    if pad_token_id is None:
        raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
    # 将标签中可能存在的 -100 值替换为 `pad_token_id`
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

    return shifted_input_ids


logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

_CONFIG_FOR_DOC = "VisionEncoderDecoderConfig"

VISION_ENCODER_DECODER_START_DOCSTRING = r"""
    此类可用于初始化一个图像到文本序列模型,其中编码器是任何预训练的视觉自编码模型,解码器是任何预训练的文本自回归模型。
    编码器通过 [`~AutoModel.from_pretrained`] 函数加载,解码器通过 [`~AutoModelForCausalLM.from_pretrained`] 函数加载。
    交叉注意力层会自动添加到解码器中,并应在下游生成任务(如图像字幕)中进行微调。

    初始化序列到序列模型时使用预训练检查点的有效性
    tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
    Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
    Zhou, Wei Li, Peter J. Liu.



    Additionally, in [TrOCR: Transformer-based Optical Character Recognition with Pre-trained
    Models](https://arxiv.org/abs/2109.10282) it is shown how leveraging large pretrained vision models for optical
    character recognition (OCR) yields a significant performance improvement.



    After such a Vision-Encoder-Text-Decoder model has been trained/fine-tuned, it can be saved/loaded just like any
    other models (see the examples for more information).



    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)



    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.



    Parameters:
        config ([`VisionEncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
"""

@add_start_docstrings(VISION_ENCODER_DECODER_START_DOCSTRING)
class VisionEncoderDecoderModel(PreTrainedModel):
    r"""
    [`VisionEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with
    one of the base vision model classes of the library as encoder and another one as decoder when created with the
    :meth*~transformers.AutoModel.from_pretrained* class method for the encoder and
    :meth*~transformers.AutoModelForCausalLM.from_pretrained* class method for the decoder.
    """

    # 设置配置类为 VisionEncoderDecoderConfig
    config_class = VisionEncoderDecoderConfig
    # 指定基础模型前缀为 "vision_encoder_decoder"
    base_model_prefix = "vision_encoder_decoder"
    # 主输入名称为 "pixel_values"
    main_input_name = "pixel_values"
    # 支持梯度检查点
    supports_gradient_checkpointing = True

    def __init__(
        self,
        config: Optional[PretrainedConfig] = None,
        encoder: Optional[PreTrainedModel] = None,
        decoder: Optional[PreTrainedModel] = None,
        ):
            # 如果未提供配置且未同时提供编码器和解码器,则抛出数值错误
            raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
        if config is None:
            # 如果未提供配置,则从编码器和解码器的配置中创建视觉编码器解码器配置
            config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
        else:
            # 如果提供的配置不是预期的配置类类型,则抛出数值错误
            if not isinstance(config, self.config_class):
                raise ValueError(f"Config: {config} has to be of type {self.config_class}")

        if config.decoder.cross_attention_hidden_size is not None:
            # 如果解码器配置中指定了交叉注意力的隐藏大小
            if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
                # 则要求解码器的交叉注意力隐藏大小必须与编码器的隐藏大小相等,否则抛出数值错误
                raise ValueError(
                    "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
                    f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
                    f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
                    " `config.encoder.hidden_size`."
                )

        # 初始化配置,确保输入和输出嵌入不被绑定
        config.tie_word_embeddings = False
        # 调用父类初始化方法,传入配置
        super().__init__(config)

        if encoder is None:
            # 如果未提供编码器,则从配置中创建自动模型
            encoder = AutoModel.from_config(config.encoder)

        if decoder is None:
            # 如果未提供解码器,则从配置中创建自动因果语言模型
            decoder = AutoModelForCausalLM.from_config(config.decoder)

        # 将编码器和解码器存储在实例变量中
        self.encoder = encoder
        self.decoder = decoder

        if self.encoder.config.to_dict() != self.config.encoder.to_dict():
            # 如果编码器的配置不等于共享配置,则记录警告信息
            logger.warning(
                f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
                f" {self.config.encoder}"
            )
        if self.decoder.config.to_dict() != self.config.decoder.to_dict():
            # 如果解码器的配置不等于共享配置,则记录警告信息
            logger.warning(
                f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
                f" {self.config.decoder}"
            )

        # 确保各自模型的配置引用了共享的配置,以便配置的更新能够同步
        self.encoder.config = self.config.encoder
        self.decoder.config = self.config.decoder

        # 如果编码器和解码器的隐藏大小不相等且解码器未指定交叉注意力隐藏大小,则需要对编码器输出进行投影以适配解码器
        if (
            self.encoder.config.hidden_size != self.decoder.config.hidden_size
            and self.decoder.config.cross_attention_hidden_size is None
        ):
            self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)

        if self.encoder.get_output_embeddings() is not None:
            # 如果编码器具有语言模型头部,则抛出数值错误
            raise ValueError(
                f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
            )

    def get_encoder(self):
        # 返回存储的编码器模型
        return self.encoder

    def get_decoder(self):
        # 返回存储的解码器模型
        return self.decoder
    # 返回当前模型的解码器的输出嵌入层
    def get_output_embeddings(self):
        return self.decoder.get_output_embeddings()

    # 设置当前模型的解码器的输出嵌入层为新的嵌入层
    def set_output_embeddings(self, new_embeddings):
        return self.decoder.set_output_embeddings(new_embeddings)

    # 从预训练的编码器和解码器模型名或路径创建一个模型实例
    @classmethod
    def from_encoder_decoder_pretrained(
        cls,
        encoder_pretrained_model_name_or_path: str = None,
        decoder_pretrained_model_name_or_path: str = None,
        *model_args,
        **kwargs,
    ):
        pass

    # 前向传播函数,执行模型的正向运算
    @add_start_docstrings_to_model_forward(VISION_ENCODER_DECODER_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs,
    ):
        pass

    # 根据标签准备解码器的输入标识,用于生成序列的输入
    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

    # 准备生成阶段的输入,整理输入数据用于模型生成
    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
    ):
        decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
        decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
        input_dict = {
            "attention_mask": attention_mask,
            "decoder_attention_mask": decoder_attention_mask,
            "decoder_input_ids": decoder_inputs["input_ids"],
            "encoder_outputs": encoder_outputs,
            "past_key_values": decoder_inputs["past_key_values"],
            "use_cache": use_cache,
        }
        return input_dict

    # 调整标记嵌入的大小(目前未实现)
    def resize_token_embeddings(self, *args, **kwargs):
        raise NotImplementedError(
            "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the"
            " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
        )

    # 重新排序缓存中的过去键值对,用于Beam搜索中的解码器缓存重排
    def _reorder_cache(self, past_key_values, beam_idx):
        # 在这里执行解码器缓存的重新排序
        return self.decoder._reorder_cache(past_key_values, beam_idx)

.\models\vision_encoder_decoder\__init__.py

# 版权声明和许可证信息,保留所有权利
#
# 根据 Apache 许可证版本 2.0 授权使用此文件;
# 除非符合许可证的规定,否则您不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则按“原样”分发本软件,
# 不附带任何形式的担保或条件,无论是明示的还是默示的。
# 有关具体语言的权限,请参阅许可证。
#

# 从 typing 模块导入 TYPE_CHECKING 类型提示
from typing import TYPE_CHECKING

# 从 utils 模块导入所需的类和函数
from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_flax_available,
    is_tf_available,
    is_torch_available,
)

# 定义模块导入结构
_import_structure = {
    "configuration_vision_encoder_decoder": ["VisionEncoderDecoderConfig", "VisionEncoderDecoderOnnxConfig"]
}

# 检查是否有 torch 可用,如果不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,则将 VisionEncoderDecoderModel 导入到模块导入结构中
    _import_structure["modeling_vision_encoder_decoder"] = ["VisionEncoderDecoderModel"]

# 检查是否有 TensorFlow 可用,如果不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,则将 TFVisionEncoderDecoderModel 导入到模块导入结构中
    _import_structure["modeling_tf_vision_encoder_decoder"] = ["TFVisionEncoderDecoderModel"]

# 检查是否有 Flax 可用,如果不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_flax_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,则将 FlaxVisionEncoderDecoderModel 导入到模块导入结构中
    _import_structure["modeling_flax_vision_encoder_decoder"] = ["FlaxVisionEncoderDecoderModel"]

# 如果是 TYPE_CHECKING 模式
if TYPE_CHECKING:
    # 从 configuration_vision_encoder_decoder 模块中导入特定配置类
    from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig, VisionEncoderDecoderOnnxConfig

    # 检查是否有 torch 可用,如果不可用则抛出 OptionalDependencyNotAvailable 异常
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果可用,则从 modeling_vision_encoder_decoder 模块中导入 VisionEncoderDecoderModel 类
        from .modeling_vision_encoder_decoder import VisionEncoderDecoderModel

    # 检查是否有 TensorFlow 可用,如果不可用则抛出 OptionalDependencyNotAvailable 异常
    try:
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果可用,则从 modeling_tf_vision_encoder_decoder 模块中导入 TFVisionEncoderDecoderModel 类
        from .modeling_tf_vision_encoder_decoder import TFVisionEncoderDecoderModel

    # 检查是否有 Flax 可用,如果不可用则抛出 OptionalDependencyNotAvailable 异常
    try:
        if not is_flax_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果可用,则从 modeling_flax_vision_encoder_decoder 模块中导入 FlaxVisionEncoderDecoderModel 类
        from .modeling_flax_vision_encoder_decoder import FlaxVisionEncoderDecoderModel

# 如果不是 TYPE_CHECKING 模式
else:
    # 导入 sys 模块
    import sys

    # 将当前模块定义为延迟加载模块的 LazyModule 对象
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\vision_text_dual_encoder\configuration_vision_text_dual_encoder.py

# 设置文件编码为 UTF-8

# 导入必要的模块和类
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto.configuration_auto import AutoConfig
from ..chinese_clip.configuration_chinese_clip import ChineseCLIPVisionConfig
from ..clip.configuration_clip import CLIPVisionConfig
from ..siglip.configuration_siglip import SiglipVisionConfig

# 获取日志记录器对象
logger = logging.get_logger(__name__)

# 定义不同视觉模型配置类的映射关系
VISION_MODEL_CONFIGS = {
    "clip_vision_model": CLIPVisionConfig,
    "chinese_clip_vision_model": ChineseCLIPVisionConfig,
    "siglip_vision_model": SiglipVisionConfig,
}

# VisionTextDualEncoderConfig 类继承自 PretrainedConfig 类,用于存储 VisionTextDualEncoderModel 的配置信息
class VisionTextDualEncoderConfig(PretrainedConfig):
    r"""
    [`VisionTextDualEncoderConfig`] 是一个配置类,用于存储 [`VisionTextDualEncoderModel`] 的配置信息。
    根据指定的参数实例化 [`VisionTextDualEncoderModel`] 模型,定义了文本模型和视觉模型的配置。

    配置对象继承自 [`PretrainedConfig`],可用于控制模型的输出。更多信息请参阅 [`PretrainedConfig`] 的文档。

    Args:
        projection_dim (`int`, *optional*, defaults to 512):
            文本和视觉投影层的维度。
        logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
            *logit_scale* 参数的初始值。默认值按照原始 CLIP 实现使用。
        kwargs (*optional*):
            字典形式的关键字参数。

    Examples:

    ```
    >>> from transformers import ViTConfig, BertConfig, VisionTextDualEncoderConfig, VisionTextDualEncoderModel

    >>> # 初始化 BERT 和 ViT 的配置
    >>> config_vision = ViTConfig()
    >>> config_text = BertConfig()

    >>> config = VisionTextDualEncoderConfig.from_vision_text_configs(config_vision, config_text, projection_dim=512)

    >>> # 初始化一个带有随机权重的 BERT 和 ViT 模型
    >>> model = VisionTextDualEncoderModel(config=config)

    >>> # 访问模型配置
    >>> config_vision = model.config.vision_config
    >>> config_text = model.config.text_config

    >>> # 保存模型及其配置
    >>> model.save_pretrained("vit-bert")

    >>> # 从预训练文件夹加载模型和配置
    ```
    # 从预训练模型“vit-bert”加载视觉文本双编码器配置
    vision_text_config = VisionTextDualEncoderConfig.from_pretrained("vit-bert")
    # 使用加载的配置实例化视觉文本双编码器模型
    model = VisionTextDualEncoderModel.from_pretrained("vit-bert", config=vision_text_config)



    # 设定模型类型为“vision-text-dual-encoder”
    model_type = "vision-text-dual-encoder"
    # 表示这个类是一个复合类
    is_composition = True

    def __init__(self, projection_dim=512, logit_scale_init_value=2.6592, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)

        # 检查是否提供了视觉配置参数
        if "vision_config" not in kwargs:
            raise ValueError("`vision_config` can not be `None`.")
        
        # 检查是否提供了文本配置参数
        if "text_config" not in kwargs:
            raise ValueError("`text_config` can not be `None`.")
        
        # 弹出并获取视觉配置参数
        vision_config = kwargs.pop("vision_config")
        # 弹出并获取文本配置参数
        text_config = kwargs.pop("text_config")
        
        # 获取视觉模型类型
        vision_model_type = vision_config.pop("model_type")
        # 获取文本模型类型
        text_model_type = text_config.pop("model_type")
        
        # 根据视觉模型类型获取对应的配置类
        vision_config_class = VISION_MODEL_CONFIGS.get(vision_model_type)
        # 如果找到了对应的配置类,则使用提供的视觉配置参数实例化它
        if vision_config_class is not None:
            self.vision_config = vision_config_class(**vision_config)
        # 否则,根据视觉模型类型和参数自动创建一个配置实例
        else:
            self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
            # 如果这个配置实例本身有一个名为`vision_config`的属性,则将其设置为当前实例的`vision_config`
            if hasattr(self.vision_config, "vision_config"):
                self.vision_config = self.vision_config.vision_config
        
        # 根据文本模型类型和参数自动创建一个文本配置实例
        self.text_config = AutoConfig.for_model(text_model_type, **text_config)
        
        # 设置投影维度参数
        self.projection_dim = projection_dim
        # 设置对数尺度初始化值参数
        self.logit_scale_init_value = logit_scale_init_value



    @classmethod
    def from_vision_text_configs(cls, vision_config: PretrainedConfig, text_config: PretrainedConfig, **kwargs):
        """
        从视觉模型配置和文本模型配置实例化一个[`VisionTextDualEncoderConfig`](或其派生类)。

        Args:
            vision_config (PretrainedConfig): 视觉模型配置的实例
            text_config (PretrainedConfig): 文本模型配置的实例
            **kwargs: 其他参数

        Returns:
            VisionTextDualEncoderConfig: 配置对象的一个实例
        """
        return cls(vision_config=vision_config.to_dict(), text_config=text_config.to_dict(), **kwargs)


这些注释为每行代码提供了详细的解释,包括代码的目的、参数的作用以及返回值的说明。

.\models\vision_text_dual_encoder\modeling_flax_vision_text_dual_encoder.py

# coding=utf-8
# 版权所有 2021 年 HuggingFace Inc. 团队。保留所有权利。
#
# 根据 Apache 许可证 2.0 版本(“许可证”)获得许可;
# 除非符合许可证,否则不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,软件
# 均在“按原样”基础上分发,无论是明示的还是暗示的。
# 有关特定语言的权限,请参阅许可证。
""" Flax VisionTextDualEncoder model."""

# 从 typing 模块导入必要的类型
from typing import Optional, Tuple

# 导入 Flax 的 linen 模块作为 nn
import flax.linen as nn
# 导入 JAX 库,并将其别名为 jax
import jax
# 导入 JAX 的 numpy 模块,并将其别名为 jnp
import jax.numpy as jnp
# 导入 flax.core.frozen_dict 模块中的 FrozenDict、freeze 和 unfreeze 方法
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
# 导入 flax.traverse_util 模块中的 flatten_dict 和 unflatten_dict 方法
from flax.traverse_util import flatten_dict, unflatten_dict

# 从模块中导入一些函数和类
from ...modeling_flax_utils import FlaxPreTrainedModel, append_replace_return_docstrings, overwrite_call_docstring
from ...utils import add_start_docstrings, logging
# 从 auto 模块中导入 AutoConfig 类
from ..auto.configuration_auto import AutoConfig
# 从 auto 模块中导入 FLAX_MODEL_MAPPING 和 FlaxAutoModel 类
from ..auto.modeling_flax_auto import FLAX_MODEL_MAPPING, FlaxAutoModel
# 从 clip 模块中导入 FlaxCLIPOutput 和 FlaxCLIPVisionModel 类
from ..clip.modeling_flax_clip import FlaxCLIPOutput, FlaxCLIPVisionModel
# 从 configuration_vision_text_dual_encoder 模块中导入 VisionTextDualEncoderConfig 类
from .configuration_vision_text_dual_encoder import VisionTextDualEncoderConfig

# 获取日志记录器
logger = logging.get_logger(__name__)

# 用于文档的配置
_CONFIG_FOR_DOC = "VisionTextDualEncoderConfig"

# VisionTextDualEncoder 类的文档字符串,包含详细说明和用法示例
VISION_TEXT_DUAL_ENCODER_START_DOCSTRING = r"""
    This class can be used to initialize a vision-text dual encoder model with any pretrained vision autoencoding model
    as the vision encoder and any pretrained text model as the text encoder. The vision and text encoders are loaded
    via the [`~FlaxAutoModel.from_pretrained`] method. The projection layers are automatically added to the model and
    should be fine-tuned on a downstream task, like contrastive image-text modeling.

    In [LiT: Zero-Shot Transfer with Locked-image Text Tuning](https://arxiv.org/abs/2111.07991) it is shown how
    leveraging pre-trained (locked/frozen) image and text model for contrastive learning yields significant improvment
    on new zero-shot vision tasks such as image classification or retrieval.

    After such a Vision-Text-Dual-Encoder model has been trained/fine-tuned, it can be saved/loaded just like any other
    models (see the examples for more information).

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

     This model is also a
     [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it
     as a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
     behavior.

    Finally, this model supports inherent JAX features such as:
    # 导入必要的 JAX 模块,包括 JIT 编译、自动微分、向量化和并行化
    import jax
    import jax.numpy as jnp
    from flax.training.common import VisionTextDualEncoderConfig
    from transformers import FlaxPreTrainedModel
    
    # 函数定义:初始化模型配置
    def __init__(self, config: VisionTextDualEncoderConfig):
        # 模型配置,包含所有模型参数的类
        self.config = config
    
    # 函数定义:设置计算的数据类型
    def set_dtype(self, dtype=jnp.float32):
        """
        设置计算的数据类型。
    
        Parameters:
            dtype (jax.numpy.dtype, optional, default=jax.numpy.float32):
                计算的数据类型。可以是 `jax.numpy.float32`、`jax.numpy.float16`(在 GPU 上)和 `jax.numpy.bfloat16`(在 TPU 上)。
                可以用于启用混合精度训练或在 GPU 或 TPU 上进行半精度推理。如果指定了 dtype,则所有的计算将使用给定的 dtype。
    
                **注意:这只指定了计算的数据类型,不影响模型参数的数据类型。**
    
                如果您希望更改模型参数的数据类型,请参见 `~FlaxPreTrainedModel.to_fp16` 和 `~FlaxPreTrainedModel.to_bf16`。
        """
        self.dtype = dtype
"""
定义了一个字符串常量 VISION_TEXT_DUAL_ENCODER_INPUTS_DOCSTRING,用于存储文档字符串。

Args:
    input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`):
        输入序列标记的索引,可以通过 AutoTokenizer 获取。默认情况下将忽略填充部分。
        
        [什么是输入 ID?](../glossary#input-ids)
    attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
        遮罩,用于避免在填充标记索引上执行注意力。遮罩值选择在 `[0, 1]`:
        
        - 1 表示 **未遮罩** 的标记,
        - 0 表示 **已遮罩** 的标记。
        
        [什么是注意力遮罩?](../glossary#attention-mask)
    position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
        每个输入序列标记在位置嵌入中的位置索引。选择范围在 `[0, config.max_position_embeddings - 1]`。
        
        [什么是位置 ID?](../glossary#position-ids)
    pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
        像素值。默认情况下将忽略填充部分。可以使用图像处理器获取像素值(例如,如果使用 ViT 作为编码器,应使用 `AutoImageProcessor`)。
        
        [ViTImageProcessor.__call__] 获取更多细节。
    output_attentions (`bool`, *optional*):
        是否返回所有注意力层的注意力张量。详见返回的张量中的 `attentions` 获取更多细节。
    output_hidden_states (`bool`, *optional*):
        是否返回所有层的隐藏状态。详见返回的张量中的 `hidden_states` 获取更多细节。
    return_dict (`bool`, *optional*):
        是否返回 `~utils.ModelOutput` 而不是普通元组。
"""
class FlaxVisionTextDualEncoderModule(nn.Module):
    config: VisionTextDualEncoderConfig
    dtype: jnp.dtype = jnp.float32
    # 设置函数的初始化操作,准备模型和参数配置
    def setup(self):
        # 从配置对象中获取视觉模型和文本模型的配置信息
        vision_config = self.config.vision_config
        text_config = self.config.text_config

        # 设置视觉嵌入维度和文本嵌入维度
        self.vision_embed_dim = vision_config.hidden_size
        self.text_embed_dim = text_config.hidden_size
        self.projection_dim = self.config.projection_dim

        # 根据视觉模型和文本模型的配置选择相应的模型类
        vision_module = FLAX_MODEL_MAPPING.get(self.config.vision_config.__class__, FlaxCLIPVisionModel).module_class
        text_module = FLAX_MODEL_MAPPING[self.config.text_config.__class__].module_class

        # 初始化视觉模型和文本模型
        self.vision_model = vision_module(vision_config, dtype=self.dtype)
        self.text_model = text_module(text_config, dtype=self.dtype)

        # 初始化视觉和文本的投影层
        self.visual_projection = nn.Dense(
            self.projection_dim,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(0.02),
            use_bias=False,
        )
        self.text_projection = nn.Dense(
            self.projection_dim,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(0.02),
            use_bias=False,
        )

        # 初始化logit的缩放系数,并将其作为模型的参数
        self.logit_scale = self.param(
            "logit_scale", lambda _, shape: jnp.ones(shape) * self.config.logit_scale_init_value, []
        )

    # 定义模型的调用方法,处理输入并返回模型的输出
    def __call__(
        self,
        input_ids=None,
        pixel_values=None,
        attention_mask=None,
        position_ids=None,
        token_type_ids=None,
        deterministic: bool = True,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        ):
            # 如果 return_dict 不为 None,则使用给定的 return_dict;否则使用对象自身的配置中的 return_dict
            return_dict = return_dict if return_dict is not None else self.config.return_dict

        # 使用视觉模型处理像素值,获取视觉输出,包括注意力权重和隐藏状态
        vision_outputs = self.vision_model(
            pixel_values=pixel_values,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 使用文本模型处理输入,获取文本输出,包括注意力权重和隐藏状态
        text_outputs = self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从视觉输出中获取图像嵌入,并通过投影层进行处理
        image_embeds = vision_outputs[1]
        image_embeds = self.visual_projection(image_embeds)

        # 从文本输出中获取文本嵌入,并通过投影层进行处理
        text_embeds = text_outputs[1]
        text_embeds = self.text_projection(text_embeds)

        # 对图像嵌入进行标准化处理
        image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True)
        # 对文本嵌入进行标准化处理
        text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)

        # 使用余弦相似度计算文本和图像嵌入之间的逻辑相似性得分
        logit_scale = jnp.exp(self.logit_scale)
        logits_per_text = jnp.matmul(text_embeds, image_embeds.T) * logit_scale
        logits_per_image = logits_per_text.T

        # 如果 return_dict 为 False,则返回包含多个输出的元组
        if not return_dict:
            return (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)

        # 如果 return_dict 为 True,则返回自定义的输出对象 FlaxCLIPOutput
        return FlaxCLIPOutput(
            logits_per_image=logits_per_image,
            logits_per_text=logits_per_text,
            text_embeds=text_embeds,
            image_embeds=image_embeds,
            text_model_output=text_outputs,
            vision_model_output=vision_outputs,
        )
# 将文本和视觉输入编码成嵌入向量的模型,继承自FlaxPreTrainedModel
@add_start_docstrings(VISION_TEXT_DUAL_ENCODER_START_DOCSTRING)
class FlaxVisionTextDualEncoderModel(FlaxPreTrainedModel):
    # 使用VisionTextDualEncoderConfig作为配置类
    config_class = VisionTextDualEncoderConfig
    # 使用FlaxVisionTextDualEncoderModule作为模块类
    module_class = FlaxVisionTextDualEncoderModule

    def __init__(
        self,
        config: VisionTextDualEncoderConfig,
        input_shape: Optional[Tuple] = None,
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        # 如果不初始化,则抛出错误
        if not _do_init:
            raise ValueError(
                "`FlaxVisionTextDualEncoderModel` cannot be created without initializing, `_do_init` must be `True`."
            )

        # 如果未提供输入形状,则使用默认的输入形状
        if input_shape is None:
            input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))

        # 创建模块实例
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        # 调用父类的初始化方法,传入配置、模块、输入形状、种子和数据类型
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # 初始化输入张量
        input_ids = jnp.zeros(input_shape[0], dtype="i4")
        # 生成位置编码
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
        # 生成token类型编码,默认为1
        token_type_ids = jnp.ones_like(input_ids)
        # 生成注意力掩码,默认为全1
        attention_mask = jnp.ones_like(input_ids)

        # 生成像素值,使用正态分布随机数初始化
        pixel_values = jax.random.normal(rng, input_shape[1])

        # 分割随机数生成器,用于参数和dropout
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        # 初始化模块,获取随机参数
        random_params = self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)[
            "params"
        ]

        # 如果提供了参数,则使用提供的参数替换缺失的随机参数
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params

    def __call__(
        self,
        input_ids,
        pixel_values,
        attention_mask=None,
        position_ids=None,
        token_type_ids=None,
        params: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        #```
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
        ):
            # 如果 output_attentions 不为 None,则使用其值;否则使用配置中的默认值
            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
            # 如果 output_hidden_states 不为 None,则使用其值;否则使用配置中的默认值
            output_hidden_states = (
                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
            )
            # 如果 return_dict 不为 None,则使用其值;否则使用配置中的默认值
            return_dict = return_dict if return_dict is not None else self.config.return_dict

            # 转置像素值数组,调整维度顺序
            pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))

            # 如果 position_ids 为 None,则创建一个与 input_ids 最后一个维度广播兼容的数组
            if position_ids is None:
                position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

            # 如果 token_type_ids 为 None,则创建一个与 input_ids 形状相同的全零数组
            if token_type_ids is None:
                token_type_ids = jnp.zeros_like(input_ids)

            # 如果 attention_mask 为 None,则创建一个与 input_ids 形状相同的全一数组
            if attention_mask is None:
                attention_mask = jnp.ones_like(input_ids)

            # 处理任何需要的伪随机数发生器 PRNG
            rngs = {}
            if dropout_rng is not None:
                rngs["dropout"] = dropout_rng

            # 调用 self.module.apply 方法,传递相关参数进行模型应用
            return self.module.apply(
                {"params": params or self.params},
                jnp.array(input_ids, dtype="i4"),
                jnp.array(pixel_values, dtype=jnp.float32),
                jnp.array(attention_mask, dtype="i4"),
                jnp.array(position_ids, dtype="i4"),
                jnp.array(token_type_ids, dtype="i4"),
                not train,
                output_attentions,
                output_hidden_states,
                return_dict,
                rngs=rngs,
            )

        # 定义一个方法 get_text_features,接受多个参数,用于获取文本特征
        def get_text_features(
            self,
            input_ids,
            attention_mask=None,
            position_ids=None,
            token_type_ids=None,
            params: dict = None,
            dropout_rng: jax.random.PRNGKey = None,
            train=False,
        ):
    ):
        r"""
        Args:
            input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it.

                Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)

        Returns:
            text_features (`jnp.ndarray` of shape `(batch_size, output_dim`): The text embeddings obtained by applying
            the projection layer to the pooled output of text model.
        """
        # 如果未提供 position_ids 参数,则使用 input_ids 的长度广播生成位置 IDs
        if position_ids is None:
            position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        # 如果未提供 token_type_ids 参数,则生成与 input_ids 形状相同的全零张量
        if token_type_ids is None:
            token_type_ids = jnp.zeros_like(input_ids)

        # 如果未提供 attention_mask 参数,则生成与 input_ids 形状相同的全一张量
        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)

        # 处理可能需要的任何伪随机数发生器(PRNG)
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        def _get_features(module, input_ids, attention_mask, position_ids, token_type_ids, deterministic):
            # 调用文本模型获取文本输出
            text_outputs = module.text_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                token_type_ids=token_type_ids,
                deterministic=deterministic,
            )
            # 从文本输出中获取汇聚输出
            pooled_output = text_outputs[1]
            # 应用文本投影层获得文本特征
            text_features = module.text_projection(pooled_output)
            return text_features

        # 调用模块的 apply 方法来应用参数和输入数据进行前向计算
        return self.module.apply(
            {"params": params or self.params},  # 提供模型参数
            jnp.array(input_ids, dtype="i4"),  # 输入的序列 token IDs
            jnp.array(attention_mask, dtype="i4"),  # 输入的注意力掩码
            jnp.array(position_ids, dtype="i4"),  # 输入的位置 IDs
            jnp.array(token_type_ids, dtype="i4"),  # 输入的 token 类型 IDs
            not train,  # 是否是推理模式
            method=_get_features,  # 调用的方法来获取特征
            rngs=rngs,  # 伪随机数发生器的字典
        )

    def get_image_features(
        self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train=False
        ):
        r"""
        Args:
            pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
                Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained
                using [`ImageFeatureExtractionMixin`]. See [`ImageFeatureExtractionMixin.__call__`] for details.

        Returns:
            image_features (`jnp.ndarray` of shape `(batch_size, output_dim`): The image embeddings obtained by
            applying the projection layer to the pooled output of vision model.
        """

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        # 定义一个内部函数,用于从视觉模型中提取特征
        def _get_features(module, pixel_values, deterministic):
            # 调用视觉模型,传入像素值和确定性参数,获取视觉模型的输出
            vision_outputs = module.vision_model(pixel_values=pixel_values, deterministic=deterministic)
            # 提取汇总输出(通常是第二个输出)
            pooled_output = vision_outputs[1]  # pooled_output
            # 将汇总输出应用于视觉投影层,获取图像特征
            image_features = module.visual_projection(pooled_output)
            return image_features

        # 调用当前对象所包含的模块的 apply 方法,将参数和数据传入视觉模型处理函数
        return self.module.apply(
            {"params": params or self.params},  # 使用给定的参数或对象的参数
            jnp.array(pixel_values, dtype=jnp.float32),  # 将像素值转换为 jax 数组
            not train,  # 确定是否为训练模式
            method=_get_features,  # 指定处理方法为 _get_features 函数
            rngs=rngs,  # 传入任何可能需要的随机数生成器
        )

    @classmethod
    def from_vision_text_pretrained(
        cls,
        vision_model_name_or_path: str = None,
        text_model_name_or_path: str = None,
        *model_args,
        **kwargs,
# 定义 VisionTextDualEncoderModel 的文档字符串,包含函数的返回值和示例用法
VISION_TEXT_DUAL_ENCODER_MODEL_DOCSTRING = r"""
    Returns:

    Examples:

    ```
    >>> from PIL import Image
    >>> import requests
    >>> import jax
    >>> from transformers import (
    ...     FlaxVisionTextDualEncoderModel,
    ...     VisionTextDualEncoderProcessor,
    ...     AutoImageProcessor,
    ...     AutoTokenizer,
    ... )

    >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
    >>> image_processor = AutoImageProcesor.from_pretrained("google/vit-base-patch16-224")
    >>> processor = VisionTextDualEncoderProcessor(image_processor, tokenizer)
    >>> model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained(
    ...     "google/vit-base-patch16-224", "google-bert/bert-base-uncased"
    ... )

    >>> # contrastive training
    >>> urls = [
    ...     "http://images.cocodataset.org/val2017/000000039769.jpg",
    ...     "https://farm3.staticflickr.com/2674/5850229113_4fe05d5265_z.jpg",
    ... ]
    >>> images = [Image.open(requests.get(url, stream=True).raw) for url in urls]
    >>> inputs = processor(
    ...     text=["a photo of a cat", "a photo of a dog"], images=images, return_tensors="np", padding=True
    ... )
    >>> outputs = model(
    ...     input_ids=inputs.input_ids,
    ...     attention_mask=inputs.attention_mask,
    ...     pixel_values=inputs.pixel_values,
    ... )
    >>> logits_per_image = outputs.logits_per_image  # this is the image-text similarity score

    >>> # save and load from pretrained
    >>> model.save_pretrained("vit-bert")
    >>> model = FlaxVisionTextDualEncoderModel.from_pretrained("vit-bert")

    >>> # inference
    >>> outputs = model(**inputs)
    >>> logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
    >>> probs = jax.nn.softmax(logits_per_image, axis=1)  # we can take the softmax to get the label probabilities
    ```
"""

# 调用 overwrite_call_docstring 函数,用于替换 FlaxVisionTextDualEncoderModel 类的文档字符串
overwrite_call_docstring(
    FlaxVisionTextDualEncoderModel,
    VISION_TEXT_DUAL_ENCODER_INPUTS_DOCSTRING + VISION_TEXT_DUAL_ENCODER_MODEL_DOCSTRING,
)

# 调用 append_replace_return_docstrings 函数,用于附加和替换 FlaxVisionTextDualEncoderModel 类的返回值文档字符串
append_replace_return_docstrings(
    FlaxVisionTextDualEncoderModel, output_type=FlaxCLIPOutput, config_class=_CONFIG_FOR_DOC
)

.\models\vision_text_dual_encoder\modeling_tf_vision_text_dual_encoder.py

# 定义一个文本与视觉双编码模型的文档字符串,用于说明如何初始化并使用预训练的视觉和文本编码器。
VISION_TEXT_DUAL_ENCODER_START_DOCSTRING = r"""
    This class can be used to initialize a vision-text dual encoder model with any pretrained vision autoencoding model
    as the vision encoder and any pretrained text model as the text encoder. The vision and text encoders are loaded
    via the [`~TFAutoModel.from_pretrained`] method. The projection layers are automatically added to the model and
    should be fine-tuned on a downstream task, like contrastive image-text modeling.

    In [LiT: Zero-Shot Transfer with Locked-image Text Tuning](https://arxiv.org/abs/2111.07991) it is shown how
    leveraging pre-trained (locked/frozen) image and text model for contrastive learning yields significant improvement
    on new zero-shot vision tasks such as image classification or retrieval.

    After such a Vision-Text-Dual-Encoder model has been trained/fine-tuned, it can be saved/loaded just like any other
    models (see the examples for more information).

    This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a Keras [Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it as a
    regular Keras Model and refer to the TF documentation for all matter related to general usage and behavior.
"""
    Parameters:
        config ([`VisionEncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""
Args:
    input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
        输入序列 token 的索引,位于词汇表中。默认情况下会忽略填充部分。
        可以使用 `PreTrainedTokenizer` 获取索引。详见 `PreTrainedTokenizer.encode` 和 `PreTrainedTokenizer.__call__`。

        [什么是输入 ID?](../glossary#input-ids)
    attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
        避免在填充 token 索引上执行注意力的掩码。掩码值在 `[0, 1]` 范围内:

        - 对于 **未被掩码** 的 token,为 1,
        - 对于 **被掩码** 的 token,为 0。

        [什么是注意力掩码?](../glossary#attention-mask)
    position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
        每个输入序列 token 在位置嵌入中的位置索引。选择范围在 `[0, config.max_position_embeddings - 1]` 内。

        [什么是位置 ID?](../glossary#position-ids)
    output_attentions (`bool`, *optional*):
        是否返回所有注意力层的注意力张量。查看返回张量下的 `attentions` 获取更多细节。
    output_hidden_states (`bool`, *optional*):
        是否返回所有层的隐藏状态。查看返回张量下的 `hidden_states` 获取更多细节。
    return_dict (`bool`, *optional*):
        是否返回 [`~utils.ModelOutput`] 而不是普通元组。
"""
    # 输入参数为输入序列的标记索引,形状为(batch_size, sequence_length)
    # 使用AutoTokenizer获取输入索引,详见PreTrainedTokenizer.encode和PreTrainedTokenizer.__call__
    input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
    
    # 可选参数,注意力掩码,形状为(batch_size, sequence_length)
    # 用于避免在填充标记索引上执行注意力操作,值为0和1:
    # - 1表示**未遮蔽**的标记
    # - 0表示**遮蔽**的标记
    attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
    
    # 可选参数,位置索引,形状为(batch_size, sequence_length)
    # 每个输入序列标记在位置嵌入中的位置索引,取值范围为[0, config.max_position_embeddings - 1]
    position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
    
    # 输入参数为像素值,形状为(batch_size, num_channels, height, width)
    # 像素值。如果提供填充,则默认忽略。可以使用图像处理器获取像素值(例如,使用ViT作为编码器时应使用AutoImageProcessor)。
    # 详见ViTImageProcessor.__call__
    pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
    
    # 可选参数,是否返回对比损失
    return_loss (`bool`, *optional*):
    
    # 可选参数,是否返回所有注意力层的注意力张量
    # 返回值中包含更多详细信息,详见返回的tensors下的attentions
    output_attentions (`bool`, *optional*):
    
    # 可选参数,是否返回所有层的隐藏状态
    # 返回值中包含更多详细信息,详见返回的tensors下的hidden_states
    output_hidden_states (`bool`, *optional*):
    
    # 可选参数,是否返回utils.ModelOutput而不是普通元组
    return_dict (`bool`, *optional*):
"""
# 从 transformers.models.clip.modeling_tf_clip.contrastive_loss 复制的对比损失函数定义
def contrastive_loss(logits: tf.Tensor) -> tf.Tensor:
    # 计算稀疏分类交叉熵损失的均值,用于对比损失的计算
    return tf.math.reduce_mean(
        keras.metrics.sparse_categorical_crossentropy(
            y_true=tf.range(shape_list(logits)[0]), y_pred=logits, from_logits=True
        )
    )


# 从 transformers.models.clip.modeling_tf_clip.clip_loss 复制的 CLIP 损失函数定义
def clip_loss(similarity: tf.Tensor) -> tf.Tensor:
    # 计算标题和图像的对比损失,由对比损失函数 contrastive_loss 计算
    caption_loss = contrastive_loss(similarity)
    # 转置相似度矩阵并计算图像和标题的对比损失,同样使用 contrastive_loss 函数
    image_loss = contrastive_loss(tf.transpose(similarity))
    # 返回标题损失和图像损失的平均值作为 CLIP 损失
    return (caption_loss + image_loss) / 2.0


# 使用 @add_start_docstrings(VISION_TEXT_DUAL_ENCODER_START_DOCSTRING) 装饰器注释的双编码器模型类
class TFVisionTextDualEncoderModel(TFPreTrainedModel):
    # 指定模型配置类
    config_class = VisionTextDualEncoderConfig
    # 指定基础模型前缀
    base_model_prefix = "vision_text_dual_encoder"
    # 指定加载权重前缀
    load_weight_prefix = "tf_vision_text_dual_encoder_model"

    def __init__(
        self,
        config: Optional[VisionTextDualEncoderConfig] = None,
        vision_model: Optional[TFPreTrainedModel] = None,
        text_model: Optional[TFPreTrainedModel] = None,
    ):
        # 如果未提供配置且视觉模型或文本模型任一未提供,则引发 ValueError
        if config is None and (vision_model is None or text_model is None):
            raise ValueError("Either a configuration or an vision and a text model has to be provided")

        # 如果未提供配置,则从视觉和文本模型的配置中创建 VisionTextDualEncoderConfig
        if config is None:
            config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_model.config, text_model.config)
        else:
            # 如果提供的配置不是 VisionTextDualEncoderConfig 类型,则引发 ValueError
            if not isinstance(config, self.config_class):
                raise ValueError(f"config: {config} has to be of type {self.config_class}")

        # 使用配置初始化父类
        super().__init__(config)

        # 如果未提供视觉模型,则根据配置创建适当的视觉模型
        if vision_model is None:
            if isinstance(config.vision_config, CLIPVisionConfig):
                vision_model = TFCLIPVisionModel.from_config(config.vision_config, name="vision_model")
            else:
                vision_model = TFAutoModel.from_config(config.vision_config, name="vision_model")

        # 如果未提供文本模型,则根据配置创建适当的文本模型
        if text_model is None:
            text_model = TFAutoModel.from_config(config.text_config, name="text_model")

        # 分别设置视觉模型和文本模型
        self.vision_model = vision_model
        self.text_model = text_model

        # 确保各模型的配置引用共享配置,以保持配置更新同步
        self.vision_model.config = self.config.vision_config
        self.text_model.config = self.config.text_config

        # 设置视觉嵌入维度、文本嵌入维度和投影维度
        self.vision_embed_dim = config.vision_config.hidden_size
        self.text_embed_dim = config.text_config.hidden_size
        self.projection_dim = config.projection_dim

        # 定义视觉和文本的投影层,不使用偏置项
        self.visual_projection = keras.layers.Dense(self.projection_dim, use_bias=False, name="visual_projection")
        self.text_projection = keras.layers.Dense(self.projection_dim, use_bias=False, name="text_projection")

        # 初始化日志尺度为 None
        self.logit_scale = None

        # 设置模型配置
        self.config = config
    # 在构建方法中构建模型,确保命名正确
    def build(self, input_shape=None):
        # 如果已经构建过,则直接返回,避免重复构建
        if self.built:
            return
        # 设置标志表示模型已经构建
        self.built = True
        # 使用常量初始化器设置logit_scale权重,shape为(1,)
        initializer = keras.initializers.Constant(self.config.logit_scale_init_value)
        self.logit_scale = self.add_weight(shape=(1,), initializer=initializer, name="logit_scale")

        # 如果存在visual_projection属性,则构建它并设置命名空间
        if getattr(self, "visual_projection", None) is not None:
            with tf.name_scope(self.visual_projection.name):
                self.visual_projection.build([None, None, self.vision_embed_dim])
        
        # 如果存在text_projection属性,则构建它并设置命名空间
        if getattr(self, "text_projection", None) is not None:
            with tf.name_scope(self.text_projection.name):
                self.text_projection.build([None, None, self.text_embed_dim])
        
        # 设置vision_model的命名空间并构建其模型
        with tf.name_scope(self.vision_model.name):
            self.vision_model.build(None)
        
        # 设置text_model的命名空间并构建其模型
        with tf.name_scope(self.text_model.name):
            self.text_model.build(None)

    # 将TensorFlow的权重名称转换为PyTorch风格的权重名称
    def tf_to_pt_weight_rename(self, tf_weight):
        # 如果权重名称中包含"vision_model",则根据不同情况进行重命名处理
        if "vision_model" in tf_weight:
            if tf_weight.count("vision_model") == 1:
                return (re.sub(r"vision_model\..*?\.", "vision_model.", tf_weight),)
            elif tf_weight.count("vision_model") == 2:
                return (re.sub(r"vision_model\..*?\.vision_model", "vision_model.vision_model", tf_weight),)
            else:
                raise ValueError(
                    f"Unexpected weight name {tf_weight}. Please file an issue on the"
                    " Transformers repo to let us know about this error!"
                )
        # 如果权重名称中包含"text_model",则进行相应的重命名处理
        elif "text_model" in tf_weight:
            return (re.sub(r"text_model\..*?\.", "text_model.", tf_weight),)
        # 如果以上条件都不符合,则返回原始的权重名称
        else:
            return (tf_weight,)

    # 添加模型前向传播的文档字符串,并用VISION_TEXT_DUAL_ENCODER_TEXT_INPUTS_DOCSTRING进行注释
    def get_text_features(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        token_type_ids=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        Returns:
            text_features (`tf.Tensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying
            the projection layer to the pooled output of [`TFCLIPTextModel`].

        Examples:

        ```
        >>> from transformers import TFVisionTextDualEncoderModel, AutoTokenizer

        >>> model = TFVisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian", from_pt=True)
        >>> tokenizer = AutoTokenizer.from_pretrained("clip-italian/clip-italian")

        >>> inputs = tokenizer(["una foto di un gatto", "una foto di un cane"], padding=True, return_tensors="np")
        >>> text_features = model.get_text_features(**inputs)
        ```"""
        # 使用 self.text_model 处理输入,获取文本输出
        text_outputs = self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从文本输出中获取池化后的输出
        pooled_output = text_outputs[1]
        # 使用 self.text_projection 对池化输出进行投影,得到文本特征
        text_features = self.text_projection(pooled_output)

        return text_features

    @add_start_docstrings_to_model_forward(VISION_TEXT_DUAL_ENCODER_VISION_INPUTS_DOCSTRING)
    def get_image_features(
        self,
        pixel_values=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        Returns:
            image_features (`tf.Tensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying
            the projection layer to the pooled output of [`TFCLIPVisionModel`].

        Examples:

        ```
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import TFVisionTextDualEncoderModel, AutoImageProcessor

        >>> model = TFVisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian", from_pt=True)
        >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = image_processor(images=image, return_tensors="np")

        >>> image_features = model.get_image_features(**inputs)
        ```"""
        # 使用 self.vision_model 处理输入,获取视觉输出
        vision_outputs = self.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从视觉输出中获取池化后的输出
        pooled_output = vision_outputs[1]  # pooled_output
        # 使用 self.visual_projection 对池化输出进行投影,得到图像特征
        image_features = self.visual_projection(pooled_output)

        return image_features

    @unpack_inputs
    @add_start_docstrings_to_model_forward(VISION_TEXT_DUAL_ENCODER_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=TFCLIPOutput, config_class=_CONFIG_FOR_DOC)
    # 定义一个方法 `call`,用于执行模型推理或训练过程的输入处理和参数设置
    def call(
        self,
        input_ids: tf.Tensor | None = None,  # 输入文本的token IDs张量,默认为None
        pixel_values: tf.Tensor | None = None,  # 输入图像的像素值张量,默认为None
        attention_mask: tf.Tensor | None = None,  # 注意力掩码张量,默认为None
        position_ids: tf.Tensor | None = None,  # 位置编码张量,默认为None
        return_loss: Optional[bool] = None,  # 是否返回损失张量的布尔值,可选,默认为None
        token_type_ids: tf.Tensor | None = None,  # token类型 IDs 张量,默认为None
        output_attentions: Optional[bool] = None,  # 是否输出注意力张量的布尔值,可选,默认为None
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态张量的布尔值,可选,默认为None
        return_dict: Optional[bool] = None,  # 是否返回字典格式的输出结果的布尔值,可选,默认为None
        training: bool = False,  # 是否为训练模式的布尔值,默认为False
    ):
        
    # 类方法,用于从预训练的视觉-文本模型加载模型
    @classmethod
    def from_vision_text_pretrained(
        cls,
        vision_model_name_or_path: str = None,  # 视觉模型名称或路径的字符串,默认为None
        text_model_name_or_path: str = None,  # 文本模型名称或路径的字符串,默认为None
        *model_args,  # 模型参数的位置参数
        **kwargs,  # 模型参数的关键字参数
    ):
        
    # 属性方法,返回构建网络所需的虚拟输入数据字典
    @property
    def dummy_inputs(self):
        """
        Dummy inputs to build the network.

        Returns:
            `Dict[str, tf.Tensor]`: The dummy inputs.
        """
        # 使用预定义的虚拟输入数据构建输入文本的token IDs张量
        input_ids = tf.constant(DUMMY_INPUTS, dtype=tf.int32)
        batch_size, seq_len = input_ids.shape

        # 使用随机生成的虚拟输入数据构建输入图像的像素值张量
        VISION_DUMMY_INPUTS = tf.random.uniform(
            shape=(
                batch_size,
                self.config.vision_config.num_channels,
                self.config.vision_config.image_size,
                self.config.vision_config.image_size,
            ),
            dtype=tf.float32,
        )
        pixel_values = tf.constant(VISION_DUMMY_INPUTS)
        # 构建并返回包含虚拟输入数据的字典
        dummy = {"pixel_values": pixel_values, "input_ids": input_ids}
        return dummy

.\models\vision_text_dual_encoder\modeling_vision_text_dual_encoder.py

# 设置编码格式为 UTF-8

# 版权声明和许可协议,表明此代码的版权和许可情况
# 版权所有 2021 年 HuggingFace Inc. 团队。保留所有权利。
# 根据 Apache 许可证 2.0 版本授权,除非遵守许可证,否则不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则按“原样”分发本软件,
# 不提供任何明示或暗示的保证或条件。有关详细信息,请参阅许可证。

""" PyTorch VisionTextDualEncoder model. """

# 导入必要的库
from typing import Optional, Tuple, Union

import torch
from torch import nn

# 导入模型相关的实用函数和类
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings

# 自动导入相关配置类
from ..auto.configuration_auto import AutoConfig
from ..auto.modeling_auto import AutoModel

# 导入与 CLIP 相关的模型和配置
from ..clip.modeling_clip import CLIPOutput, CLIPVisionConfig, CLIPVisionModel

# 导入 VisionTextDualEncoder 模型的配置类
from .configuration_vision_text_dual_encoder import VisionTextDualEncoderConfig

# 获取日志记录器
logger = logging.get_logger(__name__)

# 文档用变量,指定 VisionTextDualEncoderConfig 的字符串
_CONFIG_FOR_DOC = "VisionTextDualEncoderConfig"

# VisionTextDualEncoder 类的文档字符串,提供了关于该类的详细信息和用法示例
VISION_TEXT_DUAL_ENCODER_START_DOCSTRING = r"""
    This class can be used to initialize a vision-text dual encoder model with any pretrained vision autoencoding model
    as the vision encoder and any pretrained text model as the text encoder. The vision and text encoders are loaded
    via the [`~AutoModel.from_pretrained`] method. The projection layers are automatically added to the model and
    should be fine-tuned on a downstream task, like contrastive image-text modeling.

    In [LiT: Zero-Shot Transfer with Locked-image Text Tuning](https://arxiv.org/abs/2111.07991) it is shown how
    leveraging pre-trained (locked/frozen) image and text model for contrastive learning yields significant improvment
    on new zero-shot vision tasks such as image classification or retrieval.

    After such a Vision-Text-Dual-Encoder model has been trained/fine-tuned, it can be saved/loaded just like any other
    models (see the examples for more information).

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.
    Parameters:
        config ([`VisionEncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
Args:
    input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
        输入序列标记在词汇表中的索引。默认情况下,将忽略填充标记。

        可以使用[`PreTrainedTokenizer`]获取索引。参见[`PreTrainedTokenizer.encode`]和[`PreTrainedTokenizer.__call__`]了解详情。

        [什么是输入 ID?](../glossary#input-ids)
    attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
        遮盖机制,用于避免在填充标记索引上执行注意力操作。遮盖值选择在 `[0, 1]`:

        - 1 表示**未遮盖**的标记,
        - 0 表示**遮盖**的标记。

        [什么是注意力遮盖?](../glossary#attention-mask)
    position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
        输入序列中每个标记的位置索引,用于位置嵌入。选择范围在 `[0, config.max_position_embeddings - 1]`。

        [什么是位置 ID?](../glossary#position-ids)
    output_attentions (`bool`, *optional*):
        是否返回所有注意力层的注意力张量。查看返回张量中的 `attentions` 获取更多细节。
    output_hidden_states (`bool`, *optional*):
        是否返回所有层的隐藏状态。查看返回张量中的 `hidden_states` 获取更多细节。
    return_dict (`bool`, *optional*):
        是否返回 [`~utils.ModelOutput`] 而不是简单的元组。
"""
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            # 输入序列标记的索引,用于表示词汇表中的每个标记。默认情况下会忽略填充。
            # 可以使用 `AutoTokenizer` 获取这些索引。参见 `PreTrainedTokenizer.encode` 和 `PreTrainedTokenizer.__call__` 获取详情。
            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            # 避免在填充的标记索引上执行注意力操作的掩码。掩码的值选择在 `[0, 1]` 之间:
            # - 对于 **未屏蔽的** 标记,设为 1,
            # - 对于 **屏蔽的** 标记,设为 0。
            [What are attention masks?](../glossary#attention-mask)
        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            # 每个输入序列标记在位置嵌入中的位置索引。选择范围在 `[0, config.max_position_embeddings - 1]` 之间。
            [What are position IDs?](../glossary#position-ids)
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            # 像素值。默认情况下会忽略填充。可以通过图像处理器获取像素值(例如,如果使用 ViT 作为编码器,应使用 `AutoImageProcessor`)。
            # 详见 `ViTImageProcessor.__call__` 获取详情。
        return_loss (`bool`, *optional*):
            # 是否返回对比损失。
        output_attentions (`bool`, *optional*):
            # 是否返回所有注意力层的注意力张量。返回的张量中的 `attentions` 字段提供更多细节。
        output_hidden_states (`bool`, *optional*):
            # 是否返回所有层的隐藏状态。返回的张量中的 `hidden_states` 字段提供更多细节。
        return_dict (`bool`, *optional*):
            # 是否返回 `~utils.ModelOutput` 而不是普通元组。
"""
Copied from transformers.models.clip.modeling_clip.contrastive_loss
定义对比损失函数,输入为 logits,输出为损失值
"""
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
    return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))


"""
Copied from transformers.models.clip.modeling_clip.clip_loss
定义 CLIP 损失函数,输入为相似性张量,输出为损失值
"""
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
    # 计算文本和图像的对比损失
    caption_loss = contrastive_loss(similarity)
    image_loss = contrastive_loss(similarity.t())
    # 返回文本和图像损失的平均值
    return (caption_loss + image_loss) / 2.0


"""
@add_start_docstrings(VISION_TEXT_DUAL_ENCODER_START_DOCSTRING)
双编码器模型,结合视觉和文本输入进行编码
"""
class VisionTextDualEncoderModel(PreTrainedModel):
    config_class = VisionTextDualEncoderConfig
    base_model_prefix = "vision_text_dual_encoder"

    def __init__(
        self,
        config: Optional[VisionTextDualEncoderConfig] = None,
        vision_model: Optional[PreTrainedModel] = None,
        text_model: Optional[PreTrainedModel] = None,
    ):
        if config is None and (vision_model is None or text_model is None):
            raise ValueError("Either a configuration or an vision and a text model has to be provided")

        if config is None:
            # 如果未提供配置,则从视觉和文本模型的配置创建配置对象
            config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_model.config, text_model.config)
        else:
            if not isinstance(config, self.config_class):
                raise ValueError(f"config: {config} has to be of type {self.config_class}")

        # 使用父类初始化模型
        super().__init__(config)

        # 如果未提供视觉模型,则根据配置创建默认的视觉模型
        if vision_model is None:
            if isinstance(config.vision_config, CLIPVisionConfig):
                vision_model = CLIPVisionModel(config.vision_config)
            else:
                vision_model = AutoModel.from_config(config.vision_config)

        # 如果未提供文本模型,则根据配置创建默认的文本模型
        if text_model is None:
            text_model = AutoModel.from_config(config.text_config)

        # 将创建的视觉模型和文本模型保存到当前对象中
        self.vision_model = vision_model
        self.text_model = text_model

        # 确保各个模型的配置对象与共享的配置对象同步更新
        self.vision_model.config = self.config.vision_config
        self.text_model.config = self.config.text_config

        # 设置视觉和文本嵌入的维度和投影维度
        self.vision_embed_dim = config.vision_config.hidden_size
        self.text_embed_dim = config.text_config.hidden_size
        self.projection_dim = config.projection_dim

        # 定义视觉和文本的线性投影层
        self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
        self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)

        # 初始化 logits 缩放参数
        self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))

    @add_start_docstrings_to_model_forward(VISION_TEXT_DUAL_ENCODER_TEXT_INPUTS_DOCSTRING)
    def get_text_features(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        token_type_ids=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        # 注意:这里函数定义没有完全列出,后续可能还有参数
    @add_start_docstrings_to_model_forward(VISION_TEXT_DUAL_ENCODER_VISION_INPUTS_DOCSTRING)
    # 使用装饰器添加模型前向传播的文档字符串,文档字符串定义了输入参数和返回结果的形状和含义
    def get_image_features(
        self,
        pixel_values=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        Returns:
            image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
            applying the projection layer to the pooled output of [`CLIPVisionModel`].

        Examples:

        ```
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import VisionTextDualEncoderModel, AutoImageProcessor

        >>> model = VisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian")
        >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = image_processor(images=image, return_tensors="pt")

        >>> image_features = model.get_image_features(**inputs)
        ```"""
        # 使用视觉模型处理像素值,返回视觉特征,可以控制是否输出注意力和隐藏状态,并选择返回形式
        vision_outputs = self.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从视觉输出的第二个元素中获取池化输出作为特征表示
        pooled_output = vision_outputs[1]  # pooled_output
        # 将池化输出应用于视觉投影层,得到最终的图像特征表示
        image_features = self.visual_projection(pooled_output)

        # 返回图像特征表示
        return image_features
    # 定义一个类方法 `forward`,用于模型的前向传播
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,  # 输入的 token IDs,类型为长整型张量,可选
        pixel_values: Optional[torch.FloatTensor] = None,  # 输入的像素值,类型为浮点张量,可选
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码张量,类型为张量,可选
        position_ids: Optional[torch.LongTensor] = None,  # 位置 IDs,类型为长整型张量,可选
        return_loss: Optional[bool] = None,  # 是否返回损失值,类型为布尔值,可选
        token_type_ids: Optional[torch.LongTensor] = None,  # Token 类型 IDs,类型为长整型张量,可选
        output_attentions: Optional[bool] = None,  # 是否输出注意力权重,类型为布尔值,可选
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态,类型为布尔值,可选
        return_dict: Optional[bool] = None,  # 是否返回字典格式的输出,类型为布尔值,可选
    @classmethod
    def from_pretrained(cls, *args, **kwargs):
        # 目前不支持复合模型的快速初始化
        kwargs["_fast_init"] = False
        # 调用父类的 `from_pretrained` 方法,并传递所有的位置参数和关键字参数
        return super().from_pretrained(*args, **kwargs)

    @classmethod
    def from_vision_text_pretrained(
        cls,
        vision_model_name_or_path: str = None,  # 视觉模型的名称或路径,类型为字符串,可选
        text_model_name_or_path: str = None,  # 文本模型的名称或路径,类型为字符串,可选
        *model_args,  # 其他模型参数,位置参数的元组
        **kwargs,  # 其他模型参数,关键字参数的字典

.\models\vision_text_dual_encoder\processing_vision_text_dual_encoder.py

# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for VisionTextDualEncoder
"""

import warnings

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding


class VisionTextDualEncoderProcessor(ProcessorMixin):
    r"""
    Constructs a VisionTextDualEncoder processor which wraps an image processor and a tokenizer into a single
    processor.

    [`VisionTextDualEncoderProcessor`] offers all the functionalities of [`AutoImageProcessor`] and [`AutoTokenizer`].
    See the [`~VisionTextDualEncoderProcessor.__call__`] and [`~VisionTextDualEncoderProcessor.decode`] for more
    information.

    Args:
        image_processor ([`AutoImageProcessor`], *optional*):
            The image processor is a required input.
        tokenizer ([`PreTrainedTokenizer`], *optional*):
            The tokenizer is a required input.
    """

    attributes = ["image_processor", "tokenizer"]
    image_processor_class = "AutoImageProcessor"
    tokenizer_class = "AutoTokenizer"

    def __init__(self, image_processor=None, tokenizer=None, **kwargs):
        # Deprecated feature_extractor handling
        feature_extractor = None
        if "feature_extractor" in kwargs:
            warnings.warn(
                "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
                " instead.",
                FutureWarning,
            )
            feature_extractor = kwargs.pop("feature_extractor")

        # Set image_processor to feature_extractor if not provided separately
        image_processor = image_processor if image_processor is not None else feature_extractor
        if image_processor is None:
            raise ValueError("You have to specify an image_processor.")
        if tokenizer is None:
            raise ValueError("You have to specify a tokenizer.")

        # Initialize the processor with image_processor and tokenizer
        super().__init__(image_processor, tokenizer)
        # Set the current_processor to image_processor
        self.current_processor = self.image_processor

    def batch_decode(self, *args, **kwargs):
        """
        This method forwards all its arguments to VisionTextDualEncoderTokenizer's
        [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information.
        """
        # Delegate batch_decode to the underlying tokenizer
        return self.tokenizer.batch_decode(*args, **kwargs)
    # 将所有参数转发给 VisionTextDualEncoderTokenizer 的 `PreTrainedTokenizer.decode` 方法。
    # 请参考该方法的文档字符串获取更多信息。
    def decode(self, *args, **kwargs):
        return self.tokenizer.decode(*args, **kwargs)

    # 返回模型输入的名称列表,合并并去除重复的 Tokenizer 和图像处理器的输入名称。
    @property
    def model_input_names(self):
        tokenizer_input_names = self.tokenizer.model_input_names
        image_processor_input_names = self.image_processor.model_input_names
        return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))

    # 返回图像处理器的类。
    # 发出警告,提示`feature_extractor_class`将在v5中移除,建议使用`image_processor_class`代替。
    @property
    def feature_extractor_class(self):
        warnings.warn(
            "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.",
            FutureWarning,
        )
        return self.image_processor_class

    # 返回图像处理器。
    # 发出警告,提示`feature_extractor`将在v5中移除,建议使用`image_processor`代替。
    @property
    def feature_extractor(self):
        warnings.warn(
            "`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.",
            FutureWarning,
        )
        return self.image_processor
posted @ 2024-07-01 10:57  绝不原创的飞龙  阅读(30)  评论(0编辑  收藏  举报