Transformers-源码解析-九十二-

Transformers 源码解析(九十二)

.\models\rag\tokenization_rag.py

# coding=utf-8
# 声明文件编码格式为 UTF-8

# 版权声明和许可证信息

# 导入必要的模块和类
import os
import warnings
from typing import List, Optional

# 导入日志记录工具
from ...tokenization_utils_base import BatchEncoding
from ...utils import logging
from .configuration_rag import RagConfig

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)


class RagTokenizer:
    def __init__(self, question_encoder, generator):
        # 初始化 RAG Tokenizer 类,接受问题编码器和生成器作为参数
        self.question_encoder = question_encoder
        self.generator = generator
        self.current_tokenizer = self.question_encoder

    def save_pretrained(self, save_directory):
        # 将当前 tokenizer 实例保存到指定目录下
        if os.path.isfile(save_directory):
            # 如果保存路径是一个文件,抛出错误
            raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
        # 创建目录,如果目录已存在则不报错
        os.makedirs(save_directory, exist_ok=True)
        # 分别保存问题编码器和生成器的 tokenizer 到指定目录下的不同子目录
        question_encoder_path = os.path.join(save_directory, "question_encoder_tokenizer")
        generator_path = os.path.join(save_directory, "generator_tokenizer")
        self.question_encoder.save_pretrained(question_encoder_path)
        self.generator.save_pretrained(generator_path)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        # 从预训练模型或路径加载 RAG Tokenizer 实例
        # 动态导入 AutoTokenizer 类
        from ..auto.tokenization_auto import AutoTokenizer

        # 获取配置信息,如果未提供则从预训练模型加载
        config = kwargs.pop("config", None)
        if config is None:
            config = RagConfig.from_pretrained(pretrained_model_name_or_path)

        # 根据配置加载问题编码器和生成器的 tokenizer
        question_encoder = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path, config=config.question_encoder, subfolder="question_encoder_tokenizer"
        )
        generator = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path, config=config.generator, subfolder="generator_tokenizer"
        )
        return cls(question_encoder=question_encoder, generator=generator)

    def __call__(self, *args, **kwargs):
        # 实现 __call__ 方法,允许实例像函数一样被调用
        return self.current_tokenizer(*args, **kwargs)

    def batch_decode(self, *args, **kwargs):
        # 调用生成器的批量解码方法
        return self.generator.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        # 调用生成器的解码方法
        return self.generator.decode(*args, **kwargs)

    def _switch_to_input_mode(self):
        # 切换当前 tokenizer 到问题编码器模式
        self.current_tokenizer = self.question_encoder

    def _switch_to_target_mode(self):
        # 切换当前 tokenizer 到生成器模式
        self.current_tokenizer = self.generator
    # 警告:`prepare_seq2seq_batch`已被弃用,并将在🤗 Transformers版本5中移除。请使用常规的`__call__`方法准备输入,并在`with_target_tokenizer`上下文管理器下使用分词器准备目标。查看特定分词器的文档获取更多详情
    warnings.warn(
        "`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of 🤗 Transformers. Use the "
        "regular `__call__` method to prepare your inputs and the tokenizer under the `with_target_tokenizer` "
        "context manager to prepare your targets. See the documentation of your specific tokenizer for more "
        "details",
        FutureWarning,
    )
    
    # 如果未提供最大长度参数,则使用当前分词器的模型最大长度
    if max_length is None:
        max_length = self.current_tokenizer.model_max_length
    
    # 使用模型的__call__方法准备输入,包括源文本、添加特殊标记、返回的张量类型、最大长度、填充方式和截断标志
    model_inputs = self(
        src_texts,
        add_special_tokens=True,
        return_tensors=return_tensors,
        max_length=max_length,
        padding=padding,
        truncation=truncation,
        **kwargs,
    )
    
    # 如果未提供目标文本,则直接返回模型输入
    if tgt_texts is None:
        return model_inputs
    
    # 处理目标文本
    # 如果未提供最大目标长度参数,则使用当前分词器的模型最大长度
    if max_target_length is None:
        max_target_length = self.current_tokenizer.model_max_length
    
    # 使用模型的__call__方法准备目标标签,包括目标文本、添加特殊标记、返回的张量类型、填充方式、最大长度和截断标志
    labels = self(
        text_target=tgt_texts,
        add_special_tokens=True,
        return_tensors=return_tensors,
        padding=padding,
        max_length=max_target_length,
        truncation=truncation,
        **kwargs,
    )
    
    # 将准备好的目标标签的输入ID存储在模型输入字典中的"labels"键下
    model_inputs["labels"] = labels["input_ids"]
    
    # 返回最终的模型输入字典,包括源文本、可能的目标文本标签
    return model_inputs

.\models\rag\__init__.py

# 导入类型检查模块
from typing import TYPE_CHECKING

# 导入自定义异常和延迟加载模块
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available


# 定义模块的导入结构字典
_import_structure = {
    "configuration_rag": ["RagConfig"],  # 配置模块中的 RagConfig 类
    "retrieval_rag": ["RagRetriever"],   # 检索模块中的 RagRetriever 类
    "tokenization_rag": ["RagTokenizer"],  # 分词模块中的 RagTokenizer 类
}

# 尝试导入 Torch 模块,如果不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass  # 如果异常发生,则继续执行后续代码
else:
    # 如果 Torch 可用,则更新导入结构字典中的建模模块
    _import_structure["modeling_rag"] = [
        "RagModel",
        "RagPreTrainedModel",
        "RagSequenceForGeneration",
        "RagTokenForGeneration",
    ]

# 尝试导入 TensorFlow 模块,如果不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass  # 如果异常发生,则继续执行后续代码
else:
    # 如果 TensorFlow 可用,则更新导入结构字典中的 TensorFlow 建模模块
    _import_structure["modeling_tf_rag"] = [
        "TFRagModel",
        "TFRagPreTrainedModel",
        "TFRagSequenceForGeneration",
        "TFRagTokenForGeneration",
    ]


# 如果是类型检查模式,导入特定的模块
if TYPE_CHECKING:
    from .configuration_rag import RagConfig
    from .retrieval_rag import RagRetriever
    from .tokenization_rag import RagTokenizer

    # 尝试导入 Torch 模型模块,如果不可用则继续执行后续代码
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果 Torch 可用,则导入建模相关的 Torch 模块
        from .modeling_rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration

    # 尝试导入 TensorFlow 模型模块,如果不可用则继续执行后续代码
    try:
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果 TensorFlow 可用,则导入建模相关的 TensorFlow 模块
        from .modeling_tf_rag import (
            TFRagModel,
            TFRagPreTrainedModel,
            TFRagSequenceForGeneration,
            TFRagTokenForGeneration,
        )

# 如果不是类型检查模式,则将当前模块设置为一个延迟加载模块
else:
    import sys

    # 将当前模块注册为一个延迟加载模块,使用 LazyModule 类进行管理
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\realm\configuration_realm.py

# coding=utf-8
# Copyright 2022 The REALM authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" REALM model configuration."""

from ...configuration_utils import PretrainedConfig  # 导入预训练配置基类
from ...utils import logging  # 导入日志工具


logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

REALM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "google/realm-cc-news-pretrained-embedder": (
        "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/config.json"
    ),
    "google/realm-cc-news-pretrained-encoder": (
        "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/config.json"
    ),
    "google/realm-cc-news-pretrained-scorer": (
        "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/config.json"
    ),
    "google/realm-cc-news-pretrained-openqa": (
        "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/config.json"
    ),
    "google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/config.json",
    "google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/config.json",
    "google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/config.json",
    "google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/config.json",
    # See all REALM models at https://huggingface.co/models?filter=realm
}

class RealmConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of

    1. [`RealmEmbedder`]
    2. [`RealmScorer`]
    3. [`RealmKnowledgeAugEncoder`]
    4. [`RealmRetriever`]
    5. [`RealmReader`]
    6. [`RealmForOpenQA`]

    It is used to instantiate an REALM model according to the specified arguments, defining the model architecture.
    Instantiating a configuration with the defaults will yield a similar configuration to that of the REALM
    [google/realm-cc-news-pretrained-embedder](https://huggingface.co/google/realm-cc-news-pretrained-embedder)
    architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Example:

    ```
    >>> from transformers import RealmConfig, RealmEmbedder

    >>> # Initializing a REALM realm-cc-news-pretrained-* style configuration
    >>> configuration = RealmConfig()
    ```
    # 使用给定的配置初始化一个模型(具有随机权重),使用 google/realm-cc-news-pretrained-embedder 风格的配置
    model = RealmEmbedder(configuration)

    # 访问模型的配置信息
    configuration = model.config

.\models\realm\modeling_realm.py

# 导入必要的库和模块
import math  # 导入数学库
import os    # 导入操作系统相关的库
from dataclasses import dataclass  # 导入dataclass用于定义数据类
from typing import Optional, Tuple, Union  # 导入类型提示相关模块

import torch  # 导入PyTorch库
from torch import nn  # 导入神经网络模块
from torch.nn import CrossEntropyLoss  # 导入交叉熵损失函数

# 导入自定义的模块和类
from ...activations import ACT2FN  # 导入激活函数
from ...modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    BaseModelOutputWithPoolingAndCrossAttentions,
    MaskedLMOutput,
    ModelOutput,
)  # 导入模型输出相关的类
from ...modeling_utils import PreTrainedModel  # 导入预训练模型相关的工具函数和类
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer  # 导入PyTorch工具函数
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings  # 导入辅助函数和日志模块
from .configuration_realm import RealmConfig  # 导入REALM模型的配置类

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)
# 以下是用于文档的预训练模型的名称和路径
_EMBEDDER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-embedder"
_ENCODER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-encoder"
_SCORER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-scorer"
_CONFIG_FOR_DOC = "RealmConfig"

# REALM的预训练模型的存档列表
REALM_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "google/realm-cc-news-pretrained-embedder",
    "google/realm-cc-news-pretrained-encoder",
    "google/realm-cc-news-pretrained-scorer",
    "google/realm-cc-news-pretrained-openqa",
    "google/realm-orqa-nq-openqa",
    "google/realm-orqa-nq-reader",
    "google/realm-orqa-wq-openqa",
    "google/realm-orqa-wq-reader",
    # 查看所有REALM模型的完整列表:https://huggingface.co/models?filter=realm
]


def load_tf_weights_in_realm(model, config, tf_checkpoint_path):
    """Load tf checkpoints in a pytorch model."""
    try:
        import re  # 导入正则表达式模块
        import numpy as np  # 导入NumPy库
        import tensorflow as tf  # 导入TensorFlow库
    except ImportError:
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
        raise
    # 获取TensorFlow检查点文件的绝对路径
    tf_path = os.path.abspath(tf_checkpoint_path)
    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
    # 从TF模型加载权重
    init_vars = tf.train.list_variables(tf_path)  # 获取TF模型中的变量列表
    names = []
    arrays = []

    for name, shape in init_vars:
        logger.info(f"Loading TF weight {name} with shape {shape}")
        array = tf.train.load_variable(tf_path, name)  # 加载TF模型中的变量
        names.append(name)
        arrays.append(array)
    # 返回函数中的变量 `model`
    return model
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->Realm
# 定义 RealmEmbeddings 类,用于构建包含单词、位置和标记类型嵌入的总体嵌入。
class RealmEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, config):
        super().__init__()
        # 初始化单词嵌入层,将词汇表大小、隐藏层大小和填充标记 ID 作为参数
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        # 初始化位置嵌入层,将最大位置嵌入大小和隐藏层大小作为参数
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        # 初始化标记类型嵌入层,将类型词汇表大小和隐藏层大小作为参数
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # self.LayerNorm 没有使用蛇形命名以保持与 TensorFlow 模型变量名一致,可以加载任何 TensorFlow 检查点文件
        # 初始化 LayerNorm 层,将隐藏层大小和层归一化 epsilon 作为参数
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 初始化 Dropout 层,将隐藏层 dropout 概率作为参数
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # position_ids (1, len position emb) 在序列化时是连续的内存,并在导出时被导出
        # 设置位置嵌入类型,默认为绝对位置编码
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
        # 注册缓冲区,用于存储位置 IDs,扩展为 (1, max_position_embeddings)
        self.register_buffer(
            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
        )
        # 注册缓冲区,用于存储标记类型 IDs,初始化为与位置 IDs 相同形状的零张量
        self.register_buffer(
            "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
        )

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        past_key_values_length: int = 0,
        # 这里定义了 forward 方法的输入参数和默认值
    # 定义函数的输入类型和返回类型为 torch.Tensor
    ) -> torch.Tensor:
        # 如果输入的 input_ids 不为空,则获取其形状
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            # 否则,获取 inputs_embeds 的形状,但是不包括最后一个维度
            input_shape = inputs_embeds.size()[:-1]

        # 获取序列长度,这里假设 input_shape 的第二个维度为序列长度
        seq_length = input_shape[1]

        # 如果 position_ids 为空,则从 self.position_ids 中切片获取对应位置的位置 ids
        if position_ids is None:
            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]

        # 设置 token_type_ids 为构造函数中注册的缓冲区,通常为全零,用于在模型追踪时帮助用户
        # 如果 token_type_ids 为空,则检查模型是否具有 "token_type_ids" 属性
        if token_type_ids is None:
            if hasattr(self, "token_type_ids"):
                # 获取已注册的缓冲区的 token_type_ids,展开以匹配输入形状的第二个维度长度
                buffered_token_type_ids = self.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
                token_type_ids = buffered_token_type_ids_expanded
            else:
                # 否则,创建全零的 token_type_ids 张量,与输入形状相同
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

        # 如果 inputs_embeds 为空,则使用 word_embeddings 对 input_ids 进行嵌入
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        # 根据 token_type_ids 获取 token_type_embeddings
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        # 将 inputs_embeds 和 token_type_embeddings 相加作为嵌入向量
        embeddings = inputs_embeds + token_type_embeddings

        # 如果 position_embedding_type 是 "absolute",则添加位置嵌入
        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings

        # 应用 LayerNorm 规范化嵌入向量
        embeddings = self.LayerNorm(embeddings)

        # 对嵌入向量应用 dropout
        embeddings = self.dropout(embeddings)

        # 返回最终的嵌入向量作为输出
        return embeddings
# 从 transformers.models.bert.modeling_bert.BertSelfAttention 复制并将 Bert 替换为 Realm 的自注意力机制
class RealmSelfAttention(nn.Module):
    def __init__(self, config, position_embedding_type=None):
        super().__init__()
        # 如果隐藏层大小不是注意力头数的倍数且没有嵌入大小属性,则引发值错误
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # 用于生成查询、键和值的线性层
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        # 注意力概率的 dropout
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        # 如果位置嵌入类型是相对键或相对键查询,则初始化距离嵌入
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

        self.is_decoder = config.is_decoder

    # 转置张量以适应注意力分数的计算
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    # RealmSelfAttention 的前向传播
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
        # 从 transformers.models.bert.modeling_bert.BertSelfOutput 复制并将 Bert 替换为 Realm 的自输出层
class RealmSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 线性层,用于变换隐藏状态
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 层归一化
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # dropout
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    # RealmSelfOutput 的前向传播
    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 线性变换
        hidden_states = self.dense(hidden_states)
        # dropout
        hidden_states = self.dropout(hidden_states)
        # 层归一化并添加输入张量,然后返回
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


# 从 transformers.models.bert.modeling_bert.BertAttention 复制并将 Bert 替换为 Realm 的注意力模块
# RealmAttention 类,继承自 nn.Module
class RealmAttention(nn.Module):
    # 初始化方法,接受 config 和 position_embedding_type 两个参数
    def __init__(self, config, position_embedding_type=None):
        super().__init__()
        # 创建 RealmSelfAttention 对象并赋值给 self.self 属性
        self.self = RealmSelfAttention(config, position_embedding_type=position_embedding_type)
        # 创建 RealmSelfOutput 对象并赋值给 self.output 属性
        self.output = RealmSelfOutput(config)
        # 创建一个空集合用于存储被剪枝的注意力头的索引
        self.pruned_heads = set()

    # 剪枝注意力头的方法
    def prune_heads(self, heads):
        # 如果 heads 列表为空,则直接返回
        if len(heads) == 0:
            return
        # 调用 find_pruneable_heads_and_indices 方法,获取可剪枝头的索引和具体头的信息
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )

        # 剪枝 self 属性中的 query、key、value 线性层
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        # 剪枝 self.output 属性中的 dense 线性层
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # 更新 self 属性中的注意力头数量和总大小
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        # 将剪枝过的头索引添加到 self.pruned_heads 中
        self.pruned_heads = self.pruned_heads.union(heads)

    # 前向传播方法
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # 调用 self.self 的前向传播方法
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
        )
        # 将 self_outputs 的第一个元素作为输入,调用 self.output 的前向传播方法
        attention_output = self.output(self_outputs[0], hidden_states)
        # 构建输出元组,包括 attention_output 和 self_outputs 的其余部分
        outputs = (attention_output,) + self_outputs[1:]  # 如果需要输出 attentions,则添加它们
        return outputs


# 从 transformers.models.bert.modeling_bert.BertIntermediate 复制并修改为 RealmIntermediate 类
class RealmIntermediate(nn.Module):
    # 初始化方法,接受 config 参数
    def __init__(self, config):
        super().__init__()
        # 创建一个全连接层,将输入特征维度为 config.hidden_size 转换为 config.intermediate_size
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        # 如果 config.hidden_act 是字符串类型,则使用 ACT2FN 字典中对应的激活函数;否则直接使用 config.hidden_act
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    # 前向传播方法
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 输入 hidden_states 经过全连接层 dense
        hidden_states = self.dense(hidden_states)
        # 将全连接层的输出经过激活函数 intermediate_act_fn
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


# 从 transformers.models.bert.modeling_bert.BertOutput 复制并修改为 RealmOutput 类
class RealmOutput(nn.Module):
    # 初始化函数,接受一个配置对象作为参数
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__()
        # 创建一个全连接层,输入大小为配置对象中的 intermediate_size,输出大小为 hidden_size
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        # 创建一个 LayerNorm 层,输入大小为 hidden_size,设置 epsilon 参数为配置对象中的 layer_norm_eps
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 创建一个 Dropout 层,设置丢弃概率为配置对象中的 hidden_dropout_prob
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    # 前向传播函数,接受两个张量作为输入,返回一个张量
    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 将输入张量经过全连接层得到新的隐藏状态张量
        hidden_states = self.dense(hidden_states)
        # 对新的隐藏状态张量进行 Dropout 操作
        hidden_states = self.dropout(hidden_states)
        # 将 Dropout 后的隐藏状态张量与输入张量相加,然后经过 LayerNorm 层
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        # 返回经过处理后的隐藏状态张量
        return hidden_states
# 从transformers.models.bert.modeling_bert.BertLayer复制并修改为RealmLayer,用于Realm模型中的一个层
class RealmLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化层的配置参数
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        # 序列长度维度,默认为1,用于指定输入张量中表示序列长度的维度
        self.seq_len_dim = 1
        # 创建RealmAttention对象,处理注意力机制
        self.attention = RealmAttention(config)
        # 是否作为解码器使用的标志
        self.is_decoder = config.is_decoder
        # 是否添加跨层注意力机制的标志
        self.add_cross_attention = config.add_cross_attention
        # 如果添加了跨层注意力机制,必须作为解码器使用
        if self.add_cross_attention:
            if not self.is_decoder:
                raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
            # 创建具有绝对位置嵌入的RealmAttention对象,用于跨层注意力机制
            self.crossattention = RealmAttention(config, position_embedding_type="absolute")
        # Realm模型的中间层,负责前向传播中的中间处理
        self.intermediate = RealmIntermediate(config)
        # Realm模型的输出层,负责生成最终输出
        self.output = RealmOutput(config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # 定义函数签名,返回类型为元组,包含一个 torch.Tensor 类型的对象

        # 如果有过去的注意力头/值缓存,则从中提取解码器单向自注意力的缓存键/值对,位置在1,2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
        
        # 使用 self.attention 方法进行自注意力计算
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            past_key_value=self_attn_past_key_value,
        )
        
        # 获取自注意力计算的输出
        attention_output = self_attention_outputs[0]

        # 如果是解码器,最后一个输出是自注意力缓存的元组
        if self.is_decoder:
            outputs = self_attention_outputs[1:-1]  # 排除最后一个元素,因为它是自注意力的缓存
            present_key_value = self_attention_outputs[-1]  # 获取当前注意力的键/值对
        else:
            outputs = self_attention_outputs[1:]  # 如果输出注意力权重,添加自注意力
            # outputs 现在包含所有的输出元素,除了第一个元素,即自注意力输出

        # 初始化交叉注意力的键/值对为 None
        cross_attn_present_key_value = None
        
        # 如果是解码器并且提供了编码器的隐藏状态
        if self.is_decoder and encoder_hidden_states is not None:
            # 如果没有 crossattention 属性,抛出 ValueError
            if not hasattr(self, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
                    " by setting `config.add_cross_attention=True`"
                )

            # 如果有过去的注意力头/值缓存,则从中提取交叉注意力的缓存键/值对,位置在3,4
            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
            
            # 使用 self.crossattention 方法进行交叉注意力计算
            cross_attention_outputs = self.crossattention(
                attention_output,
                attention_mask,
                head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                cross_attn_past_key_value,
                output_attentions,
            )
            
            # 获取交叉注意力计算的输出
            attention_output = cross_attention_outputs[0]
            outputs = outputs + cross_attention_outputs[1:-1]  # 添加交叉注意力的输出
            
            # 将交叉注意力的键/值对添加到当前的注意力键/值对中的位置3,4
            cross_attn_present_key_value = cross_attention_outputs[-1]
            present_key_value = present_key_value + cross_attn_present_key_value

        # 将注意力输出应用于前向传播的分块处理
        layer_output = apply_chunking_to_forward(
            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
        )
        outputs = (layer_output,) + outputs

        # 如果是解码器,将注意力的键/值对作为最后一个输出返回
        if self.is_decoder:
            outputs = outputs + (present_key_value,)

        return outputs

    def feed_forward_chunk(self, attention_output):
        # 对注意力输出应用 feed forward 网络的一部分,并返回处理后的层输出
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output
# 从 transformers.models.bert.modeling_bert.BertEncoder 复制并修改为 RealmEncoder
class RealmEncoder(nn.Module):
    # RealmEncoder 类的初始化方法
    def __init__(self, config):
        super().__init__()
        # 将传入的配置信息保存到实例变量中
        self.config = config
        # 创建一个由多个 RealmLayer 组成的层列表,列表长度等于配置中指定的隐藏层数
        self.layer = nn.ModuleList([RealmLayer(config) for _ in range(config.num_hidden_layers)])
        # 默认不启用梯度检查点
        self.gradient_checkpointing = False

    # RealmEncoder 类的前向传播方法
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = True,
    # 返回的类型注解和输出类型,指示此函数返回一个元组,元素为torch.Tensor或BaseModelOutputWithPastAndCrossAttentions对象
    -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
        # 如果不输出隐藏状态,则将all_hidden_states设为空元组
        all_hidden_states = () if output_hidden_states else None
        # 如果不输出注意力权重,则将all_self_attentions设为空元组
        all_self_attentions = () if output_attentions else None
        # 如果不输出跨层注意力权重或模型配置不支持,则将all_cross_attentions设为空元组
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

        # 如果启用了梯度检查点且处于训练模式下
        if self.gradient_checkpointing and self.training:
            # 如果use_cache为True,则给出警告并将其设置为False,因为梯度检查点和缓存不兼容
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        # 如果不使用缓存,则初始化next_decoder_cache为空元组
        next_decoder_cache = () if use_cache else None
        # 遍历每个Transformer层
        for i, layer_module in enumerate(self.layer):
            # 如果需要输出隐藏状态,则将当前层的隐藏状态添加到all_hidden_states中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 如果存在头部掩码,则将其从head_mask中取出
            layer_head_mask = head_mask[i] if head_mask is not None else None
            # 如果存在先前的键值对,则从past_key_values中取出
            past_key_value = past_key_values[i] if past_key_values is not None else None

            # 如果启用了梯度检查点并处于训练模式下
            if self.gradient_checkpointing and self.training:
                # 使用_gradient_checkpointing_func函数来执行梯度检查点,减少内存使用
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )
            else:
                # 否则直接调用当前层的模块,计算输出
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )

            # 更新隐藏状态为当前层的输出的第一个元素
            hidden_states = layer_outputs[0]
            # 如果使用缓存,则将当前层的输出的最后一个元素添加到next_decoder_cache中
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)
            # 如果需要输出注意力权重,则将当前层的自注意力权重添加到all_self_attentions中
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)
                # 如果模型配置支持添加跨层注意力,则将当前层的跨层注意力权重添加到all_cross_attentions中
                if self.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

        # 如果需要输出隐藏状态,则将最终隐藏状态添加到all_hidden_states中
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 如果不返回字典形式的结果,则将各个结果组成元组返回,过滤掉为None的部分
        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_decoder_cache,
                    all_hidden_states,
                    all_self_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )
        # 否则返回BaseModelOutputWithPastAndCrossAttentions对象,包含最终的输出
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )
# 定义一个名为 RealmPooler 的类,继承自 nn.Module
class RealmPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 创建一个全连接层,输入和输出的维度都是 config.hidden_size
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 激活函数使用双曲正切函数
        self.activation = nn.Tanh()

    # 前向传播函数,接收隐藏状态 hidden_states,并返回池化后的张量
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 选择隐藏状态中的第一个标记对应的张量
        first_token_tensor = hidden_states[:, 0]
        # 将选择的张量通过全连接层进行线性变换
        pooled_output = self.dense(first_token_tensor)
        # 将线性变换后的结果通过激活函数处理
        pooled_output = self.activation(pooled_output)
        # 返回池化后的输出张量
        return pooled_output


# 定义一个名为 RealmEmbedderOutput 的数据类,继承自 ModelOutput
@dataclass
class RealmEmbedderOutput(ModelOutput):
    """
    RealmEmbedder 模型的输出。

    Args:
        projected_score (`torch.FloatTensor` of shape `(batch_size, config.retriever_proj_size)`):
            投影分数。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, 当 `output_hidden_states=True` 时返回或当 `config.output_hidden_states=True` 时返回):
            一个元组,包含 `torch.FloatTensor`(一个用于嵌入输出 + 一个用于每个层的输出)的形状为 `(batch_size, sequence_length, hidden_size)`。

            模型在每一层输出的隐藏状态加上初始嵌入的输出。
        attentions (`tuple(torch.FloatTensor)`, *optional*, 当 `output_attentions=True` 时返回或当 `config.output_attentions=True` 时返回):
            一个元组,包含 `torch.FloatTensor`(每层一个)的形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。

            在注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
    """

    projected_score: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


# 定义一个名为 RealmScorerOutput 的数据类,继承自 ModelOutput
@dataclass
class RealmScorerOutput(ModelOutput):
    """
    RealmScorer 模型的输出。

    Args:
        relevance_score (`torch.FloatTensor` of shape `(batch_size, config.num_candidates)`):
            文件候选的相关性分数(softmax 之前)。
        query_score (`torch.FloatTensor` of shape `(batch_size, config.retriever_proj_size)`):
            源自查询嵌入的查询分数。
        candidate_score (`torch.FloatTensor` of shape `(batch_size, config.num_candidates, config.retriever_proj_size)`):
            源自嵌入器的候选分数。
    """

    relevance_score: torch.FloatTensor = None
    query_score: torch.FloatTensor = None
    candidate_score: torch.FloatTensor = None


# 定义一个名为 RealmReaderOutput 的数据类,继承自 ModelOutput
@dataclass
class RealmReaderOutput(ModelOutput):
    """
    RealmReader 模型的输出。
    
    这里没有特定的参数和注释,仅作为占位使用。
    """
    pass
    # 定义函数的参数和返回值的类型注解,使用了 torch 库中的数据类型
    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `start_positions`, `end_positions`, `has_answers` are provided):
            总损失。
        retriever_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `start_positions`, `end_positions`, `has_answers` are provided):
            检索器损失。
        reader_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `start_positions`, `end_positions`, `has_answers` are provided):
            阅读器损失。
        retriever_correct (`torch.BoolTensor` of shape `(config.searcher_beam_size,)`, *optional*):
            检索器是否正确检测到包含答案的证据块。
        reader_correct (`torch.BoolTensor` of shape `(config.reader_beam_size, num_candidates)`, *optional*):
            阅读器是否正确检测到包含答案的文本片段候选。
        block_idx (`torch.LongTensor` of shape `()`):
            预测答案最有可能出现的检索到的证据块的索引。
        candidate (`torch.LongTensor` of shape `()`):
            预测答案最有可能出现的检索到的文本片段候选的索引。
        start_pos (`torch.IntTensor` of shape `()`):
            预测答案在 *RealmReader* 输入中起始位置的索引。
        end_pos (`torch.IntTensor` of shape `()`):
            预测答案在 *RealmReader* 输入中结束位置的索引。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            模型每一层的隐藏状态,包括初始嵌入输出。
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            self-attention 头部的注意力权重,用于计算注意力头部的加权平均值。

    # 初始化所有变量为 None,这些变量用于存储模型输出的各种损失、正确性和位置信息等。
    loss: torch.FloatTensor = None
    retriever_loss: torch.FloatTensor = None
    reader_loss: torch.FloatTensor = None
    retriever_correct: torch.BoolTensor = None
    reader_correct: torch.BoolTensor = None
    block_idx: torch.LongTensor = None
    candidate: torch.LongTensor = None
    start_pos: torch.int32 = None
    end_pos: torch.int32 = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class RealmForOpenQAOutput(ModelOutput):
    """

    Outputs of [`RealmForOpenQA`] models.

    Args:
        reader_output (`dict`):
            Reader output.
        predicted_answer_ids (`torch.LongTensor` of shape `(answer_sequence_length)`):
            Predicted answer ids.
    """

    # 定义了一个数据类,用于封装 RealmForOpenQA 模型的输出结果
    reader_output: dict = None  # 用于存储阅读器模型的输出,是一个字典类型
    predicted_answer_ids: torch.LongTensor = None  # 存储预测的答案 id,是一个长整型张量


class RealmPredictionHeadTransform(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)  # 创建一个全连接层
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]  # 根据配置选择激活函数
        else:
            self.transform_act_fn = config.hidden_act  # 使用给定的激活函数
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # 创建一个 LayerNorm 层

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)  # 全连接层的前向传播
        hidden_states = self.transform_act_fn(hidden_states)  # 应用激活函数
        hidden_states = self.LayerNorm(hidden_states)  # 应用 LayerNorm
        return hidden_states


class RealmLMPredictionHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.transform = RealmPredictionHeadTransform(config)  # 创建一个预测头变换模块

        # 输出权重与输入嵌入相同,但每个标记有一个只输出的偏置
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)  # 创建一个线性层,用于预测词汇表中的词

        self.bias = nn.Parameter(torch.zeros(config.vocab_size))  # 创建一个偏置参数

        # 需要一个链接,以便偏置在调整标记嵌入大小时正确调整大小
        self.decoder.bias = self.bias  # 将创建的偏置赋给 decoder 层的偏置

    def forward(self, hidden_states):
        hidden_states = self.transform(hidden_states)  # 应用预测头变换
        hidden_states = self.decoder(hidden_states)  # 应用线性层进行预测
        return hidden_states


class RealmOnlyMLMHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.predictions = RealmLMPredictionHead(config)  # 创建一个仅包含 MLM 预测头的模块

    def forward(self, sequence_output):
        prediction_scores = self.predictions(sequence_output)  # 使用预测头进行序列输出的预测
        return prediction_scores


class RealmScorerProjection(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.predictions = RealmLMPredictionHead(config)  # 创建一个用于打分投影的预测头模块
        self.dense = nn.Linear(config.hidden_size, config.retriever_proj_size)  # 创建一个全连接层
        self.LayerNorm = nn.LayerNorm(config.retriever_proj_size, eps=config.layer_norm_eps)  # 创建一个 LayerNorm 层

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)  # 应用全连接层
        hidden_states = self.LayerNorm(hidden_states)  # 应用 LayerNorm
        return hidden_states


class RealmReaderProjection(nn.Module):
    # 此处添加 RealmReaderProjection 类的定义和实现
    pass
    # 初始化方法,接受一个配置对象并设置模型的各种参数
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__()
        # 将配置对象保存在实例变量中
        self.config = config
        # 创建一个线性层,输入大小为配置中的隐藏大小,输出大小为span_hidden_size * 2
        self.dense_intermediate = nn.Linear(config.hidden_size, config.span_hidden_size * 2)
        # 创建一个线性层,输入大小为span_hidden_size,输出大小为1
        self.dense_output = nn.Linear(config.span_hidden_size, 1)
        # 创建一个LayerNorm层,标准化大小为span_hidden_size的向量,epsilon值为config中的设定值
        self.layer_normalization = nn.LayerNorm(config.span_hidden_size, eps=config.reader_layer_norm_eps)
        # 创建ReLU激活函数实例
        self.relu = nn.ReLU()

    # 前向传播方法,接受隐藏状态和块掩码作为输入,输出阅读器的逻辑概率、候选开始位置和结束位置
    def forward(self, hidden_states, block_mask):
        # 内部函数,生成跨度候选
        def span_candidates(masks):
            """
            Generate span candidates.

            Args:
                masks: <bool> [num_retrievals, max_sequence_len]

            Returns:
                starts: <int32> [num_spans] ends: <int32> [num_spans] span_masks: <int32> [num_retrievals, num_spans]
                whether spans locate in evidence block.
            """
            # 获取掩码的形状信息
            _, max_sequence_len = masks.shape

            # 内部函数,根据宽度生成跨度
            def _spans_given_width(width):
                current_starts = torch.arange(max_sequence_len - width + 1, device=masks.device)
                current_ends = torch.arange(width - 1, max_sequence_len, device=masks.device)
                return current_starts, current_ends

            # 生成不同宽度下的起始点和结束点列表
            starts, ends = zip(*(_spans_given_width(w + 1) for w in range(self.config.max_span_width)))

            # 将列表合并成一个张量 [num_spans]
            starts = torch.cat(starts, 0)
            ends = torch.cat(ends, 0)

            # 根据开始和结束位置索引掩码张量 [num_retrievals, num_spans]
            start_masks = torch.index_select(masks, dim=-1, index=starts)
            end_masks = torch.index_select(masks, dim=-1, index=ends)
            span_masks = start_masks * end_masks

            return starts, ends, span_masks

        # 将掩码转换为得分,用于屏蔽无效候选
        def mask_to_score(mask, dtype=torch.float32):
            return (1.0 - mask.type(dtype)) * torch.finfo(dtype).min

        # 使用线性层处理隐藏状态 [reader_beam_size, max_sequence_len, span_hidden_size * 2]
        hidden_states = self.dense_intermediate(hidden_states)
        # 将处理后的隐藏状态分成开始和结束投影 [reader_beam_size, max_sequence_len, span_hidden_size]
        start_projection, end_projection = hidden_states.chunk(2, dim=-1)

        # 生成跨度候选及其对应的掩码 [reader_beam_size, num_candidates, span_hidden_size]
        candidate_starts, candidate_ends, candidate_mask = span_candidates(block_mask)

        # 根据候选开始和结束索引获取对应的投影向量 [reader_beam_size, num_candidates, span_hidden_size]
        candidate_start_projections = torch.index_select(start_projection, dim=1, index=candidate_starts)
        candidate_end_projections = torch.index_select(end_projection, dim=1, index=candidate_ends)
        candidate_hidden = candidate_start_projections + candidate_end_projections

        # 应用ReLU激活函数 [reader_beam_size, num_candidates, span_hidden_size]
        candidate_hidden = self.relu(candidate_hidden)
        # 应用LayerNorm进行标准化 [reader_beam_size, num_candidates, span_hidden_size]
        candidate_hidden = self.layer_normalization(candidate_hidden)
        # 使用线性层计算阅读器的逻辑概率,然后压缩维度 [reader_beam_size, num_candidates]
        reader_logits = self.dense_output(candidate_hidden).squeeze(-1)
        # 添加掩码转换为得分的结果到阅读器的逻辑概率中 [reader_beam_size, num_candidates]
        reader_logits += mask_to_score(candidate_mask, dtype=reader_logits.dtype)

        return reader_logits, candidate_starts, candidate_ends
# 定义一个多行文档字符串,描述了该模型是一个 PyTorch 的 torch.nn.Module 子类,用法与一般的 PyTorch 模块相同,
# 并建议查阅 PyTorch 文档以获取有关一般用法和行为的所有信息。
REALM_START_DOCSTRING = r"""
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.
    
    Parameters:
        config ([`RealmConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

# 定义一个空的文档字符串,将用于记录函数的输入说明
REALM_INPUTS_DOCSTRING = r"""
    
"""
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            # 输入序列标记在词汇表中的索引。
            # 可以使用 [`AutoTokenizer`] 获取这些索引。详见 [`PreTrainedTokenizer.encode`] 和 [`PreTrainedTokenizer.__call__`]。
            # 
            # [什么是输入 ID?](../glossary#input-ids)
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            # 遮罩,用于在填充的标记索引上避免执行注意力操作。
            # 遮罩值选择在 `[0, 1]` 之间:
            # - 1 表示**未遮罩**的标记,
            # - 0 表示**遮罩**的标记。
            # 
            # [什么是注意力遮罩?](../glossary#attention-mask)
        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            # 段标记索引,用于指示输入的第一部分和第二部分。索引选择在 `[0, 1]` 之间:
            # - 0 对应于 *句子 A* 的标记,
            # - 1 对应于 *句子 B* 的标记。
            # 
            # [什么是标记类型 ID?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            # 每个输入序列标记在位置嵌入中的位置索引。选取范围为 `[0, config.max_position_embeddings - 1]`。
            # 
            # [什么是位置 ID?](../glossary#position-ids)
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            # 用于屏蔽自注意力模块中选定头部的遮罩。遮罩值选择在 `[0, 1]` 之间:
            # - 1 表示头部**未遮罩**,
            # - 0 表示头部**遮罩**。
        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
            # 可选,而不是传递 `input_ids`,您可以选择直接传递嵌入表示。
            # 如果您希望对如何将 *input_ids* 索引转换为关联向量有更多控制权,则这很有用,而不是使用模型的内部嵌入查找矩阵。
        output_attentions (`bool`, *optional*):
            # 是否返回所有注意力层的注意力张量。详见返回的张量下的 `attentions` 获取更多细节。
        output_hidden_states (`bool`, *optional*):
            # 是否返回所有层的隐藏状态。详见返回的张量下的 `hidden_states` 获取更多细节。
        return_dict (`bool`, *optional*):
            # 是否返回 [`~utils.ModelOutput`] 而不是普通元组。
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """
    
    # RealmPreTrainedModel 类,用于处理权重初始化和预训练模型下载加载的抽象类

    config_class = RealmConfig
    # 类属性 config_class 指定为 RealmConfig,用于配置模型的配置类

    load_tf_weights = load_tf_weights_in_realm
    # 类属性 load_tf_weights 指定为 load_tf_weights_in_realm,用于加载 TF 格式的权重到 Realm 模型中

    base_model_prefix = "realm"
    # 类属性 base_model_prefix 指定为 "realm",作为基础模型的前缀名称

    def _init_weights(self, module):
        """Initialize the weights"""
        # 初始化模型的权重

        if isinstance(module, nn.Linear):
            # 如果模块是线性层
            # 与 TF 版本稍有不同,TF 版本使用截断正态分布进行初始化
            # 参考 https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 初始化权重为正态分布,均值为 0,标准差为 self.config.initializer_range
            if module.bias is not None:
                module.bias.data.zero_()
                # 如果有偏置项,则将偏置项初始化为 0
        elif isinstance(module, nn.Embedding):
            # 如果模块是嵌入层
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 初始化权重为正态分布,均值为 0,标准差为 self.config.initializer_range
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
                # 如果设置了 padding_idx,则将对应位置的权重初始化为 0
        elif isinstance(module, nn.LayerNorm):
            # 如果模块是 LayerNorm 层
            module.bias.data.zero_()
            # 将偏置项初始化为 0
            module.weight.data.fill_(1.0)
            # 将权重初始化为全 1

    def _flatten_inputs(self, *inputs):
        """Flatten inputs' shape to (-1, input_shape[-1])"""
        # 将输入张量的形状展平为 (-1, input_shape[-1])

        flattened_inputs = []
        # 初始化空列表,用于存储展平后的输入张量

        for tensor in inputs:
            # 遍历输入张量列表
            if tensor is None:
                flattened_inputs.append(None)
                # 如果张量为 None,则直接添加 None 到展平后的输入列表
            else:
                input_shape = tensor.shape
                # 获取张量的形状
                if len(input_shape) > 2:
                    tensor = tensor.view((-1, input_shape[-1]))
                    # 如果张量维度大于 2,则将其展平为 (-1, input_shape[-1])
                flattened_inputs.append(tensor)
                # 将展平后的张量添加到展平后的输入列表中

        return flattened_inputs
        # 返回展平后的输入列表


class RealmBertModel(RealmPreTrainedModel):
    """
    Same as the original BertModel but remove docstrings.
    """
    
    # RealmBertModel 类,继承自 RealmPreTrainedModel,与原始的 BertModel 类似,但删除了文档字符串

    def __init__(self, config, add_pooling_layer=True):
        super().__init__(config)
        # 调用父类 RealmPreTrainedModel 的初始化方法,传入配置参数 config

        self.config = config
        # 设置实例属性 config 为传入的配置参数 config

        self.embeddings = RealmEmbeddings(config)
        # 初始化 embeddings 层,使用 RealmEmbeddings 类,并传入配置参数 config
        self.encoder = RealmEncoder(config)
        # 初始化 encoder 层,使用 RealmEncoder 类,并传入配置参数 config

        self.pooler = RealmPooler(config) if add_pooling_layer else None
        # 如果 add_pooling_layer 为 True,则初始化 pooler 层为 RealmPooler 类,传入配置参数 config;否则设为 None

        # Weights initialization is mostly managed by other Realm models,
        # but we also have them initialized here to keep a consistency.
        # 权重初始化大部分由其他 Realm 模型管理,
        # 但我们在这里也进行初始化以保持一致性。
        self.post_init()

    def get_input_embeddings(self):
        return self.embeddings.word_embeddings
        # 返回 embeddings 层的 word_embeddings 属性作为输入嵌入层

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value
        # 设置 embeddings 层的 word_embeddings 属性为指定的 value

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        # 对模型的注意力头进行剪枝

        for layer, heads in heads_to_prune.items():
            # 遍历要剪枝的层及其对应的头部列表
            self.encoder.layer[layer].attention.prune_heads(heads)
            # 调用 encoder 的指定层的 attention 对象的 prune_heads 方法,对指定头部列表进行剪枝
    # 定义模型的前向传播方法,处理模型的输入和返回输出
    def forward(
        self,
        input_ids=None,                    # 输入的token IDs
        attention_mask=None,               # 注意力掩码,指定哪些token需被注意
        token_type_ids=None,               # token类型IDs,用于区分句子A和句子B
        position_ids=None,                 # 位置IDs,指定每个token的位置信息
        head_mask=None,                    # 头部掩码,指定每个注意力头是否可用
        inputs_embeds=None,                # 输入的嵌入表示
        encoder_hidden_states=None,        # 编码器的隐藏状态
        encoder_attention_mask=None,       # 编码器的注意力掩码
        past_key_values=None,              # 用于存储循环计算的键值对
        use_cache=None,                    # 是否使用缓存
        output_attentions=None,            # 是否输出注意力权重
        output_hidden_states=None,         # 是否输出隐藏状态
        return_dict=None,                  # 是否返回一个字典作为输出
# 添加起始文档字符串和相关信息到 RealmEmbedder 类
@add_start_docstrings(
    "The embedder of REALM outputting projected score that will be used to calculate relevance score.",
    REALM_START_DOCSTRING,
)
class RealmEmbedder(RealmPreTrainedModel):
    # 定义共享权重的键列表
    _tied_weights_keys = ["cls.predictions.decoder.bias"]

    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)

        # 实例化 RealmBertModel 并传入配置
        self.realm = RealmBertModel(self.config)
        # 实例化 RealmScorerProjection 并传入配置
        self.cls = RealmScorerProjection(self.config)
        # 执行额外的初始化操作
        self.post_init()

    def get_input_embeddings(self):
        # 返回 RealmEmbedder 使用的输入嵌入
        return self.realm.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        # 设置 RealmEmbedder 使用的输入嵌入
        self.realm.embeddings.word_embeddings = value

    # 向模型前向方法添加起始文档字符串和输入说明
    @add_start_docstrings_to_model_forward(REALM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    # 替换前向方法的返回值文档字符串
    @replace_return_docstrings(output_type=RealmEmbedderOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, RealmEmbedderOutput]:
        """
        RealmEmbedder 的前向传播方法。

        Returns:
            如果 return_dict 为 False,则返回元组 (projected_score, hidden_states, attentions)。
            如果 return_dict 为 True,则返回 RealmEmbedderOutput 对象,其中包含 projected_score、hidden_states 和 attentions。
        """

        # 根据 return_dict 的值确定是否使用配置中的默认设置
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用 RealmBertModel 的前向传播
        realm_outputs = self.realm(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取池化后的输出,维度为 [batch_size, hidden_size]
        pooler_output = realm_outputs[1]
        # 通过 RealmScorerProjection 计算投影分数,维度为 [batch_size, retriever_proj_size]
        projected_score = self.cls(pooler_output)

        # 根据 return_dict 的值决定返回的结果类型
        if not return_dict:
            return (projected_score,) + realm_outputs[2:4]
        else:
            return RealmEmbedderOutput(
                projected_score=projected_score,
                hidden_states=realm_outputs.hidden_states,
                attentions=realm_outputs.attentions,
            )
    # 描述了 REALM 输出的评分器,生成候选文档的相关性分数(softmax 之前的分数)。
    # REALM_START_DOCSTRING 是一个可能是常量或字符串的变量或符号,可能用于文档字符串的起始。
    "The scorer of REALM outputting relevance scores representing the score of document candidates (before softmax).",
    REALM_START_DOCSTRING,
# RealmScorer 类的定义,继承自 RealmPreTrainedModel
class RealmScorer(RealmPreTrainedModel):
    r"""
    Args:
        query_embedder ([`RealmEmbedder`]):
            Embedder for input sequences. If not specified, it will use the same embedder as candidate sequences.
    """

    # 初始化方法,接受 config 和可选的 query_embedder 参数
    def __init__(self, config, query_embedder=None):
        super().__init__(config)

        # 创建 RealmEmbedder 对象并赋值给 self.embedder
        self.embedder = RealmEmbedder(self.config)

        # 如果 query_embedder 参数不为 None,则使用该参数作为 query_embedder;否则使用 self.embedder
        self.query_embedder = query_embedder if query_embedder is not None else self.embedder

        # 调用 post_init 方法,用于进一步初始化
        self.post_init()

    # 前向传播方法,使用 add_start_docstrings_to_model_forward 和 replace_return_docstrings 进行文档字符串的处理
    @add_start_docstrings_to_model_forward(REALM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=RealmScorerOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        candidate_input_ids: Optional[torch.LongTensor] = None,
        candidate_attention_mask: Optional[torch.FloatTensor] = None,
        candidate_token_type_ids: Optional[torch.LongTensor] = None,
        candidate_inputs_embeds: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,



# RealmKnowledgeAugEncoder 类的定义,继承自 RealmPreTrainedModel
@add_start_docstrings(
    "The knowledge-augmented encoder of REALM outputting masked language model logits and marginal log-likelihood"
    " loss.",
    REALM_START_DOCSTRING,
)
class RealmKnowledgeAugEncoder(RealmPreTrainedModel):
    _tied_weights_keys = ["cls.predictions.decoder"]

    # 初始化方法,接受 config 参数
    def __init__(self, config):
        super().__init__(config)
        
        # 创建 RealmBertModel 对象并赋值给 self.realm
        self.realm = RealmBertModel(self.config)
        
        # 创建 RealmOnlyMLMHead 对象并赋值给 self.cls
        self.cls = RealmOnlyMLMHead(self.config)
        
        # 调用 post_init 方法,用于进一步初始化
        self.post_init()

    # 获取输入嵌入层的方法,返回 self.realm.embeddings.word_embeddings
    def get_input_embeddings(self):
        return self.realm.embeddings.word_embeddings

    # 设置输入嵌入层的方法,将 value 赋给 self.realm.embeddings.word_embeddings
    def set_input_embeddings(self, value):
        self.realm.embeddings.word_embeddings = value

    # 获取输出嵌入层的方法,返回 self.cls.predictions.decoder
    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    # 设置输出嵌入层的方法,将 new_embeddings 赋给 self.cls.predictions.decoder
    def set_output_embeddings(self, new_embeddings):
        self.cls.predictions.decoder = new_embeddings

    # 前向传播方法,使用 add_start_docstrings_to_model_forward 和 replace_return_docstrings 进行文档字符串的处理
    @add_start_docstrings_to_model_forward(
        REALM_INPUTS_DOCSTRING.format("batch_size, num_candidates, sequence_length")
    )
    @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
    # 定义一个方法 forward,用于模型的前向传播
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,  # 输入的 token IDs,类型为可选的长整型张量
        attention_mask: Optional[torch.FloatTensor] = None,  # 注意力掩码,类型为可选的浮点张量
        token_type_ids: Optional[torch.LongTensor] = None,  # token 类型 IDs,类型为可选的长整型张量
        position_ids: Optional[torch.LongTensor] = None,  # 位置 IDs,类型为可选的长整型张量
        head_mask: Optional[torch.FloatTensor] = None,  # 头部掩码,类型为可选的浮点张量
        inputs_embeds: Optional[torch.FloatTensor] = None,  # 输入嵌入,类型为可选的浮点张量
        relevance_score: Optional[torch.FloatTensor] = None,  # 相关性分数,类型为可选的浮点张量
        labels: Optional[torch.LongTensor] = None,  # 标签,类型为可选的长整型张量
        mlm_mask: Optional[torch.LongTensor] = None,  # MLM 掩码,类型为可选的长整型张量
        output_attentions: Optional[bool] = None,  # 是否输出注意力信息,类型为可选的布尔值
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态信息,类型为可选的布尔值
        return_dict: Optional[bool] = None,  # 是否返回字典形式的输出,类型为可选的布尔值
# 使用装饰器添加文档字符串到 RealmReader 类,描述其作用为 REALM 的阅读器。
@add_start_docstrings("The reader of REALM.", REALM_START_DOCSTRING)
class RealmReader(RealmPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        # 初始化函数,继承自 RealmPreTrainedModel,设置类的标签数量
        self.num_labels = config.num_labels

        # 创建 REALM 的 BERT 模型
        self.realm = RealmBertModel(config)
        # 创建仅包含 MLM 头部的模型
        self.cls = RealmOnlyMLMHead(config)
        # 创建用于 Realm 阅读器的投影层
        self.qa_outputs = RealmReaderProjection(config)

        # 执行后续初始化
        self.post_init()

    # 使用装饰器添加文档字符串到 forward 方法,描述其输入和输出
    @add_start_docstrings_to_model_forward(REALM_INPUTS_DOCSTRING.format("reader_beam_size, sequence_length"))
    @replace_return_docstrings(output_type=RealmReaderOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        relevance_score: Optional[torch.FloatTensor] = None,
        block_mask: Optional[torch.BoolTensor] = None,
        start_positions: Optional[torch.LongTensor] = None,
        end_positions: Optional[torch.LongTensor] = None,
        has_answers: Optional[torch.BoolTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
REALM_FOR_OPEN_QA_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.
            
            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.
            
            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
            
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
            
            [What are attention masks?](../glossary#attention-mask)
        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:
            
            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token (should not be used in this model by design).
            
            [What are token type IDs?](../glossary#token-type-ids)
        answer_ids (`list` of shape `(num_answers, answer_length)`, *optional*):
            Answer ids for computing the marginal log-likelihood loss. Indices should be in `[-1, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-1` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
`RealmForOpenQA` 用于端到端的开放域问答。

该类继承自 `RealmPreTrainedModel`。
"""
class RealmForOpenQA(RealmPreTrainedModel):
    def __init__(self, config, retriever=None):
        """
        初始化方法,用于实例化一个 `RealmForOpenQA` 对象。

        Args:
            config (`PretrainedConfig`): 包含该模型配置信息的配置对象。
            retriever (`Optional`): 用于检索的对象,默认为 `None`。
        """
        super().__init__(config)
        self.embedder = RealmEmbedder(config)  # 实例化一个 `RealmEmbedder` 对象
        self.reader = RealmReader(config)  # 实例化一个 `RealmReader` 对象
        self.register_buffer(
            "block_emb",
            torch.zeros(()).new_empty(
                size=(config.num_block_records, config.retriever_proj_size),
                dtype=torch.float32,
                device=torch.device("cpu"),
            ),
        )
        self.retriever = retriever  # 设置检索器对象

        self.post_init()  # 调用初始化后处理方法

    @property
    def searcher_beam_size(self):
        """
        获取搜索器的 beam size。在训练模式下返回 `config.searcher_beam_size`,
        否则返回 `config.reader_beam_size`。

        Returns:
            `int`: beam size 的大小。
        """
        if self.training:
            return self.config.searcher_beam_size
        return self.config.reader_beam_size

    def block_embedding_to(self, device):
        """
        将 `self.block_emb` 发送到指定的设备。

        Args:
            device (`str` or `torch.device`):
                要发送 `self.block_emb` 的目标设备。
        """
        self.block_emb = self.block_emb.to(device)

    @add_start_docstrings_to_model_forward(REALM_FOR_OPEN_QA_DOCSTRING.format("1, sequence_length"))
    @replace_return_docstrings(output_type=RealmForOpenQAOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor],
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        answer_ids: Optional[torch.LongTensor] = None,
        return_dict: Optional[bool] = None,
        ):
        """
        模型的前向传播方法,用于执行推断或训练。

        Args:
            input_ids (`Optional[torch.LongTensor]`):
                输入的 token IDs。
            attention_mask (`Optional[torch.FloatTensor]`, optional):
                注意力掩码张量,默认为 `None`。
            token_type_ids (`Optional[torch.LongTensor]`, optional):
                分段 token IDs,默认为 `None`。
            answer_ids (`Optional[torch.LongTensor]`, optional):
                答案 token IDs,默认为 `None`。
            return_dict (`Optional[bool]`, optional):
                是否返回字典作为输出,默认为 `None`。

        Returns:
            `RealmForOpenQAOutput` 或者是一个字典,包含模型输出的各种信息。
        """

.\models\realm\retrieval_realm.py

# coding=utf-8
# Copyright 2022 The REALM authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""REALM Retriever model implementation."""

import os
from typing import Optional, Union

import numpy as np
from huggingface_hub import hf_hub_download

from ... import AutoTokenizer
from ...utils import logging


_REALM_BLOCK_RECORDS_FILENAME = "block_records.npy"


logger = logging.get_logger(__name__)


class ScaNNSearcher:
    """Note that ScaNNSearcher cannot currently be used within the model. In future versions, it might however be included."""

    def __init__(
        self,
        db,
        num_neighbors,
        dimensions_per_block=2,
        num_leaves=1000,
        num_leaves_to_search=100,
        training_sample_size=100000,
    ):
        """Build scann searcher."""
        
        # Import the necessary modules for constructing a SCANN searcher
        from scann.scann_ops.py.scann_ops_pybind import builder as Builder
        
        # Initialize the builder with database and search parameters
        builder = Builder(db=db, num_neighbors=num_neighbors, distance_measure="dot_product")
        
        # Configure the tree parameters
        builder = builder.tree(
            num_leaves=num_leaves, num_leaves_to_search=num_leaves_to_search, training_sample_size=training_sample_size
        )
        
        # Configure scoring parameters
        builder = builder.score_ah(dimensions_per_block=dimensions_per_block)
        
        # Build the searcher object
        self.searcher = builder.build()

    def search_batched(self, question_projection):
        """Perform batched search using the constructed SCANN searcher."""
        
        # Perform batched search and retrieve block IDs
        retrieved_block_ids, _ = self.searcher.search_batched(question_projection.detach().cpu())
        
        # Return retrieved block IDs as int64
        return retrieved_block_ids.astype("int64")


class RealmRetriever:
    """The retriever of REALM outputting the retrieved evidence block and whether the block has answers as well as answer
    positions."

        Parameters:
            block_records (`np.ndarray`):
                A numpy array which cantains evidence texts.
            tokenizer ([`RealmTokenizer`]):
                The tokenizer to encode retrieved texts.
    """

    def __init__(self, block_records, tokenizer):
        """Initialize RealmRetriever with block records and tokenizer."""
        
        # Initialize superclass
        super().__init__()
        
        # Store the provided block records
        self.block_records = block_records
        
        # Store the provided tokenizer
        self.tokenizer = tokenizer
    # 定义类的实例方法,用于生成压缩块的输入
    def __call__(self, retrieved_block_ids, question_input_ids, answer_ids, max_length=None, return_tensors="pt"):
        # 从 self.block_records 中按索引提取检索到的块
        retrieved_blocks = np.take(self.block_records, indices=retrieved_block_ids, axis=0)

        # 根据问题输入的 token IDs 解码出文本问题
        question = self.tokenizer.decode(question_input_ids[0], skip_special_tokens=True)

        # 初始化文本列表
        text = []
        text_pair = []

        # 遍历每个检索到的块
        for retrieved_block in retrieved_blocks:
            # 将问题文本添加到 text 列表
            text.append(question)
            # 将检索到的块解码并添加到 text_pair 列表
            text_pair.append(retrieved_block.decode())

        # 使用 tokenizer 处理 text 和 text_pair,进行拼接和填充等预处理
        concat_inputs = self.tokenizer(
            text, text_pair, padding=True, truncation=True, return_special_tokens_mask=True, max_length=max_length
        )

        # 将处理后的输入转换为张量
        concat_inputs_tensors = concat_inputs.convert_to_tensors(return_tensors)

        # 如果提供了答案 IDs,则调用 block_has_answer 方法计算答案和返回拼接输入的张量
        if answer_ids is not None:
            return self.block_has_answer(concat_inputs, answer_ids) + (concat_inputs_tensors,)
        else:
            # 否则返回空元组和拼接输入的张量
            return (None, None, None, concat_inputs_tensors)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *init_inputs, **kwargs):
        # 如果预训练模型路径是一个目录,则拼接出块记录文件的路径
        if os.path.isdir(pretrained_model_name_or_path):
            block_records_path = os.path.join(pretrained_model_name_or_path, _REALM_BLOCK_RECORDS_FILENAME)
        else:
            # 否则从 Hugging Face Hub 下载模型文件并指定块记录文件名
            block_records_path = hf_hub_download(
                repo_id=pretrained_model_name_or_path, filename=_REALM_BLOCK_RECORDS_FILENAME, **kwargs
            )
        # 加载块记录文件为 numpy 数组
        block_records = np.load(block_records_path, allow_pickle=True)

        # 从预训练模型加载 tokenizer
        tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)

        # 返回当前类的实例,初始化时传入加载的块记录和 tokenizer
        return cls(block_records, tokenizer)

    # 实例方法,用于将块记录和 tokenizer 保存到指定目录
    def save_pretrained(self, save_directory):
        # 保存块记录文件为 numpy 格式
        np.save(os.path.join(save_directory, _REALM_BLOCK_RECORDS_FILENAME), self.block_records)
        # 保存 tokenizer 到指定目录
        self.tokenizer.save_pretrained(save_directory)
    # 检查给定的拼接输入中是否包含答案
    def block_has_answer(self, concat_inputs, answer_ids):
        """check if retrieved_blocks has answers."""
        # 用于存储每个拼接输入是否含有答案的布尔列表
        has_answers = []
        # 用于存储每个拼接输入中所有答案起始位置的列表
        start_pos = []
        # 用于存储每个拼接输入中所有答案结束位置的列表
        end_pos = []
        # 记录每个拼接输入中最多的答案数
        max_answers = 0

        # 遍历每个拼接输入的input_ids
        for input_id in concat_inputs.input_ids:
            # 将input_id转换为Python列表
            input_id_list = input_id.tolist()
            # 查找第一个[SEP]标记的索引位置
            first_sep_idx = input_id_list.index(self.tokenizer.sep_token_id)
            # 查找第二个[SEP]标记的索引位置,限定搜索范围从第一个[SEP]之后开始
            second_sep_idx = first_sep_idx + 1 + input_id_list[first_sep_idx + 1:].index(self.tokenizer.sep_token_id)

            # 初始化存储当前拼接输入答案起始和结束位置的列表
            start_pos.append([])
            end_pos.append([])
            # 遍历每个答案id列表中的答案
            for answer in answer_ids:
                # 在第一个和第二个[SEP]之间查找答案的起始位置
                for idx in range(first_sep_idx + 1, second_sep_idx):
                    if answer[0] == input_id_list[idx]:
                        # 检查是否在当前位置开始的连续序列与答案匹配
                        if input_id_list[idx: idx + len(answer)] == answer:
                            # 将找到的答案起始和结束位置添加到列表中
                            start_pos[-1].append(idx)
                            end_pos[-1].append(idx + len(answer) - 1)

            # 如果当前拼接输入没有找到答案,则记录为False,否则记录为True
            if len(start_pos[-1]) == 0:
                has_answers.append(False)
            else:
                has_answers.append(True)
                # 更新当前拼接输入中最大答案数量
                if len(start_pos[-1]) > max_answers:
                    max_answers = len(start_pos[-1])

        # 对于没有答案的拼接输入,在start_pos和end_pos中填充-1以对齐最大答案数量
        for start_pos_, end_pos_ in zip(start_pos, end_pos):
            if len(start_pos_) < max_answers:
                padded = [-1] * (max_answers - len(start_pos_))
                start_pos_ += padded
                end_pos_ += padded

        # 返回结果:每个拼接输入是否含有答案的布尔列表,每个拼接输入中答案起始和结束位置的列表
        return has_answers, start_pos, end_pos

.\models\realm\tokenization_realm.py

# coding=utf-8
# Copyright 2022 The REALM authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for REALM."""

# Import necessary libraries
import collections  # 导入 collections 模块
import os  # 导入 os 模块
import unicodedata  # 导入 unicodedata 模块
from typing import List, Optional, Tuple  # 导入类型提示相关的模块

# Import from tokenization_utils
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
# Import from tokenization_utils_base
from ...tokenization_utils_base import BatchEncoding
# Import logging from utils
from ...utils import PaddingStrategy, logging

# Get logger instance for current module
logger = logging.get_logger(__name__)

# Define constant for vocabulary file names
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}

# Define mapping of pretrained model names to their respective vocabulary file URLs
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "google/realm-cc-news-pretrained-embedder": (
            "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt"
        ),
        "google/realm-cc-news-pretrained-encoder": (
            "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt"
        ),
        "google/realm-cc-news-pretrained-scorer": (
            "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt"
        ),
        "google/realm-cc-news-pretrained-openqa": (
            "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt"
        ),
        "google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/vocab.txt",
        "google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/vocab.txt",
        "google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/vocab.txt",
        "google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/vocab.txt",
    }
}

# Define sizes of positional embeddings for different pretrained models
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "google/realm-cc-news-pretrained-embedder": 512,
    "google/realm-cc-news-pretrained-encoder": 512,
    "google/realm-cc-news-pretrained-scorer": 512,
    "google/realm-cc-news-pretrained-openqa": 512,
    "google/realm-orqa-nq-openqa": 512,
    "google/realm-orqa-nq-reader": 512,
    "google/realm-orqa-wq-openqa": 512,
    "google/realm-orqa-wq-reader": 512,
}

# Define initial configurations for different pretrained models
PRETRAINED_INIT_CONFIGURATION = {
    "google/realm-cc-news-pretrained-embedder": {"do_lower_case": True},
    "google/realm-cc-news-pretrained-encoder": {"do_lower_case": True},
    "google/realm-cc-news-pretrained-scorer": {"do_lower_case": True},
    "google/realm-cc-news-pretrained-openqa": {"do_lower_case": True},
}
    # 定义一个字典,包含多个键值对,每个键是字符串,对应的值是一个字典,具有一个布尔型键"do_lower_case",其值为True
    "google/realm-orqa-nq-openqa": {"do_lower_case": True},
    "google/realm-orqa-nq-reader": {"do_lower_case": True},
    "google/realm-orqa-wq-openqa": {"do_lower_case": True},
    "google/realm-orqa-wq-reader": {"do_lower_case": True},
}

# 定义一个函数 load_vocab,用于加载一个词汇文件到一个有序字典中
def load_vocab(vocab_file):
    vocab = collections.OrderedDict()  # 创建一个有序字典对象 vocab
    with open(vocab_file, "r", encoding="utf-8") as reader:  # 打开词汇文件以读取模式,并指定编码为 utf-8
        tokens = reader.readlines()  # 读取文件的所有行并存储在 tokens 列表中
    for index, token in enumerate(tokens):  # 遍历 tokens 列表的索引和元素
        token = token.rstrip("\n")  # 去掉 token 末尾的换行符
        vocab[token] = index  # 将 token 和其索引添加到 vocab 字典中
    return vocab  # 返回加载完成的词汇字典

# 定义一个函数 whitespace_tokenize,用于对文本进行基本的空白符清理和分割
def whitespace_tokenize(text):
    text = text.strip()  # 去除文本两端的空白符
    if not text:  # 如果文本为空
        return []  # 返回空列表
    tokens = text.split()  # 使用空白符对文本进行分割,并存储结果在 tokens 列表中
    return tokens  # 返回分割后的 tokens 列表

# 定义一个类 RealmTokenizer,继承自 PreTrainedTokenizer 类
class RealmTokenizer(PreTrainedTokenizer):
    r"""
    Construct a REALM tokenizer.

    [`RealmTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation splitting and
    wordpiece.

    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
    this superclass for more information regarding those methods.
    """
    # 定义类的初始化方法,用于初始化一个新的Tokenizer对象
    def __init__(
        self,
        vocab_file,
        do_lower_case=True,
        do_basic_tokenize=True,
        never_split=None,
        unk_token="[UNK]",
        sep_token="[SEP]",
        pad_token="[PAD]",
        cls_token="[CLS]",
        mask_token="[MASK]",
        tokenize_chinese_chars=True,
        strip_accents=None,
        **kwargs,
    ):
    ):
        if not os.path.isfile(vocab_file):
            # 如果给定的词汇文件不存在,则抛出数值错误异常
            raise ValueError(
                f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
                " model use `tokenizer = RealmTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
            )
        # 加载给定路径下的词汇表文件并存储到实例变量 self.vocab 中
        self.vocab = load_vocab(vocab_file)
        # 使用 collections.OrderedDict 创建从词汇 ID 到词汇的有序映射
        self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
        # 根据参数决定是否执行基础分词
        self.do_basic_tokenize = do_basic_tokenize
        if do_basic_tokenize:
            # 如果需要进行基础分词,则创建 BasicTokenizer 对象
            self.basic_tokenizer = BasicTokenizer(
                do_lower_case=do_lower_case,
                never_split=never_split,
                tokenize_chinese_chars=tokenize_chinese_chars,
                strip_accents=strip_accents,
            )
        # 使用给定的词汇表和未知标记创建 WordpieceTokenizer 对象
        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
        # 调用父类构造函数,初始化实例
        super().__init__(
            do_lower_case=do_lower_case,
            do_basic_tokenize=do_basic_tokenize,
            never_split=never_split,
            unk_token=unk_token,
            sep_token=sep_token,
            pad_token=pad_token,
            cls_token=cls_token,
            mask_token=mask_token,
            tokenize_chinese_chars=tokenize_chinese_chars,
            strip_accents=strip_accents,
            **kwargs,
        )

    @property
    def do_lower_case(self):
        # 返回基础分词器的小写标志
        return self.basic_tokenizer.do_lower_case

    @property
    def vocab_size(self):
        # 返回词汇表的大小(词汇数量)
        return len(self.vocab)

    def get_vocab(self):
        # 返回包含词汇表及其附加标记编码器的字典
        return dict(self.vocab, **self.added_tokens_encoder)

    def _tokenize(self, text):
        # 对输入文本进行分词处理,返回分词后的 token 列表
        split_tokens = []
        if self.do_basic_tokenize:
            # 如果需要进行基础分词
            for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
                # 如果 token 在 never_split 集合中,则直接添加到分词结果列表中
                if token in self.basic_tokenizer.never_split:
                    split_tokens.append(token)
                else:
                    # 否则,使用 WordpieceTokenizer 对 token 进行进一步分词,并添加到分词结果列表中
                    split_tokens += self.wordpiece_tokenizer.tokenize(token)
        else:
            # 如果不需要进行基础分词,则直接使用 WordpieceTokenizer 对整个文本进行分词处理
            split_tokens = self.wordpiece_tokenizer.tokenize(text)
        return split_tokens

    def _convert_token_to_id(self, token):
        """Converts a token (str) in an id using the vocab."""
        # 根据词汇表将 token 转换为对应的 ID,如果未找到,则使用未知标记的 ID
        return self.vocab.get(token, self.vocab.get(self.unk_token))

    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        # 根据词汇表将索引转换为对应的 token,如果索引未找到,则使用未知标记
        return self.ids_to_tokens.get(index, self.unk_token)

    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (string) in a single string."""
        # 将 token 列表转换为单个字符串,去除连字符("##"),并去除两端空白
        out_string = " ".join(tokens).replace(" ##", "").strip()
        return out_string

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating and
        adding special tokens. A REALM sequence has the following format:

        - single sequence: `[CLS] X [SEP]`
        - pair of sequences: `[CLS] A [SEP] B [SEP]`

        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of input IDs with the appropriate special tokens.
        """
        # If only one sequence is provided, add `[CLS]`, the sequence tokens, and `[SEP]`
        if token_ids_1 is None:
            return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
        
        # For sequence pairs, construct `[CLS]`, tokens of first sequence, `[SEP]`, tokens of second sequence, and final `[SEP]`
        cls = [self.cls_token_id]
        sep = [self.sep_token_id]
        return cls + token_ids_0 + sep + token_ids_1 + sep

    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
        """
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` method.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """

        # If the token list already has special tokens, delegate to the base class method
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        # Construct a special tokens mask for sequences without existing special tokens
        if token_ids_1 is not None:
            return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
        return [1] + ([0] * len(token_ids_0)) + [1]

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Create token type IDs tensor from given sequence or pair of sequences. A REALM token type IDs sequence has the
        following format:

        - single sequence: `[0] * (len(token_ids_0) + 2)`
        - pair of sequences: `[0] * (len(token_ids_0) + 2) + [1] * (len(token_ids_1) + 1)`

        Args:
            token_ids_0 (`List[int]`):
                List of IDs for the first sequence.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of token type IDs with the appropriate length and values.
        """
    def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]]) -> List[int]:
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. A REALM sequence
        pair mask has the following format:

        ```
        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
        | first sequence    | second sequence |
        ```

        If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).

        Args:
            token_ids_0 (`List[int]`):
                List of IDs for the first sequence.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
        """
        # Define special tokens for separation and classification
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]
        
        # If only one sequence is provided, return a mask with zeros for its length
        if token_ids_1 is None:
            return len(cls + token_ids_0 + sep) * [0]
        
        # For sequence pairs, concatenate tokens with special tokens and create the mask
        return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # Initialize index for vocabulary items
        index = 0
        
        # Determine the vocabulary file path based on whether save_directory is a directory or a file path
        if os.path.isdir(save_directory):
            vocab_file = os.path.join(
                save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
            )
        else:
            vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
        
        # Write the vocabulary items to the specified file
        with open(vocab_file, "w", encoding="utf-8") as writer:
            # Iterate over sorted vocabulary items and write them to the file
            for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
                # Check for non-consecutive indices and log a warning if found
                if index != token_index:
                    logger.warning(
                        f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
                        " Please check that the vocabulary is not corrupted!"
                    )
                    index = token_index
                # Write the token to the file followed by a newline
                writer.write(token + "\n")
                index += 1
        
        # Return the path to the saved vocabulary file
        return (vocab_file,)
# 定义一个名为 BasicTokenizer 的类,用于执行基本的分词操作(如标点符号分割、转换为小写等)。
class BasicTokenizer(object):

    """
    构造一个 BasicTokenizer 实例,用于运行基本的分词操作(如标点符号分割、转换为小写等)。

    Args:
        do_lower_case (`bool`, *optional*, defaults to `True`):
            是否在分词时将输入转换为小写。
        never_split (`Iterable`, *optional*):
            在分词过程中永远不会被拆分的 token 集合。仅在 `do_basic_tokenize=True` 时有效。
        tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
            是否分词中文字符。
            
            对于日语,这应该被禁用(参见此 [issue](https://github.com/huggingface/transformers/issues/328))。
        strip_accents (`bool`, *optional*):
            是否去除所有的重音符号。如果未指定此选项,则会根据 `lowercase` 的值(与原始 BERT 相同)来确定。
    """

    def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
        # 如果 never_split 参数为 None,则将其设置为空列表
        if never_split is None:
            never_split = []
        # 设置是否将输入转换为小写
        self.do_lower_case = do_lower_case
        # 将 never_split 转换为集合,这些 token 在分词时不会被拆分
        self.never_split = set(never_split)
        # 是否分词中文字符
        self.tokenize_chinese_chars = tokenize_chinese_chars
        # 是否去除重音符号
        self.strip_accents = strip_accents
    def tokenize(self, text, never_split=None):
        """
        Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
        WordPieceTokenizer.

        Args:
            never_split (`List[str]`, *optional*)
                Kept for backward compatibility purposes. Now implemented directly at the base class level (see
                [`PreTrainedTokenizer.tokenize`]) List of token not to split.
        """
        # 如果提供了 never_split 参数,则将其与 self.never_split 取并集,否则使用 self.never_split
        never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
        # 清理文本,处理特殊字符等
        text = self._clean_text(text)

        # 以下内容于2018年11月1日添加,用于多语言和中文模型。
        # 现在也应用于英语模型,但这并不重要,因为英语模型没有在任何中文数据上训练,
        # 通常不包含任何中文数据(尽管词汇表中有些中文词汇,因为英文维基百科中有一些中文词汇)。
        if self.tokenize_chinese_chars:
            # 对中文字符进行特殊处理
            text = self._tokenize_chinese_chars(text)
        # 将文本按空白符分割为原始token
        orig_tokens = whitespace_tokenize(text)
        split_tokens = []
        for token in orig_tokens:
            if token not in never_split:
                if self.do_lower_case:
                    # 如果设置为小写,则将token转换为小写
                    token = token.lower()
                    if self.strip_accents is not False:
                        # 如果需要去除重音符号,则执行去除重音符号操作
                        token = self._run_strip_accents(token)
                elif self.strip_accents:
                    # 如果需要去除重音符号,则执行去除重音符号操作
                    token = self._run_strip_accents(token)
            # 将token根据标点符号进行分割,并扩展到split_tokens中
            split_tokens.extend(self._run_split_on_punc(token, never_split))

        # 将分割后的token重新按空白符合并,并返回
        output_tokens = whitespace_tokenize(" ".join(split_tokens))
        return output_tokens

    def _run_strip_accents(self, text):
        """Strips accents from a piece of text."""
        # 将文本中的重音符号规范化为NFD形式
        text = unicodedata.normalize("NFD", text)
        output = []
        for char in text:
            cat = unicodedata.category(char)
            # 如果字符类别为Mn(非间距连字符),则跳过该字符
            if cat == "Mn":
                continue
            output.append(char)
        # 将处理后的字符列表连接成字符串并返回
        return "".join(output)

    def _run_split_on_punc(self, text, never_split=None):
        """Splits punctuation on a piece of text."""
        # 如果指定了never_split,并且text在never_split中,则不分割,直接返回
        if never_split is not None and text in never_split:
            return [text]
        chars = list(text)
        i = 0
        start_new_word = True
        output = []
        while i < len(chars):
            char = chars[i]
            # 如果是标点符号,则将其作为一个新的列表项添加到output中
            if _is_punctuation(char):
                output.append([char])
                start_new_word = True
            else:
                # 如果不是标点符号,根据start_new_word标志判断是否创建新的列表项
                if start_new_word:
                    output.append([])
                start_new_word = False
                output[-1].append(char)
            i += 1

        # 将分割后的字符列表重新连接成字符串并返回
        return ["".join(x) for x in output]
    def _tokenize_chinese_chars(self, text):
        """Adds whitespace around any CJK character."""
        output = []
        for char in text:
            cp = ord(char)  # 获取字符的 Unicode 码点
            if self._is_chinese_char(cp):  # 如果字符是中日韩字符,则在其前后添加空格
                output.append(" ")
                output.append(char)
                output.append(" ")
            else:
                output.append(char)  # 如果不是中日韩字符,则直接添加字符
        return "".join(output)  # 将处理后的字符列表连接成字符串并返回

    def _is_chinese_char(self, cp):
        """Checks whether CP is the codepoint of a CJK character."""
        # This defines a "chinese character" as anything in the CJK Unicode block:
        #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
        #
        # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
        # despite its name. The modern Korean Hangul alphabet is a different block,
        # as is Japanese Hiragana and Katakana. Those alphabets are used to write
        # space-separated words, so they are not treated specially and handled
        # like the all of the other languages.
        if (
            (cp >= 0x4E00 and cp <= 0x9FFF)  # CJK 统一汉字
            or (cp >= 0x3400 and cp <= 0x4DBF)  # CJK 扩展A
            or (cp >= 0x20000 and cp <= 0x2A6DF)  # CJK 扩展B
            or (cp >= 0x2A700 and cp <= 0x2B73F)  # CJK 扩展C
            or (cp >= 0x2B740 and cp <= 0x2B81F)  # CJK 扩展D
            or (cp >= 0x2B820 and cp <= 0x2CEAF)  # CJK 扩展E
            or (cp >= 0xF900 and cp <= 0xFAFF)  # CJK 兼容汉字
            or (cp >= 0x2F800 and cp <= 0x2FA1F)  # CJK 兼容表意文字
        ):  # 判断 Unicode 码点是否在中日韩字符范围内
            return True

        return False

    def _clean_text(self, text):
        """Performs invalid character removal and whitespace cleanup on text."""
        output = []
        for char in text:
            cp = ord(char)  # 获取字符的 Unicode 码点
            if cp == 0 or cp == 0xFFFD or _is_control(char):  # 如果字符是无效字符或控制字符,则跳过
                continue
            if _is_whitespace(char):  # 如果字符是空白字符,则替换为单个空格
                output.append(" ")
            else:
                output.append(char)  # 否则保留字符
        return "".join(output)  # 将处理后的字符列表连接成字符串并返回
class WordpieceTokenizer(object):
    """Runs WordPiece tokenization."""

    def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
        # 初始化 WordpieceTokenizer 类,设置词汇表、未知标记和每个单词最大字符数
        self.vocab = vocab
        self.unk_token = unk_token
        self.max_input_chars_per_word = max_input_chars_per_word

    def tokenize(self, text):
        """
        Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
        tokenization using the given vocabulary.

        For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.

        Args:
            text: A single token or whitespace separated tokens. This should have
                already been passed through *BasicTokenizer*.

        Returns:
            A list of wordpiece tokens.
        """
        # 初始化输出的 token 列表
        output_tokens = []
        # 对文本进行分词处理,以空白字符为分隔符
        for token in whitespace_tokenize(text):
            chars = list(token)
            # 如果单词长度超过设定的最大字符数,将其视为未知标记
            if len(chars) > self.max_input_chars_per_word:
                output_tokens.append(self.unk_token)
                continue

            is_bad = False
            start = 0
            sub_tokens = []
            while start < len(chars):
                end = len(chars)
                cur_substr = None
                # 采用贪婪算法寻找最长匹配的子串
                while start < end:
                    substr = "".join(chars[start:end])
                    if start > 0:
                        substr = "##" + substr
                    # 检查子串是否在词汇表中
                    if substr in self.vocab:
                        cur_substr = substr
                        break
                    end -= 1
                if cur_substr is None:
                    is_bad = True
                    break
                sub_tokens.append(cur_substr)
                start = end

            # 如果无法成功分词,则添加未知标记;否则添加分词结果到输出列表
            if is_bad:
                output_tokens.append(self.unk_token)
            else:
                output_tokens.extend(sub_tokens)
        return output_tokens

.\models\realm\tokenization_realm_fast.py

# coding=utf-8
# 版权 2022 年 REALM 作者和 HuggingFace Inc. 团队所有。
#
# 根据 Apache 许可证 2.0 版本(“许可证”)获得许可;
# 除非符合许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,软件根据“原样”分发,
# 没有任何明示或暗示的担保或条件。
# 有关详细信息,请参阅许可证。
"""REALM 的快速分词类。"""

import json
from typing import List, Optional, Tuple

from tokenizers import normalizers  # 导入 tokenizers 包中的 normalizers 模块

from ...tokenization_utils_base import BatchEncoding  # 导入 BatchEncoding 类
from ...tokenization_utils_fast import PreTrainedTokenizerFast  # 导入 PreTrainedTokenizerFast 类
from ...utils import PaddingStrategy, logging  # 导入 PaddingStrategy 和 logging 类
from .tokenization_realm import RealmTokenizer  # 从当前目录导入 RealmTokenizer 类

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}  # 定义 VOCAB_FILES_NAMES 字典

PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "google/realm-cc-news-pretrained-embedder": (
            "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt"
        ),
        "google/realm-cc-news-pretrained-encoder": (
            "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt"
        ),
        "google/realm-cc-news-pretrained-scorer": (
            "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt"
        ),
        "google/realm-cc-news-pretrained-openqa": (
            "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt"
        ),
        "google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/vocab.txt",
        "google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/vocab.txt",
        "google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/vocab.txt",
        "google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/vocab.txt",
    },
    # 定义 PRETRAINED_VOCAB_FILES_MAP 字典,包含不同模型的预训练词汇文件 URL
}
    # 定义一个字典,存储多个模型的名称和对应的 tokenizer 文件的 URL
    "tokenizer_file": {
        # 模型 google/realm-cc-news-pretrained-embedder 的 tokenizer 文件 URL
        "google/realm-cc-news-pretrained-embedder": (
            "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/tokenizer.jsont"
        ),
        # 模型 google/realm-cc-news-pretrained-encoder 的 tokenizer 文件 URL
        "google/realm-cc-news-pretrained-encoder": (
            "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/tokenizer.json"
        ),
        # 模型 google/realm-cc-news-pretrained-scorer 的 tokenizer 文件 URL
        "google/realm-cc-news-pretrained-scorer": (
            "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/tokenizer.json"
        ),
        # 模型 google/realm-cc-news-pretrained-openqa 的 tokenizer 文件 URL
        "google/realm-cc-news-pretrained-openqa": (
            "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/tokenizer.json"
        ),
        # 模型 google/realm-orqa-nq-openqa 的 tokenizer 文件 URL
        "google/realm-orqa-nq-openqa": (
            "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/tokenizer.json"
        ),
        # 模型 google/realm-orqa-nq-reader 的 tokenizer 文件 URL
        "google/realm-orqa-nq-reader": (
            "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/tokenizer.json"
        ),
        # 模型 google/realm-orqa-wq-openqa 的 tokenizer 文件 URL
        "google/realm-orqa-wq-openqa": (
            "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/tokenizer.json"
        ),
        # 模型 google/realm-orqa-wq-reader 的 tokenizer 文件 URL
        "google/realm-orqa-wq-reader": (
            "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/tokenizer.json"
        ),
    },
}

# 定义预训练模型的位置嵌入大小字典,每个模型名称映射到其对应的位置嵌入大小(均为512)
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "google/realm-cc-news-pretrained-embedder": 512,
    "google/realm-cc-news-pretrained-encoder": 512,
    "google/realm-cc-news-pretrained-scorer": 512,
    "google/realm-cc-news-pretrained-openqa": 512,
    "google/realm-orqa-nq-openqa": 512,
    "google/realm-orqa-nq-reader": 512,
    "google/realm-orqa-wq-openqa": 512,
    "google/realm-orqa-wq-reader": 512,
}

# 定义预训练模型初始化配置字典,每个模型名称映射到其对应的初始化配置字典,这里只设置了一个通用项 do_lower_case=True
PRETRAINED_INIT_CONFIGURATION = {
    "google/realm-cc-news-pretrained-embedder": {"do_lower_case": True},
    "google/realm-cc-news-pretrained-encoder": {"do_lower_case": True},
    "google/realm-cc-news-pretrained-scorer": {"do_lower_case": True},
    "google/realm-cc-news-pretrained-openqa": {"do_lower_case": True},
    "google/realm-orqa-nq-openqa": {"do_lower_case": True},
    "google/realm-orqa-nq-reader": {"do_lower_case": True},
    "google/realm-orqa-wq-openqa": {"do_lower_case": True},
    "google/realm-orqa-wq-reader": {"do_lower_case": True},
}


class RealmTokenizerFast(PreTrainedTokenizerFast):
    r"""
    Construct a "fast" REALM tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece.

    [`RealmTokenizerFast`] is identical to [`BertTokenizerFast`] and runs end-to-end tokenization: punctuation
    splitting and wordpiece.

    This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
    refer to this superclass for more information regarding those methods.
    # 定义预置的词汇文件名列表,通常包含不同语言的词汇文件名
    vocab_files_names = VOCAB_FILES_NAMES
    # 定义预训练模型的词汇文件映射,用于加载不同预训练模型的词汇文件
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    # 定义预训练模型的初始化配置,可能包含模型的特定参数配置
    pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
    # 定义预训练模型的最大输入长度限制,通常用于限制输入序列的最大长度
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    # 定义慢速分词器类,通常用于特定语言或对分词速度要求不高的场景
    slow_tokenizer_class = RealmTokenizer

    # 初始化函数,用于创建一个新的实例对象
    def __init__(
        self,
        vocab_file=None,  # 词汇文件路径,用于加载模型的词汇
        tokenizer_file=None,  # 分词器文件路径,用于加载保存的分词器模型
        do_lower_case=True,  # 是否将输入文本转为小写
        unk_token="[UNK]",  # 未知标记,用于词汇中未出现的词的表示
        sep_token="[SEP]",  # 分隔符标记,用于组合多个序列的标记
        pad_token="[PAD]",  # 填充标记,用于批处理不同长度的序列
        cls_token="[CLS]",  # 分类器标记,用于序列分类任务的开始标记
        mask_token="[MASK]",  # 掩码标记,用于掩码语言模型训练中的预测
        tokenize_chinese_chars=True,  # 是否分词中文字符
        strip_accents=None,  # 是否去除所有的重音符号
        **kwargs,  # 其他可选参数,用于灵活设置
        ):
        # 调用父类的构造函数初始化对象,设置各种参数
        super().__init__(
            vocab_file,
            tokenizer_file=tokenizer_file,
            do_lower_case=do_lower_case,
            unk_token=unk_token,
            sep_token=sep_token,
            pad_token=pad_token,
            cls_token=cls_token,
            mask_token=mask_token,
            tokenize_chinese_chars=tokenize_chinese_chars,
            strip_accents=strip_accents,
            **kwargs,
        )

        # 从后端分词器获取规范化器的状态并反序列化
        normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
        # 检查当前对象的参数是否与规范化器的状态匹配,如果不匹配则更新规范化器
        if (
            normalizer_state.get("lowercase", do_lower_case) != do_lower_case
            or normalizer_state.get("strip_accents", strip_accents) != strip_accents
            or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars
        ):
            normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
            normalizer_state["lowercase"] = do_lower_case
            normalizer_state["strip_accents"] = strip_accents
            normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars
            self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)

        # 设置对象的小写处理标志位
        self.do_lower_case = do_lower_case

    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
        adding special tokens. A REALM sequence has the following format:

        - single sequence: `[CLS] X [SEP]`
        - pair of sequences: `[CLS] A [SEP] B [SEP]`

        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
        """
        # 构建带有特殊标记的模型输入序列
        output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]

        if token_ids_1 is not None:
            output += token_ids_1 + [self.sep_token_id]

        return output

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ):
    def create_token_type_ids_from_sequences(
        self,
        token_ids_0: List[int],
        token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. A REALM sequence
        pair mask has the following format:

        ```
        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
        | first sequence    | second sequence |
        ```

        If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).

        Args:
            token_ids_0 (`List[int]`):
                List of IDs for the first sequence.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
        """
        # Define the separator and classification token IDs
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]
        
        # If token_ids_1 is None, return a mask with zeros for only the first sequence
        if token_ids_1 is None:
            return len(cls + token_ids_0 + sep) * [0]
        
        # Otherwise, concatenate masks for both sequences where the first sequence has 0s and the second has 1s
        return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        """
        Save the tokenizer's vocabulary to the specified directory.

        Args:
            save_directory (str):
                Directory path where the vocabulary will be saved.
            filename_prefix (str, *optional*):
                Optional prefix for the saved vocabulary files.

        Returns:
            Tuple[str]: Tuple containing the paths to the saved files.
        """
        # Save the tokenizer's model (vocabulary) to the specified directory with an optional filename prefix
        files = self._tokenizer.model.save(save_directory, name=filename_prefix)
        
        # Return a tuple of file paths that were saved
        return tuple(files)

.\models\realm\__init__.py

# 版权声明和许可声明,说明该文件的版权归 HuggingFace 团队所有,使用 Apache License 2.0 进行许可
#
# 如果不符合许可协议的规定,除非法律另有要求或书面同意,否则不得使用此文件
from typing import TYPE_CHECKING

# 从 utils 模块中导入 OptionalDependencyNotAvailable 异常类和 LazyModule 类
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available

# 定义要导入的模块结构,包括配置、tokenization、modeling 和 retrieval 的相关内容
_import_structure = {
    "configuration_realm": ["REALM_PRETRAINED_CONFIG_ARCHIVE_MAP", "RealmConfig"],
    "tokenization_realm": ["RealmTokenizer"],
}

# 检查是否有 tokenizers 可用,如果不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_tokenizers_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,则添加 "tokenization_realm_fast" 到导入结构中,包含 "RealmTokenizerFast"
    _import_structure["tokenization_realm_fast"] = ["RealmTokenizerFast"]

# 检查是否有 torch 可用,如果不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,则添加 "modeling_realm" 和 "retrieval_realm" 到导入结构中,包含相关类和函数
    _import_structure["modeling_realm"] = [
        "REALM_PRETRAINED_MODEL_ARCHIVE_LIST",
        "RealmEmbedder",
        "RealmForOpenQA",
        "RealmKnowledgeAugEncoder",
        "RealmPreTrainedModel",
        "RealmReader",
        "RealmScorer",
        "load_tf_weights_in_realm",
    ]
    _import_structure["retrieval_realm"] = ["RealmRetriever"]

# 如果是类型检查模式,则从相应模块中导入所需内容
if TYPE_CHECKING:
    from .configuration_realm import REALM_PRETRAINED_CONFIG_ARCHIVE_MAP, RealmConfig
    from .tokenization_realm import RealmTokenizer

    try:
        if not is_tokenizers_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .tokenization_realm import RealmTokenizerFast

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_realm import (
            REALM_PRETRAINED_MODEL_ARCHIVE_LIST,
            RealmEmbedder,
            RealmForOpenQA,
            RealmKnowledgeAugEncoder,
            RealmPreTrainedModel,
            RealmReader,
            RealmScorer,
            load_tf_weights_in_realm,
        )
        from .retrieval_realm import RealmRetriever

# 如果不是类型检查模式,则将当前模块设置为 LazyModule 的实例,导入 _import_structure 中定义的内容
else:
    import sys

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\reformer\configuration_reformer.py

# 设置文件编码为UTF-8,确保可以正确处理中文等特殊字符
# 版权声明和许可信息,指定代码的使用权限和限制条件
# 引入预训练配置模块和日志记录工具
from ...configuration_utils import PretrainedConfig
from ...utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义预训练模型配置文件的下载链接映射
REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "google/reformer-crime-and-punishment": (
        "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/config.json"
    ),
    "google/reformer-enwik8": "https://huggingface.co/google/reformer-enwik8/resolve/main/config.json",
}

# 定义ReformerConfig类,继承自PretrainedConfig类
class ReformerConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`ReformerModel`]. It is used to instantiate a
    Reformer model according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the ReFormer
    [google/reformer-crime-and-punishment](https://huggingface.co/google/reformer-crime-and-punishment) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Examples:

    ```
    >>> from transformers import ReformerConfig, ReformerModel

    >>> # Initializing a Reformer configuration
    >>> configuration = ReformerConfig()

    >>> # Initializing a Reformer model (with random weights)
    >>> model = ReformerModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
"""

    # 模型类型为reformer,用于标识模型种类
    model_type = "reformer"
    # 在推理过程中忽略的特定键,用于控制模型输出
    keys_to_ignore_at_inference = ["past_buckets_states"]
    # 属性映射,暂未定义任何属性
    attribute_map = {}
    def __init__(
        self,
        attention_head_size=64,
        attn_layers=["local", "lsh", "local", "lsh", "local", "lsh"],
        axial_norm_std=1.0,
        axial_pos_embds=True,
        axial_pos_shape=[64, 64],
        axial_pos_embds_dim=[64, 192],
        chunk_size_lm_head=0,
        eos_token_id=2,
        feed_forward_size=512,
        hash_seed=None,
        hidden_act="relu",
        hidden_dropout_prob=0.05,
        hidden_size=256,
        initializer_range=0.02,
        is_decoder=False,
        layer_norm_eps=1e-12,
        local_num_chunks_before=1,
        local_num_chunks_after=0,
        local_attention_probs_dropout_prob=0.05,
        local_attn_chunk_length=64,
        lsh_attn_chunk_length=64,
        lsh_attention_probs_dropout_prob=0.0,
        lsh_num_chunks_before=1,
        lsh_num_chunks_after=0,
        max_position_embeddings=4096,
        num_attention_heads=12,
        num_buckets=None,
        num_hashes=1,
        pad_token_id=0,
        vocab_size=320,
        tie_word_embeddings=False,
        use_cache=True,
        classifier_dropout=None,
        **kwargs,
    ):
        # 设置对象的哈希种子
        self.hash_seed = hash_seed
        # 设置对象的词汇表大小
        self.vocab_size = vocab_size
        # 设置对象的注意力头大小
        self.attention_head_size = attention_head_size
        # 设置对象的隐藏层大小
        self.hidden_size = hidden_size
        # 设置对象的注意力头数量
        self.num_attention_heads = num_attention_heads
        # 设置对象的哈希数量
        self.num_hashes = num_hashes
        # 记录对象的注意力层总数
        self.num_hidden_layers = len(attn_layers)
        # 将桶的数量转换为元组形式(如果是列表的话)
        self.num_buckets = tuple(num_buckets) if isinstance(num_buckets, list) else num_buckets
        # 设置LSH注意力的块长度
        self.lsh_attn_chunk_length = lsh_attn_chunk_length
        # 设置局部注意力的块长度
        self.local_attn_chunk_length = local_attn_chunk_length
        # 设置LSH注意力之后的块数
        self.lsh_num_chunks_after = lsh_num_chunks_after
        # 设置LSH注意力之前的块数
        self.lsh_num_chunks_before = lsh_num_chunks_before
        # 设置局部注意力之后的块数
        self.local_num_chunks_after = local_num_chunks_after
        # 设置局部注意力之前的块数
        self.local_num_chunks_before = local_num_chunks_before
        # 设置隐藏层激活函数类型
        self.hidden_act = hidden_act
        # 设置前馈网络的大小
        self.feed_forward_size = feed_forward_size
        # 设置隐藏层的丢弃概率
        self.hidden_dropout_prob = hidden_dropout_prob
        # 设置LSH注意力的注意力概率丢弃概率
        self.lsh_attention_probs_dropout_prob = lsh_attention_probs_dropout_prob
        # 设置局部注意力的注意力概率丢弃概率
        self.local_attention_probs_dropout_prob = local_attention_probs_dropout_prob
        # 设置最大位置嵌入的长度
        self.max_position_embeddings = max_position_embeddings
        # 设置初始化器的范围
        self.initializer_range = initializer_range
        # 设置层归一化的epsilon值
        self.layer_norm_eps = layer_norm_eps
        # 设置是否使用轴向位置嵌入
        self.axial_pos_embds = axial_pos_embds
        # 设置轴向位置嵌入的形状
        self.axial_pos_shape = tuple(axial_pos_shape)
        # 设置轴向位置嵌入的维度
        self.axial_pos_embds_dim = tuple(axial_pos_embds_dim)
        # 设置轴向归一化的标准差
        self.axial_norm_std = axial_norm_std
        # 设置语言模型头部的块大小
        self.chunk_size_lm_head = chunk_size_lm_head
        # 设置注意力层的类型列表
        self.attn_layers = attn_layers
        # 设置是否使用缓存
        self.use_cache = use_cache
        # 设置分类器的丢弃率
        self.classifier_dropout = classifier_dropout
        # 调用父类的初始化方法,传入关键参数
        super().__init__(
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            is_decoder=is_decoder,
            tie_word_embeddings=tie_word_embeddings,
            **kwargs,
        )

.\models\reformer\convert_reformer_trax_checkpoint_to_pytorch.py

# 导入必要的库和模块
import argparse  # 用于解析命令行参数
import pickle    # 用于序列化和反序列化 Python 对象

import numpy as np   # 导入 NumPy 库,用于处理数组
import torch         # 导入 PyTorch 库
from torch import nn  # 导入 PyTorch 的神经网络模块

from transformers import ReformerConfig, ReformerModelWithLMHead  # 导入 transformers 库中的 Reformer 模型相关类
from transformers.utils import logging   # 导入 logging 模块,用于日志记录

logging.set_verbosity_info()  # 设置日志记录级别为 info

def set_param(torch_layer, weight, bias=None):
    # 设置一个神经网络层的参数
    assert torch_layer.weight.shape == weight.shape, f"{torch_layer} layer.weight does not match"
    # 将给定的权重设为神经网络层的权重参数
    torch_layer.weight = nn.Parameter(weight)
    if bias is not None:
        assert torch_layer.bias.shape == bias.shape, f"{torch_layer} layer.bias does not match"
        # 将给定的偏置设为神经网络层的偏置参数
        torch_layer.bias = nn.Parameter(bias)

def set_layer_weights_in_torch_lsh(weights, torch_layer, hidden_size):
    # 设置 Torch 中 LSH(Locality-Sensitive Hashing)层的权重
    np_query_key = np.asarray(weights[0])
    np_value = np.asarray(weights[1])
    np_dense = np.asarray(weights[2])

    # 设置自注意力机制中 query_key 的权重参数
    set_param(
        torch_layer.self_attention.query_key,
        torch.tensor(np_query_key).transpose(1, 2).contiguous().view(-1, hidden_size),
    )
    # 设置自注意力机制中 value 的权重参数
    set_param(
        torch_layer.self_attention.value,
        torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size),
    )
    # 设置输出层的密集连接层(dense layer)的权重参数
    set_param(
        torch_layer.output.dense,
        torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1),
    )

def set_layer_weights_in_torch_local(weights, torch_layer, hidden_size):
    # 设置 Torch 中 Local 层的权重
    np_query = np.asarray(weights[0])
    np_key = np.asarray(weights[1])
    np_value = np.asarray(weights[2])
    np_dense = np.asarray(weights[3])

    # 设置自注意力机制中 query 的权重参数
    set_param(
        torch_layer.self_attention.query,
        torch.tensor(np_query).transpose(1, 2).contiguous().view(-1, hidden_size),
    )
    # 设置自注意力机制中 key 的权重参数
    set_param(
        torch_layer.self_attention.key,
        torch.tensor(np_key).transpose(1, 2).contiguous().view(-1, hidden_size),
    )
    # 设置自注意力机制中 value 的权重参数
    set_param(
        torch_layer.self_attention.value,
        torch.tensor(np_value).transpose(1, 2).contiguous().view(-1, hidden_size),
    )
    # 设置输出层的密集连接层(dense layer)的权重参数
    set_param(
        torch_layer.output.dense,
        torch.tensor(np_dense).view(-1, hidden_size).contiguous().transpose(0, 1),
    )

def set_block_weights_in_torch(weights, torch_block, hidden_size):
    # 设置 Torch 中的块(block)的权重
    # layernorm 1
    layer_norm_1 = weights[0][0][0]
    layer_norm_1_weight = np.asarray(layer_norm_1[0])
    # 将 layer_norm_1 的偏置项转换为 NumPy 数组
    layer_norm_1_bias = np.asarray(layer_norm_1[1])
    
    # 设置注意力层的参数,包括层归一化的权重和偏置
    set_param(
        torch_block.attention.layer_norm,  # 设置注意力层的层归一化
        torch.tensor(layer_norm_1_weight),  # 转换为 PyTorch 张量并设置层归一化的权重
        torch.tensor(layer_norm_1_bias),    # 转换为 PyTorch 张量并设置层归一化的偏置
    )
    
    # 获取注意力权重
    attn_weights = weights[0][1]
    
    # 根据注意力权重的长度选择设置分片式或局部的注意力层权重
    if len(attn_weights) < 4:
        set_layer_weights_in_torch_lsh(attn_weights, torch_block.attention, hidden_size)
    else:
        set_layer_weights_in_torch_local(attn_weights, torch_block.attention, hidden_size)
    
    # 获取中间权重
    intermediate_weights = weights[2][0][1][2]
    
    # 如果中间权重长度为 4,则选择其中的第三个权重作为 Chunked Feed Forward 的权重
    if len(intermediate_weights) == 4:
        intermediate_weights = intermediate_weights[2]
    
    # 设置第二个层归一化的权重和偏置
    layer_norm_2_weight = np.asarray(intermediate_weights[0][0])
    layer_norm_2_bias = np.asarray(intermediate_weights[0][1])
    set_param(
        torch_block.feed_forward.layer_norm,  # 设置前馈层的层归一化
        torch.tensor(layer_norm_2_weight),    # 转换为 PyTorch 张量并设置层归一化的权重
        torch.tensor(layer_norm_2_bias),      # 转换为 PyTorch 张量并设置层归一化的偏置
    )
    
    # 设置中间密集层的权重和偏置
    inter_dense_weight = np.asarray(intermediate_weights[1][0])
    inter_dense_bias = np.asarray(intermediate_weights[1][1])
    set_param(
        torch_block.feed_forward.dense.dense,  # 设置前馈层的密集层权重
        torch.tensor(inter_dense_weight).transpose(0, 1).contiguous(),  # 转换为 PyTorch 张量并设置密集层的权重
        torch.tensor(inter_dense_bias),        # 转换为 PyTorch 张量并设置密集层的偏置
    )
    
    # 设置中间输出层的权重和偏置
    out_dense_weight = np.asarray(intermediate_weights[4][0])
    out_dense_bias = np.asarray(intermediate_weights[4][1])
    set_param(
        torch_block.feed_forward.output.dense,  # 设置前馈层的输出层权重
        torch.tensor(out_dense_weight).transpose(0, 1).contiguous(),  # 转换为 PyTorch 张量并设置输出层的权重
        torch.tensor(out_dense_bias),            # 转换为 PyTorch 张量并设置输出层的偏置
    )
# 将给定的权重设置到指定的 PyTorch 模型中
def set_model_weights_in_torch(weights, torch_model, hidden_size):
    # 获取 PyTorch 模型中的 reformer 部分
    torch_model_reformer = torch_model.reformer

    # 从权重中获取词嵌入
    word_embeddings = np.asarray(weights[1])
    # 设置词嵌入参数
    set_param(
        torch_model_reformer.embeddings.word_embeddings,
        torch.tensor(word_embeddings),
    )

    # 如果权重的第 3 项是元组
    if isinstance(weights[3], tuple):
        # 获取位置嵌入
        position_embeddings = torch_model_reformer.embeddings.position_embeddings
        # 遍历位置嵌入的权重
        for emb_idx in range(len(position_embeddings.weights)):
            emb_weights = np.asarray(weights[3][emb_idx][0])
            # 断言确保位置嵌入的形状匹配
            assert (
                position_embeddings.weights[emb_idx].shape == emb_weights.shape
            ), f"{position_embeddings[emb_idx]} emb does not match"
            # 设置位置嵌入参数为可训练的 Tensor
            position_embeddings.weights[emb_idx] = nn.Parameter(torch.tensor(emb_weights))

    # 获取 Trax 模型的层权重
    trax_layer_weights = weights[5]
    # 断言确保编码器层的数量匹配
    assert len(torch_model_reformer.encoder.layers) * 4 == len(
        trax_layer_weights
    ), "HF and trax model do not have the same number of layers"
    # 遍历编码器的每一层并设置权重
    for layer_idx, layer in enumerate(torch_model_reformer.encoder.layers):
        block_weights = trax_layer_weights[4 * layer_idx : 4 * (layer_idx + 1)]
        set_block_weights_in_torch(block_weights, layer, hidden_size)

    # 设置输出层的 LayerNorm 参数
    layer_norm_out_weight = np.asarray(weights[7][0])
    layer_norm_out_bias = np.asarray(weights[7][1])
    set_param(
        torch_model_reformer.encoder.layer_norm,
        torch.tensor(layer_norm_out_weight),
        torch.tensor(layer_norm_out_bias),
    )

    # 设置输出嵌入层的参数
    output_embed_weights = np.asarray(weights[9][0])
    output_embed_bias = np.asarray(weights[9][1])
    set_param(
        torch_model.lm_head.decoder,
        torch.tensor(output_embed_weights).transpose(0, 1).contiguous(),
        torch.tensor(output_embed_bias),
    )


# 将 Trax 的检查点文件转换为 PyTorch 模型并保存
def convert_trax_checkpoint_to_pytorch(trax_model_pkl_path, config_file, pytorch_dump_path):
    # 从配置文件中加载 Reformer 模型配置
    config = ReformerConfig.from_json_file(config_file)
    print(f"Building PyTorch model from configuration: {config}")
    # 根据配置创建 PyTorch 模型
    model = ReformerModelWithLMHead(config)

    # 从 Trax 检查点文件中加载权重
    with open(trax_model_pkl_path, "rb") as f:
        model_weights = pickle.load(f)["weights"]

    # 将加载的权重设置到 PyTorch 模型中
    set_model_weights_in_torch(model_weights, model, config.hidden_size)

    # 保存转换后的 PyTorch 模型
    print(f"Save PyTorch model to {pytorch_dump_path}")
    torch.save(model.state_dict(), pytorch_dump_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # 必需参数
    parser.add_argument(
        "--trax_model_pkl_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
    )
    parser.add_argument(
        "--config_file",
        default=None,
        type=str,
        required=True,
        help=(
            "The config json file corresponding to the pre-trained Reformer model. \n"
            "This specifies the model architecture."
        ),
    )
    # 添加一个命令行参数,用于指定输出的 PyTorch 模型的路径
    parser.add_argument(
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
    )
    # 解析命令行参数,并将其存储在 args 对象中
    args = parser.parse_args()
    # 调用函数 convert_trax_checkpoint_to_pytorch,将转换 Trax 模型为 PyTorch 模型
    # 使用 args 对象中的 trax_model_pkl_path(Trax 模型路径)、config_file(配置文件路径)和 pytorch_dump_path(输出 PyTorch 模型路径)作为参数
    convert_trax_checkpoint_to_pytorch(args.trax_model_pkl_path, args.config_file, args.pytorch_dump_path)

.\models\reformer\modeling_reformer.py

# 定义一个函数 _stable_argsort,用于稳定地对输入的向量进行排序操作
def _stable_argsort(vector, dim):
    # 此函数对向量进行缩放以确保 torch.argsort 的稳定性
    # torch.argsort 在默认情况下不是稳定的排序算法
    # 创建一个偏移量张量,其值从 0 到向量的长度,用于稳定化排序
    scale_offset = torch.arange(vector.shape[dim], device=vector.device).view(1, 1, -1)
    scale_offset = scale_offset.expand(vector.shape)
    # 根据给定的维度(dim),对输入向量(vector)进行缩放和排序
    scaled_vector = vector.shape[dim] * vector + (scale_offset % vector.shape[dim])
    # 使用PyTorch中的argsort函数对缩放后的向量进行排序,按照指定的维度(dim)排序
    return torch.argsort(scaled_vector, dim=dim)
def _get_least_common_mult_chunk_len(config):
    attn_types = config.attn_layers  # 获取配置中的注意力类型列表
    attn_types_set = set(attn_types)  # 将注意力类型转换为集合,去除重复项
    if len(attn_types_set) == 1 and attn_types[0] == "lsh":  # 如果只有一种注意力类型且为'lsh'
        return config.lsh_attn_chunk_length  # 返回配置中的LSH注意力块长度
    elif len(attn_types_set) == 1 and attn_types[0] == "local":  # 如果只有一种注意力类型且为'local'
        return config.local_attn_chunk_length  # 返回配置中的本地注意力块长度
    elif len(attn_types_set) == 2 and attn_types_set == {"lsh", "local"}:  # 如果有两种注意力类型且分别为'lsh'和'local'
        return np.lcm(config.lsh_attn_chunk_length, config.local_attn_chunk_length)  # 返回LSH和本地注意力块长度的最小公倍数
    else:
        raise NotImplementedError(
            f"Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {config.attn_layers}. Select "
            "attn layer types from ['lsh', 'local'] only."
        )  # 抛出未实现的错误,提示只能选择 'lsh' 和 'local' 两种类型的注意力层


def _get_min_chunk_len(config):
    attn_types = config.attn_layers  # 获取配置中的注意力类型列表
    attn_types_set = set(attn_types)  # 将注意力类型转换为集合,去除重复项
    if len(attn_types_set) == 1 and attn_types[0] == "lsh":  # 如果只有一种注意力类型且为'lsh'
        return config.lsh_attn_chunk_length  # 返回配置中的LSH注意力块长度
    elif len(attn_types_set) == 1 and attn_types[0] == "local":  # 如果只有一种注意力类型且为'local'
        return config.local_attn_chunk_length  # 返回配置中的本地注意力块长度
    elif len(attn_types_set) == 2 and attn_types_set == {"lsh", "local"}:  # 如果有两种注意力类型且分别为'lsh'和'local'
        return min(config.lsh_attn_chunk_length, config.local_attn_chunk_length)  # 返回LSH和本地注意力块长度的最小值
    else:
        raise NotImplementedError(
            f"Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {config.attn_layers}. Select "
            "attn layer types from ['lsh', 'local'] only."
        )  # 抛出未实现的错误,提示只能选择 'lsh' 和 'local' 两种类型的注意力层


class AxialPositionEmbeddings(nn.Module):
    """
    Constructs axial position embeddings. Useful for very long input sequences to save memory and time.
    """

    def __init__(self, config):
        super().__init__()
        self.axial_pos_shape = config.axial_pos_shape  # 设置轴向位置嵌入的形状
        self.axial_pos_embds_dim = config.axial_pos_embds_dim  # 设置轴向位置嵌入的维度
        self.dropout = config.hidden_dropout_prob  # 设置隐藏层的dropout比例

        self.least_common_mult_chunk_length = _get_least_common_mult_chunk_len(config)  # 计算最小公倍数块长度
        self.weights = nn.ParameterList()  # 初始化参数列表

        if sum(self.axial_pos_embds_dim) != config.hidden_size:  # 如果轴向位置嵌入的维度之和不等于隐藏层大小
            raise ValueError(
                f"Make sure that config.axial_pos_embds factors: {self.axial_pos_embds_dim} sum to "
                f"config.hidden_size: {config.hidden_size}"
            )  # 抛出值错误,提示轴向位置嵌入的维度之和应等于隐藏层大小

        # create weights
        for axis, axial_pos_embd_dim in enumerate(self.axial_pos_embds_dim):
            # create expanded shapes
            ax_shape = [1] * len(self.axial_pos_shape)  # 创建轴向形状的扩展列表
            ax_shape[axis] = self.axial_pos_shape[axis]  # 设置当前轴的形状
            ax_shape = tuple(ax_shape) + (axial_pos_embd_dim,)  # 转换为元组并添加嵌入维度

            # create tensor and init
            self.weights.append(nn.Parameter(torch.ones(ax_shape, dtype=torch.float32)))  # 创建参数张量并初始化


class PositionEmbeddings(nn.Module):
    """Constructs conventional position embeddings of shape `[max_pos_embeddings, hidden_size]`."""
    # 初始化方法,用于初始化对象
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__()
        # 将配置中的隐藏层dropout概率赋值给对象的dropout属性
        self.dropout = config.hidden_dropout_prob
        # 创建一个Embedding层,用于位置ID到隐藏层大小的嵌入映射
        self.embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)

    # 前向传播方法,定义了数据如何在模型中前向传播
    def forward(self, position_ids):
        # 将位置ID转换为位置嵌入向量
        position_embeddings = self.embedding(position_ids)
        # 对位置嵌入向量进行dropout操作,根据self.training确定是否训练模式
        position_embeddings = nn.functional.dropout(position_embeddings, p=self.dropout, training=self.training)
        # 返回处理后的位置嵌入向量作为模型的输出
        return position_embeddings
class ReformerEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, config):
        super().__init__()
        self.max_position_embeddings = config.max_position_embeddings  # 初始化最大位置嵌入数
        self.dropout = config.hidden_dropout_prob  # 初始化隐藏层dropout概率

        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)  # 创建词嵌入层
        self.position_embeddings = (
            AxialPositionEmbeddings(config) if config.axial_pos_embds else PositionEmbeddings(config)
        )  # 根据配置选择轴向或普通位置嵌入层

    def forward(self, input_ids=None, position_ids=None, inputs_embeds=None, start_idx_pos_encodings=0):
        if input_ids is not None:
            input_shape = input_ids.size()  # 获取输入ids的形状
            device = input_ids.device  # 获取输入ids所在设备
        else:
            input_shape = inputs_embeds.size()[:-1]  # 获取嵌入输入的形状(去掉最后一维)
            device = inputs_embeds.device  # 获取嵌入输入所在设备

        seq_length = input_shape[1]  # 获取序列长度
        if position_ids is None:
            position_ids = torch.arange(
                start_idx_pos_encodings, start_idx_pos_encodings + seq_length, dtype=torch.long, device=device
            )  # 创建位置ids,如果未提供的话
            position_ids = position_ids.unsqueeze(0).expand(input_shape)  # 扩展位置ids到输入形状

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)  # 使用词嵌入层获取嵌入输入

        if position_ids.shape[-1] > self.max_position_embeddings:
            raise ValueError(
                f"Sequence Length: {position_ids.shape[-1]} has to be less or equal than "
                f"config.max_position_embeddings {self.max_position_embeddings}."
            )  # 检查位置ids的长度是否超过最大位置嵌入数,如果超过则抛出异常

        # dropout
        embeddings = nn.functional.dropout(inputs_embeds, p=self.dropout, training=self.training)  # 应用dropout

        # add positional embeddings
        position_embeddings = self.position_embeddings(position_ids)  # 添加位置嵌入
        embeddings = embeddings + position_embeddings  # 将位置嵌入加到词嵌入上
        return embeddings


class EfficientAttentionMixin:
    """
    A few utilities for nn.Modules in Reformer, to be used as a mixin.
    """

    def _look_adjacent(self, vectors, num_chunks_before, num_chunks_after):
        """
        Used to implement attention between consecutive chunks.

        Args:
            vectors: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...]
            num_chunks_before: chunks before current chunk to include in attention
            num_chunks_after: chunks after current chunk to include in attention

        Returns:
            tensor of shape [num_chunks, N * chunk_length, ...], where N = (1 + num_chunks_before + num_chunks_after).
        """
        if num_chunks_before == 0 and num_chunks_after == 0:
            return vectors  # 如果没有前后的chunk,直接返回向量

        slices = []
        for i in range(-num_chunks_before, num_chunks_after + 1):
            if i == 0:
                slices.append(vectors)  # 中心chunk直接添加
            else:
                slices.append(torch.cat([vectors[:, :, i:, ...], vectors[:, :, :i, ...]], dim=2))  # 添加前后的chunk
        return torch.cat(slices, dim=3)  # 合并所有chunk并返回
    # 将输入张量 x 的最后一维划分为 num_attn_heads 和 attn_head_size 维度,并重新构造张量形状
    def _split_hidden_size_dim(self, x, num_attn_heads, attn_head_size):
        new_x_shape = x.size()[:-1] + (num_attn_heads, attn_head_size)
        x = x.view(*new_x_shape)
        return x.transpose(2, 1)

    # 将输入张量 x 的第三和第四维度互换,然后将其余维度展平为 hidden_size 维度
    def _merge_hidden_size_dims(self, x, num_attn_heads, attn_head_size):
        x = x.permute(0, 2, 1, 3)
        return torch.reshape(x, (x.size()[0], -1, num_attn_heads * attn_head_size))

    # 将输入张量 vectors 的序列长度维度划分为 dim_factor_1 和 dim_factor_2 维度
    # 如果 vectors 是四维张量,则还需添加 attn_head_size 维度
    def _split_seq_length_dim_to(self, vectors, dim_factor_1, dim_factor_2, num_attn_heads, attn_head_size=None):
        batch_size = vectors.shape[0]
        split_dim_shape = (batch_size, num_attn_heads, dim_factor_1, dim_factor_2)

        if len(vectors.shape) == 4:
            return torch.reshape(vectors, split_dim_shape + (attn_head_size,))
        elif len(vectors.shape) == 3:
            return torch.reshape(vectors, split_dim_shape)
        else:
            raise ValueError(f"Input vector rank should be one of [3, 4], but is: {len(vectors.shape)}")
# 定义一个名为 LSHSelfAttention 的类,继承自 nn.Module 和 EfficientAttentionMixin
class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
    # 初始化方法,接收一个 config 参数
    def __init__(self, config):
        super().__init__()
        # 将传入的配置保存到实例变量 self.config 中
        self.config = config

        # 从配置中获取并设置各种参数
        self.chunk_length = config.lsh_attn_chunk_length  # LSH 注意力的块长度
        self.num_hashes = config.num_hashes  # 哈希函数的数量
        self.num_buckets = config.num_buckets  # 桶的数量
        self.num_chunks_before = config.lsh_num_chunks_before  # 注意力前的块数量
        self.num_chunks_after = config.lsh_num_chunks_after  # 注意力后的块数量
        self.hash_seed = config.hash_seed  # 哈希种子
        self.is_decoder = config.is_decoder  # 是否为解码器
        self.max_position_embeddings = config.max_position_embeddings  # 最大位置编码

        self.dropout = config.lsh_attention_probs_dropout_prob  # 注意力概率 dropout

        self.num_attention_heads = config.num_attention_heads  # 注意力头的数量
        self.attention_head_size = config.attention_head_size  # 每个注意力头的大小
        self.all_head_size = self.num_attention_heads * self.attention_head_size  # 所有注意力头的总大小
        self.hidden_size = config.hidden_size  # 隐藏层大小

        # 定义查询和键的投影矩阵,无偏置
        self.query_key = nn.Linear(self.hidden_size, self.all_head_size, bias=False)
        self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False)

        # 注册缓冲区,保存不同精度的掩码值
        self.register_buffer("self_mask_value_float16", torch.tensor(-1e3), persistent=False)
        self.register_buffer("self_mask_value_float32", torch.tensor(-1e5), persistent=False)
        self.register_buffer("mask_value_float16", torch.tensor(-1e4), persistent=False)
        self.register_buffer("mask_value_float32", torch.tensor(-1e9), persistent=False)

    # 前向传播方法
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        num_hashes=None,
        buckets=None,
        past_buckets_states=None,
        use_cache=False,
        output_attentions=False,
        **kwargs,
    ):
        # 基于每个注意力头的查询矩阵
        def _query_per_attn_head(self, hidden_states):
            # 重塑和转置查询矩阵,以便用于每个注意力头
            per_head_query_key = self.query_key.weight.reshape(
                self.num_attention_heads, self.attention_head_size, self.hidden_size
            ).transpose(-2, -1)
            # 使用 einsum 计算查询向量
            query_key_vectors = torch.einsum("balh,ahr->balr", hidden_states, per_head_query_key)
            return query_key_vectors

        # 基于每个注意力头的值矩阵
        def _value_per_attn_head(self, hidden_states):
            # 重塑和转置值矩阵,以便用于每个注意力头
            per_head_value = self.value.weight.reshape(
                self.num_attention_heads, self.attention_head_size, self.hidden_size
            ).transpose(-2, -1)
            # 使用 einsum 计算值向量
            value_vectors = torch.einsum("balh,ahr->balr", hidden_states, per_head_value)
            return value_vectors
    def _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(self, sequence_length, buckets, num_hashes):
        # 不需要计算梯度
        with torch.no_grad():
            # 基于哈希进行排序
            sorted_bucket_idx = _stable_argsort(buckets, dim=-1)

            # 创建简单的索引用于散开操作,以便进行反排序
            indices = (
                torch.arange(sorted_bucket_idx.shape[-1], device=buckets.device)
                .view(1, 1, -1)
                .expand(sorted_bucket_idx.shape)
            )

            # 获取反排序的索引
            undo_sorted_bucket_idx = sorted_bucket_idx.new(*sorted_bucket_idx.size())
            undo_sorted_bucket_idx.scatter_(-1, sorted_bucket_idx, indices)

        return sorted_bucket_idx, undo_sorted_bucket_idx

    def _set_num_buckets(self, sequence_length):
        # 根据论文推荐,`num_buckets` 应该设置为 2 * sequence_length // chunk_length
        num_buckets_pow_2 = (2 * (sequence_length // self.chunk_length)).bit_length() - 1
        # 确保 buckets 是2的幂
        num_buckets = 2**num_buckets_pow_2

        # 如果 `num_buckets` 太大,则进行因式分解
        num_buckets_limit = 2 * max(
            int((self.max_position_embeddings // self.chunk_length) ** (0.5)),
            self.chunk_length,
        )
        if num_buckets > num_buckets_limit:
            num_buckets = [2 ** (num_buckets_pow_2 // 2), 2 ** (num_buckets_pow_2 - num_buckets_pow_2 // 2)]

        logger.warning(f"config.num_buckets 未设置。将 config.num_buckets 设置为 {num_buckets}...")

        # 在配置中设置 num_buckets 以便正确保存
        self.config.num_buckets = num_buckets
        self.num_buckets = num_buckets

    def _attend(
        self,
        query_vectors,
        key_vectors,
        value_vectors,
        sorted_bucket_idx_per_hash,
        attention_mask,
        head_mask,
        do_standard_self_attention,
        do_cached_attention,
    ):
        # 这是一个方法用于注意力机制的实现,处理给定的向量和掩码等

    def _compute_attn_mask(
        self, query_indices, key_indices, attention_mask, query_key_dot_shape, do_standard_self_attention
    ):
        # 这是一个方法用于计算注意力掩码,根据给定的索引、掩码和其他参数进行操作
        # attention mask for LSH
        if attention_mask is not None:
            # 如果存在注意力掩码,则将其转换为布尔型,并扩展维度以匹配LSH的顺序
            attention_mask = attention_mask.to(torch.bool)[:, None, :]
            if not do_standard_self_attention:
                # 如果不是标准的自注意力机制,则需要将注意力掩码扩展以适应key_value_bucket_idx的形状
                attention_mask = attention_mask[:, None, :]
                attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,))
                # 从LSH排序后的key_indices中提取注意力掩码
                attention_mask = torch.gather(attention_mask, -1, key_indices)

            # 将注意力掩码扩展以适应query_key_dot_shape的形状
            attention_mask = attention_mask.unsqueeze(-2).expand(query_key_dot_shape)

        # Causal mask
        if self.is_decoder is True:
            # 如果是解码器,创建因果掩码,使得查询的索引大于等于键的索引的位置为True
            causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device)

            # 如果注意力掩码不为None,则将因果掩码与注意力掩码相乘
            if attention_mask is not None:
                attention_mask = causal_mask * attention_mask
            else:
                attention_mask = causal_mask

        # 返回最终的注意力掩码
        return attention_mask

    def _get_relevant_hid_states_and_buckets(
        self, query_vectors, attention_mask, num_hashes, hidden_states, past_states, past_buckets
    ):
        # 获取相关隐藏状态和存储桶
        # 这个函数用于从查询向量中获取相关的隐藏状态和存储桶

    def _expand_to_indices_in_relevant_chunk(self, indices, sequence_length):
        # 获取相关块中的索引并扩展
        # 根据给定的索引确定块的起始位置和大小,并通过arange添加正确的块偏移量

        # 计算块的起始索引并扩展
        start_indices_chunk = ((indices[:, -1] // self.chunk_length) - self.num_chunks_before) * self.chunk_length
        total_chunk_size = self.chunk_length * (1 + self.num_chunks_before + self.num_chunks_after)

        expanded_start_indices = start_indices_chunk.unsqueeze(-1).expand(indices.shape[0], total_chunk_size)
        
        # 创建块序列索引,确保通过取模运算满足循环逻辑
        chunk_sequence_indices = expanded_start_indices + torch.arange(
            total_chunk_size, device=indices.device, dtype=torch.long
        ).unsqueeze(0).expand(indices.shape[0], total_chunk_size)

        chunk_sequence_indices = chunk_sequence_indices.flatten() % sequence_length

        # 扩展索引并设置正确的索引
        indices = indices.unsqueeze(1).expand((indices.shape[0], total_chunk_size, -1)).flatten(0, 1).clone()
        indices[:, -1] = chunk_sequence_indices

        return indices

    def _len_and_dim_norm(self, vectors, sqrt_num):
        """
        length and attention head size dim normalization
        """
        # 对向量进行长度和注意力头尺寸维度归一化处理

        # 首先进行长度归一化
        vectors = self._len_norm(vectors)
        vectors = vectors / sqrt_num
        return vectors

    def _len_norm(self, x, epsilon=1e-6):
        """
        length normalization
        """
        # 长度归一化处理

        # 计算方差
        variance = torch.mean(x**2, -1, keepdim=True)
        # 根据方差进行归一化处理
        norm_x = x * torch.rsqrt(variance + epsilon)
        return norm_x
    # 定义一个私有方法 `_gather_by_expansion`,用于扩展 `vectors` 和 `idxs` 的维度,并根据所有哈希值进行聚合
    def _gather_by_expansion(self, vectors, idxs, num_hashes):
        # 将 `idxs` 在最后一个维度上增加一个维度,并在所有维度上进行扩展,以便与 `vectors` 的维度匹配
        expanded_idxs = idxs.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size)
        # 将 `vectors` 在第三个维度上重复 `num_hashes` 次,以便与 `expanded_idxs` 的维度匹配
        vectors = vectors.repeat(1, 1, num_hashes, 1)
        # 使用 `torch.gather` 函数根据 `expanded_idxs` 在第三个维度上聚合 `vectors` 的数据
        return torch.gather(vectors, 2, expanded_idxs)
class ReverseSort(Function):
    """
    After chunked attention is applied which sorted clusters, original ordering has to be restored. Since customized
    backward function is used for Reformer, the gradients of the output vectors have to be explicitly sorted here.
    """

    @staticmethod
    def forward(ctx, out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx):
        # save sorted_bucket_idx for backprop
        with torch.no_grad():
            ctx.sorted_bucket_idx = sorted_bucket_idx

            # undo sort to have correct order for next layer
            expanded_undo_sort_indices = undo_sorted_bucket_idx.unsqueeze(-1).expand(out_vectors.shape)
            out_vectors = torch.gather(out_vectors, 2, expanded_undo_sort_indices)
            logits = torch.gather(logits, 2, undo_sorted_bucket_idx)
        return out_vectors, logits

    @staticmethod
    def backward(ctx, grad_out_vectors, grad_logits):
        # get parameters saved in ctx
        sorted_bucket_idx = ctx.sorted_bucket_idx

        expanded_sort_indices = sorted_bucket_idx.unsqueeze(-1).expand(grad_out_vectors.shape)
        # reverse sort of forward
        grad_out_vectors = torch.gather(grad_out_vectors, 2, expanded_sort_indices)
        grad_logits = torch.gather(grad_logits, 2, sorted_bucket_idx)

        # return grad and `None` fillers for last 2 forward args
        return grad_out_vectors, grad_logits, None, None


class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
    def __init__(self, config):
        super().__init__()

        self.num_attention_heads = config.num_attention_heads
        self.chunk_length = config.local_attn_chunk_length
        self.num_chunks_before = config.local_num_chunks_before
        self.num_chunks_after = config.local_num_chunks_after
        self.is_decoder = config.is_decoder
        self.pad_token_id = config.pad_token_id

        self.attention_head_size = config.attention_head_size
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.hidden_size = config.hidden_size

        # projection matrices
        self.query = nn.Linear(self.hidden_size, self.all_head_size, bias=False)
        self.key = nn.Linear(self.hidden_size, self.all_head_size, bias=False)
        self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False)

        self.dropout = config.local_attention_probs_dropout_prob

        # save mask value here
        self.register_buffer("mask_value_float16", torch.tensor(-1e4), persistent=False)
        self.register_buffer("mask_value_float32", torch.tensor(-1e9), persistent=False)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        past_buckets_states=None,
        use_cache=False,
        output_attentions=False,
        **kwargs,
    ):
        """
        Performs the forward pass for local self-attention mechanism.

        Args:
            hidden_states (torch.Tensor): Input embeddings or hidden states.
            attention_mask (torch.Tensor, optional): Mask indicating which elements should be attended to.
            head_mask (torch.Tensor, optional): Mask indicating heads to be masked out.
            past_buckets_states (torch.Tensor, optional): States from previous attention buckets.
            use_cache (bool, optional): Whether to use caching mechanism.
            output_attentions (bool, optional): Whether to output attention scores.

        Returns:
            torch.Tensor: Output embeddings or hidden states.
        """

    def _compute_attn_mask(
        self, query_indices, key_indices, attention_mask, query_key_dots_shape, do_standard_self_attention
    ):
        """
        Computes the attention mask based on query and key indices.

        Args:
            query_indices (torch.Tensor): Indices for queries.
            key_indices (torch.Tensor): Indices for keys.
            attention_mask (torch.Tensor): Attention mask.
            query_key_dots_shape (Tuple): Shape of the query-key dot product.
            do_standard_self_attention (bool): Whether to perform standard self-attention.

        Returns:
            torch.Tensor: Computed attention mask.
        """
        # chunk attention mask and look before and after
        # 如果存在注意力掩码,则将其转换为布尔型并添加维度以适应后续操作
        if attention_mask is not None:
            attention_mask = attention_mask.to(torch.bool)[:, None, :]

            # 如果不使用标准的自注意力机制,则分割注意力掩码并在最后一个维度上添加分块前后的注意力
            if not do_standard_self_attention:
                attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1)
                attention_mask = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after)

            # 创建注意力掩码
            attention_mask = attention_mask.unsqueeze(-2).expand(query_key_dots_shape)

        # Causal mask
        # 如果是解码器,创建因果注意力掩码
        if self.is_decoder is True:
            causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device)

            # 如果注意力掩码不为空,则将因果掩码与注意力掩码相乘
            if attention_mask is not None:
                attention_mask = causal_mask * attention_mask
            else:
                attention_mask = causal_mask

        # 返回最终的注意力掩码
        return attention_mask


    @staticmethod
    def _retrieve_relevant_hidden_states(previous_hidden_states, chunk_length, num_chunks_before):
        # 计算需要检索的相关隐藏状态的起始位置
        start_position = ((previous_hidden_states.shape[1] // chunk_length) - num_chunks_before) * chunk_length
        # 返回从起始位置开始的相关隐藏状态
        return previous_hidden_states[:, start_position:]
class ReformerSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        all_head_size = config.num_attention_heads * config.attention_head_size
        self.dropout = config.hidden_dropout_prob  # 设置dropout比率

        self.dense = nn.Linear(all_head_size, config.hidden_size, bias=False)  # 创建线性层,将注意力头的输出映射到隐藏层大小

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)  # 前向传播中,将隐藏状态输入到线性层中
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)  # 使用dropout进行正则化
        return hidden_states  # 返回处理后的隐藏状态


class ReformerAttention(nn.Module):
    def __init__(self, config, layer_id=0):
        super().__init__()
        self.layer_id = layer_id  # 层的编号
        self.attn_layers = config.attn_layers  # 注意力层列表

        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # 应用Layer normalization

        # 根据配置选择合适的自注意力机制
        if len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "lsh":
            self.self_attention = LSHSelfAttention(config)  # 使用LSH自注意力
        elif len(set(self.attn_layers)) == 1 and self.attn_layers[0] == "local":
            self.self_attention = LocalSelfAttention(config)  # 使用局部自注意力
        elif len(set(self.attn_layers)) == 2 and set(self.attn_layers) == {"lsh", "local"}:
            # 如果同时支持LSH和局部注意力,则根据层的编号选择正确的注意力机制
            if self.attn_layers[self.layer_id] == "lsh":
                self.self_attention = LSHSelfAttention(config)
            else:
                self.self_attention = LocalSelfAttention(config)
        else:
            # 抛出未实现错误,说明配置不支持的注意力类型
            raise NotImplementedError(
                f"Only attn layer types 'lsh' and 'local' exist, but got `config.attn_layers`: {self.attn_layers}. "
                "Select attn layer types from ['lsh', 'local'] only."
            )
        
        self.output = ReformerSelfOutput(config)  # 创建自注意力层输出对象

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        num_hashes=None,
        past_buckets_states=None,
        use_cache=False,
        orig_sequence_length=None,
        output_attentions=False,
        buckets=None,
        ):
            # 对隐藏状态进行层归一化处理
            hidden_states = self.layer_norm(hidden_states)

            # 确保缓存的隐藏状态在反向传播时设置为None
            if past_buckets_states is not None:
                past_buckets_states_layer = past_buckets_states[self.layer_id]
            else:
                past_buckets_states_layer = None

            # 如果需要,使用缓存的桶进行反向传播,用于LSHSelfAttention
            self_attention_outputs = self.self_attention(
                hidden_states=hidden_states,
                head_mask=head_mask,
                attention_mask=attention_mask,
                num_hashes=num_hashes,
                past_buckets_states=past_buckets_states_layer,
                use_cache=use_cache,
                output_attentions=output_attentions,
                buckets=buckets,
            )

            # 如果self_attention_outputs具有"buckets"属性,则将其分配给buckets变量
            if hasattr(self_attention_outputs, "buckets"):
                buckets = self_attention_outputs.buckets
            else:
                buckets = None

            # 如果需要,将隐藏状态缓存以供将来使用
            if use_cache:
                if past_buckets_states[self.layer_id][0] is None:
                    # 填充的输入不应该被缓存
                    past_buckets = (
                        buckets[:, :, :, :orig_sequence_length]
                        if (buckets is not None and orig_sequence_length > 1)
                        else buckets
                    )
                else:
                    past_buckets = torch.cat([past_buckets_states[self.layer_id][0], buckets], dim=-1)

                if past_buckets_states[self.layer_id][1] is None:
                    # 填充的输入不应该被缓存
                    past_states = hidden_states[:, :orig_sequence_length]
                else:
                    past_states = torch.cat([past_buckets_states[self.layer_id][1], hidden_states], dim=1)

                past_buckets_states[self.layer_id] = (past_buckets, past_states)

            # 计算注意力前馈输出
            attention_output = self.output(self_attention_outputs.hidden_states)

            # 返回AttentionOutput对象,包含注意力机制的输出
            return AttentionOutput(
                hidden_states=attention_output,
                attention_probs=self_attention_outputs.attention_probs,
                buckets=buckets,
            )
class ReformerFeedForwardDense(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dropout = config.hidden_dropout_prob  # 从配置中获取隐藏层dropout概率

        if isinstance(config.hidden_act, str):
            self.act_fn = ACT2FN[config.hidden_act]  # 如果隐藏层激活函数是字符串,从预定义映射中获取对应的函数
        else:
            self.act_fn = config.hidden_act  # 否则直接使用配置中的激活函数

        self.dense = nn.Linear(config.hidden_size, config.feed_forward_size)  # 创建线性层

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)  # 输入隐藏状态经过线性层
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)  # 使用dropout对隐藏状态进行处理
        hidden_states = self.act_fn(hidden_states)  # 使用激活函数对处理后的隐藏状态进行非线性变换
        return hidden_states  # 返回处理后的隐藏状态


class ReformerFeedForwardOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dropout = config.hidden_dropout_prob  # 从配置中获取隐藏层dropout概率

        self.dense = nn.Linear(config.feed_forward_size, config.hidden_size)  # 创建线性层

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)  # 输入隐藏状态经过线性层
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)  # 使用dropout对隐藏状态进行处理
        return hidden_states  # 返回处理后的隐藏状态


class ChunkReformerFeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.chunk_size_feed_forward = config.chunk_size_feed_forward  # 从配置中获取前馈层的分块大小
        self.seq_len_dim = 1  # 序列长度的维度设定为1

        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # 创建Layer Normalization层
        self.dense = ReformerFeedForwardDense(config)  # 创建前馈层的Dense层
        self.output = ReformerFeedForwardOutput(config)  # 创建前馈层的输出层

    def forward(self, attention_output):
        return apply_chunking_to_forward(
            self.forward_chunk,
            self.chunk_size_feed_forward,
            self.seq_len_dim,
            attention_output,
        )

    def forward_chunk(self, hidden_states):
        hidden_states = self.layer_norm(hidden_states)  # 对隐藏状态进行Layer Normalization
        hidden_states = self.dense(hidden_states)  # 输入隐藏状态经过前馈层的Dense层
        return self.output(hidden_states)  # 返回前馈层的输出


class ReformerLayer(nn.Module):
    def __init__(self, config, layer_id=0):
        super().__init__()
        self.attention = ReformerAttention(config, layer_id)  # 创建ReformerAttention层,用于注意力机制
        # dropout requires to have the same
        # seed for forward and backward pass
        self.attention_seed = None
        self.feed_forward_seed = None

        self.feed_forward = ChunkReformerFeedForward(config)  # 创建分块前馈层
    def _init_attention_seed(self):
        """
        This function sets a new seed for the attention layer to make dropout deterministic for both forward calls: 1
        normal forward call and 1 forward call in backward to recalculate activations.
        """

        # randomize seeds
        # 指定一个新的种子给注意力层,以便在前向调用(普通的前向调用和反向调用中的前向调用)中使dropout具有确定性。

        # use cuda generator if available
        # 如果存在 CUDA 生成器,则使用它
        if hasattr(torch.cuda, "default_generators") and len(torch.cuda.default_generators) > 0:
            # GPU
            device_idx = torch.cuda.current_device()
            self.attention_seed = torch.cuda.default_generators[device_idx].seed()
        else:
            # CPU
            self.attention_seed = int(torch.seed() % sys.maxsize)

        # 设置 PyTorch 的随机种子
        torch.manual_seed(self.attention_seed)

    def _init_feed_forward_seed(self):
        """
        This function sets a new seed for the feed forward layer to make dropout deterministic for both forward calls:
        1 normal forward call and 1 forward call in backward to recalculate activations.
        """
        # randomize seeds
        # 指定一个新的种子给前馈层,以便在前向调用(普通的前向调用和反向调用中的前向调用)中使dropout具有确定性。

        # use cuda generator if available
        # 如果存在 CUDA 生成器,则使用它
        if hasattr(torch.cuda, "default_generators") and len(torch.cuda.default_generators) > 0:
            # GPU
            device_idx = torch.cuda.current_device()
            self.feed_forward_seed = torch.cuda.default_generators[device_idx].seed()
        else:
            # CPU
            self.feed_forward_seed = int(torch.seed() % sys.maxsize)

        # 设置 PyTorch 的随机种子
        torch.manual_seed(self.feed_forward_seed)

    def forward(
        self,
        prev_attn_output,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        num_hashes=None,
        past_buckets_states=None,
        use_cache=False,
        orig_sequence_length=None,
        output_attentions=False,
    ):
        # 在没有梯度的情况下执行代码块
        with torch.no_grad():
            # 每次前向传播时采样不同的种子
            # 用于dropout,并保存在反向传播的前向函数中
            # 以确保正确的dropout效果
            if self.training:
                self._init_attention_seed()

            # 执行注意力计算
            attn_outputs = self.attention(
                hidden_states=hidden_states,
                head_mask=head_mask,
                attention_mask=attention_mask,
                num_hashes=num_hashes,
                past_buckets_states=past_buckets_states,
                use_cache=use_cache,
                orig_sequence_length=orig_sequence_length,
                output_attentions=output_attentions,
            )
            # 获取注意力输出的隐藏状态
            attn_output = attn_outputs.hidden_states

            # 实现RevNet(参见https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0中的图6)
            # Y_1 = X_1 + f(X_2)
            attn_output = prev_attn_output + attn_output

            # 释放内存
            del prev_attn_output

            # 每次前向传播时采样不同的种子
            # 用于dropout,并保存种子以便在反向传播中使用
            # 以确保正确的dropout效果
            if self.training:
                self._init_feed_forward_seed()

            # Y_2 = X_2 + g(Y_1)
            hidden_states = hidden_states + self.feed_forward(attn_output)

        # 返回ReformerOutput对象,包含注意力输出、隐藏状态、注意力概率和buckets
        return ReformerOutput(
            attn_output=attn_output,
            hidden_states=hidden_states,
            attention_probs=attn_outputs.attention_probs,
            buckets=attn_outputs.buckets,
        )

    def backward_pass(
        self,
        next_attn_output,
        hidden_states,
        grad_attn_output,
        grad_hidden_states,
        attention_mask=None,
        head_mask=None,
        buckets=None,
        # 实现可逆 ResNets 的反向传播过程。
        # 关于这个工作原理的良好博客文章可以在以下链接找到:
        # 实现 RevNet(参见 https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0 中的图 6)
        # 这段代码受 https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py 启发

        # 断言处于训练模式,确保在训练 `ReformerModel` 及其变体时使用 `model.train()` 来将模型置于训练模式
        assert self.training, (
            "If you want to train `ReformerModel` and its variations, make sure to use `model.train()` to put the"
            " model into training mode."
        )

        with torch.enable_grad():
            # 设置下一个注意力输出的梯度需求为True
            next_attn_output.requires_grad = True

            # 设置种子以确保正确的dropout
            torch.manual_seed(self.feed_forward_seed)
            # g(Y_1)
            # 使用 feed_forward 方法计算下一个注意力输出的隐藏状态
            res_hidden_states = self.feed_forward(next_attn_output)
            # 反向传播 g(Y_1) 的梯度到 grad_hidden_states,保留计算图以备后续使用
            res_hidden_states.backward(grad_hidden_states, retain_graph=True)

        with torch.no_grad():
            # X_2 = Y_2 - g(Y_1)
            # 更新隐藏状态,减去 g(Y_1) 的结果
            hidden_states = hidden_states - res_hidden_states
            # 删除 res_hidden_states 变量以释放内存
            del res_hidden_states

            # 累加 next_attn_output 的梯度到 grad_attn_output
            grad_attn_output = grad_attn_output + next_attn_output.grad
            # 清空 next_attn_output 的梯度,以便下一轮使用
            next_attn_output.grad = None

        with torch.enable_grad():
            # 设置隐藏状态的梯度需求为True
            hidden_states.requires_grad = True

            # 设置种子以确保正确的dropout
            torch.manual_seed(self.attention_seed)
            # f(X_2)
            # 使用 attention 方法计算隐藏状态的输出
            # 如果 buckets 不为 None,则使用缓存的 buckets 进行反向传播
            output = self.attention(
                hidden_states=hidden_states,
                head_mask=head_mask,
                attention_mask=attention_mask,
                buckets=buckets,
            ).hidden_states
            # 反向传播 f(X_2) 的梯度到 grad_attn_output,保留计算图以备后续使用
            output.backward(grad_attn_output, retain_graph=True)

        with torch.no_grad():
            # X_1 = Y_1 - f(X_2)
            # 更新注意力输出,减去 f(X_2) 的结果
            attn_output = next_attn_output - output
            # 删除 output 和 next_attn_output 变量以释放内存
            del output, next_attn_output

            # 累加 hidden_states 的梯度到 grad_hidden_states
            grad_hidden_states = grad_hidden_states + hidden_states.grad
            # 清空 hidden_states 的梯度,以便下一轮使用
            hidden_states.grad = None
            # 分离 hidden_states 的计算图,使其不再跟踪梯度
            hidden_states = hidden_states.detach()

        # 返回 ReformerBackwardOutput 对象,其中包括更新后的 attn_output、hidden_states、grad_attn_output 和 grad_hidden_states
        return ReformerBackwardOutput(
            attn_output=attn_output,
            hidden_states=hidden_states,
            grad_attn_output=grad_attn_output,
            grad_hidden_states=grad_hidden_states,
        )
    """
    针对可逆函数的自定义反向传播函数,以防止 PyTorch 执行通常的反向传播。
    通过这种方式确保在前向传播期间不保存内存消耗昂贵的激活值。
    本函数的实现受到 https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py 的启发。
    """

    @staticmethod
    # 定义静态方法 forward,用于执行前向传播
    def forward(
        ctx,
        hidden_states,
        layers,
        attention_mask,
        head_mask,
        num_hashes,
        all_hidden_states,
        all_attentions,
        past_buckets_states,
        use_cache,
        orig_sequence_length,
        output_hidden_states,
        output_attentions,
    ):
        # 初始化空的所有桶
        all_buckets = ()

        # 将 hidden_states 张量按照最后一个维度分为两部分
        hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1)

        # 遍历层和对应的头部掩码
        for layer_id, (layer, layer_head_mask) in enumerate(zip(layers, head_mask)):
            # 如果需要输出隐藏状态,则将当前隐藏状态添加到 all_hidden_states 列表中
            if output_hidden_states is True:
                all_hidden_states.append(hidden_states)

            # 调用层的前向传播函数,获取层的输出
            layer_outputs = layer(
                prev_attn_output=attn_output,
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                head_mask=layer_head_mask,
                num_hashes=num_hashes,
                past_buckets_states=past_buckets_states,
                use_cache=use_cache,
                orig_sequence_length=orig_sequence_length,
                output_attentions=output_attentions,
            )

            # 更新 attn_output 和 hidden_states
            attn_output = layer_outputs.attn_output
            hidden_states = layer_outputs.hidden_states
            # 将当前层的桶添加到 all_buckets 中
            all_buckets = all_buckets + (layer_outputs.buckets,)

            # 如果需要输出注意力权重,则将当前层的注意力权重添加到 all_attentions 列表中
            if output_attentions:
                all_attentions.append(layer_outputs.attention_probs)

        # 如果需要输出隐藏状态,则将最后一个隐藏状态添加到 all_hidden_states 列表中
        if output_hidden_states is True:
            all_hidden_states.append(hidden_states)

        # 将 attn_output 和 hidden_states 的梯度信息保存到 ctx 中,以备反向传播使用
        ctx.save_for_backward(attn_output.detach(), hidden_states.detach())
        ctx.layers = layers
        ctx.all_buckets = all_buckets
        ctx.head_mask = head_mask
        ctx.attention_mask = attention_mask

        # 将 attn_output 和 hidden_states 拼接在一起作为输出
        return torch.cat([attn_output, hidden_states], dim=-1)

    @staticmethod
    def backward(ctx, grad_hidden_states):
        # 将 grad_hidden_states 按最后一个维度分成两部分
        grad_attn_output, grad_hidden_states = torch.chunk(grad_hidden_states, 2, dim=-1)

        # 从上下文 ctx 中获取保存的张量参数
        attn_output, hidden_states = ctx.saved_tensors

        # 创建包含各种参数的元组 output
        output = ReformerBackwardOutput(
            attn_output=attn_output,
            hidden_states=hidden_states,
            grad_attn_output=grad_attn_output,
            grad_hidden_states=grad_hidden_states,
        )

        # 释放内存,删除不再需要的变量
        del grad_attn_output, grad_hidden_states, attn_output, hidden_states

        # 从上下文中获取反向传播所需的各个参数
        layers = ctx.layers
        all_buckets = ctx.all_buckets
        head_mask = ctx.head_mask
        attention_mask = ctx.attention_mask

        # 对每一层进行反向传播
        for idx, layer in enumerate(layers[::-1]):
            # 弹出最后一个 buckets 并从堆栈中移除
            buckets = all_buckets[-1]
            all_buckets = all_buckets[:-1]

            # 执行反向传播
            output = layer.backward_pass(
                next_attn_output=output.attn_output,
                hidden_states=output.hidden_states,
                grad_attn_output=output.grad_attn_output,
                grad_hidden_states=output.grad_hidden_states,
                head_mask=head_mask[len(layers) - idx - 1],
                attention_mask=attention_mask,
                buckets=buckets,
            )

        # 断言所有 buckets 必须为空元组,用于确认反向传播后所有 buckets 已清空
        assert all_buckets == (), "buckets have to be empty after backpropagation"

        # 将 grad_attn_output 和 grad_hidden_states 沿最后一个维度拼接
        grad_hidden_states = torch.cat([output.grad_attn_output, output.grad_hidden_states], dim=-1)

        # 返回与 forward() 中参数个数相匹配的梯度,其他返回 None
        return grad_hidden_states, None, None, None, None, None, None, None, None, None, None, None
class ReformerEncoder(nn.Module):
    # 定义 Reformer 编码器模型,继承自 nn.Module
    def __init__(self, config):
        super().__init__()
        # 初始化函数,接受配置参数 config

        # 设置 dropout 概率
        self.dropout = config.hidden_dropout_prob

        # 创建多层 ReformerLayer 组成的层列表
        self.layers = nn.ModuleList([ReformerLayer(config, i) for i in range(config.num_hidden_layers)])
        # Reformer 使用 Rev Nets,因此最后一层的输出会被连接起来,
        # 并且对 2 * hidden_size 进行 Layer Norm 处理
        self.layer_norm = nn.LayerNorm(2 * config.hidden_size, eps=config.layer_norm_eps)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        num_hashes=None,
        past_buckets_states=None,
        use_cache=False,
        orig_sequence_length=None,
        output_hidden_states=False,
        output_attentions=False,
    ):
        # 声明存储所有隐藏状态和注意力权重的列表
        all_hidden_states = []
        all_attentions = []

        # 如果需要的话,初始化缓存的历史桶状态
        if past_buckets_states is None:
            past_buckets_states = [((None), (None)) for i in range(len(self.layers))]

        # 将隐藏状态进行拼接,用于可逆 ResNet
        hidden_states = torch.cat([hidden_states, hidden_states], dim=-1)
        # 调用自定义的可逆函数进行前向传播
        hidden_states = _ReversibleFunction.apply(
            hidden_states,
            self.layers,
            attention_mask,
            head_mask,
            num_hashes,
            all_hidden_states,
            all_attentions,
            past_buckets_states,
            use_cache,
            orig_sequence_length,
            output_hidden_states,
            output_attentions,
        )

        # 对拼接后的隐藏状态应用 Layer Norm
        hidden_states = self.layer_norm(hidden_states)

        # 对隐藏状态应用 dropout
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

        # 返回 ReformerEncoderOutput 对象,包含隐藏状态、所有隐藏状态列表、所有注意力权重列表和历史桶状态
        return ReformerEncoderOutput(
            hidden_states=hidden_states,
            all_hidden_states=all_hidden_states,
            all_attentions=all_attentions,
            past_buckets_states=past_buckets_states,
        )


class ReformerOnlyLMHead(nn.Module):
    # 定义仅包含语言模型头部的 Reformer 模型,继承自 nn.Module
    def __init__(self, config):
        super().__init__()

        # Reformer 使用 Rev Nets,因此最后一层的输出会被连接起来,
        # 并且对 2 * hidden_size 进行 Layer Norm 处理
        self.seq_len_dim = 1
        self.chunk_size_lm_head = config.chunk_size_lm_head
        # 定义线性层作为解码器,输出大小为 vocab_size
        self.decoder = nn.Linear(2 * config.hidden_size, config.vocab_size, bias=False)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        self.decoder.bias = self.bias

    def forward(self, hidden_states):
        # 应用分块处理来执行前向传播
        return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)

    def forward_chunk(self, hidden_states):
        # 使用解码器层进行前向传播
        hidden_states = self.decoder(hidden_states)
        return hidden_states
    def _tie_weights(self):
        # 如果两个权重被断开连接(在TPU上或者当偏置被重新调整大小时),用于将它们绑定在一起。
        self.bias = self.decoder.bias
# 定义一个抽象类,用于处理权重初始化以及下载和加载预训练模型的简单接口
class ReformerPreTrainedModel(PreTrainedModel):
    # 指定配置类
    config_class = ReformerConfig
    # 模型名称前缀
    base_model_prefix = "reformer"

    # 返回一个包含虚拟输入数据的字典
    @property
    def dummy_inputs(self):
        input_ids = torch.tensor(DUMMY_INPUTS)  # 创建包含虚拟输入 IDs 的张量
        input_mask = torch.tensor(DUMMY_MASK)   # 创建包含虚拟输入掩码的张量
        dummy_inputs = {
            "input_ids": input_ids,             # 将输入 IDs 放入字典
            "attention_mask": input_mask,       # 将注意力掩码放入字典
        }
        return dummy_inputs                    # 返回虚拟输入字典

    # 初始化模型的权重
    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, AxialPositionEmbeddings):
            # 如果模块是 AxialPositionEmbeddings 类型,则初始化其权重
            for weight in module.weights:
                nn.init.normal_(weight, std=self.config.axial_norm_std)
        elif isinstance(module, nn.Embedding):
            # 如果模块是 nn.Embedding 类型,则初始化其权重和可能的填充索引
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.Linear):
            # 如果模块是 nn.Linear 类型,则初始化其权重和偏置
            # 与 TF 版本稍有不同,这里使用正态分布来初始化权重
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            # 如果模块是 nn.LayerNorm 类型,则初始化其权重和偏置
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


@dataclass
class ReformerModelOutput(ModelOutput):
    """
    Output type of [`ReformerModel`].
    """
    # 定义函数参数及其类型注释
    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_predict, hidden_size)`):
            模型最后一层的隐藏状态序列。
    
            `num_predict` 对应于 `target_mapping.shape[1]`。如果 `target_mapping` 是 `None`,那么 `num_predict` 对应于 `sequence_length`。
        past_buckets_states (`List[Tuple(torch.LongTensor, torch.FloatTensor)]`, *optional*, 在 `use_cache=True` 或 `config.use_cache=True` 时返回):
            包含预先计算的桶和隐藏状态的列表,用于加速顺序解码。
    
            每个元素是一个元组,第一个元素是形状为 `(batch_size, num_heads, num_hashes, sequence_length)` 的先前*桶*,
            第二个元素是形状为 `(batch_size, sequence_length, hidden_size)` 的先前*隐藏状态*。
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, 在 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回):
            元组中包含了每层的输出 (`embeddings` 输出的一个和每层输出的一个) 的 `torch.FloatTensor`,形状为 `(batch_size, sequence_length, hidden_size)`。
    
            模型每一层的隐藏状态,加上初始嵌入的输出。
        attentions (`tuple(torch.FloatTensor)`, *optional*, 在 `output_attentions=True` 或 `config.output_attentions=True` 时返回):
            元组中包含了每层的注意力权重 `torch.FloatTensor` (每层一个),形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
    
            经过注意力 softmax 后的注意力权重,用于计算自注意力头中的加权平均值。
@dataclass
class ReformerModelWithLMHeadOutput(ModelOutput):
    """
    Output type of [`ReformerModelWithLMHead`].

    Args:
        loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided)
            Language modeling loss (for next-token prediction).
        logits (`torch.FloatTensor` of shape `(batch_size, num_predict, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).

            `num_predict` corresponds to `target_mapping.shape[1]`. If `target_mapping` is `None`, then `num_predict`
            corresponds to `sequence_length`.
        past_buckets_states (`List[Tuple(torch.LongTensor, torch.FloatTensor)]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            List of `Tuple(torch.LongTensor, torch.FloatTensor` of length `config.n_layers`, with the first element
            being the previous *buckets* of shape `(batch_size, num_heads, num_hashes, sequence_length)`) and the
            second being the previous *hidden_states* of shape `(batch_size, sequence_length, hidden_size)`).

            Contains precomputed buckets and hidden-states that can be used (see `past_buckets_states` input) to speed
            up sequential decoding.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer)
            of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    # 定义 ReformerModelWithLMHeadOutput 类,作为 Reformer 模型输出的数据结构
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    past_buckets_states: Optional[List[Tuple[torch.LongTensor, torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


REFORMER_START_DOCSTRING = r"""
    Reformer was proposed in [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev,
    Łukasz Kaiser, Anselm Levskaya.

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)
"""
# 定义 REFORMER_START_DOCSTRING 字符串,描述 Reformer 模型的起始文档字符串
    # 这是一个 PyTorch 的模型,继承自 [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) 类。
    # 可以像普通的 PyTorch 模块一样使用,关于一般使用和行为的所有问题,请参考 PyTorch 文档。
    
    # 参数:
    #     config ([`ReformerConfig`]): 这是一个模型配置类,包含了模型的所有参数。
    #         使用配置文件初始化模型不会加载与模型关联的权重,只会加载配置信息。
    #         若要加载模型权重,请查看 [`~PreTrainedModel.from_pretrained`] 方法。
"""
REFORMER_INPUTS_DOCSTRING = r"""
"""

# 使用装饰器添加文档字符串,描述模型输出原始隐藏状态的Reformer模型,没有特定的输出头
@add_start_docstrings(
    "The bare Reformer Model transformer outputting raw hidden-states without any specific head on top.",
    REFORMER_START_DOCSTRING,
)
# 定义ReformerModel类,继承自ReformerPreTrainedModel
class ReformerModel(ReformerPreTrainedModel):
    def __init__(self, config):
        # 调用父类构造函数
        super().__init__(config)
        # 将配置保存在self.config中
        self.config = config
        # 断言确保num_hidden_layers大于0,否则抛出异常
        assert (
            self.config.num_hidden_layers > 0
        ), "`config.attn_layers` is empty. Select at least one attn layer form ['lsh', 'local']"

        # 初始化词嵌入和编码器
        self.embeddings = ReformerEmbeddings(config)
        self.encoder = ReformerEncoder(config)

        # 初始化权重并进行最终处理
        self.post_init()

    # 返回输入嵌入
    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    # 设置输入嵌入
    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

    # 剪枝模型的注意力头
    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    # 使用装饰器添加文档字符串,描述模型前向传播函数的输入
    @add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
    # 添加代码示例的文档字符串
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=ReformerModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    # 定义前向传播函数
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        num_hashes: Optional[int] = None,
        past_buckets_states: Optional[List[Tuple[torch.Tensor]]] = None,
        use_cache: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 剪切到多个块长度的填充函数
    def _pad_to_mult_of_chunk_length(
        self,
        input_ids,
        inputs_embeds=None,
        attention_mask=None,
        position_ids=None,
        input_shape=None,
        padding_length=None,
        padded_seq_length=None,
        device=None,
        ):
            # 如果输入的序列长度不是 `config.chunk_length` 的倍数,则发出警告
            logger.warning_once(
                f"Input ids are automatically padded from {input_shape[-1]} to {input_shape[-1] + padding_length} to be a "
                f"multiple of `config.chunk_length`: {padded_seq_length}"
            )

        # 使用指定的填充长度和 pad_token_id 创建填充后的输入 ids 张量
        padded_input_ids = torch.full(
            (input_shape[0], padding_length),
            self.config.pad_token_id,
            device=device,
            dtype=torch.long,
        )

        # 扩展 `attention_mask`
        if attention_mask is not None:
            # 创建一个与输入形状相同的全零张量,并将其与原 attention_mask 进行拼接
            pad_attention_mask = torch.zeros(input_shape[0], padding_length, device=device, dtype=attention_mask.dtype)
            attention_mask = torch.cat([attention_mask, pad_attention_mask], dim=-1)
        else:
            # 如果原 attention_mask 不存在,则创建一个全真值的张量和一个全零张量,然后进行拼接
            attention_mask = torch.cat(
                [
                    torch.ones(input_shape, device=device, dtype=torch.bool),
                    torch.zeros((input_shape[0], padding_length), device=device, dtype=torch.bool),
                ],
                dim=-1,
            )

        # 如果存在输入 ids,则将填充后的输入 ids 进行拼接,并更新输入形状
        if input_ids is not None:
            input_ids = torch.cat([input_ids, padded_input_ids], dim=-1)
            input_shape = input_ids.size()

            # 如果存在位置 ids,则对位置 ids 进行填充
            if position_ids is not None:
                # 创建一个从原始长度到填充长度的序列,然后拼接到原位置 ids 后面
                padded_position_ids = torch.arange(input_shape[-1], padded_seq_length, dtype=torch.long, device=device)
                padded_position_ids = position_ids.unsqueeze(0).expand(input_shape[0], padding_length)
                position_ids = torch.cat([position_ids, padded_position_ids], dim=-1)

        # 如果存在输入嵌入张量,则对其进行填充,并更新输入形状
        if inputs_embeds is not None:
            # 使用填充后的输入 ids 和位置 ids 重新计算输入嵌入张量,并进行拼接
            padded_inputs_embeds = self.embeddings(padded_input_ids, position_ids)
            inputs_embeds = torch.cat([inputs_embeds, padded_inputs_embeds], dim=-2)
            input_shape = inputs_embeds.size()

        # 返回填充后的 input_ids、inputs_embeds、attention_mask、position_ids 和更新后的 input_shape
        return input_ids, inputs_embeds, attention_mask, position_ids, input_shape
# 使用装饰器为 ReformerModelWithLMHead 类添加文档字符串,描述其作为带有语言建模头部的 Reformer 模型
@add_start_docstrings("""Reformer Model with a `language modeling` head on top.""", REFORMER_START_DOCSTRING)
class ReformerModelWithLMHead(ReformerPreTrainedModel):
    # 定义需要共享权重的键列表,这些键对应于语言建模头部的权重和偏置
    _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]

    def __init__(self, config):
        # 调用父类的构造函数,传入配置参数
        super().__init__(config)
        # 断言配置中必须设置 is_decoder=True,否则抛出异常
        assert config.is_decoder, "If you want to use `ReformerModelWithLMHead` make sure that `is_decoder=True`."
        # 断言配置中不应该使用 "local" 注意力层或者 local_num_chunks_after 应为 0,否则抛出异常
        assert "local" not in self.config.attn_layers or config.local_num_chunks_after == 0, (
            "If causal mask is enabled, make sure that `config.local_num_chunks_after` is set to 0 and not"
            f" {config.local_num_chunks_after}."
        )
        # 断言配置中不应该使用 "lsh" 注意力层或者 lsh_num_chunks_after 应为 0,否则抛出异常
        assert "lsh" not in self.config.attn_layers or config.lsh_num_chunks_after == 0, (
            "If causal mask is enabled, make sure that `config.lsh_num_chunks_after` is set to 1 and not"
            f" {config.lsh_num_chunks_after}."
        )

        # 实例化 ReformerModel,传入配置参数
        self.reformer = ReformerModel(config)
        # 实例化 ReformerOnlyLMHead,传入配置参数
        self.lm_head = ReformerOnlyLMHead(config)

        # 初始化权重并应用最终处理
        self.post_init()

    # 获取语言建模头部的输出嵌入层
    def get_output_embeddings(self):
        return self.lm_head.decoder

    # 设置语言建模头部的输出嵌入层为新的嵌入层
    def set_output_embeddings(self, new_embeddings):
        self.lm_head.decoder = new_embeddings

    # 重写 forward 方法,并使用装饰器添加文档字符串和代码示例文档字符串
    @add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=CausalLMOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        num_hashes: Optional[int] = None,
        past_buckets_states: Optional[List[Tuple[torch.Tensor]]] = None,
        use_cache: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: Optional[torch.Tensor] = None,
    ) -> Union[Tuple, CausalLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
                Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
                config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
                labels in `[0, ..., config.vocab_size]`
        """
        # 如果 return_dict 已经被定义,则使用其当前值;否则使用 self.config.use_return_dict 的值
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用 Reformer 模型,传递多个参数以获取输出
        reformer_outputs = self.reformer(
            input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            num_hashes=num_hashes,
            past_buckets_states=past_buckets_states,
            use_cache=use_cache,
            output_hidden_states=output_hidden_states,
            output_attentions=output_attentions,
            return_dict=return_dict,
        )

        # 从 Reformer 输出中获取序列输出
        sequence_output = reformer_outputs[0]
        # 将序列输出传递给语言模型头部,获取 logits
        logits = self.lm_head(sequence_output)

        # 初始化损失为 None
        loss = None
        # 如果 labels 不为空,则计算损失
        if labels is not None:
            # 将 logits 左移一位,以便对应标签右移一位,即 tokens < n 预测 n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # 使用交叉熵损失函数计算损失
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))

        # 如果 return_dict 为 False,则返回一个元组,包含 logits 和其他输出
        if not return_dict:
            output = (logits,) + reformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        # 如果 return_dict 为 True,则返回一个包含损失和各种输出的 ReformerModelWithLMHeadOutput 对象
        return ReformerModelWithLMHeadOutput(
            loss=loss,
            logits=logits,
            past_buckets_states=reformer_outputs.past_buckets_states,
            hidden_states=reformer_outputs.hidden_states,
            attentions=reformer_outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, use_cache=None, num_hashes=None, **kwargs
    ):
        # 如果 past_key_values 不为空,则只保留输入的最后一个 token
        if past_key_values is not None:
            input_ids = input_ids[:, -1:]

        # 准备用于生成的输入字典
        inputs_dict = {
            "input_ids": input_ids,
            "past_buckets_states": past_key_values,
            "use_cache": use_cache,
            "num_hashes": num_hashes,
        }

        # 返回输入字典
        return inputs_dict
    # 重新排序缓存中的过去键值对状态
    def _reorder_cache(self, past_key_values, beam_idx):
        # 存储重新排序后的过去桶状态和隐藏状态的列表
        reord_past_buckets_states = []
        # 遍历每个层级的过去键值对
        for layer_past in past_key_values:
            # 如果当前层的桶状态不为None,则根据beam_idx重新排序
            if layer_past[0] is not None:
                reord_buckets = layer_past[0].index_select(0, beam_idx.to(layer_past[0].device))
            else:
                reord_buckets = None

            # 根据beam_idx重新排序当前层的隐藏状态
            reord_hidden_states = layer_past[1].index_select(0, beam_idx.to(layer_past[1].device))
            # 将重新排序后的桶状态和隐藏状态组成元组,并添加到列表中
            reord_past_buckets_states.append((reord_buckets, reord_hidden_states))
        
        # 返回所有层级的重新排序后的过去桶状态和隐藏状态的列表
        return reord_past_buckets_states
@add_start_docstrings("""Reformer Model with a `language modeling` head on top.""", REFORMER_START_DOCSTRING)
class ReformerForMaskedLM(ReformerPreTrainedModel):
    _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]

    def __init__(self, config):
        super().__init__(config)
        # 确保不是解码器模式,因为此模型用于双向自注意力
        assert not config.is_decoder, (
            "If you want to use `ReformerForMaskedLM` make sure `config.is_decoder=False` for bi-directional"
            " self-attention."
        )
        # 初始化Reformer模型和ReformerLMHead
        self.reformer = ReformerModel(config)
        self.lm_head = ReformerOnlyLMHead(config)

        # 初始化权重并应用最终处理
        self.post_init()

    def get_output_embeddings(self):
        # 返回语言建模头部的解码器权重
        return self.lm_head.decoder

    def set_output_embeddings(self, new_embeddings):
        # 设置新的输出嵌入到语言建模头部的解码器
        self.lm_head.decoder = new_embeddings

    @add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        num_hashes: Optional[int] = None,
        labels: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        return_dict: Optional[bool] = None,



@add_start_docstrings(
    """
    Reformer Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    """,
    REFORMER_START_DOCSTRING,
)
class ReformerForSequenceClassification(ReformerPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        # 保存标签数和配置
        self.num_labels = config.num_labels
        self.config = config

        # 初始化Reformer模型和分类器
        self.reformer = ReformerModel(config)
        self.classifier = ReformerClassificationHead(config)
        if config.is_decoder is True:
            # 如果配置为解码器,警告可能需要禁用因果遮蔽以进行序列分类
            logger.warning("You might want to disable causal masking for sequence classification")

        # 初始化权重并应用最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        num_hashes: Optional[int] = None,
        labels: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        return_dict: Optional[bool] = None,
class ReformerClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config):
        super().__init__()
        # 初始化分类器头部,输入维度为2倍的隐藏状态大小,输出维度为隐藏状态大小
        self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size)
        
        # 根据配置设置分类器的dropout,如果未提供分类器dropout,则使用隐藏层的dropout概率
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        
        # 最终的线性投影层,将隐藏状态映射到标签数量大小的输出
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, hidden_states, **kwargs):
        # 只取第一个特殊符号 <s> 的隐藏状态,通常对应于 [CLS] 标记
        hidden_states = hidden_states[:, 0, :]  # take <s> token (equiv. to [CLS])
        
        # 应用dropout以减少过拟合风险
        hidden_states = self.dropout(hidden_states)
        
        # 使用全连接层进行特征变换
        hidden_states = self.dense(hidden_states)
        
        # 应用tanh激活函数,增强非线性建模能力
        hidden_states = torch.tanh(hidden_states)
        
        # 再次应用dropout
        hidden_states = self.dropout(hidden_states)
        
        # 最终通过线性投影层得到分类任务的输出
        hidden_states = self.out_proj(hidden_states)
        return hidden_states


@add_start_docstrings(
    """
    Reformer Model with a span classification head on top for extractive question-answering tasks like SQuAD / TriviaQA
    ( a linear layer on top of hidden-states output to compute `span start logits` and `span end logits`.
    """,
    REFORMER_START_DOCSTRING,
)
class ReformerForQuestionAnswering(ReformerPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        
        # 记录标签数量
        self.num_labels = config.num_labels
        
        # 初始化Reformer模型,这里的隐藏状态是2倍的隐藏大小,因为使用了可逆残差层
        self.reformer = ReformerModel(config)
        
        # 初始化问题回答任务的输出层,输入维度是2倍的隐藏状态大小,输出维度是标签数量
        self.qa_outputs = nn.Linear(2 * config.hidden_size, config.num_labels)

        # 初始化权重并应用最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=QuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        num_hashes: Optional[int] = None,
        start_positions: Optional[torch.Tensor] = None,
        end_positions: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        return_dict: Optional[bool] = None,

.\models\reformer\tokenization_reformer.py

# coding=utf-8
# 声明脚本编码格式为 UTF-8
# Copyright 2020 The Trax Authors and The HuggingFace Inc. team.
# 版权声明,指出代码版权归属和授权信息
#
# Licensed under the Apache License, Version 2.0 (the "License");
# 根据 Apache License 2.0 授权许可使用本代码
# you may not use this file except in compliance with the License.
# 除非符合许可证的规定,否则不得使用此文件
# You may obtain a copy of the License at
# 可以在上述链接获取许可证的副本
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 除非适用法律要求或书面同意,否则本软件按“原样”分发,不附带任何明示或暗示的保证或条件
# See the License for the specific language governing permissions and
# limitations under the License.
# 查阅许可证以了解具体的语言许可和限制
""" Tokenization class for model Reformer."""
# 用于 Reformer 模型的分词类

import os
# 导入操作系统相关功能模块
from shutil import copyfile
# 导入复制文件功能模块
from typing import Any, Dict, List, Optional, Tuple
# 导入类型提示模块

import sentencepiece as spm
# 导入 sentencepiece 库

from ...tokenization_utils import PreTrainedTokenizer
# 从 tokenization_utils 中导入 PreTrainedTokenizer 类
from ...utils import logging
# 从 utils 中导入 logging 模块


logger = logging.get_logger(__name__)
# 获取当前模块的 logger

SPIECE_UNDERLINE = "▁"
# 定义 SPIECE_UNDERLINE 符号为 "▁"

VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
# 定义词汇文件名映射字典

PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "google/reformer-crime-and-punishment": (
            "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/spiece.model"
        )
    }
}
# 预训练词汇文件映射字典,包含模型名称和对应的词汇文件下载链接

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "google/reformer-crime-and-punishment": 524288,
}
# 预训练位置嵌入尺寸字典,包含模型名称和对应的位置嵌入大小


class ReformerTokenizer(PreTrainedTokenizer):
    """
    Construct a Reformer tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece) .

    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
    this superclass for more information regarding those methods.
    """
    # Reformer 分词器类,基于 SentencePiece,继承自 PreTrainedTokenizer
    """
    Args:
        vocab_file (`str`):
            [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
            contains the vocabulary necessary to instantiate a tokenizer.
        eos_token (`str`, *optional*, defaults to `"</s>"`):
            The end of sequence token.

            <Tip>

            When building a sequence using special tokens, this is not the token that is used for the end of sequence.
            The token used is the `sep_token`.

            </Tip>

        unk_token (`str`, *optional*, defaults to `"<unk>"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        additional_special_tokens (`List[str]`, *optional*, defaults to `[]`):
            Additional special tokens used by the tokenizer.
        sp_model_kwargs (`dict`, *optional*):
            Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
            SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
            to set:

            - `enable_sampling`: Enable subword regularization.
            - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.

              - `nbest_size = {0,1}`: No sampling is performed.
              - `nbest_size > 1`: samples from the nbest_size results.
              - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
                using forward-filtering-and-backward-sampling algorithm.

            - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
              BPE-dropout.
    """

    # 定义类的常量
    vocab_files_names = VOCAB_FILES_NAMES
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        vocab_file,
        eos_token="</s>",
        unk_token="<unk>",
        additional_special_tokens=[],
        sp_model_kwargs: Optional[Dict[str, Any]] = None,
        **kwargs,
    ) -> None:
        # 初始化函数,设置实例变量
        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs

        self.vocab_file = vocab_file
        # 创建 SentencePieceProcessor 实例,并加载指定的词汇文件
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        self.sp_model.Load(vocab_file)

        # 调用父类的初始化方法,传递特殊的 token 和 sp_model_kwargs
        super().__init__(
            eos_token=eos_token,
            unk_token=unk_token,
            additional_special_tokens=additional_special_tokens,
            sp_model_kwargs=self.sp_model_kwargs,
            **kwargs,
        )

    @property
    def vocab_size(self):
        # 返回当前 SentencePieceProcessor 实例的词汇大小
        return self.sp_model.get_piece_size()
    def get_vocab(self) -> Dict[str, int]:
        # 创建一个词汇表字典,将词汇映射为其对应的 ID
        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
        # 将额外添加的特殊token编码器内容合并到词汇表字典中
        vocab.update(self.added_tokens_encoder)
        return vocab

    def __getstate__(self):
        # 复制对象的当前状态
        state = self.__dict__.copy()
        # 置空sp_model字段,以准备进行对象的序列化
        state["sp_model"] = None
        return state

    def __setstate__(self, d):
        # 恢复对象的状态
        self.__dict__ = d

        # 为了向后兼容性
        # 如果对象没有sp_model_kwargs属性,则创建一个空字典
        if not hasattr(self, "sp_model_kwargs"):
            self.sp_model_kwargs = {}

        # 使用sp_model_kwargs参数重新初始化sp_model对象
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        # 加载vocab_file指定的词汇模型文件到sp_model对象中
        self.sp_model.Load(self.vocab_file)

    def _tokenize(self, text: str) -> List[str]:
        """接受一个字符串作为输入,并返回一个由单词/子词组成的列表(tokens)"""
        return self.sp_model.encode(text, out_type=str)

    def _convert_token_to_id(self, token):
        """将一个token(字符串)转换为其对应的ID,使用词汇表"""
        return self.sp_model.piece_to_id(token)

    def _convert_id_to_token(self, index):
        """将一个索引(整数)转换为其对应的token(字符串),使用词汇表"""
        if index < self.sp_model.get_piece_size():
            token = self.sp_model.IdToPiece(index)
        return token

    def convert_tokens_to_string(self, tokens):
        """将一系列token(字符串)转换为单个字符串"""
        current_sub_tokens = []
        out_string = ""
        for token in tokens:
            # 确保特殊token不使用sentencepiece模型进行解码
            if token in self.all_special_tokens:
                out_string += self.sp_model.decode(current_sub_tokens) + token
                current_sub_tokens = []
            else:
                current_sub_tokens.append(token)
        out_string += self.sp_model.decode(current_sub_tokens)
        return out_string.strip()

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # 检查保存目录是否存在,如果不存在则报错
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return
        # 构建输出的词汇表文件路径
        out_vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )

        # 如果当前词汇表文件路径与目标路径不同,并且当前词汇表文件存在,则进行复制
        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
            copyfile(self.vocab_file, out_vocab_file)
        # 如果当前词汇表文件不存在,则将sp_model的序列化模型内容写入目标路径文件中
        elif not os.path.isfile(self.vocab_file):
            with open(out_vocab_file, "wb") as fi:
                content_spiece_model = self.sp_model.serialized_model_proto()
                fi.write(content_spiece_model)

        return (out_vocab_file,)
posted @ 2024-06-29 16:56  绝不原创的飞龙  阅读(8)  评论(0编辑  收藏  举报