Transformers-源码解析-九十一-

Transformers 源码解析(九十一)

.\models\pvt_v2\__init__.py

# coding=utf-8
# 版权所有 2023 作者:Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan,
# Kaitao Song, Ding Liang, Tong Lu, Ping Luo, Ling Shao 和 HuggingFace Inc. 团队。
# 保留所有权利。
#
# 根据Apache许可证2.0版许可
# 除非符合许可证的规定,否则不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件根据“原样”分发,
# 没有任何形式的明示或暗示的担保或条件。
# 有关详细信息,请参阅许可证。
from typing import TYPE_CHECKING

from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_torch_available,
    is_vision_available,
)

# 定义导入结构
_import_structure = {
    "configuration_pvt_v2": ["PVT_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "PvtV2Config"],
}

try:
    # 如果torch不可用,引发OptionalDependencyNotAvailable异常
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    # 如果OptionalDependencyNotAvailable异常被引发,则不执行任何操作
    pass
else:
    # 如果torch可用,则添加以下模块到导入结构
    _import_structure["modeling_pvt_v2"] = [
        "PVT_V2_PRETRAINED_MODEL_ARCHIVE_LIST",
        "PvtV2ForImageClassification",
        "PvtV2Model",
        "PvtV2PreTrainedModel",
        "PvtV2Backbone",
    ]


if TYPE_CHECKING:
    # 如果当前在类型检查模式下
    from .configuration_pvt_v2 import PVT_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, PvtV2Config

    try:
        # 如果torch不可用,引发OptionalDependencyNotAvailable异常
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        # 如果OptionalDependencyNotAvailable异常被引发,则不执行任何操作
        pass
    else:
        # 如果torch可用,则导入以下模块到当前命名空间
        from .modeling_pvt_v2 import (
            PVT_V2_PRETRAINED_MODEL_ARCHIVE_LIST,
            PvtV2Backbone,
            PvtV2ForImageClassification,
            PvtV2Model,
            PvtV2PreTrainedModel,
        )

else:
    # 如果不在类型检查模式下,则导入延迟模块_LazyModule
    import sys

    # 将当前模块注册为_LazyModule类型,使用指定的导入结构和模块规范
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\qdqbert\configuration_qdqbert.py

# coding=utf-8
# Copyright 2021 NVIDIA Corporation and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
QDQBERT model configuration
"""

# 导入预训练配置类 PretrainedConfig 和日志工具 logging
from ...configuration_utils import PretrainedConfig
from ...utils import logging

# 获取 logger 对象用于记录日志
logger = logging.get_logger(__name__)

# QDQBERT 预训练配置文件映射字典,将模型名称映射到其配置文件的 URL
QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "google-bert/bert-base-uncased": "https://huggingface.co/google-bert/bert-base-uncased/resolve/main/config.json",
    # QDQBERT 模型可以从任何 BERT 检查点加载,这些检查点可在 https://huggingface.co/models?filter=bert 找到
}

# QDQBERT 配置类,继承自 PretrainedConfig
class QDQBertConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`QDQBertModel`]. It is used to instantiate an
    QDQBERT model according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the BERT
    [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    # 定义模型类型为 "qdqbert"
    model_type = "qdqbert"
        # 初始化函数,用于创建一个新的实例
        def __init__(
            self,
            vocab_size=30522,                          # 词汇表大小,默认为30522
            hidden_size=768,                           # 隐藏层大小,默认为768
            num_hidden_layers=12,                      # 隐藏层的数量,默认为12
            num_attention_heads=12,                    # 注意力头的数量,默认为12
            intermediate_size=3072,                    # 中间层大小,默认为3072
            hidden_act="gelu",                         # 隐藏层激活函数,默认为gelu
            hidden_dropout_prob=0.1,                   # 隐藏层的Dropout概率,默认为0.1
            attention_probs_dropout_prob=0.1,          # 注意力概率的Dropout概率,默认为0.1
            max_position_embeddings=512,               # 最大位置嵌入数,默认为512
            type_vocab_size=2,                         # 类型词汇表大小,默认为2
            initializer_range=0.02,                    # 初始化范围,默认为0.02
            layer_norm_eps=1e-12,                      # Layer Norm的epsilon值,默认为1e-12
            use_cache=True,                            # 是否使用缓存,默认为True
            pad_token_id=1,                            # 填充token的ID,默认为1
            bos_token_id=0,                            # 开始token的ID,默认为0
            eos_token_id=2,                            # 结束token的ID,默认为2
            **kwargs,
        ):
            # 调用父类的初始化函数,设置特殊token的ID和额外的关键字参数
            super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

            # 初始化类的成员变量,设置模型的各种参数
            self.vocab_size = vocab_size                # 设置词汇表大小
            self.max_position_embeddings = max_position_embeddings  # 设置最大位置嵌入数
            self.hidden_size = hidden_size              # 设置隐藏层大小
            self.num_hidden_layers = num_hidden_layers  # 设置隐藏层的数量
            self.num_attention_heads = num_attention_heads  # 设置注意力头的数量
            self.intermediate_size = intermediate_size  # 设置中间层大小
            self.hidden_act = hidden_act                # 设置隐藏层激活函数
            self.hidden_dropout_prob = hidden_dropout_prob  # 设置隐藏层的Dropout概率
            self.attention_probs_dropout_prob = attention_probs_dropout_prob  # 设置注意力概率的Dropout概率
            self.initializer_range = initializer_range  # 设置初始化范围
            self.type_vocab_size = type_vocab_size      # 设置类型词汇表大小
            self.layer_norm_eps = layer_norm_eps        # 设置Layer Norm的epsilon值
            self.use_cache = use_cache                  # 设置是否使用缓存

.\models\qdqbert\modeling_qdqbert.py

# coding=utf-8
# Copyright 2021 NVIDIA Corporation and The HuggingFace Team.
# Copyright (c) 2018-2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
PyTorch QDQBERT model.
"""


import math  # 导入数学库
import os  # 导入操作系统功能库
import warnings  # 导入警告模块
from typing import Dict, List, Optional, Tuple, Union  # 导入类型提示相关模块

import torch  # 导入PyTorch库
import torch.utils.checkpoint  # 导入PyTorch的checkpoint功能
from torch import nn  # 导入PyTorch的神经网络模块
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss  # 导入损失函数

from ...activations import ACT2FN  # 导入激活函数
from ...modeling_outputs import (  # 导入模型输出相关类
    BaseModelOutputWithPastAndCrossAttentions,
    BaseModelOutputWithPoolingAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    NextSentencePredictorOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel  # 导入预训练模型基类
from ...pytorch_utils import (  # 导入PyTorch工具函数
    find_pruneable_heads_and_indices,
    prune_linear_layer,
)
from ...utils import (  # 导入通用工具函数
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_pytorch_quantization_available,
    logging,
    replace_return_docstrings,
    requires_backends,
)
from .configuration_qdqbert import QDQBertConfig  # 导入QDQBERT模型配置

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

# soft dependency
if is_pytorch_quantization_available():  # 检查是否支持PyTorch量化
    try:
        from pytorch_quantization import nn as quant_nn  # 导入PyTorch量化模块
        from pytorch_quantization.nn.modules.tensor_quantizer import TensorQuantizer  # 导入量化张量模块
    except OSError:
        logger.error(
            "QDQBERT model are not usable since `pytorch_quantization` can't be loaded. Please try to reinstall it"
            " following the instructions here:"
            " https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization."
        )

_CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased"  # 定义用于文档的检查点名称
_CONFIG_FOR_DOC = "QDQBertConfig"  # 定义用于文档的配置名称

QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [  # 预训练模型存档列表
    "google-bert/bert-base-uncased",
    # See all BERT models at https://huggingface.co/models?filter=bert
]


def load_tf_weights_in_qdqbert(model, tf_checkpoint_path):
    """Load tf checkpoints in a pytorch model."""
    try:
        import re  # 导入正则表达式模块
        import numpy as np  # 导入NumPy库
        import tensorflow as tf  # 导入TensorFlow库
    except ImportError:
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
        raise
    tf_path = os.path.abspath(tf_checkpoint_path)
    # 使用给定的 TensorFlow 检查点路径获取其绝对路径
    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
    # 记录日志,指示正在转换 TensorFlow 检查点

    # 从 TensorFlow 模型中加载权重变量
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []

    # 遍历初始化的变量名和形状
    for name, shape in init_vars:
        logger.info(f"Loading TF weight {name} with shape {shape}")
        # 记录日志,指示正在加载 TensorFlow 权重,并记录其形状
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
        arrays.append(array)

    # 遍历加载的变量名和数组
    for name, array in zip(names, arrays):
        # 将变量名按斜杠划分为子路径
        name = name.split("/")

        # 跳过特定的变量名,这些变量不需要在预训练模型中使用
        if any(
            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
            for n in name
        ):
            logger.info(f"Skipping {'/'.join(name)}")
            continue
        
        # 设置指针初始位置为模型对象
        pointer = model
        
        # 遍历变量名的各个部分
        for m_name in name:
            # 如果变量名匹配字母加下划线加数字的模式,则按下划线划分
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
                scope_names = re.split(r"_(\d+)", m_name)
            else:
                scope_names = [m_name]
            
            # 根据变量名的首部设置指针指向相应的模型部分
            if scope_names[0] == "kernel" or scope_names[0] == "gamma":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
                pointer = getattr(pointer, "bias")
            elif scope_names[0] == "output_weights":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "squad":
                pointer = getattr(pointer, "classifier")
            else:
                try:
                    pointer = getattr(pointer, scope_names[0])
                except AttributeError:
                    logger.info(f"Skipping {'/'.join(name)}")
                    continue
            
            # 如果变量名有多个部分,则继续在模型中深入
            if len(scope_names) >= 2:
                num = int(scope_names[1])
                pointer = pointer[num]
        
        # 如果变量名以 "_embeddings" 结尾,则将指针指向权重部分
        if m_name[-11:] == "_embeddings":
            pointer = getattr(pointer, "weight")
        elif m_name == "kernel":
            array = np.transpose(array)  # 转置数组(针对 kernel 变量)

        # 检查指针和加载的数组形状是否匹配,如果不匹配则抛出异常
        try:
            if pointer.shape != array.shape:
                raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
        
        # 记录日志,指示初始化 PyTorch 权重
        logger.info(f"Initialize PyTorch weight {name}")
        # 将加载的 NumPy 数组转换为 PyTorch 的 Tensor,并赋值给指针
        pointer.data = torch.from_numpy(array)

    # 返回转换后的 PyTorch 模型
    return model
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert -> QDQBert
# 定义了一个名为 QDQBertEmbeddings 的类,用于构建从词嵌入、位置嵌入和标记类型嵌入生成的嵌入向量。

class QDQBertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""
    # 类的初始化函数,接受一个 config 参数
    def __init__(self, config):
        super().__init__()
        # 词嵌入层,使用 nn.Embedding 创建,参数为词汇表大小、隐藏层大小和填充标记索引
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        # 位置嵌入层,使用 nn.Embedding 创建,参数为最大位置嵌入数和隐藏层大小
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        # 标记类型嵌入层,使用 nn.Embedding 创建,参数为标记类型数和隐藏层大小
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # LayerNorm 层,用于归一化隐藏层输出,参数为隐藏层大小和层标准化的 epsilon 值
        # 这里的 LayerNorm 命名不使用蛇形命名法,以保持与 TensorFlow 模型变量名称的一致性,可以加载任何 TensorFlow 检查点文件
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # Dropout 层,用于随机失活以防止过拟合,参数为隐藏层的丢弃概率
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        
        # position_embedding_type 属性,默认为 "absolute",表示位置嵌入类型为绝对位置编码
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
        # 注册一个持久化张量 "position_ids",其值为从 0 到 max_position_embeddings-1 的序列张量,形状为 (1, max_position_embeddings)
        self.register_buffer(
            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
        )
        # 注册一个持久化张量 "token_type_ids",其值为全零的张量,形状与 "position_ids" 相同,数据类型为长整型
        self.register_buffer(
            "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
        )

    # 前向传播函数,接受多个输入参数,并返回嵌入向量
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        past_key_values_length: int = 0,
        # 省略了函数剩余部分的参数和逻辑,不在这里进行注释
        # 定义函数forward的输入参数及其类型注解,返回torch.Tensor类型的张量
        if input_ids is not None:
            # 如果input_ids不为None,则获取其形状(尺寸)
            input_shape = input_ids.size()
        else:
            # 如果input_ids为None,则获取inputs_embeds的形状,但去掉最后一维
            input_shape = inputs_embeds.size()[:-1]

        # 获取序列长度,即input_shape的第二个维度
        seq_length = input_shape[1]

        # 如果position_ids为None,则从self.position_ids中选择对应长度的位置编码
        if position_ids is None:
            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]

        # 设置token_type_ids为构造函数中注册的缓冲区,通常是全零向量,在模型追踪时帮助解决不传递token_type_ids的问题
        if token_type_ids is None:
            if hasattr(self, "token_type_ids"):
                # 如果self中有token_type_ids属性,则使用其缓冲区中的值,并进行扩展以匹配input_shape
                buffered_token_type_ids = self.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
                token_type_ids = buffered_token_type_ids_expanded
            else:
                # 否则,创建全零的token_type_ids张量,与input_shape相同的形状
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

        # 如果inputs_embeds为None,则通过word_embeddings层获取input_ids的嵌入表示
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        
        # 使用token_type_ids获取token_type的嵌入表示
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        # 将input embeddings与token_type embeddings相加,得到最终的嵌入表示
        embeddings = inputs_embeds + token_type_embeddings

        # 如果position_embedding_type为"absolute",则加上位置编码的嵌入表示
        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings
        
        # 对嵌入表示进行LayerNorm归一化
        embeddings = self.LayerNorm(embeddings)
        
        # 对归一化后的嵌入表示进行dropout处理
        embeddings = self.dropout(embeddings)
        
        # 返回最终的嵌入表示张量
        return embeddings
# 定义一个名为 QDQBertSelfOutput 的类,继承自 nn.Module
class QDQBertSelfOutput(nn.Module):
    # 初始化方法,接收一个 config 参数
    def __init__(self, config):
        super().__init__()
        
        # 创建一个 QuantLinear 对象,用于量化线性层的输入和输出
        self.dense = quant_nn.QuantLinear(config.hidden_size, config.hidden_size)
        
        # 创建一个 LayerNorm 层,用于归一化隐藏状态的特征
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        
        # 创建一个 Dropout 层,用于随机失活隐藏状态的特征
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        
        # 创建一个用于量化局部输入的 TensorQuantizer 对象
        self.add_local_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
        
        # 创建一个用于量化残差输入的 TensorQuantizer 对象
        self.add_residual_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
    # 定义模型的前向传播函数,接收隐藏状态和输入张量作为参数
    def forward(self, hidden_states, input_tensor):
        # 将隐藏状态通过全连接层 dense 进行线性变换
        hidden_states = self.dense(hidden_states)
        # 对线性变换后的隐藏状态进行 dropout 操作,以防止过拟合
        hidden_states = self.dropout(hidden_states)
        # 对输入进行局部加法的量化处理
        add_local = self.add_local_input_quantizer(hidden_states)
        # 对输入张量进行残差加法的量化处理
        add_residual = self.add_residual_input_quantizer(input_tensor)
        # 将量化后的局部加法和残差加法结果进行 LayerNorm 归一化
        hidden_states = self.LayerNorm(add_local + add_residual)
        # 返回处理后的隐藏状态作为输出
        return hidden_states
# 基于 transformers.models.bert.modeling_bert.BertAttention 更改为 QDQBert
class QDQBertAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化自注意力层和自注意力输出层
        self.self = QDQBertSelfAttention(config)
        self.output = QDQBertSelfOutput(config)
        # 存储被修剪的注意力头的集合
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        # 查找可修剪的注意力头并返回索引
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )

        # 修剪线性层
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # 更新超参数并存储修剪的注意力头
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        # 执行自注意力计算
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
        )
        # 使用自注意力输出层处理自注意力结果和隐藏状态
        attention_output = self.output(self_outputs[0], hidden_states)
        # 如果需要输出注意力,将其添加到输出中
        outputs = (attention_output,) + self_outputs[1:]
        return outputs


class QDQBertIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 量化线性层
        self.dense = quant_nn.QuantLinear(config.hidden_size, config.intermediate_size)
        # 初始化中间激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states):
        # 通过量化线性层处理隐藏状态
        hidden_states = self.dense(hidden_states)
        # 应用中间激活函数
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class QDQBertOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        # Quantize Linear layer
        # 使用量化的神经网络层来定义一个线性层,输入大小为config.intermediate_size,输出大小为config.hidden_size
        self.dense = quant_nn.QuantLinear(config.intermediate_size, config.hidden_size)
        
        # Layer normalization 层,对隐藏状态进行归一化处理
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        
        # Dropout 层,以config.hidden_dropout_prob的概率对输入进行随机置零,用于防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # Quantize the inputs to the residual add
        # 对残差加法的输入进行量化处理,使用默认的输入量化描述符
        self.add_local_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
        self.add_residual_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)

    def forward(self, hidden_states, input_tensor):
        # 将隐藏状态传入量化的线性层进行处理
        hidden_states = self.dense(hidden_states)
        
        # 对处理后的隐藏状态进行dropout操作,以减少过拟合
        hidden_states = self.dropout(hidden_states)
        
        # Quantize the inputs to the residual add
        # 对残差加法的输入进行量化处理
        add_local = self.add_local_input_quantizer(hidden_states)
        add_residual = self.add_residual_input_quantizer(input_tensor)
        
        # 对量化后的本地加法和残差加法进行 Layer normalization 处理
        hidden_states = self.LayerNorm(add_local + add_residual)
        
        # 返回处理后的隐藏状态
        return hidden_states
# 根据 transformers.models.bert.modeling_bert.BertLayer 修改为 QDQBertLayer,是 QDQ 模型的 Bert 层
class QDQBertLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 设置序列长度维度为 1
        self.seq_len_dim = 1
        # 初始化 QDQBertAttention 层
        self.attention = QDQBertAttention(config)
        # 标记是否为解码器模型
        self.is_decoder = config.is_decoder
        # 标记是否添加跨层注意力
        self.add_cross_attention = config.add_cross_attention
        # 如果添加了跨层注意力
        if self.add_cross_attention:
            # 如果不是解码器模型,抛出异常
            if not self.is_decoder:
                raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
            # 初始化跨层注意力 QDQBertAttention 层
            self.crossattention = QDQBertAttention(config)
        # 初始化 QDQBertIntermediate 层
        self.intermediate = QDQBertIntermediate(config)
        # 初始化 QDQBertOutput 层
        self.output = QDQBertOutput(config)

    # 前向传播函数
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        # 如果过去的键/值元组不为空,则从中提取解码器单向自注意力的缓存键/值,位置为1和2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
        # 使用自注意力层处理隐藏状态,生成自注意力输出
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            past_key_value=self_attn_past_key_value,
        )
        # 提取自注意力输出的注意力部分
        attention_output = self_attention_outputs[0]

        # 如果是解码器,最后一个输出是自注意力缓存的元组
        if self.is_decoder:
            # 去除第一个和最后一个元素(自注意力输出中的自注意力元组和最后一个是自注意力缓存)
            outputs = self_attention_outputs[1:-1]
            # 获取当前自注意力的键/值元组
            present_key_value = self_attention_outputs[-1]
        else:
            # 如果不是解码器,从自注意力输出中去除第一个元素(自注意力输出中的自注意力元组)
            outputs = self_attention_outputs[1:]  # 如果输出注意力权重,则添加自注意力
        

        cross_attn_present_key_value = None
        # 如果是解码器且有编码器隐藏状态
        if self.is_decoder and encoder_hidden_states is not None:
            # 如果模型没有交叉注意力层,则引发错误
            if not hasattr(self, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
                    " by setting `config.add_cross_attention=True`"
                )

            # 如果过去的键/值元组不为空,则从中提取交叉注意力的缓存键/值,位置为3和4
            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
            # 使用交叉注意力层处理自注意力输出和编码器隐藏状态,生成交叉注意力输出
            cross_attention_outputs = self.crossattention(
                attention_output,
                attention_mask,
                head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                cross_attn_past_key_value,
                output_attentions,
            )
            # 提取交叉注意力输出的注意力部分
            attention_output = cross_attention_outputs[0]
            # 将交叉注意力输出中的除了第一个和最后一个元素以外的部分添加到输出中(如果输出注意力权重)
            outputs = outputs + cross_attention_outputs[1:-1]

            # 将交叉注意力的当前键/值添加到当前键/值中
            cross_attn_present_key_value = cross_attention_outputs[-1]
            present_key_value = present_key_value + cross_attn_present_key_value

        # 使用前馈网络块处理注意力输出
        layer_output = self.feed_forward_chunk(attention_output)
        # 将前馈网络块的输出作为第一个元素,连接到输出中
        outputs = (layer_output,) + outputs

        # 如果是解码器,将注意力键/值作为最后一个输出返回
        if self.is_decoder:
            outputs = outputs + (present_key_value,)

        return outputs

    # 定义前馈网络块函数,输入为注意力输出
    def feed_forward_chunk(self, attention_output):
        # 使用中间层处理注意力输出
        intermediate_output = self.intermediate(attention_output)
        # 使用输出层处理中间输出和注意力输出,生成层输出
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output
# 根据 transformers.models.bert.modeling_bert.BertEncoder 修改为 QDQBertEncoder,表示这是一个基于 QDQBert 的编码器类
class QDQBertEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config  # 初始化时保存配置参数
        # 创建多个 QDQBertLayer 层组成的列表,数量等于配置中指定的隐藏层数量
        self.layer = nn.ModuleList([QDQBertLayer(config) for _ in range(config.num_hidden_layers)])
        self.gradient_checkpointing = False  # 初始化梯度检查点标志为 False

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True,
        ):
        # 如果不需要输出隐藏状态,则初始化一个空元组
        all_hidden_states = () if output_hidden_states else None
        # 如果不需要输出注意力权重,则初始化一个空元组
        all_self_attentions = () if output_attentions else None
        # 如果不需要输出交叉注意力权重,且配置允许,则初始化一个空元组
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

        # 如果不使用缓存,则初始化一个空元组
        next_decoder_cache = () if use_cache else None
        # 遍历每一个解码器层
        for i, layer_module in enumerate(self.layer):
            # 如果需要输出隐藏状态,则将当前隐藏状态添加到所有隐藏状态中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 如果指定了解码器头部掩码,则获取当前层的掩码;否则为None
            layer_head_mask = head_mask[i] if head_mask is not None else None
            # 如果指定了过去的键值对,则获取当前层的过去键值对;否则为None
            past_key_value = past_key_values[i] if past_key_values is not None else None

            # 如果启用了梯度检查点且处于训练模式
            if self.gradient_checkpointing and self.training:
                # 如果同时使用缓存,则发出警告并设置不使用缓存
                if use_cache:
                    logger.warning_once(
                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                    )
                    use_cache = False
                # 调用梯度检查点函数,计算当前层的输出
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )
            else:
                # 否则,直接调用当前层模块计算当前层的输出
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )

            # 更新当前隐藏状态为当前层的输出的第一个元素
            hidden_states = layer_outputs[0]
            # 如果使用缓存,则将当前层的输出的最后一个元素添加到下一个解码器缓存中
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)
            # 如果需要输出注意力权重,则将当前层的注意力权重添加到所有自注意力权重中
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)
                # 如果配置允许且需要添加交叉注意力权重,则将当前层的交叉注意力权重添加到所有交叉注意力权重中
                if self.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

        # 如果需要输出隐藏状态,则将最终隐藏状态添加到所有隐藏状态中
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 如果不返回字典形式的结果,则按顺序返回非空对象
        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_decoder_cache,
                    all_hidden_states,
                    all_self_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )
        # 否则,返回带过去键值对和交叉注意力权重的基础模型输出对象
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )
# 从transformers.models.bert.modeling_bert.BertPooler复制过来,将Bert改为QDQBert
class QDQBertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 密集连接层,输入和输出大小都是config.hidden_size
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 激活函数Tanh
        self.activation = nn.Tanh()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 只使用第一个token对应的隐藏状态来“池化”模型
        first_token_tensor = hidden_states[:, 0]
        # 将第一个token的隐藏状态输入密集连接层
        pooled_output = self.dense(first_token_tensor)
        # 应用激活函数Tanh
        pooled_output = self.activation(pooled_output)
        return pooled_output


# 从transformers.models.bert.modeling_bert.BertPredictionHeadTransform复制过来,将Bert改为QDQBert
class QDQBertPredictionHeadTransform(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 密集连接层,输入和输出大小都是config.hidden_size
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 根据配置选择激活函数,支持字符串或函数形式
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        # LayerNorm层,输入大小为config.hidden_size,epsilon为config.layer_norm_eps
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 将隐藏状态输入密集连接层
        hidden_states = self.dense(hidden_states)
        # 应用激活函数
        hidden_states = self.transform_act_fn(hidden_states)
        # 应用LayerNorm
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


# 基于transformers.models.bert.modeling_bert.BertLMPredictionHead,将Bert改为QDQBert
class QDQBertLMPredictionHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 使用QDQBertPredictionHeadTransform处理隐藏状态
        self.transform = QDQBertPredictionHeadTransform(config)

        # 输出权重与输入嵌入相同,但每个token有一个仅输出的偏置
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # 偏置参数,与resize_token_embeddings正确调整大小的链接
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))

        # 需要连接这两个变量,以便偏置与`resize_token_embeddings`正确调整大小
        self.decoder.bias = self.bias

    def forward(self, hidden_states):
        hidden_states = self.transform(hidden_states)
        hidden_states = self.decoder(hidden_states)
        return hidden_states


# 基于transformers.models.bert.modeling_bert.BertOnlyMLMHead,将Bert改为QDQBert
class QDQBertOnlyMLMHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 使用QDQBertLMPredictionHead进行预测
        self.predictions = QDQBertLMPredictionHead(config)

    def forward(self, sequence_output):
        # 对序列输出进行预测
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


# 从transformers.models.bert.modeling_bert.BertOnlyNSPHead复制过来,将Bert改为QDQBert
class QDQBertOnlyNSPHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 用于二分类的线性层,输入大小为config.hidden_size,输出大小为2
        self.seq_relationship = nn.Linear(config.hidden_size, 2)
    # 定义一个方法 `forward`,接受参数 `pooled_output`
    def forward(self, pooled_output):
        # 调用模型中的 `seq_relationship` 方法,传入 `pooled_output` 参数,计算序列关系得分
        seq_relationship_score = self.seq_relationship(pooled_output)
        # 返回计算得到的序列关系得分
        return seq_relationship_score
# 根据 transformers.models.bert.modeling_bert.BertPreTrainingHeads 更改为 QDQBertPreTrainingHeads,并将 Bert 替换为 QDQBert
class QDQBertPreTrainingHeads(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 使用 QDQBertLMPredictionHead 初始化预测头部
        self.predictions = QDQBertLMPredictionHead(config)
        # 使用线性层初始化序列关系头部,输出维度为2
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, sequence_output, pooled_output):
        # 对序列输出进行预测
        prediction_scores = self.predictions(sequence_output)
        # 对汇总输出进行序列关系预测
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score


# 根据 transformers.models.bert.modeling_bert.BertPreTrainedModel 更改为 QDQBertPreTrainedModel,并将 Bert 替换为 QDQBert
class QDQBertPreTrainedModel(PreTrainedModel):
    """
    一个抽象类,处理权重初始化以及下载和加载预训练模型的简单接口。
    """

    # 使用 QDQBertConfig 作为配置类
    config_class = QDQBertConfig
    # 使用 load_tf_weights_in_qdqbert 来加载 TF 权重
    load_tf_weights = load_tf_weights_in_qdqbert
    # 模型基础名称前缀设置为 "bert"
    base_model_prefix = "bert"
    # 支持梯度检查点
    supports_gradient_checkpointing = True

    def _init_weights(self, module):
        """初始化权重"""
        if isinstance(module, nn.Linear):
            # 稍微不同于 TF 版本,使用正态分布初始化权重
            # 参考 https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


QDQBERT_START_DOCSTRING = r"""
    此模型继承自 [`PreTrainedModel`]。查看超类文档以了解库实现的通用方法(例如下载或保存模型、调整输入嵌入、修剪头等)。

    此模型还是 PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) 的子类。
    将其用作常规的 PyTorch 模块,并参考 PyTorch 文档以获取所有与一般使用和行为相关的事项。

    参数:
        config ([`QDQBertConfig`]): 包含模型所有参数的配置类。
            使用配置文件初始化不会加载与模型相关的权重,只加载配置。查看 [`~PreTrainedModel.from_pretrained`] 方法以加载模型权重。
"""

QDQBERT_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            # 输入序列中的token索引,用于词汇表中的位置。

            # 可以使用`AutoTokenizer`获得这些索引。详情请参见`PreTrainedTokenizer.encode`和`PreTrainedTokenizer.__call__`。

            # [什么是input IDs?](../glossary#input-ids)
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            # 遮盖掩模,用于避免在填充token索引上进行注意力计算。遮盖值为0或1:

            # - 1 表示对**未遮盖**的token进行注意力计算,
            # - 0 表示对**遮盖**的token进行注意力计算。

            # [什么是attention masks?](../glossary#attention-mask)
        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            # 段落token索引,指示输入中第一部分和第二部分。索引为0或1:

            # - 0 对应*句子A*的token,
            # - 1 对应*句子B*的token。

            # [什么是token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            # 输入序列token在位置嵌入中的位置索引。选择范围为`[0, config.max_position_embeddings - 1]`。

            # [什么是position IDs?](../glossary#position-ids)
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            # 用于屏蔽自注意力模块中选定头部的遮盖。遮盖值为0或1:

            # - 1 表示该头部**未遮盖**,
            # - 0 表示该头部**遮盖**。

        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
            # 可选地,您可以直接传递嵌入表示而不是`input_ids`。如果您想对如何将`input_ids`索引转换为相关向量有更多控制权,这很有用。

        output_attentions (`bool`, *optional*):
            # 是否返回所有注意力层的注意力张量。返回的张量中的`attentions`部分会有更多细节。

        output_hidden_states (`bool`, *optional*):
            # 是否返回所有层的隐藏状态。返回的张量中的`hidden_states`部分会有更多细节。

        return_dict (`bool`, *optional*):
            # 是否返回[`~utils.ModelOutput`]而不是普通元组。
    """

    @add_start_docstrings(
        "The bare QDQBERT Model transformer outputting raw hidden-states without any specific head on top.",
        QDQBERT_START_DOCSTRING,
    )
    class QDQBertModel(QDQBertPreTrainedModel):
        """

        The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
        cross-attention is added between the self-attention layers, following the architecture described in [Attention is
        all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
        Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.

        To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
        to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
        `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
        """

        def __init__(self, config, add_pooling_layer: bool = True):
            requires_backends(self, "pytorch_quantization")
            super().__init__(config)
            self.config = config

            # Initialize the embeddings layer using QDQBertEmbeddings class
            self.embeddings = QDQBertEmbeddings(config)
            # Initialize the encoder layer using QDQBertEncoder class
            self.encoder = QDQBertEncoder(config)

            # Initialize the pooler layer if add_pooling_layer is set to True
            self.pooler = QDQBertPooler(config) if add_pooling_layer else None

            # Initialize weights and apply final processing
            self.post_init()

        def get_input_embeddings(self):
            # Return the word embeddings from the embeddings layer
            return self.embeddings.word_embeddings

        def set_input_embeddings(self, value):
            # Set the word embeddings in the embeddings layer to the given value
            self.embeddings.word_embeddings = value

        def _prune_heads(self, heads_to_prune: Dict[int, List[int]]):
            """
            Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
            class PreTrainedModel
            """
            for layer, heads in heads_to_prune.items():
                # Prune heads specified in heads_to_prune for the attention layer in each encoder layer
                self.encoder.layer[layer].attention.prune_heads(heads)

        @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
        @add_code_sample_docstrings(
            checkpoint=_CHECKPOINT_FOR_DOC,
            output_type=BaseModelOutputWithPoolingAndCrossAttentions,
            config_class=_CONFIG_FOR_DOC,
        )
    ```
    # 正向传播函数,用于模型的前向推理过程
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,  # 输入的token id序列,数据类型为长整型Tensor,可选参数
        attention_mask: Optional[torch.FloatTensor] = None,  # 注意力掩码,数据类型为浮点型Tensor,可选参数
        token_type_ids: Optional[torch.LongTensor] = None,  # token类型id,数据类型为长整型Tensor,可选参数
        position_ids: Optional[torch.LongTensor] = None,  # 位置id,数据类型为长整型Tensor,可选参数
        head_mask: Optional[torch.FloatTensor] = None,  # 头部掩码,数据类型为浮点型Tensor,可选参数
        inputs_embeds: Optional[torch.FloatTensor] = None,  # 输入的嵌入向量,数据类型为浮点型Tensor,可选参数
        encoder_hidden_states: Optional[torch.FloatTensor] = None,  # 编码器的隐藏状态,数据类型为浮点型Tensor,可选参数
        encoder_attention_mask: Optional[torch.FloatTensor] = None,  # 编码器的注意力掩码,数据类型为浮点型Tensor,可选参数
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,  # 过去的键值对,数据类型为嵌套元组的浮点型Tensor,可选参数
        use_cache: Optional[bool] = None,  # 是否使用缓存,数据类型为布尔型,可选参数
        output_attentions: Optional[bool] = None,  # 是否输出注意力权重,数据类型为布尔型,可选参数
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态,数据类型为布尔型,可选参数
        return_dict: Optional[bool] = None,  # 是否返回字典格式结果,数据类型为布尔型,可选参数
# 使用装饰器为模型添加文档字符串,指定了其用途为在 CLM 微调中使用语言建模头部的 QDQBERT 模型
@add_start_docstrings(
    """QDQBERT Model with a `language modeling` head on top for CLM fine-tuning.""", QDQBERT_START_DOCSTRING
)
# 定义 QDQBertLMHeadModel 类,继承自 QDQBertPreTrainedModel
class QDQBertLMHeadModel(QDQBertPreTrainedModel):
    # 定义了一组关键字列表,用于指定需要共享权重的参数键名
    _tied_weights_keys = ["predictions.decoder.weight", "predictions.decoder.bias"]

    # 初始化方法,接受配置参数 config
    def __init__(self, config):
        # 调用父类初始化方法
        super().__init__(config)

        # 如果配置中未指定为解码器,则发出警告
        if not config.is_decoder:
            logger.warning("If you want to use `QDQBertLMHeadModel` as a standalone, add `is_decoder=True.`")

        # 创建 QDQBertModel 实例,并禁用添加池化层
        self.bert = QDQBertModel(config, add_pooling_layer=False)
        
        # 创建 QDQBertOnlyMLMHead 实例
        self.cls = QDQBertOnlyMLMHead(config)

        # 执行后续的初始化和权重处理
        self.post_init()

    # 返回输出嵌入层的方法
    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    # 设置输出嵌入层的方法
    def set_output_embeddings(self, new_embeddings):
        self.cls.predictions.decoder = new_embeddings

    # forward 方法用于模型前向传播
    @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.LongTensor]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 略

    # prepare_inputs_for_generation 方法准备生成的输入
    def prepare_inputs_for_generation(
        self,
        input_ids: Optional[torch.LongTensor],
        past_key_values=None,
        attention_mask: Optional[torch.Tensor] = None,
        **model_kwargs,
    ):
        # 获取输入张量的形状
        input_shape = input_ids.shape
        
        # 如果没有给定注意力遮罩,则创建全为1的遮罩张量
        if attention_mask is None:
            attention_mask = input_ids.new_ones(input_shape)

        # 如果给定了过去的键值对,则根据过去的键值对调整输入的 input_ids
        if past_key_values is not None:
            # 获取过去键值对的长度
            past_length = past_key_values[0][0].shape[2]

            # 如果输入的 input_ids 长度大于过去的长度,则截取掉前面的部分
            if input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                # 否则,默认只保留最后一个输入 ID
                remove_prefix_length = input_ids.shape[1] - 1

            input_ids = input_ids[:, remove_prefix_length:]

        # 返回一个包含更新后的 input_ids、attention_mask 和 past_key_values 的字典
        return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}

    # 重新排序缓存中的过去键值对,以便与 beam 搜索索引对应
    def _reorder_cache(self, past_key_values, beam_idx):
        # 初始化一个空的重新排序后的过去键值对元组
        reordered_past = ()
        
        # 对每一层的过去状态进行重新排序
        for layer_past in past_key_values:
            reordered_past += (
                # 将每个过去状态按照 beam_idx 列表重新排序,并转移到正确的设备上
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        
        # 返回重新排序后的过去键值对元组
        return reordered_past
@add_start_docstrings("""QDQBERT Model with a `language modeling` head on top.""", QDQBERT_START_DOCSTRING)
class QDQBertForMaskedLM(QDQBertPreTrainedModel):
    _tied_weights_keys = ["predictions.decoder.weight", "predictions.decoder.bias"]

    def __init__(self, config):
        super().__init__(config)

        if config.is_decoder:
            # 如果配置为解码器,警告用户使用双向自注意力时需将 `config.is_decoder` 设为 False
            logger.warning(
                "If you want to use `QDQBertForMaskedLM` make sure `config.is_decoder=False` for "
                "bi-directional self-attention."
            )

        # 使用配置初始化 QDQBERT 模型,禁用添加池化层
        self.bert = QDQBertModel(config, add_pooling_layer=False)
        # 初始化 MLM 头部
        self.cls = QDQBertOnlyMLMHead(config)

        # 初始化权重并进行最终处理
        self.post_init()

    def get_output_embeddings(self):
        # 返回 MLM 头部的解码器权重
        return self.cls.predictions.decoder

    def set_output_embeddings(self, new_embeddings):
        # 设置 MLM 头部的解码器权重
        self.cls.predictions.decoder = new_embeddings

    @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=MaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        """
        # 初始化 return_dict 变量,如果 return_dict 参数非空则使用其值,否则使用 self.config.use_return_dict 的值
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用 BERT 模型进行前向传播
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取 BERT 输出的序列输出
        sequence_output = outputs[0]
        
        # 使用分类层进行预测得分的计算
        prediction_scores = self.cls(sequence_output)

        masked_lm_loss = None
        # 如果 labels 参数不为空,则计算 masked language modeling 的损失
        if labels is not None:
            loss_fct = CrossEntropyLoss()  # 定义交叉熵损失函数,用于计算损失
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

        # 如果 return_dict 为 False,则返回 tuple 类型的输出
        if not return_dict:
            output = (prediction_scores,) + outputs[2:]  # 构建输出元组
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        # 如果 return_dict 为 True,则返回 MaskedLMOutput 对象
        return MaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self, input_ids: torch.LongTensor, attention_mask: Optional[torch.FloatTensor] = None, **model_kwargs
    ):
        # 获取输入张量的形状和有效的 batch 大小
        input_shape = input_ids.shape
        effective_batch_size = input_shape[0]

        # 如果配置文件中的 pad_token_id 为空,则抛出 ValueError 异常
        if self.config.pad_token_id is None:
            raise ValueError("The PAD token should be defined for generation")

        # 扩展 attention_mask,增加一个全零列
        attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
        
        # 创建一个全为 pad_token_id 的 dummy_token 张量,并将其连接到 input_ids 后面
        dummy_token = torch.full(
            (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
        )
        input_ids = torch.cat([input_ids, dummy_token], dim=1)

        # 返回包含输入张量和 attention_mask 的字典
        return {"input_ids": input_ids, "attention_mask": attention_mask}
# 使用指定的文档字符串初始化一个带有“下一个句子预测(分类)”头部的Bert模型。
@add_start_docstrings(
    """Bert Model with a `next sentence prediction (classification)` head on top.""",
    QDQBERT_START_DOCSTRING,
)
# 创建一个QDQBertForNextSentencePrediction类,继承自QDQBertPreTrainedModel类。
class QDQBertForNextSentencePrediction(QDQBertPreTrainedModel):
    # 初始化方法,接受一个配置对象作为参数。
    def __init__(self, config):
        # 调用父类的初始化方法。
        super().__init__(config)

        # 实例化一个QDQBertModel对象,作为BERT模型的主体。
        self.bert = QDQBertModel(config)
        # 实例化一个QDQBertOnlyNSPHead对象,作为只包含NSP头部的模型组件。
        self.cls = QDQBertOnlyNSPHead(config)

        # 调用自定义的初始化方法,用于初始化权重并进行最终的处理。
        self.post_init()

    # 前向传播方法,接受多个输入参数和关键字参数。
    @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs,
        ) -> Union[Tuple, NextSentencePredictorOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
            (see `input_ids` docstring). Indices should be in `[0, 1]`:

            - 0 indicates sequence B is a continuation of sequence A,
            - 1 indicates sequence B is a random sequence.

        Returns:
            If `return_dict=True`, returns a `NextSentencePredictorOutput` object containing:
                - loss (`torch.FloatTensor`, *optional*): Next sentence prediction loss.
                - logits (`torch.FloatTensor` of shape `(batch_size, 2)`): Scores for next sentence prediction.
                - hidden_states (`Optional[Tuple[torch.FloatTensor]]`): Tuple of hidden states at each layer of the model.
                - attentions (`Optional[Tuple[torch.FloatTensor]]`): Tuple of attention tensors at each layer of the model.

            If `return_dict=False`, returns a tuple:
                - next_sentence_loss (`Optional[torch.FloatTensor]`): Next sentence prediction loss.
                - seq_relationship_scores (`torch.FloatTensor`): Scores for next sentence prediction.
                - hidden_states (`Optional[Tuple[torch.FloatTensor]]`): Tuple of hidden states at each layer of the model.
                - attentions (`Optional[Tuple[torch.FloatTensor]]`): Tuple of attention tensors at each layer of the model.

        Example:

        ```
        >>> from transformers import AutoTokenizer, QDQBertForNextSentencePrediction
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
        >>> model = QDQBertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased")

        >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
        >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
        >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")

        >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
        >>> logits = outputs.logits
        >>> assert logits[0, 0] < logits[0, 1]  # next sentence was random
        ```

        Check if `next_sentence_label` is provided in `kwargs`; issue a warning and use `labels` instead if found.
        """
        
        if "next_sentence_label" in kwargs:
            warnings.warn(
                "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
                " `labels` instead.",
                FutureWarning,
            )
            labels = kwargs.pop("next_sentence_label")

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        pooled_output = outputs[1]

        seq_relationship_scores = self.cls(pooled_output)

        next_sentence_loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))

        if not return_dict:
            output = (seq_relationship_scores,) + outputs[2:]
            return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output

        return NextSentencePredictorOutput(
            loss=next_sentence_loss,
            logits=seq_relationship_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
@add_start_docstrings(
    """
    Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    """,
    QDQBERT_START_DOCSTRING,
)
class QDQBertForSequenceClassification(QDQBertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels  # 从配置中获取标签数量
        self.config = config

        self.bert = QDQBertModel(config)  # 使用给定配置初始化 QDQBertModel
        self.dropout = nn.Dropout(config.hidden_dropout_prob)  # 根据配置设置 dropout 层
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)  # 创建一个线性层用于分类,输入维度为隐藏大小,输出维度为标签数量
        # 初始化权重并进行最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=SequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,  # 输入的token IDs
        attention_mask: Optional[torch.FloatTensor] = None,  # 注意力掩码
        token_type_ids: Optional[torch.LongTensor] = None,  # token 类型 IDs
        position_ids: Optional[torch.LongTensor] = None,  # 位置 IDs
        head_mask: Optional[torch.FloatTensor] = None,  # 头部掩码
        inputs_embeds: Optional[torch.FloatTensor] = None,  # 嵌入输入
        labels: Optional[torch.LongTensor] = None,  # 标签
        output_attentions: Optional[bool] = None,  # 是否输出注意力
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态
        return_dict: Optional[bool] = None,  # 是否返回字典形式结果
        ):
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        # 如果 return_dict 不为 None,则使用其值;否则使用 self.config.use_return_dict 的值
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 将输入传递给 BERT 模型进行处理,获取模型的输出
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从 BERT 模型的输出中获取池化后的特征表示
        pooled_output = outputs[1]

        # 对池化后的特征表示应用 dropout 操作
        pooled_output = self.dropout(pooled_output)
        
        # 将 dropout 后的特征表示输入分类器,得到 logits(预测值)
        logits = self.classifier(pooled_output)

        # 初始化损失为 None
        loss = None

        # 如果传入了 labels,则计算损失
        if labels is not None:
            # 如果问题类型未指定,则根据 num_labels 和 labels 的数据类型进行判断
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            # 根据问题类型计算相应的损失
            if self.config.problem_type == "regression":
                loss_fct = MSELoss()  # 使用均方误差损失函数
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()  # 使用交叉熵损失函数
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()  # 使用带 logits 的二元交叉熵损失函数
                loss = loss_fct(logits, labels)

        # 如果 return_dict 为 False,则输出格式为元组
        if not return_dict:
            output = (logits,) + outputs[2:]  # 包括 logits 和其他输出状态
            return ((loss,) + output) if loss is not None else output

        # 如果 return_dict 为 True,则输出格式为 SequenceClassifierOutput 对象
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
# 使用装饰器为类添加文档字符串,描述了该类的功能和用途,特别是在多选分类任务中使用 BERT 模型的情况
@add_start_docstrings(
    """
    Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
    softmax) e.g. for RocStories/SWAG tasks.
    """,
    QDQBERT_START_DOCSTRING,
)
class QDQBertForMultipleChoice(QDQBertPreTrainedModel):
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)

        # 初始化 BERT 模型
        self.bert = QDQBertModel(config)
        # 添加 dropout 层,以防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # 添加一个线性层作为分类器,输入尺寸为隐藏状态的尺寸,输出尺寸为1(用于二分类)
        self.classifier = nn.Linear(config.hidden_size, 1)

        # 初始化权重并进行最终处理
        self.post_init()

    # 使用装饰器为 forward 方法添加文档字符串,描述了该方法的输入和输出
    @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=MultipleChoiceModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        # forward 方法的输入参数,用于多选分类任务的 BERT 模型
        ) -> Union[Tuple, MultipleChoiceModelOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
            `input_ids` above)
        """
        # 确定是否返回字典类型的输出,若未指定则使用模型配置中的默认设置
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # 获取输入张量的第二维大小,即选项的数量
        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

        # 重塑输入张量以便进行批处理处理
        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
        inputs_embeds = (
            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
            if inputs_embeds is not None
            else None
        )

        # 使用BERT模型进行前向传播计算
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 提取汇聚后的输出
        pooled_output = outputs[1]

        # 对汇聚输出进行dropout处理
        pooled_output = self.dropout(pooled_output)
        # 使用分类器对处理后的特征进行分类预测
        logits = self.classifier(pooled_output)
        # 重塑logits张量以匹配多选项问题的形状
        reshaped_logits = logits.view(-1, num_choices)

        # 初始化损失为None
        loss = None
        # 如果提供了标签,则计算交叉熵损失
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)

        # 如果不要求返回字典类型的输出,则返回一个元组
        if not return_dict:
            output = (reshaped_logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 如果要求返回字典类型的输出,则返回MultipleChoiceModelOutput对象
        return MultipleChoiceModelOutput(
            loss=loss,
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
QDQBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
"""
QDQBERT_START_DOCSTRING,
)
class QDQBertForTokenClassification(QDQBertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        # Initialize QDQBertModel with provided configuration
        self.bert = QDQBertModel(config, add_pooling_layer=False)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        # Initialize weights and apply final processing
        self.post_init()

    @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, TokenClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Perform forward pass through QDQBertModel
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        # Apply dropout to the sequence output
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            # Calculate CrossEntropyLoss if labels are provided
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            # Prepare output tuple if return_dict is False
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # Return TokenClassifierOutput if return_dict is True
        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
    QDQBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
    QDQBERT_START_DOCSTRING,


# QDQBERT 模型,顶部带有用于类似 SQuAD 的抽取式问答任务的跨度分类头部(在隐藏状态输出之上的线性层,用于计算“跨度起始对数”和“跨度终止对数”)。
# QDQBERT_START_DOCSTRING 是用于文档字符串的起始标记。
)
# 结束 QDQBertForQuestionAnswering 类的定义

class QDQBertForQuestionAnswering(QDQBertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        # 初始化 BERT 模型,不添加池化层
        self.bert = QDQBertModel(config, add_pooling_layer=False)
        # 线性层,用于答案抽取任务的输出
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        # 初始化权重并应用最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=QuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    # 前向传播函数,接受多个输入参数
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        start_positions: Optional[torch.LongTensor] = None,
        end_positions: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, QuestionAnsweringModelOutput]:
        r"""
        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        """
        # 默认情况下,如果 return_dict 为 None,则使用 self.config.use_return_dict
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 使用 BERT 模型处理输入数据,输出包括 sequence_output 和其他附加信息
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 取出 BERT 输出的 sequence_output,即模型最后一层的输出
        sequence_output = outputs[0]

        # 将 sequence_output 传入 QA 输出层,得到起始位置和结束位置的 logits
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()  # 去除多余的维度,使得数据连续
        end_logits = end_logits.squeeze(-1).contiguous()  # 去除多余的维度,使得数据连续

        total_loss = None
        # 如果给定了起始和结束位置,则计算损失
        if start_positions is not None and end_positions is not None:
            # 如果在多 GPU 环境中,添加一个维度
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # 将超出模型输入的起始和结束位置修正到有效范围内
            ignored_index = start_logits.size(1)
            start_positions = start_positions.clamp(0, ignored_index)
            end_positions = end_positions.clamp(0, ignored_index)

            # 使用交叉熵损失函数计算起始位置和结束位置的损失
            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

        # 如果不需要返回字典格式的输出,则直接返回 logits 和其他附加信息
        if not return_dict:
            output = (start_logits, end_logits) + outputs[2:]  # 包括除了 sequence_output 外的其他输出
            return ((total_loss,) + output) if total_loss is not None else output

        # 返回格式化的 QuestionAnsweringModelOutput 对象,包括损失、起始和结束 logits,以及其他附加信息
        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

.\models\qdqbert\__init__.py

# 版权声明和许可信息,指出代码的版权和许可细则
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入类型检查模块的标记
from typing import TYPE_CHECKING

# 导入自定义的异常和模块懒加载工具函数
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available

# 定义模块的导入结构
_import_structure = {"configuration_qdqbert": ["QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "QDQBertConfig"]}

# 检查是否有torch可用,若不可用则抛出OptionalDependencyNotAvailable异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 若torch可用,则更新模块导入结构
    _import_structure["modeling_qdqbert"] = [
        "QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
        "QDQBertForMaskedLM",
        "QDQBertForMultipleChoice",
        "QDQBertForNextSentencePrediction",
        "QDQBertForQuestionAnswering",
        "QDQBertForSequenceClassification",
        "QDQBertForTokenClassification",
        "QDQBertLayer",
        "QDQBertLMHeadModel",
        "QDQBertModel",
        "QDQBertPreTrainedModel",
        "load_tf_weights_in_qdqbert",
    ]

# 如果当前为类型检查模式,则导入相关类型
if TYPE_CHECKING:
    from .configuration_qdqbert import QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, QDQBertConfig

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_qdqbert import (
            QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
            QDQBertForMaskedLM,
            QDQBertForMultipleChoice,
            QDQBertForNextSentencePrediction,
            QDQBertForQuestionAnswering,
            QDQBertForSequenceClassification,
            QDQBertForTokenClassification,
            QDQBertLayer,
            QDQBertLMHeadModel,
            QDQBertModel,
            QDQBertPreTrainedModel,
            load_tf_weights_in_qdqbert,
        )

# 如果不是类型检查模式,则进行模块的延迟加载处理
else:
    import sys

    # 将当前模块替换为懒加载模块的实例,用于延迟导入
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\qwen2\configuration_qwen2.py

# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Qwen2 model configuration"""

from ...configuration_utils import PretrainedConfig  # 导入预训练配置类
from ...utils import logging  # 导入日志工具


logger = logging.get_logger(__name__)  # 获取模块的日志记录器

QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "Qwen/Qwen2-7B-beta": "https://huggingface.co/Qwen/Qwen2-7B-beta/resolve/main/config.json",
}


class Qwen2Config(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
    Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of
    Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    ```
    >>> from transformers import Qwen2Model, Qwen2Config

    >>> # Initializing a Qwen2 style configuration
    >>> configuration = Qwen2Config()

    >>> # Initializing a model from the Qwen2-7B style configuration
    >>> model = Qwen2Model(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```

    model_type = "qwen2"  # 模型类型为 Qwen2
    keys_to_ignore_at_inference = ["past_key_values"]  # 推断时忽略的键列表

    def __init__(
        self,
        vocab_size=151936,  # 词汇表大小,默认为 151936
        hidden_size=4096,  # 隐藏层大小,默认为 4096
        intermediate_size=22016,  # 中间层大小,默认为 22016
        num_hidden_layers=32,  # 隐藏层层数,默认为 32
        num_attention_heads=32,  # 注意力头数,默认为 32
        num_key_value_heads=32,  # 键值头数,默认为 32
        hidden_act="silu",  # 隐藏层激活函数,默认为 silu
        max_position_embeddings=32768,  # 最大位置嵌入数,默认为 32768
        initializer_range=0.02,  # 初始化范围,默认为 0.02
        rms_norm_eps=1e-6,  # RMS 归一化参数,默认为 1e-6
        use_cache=True,  # 是否使用缓存,默认为 True
        tie_word_embeddings=False,  # 是否绑定词嵌入,默认为 False
        rope_theta=10000.0,  # ROPE 参数,默认为 10000.0
        use_sliding_window=False,  # 是否使用滑动窗口,默认为 False
        sliding_window=4096,  # 滑动窗口大小,默认为 4096
        max_window_layers=28,  # 最大窗口层数,默认为 28
        attention_dropout=0.0,  # 注意力机制的 dropout,默认为 0.0
        **kwargs,  # 其他关键字参数
        ):
            # 设置模型的超参数
            self.vocab_size = vocab_size
            self.max_position_embeddings = max_position_embeddings
            self.hidden_size = hidden_size
            self.intermediate_size = intermediate_size
            self.num_hidden_layers = num_hidden_layers
            self.num_attention_heads = num_attention_heads
            self.use_sliding_window = use_sliding_window
            self.sliding_window = sliding_window
            self.max_window_layers = max_window_layers

            # 为了向后兼容性
            if num_key_value_heads is None:
                num_key_value_heads = num_attention_heads

            # 设置键值头的数量
            self.num_key_value_heads = num_key_value_heads
            self.hidden_act = hidden_act
            self.initializer_range = initializer_range
            self.rms_norm_eps = rms_norm_eps
            self.use_cache = use_cache
            self.rope_theta = rope_theta
            self.attention_dropout = attention_dropout

            # 调用父类的初始化方法,传入参数和关键字参数
            super().__init__(
                tie_word_embeddings=tie_word_embeddings,
                **kwargs,
            )

.\models\qwen2\modeling_qwen2.py

# coding=utf-8
# 版权所有 2024 年 Qwen 团队,阿里巴巴集团和 HuggingFace Inc. 团队。保留所有权利。
#
# 本代码基于 EleutherAI 的 GPT-NeoX 库和此库中的 GPT-NeoX 和 OPT 实现进行了修改,以适应与 Meta AI 团队训练的模型相比的轻微架构差异。
#
# 根据 Apache 许可证版本 2.0 许可,除非符合许可要求,否则不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则按“原样”分发本软件,
# 没有任何明示或暗示的担保或条件。
# 有关特定语言的权限,请参阅许可证。
""" PyTorch Qwen2 模型。"""
import inspect  # 导入 inspect 模块,用于获取对象的信息
import math  # 导入 math 模块,提供数学函数
import warnings  # 导入 warnings 模块,用于警告控制
from typing import List, Optional, Tuple, Union  # 导入类型提示相关的模块

import torch  # 导入 PyTorch 模块
import torch.nn.functional as F  # 导入 PyTorch 的函数模块
import torch.utils.checkpoint  # 导入 PyTorch 的 checkpoint 模块
from torch import nn  # 从 PyTorch 中导入神经网络模块
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss  # 导入损失函数相关模块

from ...activations import ACT2FN  # 导入激活函数映射
from ...cache_utils import Cache, DynamicCache  # 导入缓存相关的工具函数
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa  # 导入注意力掩码相关的函数
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast  # 导入模型输出类
from ...modeling_utils import PreTrainedModel  # 导入预训练模型的基类
from ...utils import (  # 导入工具函数
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
    logging,
    replace_return_docstrings,
)
from .configuration_qwen2 import Qwen2Config  # 导入 Qwen2 配置类


if is_flash_attn_2_available():
    from flash_attn import flash_attn_func, flash_attn_varlen_func  # 如果可用,导入 flash_attn 的函数
    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa:导入 Bert 相关的填充函数

    _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器


_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"  # 文档用的模型检查点名称
_CONFIG_FOR_DOC = "Qwen2Config"  # 文档用的配置名称

QWEN2_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "Qwen/Qwen2-7B-beta",
    # 查看所有 Qwen2 模型请访问 https://huggingface.co/models?filter=qwen2
]


# 从 transformers.models.llama.modeling_llama._get_unpad_data 复制的函数
def _get_unpad_data(attention_mask):
    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)  # 计算每个样本中非填充部分的序列长度总和
    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()  # 获取非填充位置的索引
    max_seqlen_in_batch = seqlens_in_batch.max().item()  # 找出批次中最大的序列长度
    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))  # 计算累积序列长度
    return (
        indices,
        cu_seqlens,
        max_seqlen_in_batch,
    )


# 从 transformers.models.llama.modeling_llama.LlamaRMSNorm 复制并更名为 Qwen2
class Qwen2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        Qwen2RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))  # 初始化权重参数为全部为1的张量
        self.variance_epsilon = eps  # 设置方差的epsilon值

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype  # 记录输入张量的数据类型
        hidden_states = hidden_states.to(torch.float32)  # 将输入张量转换为float32类型
        variance = hidden_states.pow(2).mean(-1, keepdim=True)  # 计算张量的方差
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)  # 根据方差和epsilon进行归一化
        return self.weight * hidden_states.to(input_dtype)  # 返回归一化后的张量乘以权重参数


# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2
class Qwen2RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim  # 设置维度参数
        self.max_position_embeddings = max_position_embeddings  # 设置最大位置嵌入长度
        self.base = base  # 设置基础值
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)  # 注册频率逆数的缓冲张量

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len  # 记录缓存的最大序列长度
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

        freqs = torch.outer(t, self.inv_freq)  # 计算频率张量
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)  # 拼接cos和sin的嵌入张量
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)  # 注册cos缓存张量
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)  # 注册sin缓存张量

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),  # 返回cos缓存张量
            self.sin_cached[:seq_len].to(dtype=x.dtype),  # 返回sin缓存张量
        )


# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]  # 取输入张量的前一半维度
    x2 = x[..., x.shape[-1] // 2 :]  # 取输入张量的后一半维度
    return torch.cat((-x2, x1), dim=-1)  # 返回将输入张量的后一半维度与前一半维度拼接后的张量


# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors."""
    Args:
        q (`torch.Tensor`): 查询张量。
        k (`torch.Tensor`): 键张量。
        cos (`torch.Tensor`): 旋转嵌入的余弦部分。
        sin (`torch.Tensor`): 旋转嵌入的正弦部分。
        position_ids (`torch.Tensor`):
            查询和键张量对应的位置索引。例如,当使用 KV 缓存时,可以传递偏移的位置 id。
        unsqueeze_dim (`int`, *optional*, 默认为 1):
            'unsqueeze_dim' 参数指定沿其展开 cos[position_ids] 和 sin[position_ids] 的维度,以便它们可以正确地广播到 q 和 k 的维度。
            例如,cos[position_ids] 和 sin[position_ids] 的形状为 [batch_size, seq_len, head_dim]。
            如果 q 和 k 的形状为 [batch_size, heads, seq_len, head_dim],则设置 unsqueeze_dim=1 使得 cos[position_ids] 和 sin[position_ids] 可以广播到 q 和 k 的形状。
            类似地,如果 q 和 k 的形状为 [batch_size, seq_len, heads, head_dim],则设置 unsqueeze_dim=2。

    Returns:
        `tuple(torch.Tensor)`: 包含使用旋转位置嵌入旋转后的查询和键张量的元组。
"""
    # 按照位置索引从 cos 中选择并展开维度
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    # 按照位置索引从 sin 中选择并展开维度
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    # 计算旋转后的查询嵌入
    q_embed = (q * cos) + (rotate_half(q) * sin)
    # 计算旋转后的键嵌入
    k_embed = (k * cos) + (rotate_half(k) * sin)
    # 返回旋转后的查询和键张量
    return q_embed, k_embed
# 从 transformers.models.mistral.modeling_mistral.MistralMLP 复制并修改为 Qwen2MLP
class Qwen2MLP(nn.Module):
    # 初始化方法,接收一个配置对象 config
    def __init__(self, config):
        super().__init__()
        # 将配置对象保存在实例中
        self.config = config
        # 从配置中获取隐藏层大小和中间层大小
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        # 创建一个线性层,用于门控投影,输入维度是隐藏层大小,输出维度是中间层大小,无偏置
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        # 创建一个线性层,用于上游投影,输入维度是隐藏层大小,输出维度是中间层大小,无偏置
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        # 创建一个线性层,用于下游投影,输入维度是中间层大小,输出维度是隐藏层大小,无偏置
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        # 根据配置中的激活函数名,选择相应的激活函数,并保存在实例中
        self.act_fn = ACT2FN[config.hidden_act]

    # 前向传播方法,接收输入张量 x
    def forward(self, x):
        # 对输入张量进行门控投影,然后应用激活函数,再乘以上游投影结果,最后下游投影得到输出
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


# 从 transformers.models.llama.modeling_llama.repeat_kv 复制过来
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    这相当于 torch.repeat_interleave(x, dim=1, repeats=n_rep)。将隐藏状态从 (batch,
    num_key_value_heads, seqlen, head_dim) 扩展为 (batch, num_attention_heads, seqlen, head_dim)
    """
    # 获取隐藏状态张量的形状信息
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    # 如果重复次数 n_rep 为 1,直接返回原始隐藏状态张量
    if n_rep == 1:
        return hidden_states
    # 在第二维度上扩展隐藏状态张量,重复 n_rep 次
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    # 重新整形扩展后的张量,将第二和第三维度合并为新的第二维度
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class Qwen2Attention(nn.Module):
    """
    从 'Attention Is All You Need' 论文中的多头注意力机制修改而来。修改为使用滑动窗口注意力:Longformer
    和 "Generating Long Sequences with Sparse Transformers"。
    """
    def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
        super().__init__()
        self.config = config  # 设置实例的配置参数对象
        self.layer_idx = layer_idx  # 设置实例的层索引,可选

        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
                "lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.hidden_size = config.hidden_size  # 从配置中获取隐藏层大小
        self.num_heads = config.num_attention_heads  # 从配置中获取注意力头的数量
        self.head_dim = self.hidden_size // self.num_heads  # 计算每个注意力头的维度
        self.num_key_value_heads = config.num_key_value_heads  # 从配置中获取键值头的数量
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads  # 计算键值头的组数
        self.max_position_embeddings = config.max_position_embeddings  # 从配置中获取最大位置嵌入数
        self.rope_theta = config.rope_theta  # 从配置中获取绳索旋转角度
        self.is_causal = True  # 设置实例是否因果
        self.attention_dropout = config.attention_dropout  # 从配置中获取注意力丢弃率

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
        # 初始化查询投影层线性变换
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        # 初始化键投影层线性变换
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        # 初始化值投影层线性变换
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
        # 初始化输出投影层线性变换,没有偏置项

        self.rotary_emb = Qwen2RotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )
        # 初始化旋转嵌入层对象,用于处理注意力旋转操作

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs,
# Qwen2FlashAttention2 类,继承自 Qwen2Attention,实现了 Qwen2 闪存注意力模块。
# 该模块主要的改动在于前向传播过程中需要正确调用闪存注意力的公共 API,并处理可能包含的填充标记。
# 另外,对于滑动窗口注意力,仅应用于底部 config.max_window_layers 层。

# 从 transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ 复制而来
def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)

    # TODO: 当 Flash Attention 版本升级到 2.1 后应移除此部分。
    # flash_attn<2.1 生成左上对齐的因果蒙版,而需要的是右下对齐,默认在 flash_attn>=2.1 中已实现。该属性用于处理此差异。
    # 参考: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0。
    # 注意,对于 flash_attn<2.1,当 q_seqlen != k_seqlen(除了 q_seqlen == 1 的情况)会产生错误的蒙版(左上)。
    self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Cache] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    **kwargs,
):
    # Qwen2FlashAttention2 类的前向传播函数
    # hidden_states: 输入的隐藏状态张量
    # attention_mask: 可选的注意力蒙版张量
    # position_ids: 可选的位置 ID 张量
    # past_key_value: 可选的缓存键值对
    # output_attentions: 是否输出注意力权重
    # use_cache: 是否使用缓存
    # **kwargs: 其他关键字参数

def _flash_attention_forward(
    self,
    query_states,
    key_states,
    value_states,
    attention_mask,
    query_length,
    dropout=0.0,
    softmax_scale=None,
    use_sliding_windows=False,
):
    # Qwen2FlashAttention2 类的闪存注意力前向传播函数
    # query_states: 查询状态
    # key_states: 键状态
    # value_states: 值状态
    # attention_mask: 注意力蒙版
    # query_length: 查询长度
    # dropout: dropout 比率,默认为 0.0
    # softmax_scale: softmax 缩放参数,可选
    # use_sliding_windows: 是否使用滑动窗口, 默认为 False

# 从 transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input 复制而来
    # 定义一个方法 _upad_input,接受查询层、键层、值层、注意力掩码和查询长度作为输入参数
    def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
        # 获取批处理大小、键值序列长度、头数、头维度
        batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape

        # 如果键值序列长度不等于注意力掩码的最后一个维度长度
        if kv_seq_len != attention_mask.shape[-1]:
            # 调整注意力掩码,使其匹配键值序列的长度
            attention_mask_num_tokens = attention_mask.shape[-1]
            attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]

        # 从注意力掩码中获取未填充数据的索引、当前序列长度、批处理中最大序列长度
        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)

        # 根据获取的索引重新排序键层和值层
        key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
        value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)

        # 如果查询长度等于键值序列长度
        if query_length == kv_seq_len:
            # 重新排序查询层
            query_layer = index_first_axis(
                query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
            )
            cu_seqlens_q = cu_seqlens_k
            max_seqlen_in_batch_q = max_seqlen_in_batch_k
            indices_q = indices_k
        # 如果查询长度为1
        elif query_length == 1:
            # 将查询层缩减为一维
            max_seqlen_in_batch_q = 1
            cu_seqlens_q = torch.arange(
                batch_size + 1, dtype=torch.int32, device=query_layer.device
            )  # 这里有一个内存复制操作,效率不高。
            indices_q = cu_seqlens_q[:-1]
            query_layer = query_layer.squeeze(1)
        else:
            # 根据查询长度调整注意力掩码,获取未填充的输入
            attention_mask = attention_mask[:, -query_length:]
            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

        # 返回重新排序后的查询层、键层、值层,以及查询层索引、当前序列长度元组和最大序列长度元组
        return (
            query_layer,
            key_layer,
            value_layer,
            indices_q,
            (cu_seqlens_q, cu_seqlens_k),
            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
        )
# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2
class Qwen2SdpaAttention(Qwen2Attention):
    """
    Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
    `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
    SDPA API.
    """

    # Adapted from Qwen2Attention.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
    ):
        """
        Override of forward method from Qwen2Attention to adapt to SDPA API.

        Parameters:
        - hidden_states (torch.Tensor): Input tensor to the attention module.
        - attention_mask (Optional[torch.Tensor]): Mask tensor indicating which elements should be attended to.
        - position_ids (Optional[torch.LongTensor]): Tensor containing positional ids.
        - past_key_value (Optional[Cache]): Cached key value pairs from previous computations.
        - output_attentions (bool): Whether to output attention weights.
        - use_cache (bool): Whether to use cached key value pairs for future computations.

        Returns:
        - Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: Tuple containing:
            - torch.Tensor: Output tensor from the attention module.
            - Optional[torch.Tensor]: Attention weights if `output_attentions` is `True`.
            - Optional[Tuple[torch.Tensor]]: Cached key value pairs if `use_cache` is `True`.
        """
        raise NotImplementedError

QWEN2_ATTENTION_CLASSES = {
    "eager": Qwen2Attention,
    "flash_attention_2": Qwen2FlashAttention2,
    "sdpa": Qwen2SdpaAttention,
}

class Qwen2DecoderLayer(nn.Module):
    def __init__(self, config: Qwen2Config, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size

        if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
            logger.warning_once(
                f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
                "unexpected results may be encountered."
            )
        # Initialize self attention mechanism based on configuration
        self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)

        # Initialize MLP layer
        self.mlp = Qwen2MLP(config)

        # Layer normalization for input to the layer
        self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        # Layer normalization after attention mechanism
        self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        **kwargs,
    ):
        """
        Forward pass for Qwen2 decoder layer.

        Parameters:
        - hidden_states (torch.Tensor): Input tensor to the decoder layer.
        - attention_mask (Optional[torch.Tensor]): Mask tensor indicating which elements should be attended to.
        - position_ids (Optional[torch.LongTensor]): Tensor containing positional ids.
        - past_key_value (Optional[Tuple[torch.Tensor]]): Cached key value pairs from previous computations.
        - output_attentions (Optional[bool]): Whether to output attention weights.
        - use_cache (Optional[bool]): Whether to use cached key value pairs for future computations.
        - **kwargs: Additional keyword arguments for future expansion.

        Returns:
        - Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: Tuple containing:
            - torch.Tensor: Output tensor from the decoder layer.
            - Optional[torch.Tensor]: Attention weights if `output_attentions` is `True`.
            - Optional[Tuple[torch.Tensor]]: Cached key value pairs if `use_cache` is `True`.
        """
        raise NotImplementedError
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        # 如果传入了 `padding_mask` 参数,则发出警告,提示该参数在 v4.37 版本中将被移除
        if "padding_mask" in kwargs:
            warnings.warn(
                "Passing `padding_mask` is deprecated and will be removed in v4.37. "
                "Please make sure use `attention_mask` instead.`"
            )
        """
        Args:
            hidden_states (`torch.FloatTensor`): 输入层的输入,形状为 `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *可选*): 注意力掩码,形状为 `(batch, sequence_length)`,其中填充元素为 0
            output_attentions (`bool`, *可选*):
                是否返回所有注意力层的注意力张量。查看返回的张量中 `attentions` 以获取更多细节。
            use_cache (`bool`, *可选*):
                如果设置为 `True`,将返回 `past_key_values` 键值状态,可用于加速解码(参见 `past_key_values`)。
            past_key_value (`Tuple(torch.FloatTensor)`, *可选*): 缓存的过去键和值投影状态
        """

        # 记录输入的残差连接
        residual = hidden_states

        # 应用输入层归一化
        hidden_states = self.input_layernorm(hidden_states)

        # 自注意力机制
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
        )

        # 将残差连接应用到自注意力输出上
        hidden_states = residual + hidden_states

        # 全连接层
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        # 构建输出
        outputs = (hidden_states,)

        # 如果需要输出注意力权重,则添加到输出中
        if output_attentions:
            outputs += (self_attn_weights,)

        # 如果需要使用缓存,则添加当前的键值状态到输出中
        if use_cache:
            outputs += (present_key_value,)

        return outputs
# QWEN2_START_DOCSTRING 是一个原始字符串文档,描述了该模型的继承和基本使用说明
QWEN2_START_DOCSTRING = r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`Qwen2Config`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

# 使用装饰器 @add_start_docstrings 添加文档注释,说明该类是基于 Qwen2PreTrainedModel 的一个裸模型
@add_start_docstrings(
    "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
    QWEN2_START_DOCSTRING,
)
class Qwen2PreTrainedModel(PreTrainedModel):
    # 设置模型的配置类和模型名称前缀
    config_class = Qwen2Config
    base_model_prefix = "model"
    # 支持梯度检查点
    supports_gradient_checkpointing = True
    # 不拆分的模块列表
    _no_split_modules = ["Qwen2DecoderLayer"]
    # 跳过设备放置的键
    _skip_keys_device_placement = "past_key_values"
    # 支持快闪注意力机制 2
    _supports_flash_attn_2 = True
    # 支持自我注意力分配
    _supports_sdpa = True
    # 支持缓存类
    _supports_cache_class = True

    # 初始化权重函数,根据配置的初始化范围初始化线性层和嵌入层
    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

# QWEN2_INPUTS_DOCSTRING 是一个原始字符串文档,目前为空
QWEN2_INPUTS_DOCSTRING = r"""
"""

# 使用装饰器 @add_start_docstrings 添加文档注释,说明该类是 Qwen2PreTrainedModel 的一个具体实现,用于 Transformer 解码器
@add_start_docstrings(
    "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
    QWEN2_START_DOCSTRING,
)
class Qwen2Model(Qwen2PreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]

    Args:
        config: Qwen2Config
    """

    # 初始化方法,接受一个 Qwen2Config 类型的参数 config
    def __init__(self, config: Qwen2Config):
        super().__init__(config)
        self.padding_idx = config.pad_token_id  # 设置填充索引
        self.vocab_size = config.vocab_size  # 设置词汇表大小

        # 创建词嵌入层,使用 config 中的参数进行初始化
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        # 创建多层解码器层的列表,每层都是 Qwen2DecoderLayer 类的实例
        self.layers = nn.ModuleList(
            [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self._attn_implementation = config._attn_implementation  # 设置注意力实现方式
        self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)  # 设置 RMS 归一化器

        self.gradient_checkpointing = False  # 设置是否使用梯度检查点
        # 初始化权重并应用最终处理
        self.post_init()

    # 返回词嵌入层对象
    def get_input_embeddings(self):
        return self.embed_tokens
    # 定义一个方法,用于设置输入的嵌入向量
    def set_input_embeddings(self, value):
        # 将输入的嵌入向量赋给对象的embed_tokens属性
        self.embed_tokens = value

    # 使用装饰器将下面的方法添加文档字符串,文档字符串内容在外部定义为QWEN2_INPUTS_DOCSTRING
    @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
    # 定义前向传播方法,接受多个输入参数
    def forward(
        self,
        input_ids: torch.LongTensor = None,  # 输入的token ids,数据类型为torch中的LongTensor
        attention_mask: Optional[torch.Tensor] = None,  # 可选的注意力遮罩张量,数据类型为torch.Tensor
        position_ids: Optional[torch.LongTensor] = None,  # 可选的位置ids张量,数据类型为torch中的LongTensor
        past_key_values: Optional[List[torch.FloatTensor]] = None,  # 可选的过去的键值张量列表,数据类型为包含torch中的FloatTensor的列表
        inputs_embeds: Optional[torch.FloatTensor] = None,  # 可选的输入嵌入张量,数据类型为torch.FloatTensor
        use_cache: Optional[bool] = None,  # 是否使用缓存的标志,数据类型为bool型,可选
        output_attentions: Optional[bool] = None,  # 是否输出注意力权重的标志,数据类型为bool型,可选
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态的标志,数据类型为bool型,可选
        return_dict: Optional[bool] = None,  # 是否返回字典格式的输出,数据类型为bool型,可选
class Qwen2ForCausalLM(Qwen2PreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        # 调用父类的初始化方法,传入配置对象
        super().__init__(config)
        # 使用配置对象初始化 Qwen2Model 模型
        self.model = Qwen2Model(config)
        # 设置词汇表大小
        self.vocab_size = config.vocab_size
        # 使用线性层初始化 lm_head,连接隐藏状态和词汇表大小,不带偏置
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # 初始化权重并进行最终处理
        self.post_init()

    def get_input_embeddings(self):
        # 返回模型中的 embed_tokens 属性,用作输入嵌入层
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        # 设置模型中的 embed_tokens 属性为给定的 value,用作输入嵌入层
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        # 返回模型的输出嵌入层 lm_head
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        # 设置模型的输出嵌入层 lm_head 为给定的 new_embeddings
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        # 设置模型的 decoder 属性为给定的 decoder
        self.model = decoder

    def get_decoder(self):
        # 返回模型的 decoder 属性
        return self.model

    @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        """
        实现模型的前向传播逻辑,支持文档化字符串和返回字符串替换。
        """
        # 具体的前向传播逻辑在模型的实现中处理,这里只是声明
        pass

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):
        """
        准备用于生成的输入,根据需要定制输入。
        """
        # 具体的输入准备逻辑在模型的实现中处理,这里只是声明
        pass
        # 如果传入的 past_key_values 不为空,则进行处理
        if past_key_values is not None:
            # 如果 past_key_values 是 Cache 类型的实例
            if isinstance(past_key_values, Cache):
                # 获取缓存的序列长度、已看到的 token 数量和最大缓存长度
                cache_length = past_key_values.get_seq_length()
                past_length = past_key_values.seen_tokens
                max_cache_length = past_key_values.get_max_length()
            else:
                # 否则从 past_key_values 中获取缓存长度和已看到的 token 数量
                cache_length = past_length = past_key_values[0][0].shape[2]
                max_cache_length = None

            # 仅保留未处理的 token:
            # 1 - 如果 attention_mask 的长度超过 input_ids 的长度,则说明部分输入仅作为缓存传递
            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
            # 2 - 如果 past_length 小于 input_ids 的长度,则说明 input_ids 包含所有输入 token,可以根据 past_length 舍弃部分 input_ids
            elif past_length < input_ids.shape[1]:
                input_ids = input_ids[:, past_length:]
            # 3 - 否则 (past_length >= input_ids.shape[1]),假设 input_ids 仅包含未处理的 token

            # 如果将超过最大缓存长度,需要裁剪输入 attention_mask
            if (
                max_cache_length is not None
                and attention_mask is not None
                and cache_length + input_ids.shape[1] > max_cache_length
            ):
                attention_mask = attention_mask[:, -max_cache_length:]

        # 获取 kwargs 中的 position_ids,如果 attention_mask 不为空且 position_ids 为空,则动态生成 position_ids
        position_ids = kwargs.get("position_ids", None)
        if attention_mask is not None and position_ids is None:
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -input_ids.shape[1] :]

        # 如果传入 inputs_embeds,则仅在第一代生成步骤中使用它们
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        # 更新 model_inputs 字典,包括 position_ids、past_key_values、use_cache 和 attention_mask 等
        model_inputs.update(
            {
                "position_ids": position_ids,
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
            }
        )
        return model_inputs
    # 定义一个函数 _reorder_cache,用于重新排序缓存中的过去键值对
    def _reorder_cache(past_key_values, beam_idx):
        # 初始化一个空的重新排序后的过去键值对元组
        reordered_past = ()
        # 遍历过去键值对列表中的每一层
        for layer_past in past_key_values:
            # 对每一层的过去状态,按照给定的 beam_idx 重新排序,并转移到相同的设备上
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        # 返回重新排序后的过去键值对元组
        return reordered_past
"""
Qwen2模型变换器,顶部带有序列分类头(线性层)。

[`Qwen2ForSequenceClassification`] 使用最后一个标记进行分类,类似其他因果模型(例如GPT-2)的做法。

由于它在最后一个标记上进行分类,因此需要知道最后一个标记的位置。如果配置中定义了 `pad_token_id`,它会找到每行中不是填充标记的最后一个标记。如果没有定义 `pad_token_id`,则简单地取每个批次行中的最后一个值。当传递 `inputs_embeds` 而不是 `input_ids` 时,由于无法猜测填充标记,它也会采取同样的做法(取每个批次行中的最后一个值)。
"""
@add_start_docstrings(
    """
    `forward` 方法用于执行前向传播。

    参数:
    - `input_ids` (torch.LongTensor, optional): 输入的token ID序列.
    - `attention_mask` (torch.Tensor, optional): 注意力遮罩,指示哪些元素是填充的.
    - `position_ids` (torch.LongTensor, optional): 指示每个token的位置ID.
    - `past_key_values` (List[torch.FloatTensor], optional): 过去的键值对,用于缓存计算.
    - `inputs_embeds` (torch.FloatTensor, optional): 替代 `input_ids` 的嵌入表示.
    - `labels` (torch.LongTensor, optional): 分类标签.
    - `use_cache` (bool, optional): 是否使用缓存.
    - `output_attentions` (bool, optional): 是否输出注意力权重.
    - `output_hidden_states` (bool, optional): 是否输出隐藏状态.
    - `return_dict` (bool, optional): 是否返回字典格式的输出.

    返回:
    - 根据模型配置返回不同的输出,包括分类结果、注意力权重等.
    """,
    QWEN2_FORWARD_DOCSTRING,
)
class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.model = Qwen2Model(config)
        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        """
        返回模型中的输入嵌入层。
        """
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        """
        设置模型的输入嵌入层。
        """
        self.model.embed_tokens = value

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        """
        Qwen2模型的前向传播方法。
        """

.\models\qwen2\tokenization_qwen2.py

# 定义常量,指定文件名字典,包括词汇表文件和合并文件
VOCAB_FILES_NAMES = {
    "vocab_file": "vocab.json",     # 词汇表文件名
    "merges_file": "merges.txt",    # 合并文件名
}

# 定义预训练模型的词汇文件映射
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/vocab.json"},   # 词汇表文件映射
    "merges_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/merges.txt"},  # 合并文件映射
}

# 定义预训练模型最大输入尺寸
MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}   # 模型最大输入尺寸映射

# 定义用于预分词的正则表达式模式
PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
    # 匹配缩写词和单词、数字、非字母数字字符、空白行、空格

@lru_cache()
# 从 transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode 复制而来
def bytes_to_unicode():
    """
    返回 utf-8 字节列表及其与 Unicode 字符的映射表。避免映射到空白字符或控制字符,以避免 BPE 代码错误。

    可逆的 BPE 代码在 Unicode 字符串上工作。这意味着如果要避免 UNK(未知标记),需要在词汇表中包含大量的 Unicode 字符。
    例如,对于约 100 亿个标记的数据集,您大约需要包含 5000 个 Unicode 字符才能获得良好的覆盖率。
    """
    # 定义 Unicode 字节和 Unicode 字符映射表的起始范围
    bs = (
        list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
    )
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8 + n)
            n += 1
    cs = [chr(n) for n in cs]
    # 返回 Unicode 字节到字符的映射字典
    return dict(zip(bs, cs))


# 从 transformers.models.gpt2.tokenization_gpt2.get_pairs 复制而来
def get_pairs(word):
    """
    返回单词中的符号对集合。

    单词表示为符号元组(符号是长度可变的字符串)。
    """
    # 初始化符号对集合和前一个字符
    pairs = set()
    prev_char = word[0]
    # 遍历单词中的字符,生成符号对集合
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    # 返回符号对集合
    return pairs


class Qwen2Tokenizer(PreTrainedTokenizer):
    """
    Qwen2 的 tokenizer 类,继承自 PreTrainedTokenizer 类。
    """
    # 定义一个 Qwen2 tokenizer,基于字节级的 Byte-Pair-Encoding。
    
    # 和 GPT2Tokenizer 类似,此分词器经过训练以将空格视为标记的一部分,因此一个单词在句子开头(没有空格)和其他位置可能会被编码成不同的标记:
    
    vocab_files_names = VOCAB_FILES_NAMES
    # 从 transformers 库导入的文件名列表,包含词汇文件名
    
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    # 预训练模型词汇文件的映射,用于指定各种预训练模型的词汇文件
    
    max_model_input_sizes = MAX_MODEL_INPUT_SIZES
    # 不同模型的最大输入长度限制,以 token 数量计算
    
    model_input_names = ["input_ids", "attention_mask"]
    # 模型输入所需的标记名称列表,包括输入 IDs 和注意力掩码
    # 初始化方法,用于创建一个新的tokenizer对象
    def __init__(
        self,
        vocab_file,  # 词汇文件路径,用于指定词汇表
        merges_file,  # 合并文件路径,用于指定BPE合并规则
        errors="replace",  # 解码错误处理方式,默认替换错误字符
        unk_token="<|endoftext|>",  # 未知标记,默认为特定的结束标记
        bos_token=None,  # 开始标记,如果指定则创建特殊的添加标记对象
        eos_token="<|endoftext|>",  # 结束标记,默认为特定的结束标记
        pad_token="<|endoftext|>",  # 填充标记,默认为特定的结束标记
        clean_up_tokenization_spaces=False,  # 是否清除标记化空格
        split_special_tokens=False,  # 是否拆分特殊标记
        **kwargs,  # 其他关键字参数
    ):
        # 如果bos_token是字符串,则创建一个特殊的添加标记对象,不进行左右剥离
        bos_token = (
            AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
            if isinstance(bos_token, str)
            else bos_token
        )
        # 如果eos_token是字符串,则创建一个特殊的添加标记对象,不进行左右剥离
        eos_token = (
            AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
            if isinstance(eos_token, str)
            else eos_token
        )
        # 如果unk_token是字符串,则创建一个特殊的添加标记对象,不进行左右剥离
        unk_token = (
            AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
            if isinstance(unk_token, str)
            else unk_token
        )
        # 如果pad_token是字符串,则创建一个特殊的添加标记对象,不进行左右剥离
        pad_token = (
            AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
            if isinstance(pad_token, str)
            else pad_token
        )

        # 从vocab_file中加载词汇表到self.encoder
        with open(vocab_file, encoding="utf-8") as vocab_handle:
            self.encoder = json.load(vocab_handle)
        # 创建self.decoder,将self.encoder的键值对颠倒
        self.decoder = {v: k for k, v in self.encoder.items()}
        self.errors = errors  # 设置解码时的错误处理方式
        self.byte_encoder = bytes_to_unicode()  # 使用字节到Unicode的编码器
        # 创建self.byte_decoder,将self.byte_encoder的键值对颠倒
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        bpe_merges = []
        # 从merges_file中读取BPE合并规则,创建bpe_merges列表
        with open(merges_file, encoding="utf-8") as merges_handle:
            for line in merges_handle:
                line = line.strip()
                if not line or line.startswith("#"):
                    continue
                bpe_merges.append(tuple(line.split()))
        # 使用BPE合并规则创建self.bpe_ranks字典
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
        
        # 注意:缓存可以无限增长,对于长时间运行的进程(特别是没有空格分隔单词的语言文本,如中文),缓存可能会变得非常大。
        # 这不是内存泄漏,但看起来像是。GPT2Tokenizer也有同样的问题,因此我们保持一致。
        self.cache = {}  # 初始化缓存,用于存储tokenization的结果
        
        # 编译预处理的正则表达式模式,用于分隔文本
        self.pat = re.compile(PRETOKENIZE_REGEX)

        # 如果kwargs中包含"add_prefix_space"并且其值为True,则发出警告
        if kwargs.get("add_prefix_space", False):
            logger.warning_once(
                f"{self.__class__.__name__} does not support `add_prefix_space`, setting it to True has no effect."
            )

        # 调用父类的初始化方法,设置错误处理方式、开始标记、结束标记、填充标记、未知标记等
        super().__init__(
            errors=errors,
            bos_token=bos_token,
            eos_token=eos_token,
            pad_token=pad_token,
            unk_token=unk_token,
            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
            split_special_tokens=split_special_tokens,
            **kwargs,
        )

    @property
    # 返回词汇表大小
    def vocab_size(self) -> int:
        return len(self.encoder)
    # 从 GPT2Tokenizer 类中复制而来,返回词汇表的字典,包括编码器和添加的特殊标记编码器
    def get_vocab(self):
        return dict(self.encoder, **self.added_tokens_encoder)

    # 从 GPT2Tokenizer 类中复制而来,执行 BPE(字节对编码)算法,将 token 分解为 BPE tokens
    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token)
        pairs = get_pairs(word)

        if not pairs:
            return token

        while True:
            # 找出当前最小的 bigram,根据 bpe_ranks 中的排序
            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                except ValueError:
                    new_word.extend(word[i:])
                    break
                else:
                    new_word.extend(word[i:j])
                    i = j

                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
                    new_word.append(first + second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        # 将 tuple 转换为字符串,并缓存结果
        word = " ".join(word)
        self.cache[token] = word
        return word

    # 从 GPT2Tokenizer 类中复制而来,对文本进行分词处理
    def _tokenize(self, text):
        """Tokenize a string."""
        bpe_tokens = []
        for token in re.findall(self.pat, text):
            # 将 token 转换为 UTF-8 编码的字节,并用 byte_encoder 映射到 unicode 字符串,避免 BPE 的控制标记(在我们的情况下是空格)
            token = "".join(
                self.byte_encoder[b] for b in token.encode("utf-8")
            )
            # 使用 BPE 算法处理 token,将结果拆分为多个 BPE token,并添加到 bpe_tokens 中
            bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
        return bpe_tokens

    # 从 GPT2Tokenizer 类中复制而来,将 token 转换为其对应的 id
    def _convert_token_to_id(self, token):
        """Converts a token (str) in an id using the vocab."""
        return self.encoder.get(token, self.encoder.get(self.unk_token))

    # 从 GPT2Tokenizer 类中复制而来,将 id 转换为其对应的 token
    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        return self.decoder.get(index)

    # 从 GPT2Tokenizer 类中复制而来,将 tokens 序列转换为单个字符串
    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (string) in a single string."""
        text = "".join(tokens)
        # 使用 byte_decoder 将每个字符的字节解码为 UTF-8 字符串,并处理可能的错误
        text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
        return text
    # `spaces_between_special_tokens`默认为True,用于慢速标记器中的_decode,无法在其他地方配置,
    # 但对于Qwen2Tokenizer,它应该默认为False
    def decode(
        self,
        token_ids,
        skip_special_tokens: bool = False,
        clean_up_tokenization_spaces: Optional[bool] = False,
        spaces_between_special_tokens: bool = False,
        **kwargs,
    ) -> str:
        # 调用父类方法来解码token_ids为字符串
        return super().decode(
            token_ids,
            skip_special_tokens=skip_special_tokens,
            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
            spaces_between_special_tokens=spaces_between_special_tokens,
            **kwargs,
        )

    # 从transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary复制而来
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # 如果保存路径不是目录,则记录错误并返回
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return
        # 构建词汇表文件名和合并文件名
        vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )
        merge_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
        )

        # 将编码器内容以JSON格式写入词汇表文件
        with open(vocab_file, "w", encoding="utf-8") as f:
            f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")

        index = 0
        # 将BPE标记和它们的索引写入合并文件
        with open(merge_file, "w", encoding="utf-8") as writer:
            writer.write("#version: 0.2\n")
            for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
                if index != token_index:
                    logger.warning(
                        f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
                        " Please check that the tokenizer is not corrupted!"
                    )
                    index = token_index
                writer.write(" ".join(bpe_tokens) + "\n")
                index += 1

        return vocab_file, merge_file

    # 准备文本进行标记化前的预处理,包括Unicode规范化和传递额外的参数
    def prepare_for_tokenization(self, text, **kwargs):
        text = unicodedata.normalize("NFC", text)
        return (text, kwargs)

.\models\qwen2\tokenization_qwen2_fast.py

# coding=utf-8
# 版权所有 2024 年 Qwen 团队、阿里巴巴集团和 HuggingFace 公司。保留所有权利。
#
# 根据 Apache 许可证 2.0 版本(“许可证”)授权;
# 除非符合许可证的规定,否则不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件根据“原样”基础分发,
# 没有任何明示或暗示的保证或条件。
# 有关更多信息,请参阅许可证。

"""Qwen2 的标记化类。"""

from typing import Optional, Tuple

from ...tokenization_utils import AddedToken
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
from .tokenization_qwen2 import Qwen2Tokenizer

logger = logging.get_logger(__name__)

# 定义词汇文件的名称映射
VOCAB_FILES_NAMES = {
    "vocab_file": "vocab.json",
    "merges_file": "merges.txt",
    "tokenizer_file": "tokenizer.json",
}

# 定义预训练模型所需的词汇文件映射
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/vocab.json"},
    "merges_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/merges.txt"},
    "tokenizer_file": {
        "qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/tokenizer.json"
    },
}

# 定义模型的最大输入尺寸
MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}


class Qwen2TokenizerFast(PreTrainedTokenizerFast):
    """
    构建一个“快速”的 Qwen2 分词器(基于 HuggingFace 的 *tokenizers* 库)。基于字节级的 Byte-Pair-Encoding。

    与 GPT2Tokenizer 类似,此分词器经过训练,将空格视为标记的一部分,因此一个单词在句子开头(没有空格)和其他位置将被编码为不同的标记:

    ```
    >>> from transformers import Qwen2TokenizerFast

    >>> tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen-tokenizer")
    >>> tokenizer("Hello world")["input_ids"]
    [9707, 1879]

    >>> tokenizer(" Hello world")["input_ids"]
    [21927, 1879]
    ```
    这是预期的行为。

    此分词器继承自 [`PreTrainedTokenizerFast`],其中包含大多数主要方法。用户应参考该超类以获取有关这些方法的更多信息。
    """
    Args:
        vocab_file (`str`, *optional*):
            Path to the vocabulary file.
        merges_file (`str`, *optional*):
            Path to the merges file.
        tokenizer_file (`str`, *optional*):
            Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
            contains everything needed to load the tokenizer.
        unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead. Not applicable to this tokenizer.
        bos_token (`str`, *optional`):
            The beginning of sequence token. Not applicable for this tokenizer.
        eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
            The end of sequence token.
        pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
            The token used for padding, for example when batching sequences of different lengths.
    """

    # These variables define certain constants for the tokenizer configuration
    vocab_files_names = VOCAB_FILES_NAMES
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    max_model_input_sizes = MAX_MODEL_INPUT_SIZES
    model_input_names = ["input_ids", "attention_mask"]
    slow_tokenizer_class = Qwen2Tokenizer

    def __init__(
        self,
        vocab_file=None,
        merges_file=None,
        tokenizer_file=None,
        unk_token="<|endoftext|>",
        bos_token=None,
        eos_token="<|endoftext|>",
        pad_token="<|endoftext|>",
        **kwargs,
    ):
        """
        Initializes a new instance of the Qwen2Tokenizer class.
        
        Args:
            vocab_file (str, optional): Path to the vocabulary file.
            merges_file (str, optional): Path to the merges file.
            tokenizer_file (str, optional): Path to tokenizers file.
            unk_token (str, optional, default="<|endoftext|>"): The unknown token.
            bos_token (str, optional): The beginning of sequence token.
            eos_token (str, optional, default="<|endoftext|>"): The end of sequence token.
            pad_token (str, optional, default="<|endoftext|>"): The padding token.
            **kwargs: Additional keyword arguments passed to the base class constructor.
        """
        
        # Set bos_token, eos_token, unk_token, and pad_token as AddedToken objects if they are strings
        bos_token = (
            AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
            if isinstance(bos_token, str)
            else bos_token
        )
        eos_token = (
            AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
            if isinstance(eos_token, str)
            else eos_token
        )
        unk_token = (
            AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
            if isinstance(unk_token, str)
            else unk_token
        )
        pad_token = (
            AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
            if isinstance(pad_token, str)
            else pad_token
        )
        
        # Call the base class constructor with the provided arguments
        super().__init__(
            vocab_file,
            merges_file,
            tokenizer_file=tokenizer_file,
            unk_token=unk_token,
            bos_token=bos_token,
            eos_token=eos_token,
            pad_token=pad_token,
            **kwargs,
        )
    # 从 transformers 库中 GPT2TokenizerFast 类的 save_vocabulary 方法复制而来
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # 调用内部的 tokenizer 模块的 save 方法,将模型保存到指定的目录中,并使用给定的前缀作为文件名前缀
        files = self._tokenizer.model.save(save_directory, name=filename_prefix)
        # 返回保存的文件名组成的元组
        return tuple(files)

.\models\qwen2\__init__.py

# 版权声明和许可信息
# 本代码受 Apache 许可证 2.0 版本保护,详细信息可查阅许可证
# http://www.apache.org/licenses/LICENSE-2.0

# 引入类型检查
from typing import TYPE_CHECKING

# 引入依赖的模块和函数
from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_tokenizers_available,
    is_torch_available,
)

# 定义模块的导入结构,包括配置和标记化
_import_structure = {
    "configuration_qwen2": ["QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Qwen2Config"],
    "tokenization_qwen2": ["Qwen2Tokenizer"],
}

# 检查 tokenizers 是否可用,若不可用则引发异常
try:
    if not is_tokenizers_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,增加快速标记化的导入结构
    _import_structure["tokenization_qwen2_fast"] = ["Qwen2TokenizerFast"]

# 检查 torch 是否可用,若不可用则引发异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,增加模型的导入结构
    _import_structure["modeling_qwen2"] = [
        "Qwen2ForCausalLM",
        "Qwen2Model",
        "Qwen2PreTrainedModel",
        "Qwen2ForSequenceClassification",
    ]

# 如果正在进行类型检查,导入相应的模块和函数
if TYPE_CHECKING:
    from .configuration_qwen2 import QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP, Qwen2Config
    from .tokenization_qwen2 import Qwen2Tokenizer

    try:
        if not is_tokenizers_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .tokenization_qwen2_fast import Qwen2TokenizerFast

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_qwen2 import (
            Qwen2ForCausalLM,
            Qwen2ForSequenceClassification,
            Qwen2Model,
            Qwen2PreTrainedModel,
        )

# 如果不是类型检查,将模块定义为延迟加载模块
else:
    import sys

    # 使用 LazyModule 将模块定义为延迟加载模块
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\rag\configuration_rag.py

# 设置文件编码为 UTF-8

# 版权声明和许可协议,此代码版权归 RAG 作者和 HuggingFace Inc. 团队所有,使用 Apache License, Version 2.0 许可
# 详细许可信息可以在 http://www.apache.org/licenses/LICENSE-2.0 获取

# 导入预训练配置类 PretrainedConfig 和辅助函数 add_start_docstrings
from ...configuration_utils import PretrainedConfig
from ...utils import add_start_docstrings

# RAG 配置文档字符串,描述了 RagConfig 类的配置信息
RAG_CONFIG_DOC = r"""
    [`RagConfig`] 存储了 *RagModel* 的配置。配置对象继承自 [`PretrainedConfig`],
    可以用于控制模型的输出。更多信息请参阅 [`PretrainedConfig`] 的文档。
"""

# 使用 add_start_docstrings 函数为 RagConfig 类添加起始文档字符串
@add_start_docstrings(RAG_CONFIG_DOC)
class RagConfig(PretrainedConfig):
    # 指定模型类型为 "rag"
    model_type = "rag"
    # 表示 RagConfig 是由其他组件组合而成
    is_composition = True

    # 构造函数,初始化 RagConfig 类的配置参数
    def __init__(
        self,
        vocab_size=None,
        is_encoder_decoder=True,
        prefix=None,
        bos_token_id=None,
        pad_token_id=None,
        eos_token_id=None,
        decoder_start_token_id=None,
        title_sep=" / ",
        doc_sep=" // ",
        n_docs=5,
        max_combined_length=300,
        retrieval_vector_size=768,
        retrieval_batch_size=8,
        dataset="wiki_dpr",
        dataset_split="train",
        index_name="compressed",
        index_path=None,
        passages_path=None,
        use_dummy_dataset=False,
        reduce_loss=False,
        label_smoothing=0.0,
        do_deduplication=True,
        exclude_bos_score=False,
        do_marginalize=False,
        output_retrieved=False,
        use_cache=True,
        forced_eos_token_id=None,
        dataset_revision=None,
        **kwargs,
    ):
        # 调用父类 PretrainedConfig 的构造函数,初始化配置参数
        super().__init__(
            vocab_size=vocab_size,
            is_encoder_decoder=is_encoder_decoder,
            prefix=prefix,
            bos_token_id=bos_token_id,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            decoder_start_token_id=decoder_start_token_id,
            **kwargs,
        )
    ):
        super().__init__(
            bos_token_id=bos_token_id,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            decoder_start_token_id=decoder_start_token_id,
            forced_eos_token_id=forced_eos_token_id,
            is_encoder_decoder=is_encoder_decoder,
            prefix=prefix,
            vocab_size=vocab_size,
            **kwargs,
        )
        # 调用父类的初始化方法,传入各种模型配置参数和额外的关键字参数

        assert (
            "question_encoder" in kwargs and "generator" in kwargs
        ), "Config has to be initialized with question_encoder and generator config"
        # 断言确保关键字参数中包含 "question_encoder" 和 "generator",否则抛出异常信息

        question_encoder_config = kwargs.pop("question_encoder")
        # 从关键字参数中弹出 "question_encoder" 并赋值给变量 question_encoder_config
        question_encoder_model_type = question_encoder_config.pop("model_type")
        # 从 question_encoder_config 中弹出 "model_type" 并赋值给变量 question_encoder_model_type

        decoder_config = kwargs.pop("generator")
        # 从关键字参数中弹出 "generator" 并赋值给变量 decoder_config
        decoder_model_type = decoder_config.pop("model_type")
        # 从 decoder_config 中弹出 "model_type" 并赋值给变量 decoder_model_type

        from ..auto.configuration_auto import AutoConfig
        # 从自动生成的配置模块中导入 AutoConfig 类

        self.question_encoder = AutoConfig.for_model(question_encoder_model_type, **question_encoder_config)
        # 使用 AutoConfig 根据 question_encoder_model_type 和 question_encoder_config 创建 question_encoder 实例
        self.generator = AutoConfig.for_model(decoder_model_type, **decoder_config)
        # 使用 AutoConfig 根据 decoder_model_type 和 decoder_config 创建 generator 实例

        self.reduce_loss = reduce_loss
        # 将 reduce_loss 参数赋值给实例变量 self.reduce_loss
        self.label_smoothing = label_smoothing
        # 将 label_smoothing 参数赋值给实例变量 self.label_smoothing
        self.exclude_bos_score = exclude_bos_score
        # 将 exclude_bos_score 参数赋值给实例变量 self.exclude_bos_score
        self.do_marginalize = do_marginalize
        # 将 do_marginalize 参数赋值给实例变量 self.do_marginalize

        self.title_sep = title_sep
        # 将 title_sep 参数赋值给实例变量 self.title_sep
        self.doc_sep = doc_sep
        # 将 doc_sep 参数赋值给实例变量 self.doc_sep
        self.n_docs = n_docs
        # 将 n_docs 参数赋值给实例变量 self.n_docs
        self.max_combined_length = max_combined_length
        # 将 max_combined_length 参数赋值给实例变量 self.max_combined_length

        self.dataset = dataset
        # 将 dataset 参数赋值给实例变量 self.dataset
        self.dataset_split = dataset_split
        # 将 dataset_split 参数赋值给实例变量 self.dataset_split
        self.index_name = index_name
        # 将 index_name 参数赋值给实例变量 self.index_name

        self.retrieval_vector_size = retrieval_vector_size
        # 将 retrieval_vector_size 参数赋值给实例变量 self.retrieval_vector_size
        self.retrieval_batch_size = retrieval_batch_size
        # 将 retrieval_batch_size 参数赋值给实例变量 self.retrieval_batch_size
        self.passages_path = passages_path
        # 将 passages_path 参数赋值给实例变量 self.passages_path
        self.index_path = index_path
        # 将 index_path 参数赋值给实例变量 self.index_path
        self.use_dummy_dataset = use_dummy_dataset
        # 将 use_dummy_dataset 参数赋值给实例变量 self.use_dummy_dataset
        self.dataset_revision = dataset_revision
        # 将 dataset_revision 参数赋值给实例变量 self.dataset_revision

        self.output_retrieved = output_retrieved
        # 将 output_retrieved 参数赋值给实例变量 self.output_retrieved

        self.do_deduplication = do_deduplication
        # 将 do_deduplication 参数赋值给实例变量 self.do_deduplication

        self.use_cache = use_cache
        # 将 use_cache 参数赋值给实例变量 self.use_cache

        if self.forced_eos_token_id is None:
            self.forced_eos_token_id = getattr(self.generator, "forced_eos_token_id", None)
        # 如果实例变量 forced_eos_token_id 为 None,则尝试从 generator 中获取 "forced_eos_token_id" 并赋值给它

    @classmethod
    def from_question_encoder_generator_configs(
        cls, question_encoder_config: PretrainedConfig, generator_config: PretrainedConfig, **kwargs
    ) -> PretrainedConfig:
        r"""
        Instantiate a [`EncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model configuration and
        decoder model configuration.

        Returns:
            [`EncoderDecoderConfig`]: An instance of a configuration object
        """
        return cls(question_encoder=question_encoder_config.to_dict(), generator=generator_config.to_dict(), **kwargs)
        # 使用 question_encoder_config 和 generator_config 的字典形式创建一个 EncoderDecoderConfig 实例,并返回

.\models\rag\modeling_rag.py

# 导入必要的库和模块
import copy  # 导入深拷贝模块
from dataclasses import dataclass  # 导入dataclass装饰器
from typing import Callable, List, Optional, Tuple, Union  # 导入类型提示

import torch  # 导入PyTorch库
from torch import nn  # 导入神经网络模块

# 导入配置相关的工具和类
from ...configuration_utils import PretrainedConfig  # 导入预训练配置类
from ...generation import BeamSearchScorer, GenerationConfig, LogitsProcessorList, StoppingCriteriaList  # 导入生成相关类和模块
from ...modeling_outputs import ModelOutput  # 导入模型输出基类
from ...modeling_utils import PreTrainedModel  # 导入预训练模型类
from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings  # 导入工具类和函数

# 导入RAG模型相关配置和检索器
from .configuration_rag import RagConfig  # 导入RAG配置类
from .retrieval_rag import RagRetriever  # 导入RAG检索器类

# 获取日志记录器
logger = logging.get_logger(__name__)

# 用于文档的配置名称
_CONFIG_FOR_DOC = "RagConfig"

@dataclass
class RetrievAugLMMarginOutput(ModelOutput):
    """
    检索增强的边际化模型输出的基类。

    """
    loss: Optional[torch.FloatTensor] = None  # 损失值,可选的浮点张量
    logits: torch.FloatTensor = None  # 对数张量,浮点张量
    doc_scores: torch.FloatTensor = None  # 文档分数,浮点张量
    past_key_values: Optional[List[torch.FloatTensor]] = None  # 过去的键值,可选的浮点张量列表
    retrieved_doc_embeds: Optional[torch.FloatTensor] = None  # 检索到的文档嵌入,可选的浮点张量
    retrieved_doc_ids: Optional[torch.LongTensor] = None  # 检索到的文档ID,可选的长整型张量
    context_input_ids: Optional[torch.LongTensor] = None  # 上下文输入ID,可选的长整型张量
    context_attention_mask: Optional[torch.LongTensor] = None  # 上下文注意力掩码,可选的长整型张量
    question_encoder_last_hidden_state: Optional[torch.FloatTensor] = None  # 问题编码器最后隐藏状态,可选的浮点张量
    question_enc_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None  # 问题编码器隐藏状态元组,可选的浮点张量元组
    question_enc_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None  # 问题编码器注意力元组,可选的浮点张量元组
    generator_enc_last_hidden_state: Optional[torch.FloatTensor] = None  # 生成器编码器最后隐藏状态,可选的浮点张量
    generator_enc_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None  # 生成器编码器隐藏状态元组,可选的浮点张量元组
    generator_enc_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None  # 生成器编码器注意力元组,可选的浮点张量元组
    generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None  # 生成器解码器隐藏状态元组,可选的浮点张量元组
    generator_dec_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None  # 生成器解码器注意力元组,可选的浮点张量元组
    generator_cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None  # 生成器交叉注意力元组,可选的浮点张量元组

@dataclass
class RetrievAugLMOutput(ModelOutput):
    """
    检索增强的语言模型输出基类。

    """
    logits: torch.FloatTensor = None  # 对数张量,浮点张量
    doc_scores: torch.FloatTensor = None  # 文档分数,浮点张量
    past_key_values: Optional[List[torch.FloatTensor]] = None  # 过去的键值,可选的浮点张量列表
    retrieved_doc_embeds: Optional[torch.FloatTensor] = None  # 检索到的文档嵌入,可选的浮点张量
    retrieved_doc_ids: Optional[torch.LongTensor] = None  # 检索到的文档ID,可选的长整型张量
    context_input_ids: Optional[torch.LongTensor] = None  # 上下文输入ID,可选的长整型张量
    context_attention_mask: Optional[torch.LongTensor] = None  # 上下文注意力掩码,可选的长整型张量
    # 定义问题编码器的最后隐藏状态,初始值为None,类型为可选的浮点张量
    question_encoder_last_hidden_state: Optional[torch.FloatTensor] = None
    
    # 定义问题编码器的隐藏状态列表,初始值为None,类型为可选的包含多个浮点张量的元组
    question_enc_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    
    # 定义问题编码器的注意力列表,初始值为None,类型为可选的包含多个浮点张量的元组
    question_enc_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    
    # 定义生成器编码器的最后隐藏状态,初始值为None,类型为可选的浮点张量
    generator_enc_last_hidden_state: Optional[torch.FloatTensor] = None
    
    # 定义生成器编码器的隐藏状态列表,初始值为None,类型为可选的包含多个浮点张量的元组
    generator_enc_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    
    # 定义生成器编码器的注意力列表,初始值为None,类型为可选的包含多个浮点张量的元组
    generator_enc_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    
    # 定义生成器解码器的隐藏状态列表,初始值为None,类型为可选的包含多个浮点张量的元组
    generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    
    # 定义生成器解码器的注意力列表,初始值为None,类型为可选的包含多个浮点张量的元组
    generator_dec_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    
    # 定义生成器交叉注意力列表,初始值为None,类型为可选的包含多个浮点张量的元组
    generator_cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
# 定义一个自定义的 RAG 预训练模型类,继承自 PreTrainedModel
class RagPreTrainedModel(PreTrainedModel):
    r"""
    RAG models were released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP
    Tasks](https://arxiv.org/abs/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandra Piktus et al.

    RAG is a retriever augmented model and encapsulate three components: a question encoder, a dataset retriever and a
    generator, the encoder and generator are trainable while the retriever is just an indexed dataset.

    """

    # 指定配置类为 RagConfig
    config_class = RagConfig
    # 指定基础模型的前缀为 "rag"
    base_model_prefix = "rag"

    @classmethod
    def from_pretrained(cls, *args, **kwargs):
        # 目前不支持快速初始化
        # 对于复合模型
        kwargs["_fast_init"] = False
        # 调用父类的 from_pretrained 方法
        return super().from_pretrained(*args, **kwargs)

    @classmethod
    def from_pretrained_question_encoder_generator(
        cls,
        question_encoder_pretrained_model_name_or_path: str = None,
        generator_pretrained_model_name_or_path: str = None,
        retriever: RagRetriever = None,
        **kwargs,
    ):
        # 以下是 RAG 模型的文档字符串定义,描述了模型的结构和使用方法
        RAG_START_DOCSTRING = r"""
        RAG is a seq2seq model which encapsulates two core components: a question encoder and a generator. During a forward
        pass, we encode the input with the question encoder and pass it to the retriever to extract relevant context
        documents. The documents are then prepended to the input. Such contextualized inputs is passed to the generator.

        The question encoder can be any *autoencoding* model, preferably [`DPRQuestionEncoder`], and the generator can be
        any *seq2seq* model, preferably [`BartForConditionalGeneration`].

        The model can be initialized with a [`RagRetriever`] for end-to-end generation or used in combination with the
        outputs of a retriever in multiple steps---see examples for more details. The model is compatible any
        *autoencoding* model as the `question_encoder` and any *seq2seq* model with language model head as the `generator`.
        It has been tested with [`DPRQuestionEncoder`] as the `question_encoder` and [`BartForConditionalGeneration`] or
        [`T5ForConditionalGeneration`] as the `generator`.

        This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
        library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
        etc.)

        This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
        Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
        and behavior.
        """
    Args:
        config ([`RagConfig`]):
            模型配置类,包含模型的所有参数。通过配置文件初始化不会加载与模型相关的权重,只加载配置信息。
            若要加载模型权重,请查看 [`~PreTrainedModel.from_pretrained`] 方法。
        question_encoder ([`PreTrainedModel`]):
            编码器模型,与由 `retriever` 封装的 faiss 索引兼容。
        generator ([`PreTrainedModel`]):
            在 RAG 结构中用作生成器的 seq2seq 模型。
        retriever ([`RagRetriever`]):
            检索器类,封装了一个 faiss 索引,用于查询获取当前输入的上下文文档。
"""
"""


RAG_FORWARD_INPUTS_DOCSTRING = r"""
"""


@add_start_docstrings_to_model_forward(RAG_START_DOCSTRING)
class RagModel(RagPreTrainedModel):
    def __init__(
        self,
        config: Optional[PretrainedConfig] = None,
        question_encoder: Optional[PreTrainedModel] = None,
        generator: Optional[PreTrainedModel] = None,
        retriever: Optional[RagRetriever] = None,  # or maybe just use a `set_retriever(...)` method
        **kwargs,
    ):
        assert config is not None or (
            question_encoder is not None and generator is not None
        ), "Either a configuration or an question_encoder and a generator has to be provided."

        if config is None:
            # Constructing a RagConfig object from provided question_encoder and generator configurations
            config = RagConfig.from_question_encoder_generator_configs(
                question_encoder.config, generator.config, **kwargs
            )
        else:
            assert isinstance(config, self.config_class), f"config: {config} has to be of type {self.config_class}"
        super().__init__(config)
        
        if question_encoder is None:
            # If question_encoder is not provided, instantiate a default model using AutoModel
            from ..auto.modeling_auto import AutoModel
            question_encoder = AutoModel.from_config(config.question_encoder)

        if generator is None:
            # If generator is not provided, instantiate a default Seq2SeqLM model using AutoModelForSeq2SeqLM
            from ..auto.modeling_auto import AutoModelForSeq2SeqLM
            generator = AutoModelForSeq2SeqLM.from_config(config.generator)

        self.retriever = retriever
        if self.retriever is not None:
            # Ensure retriever is of type RagRetriever
            assert isinstance(
                retriever, RagRetriever
            ), f"`self.retriever` is of type {type(self.retriever)}, but should be of type `RagRetriever`"
            self.retriever = retriever

        self.question_encoder = question_encoder
        self.generator = generator

        # Initialize context encoder attributes
        self.ctx_encoder = None
        self.context_encoder_training = False

    @add_start_docstrings_to_model_forward(
        """
        A RAG-sequence model implementation. It performs RAG-sequence specific marginalization in the forward pass.
        """
    )
    @replace_return_docstrings(output_type=RetrievAugLMOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        doc_scores: Optional[torch.FloatTensor] = None,
        context_input_ids: Optional[torch.LongTensor] = None,
        context_attention_mask: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_retrieved: Optional[bool] = None,
        n_docs: Optional[int] = None,
    ):
        """
        Perform a forward pass through the RAG model.

        This method implements specific marginalization for RAG-sequence models.

        Args:
            input_ids (Optional[torch.LongTensor]): Input tensor of token indices.
            attention_mask (Optional[torch.Tensor]): Mask tensor indicating which elements in the input should be attended to.
            encoder_outputs (Optional[Tuple[Tuple[torch.FloatTensor]]]): Outputs of the encoder.
            decoder_input_ids (Optional[torch.LongTensor]): Input tensor for decoder.
            decoder_attention_mask (Optional[torch.BoolTensor]): Mask tensor for decoder attention.
            past_key_values (Optional[Tuple[Tuple[torch.FloatTensor]]]): Cached key-values for faster decoding.
            doc_scores (Optional[torch.FloatTensor]): Scores indicating relevance of retrieved documents.
            context_input_ids (Optional[torch.LongTensor]): Tensor of token indices for context.
            context_attention_mask (Optional[torch.LongTensor]): Mask tensor for context attention.
            use_cache (Optional[bool]): Whether to use cached values.
            output_attentions (Optional[bool]): Whether to output attention weights.
            output_hidden_states (Optional[bool]): Whether to output hidden states.
            output_retrieved (Optional[bool]): Whether to output retrieved documents.
            n_docs (Optional[int]): Number of documents to retrieve.

        Returns:
            RetrievAugLMOutput: Object containing the model outputs.
        """
        pass
    RAG_START_DOCSTRING,
# 定义一个继承自 RagPreTrainedModel 的类,用于生成RAG(Retrieval-Augmented Generation)模型的序列
class RagSequenceForGeneration(RagPreTrainedModel):
    def __init__(
        self,
        config: Optional[PretrainedConfig] = None,
        question_encoder: Optional[PreTrainedModel] = None,
        generator: Optional[PreTrainedModel] = None,
        retriever: Optional[RagRetriever] = None,
        **kwargs,
    ):
        # 断言语句,要求提供配置信息或者问题编码器和生成器的组合之一
        assert config is not None or (
            question_encoder is not None and generator is not None
        ), "Either a configuration or an encoder and a generator has to be provided."

        # 如果未提供配置信息,则根据提供的问题编码器和生成器配置生成一个 RagConfig 对象
        if config is None:
            config = RagConfig.from_question_encoder_generator_configs(
                question_encoder.config, generator.config, **kwargs
            )
        # 调用父类的初始化方法,传入配置信息
        super().__init__(config)

        # 实例化 RAG 模型,传入配置信息、问题编码器、生成器和检索器
        self.rag = RagModel(config=config, question_encoder=question_encoder, generator=generator, retriever=retriever)

    # 设置模型的检索器
    def set_retriever(self, retriever: RagRetriever):
        self.rag.retriever = retriever

    # 设置用于训练的上下文编码器
    def set_context_encoder_for_training(self, ctx_encoder: PreTrainedModel):
        self.rag.context_encoder_training = True
        self.rag.ctx_encoder = ctx_encoder

    # 前向传播方法,接收多个输入参数,详细的参数说明由装饰器 @add_start_docstrings_to_model_forward 和 @replace_return_docstrings 提供
    @add_start_docstrings_to_model_forward(RAG_FORWARD_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=RetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        context_input_ids: Optional[torch.LongTensor] = None,
        context_attention_mask: Optional[torch.LongTensor] = None,
        doc_scores: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_retrieved: Optional[bool] = None,
        exclude_bos_score: Optional[bool] = None,
        reduce_loss: Optional[bool] = None,
        labels: Optional[torch.LongTensor] = None,
        n_docs: Optional[int] = None,
        **kwargs,  # 需要传递给生成过程的额外参数
    ):
        pass  # 实际前向传播逻辑在 RagModel 类中定义

    # 返回模型的检索器属性
    @property
    def retriever(self):
        return self.rag.retriever

    # 返回模型的生成器属性
    @property
    def generator(self):
        return self.rag.generator

    # 返回模型的问题编码器属性
    @property
    def question_encoder(self):
        return self.rag.question_encoder

    # 使用 torch.no_grad 装饰器,表示该方法不需要计算梯度信息
    @torch.no_grad()
    # 定义一个生成方法,用于生成文本序列。
    def generate(
        self,
        input_ids: Optional[torch.LongTensor] = None,  # 输入序列的索引张量,可以为空
        attention_mask: Optional[torch.LongTensor] = None,  # 注意力掩码张量,可以为空
        context_input_ids: Optional[torch.LongTensor] = None,  # 上下文输入序列的索引张量,可以为空
        context_attention_mask: Optional[torch.LongTensor] = None,  # 上下文输入的注意力掩码张量,可以为空
        doc_scores: Optional[torch.FloatTensor] = None,  # 文档评分张量,可以为空
        do_deduplication: Optional[bool] = None,  # 是否去重,默认为True
        num_return_sequences: Optional[int] = None,  # 返回序列的数量,默认为1
        num_beams: Optional[int] = None,  # Beam搜索中的Beam大小,默认为1
        n_docs: Optional[int] = None,  # 文档数量,可以为空
        **model_kwargs,  # 其他模型相关参数,接收任意关键字参数
    ):
        # 定义一个计算负对数似然(Negative Log-Likelihood,NLL)的方法
    def get_nll(
        self,
        seq_logits,  # 序列的logits,用于计算NLL
        doc_scores,  # 文档评分,用于加权序列NLL
        target,  # 目标序列,用于计算NLL
        reduce_loss=False,  # 是否减少损失,默认为False
        epsilon=0.0,  # 平滑项,用于数值稳定性,默认为0.0
        exclude_bos_score=False,  # 是否排除起始标记得分,默认为False
        n_docs=None  # 文档数量,可以为空
        # shift tokens left
        target = torch.cat(
            [target[:, 1:], target.new(target.shape[0], 1).fill_(self.config.generator.pad_token_id)], 1
        )

        # Determine the number of documents to consider, defaulting to self.config.n_docs if not specified
        n_docs = n_docs if n_docs is not None else self.config.n_docs

        # Determine the beginning of sequence token ID (`bos_token_id`) based on model configuration
        bos_token_id = self.config.bos_token_id or self.config.generator.bos_token_id
        use_bos = bos_token_id is not None and target[:, 0].eq(bos_token_id).all()

        def _mask_pads(ll, smooth_obj):
            # Create a mask for padding tokens in the target sequence
            pad_mask = target.eq(self.config.generator.pad_token_id)
            if pad_mask.any():
                # Apply the mask to log-likelihood and smoothing objective
                ll.masked_fill_(pad_mask, 0.0)
                smooth_obj.masked_fill_(pad_mask, 0.0)
            return ll.squeeze(-1), smooth_obj.squeeze(-1)

        # Compute log softmax over sequence logits and reshape for RAG sequence marginalization
        seq_logprobs = nn.functional.log_softmax(seq_logits, dim=-1).view(
            seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.size(-1)
        )  # batch_size x n_docs x tgt_len x #vocab_size
        doc_logprobs = nn.functional.log_softmax(doc_scores, dim=1).unsqueeze(-1).unsqueeze(-1)

        # RAG-sequence marginalization
        first_token_scores = seq_logprobs[:, :, :1, :]
        second_token_scores = seq_logprobs[:, :, 1:2, :]
        remainder = seq_logprobs[:, :, 2:, :]
        rag_logprobs = torch.cat([first_token_scores, second_token_scores + doc_logprobs, remainder], dim=2)

        # Ensure target tensor matches dimensions of rag_logprobs for indexing
        target = target.unsqueeze(1).unsqueeze(-1).repeat(1, n_docs, 1, 1)
        assert target.dim() == rag_logprobs.dim()

        # Gather log probabilities corresponding to target indices and apply padding mask
        ll = rag_logprobs.gather(dim=-1, index=target)
        smooth_obj = rag_logprobs.sum(dim=-1, keepdim=True)  # total sum of all (normalized) logits

        # Apply padding mask to log-likelihood and smoothing objective
        ll, smooth_obj = _mask_pads(ll, smooth_obj)

        # Sum over tokens to compute loss, optionally excluding beginning of sequence token
        ll = ll[:, :, 1:].sum(2) if exclude_bos_score and use_bos else ll.sum(2)
        smooth_obj = smooth_obj.sum(2)
        ll = ll.logsumexp(1)  # logsumexp over docs
        smooth_obj = smooth_obj.logsumexp(1)

        # Calculate negative log-likelihood (nll) loss and smoothed loss
        nll_loss = -ll
        smooth_loss = -smooth_obj

        # Optionally reduce loss across batches
        if reduce_loss:
            nll_loss = nll_loss.sum()
            smooth_loss = smooth_loss.sum()

        # Compute final loss using nll_loss, smooth_loss, and epsilon for smoothing
        eps_i = epsilon / rag_logprobs.size(-1)
        loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
        return loss

    @staticmethod
    def _cat_and_pad(tensors, pad_token_id):
        # Concatenate tensors into a padded tensor with specified pad_token_id
        output = (
            tensors[0].new(sum([t.shape[0] for t in tensors]), max([t.shape[1] for t in tensors])).fill_(pad_token_id)
        )
        ind = 0
        for t in tensors:
            output[ind : ind + t.shape[0], : t.shape[1]] = t
            ind += t.shape[0]
        return output
"""
一个实现了RAG-token模型的类。在前向传播中执行了RAG-token特定的边缘化操作。
"""
@add_start_docstrings_to_model_forward(
    """
    A RAG-token model implementation. It performs RAG-token specific marginalization in the forward pass.
    """,
    RAG_START_DOCSTRING,
)
class RagTokenForGeneration(RagPreTrainedModel):
    def __init__(
        self,
        config: Optional[PretrainedConfig] = None,
        question_encoder: Optional[PreTrainedModel] = None,
        generator: Optional[PreTrainedModel] = None,
        retriever: Optional[RagRetriever] = None,
        **kwargs,
    ):
        # 断言:确保提供了配置或者问题编码器和生成器的组合
        assert config is not None or (
            question_encoder is not None and generator is not None
        ), "Either a configuration or an encoder and a generator has to be provided."

        # 如果没有提供配置,则根据问题编码器和生成器的配置创建RAG配置对象
        if config is None:
            config = RagConfig.from_question_encoder_generator_configs(
                question_encoder.config, generator.config, **kwargs
            )

        # 调用父类初始化方法
        super().__init__(config)

        # 实例化RAG模型
        self.rag = RagModel(config=config, question_encoder=question_encoder, generator=generator, retriever=retriever)

    # 设置检索器
    def set_retriever(self, retriever: RagRetriever):
        self.rag.retriever = retriever

    # 设置用于训练的上下文编码器
    def set_context_encoder_for_training(self, ctx_encoder: PreTrainedModel):
        self.rag.context_encoder_training = True
        self.rag.ctx_encoder = ctx_encoder

    # 准备生成的输入
    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        past_key_values=None,
        attention_mask=None,
        use_cache=None,
        encoder_outputs=None,
        doc_scores=None,
        n_docs=None,
        **kwargs,
    ):
        # 如果已经定义了过去的键值对,则只使用最后一个decoder_input_ids
        if past_key_values is not None:
            decoder_input_ids = decoder_input_ids[:, -1:]

        return {
            "input_ids": None,
            "encoder_outputs": encoder_outputs,
            "doc_scores": doc_scores,
            "context_attention_mask": attention_mask,
            "decoder_input_ids": decoder_input_ids,
            "past_key_values": past_key_values,
            "use_cache": use_cache,
            "do_marginalize": True,
            "n_docs": n_docs,
        }

    # 检索器的属性
    @property
    def retriever(self):
        return self.rag.retriever

    # 生成器的属性
    @property
    def generator(self):
        return self.rag.generator

    # 问题编码器的属性
    @property
    def question_encoder(self):
        return self.rag.question_encoder

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        """Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs"""

        def _reorder_stacked(hidden_states, new_order):
            # 计算每个文档的数量
            n_docs = hidden_states.shape[0] // new_order.shape[0]
            # 将隐藏状态重塑为 [batch_size, n_docs, ...] 的形状
            hidden_states = hidden_states.view(-1, n_docs, *hidden_states.shape[1:])
            # 根据新的顺序索引选择隐藏状态
            hidden_states = hidden_states.index_select(0, new_order)
            # 恢复原来的形状
            result = hidden_states.view(-1, *hidden_states.shape[2:])
            return result

        # 初始化重新排序后的缓存
        reordered_past = ()
        # 遍历每一层的缓存
        for layer_past in past_key_values:
            # 对每个缓存状态重新排序,并添加到结果中
            reordered_past += (
                tuple(_reorder_stacked(past_state, beam_idx.to(past_state.device)) for past_state in layer_past),
            )

        return reordered_past

    def marginalize(self, seq_logits, doc_scores, n_docs=None):
        # 如果未提供 n_docs,则使用默认值 self.config.n_docs
        n_docs = n_docs if n_docs is not None else self.config.n_docs

        # 对序列的 logits 进行 log_softmax,并重塑为 [batch_size / n_docs, n_docs, ..., num_labels]
        seq_logprobs = nn.functional.log_softmax(seq_logits, dim=-1).view(
            seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.size(-1)
        )
        # 对文档分数进行 log_softmax
        doc_logprobs = torch.log_softmax(doc_scores, dim=1)
        # 计算序列 log_probs 和文档 log_probs 的和,并进行 logsumexp 运算
        log_prob_sum = seq_logprobs + doc_logprobs.unsqueeze(-1).unsqueeze(-1)
        return torch.logsumexp(log_prob_sum, dim=1)

    @add_start_docstrings_to_model_forward(RAG_FORWARD_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=RetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        context_input_ids: Optional[torch.LongTensor] = None,
        context_attention_mask: Optional[torch.LongTensor] = None,
        doc_scores: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_retrieved: Optional[bool] = None,
        do_marginalize: Optional[bool] = None,
        reduce_loss: Optional[bool] = None,
        labels: Optional[torch.LongTensor] = None,
        n_docs: Optional[int] = None,
        **kwargs,  # 需要用于生成的其他参数
    ):
        # 在 forward 方法中使用 torch.no_grad(),确保不计算梯度
        @torch.no_grad()
    def generate(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        context_input_ids: Optional[torch.LongTensor] = None,
        context_attention_mask: Optional[torch.LongTensor] = None,
        doc_scores: Optional[torch.FloatTensor] = None,
        n_docs: Optional[int] = None,
        generation_config: Optional[GenerationConfig] = None,
        prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
        logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
        stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
        **kwargs,
    ):
        """
        Generate function for the model to generate text outputs based on given inputs.
        """
        # Implementation details are encapsulated in the class and not commented here.

    def get_input_embeddings(self):
        """
        Retrieve input embeddings from the RAG generator.
        """
        return self.rag.generator.get_input_embeddings()

    def get_output_embeddings(self):
        """
        Retrieve output embeddings from the RAG generator.
        """
        return self.rag.generator.get_output_embeddings()

    def set_output_embeddings(self, new_embeddings):
        """
        Set new output embeddings for the RAG generator.
        """
        return self.rag.generator.set_output_embeddings(new_embeddings)

    def shift_tokens_right(self, input_ids, start_token_id=None):
        """
        Shift input ids one token to the right, and pad with start_token_id.
        """
        if start_token_id is None:
            start_token_id = self.config.decoder_start_token_id
        shifted_input_ids = input_ids.new_zeros(input_ids.shape)
        shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
        shifted_input_ids[:, 0] = start_token_id
        return shifted_input_ids

    def get_nll(self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, n_docs=None):
        """
        Calculate negative log likelihood loss for sequence logits and document scores.
        """
        n_docs = n_docs if n_docs is not None else self.config.n_docs

        # Shift tokens left and handle padding
        target = torch.cat(
            [target[:, 1:], target.new(target.shape[0], 1).fill_(self.config.generator.pad_token_id)], 1
        )

        def _mask_pads(ll, smooth_obj):
            """
            Mask padding tokens in loss calculations.
            """
            pad_mask = target.eq(self.config.generator.pad_token_id)
            if pad_mask.any():
                ll.masked_fill_(pad_mask, 0.0)
                smooth_obj.masked_fill_(pad_mask, 0.0)
            return ll.squeeze(-1), smooth_obj.squeeze(-1)

        # Marginalize logits and calculate log probabilities
        rag_logprobs = self.marginalize(seq_logits, doc_scores, n_docs)

        target = target.unsqueeze(-1)
        assert target.dim() == rag_logprobs.dim()

        # Gather log probabilities based on target indices
        ll = rag_logprobs.gather(dim=-1, index=target)
        smooth_obj = rag_logprobs.sum(dim=-1, keepdim=True)  # total sum of all (normalised) logits
        ll, smooth_obj = _mask_pads(ll, smooth_obj)
        ll = ll.sum(1)  # sum over tokens
        smooth_obj = smooth_obj.sum(1)

        # Compute final negative log likelihood loss and smooth loss
        nll_loss = -ll
        smooth_loss = -smooth_obj

        if reduce_loss:
            nll_loss = nll_loss.sum()
            smooth_loss = smooth_loss.sum()

        eps_i = epsilon / rag_logprobs.size(-1)
        loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
        return loss

.\models\rag\modeling_tf_rag.py

# coding=utf-8
# 版权所有 2020 年,RAG 作者和 HuggingFace 公司团队。
#
# 根据 Apache 许可证 2.0 版本(“许可证”)许可;
# 您不得使用此文件,除非符合许可证的规定。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件根据“原样”分发,
# 不提供任何明示或暗示的担保或条件。
# 有关详细信息,请参阅许可证。

"""TFRAG 模型实现。"""


from __future__ import annotations

import copy
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import numpy as np
import tensorflow as tf

# 导入各种模块和函数
from ...configuration_utils import PretrainedConfig
from ...generation import TFLogitsProcessorList
from ...modeling_tf_utils import (
    TFCausalLanguageModelingLoss,
    TFModelInputType,
    TFPreTrainedModel,
    keras,
    shape_list,
    unpack_inputs,
)
from ...utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_rag import RagConfig
from .retrieval_rag import RagRetriever

# 获取日志记录器
logger = logging.get_logger(__name__)

# 用于文档的配置名称
_CONFIG_FOR_DOC = "RagConfig"

# 定义 TFRetrievAugLMMarginOutput 类,继承自 ModelOutput
@dataclass
class TFRetrievAugLMMarginOutput(ModelOutput):
    """
    用于检索增强的边缘化模型输出的基类。

    """

    loss: tf.Tensor | None = None  # 损失张量,可选
    logits: tf.Tensor = None  # logits 张量
    past_key_values: List[tf.Tensor] | None = None  # 过去键值列表,可选
    doc_scores: tf.Tensor | None = None  # 文档分数张量,可选
    retrieved_doc_embeds: tf.Tensor | None = None  # 检索的文档嵌入张量,可选
    retrieved_doc_ids: tf.Tensor | None = None  # 检索的文档 ID 张量,可选
    context_input_ids: tf.Tensor | None = None  # 上下文输入 ID 张量,可选
    context_attention_mask: tf.Tensor | None = None  # 上下文注意力掩码张量,可选
    question_encoder_last_hidden_state: tf.Tensor | None = None  # 问题编码器最后隐藏状态张量,可选
    question_enc_hidden_states: Tuple[tf.Tensor, ...] | None = None  # 问题编码器隐藏状态元组,可选
    question_enc_attentions: Tuple[tf.Tensor, ...] | None = None  # 问题编码器注意力元组,可选
    generator_enc_last_hidden_state: tf.Tensor | None = None  # 生成器编码器最后隐藏状态张量,可选
    generator_enc_hidden_states: Tuple[tf.Tensor, ...] | None = None  # 生成器编码器隐藏状态元组,可选
    generator_enc_attentions: Tuple[tf.Tensor, ...] | None = None  # 生成器编码器注意力元组,可选
    generator_dec_hidden_states: Tuple[tf.Tensor, ...] | None = None  # 生成器解码器隐藏状态元组,可选
    generator_dec_attentions: Tuple[tf.Tensor, ...] | None = None  # 生成器解码器注意力元组,可选

# 定义 TFRetrievAugLMOutput 类,继承自 ModelOutput
@dataclass
class TFRetrievAugLMOutput(ModelOutput):
    """
    """

    logits: tf.Tensor = None  # logits 张量
    past_key_values: List[tf.Tensor] | None = None  # 过去键值列表,可选
    doc_scores: tf.Tensor | None = None  # 文档分数张量,可选
    retrieved_doc_embeds: tf.Tensor | None = None  # 检索的文档嵌入张量,可选
    retrieved_doc_ids: tf.Tensor | None = None  # 检索的文档 ID 张量,可选
    context_input_ids: tf.Tensor | None = None  # 上下文输入 ID 张量,可选
    context_attention_mask: tf.Tensor | None = None  # 上下文注意力掩码张量,可选
    question_encoder_last_hidden_state: tf.Tensor | None = None  # 问题编码器最后隐藏状态张量,可选
    question_enc_hidden_states: Tuple[tf.Tensor, ...] | None = None  # 问题编码器隐藏状态元组,可选
    question_enc_attentions: Tuple[tf.Tensor, ...] | None = None  # 问题编码器注意力元组,可选
    # 定义变量 generator_enc_last_hidden_state,用于存储生成器编码器的最后隐藏状态,初始值为 None
    generator_enc_last_hidden_state: tf.Tensor | None = None
    # 定义变量 generator_enc_hidden_states,用于存储生成器编码器的所有隐藏状态的元组,初始值为 None
    generator_enc_hidden_states: Tuple[tf.Tensor, ...] | None = None
    # 定义变量 generator_enc_attentions,用于存储生成器编码器的所有注意力权重的元组,初始值为 None
    generator_enc_attentions: Tuple[tf.Tensor, ...] | None = None
    # 定义变量 generator_dec_hidden_states,用于存储生成器解码器的所有隐藏状态的元组,初始值为 None
    generator_dec_hidden_states: Tuple[tf.Tensor, ...] | None = None
    # 定义变量 generator_dec_attentions,用于存储生成器解码器的所有注意力权重的元组,初始值为 None
    generator_dec_attentions: Tuple[tf.Tensor, ...] | None = None
# 定义一个名为 TFRagPreTrainedModel 的类,继承自 TFPreTrainedModel 类
class TFRagPreTrainedModel(TFPreTrainedModel):
    # 类的文档字符串,描述了 RAG 模型的功能和组成部分
    r"""
    RAG models were released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP
    Tasks](https://arxiv.org/abs/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandra Piktus et al.

    RAG is a retriever augmented model and encapsulate three components: a question encoder, a dataset retriever and a
    generator, the encoder and generator are trainable while the retriever is just an indexed dataset.
    """

    # 类属性,指定配置类为 RagConfig
    config_class = RagConfig
    # 类属性,基础模型前缀为 "rag"
    base_model_prefix = "rag"
    # 在加载时要忽略的键的列表
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    @classmethod
    # 类方法,用于从预训练的问题编码器和生成器创建实例
    def from_pretrained_question_encoder_generator(
        cls,
        question_encoder_pretrained_model_name_or_path: str = None,
        generator_pretrained_model_name_or_path: str = None,
        retriever: RagRetriever = None,
        *model_args,
        **kwargs,
RAG_START_DOCSTRING = r"""
    Args:
        config ([`RagConfig`]):
            # 模型配置类,包含模型的所有参数。使用配置文件初始化时不会加载模型的权重,只加载配置信息。
            # 若要加载模型权重,请参考 [`~TFPreTrainedModel.from_pretrained`] 方法。
        question_encoder ([`TFPreTrainedModel`]):
            # 编码器模型,与由 `retriever` 封装的 faiss 索引兼容。
        generator ([`TFPreTrainedModel`]):
            # 在 RAG 架构中用作生成器的 seq2seq 模型。
        retriever ([`RagRetriever`]):
            # 检索器类,封装了一个 faiss 索引,用于获取当前输入的上下文文档。
"""
"""


RAG_FORWARD_INPUTS_DOCSTRING = r"""
"""


@add_start_docstrings_to_model_forward(RAG_START_DOCSTRING)
class TFRagModel(TFRagPreTrainedModel):
    load_weight_prefix = "tf_rag_model_1"

    def __init__(
        self,
        config: Optional[PretrainedConfig] = None,
        question_encoder: Optional[TFPreTrainedModel] = None,
        generator: Optional[TFPreTrainedModel] = None,
        retriever: Optional[RagRetriever] = None,
        load_weight_prefix: Optional[str] = None,
        **kwargs,
    ):
        assert config is not None or (
            question_encoder is not None and generator is not None
        ), "Either a configuration or an question_encoder and a generator has to be provided."

        if config is None:
            # 从问题编码器和生成器的配置中创建一个 RagConfig 对象
            config = RagConfig.from_question_encoder_generator_configs(
                question_encoder.config, generator.config, **kwargs
            )
        else:
            assert isinstance(config, self.config_class), f"config: {config} has to be of type {self.config_class}"
        super().__init__(config, **kwargs)

        if question_encoder is None:
            # 如果没有提供问题编码器,则使用自动加载的 TFAutoModel 创建一个
            from ..auto.modeling_tf_auto import TFAutoModel

            question_encoder = TFAutoModel.from_config(config.question_encoder, name="question_encoder")

        if generator is None:
            # 如果没有提供生成器,则使用自动加载的 TFAutoModelForSeq2SeqLM 创建一个
            from ..auto.modeling_tf_auto import TFAutoModelForSeq2SeqLM

            load_weight_prefix = load_weight_prefix if load_weight_prefix is not None else self.load_weight_prefix
            generator = TFAutoModelForSeq2SeqLM.from_config(
                config.generator, name="generator", load_weight_prefix=load_weight_prefix + "/generator"
            )

        self.retriever = retriever
        if self.retriever is not None:
            # 如果提供了检索器,确保它是 RagRetriever 类型的对象
            assert isinstance(
                retriever, RagRetriever
            ), f"`self.retriever` is of type {type(self.retriever)}, but should be of type `RagRetriever`"
            self.retriever = retriever

        self.question_encoder = question_encoder
        self.generator = generator

    def set_retriever(self, retriever: RagRetriever):
        # 设置检索器
        self.retriever = retriever

    @unpack_inputs
    @add_start_docstrings_to_model_forward(RAG_FORWARD_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=TFRetrievAugLMOutput, config_class=_CONFIG_FOR_DOC)
    # 定义类方法 `call`,用于模型调用和推理
    def call(
        self,
        input_ids: TFModelInputType | None = None,  # 输入序列的标识符(TensorFlow 模型输入类型或 None)
        attention_mask: np.ndarray | tf.Tensor | None = None,  # 注意力掩码,用于指定哪些位置的输入要被关注
        encoder_outputs: np.ndarray | tf.Tensor | None = None,  # 编码器的输出,可能是 numpy 数组或 TensorFlow 张量
        decoder_input_ids: np.ndarray | tf.Tensor | None = None,  # 解码器的输入标识符序列,可能是 numpy 数组或 TensorFlow 张量
        decoder_attention_mask: np.ndarray | tf.Tensor | None = None,  # 解码器的注意力掩码
        past_key_values: Tuple[Tuple[Union[np.ndarray, tf.Tensor]]] | None = None,  # 过去的键值对,用于 Transformer 模型的存储和重用
        doc_scores: np.ndarray | tf.Tensor | None = None,  # 文档评分,可能是 numpy 数组或 TensorFlow 张量
        context_input_ids: np.ndarray | tf.Tensor | None = None,  # 上下文输入的标识符序列
        context_attention_mask: np.ndarray | tf.Tensor | None = None,  # 上下文输入的注意力掩码
        use_cache: bool | None = None,  # 是否使用缓存来加速解码过程
        output_attentions: bool | None = None,  # 是否输出注意力权重
        output_hidden_states: bool | None = None,  # 是否输出隐藏状态
        output_retrieved: bool | None = None,  # 是否输出检索到的信息(如检索式推理)
        n_docs: int | None = None,  # 文档数量
        return_dict: bool | None = None,  # 是否返回字典格式的输出
        training: bool = False,  # 是否在训练模式下
        **kwargs,  # 其他关键字参数,用于接收任何未指定的额外参数
    ):
        # 如果模型已经构建好,则直接返回
        if self.built:
            return
        # 标记模型已经构建
        self.built = True
        # 使用 TensorFlow 的名称作用域,构建生成器部分的模型
        with tf.name_scope(self.generator.name):
            self.generator.build(None)
        # 使用 TensorFlow 的名称作用域,构建问题编码器部分的模型
        with tf.name_scope(self.question_encoder.name):
            self.question_encoder.build(None)
@add_start_docstrings_to_model_forward(
    """
    A TF RAG-token model implementation. It performs RAG-token specific marginalization in the forward pass.
    """,
    RAG_START_DOCSTRING,
)
class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss):
    load_weight_prefix = "tf_rag_token_for_generation_1/rag"

    def __init__(
        self,
        config: Optional[PretrainedConfig] = None,
        question_encoder: Optional[TFPreTrainedModel] = None,
        generator: Optional[TFPreTrainedModel] = None,
        retriever: Optional[RagRetriever] = None,
        **kwargs,
    ):
        assert config is not None or (
            question_encoder is not None and generator is not None
        ), "Either a configuration or an encoder and a generator has to be provided."

        if config is None:
            # 如果未提供配置,根据提供的问题编码器和生成器配置生成一个新的RagConfig对象
            config = RagConfig.from_question_encoder_generator_configs(
                question_encoder.config, generator.config, **kwargs
            )

        super().__init__(config)

        # 实例化RAG模型
        self.rag = TFRagModel(
            config=config,
            question_encoder=question_encoder,
            generator=generator,
            retriever=retriever,
            load_weight_prefix=self.load_weight_prefix,
            name="rag",
        )

    def set_retriever(self, retriever: RagRetriever):
        # 设置RAG模型的检索器
        self.rag.retriever = retriever

    # 从 https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_bart.py 改编而来
    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        past_key_values=None,
        attention_mask=None,
        use_cache=None,
        encoder_outputs=None,
        doc_scores=None,
        n_docs=None,
        **kwargs,
    ):
        if past_key_values is not None:
            # 如果定义了过去的键值,只使用最后一个decoder_input_ids
            decoder_input_ids = decoder_input_ids[:, -1:]

        return {
            "input_ids": None,
            "encoder_outputs": encoder_outputs,
            "doc_scores": doc_scores,
            "context_attention_mask": attention_mask,
            "decoder_input_ids": decoder_input_ids,
            "past_key_values": past_key_values,
            "use_cache": use_cache,
            "do_marginalize": True,
            "n_docs": n_docs,
        }

    @property
    def retriever(self):
        # 返回RAG模型的检索器
        return self.rag.retriever

    @property
    def generator(self):
        # 返回RAG模型的生成器
        return self.rag.generator

    @property
    def question_encoder(self):
        # 返回RAG模型的问题编码器
        return self.rag.question_encoder

    @staticmethod
    def _gather_beams(nested, beam_indices, batch_axis=0):
        """
        RAG-specific `_gather_beams`: gathers the beam slices indexed by beam_indices into new beam array. If the
        nested tensor has a shape mismatch with the beam indices, then it means it is the cache. In that case, isolates
        and takes care of the extra dimension for ndocs.
        """

        def gather_fn(tensor):
            # 判断是否为 RAG 的缓存数据
            is_rag_cache = tensor.shape[0] != beam_indices.shape[0]
            if is_rag_cache:
                # 如果是缓存数据,则计算每个文档的数量和批次大小
                n_docs = tensor.shape[0] // beam_indices.shape[0]
                batch_size = beam_indices.shape[0]
                # 重塑张量为 (批次大小, num beams, n_docs, ...) 的格式,这是 RAG 期望的缓存格式
                tensor = tf.reshape(tensor, (batch_size, -1, n_docs, *tensor.shape[2:]))

            # 使用给定的索引从张量中收集数据
            gathered_tensor = tf.gather(params=tensor, indices=beam_indices, axis=1, batch_dims=1)

            if is_rag_cache:
                # 如果是缓存数据,则重新塑造成 beam search 期望的形状
                gathered_tensor = tf.reshape(gathered_tensor, (batch_size * n_docs, -1, *gathered_tensor.shape[3:]))

            return gathered_tensor

        # 对嵌套结构应用 gather_fn 函数,用于收集索引的数据
        return tf.nest.map_structure(gather_fn, nested)

    def marginalize(self, seq_logits, doc_scores, n_docs=None):
        n_docs = n_docs if n_docs is not None else self.config.n_docs

        # RAG-token marginalization
        # 对序列 logits 应用 log_softmax,在指定轴上进行归一化
        seq_logprobs = tf.nn.log_softmax(seq_logits, axis=-1)
        # 重新塑造成 [batch_size // n_docs, n_docs, -1, seq_logits.shape[-1]] 的形状
        seq_logprobs = tf.reshape(seq_logprobs, [seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.shape[-1]])
        # 对文档分数应用 log_softmax,在第 1 轴上进行归一化
        doc_logprobs = tf.nn.log_softmax(doc_scores, axis=1)
        # 在最后添加两个维度
        doc_logprobs = tf.expand_dims(doc_logprobs, axis=-1)
        doc_logprobs = tf.expand_dims(doc_logprobs, axis=-1)  # 两次
        # 计算序列和文档 log-probabilities 的总和
        log_prob_sum = seq_logprobs + doc_logprobs
        # 在第 1 轴上计算 logsumexp
        return tf.reduce_logsumexp(log_prob_sum, axis=1)

    @unpack_inputs
    @add_start_docstrings_to_model_forward(RAG_FORWARD_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=TFRetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC)
    # 定义一个方法 `call`,用于模型调用。接受多个输入参数,包括输入的编码 `input_ids`,
    # 注意力掩码 `attention_mask`,解码器的输入编码 `decoder_input_ids` 和注意力掩码 `decoder_attention_mask`,
    # 编码器的输出 `encoder_outputs`,过去的键值对 `past_key_values`,文档分数 `doc_scores`,
    # 上下文输入编码 `context_input_ids` 和上下文注意力掩码 `context_attention_mask` 等等。
    # 其他参数包括是否使用缓存 `use_cache`,是否输出注意力 `output_attentions` 和隐藏状态 `output_hidden_states`,
    # 是否输出检索结果 `output_retrieved`,文档数量 `n_docs`,是否边际化 `do_marginalize`,
    # 标签 `labels`,是否减少损失 `reduce_loss`,是否返回字典 `return_dict`,
    # 是否处于训练模式 `training` 等等。
    # 方法允许传递任意其他关键字参数 `kwargs`。
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        decoder_input_ids: np.ndarray | tf.Tensor | None = None,
        decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
        encoder_outputs: np.ndarray | tf.Tensor | None = None,
        past_key_values: Tuple[Tuple[Union[np.ndarray, tf.Tensor]]] | None = None,
        doc_scores: np.ndarray | tf.Tensor | None = None,
        context_input_ids: np.ndarray | tf.Tensor | None = None,
        context_attention_mask: np.ndarray | tf.Tensor | None = None,
        use_cache: bool | None = None,
        output_attentions: bool | None = None,
        output_hidden_states: bool | None = None,
        output_retrieved: bool | None = None,
        n_docs: int | None = None,
        do_marginalize: bool | None = None,
        labels: np.ndarray | tf.Tensor | None = None,
        reduce_loss: bool | None = None,
        return_dict: bool | None = None,
        training: bool = False,
        **kwargs,  # needs kwargs for generation
    ):
        pass  # 方法主体未提供

    # 定义一个生成方法 `generate`,用于生成文本。接受多个输入参数,包括输入的编码 `input_ids`,
    # 注意力掩码 `attention_mask`,上下文输入编码 `context_input_ids` 和上下文注意力掩码 `context_attention_mask`,
    # 文档分数 `doc_scores`,文档数量 `n_docs`,生成配置 `generation_config`,
    # 对 logits 进行处理的处理器 `logits_processor`,以及其他关键字参数 `kwargs`。
    def generate(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: tf.Tensor | None = None,
        context_input_ids=None,
        context_attention_mask=None,
        doc_scores=None,
        n_docs=None,
        generation_config=None,
        logits_processor=TFLogitsProcessorList(),
        **kwargs,
    ):
        pass  # 方法主体未提供

    # 返回 RAG 模型中生成器的输入嵌入
    def get_input_embeddings(self):
        return self.rag.generator.get_input_embeddings()

    # 返回 RAG 模型中生成器的输出嵌入
    def get_output_embeddings(self):
        return self.rag.generator.get_output_embeddings()

    # 从 tf_t5 和 tf_bart 的 _shift_right 方法进行适配
    # 该方法可能实现了类似于将序列向右移动一个位置的功能
    # 但具体实现细节不在此处提供
    # 适配自 tf_t5 和 tf_bart 的 _shift_right 方法
    # 将输入的 token ids 向右移动一位,并用 start_token_id 进行填充
    def shift_tokens_right(self, input_ids, start_token_id=None):
        """Shift input ids one token to the right, and pad with start_token_id"""

        if start_token_id is None:
            start_token_id = self.generator.config.decoder_start_token_id
            assert start_token_id is not None, (
                "self.generator.config.decoder_start_token_id has to be defined. In Rag we commonly use Bart as"
                " generator, see Bart docs for more information"
            )

        pad_token_id = self.generator.config.pad_token_id
        assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."

        # 创建一个形状为 (batch_size, 1) 的张量,用 start_token_id 填充
        start_tokens = tf.fill((shape_list(input_ids)[0], 1), tf.cast(start_token_id, input_ids.dtype))
        # 将 start_tokens 与 input_ids 的前 n-1 列拼接起来,实现向右移动一位的效果
        shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)

        # 将 labels 中可能存在的 -100 值替换为 pad_token_id
        shifted_input_ids = tf.where(
            shifted_input_ids == -100,
            tf.fill(shape_list(shifted_input_ids), tf.cast(pad_token_id, input_ids.dtype)),
            shifted_input_ids,
        )

        # 使用断言确保 `labels` 中只有正值和 -100
        assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, shifted_input_ids.dtype))

        # 确保断言操作被调用,通过在结果中包裹一个 identity 操作
        with tf.control_dependencies([assert_gte0]):
            shifted_input_ids = tf.identity(shifted_input_ids)

        return shifted_input_ids

    # nll 代表 'negative log likelihood'
    def get_nll(self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, n_docs=None):
        n_docs = n_docs if n_docs is not None else self.config.n_docs
        # 将 tokens 向左移动(来自原始的 PyTorch 版本)

        # 将 target 的每一行向左移动一个 token,并用 self.config.generator.pad_token_id 进行填充
        target = tf.concat(
            [target[:, 1:], tf.fill([target.shape[0], 1], tf.cast(self.config.generator.pad_token_id, target.dtype))],
            axis=1,
        )
        # 对 seq_logits 和 doc_scores 进行边缘化,得到 rag_logprobs
        rag_logprobs = self.marginalize(seq_logits, doc_scores, n_docs)
        # 计算损失,匹配 logits 版本,reduce_loss 参数决定是否减少损失
        loss = self.hf_compute_loss(target, rag_logprobs, from_logits=True, reduce_loss=reduce_loss)

        return loss

    # 采用 modeling_tf_bart,并添加 smooth_loss 以匹配 PyTorch 版本
    def hf_compute_loss(self, labels, y_pred, smooth_epsilon=0.0, from_logits=True, reduce_loss=False):
        """计算损失函数,忽略填充标记的交叉熵损失"""
        # Matt: 该损失函数目前无法与XLA兼容,但它执行了一些非常奇怪的操作,
        #       我不太确定如何转换它。
        #       这里执行了一些非常奇怪的操作,我不太确定如何转换它。

        # 定义损失函数为稀疏分类交叉熵损失,用于处理输出为 logits 的情况
        loss_fn = keras.losses.SparseCategoricalCrossentropy(
            from_logits=True,  # 输出是否为 logits
            reduction=keras.losses.Reduction.SUM,  # 损失函数如何进行汇总
        )

        if from_logits is False:  # 如果输出不是 logits,则将其转换为 logits
            eps = 1e-9
            y_pred = tf.clip_by_value(y_pred, clip_value_min=eps, clip_value_max=1 - eps)
            y_pred = tf.math.log(y_pred)

        logits = y_pred
        melted_labels = tf.reshape(labels, (-1,))
        # 找出非填充标记的位置
        active_loss = tf.not_equal(melted_labels, self.config.generator.pad_token_id)

        # 根据非填充标记的位置,筛选出有效的 logits 和 labels
        reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, logits.shape[2])), active_loss)
        labels = tf.boolean_mask(melted_labels, active_loss)

        # 计算交叉熵损失
        nll_loss = loss_fn(labels, reduced_logits)

        # 计算平滑损失
        smooth_loss = -tf.reduce_sum(reduced_logits, axis=-1)
        smooth_loss = tf.reduce_sum(smooth_loss)  # 类似于 torch 的 sum 和 squeeze 操作
        eps_i = smooth_epsilon / reduced_logits.shape[-1]

        # 计算最终损失函数,结合交叉熵损失和平滑损失
        loss = (1.0 - smooth_epsilon) * nll_loss + eps_i * smooth_loss

        return loss

    def build(self, input_shape=None):
        # 如果已经构建过,则直接返回
        if self.built:
            return
        self.built = True
        # 如果存在 rag 属性,则在 rag 的命名空间下构建模型
        if getattr(self, "rag", None) is not None:
            with tf.name_scope(self.rag.name):
                self.rag.build(None)
# 使用装饰器为模型的 call 方法添加文档字符串,描述其功能为执行RAG-sequence模型的前向传播过程
@add_start_docstrings_to_model_forward(
    """
    A TF RAG-sequence model implementation. It performs RAG-sequence specific marginalization in the forward pass.
    """,
    RAG_START_DOCSTRING,
)
class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss):
    # 加载权重的前缀
    load_weight_prefix = "tf_rag_sequence_for_generation_1/rag"

    # 初始化方法
    def __init__(
        self,
        config: Optional[PretrainedConfig] = None,
        question_encoder: Optional[TFPreTrainedModel] = None,
        generator: Optional[TFPreTrainedModel] = None,
        retriever: Optional[RagRetriever] = None,
        **kwargs,
    ):
        # 断言确保提供了配置或者问题编码器与生成器
        assert config is not None or (
            question_encoder is not None and generator is not None
        ), "Either a configuration or an encoder and a generator has to be provided."

        # 如果未提供配置,则从问题编码器和生成器配置中创建 RagConfig 对象
        if config is None:
            config = RagConfig.from_question_encoder_generator_configs(
                question_encoder.config, generator.config, **kwargs
            )

        # 调用父类初始化方法
        super().__init__(config)

        # 实例化模型
        self.rag = TFRagModel(
            config=config,
            question_encoder=question_encoder,
            generator=generator,
            retriever=retriever,
            load_weight_prefix=self.load_weight_prefix,
            name="rag",
        )

    # 设置检索器的方法
    def set_retriever(self, retriever: RagRetriever):
        self.rag.retriever = retriever

    # 检索器属性的 getter 方法
    @property
    def retriever(self):
        return self.rag.retriever

    # 生成器属性的 getter 方法
    @property
    def generator(self):
        return self.rag.generator

    # 问题编码器属性的 getter 方法
    @property
    def question_encoder(self):
        return self.rag.question_encoder

    # 装饰器为 call 方法添加文档字符串,描述其输入输出及功能
    @unpack_inputs
    @add_start_docstrings_to_model_forward(RAG_FORWARD_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=TFRetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC)
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        decoder_input_ids: np.ndarray | tf.Tensor | None = None,
        decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
        encoder_outputs: np.ndarray | tf.Tensor | None = None,
        past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
        doc_scores: np.ndarray | tf.Tensor | None = None,
        context_input_ids: np.ndarray | tf.Tensor | None = None,
        context_attention_mask: np.ndarray | tf.Tensor | None = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_retrieved: Optional[bool] = None,
        n_docs: Optional[int] = None,
        exclude_bos_score: Optional[bool] = None,
        labels: np.ndarray | tf.Tensor | None = None,
        reduce_loss: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: bool = False,
        **kwargs,  # needs kwargs for generation
    ):
        # 实现模型的前向传播
        pass  # The actual implementation details would follow here, but are not provided in the snippet
    # 定义一个方法 get_nll,接受一些参数:seq_logits 是序列的逻辑回归输出,doc_scores 是文档得分,target 是目标值
    # reduce_loss 控制是否减少损失,默认为 False;epsilon 是一个小数,排除 BOS 得分,默认为 False
    # n_docs 是文档数量,默认为 None
    def get_nll(
        self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, exclude_bos_score=False, n_docs=None
    ):
        pass  # 这里只是定义方法的结构,没有具体实现

    # 定义一个生成方法 generate,接受多个输入参数:input_ids 是模型输入的 token IDs
    # attention_mask 是注意力掩码,context_input_ids 和 context_attention_mask 是上下文相关的输入
    # doc_scores 是文档得分,do_deduplication 控制是否去重,默认为 True;num_return_sequences 控制返回的序列数量,默认为 1
    # num_beams 控制束搜索的数量,默认为 1,n_docs 是文档数量,默认为 None
    def generate(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: tf.Tensor | None = None,
        context_input_ids=None,
        context_attention_mask=None,
        doc_scores=None,
        do_deduplication=None,  # 默认为 True
        num_return_sequences=None,  # 默认为 1
        num_beams=None,  # 默认为 1
        n_docs=None,
        **model_kwargs,
    ):
        pass  # 这里只是定义方法的结构,没有具体实现

    # 静态方法 _cat_and_pad 用于生成方法 generate 中的输入张量列表的拼接和填充
    @staticmethod
    def _cat_and_pad(tensors, pad_token_id):
        # used by generate(): tensors is a (batched) list of (candidates, len); len is varied across batch
        # 方法 generate 的辅助方法,tensors 是一个 (批量化的) 列表,每个元素是 (候选项,长度);长度在批次中可能不同

        # 初始化一个填充后的张量,形状为 (所有候选项总数,最大候选项长度)
        new_shape = sum([t.shape[0] for t in tensors]), max([t.shape[1] for t in tensors])
        output = tf.fill(new_shape, pad_token_id)

        # 使用 tf.Variable 创建可变张量,因为普通张量不支持切片赋值
        output = tf.Variable(output)

        # 逐个赋值每个输入张量的内容到 output 中相应位置
        ind = 0
        for t in tensors:
            output[ind : ind + t.shape[0], : t.shape[1]].assign(t)
            ind += t.shape[0]

        # 转换回普通张量并返回,确保类型与第一个张量的元素类型一致
        output = tf.convert_to_tensor(output)
        return tf.cast(output, tensors[0][0].dtype)

    # 方法 build 用于构建对象,初始化对象的属性和状态
    def build(self, input_shape=None):
        if self.built:
            return  # 如果已经构建过,直接返回

        self.built = True  # 标记对象已经构建

        # 如果对象具有属性 rag,则在 rag 的命名空间下构建对象
        if getattr(self, "rag", None) is not None:
            with tf.name_scope(self.rag.name):
                self.rag.build(None)

.\models\rag\retrieval_rag.py

# coding=utf-8
# Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
RAG Retriever model implementation.
"""

import os
import pickle
import time
from typing import Iterable, List, Optional, Tuple

import numpy as np

from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import BatchEncoding
from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends, strtobool
from .configuration_rag import RagConfig
from .tokenization_rag import RagTokenizer

# 如果datasets可用,则导入相关模块
if is_datasets_available():
    from datasets import Dataset, load_dataset, load_from_disk

# 如果faiss可用,则导入faiss模块
if is_faiss_available():
    import faiss

# 获取日志记录器
logger = logging.get_logger(__name__)

# Legacy索引路径
LEGACY_INDEX_PATH = "https://storage.googleapis.com/huggingface-nlp/datasets/wiki_dpr/"

class Index:
    """
    A base class for the Indices encapsulated by the [`RagRetriever`].
    """

    def get_doc_dicts(self, doc_ids: np.ndarray) -> List[dict]:
        """
        Returns a list of dictionaries, containing titles and text of the retrieved documents.

        Args:
            doc_ids (`np.ndarray` of shape `(batch_size, n_docs)`):
                A tensor of document indices.
        """
        raise NotImplementedError

    def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]:
        """
        For each query in the batch, retrieves `n_docs` documents.

        Args:
            question_hidden_states (`np.ndarray` of shape `(batch_size, vector_size)`):
                An array of query vectors.
            n_docs (`int`):
                The number of docs retrieved per query.

        Returns:
            `np.ndarray` of shape `(batch_size, n_docs)`: A tensor of indices of retrieved documents.
            `np.ndarray` of shape `(batch_size, vector_size)`: A tensor of vector representations of retrieved documents.
        """
        raise NotImplementedError

    def is_initialized(self):
        """
        Returns `True` if index is already initialized.
        """
        raise NotImplementedError

    def init_index(self):
        """
        A function responsible for loading the index into memory. Should be called only once per training run of a RAG
        model. E.g. if the model is trained on multiple GPUs in a distributed setup, only one of the workers will load
        the index.
        """
        raise NotImplementedError
    """
    一个可以从使用 https://github.com/facebookresearch/DPR 构建的文件中反序列化的索引。我们使用该仓库中指定的默认 faiss 索引参数。

    Args:
        vector_size (`int`):
            索引向量的维度。
        index_path (`str`):
            包含与 [`~models.rag.retrieval_rag.LegacyIndex`] 兼容的索引文件的 *目录* 路径。
    """

    # 索引文件名
    INDEX_FILENAME = "hf_bert_base.hnswSQ8_correct_phi_128.c_index"
    # 文章段落文件名
    PASSAGE_FILENAME = "psgs_w100.tsv.pkl"

    def __init__(self, vector_size, index_path):
        # 用于映射索引 ID 到数据库 ID 的空列表
        self.index_id_to_db_id = []
        # 索引文件路径
        self.index_path = index_path
        # 加载文章段落数据
        self.passages = self._load_passages()
        # 索引向量的维度
        self.vector_size = vector_size
        # 索引对象
        self.index = None
        # 索引是否已初始化标志
        self._index_initialized = False

    def _resolve_path(self, index_path, filename):
        # 判断索引路径是本地路径还是远程路径
        is_local = os.path.isdir(index_path)
        try:
            # 从 URL 或缓存中加载文件
            resolved_archive_file = cached_file(index_path, filename)
        except EnvironmentError:
            # 抛出加载错误信息
            msg = (
                f"Can't load '{filename}'. Make sure that:\n\n"
                f"- '{index_path}' is a correct remote path to a directory containing a file named {filename}\n\n"
                f"- or '{index_path}' is the correct path to a directory containing a file named {filename}.\n\n"
            )
            raise EnvironmentError(msg)
        # 打印加载信息,如果是本地路径则显示完整路径
        if is_local:
            logger.info(f"loading file {resolved_archive_file}")
        else:
            logger.info(f"loading file {filename} from cache at {resolved_archive_file}")
        # 返回解析后的文件路径
        return resolved_archive_file

    def _load_passages(self):
        # 打印从指定路径加载段落信息的日志
        logger.info(f"Loading passages from {self.index_path}")
        # 解析文章段落文件的路径
        passages_path = self._resolve_path(self.index_path, self.PASSAGE_FILENAME)
        # 如果未设置环境变量 TRUST_REMOTE_CODE 或其值为 False,则抛出安全性错误
        if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
            raise ValueError(
                "This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
                "malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
                "that could have been tampered with. If you already verified the pickle data and decided to use it, "
                "you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
            )
        # 使用 pickle 加载文章段落数据
        with open(passages_path, "rb") as passages_file:
            passages = pickle.load(passages_file)
        # 返回加载的段落数据
        return passages
    # 日志记录,显示正在从指定路径加载索引
    logger.info(f"Loading index from {self.index_path}")
    # 解析索引文件路径,生成完整路径名,包括索引文件名和扩展名
    resolved_index_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index.dpr")
    # 使用 Faiss 库从解析后的索引路径读取索引数据
    self.index = faiss.read_index(resolved_index_path)
    # 解析元数据文件路径,生成完整路径名,包括索引文件名和元数据文件扩展名
    resolved_meta_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index_meta.dpr")
    # 如果环境变量 TRUST_REMOTE_CODE 不为真,则抛出值错误异常,警告使用不安全的 pickle.load
    if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
        raise ValueError(
            "This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
            "malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
            "that could have been tampered with. If you already verified the pickle data and decided to use it, "
            "you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
        )
    # 使用二进制读取模式打开元数据文件,加载 self.index_id_to_db_id 字典数据
    with open(resolved_meta_path, "rb") as metadata_file:
        self.index_id_to_db_id = pickle.load(metadata_file)
    # 断言索引 id 到数据库 id 的映射长度应与 Faiss 索引的总数相同,用于验证数据一致性
    assert (
        len(self.index_id_to_db_id) == self.index.ntotal
    ), "Deserialized index_id_to_db_id should match faiss index size"
class HFIndexBase(Index):
    # HFIndexBase 类,继承自 Index 类,用于处理特定格式的数据集索引

    def __init__(self, vector_size, dataset, index_initialized=False):
        # 初始化方法,接受向量大小、数据集和索引初始化状态作为参数
        self.vector_size = vector_size
        self.dataset = dataset
        self._index_initialized = index_initialized
        # 检查数据集格式是否正确
        self._check_dataset_format(with_index=index_initialized)
        # 设置数据集格式为 numpy 格式,指定列为 embeddings,输出所有列,数据类型为 float32
        dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True, dtype="float32")

    def _check_dataset_format(self, with_index: bool):
        # 检查数据集格式是否符合要求,包括是否为 datasets.Dataset 对象,列是否包含必需的 title、text 和 embeddings
        if not isinstance(self.dataset, Dataset):
            raise ValueError(f"Dataset should be a datasets.Dataset object, but got {type(self.dataset)}")
        if len({"title", "text", "embeddings"} - set(self.dataset.column_names)) > 0:
            raise ValueError(
                "Dataset should be a dataset with the following columns: "
                "title (str), text (str) and embeddings (arrays of dimension vector_size), "
                f"but got columns {self.dataset.column_names}"
            )
        # 如果需要索引但数据集中未包含 embeddings 索引,则引发异常
        if with_index and "embeddings" not in self.dataset.list_indexes():
            raise ValueError(
                "Missing faiss index in the dataset. Make sure you called `dataset.add_faiss_index` to compute it "
                "or `dataset.load_faiss_index` to load one from the disk."
            )

    def init_index(self):
        # 初始化索引的抽象方法,需要在子类中实现
        raise NotImplementedError()

    def is_initialized(self):
        # 返回索引是否已初始化的状态
        return self._index_initialized

    def get_doc_dicts(self, doc_ids: np.ndarray) -> List[dict]:
        # 根据文档的 ID 获取文档信息,返回一个字典列表
        return [self.dataset[doc_ids[i].tolist()] for i in range(doc_ids.shape[0])]

    def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]:
        # 根据问题的隐藏状态获取前 n_docs 个最相关的文档
        # 使用数据集的批量搜索功能,根据 embeddings 列和问题的隐藏状态进行搜索
        _, ids = self.dataset.search_batch("embeddings", question_hidden_states, n_docs)
        # 根据搜索结果获取对应的文档数据
        docs = [self.dataset[[i for i in indices if i >= 0]] for indices in ids]
        # 提取文档的 embeddings 向量
        vectors = [doc["embeddings"] for doc in docs]
        # 对于搜索结果数量不足 n_docs 的情况,用零向量填充
        for i in range(len(vectors)):
            if len(vectors[i]) < n_docs:
                vectors[i] = np.vstack([vectors[i], np.zeros((n_docs - len(vectors[i]), self.vector_size))])
        # 返回搜索结果的 IDs 和对应的 embeddings 向量
        return np.array(ids), np.array(vectors)  # shapes (batch_size, n_docs) and (batch_size, n_docs, d)


class CanonicalHFIndex(HFIndexBase):
    """
    A wrapper around an instance of [`~datasets.Datasets`]. If `index_path` is set to `None`, we load the pre-computed
    index available with the [`~datasets.arrow_dataset.Dataset`], otherwise, we load the index from the indicated path
    on disk.
    """
    # CanonicalHFIndex 类,继承自 HFIndexBase 类,是对 datasets.Datasets 的封装,支持加载预先计算的索引或从磁盘加载索引
    """
    Args:
        vector_size (`int`): the dimension of the passages embeddings used by the index
        dataset_name (`str`, optional, defaults to `wiki_dpr`):
            A dataset identifier of the indexed dataset on HuggingFace AWS bucket (list all available datasets and ids
            with `datasets.list_datasets()`).
        dataset_split (`str`, optional, defaults to `train`):
            Which split of the `dataset` to load.
        index_name (`str`, optional, defaults to `train`):
            The index_name of the index associated with the `dataset`. The index loaded from `index_path` will be saved
            under this name.
        index_path (`str`, optional, defaults to `None`):
            The path to the serialized faiss index on disk.
        use_dummy_dataset (`bool`, optional, defaults to `False`):
            If True, use the dummy configuration of the dataset for tests.
    """

    def __init__(
        self,
        vector_size: int,
        dataset_name: str = "wiki_dpr",
        dataset_split: str = "train",
        index_name: Optional[str] = None,
        index_path: Optional[str] = None,
        use_dummy_dataset: bool = False,
        dataset_revision=None,
    ):
        # Validate that either `index_name` or `index_path` is provided
        if int(index_path is None) + int(index_name is None) != 1:
            raise ValueError("Please provide `index_name` or `index_path`.")
        
        # Initialize instance variables with provided parameters
        self.dataset_name = dataset_name
        self.dataset_split = dataset_split
        self.index_name = index_name
        self.index_path = index_path
        self.use_dummy_dataset = use_dummy_dataset
        self.dataset_revision = dataset_revision
        
        # Log information about dataset loading
        logger.info(f"Loading passages from {self.dataset_name}")
        
        # Load the dataset using Hugging Face datasets library
        dataset = load_dataset(
            self.dataset_name,
            with_index=False,
            split=self.dataset_split,
            dummy=self.use_dummy_dataset,
            revision=self.dataset_revision,
        )
        
        # Call superclass initialization with vector size and loaded dataset
        super().__init__(vector_size, dataset, index_initialized=False)

    def init_index(self):
        # Initialize index based on provided `index_path` or `index_name`
        if self.index_path is not None:
            # Load index from specified file path
            logger.info(f"Loading index from {self.index_path}")
            self.dataset.load_faiss_index("embeddings", file=self.index_path)
        else:
            # Load index associated with `index_name` from dataset
            logger.info(f"Loading index from {self.dataset_name} with index name {self.index_name}")
            
            # Load dataset with embeddings and index
            self.dataset = load_dataset(
                self.dataset_name,
                with_embeddings=True,
                with_index=True,
                split=self.dataset_split,
                index_name=self.index_name,
                dummy=self.use_dummy_dataset,
                revision=self.dataset_revision,
            )
            
            # Set dataset format to numpy for compatibility
            self.dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
        
        # Mark index initialization as completed
        self._index_initialized = True
class CustomHFIndex(HFIndexBase):
    """
    A wrapper around an instance of [`datasets.Datasets`]. The dataset and the index are both loaded from the
    indicated paths on disk.

    Args:
        vector_size (`int`): the dimension of the passages embeddings used by the index
        dataset_path (`str`):
            The path to the serialized dataset on disk. The dataset should have 3 columns: title (str), text (str) and
            embeddings (arrays of dimension vector_size)
        index_path (`str`)
            The path to the serialized faiss index on disk.
    """

    def __init__(self, vector_size: int, dataset, index_path=None):
        super().__init__(vector_size, dataset, index_initialized=index_path is None)
        self.index_path = index_path
        # 初始化函数,设置向量大小和数据集,并根据 index_path 是否为 None 来初始化索引状态

    @classmethod
    def load_from_disk(cls, vector_size, dataset_path, index_path):
        logger.info(f"Loading passages from {dataset_path}")
        if dataset_path is None or index_path is None:
            raise ValueError(
                "Please provide `dataset_path` and `index_path` after calling `dataset.save_to_disk(dataset_path)` "
                "and `dataset.get_index('embeddings').save(index_path)`."
            )
        dataset = load_from_disk(dataset_path)
        return cls(vector_size=vector_size, dataset=dataset, index_path=index_path)
        # 从磁盘加载数据集和索引,根据提供的路径信息,返回一个 CustomHFIndex 的实例

    def init_index(self):
        if not self.is_initialized():
            logger.info(f"Loading index from {self.index_path}")
            self.dataset.load_faiss_index("embeddings", file=self.index_path)
            self._index_initialized = True
            # 如果索引尚未初始化,则加载 faiss 索引文件到数据集,并设置索引初始化状态为 True
    # 定义一个名为 RagRetriever 的类,用于处理检索相关的功能
    class RagRetriever:

        # 初始化函数,接受配置参数、问题编码器的分词器、生成器的分词器、索引对象和初始化检索标志
        def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None, init_retrieval=True):
            # 标志是否初始化检索
            self._init_retrieval = init_retrieval
            # 要求必需的后端库
            requires_backends(self, ["datasets", "faiss"])
            # 调用父类初始化函数
            super().__init__()
            # 设置索引对象,如果未提供索引则调用内部方法构建索引
            self.index = index or self._build_index(config)
            # 设置生成器的分词器
            self.generator_tokenizer = generator_tokenizer
            # 设置问题编码器的分词器
            self.question_encoder_tokenizer = question_encoder_tokenizer

            # 设置文档数量
            self.n_docs = config.n_docs
            # 设置检索批处理大小
            self.batch_size = config.retrieval_batch_size

            # 存储配置参数
            self.config = config
            # 如果标志允许,则初始化检索
            if self._init_retrieval:
                self.init_retrieval()

            # 上下文编码器的分词器初始化为空
            self.ctx_encoder_tokenizer = None
            # 是否返回标记化的文档标志初始化为假
            self.return_tokenized_docs = False

        @staticmethod
    # 从给定的配置参数构建索引对象
    def _build_index(config):
        # 如果配置指定使用旧版索引,返回 LegacyIndex 实例
        if config.index_name == "legacy":
            return LegacyIndex(
                config.retrieval_vector_size,
                config.index_path or LEGACY_INDEX_PATH,
            )
        # 如果配置指定使用自定义索引,加载自定义索引数据并返回 CustomHFIndex 实例
        elif config.index_name == "custom":
            return CustomHFIndex.load_from_disk(
                vector_size=config.retrieval_vector_size,
                dataset_path=config.passages_path,
                index_path=config.index_path,
            )
        # 否则,返回 CanonicalHFIndex 实例
        else:
            return CanonicalHFIndex(
                vector_size=config.retrieval_vector_size,
                dataset_name=config.dataset,
                dataset_split=config.dataset_split,
                index_name=config.index_name,
                index_path=config.index_path,
                use_dummy_dataset=config.use_dummy_dataset,
                dataset_revision=config.dataset_revision,
            )

    @classmethod
    # 从预训练模型或路径中加载检索器实例
    def from_pretrained(cls, retriever_name_or_path, indexed_dataset=None, **kwargs):
        requires_backends(cls, ["datasets", "faiss"])
        # 加载配置信息,如果未提供则从预训练模型中加载
        config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
        # 加载 RAG 模型的分词器
        rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
        # 获取问题编码器和生成器的分词器
        question_encoder_tokenizer = rag_tokenizer.question_encoder
        generator_tokenizer = rag_tokenizer.generator
        # 如果提供了索引数据集,强制配置使用自定义索引,并创建 CustomHFIndex 实例
        if indexed_dataset is not None:
            config.index_name = "custom"
            index = CustomHFIndex(config.retrieval_vector_size, indexed_dataset)
        # 否则,根据配置构建索引对象
        else:
            index = cls._build_index(config)
        # 返回根据配置构建的检索器实例
        return cls(
            config,
            question_encoder_tokenizer=question_encoder_tokenizer,
            generator_tokenizer=generator_tokenizer,
            index=index,
        )

    # 将当前对象的预训练参数保存到指定目录
    def save_pretrained(self, save_directory):
        # 如果当前索引为 CustomHFIndex 类型
        if isinstance(self.index, CustomHFIndex):
            # 如果配置中索引路径为空,则保存索引数据到默认路径
            if self.config.index_path is None:
                index_path = os.path.join(save_directory, "hf_dataset_index.faiss")
                self.index.dataset.get_index("embeddings").save(index_path)
                self.config.index_path = index_path
            # 如果配置中 passages_path 为空,则保存数据集到默认路径
            if self.config.passages_path is None:
                passages_path = os.path.join(save_directory, "hf_dataset")
                # 由于当前版本的 datasets 不支持带有索引的 save_to_disk 操作,因此需执行此操作
                faiss_index = self.index.dataset._indexes.pop("embeddings")
                self.index.dataset.save_to_disk(passages_path)
                self.index.dataset._indexes["embeddings"] = faiss_index
                self.config.passages_path = passages_path
        # 将当前配置保存到指定目录
        self.config.save_pretrained(save_directory)
        # 初始化 RAG 分词器并保存到指定目录
        rag_tokenizer = RagTokenizer(
            question_encoder=self.question_encoder_tokenizer,
            generator=self.generator_tokenizer,
        )
        rag_tokenizer.save_pretrained(save_directory)
    def init_retrieval(self):
        """
        Retriever initialization function. It loads the index into memory.
        """

        # 记录初始化检索过程的日志信息
        logger.info("initializing retrieval")
        # 调用索引对象的初始化方法,加载索引到内存中
        self.index.init_index()

    def postprocess_docs(self, docs, input_strings, prefix, n_docs, return_tensors=None):
        r"""
        Postprocessing retrieved `docs` and combining them with `input_strings`.

        Args:
            docs  (`dict`):
                Retrieved documents.
            input_strings (`str`):
                Input strings decoded by `preprocess_query`.
            prefix (`str`):
                Prefix added at the beginning of each input, typically used with T5-based models.

        Return:
            `tuple(tensors)`: a tuple consisting of two elements: contextualized `input_ids` and a compatible
            `attention_mask`.
        """

        def cat_input_and_doc(doc_title, doc_text, input_string, prefix):
            # TODO(Patrick): if we train more RAG models, I want to put the input first to take advantage of effortless truncation
            # TODO(piktus): better handling of truncation
            # 如果文档标题以双引号开头,去除开头的双引号
            if doc_title.startswith('"'):
                doc_title = doc_title[1:]
            # 如果文档标题以双引号结尾,去除结尾的双引号
            if doc_title.endswith('"'):
                doc_title = doc_title[:-1]
            # 如果前缀为空,则置为空字符串
            if prefix is None:
                prefix = ""
            # 组装处理后的文本片段,包括标题、文本内容、输入字符串,中间用指定的分隔符分隔,并处理多余的空格
            out = (prefix + doc_title + self.config.title_sep + doc_text + self.config.doc_sep + input_string).replace(
                "  ", " "
            )
            return out

        # 构建 RAG 模型的输入字符串列表
        rag_input_strings = [
            cat_input_and_doc(
                docs[i]["title"][j],
                docs[i]["text"][j],
                input_strings[i],
                prefix,
            )
            for i in range(len(docs))  # 遍历每个文档
            for j in range(n_docs)      # 遍历每个文档的多个版本(如果有的话)
        ]

        # 使用生成器的 tokenizer 对输入字符串列表进行批量编码处理
        contextualized_inputs = self.generator_tokenizer.batch_encode_plus(
            rag_input_strings,
            max_length=self.config.max_combined_length,  # 指定最大长度
            return_tensors=return_tensors,                # 是否返回张量
            padding="max_length",                         # 填充到最大长度
            truncation=True,                              # 是否截断超出最大长度的部分
        )

        # 返回编码后的输入张量
        return contextualized_inputs["input_ids"], contextualized_inputs["attention_mask"]

    def _chunk_tensor(self, t: Iterable, chunk_size: int) -> List[Iterable]:
        # 将输入张量按照指定的块大小进行切片并返回切片后的列表
        return [t[i : i + chunk_size] for i in range(0, len(t), chunk_size)]
    def _main_retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, np.ndarray]:
        # 将查询向量按照设定的批次大小分块处理
        question_hidden_states_batched = self._chunk_tensor(question_hidden_states, self.batch_size)
        ids_batched = []
        vectors_batched = []
        for question_hidden_states in question_hidden_states_batched:
            # 记录开始时间
            start_time = time.time()
            # 使用索引对象获取每个查询向量的前 n_docs 个文档的 ids 和向量表示
            ids, vectors = self.index.get_top_docs(question_hidden_states, n_docs)
            # 打印索引搜索时间和当前批次大小
            logger.debug(
                f"index search time: {time.time() - start_time} sec, batch size {question_hidden_states.shape}"
            )
            # 将获取的 ids 和 vectors 扩展到批次级别的列表中
            ids_batched.extend(ids)
            vectors_batched.extend(vectors)
        return (
            np.array(ids_batched),
            np.array(vectors_batched),
        )  # shapes (batch_size, n_docs) and (batch_size, n_docs, d)

    def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, List[dict]]:
        """
        为指定的 `question_hidden_states` 检索文档。

        Args:
            question_hidden_states (`np.ndarray` of shape `(batch_size, vector_size)`):
                要检索的查询向量的批次。
            n_docs (`int`):
                每个查询检索的文档数量。

        Return:
            `Tuple[np.ndarray, np.ndarray, List[dict]]`: 返回包含以下对象的元组:

            - **retrieved_doc_embeds** (`np.ndarray` of shape `(batch_size, n_docs, dim)`) -- 每个查询的检索嵌入的文档。
            - **doc_ids** (`np.ndarray` of shape `(batch_size, n_docs)`) -- 索引中文档的 ids。
            - **doc_dicts** (`List[dict]`): 每个查询的 `retrieved_doc_embeds` 示例。
        """

        # 使用 _main_retrieve 方法获取文档 ids 和检索嵌入
        doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
        # 返回检索嵌入、文档 ids 和获取的文档字典
        return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)

    def set_ctx_encoder_tokenizer(self, ctx_encoder_tokenizer: PreTrainedTokenizer):
        # 用于端到端检索器训练中,设置上下文编码器的分词器
        self.ctx_encoder_tokenizer = ctx_encoder_tokenizer
        self.return_tokenized_docs = True

    def __call__(
        self,
        question_input_ids: List[List[int]],
        question_hidden_states: np.ndarray,
        prefix=None,
        n_docs=None,
        return_tensors=None,
):
posted @ 2024-06-29 16:57  绝不原创的飞龙  阅读(135)  评论(0编辑  收藏  举报