Transformers-源码解析-七十四-

Transformers 源码解析(七十四)

.\models\mega\__init__.py

# Copyright 2023 The HuggingFace Team. All rights reserved.
# 版权声明及许可信息,指明该代码的版权归属及使用许可
#
# Licensed under the Apache License, Version 2.0 (the "License");
# 根据 Apache 许可证版本 2.0 进行许可
# you may not use this file except in compliance with the License.
# 除非符合许可证的条件,否则不得使用此文件
# You may obtain a copy of the License at
# 可以在以下网址获得许可证的副本
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 除非法律有明确规定或书面同意,否则按"原样"分发
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 没有任何形式的明示或暗示的保证和条件
# See the License for the specific language governing permissions and
# 请查阅许可证了解具体的语言授权条款及限制。
# limitations under the License.

from typing import TYPE_CHECKING
# 引入类型检查模块

from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_torch_available,
)
# 从相对路径中引入必要的工具模块及函数

_import_structure = {
    "configuration_mega": ["MEGA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegaConfig", "MegaOnnxConfig"],
}
# 定义一个字典,包含 Mega 模块的配置信息

try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
# 检查是否存在 Torch 库,如果不存在则引发异常
except OptionalDependencyNotAvailable:
    pass
else:
    _import_structure["modeling_mega"] = [
        "MEGA_PRETRAINED_MODEL_ARCHIVE_LIST",
        "MegaForCausalLM",
        "MegaForMaskedLM",
        "MegaForMultipleChoice",
        "MegaForQuestionAnswering",
        "MegaForSequenceClassification",
        "MegaForTokenClassification",
        "MegaModel",
        "MegaPreTrainedModel",
    ]
    # 如果 Torch 存在,则将 Mega 模块的建模信息添加到导入结构中

if TYPE_CHECKING:
    from .configuration_mega import MEGA_PRETRAINED_CONFIG_ARCHIVE_MAP, MegaConfig, MegaOnnxConfig
    # 如果在类型检查模式下,从配置模块导入配置映射和配置类

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_mega import (
            MEGA_PRETRAINED_MODEL_ARCHIVE_LIST,
            MegaForCausalLM,
            MegaForMaskedLM,
            MegaForMultipleChoice,
            MegaForQuestionAnswering,
            MegaForSequenceClassification,
            MegaForTokenClassification,
            MegaModel,
            MegaPreTrainedModel,
        )
        # 如果 Torch 存在,从建模模块导入 Mega 模块的各个模型类

else:
    import sys
    # 导入 sys 模块

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
    # 如果不是类型检查模式,将当前模块替换为懒加载模块,实现按需导入

.\models\megatron_bert\configuration_megatron_bert.py

# 设置文件编码为 UTF-8
# 版权声明,版权归 NVIDIA 公司和 HuggingFace Inc. 团队所有
# 根据 Apache 许可证版本 2.0 使用本文件,除非符合许可证的要求,否则不得使用本文件
# 可以在以下网址获取许可证的副本:http://www.apache.org/licenses/LICENSE-2.0
# 根据适用法律或书面同意,本软件是基于“按原样”分发的,没有任何形式的担保或条件
# 请参阅许可证以获取详细的条款和条件信息

""" MEGATRON_BERT 模型配置"""

# 从 transformers 库中导入预训练配置类 PretrainedConfig
from ...configuration_utils import PretrainedConfig
# 从 transformers 库中导入日志记录工具 logging
from ...utils import logging

# 获取指定名称空间下的日志记录器
logger = logging.get_logger(__name__)

# MEGATRON_BERT 预训练配置文件存档映射,目前为空字典
MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    # 可以在 https://huggingface.co/models?filter=bert 查看所有 MEGATRON_BERT 模型
}


class MegatronBertConfig(PretrainedConfig):
    r"""
    这是用于存储 [`MegatronBertModel`] 配置的配置类。它用于根据指定的参数实例化一个 MEGATRON_BERT 模型,
    定义模型的架构。使用默认值实例化配置将产生类似于 MEGATRON_BERT 
    [nvidia/megatron-bert-uncased-345m](https://huggingface.co/nvidia/megatron-bert-uncased-345m) 架构的配置。

    配置对象继承自 [`PretrainedConfig`],可用于控制模型输出。有关更多信息,请阅读 [`PretrainedConfig`] 的文档。

    Examples:

    ```
    >>> from transformers import MegatronBertConfig, MegatronBertModel

    >>> # 初始化一个 MEGATRON_BERT google-bert/bert-base-uncased 风格的配置
    >>> configuration = MegatronBertConfig()

    >>> # 使用配置初始化一个(带有随机权重)从 google-bert/bert-base-uncased 风格配置的模型
    >>> model = MegatronBertModel(configuration)

    >>> # 访问模型配置
    >>> configuration = model.config
    ```
    """

    # 模型类型为 "megatron-bert"
    model_type = "megatron-bert"

    def __init__(
        self,
        vocab_size=29056,
        hidden_size=1024,
        num_hidden_layers=24,
        num_attention_heads=16,
        intermediate_size=4096,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=512,
        type_vocab_size=2,
        initializer_range=0.02,
        layer_norm_eps=1e-12,
        pad_token_id=0,
        position_embedding_type="absolute",
        use_cache=True,
        **kwargs,
        ):
        # 调用父类的初始化方法,传递填充令牌 ID 和其他关键字参数
        super().__init__(pad_token_id=pad_token_id, **kwargs)

        # 设置模型的词汇表大小
        self.vocab_size = vocab_size
        # 设置隐藏层的大小
        self.hidden_size = hidden_size
        # 设置隐藏层的数量
        self.num_hidden_layers = num_hidden_layers
        # 设置注意力头的数量
        self.num_attention_heads = num_attention_heads
        # 设置隐藏层激活函数的类型
        self.hidden_act = hidden_act
        # 设置中间层大小(即 Transformer 中的 feedforward 层大小)
        self.intermediate_size = intermediate_size
        # 设置隐藏层的 dropout 概率
        self.hidden_dropout_prob = hidden_dropout_prob
        # 设置注意力概率 dropout 的概率
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        # 设置最大位置嵌入的大小
        self.max_position_embeddings = max_position_embeddings
        # 设置类型词汇表的大小(通常用于区分句子 A 和句子 B)
        self.type_vocab_size = type_vocab_size
        # 设置初始化范围(权重初始化的范围)
        self.initializer_range = initializer_range
        # 设置层归一化的 epsilon 值
        self.layer_norm_eps = layer_norm_eps
        # 设置位置嵌入的类型(绝对位置编码或相对位置编码)
        self.position_embedding_type = position_embedding_type
        # 设置是否使用缓存(用于缓存中间计算结果,提高效率)
        self.use_cache = use_cache

.\models\megatron_bert\convert_megatron_bert_checkpoint.py

# 引入 argparse 库用于解析命令行参数
import argparse
# 引入 os 库用于与操作系统交互
import os
# 引入 re 库用于正则表达式操作
import re
# 引入 zipfile 库用于 ZIP 文件操作
import zipfile

# 引入 torch 库
import torch

# 从 transformers 库中引入 MegatronBertConfig 类
from transformers import MegatronBertConfig


def recursive_print(name, val, spaces=0):
    # 递归打印字典或者 Tensor 的内容
    # 格式化消息
    if name is None:
        msg = None
    else:
        fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}"
        msg = fmt.format(name)

    # 打印消息并递归打印(如果需要)
    if isinstance(val, dict):
        if msg is not None:
            print(msg)
        for k in val.keys():
            recursive_print(k, val[k], spaces + 2)
    elif isinstance(val, torch.Tensor):
        print(msg, ":", val.size())
    else:
        print(msg, ":", val)


def fix_query_key_value_ordering(param, checkpoint_version, num_splits, num_heads, hidden_size):
    # 重新排列 param 张量的布局为 [num_splits * num_heads * hidden_size, :]
    # 以便与后续版本的 NVIDIA Megatron-LM 兼容
    # 在 Megatron-LM 内部执行逆操作以读取检查点:
    # https://github.com/NVIDIA/Megatron-LM/blob/v2.4/megatron/checkpointing.py#L209
    # 如果 param 是 self-attention 块的权重张量,则返回的张量还需要再次转置才能被 HuggingFace BERT 读取
    input_shape = param.size()
    # 如果版本号为 1.0:
    if checkpoint_version == 1.0:
        # 版本 1.0 存储形状为 [num_heads * hidden_size * num_splits, :] 的参数
        saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:]
        # 调整参数的形状为 saved_shape
        param = param.view(*saved_shape)
        # 将维度 0 和 2 进行转置
        param = param.transpose(0, 2)
        # 将维度 1 和 2 进行转置并保证内存连续性
        param = param.transpose(1, 2).contiguous()
    # 如果版本号大于或等于 2.0:
    elif checkpoint_version >= 2.0:
        # 其他版本存储形状为 [num_heads * num_splits * hidden_size, :] 的参数
        saved_shape = (num_heads, num_splits, hidden_size) + input_shape[1:]
        # 调整参数的形状为 saved_shape
        param = param.view(*saved_shape)
        # 将维度 0 和 1 进行转置并保证内存连续性
        param = param.transpose(0, 1).contiguous()
    # 最终将参数的形状调整为 input_shape
    param = param.view(*input_shape)
    # 返回调整形状后的参数
    return param
# 定义一个函数,用于转换 Megatron-LM 的检查点到适用于 Transformers 框架的格式
def convert_megatron_checkpoint(args, input_state_dict, config):
    # 输出的模型状态字典
    output_state_dict = {}

    # 旧版本可能没有存储训练参数
    ds_args = input_state_dict.get("args", None)
    if ds_args is not None:
        # 如果存在训练参数,将其配置信息应用到转换后的配置中
        config.tokenizer_type = ds_args.tokenizer_type
        config.vocab_size = ds_args.padded_vocab_size
        config.max_position_embeddings = ds_args.max_position_embeddings
        config.hidden_size = ds_args.hidden_size
        config.num_hidden_layers = ds_args.num_layers
        config.num_attention_heads = ds_args.num_attention_heads
        config.intermediate_size = ds_args.ffn_hidden_size if "ffn_hidden_size" in ds_args else 4 * ds_args.hidden_size

    # 注意力头的数量
    heads = config.num_attention_heads
    # 每个注意力头的隐藏大小
    hidden_size_per_head = config.hidden_size // heads
    # Megatron-LM 的检查点版本
    if "checkpoint_version" in input_state_dict.keys():
        checkpoint_version = input_state_dict["checkpoint_version"]
    else:
        checkpoint_version = 0.0

    # 模型
    model = input_state_dict["model"]
    # 语言模型
    lm = model["language_model"]
    # 嵌入层
    embeddings = lm["embedding"]

    # 词嵌入
    word_embeddings = embeddings["word_embeddings"]["weight"]
    # 截断嵌入表到指定的词汇表大小
    word_embeddings = word_embeddings[: config.vocab_size, :]
    # 存储词嵌入
    output_state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings

    # 位置嵌入
    pos_embeddings = embeddings["position_embeddings"]["weight"]
    assert pos_embeddings.size(0) == config.max_position_embeddings and pos_embeddings.size(1) == config.hidden_size
    # 存储位置嵌入
    output_state_dict["bert.embeddings.position_embeddings.weight"] = pos_embeddings

    # 类型嵌入
    tokentype_embeddings = embeddings["tokentype_embeddings"]["weight"]
    # 存储类型嵌入
    output_state_dict["bert.embeddings.token_type_embeddings.weight"] = tokentype_embeddings

    # Transformer 模块
    transformer = lm["transformer"] if "transformer" in lm.keys() else lm["encoder"]

    # 用于提取层名称的正则表达式
    layer_re = re.compile(r"layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")

    # Megatron-LM 到 Transformers 的简单名称映射
    megatron_to_transformers = {
        "attention.dense": ".attention.output.dense.",
        "self_attention.dense": ".attention.output.dense.",
        "mlp.dense_h_to_4h": ".intermediate.dense.",
        "mlp.dense_4h_to_h": ".output.dense.",
    }
    # 跟踪注意力/查询/值张量的变量,初始设为None
    attention_qkv_weight = None

    # 提取模型的各层参数并存储到输出状态字典中

    # 存储最终的层归一化权重
    output_state_dict["bert.encoder.ln.weight"] = transformer["final_layernorm.weight"]
    # 存储最终的层归一化偏置
    output_state_dict["bert.encoder.ln.bias"] = transformer["final_layernorm.bias"]

    # 提取并存储池化器的权重和偏置
    pooler = lm["pooler"]
    output_state_dict["bert.pooler.dense.weight"] = pooler["dense.weight"]
    output_state_dict["bert.pooler.dense.bias"] = pooler["dense.bias"]

    # 从 Megatron 的语言模型头部提取并存储 LM 头部的权重和偏置

    # 提取转换矩阵的权重
    output_state_dict["cls.predictions.transform.dense.weight"] = lm_head["dense.weight"]
    # 提取转换矩阵的偏置
    output_state_dict["cls.predictions.transform.dense.bias"] = lm_head["dense.bias"]

    # 提取转换层归一化的权重
    output_state_dict["cls.predictions.transform.LayerNorm.weight"] = lm_head["layernorm.weight"]
    # 提取转换层归一化的偏置
    output_state_dict["cls.predictions.transform.LayerNorm.bias"] = lm_head["layernorm.bias"]

    # 对于解码器,复制词嵌入的权重并存储到输出状态字典中
    output_state_dict["cls.predictions.decoder.weight"] = word_embeddings
    # 存储 LM 头部的偏置
    output_state_dict["cls.predictions.bias"] = lm_head["bias"]

    # 从 Megatron 的二元分类器提取并存储分类器的权重和偏置

    # 存储序列关系分类器的权重
    output_state_dict["cls.seq_relationship.weight"] = binary_head["weight"]
    # 存储序列关系分类器的偏置
    output_state_dict["cls.seq_relationship.bias"] = binary_head["bias"]

    # 返回最终的输出状态字典
    return output_state_dict
# 定义程序的主函数
def main():
    # 创建参数解析器
    parser = argparse.ArgumentParser()
    # 添加用于打印检查点结构的参数
    parser.add_argument("--print-checkpoint-structure", action="store_true")
    # 添加指向包含检查点的 ZIP 文件路径的参数
    parser.add_argument("path_to_checkpoint", type=str, help="Path to the ZIP file containing the checkpoint")
    # 添加可选的配置文件参数,描述预训练模型的配置
    parser.add_argument(
        "--config_file",
        default="",
        type=str,
        help="An optional config json file describing the pre-trained model.",
    )
    # 解析命令行参数
    args = parser.parse_args()

    # 提取路径的基本名称部分
    basename = os.path.dirname(args.path_to_checkpoint)

    # 加载模型
    print(f'Extracting PyTorch state dictionary from "{args.path_to_checkpoint}"')
    # 如果路径以 .zip 结尾,则使用 zipfile 模块解压缩
    if args.path_to_checkpoint.endswith(".zip"):
        with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint:
            # 使用 zipfile 中的文件打开函数获取 PyTorch 状态字典
            with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict:
                input_state_dict = torch.load(pytorch_dict, map_location="cpu")
    else:
        # 否则直接加载 PyTorch 状态字典
        input_state_dict = torch.load(args.path_to_checkpoint, map_location="cpu")

    # 根据配置文件是否为空,选择相应的 MegatronBertConfig
    if args.config_file == "":
        # 默认使用 Megatron-BERT 345m 的配置
        config = MegatronBertConfig()
        # 根据输入状态字典调整词汇表大小
        config.vocab_size = input_state_dict["model"]["lm_head"]["bias"].numel()
    else:
        # 从 JSON 文件加载 MegatronBertConfig
        config = MegatronBertConfig.from_json_file(args.config_file)

    # 执行转换
    print("Converting")
    output_state_dict = convert_megatron_checkpoint(args, input_state_dict, config)

    # 如果需要打印检查点结构,则递归打印输出状态字典
    if args.print_checkpoint_structure:
        recursive_print(None, output_state_dict)

    # 将配置保存到文件中
    print("Saving config")
    config.save_pretrained(basename)

    # 将输出的状态字典保存到文件中
    output_checkpoint_file = os.path.join(basename, "pytorch_model.bin")
    print(f'Saving checkpoint to "{output_checkpoint_file}"')
    torch.save(output_state_dict, output_checkpoint_file)


if __name__ == "__main__":
    # 如果是直接执行本脚本,则调用主函数
    main()

.\models\megatron_bert\modeling_megatron_bert.py

# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018-2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch MegatronBERT model."""


import math
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    BaseModelOutputWithPoolingAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    NextSentencePredictorOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
    ModelOutput,
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from .configuration_megatron_bert import MegatronBertConfig


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "MegatronBertConfig"
_CHECKPOINT_FOR_DOC = "nvidia/megatron-bert-cased-345m"

MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "nvidia/megatron-bert-cased-345m",
    # See all MegatronBERT models at https://huggingface.co/models?filter=megatron_bert
]


def load_tf_weights_in_megatron_bert(model, config, tf_checkpoint_path):
    """Load tf checkpoints in a pytorch model."""
    try:
        import re

        import numpy as np
        import tensorflow as tf
    except ImportError:
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
        raise
    # 获取 TensorFlow checkpoint 文件的绝对路径
    tf_path = os.path.abspath(tf_checkpoint_path)
    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
    # 从 TF 模型中加载权重
    init_vars = tf.train.list_variables(tf_path)
    # 初始化变量名称列表和数组
    names = []
    arrays = []
    # 遍历初始化变量的列表,每个变量由名称和形状组成
    for name, shape in init_vars:
        # 记录日志,显示正在加载的 TensorFlow 权重的名称和形状
        logger.info(f"Loading TF weight {name} with shape {shape}")
        # 使用 TensorFlow API 加载指定路径下的变量数据
        array = tf.train.load_variable(tf_path, name)
        # 将变量名称添加到名称列表
        names.append(name)
        # 将加载的变量数据添加到数组列表
        arrays.append(array)

    # 遍历名称列表和数组列表,分别为权重变量名和对应的数据数组
    for name, array in zip(names, arrays):
        # 将变量名称按 '/' 分割为列表
        name = name.split("/")
        
        # 检查名称列表中是否包含特定的变量名,如果包含则跳过当前循环
        if any(
            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
            for n in name
        ):
            # 记录日志,显示跳过加载的变量名称
            logger.info(f"Skipping {'/'.join(name)}")
            continue
        
        # 初始化指针为模型对象
        pointer = model
        
        # 遍历变量名列表
        for m_name in name:
            # 如果变量名符合特定格式,按指定规则分割名称
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
                scope_names = re.split(r"_(\d+)", m_name)
            else:
                scope_names = [m_name]
            
            # 根据名称的第一个部分选择不同的操作
            if scope_names[0] == "kernel" or scope_names[0] == "gamma":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
                pointer = getattr(pointer, "bias")
            elif scope_names[0] == "output_weights":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "squad":
                pointer = getattr(pointer, "classifier")
            else:
                # 尝试获取指定名称的属性,如果失败则记录日志并跳过当前循环
                try:
                    pointer = getattr(pointer, scope_names[0])
                except AttributeError:
                    logger.info(f"Skipping {'/'.join(name)}")
                    continue
            
            # 如果名称列表长度大于等于2,表示有额外的数字部分
            if len(scope_names) >= 2:
                num = int(scope_names[1])
                pointer = pointer[num]
        
        # 检查变量名的结尾是否为 "_embeddings"
        if m_name[-11:] == "_embeddings":
            pointer = getattr(pointer, "weight")
        elif m_name == "kernel":
            # 如果变量名为 "kernel",将数组转置
            array = np.transpose(array)
        
        # 检查指针和数组的形状是否匹配,如果不匹配则引发 ValueError
        if pointer.shape != array.shape:
            raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
        
        # 记录日志,显示正在初始化的 PyTorch 权重的名称
        logger.info("Initialize PyTorch weight {}".format(name))
        # 将数组转换为 PyTorch 张量,并赋值给指针的数据部分
        pointer.data = torch.from_numpy(array)

    # 返回经过初始化后的模型对象
    return model
class MegatronBertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, config):
        super().__init__()
        # 定义词嵌入层,将词索引映射为隐藏表示向量,支持填充索引
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        # 定义位置嵌入层,将位置索引映射为隐藏表示向量
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        # 定义标记类型嵌入层,将标记类型索引映射为隐藏表示向量
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file

        # 在 Megatron 中,LayerNorm 在第一个 dropout 之后应用。
        # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        # 注册一个缓冲区变量,存储从 0 到 config.max_position_embeddings-1 的位置索引
        self.register_buffer(
            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
        )
        # 设置位置嵌入类型,默认为绝对位置编码
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.LongTensor] = None,
        past_key_values_length: int = 0,
    ) -> torch.Tensor:
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if position_ids is None:
            # 如果未提供位置索引,则使用预先注册的位置 ids,从 past_key_values_length 到 seq_length + past_key_values_length
            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]

        if token_type_ids is None:
            # 如果未提供标记类型索引,则使用全零的张量,形状与输入一致
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

        if inputs_embeds is None:
            # 如果未提供输入嵌入,通过词嵌入层获得嵌入表示
            inputs_embeds = self.word_embeddings(input_ids)
        # 获取标记类型嵌入
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        # 将词嵌入和标记类型嵌入相加作为最终的嵌入表示
        embeddings = inputs_embeds + token_type_embeddings

        if self.position_embedding_type == "absolute":
            # 如果位置嵌入类型为绝对位置编码,则添加绝对位置嵌入
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings

        # Megatron BERT 将 LayerNorm 移动到 dropout 后面(以及每个层中)。
        # embeddings = self.LayerNorm(embeddings)
        # 应用 dropout
        embeddings = self.dropout(embeddings)
        return embeddings


# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->MegatronBert
class MegatronBertSelfAttention(nn.Module):
    # 初始化函数,用于初始化一个注意力机制模型实例
    def __init__(self, config, position_embedding_type=None):
        # 调用父类的初始化方法
        super().__init__()
        # 检查隐藏层大小是否能够整除注意力头的数量,如果不能整除且没有嵌入大小参数,则引发数值错误
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        # 设置注意力头的数量和每个头的大小
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # 初始化查询、键、值的线性变换层
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        # 初始化dropout层,用于注意力概率的随机丢弃
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        # 设置位置嵌入的类型,默认为绝对位置嵌入
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        # 如果位置嵌入类型是相对位置嵌入,则初始化距离嵌入层
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

        # 标记当前模型是否为解码器
        self.is_decoder = config.is_decoder
# 基于 transformers.models.bert.modeling_bert.BertSelfOutput。将 LayerNorm 移到下面的 MegatronBertAttention 中。
class MegatronBertSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化一个全连接层,输入和输出维度都为 config.hidden_size
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 初始化一个 dropout 层,使用 config.hidden_dropout_prob 的概率进行随机失活
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
        # 将输入的 hidden_states 应用全连接层
        hidden_states = self.dense(hidden_states)
        # 对全连接层的输出进行 dropout 处理
        hidden_states = self.dropout(hidden_states)
        # 返回残差连接后的结果
        return residual + hidden_states


# 基于 transformers.models.bert.modeling_bert.BertAttention。添加了 LayerNorm。
class MegatronBertAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化 LayerNorm 层,输入维度为 config.hidden_size,epsilon 参数为 config.layer_norm_eps
        self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 初始化 MegatronBertSelfAttention 层
        self.self = MegatronBertSelfAttention(config)
        # 初始化 MegatronBertSelfOutput 层
        self.output = MegatronBertSelfOutput(config)
        # 初始化一个集合,用于存储被修剪的注意力头的索引
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        # 调用 find_pruneable_heads_and_indices 函数找到可修剪的头部和其索引
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )

        # 修剪线性层
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # 更新超参数并存储修剪后的头部
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # 对输入的 hidden_states 应用 LayerNorm 层
        ln_outputs = self.ln(hidden_states)
        # 将经过 LayerNorm 处理后的输出传递给 MegatronBertSelfAttention 层处理
        self_outputs = self.self(
            ln_outputs,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
        )
        # 将 MegatronBertSelfOutput 处理后的结果与原始 hidden_states 相加得到 attention_output
        attention_output = self.output(self_outputs[0], hidden_states)
        # 如果需要输出注意力矩阵,则在 outputs 中添加它们
        outputs = (attention_output,) + self_outputs[1:]
        # 返回最终的输出元组
        return outputs


# 从 transformers.models.bert.modeling_bert.BertIntermediate 复制,修改为 Bert->MegatronBert
class MegatronBertIntermediate(nn.Module):
    # 初始化方法,用于创建一个新的实例
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__()
        # 创建一个全连接层,输入大小为 config.hidden_size,输出大小为 config.intermediate_size
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        
        # 判断 config.hidden_act 是否是字符串类型
        if isinstance(config.hidden_act, str):
            # 如果是字符串,则从预定义的映射 ACT2FN 中获取对应的激活函数
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            # 如果不是字符串,则直接使用 config.hidden_act 作为激活函数
            self.intermediate_act_fn = config.hidden_act

    # 前向传播方法,接受一个 torch.Tensor 类型的 hidden_states,返回一个 torch.Tensor 类型的结果
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 使用全连接层对输入的 hidden_states 进行变换
        hidden_states = self.dense(hidden_states)
        # 对变换后的 hidden_states 应用激活函数 intermediate_act_fn
        hidden_states = self.intermediate_act_fn(hidden_states)
        # 返回激活后的结果
        return hidden_states
# 基于 transformers.models.bert.modeling_bert.BertOutput。将 LayerNorm 移动到下面的 MegatronBertLayer 中。
class MegatronBertOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 创建一个线性层,将输入的特征维度转换为隐藏层的维度
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        # 添加一个用于随机失活的层,以减少过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 使用线性层对隐藏状态进行线性变换
        hidden_states = self.dense(hidden_states)
        # 对线性变换后的隐藏状态进行随机失活处理
        hidden_states = self.dropout(hidden_states)
        # 返回经过线性变换和随机失活处理后的隐藏状态与输入张量的和
        return input_tensor + hidden_states


# 基于 transformers.models.bert.modeling_bert.BertLayer。添加了 LayerNorm。
class MegatronBertLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 设置用于前向传播分块处理的大小
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        # 序列长度的维度索引
        self.seq_len_dim = 1
        # 创建注意力层对象
        self.attention = MegatronBertAttention(config)
        # 是否作为解码器使用
        self.is_decoder = config.is_decoder
        # 是否添加跨注意力
        self.add_cross_attention = config.add_cross_attention
        # 如果添加了跨注意力,需要作为解码器模型使用
        if self.add_cross_attention:
            if not self.is_decoder:
                raise TypeError(f"{self} should be used as a decoder model if cross attention is added")
            # 创建跨注意力层对象
            self.crossattention = MegatronBertAttention(config)
        # 使用 LayerNorm 对隐藏状态进行归一化处理
        self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 创建 Bert 中间层对象
        self.intermediate = MegatronBertIntermediate(config)
        # 创建 Bert 输出层对象
        self.output = MegatronBertOutput(config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
        ):
        # 在前向传播中,依次通过注意力层、LayerNorm、中间层和输出层处理隐藏状态
    ) -> Tuple[torch.Tensor]:
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        # 如果过去的键/值对不为空,则获取自注意力部分的前两个位置的缓存
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
        # 使用自注意力模型处理隐藏状态,得到自注意力输出
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            past_key_value=self_attn_past_key_value,
        )
        # 提取自注意力输出的主要输出部分
        attention_output = self_attention_outputs[0]

        # 如果是解码器,最后一个输出是自注意力缓存的元组
        if self.is_decoder:
            # 排除最后一个元素,因为它是缓存的结构
            outputs = self_attention_outputs[1:-1]
            # 获取当前注意力的键/值对
            present_key_value = self_attention_outputs[-1]
        else:
            # 如果不是解码器,输出包括自注意力权重
            outputs = self_attention_outputs[1:]  # 添加自注意力权重
          
        # 交叉注意力的当前键/值对默认为空
        cross_attn_present_key_value = None
        # 如果是解码器且存在编码器隐藏状态
        if self.is_decoder and encoder_hidden_states is not None:
            # 检查是否存在交叉注意力层
            if not hasattr(self, "crossattention"):
                raise AttributeError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
                    " by setting `config.add_cross_attention=True`"
                )

            # 交叉注意力的过去键/值对在过去键/值对元组的第三和第四位置
            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
            # 使用交叉注意力模型处理自注意力输出和编码器的相关信息
            cross_attention_outputs = self.crossattention(
                attention_output,
                attention_mask,
                head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                cross_attn_past_key_value,
                output_attentions,
            )
            # 提取交叉注意力的主要输出部分
            attention_output = cross_attention_outputs[0]
            # 添加交叉注意力权重到输出列表中
            outputs = outputs + cross_attention_outputs[1:-1]

            # 将交叉注意力缓存添加到当前键/值对的第三和第四位置
            cross_attn_present_key_value = cross_attention_outputs[-1]
            present_key_value = present_key_value + cross_attn_present_key_value

        # 将注意力输出应用于前馈网络的函数,并根据需要分块处理
        layer_output = apply_chunking_to_forward(
            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
        )
        # 将层输出添加到输出元组中
        outputs = (layer_output,) + outputs

        # 如果是解码器,将注意力键/值对作为最后的输出
        if self.is_decoder:
            outputs = outputs + (present_key_value,)

        return outputs

    def feed_forward_chunk(self, attention_output):
        # 对注意力输出进行层归一化处理
        ln_output = self.ln(attention_output)
        # 进行前馈网络的中间层处理
        intermediate_output = self.intermediate(ln_output)
        # 输出最终的层输出
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output
# 定义 MegatronBertEncoder 类,继承自 nn.Module
class MegatronBertEncoder(nn.Module):
    # 初始化方法,接受 config 参数
    def __init__(self, config):
        # 调用父类初始化方法
        super().__init__()
        # 将 config 存储在实例中
        self.config = config
        # 创建一个包含多个 MegatronBertLayer 实例的列表,列表长度为 config.num_hidden_layers
        self.layer = nn.ModuleList([MegatronBertLayer(config) for _ in range(config.num_hidden_layers)])

        # 最终的层归一化层。我们删除了第一个 LN,将 LN 移动到每个隐藏层以及此层
        # 这只是最终的 LN(Transformer 的 BERT 附加到每个隐藏层)。
        self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.gradient_checkpointing = False

    # 前向传播方法
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = True,
# 从 transformers.models.bert.modeling_bert.BertPooler 复制,将 Bert->MegatronBert
class MegatronBertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 创建一个线性层
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 创建一个 Tanh 激活函数
        self.activation = nn.Tanh()

    # 前向传播方法
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 通过简单地提取与第一个令牌对应的隐藏状态来“池化”模型。
        first_token_tensor = hidden_states[:, 0]
        # 经过线性层
        pooled_output = self.dense(first_token_tensor)
        # 经过激活函数
        pooled_output = self.activation(pooled_output)
        return pooled_output


# 从 transformers.models.bert.modeling_bert.BertPredictionHeadTransform 复制,将 Bert->MegatronBert
class MegatronBertPredictionHeadTransform(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 创建一个线性层
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 如果 config.hidden_act 是字符串,则使用 ACT2FN 字典中对应的激活函数
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        # 创建一个层归一化层
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    # 前向传播方法
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 通过线性层
        hidden_states = self.dense(hidden_states)
        # 经过激活函数
        hidden_states = self.transform_act_fn(hidden_states)
        # 经过层归一化层
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


# 从 transformers.models.bert.modeling_bert.BertLMPredictionHead 复制,将 Bert->MegatronBert
class MegatronBertLMPredictionHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.transform = MegatronBertPredictionHeadTransform(config)

        # 输出权重与输入嵌入的权重相同,但每个标记都有一个仅用于输出的偏置项。
        # 创建一个线性层,用于将隐藏状态映射到词汇表大小的输出空间,没有偏置项。
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # 创建一个可学习的偏置参数,大小与词汇表大小相同。
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))

        # 需要一个链接,以便偏置项能够正确地随 `resize_token_embeddings` 被调整大小。
        self.decoder.bias = self.bias

    def forward(self, hidden_states):
        # 使用预定义的变换对隐藏状态进行转换
        hidden_states = self.transform(hidden_states)
        # 将转换后的隐藏状态输入到线性层中,得到输出
        hidden_states = self.decoder(hidden_states)
        return hidden_states
# 从 transformers.models.bert.modeling_bert.BertOnlyMLMHead 复制而来,将 Bert 替换为 MegatronBert
class MegatronBertOnlyMLMHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 使用 MegatronBertLMPredictionHead 初始化预测模块
        self.predictions = MegatronBertLMPredictionHead(config)

    def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
        # 基于序列输出计算预测分数
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


# 从 transformers.models.bert.modeling_bert.BertOnlyNSPHead 复制而来,将 Bert 替换为 MegatronBert
class MegatronBertOnlyNSPHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 使用线性层初始化序列关系预测模块
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, pooled_output):
        # 基于池化输出计算序列关系分数
        seq_relationship_score = self.seq_relationship(pooled_output)
        return seq_relationship_score


# 从 transformers.models.bert.modeling_bert.BertPreTrainingHeads 复制而来,将 Bert 替换为 MegatronBert
class MegatronBertPreTrainingHeads(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 使用 MegatronBertLMPredictionHead 初始化预测模块
        self.predictions = MegatronBertLMPredictionHead(config)
        # 使用线性层初始化序列关系预测模块
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, sequence_output, pooled_output):
        # 基于序列输出计算预测分数
        prediction_scores = self.predictions(sequence_output)
        # 基于池化输出计算序列关系分数
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score


# MegatronBertPreTrainedModel 类,为预训练模型提供权重初始化和简单的预训练模型加载接口
class MegatronBertPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 使用 MegatronBertConfig 进行配置
    config_class = MegatronBertConfig
    # 使用 load_tf_weights_in_megatron_bert 进行加载 TensorFlow 权重
    load_tf_weights = load_tf_weights_in_megatron_bert
    # 设置基础模型前缀为 "bert"
    base_model_prefix = "bert"
    # 支持梯度检查点
    supports_gradient_checkpointing = True

    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # 使用正态分布初始化权重,与 TensorFlow 版本稍有不同,后者使用截断正态分布
            # 参考 https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            # 将 LayerNorm 的偏置项初始化为零,权重初始化为一
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            # 线性层的偏置项初始化为零
            module.bias.data.zero_()


@dataclass
# 从 transformers.models.bert.modeling_bert.BertForPreTrainingOutput 复制而来,将 Bert 替换为 MegatronBert
class MegatronBertForPreTrainingOutput(ModelOutput):
    """
    Output type of [`MegatronBertForPreTraining`].
    """
    # 可选参数:如果提供了 `labels`,则返回损失值,类型为 `torch.FloatTensor`,形状为 `(1,)`
    loss: Optional[torch.FloatTensor] = None
    # 语言建模头部的预测得分,形状为 `(batch_size, sequence_length, config.vocab_size)`
    prediction_logits: torch.FloatTensor = None
    # 下一个序列预测(分类)头部的预测得分,形状为 `(batch_size, 2)`
    seq_relationship_logits: torch.FloatTensor = None
    # 可选参数:如果 `output_hidden_states=True` 或 `config.output_hidden_states=True`,返回模型每层的隐藏状态
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # 可选参数:如果 `output_attentions=True` 或 `config.output_attentions=True`,返回注意力权重
    attentions: Optional[Tuple[torch.FloatTensor]] = None
# MEGATRON_BERT_START_DOCSTRING 变量包含了关于 Megatron-BERT 模型的文档字符串,描述了其继承自 PreTrainedModel 的特性,
# 并提供了关于如何使用这个模型的基本信息,包括参数配置和模型行为的说明。

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`MegatronBertConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.



# MEGATRON_BERT_INPUTS_DOCSTRING 变量当前为空字符串,应该用来描述 Megatron-BERT 模型的输入说明文档。

MEGATRON_BERT_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            # 输入序列标记在词汇表中的索引。
            # 可以使用 `AutoTokenizer` 获得这些索引。参见 `PreTrainedTokenizer.encode` 和 `PreTrainedTokenizer.__call__`。
            # [什么是输入 ID?](../glossary#input-ids)
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            # 遮罩,避免在填充标记索引上执行注意力操作。遮罩值在 `[0, 1]` 范围内选择:
            # - 1 表示不遮罩的标记,
            # - 0 表示遮罩的标记。
            # [什么是注意力遮罩?](../glossary#attention-mask)
        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            # 段标记索引,指示输入的第一部分和第二部分。索引在 `[0, 1]` 中选择:
            # - 0 对应 *句子 A* 的标记,
            # - 1 对应 *句子 B* 的标记。
            # [什么是标记类型 ID?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            # 每个输入序列标记在位置嵌入中的位置索引。选择范围是 `[0, config.max_position_embeddings - 1]`。
            # [什么是位置 ID?](../glossary#position-ids)
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            # 用于将自注意力模块的特定头部置零的遮罩。遮罩值在 `[0, 1]` 范围内选择:
            # - 1 表示头部未被遮罩,
            # - 0 表示头部被遮罩。
        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
            # 可选,可以直接传递嵌入表示而不是 `input_ids`。如果想更精细地控制如何将 `input_ids` 索引转换为关联向量,这将很有用。
            # 这对于比模型内部嵌入查找矩阵更有控制的情况很有用。
        output_attentions (`bool`, *optional*):
            # 是否返回所有注意力层的注意力张量。有关更多详细信息,请参见返回张量中的 `attentions`。
        output_hidden_states (`bool`, *optional*):
            # 是否返回所有层的隐藏状态。有关更多详细信息,请参见返回张量中的 `hidden_states`。
        return_dict (`bool`, *optional*):
            # 是否返回 [`~utils.ModelOutput`] 而不是普通元组。
"""
@add_start_docstrings(
    "The bare MegatronBert Model transformer outputting raw hidden-states without any specific head on top.",
    MEGATRON_BERT_START_DOCSTRING,
)
class MegatronBertModel(MegatronBertPreTrainedModel):
    """
    MegatronBertModel类继承自MegatronBertPreTrainedModel,代表一个裸的MegatronBert模型,输出没有特定头部的原始隐藏状态。

    这个模型可以作为编码器(只有自注意力)或解码器使用。当作为解码器时,在自注意力层之间会添加一个交叉注意力层,遵循[Attention is
    all you need](https://arxiv.org/abs/1706.03762)中描述的架构,作者包括Ashish Vaswani、Noam Shazeer、Niki Parmar、
    Jakob Uszkoreit、Llion Jones、Aidan N. Gomez、Lukasz Kaiser和Illia Polosukhin。

    要作为解码器使用,需要用`is_decoder`参数设置为`True`来初始化模型配置。要在Seq2Seq模型中使用,需要用`is_decoder`和
    `add_cross_attention`参数都设置为`True`来初始化;此时前向传播期望一个`encoder_hidden_states`作为输入。
    """

    def __init__(self, config, add_pooling_layer=True):
        super().__init__(config)
        self.config = config

        # 初始化嵌入层和编码器
        self.embeddings = MegatronBertEmbeddings(config)
        self.encoder = MegatronBertEncoder(config)

        # 如果add_pooling_layer为True,初始化池化层
        self.pooler = MegatronBertPooler(config) if add_pooling_layer else None

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

    def _prune_heads(self, heads_to_prune):
        """
        剪枝模型中的注意力头。heads_to_prune: {layer_num: 要在该层剪枝的头列表} 参见基类PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    @add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=BaseModelOutputWithPoolingAndCrossAttentions,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
"""
MegatronBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
`next sentence prediction (classification)` head.
"""
# 声明一个 MegatronBertForPreTraining 类,继承自 MegatronBertPreTrainedModel 类
class MegatronBertForPreTraining(MegatronBertPreTrainedModel):
    # 定义一个列表,包含了与权重绑定相关的键值
    _tied_weights_keys = ["cls.predictions.decoder"]

    # 初始化函数,接受配置对象 config 和一个布尔型参数 add_binary_head
    def __init__(self, config, add_binary_head=True):
        # 调用父类的初始化函数
        super().__init__(config)

        # 创建一个 MegatronBertModel 对象
        self.bert = MegatronBertModel(config)
        # 创建一个 MegatronBertPreTrainingHeads 对象
        self.cls = MegatronBertPreTrainingHeads(config)

        # 调用对象的后初始化方法
        self.post_init()

    # 获取输出嵌入的方法
    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    # 设置输出嵌入的方法,接受一个新的嵌入张量 new_embeddings
    def set_output_embeddings(self, new_embeddings):
        self.cls.predictions.decoder = new_embeddings

    # 前向传播函数,接受多个输入参数,具体功能请参考文档中的说明
    @add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=MegatronBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        next_sentence_label: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 函数的具体实现由装饰器提供,用于替换文档字符串和返回值的描述

"""
MegatronBert Model with a `language modeling` head on top for CLM fine-tuning.
"""
# 声明一个 MegatronBertForCausalLM 类,继承自 MegatronBertPreTrainedModel 类
class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
    # 定义一个列表,包含了与权重绑定相关的键值
    _tied_weights_keys = ["cls.predictions.decoder"]

    # 初始化函数,接受配置对象 config
    def __init__(self, config):
        # 调用父类的初始化函数
        super().__init__(config)

        # 如果配置对象不是解码器,发出警告信息
        if not config.is_decoder:
            logger.warning("If you want to use `MegatronBertForCausalLM` as a standalone, add `is_decoder=True.`")

        # 创建一个 MegatronBertModel 对象,关闭额外的池化层
        self.bert = MegatronBertModel(config, add_pooling_layer=False)
        # 创建一个 MegatronBertOnlyMLMHead 对象
        self.cls = MegatronBertOnlyMLMHead(config)

        # 调用对象的后初始化方法
        self.post_init()

    # 获取输出嵌入的方法
    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    # 设置输出嵌入的方法,接受一个新的嵌入张量 new_embeddings
    def set_output_embeddings(self, new_embeddings):
        self.cls.predictions.decoder = new_embeddings

    # 前向传播函数,接受多个输入参数,具体功能请参考文档中的说明
    @add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 此方法用于模型的前向传播,接收多个输入参数并返回模型输出
        # 可选参数中包括输入张量、注意力掩码、token类型ID、位置ID等
        # 返回包括预测标签、隐藏状态等,具体返回方式由return_dict参数控制
        pass

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
        # 准备生成过程中的输入,根据输入ID、过去键值等准备生成所需的输入格式
        input_shape = input_ids.shape

        # 如果没有给定注意力掩码,则创建全1的注意力掩码
        if attention_mask is None:
            attention_mask = input_ids.new_ones(input_shape)

        # 如果使用了过去的键值(past_key_values),则调整输入ID,移除前缀部分
        if past_key_values is not None:
            past_length = past_key_values[0][0].shape[2]

            # 一些生成方法可能已经只传递了最后一个输入ID
            if input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                # 默认行为:仅保留最后一个ID
                remove_prefix_length = input_ids.shape[1] - 1

            input_ids = input_ids[:, remove_prefix_length:]

        # 返回包含输入ID、注意力掩码和过去键值的字典
        return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}

    def _reorder_cache(self, past_key_values, beam_idx):
        # 重新排序缓存中的过去键值,以适应beam搜索的顺序
        reordered_past = ()
        for layer_past in past_key_values:
            # 对每一层的过去状态进行重新排序
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        return reordered_past
# 使用装饰器为 MegatronBertForMaskedLM 类添加文档字符串,描述其作为 MegatronBert 模型并带有语言建模头部的特性
@add_start_docstrings("""MegatronBert Model with a `language modeling` head on top.""", MEGATRON_BERT_START_DOCSTRING)
class MegatronBertForMaskedLM(MegatronBertPreTrainedModel):
    # 定义权重共享的关键字列表
    _tied_weights_keys = ["cls.predictions.decoder"]

    # 初始化函数,接受一个配置对象作为参数
    def __init__(self, config):
        # 调用父类的初始化函数
        super().__init__(config)

        # 如果配置标记为解码器,则发出警告
        if config.is_decoder:
            logger.warning(
                "If you want to use `MegatronBertForMaskedLM` make sure `config.is_decoder=False` for "
                "bi-directional self-attention."
            )

        # 创建 MegatronBertModel 实例,禁用池化层
        self.bert = MegatronBertModel(config, add_pooling_layer=False)
        # 创建 MegatronBertOnlyMLMHead 实例
        self.cls = MegatronBertOnlyMLMHead(config)

        # 初始化权重并应用最终处理
        self.post_init()

    # 返回输出嵌入层的方法
    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    # 设置输出嵌入层的方法,接受新的嵌入层作为参数
    def set_output_embeddings(self, new_embeddings):
        self.cls.predictions.decoder = new_embeddings

    # 前向传播方法,接受多个输入和控制参数
    @add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=MaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        """
        # 初始化返回字典,如果未提供则根据配置决定是否使用返回字典
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用BERT模型进行前向传播
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从BERT模型输出中获取序列输出
        sequence_output = outputs[0]

        # 将序列输出传入分类器,得到预测得分
        prediction_scores = self.cls(sequence_output)

        # 初始化masked_lm_loss为None
        masked_lm_loss = None

        # 如果提供了标签,则计算masked language modeling损失
        if labels is not None:
            # 使用交叉熵损失函数,忽略标签为-100的token(padding token)
            loss_fct = CrossEntropyLoss()  # -100 index = padding token
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

        # 如果return_dict为False,则按非字典方式返回结果
        if not return_dict:
            output = (prediction_scores,) + outputs[2:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        # 如果return_dict为True,则按MaskedLMOutput对象方式返回结果
        return MaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
        # 获取输入的形状信息和有效的批量大小
        input_shape = input_ids.shape
        effective_batch_size = input_shape[0]

        # 如果未定义pad_token_id,则无法进行生成,抛出异常
        if self.config.pad_token_id is None:
            raise ValueError("The PAD token should be defined for generation")

        # 在attention_mask末尾添加一个虚拟token
        attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)

        # 创建一个全是pad_token_id的虚拟token张量
        dummy_token = torch.full(
            (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
        )

        # 将虚拟token添加到输入ids的末尾
        input_ids = torch.cat([input_ids, dummy_token], dim=1)

        # 返回生成模型需要的输入字典
        return {"input_ids": input_ids, "attention_mask": attention_mask}
# 使用装饰器添加文档字符串,描述了 MegatronBertForNextSentencePrediction 类的作用及其顶部的文档信息
@add_start_docstrings(
    """MegatronBert Model with a `next sentence prediction (classification)` head on top.""",
    MEGATRON_BERT_START_DOCSTRING,
)
# 定义 MegatronBertForNextSentencePrediction 类,继承自 MegatronBertPreTrainedModel
class MegatronBertForNextSentencePrediction(MegatronBertPreTrainedModel):
    
    # 初始化方法,接受一个 config 对象作为参数
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)

        # 创建 MegatronBertModel 实例,并赋值给 self.bert
        self.bert = MegatronBertModel(config)
        
        # 创建 MegatronBertOnlyNSPHead 实例,并赋值给 self.cls
        self.cls = MegatronBertOnlyNSPHead(config)

        # 初始化权重并应用最终处理
        self.post_init()

    # 使用装饰器添加文档字符串,描述了 forward 方法的输入参数及其作用
    @add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    # 替换返回文档字符串,指定输出类型为 NextSentencePredictorOutput,配置类为 _CONFIG_FOR_DOC
    @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
    # 前向传播方法,接受多个输入参数和 **kwargs
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs,
    ):
    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, NextSentencePredictorOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
            (see `input_ids` docstring). Indices should be in `[0, 1]`:
            
            - 0 indicates sequence B is a continuation of sequence A,
            - 1 indicates sequence B is a random sequence.
    
        Returns:
            Depending on `return_dict`:
            - If `return_dict=False` (default): returns a tuple with `seq_relationship_scores` followed by `outputs[2:]`.
            - If `return_dict=True`: returns a `NextSentencePredictorOutput` containing loss, logits, hidden states, and attentions.
    
        Example:
        ```
        >>> from transformers import AutoTokenizer, MegatronBertForNextSentencePrediction
        >>> import torch
    
        >>> tokenizer = AutoTokenizer.from_pretrained("nvidia/megatron-bert-cased-345m")
        >>> model = MegatronBertForNextSentencePrediction.from_pretrained("nvidia/megatron-bert-cased-345m")
    
        >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
        >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
        >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
    
        >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
        >>> logits = outputs.logits
        >>> assert logits[0, 0] < logits[0, 1]  # next sentence was random
        ```
    
        if "next_sentence_label" in kwargs:
            warnings.warn(
                "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
                " `labels` instead.",
                FutureWarning,
            )
            labels = kwargs.pop("next_sentence_label")
    
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    
        # Pass input tensors through the BERT model to get outputs
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
    
        # Get the pooled output from BERT's outputs
        pooled_output = outputs[1]
    
        # Predict next sentence relationship using a classifier layer
        seq_relationship_scores = self.cls(pooled_output)
    
        next_sentence_loss = None
        # Compute loss if labels are provided
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
    
        # Return outputs based on `return_dict` flag
        if not return_dict:
            output = (seq_relationship_scores,) + outputs[2:]
            return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
    
        # Return a `NextSentencePredictorOutput` object if `return_dict=True`
        return NextSentencePredictorOutput(
            loss=next_sentence_loss,
            logits=seq_relationship_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
@add_start_docstrings(
    """
    MegatronBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    """,
    MEGATRON_BERT_START_DOCSTRING,
)
class MegatronBertForSequenceClassification(MegatronBertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        # 初始化 Bert 模型和相关组件
        self.bert = MegatronBertModel(config)
        # Dropout 层,用于减少过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # 分类器,线性层,将 BERT 输出映射到标签空间
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        # 初始化权重并进行后续处理
        self.post_init()

    @add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=SequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        """
        前向传播函数,处理输入并生成模型输出。

        Args:
            input_ids (Optional[torch.LongTensor], optional): 输入的 token IDs. Defaults to None.
            attention_mask (Optional[torch.FloatTensor], optional): 注意力掩码,指示哪些元素是填充的. Defaults to None.
            token_type_ids (Optional[torch.LongTensor], optional): token 类型 IDs,区分 segment A 和 segment B. Defaults to None.
            position_ids (Optional[torch.LongTensor], optional): token 的位置 IDs. Defaults to None.
            head_mask (Optional[torch.FloatTensor], optional): 多头注意力机制的掩码. Defaults to None.
            inputs_embeds (Optional[torch.FloatTensor], optional): 嵌入式表示的输入. Defaults to None.
            labels (Optional[torch.LongTensor], optional): 标签,用于计算损失. Defaults to None.
            output_attentions (Optional[bool], optional): 是否返回注意力权重. Defaults to None.
            output_hidden_states (Optional[bool], optional): 是否返回所有隐藏状态. Defaults to None.
            return_dict (Optional[bool], optional): 是否以字典形式返回输出. Defaults to None.

        Returns:
            SequenceClassifierOutput: 包含模型输出和损失的对象
        """
        # BERT 模型的 forward 方法,处理输入并生成模型输出
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        pooled_output = outputs[1]  # 取出池化输出,通常用于分类任务

        pooled_output = self.dropout(pooled_output)  # 应用 dropout 防止过拟合
        logits = self.classifier(pooled_output)  # 使用线性分类器映射到标签空间

        return SequenceClassifierOutput(
            logits=logits,
            hidden_states=outputs.hidden_states if output_hidden_states else None,
            attentions=outputs.attentions if output_attentions else None,
        )
    ) -> Union[Tuple, SequenceClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        # 初始化返回字典,如果未提供则使用配置中的默认值
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 将输入传递给BERT模型进行处理,并获取其输出
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从BERT模型的输出中获取汇聚的输出表示
        pooled_output = outputs[1]

        # 对汇聚的输出表示进行dropout操作
        pooled_output = self.dropout(pooled_output)

        # 将dropout后的输出传递给分类器,得到预测的logits
        logits = self.classifier(pooled_output)

        # 初始化损失为None
        loss = None

        # 如果提供了标签,则计算相应的损失
        if labels is not None:
            # 如果问题类型未定义,则根据标签类型和类数自动推断问题类型
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            # 根据问题类型选择合适的损失函数
            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        # 如果不需要返回字典,则按照非字典返回格式组织输出
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 如果需要返回字典,则创建SequenceClassifierOutput对象,并返回
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
# 定义 MegatronBertForMultipleChoice 类,继承自 MegatronBertPreTrainedModel,用于多项选择任务的 Megatron-BERT 模型
@add_start_docstrings(
    """
    MegatronBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output
    and a softmax) e.g. for RocStories/SWAG tasks.
    """,
    MEGATRON_BERT_START_DOCSTRING,
)
class MegatronBertForMultipleChoice(MegatronBertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        # 初始化 Megatron-BERT 模型
        self.bert = MegatronBertModel(config)
        # Dropout 层,用于随机断开神经元连接,防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # 分类器,线性层,将 BERT 隐藏层的输出映射到一个值,用于多项选择的分类
        self.classifier = nn.Linear(config.hidden_size, 1)

        # 初始化权重并应用最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(
        MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
    )
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=MultipleChoiceModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    # 前向传播函数,接收多个输入和控制参数,返回模型输出
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        # 函数参数说明文档
        ) -> Union[Tuple, MultipleChoiceModelOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
            `input_ids` above)
        """
        # 根据需要确定是否返回字典格式的输出
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # 获取输入张量的选择数
        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

        # 重新调整输入张量的形状,将其视为二维张量
        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
        # 如果存在输入嵌入,则将其视为三维张量
        inputs_embeds = (
            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
            if inputs_embeds is not None
            else None
        )

        # 使用BERT模型处理输入
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 提取池化后的输出
        pooled_output = outputs[1]

        # 对池化后的输出进行dropout
        pooled_output = self.dropout(pooled_output)
        # 使用分类器预测logits
        logits = self.classifier(pooled_output)
        # 调整logits的形状以匹配选择数
        reshaped_logits = logits.view(-1, num_choices)

        loss = None
        # 如果有提供标签,则计算交叉熵损失
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)

        # 如果不要求返回字典格式的输出,则按元组格式返回结果
        if not return_dict:
            output = (reshaped_logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 如果要求返回字典格式的输出,则创建MultipleChoiceModelOutput对象
        return MultipleChoiceModelOutput(
            loss=loss,
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
"""
MegatronBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
for Named-Entity-Recognition (NER) tasks.
"""
@add_start_docstrings(
    """
    MegatronBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
    for Named-Entity-Recognition (NER) tasks.
    """,
    MEGATRON_BERT_START_DOCSTRING,
)
class MegatronBertForTokenClassification(MegatronBertPreTrainedModel):
    def __init__(self, config):
        """
        Initialize the MegatronBertForTokenClassification model.

        Args:
            config (MegatronBertConfig): Configuration object specifying the model architecture and hyperparameters.
        """
        super().__init__(config)
        self.num_labels = config.num_labels

        # Initialize the MegatronBertModel with pooling layer excluded
        self.bert = MegatronBertModel(config, add_pooling_layer=False)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        # Initialize weights and apply final processing
        self.post_init()

    @add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, TokenClassifierOutput]:
        """
        Forward pass of the MegatronBertForTokenClassification model.

        Args:
            input_ids (torch.LongTensor, optional): Tensor of shape `(batch_size, sequence_length)` containing input token IDs.
            attention_mask (torch.FloatTensor, optional): Tensor of shape `(batch_size, sequence_length)` containing attention masks.
            token_type_ids (torch.LongTensor, optional): Tensor of shape `(batch_size, sequence_length)` containing token type IDs.
            position_ids (torch.LongTensor, optional): Tensor of shape `(batch_size, sequence_length)` containing position IDs.
            head_mask (torch.FloatTensor, optional): Tensor of shape `(batch_size, sequence_length)` containing attention head masks.
            inputs_embeds (torch.FloatTensor, optional): Tensor of shape `(batch_size, sequence_length, hidden_size)` containing precomputed embeddings.
            labels (torch.LongTensor, optional): Tensor of shape `(batch_size, sequence_length)` containing labels for computing token classification loss.
            output_attentions (bool, optional): Whether to output attentions.
            output_hidden_states (bool, optional): Whether to output hidden states.
            return_dict (bool, optional): Whether to return outputs as a dictionary.

        Returns:
            Union[Tuple, TokenClassifierOutput]: Depending on `return_dict`, either a tuple or a `TokenClassifierOutput` object.

        Notes:
            - Labels should be in the range `[0, ..., config.num_labels - 1]`.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Perform the forward pass through MegatronBertModel
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        # Apply dropout on the output of the BERT model
        sequence_output = self.dropout(sequence_output)
        
        # Pass the modified output through the classifier layer
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            # Compute the token classification loss
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            # Prepare output tuple if return_dict is False
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # Return TokenClassifierOutput object if return_dict is True
        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
# 使用自定义的文档字符串描述 MegatronBertForQuestionAnswering 类,它是基于 Megatron-BERT 模型的抽取式问答任务模型,
# 在隐藏状态输出的基础上加上线性层,用于计算 `span start logits` 和 `span end logits`。
@add_start_docstrings(
    """
    MegatronBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
    linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
    MEGATRON_BERT_START_DOCSTRING,
)
class MegatronBertForQuestionAnswering(MegatronBertPreTrainedModel):
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)
        # 设置类别数目
        self.num_labels = config.num_labels

        # 初始化 Megatron-BERT 模型,不添加池化层
        self.bert = MegatronBertModel(config, add_pooling_layer=False)
        # QA 输出层,线性层,输入为隐藏状态大小,输出为类别数目
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        # 初始化权重并进行最终处理
        self.post_init()

    # 使用自定义的文档字符串描述 forward 方法的输入参数和功能
    @add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    # 使用代码示例的文档字符串描述 forward 方法的返回值类型和相关配置
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=QuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        start_positions: Optional[torch.LongTensor] = None,
        end_positions: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        ) -> Union[Tuple, QuestionAnsweringModelOutput]:
        r"""
        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        """
        # 如果 return_dict 未指定,则使用配置中的默认设置
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用 BERT 模型进行前向传播
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取 BERT 输出的序列表示
        sequence_output = outputs[0]

        # 将序列表示传递给 QA 输出层,得到起始位置和结束位置的 logits
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()  # 压缩维度并确保连续存储
        end_logits = end_logits.squeeze(-1).contiguous()  # 压缩维度并确保连续存储

        total_loss = None
        # 如果提供了起始和结束位置,则计算损失
        if start_positions is not None and end_positions is not None:
            # 如果是多 GPU 情况下,需要添加一个维度
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # 忽略超出模型输入范围的位置
            ignored_index = start_logits.size(1)
            start_positions = start_positions.clamp(0, ignored_index)
            end_positions = end_positions.clamp(0, ignored_index)

            # 使用交叉熵损失函数计算起始和结束位置的损失
            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

        # 如果不需要返回字典格式的输出,则按原样返回结果
        if not return_dict:
            output = (start_logits, end_logits) + outputs[2:]  # 加入额外的输出
            return ((total_loss,) + output) if total_loss is not None else output

        # 返回 QuestionAnsweringModelOutput 格式的结果
        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

.\models\megatron_bert\__init__.py

# 引入必要的模块和函数,包括类型检查和依赖检查
from typing import TYPE_CHECKING
# 从相对路径导入自定义的异常和模块加载器
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available

# 定义模块导入结构的字典,包括配置和模型名称
_import_structure = {
    "configuration_megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"],
}

# 尝试检查是否存在 Torch 库,如果不存在则引发自定义异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果 Torch 可用,则添加 Megatron BERT 模型相关的模块和类到导入结构中
    _import_structure["modeling_megatron_bert"] = [
        "MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
        "MegatronBertForCausalLM",
        "MegatronBertForMaskedLM",
        "MegatronBertForMultipleChoice",
        "MegatronBertForNextSentencePrediction",
        "MegatronBertForPreTraining",
        "MegatronBertForQuestionAnswering",
        "MegatronBertForSequenceClassification",
        "MegatronBertForTokenClassification",
        "MegatronBertModel",
        "MegatronBertPreTrainedModel",
    ]

# 如果是类型检查阶段,则执行以下导入
if TYPE_CHECKING:
    # 从相对路径导入 Megatron BERT 的配置和模型定义
    from .configuration_megatron_bert import MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronBertConfig

    # 尝试检查是否存在 Torch 库,如果不存在则引发自定义异常
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 从相对路径导入 Megatron BERT 的模型定义
        from .modeling_megatron_bert import (
            MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
            MegatronBertForCausalLM,
            MegatronBertForMaskedLM,
            MegatronBertForMultipleChoice,
            MegatronBertForNextSentencePrediction,
            MegatronBertForPreTraining,
            MegatronBertForQuestionAnswering,
            MegatronBertForSequenceClassification,
            MegatronBertForTokenClassification,
            MegatronBertModel,
            MegatronBertPreTrainedModel,
        )

# 如果不是类型检查阶段,则执行以下懒加载模块的设置
else:
    # 导入 sys 模块用于动态修改当前模块对象
    import sys

    # 使用自定义的惰性模块加载器将当前模块替换为 LazyModule 类的实例
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\megatron_gpt2\checkpoint_reshaping_and_interoperability.py

# 导入必要的库和模块
import argparse  # 导入处理命令行参数的模块
import json  # 导入处理 JSON 数据的模块
import os  # 导入处理操作系统相关功能的模块
import re  # 导入正则表达式模块
import sys  # 导入系统相关的功能模块
import types  # 导入 types 模块,用于操作类型信息

import torch  # 导入 PyTorch 深度学习库

# 导入 transformers 相关模块和类
from transformers import AutoTokenizer, GPT2Config
from transformers.modeling_utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, shard_checkpoint


def add_checkpointing_args(parser):
    # 添加命令行参数:Megatron 代码库的基本目录
    parser.add_argument("--megatron-path", type=str, default=None, help="Base directory of Megatron repository")
    # 添加命令行参数:是否进行 Megatron 到 Transformers 的检查点转换
    parser.add_argument(
        "--convert_checkpoint_from_megatron_to_transformers",
        action="store_true",
        help=(
            "If True, convert a Megatron checkpoint to a Transformers checkpoint. "
            "If False, convert a Transformers checkpoint to a Megatron checkpoint."
        ),
    )
    # 添加命令行参数:待转换的检查点路径
    parser.add_argument(
        "--load_path",
        type=str,
        required=True,
        help="Path to the checkpoint to convert.",
    )
    # 添加命令行参数:转换后保存的检查点路径
    parser.add_argument(
        "--save_path",
        type=str,
        required=True,
        help="Path to the converted checkpoint.",
    )
    # 添加命令行参数:是否打印检查点结构
    parser.add_argument("--print-checkpoint-structure", action="store_true")
    return parser


def add_megatron_checkpoint_args(parser):
    # 添加命令行参数:转换后的张量模型并行大小
    parser.add_argument(
        "--target_tensor_model_parallel_size",
        type=int,
        default=1,
        help=(
            "The tensor model parallel size of the converted checkpoint. "
            "Only used when converting a Transformers checkpoint to a Megatron checkpoint."
        ),
    )
    # 添加命令行参数:转换后的管道模型并行大小
    parser.add_argument(
        "--target_pipeline_model_parallel_size",
        type=int,
        default=1,
        help=(
            "The pipeline model parallel size of the converted checkpoint. "
            "Only used when converting a Transformers checkpoint to a Megatron checkpoint."
        ),
    )
    # 添加命令行参数:转换后的数据并行大小
    parser.add_argument(
        "--target_data_parallel_size",
        type=int,
        default=1,
        help=(
            "The data parallel size of the converted checkpoint. "
            "Only used when converting a Transformers checkpoint to a Megatron checkpoint."
        ),
    )
    # 添加命令行参数:转换后的参数数据类型
    parser.add_argument(
        "--target_params_dtype",
        type=str,
        default="fp32",
        help=(
            "The dtype of the converted checkpoint. "
            "Only used when converting a Transformers checkpoint to a Megatron checkpoint."
        ),
    )
    # 添加命令行参数:使得词汇表大小可被此值整除
    parser.add_argument(
        "--make_vocab_size_divisible_by",
        type=int,
        default=128,
        help=(
            "Pad the vocab size to be divisible by this value. "
            "This is added for computational efficiency reasons. "
            "Only used when converting a Transformers checkpoint to a Megatron checkpoint."
        ),
    )
    # 添加命令行参数:使用分布式优化器
    parser.add_argument(
        "--use_distributed_optimizer",
        action="store_true",
        help=(
            "If True, use the distributed optimizer. "
            "Only used when converting a Transformers checkpoint to a Megatron checkpoint."
        ),
    )
    # 返回配置好命令行参数的解析器对象
    return parser
def add_transformers_checkpoint_args(parser):
    """
    添加 Transformers 检查点的参数到解析器中。

    Args:
        parser (ArgumentParser): 解析器对象,用于添加参数

    Returns:
        ArgumentParser: 更新后的解析器对象
    """
    parser.add_argument(
        "--tokenizer_name",
        type=str,
        default=None,
        help=(
            "要保存的预训练分词器的名称。如果不是 None,则会保存分词器。"
            "仅在将 Megatron 检查点转换为 Transformers 检查点时使用。"
        ),
    )
    parser.add_argument(
        "--max_shard_size",
        type=str,
        default="10GB",
        help=(
            "在分片之前检查点的最大大小。检查点分片将小于此大小。"
            "如果表示为字符串,需由数字后跟单位(如 `5MB`)组成。"
            "仅在将 Megatron 检查点转换为 Transformers 检查点时使用。"
        ),
    )

    return parser


# "automated" rules 名称映射的简单映射。
megatron_to_transformers = {
    "attention.dense": ".attn.c_proj.",
    "self_attention.dense": ".attn.c_proj.",
    "mlp.dense_h_to_4h": ".mlp.c_fc.",
    "mlp.dense_4h_to_h": ".mlp.c_proj.",
}
# 从 transformers 到 megatron 的反向映射。
transformers_to_megatron = {v[1:-1]: k for k, v in megatron_to_transformers.items()}

tensor_parallel_params = [
    # 在 tp ranks 之间合并的 megatron-lm 层
    "self_attention.query_key_value.weight",
    "self_attention.query_key_value.bias",
    "self_attention.dense.weight",
    "mlp.dense_h_to_4h.weight",
    "mlp.dense_h_to_4h.bias",
    "mlp.dense_4h_to_h.weight",
    # 已弃用
    "attention.query_key_value.weight",
    "attention.query_key_value.bias",
    "attention.dense.weight",
    # 在 tp ranks 之间分割的 transformers 层
    "attn.c_attn.weight",
    "attn.c_attn.bias",
    "attn.c_proj.weight",
    "mlp.c_fc.weight",
    "mlp.c_fc.bias",
    "mlp.c_proj.weight",
]


def recursive_print(name, val, spaces=0):
    """
    递归打印检查点的结构。此函数源自 `convert_megatron_gpt2_checkpoint.py`。

    Args:
        name (str): 当前张量参数的名称
        val (Tuple(int)): 当前张量参数的形状
        spaces (int): 输出嵌套结构之前的空格数
    """
    # 格式化消息。
    if name is None:
        msg = None
    else:
        fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}"
        msg = fmt.format(name)

    # 打印并递归(如果需要)。
    if isinstance(val, dict):
        if msg is not None:
            print(msg)
        for k in val.keys():
            recursive_print(k, val[k], spaces + 2)
    elif isinstance(val, torch.Tensor):
        print(msg, ":", val.size())
    else:
        print(msg, ":", val)


def megatron_to_transformers_fix_query_key_value_ordering(
    param, checkpoint_version, num_splits, num_heads, hidden_size
):
    """
    重新排列 param 张量的布局,以便与后续版本兼容为 [num_splits * num_heads * hidden_size, :]。

    Args:
        param: 要重新排列的参数张量
        checkpoint_version: 检查点版本
        num_splits: 分片数
        num_heads: 头数
        hidden_size: 隐藏大小
    """
    # 获取输入张量的形状
    input_shape = param.size()
    
    # 根据不同的检查点版本进行张量重排
    if checkpoint_version == 1.0:
        # 版本 1.0 存储格式为 [num_heads * hidden_size * num_splits, :]
        saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:]
        # 重塑张量形状
        param = param.view(*saved_shape)
        # 转置操作,调整张量维度顺序
        param = param.transpose(0, 2)
        param = param.transpose(1, 2).contiguous()
    elif checkpoint_version >= 2.0:
        # 其他版本存储格式为 [num_heads * num_splits * hidden_size, :]
        saved_shape = (num_heads, num_splits, hidden_size) + input_shape[1:]
        # 重塑张量形状
        param = param.view(*saved_shape)
        param = param.transpose(0, 1).contiguous()
    
    # 恢复原始张量形状
    param = param.view(*input_shape)
    
    # 返回处理后的张量
    return param
# 将参数张量的布局重新排列,以适应相应的NVIDIA Megatron-LM检查点版本。输入形状为 [num_splits * num_heads * hidden_size, :],输出形状根据版本不同分别为 [num_heads * hidden_size * num_splits, :](版本 1.0及之前)和 [num_heads * num_splits * hidden_size, :](版本 2.0及之后)。如果参数是自注意力块的权重张量,则在调用此函数之前需要已经进行转置。

def transformers_to_megatron_fix_query_key_value_ordering(
    param, checkpoint_version, num_splits, num_heads, hidden_size
):
    """
    Permutes layout of param tensor to the one compatible with respective NVIDIA Megatron-LM chekpoint versions. Input
    is [num_splits * num_heads * hidden_size, :] and output is [num_heads * hidden_size * num_splits, :] for version
    1.0 and [num_heads * num_splits * hidden_size, :] for version 2.0 and later. If param is the weight tensor of the
    self-attention block, the param needs to be already transposed before calling this function.

    Args:
        param (torch.Tensor): the tensor to permute
        checkpoint_version (int): the version of the checkpoint.
        num_splits (int): the number of projections, usually 3 for (Query, Key, Value)
        num_heads (int): the number of attention heads
        hidden_size (int): the hidden size per head
    """

    # 获取输入张量的形状
    input_shape = param.size()
    if checkpoint_version == 1.0:
        # 对于版本 1.0,存储结构为 [num_heads * hidden_size * num_splits, :]
        current_shape = (num_splits, num_heads, hidden_size) + input_shape[1:]
        # 调整张量的形状和顺序以匹配版本 1.0 的要求
        param = param.view(*current_shape)
        param = param.transpose(0, 2)
        param = param.transpose(1, 2).contiguous()
    elif checkpoint_version >= 2.0:
        # 对于版本 2.0 及更高,存储结构为 [num_heads * num_splits * hidden_size, :]
        current_shape = (num_splits, num_heads, hidden_size) + input_shape[1:]
        # 调整张量的形状和顺序以匹配版本 2.0 及更高的要求
        param = param.view(*current_shape)
        param = param.transpose(0, 1).contiguous()
    # 恢复原始张量的形状
    param = param.view(*input_shape)
    return param


# 从transformers的分片检查点中合并成一个单一检查点。
def merge_transformers_sharded_states(path, num_checkpoints):
    """
    Merge sharded checkpoints from transformers into a single checkpoint.

    Args:
        path (str): the path to the sharded checkpoints
        num_checkpoints (int): the number of checkpoints to merge
    """
    # 创建一个空的状态字典用于存储合并后的状态
    state_dict = {}
    for i in range(1, num_checkpoints + 1):
        # 构建每个分片检查点的完整路径
        checkpoint_path = os.path.join(path, f"pytorch_model-{i:05d}-of-{num_checkpoints:05d}.bin")
        # 加载当前分片的检查点到内存中
        current_chunk = torch.load(checkpoint_path, map_location="cpu")
        # 将当前分片的状态字典更新到总的状态字典中
        state_dict.update(current_chunk)
    return state_dict


# 从NVIDIA Megatron-LM检查点中获取分片状态,基于提供的张量并行大小、管道并行大小和管道并行等级。
def get_megatron_sharded_states(args, tp_size, pp_size, pp_rank):
    """
    Get sharded checkpoints from NVIDIA Megatron-LM checkpoint based on the provided tensor parallel size, pipeline
    parallel size and pipeline parallel rank.

    Args:
        args (argparse.Namespace): the arguments to the script
        tp_size (int): the tensor parallel size
        pp_size (int): the pipeline parallel size
        pp_rank (int): the pipeline parallel rank
    """
    # 创建一个空列表来存储张量并行状态字典
    tp_state_dicts = []
    # 遍历指定范围内的整数,生成索引 i,范围从 0 到 tp_size-1
    for i in range(tp_size):
        # 根据进程数 pp_size 的情况生成子目录名
        sub_dir_name = f"mp_rank_{i:02d}" if pp_size == 1 else f"mp_rank_{i:02d}_{pp_rank:03d}"
        
        # 遍历检查点文件名列表,查找存在于文件系统中的第一个检查点文件
        for checkpoint_name in ["model_optim_rng.pt", "model_rng.pt"]:
            # 构建完整的检查点文件路径
            checkpoint_path = os.path.join(args.load_path, sub_dir_name, checkpoint_name)
            
            # 如果找到该路径对应的文件存在,则跳出当前循环
            if os.path.isfile(checkpoint_path):
                break
        
        # 使用 CPU 加载检查点文件的状态字典,并存储到列表中
        state_dict = torch.load(checkpoint_path, map_location="cpu")
        tp_state_dicts.append(state_dict)
    
    # 返回包含所有加载状态字典的列表
    return tp_state_dicts
# 根据路径从字典中获取元素。如果元素不存在,则递归添加空字典。
def get_element_from_dict_by_path(d, path):
    # 将路径字符串按 "." 分割为列表
    path = path.split(".")
    # 遍历路径中的每个键
    for k in path:
        # 如果当前键不在字典中,将其添加为一个空字典
        if k not in d:
            d[k] = {}
        # 更新字典为当前键对应的值,以便继续下一级路径的处理
        d = d[k]
    # 返回最终路径对应的值,即字典中指定路径的元素
    return d


# 将 NVIDIA Megatron-LM 的检查点转换为 HuggingFace Transformers 的检查点
def convert_checkpoint_from_megatron_to_transformers(args):
    """
    Convert NVIDIA Megatron-LM checkpoint to HuggingFace Transformers checkpoint. This handles Megatron checkpoints
    with different tensor parallelism and pipeline parallelism sizes. It saves the converted checkpoint into shards
    using HuggingFace Transformers checkpoint sharding functionality. This greatly extends the functionality of
    `convert_megatron_gpt2_checkpoint.py`

    Args:
        args (argparse.Namespace): the arguments to the script
    """
    # 获取 Megatron-LM 检查点目录下的子目录列表
    sub_dirs = os.listdir(args.load_path)
    # 可能的子目录命名约定
    possible_sub_dirs = ["mp_rank_00", "mp_rank_00_000"]
    # 遍历可能的子目录,寻找符合条件的检查点路径
    for sub_dir in possible_sub_dirs:
        if sub_dir in sub_dirs:
            # 获取子目录下的第一个文件名作为检查点文件名
            rank0_checkpoint_name = os.listdir(os.path.join(args.load_path, sub_dir))[0]
            # 构建完整的检查点路径
            rank0_checkpoint_path = os.path.join(args.load_path, sub_dir, rank0_checkpoint_name)
            break
    # 打印加载 Megatron-LM 检查点参数的信息
    print(f"Loading Megatron-LM checkpoint arguments from: {rank0_checkpoint_path}")
    # 使用 torch 加载 Megatron-LM 检查点的状态字典
    state_dict = torch.load(rank0_checkpoint_path, map_location="cpu")
    # 从状态字典中获取 Megatron-LM 的参数
    megatron_args = state_dict.get("args", None)
    # 如果未找到 Megatron-LM 参数,则抛出错误
    if megatron_args is None:
        raise ValueError(
            "Megatron-LM checkpoint does not contain arguments. This utility only supports Megatron-LM checkpoints"
            " containing all the megatron arguments. This is because it loads all config related to model"
            " architecture, the tensor and pipeline model parallel size from the checkpoint insead of user having to"
            " manually specify all the details. Please save Megatron-LM checkpoint along with all the megatron"
            " arguments to use this utility."
        )

    # 根据 Megatron-LM 的参数创建 Transformers GPT2 的配置
    if megatron_args is not None:
        # 根据 Megatron-LM 的参数选择激活函数
        if megatron_args.bias_gelu_fusion:
            activation_function = "gelu_fast"
        elif megatron_args.openai_gelu:
            activation_function = "gelu_new"
        else:
            activation_function = "gelu"
    else:
        # 如果未提供 Megatron-LM 参数,默认使用 "gelu_new" 作为激活函数
        activation_function = "gelu_new"
    # 确定词汇表大小
    vocab_size = (
        megatron_args.padded_vocab_size
        if getattr(megatron_args, "orig_vocab_size", None) is None
        else megatron_args.orig_vocab_size
    )
    # 打印词汇表大小
    print(vocab_size)
    # 使用 GPT2Config 类创建配置对象,设置模型的各种参数
    config = GPT2Config(
        vocab_size=vocab_size,  # 词汇表大小
        n_positions=megatron_args.max_position_embeddings,  # 最大位置编码数
        n_embd=megatron_args.hidden_size,  # 隐藏层大小
        n_layer=megatron_args.num_layers,  # 层数
        n_head=megatron_args.num_attention_heads,  # 注意力头数
        n_inner=megatron_args.ffn_hidden_size,  # FeedForward 层的隐藏大小
        activation_function=activation_function,  # 激活函数类型
        resid_pdrop=0.1,  # 残差连接中的丢弃率
        embd_pdrop=0.1,  # 嵌入层中的丢弃率
        attn_pdrop=0.1,  # 注意力层中的丢弃率
        layer_norm_epsilon=1e-5,  # 层归一化的 epsilon 值
        initializer_range=0.02,  # 初始化范围
        summary_type="cls_index",  # 摘要类型
        summary_use_proj=True,  # 是否使用摘要投影
        summary_activation=None,  # 摘要激活函数类型
        summary_proj_to_labels=True,  # 是否将摘要投影到标签
        summary_first_dropout=0.1,  # 摘要层的第一个丢弃率
        scale_attn_weights=True,  # 是否缩放注意力权重
        use_cache=True,  # 是否使用缓存
        bos_token_id=vocab_size - 1,  # 开始标记的 ID
        eos_token_id=vocab_size - 1,  # 结束标记的 ID
        architectures=["GPT2LMHeadModel"],  # 模型架构列表
    )

    # 初始化空的状态字典
    output_state_dict = {}

    # 从状态字典中获取检查点版本,如果没有则默认为 0.0
    checkpoint_version = state_dict.get("checkpoint_version", 0.0)
    
    # 获取 tensor 模型并行的大小和 pipeline 模型并行的大小
    tp_size = megatron_args.tensor_model_parallel_size
    pp_size = megatron_args.pipeline_model_parallel_size
    
    # 设置数据类型为 torch.float32
    dtype = torch.float32
    
    # 编译正则表达式,用于提取层名称
    # 正则表达式用于匹配形式如 layers.(\d+).([a-z0-9_.]+).([a-z]+) 的字符串
    layer_re = re.compile(r"layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")

    # 输出信息:转换中
    print("Converting")

    # 输出信息:转换嵌入层
    print("Converting embeddings")
    
    # 获取 Megatron 分片状态字典的 tensor 模型并行数据
    tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, 0)

    # 获取位置嵌入并存储到 output_state_dict 中
    position_embeddings = get_element_from_dict_by_path(
        tp_state_dicts[0], "model.language_model.embedding.position_embeddings.weight"
    )
    output_state_dict["transformer.wpe.weight"] = position_embeddings.to(dtype)

    # 获取单词嵌入并存储到 output_state_dict 中
    word_embeddings = torch.cat(
        [
            get_element_from_dict_by_path(
                tp_state_dicts[tp_rank], "model.language_model.embedding.word_embeddings.weight"
            )
            for tp_rank in range(tp_size)
        ],
        dim=0,
    )
    word_embeddings = word_embeddings[:vocab_size].to(dtype)
    output_state_dict["transformer.wte.weight"] = word_embeddings

    # 输出信息:转换 transformer 层
    print("Converting transformer layers")
    
    # 获取配置中的头数和每个头的隐藏大小
    heads = config.n_head
    hidden_size_per_head = config.n_embd // config.n_head
    n_positions = config.n_positions
    
    # 计算每个 pipeline 模型并行的层数
    num_layers = config.num_hidden_layers // pp_size

    # 如果配置的层数与当前层索引不匹配,则抛出值错误
    if config.n_layer != (layer_idx + 1):
        raise ValueError(f"Expected {config.n_layer} layers but found {layer_idx + 1}")

    # 输出信息:转换最终的 layernorm 层
    print("Converting final layernorm")
    
    # 从 tp_state_dicts 中获取指定路径的参数,并存储到 output_state_dict 中
    params = get_element_from_dict_by_path(tp_state_dicts[0], str(path))
    output_state_dict["transformer.ln_f.weight"] = params["final_layernorm.weight"].to(dtype)
    output_state_dict["transformer.ln_f.bias"] = params["final_layernorm.bias"].to(dtype)

    # 输出信息:转换语言模型头
    print("Converting LM head")
    # 将 word_embeddings 的权重转换为指定的数据类型,并存入输出状态字典中
    output_state_dict["lm_head.weight"] = word_embeddings.to(dtype)

    # 输出转换完成的信息提示
    print("Conversion from Megatron-LM to Transformers is done!")

    # 如果设置了打印检查点结构的选项,则递归打印输出状态字典的结构
    if args.print_checkpoint_structure:
        recursive_print(None, output_state_dict)

    # 根据参数设置 tokenizer 的名称,若未指定则使用默认名称
    # 创建对应的 AutoTokenizer 对象
    if args.tokenizer_name is None:
        tokenizer_name = "openai-community/gpt2"
    else:
        tokenizer_name = args.tokenizer_name

    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    # 获取 tokenizer 类的名称并存入配置对象中
    tokenizer_class = type(tokenizer).__name__
    config.tokenizer_class = tokenizer_class

    # 保存配置对象到指定路径
    print("Saving config")
    config.save_pretrained(args.save_path)

    # 根据参数保存 tokenizer 到指定路径
    if args.tokenizer_name is not None:
        print(f"Adding {tokenizer_class} tokenizer files")
        tokenizer.save_pretrained(args.save_path)

    # 将输出状态字典分片并存储到文件中
    max_shard_size = int(args.max_shard_size) if args.max_shard_size.isdigit() else args.max_shard_size
    shards, index = shard_checkpoint(output_state_dict, max_shard_size=max_shard_size)

    # 逐个保存分片后的模型
    for shard_file, shard in shards.items():
        torch.save(shard, os.path.join(args.save_path, shard_file))

    # 若没有分片,则直接输出模型权重文件的保存路径
    if index is None:
        print(f"Model weights saved in {os.path.join(args.save_path, WEIGHTS_NAME)}")
    else:
        # 否则保存分片索引到文件中,并输出详细信息
        save_index_file = os.path.join(args.save_path, WEIGHTS_INDEX_NAME)
        with open(save_index_file, "w", encoding="utf-8") as f:
            content = json.dumps(index, indent=2, sort_keys=True) + "\n"
            f.write(content)
        print(
            f"The model is bigger than the maximum size per checkpoint ({args.max_shard_size}) and is going to be "
            f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
            f"index located at {save_index_file}."
        )
def convert_checkpoint_from_transformers_to_megatron(args):
    """
    Convert a checkpoint from HuggingFace Transformers to Megatron-LM. This allows converted checkpoints with variable
    tensor parallelism and pipeline parallelism sizes. It takes as input a checkpoint from HuggingFace Transformers
    which can have multiple shards.

    Args:
        args (argparse.Namespace): the arguments to the script

    """
    # 如果保存路径不存在,则创建
    os.makedirs(args.save_path, exist_ok=True)
    
    # 将父目录加入系统路径中以便搜索
    # 在当前文件的上级目录中搜索
    sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))

    # 如果指定了Megatron路径,则将其作为最高优先级的路径插入系统路径中
    if args.megatron_path is not None:
        sys.path.insert(0, args.megatron_path)

    try:
        # 尝试导入Megatron的tokenizer模块中的_vocab_size_with_padding函数
        from megatron.tokenizer.tokenizer import _vocab_size_with_padding
    except ModuleNotFoundError:
        # 如果导入失败,则输出错误信息并退出程序
        print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
        exit(1)

    # 加载transformers模型的状态字典和配置文件
    sub_dirs = [x for x in os.listdir(args.load_path) if x.startswith("pytorch_model")]
    
    # 如果只有一个子目录,直接加载单一的pytorch_model.bin文件
    if len(sub_dirs) == 1:
        checkpoint_name = "pytorch_model.bin"
        state_dict = torch.load(os.path.join(args.load_path, checkpoint_name), map_location="cpu")
    else:
        # 如果有多个子目录,调用merge_transformers_sharded_states函数合并多个分片的状态
        num_checkpoints = len(sub_dirs) - 1
        state_dict = merge_transformers_sharded_states(args.load_path, num_checkpoints)

    # 从预训练模型路径中加载GPT2Config配置
    config = GPT2Config.from_pretrained(args.load_path)

    # 保存跟踪文件
    tracker_filepath = os.path.join(args.save_path, "latest_checkpointed_iteration.txt")
    with open(tracker_filepath, "w") as f:
        # 写入"release"作为内容
        f.write("release")

    # 创建`release`目录在args.save_path中
    release_dir = os.path.join(args.save_path, "release")
    os.makedirs(release_dir, exist_ok=True)

    # 设置Megatron的参数
    megatron_args = {
        "orig_vocab_size": config.vocab_size,
        "max_position_embeddings": config.n_positions,
        "hidden_size": config.n_embd,
        "num_layers": config.n_layer,
        "num_attention_heads": config.n_head,
        "ffn_hidden_size": config.n_inner,
        "tensor_model_parallel_size": args.target_tensor_model_parallel_size,
        "pipeline_model_parallel_size": args.target_pipeline_model_parallel_size,
        "data_parallel_size": args.target_data_parallel_size,
        "make_vocab_size_divisible_by": args.make_vocab_size_divisible_by,
        "rank": 0,
        "tokenizer_type": "GPT2BPETokenizer",
    }

    # 根据激活函数类型设置相应的Megatron参数
    if config.activation_function == "gelu":
        megatron_args["bias_gelu_fusion"] = False
        megatron_args["openai_gelu"] = False
    elif config.activation_function == "gelu_fast":
        megatron_args["bias_gelu_fusion"] = True
        megatron_args["openai_gelu"] = False
    elif config.activation_function == "gelu_new":
        megatron_args["bias_gelu_fusion"] = False
        megatron_args["openai_gelu"] = True

    # 使用types模块创建命名空间对象margs,并设置其属性为megatron_args中的键值对
    margs = types.SimpleNamespace()
    for k, v in megatron_args.items():
        setattr(margs, k, v)
    # 根据参数设置目标参数的数据类型
    if args.target_params_dtype == "fp16":
        dtype = torch.float16
    elif args.target_params_dtype == "bf16":
        dtype = torch.bfloat16
    else:
        dtype = torch.float32
    setattr(margs, "params_dtype", dtype)  # 将数据类型设置为模型参数对象的属性

    # 保存一个虚拟的优化器状态字典
    dummy_optim_state_dict = {}
    dummy_optim_state_dict["optimizer"] = {
        "step": 0,  # 优化器的步数
        "param_groups": [
            {
                "lr": 0.0,  # 学习率
                "beta1": 0.0,  # Adam优化器的beta1参数
                "beta2": 0.0,  # Adam优化器的beta2参数
                "eps": 0.0,  # Adam优化器的epsilon参数
                "weight_decay": 0.0,  # 权重衰减参数
                "correct_bias": False,  # 是否校正偏置
                "params": [],  # 参数组
            }
        ],
    }

    # 如果使用分布式优化器
    if args.use_distributed_optimizer:
        for i in range(args.target_pipeline_model_parallel_size):
            for j in range(args.target_tensor_model_parallel_size):
                for k in range(args.target_data_parallel_size):
                    if args.target_pipeline_model_parallel_size == 1:
                        checkpoint_dir = f"mp_rank_{j:02d}_{k:03d}"
                    else:
                        checkpoint_dir = f"mp_rank_{j:02d}_{i:03d}_{k:03d}"
                    checkpoint_dir = os.path.join(release_dir, checkpoint_dir)
                    os.makedirs(checkpoint_dir, exist_ok=True)
                    torch.save(
                        dummy_optim_state_dict,
                        os.path.join(checkpoint_dir, "optim.pt"),
                    )

    # 打印信息,开始转换
    print("Converting")

    # 创建一个空列表,用于存储输出的状态字典
    output_state_dict = []

    # 为每个目标张量模型并行大小创建一个空字典
    for i in range(args.target_tensor_model_parallel_size):
        output_state_dict.append({})

    # 处理嵌入层
    print("converting embedding layer")

    # 将位置嵌入和词嵌入转换为指定的数据类型
    pos_embedding = state_dict["transformer.wpe.weight"].to(dtype)
    word_embedding = state_dict["transformer.wte.weight"].to(dtype)

    orig_vocab_size = config.vocab_size
    padded_vocab_size = _vocab_size_with_padding(orig_vocab_size, margs)
    setattr(margs, "padded_vocab_size", padded_vocab_size)

    # 如果原始词汇表大小大于填充后的大小,则裁剪多余的填充部分
    if orig_vocab_size > padded_vocab_size:
        full_word_embed = word_embedding[0:padded_vocab_size, :]
    # 如果原始词汇表大小小于填充后的大小,则扩展嵌入向量以适应填充后的大小
    elif orig_vocab_size < padded_vocab_size:
        padding_size = padded_vocab_size - orig_vocab_size
        full_word_embed = torch.cat((word_embedding, word_embedding[-1].unsqueeze(0).expand(padding_size, -1)))
    # 如果原始词汇表大小等于填充后的大小,则直接使用原始词嵌入
    else:
        full_word_embed = word_embedding

    # 将嵌入向量按照目标张量模型并行大小进行分块
    out_word_embed = torch.chunk(full_word_embed, args.target_tensor_model_parallel_size, dim=0)
    # 遍历目标张量模型并设置位置嵌入和词嵌入权重
    for i in range(args.target_tensor_model_parallel_size):
        # 获取模型状态字典中位置嵌入的路径并更新其权重为指定的位置嵌入
        pos_emb_dict = get_element_from_dict_by_path(
            output_state_dict[i], "model.language_model.embedding.position_embeddings"
        )
        pos_emb_dict["weight"] = pos_embedding

        # 获取模型状态字典中词嵌入的路径并更新其权重为当前输出的词嵌入的克隆
        word_emb_dict = get_element_from_dict_by_path(
            output_state_dict[i], "model.language_model.embedding.word_embeddings"
        )
        word_emb_dict["weight"] = out_word_embed[i].clone()

    # 转换器层处理
    print("converting transformer layers")

    # 检查注意力头数是否能被目标张量并行大小整除,否则引发数值错误
    if config.num_attention_heads % args.target_tensor_model_parallel_size != 0:
        raise ValueError(
            f"Number of attention heads ({config.num_attention_heads}) must be divisible by number of tensor parallelism"
            f" ({args.target_tensor_model_parallel_size})"
        )

    # 检查隐藏层数是否能被目标管道并行大小整除,否则引发数值错误
    if config.num_hidden_layers % args.target_pipeline_model_parallel_size != 0:
        raise ValueError(
            f"Number of layers ({config.num_hidden_layers}) must be divisible by number of pipeline parallelism"
            f" ({args.target_pipeline_model_parallel_size})"
        )

    # 计算每个管道并行块的转换器层数量
    num_layers = config.num_hidden_layers // args.target_pipeline_model_parallel_size

    # 正则表达式,用于匹配和解析转换器层的名称
    layer_re = re.compile(r"transformer.h\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")

    # Transformer模型中的注意力头数
    heads = config.n_head

    # 每个注意力头的隐藏大小
    hidden_size_per_head = config.n_embd // config.n_head
# 定义主函数入口
def main():
    # 创建参数解析器对象
    parser = argparse.ArgumentParser()
    # 向参数解析器添加用于检查点的参数
    parser = add_checkpointing_args(parser)
    # 向参数解析器添加用于 Megatron 检查点的参数
    parser = add_megatron_checkpoint_args(parser)
    # 向参数解析器添加用于 Transformers 检查点的参数
    parser = add_transformers_checkpoint_args(parser)
    # 解析命令行参数
    args = parser.parse_args()
    
    # 如果命令行参数中包含转换 Megatron 到 Transformers 的选项
    if args.convert_checkpoint_from_megatron_to_transformers:
        # 执行 Megatron 到 Transformers 的检查点转换
        convert_checkpoint_from_megatron_to_transformers(args)
    else:
        # 否则执行 Transformers 到 Megatron 的检查点转换
        convert_checkpoint_from_transformers_to_megatron(args)

# 如果该脚本作为主程序运行,则执行 main() 函数
if __name__ == "__main__":
    main()

.\models\megatron_gpt2\convert_megatron_gpt2_checkpoint.py

# 定义递归打印函数,用于打印参数名及其对应的值
def recursive_print(name, val, spaces=0):
    # 格式化消息,根据参数名生成相应格式的字符串,控制输出的对齐
    if name is None:
        msg = None
    else:
        fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}"
        msg = fmt.format(name)

    # 打印消息并递归处理(如果需要的话)
    if isinstance(val, dict):
        if msg is not None:
            print(msg)
        for k in val.keys():
            recursive_print(k, val[k], spaces + 2)
    elif isinstance(val, torch.Tensor):
        # 如果值是 torch.Tensor 类型,则打印参数名、值的尺寸
        print(msg, ":", val.size())
    else:
        # 否则,只打印参数名
        print(msg, ":", val)


# 对参数张量进行重新排序,以适应后续版本的 NVIDIA Megatron-LM
def fix_query_key_value_ordering(param, checkpoint_version, num_splits, num_heads, hidden_size):
    # 参数张量的输入形状
    input_shape = param.size()
    # 将布局排列为 [num_splits * num_heads * hidden_size, :],以便与后续版本的 Megatron-LM 兼容
    # 在 Megatron-LM 内部,会执行反向操作来读取检查点
    # 如果 param 是 self-attention 块的权重张量,则返回的张量需要再次转置,以便 HuggingFace GPT2 读取
    # 参考 Megatron-LM 源码中的实现:https://github.com/NVIDIA/Megatron-LM/blob/v2.4/megatron/checkpointing.py#L209
    # 如果检查点版本为1.0:
    if checkpoint_version == 1.0:
        # 版本1.0存储的形状是 [num_heads * hidden_size * num_splits, :]
        saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:]
        # 调整张量的形状为 saved_shape
        param = param.view(*saved_shape)
        # 转置操作:交换维度0和2
        param = param.transpose(0, 2)
        # 再次转置操作:交换维度1和2,并确保内存连续性
        param = param.transpose(1, 2).contiguous()
    
    # 如果检查点版本大于等于2.0:
    elif checkpoint_version >= 2.0:
        # 其他版本存储的形状是 [num_heads * num_splits * hidden_size, :]
        saved_shape = (num_heads, num_splits, hidden_size) + input_shape[1:]
        # 调整张量的形状为 saved_shape
        param = param.view(*saved_shape)
        # 转置操作:交换维度0和1,并确保内存连续性
        param = param.transpose(0, 1).contiguous()
    
    # 最后将张量的形状调整为 input_shape,并返回处理后的参数张量
    param = param.view(*input_shape)
    return param
####################################################################################################

# 定义函数用于将 Megatron-LM 模型检查点转换为适用于 Transformers 模型的格式
def convert_megatron_checkpoint(args, input_state_dict, config):
    # 初始化输出状态字典
    output_state_dict = {}

    # 旧版本可能未存储训练参数,检查并获取相关参数
    ds_args = input_state_dict.get("args", None)
    if ds_args is not None:
        # 如果存在训练参数,则根据这些参数设置配置对象的相关属性
        config.vocab_size = ds_args.padded_vocab_size
        config.n_positions = ds_args.max_position_embeddings
        config.n_embd = ds_args.hidden_size
        config.n_layer = ds_args.num_layers
        config.n_head = ds_args.num_attention_heads
        config.n_inner = ds_args.ffn_hidden_size

    # 获取注意力头的数量
    heads = config.n_head
    # 计算每个注意力头的隐藏层大小
    hidden_size_per_head = config.n_embd // config.n_head
    # 获取 Megatron-LM 检查点版本信息
    if "checkpoint_version" in input_state_dict.keys():
        checkpoint_version = input_state_dict["checkpoint_version"]
    else:
        checkpoint_version = 0.0

    # 获取模型对象
    model = input_state_dict["model"]
    # 获取语言模型
    lm = model["language_model"]
    # 获取嵌入层
    embeddings = lm["embedding"]

    # 获取词嵌入
    word_embeddings = embeddings["word_embeddings"]["weight"]
    # 将词嵌入表截断到指定的词汇量大小
    word_embeddings = word_embeddings[: config.vocab_size, :]
    output_state_dict["transformer.wte.weight"] = word_embeddings

    # 获取位置嵌入
    pos_embeddings = embeddings["position_embeddings"]["weight"]
    # 检查位置嵌入的长度与配置中的位置数是否匹配
    n_positions = pos_embeddings.size(0)
    if n_positions != config.n_positions:
        raise ValueError(
            f"pos_embeddings.max_sequence_length={n_positions} and config.n_positions={config.n_positions} don't match"
        )
    # 存储位置嵌入
    output_state_dict["transformer.wpe.weight"] = pos_embeddings

    # 获取变压器层对象,根据是否包含 "transformer" 键来决定
    transformer = lm["transformer"] if "transformer" in lm.keys() else lm["encoder"]

    # 编译用于提取层名称的正则表达式
    layer_re = re.compile(r"layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")

    # Megatron-LM 到 Transformers 的简单映射规则
    megatron_to_transformers = {
        "attention.dense": ".attn.c_proj.",
        "self_attention.dense": ".attn.c_proj.",
        "mlp.dense_h_to_4h": ".mlp.c_fc.",
        "mlp.dense_4h_to_h": ".mlp.c_proj.",
    }

    # 提取层信息的准备工作
    # 遍历transformer.items()中的键值对,其中key为层的名称,val为对应的值(通常是权重或偏置)。
    for key, val in transformer.items():
        # 使用正则表达式匹配层的名称。
        m = layer_re.match(key)

        # 如果匹配结果为None,说明这不是一个层,直接跳出循环。
        if m is None:
            break

        # 提取层的索引。
        layer_idx = int(m.group(1))
        # 提取操作的名称。
        op_name = m.group(2)
        # 判断是权重还是偏置。
        weight_or_bias = m.group(3)

        # 构造层的名称。
        layer_name = f"transformer.h.{layer_idx}"

        # 对于layernorm,直接存储layernorm的值。
        if op_name.endswith("layernorm"):
            ln_name = "ln_1" if op_name.startswith("input") else "ln_2"
            output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = val

        # 转置QKV矩阵。
        elif (
            op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value"
        ) and weight_or_bias == "weight":
            # 插入一个1x1xDxD的偏置张量。
            causal_mask = torch.tril(torch.ones((n_positions, n_positions), dtype=torch.float16)).view(
                1, 1, n_positions, n_positions
            )
            output_state_dict[layer_name + ".attn.bias"] = causal_mask

            # 插入一个"虚拟"张量作为masked_bias。
            masked_bias = torch.tensor(-1e4, dtype=torch.float16)
            output_state_dict[layer_name + ".attn.masked_bias"] = masked_bias

            # 调整QKV矩阵的顺序。
            out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head)
            # Megatron存储的是(3*D) x D,但transformers-GPT2需要的是D x (3*D)。
            out_val = out_val.transpose(0, 1).contiguous()
            # 存储。
            output_state_dict[layer_name + ".attn.c_attn.weight"] = out_val

        # 转置偏置。
        elif (
            op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value"
        ) and weight_or_bias == "bias":
            out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head)
            # 存储。形状无变化。
            output_state_dict[layer_name + ".attn.c_attn.bias"] = out_val

        # 转置权重。
        elif weight_or_bias == "weight":
            out_name = megatron_to_transformers[op_name]
            output_state_dict[layer_name + out_name + "weight"] = val.transpose(0, 1)

        # 复制偏置。
        elif weight_or_bias == "bias":
            out_name = megatron_to_transformers[op_name]
            output_state_dict[layer_name + out_name + "bias"] = val

    # 调试断言,确保层数与config.n_layer相符。
    assert config.n_layer == layer_idx + 1

    # 最终的layernorm。
    output_state_dict["transformer.ln_f.weight"] = transformer["final_layernorm.weight"]
    output_state_dict["transformer.ln_f.bias"] = transformer["final_layernorm.bias"]

    # 对于LM头,transformers需要权重矩阵来加权嵌入。
    output_state_dict["lm_head.weight"] = word_embeddings

    # 完成任务!
    # 返回函数中的状态字典作为输出
    return output_state_dict
# 定义主函数,程序的入口点
def main():
    # 创建参数解析器
    parser = argparse.ArgumentParser()
    # 添加布尔型参数 --print-checkpoint-structure,用于指定是否打印检查点结构
    parser.add_argument("--print-checkpoint-structure", action="store_true")
    # 添加位置参数 path_to_checkpoint,表示检查点文件的路径(可以是 .zip 文件或直接的 .pt 文件)
    parser.add_argument(
        "path_to_checkpoint",
        type=str,
        help="Path to the checkpoint file (.zip archive or direct .pt file)",
    )
    # 添加可选参数 --config_file,表示可选的配置 JSON 文件,描述预训练模型
    parser.add_argument(
        "--config_file",
        default="",
        type=str,
        help="An optional config json file describing the pre-trained model.",
    )
    # 解析命令行参数
    args = parser.parse_args()

    # 提取基本文件名
    basename = os.path.dirname(args.path_to_checkpoint)

    # 加载模型
    # 如果检查点路径以 .zip 结尾,则假设其为压缩文件
    print(f"Extracting PyTorch state dictionary from {args.path_to_checkpoint}")
    if args.path_to_checkpoint.endswith(".zip"):
        # 使用 zipfile 库打开 .zip 文件
        with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint:
            # 打开压缩包中的指定文件 release/mp_rank_00/model_optim_rng.pt
            with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict:
                # 使用 torch.load 加载 PyTorch 的状态字典
                input_state_dict = torch.load(pytorch_dict, map_location="cpu")
    else:
        # 直接加载 .pt 文件
        input_state_dict = torch.load(args.path_to_checkpoint, map_location="cpu")

    # 从输入状态字典中获取参数 args
    ds_args = input_state_dict.get("args", None)

    # 读取配置文件,或者默认使用 NVIDIA 发布的模型配置
    if args.config_file == "":
        if ds_args is not None:
            if ds_args.bias_gelu_fusion:
                activation_function = "gelu_fast"
            elif ds_args.openai_gelu:
                activation_function = "gelu_new"
            else:
                activation_function = "gelu"
        else:
            # 在早期版本中可能使用的激活函数
            activation_function = "gelu_new"

        # 明确指定所有参数,以防默认值发生更改
        config = GPT2Config(
            vocab_size=50257,
            n_positions=1024,
            n_embd=1024,
            n_layer=24,
            n_head=16,
            n_inner=4096,
            activation_function=activation_function,
            resid_pdrop=0.1,
            embd_pdrop=0.1,
            attn_pdrop=0.1,
            layer_norm_epsilon=1e-5,
            initializer_range=0.02,
            summary_type="cls_index",
            summary_use_proj=True,
            summary_activation=None,
            summary_proj_to_labels=True,
            summary_first_dropout=0.1,
            scale_attn_weights=True,
            use_cache=True,
            bos_token_id=50256,
            eos_token_id=50256,
        )
    else:
        # 从 JSON 文件中加载配置
        config = GPT2Config.from_json_file(args.config_file)

    # 设置模型架构为 "GPT2LMHeadModel"
    config.architectures = ["GPT2LMHeadModel"]

    # 转换模型
    print("Converting")
    output_state_dict = convert_megatron_checkpoint(args, input_state_dict, config)

    # 如果指定了 --print-checkpoint-structure 参数,则递归打印转换后状态字典的结构
    if args.print_checkpoint_structure:
        recursive_print(None, output_state_dict)
    # Add tokenizer class info to config
    # 将分词器类信息添加到配置中

    if ds_args is not None:
        # 如果数据集参数不为空,则获取分词器类型
        tokenizer_type = ds_args.tokenizer_type
        
        if tokenizer_type == "GPT2BPETokenizer":
            # 如果分词器类型为"GPT2BPETokenizer",选择使用特定的模型
            tokenizer_model_name = "openai-community/gpt2"
        elif tokenizer_type == "PretrainedFromHF":
            # 如果分词器类型为"PretrainedFromHF",使用数据集参数中指定的模型名称或路径
            tokenizer_model_name = ds_args.tokenizer_name_or_path
        else:
            # 如果分词器类型不被识别,则引发值错误异常
            raise ValueError(f"Unrecognized tokenizer_type {tokenizer_type}")
    else:
        # 如果数据集参数为空,默认使用"openai-community/gpt2"作为模型名称
        tokenizer_model_name = "openai-community/gpt2"

    # 根据模型名称加载分词器
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_name)
    # 获取分词器的类名
    tokenizer_class = type(tokenizer).__name__
    # 将分词器类名存储到配置中
    config.tokenizer_class = tokenizer_class

    # 将配置保存到文件中
    print("Saving config")
    config.save_pretrained(basename)

    # 根据参数保存分词器
    print(f"Adding {tokenizer_class} tokenizer files")
    tokenizer.save_pretrained(basename)

    # 将状态字典保存到文件中
    output_checkpoint_file = os.path.join(basename, "pytorch_model.bin")
    print(f'Saving checkpoint to "{output_checkpoint_file}"')
    torch.save(output_state_dict, output_checkpoint_file)
# 如果当前脚本作为主程序运行(而不是被导入),则执行 main 函数
if __name__ == "__main__":
    # 调用主函数,程序的入口点
    main()

.\models\megatron_gpt2\__init__.py

# 版权声明及许可信息,指定了此代码的使用条款和条件
# 版权所有 © 2021 NVIDIA Corporation 和 The HuggingFace Team. 保留所有权利。
#
# 根据 Apache 许可证 2.0 版本("许可证")许可;除非符合许可证的要求,否则不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则依据"原样"分发本软件,
# 没有任何明示或暗示的担保或条件。
# 有关具体语言的详细信息,请参阅许可证。

.\models\mgp_str\configuration_mgp_str.py

# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" MGP-STR model configuration"""

# Importing necessary modules from the Transformers library
from ...configuration_utils import PretrainedConfig  # 导入预训练配置类
from ...utils import logging  # 导入日志记录工具

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 预训练模型配置文件映射字典,指定了模型名称到其配置文件的映射关系
MGP_STR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "alibaba-damo/mgp-str-base": "https://huggingface.co/alibaba-damo/mgp-str-base/resolve/main/config.json",
}

# MgpstrConfig 类,继承自 PretrainedConfig,用于存储 MGP-STR 模型的配置信息
class MgpstrConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of an [`MgpstrModel`]. It is used to instantiate an
    MGP-STR model according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the MGP-STR
    [alibaba-damo/mgp-str-base](https://huggingface.co/alibaba-damo/mgp-str-base) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    # 定义默认的图像大小为 [32, 128]
    Args:
        image_size (`List[int]`, *optional*, defaults to `[32, 128]`):
            The size (resolution) of each image.
        # 定义每个补丁的大小,默认为 4
        patch_size (`int`, *optional*, defaults to 4):
            The size (resolution) of each patch.
        # 定义输入通道数,默认为 3
        num_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        # 定义输出令牌的最大数量,默认为 27
        max_token_length (`int`, *optional*, defaults to 27):
            The max number of output tokens.
        # 定义字符头的类别数量,默认为 38
        num_character_labels (`int`, *optional*, defaults to 38):
            The number of classes for character head .
        # 定义bpe头的类别数量,默认为 50257
        num_bpe_labels (`int`, *optional*, defaults to 50257):
            The number of classes for bpe head .
        # 定义wordpiece头的类别数量,默认为 30522
        num_wordpiece_labels (`int`, *optional*, defaults to 30522):
            The number of classes for wordpiece head .
        # 定义嵌入维度,默认为 768
        hidden_size (`int`, *optional*, defaults to 768):
            The embedding dimension.
        # 定义Transformer编码器中的隐藏层数量,默认为 12
        num_hidden_layers (`int`, *optional*, defaults to 12):
            Number of hidden layers in the Transformer encoder.
        # 定义Transformer编码器中每个注意力层的注意头数量,默认为 12
        num_attention_heads (`int`, *optional*, defaults to 12):
            Number of attention heads for each attention layer in the Transformer encoder.
        # 定义mlp隐藏维度与嵌入维度的比率,默认为 4.0
        mlp_ratio (`float`, *optional*, defaults to 4.0):
            The ratio of mlp hidden dim to embedding dim.
        # 定义是否向查询、键和值添加偏置,默认为 True
        qkv_bias (`bool`, *optional*, defaults to `True`):
            Whether to add a bias to the queries, keys and values.
        # 定义模型是否包含蒸馏令牌和头部,如DeiT模型,默认为 False
        distilled (`bool`, *optional*, defaults to `False`):
            Model includes a distillation token and head as in DeiT models.
        # 定义层归一化层使用的 epsilon,默认为 1e-05
        layer_norm_eps (`float`, *optional*, defaults to 1e-05):
            The epsilon used by the layer normalization layers.
        # 定义所有全连接层的丢弃概率,包括嵌入和编码器,默认为 0.0
        drop_rate (`float`, *optional*, defaults to 0.0):
            The dropout probability for all fully connected layers in the embeddings, encoder.
        # 定义注意力概率的丢弃比率,默认为 0.0
        attn_drop_rate (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.
        # 定义随机深度的丢弃率,默认为 0.0
        drop_path_rate (`float`, *optional*, defaults to 0.0):
            The stochastic depth rate.
        # 定义是否返回A^3模块注意力的布尔值,默认为 False
        output_a3_attentions (`bool`, *optional*, defaults to `False`):
            Whether or not the model should returns A^3 module attentions.
        # 定义所有权重矩阵初始化时的截断正态分布的标准差,默认为 0.02
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

    Example:

    ```
    >>> from transformers import MgpstrConfig, MgpstrForSceneTextRecognition

    >>> # Initializing a Mgpstr mgp-str-base style configuration
    >>> configuration = MgpstrConfig()

    >>> # Initializing a model (with random weights) from the mgp-str-base style configuration
    >>> model = MgpstrForSceneTextRecognition(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```

    # 设置模型类型为 "mgp-str"
    model_type = "mgp-str"
    # 定义一个初始化函数,初始化一个模型对象
    def __init__(
        self,
        image_size=[32, 128],  # 图像大小,默认为[32, 128]
        patch_size=4,          # 补丁大小,默认为4
        num_channels=3,        # 图像通道数,默认为3
        max_token_length=27,   # 最大标记长度,默认为27
        num_character_labels=38,  # 字符标签数,默认为38
        num_bpe_labels=50257,      # BPE标签数,默认为50257
        num_wordpiece_labels=30522,  # WordPiece标签数,默认为30522
        hidden_size=768,        # 隐藏层大小,默认为768
        num_hidden_layers=12,   # 隐藏层数,默认为12
        num_attention_heads=12,  # 注意力头数,默认为12
        mlp_ratio=4.0,          # MLP(多层感知机)比例,默认为4.0
        qkv_bias=True,          # 是否在QKV转换中使用偏置,默认为True
        distilled=False,        # 是否为蒸馏模型,默认为False
        layer_norm_eps=1e-5,    # 层归一化的epsilon值,默认为1e-5
        drop_rate=0.0,          # dropout比率,默认为0.0
        attn_drop_rate=0.0,     # 注意力dropout比率,默认为0.0
        drop_path_rate=0.0,     # 路径dropout比率,默认为0.0
        output_a3_attentions=False,  # 是否输出A3注意力,默认为False
        initializer_range=0.02,  # 初始化范围,默认为0.02
        **kwargs,               # 其他关键字参数
    ):
        super().__init__(**kwargs)  # 调用父类的初始化方法

        self.image_size = image_size  # 初始化图像大小属性
        self.patch_size = patch_size  # 初始化补丁大小属性
        self.num_channels = num_channels  # 初始化图像通道数属性
        self.max_token_length = max_token_length  # 初始化最大标记长度属性
        self.num_character_labels = num_character_labels  # 初始化字符标签数属性
        self.num_bpe_labels = num_bpe_labels  # 初始化BPE标签数属性
        self.num_wordpiece_labels = num_wordpiece_labels  # 初始化WordPiece标签数属性
        self.hidden_size = hidden_size  # 初始化隐藏层大小属性
        self.num_hidden_layers = num_hidden_layers  # 初始化隐藏层数属性
        self.num_attention_heads = num_attention_heads  # 初始化注意力头数属性
        self.mlp_ratio = mlp_ratio  # 初始化MLP比例属性
        self.distilled = distilled  # 初始化蒸馏模型属性
        self.layer_norm_eps = layer_norm_eps  # 初始化层归一化epsilon属性
        self.drop_rate = drop_rate  # 初始化dropout比率属性
        self.qkv_bias = qkv_bias  # 初始化QKV偏置属性
        self.attn_drop_rate = attn_drop_rate  # 初始化注意力dropout比率属性
        self.drop_path_rate = drop_path_rate  # 初始化路径dropout比率属性
        self.output_a3_attentions = output_a3_attentions  # 初始化是否输出A3注意力属性
        self.initializer_range = initializer_range  # 初始化初始化范围属性

.\models\mgp_str\modeling_mgp_str.py

# 设置文件编码为 UTF-8
# 版权声明及保留所有权利给 Alibaba Research 和 HuggingFace Inc. 团队
#
# 根据 Apache 许可证 2.0 版本授权使用本文件
# 除非符合许可证规定,否则不得使用本文件
# 您可以从以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,本软件是基于"原样"提供的,不提供任何形式的担保或条件,
# 包括但不限于,适销性、特定用途适用性和非侵权性担保。
# 有关详细信息,请参阅许可证。

""" PyTorch MGP-STR model."""

import collections.abc
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn

from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
    ModelOutput,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from .configuration_mgp_str import MgpstrConfig

# 获取 logger 对象用于记录日志
logger = logging.get_logger(__name__)

# 用于文档的通用说明
_CONFIG_FOR_DOC = "MgpstrConfig"
_TOKENIZER_FOR_DOC = "MgpstrTokenizer"

# 模型检查点的基本说明
_CHECKPOINT_FOR_DOC = "alibaba-damo/mgp-str-base"

# 预训练模型存档列表
MGP_STR_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "alibaba-damo/mgp-str-base",
    # 查看所有 MGP-STR 模型的列表:https://huggingface.co/models?filter=mgp-str
]

# 以下是函数定义和类定义,用于模型中的路径丢弃功能
# 从 transformers.models.beit.modeling_beit.drop_path 中复制的函数
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    """
    if drop_prob == 0.0 or not training:
        return input
    keep_prob = 1 - drop_prob
    shape = (input.shape[0],) + (1,) * (input.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
    random_tensor.floor_()  # binarize
    output = input.div(keep_prob) * random_tensor
    return output


# 从 transformers.models.beit.modeling_beit.BeitDropPath 中复制的类,将 Beit 改为 Mgpstr
class MgpstrDropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
    # 初始化函数,用于创建一个新的对象实例
    def __init__(self, drop_prob: Optional[float] = None) -> None:
        # 调用父类的初始化方法
        super().__init__()
        # 设置对象的属性,用于指定 dropout 的概率
        self.drop_prob = drop_prob

    # 前向传播函数,处理输入的隐藏状态张量
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 调用 drop_path 函数,对隐藏状态进行随机的 dropout 操作
        return drop_path(hidden_states, self.drop_prob, self.training)

    # 提供额外的表示信息,用于描述当前对象的状态
    def extra_repr(self) -> str:
        # 返回一个字符串,表示对象的 dropout 概率
        return "p={}".format(self.drop_prob)
# 定义了一个数据类 `MgpstrModelOutput`,继承自 `ModelOutput` 类,用于表示模型输出结果
@dataclass
class MgpstrModelOutput(ModelOutput):
    
    """
    Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
    Args:
        logits (`tuple(torch.FloatTensor)` of shape `(batch_size, config.num_character_labels)`):
            Tuple of `torch.FloatTensor` containing classification scores (before SoftMax) for characters, bpe, and wordpiece.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` containing hidden states of the model at each layer and optional initial embeddings.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` containing attention weights for each layer after softmax computation.
        a3_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_a3_attentions=True` is passed or when `config.output_a3_attentions=True`):
            Tuple of `torch.FloatTensor` containing attention weights for character, bpe, and wordpiece after softmax computation.
    """

    # logits 包含分类分数,形状为 (batch_size, config.num_character_labels)
    logits: Tuple[torch.FloatTensor] = None
    # hidden_states 包含模型每层的隐藏状态,形状为 (batch_size, sequence_length, hidden_size),可选
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # attentions 包含每层注意力权重,形状为 (batch_size, config.max_token_length, sequence_length, sequence_length),可选
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    # a3_attentions 包含字符、bpe、wordpiece 的注意力权重,形状为 (batch_size, config.max_token_length, sequence_length),可选
    a3_attentions: Optional[Tuple[torch.FloatTensor]] = None


class MgpstrEmbeddings(nn.Module):
    """2D Image to Patch Embedding"""
    # 初始化函数,接受一个MgpstrConfig类型的配置对象作为参数
    def __init__(self, config: MgpstrConfig):
        # 调用父类的初始化方法
        super().__init__()
        # 根据配置对象中的image_size属性确定图像大小,若为可迭代对象则直接使用,否则将其转换为元组
        image_size = (
            config.image_size
            if isinstance(config.image_size, collections.abc.Iterable)
            else (config.image_size, config.image_size)
        )
        # 根据配置对象中的patch_size属性确定patch大小,若为可迭代对象则直接使用,否则将其转换为元组
        patch_size = (
            config.patch_size
            if isinstance(config.patch_size, collections.abc.Iterable)
            else (config.patch_size, config.patch_size)
        )
        # 设置对象的image_size属性为确定后的图像大小
        self.image_size = image_size
        # 设置对象的patch_size属性为确定后的patch大小
        self.patch_size = patch_size
        # 根据图像大小和patch大小计算出网格大小,以元组形式保存
        self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
        # 计算总的patch数目
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        # 如果配置对象中指定为精炼模式,则token数目为2,否则为1
        self.num_tokens = 2 if config.distilled else 1

        # 使用nn.Conv2d定义一个投影层,将输入通道数转换为隐藏大小,卷积核大小为patch_size,步长也为patch_size
        self.proj = nn.Conv2d(config.num_channels, config.hidden_size, kernel_size=patch_size, stride=patch_size)

        # 定义一个可学习的分类token,维度为1x1x隐藏大小,作为分类信息的表示
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))

        # 定义一个可学习的位置嵌入矩阵,维度为1x(num_patches + num_tokens)x隐藏大小,表示每个patch和token的位置信息
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + self.num_tokens, config.hidden_size))
        # 使用Dropout进行位置嵌入的随机失活,概率为配置对象中指定的drop_rate
        self.pos_drop = nn.Dropout(p=config.drop_rate)

    # 前向传播函数,接受输入的像素值张量,返回嵌入向量张量
    def forward(self, pixel_values):
        # 获取输入像素值张量的形状信息
        batch_size, channel, height, width = pixel_values.shape
        # 检查输入图像的高度和宽度是否与预期的image_size匹配,若不匹配则抛出数值错误异常
        if height != self.image_size[0] or width != self.image_size[1]:
            raise ValueError(
                f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
            )

        # 将输入的像素值张量通过投影层转换为patch嵌入张量,同时将其展平并转置以适应后续操作
        patch_embeddings = self.proj(pixel_values)
        patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)  # BCHW -> BNC

        # 使用分类token扩展为batch_size份,形状为(batch_size, 1, 隐藏大小)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        # 将分类token与patch嵌入拼接在一起,形状为(batch_size, num_patches + num_tokens, 隐藏大小)
        embedding_output = torch.cat((cls_tokens, patch_embeddings), dim=1)
        # 加上位置嵌入信息,形状为(1, num_patches + num_tokens, 隐藏大小),与embedding_output形状相加
        embedding_output = embedding_output + self.pos_embed
        # 对加和后的embedding_output进行位置嵌入的随机失活
        embedding_output = self.pos_drop(embedding_output)

        # 返回最终的嵌入向量张量
        return embedding_output
class MgpstrMlp(nn.Module):
    """MLP as used in Vision Transformer, MLP-Mixer and related networks"""

    def __init__(self, config: MgpstrConfig, hidden_features):
        super().__init__()
        hidden_features = hidden_features or config.hidden_size  # 如果未提供 hidden_features,则使用 config 中的 hidden_size
        self.fc1 = nn.Linear(config.hidden_size, hidden_features)  # 第一个全连接层,输入维度为 config.hidden_size,输出维度为 hidden_features
        self.act = nn.GELU()  # GELU 激活函数
        self.fc2 = nn.Linear(hidden_features, config.hidden_size)  # 第二个全连接层,输入维度为 hidden_features,输出维度为 config.hidden_size
        self.drop = nn.Dropout(config.drop_rate)  # Dropout 操作,丢弃率为 config.drop_rate

    def forward(self, hidden_states):
        hidden_states = self.fc1(hidden_states)  # 第一个全连接层的前向传播
        hidden_states = self.act(hidden_states)  # 应用 GELU 激活函数
        hidden_states = self.drop(hidden_states)  # Dropout 操作
        hidden_states = self.fc2(hidden_states)  # 第二个全连接层的前向传播
        hidden_states = self.drop(hidden_states)  # 再次应用 Dropout
        return hidden_states  # 返回处理后的隐藏状态


class MgpstrAttention(nn.Module):
    def __init__(self, config: MgpstrConfig):
        super().__init__()
        self.num_heads = config.num_attention_heads  # 注意力头的数量
        head_dim = config.hidden_size // config.num_attention_heads  # 每个注意力头的维度
        self.scale = head_dim ** -0.5  # 缩放因子

        self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias)  # QKV 线性变换
        self.attn_drop = nn.Dropout(config.attn_drop_rate)  # Attention Dropout 操作
        self.proj = nn.Linear(config.hidden_size, config.hidden_size)  # 投影层线性变换
        self.proj_drop = nn.Dropout(config.drop_rate)  # Dropout 操作

    def forward(self, hidden_states):
        batch_size, num, channel = hidden_states.shape  # 获取输入张量的形状信息
        qkv = (
            self.qkv(hidden_states)  # 执行 QKV 线性变换
            .reshape(batch_size, num, 3, self.num_heads, channel // self.num_heads)  # 重塑张量形状以便后续处理
            .permute(2, 0, 3, 1, 4)  # 调整维度顺序
        )
        query, key, value = qkv[0], qkv[1], qkv[2]  # 分割 QKV 信息以便后续处理(为了兼容 TorchScript)

        attention_probs = (query @ key.transpose(-2, -1)) * self.scale  # 计算注意力分数
        attention_probs = attention_probs.softmax(dim=-1)  # 对注意力分数进行 softmax 操作
        attention_probs = self.attn_drop(attention_probs)  # 应用 Attention Dropout

        context_layer = (attention_probs @ value).transpose(1, 2).reshape(batch_size, num, channel)  # 计算上下文向量
        context_layer = self.proj(context_layer)  # 应用投影层线性变换
        context_layer = self.proj_drop(context_layer)  # 应用 Dropout 操作
        return (context_layer, attention_probs)  # 返回上下文层及注意力分数


class MgpstrLayer(nn.Module):
    def __init__(self, config: MgpstrConfig, drop_path=None):
        super().__init__()
        self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # Layer Normalization 操作
        self.attn = MgpstrAttention(config)  # 注意力机制
        self.drop_path = MgpstrDropPath(drop_path) if drop_path is not None else nn.Identity()  # 随机深度路径(用于随机深度)
        self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # Layer Normalization 操作
        mlp_hidden_dim = int(config.hidden_size * config.mlp_ratio)  # MLP 隐藏层维度
        self.mlp = MgpstrMlp(config, mlp_hidden_dim)  # 多层感知机模块
    # 定义模型的前向传播方法,接受隐藏状态作为输入
    def forward(self, hidden_states):
        # 使用 self.attn 对隐藏状态进行自注意力机制计算,经过 self.norm1 归一化处理
        self_attention_outputs = self.attn(self.norm1(hidden_states))
        # 获取自注意力机制的输出结果和中间层输出
        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[1]

        # 第一个残差连接:将经过注意力计算后的输出加上原始输入隐藏状态,并施加随机丢弃(drop path)
        hidden_states = self.drop_path(attention_output) + hidden_states

        # 第二个残差连接:将经过 self.norm2 归一化后的隐藏状态经过 MLP 处理,再次施加随机丢弃(drop path),然后加上之前的 hidden_states
        layer_output = hidden_states + self.drop_path(self.mlp(self.norm2(hidden_states)))

        # 将最终的层输出和注意力输出组成元组作为最终的输出
        outputs = (layer_output, outputs)
        return outputs
class MgpstrEncoder(nn.Module):
    def __init__(self, config: MgpstrConfig):
        super().__init__()
        # stochastic depth decay rule
        # 根据配置中的drop_path_rate生成随机深度的衰减规则列表
        dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]

        # 使用MgpstrLayer创建神经网络模型的多层堆叠
        self.blocks = nn.Sequential(
            *[MgpstrLayer(config=config, drop_path=dpr[i]) for i in range(config.num_hidden_layers)]
        )

    def forward(self, hidden_states, output_attentions=False, output_hidden_states=False, return_dict=True):
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None

        # 遍历并执行神经网络模型的每个块(layer)
        for _, blk in enumerate(self.blocks):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 执行当前块(layer)的前向传播,更新隐藏状态
            layer_outputs = blk(hidden_states)
            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        # 如果需要输出隐藏状态,则将最终隐藏状态加入到所有隐藏状态的元组中
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 如果不要求返回字典,则根据需要返回不同的元组或对象
        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )


class MgpstrA3Module(nn.Module):
    def __init__(self, config: MgpstrConfig):
        super().__init__()
        # 初始化层归一化层,用于标准化token的向量表示
        self.token_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 通过卷积操作生成token学习器,用于生成token和注意力权重
        self.tokenLearner = nn.Sequential(
            nn.Conv2d(config.hidden_size, config.hidden_size, kernel_size=(1, 1), stride=1, groups=8, bias=False),
            nn.Conv2d(config.hidden_size, config.max_token_length, kernel_size=(1, 1), stride=1, bias=False),
        )
        # 初始化特征提取器的卷积层,用于生成特征表示
        self.feat = nn.Conv2d(
            config.hidden_size, config.hidden_size, kernel_size=(1, 1), stride=1, groups=8, bias=False
        )
        # 初始化层归一化层,用于标准化特征向量的表示
        self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states):
        # 标准化token的向量表示
        hidden_states = self.token_norm(hidden_states)
        # 调整张量维度以便进入token学习器
        hidden_states = hidden_states.transpose(1, 2).unsqueeze(-1)
        # 使用token学习器生成token及其注意力权重
        selected = self.tokenLearner(hidden_states)
        selected = selected.flatten(2)
        attentions = F.softmax(selected, dim=-1)

        # 使用特征提取器生成特征表示
        feat = self.feat(hidden_states)
        feat = feat.flatten(2).transpose(1, 2)
        # 使用注意力权重和特征表示计算A3模块的输出
        feat = torch.einsum("...si,...id->...sd", attentions, feat)
        a3_out = self.norm(feat)

        return (a3_out, attentions)


class MgpstrPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 指定配置类和基础模型前缀
    config_class = MgpstrConfig
    base_model_prefix = "mgp_str"
    def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
        """Initialize the weights"""
        # 如果 module 是 MgpstrEmbeddings 类型的实例
        if isinstance(module, MgpstrEmbeddings):
            # 对 module 的位置嵌入和类别标记进行截断正态分布初始化
            nn.init.trunc_normal_(module.pos_embed, mean=0.0, std=self.config.initializer_range)
            nn.init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range)
        # 如果 module 是 nn.Linear 或者 nn.Conv2d 类型的实例
        elif isinstance(module, (nn.Linear, nn.Conv2d)):
            # 初始化 module 的权重数据为截断正态分布
            module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range)
            # 如果 module 有偏置项,则将其数据置零
            if module.bias is not None:
                module.bias.data.zero_()
        # 如果 module 是 nn.LayerNorm 类型的实例
        elif isinstance(module, nn.LayerNorm):
            # 将 module 的偏置项数据置零
            module.bias.data.zero_()
            # 将 module 的权重数据填充为 1.0
            module.weight.data.fill_(1.0)
# 定义多行字符串,用于描述 MGP-STR 模型的基本信息和使用说明
MGP_STR_START_DOCSTRING = r"""
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`MgpstrConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

# 定义多行字符串,用于描述 MGP-STR 模型前向传播方法的输入参数说明
MGP_STR_INPUTS_DOCSTRING = r"""
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
            for details.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""

# 使用装饰器 `add_start_docstrings`,为 MgpstrModel 类添加类级别的文档字符串,描述该类为裸模型变换器,输出未经特定头部处理的原始隐藏状态
@add_start_docstrings(
    "The bare MGP-STR Model transformer outputting raw hidden-states without any specific head on top.",
    MGP_STR_START_DOCSTRING,
)
# 定义 MgpstrModel 类,继承自 MgpstrPreTrainedModel 类
class MgpstrModel(MgpstrPreTrainedModel):
    def __init__(self, config: MgpstrConfig):
        # 调用父类构造函数初始化模型
        super().__init__(config)
        # 将配置信息存储在实例变量中
        self.config = config
        # 创建并初始化嵌入层对象
        self.embeddings = MgpstrEmbeddings(config)
        # 创建并初始化编码器对象
        self.encoder = MgpstrEncoder(config)

    # 定义方法用于获取输入嵌入层对象
    def get_input_embeddings(self) -> nn.Module:
        return self.embeddings.proj

    # 使用装饰器 `add_start_docstrings_to_model_forward`,为 forward 方法添加文档字符串,描述其输入参数的详细用法
    @add_start_docstrings_to_model_forward(MGP_STR_INPUTS_DOCSTRING)
    def forward(
        self,
        pixel_values: torch.FloatTensor,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        # 注意:这里的方法还未完整定义,继续定义在后续的代码中
    ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
        # 如果 output_attentions 参数为 None,则使用模型配置中的设定
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 如果 output_hidden_states 参数为 None,则使用模型配置中的设定
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 如果 return_dict 参数为 None,则使用模型配置中的设定
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 如果 pixel_values 为空,则抛出数值错误异常
        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        # 将像素值 pixel_values 输入到 embeddings 层,得到嵌入输出 embedding_output
        embedding_output = self.embeddings(pixel_values)

        # 将嵌入输出 embedding_output 输入到编码器 encoder 中进行编码
        encoder_outputs = self.encoder(
            embedding_output,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 如果 return_dict 为 False,则直接返回编码器的输出 encoder_outputs
        if not return_dict:
            return encoder_outputs
        
        # 如果 return_dict 为 True,则封装编码器的输出为 BaseModelOutput 对象并返回
        return BaseModelOutput(
            last_hidden_state=encoder_outputs.last_hidden_state,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
# 使用装饰器为类添加文档字符串,描述了该类是一个 MGP-STR 模型转换器,具有三个分类头部,用于场景文本识别 (STR)。
# 该模型在变换编码器输出的基础上添加了三个 A^3 模块和三个线性层。
class MgpstrForSceneTextRecognition(MgpstrPreTrainedModel):
    # 指定配置类为 MgpstrConfig
    config_class = MgpstrConfig
    # 主要输入名称为 "pixel_values"
    main_input_name = "pixel_values"

    def __init__(self, config: MgpstrConfig) -> None:
        # 调用父类的初始化方法
        super().__init__(config)

        # 初始化时从配置中获取标签数目
        self.num_labels = config.num_labels
        # 创建 MGP-STR 模型
        self.mgp_str = MgpstrModel(config)

        # 创建三个不同的 A^3 模块,分别用于字符级别、BPE(Byte Pair Encoding)级别和词片段级别的处理
        self.char_a3_module = MgpstrA3Module(config)
        self.bpe_a3_module = MgpstrA3Module(config)
        self.wp_a3_module = MgpstrA3Module(config)

        # 创建三个线性头部,分别用于字符级别、BPE 级别和词片段级别的分类
        self.char_head = nn.Linear(config.hidden_size, config.num_character_labels)
        self.bpe_head = nn.Linear(config.hidden_size, config.num_bpe_labels)
        self.wp_head = nn.Linear(config.hidden_size, config.num_wordpiece_labels)

    # 使用装饰器为 forward 方法添加输入文档字符串,描述输入参数的含义
    # 并替换返回值的文档字符串为 MgpstrModelOutput 类型和 MgpstrConfig 配置类的描述
    @add_start_docstrings_to_model_forward(MGP_STR_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=MgpstrModelOutput, config_class=MgpstrConfig)
    def forward(
        self,
        pixel_values: torch.FloatTensor,
        output_attentions: Optional[bool] = None,
        output_a3_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], MgpstrModelOutput]:
        r"""
        output_a3_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of a3 modules. See `a3_attentions` under returned tensors
            for more detail.

        Returns:
            This function returns either a tuple of torch.FloatTensor or an instance of MgpstrModelOutput.

        Example:

        ```
        >>> from transformers import (
        ...     MgpstrProcessor,
        ...     MgpstrForSceneTextRecognition,
        ... )
        >>> import requests
        >>> from PIL import Image

        >>> # load image from the IIIT-5k dataset
        >>> url = "https://i.postimg.cc/ZKwLg2Gw/367-14.png"
        >>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")

        >>> processor = MgpstrProcessor.from_pretrained("alibaba-damo/mgp-str-base")
        >>> pixel_values = processor(images=image, return_tensors="pt").pixel_values

        >>> model = MgpstrForSceneTextRecognition.from_pretrained("alibaba-damo/mgp-str-base")

        >>> # inference
        >>> outputs = model(pixel_values)
        >>> out_strs = processor.batch_decode(outputs.logits)
        >>> out_strs["generated_text"]
        '["ticket"]'
        ```

        Initialize variables to default values if not provided by the caller.
        `output_attentions`, `output_hidden_states`, and `return_dict` are set based on the model configuration.
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        mgp_outputs = self.mgp_str(
            pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = mgp_outputs[0]

        # Apply attention modules to sequence_output
        char_a3_out, char_attention = self.char_a3_module(sequence_output)
        bpe_a3_out, bpe_attention = self.bpe_a3_module(sequence_output)
        wp_a3_out, wp_attention = self.wp_a3_module(sequence_output)

        # Compute logits using corresponding head modules
        char_logits = self.char_head(char_a3_out)
        bpe_logits = self.bpe_head(bpe_a3_out)
        wp_logits = self.wp_head(wp_a3_out)

        # Aggregate all attention tensors if output_a3_attentions is True
        all_a3_attentions = (char_attention, bpe_attention, wp_attention) if output_a3_attentions else None
        all_logits = (char_logits, bpe_logits, wp_logits)

        # Return either a tuple of outputs or MgpstrModelOutput based on return_dict
        if not return_dict:
            outputs = (all_logits, all_a3_attentions) + mgp_outputs[1:]
            return tuple(output for output in outputs if output is not None)
        return MgpstrModelOutput(
            logits=all_logits,
            hidden_states=mgp_outputs.hidden_states,
            attentions=mgp_outputs.attentions,
            a3_attentions=all_a3_attentions,
        )

.\models\mgp_str\processing_mgp_str.py

# coding=utf-8
# 定义字符编码类型枚举,包括字符级编码、BPE编码和WordPiece编码
from transformers import AutoTokenizer
from transformers.utils import is_torch_available
from transformers.utils.generic import ExplicitEnum
from ...processing_utils import ProcessorMixin

# 检查是否安装了torch,以便条件导入
if is_torch_available():
    import torch

# 枚举不同的解码类型:字符级、BPE、WordPiece
class DecodeType(ExplicitEnum):
    CHARACTER = "char"
    BPE = "bpe"
    WORDPIECE = "wp"

# 支持的注释格式,包括字符级、BPE和WordPiece
SUPPORTED_ANNOTATION_FORMATS = (DecodeType.CHARACTER, DecodeType.BPE, DecodeType.WORDPIECE)

# MGP-STR处理器类,继承自ProcessorMixin
class MgpstrProcessor(ProcessorMixin):
    """
    构建MGP-STR处理器,将图像处理器和MGP-STR分词器封装到一个单独的处理器中。

    [`MgpstrProcessor`] 提供了`ViTImageProcessor`和`MgpstrTokenizer`的所有功能。查看[`~MgpstrProcessor.__call__`]和
    [`~MgpstrProcessor.batch_decode`]获取更多信息。

    Args:
        image_processor (`ViTImageProcessor`, *可选*):
            `ViTImageProcessor`的实例。图像处理器是必需的输入。
        tokenizer ([`MgpstrTokenizer`], *可选*):
            分词器是必需的输入。
    """

    # 类属性定义
    attributes = ["image_processor", "char_tokenizer"]
    image_processor_class = "ViTImageProcessor"
    char_tokenizer_class = "MgpstrTokenizer"

    def __init__(self, image_processor=None, tokenizer=None, **kwargs):
        # 弃用警告:`feature_extractor`参数将在v5中移除,请使用`image_processor`
        feature_extractor = None
        if "feature_extractor" in kwargs:
            warnings.warn(
                "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
                " instead.",
                FutureWarning,
            )
            feature_extractor = kwargs.pop("feature_extractor")

        # 设置图像处理器,如果没有提供则使用`feature_extractor`
        image_processor = image_processor if image_processor is not None else feature_extractor
        if image_processor is None:
            raise ValueError("You need to specify an `image_processor`.")
        
        # 检查是否提供了分词器,如果没有则引发异常
        if tokenizer is None:
            raise ValueError("You need to specify a `tokenizer`.")

        # 初始化MGP-STR处理器实例,设置字符级分词器
        self.char_tokenizer = tokenizer
        # 使用预训练模型创建BPE编码的分词器
        self.bpe_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
        # 使用预训练模型创建WordPiece编码的分词器
        self.wp_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")

        # 调用父类ProcessorMixin的构造函数,传递图像处理器和分词器
        super().__init__(image_processor, tokenizer)
    def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
        """
        当以普通模式使用时,此方法将所有参数转发到 ViTImageProcessor 的 [`~ViTImageProcessor.__call__`] 并返回其输出。
        如果 `text` 不为 `None`,此方法还将 `text` 和 `kwargs` 参数转发到 MgpstrTokenizer 的 [`~MgpstrTokenizer.__call__`] 以编码文本。
        更多信息请参考上述方法的文档字符串。
        """
        if images is None and text is None:
            raise ValueError("You need to specify either an `images` or `text` input to process.")

        if images is not None:
            # 使用图像处理器处理图像输入
            inputs = self.image_processor(images, return_tensors=return_tensors, **kwargs)
        if text is not None:
            # 使用字符标记器编码文本输入
            encodings = self.char_tokenizer(text, return_tensors=return_tensors, **kwargs)

        if text is None:
            return inputs
        elif images is None:
            return encodings
        else:
            # 将标记化的文本作为标签添加到图像输入中
            inputs["labels"] = encodings["input_ids"]
            return inputs

    def batch_decode(self, sequences):
        """
        将一组标记 id 的列表转换为字符串列表,通过调用 decode 方法实现。

        Args:
            sequences (`torch.Tensor`):
                标记化输入 id 的列表。

        Returns:
            `Dict[str, any]`: 解码结果的所有输出字典。
                generated_text (`List[str]`): 融合字符、bpe 和 wp 后的最终结果。
                scores (`List[float]`): 融合字符、bpe 和 wp 后的最终分数。
                char_preds (`List[str]`): 字符解码后的句子列表。
                bpe_preds (`List[str]`): bpe 解码后的句子列表。
                wp_preds (`List[str]`): wp 解码后的句子列表。

        此方法将其所有参数转发到 PreTrainedTokenizer 的 [`~PreTrainedTokenizer.batch_decode`]。更多信息请参考此方法的文档字符串。
        """
        char_preds, bpe_preds, wp_preds = sequences
        batch_size = char_preds.size(0)

        # 分别调用 `_decode_helper` 方法解码字符、bpe 和 wp
        char_strs, char_scores = self._decode_helper(char_preds, "char")
        bpe_strs, bpe_scores = self._decode_helper(bpe_preds, "bpe")
        wp_strs, wp_scores = self._decode_helper(wp_preds, "wp")

        final_strs = []
        final_scores = []
        for i in range(batch_size):
            scores = [char_scores[i], bpe_scores[i], wp_scores[i]]
            strs = [char_strs[i], bpe_strs[i], wp_strs[i]]
            max_score_index = scores.index(max(scores))
            final_strs.append(strs[max_score_index])
            final_scores.append(scores[max_score_index])

        out = {}
        out["generated_text"] = final_strs
        out["scores"] = final_scores
        out["char_preds"] = char_strs
        out["bpe_preds"] = bpe_strs
        out["wp_preds"] = wp_strs
        return out
    def _decode_helper(self, pred_logits, format):
        """
        Convert a list of lists of bpe token ids into a list of strings by calling bpe tokenizer.

        Args:
            pred_logits (`torch.Tensor`):
                List of model prediction logits.
            format (`Union[DecoderType, str]`):
                Type of model prediction. Must be one of ['char', 'bpe', 'wp'].
        Returns:
            `tuple`:
                dec_strs(`str`): The decode strings of model prediction.
                conf_scores(`List[float]`): The confidence score of model prediction.
        """
        # 根据不同的解码类型选择相应的解码器和结束标记
        if format == DecodeType.CHARACTER:
            decoder = self.char_decode
            eos_token = 1  # 结束标记为1
            eos_str = "[s]"  # 结束字符串为"[s]"
        elif format == DecodeType.BPE:
            decoder = self.bpe_decode
            eos_token = 2  # 结束标记为2
            eos_str = "#"  # 结束字符串为"#"
        elif format == DecodeType.WORDPIECE:
            decoder = self.wp_decode
            eos_token = 102  # 结束标记为102
            eos_str = "[SEP]"  # 结束字符串为"[SEP]"
        else:
            raise ValueError(f"Format {format} is not supported.")  # 如果格式不支持,则抛出异常

        dec_strs, conf_scores = [], []  # 初始化解码字符串列表和置信度分数列表
        batch_size = pred_logits.size(0)  # 获取批次大小
        batch_max_length = pred_logits.size(1)  # 获取每个样本的最大长度
        _, preds_index = pred_logits.topk(1, dim=-1, largest=True, sorted=True)  # 获取每个位置上概率最大的预测索引
        preds_index = preds_index.view(-1, batch_max_length)[:, 1:]  # 去除开始标记,保留有效预测部分
        preds_str = decoder(preds_index)  # 使用对应解码器对预测索引进行解码成字符串
        preds_max_prob, _ = torch.nn.functional.softmax(pred_logits, dim=2).max(dim=2)  # 获取每个位置上的最大概率及其索引
        preds_max_prob = preds_max_prob[:, 1:]  # 去除开始位置的概率

        # 遍历每个样本
        for index in range(batch_size):
            pred_eos = preds_str[index].find(eos_str)  # 查找结束字符串在预测字符串中的位置
            pred = preds_str[index][:pred_eos]  # 截取到结束字符串前的部分作为最终预测
            pred_index = preds_index[index].cpu().tolist()  # 将预测索引转换为CPU上的列表
            pred_eos_index = pred_index.index(eos_token) if eos_token in pred_index else -1  # 查找结束标记的位置
            pred_max_prob = preds_max_prob[index][: pred_eos_index + 1]  # 获取对应的最大概率
            confidence_score = pred_max_prob.cumprod(dim=0)[-1] if pred_max_prob.nelement() != 0 else 0.0  # 计算置信度分数
            dec_strs.append(pred)  # 将预测字符串添加到结果列表
            conf_scores.append(confidence_score)  # 将置信度分数添加到结果列表

        return dec_strs, conf_scores  # 返回解码字符串列表和置信度分数列表

    def char_decode(self, sequences):
        """
        Convert a list of lists of char token ids into a list of strings by calling char tokenizer.

        Args:
            sequences (`torch.Tensor`):
                List of tokenized input ids.
        Returns:
            `List[str]`: The list of char decoded sentences.
        """
        # 使用字符级解码器对字符级标记序列进行解码成字符串
        decode_strs = [seq.replace(" ", "") for seq in self.char_tokenizer.batch_decode(sequences)]
        return decode_strs  # 返回解码后的字符串列表

    def bpe_decode(self, sequences):
        """
        Convert a list of lists of bpe token ids into a list of strings by calling bpe tokenizer.

        Args:
            sequences (`torch.Tensor`):
                List of tokenized input ids.
        Returns:
            `List[str]`: The list of bpe decoded sentences.
        """
        return self.bpe_tokenizer.batch_decode(sequences)  # 使用BPE解码器对BPE级标记序列进行解码成字符串并返回
    def wp_decode(self, sequences):
        """
        Convert a list of lists of word piece token ids into a list of strings by calling word piece tokenizer.

        Args:
            sequences (`torch.Tensor`):
                List of tokenized input ids.
        Returns:
            `List[str]`: The list of wp decoded sentences.
        """
        # 对每个序列进行批量解码,并去除解码后字符串中的空格
        decode_strs = [seq.replace(" ", "") for seq in self.wp_tokenizer.batch_decode(sequences)]
        # 返回解码后的字符串列表
        return decode_strs

.\models\mgp_str\tokenization_mgp_str.py

# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tokenization classes for MGT-STR CHAR.
"""

import json  # 导入json模块,用于处理JSON格式的数据
import os    # 导入os模块,提供了与操作系统交互的功能
from typing import Optional, Tuple   # 导入类型提示相关的模块

from ...tokenization_utils import PreTrainedTokenizer  # 导入预训练分词器的基类
from ...utils import logging   # 导入日志记录模块

logger = logging.get_logger(__name__)   # 获取当前模块的日志记录器对象

VOCAB_FILES_NAMES = {"vocab_file": "vocab.json"}   # 定义词汇表文件名映射字典

PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "mgp-str": "https://huggingface.co/alibaba-damo/mgp-str-base/blob/main/vocab.json",
    }
}   # 预训练词汇文件映射,指定了不同预训练模型的词汇文件路径

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"mgp-str": 27}   # 预训练位置嵌入的尺寸映射

class MgpstrTokenizer(PreTrainedTokenizer):
    """
    Construct a MGP-STR char tokenizer.

    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
    this superclass for more information regarding those methods.

    Args:
        vocab_file (`str`):
            Path to the vocabulary file.
        unk_token (`str`, *optional*, defaults to `"[GO]"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        bos_token (`str`, *optional*, defaults to `"[GO]"`):
            The beginning of sequence token.
        eos_token (`str`, *optional*, defaults to `"[s]"`):
            The end of sequence token.
        pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"[GO]"`):
            A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
            attention mechanisms or loss computation.
    """

    vocab_files_names = VOCAB_FILES_NAMES   # 设置词汇文件名映射
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP   # 设置预训练词汇文件映射
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES   # 设置预训练位置嵌入的尺寸

    def __init__(self, vocab_file, unk_token="[GO]", bos_token="[GO]", eos_token="[s]", pad_token="[GO]", **kwargs):
        """
        Initialize a tokenizer instance.

        Args:
            vocab_file (`str`):
                Path to the vocabulary file.
            unk_token (`str`, *optional*, defaults to `"[GO]"`):
                The unknown token.
            bos_token (`str`, *optional*, defaults to `"[GO]"`):
                The beginning of sequence token.
            eos_token (`str`, *optional*, defaults to `"[s]"`):
                The end of sequence token.
            pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"[GO]"`):
                The padding token used in batching.
            **kwargs:
                Additional keyword arguments passed to the parent class constructor.
        """
        with open(vocab_file, encoding="utf-8") as vocab_handle:
            self.vocab = json.load(vocab_handle)   # 从指定路径加载词汇表文件,并转换为字典形式
        self.decoder = {v: k for k, v in self.vocab.items()}   # 创建反向词汇表,用于将ID转换为对应的词汇
        super().__init__(
            unk_token=unk_token,
            bos_token=bos_token,
            eos_token=eos_token,
            pad_token=pad_token,
            **kwargs,
        )

    @property
    def vocab_size(self):
        """
        Return the size of the vocabulary.

        Returns:
            int: Number of tokens in the vocabulary.
        """
        return len(self.vocab)   # 返回词汇表中词汇的数量

    def get_vocab(self):
        """
        Get the vocabulary (including any additional tokens).

        Returns:
            dict: A dictionary containing the vocabulary tokens and their IDs.
        """
        vocab = dict(self.vocab).copy()
        vocab.update(self.added_tokens_encoder)
        return vocab   # 返回包含额外token的完整词汇表字典
    # 将文本字符串进行分词处理,返回字符级别的标记列表
    def _tokenize(self, text):
        char_tokens = []  # 初始化一个空列表,用于存储字符级别的标记
        for s in text:
            char_tokens.extend(s)  # 将每个字符作为一个标记加入到列表中
        return char_tokens  # 返回字符级别的标记列表

    # 根据词汇表将标记转换为对应的 ID
    def _convert_token_to_id(self, token):
        return self.vocab.get(token, self.vocab.get(self.unk_token))  # 返回标记对应的 ID,如果标记不存在则使用未知标记的 ID

    # 根据词汇表将 ID 转换为对应的标记
    def _convert_id_to_token(self, index):
        return self.decoder.get(index)  # 返回给定 ID 对应的标记

    # 将词汇表保存到指定的目录中
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        if not os.path.isdir(save_directory):  # 检查保存目录是否存在,如果不存在则记录错误并返回
            logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
            return  # 返回空值,表示保存操作未成功

        # 构建词汇表文件的路径,文件名根据可选的前缀和预定义的文件名组成
        vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )

        # 将词汇表以 JSON 格式写入到文件中
        with open(vocab_file, "w", encoding="utf-8") as f:
            f.write(json.dumps(self.vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")

        return (vocab_file,)  # 返回保存的词汇表文件路径的元组

.\models\mgp_str\__init__.py

# flake8: noqa
# 由于在此模块中无法忽略 "F401 '...' imported but unused" 警告,但要保留其他警告。
# 因此,完全不检查这个模块。

# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

# 定义模块结构的导入方式和依赖关系
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available

# 模块导入结构字典,指定各模块和其对应的导入内容列表
_import_structure = {
    "configuration_mgp_str": ["MGP_STR_PRETRAINED_CONFIG_ARCHIVE_MAP", "MgpstrConfig"],
    "processing_mgp_str": ["MgpstrProcessor"],
    "tokenization_mgp_str": ["MgpstrTokenizer"],
}

# 检查是否有torch可用,若不可用则抛出OptionalDependencyNotAvailable异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果torch可用,则增加额外的模块导入信息到_import_structure字典中
    _import_structure["modeling_mgp_str"] = [
        "MGP_STR_PRETRAINED_MODEL_ARCHIVE_LIST",
        "MgpstrModel",
        "MgpstrPreTrainedModel",
        "MgpstrForSceneTextRecognition",
    ]

# 如果在类型检查模式下
if TYPE_CHECKING:
    # 从对应模块中导入特定的类或变量
    from .configuration_mgp_str import MGP_STR_PRETRAINED_CONFIG_ARCHIVE_MAP, MgpstrConfig
    from .processing_mgp_str import MgpstrProcessor
    from .tokenization_mgp_str import MgpstrTokenizer

    # 再次检查torch是否可用,若不可用则抛出异常
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果torch可用,则从模型相关模块中导入特定的类或变量
        from .modeling_mgp_str import (
            MGP_STR_PRETRAINED_MODEL_ARCHIVE_LIST,
            MgpstrForSceneTextRecognition,
            MgpstrModel,
            MgpstrPreTrainedModel,
        )
else:
    # 如果不在类型检查模式下,则将当前模块设置为懒加载模块
    import sys

    # 使用_LazyModule类封装当前模块,以实现延迟加载
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\mistral\configuration_mistral.py

# coding=utf-8
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Mistral model configuration"""

from ...configuration_utils import PretrainedConfig  # 导入预训练配置基类
from ...utils import logging  # 导入日志模块


logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器对象

MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json",
    "mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json",
}

class MistralConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an
    Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1.

    [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
    [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    ```
    >>> from transformers import MistralModel, MistralConfig

    >>> # Initializing a Mistral 7B style configuration
    >>> configuration = MistralConfig()

    >>> # Initializing a model from the Mistral 7B style configuration
    >>> model = MistralModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    """

    model_type = "mistral"  # 模型类型为 mistral
    keys_to_ignore_at_inference = ["past_key_values"]  # 推断阶段要忽略的键列表

    def __init__(
        self,
        vocab_size=32000,  # 词汇表大小,默认为 32000
        hidden_size=4096,  # 隐藏层大小,默认为 4096
        intermediate_size=14336,  # 中间层大小,默认为 14336
        num_hidden_layers=32,  # 隐藏层层数,默认为 32
        num_attention_heads=32,  # 注意力头数,默认为 32
        num_key_value_heads=8,  # 键值头数,默认为 8
        hidden_act="silu",  # 隐藏层激活函数,默认为 "silu"
        max_position_embeddings=4096 * 32,  # 最大位置嵌入数,默认为 4096 * 32
        initializer_range=0.02,  # 初始化范围,默认为 0.02
        rms_norm_eps=1e-6,  # RMS 归一化的 epsilon,默认为 1e-6
        use_cache=True,  # 是否使用缓存,默认为 True
        pad_token_id=None,  # 填充标记的 id,默认为 None
        bos_token_id=1,  # 起始标记的 id,默认为 1
        eos_token_id=2,  # 终止标记的 id,默认为 2
        tie_word_embeddings=False,  # 是否绑定词嵌入,默认为 False
        rope_theta=10000.0,  # ROPE 参数,默认为 10000.0
        sliding_window=4096,  # 滑动窗口大小,默认为 4096
        attention_dropout=0.0,  # 注意力层的 dropout 比率,默认为 0.0
        **kwargs,  # 其他关键字参数
    ):
        super().__init__(**kwargs)  # 调用父类的初始化方法
        ):
        # 设置模型的词汇表大小
        self.vocab_size = vocab_size
        # 设置模型的最大位置嵌入数量
        self.max_position_embeddings = max_position_embeddings
        # 设置模型的隐藏层大小
        self.hidden_size = hidden_size
        # 设置模型的中间层大小
        self.intermediate_size = intermediate_size
        # 设置模型的隐藏层数量
        self.num_hidden_layers = num_hidden_layers
        # 设置模型的注意力头数量
        self.num_attention_heads = num_attention_heads
        # 设置模型的滑动窗口大小
        self.sliding_window = sliding_window

        # 为了向后兼容性
        # 如果未提供键值头数量,则使用注意力头数量
        if num_key_value_heads is None:
            num_key_value_heads = num_attention_heads

        # 设置模型的键值头数量
        self.num_key_value_heads = num_key_value_heads
        # 设置模型的隐藏层激活函数
        self.hidden_act = hidden_act
        # 设置模型的初始化范围
        self.initializer_range = initializer_range
        # 设置模型的RMS归一化的epsilon值
        self.rms_norm_eps = rms_norm_eps
        # 设置模型是否使用缓存
        self.use_cache = use_cache
        # 设置模型的ROPE theta值
        self.rope_theta = rope_theta
        # 设置模型的注意力dropout率
        self.attention_dropout = attention_dropout

        # 调用父类初始化方法,设置模型的特殊标记ID,并传递额外参数
        super().__init__(
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            tie_word_embeddings=tie_word_embeddings,
            **kwargs,
        )

.\models\mistral\convert_mistral_weights_to_hf.py

# 导入必要的库和模块
import argparse  # 解析命令行参数的库
import gc  # Python 的垃圾回收模块
import json  # 处理 JSON 格式数据的库
import os  # 提供与操作系统交互的功能
import shutil  # 提供高级文件操作功能
import warnings  # 发出警告的模块

import torch  # 引入 PyTorch 深度学习库

# 从transformers库中导入所需的类和函数
from transformers import (
    LlamaTokenizer,  # LlamaTokenizer 分词器
    MistralConfig,  # Mistral模型的配置类
    MistralForCausalLM,  # 用于生成文本的Mistral模型
)

try:
    from transformers import LlamaTokenizerFast  # 尝试导入快速版LlamaTokenizer

    tokenizer_class = LlamaTokenizerFast  # 如果导入成功,使用快速版分词器
except ImportError as e:
    warnings.warn(e)  # 输出导入错误的警告
    warnings.warn(
        "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
    )
    tokenizer_class = LlamaTokenizer  # 如果导入失败,使用慢速版分词器

"""
示例用法:

python src/transformers/models/mistral/convert_mistral_weights_to_hf.py \
    --input_dir /path/to/downloaded/mistral/weights --model_size 7B --output_dir /output/path
"""

# 将不同模型大小映射到对应的分片数量
NUM_SHARDS = {"7B": 1}


def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
    # 计算中间层的尺寸
    return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)


def read_json(path):
    # 读取指定路径下的JSON文件内容并返回解析后的Python对象
    with open(path, "r") as f:
        return json.load(f)


def write_json(text, path):
    # 将Python对象text写入到指定路径的JSON文件中
    with open(path, "w") as f:
        json.dump(text, f)


def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True):
    # 为了向后兼容,检查参数文件是否位于指定路径,若不是,则修改输入基础路径
    if not os.path.isfile(os.path.join(input_base_path, "params.json")):
        input_base_path = os.path.join(input_base_path, model_size)

    # 创建存储模型的目录和临时目录
    os.makedirs(model_path, exist_ok=True)
    tmp_model_path = os.path.join(model_path, "tmp")
    os.makedirs(tmp_model_path, exist_ok=True)

    # 读取参数文件中的参数信息
    params = read_json(os.path.join(input_base_path, "params.json"))
    num_shards = NUM_SHARDS[model_size]

    # 将参数中的滑动窗口大小转换为整数
    sliding_window = int(params["sliding_window"])
    n_layers = params["n_layers"]
    n_heads = params["n_heads"]
    # 计算每个分片中的注意力头数量
    n_heads_per_shard = n_heads // num_shards
    # 从参数字典中获取维度信息
    dim = params["dim"]
    # 计算每个注意力头的维度
    dims_per_head = dim // n_heads
    # 获取参数中的 "rope_theta",默认为 10000.0
    base = params.get("rope_theta", 10000.0)
    # 计算正弦频率的倒数,用于位置编码
    inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
    # 设置最大位置编码长度
    max_position_embeddings = 4096 * 8

    # 如果指定了 tokenizer_path,则初始化并保存 tokenizer
    if tokenizer_path is not None:
        tokenizer = tokenizer_class(tokenizer_path)
        tokenizer.save_pretrained(model_path)
    # 获取词汇表大小,如果未指定 tokenizer_path 则默认为 32000
    vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000

    # 如果参数中包含 "n_kv_heads",则设置键值头的数量
    if "n_kv_heads" in params:
        num_key_value_heads = params["n_kv_heads"]  # for GQA / MQA
        # 计算每个本地键值头的数量
        num_local_key_value_heads = num_key_value_heads // num_shards
        # 计算键值维度
        key_value_dim = dims_per_head * num_local_key_value_heads
    else:  # 兼容其他检查点
        num_key_value_heads = n_heads
        num_local_key_value_heads = n_heads_per_shard
        key_value_dim = dim

    # 定义用于切片旋转的排列函数
    def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
        return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)

    # 打印加载检查点的消息
    print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
    # 加载权重
    loaded = [
        torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
        for i in range(num_shards)
    ]
    # 初始化参数计数器和索引字典
    param_count = 0
    index_dict = {"weight_map": {}}
    # 设置模型文件名
    filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
    # 构建状态字典
    state_dict = {
        "model.norm.weight": loaded[0]["norm.weight"],
        "model.embed_tokens.weight": torch.cat([loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1),
        "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
    }

    # 将状态字典的键值对保存到索引字典中,并统计参数数量
    for k, v in state_dict.items():
        index_dict["weight_map"][k] = filename
        param_count += v.numel()
    # 将状态字典保存到临时模型路径
    torch.save(state_dict, os.path.join(tmp_model_path, filename))

    # 写入配置信息到索引字典
    index_dict["metadata"] = {"total_size": param_count * 2}
    write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
    # 创建 Mistral 模型配置
    config = MistralConfig(
        hidden_size=dim,
        intermediate_size=params["hidden_dim"],
        num_attention_heads=params["n_heads"],
        num_hidden_layers=params["n_layers"],
        rms_norm_eps=params["norm_eps"],
        num_key_value_heads=num_key_value_heads,
        vocab_size=vocab_size,
        rope_theta=base,
        max_position_embeddings=max_position_embeddings,
        sliding_window=sliding_window,
    )
    # 将模型配置保存到临时模型路径
    config.save_pretrained(tmp_model_path)

    # 释放不再需要的变量,进行内存回收
    del state_dict
    del loaded
    gc.collect()

    # 打印加载模型检查点的消息
    print("Loading the checkpoint in a Mistral model.")
    # 从预训练模型路径加载 Mistral 模型
    model = MistralForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
    # 移除模型配置中的 _name_or_path 属性,避免保存到配置中
    del model.config._name_or_path
    # 设置模型配置中的 Torch 数据类型为 float16
    model.config.torch_dtype = torch.float16
    # 打印保存模型为 Transformers 格式的消息
    print("Saving in the Transformers format.")
    # 使用安全序列化选项保存模型到指定路径
    model.save_pretrained(model_path, safe_serialization=safe_serialization)
    # 递归删除临时模型路径下的所有文件和文件夹
    shutil.rmtree(tmp_model_path)
# 定义一个函数用于保存 tokenizer
def write_tokenizer(tokenizer_path, input_tokenizer_path):
    # 打印保存 tokenizer 的信息,包括 tokenizer 类型和保存路径
    print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
    # 根据输入的 tokenizer 路径初始化 tokenizer 对象
    tokenizer = tokenizer_class(input_tokenizer_path)
    # 调用预训练模型的方法保存 tokenizer 到指定路径
    tokenizer.save_pretrained(tokenizer_path)


# 定义主函数
def main():
    # 创建命令行参数解析器
    parser = argparse.ArgumentParser()
    # 添加命令行参数:输入目录,用于存放 Mistral 权重,包含 tokenizer.model 和 model 文件夹
    parser.add_argument(
        "--input_dir",
        help="Location of Mistral weights, which contains tokenizer.model and model folders",
    )
    # 添加命令行参数:模型大小,可以选择 "7B" 或 "tokenizer_only"
    parser.add_argument(
        "--model_size",
        choices=["7B", "tokenizer_only"],
        help="'f' models correspond to the finetuned versions, and are specific to the Mistral2 official release. For more details on Mistral2, checkout the original repo: https://huggingface.co/meta-mistral",
    )
    # 添加命令行参数:输出目录,用于存放 HF 模型和 tokenizer
    parser.add_argument(
        "--output_dir",
        help="Location to write HF model and tokenizer",
    )
    # 添加命令行参数:安全序列化选项,是否使用 `safetensors` 进行保存
    parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
    # 解析命令行参数
    args = parser.parse_args()
    # 构建 tokenizer 的路径,拼接输入目录和 tokenizer 文件名
    spm_path = os.path.join(args.input_dir, "tokenizer.model")
    # 如果模型大小不是 "tokenizer_only",则调用 write_model 函数
    if args.model_size != "tokenizer_only":
        write_model(
            model_path=args.output_dir,
            input_base_path=args.input_dir,
            model_size=args.model_size,
            safe_serialization=args.safe_serialization,
            tokenizer_path=spm_path,
        )
    else:
        # 否则,调用 write_tokenizer 函数保存 tokenizer
        write_tokenizer(args.output_dir, spm_path)


# 如果当前脚本作为主程序运行,则执行主函数 main()
if __name__ == "__main__":
    main()

.\models\mistral\modeling_flax_mistral.py

# Import necessary modules and functions from libraries
from typing import Optional, Tuple  # 导入类型提示相关的模块和函数

import flax.linen as nn  # 导入 Flax 中的 Linen 模块并重命名为 nn
import jax  # 导入 JAX 库
import jax.numpy as jnp  # 导入 JAX 中的 numpy 模块并重命名为 jnp
import numpy as np  # 导入 numpy 库并重命名为 np
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze  # 从 Flax 中导入相关函数
from flax.linen import combine_masks, make_causal_mask  # 从 Flax 的 Linen 模块导入函数
from flax.linen.attention import dot_product_attention_weights  # 从 Flax 的 attention 模块导入函数
from flax.traverse_util import flatten_dict, unflatten_dict  # 从 Flax 的 traverse_util 模块导入函数
from jax import lax  # 导入 JAX 的 lax 模块

# Import specific outputs and utilities from related modules
from ...modeling_flax_outputs import (
    FlaxBaseModelOutput,
    FlaxBaseModelOutputWithPast,
    FlaxCausalLMOutput,
    FlaxCausalLMOutputWithCrossAttentions,
)
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, logging
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward
from .configuration_mistral import MistralConfig  # 导入 MistralConfig 配置类

# Get the logger for this module
logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器对象

# Constants for documentation purposes
_CONFIG_FOR_DOC = "MistralConfig"  # 用于文档的配置示例名称
_REAL_CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1"  # 实际检查点的文档示例
_CHECKPOINT_FOR_DOC = "ksmcg/Mistral-tiny"  # 检查点的文档示例

# Start of the model documentation string
MISTRAL_START_DOCSTRING = r"""

    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a Flax Linen
    [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
    regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
    # 参数说明:
    # config ([`MistralConfig`]): 模型配置类,包含模型的所有参数。
    #                          使用配置文件初始化时不会加载模型的权重,只加载配置信息。
    #                          查看 [`~FlaxPreTrainedModel.from_pretrained`] 方法以加载模型权重。
    # dtype (`jax.numpy.dtype`, *optional*, 默认为 `jax.numpy.float32`):
    #      计算使用的数据类型。可以是 `jax.numpy.float32`, `jax.numpy.float16` 或 `jax.numpy.bfloat16` 中的一种。
    #      可用于在 GPU 或 TPU 上启用混合精度训练或半精度推断。如果指定,则所有计算将使用给定的 `dtype` 进行。
    #
    #      **注意,这仅指定计算时的数据类型,不影响模型参数的数据类型。**
    #
    #      如果要更改模型参数的数据类型,请参阅 [`~FlaxPreTrainedModel.to_fp16`] 和 [`~FlaxPreTrainedModel.to_bf16`]。
# 定义了一个文档字符串常量,描述了 `FlaxMistralRMSNorm` 类的输入参数和用法
MISTRAL_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
            输入序列标记在词汇表中的索引。默认情况下,提供的填充将被忽略。
            可以使用 [`AutoTokenizer`] 获取索引。详见 [`PreTrainedTokenizer.encode`] 和 [`PreTrainedTokenizer.__call__`]。

            [什么是输入 ID?](../glossary#input-ids)
        attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            避免在填充标记索引上执行注意力操作的掩码。掩码值在 `[0, 1]` 范围内:

            - 对于 **未被掩码** 的标记,值为 1,
            - 对于 **被掩码** 的标记,值为 0。

            可以使用 [`AutoTokenizer`] 获取索引。详见 [`PreTrainedTokenizer.encode`] 和 [`PreTrainedTokenizer.__call__`]。

            如果使用了 `past_key_values`,可以选择仅输入最后的 `decoder_input_ids`(参见 `past_key_values`)。

            如果要更改填充行为,应阅读 [`modeling_opt._prepare_decoder_attention_mask`] 并根据需求进行修改。详见 [该论文中的图表 1](https://arxiv.org/abs/1910.13461) 获取有关默认策略的更多信息。

            - 1 表示头部 **未被掩码**,
            - 0 表示头部 **被掩码**。
        position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            每个输入序列标记在位置嵌入中的位置索引。选择范围为 `[0, config.n_positions - 1]`。

            [什么是位置 ID?](../glossary#position-ids)
        past_key_values (`Dict[str, np.ndarray]`, *optional*, 由 `init_cache` 返回或传递先前的 `past_key_values`):
            预计算隐藏状态的字典(键和值在注意力块中)。可用于快速自回归解码。预计算的键和值隐藏状态的形状为 *[batch_size, max_length]*。
        output_attentions (`bool`, *optional*):
            是否返回所有注意力层的注意力张量。详见返回的张量中的 `attentions` 获取更多细节。
        output_hidden_states (`bool`, *optional*):
            是否返回所有层的隐藏状态。详见返回的张量中的 `hidden_states` 获取更多细节。
        return_dict (`bool`, *optional*):
            是否返回 [`~utils.ModelOutput`] 而不是普通的元组。
"""

# 从 `transformers.models.llama.modeling_flax_llama.FlaxLlamaRMSNorm` 复制并修改为 `FlaxMistralRMSNorm`
class FlaxMistralRMSNorm(nn.Module):
    # 类型注解,指定了 `config` 属性的类型为 `MistralConfig`
    config: MistralConfig
    # 默认数据类型为 `jnp.float32`
    dtype: jnp.dtype = jnp.float32
    # 初始化对象的epsilon属性为配置中的rms_norm_eps值
    self.epsilon = self.config.rms_norm_eps

    # 初始化对象的weight属性,使用param方法生成,传入的lambda函数生成一个形状为hidden_size的全1数组
    self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.hidden_size)

    # 定义对象的调用方法,接收hidden_states作为参数
    def __call__(self, hidden_states):
        # 将hidden_states转换为JAX支持的float32类型的数组variance
        variance = jnp.asarray(hidden_states, dtype=jnp.float32)
        
        # 对variance中的每个元素求平方
        variance = jnp.power(variance, 2)
        
        # 对variance在最后一个维度上求平均值,并保持维度为1
        variance = variance.mean(-1, keepdims=True)
        
        # 使用JAX的sqrt函数对variance加上epsilon后开方,作为对hidden_states的归一化系数
        # 注意:使用jax.numpy.sqrt代替jax.lax.rsqrt是因为两者的行为不同于torch.rsqrt
        hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon)

        # 返回归一化后的hidden_states乘以对象的weight属性
        return self.weight * jnp.asarray(hidden_states, dtype=self.dtype)
# 从 transformers.models.llama.modeling_flax_llama.FlaxLlamaRotaryEmbedding 复制代码,将 Llama 替换为 Mistral
class FlaxMistralRotaryEmbedding(nn.Module):
    # 使用 MistralConfig 配置信息
    config: MistralConfig
    # 数据类型默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # 计算每个注意力头的维度
        head_dim = self.config.hidden_size // self.config.num_attention_heads
        # 创建正弦和余弦位置编码
        self.sincos = create_sinusoidal_positions(self.config.max_position_embeddings, head_dim)

    def __call__(self, key, query, position_ids):
        # 根据位置编码获取对应的正弦和余弦值
        sincos = self.sincos[position_ids]
        sin_pos, cos_pos = jnp.split(sincos, 2, axis=-1)

        # 应用旋转位置编码到键和查询张量
        key = apply_rotary_pos_emb(key, sin_pos, cos_pos)
        query = apply_rotary_pos_emb(query, sin_pos, cos_pos)

        # 转换为指定数据类型
        key = jnp.asarray(key, dtype=self.dtype)
        query = jnp.asarray(query, dtype=self.dtype)

        # 返回处理后的键和查询张量
        return key, query


# 从 transformers.models.llama.modeling_flax_llama.FlaxLlamaMLP 复制代码,将 Llama 替换为 Mistral
class FlaxMistralMLP(nn.Module):
    # 使用 MistralConfig 配置信息
    config: MistralConfig
    # 数据类型默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # 获取嵌入维度和内部维度
        embed_dim = self.config.hidden_size
        inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim

        # 初始化内核,并设置激活函数
        kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
        self.act = ACT2FN[self.config.hidden_act]

        # 定义门控投影、下游投影和上游投影
        self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
        self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
        self.up_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)

    def __call__(self, hidden_states):
        # 上游投影处理隐藏状态
        up_proj_states = self.up_proj(hidden_states)
        # 使用激活函数处理门控投影的隐藏状态
        gate_states = self.act(self.gate_proj(hidden_states))

        # 应用门控和上游投影到下游投影的隐藏状态
        hidden_states = self.down_proj(up_proj_states * gate_states)
        # 返回处理后的隐藏状态
        return hidden_states


# 从 transformers.models.llama.modeling_flax_llama.apply_rotary_pos_emb 复制代码
def apply_rotary_pos_emb(tensor, sin_pos, cos_pos):
    # 应用旋转位置编码到张量
    return (tensor * cos_pos) + (rotate_half(tensor) * sin_pos)


# 从 transformers.models.llama.modeling_flax_llama.create_sinusoidal_positions 复制代码
def create_sinusoidal_positions(num_pos, dim):
    # 计算逆频率
    inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
    freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32")

    # 创建正弦和余弦位置编码
    emb = np.concatenate((freqs, freqs), axis=-1)
    out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1)
    return jnp.array(out[:, :, :num_pos])


# 从 transformers.models.llama.modeling_flax_llama.rotate_half 复制代码
def rotate_half(tensor):
    """旋转输入张量的一半隐藏维度。"""
    rotate_half_tensor = jnp.concatenate(
        (-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1
    )
    return rotate_half_tensor


# 定义 FlaxMistralAttention 类,用于注意力机制,未完整复制
class FlaxMistralAttention(nn.Module):
    # 使用 MistralConfig 配置信息
    config: MistralConfig
    # 数据类型默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32
    def setup(self):
        # 从配置中获取参数
        config = self.config
        # 设置隐藏层大小
        self.hidden_size = config.hidden_size
        # 设置注意力头数
        self.num_heads = config.num_attention_heads
        # 计算每个注意力头的维度
        self.head_dim = self.hidden_size // self.num_heads
        # 设置键值头数
        self.num_key_value_heads = config.num_key_value_heads
        # 计算每个键值组的头数
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        # 设置最大位置嵌入数
        self.max_position_embeddings = config.max_position_embeddings
        # 判断是否需要在注意力softmax计算中使用fp32精度
        self.attention_softmax_in_fp32 = self.dtype is not jnp.float32
        # 设置rope_theta
        self.rope_theta = config.rope_theta
        
        # 检查隐藏层大小是否可以被注意力头数整除
        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        
        # 初始化查询、键、值和输出的线性投影层
        self.q_proj = nn.Dense(self.num_heads * self.head_dim, use_bias=False, dtype=self.dtype)
        self.k_proj = nn.Dense(self.num_key_value_heads * self.head_dim, use_bias=False, dtype=self.dtype)
        self.v_proj = nn.Dense(self.num_key_value_heads * self.head_dim, use_bias=False, dtype=self.dtype)
        self.o_proj = nn.Dense(self.hidden_size, use_bias=False, dtype=self.dtype)
        
        # 创建自回归遮罩
        casual_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
        # 根据滑动窗口大小生成自回归遮罩
        self.causal_mask = jnp.triu(casual_mask, k=-config.sliding_window)
        
        # 初始化旋转嵌入
        self.rotary_emb = FlaxMistralRotaryEmbedding(config, dtype=self.dtype)

    def _split_heads(self, hidden_states, num_heads):
        # 将隐藏状态分割成多个头
        return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))

    def _merge_heads(self, hidden_states):
        # 合并多个头的隐藏状态
        return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))

    @nn.compact
    # 从transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache复制而来
    def _concatenate_to_cache(self, key, value, query, attention_mask):
        """
        This function takes projected key, value states from a single input token and concatenates the states to cached
        states from previous steps. This function is slightly adapted from the official Flax repository:
        https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
        """
        # 检测是否初始化缓存数据
        is_initialized = self.has_variable("cache", "cached_key")
        # 获取或者初始化缓存的 key 和 value,若不存在则创建零张量
        cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
        cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
        # 获取或者初始化缓存的索引,若不存在则设置为 0
        cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))

        if is_initialized:
            *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
            # 使用新的 1D 空间切片更新 key 和 value 缓存
            cur_index = cache_index.value
            indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
            key = lax.dynamic_update_slice(cached_key.value, key, indices)
            value = lax.dynamic_update_slice(cached_value.value, value, indices)
            # 更新缓存中的 key 和 value
            cached_key.value = key
            cached_value.value = value
            # 更新缓存索引,增加已更新的缓存向量数目
            num_updated_cache_vectors = query.shape[1]
            cache_index.value = cache_index.value + num_updated_cache_vectors
            # 用于缓存的自注意力掩码:我们的单个查询位置应仅关注已生成和缓存的 key 位置,而不是剩余的零元素。
            pad_mask = jnp.broadcast_to(
                jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
                tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
            )
            # 组合现有的掩码和给定的注意力掩码
            attention_mask = combine_masks(pad_mask, attention_mask)
        
        # 返回更新后的 key, value 和注意力掩码
        return key, value, attention_mask
        ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        # 使用 self.q_proj 对隐藏状态进行投影得到查询状态
        query_states = self.q_proj(hidden_states)
        # 使用 self.k_proj 对隐藏状态进行投影得到键状态
        key_states = self.k_proj(hidden_states)
        # 使用 self.v_proj 对隐藏状态进行投影得到值状态
        value_states = self.v_proj(hidden_states)

        # 将查询状态按照头数进行分割
        query_states = self._split_heads(query_states, self.num_heads)
        # 将键状态按照键值头数进行分割
        key_states = self._split_heads(key_states, self.num_key_value_heads)
        # 将值状态按照键值头数进行分割
        value_states = self._split_heads(value_states, self.num_key_value_heads)

        # 使用 rotary_emb 方法对键状态和查询状态进行旋转嵌入
        key_states, query_states = self.rotary_emb(key_states, query_states, position_ids)

        # 获取查询和键的长度
        query_length, key_length = query_states.shape[1], key_states.shape[1]

        # 根据是否有缓存的键来确定掩码的偏移量和最大解码长度
        if self.has_variable("cache", "cached_key"):
            mask_shift = self.variables["cache"]["cache_index"]
            max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
            # 创建动态切片的因果掩码
            causal_mask = lax.dynamic_slice(
                self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
            )
        else:
            # 使用预先计算好的因果掩码
            causal_mask = self.causal_mask[:, :, :query_length, :key_length]

        # 获取批次大小
        batch_size = hidden_states.shape[0]
        # 将因果掩码广播到与注意力掩码相同的形状
        causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
        # 将注意力掩码扩展到与因果掩码相同的形状
        attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
        # 结合注意力掩码和因果掩码
        attention_mask = combine_masks(attention_mask, causal_mask)

        # 如果有缓存的键或者需要初始化缓存,则将键状态、值状态和注意力掩码拼接到缓存中
        if self.has_variable("cache", "cached_key") or init_cache:
            key_states, value_states, attention_mask = self._concatenate_to_cache(
                key_states, value_states, query_states, attention_mask
            )
        
        # 将键状态在键值组之间重复以支持并行处理
        key_states = jnp.repeat(key_states, self.num_key_value_groups, axis=2)
        # 将值状态在键值组之间重复以支持并行处理
        value_states = jnp.repeat(value_states, self.num_key_value_groups, axis=2)

        # 创建注意力偏置,根据注意力掩码设置有效和无效区域的偏置值
        attention_bias = lax.select(
            attention_mask > 0,
            jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
            jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
        )

        # 常规的点积注意力计算
        attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype
        attn_weights = dot_product_attention_weights(
            query_states,
            key_states,
            bias=attention_bias,
            deterministic=deterministic,
            dropout_rate=self.config.attention_dropout,
            dtype=attention_dtype,
        )

        # 如果需要在 float32 中执行 softmax,将注意力权重转换为目标 dtype
        if self.attention_softmax_in_fp32:
            attn_weights = attn_weights.astype(self.dtype)

        # 使用 einsum 执行注意力加权求和操作
        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
        # 合并多头的结果
        attn_output = self._merge_heads(attn_output)
        # 对输出应用输出投影
        attn_output = self.o_proj(attn_output)

        # 准备输出,包括注意力权重(如果需要)
        outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
        return outputs
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaDecoderLayer with Llama->Mistral
class FlaxMistralDecoderLayer(nn.Module):
    config: MistralConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # 初始化输入层的 Layer Normalization
        self.input_layernorm = FlaxMistralRMSNorm(self.config, dtype=self.dtype)
        # 初始化自注意力机制
        self.self_attn = FlaxMistralAttention(self.config, dtype=self.dtype)
        # 初始化自注意力后的 Layer Normalization
        self.post_attention_layernorm = FlaxMistralRMSNorm(self.config, dtype=self.dtype)
        # 初始化多层感知机 MLP
        self.mlp = FlaxMistralMLP(self.config, dtype=self.dtype)

    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        position_ids=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
    ):
        # 残差连接
        residual = hidden_states
        # 应用输入层的 Layer Normalization
        hidden_states = self.input_layernorm(hidden_states)
        # 应用自注意力机制
        outputs = self.self_attn(
            hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
        )
        # 残差连接
        attn_output = outputs[0]
        hidden_states = residual + attn_output

        # 残差连接
        residual = hidden_states
        # 应用自注意力后的 Layer Normalization
        hidden_states = self.post_attention_layernorm(hidden_states)
        # 应用多层感知机 MLP
        hidden_states = self.mlp(hidden_states)
        # 残差连接
        hidden_states = residual + hidden_states

        return (hidden_states,) + outputs[1:]


# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel with GPTNeo->Mistral, GPT_NEO->MISTRAL, transformer->model
class FlaxMistralPreTrainedModel(FlaxPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = MistralConfig
    base_model_prefix = "model"
    module_class: nn.Module = None

    def __init__(
        self,
        config: MistralConfig,
        input_shape: Tuple = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        # 初始化模块对象
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        # 调用父类初始化方法
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # 初始化输入张量
        input_ids = jnp.zeros(input_shape, dtype="i4")
        # 创建与input_ids形状相同的全1张量作为注意力掩码
        attention_mask = jnp.ones_like(input_ids)
        # 根据input_ids的形状广播生成位置编码张量
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
        # 拆分随机数生成器rng,生成参数随机数和dropout随机数
        params_rng, dropout_rng = jax.random.split(rng)
        # 存储随机数生成器
        rngs = {"params": params_rng, "dropout": dropout_rng}

        # 使用self.module的初始化方法初始化模型参数,返回未解冻的参数字典
        random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]

        # 如果传入了预训练的参数params,则与随机初始化的参数进行合并
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            # 返回合并后的冻结参数字典
            return freeze(unflatten_dict(params))
        else:
            # 否则返回随机初始化的参数字典
            return random_params

    def init_cache(self, batch_size, max_length):
        r"""
        Args:
            batch_size (`int`):
                用于快速自回归解码的批处理大小,定义了初始化缓存的批处理大小。
            max_length (`int`):
                自回归解码的最大可能长度,定义了初始化缓存的序列长度。
        """
        # 初始化用于检索缓存的输入变量
        input_ids = jnp.ones((batch_size, max_length))
        # 创建与input_ids形状相同的全1张量作为注意力掩码
        attention_mask = jnp.ones_like(input_ids)
        # 根据input_ids的形状广播生成位置编码张量
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        # 使用self.module的初始化方法初始化模型变量,设置init_cache=True以初始化缓存
        init_variables = self.module.init(
            jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
        )
        # 返回未解冻的缓存字典
        return unfreeze(init_variables["cache"])

    @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        position_ids=None,
        params: dict = None,
        past_key_values: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        ):
            # 如果没有显式传入 output_attentions 参数,则使用配置中的设定
            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
            # 如果没有显式传入 output_hidden_states 参数,则使用配置中的设定
            output_hidden_states = (
                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
            )
            # 如果没有显式传入 return_dict 参数,则使用配置中的设定
            return_dict = return_dict if return_dict is not None else self.config.return_dict

            # 获取输入张量的批量大小和序列长度
            batch_size, sequence_length = input_ids.shape

            # 如果未传入 position_ids,则根据序列长度和批量大小广播生成位置 ID
            if position_ids is None:
                if past_key_values is not None:
                    raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")

                position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

            # 如果未传入 attention_mask,则创建全为 1 的注意力遮罩
            if attention_mask is None:
                attention_mask = jnp.ones((batch_size, sequence_length))

            # 处理任何需要的伪随机数生成器
            rngs = {}
            if dropout_rng is not None:
                rngs["dropout"] = dropout_rng

            inputs = {"params": params or self.params}

            # 如果传入了 past_key_values,则将其作为 cache 输入到模块中,确保 cache 是可变的
            if past_key_values:
                inputs["cache"] = past_key_values
                mutable = ["cache"]
            else:
                mutable = False

            # 调用模块的 apply 方法进行前向传播
            outputs = self.module.apply(
                inputs,
                jnp.array(input_ids, dtype="i4"),
                jnp.array(attention_mask, dtype="i4"),
                jnp.array(position_ids, dtype="i4"),
                not train,
                False,
                output_attentions,
                output_hidden_states,
                return_dict,
                rngs=rngs,
                mutable=mutable,
            )

            # 如果传入了 past_key_values 并且设置了 return_dict,则将更新后的 cache 添加到模型输出中
            if past_key_values is not None and return_dict:
                outputs, past_key_values = outputs
                outputs["past_key_values"] = unfreeze(past_key_values["cache"])
                return outputs
            # 如果传入了 past_key_values 但未设置 return_dict,则更新 cache 并将其添加到模型输出中
            elif past_key_values is not None and not return_dict:
                outputs, past_key_values = outputs
                outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]

            # 返回模型输出
            return outputs
# 从transformers.models.llama.modeling_flax_llama.FlaxLlamaLayerCollection复制而来,将Llama改为Mistral
class FlaxMistralLayerCollection(nn.Module):
    # MistralConfig的实例变量config,dtype默认为jnp.float32
    config: MistralConfig
    dtype: jnp.dtype = jnp.float32

    # 模块初始化方法
    def setup(self):
        # 创建self.config.num_hidden_layers个FlaxMistralDecoderLayer对象列表
        self.blocks = [
            FlaxMistralDecoderLayer(self.config, dtype=self.dtype, name=str(i))
            for i in range(self.config.num_hidden_layers)
        ]

    # 模块调用方法
    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        position_ids=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = False,
    ):
        # 如果输出attentions,则初始化空元组all_attentions;否则为None
        all_attentions = () if output_attentions else None
        # 如果输出hidden states,则初始化空元组all_hidden_states;否则为None
        all_hidden_states = () if output_hidden_states else None

        # 遍历self.blocks中的每个FlaxMistralDecoderLayer对象
        for block in self.blocks:
            # 如果需要输出hidden states,则将当前hidden_states添加到all_hidden_states元组中
            if output_hidden_states:
                all_hidden_states += (hidden_states,)
            # 调用block对象进行前向传播,获取layer_outputs
            layer_outputs = block(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                deterministic=deterministic,
                init_cache=init_cache,
                output_attentions=output_attentions,
            )
            # 更新hidden_states为block的输出的第一个元素
            hidden_states = layer_outputs[0]

            # 如果需要输出attentions,则将当前层的attentions添加到all_attentions元组中
            if output_attentions:
                all_attentions += (layer_outputs[1],)

        # 输出包含可能为None值的元组outputs,FlaxMistralModule将会过滤掉这些None值
        outputs = (hidden_states, all_hidden_states, all_attentions)

        # 返回outputs作为模块的输出结果
        return outputs


# 从transformers.models.llama.modeling_flax_llama.FlaxLlamaModule复制而来,将Llama改为Mistral
class FlaxMistralModule(nn.Module):
    # MistralConfig的实例变量config,dtype默认为jnp.float32
    config: MistralConfig
    dtype: jnp.dtype = jnp.float32

    # 模块初始化方法
    def setup(self):
        # 设置self.hidden_size为self.config.hidden_size
        self.hidden_size = self.config.hidden_size
        # 使用正态分布初始化embed_tokens的embedding参数
        embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range)
        # 创建nn.Embed对象embed_tokens,用于token的embedding
        self.embed_tokens = nn.Embed(
            self.config.vocab_size,
            self.hidden_size,
            embedding_init=embedding_init,
            dtype=self.dtype,
        )
        # 创建FlaxMistralLayerCollection对象self.layers,用于处理层间关系
        self.layers = FlaxMistralLayerCollection(self.config, dtype=self.dtype)
        # 创建FlaxMistralRMSNorm对象self.norm,用于层间正则化
        self.norm = FlaxMistralRMSNorm(self.config, dtype=self.dtype)

    # 模块调用方法
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        position_ids=None,
        deterministic=True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 返回字典形式的输出结果
        # 输入参数input_ids、attention_mask、position_ids以及其他标志位
    ):
        # 将输入的 token IDs 转换为嵌入表示,数据类型为整数
        input_embeds = self.embed_tokens(input_ids.astype("i4"))

        # 使用 Transformer 层处理输入数据
        outputs = self.layers(
            input_embeds,
            position_ids=position_ids,
            attention_mask=attention_mask,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取模型输出的隐藏状态
        hidden_states = outputs[0]

        # 对隐藏状态进行规范化处理
        hidden_states = self.norm(hidden_states)

        # 如果需要输出所有隐藏状态,则将当前隐藏状态加入所有隐藏状态列表
        if output_hidden_states:
            all_hidden_states = outputs[1] + (hidden_states,)
            outputs = (hidden_states, all_hidden_states) + outputs[2:]
        else:
            outputs = (hidden_states,) + outputs[1:]

        # 如果不需要返回字典形式的输出,则去除所有值为 None 的项并返回元组
        if not return_dict:
            return tuple(v for v in outputs if v is not None)

        # 返回 FlaxBaseModelOutput 对象,包含最后的隐藏状态、所有隐藏状态和注意力值
        return FlaxBaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=outputs[1],
            attentions=outputs[-1],
        )
# 添加文档字符串到 FlaxMistralModel 类,说明其作用是提供裸的 Mistral 模型变换器输出,没有特定的输出头部。
@add_start_docstrings(
    "The bare Mistral Model transformer outputting raw hidden-states without any specific head on top.",
    MISTRAL_START_DOCSTRING,
)
class FlaxMistralModel(FlaxMistralPreTrainedModel):
    # 设置模块类为 FlaxMistralModule
    module_class = FlaxMistralModule


# 向 FlaxMistralModel 类添加调用示例文档字符串,用于样例的调用说明
append_call_sample_docstring(
    FlaxMistralModel,
    _CHECKPOINT_FOR_DOC,
    FlaxBaseModelOutputWithPast,
    _CONFIG_FOR_DOC,
    real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
)


# 从 transformers.models.llama.modeling_flax_llama.FlaxLlamaForCausalLMModule 复制代码,并将 Llama 更改为 Mistral
class FlaxMistralForCausalLMModule(nn.Module):
    config: MistralConfig  # 定义配置为 MistralConfig 类型
    dtype: jnp.dtype = jnp.float32  # 设置数据类型为 jnp.float32,默认为 float32

    def setup(self):
        # 使用配置和数据类型创建 FlaxMistralModule 模型
        self.model = FlaxMistralModule(self.config, dtype=self.dtype)
        # 创建 LM 头部,是一个全连接层,用于语言建模任务
        self.lm_head = nn.Dense(
            self.config.vocab_size,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
        )

    def __call__(
        self,
        input_ids,
        attention_mask=None,
        position_ids=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 调用模型进行前向传播
        outputs = self.model(
            input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从模型输出中提取隐藏状态
        hidden_states = outputs[0]
        # 计算语言建模的 logits
        lm_logits = self.lm_head(hidden_states)

        # 如果不返回字典,则返回一个元组,包含 lm_logits 和其他输出
        if not return_dict:
            return (lm_logits,) + outputs[1:]

        # 返回 FlaxCausalLMOutput 对象,包含 logits、隐藏状态和注意力信息
        return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)


# 添加文档字符串到 FlaxMistralForCausalLM 类,说明其作用是在 Mistral 模型变换器上方增加语言建模头部(线性层)
@add_start_docstrings(
    """
    The Mistral Model transformer with a language modeling head (linear layer) on top.
    """,
    MISTRAL_START_DOCSTRING,
)
# 从 transformers.models.gptj.modeling_flax_gptj.FlaxGPTJForCausalLM 复制代码,并将 GPTJ 更改为 Mistral
class FlaxMistralForCausalLM(FlaxMistralPreTrainedModel):
    module_class = FlaxMistralForCausalLMModule
    def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
        # initializing the cache
        # 获取输入的批量大小和序列长度
        batch_size, seq_length = input_ids.shape

        # 使用初始化方法初始化过去的键值对
        past_key_values = self.init_cache(batch_size, max_length)

        # 因为Mistral使用因果遮罩,对超出input_ids.shape[-1]和小于cache_length的位置已经进行了遮罩处理
        # 所以我们可以在这里创建一个静态的注意力遮罩,这对编译效率更高
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
        if attention_mask is not None:
            # 根据给定的注意力遮罩计算位置ID
            position_ids = attention_mask.cumsum(axis=-1) - 1
            # 动态更新静态的注意力遮罩,将attention_mask的值复制进去
            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
        else:
            # 如果没有给定注意力遮罩,则使用默认的位置ID
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))

        # 返回准备好的输入字典
        return {
            "past_key_values": past_key_values,
            "attention_mask": extended_attention_mask,
            "position_ids": position_ids,
        }

    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        # 更新生成过程中的输入参数
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        # 更新位置ID,将当前位置向后移动一步
        model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
        return model_kwargs
# 调用函数 `append_call_sample_docstring`,用于向指定类添加示例文档字符串。
# 第一个参数 `FlaxMistralForCausalLM`:目标类,将在其上添加示例文档字符串。
# 第二个参数 `_CHECKPOINT_FOR_DOC`:用作示例文档字符串中的检查点的常量或路径。
# 第三个参数 `FlaxCausalLMOutputWithCrossAttentions`:示例文档字符串中的输出类。
# 第四个参数 `_CONFIG_FOR_DOC`:用作示例文档字符串中的配置的常量或路径。
# 关键字参数 `real_checkpoint=_REAL_CHECKPOINT_FOR_DOC`:用于指定示例文档字符串中真实检查点的常量或路径。
append_call_sample_docstring(
    FlaxMistralForCausalLM,
    _CHECKPOINT_FOR_DOC,
    FlaxCausalLMOutputWithCrossAttentions,
    _CONFIG_FOR_DOC,
    real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
)
posted @ 2024-06-29 17:06  绝不原创的飞龙  阅读(50)  评论(0编辑  收藏  举报