Transformers-源码解析-十六-

Transformers 源码解析(十六)

.\models\bert\__init__.py

# 从 typing 模块导入 TYPE_CHECKING 类型检查工具
from typing import TYPE_CHECKING

# 从 ...utils 中导入必要的模块和异常类
from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_flax_available,
    is_tensorflow_text_available,
    is_tf_available,
    is_tokenizers_available,
    is_torch_available,
)

# 定义一个字典 _import_structure,用于组织各模块需要导入的内容列表
_import_structure = {
    "configuration_bert": ["BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BertConfig", "BertOnnxConfig"],
    "tokenization_bert": ["BasicTokenizer", "BertTokenizer", "WordpieceTokenizer"],
}

# 检查是否安装了 tokenizers 库,如果未安装则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_tokenizers_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果安装了 tokenizers,则添加 tokenization_bert_fast 模块到 _import_structure 字典
    _import_structure["tokenization_bert_fast"] = ["BertTokenizerFast"]

# 检查是否安装了 torch 库,如果未安装则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果安装了 torch,则添加 modeling_bert 模块到 _import_structure 字典
    _import_structure["modeling_bert"] = [
        "BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
        "BertForMaskedLM",
        "BertForMultipleChoice",
        "BertForNextSentencePrediction",
        "BertForPreTraining",
        "BertForQuestionAnswering",
        "BertForSequenceClassification",
        "BertForTokenClassification",
        "BertLayer",
        "BertLMHeadModel",
        "BertModel",
        "BertPreTrainedModel",
        "load_tf_weights_in_bert",
    ]

# 检查是否安装了 TensorFlow 库,如果未安装则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果安装了 TensorFlow,则添加 modeling_tf_bert 模块到 _import_structure 字典
    _import_structure["modeling_tf_bert"] = [
        "TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
        "TFBertEmbeddings",
        "TFBertForMaskedLM",
        "TFBertForMultipleChoice",
        "TFBertForNextSentencePrediction",
        "TFBertForPreTraining",
        "TFBertForQuestionAnswering",
        "TFBertForSequenceClassification",
        "TFBertForTokenClassification",
        "TFBertLMHeadModel",
        "TFBertMainLayer",
        "TFBertModel",
        "TFBertPreTrainedModel",
    ]

# 检查是否安装了 TensorFlow Text 库,如果未安装则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_tensorflow_text_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果安装了 TensorFlow Text,则添加 tokenization_bert_tf 模块到 _import_structure 字典
    _import_structure["tokenization_bert_tf"] = ["TFBertTokenizer"]

# 检查是否安装了 Flax 库,如果未安装则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_flax_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果安装了 Flax,则继续添加相关内容,未提供完整的代码
    pass
    # 将多个模型类名添加到_import_structure字典中的"modeling_flax_bert"键下
    _import_structure["modeling_flax_bert"] = [
        "FlaxBertForCausalLM",                   # FlaxBert用于因果语言建模的模型类
        "FlaxBertForMaskedLM",                   # FlaxBert用于遮蔽语言建模的模型类
        "FlaxBertForMultipleChoice",             # FlaxBert用于多选题的模型类
        "FlaxBertForNextSentencePrediction",     # FlaxBert用于下一句预测的模型类
        "FlaxBertForPreTraining",                # FlaxBert用于预训练的模型类
        "FlaxBertForQuestionAnswering",          # FlaxBert用于问答的模型类
        "FlaxBertForSequenceClassification",     # FlaxBert用于序列分类的模型类
        "FlaxBertForTokenClassification",        # FlaxBert用于标记分类的模型类
        "FlaxBertModel",                         # FlaxBert模型的基础模型类
        "FlaxBertPreTrainedModel",               # FlaxBert预训练模型的基础模型类
    ]
# 如果在类型检查模式下
if TYPE_CHECKING:
    # 导入 BERT 配置相关的模块和类
    from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig, BertOnnxConfig
    # 导入 BERT 的分词器相关模块和类
    from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer

    # 尝试检查 tokenizers 是否可用,如果不可用则抛出异常 OptionalDependencyNotAvailable
    try:
        if not is_tokenizers_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果可用,导入快速的 BERT 分词器
        from .tokenization_bert_fast import BertTokenizerFast

    # 尝试检查 torch 是否可用,如果不可用则抛出异常 OptionalDependencyNotAvailable
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果可用,导入 BERT 相关的模型和类
        from .modeling_bert import (
            BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
            BertForMaskedLM,
            BertForMultipleChoice,
            BertForNextSentencePrediction,
            BertForPreTraining,
            BertForQuestionAnswering,
            BertForSequenceClassification,
            BertForTokenClassification,
            BertLayer,
            BertLMHeadModel,
            BertModel,
            BertPreTrainedModel,
            load_tf_weights_in_bert,
        )

    # 尝试检查 tensorflow 是否可用,如果不可用则抛出异常 OptionalDependencyNotAvailable
    try:
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果可用,导入 TF 版本的 BERT 相关模型和类
        from .modeling_tf_bert import (
            TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
            TFBertEmbeddings,
            TFBertForMaskedLM,
            TFBertForMultipleChoice,
            TFBertForNextSentencePrediction,
            TFBertForPreTraining,
            TFBertForQuestionAnswering,
            TFBertForSequenceClassification,
            TFBertForTokenClassification,
            TFBertLMHeadModel,
            TFBertMainLayer,
            TFBertModel,
            TFBertPreTrainedModel,
        )

    # 尝试检查 tensorflow_text 是否可用,如果不可用则抛出异常 OptionalDependencyNotAvailable
    try:
        if not is_tensorflow_text_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果可用,导入 TF 版本的 BERT 分词器
        from .tokenization_bert_tf import TFBertTokenizer

    # 尝试检查 flax 是否可用,如果不可用则抛出异常 OptionalDependencyNotAvailable
    try:
        if not is_flax_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果可用,导入 Flax 版本的 BERT 相关模型和类
        from .modeling_flax_bert import (
            FlaxBertForCausalLM,
            FlaxBertForMaskedLM,
            FlaxBertForMultipleChoice,
            FlaxBertForNextSentencePrediction,
            FlaxBertForPreTraining,
            FlaxBertForQuestionAnswering,
            FlaxBertForSequenceClassification,
            FlaxBertForTokenClassification,
            FlaxBertModel,
            FlaxBertPreTrainedModel,
        )

# 如果不在类型检查模式下
else:
    # 导入 sys 模块
    import sys

    # 将当前模块设置为一个 LazyModule 对象,并导入相关结构和规范
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\bertweet\tokenization_bertweet.py

# 导入标准库和第三方库
import html  # 用于 HTML 编码和解码
import os    # 提供与操作系统交互的功能
import re    # 用于正则表达式操作
from shutil import copyfile  # 用于复制文件
from typing import List, Optional, Tuple  # 引入类型提示相关的库

import regex  # 引入 regex 库,支持更强大的正则表达式功能

# 导入 Tokenizer 的基类 PreTrainedTokenizer 和日志模块
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义词汇文件和合并文件的名称映射
VOCAB_FILES_NAMES = {
    "vocab_file": "vocab.txt",
    "merges_file": "bpe.codes",
}

# 预训练模型的词汇文件映射
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "vinai/bertweet-base": "https://huggingface.co/vinai/bertweet-base/resolve/main/vocab.txt",
    },
    "merges_file": {
        "vinai/bertweet-base": "https://huggingface.co/vinai/bertweet-base/resolve/main/bpe.codes",
    },
}

# 预训练模型的位置编码大小映射
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "vinai/bertweet-base": 128,
}

def get_pairs(word):
    """
    返回单词中的符号对集合。

    单词被表示为符号元组(符号是长度可变的字符串)。
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char

    pairs = set(pairs)
    return pairs


class BertweetTokenizer(PreTrainedTokenizer):
    """
    构造一个 BERTweet 分词器,使用字节对编码。

    此分词器继承自 PreTrainedTokenizer,该类包含大多数主要方法。用户应参考这个超类以获取更多关于这些方法的信息。
    """
    # 定义一个 Transformer 模型的配置类,用于管理与模型相关的参数和配置
    
    vocab_files_names = VOCAB_FILES_NAMES
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    
    # 初始化函数,用于设置模型配置参数
    def __init__(
        self,
        vocab_file,  # 词汇表文件的路径
        merges_file,  # 合并文件的路径
        normalization=False,  # 是否进行标准化预处理,默认为False
        bos_token="<s>",  # 预训练期间用于序列开始的特殊符号,默认为"<s>"
        eos_token="</s>",  # 序列结束的特殊符号,默认为"</s>"
        sep_token="</s>",  # 用于多个序列构建时的分隔符,默认为"</s>"
        cls_token="<s>",  # 序列分类时使用的特殊符号,构建时是序列的第一个符号,默认为"<s>"
        unk_token="<unk>",  # 未知符号,词汇表中没有时的替代符号,默认为"<unk>"
        pad_token="<pad>",  # 填充符号,用于处理不同长度序列时的填充,默认为"<pad>"
        mask_token="<mask>",  # 掩码符号,用于掩码语言建模训练中的标记,默认为"<mask>"
        **kwargs,  # 其他可选参数
    ):
        try:
            from emoji import demojize  # 尝试导入 demojize 函数从 emoji 模块
            self.demojizer = demojize  # 如果成功导入,将 demojize 函数赋值给 self.demojizer
        except ImportError:
            logger.warning(
                "emoji is not installed, thus not converting emoticons or emojis into text. Install emoji: pip3"
                " install emoji==0.6.0"
            )
            self.demojizer = None  # 如果导入失败,记录警告信息,并将 self.demojizer 设为 None

        self.vocab_file = vocab_file  # 初始化词汇表文件路径
        self.merges_file = merges_file  # 初始化合并文件路径

        self.encoder = {}  # 初始化编码器字典
        self.encoder[str(bos_token)] = 0  # 将特殊标记 bos_token 编码为 0
        self.encoder[str(pad_token)] = 1  # 将特殊标记 pad_token 编码为 1
        self.encoder[str(eos_token)] = 2  # 将特殊标记 eos_token 编码为 2
        self.encoder[str(unk_token)] = 3  # 将特殊标记 unk_token 编码为 3

        self.add_from_file(vocab_file)  # 调用 add_from_file 方法,从 vocab_file 添加更多词汇到编码器

        self.decoder = {v: k for k, v in self.encoder.items()}  # 创建解码器,将编码器的键值对颠倒

        with open(merges_file, encoding="utf-8") as merges_handle:
            merges = merges_handle.read().split("\n")[:-1]  # 读取并处理合并文件的内容
        merges = [tuple(merge.split()[:-1]) for merge in merges]  # 将每行合并内容转换为元组列表
        self.bpe_ranks = dict(zip(merges, range(len(merges))))  # 创建 BPE 合并的排名字典
        self.cache = {}  # 初始化缓存字典

        self.normalization = normalization  # 设置文本规范化选项
        self.tweetPreprocessor = TweetTokenizer()  # 初始化 TweetTokenizer 作为 tweetPreprocessor
        self.special_puncts = {"’": "'", "…": "..."}  # 定义特殊标点符号映射

        super().__init__(  # 调用父类的初始化方法,传递相应参数和关键字参数
            normalization=normalization,
            bos_token=bos_token,
            eos_token=eos_token,
            sep_token=sep_token,
            cls_token=cls_token,
            unk_token=unk_token,
            pad_token=pad_token,
            mask_token=mask_token,
            **kwargs,
        )
    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None,
        already_has_special_tokens: bool = False
    ) -> List[int]:
        """
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` method.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """

        # If the token list already has special tokens, delegate to the superclass method
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        # If there are no sequence pairs (token_ids_1 is None), add special tokens around token_ids_0
        if token_ids_1 is None:
            return [1] + ([0] * len(token_ids_0)) + [1]
        
        # For sequence pairs, add special tokens around both token_ids_0 and token_ids_1
        return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. BERTweet does
        not make use of token type ids, therefore a list of zeros is returned.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of zeros.
        """

        # Define special tokens for separation and classification
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]

        # If there are no sequence pairs, return a list of zeros of length equal to cls + token_ids_0 + sep
        if token_ids_1 is None:
            return len(cls + token_ids_0 + sep) * [0]
        
        # For sequence pairs, return a list of zeros of length equal to cls + token_ids_0 + sep + sep + token_ids_1 + sep
        return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]

    @property
    def vocab_size(self):
        # Return the size of the vocabulary, which is the length of the encoder dictionary
        return len(self.encoder)

    def get_vocab(self):
        # Return the combined dictionary of encoder and added_tokens_encoder
        return dict(self.encoder, **self.added_tokens_encoder)
    def bpe(self, token):
        # 如果 token 已经在缓存中,直接返回缓存中的结果
        if token in self.cache:
            return self.cache[token]
        
        # 将 token 转换为元组形式
        word = tuple(token)
        # 在 token 的末尾添加 "</w>",表示单词结束
        word = tuple(list(word[:-1]) + [word[-1] + "</w>"])
        # 获取单词中的所有字符对,并进行 BPE 算法处理
        pairs = get_pairs(word)

        # 如果没有字符对,直接返回原始 token
        if not pairs:
            return token

        # 循环处理字符对,直到无法再合并为止
        while True:
            # 找到优先级最高的字符对
            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
            # 如果该字符对不在预定义的 BPE 优先级中,停止处理
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            # 遍历单词中的字符
            while i < len(word):
                try:
                    j = word.index(first, i)
                except ValueError:
                    # 如果找不到字符对的第一个字符,直接将剩余部分添加到新单词中
                    new_word.extend(word[i:])
                    break
                else:
                    # 将当前位置到字符对第一个字符位置之间的部分添加到新单词中
                    new_word.extend(word[i:j])
                    i = j

                # 如果当前位置的字符与字符对的第一个字符相同,并且下一个字符与字符对的第二个字符相同,则合并为一个新的字符
                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
                    new_word.append(first + second)
                    i += 2
                else:
                    # 否则,将当前位置的字符添加到新单词中,并移动到下一个位置
                    new_word.append(word[i])
                    i += 1
            # 将新单词转换为元组形式,并更新 word 变量为新单词
            new_word = tuple(new_word)
            word = new_word
            # 如果新单词长度为1,停止循环
            if len(word) == 1:
                break
            else:
                # 否则,继续获取新的字符对
                pairs = get_pairs(word)
        
        # 将处理后的单词以 "@@ " 连接起来,并去掉末尾的特殊标记 "</w>"
        word = "@@ ".join(word)
        word = word[:-4]
        # 将处理后的结果缓存起来,并返回
        self.cache[token] = word
        return word

    def _tokenize(self, text):
        """Tokenize a string."""
        # 如果启用了 Tweet 规范化,则在进行 BPE 处理之前先对文本进行规范化
        if self.normalization:
            text = self.normalizeTweet(text)

        split_tokens = []
        # 使用正则表达式将文本分割成单词列表
        words = re.findall(r"\S+\n?", text)
        for token in words:
            # 对每个单词进行 BPE 处理,并将处理结果按空格分割后添加到 split_tokens 列表中
            split_tokens.extend(list(self.bpe(token).split(" ")))
        return split_tokens

    def normalizeTweet(self, tweet):
        """
        Normalize a raw Tweet
        """
        # 替换 Tweet 中的特殊标点符号
        for punct in self.special_puncts:
            tweet = tweet.replace(punct, self.special_puncts[punct])

        # 使用 Tweet 预处理器对 Tweet 进行分词
        tokens = self.tweetPreprocessor.tokenize(tweet)
        # 对每个 token 进行规范化处理,并用空格连接起来
        normTweet = " ".join([self.normalizeToken(token) for token in tokens])

        # 进行特定的单词规范化处理,替换常见的缩写和缩略语
        normTweet = (
            normTweet.replace("cannot ", "can not ")
            .replace("n't ", " n't ")
            .replace("n 't ", " n't ")
            .replace("ca n't", "can't")
            .replace("ai n't", "ain't")
        )
        normTweet = (
            normTweet.replace("'m ", " 'm ")
            .replace("'re ", " 're ")
            .replace("'s ", " 's ")
            .replace("'ll ", " 'll ")
            .replace("'d ", " 'd ")
            .replace("'ve ", " 've ")
        )
        normTweet = (
            normTweet.replace(" p . m .", "  p.m.")
            .replace(" p . m ", " p.m ")
            .replace(" a . m .", " a.m.")
            .replace(" a . m ", " a.m ")
        )

        return " ".join(normTweet.split())
    # 将给定的 token 标准化为小写形式
    def normalizeToken(self, token):
        lowercased_token = token.lower()
        # 如果 token 以 "@" 开头,则返回 "@USER"
        if token.startswith("@"):
            return "@USER"
        # 如果 token 的小写形式以 "http" 或 "www" 开头,则返回 "HTTPURL"
        elif lowercased_token.startswith("http") or lowercased_token.startswith("www"):
            return "HTTPURL"
        # 如果 token 的长度为 1
        elif len(token) == 1:
            # 如果 token 是特殊标点符号中的一种,则返回其对应的值
            if token in self.special_puncts:
                return self.special_puncts[token]
            # 如果存在表情解析器,则用表情解析器处理 token,否则返回原 token
            if self.demojizer is not None:
                return self.demojizer(token)
            else:
                return token
        # 对于其他情况,直接返回 token
        else:
            return token

    # 根据 token 转换为对应的 id,使用给定的词汇表
    def _convert_token_to_id(self, token):
        return self.encoder.get(token, self.encoder.get(self.unk_token))

    # 根据 id 转换为对应的 token,使用给定的词汇表
    def _convert_id_to_token(self, index):
        return self.decoder.get(index, self.unk_token)

    # 将一系列 tokens 转换为单个字符串
    def convert_tokens_to_string(self, tokens):
        out_string = " ".join(tokens).replace("@@ ", "").strip()
        return out_string

    # 保存词汇表到指定目录
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # 如果保存目录不存在,记录错误并返回
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return
        # 构造词汇表文件路径和合并文件路径
        out_vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )
        out_merge_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
        )
        
        # 如果当前词汇表文件路径与目标路径不同且当前路径下存在词汇表文件,则复制词汇表文件到目标路径
        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
            copyfile(self.vocab_file, out_vocab_file)
        # 如果当前路径下不存在词汇表文件,则将当前模型的序列化词汇表模型写入目标路径
        elif not os.path.isfile(self.vocab_file):
            with open(out_vocab_file, "wb") as fi:
                content_spiece_model = self.sp_model.serialized_model_proto()
                fi.write(content_spiece_model)

        # 如果当前合并文件路径与目标路径不同,则复制合并文件到目标路径
        if os.path.abspath(self.merges_file) != os.path.abspath(out_merge_file):
            copyfile(self.merges_file, out_merge_file)

        return out_vocab_file, out_merge_file
    def add_from_file(self, f):
        """
        从文本文件中加载一个预先存在的字典,并将其符号添加到当前实例中。
        """
        # 如果输入参数 f 是字符串类型,则尝试打开该文件
        if isinstance(f, str):
            try:
                with open(f, "r", encoding="utf-8") as fd:
                    # 递归调用 add_from_file 方法,加载文件内容
                    self.add_from_file(fd)
            except FileNotFoundError as fnfe:
                # 如果文件不存在,则抛出 FileNotFound 异常
                raise fnfe
            except UnicodeError:
                # 如果在文件中检测到不正确的编码,则抛出异常
                raise Exception(f"Incorrect encoding detected in {f}, please rebuild the dataset")
            # 返回,结束当前函数调用
            return

        # 读取文件中的所有行
        lines = f.readlines()
        # 遍历每一行内容
        for lineTmp in lines:
            # 去除行首尾空白符
            line = lineTmp.strip()
            # 查找行中最后一个空格的位置
            idx = line.rfind(" ")
            # 如果找不到空格,则抛出数值错误异常
            if idx == -1:
                raise ValueError("Incorrect dictionary format, expected '<token> <cnt>'")
            # 提取空格之前的部分作为单词
            word = line[:idx]
            # 将单词作为键,将当前编码器长度作为值存入编码器字典中
            self.encoder[word] = len(self.encoder)
# Natural Language Toolkit: Twitter Tokenizer
#
# Copyright (C) 2001-2020 NLTK Project
# Author: Christopher Potts <cgpotts@stanford.edu>
#         Ewan Klein <ewan@inf.ed.ac.uk> (modifications)
#         Pierpaolo Pantone <> (modifications)
# URL: http://nltk.org/
# For license information, see LICENSE.TXT
#

"""
Twitter-aware tokenizer, designed to be flexible and easy to adapt to new domains and tasks. The basic logic is this:

1. The tuple regex_strings defines a list of regular expression strings.

2. The regex_strings strings are put, in order, into a compiled regular expression object called word_re.

3. The tokenization is done by word_re.findall(s), where s is the user-supplied string, inside the tokenize() method of
   the class Tokenizer.

4. When instantiating Tokenizer objects, there is a single option: preserve_case. By default, it is set to True. If it
   is set to False, then the tokenizer will lowercase everything except for emoticons.

"""


######################################################################
#
# import regex  # https://github.com/nltk/nltk/issues/2409
# import html
#
######################################################################
# The following strings are components in the regular expression
# that is used for tokenizing. It's important that phone_number
# appears first in the final regex (since it can contain whitespace).
# It also could matter that tags comes after emoticons, due to the
# possibility of having text like
#
#     <:| and some text >:)
#
# Most importantly, the final element should always be last, since it
# does a last ditch whitespace-based tokenization of whatever is left.

# ToDo: Update with http://en.wikipedia.org/wiki/List_of_emoticons ?

# This particular element is used in a couple ways, so we define it
# with a name:
# docstyle-ignore
EMOTICONS = r"""
    (?:
      [<>]?                           # optional opening angle bracket
      [:;=8]                          # eyes
      [\-o\*\']?                      # optional nose
      [\)\]\(\[dDpP/\:\}\{@\|\\]      # mouth
      |
      [\)\]\(\[dDpP/\:\}\{@\|\\]      # mouth
      [\-o\*\']?                      # optional nose
      [:;=8]                          # eyes
      [<>]?                           # optional closing angle bracket
      |
      <3                               # heart
    )"""

# URL pattern due to John Gruber, modified by Tom Winzig. See
# https://gist.github.com/winzig/8894715
# docstyle-ignore
URLS = r"""            # Capture 1: entire matched URL
  (?:
  https?:                     # URL protocol and colon
    (?:
      /{1,3}                     # 1-3 slashes
      |                         #   or
      [a-z0-9%]                     # Single letter or digit or '%'
                                       # (Trying not to match e.g. "URI::Escape")
    )
    |                         #   or
                                       # looks like domain name followed by a slash:
    [a-z0-9.\-]+[.]
    (?:[a-z]{2,13})
    /
  )
  (?:                         # One or more:
    [^\s()<>{}\[\]]+                 # Run of non-space, non-()<>{}[]
    |                         #   or

    \(
      [^\s()<>{}\[\]]+
    \)
  )+
  (?:                         # End with:
    \(
      [^\s()<>{}\[\]]+
    \)
    |                         #   or
    [^\s`!()\[\]{};:'".,<>?«»“”‘’]
  )
"""

# The above pattern defines URLs using a regex for tokenization purposes,
# covering various formats and components typically found in URLs.
    \([^\s()]*?\([^\s()]+\)[^\s()]*?\) # 匹配具有平衡括号的一级深度的表达式:(...(...)...)
    |
    \([^\s]+?\)                # 匹配非递归的平衡括号表达式:(...)
  )+                          # 上述两种模式可以出现一次或多次,即匹配多个括号嵌套或单个括号
  (?:                          # 结尾处可以是以下模式之一:
    \([^\s()]*?\([^\s()]+\)[^\s()]*?\) # 匹配具有平衡括号的一级深度的表达式:(...(...)...)
    |
    \([^\s]+?\)                # 匹配非递归的平衡括号表达式:(...)
    |                          # 或者
    [^\s`!()\[\]{};:'".,<>?«»“”‘’]    # 不是空格或特定的标点字符
  )
  |                          # 或者,用于匹配裸域名:
  (?:
    (?<!@)                    # 前面不是 @,避免在电子邮件地址中匹配例如 "foo@_gmail.com_"
    [a-z0-9]+
    (?:[.\-][a-z0-9]+)*
    [.]
    (?:[a-z]{2,13})
    \b
    /?
    (?!@)                    # 后面不是 @,避免在电子邮件地址中匹配例如 "foo.na" 在 "foo.na@example.com" 中
  )


这段代码是一个正则表达式模式,用于匹配具有特定形式的括号结构和裸域名。
# 定义正则表达式模式以识别不同类型的标记
# 包括 URL、电话号码、ASCII 表情、HTML 标签、ASCII 箭头、Twitter 用户名、Twitter 主题标签、电子邮件地址等
REGEXPS = (
    URLS,  # 匹配 URL
    r"""
    (?:
      (?:            # (国际)
        \+?[01]
        [ *\-.\)]*
      )?
      (?:            # (区号)
        [\(]?
        \d{3}
        [ *\-.\)]*
      )?
      \d{3}          # 交换机
      [ *\-.\)]*
      \d{4}          # 基站
    )""",  # 匹配电话号码
    EMOTICONS,  # 匹配 ASCII 表情
    r"""<[^>\s]+>""",  # 匹配 HTML 标签
    r"""[\-]+>|<[\-]+""",  # 匹配 ASCII 箭头
    r"""(?:@[\w_]+)""",  # 匹配 Twitter 用户名
    r"""(?:\#+[\w_]+[\w\'_\-]*[\w_]+)""",  # 匹配 Twitter 主题标签
    r"""[\w.+-]+@[\w-]+\.(?:[\w-]\.?)+[\w-]""",  # 匹配电子邮件地址
    r"""
    (?:[^\W\d_](?:[^\W\d_]|['\-_])+[^\W\d_]) # 带有撇号或破折号的单词
    |
    (?:[+\-]?\d+[,/.:-]\d+[+\-]?)  # 数字,包括分数、小数点
    |
    (?:[\w_]+)                     # 没有撇号或破折号的单词
    |
    (?:\.(?:\s*\.){1,})            # 省略号
    |
    (?:\S)                         # 其他非空白字符
    """,  # 匹配剩余的词类
)

######################################################################
# 这是核心的分词正则表达式:

# 将 REGEXPS 中的所有模式组合成一个大的正则表达式
WORD_RE = regex.compile(r"""(%s)""" % "|".join(REGEXPS), regex.VERBOSE | regex.I | regex.UNICODE)

# HANG_RE 用于识别连续字符的模式
HANG_RE = regex.compile(r"([^a-zA-Z0-9])\1{3,}")

# EMOTICON_RE 用于识别表情符号的模式
EMOTICON_RE = regex.compile(EMOTICONS, regex.VERBOSE | regex.I | regex.UNICODE)

# ENT_RE 用于将 HTML 实体转换为 Unicode 字符的模式
ENT_RE = regex.compile(r"&(#?(x?))([^&;\s]+);")
    # 导入HTML实体替换函数
    from nltk.tokenize.casual import _replace_html_entities

    # 使用HTML实体替换函数处理包含HTML实体的字节字符串,返回替换后的字符串
    _replace_html_entities(b"Price: &pound;100")
    # 输出结果:'Price: \\xa3100'

    # 打印使用HTML实体替换函数处理包含HTML实体的字节字符串,应该输出替换后的Unicode字符串
    print(_replace_html_entities(b"Price: &pound;100"))
    # 输出结果:Price: £100
class TweetTokenizer:
    r"""
    Examples:

    ```
    >>> # Tokenizer for tweets.
    >>> from nltk.tokenize import TweetTokenizer

    >>> tknzr = TweetTokenizer()
    >>> s0 = "This is a cooool #dummysmiley: :-) :-P <3 and some arrows < > -> <--"
    >>> tknzr.tokenize(s0)
    ['This', 'is', 'a', 'cooool', '#dummysmiley', ':', ':-)', ':-P', '<3', 'and', 'some', 'arrows', '<', '>', '->', '<--']

    >>> # Examples using *strip_handles* and *reduce_len parameters*:
    >>> tknzr = TweetTokenizer(strip_handles=True, reduce_len=True)
    >>> s1 = "@remy: This is waaaaayyyy too much for you!!!!!!"
    >>> tknzr.tokenize(s1)
    [':', 'This', 'is', 'waaayyy', 'too', 'much', 'for', 'you', '!', '!', '!']
    ```"""

    def __init__(self, preserve_case=True, reduce_len=False, strip_handles=False):
        # Initialize the TweetTokenizer with options to preserve case, reduce elongated words, and strip handles.
        self.preserve_case = preserve_case
        self.reduce_len = reduce_len
        self.strip_handles = strip_handles

    def tokenize(self, text):
        """
        Tokenize a given text into a list of words.

        Args:
            text: str

        Returns:
            list(str): A list of tokens extracted from the text.
        """
        # Fix HTML character entities before tokenization
        text = _replace_html_entities(text)
        # Remove Twitter handles if strip_handles is enabled
        if self.strip_handles:
            text = remove_handles(text)
        # Reduce elongated words to their base form if reduce_len is enabled
        if self.reduce_len:
            text = reduce_lengthening(text)
        # Replace problematic sequences of characters for safe tokenization
        safe_text = HANG_RE.sub(r"\1\1\1", text)
        # Tokenize the text using a regular expression for word boundaries
        words = WORD_RE.findall(safe_text)
        # Adjust word case unless it is part of an emoticon to preserve emoticon capitalization
        if not self.preserve_case:
            words = [x if EMOTICON_RE.search(x) else x.lower() for x in words]
        return words


######################################################################
# Normalization Functions
######################################################################

def reduce_lengthening(text):
    """
    Reduce repeated character sequences of length 3 or greater to sequences of length 3.

    Args:
        text: str

    Returns:
        str: Text with reduced elongations.
    """
    pattern = regex.compile(r"(.)\1{2,}")
    return pattern.sub(r"\1\1\1", text)


def remove_handles(text):
    """
    Remove Twitter username handles from text.

    Args:
        text: str

    Returns:
        str: Text with removed handles replaced by spaces.
    """
    pattern = regex.compile(
        r"(?<![A-Za-z0-9_!@#\$%&*])@(([A-Za-z0-9_]){20}(?!@))|(?<![A-Za-z0-9_!@#\$%&*])@(([A-Za-z0-9_]){1,19})(?![A-Za-z0-9_]*@)"
    )
    # Substitute handles with ' ' to ensure correct tokenization around removed handles
    return pattern.sub(" ", text)


######################################################################
# Tokenization Function
######################################################################

def casual_tokenize(text, preserve_case=True, reduce_len=False, strip_handles=False):
    """
    Tokenize a text string using casual tokenization rules.

    Args:
        text: str
        preserve_case: bool, optional (default=True)
            Whether to preserve case in tokens.
        reduce_len: bool, optional (default=False)
            Whether to reduce elongated words.
        strip_handles: bool, optional (default=False)
            Whether to remove Twitter handles.

    Returns:
        list(str): A list of tokens extracted from the text based on specified rules.
    """
    # 创建一个TweetTokenizer对象,用于分词化处理,根据参数设置保留大小写、缩短长度和去除句柄
    """
    Convenience function for wrapping the tokenizer.
    """
    # 返回通过TweetTokenizer对象对文本进行分词化处理得到的结果
    return TweetTokenizer(preserve_case=preserve_case, reduce_len=reduce_len, strip_handles=strip_handles).tokenize(
        text
    )
###############################################################################

# 定义一个函数 `calculate_total`,接收一个参数 `items`
def calculate_total(items):
    # 初始化一个变量 `total`,用于累计总和
    total = 0
    # 遍历参数 `items` 中的每个元素,将其加到 `total` 中
    for item in items:
        total += item
    # 返回累计的总和 `total`
    return total

.\models\bertweet\__init__.py

# 版权声明和许可证信息
# 版权所有 2020 年 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证 2.0 版本(“许可证”)许可;
# 除非符合许可证的规定,否则不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则按“原样”分发软件
# 软件没有任何明示或暗示的担保或条件。
# 有关详细信息,请参阅许可证。

# 导入类型检查模块
from typing import TYPE_CHECKING

# 导入延迟加载模块
from ...utils import _LazyModule

# 定义模块的导入结构
_import_structure = {"tokenization_bertweet": ["BertweetTokenizer"]}

# 如果是类型检查阶段
if TYPE_CHECKING:
    # 从本地模块中导入 BertweetTokenizer 类型
    from .tokenization_bertweet import BertweetTokenizer

# 如果不是类型检查阶段(即运行阶段)
else:
    # 导入系统模块
    import sys

    # 将当前模块指定为延迟加载模块
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\bert_generation\configuration_bert_generation.py

# coding=utf-8
# Copyright 2020 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""  BertGeneration model configuration"""

# Import the base class PretrainedConfig from configuration_utils module
from ...configuration_utils import PretrainedConfig

# Define a new class BertGenerationConfig that inherits from PretrainedConfig
class BertGenerationConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`BertGenerationPreTrainedModel`]. It is used to
    instantiate a BertGeneration model according to the specified arguments, defining the model architecture.
    Instantiating a configuration with the defaults will yield a similar configuration to that of the BertGeneration
    [google/bert_for_seq_generation_L-24_bbc_encoder](https://huggingface.co/google/bert_for_seq_generation_L-24_bbc_encoder)
    architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Examples:

    ```
    >>> from transformers import BertGenerationConfig, BertGenerationEncoder

    >>> # Initializing a BertGeneration config
    >>> configuration = BertGenerationConfig()

    >>> # Initializing a model (with random weights) from the config
    >>> model = BertGenerationEncoder(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```

    # Set the model_type attribute to "bert-generation"
    model_type = "bert-generation"

    # Define the constructor (__init__) method for initializing an instance of BertGenerationConfig
    def __init__(
        self,
        vocab_size=50358,  # Size of the vocabulary used by the model
        hidden_size=1024,  # Dimensionality of the encoder layers and the pooler layer
        num_hidden_layers=24,  # Number of hidden layers in the Transformer encoder
        num_attention_heads=16,  # Number of attention heads for each attention layer in the Transformer encoder
        intermediate_size=4096,  # Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder
        hidden_act="gelu",  # The activation function to be used in the hidden layers
        hidden_dropout_prob=0.1,  # The dropout probability for all fully connected layers in the embeddings, encoder, and pooler
        attention_probs_dropout_prob=0.1,  # The dropout ratio for the attention probabilities
        max_position_embeddings=512,  # The maximum sequence length that this model might ever be used with
        initializer_range=0.02,  # The standard deviation of the truncated_normal_initializer for initializing all weight matrices
        layer_norm_eps=1e-12,  # The epsilon used by the layer normalization layers
        pad_token_id=0,  # The token id for padding
        bos_token_id=2,  # The token id for the beginning of sentence token
        eos_token_id=1,  # The token id for the end of sentence token
        position_embedding_type="absolute",  # Type of position embedding to use
        use_cache=True,  # Whether to use an output cache
        **kwargs,  # Additional keyword arguments for future expansion
    ):
        # Call the constructor of the base class (PretrainedConfig) with all the provided arguments
        super().__init__(
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            num_hidden_layers=num_hidden_layers,
            num_attention_heads=num_attention_heads,
            intermediate_size=intermediate_size,
            hidden_act=hidden_act,
            hidden_dropout_prob=hidden_dropout_prob,
            attention_probs_dropout_prob=attention_probs_dropout_prob,
            max_position_embeddings=max_position_embeddings,
            initializer_range=initializer_range,
            layer_norm_eps=layer_norm_eps,
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            position_embedding_type=position_embedding_type,
            use_cache=use_cache,
            **kwargs,
        )
        ):
            # 调用父类的初始化方法,传递相关参数,并继承其行为
            super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

            # 设置当前类的词汇表大小
            self.vocab_size = vocab_size
            # 设置隐藏层大小
            self.hidden_size = hidden_size
            # 设置隐藏层的数量
            self.num_hidden_layers = num_hidden_layers
            # 设置注意力头的数量
            self.num_attention_heads = num_attention_heads
            # 设置隐藏层激活函数类型
            self.hidden_act = hidden_act
            # 设置中间层大小
            self.intermediate_size = intermediate_size
            # 设置隐藏层的 dropout 概率
            self.hidden_dropout_prob = hidden_dropout_prob
            # 设置注意力概率的 dropout 概率
            self.attention_probs_dropout_prob = attention_probs_dropout_prob
            # 设置最大位置嵌入长度
            self.max_position_embeddings = max_position_embeddings
            # 设置初始化范围
            self.initializer_range = initializer_range
            # 设置层归一化的 epsilon 值
            self.layer_norm_eps = layer_norm_eps
            # 设置位置嵌入类型
            self.position_embedding_type = position_embedding_type
            # 设置是否使用缓存
            self.use_cache = use_cache

.\models\bert_generation\modeling_bert_generation.py

# coding=utf-8
# Copyright 2020 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model specific for generation."""

import math
from typing import Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from .configuration_bert_generation import BertGenerationConfig

# 获取日志记录器实例
logger = logging.get_logger(__name__)

# 模型文档中的预定义变量
_CHECKPOINT_FOR_DOC = "google/bert_for_seq_generation_L-24_bbc_encoder"
_CONFIG_FOR_DOC = "BertGenerationConfig"


# 从transformers.models.bert.modeling_bert.BertSelfOutput复制并更改为BertGeneration
class BertGenerationSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 全连接层,用于变换隐藏状态的维度
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # Layer normalization 层,用于归一化隐藏状态
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # Dropout 层,用于随机失活,防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 全连接变换
        hidden_states = self.dense(hidden_states)
        # 随机失活
        hidden_states = self.dropout(hidden_states)
        # 残差连接和Layer normalization
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


# 从transformers.models.bert.modeling_bert.BertSelfAttention复制并更改为BertGeneration
class BertGenerationSelfAttention(nn.Module):
    # 省略了构造函数
    pass
    # 初始化函数,接受配置参数和可能的位置嵌入类型
    def __init__(self, config, position_embedding_type=None):
        # 调用父类的初始化方法
        super().__init__()
        # 检查隐藏层大小是否可以被注意力头数整除,同时检查是否存在嵌入大小属性
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            # 如果不满足条件,抛出数值错误异常
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        # 设置注意力头数和每个头的大小
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # 初始化查询、键、值的线性变换层
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        # 初始化 dropout 层,用于注意力概率的 dropout
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        # 设置位置嵌入类型,默认为绝对位置嵌入
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        # 如果位置嵌入类型是相对位置嵌入,则初始化距离嵌入的 Embedding 层
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

        # 标记是否为解码器
        self.is_decoder = config.is_decoder

    # 将输入张量重塑为注意力分数的形状
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        # 计算新的张量形状,将注意力头维度放到第二维,头大小维度放到最后一维
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        # 重塑张量形状
        x = x.view(new_x_shape)
        # 调换维度,以便计算注意力分数
        return x.permute(0, 2, 1, 3)

    # 前向传播函数,接受隐藏状态和各种可选的参数
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BertGeneration
class BertGenerationAttention(nn.Module):
    def __init__(self, config, position_embedding_type=None):
        super().__init__()
        # 初始化自注意力层,使用BertGenerationSelfAttention类
        self.self = BertGenerationSelfAttention(config, position_embedding_type=position_embedding_type)
        # 初始化输出层,使用BertGenerationSelfOutput类
        self.output = BertGenerationSelfOutput(config)
        # 存储需要被剪枝的注意力头索引的集合
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        # 调用帮助函数找到可剪枝的注意力头和其对应的索引
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )

        # 剪枝线性层
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # 更新超参数并存储已剪枝的注意力头
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # 前向传播函数,调用自注意力层和输出层
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
        )
        # 将自注意力层的输出作为参数传给输出层
        attention_output = self.output(self_outputs[0], hidden_states)
        # 如果需要输出注意力权重,则添加到outputs中
        outputs = (attention_output,) + self_outputs[1:]  # 如果需要输出注意力权重,则添加到outputs中
        return outputs


# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->BertGeneration
class BertGenerationIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 定义线性层,将隐藏状态映射到中间大小
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        # 根据配置选择中间激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 通过线性层进行映射
        hidden_states = self.dense(hidden_states)
        # 应用中间激活函数
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->BertGeneration
class BertGenerationOutput(nn.Module):
    # 初始化函数,用于创建一个新的对象实例
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__()
        # 创建一个全连接层,将输入特征大小设为 config.intermediate_size,输出特征大小设为 config.hidden_size
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        # 创建一个 Layer Normalization 层,对输入的隐藏状态进行归一化处理,归一化的维度为 config.hidden_size,设置 epsilon 为 config.layer_norm_eps
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 创建一个 Dropout 层,用于随机将输入张量中的元素设置为零,以防止过拟合,丢弃率为 config.hidden_dropout_prob
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    # 前向传播函数,定义了数据从输入到输出的流程
    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 使用全连接层对隐藏状态进行线性变换
        hidden_states = self.dense(hidden_states)
        # 对变换后的隐藏状态进行 Dropout 处理,以减少过拟合风险
        hidden_states = self.dropout(hidden_states)
        # 将 Dropout 后的隐藏状态与输入张量进行加法操作,并对结果进行 Layer Normalization
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        # 返回处理后的隐藏状态作为输出
        return hidden_states
# 从transformers.models.bert.modeling_bert.BertLayer复制而来,将Bert改为BertGeneration
class BertGenerationLayer(nn.Module):
    # 初始化方法,接收一个config对象作为参数
    def __init__(self, config):
        super().__init__()
        # 设置前向传播中的分块大小
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        # 序列长度维度设为1
        self.seq_len_dim = 1
        # 创建BertGenerationAttention对象
        self.attention = BertGenerationAttention(config)
        # 是否作为解码器使用
        self.is_decoder = config.is_decoder
        # 是否添加交叉注意力
        self.add_cross_attention = config.add_cross_attention
        # 如果添加交叉注意力,确保作为解码器模型使用
        if self.add_cross_attention:
            if not self.is_decoder:
                raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
            # 创建具有绝对位置嵌入类型的BertGenerationAttention对象
            self.crossattention = BertGenerationAttention(config, position_embedding_type="absolute")
        # 创建BertGenerationIntermediate对象
        self.intermediate = BertGenerationIntermediate(config)
        # 创建BertGenerationOutput对象
        self.output = BertGenerationOutput(config)

    # 前向传播方法
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
        # 定义方法 feed_forward_chunk,它接收 attention_output 作为输入并返回处理后的层输出
        def feed_forward_chunk(self, attention_output):
            # 使用 self.intermediate 对 attention_output 进行中间层处理
            intermediate_output = self.intermediate(attention_output)
            # 使用 self.output 对中间层输出和 attention_output 进行最终层处理,得到最终层输出
            layer_output = self.output(intermediate_output, attention_output)
            # 返回最终层输出作为 feed_forward_chunk 方法的输出
            return layer_output

        # 如果 past_key_value 不为 None,则将其前两个元素作为 self_attn_past_key_value,否则设为 None
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
        
        # 使用 self.attention 方法处理 hidden_states,根据给定的参数生成 self_attention_outputs
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            past_key_value=self_attn_past_key_value,
        )
        
        # 从 self_attention_outputs 中获取 self-attention 输出
        attention_output = self_attention_outputs[0]

        # 如果当前对象是解码器(decoder),则 self_attention_outputs 的最后一个元素为 self-attn cache 元组
        if self.is_decoder:
            # 从 self_attention_outputs 中排除最后一个元素,其余元素存入 outputs
            outputs = self_attention_outputs[1:-1]
            # 将 self_attention_outputs 的最后一个元素作为 present_key_value
            present_key_value = self_attention_outputs[-1]
        else:
            # 如果不是解码器,则 outputs 包括除了第一个元素之外的所有 self_attention_outputs
            outputs = self_attention_outputs[1:]  # 如果输出注意力权重,则添加 self attentions
       
        # 初始化交叉注意力的 present_key_value 为 None
        cross_attn_present_key_value = None
        
        # 如果是解码器且 encoder_hidden_states 不为 None
        if self.is_decoder and encoder_hidden_states is not None:
            # 如果当前对象没有 crossattention 属性,则抛出异常
            if not hasattr(self, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
                    " by setting `config.add_cross_attention=True`"
                )

            # 如果 past_key_value 不为 None,则取其倒数两个元素作为 cross_attn_past_key_value,否则设为 None
            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
            
            # 使用 self.crossattention 方法处理 attention_output,生成 cross_attention_outputs
            cross_attention_outputs = self.crossattention(
                attention_output,
                attention_mask,
                head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                cross_attn_past_key_value,
                output_attentions,
            )
            
            # 从 cross_attention_outputs 中获取交叉注意力的输出
            attention_output = cross_attention_outputs[0]
            
            # 将 cross_attention_outputs 的除了第一个和最后一个元素之外的所有元素添加到 outputs 中
            outputs = outputs + cross_attention_outputs[1:-1]  # 如果输出注意力权重,则添加 cross attentions
            
            # 将 cross_attention_outputs 的最后一个元素作为 cross_attn_present_key_value
            cross_attn_present_key_value = cross_attention_outputs[-1]
            
            # 将 present_key_value 和 cross_attn_present_key_value 相加,更新 present_key_value
            present_key_value = present_key_value + cross_attn_present_key_value
        
        # 对 attention_output 应用分块处理策略,得到层输出 layer_output
        layer_output = apply_chunking_to_forward(
            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
        )
        
        # 将 layer_output 添加到 outputs 的开头
        outputs = (layer_output,) + outputs
        
        # 如果是解码器,则将 attn key/values 作为输出的最后一个元素添加到 outputs 中
        if self.is_decoder:
            outputs = outputs + (present_key_value,)
        
        # 返回最终的 outputs
        return outputs
# 从transformers.models.bert.modeling_bert.BertEncoder复制代码,并将Bert->BertGeneration
class BertEncoder(nn.Module):
    # 初始化方法,接受一个config对象作为参数
    def __init__(self, config):
        super().__init__()
        # 将传入的config对象保存到实例变量中
        self.config = config
        # 创建一个由多个BertGenerationLayer组成的层列表,列表长度为config.num_hidden_layers
        self.layer = nn.ModuleList([BertGenerationLayer(config) for _ in range(config.num_hidden_layers)])
        # 设置梯度检查点为False
        self.gradient_checkpointing = False

    # 前向传播方法
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = True,
        ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
        # 如果不需要输出隐藏状态,则初始化为空元组;否则设为 None,准备存储每层的隐藏状态
        all_hidden_states = () if output_hidden_states else None
        # 如果不需要输出注意力权重,则初始化为空元组;否则设为 None,准备存储每层的自注意力权重
        all_self_attentions = () if output_attentions else None
        # 如果不需要输出交叉注意力权重或没有配置交叉注意力,则初始化为空元组;否则设为 None,准备存储每层的交叉注意力权重
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

        # 如果启用了梯度检查点且在训练模式下
        if self.gradient_checkpointing and self.training:
            # 如果同时设置了使用缓存,则给出警告并强制设置 `use_cache=False`
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        # 如果需要使用缓存,则初始化为空元组;否则设为 None,准备存储下一个解码器缓存
        next_decoder_cache = () if use_cache else None
        # 遍历每个解码器层
        for i, layer_module in enumerate(self.layer):
            # 如果需要输出隐藏状态,则将当前层的隐藏状态添加到所有隐藏状态的元组中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 如果提供了头部掩码,则获取当前层的头部掩码
            layer_head_mask = head_mask[i] if head_mask is not None else None
            # 如果提供了过去的键值对,则获取当前层的过去键值对
            past_key_value = past_key_values[i] if past_key_values is not None else None

            # 如果启用了梯度检查点且在训练模式下
            if self.gradient_checkpointing and self.training:
                # 使用梯度检查点函数对当前层进行调用,计算当前层的输出
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )
            else:
                # 否则直接调用当前层,计算当前层的输出
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )

            # 更新隐藏状态为当前层的输出的第一个元素
            hidden_states = layer_outputs[0]
            # 如果使用缓存,则将当前层的缓存添加到下一个解码器缓存的元组中
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)
            # 如果需要输出注意力权重,则将当前层的自注意力权重添加到所有自注意力权重的元组中
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)
                # 如果配置了添加交叉注意力,则将当前层的交叉注意力权重添加到所有交叉注意力权重的元组中
                if self.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

        # 如果需要输出隐藏状态,则将最终的隐藏状态添加到所有隐藏状态的元组中
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 如果不返回字典格式的结果,则将各项结果组成元组返回
        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_decoder_cache,
                    all_hidden_states,
                    all_self_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )
        # 否则,将结果封装成 BaseModelOutputWithPastAndCrossAttentions 对象返回
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )
def load_tf_weights_in_bert_generation(
    model, tf_hub_path, model_class, is_encoder_named_decoder=False, is_encoder=False
):
    try:
        # 尝试导入必要的库
        import numpy as np
        import tensorflow.compat.v1 as tf
        import tensorflow_hub as hub
        import tensorflow_text  # noqa: F401

        # 禁用 TensorFlow 的即时执行模式
        tf.disable_eager_execution()
    except ImportError:
        # 如果导入失败,记录错误信息并抛出异常
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
        raise

    # 使用 TensorFlow Hub 加载模型
    tf_model = hub.Module(tf_hub_path)
    # 初始化 TensorFlow 的全局变量
    init = tf.global_variables_initializer()

class BertGenerationEmbeddings(nn.Module):
    """Construct the embeddings from word and position embeddings."""

    def __init__(self, config):
        super().__init__()
        # 初始化词嵌入层和位置嵌入层
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        # LayerNorm 保持与 TensorFlow 模型变量名一致,以便加载任何 TensorFlow 的检查点文件
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # 注册位置编码张量,用于处理序列位置信息
        self.register_buffer(
            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
        )

    def forward(self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0):
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        # 如果位置编码未提供,则使用预定义的位置编码张量
        if position_ids is None:
            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]

        # 如果未提供输入嵌入,则使用输入的词嵌入进行计算
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        
        # 计算位置嵌入
        position_embeddings = self.position_embeddings(position_ids)

        # 将词嵌入和位置嵌入相加
        embeddings = inputs_embeds + position_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class BertGenerationPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # BertGenerationPreTrainedModel 的配置类
    config_class = BertGenerationConfig
    # 基础模型前缀
    base_model_prefix = "bert"
    # 支持梯度检查点
    supports_gradient_checkpointing = True
    # 初始化神经网络模块的权重
    def _init_weights(self, module):
        """Initialize the weights"""
        # 如果是线性层(全连接层)
        if isinstance(module, nn.Linear):
            # 使用正态分布初始化权重,均值为0,标准差为模型配置中的初始化范围
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 如果存在偏置项,则将其初始化为零
            if module.bias is not None:
                module.bias.data.zero_()
        # 如果是嵌入层
        elif isinstance(module, nn.Embedding):
            # 使用正态分布初始化权重,均值为0,标准差为模型配置中的初始化范围
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 如果定义了填充索引,则将对应位置的权重初始化为零
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        # 如果是层归一化层
        elif isinstance(module, nn.LayerNorm):
            # 将偏置项初始化为零
            module.bias.data.zero_()
            # 将权重初始化为1
            module.weight.data.fill_(1.0)
BERT_GENERATION_START_DOCSTRING = r"""
    
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`BertGenerationConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

BERT_GENERATION_INPUTS_DOCSTRING = r"""
    
    Inputs:
        input_ids (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.
            Indices can be obtained using :class:`~transformers.BertTokenizer`. See
            :meth:`transformers.PreTrainedTokenizer.__call__` and :meth:`transformers.PreTrainedTokenizer.encode` for
            details.
        attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, optional):
            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
            Defaults to ``None``.
        decoder_input_ids (:obj:`torch.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, optional):
            Provide for generation tasks to guide decoding. Indices can be obtained using
            :class:`~transformers.BertTokenizer`. See :meth:`transformers.PreTrainedTokenizer.__call__` and
            :meth:`transformers.PreTrainedTokenizer.encode` for details.
        decoder_attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, optional):
            Mask to avoid performing attention on padding token indices for the decoder input. Mask values selected in
            ``[0, 1]``:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
            Defaults to ``None``.
        head_mask (:obj:`torch.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, optional):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
            Defaults to ``None``.
        inputs_embeds (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, optional):
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert tokens to embeddings before feeding them into
            the model.
        decoder_inputs_embeds (:obj:`torch.Tensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`,
            optional):
            Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded
            representation. This is useful if you want more control over how to convert tokens to embeddings before
            feeding them into the model.
        labels (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, optional):
            Labels for computing the masked language modeling loss.
            Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (padding value).
            Tokens with labels set to -100 are ignored (masked), otherwise compute loss.

    Returns:
        :obj:`torch.Tensor`: Returns a tuple comprising various elements depending on the configuration. The first
        element is the final hidden states from the model, which can be used for further downstream tasks such as
        classification, regression, or sequence generation.

    Example::

        from transformers import BertTokenizer, BertGenerationModel

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertGenerationModel.from_pretrained('bert-base-uncased')

        inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        outputs = model(**inputs)

        logits = outputs.logits
"""
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            # 输入序列的token索引在词汇表中的位置。
            # 可以使用[`AutoTokenizer`]获取这些索引。参见[`PreTrainedTokenizer.__call__`]和[`PreTrainedTokenizer.encode`]获取更多细节。
            # [什么是输入ID?](../glossary#input-ids)
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            # 遮罩,用于避免在填充token索引上执行注意力计算。
            # 遮罩值在 `[0, 1]` 之间:
            # - 1 表示token没有被遮罩,
            # - 0 表示token被遮罩。
            # [什么是注意力遮罩?](../glossary#attention-mask)
        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            # 每个输入序列token在位置嵌入中的位置索引。
            # 索引范围在 `[0, config.max_position_embeddings - 1]` 之间。
            # [什么是位置ID?](../glossary#position-ids)
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            # 用于屏蔽自注意力模块中选择的注意力头的遮罩。
            # 遮罩值在 `[0, 1]` 之间:
            # - 1 表示注意力头没有被遮罩,
            # - 0 表示注意力头被遮罩。
        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
            # 可选参数,可以直接传入嵌入表示,而不是传递`input_ids`。
            # 如果您希望对如何将`input_ids`索引转换为相关向量有更多控制权,这将非常有用。
        output_attentions (`bool`, *optional*):
            # 是否返回所有注意力层的注意力张量。更多细节请参见返回的张量中的`attentions`。
        output_hidden_states (`bool`, *optional*):
            # 是否返回所有层的隐藏状态。更多细节请参见返回的张量中的`hidden_states`。
        return_dict (`bool`, *optional*):
            # 是否返回[`~utils.ModelOutput`]而不是普通的元组。
    """
    # 给BertGenerationEncoder类添加文档字符串,描述其作为无特定头部的原始隐藏状态输出的BertGeneration模型转换器
    @add_start_docstrings(
        "The bare BertGeneration model transformer outputting raw hidden-states without any specific head on top.",
        BERT_GENERATION_START_DOCSTRING,
    )
    """
    """
    # BertGenerationEncoder类,用于BertGenerationPreTrainedModel的扩展
    class BertGenerationEncoder(BertGenerationPreTrainedModel):
    
        """
        # BertGenerationEncoder类的初始化函数,初始化模型配置
        def __init__(self, config):
            # 调用父类的初始化函数
            super().__init__(config)
            # 将配置保存到对象中
            self.config = config
    
            # 初始化嵌入层
            self.embeddings = BertGenerationEmbeddings(config)
            # 初始化编码器层
            self.encoder = BertEncoder(config)
    
            # 调用后处理函数,初始化权重并进行最终处理
            self.post_init()
    
        # 获取输入嵌入层的函数
        def get_input_embeddings(self):
            return self.embeddings.word_embeddings
    
        # 设置输入嵌入层的函数
        def set_input_embeddings(self, value):
            self.embeddings.word_embeddings = value
    
        # 剪枝模型中的注意力头部的函数
        def _prune_heads(self, heads_to_prune):
            """
            Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
            class PreTrainedModel
            """
            for layer, heads in heads_to_prune.items():
                self.encoder.layer[layer].attention.prune_heads(heads)
    
    """
    """
        # 为模型正向传播函数添加文档字符串
        @add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
        @add_code_sample_docstrings(
            checkpoint=_CHECKPOINT_FOR_DOC,
            output_type=BaseModelOutputWithPastAndCrossAttentions,
            config_class=_CONFIG_FOR_DOC,
        )
    ```
    # 前向传播函数,用于模型的前向推理过程
    def forward(
        self,
        # 输入的 token IDs 张量,可以是可选的
        input_ids: Optional[torch.Tensor] = None,
        # 注意力掩码张量,用于指示哪些 token 是有效的,可以是可选的
        attention_mask: Optional[torch.Tensor] = None,
        # 位置 IDs 张量,用于指定每个 token 的位置信息,可以是可选的
        position_ids: Optional[torch.Tensor] = None,
        # 头部掩码张量,用于屏蔽某些注意力头部的输出,可以是可选的
        head_mask: Optional[torch.Tensor] = None,
        # 输入嵌入张量,代替输入 token IDs 进行输入,可以是可选的
        inputs_embeds: Optional[torch.Tensor] = None,
        # 编码器隐藏状态张量,可以是可选的
        encoder_hidden_states: Optional[torch.Tensor] = None,
        # 编码器注意力掩码张量,用于指示编码器哪些 token 是有效的,可以是可选的
        encoder_attention_mask: Optional[torch.Tensor] = None,
        # 过去的键-值元组,用于存储前一个时间步的键-值,可以是可选的
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        # 是否使用缓存,控制是否使用过去的键-值缓存结果,可以是可选的
        use_cache: Optional[bool] = None,
        # 是否输出注意力权重张量,可以是可选的
        output_attentions: Optional[bool] = None,
        # 是否输出隐藏状态张量,可以是可选的
        output_hidden_states: Optional[bool] = None,
        # 是否返回字典对象作为前向传播的输出,可以是可选的
        return_dict: Optional[bool] = None,
class BertGenerationOnlyLMHead(nn.Module):
    # 定义一个类,用于BERT生成模型的语言建模头部
    def __init__(self, config):
        super().__init__()
        # 初始化方法,继承父类构造函数
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
        # 创建一个线性层,用于生成输出词汇的logits
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        # 创建一个可学习的偏置参数
        self.decoder.bias = self.bias
        # 将偏置参数赋给线性层的偏置

    def forward(self, hidden_states):
        # 前向传播方法
        logits = self.decoder(hidden_states)
        # 计算输出的logits
        return logits

    def _tie_weights(self):
        # 方法用于绑定两个权重(如果它们在TPU上断开连接或者偏置被调整大小时)
        self.bias = self.decoder.bias


@add_start_docstrings(
    """BertGeneration Model with a `language modeling` head on top for CLM fine-tuning.""",
    BERT_GENERATION_START_DOCSTRING,
)
class BertGenerationDecoder(BertGenerationPreTrainedModel):
    # BertGeneration解码器类,继承自BertGeneration预训练模型

    _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]

    def __init__(self, config):
        super().__init__(config)

        if not config.is_decoder:
            logger.warning("If you want to use `BertGenerationDecoder` as a standalone, add `is_decoder=True.`")
        # 如果不是解码器模式,发出警告信息

        self.bert = BertGenerationEncoder(config)
        # 创建一个BertGeneration编码器对象
        self.lm_head = BertGenerationOnlyLMHead(config)
        # 创建一个BERT生成模型的语言建模头部对象

        # Initialize weights and apply final processing
        # 初始化权重并应用最终处理
        self.post_init()

    def get_output_embeddings(self):
        # 获取输出词嵌入
        return self.lm_head.decoder

    def set_output_embeddings(self, new_embeddings):
        # 设置输出词嵌入
        self.lm_head.decoder = new_embeddings

    @add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        # 前向传播方法,支持多种参数输入
    # 准备生成过程中的输入数据
    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
        # 获取输入张量的形状信息
        input_shape = input_ids.shape
        
        # 如果未提供注意力掩码,则创建一个全1的张量作为注意力掩码
        if attention_mask is None:
            attention_mask = input_ids.new_ones(input_shape)

        # 如果使用过去的键值对,则根据情况裁剪输入的输入ID
        if past_key_values is not None:
            # 获取过去键值对中第一个元素的长度
            past_length = past_key_values[0][0].shape[2]

            # 如果输入ID的长度大于过去的长度,则移除前缀部分
            if input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                # 否则,默认保留最后一个输入ID
                remove_prefix_length = input_ids.shape[1] - 1

            input_ids = input_ids[:, remove_prefix_length:]

        # 返回一个包含准备好的输入数据的字典
        return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}

    # 重新排序缓存中的过去键值对,以适应束搜索的索引
    def _reorder_cache(self, past_key_values, beam_idx):
        reordered_past = ()
        # 遍历每一层的过去键值对
        for layer_past in past_key_values:
            # 将每个过去状态按照给定的束搜索索引重新排序,并组成元组
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        # 返回重新排序后的过去键值对元组
        return reordered_past

.\models\bert_generation\tokenization_bert_generation.py

# coding=utf-8
# 上面的注释声明了编码格式和版权信息

# 导入所需的库和模块
import os
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple

# 导入 sentencepiece 库,用于处理分词任务
import sentencepiece as spm

# 导入日志模块,用于记录和输出日志信息
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging

# 获取 logger 对象,用于日志记录
logger = logging.get_logger(__name__)

# 定义词汇文件的名称常量
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}

# 定义预训练模型的词汇文件映射
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "bert_for_seq_generation": (
            "https://huggingface.co/google/bert_for_seq_generation_L-24_bbc_encoder/resolve/main/spiece.model"
        ),
    }
}

# 定义预训练模型的位置嵌入大小映射
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"bert_for_seq_generation": 512}

# 定义 BertGenerationTokenizer 类,继承自 PreTrainedTokenizer
class BertGenerationTokenizer(PreTrainedTokenizer):
    """
    Construct a BertGeneration tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).

    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
    this superclass for more information regarding those methods.
    """

# 代码块结束
    Args:
        vocab_file (`str`):
            [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
            contains the vocabulary necessary to instantiate a tokenizer.
        bos_token (`str`, *optional*, defaults to `"<s>"`):
            The begin of sequence token.
        eos_token (`str`, *optional*, defaults to `"</s>"`):
            The end of sequence token.
        unk_token (`str`, *optional*, defaults to `"<unk>"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        pad_token (`str`, *optional*, defaults to `"<pad>"`):
            The token used for padding, for example when batching sequences of different lengths.
        sep_token (`str`, *optional*, defaults to `"<::::>"`):
            The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
            sequence classification or for a text and a question for question answering. It is also used as the last
            token of a sequence built with special tokens.
        sp_model_kwargs (`dict`, *optional*):
            Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
            SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
            to set:

            - `enable_sampling`: Enable subword regularization.
            - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
              - `nbest_size = {0,1}`: No sampling is performed.
              - `nbest_size > 1`: samples from the nbest_size results.
              - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
                using forward-filtering-and-backward-sampling algorithm.
            - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
              BPE-dropout.

    """
    # 定义了一些常量用于指定预训练模型所需的文件名称
    vocab_files_names = VOCAB_FILES_NAMES
    # 定义了一个映射,指定了预训练模型所需的词汇文件的位置
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    # 定义了一个映射,指定了预训练模型的最大输入大小
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    # 初始化一个空列表,用于存储前缀 tokens
    prefix_tokens: List[int] = []
    # 模型输入的名称列表
    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        vocab_file,
        bos_token="<s>",
        eos_token="</s>",
        unk_token="<unk>",
        pad_token="<pad>",
        sep_token="<::::>",
        sp_model_kwargs: Optional[Dict[str, Any]] = None,
        **kwargs,
    ):
        """
        初始化函数,用于实例化一个新的 Tokenizer 对象。

        Parameters:
            vocab_file (str):
                SentencePiece 文件名,包含用于实例化分词器所需的词汇表。
            bos_token (str, optional, default="<s>"):
                序列的开始标记。
            eos_token (str, optional, default="</s>"):
                序列的结束标记。
            unk_token (str, optional, default="<unk>"):
                未知标记。如果词汇表中不存在的 token,将被设置为此标记。
            pad_token (str, optional, default="<pad>"):
                用于填充的标记,例如在对不同长度的序列进行批处理时使用。
            sep_token (str, optional, default="<::::>"):
                分隔符标记,用于构建由多个序列组成的序列,例如用于序列分类或问答时的文本和问题。
                也作为包含特殊标记的序列的最后一个标记使用。
            sp_model_kwargs (dict, optional):
                传递给 `SentencePieceProcessor.__init__()` 方法的参数字典。
                可以用于设置 SentencePiece 的各种参数,例如启用子词正则化等。
        """
    ) -> None:
        """初始化函数,用于设置特定的参数并加载SentencePiece模型。

        Args:
            sp_model_kwargs (dict, optional): SentencePiece模型的参数设置,默认为空字典。
            vocab_file (str): SentencePiece模型的词汇文件路径。
            bos_token (str, optional): SentencePiece模型的开始符号,默认为None。
            eos_token (str, optional): SentencePiece模型的结束符号,默认为None。
            unk_token (str, optional): SentencePiece模型的未知符号,默认为None。
            pad_token (str, optional): SentencePiece模型的填充符号,默认为None。
            sep_token (str, optional): SentencePiece模型的分隔符号,默认为None。
            **kwargs: 其他可能的参数。
        """
        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs

        self.vocab_file = vocab_file

        # 使用给定的参数初始化SentencePieceProcessor对象
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        # 加载指定的词汇文件到SentencePieceProcessor对象中
        self.sp_model.Load(vocab_file)

        # 调用父类的初始化方法,设置特殊符号的参数
        super().__init__(
            bos_token=bos_token,
            eos_token=eos_token,
            unk_token=unk_token,
            pad_token=pad_token,
            sep_token=sep_token,
            sp_model_kwargs=self.sp_model_kwargs,
            **kwargs,
        )

    @property
    def vocab_size(self):
        """获取当前SentencePiece模型的词汇大小(词汇量)。"""
        return self.sp_model.get_piece_size()

    def get_vocab(self):
        """返回一个词汇表字典,将token转换为对应的id。"""
        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
        # 添加额外的特殊token编码到词汇表字典中
        vocab.update(self.added_tokens_encoder)
        return vocab

    def __getstate__(self):
        """获取对象的状态信息,用于序列化对象。"""
        state = self.__dict__.copy()
        state["sp_model"] = None  # 将sp_model设置为None,以便序列化时不包含模型
        return state

    def __setstate__(self, d):
        """设置对象的状态信息,用于反序列化对象。"""
        self.__dict__ = d

        # 为了向后兼容性,如果不存在sp_model_kwargs属性,则设置为空字典
        if not hasattr(self, "sp_model_kwargs"):
            self.sp_model_kwargs = {}

        # 重新创建SentencePieceProcessor对象并加载词汇文件
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        self.sp_model.Load(self.vocab_file)

    def _tokenize(self, text: str) -> List[str]:
        """将输入的文本进行分词(tokenize),返回token列表。"""
        return self.sp_model.encode(text, out_type=str)

    def _convert_token_to_id(self, token):
        """将token转换为对应的id。"""
        return self.sp_model.piece_to_id(token)

    def _convert_id_to_token(self, index):
        """将id转换为对应的token。"""
        token = self.sp_model.IdToPiece(index)
        return token

    def convert_tokens_to_string(self, tokens):
        """将token序列转换为单个字符串。"""
        current_sub_tokens = []
        out_string = ""
        for token in tokens:
            # 确保特殊token不使用sentencepiece模型进行解码
            if token in self.all_special_tokens:
                out_string += self.sp_model.decode(current_sub_tokens) + token
                current_sub_tokens = []
            else:
                current_sub_tokens.append(token)
        out_string += self.sp_model.decode(current_sub_tokens)
        return out_string.strip()
    # 保存词汇表到指定目录,可选择添加前缀到文件名
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # 检查保存目录是否存在,如果不存在则记录错误并返回
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return

        # 构建输出词汇表文件的路径,如果有前缀则添加到文件名中
        out_vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )

        # 如果当前词汇表文件与输出文件不是同一个文件并且当前词汇表文件存在,则复制当前词汇表文件到输出文件
        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
            copyfile(self.vocab_file, out_vocab_file)
        # 如果当前词汇表文件不存在,则将序列化后的模型内容写入输出文件
        elif not os.path.isfile(self.vocab_file):
            with open(out_vocab_file, "wb") as fi:
                content_spiece_model = self.sp_model.serialized_model_proto()
                fi.write(content_spiece_model)

        # 返回保存的输出词汇表文件路径的元组
        return (out_vocab_file,)

.\models\bert_generation\__init__.py

# 引入必要的模块和依赖项
from typing import TYPE_CHECKING
# 从相对路径导入工具函数和类
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_sentencepiece_available, is_torch_available

# 定义模块的导入结构,包含一个字典,用于按需导入不同的模块
_import_structure = {"configuration_bert_generation": ["BertGenerationConfig"]}

# 检查是否存在SentencePiece库,若不存在则抛出OptionalDependencyNotAvailable异常
try:
    if not is_sentencepiece_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果库可用,将tokenization_bert_generation模块添加到导入结构中
    _import_structure["tokenization_bert_generation"] = ["BertGenerationTokenizer"]

# 检查是否存在Torch库,若不存在则抛出OptionalDependencyNotAvailable异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果库可用,将modeling_bert_generation模块添加到导入结构中
    _import_structure["modeling_bert_generation"] = [
        "BertGenerationDecoder",
        "BertGenerationEncoder",
        "BertGenerationPreTrainedModel",
        "load_tf_weights_in_bert_generation",
    ]

# 如果当前环境是类型检查模式,导入额外的模块
if TYPE_CHECKING:
    from .configuration_bert_generation import BertGenerationConfig

    try:
        if not is_sentencepiece_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .tokenization_bert_generation import BertGenerationTokenizer

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_bert_generation import (
            BertGenerationDecoder,
            BertGenerationEncoder,
            BertGenerationPreTrainedModel,
            load_tf_weights_in_bert_generation,
        )

# 如果不是类型检查模式,则将LazyModule注册为当前模块,用于按需导入
else:
    import sys

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\bert_japanese\tokenization_bert_japanese.py

# coding=utf-8
# 版权所有 2018 年 Google AI 语言团队和 HuggingFace Inc. 团队。
#
# 根据 Apache 许可证 2.0 版本(“许可证”)授权;
# 除非符合许可证,否则不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则按“原样”分发软件,
# 没有任何形式的担保或条件,包括但不限于有关适销性或特定用途的保证。
# 请查阅许可证以了解具体的法律规定和限制。
"""Tokenization classes."""


import collections
import copy
import os
import unicodedata
from typing import Any, Dict, List, Optional, Tuple

# 从 tokenization_utils 模块导入必要的函数和类
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
# 从 utils 模块导入 logging 函数
from ...utils import is_sentencepiece_available, is_sudachi_projection_available, logging


# 如果 sentencepiece 可用,则导入 sentencepiece 库
if is_sentencepiece_available():
    import sentencepiece as spm
else:
    spm = None

# 获取 logger 对象
logger = logging.get_logger(__name__)

# 定义词汇文件的名称映射
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "spm_file": "spiece.model"}

# 定义 subword 分隔符
SPIECE_UNDERLINE = "▁"

# 定义预训练模型的词汇文件映射
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "cl-tohoku/bert-base-japanese": "https://huggingface.co/cl-tohoku/bert-base-japanese/resolve/main/vocab.txt",
        "cl-tohoku/bert-base-japanese-whole-word-masking": (
            "https://huggingface.co/cl-tohoku/bert-base-japanese-whole-word-masking/resolve/main/vocab.txt"
        ),
        "cl-tohoku/bert-base-japanese-char": (
            "https://huggingface.co/cl-tohoku/bert-base-japanese-char/resolve/main/vocab.txt"
        ),
        "cl-tohoku/bert-base-japanese-char-whole-word-masking": (
            "https://huggingface.co/cl-tohoku/bert-base-japanese-char-whole-word-masking/resolve/main/vocab.txt"
        ),
    }
}

# 定义预训练模型的位置嵌入尺寸映射
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "cl-tohoku/bert-base-japanese": 512,
    "cl-tohoku/bert-base-japanese-whole-word-masking": 512,
    "cl-tohoku/bert-base-japanese-char": 512,
    "cl-tohoku/bert-base-japanese-char-whole-word-masking": 512,
}

# 定义预训练模型的初始化配置映射
PRETRAINED_INIT_CONFIGURATION = {
    "cl-tohoku/bert-base-japanese": {
        "do_lower_case": False,
        "word_tokenizer_type": "mecab",
        "subword_tokenizer_type": "wordpiece",
    },
    "cl-tohoku/bert-base-japanese-whole-word-masking": {
        "do_lower_case": False,
        "word_tokenizer_type": "mecab",
        "subword_tokenizer_type": "wordpiece",
    },
    "cl-tohoku/bert-base-japanese-char": {
        "do_lower_case": False,
        "word_tokenizer_type": "mecab",
        "subword_tokenizer_type": "character",
    },
    "cl-tohoku/bert-base-japanese-char-whole-word-masking": {
        "do_lower_case": False,
        "word_tokenizer_type": "mecab",
        "subword_tokenizer_type": "character",
    },
}


# 从 transformers.models.bert.tokenization_bert.load_vocab 复制而来
`
def load_vocab(vocab_file):
    """Loads a vocabulary file into a dictionary."""
    vocab = collections.OrderedDict()  # 创建一个有序字典用于存储词汇
    with open(vocab_file, "r", encoding="utf-8") as reader:  # 打开词汇文件,指定编码为 UTF-8
        tokens = reader.readlines()  # 读取文件中的所有行
    for index, token in enumerate(tokens):  # 遍历所有读取的行,索引从 0 开始
        token = token.rstrip("\n")  # 移除行尾的换行符
        vocab[token] = index  # 将词汇添加到字典中,键为词汇,值为其索引
    return vocab  # 返回加载的词汇字典

# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize
def whitespace_tokenize(text):
    """Runs basic whitespace cleaning and splitting on a piece of text."""
    text = text.strip()  # 移除文本开头和结尾的空白字符
    if not text:  # 如果文本为空,则返回空列表
        return []
    tokens = text.split()  # 将文本按空格分割成词语列表
    return tokens  # 返回分割后的词语列表

class BertJapaneseTokenizer(PreTrainedTokenizer):
    r"""
    Construct a BERT tokenizer for Japanese text.

    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer
    to: this superclass for more information regarding those methods.

    Args:
        vocab_file (`str`):
            Path to a one-wordpiece-per-line vocabulary file.
        spm_file (`str`, *optional*):
            Path to [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .spm or .model
            extension) that contains the vocabulary.
        do_lower_case (`bool`, *optional*, defaults to `True`):
            Whether to lower case the input. Only has an effect when do_basic_tokenize=True.
        do_word_tokenize (`bool`, *optional*, defaults to `True`):
            Whether to do word tokenization.
        do_subword_tokenize (`bool`, *optional*, defaults to `True`):
            Whether to do subword tokenization.
        word_tokenizer_type (`str`, *optional*, defaults to `"basic"`):
            Type of word tokenizer. Choose from ["basic", "mecab", "sudachi", "jumanpp"].
        subword_tokenizer_type (`str`, *optional*, defaults to `"wordpiece"`):
            Type of subword tokenizer. Choose from ["wordpiece", "character", "sentencepiece",].
        mecab_kwargs (`dict`, *optional*):
            Dictionary passed to the `MecabTokenizer` constructor.
        sudachi_kwargs (`dict`, *optional*):
            Dictionary passed to the `SudachiTokenizer` constructor.
        jumanpp_kwargs (`dict`, *optional*):
            Dictionary passed to the `JumanppTokenizer` constructor.
    """

    vocab_files_names = VOCAB_FILES_NAMES  # 设置词汇文件名
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP  # 设置预训练词汇文件的映射关系
    pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION  # 设置预训练模型的初始化配置
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES  # 设置模型输入的最大大小
    # 初始化函数,用于初始化对象
    def __init__(
        self,
        vocab_file,
        spm_file=None,
        do_lower_case=False,
        do_word_tokenize=True,
        do_subword_tokenize=True,
        word_tokenizer_type="basic",
        subword_tokenizer_type="wordpiece",
        never_split=None,
        unk_token="[UNK]",
        sep_token="[SEP]",
        pad_token="[PAD]",
        cls_token="[CLS]",
        mask_token="[MASK]",
        mecab_kwargs=None,
        sudachi_kwargs=None,
        jumanpp_kwargs=None,
        **kwargs,
    ):
    
    @property
    # 返回属性 do_lower_case 的值
    def do_lower_case(self):
        return self.lower_case

    # 将对象序列化为字典形式,用于 pickle 操作
    def __getstate__(self):
        state = dict(self.__dict__)
        # 如果使用的是 mecab、sudachi 或 jumanpp 分词器,则删除 word_tokenizer 属性,因为它们不支持序列化
        if self.word_tokenizer_type in ["mecab", "sudachi", "jumanpp"]:
            del state["word_tokenizer"]
        return state

    # 从字典状态中恢复对象的状态
    def __setstate__(self, state):
        self.__dict__ = state
        # 根据 word_tokenizer_type 属性重新初始化 word_tokenizer
        if self.word_tokenizer_type == "mecab":
            self.word_tokenizer = MecabTokenizer(
                do_lower_case=self.do_lower_case, never_split=self.never_split, **(self.mecab_kwargs or {})
            )
        elif self.word_tokenizer_type == "sudachi":
            self.word_tokenizer = SudachiTokenizer(
                do_lower_case=self.do_lower_case, never_split=self.never_split, **(self.sudachi_kwargs or {})
            )
        elif self.word_tokenizer_type == "jumanpp":
            self.word_tokenizer = JumanppTokenizer(
                do_lower_case=self.do_lower_case, never_split=self.never_split, **(self.jumanpp_kwargs or {})
            )

    # 对文本进行分词处理
    def _tokenize(self, text):
        if self.do_word_tokenize:
            # 使用 word_tokenizer 对文本进行分词,如果需要,会忽略特殊标记的切分
            tokens = self.word_tokenizer.tokenize(text, never_split=self.all_special_tokens)
        else:
            tokens = [text]

        if self.do_subword_tokenize:
            # 对词级标记进行子词处理
            split_tokens = [sub_token for token in tokens for sub_token in self.subword_tokenizer.tokenize(token)]
        else:
            split_tokens = tokens

        return split_tokens

    @property
    # 返回词汇表的大小
    def vocab_size(self):
        if self.subword_tokenizer_type == "sentencepiece":
            return len(self.subword_tokenizer.sp_model)
        return len(self.vocab)

    # 获取词汇表
    def get_vocab(self):
        if self.subword_tokenizer_type == "sentencepiece":
            # 如果使用 sentencepiece 分词器,返回从索引到词汇的映射
            vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
            vocab.update(self.added_tokens_encoder)
            return vocab
        # 否则,返回词汇表和添加的特殊标记的编码映射
        return dict(self.vocab, **self.added_tokens_encoder)

    # 将 token 转换为对应的 id
    def _convert_token_to_id(self, token):
        """Converts a token (str) in an id using the vocab."""
        if self.subword_tokenizer_type == "sentencepiece":
            # 使用 sentencepiece 分词器将 token 转换为 id
            return self.subword_tokenizer.sp_model.PieceToId(token)
        # 否则,使用 vocab 将 token 转换为 id,如果未找到则使用 unk_token
        return self.vocab.get(token, self.vocab.get(self.unk_token))
    # 将索引转换为对应的词汇(字符串),使用当前的词汇表进行转换
    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        if self.subword_tokenizer_type == "sentencepiece":
            # 如果使用 sentencepiece 分词器,则通过索引获取对应的词片段
            return self.subword_tokenizer.sp_model.IdToPiece(index)
        # 否则,使用预先定义的词汇表将索引转换为对应的标记(token),如果索引不存在,则使用未知标记(unk_token)
        return self.ids_to_tokens.get(index, self.unk_token)

    # 将一系列标记(tokens)转换为单个字符串
    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (string) in a single string."""
        if self.subword_tokenizer_type == "sentencepiece":
            # 如果使用 sentencepiece 分词器,则使用其解码功能将标记序列解码为单个字符串
            return self.subword_tokenizer.sp_model.decode(tokens)
        # 否则,将标记序列连接成一个字符串,去除连字符(" ##"),并去除两端的空格
        out_string = " ".join(tokens).replace(" ##", "").strip()
        return out_string

    # 从输入的标记 ID 列表构建带有特殊标记的模型输入
    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
        adding special tokens. A BERT sequence has the following format:

        - single sequence: `[CLS] X [SEP]`
        - pair of sequences: `[CLS] A [SEP] B [SEP]`

        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
        """
        if token_ids_1 is None:
            # 如果只有一个输入序列,将其前后加上特殊标记 [CLS] 和 [SEP]
            return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
        cls = [self.cls_token_id]
        sep = [self.sep_token_id]
        # 如果有两个输入序列,加上 [CLS],连接第一个序列和第二个序列,最后加上两个 [SEP]
        return cls + token_ids_0 + sep + token_ids_1 + sep

    # 获取特殊标记的掩码,标识哪些位置是特殊标记
    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
        """
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `List[int]`: List of integers indicating special tokens (1 for special token, 0 for regular token).
        """
    ) -> List[int]:
        """
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` method.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """

        if already_has_special_tokens:
            # If tokens already have special tokens, delegate to parent class method
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        if token_ids_1 is not None:
            # For sequence pair: add special tokens around both sequences
            return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
        # For single sequence: add special tokens around the sequence
        return [1] + ([0] * len(token_ids_0)) + [1]

    # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.create_token_type_ids_from_sequences
    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
        pair mask has the following format:

        ```
        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
        | first sequence    | second sequence |
        ```

        If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
        """
        sep = [self.sep_token_id]  # Separator token ID
        cls = [self.cls_token_id]  # Classification token ID
        if token_ids_1 is None:
            # If no second sequence, return token type IDs for the first sequence only
            return len(cls + token_ids_0 + sep) * [0]
        # Return token type IDs for both sequences concatenated with special tokens
        return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
    # 保存词汇表到指定目录下的文件
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # 检查保存目录是否存在
        if os.path.isdir(save_directory):
            # 如果使用的子词分割器类型是 sentencepiece
            if self.subword_tokenizer_type == "sentencepiece":
                # 构建保存 sentencepiece 词汇文件的完整路径
                vocab_file = os.path.join(
                    save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["spm_file"]
                )
            else:
                # 构建保存普通词汇文件的完整路径
                vocab_file = os.path.join(
                    save_directory,
                    (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"],
                )
        else:
            # 如果保存目录不存在,则直接将文件名作为保存路径
            vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory

        # 根据子词分割器类型写入词汇表内容到文件
        if self.subword_tokenizer_type == "sentencepiece":
            # 使用二进制方式写入 sentencepiece 模型内容到文件
            with open(vocab_file, "wb") as writer:
                content_spiece_model = self.subword_tokenizer.sp_model.serialized_model_proto()
                writer.write(content_spiece_model)
        else:
            # 使用 UTF-8 编码以文本方式写入普通词汇表内容到文件
            with open(vocab_file, "w", encoding="utf-8") as writer:
                index = 0
                # 按词汇表索引排序,将词汇和索引写入文件
                for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
                    # 检查词汇表索引是否连续,记录最后的索引
                    if index != token_index:
                        logger.warning(
                            f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
                            " Please check that the vocabulary is not corrupted!"
                        )
                        index = token_index
                    # 写入词汇及换行符
                    writer.write(token + "\n")
                    index += 1
        # 返回保存的文件路径
        return (vocab_file,)
# 定义了一个名为 MecabTokenizer 的类,用于基本的 MeCab 形态分析器的分词处理
class MecabTokenizer:
    """Runs basic tokenization with MeCab morphological parser."""

    # 初始化方法,设置了几个参数来配置分词器的行为
    def __init__(
        self,
        do_lower_case=False,  # 控制是否将所有字符转换为小写
        never_split=None,  # 永远不要分割的词汇列表,如果没有指定,默认为空
        normalize_text=True,  # 控制是否对文本进行规范化处理
        mecab_dic: Optional[str] = "ipadic",  # MeCab 使用的字典,默认为 "ipadic"
        mecab_option: Optional[str] = None,  # MeCab 的其他选项,可选,默认为 None
    ):

    # 方法用于对文本进行分词处理
    def tokenize(self, text, never_split=None, **kwargs):
        """Tokenizes a piece of text."""
        # 如果需要对文本进行规范化处理,则使用 unicodedata 进行 NFKC 规范化
        if self.normalize_text:
            text = unicodedata.normalize("NFKC", text)

        # 确定不会被分割的词汇列表,考虑实例化时设置的 never_split 和传入参数中的 never_split
        never_split = self.never_split + (never_split if never_split is not None else [])
        # 初始化空的 tokens 列表用于存储分词结果
        tokens = []

        # 使用 MeCab 对文本进行分词处理
        for word in self.mecab(text):
            # 获取当前词的表层形式(token)
            token = word.surface

            # 如果设置了 do_lower_case 为 True,并且当前 token 不在 never_split 中,则将其转换为小写
            if self.do_lower_case and token not in never_split:
                token = token.lower()

            # 将处理后的 token 添加到 tokens 列表中
            tokens.append(token)

        # 返回最终的分词结果 tokens
        return tokens


# 定义了一个名为 SudachiTokenizer 的类,用于基本的 Sudachi 形态分析器的分词处理
class SudachiTokenizer:
    """Runs basic tokenization with Sudachi morphological parser."""

    # 初始化方法,设置了几个参数来配置分词器的行为
    def __init__(
        self,
        do_lower_case=False,  # 控制是否将所有字符转换为小写
        never_split=None,  # 永远不要分割的词汇列表,如果没有指定,默认为空
        normalize_text=True,  # 控制是否对文本进行规范化处理
        trim_whitespace=False,  # 控制是否修剪文本中的空白字符
        sudachi_split_mode="A",  # Sudachi 的分割模式,默认为 "A"
        sudachi_config_path=None,  # Sudachi 的配置文件路径,可选,默认为 None
        sudachi_resource_dir=None,  # Sudachi 的资源目录路径,可选,默认为 None
        sudachi_dict_type="core",  # Sudachi 使用的词典类型,默认为 "core"
        sudachi_projection=None,  # Sudachi 的 projection 参数,可选,默认为 None
    ):
    ):
        """
        Constructs a SudachiTokenizer.

        Args:
            **do_lower_case**: (*optional*) boolean (default True)
                Whether to lowercase the input.
            **never_split**: (*optional*) list of str
                Kept for backward compatibility purposes. Now implemented directly at the base class level (see
                [`PreTrainedTokenizer.tokenize`]) List of tokens not to split.
            **normalize_text**: (*optional*) boolean (default True)
                Whether to apply unicode normalization to text before tokenization.
            **trim_whitespace**: (*optional*) boolean (default False)
                Whether to trim all whitespace, tab, newline from tokens.
            **sudachi_split_mode**: (*optional*) string
                Split mode of sudachi, choose from `["A", "B", "C"]`.
            **sudachi_config_path**: (*optional*) string
                Path to Sudachi configuration file.
            **sudachi_resource_dir**: (*optional*) string
                Directory containing Sudachi resources.
            **sudachi_dict_type**: (*optional*) string
                Dictionary type of Sudachi, choose from `["small", "core", "full"]`.
            **sudachi_projection**: (*optional*) string
                Word projection mode of Sudachi, choose from `["surface", "normalized", "reading", "dictionary", "dictionary_and_surface", "normalized_and_surface", "normalized_nouns"]`.
        """

        self.do_lower_case = do_lower_case  # 设置是否将输入转换为小写
        self.never_split = never_split if never_split is not None else []  # 设置不需要分割的标记列表
        self.normalize_text = normalize_text  # 设置是否在分词前对文本进行Unicode标准化
        self.trim_whitespace = trim_whitespace  # 设置是否去除所有标记中的空白、制表符和换行符

        try:
            from sudachipy import dictionary, tokenizer  # 导入Sudachi相关库
        except ImportError:
            raise ImportError(
                "You need to install sudachipy to use SudachiTokenizer. "
                "See https://github.com/WorksApplications/SudachiPy for installation."
            )

        if sudachi_split_mode == "A":
            self.split_mode = tokenizer.Tokenizer.SplitMode.A  # 设置Sudachi的分割模式为A
        elif sudachi_split_mode == "B":
            self.split_mode = tokenizer.Tokenizer.SplitMode.B  # 设置Sudachi的分割模式为B
        elif sudachi_split_mode == "C":
            self.split_mode = tokenizer.Tokenizer.SplitMode.C  # 设置Sudachi的分割模式为C
        else:
            raise ValueError("Invalid sudachi_split_mode is specified.")  # 报错,如果指定的Sudachi分割模式无效

        self.projection = sudachi_projection  # 设置Sudachi的词汇投影模式

        # 创建Sudachi字典对象
        sudachi_dictionary = dictionary.Dictionary(
            config_path=sudachi_config_path, resource_dir=sudachi_resource_dir, dict=sudachi_dict_type
        )
        
        # 检查Sudachi的投影模式是否可用,并设置相应的Sudachi对象
        if is_sudachi_projection_available():
            self.sudachi = sudachi_dictionary.create(self.split_mode, projection=self.projection)
        elif self.projection is not None:
            raise ImportError("You need to install sudachipy>=0.6.8 to specify `projection` field in sudachi_kwargs.")
        else:
            self.sudachi = sudachi_dictionary.create(self.split_mode)
    # Tokenizes a piece of text based on the specified tokenizer settings.
    def tokenize(self, text, never_split=None, **kwargs):
        """Tokenizes a piece of text."""
        # Normalize text if enabled to ensure consistent representation
        if self.normalize_text:
            text = unicodedata.normalize("NFKC", text)

        # Combine the default never_split tokens with any user-provided ones
        never_split = self.never_split + (never_split if never_split is not None else [])
        # Initialize an empty list to store tokens
        tokens = []

        # Iterate over tokens returned by the sudachi tokenizer
        for word in self.sudachi.tokenize(text):
            # Retrieve the surface form (actual text) of the token
            token = word.surface()

            # Convert token to lowercase if specified and not in the never_split list
            if self.do_lower_case and token not in never_split:
                token = token.lower()

            # Trim whitespace from tokens if specified
            if self.trim_whitespace:
                # Skip tokens that are completely whitespace
                if token.strip() == "":
                    continue
                else:
                    # Remove leading and trailing whitespace
                    token = token.strip()

            # Add processed token to the list of tokens
            tokens.append(token)

        # Return the list of tokens
        return tokens
class JumanppTokenizer:
    """Runs basic tokenization with jumanpp morphological parser."""

    def __init__(
        self,
        do_lower_case=False,
        never_split=None,
        normalize_text=True,
        trim_whitespace=False,
    ):
        """
        Constructs a JumanppTokenizer.

        Args:
            **do_lower_case**: (*optional*) boolean (default True)
                Whether to lowercase the input.
            **never_split**: (*optional*) list of str
                Kept for backward compatibility purposes. Now implemented directly at the base class level (see
                [`PreTrainedTokenizer.tokenize`]) List of tokens not to split.
            **normalize_text**: (*optional*) boolean (default True)
                Whether to apply unicode normalization to text before tokenization.
            **trim_whitespace**: (*optional*) boolean (default False)
                Whether to trim all whitespace, tab, newline from tokens.
        """

        self.do_lower_case = do_lower_case  # 是否将输入转换为小写,默认为 False
        self.never_split = never_split if never_split is not None else []  # 不希望被分割的特定 token 列表,默认为空列表
        self.normalize_text = normalize_text  # 是否对文本进行 Unicode 规范化,默认为 True
        self.trim_whitespace = trim_whitespace  # 是否去除所有空白符(空格、制表符、换行符),默认为 False

        try:
            import rhoknp
        except ImportError:
            raise ImportError(
                "You need to install rhoknp to use JumanppTokenizer. "
                "See https://github.com/ku-nlp/rhoknp for installation."
            )

        self.juman = rhoknp.Jumanpp()  # 初始化 Juman++ 分词器对象

    def tokenize(self, text, never_split=None, **kwargs):
        """Tokenizes a piece of text."""
        if self.normalize_text:
            text = unicodedata.normalize("NFKC", text)  # 如果需要,对文本进行 Unicode 规范化

        text = text.strip()  # 去除文本两端的空白符

        never_split = self.never_split + (never_split if never_split is not None else [])  # 合并当前实例的和传入的不希望分割的 token 列表
        tokens = []

        for mrph in self.juman.apply_to_sentence(text).morphemes:
            token = mrph.text  # 获取分词结果中的每个词素文本

            if self.do_lower_case and token not in never_split:
                token = token.lower()  # 如果需要,并且该 token 不在不分割列表中,则将其转换为小写

            if self.trim_whitespace:
                if token.strip() == "":
                    continue  # 如果需要,并且 token 是空字符串,则跳过
                else:
                    token = token.strip()  # 去除 token 前后的空白符

            tokens.append(token)  # 将处理后的 token 添加到 tokens 列表中

        return tokens


class CharacterTokenizer:
    """Runs Character tokenization."""

    def __init__(self, vocab, unk_token, normalize_text=True):
        """
        Constructs a CharacterTokenizer.

        Args:
            **vocab**:
                Vocabulary object.
            **unk_token**: str
                A special symbol for out-of-vocabulary token.
            **normalize_text**: (`optional`) boolean (default True)
                Whether to apply unicode normalization to text before tokenization.
        """
        self.vocab = vocab  # 初始化词汇表对象
        self.unk_token = unk_token  # 初始化未知 token 的特殊符号
        self.normalize_text = normalize_text  # 是否对文本进行 Unicode 规范化,默认为 True
    def tokenize(self, text):
        """
        将文本分词为字符列表。

        例如,`input = "apple"` 将返回 `["a", "p", "p", "l", "e"]`。

        Args:
            text: 单个标记或以空格分隔的标记。
                  应该已经通过 *BasicTokenizer* 处理过。

        Returns:
            包含字符的列表。
        """
        # 如果需要规范化文本,使用 Unicode 规范化函数将其转换为兼容 NFC 表示
        if self.normalize_text:
            text = unicodedata.normalize("NFKC", text)

        # 初始化空的输出 tokens 列表
        output_tokens = []
        # 遍历输入文本中的每个字符
        for char in text:
            # 如果字符不在词汇表中,则添加未知标记到输出 tokens 列表
            if char not in self.vocab:
                output_tokens.append(self.unk_token)
                continue

            # 否则,将字符添加到输出 tokens 列表中
            output_tokens.append(char)

        # 返回最终的字符列表
        return output_tokens
# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer
class BasicTokenizer(object):
    """
    Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).

    Args:
        do_lower_case (`bool`, *optional*, defaults to `True`):
            Whether or not to lowercase the input when tokenizing.
        never_split (`Iterable`, *optional*):
            Collection of tokens which will never be split during tokenization. Only has an effect when
            `do_basic_tokenize=True`
        tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
            Whether or not to tokenize Chinese characters.

            This should likely be deactivated for Japanese (see this
            [issue](https://github.com/huggingface/transformers/issues/328)).
        strip_accents (`bool`, *optional*):
            Whether or not to strip all accents. If this option is not specified, then it will be determined by the
            value for `lowercase` (as in the original BERT).
        do_split_on_punc (`bool`, *optional*, defaults to `True`):
            In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
            the full context of the words, such as contractions.
    """

    def __init__(
        self,
        do_lower_case=True,                    # 初始化方法,设置是否小写化输入,默认为True
        never_split=None,                      # 设置不进行分割的特定token集合,默认为None
        tokenize_chinese_chars=True,           # 设置是否对中文字符进行分词,默认为True
        strip_accents=None,                    # 设置是否去除所有重音符号,默认根据小写化选项决定
        do_split_on_punc=True,                 # 设置是否基于标点符号进行基本分词,默认为True
    ):
        if never_split is None:
            never_split = []                   # 如果never_split参数为None,则设为一个空列表
        self.do_lower_case = do_lower_case     # 将输入小写化选项保存到实例变量中
        self.never_split = set(never_split)    # 将never_split参数转换为集合类型并保存到实例变量中
        self.tokenize_chinese_chars = tokenize_chinese_chars  # 保存是否分词中文字符的选项到实例变量中
        self.strip_accents = strip_accents     # 将去除重音符号的选项保存到实例变量中
        self.do_split_on_punc = do_split_on_punc  # 将基于标点符号进行分词的选项保存到实例变量中
    # 对文本进行基本的分词处理。如需子词分词,请参考 WordPieceTokenizer。
    def tokenize(self, text, never_split=None):
        """
        Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.

        Args:
            never_split (`List[str]`, *optional*)
                Kept for backward compatibility purposes. Now implemented directly at the base class level (see
                [`PreTrainedTokenizer.tokenize`]) List of token not to split.
        """
        # 如果指定了 never_split 列表,则将其与类属性 never_split 的集合进行合并,以确保不拆分这些特定的 token
        never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
        # 清理文本中的特殊字符和空白
        text = self._clean_text(text)

        # 以下代码段于2018年11月1日添加,用于多语言和中文模型。
        # 现在也适用于英语模型,但由于英语模型没有在任何中文数据上进行训练,
        # 并且通常不包含任何中文数据(英语维基百科中确实包含一些中文词汇),
        # 因此这并不重要。
        if self.tokenize_chinese_chars:
            # 如果需要对中文字符进行特殊处理,则调用 _tokenize_chinese_chars 方法
            text = self._tokenize_chinese_chars(text)
        # 对文本进行 Unicode 规范化,确保统一字符的表示形式
        unicode_normalized_text = unicodedata.normalize("NFC", text)
        # 将规范化后的文本按空白字符分词,得到原始 token 列表
        orig_tokens = whitespace_tokenize(unicode_normalized_text)
        # 初始化分词结果列表
        split_tokens = []
        # 遍历每个原始 token
        for token in orig_tokens:
            # 如果 token 不在 never_split 集合中,则考虑是否进行小写处理和重音符号处理
            if token not in never_split:
                if self.do_lower_case:
                    # 如果需要小写处理,则将 token 转换为小写
                    token = token.lower()
                    if self.strip_accents is not False:
                        # 如果需要去除重音符号,则调用 _run_strip_accents 方法
                        token = self._run_strip_accents(token)
                elif self.strip_accents:
                    # 否则,如果仅需要去除重音符号,则调用 _run_strip_accents 方法
                    token = self._run_strip_accents(token)
            # 将处理后的 token 经过标点符号分割后加入分词结果列表
            split_tokens.extend(self._run_split_on_punc(token, never_split))

        # 将处理后的分词结果再次按空白字符分割,得到最终的输出 tokens
        output_tokens = whitespace_tokenize(" ".join(split_tokens))
        return output_tokens

    # 从文本中去除重音符号
    def _run_strip_accents(self, text):
        """Strips accents from a piece of text."""
        # 对文本进行 Unicode 规范化,转换为标准形式
        text = unicodedata.normalize("NFD", text)
        output = []
        # 遍历文本中的每个字符
        for char in text:
            # 获取字符的 Unicode 分类
            cat = unicodedata.category(char)
            # 如果字符的分类为 Mn(Nonspacing_Mark),表示是重音符号,跳过处理
            if cat == "Mn":
                continue
            # 否则将字符添加到输出列表中
            output.append(char)
        # 将输出列表中的字符拼接成字符串并返回
        return "".join(output)
    def _run_split_on_punc(self, text, never_split=None):
        """Splits punctuation on a piece of text."""
        # 如果不需要根据标点符号分割文本,或者文本在不分割列表中,则直接返回包含整个文本的列表
        if not self.do_split_on_punc or (never_split is not None and text in never_split):
            return [text]
        chars = list(text)
        i = 0
        start_new_word = True
        output = []
        while i < len(chars):
            char = chars[i]
            # 如果是标点符号,则创建一个新的列表作为输出的一部分,并将标志设置为开始新单词
            if _is_punctuation(char):
                output.append([char])
                start_new_word = True
            else:
                # 如果不是标点符号,且应该继续当前单词,则将字符添加到当前输出的最后一个列表中
                if start_new_word:
                    output.append([])
                start_new_word = False
                output[-1].append(char)
            i += 1

        # 将分割后的列表中的各个子列表连接成字符串,并返回列表
        return ["".join(x) for x in output]

    def _tokenize_chinese_chars(self, text):
        """Adds whitespace around any CJK character."""
        output = []
        for char in text:
            cp = ord(char)
            # 如果字符是中日韩字符,则在字符前后添加空格,并加入到输出列表中
            if self._is_chinese_char(cp):
                output.append(" ")
                output.append(char)
                output.append(" ")
            else:
                # 如果不是中日韩字符,则直接将字符加入到输出列表中
                output.append(char)
        # 将输出列表连接成一个字符串并返回
        return "".join(output)

    def _is_chinese_char(self, cp):
        """Checks whether CP is the codepoint of a CJK character."""
        # 检查给定的 Unicode 码点是否属于中日韩字符的范围
        if (
            (cp >= 0x4E00 and cp <= 0x9FFF)
            or (cp >= 0x3400 and cp <= 0x4DBF)  #
            or (cp >= 0x20000 and cp <= 0x2A6DF)  #
            or (cp >= 0x2A700 and cp <= 0x2B73F)  #
            or (cp >= 0x2B740 and cp <= 0x2B81F)  #
            or (cp >= 0x2B820 and cp <= 0x2CEAF)  #
            or (cp >= 0xF900 and cp <= 0xFAFF)
            or (cp >= 0x2F800 and cp <= 0x2FA1F)  #
        ):  #
            return True

        return False

    def _clean_text(self, text):
        """Performs invalid character removal and whitespace cleanup on text."""
        output = []
        for char in text:
            cp = ord(char)
            # 如果字符是无效字符或控制字符,则跳过
            if cp == 0 or cp == 0xFFFD or _is_control(char):
                continue
            # 如果是空白字符,则替换为单个空格;否则保留字符
            if _is_whitespace(char):
                output.append(" ")
            else:
                output.append(char)
        # 将清理后的字符列表连接成一个字符串并返回
        return "".join(output)
# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer
class WordpieceTokenizer(object):
    """Runs WordPiece tokenization."""

    def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
        # 初始化 WordpieceTokenizer 类的实例,设置词汇表、未知标记和每个单词的最大字符数
        self.vocab = vocab
        self.unk_token = unk_token
        self.max_input_chars_per_word = max_input_chars_per_word

    def tokenize(self, text):
        """
        Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
        tokenization using the given vocabulary.

        For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.

        Args:
            text: A single token or whitespace separated tokens. This should have
                already been passed through *BasicTokenizer*.

        Returns:
            A list of wordpiece tokens.
        """
        # 初始化输出的 tokens 列表
        output_tokens = []
        # 遍历通过 whitespace_tokenize 函数分词后的文本
        for token in whitespace_tokenize(text):
            # 将每个 token 拆分为字符列表
            chars = list(token)
            # 如果 token 的字符数超过设定的最大输入字符数,则用未知标记替换
            if len(chars) > self.max_input_chars_per_word:
                output_tokens.append(self.unk_token)
                continue

            is_bad = False
            start = 0
            sub_tokens = []
            # 使用贪婪算法将 token 分割为子 token
            while start < len(chars):
                end = len(chars)
                cur_substr = None
                while start < end:
                    substr = "".join(chars[start:end])
                    if start > 0:
                        substr = "##" + substr
                    # 检查子字符串是否在词汇表中
                    if substr in self.vocab:
                        cur_substr = substr
                        break
                    end -= 1
                if cur_substr is None:
                    is_bad = True
                    break
                sub_tokens.append(cur_substr)
                start = end

            # 如果存在无法识别的子 token,则用未知标记替换
            if is_bad:
                output_tokens.append(self.unk_token)
            else:
                output_tokens.extend(sub_tokens)
        return output_tokens


class SentencepieceTokenizer(object):
    """
    Runs sentencepiece tokenization. Based on transformers.models.albert.tokenization_albert.AlbertTokenizer.
    """

    def __init__(
        self,
        vocab,
        unk_token,
        do_lower_case=False,
        remove_space=True,
        keep_accents=True,
        sp_model_kwargs: Optional[Dict[str, Any]] = None,
    ):
        # 初始化 SentencepieceTokenizer 类的实例,设置词汇表、未知标记以及其他可选参数
        self.vocab = vocab
        self.unk_token = unk_token
        self.do_lower_case = do_lower_case
        self.remove_space = remove_space
        self.keep_accents = keep_accents

        # 如果没有传入 SentencePiece 参数,则设为默认空字典
        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
        # 创建 SentencePieceProcessor 对象并加载词汇表
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        self.sp_model.Load(self.vocab)
    # 对文本进行预处理,根据设置去除空格或保留原始格式,并替换特定的双引号格式
    def preprocess_text(self, inputs):
        if self.remove_space:
            # 如果需要去除空格,则去除首尾空格并将多余空格替换为单个空格
            outputs = " ".join(inputs.strip().split())
        else:
            # 否则保留原始输入文本
            outputs = inputs
        # 替换特定的双引号格式为标准双引号
        outputs = outputs.replace("``", '"').replace("''", '"')

        if not self.keep_accents:
            # 如果不保留重音符号,则规范化 Unicode 字符串,去除组合字符
            outputs = unicodedata.normalize("NFKD", outputs)
            outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
        if self.do_lower_case:
            # 如果需要将文本转换为小写,则进行小写转换
            outputs = outputs.lower()

        return outputs

    # 使用 SentencePiece 对文本进行分词处理
    def tokenize(self, text):
        """
        Tokenizes text by sentencepiece. Based on [SentencePiece](https://github.com/google/sentencepiece).
        Tokenization needs the given vocabulary.

        Args:
            text: A string needs to be tokenized.

        Returns:
            A list of sentencepiece tokens.
        """
        # 对输入文本先进行预处理
        text = self.preprocess_text(text)
        # 使用 SentencePiece 模型对文本进行编码,并以字符串形式输出
        pieces = self.sp_model.encode(text, out_type=str)
        new_pieces = []
        for piece in pieces:
            # 处理特定形式的词片段,如以数字结尾且最后一个字符是逗号的情况
            if len(piece) > 1 and piece[-1] == str(",") and piece[-2].isdigit():
                # 对词片段去除最后的逗号并处理成新的词片段列表
                cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
                # 调整处理后的词片段的格式
                if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
                    if len(cur_pieces[0]) == 1:
                        cur_pieces = cur_pieces[1:]
                    else:
                        cur_pieces[0] = cur_pieces[0][1:]
                # 将处理后的词片段添加到新的词片段列表中
                cur_pieces.append(piece[-1])
                new_pieces.extend(cur_pieces)
            else:
                # 将普通词片段直接添加到新的词片段列表中
                new_pieces.append(piece)

        return new_pieces

.\models\bert_japanese\__init__.py

# 导入必要的类型检查模块
from typing import TYPE_CHECKING

# 导入 LazyModule 用于延迟加载模块
from ...utils import _LazyModule

# 定义要导入的结构,包括 tokenization_bert_japanese 模块的几个特定类
_import_structure = {"tokenization_bert_japanese": ["BertJapaneseTokenizer", "CharacterTokenizer", "MecabTokenizer"]}

# 如果正在进行类型检查
if TYPE_CHECKING:
    # 导入具体的类,以便类型检查器能够正确处理类型
    from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer

# 如果不是在进行类型检查
else:
    # 导入 sys 模块以便后续使用
    import sys

    # 将当前模块替换为 LazyModule,以便在需要时才加载
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\bigbird_pegasus\configuration_bigbird_pegasus.py

# 设置文件编码为 UTF-8
# 版权声明,版权归 Google Research 和 HuggingFace Inc. 团队所有
#
# 根据 Apache 许可证版本 2.0 进行许可,除非符合许可证,否则不得使用此文件
# 可以在以下链接获取许可证的副本:
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则依据“原样”分发本软件,不提供任何形式的担保或条件
# 有关更多信息,请查阅许可证内容
""" BigBirdPegasus 模型配置"""

# 导入 OrderedDict 类和一些类型提示
from collections import OrderedDict
from typing import Any, Mapping, Optional

# 导入 PreTrainedTokenizer 类,它来自于父级目录中的模块
from ... import PreTrainedTokenizer

# 从 configuration_utils 模块中导入 PretrainedConfig 类
from ...configuration_utils import PretrainedConfig

# 从 onnx 模块中导入一些配置类
from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast

# 从 onnx.utils 模块导入 compute_effective_axis_dimension 函数
from ...onnx.utils import compute_effective_axis_dimension

# 导入 utils 模块中的一些实用函数和类
from ...utils import TensorType, is_torch_available, logging

# 获取 logger 对象,用于记录日志
logger = logging.get_logger(__name__)

# BigBirdPegasus 预训练配置文件的映射字典,包含了几个预训练模型的配置文件 URL
BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "google/bigbird-pegasus-large-arxiv": (
        "https://huggingface.co/google/bigbird-pegasus-large-arxiv/resolve/main/config.json"
    ),
    "google/bigbird-pegasus-large-pubmed": (
        "https://huggingface.co/google/bigbird-pegasus-large-pubmed/resolve/main/config.json"
    ),
    "google/bigbird-pegasus-large-bigpatent": (
        "https://huggingface.co/google/bigbird-pegasus-large-bigpatent/resolve/main/config.json"
    ),
    # 查看所有 BigBirdPegasus 模型的列表链接:https://huggingface.co/models?filter=bigbird_pegasus
}

class BigBirdPegasusConfig(PretrainedConfig):
    r"""
    这是用于存储 BigBirdPegasusModel 配置的类。它用于根据指定的参数实例化 BigBirdPegasus 模型,定义模型架构。
    使用默认值实例化配置将产生类似于 BigBirdPegasus google/bigbird-pegasus-large-arxiv 架构的配置。

    配置对象继承自 PretrainedConfig,并可用于控制模型输出。阅读 PretrainedConfig 的文档以获取更多信息。

    Example:

    ```
    >>> from transformers import BigBirdPegasusConfig, BigBirdPegasusModel

    >>> # 初始化一个 BigBirdPegasus bigbird-pegasus-base 风格的配置
    >>> configuration = BigBirdPegasusConfig()

    >>> # 从配置中初始化一个模型(带有随机权重)
    >>> model = BigBirdPegasusModel(configuration)

    >>> # 访问模型配置
    >>> configuration = model.config
    ```
    """

    # 模型类型为 "bigbird_pegasus"
    model_type = "bigbird_pegasus"

    # 在推理时忽略的键列表
    keys_to_ignore_at_inference = ["past_key_values"]
    # 定义一个字典,用于映射模型的属性名到预训练模型配置中使用的属性名
    attribute_map = {
        "num_attention_heads": "encoder_attention_heads",
        "hidden_size": "d_model",
        "attention_probs_dropout_prob": "attention_dropout",
    }

    # 初始化函数,用于创建一个新的预训练模型配置对象
    def __init__(
        self,
        vocab_size=96103,  # 词汇表大小,默认为96103
        max_position_embeddings=4096,  # 最大位置嵌入数,默认为4096
        encoder_layers=16,  # 编码器层数,默认为16层
        encoder_ffn_dim=4096,  # 编码器中FFN层的维度,默认为4096
        encoder_attention_heads=16,  # 编码器中注意力头的数量,默认为16个
        decoder_layers=16,  # 解码器层数,默认为16层
        decoder_ffn_dim=4096,  # 解码器中FFN层的维度,默认为4096
        decoder_attention_heads=16,  # 解码器中注意力头的数量,默认为16个
        encoder_layerdrop=0.0,  # 编码器层dropout率,默认为0.0
        decoder_layerdrop=0.0,  # 解码器层dropout率,默认为0.0
        use_cache=True,  # 是否使用缓存,默认为True
        is_encoder_decoder=True,  # 是否是编码解码模型,默认为True
        activation_function="gelu_new",  # 激活函数类型,默认为gelu_new
        d_model=1024,  # 模型维度,默认为1024
        dropout=0.1,  # 全局dropout率,默认为0.1
        attention_dropout=0.0,  # 注意力机制的dropout率,默认为0.0
        activation_dropout=0.0,  # 激活函数的dropout率,默认为0.0
        init_std=0.02,  # 参数初始化标准差,默认为0.02
        decoder_start_token_id=2,  # 解码器开始标记的ID,默认为2
        classifier_dropout=0.0,  # 分类器的dropout率,默认为0.0
        scale_embedding=True,  # 是否缩放嵌入,默认为True
        pad_token_id=0,  # 填充标记的ID,默认为0
        bos_token_id=2,  # 开始标记的ID,默认为2
        eos_token_id=1,  # 结束标记的ID,默认为1
        attention_type="block_sparse",  # 注意力类型,仅用于编码器,默认为block_sparse
        block_size=64,  # 块大小,默认为64
        num_random_blocks=3,  # 随机块的数量,默认为3
        use_bias=False,  # 是否使用偏置,默认为False
        **kwargs,  # 其他关键字参数
    ):
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.d_model = d_model
        self.encoder_ffn_dim = encoder_ffn_dim
        self.encoder_layers = encoder_layers
        self.encoder_attention_heads = encoder_attention_heads
        self.decoder_ffn_dim = decoder_ffn_dim
        self.decoder_layers = decoder_layers
        self.decoder_attention_heads = decoder_attention_heads
        self.dropout = dropout
        self.attention_dropout = attention_dropout
        self.activation_dropout = activation_dropout
        self.activation_function = activation_function
        self.init_std = init_std
        self.encoder_layerdrop = encoder_layerdrop
        self.decoder_layerdrop = decoder_layerdrop
        self.classifier_dropout = classifier_dropout
        self.use_cache = use_cache
        self.num_hidden_layers = encoder_layers  # 将编码器层数赋值给隐藏层数
        self.scale_embedding = scale_embedding  # 如果为True,则嵌入向量将缩放为sqrt(d_model)

        # 额外的配置参数
        self.attention_type = attention_type
        self.block_size = block_size
        self.num_random_blocks = num_random_blocks
        self.use_bias = use_bias

        # 调用父类的初始化方法,传入一些预定义的参数和额外的关键字参数
        super().__init__(
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            is_encoder_decoder=is_encoder_decoder,
            decoder_start_token_id=decoder_start_token_id,
            **kwargs,
        )
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig
class BigBirdPegasusOnnxConfig(OnnxSeq2SeqConfigWithPast):
    # 定义 inputs 属性,返回输入映射字典
    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        # 根据任务类型配置通用输入字典
        if self.task in ["default", "seq2seq-lm"]:
            common_inputs = OrderedDict(
                [
                    ("input_ids", {0: "batch", 1: "encoder_sequence"}),
                    ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
                ]
            )

            # 如果使用过去信息,添加特定于解码器的输入信息
            if self.use_past:
                common_inputs["decoder_input_ids"] = {0: "batch"}
                common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
            else:
                common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
                common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}

            # 如果使用过去信息,填充过去键值对
            if self.use_past:
                self.fill_with_past_key_values_(common_inputs, direction="inputs")
        elif self.task == "causal-lm":
            # 处理因果语言建模任务的情况,暂时标记为待解决
            common_inputs = OrderedDict(
                [
                    ("input_ids", {0: "batch", 1: "encoder_sequence"}),
                    ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
                ]
            )
            # 如果使用过去信息,为每个编码器层添加特定的过去键值对信息
            if self.use_past:
                num_encoder_layers, _ = self.num_layers
                for i in range(num_encoder_layers):
                    common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
                    common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
        else:
            # 默认情况下配置通用输入字典,包括编码器和解码器信息
            common_inputs = OrderedDict(
                [
                    ("input_ids", {0: "batch", 1: "encoder_sequence"}),
                    ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
                    ("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
                    ("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
                ]
            )

        # 返回通用输入字典
        return common_inputs

    # 定义 outputs 属性,返回输出映射字典
    @property
    def outputs(self) -> Mapping[str, Mapping[int, str]]:
        # 根据任务类型配置通用输出字典
        if self.task in ["default", "seq2seq-lm"]:
            common_outputs = super().outputs
        else:
            common_outputs = super(OnnxConfigWithPast, self).outputs
            # 如果使用过去信息,为每个编码器层添加特定的现在键值对信息
            if self.use_past:
                num_encoder_layers, _ = self.num_layers
                for i in range(num_encoder_layers):
                    common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
                    common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
        # 返回通用输出字典
        return common_outputs
    # 定义一个方法用于生成默认和序列到序列语言模型的虚拟输入数据
    def _generate_dummy_inputs_for_default_and_seq2seq_lm(
        self,
        tokenizer: PreTrainedTokenizer,
        batch_size: int = -1,
        seq_length: int = -1,
        is_pair: bool = False,
        framework: Optional[TensorType] = None,
    ) -> Mapping[str, Any]:
        # Generate encoder inputs using dummy data for sequence classification and question answering
        encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
            tokenizer, batch_size, seq_length, is_pair, framework
        )

        # Determine decoder sequence length based on whether past information is used
        decoder_seq_length = seq_length if not self.use_past else 1
        
        # Generate decoder inputs using dummy data, adjusted for sequence length and pairing
        decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
            tokenizer, batch_size, decoder_seq_length, is_pair, framework
        )
        
        # Prefix decoder input names and create a dictionary for decoder inputs
        decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
        
        # Combine encoder and decoder inputs into a common inputs dictionary
        common_inputs = dict(**encoder_inputs, **decoder_inputs)

        # Handle the case where past information is used
        if self.use_past:
            # Check if PyTorch is available; if not, raise an error
            if not is_torch_available():
                raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
            else:
                import torch
            
            # Extract batch size and encoder sequence length from common inputs
            batch, encoder_seq_length = common_inputs["input_ids"].shape
            
            # Determine decoder sequence length and attention heads from model configuration
            decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
            num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
            
            # Define shapes for encoder and decoder past key values
            encoder_shape = (
                batch,
                num_encoder_attention_heads,
                encoder_seq_length,
                self._config.hidden_size // num_encoder_attention_heads,
            )
            decoder_past_length = decoder_seq_length + 3
            decoder_shape = (
                batch,
                num_decoder_attention_heads,
                decoder_past_length,
                self._config.hidden_size // num_decoder_attention_heads,
            )

            # Expand decoder attention mask to accommodate past information
            common_inputs["decoder_attention_mask"] = torch.cat(
                [common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
            )

            # Initialize past key values list for storing past states
            common_inputs["past_key_values"] = []

            # Determine the minimum number of layers between encoder and decoder
            num_encoder_layers, num_decoder_layers = self.num_layers
            min_num_layers = min(num_encoder_layers, num_decoder_layers)

            # Determine the remaining side (encoder or decoder) for past key values
            remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"

            # Populate past key values with zero-initialized tensors for each layer
            for _ in range(min_num_layers):
                common_inputs["past_key_values"].append(
                    (
                        torch.zeros(decoder_shape),
                        torch.zeros(decoder_shape),
                        torch.zeros(encoder_shape),
                        torch.zeros(encoder_shape),
                    )
                )

            # TODO: test this.
            # Extend past key values with zero-initialized tensors for additional layers
            shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
            for _ in range(min_num_layers, max_num_layers):
                common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
        
        # Return the finalized common inputs dictionary
        return common_inputs
    def _generate_dummy_inputs_for_causal_lm(
        self,
        tokenizer: PreTrainedTokenizer,
        batch_size: int = -1,
        seq_length: int = -1,
        is_pair: bool = False,
        framework: Optional[TensorType] = None,
    ) -> Mapping[str, Any]:
        # 调用生成用于序列分类和问答的虚拟输入方法,获取共享的输入字典
        common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
            tokenizer, batch_size, seq_length, is_pair, framework
        )

        # 如果需要使用过去的键值(past_key_values)
        if self.use_past:
            # 检查是否安装了 torch 库
            if not is_torch_available():
                raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
            else:
                import torch
            
            # 获取 batch 和 seqlen 的大小
            batch, seqlen = common_inputs["input_ids"].shape
            
            # 设置过去键值的长度,比 seqlen 多 2
            past_key_values_length = seqlen + 2
            
            # 获取编码器层和注意力头的数量
            num_encoder_layers, _ = self.num_layers
            num_encoder_attention_heads, _ = self.num_attention_heads
            
            # 设置过去键值的形状
            past_shape = (
                batch,
                num_encoder_attention_heads,
                past_key_values_length,
                self._config.hidden_size // num_encoder_attention_heads,
            )

            # 获取掩码的数据类型
            mask_dtype = common_inputs["attention_mask"].dtype
            
            # 将新生成的掩码与现有掩码连接起来
            common_inputs["attention_mask"] = torch.cat(
                [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
            )
            
            # 初始化过去键值的占位符列表
            common_inputs["past_key_values"] = [
                (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
            ]
        
        # 返回最终生成的共享输入字典
        return common_inputs

    def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
        self,
        tokenizer: PreTrainedTokenizer,
        batch_size: int = -1,
        seq_length: int = -1,
        is_pair: bool = False,
        framework: Optional[TensorType] = None,
    ) -> Mapping[str, Any]:
        # 从 OnnxConfig.generate_dummy_inputs 复制而来
        # 为了代码清晰性,没有使用 super(OnnxConfigWithPast, self).generate_dummy_inputs
        # 如果动态轴为 -1,则使用固定的样本维度 2 来避免 ONNX 的优化影响
        batch_size = compute_effective_axis_dimension(
            batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
        )

        # 如果动态轴为 -1,则使用固定的序列长度 8 来避免 ONNX 的优化影响
        token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
        seq_length = compute_effective_axis_dimension(
            seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
        )

        # 根据计算的批处理大小和序列长度生成虚拟输入
        dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
        common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
        return common_inputs
    # 生成虚拟输入数据的方法,返回一个包含各种任务通用输入的字典
    def generate_dummy_inputs(
        self,
        tokenizer: PreTrainedTokenizer,
        batch_size: int = -1,
        seq_length: int = -1,
        is_pair: bool = False,
        framework: Optional[TensorType] = None,
    ) -> Mapping[str, Any]:
        # 如果任务类型是"default"或"seq2seq-lm",调用相应的方法生成通用输入数据
        if self.task in ["default", "seq2seq-lm"]:
            common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm(
                tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
            )
        # 如果任务类型是"causal-lm",调用相应的方法生成通用输入数据
        elif self.task == "causal-lm":
            common_inputs = self._generate_dummy_inputs_for_causal_lm(
                tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
            )
        # 对于其它任务类型,调用相应的方法生成通用输入数据
        else:
            common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
                tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
            )
        
        # 返回生成的通用输入数据字典
        return common_inputs

    # 根据任务类型选择性地扁平化过去的键值对
    def _flatten_past_key_values_(self, flattened_output, name, idx, t):
        # 如果任务类型是"default"或"seq2seq-lm",调用父类方法来扁平化过去的键值对
        if self.task in ["default", "seq2seq-lm"]:
            flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
        # 对于其他任务类型,使用带有历史信息的特定子类调用父类方法来扁平化过去的键值对
        else:
            flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
                flattened_output, name, idx, t
            )

.\models\bigbird_pegasus\convert_bigbird_pegasus_tf_to_pytorch.py

# 导入必要的库和模块
import argparse  # 导入命令行参数解析模块
from typing import Dict  # 导入类型提示模块 Dict

import tensorflow as tf  # 导入 TensorFlow 库
import torch  # 导入 PyTorch 库
from tqdm import tqdm  # 导入进度条模块 tqdm

from transformers import BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration  # 导入 Transformers 库中的类

# TensorFlow 到 Hugging Face 模型命名的初始映射列表
INIT_COMMON = [
    ("/", "."),  # 替换 "/" 为 "."
    ("layer_", "layers."),  # 替换 "layer_" 为 "layers."
    ("kernel", "weight"),  # 替换 "kernel" 为 "weight"
    ("beta", "bias"),  # 替换 "beta" 为 "bias"
    ("gamma", "weight"),  # 替换 "gamma" 为 "weight"
    ("pegasus", "model"),  # 替换 "pegasus" 为 "model"
]

# TensorFlow 到 Hugging Face 模型命名的结尾映射列表
END_COMMON = [
    (".output.dense", ".fc2"),  # 替换 ".output.dense" 为 ".fc2"
    ("intermediate.LayerNorm", "final_layer_norm"),  # 替换 "intermediate.LayerNorm" 为 "final_layer_norm"
    ("intermediate.dense", "fc1"),  # 替换 "intermediate.dense" 为 "fc1"
]

# 解码器模型权重命名模式列表,包含初始、中间和结尾映射
DECODER_PATTERNS = (
    INIT_COMMON
    + [
        ("attention.self.LayerNorm", "self_attn_layer_norm"),  # 替换 "attention.self.LayerNorm" 为 "self_attn_layer_norm"
        ("attention.output.dense", "self_attn.out_proj"),  # 替换 "attention.output.dense" 为 "self_attn.out_proj"
        ("attention.self", "self_attn"),  # 替换 "attention.self" 为 "self_attn"
        ("attention.encdec.LayerNorm", "encoder_attn_layer_norm"),  # 替换 "attention.encdec.LayerNorm" 为 "encoder_attn_layer_norm"
        ("attention.encdec_output.dense", "encoder_attn.out_proj"),  # 替换 "attention.encdec_output.dense" 为 "encoder_attn.out_proj"
        ("attention.encdec", "encoder_attn"),  # 替换 "attention.encdec" 为 "encoder_attn"
        ("key", "k_proj"),  # 替换 "key" 为 "k_proj"
        ("value", "v_proj"),  # 替换 "value" 为 "v_proj"
        ("query", "q_proj"),  # 替换 "query" 为 "q_proj"
        ("decoder.LayerNorm", "decoder.layernorm_embedding"),  # 替换 "decoder.LayerNorm" 为 "decoder.layernorm_embedding"
    ]
    + END_COMMON
)

# 剩余模型权重命名模式列表,包含初始、中间和结尾映射
REMAINING_PATTERNS = (
    INIT_COMMON
    + [
        ("embeddings.word_embeddings", "shared.weight"),  # 替换 "embeddings.word_embeddings" 为 "shared.weight"
        ("embeddings.position_embeddings", "embed_positions.weight"),  # 替换 "embeddings.position_embeddings" 为 "embed_positions.weight"
        ("attention.self.LayerNorm", "self_attn_layer_norm"),  # 替换 "attention.self.LayerNorm" 为 "self_attn_layer_norm"
        ("attention.output.dense", "self_attn.output"),  # 替换 "attention.output.dense" 为 "self_attn.output"
        ("attention.self", "self_attn.self"),  # 替换 "attention.self" 为 "self_attn.self"
        ("encoder.LayerNorm", "encoder.layernorm_embedding"),  # 替换 "encoder.LayerNorm" 为 "encoder.layernorm_embedding"
    ]
    + END_COMMON
)

# 需要忽略的键列表,这些键不进行名称转换
KEYS_TO_IGNORE = [
    "encdec/key/bias",
    "encdec/query/bias",
    "encdec/value/bias",
    "self/key/bias",
    "self/query/bias",
    "self/value/bias",
    "encdec_output/dense/bias",
    "attention/output/dense/bias",
]

def rename_state_dict_key(k, patterns):
    # 根据给定的模式列表 patterns,替换给定的键 k 的名称
    for tf_name, hf_name in patterns:
        k = k.replace(tf_name, hf_name)
    return k

def convert_bigbird_pegasus(tf_weights: dict, config_update: dict) -> BigBirdPegasusForConditionalGeneration:
    # 根据 config_update 创建 BigBirdPegasusConfig 对象 cfg
    cfg = BigBirdPegasusConfig(**config_update)
    # 根据 cfg 创建 BigBirdPegasusForConditionalGeneration 对象 torch_model
    torch_model = BigBirdPegasusForConditionalGeneration(cfg)
    # 获取 torch_model 的状态字典 state_dict
    state_dict = torch_model.state_dict()
    # 创建空字典 mapping,用于存储键的映射关系
    mapping = {}

    # 分离解码器权重
    decoder_weights = {k: tf_weights[k] for k in tf_weights if k.startswith("pegasus/decoder")}
    # 分离剩余权重
    remaining_weights = {k: tf_weights[k] for k in tf_weights if not k.startswith("pegasus/decoder")}
    # 遍历 decoder_weights 字典中的键值对,显示进度条为 "tf -> hf conversion"
    for k, v in tqdm(decoder_weights.items(), "tf -> hf conversion"):
        # 检查当前键是否以 KEYS_TO_IGNORE 中任意后缀结尾,如果是则跳过当前循环
        conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
        if any(conditions):
            continue
        # 使用预定义的 DECODER_PATTERNS 对键 k 进行重命名,得到 new_k
        patterns = DECODER_PATTERNS
        new_k = rename_state_dict_key(k, patterns)
        # 如果 new_k 不在 state_dict 中,抛出异常,指明无法在 state_dict 中找到对应的新键 new_k(从旧键 k 转换而来)
        if new_k not in state_dict:
            raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
        # 如果键 k 中包含 "dense", "query", "key", "value" 中任何一个关键字,则对 v 进行转置操作
        if any(True if i in k else False for i in ["dense", "query", "key", "value"]):
            v = v.T
        # 将 torch.Tensor 类型的 v 赋值给 mapping[new_k]
        mapping[new_k] = torch.from_numpy(v)
        # 断言 v 的形状与 state_dict[new_k] 的形状相同,如果不同则抛出异常,指明不匹配的键与形状信息
        assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"

    # 遍历 remaining_weights 字典中的键值对,显示进度条为 "tf -> hf conversion"
    for k, v in tqdm(remaining_weights.items(), "tf -> hf conversion"):
        # 检查当前键是否以 KEYS_TO_IGNORE 中任意后缀结尾,如果是则跳过当前循环
        conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
        if any(conditions):
            continue
        # 使用预定义的 REMAINING_PATTERNS 对键 k 进行重命名,得到 new_k
        patterns = REMAINING_PATTERNS
        new_k = rename_state_dict_key(k, patterns)
        # 如果 new_k 不在 state_dict 中,并且 k 不等于 "pegasus/embeddings/position_embeddings",抛出异常
        # 指明无法在 state_dict 中找到对应的新键 new_k(从旧键 k 转换而来)
        if new_k not in state_dict and k != "pegasus/embeddings/position_embeddings":
            raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
        # 如果键 k 中包含 "dense", "query", "key", "value" 中任何一个关键字,则对 v 进行转置操作
        if any(True if i in k else False for i in ["dense", "query", "key", "value"]):
            v = v.T
        # 将 torch.Tensor 类型的 v 赋值给 mapping[new_k]
        mapping[new_k] = torch.from_numpy(v)
        # 如果 k 不等于 "pegasus/embeddings/position_embeddings",断言 v 的形状与 state_dict[new_k] 的形状相同,
        # 如果不同则抛出异常,指明不匹配的键与形状信息
        if k != "pegasus/embeddings/position_embeddings":
            assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"

    # 将 mapping 中的键 "model.embed_positions.weight" 的值复制给键 "model.encoder.embed_positions.weight"
    mapping["model.encoder.embed_positions.weight"] = mapping["model.embed_positions.weight"]
    # 弹出 mapping 中键为 "model.embed_positions.weight" 的值,并将其赋给键 "model.decoder.embed_positions.weight"
    mapping["model.decoder.embed_positions.weight"] = mapping.pop("model.embed_positions.weight")
    # 载入 mapping 到 torch_model 的状态字典,允许部分键名不严格匹配
    missing, extra = torch_model.load_state_dict(mapping, strict=False)
    # 找出在 state_dict 中缺失的键,并将其列为 missing
    unexpected_missing = [
        k
        for k in missing
        if k
        not in [
            "final_logits_bias",
            "model.encoder.embed_tokens.weight",
            "model.decoder.embed_tokens.weight",
            "lm_head.weight",
        ]
    ]
    # 断言 unexpected_missing 为空列表,如果不是则抛出异常,指明找不到匹配的 torch 键
    assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
    # 断言 extra 为空列表,如果不是则抛出异常,指明找不到匹配的 tf 键
    assert extra == [], f"no matches found for the following tf keys {extra}"
    # 返回加载了 mapping 后的 torch_model
    return torch_model
# 定义函数,用于从 TensorFlow 模型的检查点文件中获取权重并以字典形式返回
def get_tf_weights_as_numpy(path) -> Dict:
    # 使用 TensorFlow 提供的工具函数列出给定路径下的所有变量和它们的形状
    init_vars = tf.train.list_variables(path)
    # 初始化一个空字典,用于存储 TensorFlow 权重
    tf_weights = {}
    # 定义要忽略的变量名列表,例如全局步数变量
    ignore_name = ["global_step"]
    # 遍历初始化变量列表,并显示转换进度描述为“converting tf checkpoint to dict”
    for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"):
        # 如果变量名中包含在忽略列表中的任何模式,则跳过该变量
        skip_key = any(pat in name for pat in ignore_name)
        if skip_key:
            continue
        # 加载指定路径中的变量数据并存储到数组中
        array = tf.train.load_variable(path, name)
        # 将加载的变量数据存储到字典中,以变量名作为键
        tf_weights[name] = array
    # 返回整理后的 TensorFlow 权重字典
    return tf_weights


# 定义函数,用于将 BigBird-Pegasus 模型的 TensorFlow 检查点转换为 PyTorch 模型
def convert_bigbird_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str, config_update: dict):
    # 获取 TensorFlow 模型的权重字典
    tf_weights = get_tf_weights_as_numpy(ckpt_path)
    # 使用给定的配置更新字典和 TensorFlow 权重字典,转换为 PyTorch 模型
    torch_model = convert_bigbird_pegasus(tf_weights, config_update)
    # 将转换后的 PyTorch 模型保存到指定的目录中
    torch_model.save_pretrained(save_dir)


# 程序入口点,用于命令行参数解析和执行转换操作
if __name__ == "__main__":
    # 创建参数解析器
    parser = argparse.ArgumentParser()
    # 添加命令行参数选项,用于指定 TensorFlow 检查点文件的路径
    parser.add_argument("--tf_ckpt_path", type=str, help="passed to tf.train.list_variables")
    # 添加命令行参数选项,用于指定输出 PyTorch 模型的保存路径
    parser.add_argument("--save_dir", default=None, type=str, help="Path to the output PyTorch model.")
    # 解析命令行参数,并存储到 args 对象中
    args = parser.parse_args()
    # 初始化一个空的配置更新字典
    config_update = {}
    # 调用函数,执行 BigBird-Pegasus 模型从 TensorFlow 到 PyTorch 的转换过程
    convert_bigbird_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir, config_update=config_update)

.\models\bigbird_pegasus\modeling_bigbird_pegasus.py

# coding=utf-8
# Copyright 2021 Google Research The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch BigBirdPegasus model."""

import copy  # 导入 copy 模块用于复制对象
import math  # 导入 math 模块用于数学运算
from typing import List, Optional, Tuple, Union  # 导入类型提示相关的类和函数

import numpy as np  # 导入 NumPy 库用于数值计算
import torch  # 导入 PyTorch 库
from torch import nn  # 导入 PyTorch 的神经网络模块
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss  # 导入损失函数

from ...activations import ACT2FN  # 导入激活函数映射
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask  # 导入辅助注意力掩码函数
from ...modeling_outputs import (  # 导入模型输出相关类
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
    Seq2SeqQuestionAnsweringModelOutput,
    Seq2SeqSequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel  # 导入预训练模型的基类
from ...utils import (  # 导入工具函数
    add_code_sample_docstrings,
    add_end_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from .configuration_bigbird_pegasus import BigBirdPegasusConfig  # 导入配置类

logger = logging.get_logger(__name__)  # 获取日志记录器

_CHECKPOINT_FOR_DOC = "google/bigbird-pegasus-large-arxiv"  # 文档中使用的预训练模型检查点名称
_CONFIG_FOR_DOC = "BigBirdPegasusConfig"  # 文档中使用的配置类名称
_EXPECTED_OUTPUT_SHAPE = [1, 7, 1024]  # 预期输出的形状

BIGBIRD_PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST = [  # 预训练模型的列表
    "google/bigbird-pegasus-large-arxiv",
    "google/bigbird-pegasus-large-pubmed",
    "google/bigbird-pegasus-large-bigpatent",
    # See all BigBirdPegasus models at https://huggingface.co/models?filter=bigbird_pegasus
]


def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)  # 创建与 input_ids 形状相同的零张量
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()  # 将 input_ids 向右移动一位
    shifted_input_ids[:, 0] = decoder_start_token_id  # 在第一列填充 decoder_start_token_id

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")  # 如果 pad_token_id 未定义则引发 ValueError
    # 将 shifted_input_ids 中可能的 -100 值替换为 pad_token_id
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

    return shifted_input_ids  # 返回右移后的 input_ids


class BigBirdPegasusLearnedPositionalEmbedding(nn.Embedding):
    """
    This module learns positional embeddings up to a fixed maximum size.
    """

    def __init__(self, num_embeddings: int, embedding_dim: int):
        super().__init__(num_embeddings, embedding_dim)  # 调用父类 nn.Embedding 的构造方法,初始化位置嵌入层
    def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
        """`input_ids_shape` is expected to be [bsz x seqlen]."""
        # 从输入参数 `input_ids_shape` 中获取 batch size 和 sequence length
        bsz, seq_len = input_ids_shape[:2]
        # 生成一个序列,表示位置编码,起始位置从 `past_key_values_length` 到 `past_key_values_length + seq_len`
        positions = torch.arange(
            past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
        )
        # 调用父类的 `forward` 方法,传入位置编码 `positions`,并返回结果
        return super().forward(positions)
# Copied from transformers.models.big_bird.modeling_big_bird.BigBirdSelfAttention with BigBird->BigBirdPegasus
# 定义了 BigBirdPegasusSelfAttention 类,继承自 nn.Module
class BigBirdPegasusSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 检查隐藏层大小是否是注意力头数的整数倍,如果不是则抛出 ValueError 异常
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # 定义用于查询、键、值的线性变换层,输入大小为隐藏大小,输出大小为所有头的大小
        self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias)
        self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias)
        self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias)

        # 定义用于dropout的层,以及是否作为解码器的标志
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.is_decoder = config.is_decoder

    # 将输入张量转换为注意力分数的形状
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    # 前向传播函数,接收隐藏状态等多个参数,执行自注意力机制
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):



# Copied from transformers.models.big_bird.modeling_big_bird.BigBirdBlockSparseAttention with BigBird->BigBirdPegasus
# 定义了 BigBirdPegasusBlockSparseAttention 类,继承自 nn.Module
class BigBirdPegasusBlockSparseAttention(nn.Module):
    def __init__(self, config, seed=None):
        super().__init__()

        self.max_seqlen = config.max_position_embeddings
        self.seed = seed

        # 检查隐藏大小是否是注意力头数的整数倍,否则抛出 ValueError 异常
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
                f"heads {config.num_attention_heads}."
            )

        self.num_attention_heads = config.num_attention_heads
        self.num_random_blocks = config.num_random_blocks
        self.block_size = config.block_size

        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # 定义用于查询、键、值的线性变换层,输入大小为隐藏大小,输出大小为所有头的大小
        self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias)
        self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias)
        self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias)
    def transpose_for_scores(self, x):
        # 计算转置后张量的新形状,以便用于注意力计算
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        # 对输入张量进行形状变换
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states,
        band_mask=None,
        from_mask=None,
        to_mask=None,
        from_blocked_mask=None,
        to_blocked_mask=None,
        output_attentions=None,
    ):
        # 当前这个类不能用于解码器

        # 获取隐藏状态张量的维度信息
        batch_size, seqlen, _ = hidden_states.size()
        to_seq_length = from_seq_length = seqlen
        from_block_size = to_block_size = self.block_size

        # 检查查询侧序列长度是否是块大小的倍数
        if from_seq_length % from_block_size != 0:
            raise ValueError("Query sided sequence length must be multiple of block size")

        # 检查键/值侧序列长度是否是块大小的倍数
        if to_seq_length % to_block_size != 0:
            raise ValueError("Key/Value sided sequence length must be multiple of block size")

        # 对查询、键、值进行转置,以便进行注意力计算
        query_layer = self.transpose_for_scores(self.query(hidden_states))
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))

        # 调用自定义的大鸟块稀疏注意力机制
        context_layer, attention_probs = self.bigbird_block_sparse_attention(
            query_layer,
            key_layer,
            value_layer,
            band_mask,
            from_mask,
            to_mask,
            from_blocked_mask,
            to_blocked_mask,
            self.num_attention_heads,
            self.num_random_blocks,
            self.attention_head_size,
            from_block_size,
            to_block_size,
            batch_size,
            from_seq_length,
            to_seq_length,
            seed=self.seed,
            plan_from_length=None,
            plan_num_rand_blocks=None,
            output_attentions=output_attentions,
        )

        # 将上下文层重新变形为原始形状
        context_layer = context_layer.contiguous().view(batch_size, from_seq_length, -1)

        # 如果需要输出注意力权重,将其包含在输出中
        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
        return outputs

    @staticmethod
    def torch_bmm_nd(inp_1, inp_2, ndim=None):
        """快速的多维矩阵乘法"""
        # 使用torch.bmm更快地实现torch.einsum ("bhqk,bhkd->bhqd")的功能
        return torch.bmm(inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:])).view(
            inp_1.shape[: ndim - 2] + (inp_1.shape[ndim - 2], inp_2.shape[ndim - 1])
        )

    @staticmethod
    def torch_bmm_nd_transpose(inp_1, inp_2, ndim=None):
        """带转置的快速多维矩阵乘法"""
        # 使用torch.bmm更快地实现torch.einsum ("bhqd,bhkd->bhqk")的功能
        return torch.bmm(
            inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:]).transpose(1, 2)
        ).view(inp_1.shape[: ndim - 2] + (inp_1.shape[ndim - 2], inp_2.shape[ndim - 2]))
    @staticmethod
    def bigbird_block_sparse_attention(
        self,
        query_layer,
        key_layer,
        value_layer,
        band_mask,
        from_mask,
        to_mask,
        from_blocked_mask,
        to_blocked_mask,
        n_heads,
        n_rand_blocks,
        attention_head_size,
        from_block_size,
        to_block_size,
        batch_size,
        from_seq_len,
        to_seq_len,
        seed,
        plan_from_length,
        plan_num_rand_blocks,
        output_attentions,
    ):
        # 实现BigBird模型中的稀疏块注意力机制
        # 参数说明:
        # - query_layer, key_layer, value_layer: 查询、键、值的张量
        # - band_mask: 带状掩码,限制注意力只在一定带状范围内
        # - from_mask, to_mask: 来源和目标的掩码,限制注意力的有效范围
        # - from_blocked_mask, to_blocked_mask: 分块的掩码,用于分块注意力机制
        # - n_heads: 注意力头的数量
        # - n_rand_blocks: 随机块的数量
        # - attention_head_size: 注意力头的尺寸
        # - from_block_size, to_block_size: 来源和目标块的尺寸
        # - batch_size: 批次大小
        # - from_seq_len, to_seq_len: 来源和目标序列的长度
        # - seed: 随机种子
        # - plan_from_length: 计划的来源长度
        # - plan_num_rand_blocks: 计划的随机块数量
        # - output_attentions: 是否输出注意力权重

        # 实现tf.gather类似的torch版本的功能,当batch_dims=2时
    @staticmethod
    def torch_gather_b2(params, indices):
        # 此操作相当于tf.gather,当batch_dims=2时

        if params.shape[:2] != indices.shape[:2]:
            raise ValueError(
                "Make sure that the first two dimensions of params and indices are identical, but"
                f" they are params: {params.shape[:2]} vs. indices: {indices.shape[:2]}"
            )
        num_indices_to_gather = indices.shape[-2] * indices.shape[-1]
        num_indices_to_pick_from = params.shape[2]

        shift = torch.arange(indices.shape[0] * indices.shape[1] * num_indices_to_gather, device=indices.device)
        indices_shift = torch.div(shift, num_indices_to_gather, rounding_mode="floor") * num_indices_to_pick_from

        flattened_indices = indices.view(-1) + indices_shift
        flattened_params = params.reshape(-1, params.shape[-2], params.shape[-1])

        out_flattened = flattened_params.index_select(0, flattened_indices)

        out = out_flattened.reshape(params.shape[:2] + (num_indices_to_gather,) + params.shape[3:])
        return out

    @staticmethod
    def _create_rand_mask_from_inputs(
        from_blocked_mask,
        to_blocked_mask,
        rand_attn,
        num_attention_heads,
        num_rand_blocks,
        batch_size,
        from_seq_length,
        from_block_size,
    ):
        """
        Create 3D attention mask from a 2D tensor mask.

        Args:
            from_blocked_mask: 2D Tensor of shape [batch_size,
                from_seq_length//from_block_size, from_block_size].
                输入的来自的序列被块化后的掩码,形状为 [batch_size, from_seq_length//from_block_size, from_block_size]。
            to_blocked_mask: int32 Tensor of shape [batch_size,
                to_seq_length//to_block_size, to_block_size].
                输入的目标序列被块化后的掩码,形状为 [batch_size, to_seq_length//to_block_size, to_block_size]。
            rand_attn: [batch_size, num_attention_heads,
                from_seq_length//from_block_size-2, num_rand_blocks]
                随机注意力的掩码,形状为 [batch_size, num_attention_heads, from_seq_length//from_block_size-2, num_rand_blocks]。
            num_attention_heads: int. Number of attention heads.
                注意力头的数量。
            num_rand_blocks: int. Number of random chunks per row.
                每行的随机块数。
            batch_size: int. Batch size for computation.
                计算的批次大小。
            from_seq_length: int. length of from sequence.
                输入序列的长度。
            from_block_size: int. size of block in from sequence.
                输入序列中的块大小。

        Returns:
            float Tensor of shape [batch_size, num_attention_heads, from_seq_length//from_block_size-2,
                from_block_size, num_rand_blocks*to_block_size].
            返回形状为 [batch_size, num_attention_heads, from_seq_length//from_block_size-2,
                from_block_size, num_rand_blocks*to_block_size] 的浮点数张量。
        """
        num_windows = from_seq_length // from_block_size - 2
        # 根据输入序列的块大小计算窗口数
        rand_mask = torch.stack([p1[i1.flatten()] for p1, i1 in zip(to_blocked_mask, rand_attn)])
        # 使用目标序列的掩码和随机注意力创建随机掩码
        rand_mask = rand_mask.view(batch_size, num_attention_heads, num_windows, num_rand_blocks * from_block_size)
        # 通过 einsum 操作组合来自序列的掩码和随机掩码
        rand_mask = torch.einsum("blq,bhlk->bhlqk", from_blocked_mask[:, 1:-1], rand_mask)
        return rand_mask

    @staticmethod
    def _get_rand_attn_plan(from_seq_length, from_block_size, num_rand_blocks):
        """
        Gives the plan of where to put random attention.

        Args:
            from_seq_length: int. length of from sequence.
                输入序列的长度。
            from_block_size: int. size of block in from sequence.
                输入序列中的块大小。
            num_rand_blocks: int. Number of random chunks per row.
                每行的随机块数。

        Returns:
            plan_from_length: ending location of from block plan_num_rand_blocks: number of random ending location for
                each block
            返回计划的输入序列块结束位置和每个块的随机结束位置的计划。
        """
        plan_from_length = []
        plan_num_rand_blocks = []
        if (2 * num_rand_blocks + 5) < (from_seq_length // from_block_size):
            plan_from_length.append(int((2 * num_rand_blocks + 5) * from_block_size))
            plan_num_rand_blocks.append(num_rand_blocks)
            plan_from_length.append(from_seq_length)
            plan_num_rand_blocks.append(0)
        elif (num_rand_blocks + 5) < (from_seq_length // from_block_size):
            plan_from_length.append(int((num_rand_blocks + 5) * from_block_size))
            plan_num_rand_blocks.append(num_rand_blocks // 2)
            plan_from_length.append(from_seq_length)
            plan_num_rand_blocks.append(num_rand_blocks - (num_rand_blocks // 2))
        else:
            plan_from_length.append(from_seq_length)
            plan_num_rand_blocks.append(num_rand_blocks)

        return plan_from_length, plan_num_rand_blocks

    def _bigbird_block_rand_mask(
        self, from_seq_length, to_seq_length, from_block_size, to_block_size, num_rand_blocks, last_idx=-1
    ):
        """
        Create adjacency list of random attention.

        Args:
            from_seq_length: int. length of from sequence.
            to_seq_length: int. length of to sequence.
            from_block_size: int. size of block in from sequence.
            to_block_size: int. size of block in to sequence.
            num_rand_blocks: int. Number of random chunks per row.
            last_idx: if -1 then num_rand_blocks blocks chosen anywhere in to sequence,
            if positive then num_rand_blocks blocks chosen only up to last_idx.

        Returns:
            adjacency list of size from_seq_length//from_block_size-2 by num_rand_blocks
        """
        # using this method when from_seq_length in [1024, 3072, 4096]

        # 检查是否 from_seq_length 和 to_seq_length 的块数相等,否则抛出异常
        if from_seq_length // from_block_size != to_seq_length // to_block_size:
            raise ValueError("Error the number of blocks needs to be same!")

        # 创建一个全零数组,表示随机注意力的邻接列表
        rand_attn = np.zeros((from_seq_length // from_block_size - 2, num_rand_blocks), dtype=np.int32)

        # 推理阶段(非训练状态),直接返回全零的随机注意力邻接列表
        if not self.training:
            return rand_attn

        # 创建中间序列,用于生成随机块索引
        middle_seq = np.arange(1, to_seq_length // to_block_size - 1, dtype=np.int32)
        last = to_seq_length // to_block_size - 1

        # 根据 last_idx 的值确定最后一个块的索引范围
        if last_idx > (2 * to_block_size):
            last = (last_idx // to_block_size) - 1

        r = num_rand_blocks  # 缩写 r 表示 num_rand_blocks

        # 循环创建每行的随机注意力邻接列表
        for i in range(1, from_seq_length // from_block_size - 1):
            start = i - 2
            end = i

            if i == 1:
                # 对第一行进行随机排列选择中间序列中的块索引
                rand_attn[i - 1, :] = np.random.permutation(middle_seq[2:last])[:r]
            elif i == 2:
                # 对第二行进行随机排列选择中间序列中的块索引
                rand_attn[i - 1, :] = np.random.permutation(middle_seq[3:last])[:r]
            elif i == from_seq_length // from_block_size - 3:
                # 对倒数第三行进行随机排列选择中间序列中的块索引
                rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r]
            elif i == from_seq_length // from_block_size - 2:
                # 对倒数第二行进行随机排列选择中间序列中的块索引
                rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r]
            else:
                if start > last:
                    start = last
                    # 如果起始大于最后一个块的索引,则选择中间序列中的前 start 个块索引进行随机排列
                    rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r]
                elif (end + 1) == last:
                    # 如果结束索引的下一个等于最后一个块的索引,则选择中间序列中的前 start 个块索引进行随机排列
                    rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r]
                else:
                    # 否则,选择中间序列中除了指定的 start 和 end 块索引外的其余块索引进行随机排列
                    rand_attn[i - 1, :] = np.random.permutation(
                        np.concatenate((middle_seq[:start], middle_seq[end + 1 : last]))
                    )[:r]

        # 返回生成的随机注意力邻接列表
        return rand_attn
    def _bigbird_block_rand_mask_with_head(
        self,
        from_seq_length,
        to_seq_length,
        from_block_size,
        to_block_size,
        num_heads,
        plan_from_length,
        plan_num_rand_blocks,
        window_block_left=1,
        window_block_right=1,
        global_block_top=1,
        global_block_bottom=1,
        global_block_left=1,
        global_block_right=1,
    ):
        """
        Generates a random mask for BigBird attention with head.

        Args:
            from_seq_length: int. Length of the source sequence.
            to_seq_length: int. Length of the target sequence.
            from_block_size: int. Block size of the source sequence.
            to_block_size: int. Block size of the target sequence.
            num_heads: int. Number of attention heads.
            plan_from_length: int. Planned length of the source sequence.
            plan_num_rand_blocks: int. Planned number of random blocks.
            window_block_left: int. Number of blocks of window to the left of a block.
            window_block_right: int. Number of blocks of window to the right of a block.
            global_block_top: int. Number of blocks globally used at the top.
            global_block_bottom: int. Number of blocks globally used at the bottom.
            global_block_left: int. Number of blocks globally used to the left.
            global_block_right: int. Number of blocks globally used to the right.

        Returns:
            Random mask with head for BigBird attention.
        """
        # Implementation of random mask generation for BigBird attention
        pass


    @staticmethod
    def _get_single_block_row_attention(
        block_id,
        to_start_block_id,
        to_end_block_id,
        num_rand_blocks,
        window_block_left=1,
        window_block_right=1,
        global_block_left=1,
        global_block_right=1,
    ):
        """
        For a single row block, get random row attention.

        Args:
            block_id: int. Block ID of the row.
            to_start_block_id: int. Start ID of the target blocks for random attention.
            to_end_block_id: int. End ID of the target blocks for random attention.
            num_rand_blocks: int. Number of random blocks to be selected.
            window_block_left: int. Number of blocks of window to the left of a block.
            window_block_right: int. Number of blocks of window to the right of a block.
            global_block_left: int. Number of blocks globally used to the left.
            global_block_right: int. Number of blocks globally used to the right.

        Returns:
            Array containing the selected random attention vector of size num_rand_blocks.
        """
        # List of to_blocks from which to choose random attention
        to_block_list = np.arange(to_start_block_id, to_end_block_id, dtype=np.int32)
        # Permute the blocks
        perm_block = np.random.permutation(to_block_list)

        # Illegal blocks for the current block id, using window
        illegal_blocks = list(range(block_id - window_block_left, block_id + window_block_right + 1))

        # Add blocks at the start and at the end
        illegal_blocks.extend(list(range(global_block_left)))
        illegal_blocks.extend(list(range(to_end_block_id - global_block_right, to_end_block_id)))

        # The second from_block cannot choose random attention on second last to_block
        if block_id == 1:
            illegal_blocks.append(to_end_block_id - 2)

        # The second last from_block cannot choose random attention on second to_block
        if block_id == to_end_block_id - 2:
            illegal_blocks.append(1)

        selected_random_blocks = []

        for i in range(to_end_block_id - to_start_block_id):
            if perm_block[i] not in illegal_blocks:
                selected_random_blocks.append(perm_block[i])
            if len(selected_random_blocks) == num_rand_blocks:
                break
        return np.array(selected_random_blocks, dtype=np.int32)
# 定义 BigBirdPegasusEncoderAttention 类,继承自 nn.Module,用于编码器部分的注意力机制
class BigBirdPegasusEncoderAttention(nn.Module):
    # 初始化方法,接受配置参数 config 和种子参数 seed(可选)
    def __init__(self, config, seed=None):
        super().__init__()
        # 将配置参数 config 和种子参数 seed 存储在实例中
        self.config = config
        self.seed = seed

        # 从配置中获取注意力类型并存储在实例变量中
        self.attention_type = config.attention_type

        # 根据不同的注意力类型选择对应的注意力模块
        if self.attention_type == "original_full":
            self.self = BigBirdPegasusSelfAttention(config)
        elif self.attention_type == "block_sparse":
            self.self = BigBirdPegasusBlockSparseAttention(config, seed)
        else:
            # 如果注意力类型不是预期的值,抛出 ValueError 异常
            raise ValueError(
                f"attention_type can either be original_full or block_sparse, but is {self.config.attention_type}"
            )

        # 定义输出层,将隐藏状态映射回原始维度
        self.output = nn.Linear(config.hidden_size, config.hidden_size, bias=config.use_bias)

    # 设置注意力类型的方法,接受字符串类型的 value 参数
    def set_attention_type(self, value: str):
        # 如果 value 不在允许的类型列表中,则抛出 ValueError 异常
        if value not in ["original_full", "block_sparse"]:
            raise ValueError(
                f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}"
            )
        
        # 如果 value 和当前注意力类型相同,则不做任何操作直接返回
        if value == self.attention_type:
            return

        # 将实例的 attention_type 设置为新的 value
        self.attention_type = value
        
        # 根据新的 attention_type 重新设置 self.self 对象
        if value == "original_full":
            # 复制所有权重到新的完全注意力类
            attn_weights = BigBirdPegasusSelfAttention(self.config)
        else:
            # 复制所有权重到新的稀疏注意力类
            attn_weights = BigBirdPegasusBlockSparseAttention(self.config, self.seed)

        # 将当前 self.self 的 query、value、key 属性复制到新的 attn_weights 对象
        attn_weights.query = self.self.query
        attn_weights.value = self.self.value
        attn_weights.key = self.self.key
        
        # 更新实例的 self.self 为新的 attn_weights
        self.self = attn_weights
        
        # 同时更新实例的 attention_type
        self.attention_type = value

        # 如果不处于训练模式,则将 self.self 设为评估状态
        if not self.training:
            self.self.eval()

    # 前向传播方法,接受多个输入参数并返回输出
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        past_key_value=None,
        output_attentions=False,
        band_mask=None,
        from_mask=None,
        to_mask=None,
        from_blocked_mask=None,
        to_blocked_mask=None,
    ):
        # 如果 head_mask 不为 None,则将其扩展一个维度以便在自注意力模块中进行乘法操作
        head_mask = head_mask.reshape(1, -1, 1, 1) if head_mask is not None else None

        # 根据配置中的 attention_type 选择不同的 self.self 模块进行计算
        if self.config.attention_type == "original_full":
            # 使用完全注意力模块进行计算
            self_outputs = self.self(
                hidden_states,
                attention_mask,
                head_mask,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
            )
        else:
            # 使用稀疏注意力模块进行计算
            self_outputs = self.self(
                hidden_states, band_mask, from_mask, to_mask, from_blocked_mask, to_blocked_mask, output_attentions
            )

        # 将注意力输出通过输出层映射回原始维度
        attention_output = self.output(self_outputs[0])
        
        # 如果需要输出注意力矩阵,则在输出元组中包含它们
        outputs = (attention_output,) + self_outputs[1:]  # 如果需要,添加注意力矩阵到输出元组中
        return outputs

# 从 transformers.models.bart.modeling_bart.BartAttention 复制并修改为 BigBirdPegasusDecoderConfig->BigBirdPegasusConfig, Bart->BigBirdPegasusDecoder
class BigBirdPegasusDecoderAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        bias: bool = True,
        is_causal: bool = False,
        config: Optional[BigBirdPegasusConfig] = None,
    ):
        super().__init__()
        self.embed_dim = embed_dim  # 设置注意力机制的嵌入维度
        self.num_heads = num_heads  # 设置注意力头的数量
        self.dropout = dropout  # 设置dropout率
        self.head_dim = embed_dim // num_heads  # 计算每个注意力头的维度
        self.config = config  # 设置配置参数

        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {num_heads})."
            )
        self.scaling = self.head_dim**-0.5  # 缩放因子,用于调整注意力分布
        self.is_decoder = is_decoder  # 是否为解码器
        self.is_causal = is_causal  # 是否是因果的

        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)  # 线性变换k
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)  # 线性变换v
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)  # 线性变换q
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)  # 输出变换

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
        # 将输入张量reshape为(batch_size, seq_len, num_heads, head_dim)的形状,并进行维度转置和连续化处理

    def forward(
        self,
        hidden_states: torch.Tensor,  # 输入的隐藏状态张量
        key_value_states: Optional[torch.Tensor] = None,  # 键值对状态张量(可选)
        past_key_value: Optional[Tuple[torch.Tensor]] = None,  # 过去的键值对(可选)
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码(可选)
        layer_head_mask: Optional[torch.Tensor] = None,  # 层级头掩码(可选)
        output_attentions: bool = False,  # 是否输出注意力权重
    ):
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        """
        residual = hidden_states  # 保存输入 hidden_states 作为残差连接的基础

        hidden_states = self.self_attn_layer_norm(hidden_states)  # 对输入 hidden_states 进行层归一化

        # 使用 self-attention 模块处理归一化后的 hidden_states
        self_attention_outputs = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            head_mask=layer_head_mask,
            output_attentions=output_attentions,
            band_mask=band_mask,
            from_mask=from_mask,
            to_mask=to_mask,
            from_blocked_mask=from_blocked_mask,
            to_blocked_mask=to_blocked_mask,
        )
        hidden_states = self_attention_outputs[0]  # 更新 hidden_states 为 self-attention 的输出结果

        # 对 hidden_states 进行 dropout 处理
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

        # 残差连接:将残差加回到处理后的 hidden_states
        hidden_states = residual + hidden_states

        residual = hidden_states  # 更新残差连接的基础为当前的 hidden_states

        hidden_states = self.final_layer_norm(hidden_states)  # 对 hidden_states 进行最终的层归一化
        hidden_states = self.activation_fn(self.fc1(hidden_states))  # 经过激活函数和第一个全连接层处理

        hidden_states = self.fc2(hidden_states)  # 第二个全连接层处理
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)  # dropout 处理

        # 残差连接:将残差加回到处理后的 hidden_states
        hidden_states = residual + hidden_states

        # 如果 hidden_states 的数据类型为 torch.float16,并且包含无穷大或 NaN 值,则进行数值截断处理
        if hidden_states.dtype == torch.float16 and (
            torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
        ):
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

        outputs = (hidden_states,)  # 将最终的 hidden_states 打包为输出元组

        if output_attentions:
            outputs += (self_attention_outputs[1],)  # 如果需要返回 attentions,则将 attentions 加入输出元组

        return outputs  # 返回最终输出元组,包含处理后的 hidden_states 和可能的 attentions

    def set_attention_type(self, value: str):
        if value not in ["original_full", "block_sparse"]:  # 检查输入值是否合法
            raise ValueError(
                f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}"
            )
        # 如果 attention_type 已经正确设置,则直接返回,无需修改
        if value == self.attention_type:
            return
        self.attention_type = value  # 更新 attention_type 为新的值
        self.self_attn.set_attention_type(value)  # 更新 self-attention 模块的 attention_type
# 定义 BigBirdPegasusDecoderLayer 类,继承自 nn.Module
class BigBirdPegasusDecoderLayer(nn.Module):
    
    # 初始化方法,接受一个 BigBirdPegasusConfig 类型的参数 config
    def __init__(self, config: BigBirdPegasusConfig):
        super().__init__()
        
        # 设置 embed_dim 属性为 config.d_model
        self.embed_dim = config.d_model
        
        # 创建 BigBirdPegasusDecoderAttention 对象并赋给 self.self_attn 属性
        self.self_attn = BigBirdPegasusDecoderAttention(
            embed_dim=self.embed_dim,
            num_heads=config.decoder_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
            bias=config.use_bias,
        )
        
        # 设置 dropout 属性为 config.dropout
        self.dropout = config.dropout
        
        # 根据配置中的激活函数名称获取相应的激活函数,并赋给 self.activation_fn 属性
        self.activation_fn = ACT2FN[config.activation_function]
        
        # 设置 activation_dropout 属性为 config.activation_dropout
        self.activation_dropout = config.activation_dropout

        # 创建 nn.LayerNorm 对象并赋给 self.self_attn_layer_norm 属性
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        
        # 创建 BigBirdPegasusDecoderAttention 对象并赋给 self.encoder_attn 属性
        self.encoder_attn = BigBirdPegasusDecoderAttention(
            self.embed_dim,
            config.decoder_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
            bias=config.use_bias,
        )
        
        # 创建 nn.LayerNorm 对象并赋给 self.encoder_attn_layer_norm 属性
        self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        
        # 创建 nn.Linear 对象并赋给 self.fc1 属性,用于第一个全连接层
        self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
        
        # 创建 nn.Linear 对象并赋给 self.fc2 属性,用于第二个全连接层
        self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
        
        # 创建 nn.LayerNorm 对象并赋给 self.final_layer_norm 属性
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    # 定义 forward 方法,执行模型的前向传播
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = True,
    ):
        # 略
        pass


# 定义 BigBirdPegasusClassificationHead 类,继承自 nn.Module
class BigBirdPegasusClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    # 初始化方法,接受 input_dim、inner_dim、num_classes、pooler_dropout 四个参数
    def __init__(
        self,
        input_dim: int,
        inner_dim: int,
        num_classes: int,
        pooler_dropout: float,
    ):
        super().__init__()
        
        # 创建 nn.Linear 对象并赋给 self.dense 属性,用于密集连接层
        self.dense = nn.Linear(input_dim, inner_dim)
        
        # 创建 nn.Dropout 对象并赋给 self.dropout 属性,用于 dropout 操作
        self.dropout = nn.Dropout(p=pooler_dropout)
        
        # 创建 nn.Linear 对象并赋给 self.out_proj 属性,用于最终的线性变换
        self.out_proj = nn.Linear(inner_dim, num_classes)

    # 定义 forward 方法,执行模型的前向传播
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 应用 dropout 操作到 hidden_states
        hidden_states = self.dropout(hidden_states)
        
        # 通过全连接层 self.dense 进行线性变换
        hidden_states = self.dense(hidden_states)
        
        # 应用 tanh 激活函数
        hidden_states = torch.tanh(hidden_states)
        
        # 再次应用 dropout 操作
        hidden_states = self.dropout(hidden_states)
        
        # 通过最终的线性变换 self.out_proj 得到最终输出
        hidden_states = self.out_proj(hidden_states)
        
        # 返回最终的输出张量
        return hidden_states


# 定义 BigBirdPegasusPreTrainedModel 类,继承自 PreTrainedModel
class BigBirdPegasusPreTrainedModel(PreTrainedModel):
    
    # 设置 config_class 属性为 BigBirdPegasusConfig 类
    config_class = BigBirdPegasusConfig
    
    # 设置 base_model_prefix 属性为 "model"
    base_model_prefix = "model"
    
    # 设置 supports_gradient_checkpointing 属性为 True
    supports_gradient_checkpointing = True
    
    # 设置 _no_split_modules 属性为 ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"]
    _no_split_modules = ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"]
    
    # 设置 _skip_keys_device_placement 属性为 "past_key_values"
    _skip_keys_device_placement = "past_key_values"
    # 初始化模块的权重,根据模块类型设置不同的初始化标准差
    def _init_weights(self, module):
        std = self.config.init_std
        # 如果是线性层模块
        if isinstance(module, nn.Linear):
            # 使用正态分布初始化权重
            module.weight.data.normal_(mean=0.0, std=std)
            # 如果存在偏置项,则将其初始化为零
            if module.bias is not None:
                module.bias.data.zero_()
        # 如果是嵌入层模块
        elif isinstance(module, nn.Embedding):
            # 使用正态分布初始化权重
            module.weight.data.normal_(mean=0.0, std=std)
            # 如果存在填充索引,则将其对应的权重初始化为零
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

    @property
    # 返回虚拟的输入数据,用于模型测试
    def dummy_inputs(self):
        # 获取填充标记的 ID
        pad_token = self.config.pad_token_id
        # 创建输入 ID 张量,包含两个示例句子的 ID 序列
        input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
        # 构建虚拟输入字典,包括注意力掩码和输入 ID
        dummy_inputs = {
            "attention_mask": input_ids.ne(pad_token),  # 注意力掩码表示哪些位置是填充的
            "input_ids": input_ids,  # 实际输入的 ID 序列
        }
        # 返回虚拟输入字典
        return dummy_inputs
BIGBIRD_PEGASUS_START_DOCSTRING = r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`BigBirdPegasusConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

BIGBIRD_PEGASUS_GENERATION_EXAMPLE = r"""
    Summarization example:

    ```
    >>> from transformers import AutoTokenizer, BigBirdPegasusForConditionalGeneration

    >>> model = BigBirdPegasusForConditionalGeneration.from_pretrained("google/bigbird-pegasus-large-arxiv")
    >>> tokenizer = AutoTokenizer.from_pretrained("google/bigbird-pegasus-large-arxiv")

    >>> ARTICLE_TO_SUMMARIZE = (
    ...     "The dominant sequence transduction models are based on complex recurrent or convolutional neural "
    ...     "networks in an encoder-decoder configuration. The best performing models also connect the encoder "
    ...     "and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, "
    ...     "based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. "
    ...     "Experiments on two machine translation tasks show these models to be superior in quality "
    ...     "while being more parallelizable and requiring significantly less time to train."
    ... )
    >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=4096, return_tensors="pt", truncation=True)

    >>> # Generate Summary
    >>> summary_ids = model.generate(inputs["input_ids"], num_beams=4, max_length=15)
    >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    'dominant sequence models are based on recurrent or convolutional neural networks .'
    ```
"""

BIGBIRD_PEGASUS_INPUTS_DOCSTRING = r"""
    Placeholder for documenting inputs for BigBirdPegasus models.
"""

BIGBIRD_PEGASUS_STANDALONE_INPUTS_DOCSTRING = r"""
    Placeholder for documenting standalone inputs for BigBirdPegasus models.
"""
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            # 输入序列标记在词汇表中的索引。默认情况下,忽略填充标记。

            # 可以使用 `ProphetNetTokenizer` 来获取这些索引。详见 `PreTrainedTokenizer.encode` 和 `PreTrainedTokenizer.__call__`。

            # [什么是输入 ID?](../glossary#input-ids)
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            # 避免对填充标记进行注意力计算的掩码张量。掩码值在 `[0, 1]` 范围内:

            # - 1 表示 **未被掩码** 的标记,
            # - 0 表示 **被掩码** 的标记。

            # [什么是注意力掩码?](../glossary#attention-mask)
        output_attentions (`bool`, *optional*):
            # 是否返回所有注意力层的注意力张量。详见返回的张量中的 `attentions` 获取更多细节。
        output_hidden_states (`bool`, *optional*):
            # 是否返回所有层的隐藏状态。详见返回的张量中的 `hidden_states` 获取更多细节。
        return_dict (`bool`, *optional*):
            # 是否返回 [`~utils.ModelOutput`] 而不是普通的元组。
    """
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    [`BigBirdPegasusEncoderLayer`].

    Args:
        config: BigBirdPegasusConfig
        embed_tokens (nn.Embedding): output embedding
    """

    def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
        super().__init__(config)

        self.attention_type = config.attention_type  # 从配置中获取注意力类型
        self.block_size = config.block_size  # 从配置中获取块大小

        self.dropout = config.dropout  # 从配置中获取 dropout 率
        self.layerdrop = config.encoder_layerdrop  # 从配置中获取层级 dropout 率

        embed_dim = config.d_model  # 从配置中获取嵌入维度
        self.padding_idx = config.pad_token_id  # 从配置中获取填充标识符
        self.max_source_positions = config.max_position_embeddings  # 从配置中获取最大位置嵌入
        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0  # 根据配置设置嵌入缩放因子

        self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)  # 初始化嵌入层

        if embed_tokens is not None:
            self.embed_tokens.weight = embed_tokens.weight  # 如果提供了外部嵌入,则使用其权重

        self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding(
            config.max_position_embeddings,
            embed_dim,
        )  # 初始化位置嵌入

        self.layers = nn.ModuleList([BigBirdPegasusEncoderLayer(config, seed=i) for i in range(config.encoder_layers)])
        # 创建多层编码器层,并存储在模块列表中

        self.layernorm_embedding = nn.LayerNorm(embed_dim)  # 初始化嵌入层归一化层

        self.gradient_checkpointing = False  # 梯度检查点设置为 False
        # 初始化权重并应用最终处理
        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        pass  # 此处为前向传播函数的占位符,实际执行模型推理过程

    def set_attention_type(self, value: str):
        if value not in ["original_full", "block_sparse"]:  # 检查传入的注意力类型是否合法
            raise ValueError(
                f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}"
            )
        # 如果注意力类型已经正确设置,则直接返回
        if value == self.attention_type:
            return
        self.attention_type = value  # 更新注意力类型为新值
        for layer in self.layers:
            layer.set_attention_type(value)  # 更新每个编码器层的注意力类型

    @staticmethod  # 静态方法,用于生成块稀疏注意力的掩码,从 Transformers 源代码复制而来
    # transformers.models.big_bird.modeling_big_bird.BigBirdModel.create_masks_for_block_sparse_attn
    def create_masks_for_block_sparse_attn(attention_mask: torch.Tensor, block_size: int):
        batch_size, seq_length = attention_mask.size()
        # 检查序列长度是否是块大小的倍数,如果不是则抛出异常
        if seq_length % block_size != 0:
            raise ValueError(
                f"Sequence length must be multiple of block size, but sequence length is {seq_length}, while block"
                f" size is {block_size}."
            )

        def create_band_mask_from_inputs(from_blocked_mask, to_blocked_mask):
            """
            从二维张量掩码创建三维注意力掩码。

            Args:
                from_blocked_mask: 形状为 [batch_size, from_seq_length//from_block_size, from_block_size] 的二维张量掩码。
                to_blocked_mask: 形状为 [batch_size, to_seq_length//to_block_size, to_block_size] 的整数张量掩码。

            Returns:
                形状为 [batch_size, 1, from_seq_length//from_block_size-4, from_block_size, 3*to_block_size] 的浮点张量。
            """
            # 构造用于填充的扩展阻塞掩码
            exp_blocked_to_pad = torch.cat(
                [to_blocked_mask[:, 1:-3], to_blocked_mask[:, 2:-2], to_blocked_mask[:, 3:-1]], dim=2
            )
            # 使用 Einstein Summation Notation 创建带状掩码
            band_mask = torch.einsum("blq,blk->blqk", from_blocked_mask[:, 2:-2], exp_blocked_to_pad)
            band_mask.unsqueeze_(1)
            return band_mask

        # 将注意力掩码视图重新形状为块大小的块编码器掩码
        blocked_encoder_mask = attention_mask.view(batch_size, seq_length // block_size, block_size)
        # 创建带状掩码
        band_mask = create_band_mask_from_inputs(blocked_encoder_mask, blocked_encoder_mask)

        # 创建来自掩码和去掩码
        from_mask = attention_mask.view(batch_size, 1, seq_length, 1)
        to_mask = attention_mask.view(batch_size, 1, 1, seq_length)

        return blocked_encoder_mask, band_mask, from_mask, to_mask

    def _pad_to_block_size(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
        """A helper function to pad tokens and mask to work with implementation of BigBird block-sparse attention."""
        # 填充函数,用于与 BigBird 块稀疏注意力实现一起工作的辅助函数
        # 填充
        block_size = self.config.block_size
        batch_size, seq_len = hidden_states.shape[:2]

        padding_len = (block_size - seq_len % block_size) % block_size
        if padding_len > 0:
            # 如果需要填充,警告并自动填充输入 ID 和嵌入到块大小的倍数
            logger.warning_once(
                f"Input ids are automatically padded from {seq_len} to {seq_len + padding_len} to be a multiple of "
                f"`config.block_size`: {block_size}"
            )
            pad_id = self.config.pad_token_id
            device = hidden_states.device
            input_ids_padding = torch.ones((batch_size, padding_len), dtype=torch.long, device=device) * pad_id
            inputs_embeds_padding = self.embed_tokens(input_ids_padding)
            hidden_states = torch.cat([hidden_states, inputs_embeds_padding], dim=-2)

            # 使用 nn.functional.pad 对注意力掩码进行填充,填充部分的注意力为0
            attention_mask = nn.functional.pad(
                attention_mask, (0, padding_len), value=0
            )  # no attention on the padding tokens

        return padding_len, hidden_states, attention_mask
class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
    """
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BigBirdPegasusDecoderLayer`]

    Args:
        config: BigBirdPegasusConfig
        embed_tokens (nn.Embedding): output embedding
    """

    def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
        super().__init__(config)
        self.dropout = config.dropout  # 从配置中获取 dropout 概率
        self.layerdrop = config.decoder_layerdrop  # 从配置中获取层级 dropout 概率
        self.padding_idx = config.pad_token_id  # 从配置中获取填充标记的索引
        self.max_target_positions = config.max_position_embeddings  # 从配置中获取最大目标位置数
        self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0  # 根据配置计算嵌入尺度

        self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)  # 初始化词嵌入层

        if embed_tokens is not None:
            self.embed_tokens.weight = embed_tokens.weight  # 如果提供了预训练的嵌入层,则使用它

        self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding(
            config.max_position_embeddings,
            config.d_model,
        )  # 初始化位置编码器

        self.layers = nn.ModuleList([BigBirdPegasusDecoderLayer(config) for _ in range(config.decoder_layers)])  # 创建多层解码器层
        self.layernorm_embedding = nn.LayerNorm(config.d_model)  # 应用层归一化到嵌入层

        self.gradient_checkpointing = False  # 初始化梯度检查点

        # Initialize weights and apply final processing
        self.post_init()  # 执行初始化权重和最终处理



@add_start_docstrings(
    "The bare BigBirdPegasus Model outputting raw hidden-states without any specific head on top.",
    BIGBIRD_PEGASUS_START_DOCSTRING,
)
class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

    def __init__(self, config: BigBirdPegasusConfig):
        super().__init__(config)

        padding_idx, vocab_size = config.pad_token_id, config.vocab_size
        self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)  # 初始化共享的嵌入层

        self.encoder = BigBirdPegasusEncoder(config, self.shared)  # 创建BigBirdPegasus编码器,使用共享嵌入
        self.decoder = BigBirdPegasusDecoder(config, self.shared)  # 创建BigBirdPegasus解码器,使用共享嵌入

        # Initialize weights and apply final processing
        self.post_init()  # 执行初始化权重和最终处理
    # 返回输入的共享输入嵌入
    def get_input_embeddings(self):
        return self.shared

    # 设置共享输入嵌入,并更新编码器和解码器的嵌入
    def set_input_embeddings(self, value):
        self.shared = value
        self.encoder.embed_tokens = self.shared
        self.decoder.embed_tokens = self.shared

    # 如果配置要求词嵌入共享,则绑定编码器和解码器的嵌入权重
    def _tie_weights(self):
        if self.config.tie_word_embeddings:
            self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
            self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

    # 返回编码器对象
    def get_encoder(self):
        return self.encoder

    # 返回解码器对象
    def get_decoder(self):
        return self.decoder

    @add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=Seq2SeqModelOutput,
        config_class=_CONFIG_FOR_DOC,
        expected_output=_EXPECTED_OUTPUT_SHAPE,
    )
    # 从 transformers.models.bart.modeling_bart.BartModel.forward 复制的代码,并将 Bart 替换为 BigBirdPegasus
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
# 使用装饰器为类添加文档字符串,描述了 BigBirdPegasusForConditionalGeneration 模型的用途和摘要功能
@add_start_docstrings(
    "The BigBirdPegasus Model with a language modeling head. Can be used for summarization.",
    BIGBIRD_PEGASUS_START_DOCSTRING,
)
# 从 transformers.models.bart.modeling_bart.BartForConditionalGeneration 复制代码,并将 Bart 替换为 BigBirdPegasus,BART 替换为 BIGBIRD_PEGASUS
class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
    # 设置模型主体的前缀为 "model"
    base_model_prefix = "model"
    # 定义在加载过程中需要忽略的键名列表,这些键名对应缺失时不会引发警告
    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
    # 定义加载时忽略的关键字列表,指定不会加载的额外逻辑
    _keys_to_ignore_on_load_missing = ["final_logits_bias"]

    # 初始化函数,接受 BigBirdPegasusConfig 类型的配置对象
    def __init__(self, config: BigBirdPegasusConfig):
        # 调用父类的初始化方法,传入配置对象
        super().__init__(config)
        # 创建 BigBirdPegasusModel 实例并赋值给 self.model
        self.model = BigBirdPegasusModel(config)
        # 注册一个缓冲区,初始化为全零向量,维度是 (1, self.model.shared.num_embeddings)
        self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
        # 创建一个线性层,作为语言模型头,输入大小为 config.d_model,输出大小为 self.model.shared.num_embeddings,不使用偏置
        self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)

        # 初始化权重并应用最终处理
        self.post_init()

    # 获取编码器部分的方法,返回 self.model 的编码器
    def get_encoder(self):
        return self.model.get_encoder()

    # 获取解码器部分的方法,返回 self.model 的解码器
    def get_decoder(self):
        return self.model.get_decoder()

    # 调整 token embeddings 大小的方法,返回新的嵌入层对象
    def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
        # 调用父类的 resize_token_embeddings 方法,返回新的嵌入层对象
        new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
        # 调整 final_logits_bias 的大小以匹配新的 token 数量
        self._resize_final_logits_bias(new_embeddings.weight.shape[0])
        return new_embeddings

    # 调整 final_logits_bias 大小的私有方法,不返回任何内容
    def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
        # 获取旧的 token 数量
        old_num_tokens = self.final_logits_bias.shape[-1]
        # 如果新的 token 数量小于等于旧的 token 数量,则截取 final_logits_bias
        if new_num_tokens <= old_num_tokens:
            new_bias = self.final_logits_bias[:, :new_num_tokens]
        # 如果新的 token 数量大于旧的 token 数量,则在最后增加零偏置
        else:
            extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
            new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
        # 注册新的 final_logits_bias 缓冲区
        self.register_buffer("final_logits_bias", new_bias)

    # 获取输出嵌入层的方法,返回 self.lm_head
    def get_output_embeddings(self):
        return self.lm_head

    # 设置输出嵌入层的方法,接受新的嵌入层作为参数
    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    # 使用装饰器添加文档字符串,描述了 model_forward 方法的输入和输出
    @add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING)
    # 替换返回值的文档字符串,指定输出类型为 Seq2SeqLMOutput,配置类为 _CONFIG_FOR_DOC
    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
    # 添加结尾的文档字符串,提供 BIGBIRD_PEGASUS_GENERATION_EXAMPLE 的生成示例
    @add_end_docstrings(BIGBIRD_PEGASUS_GENERATION_EXAMPLE)
    # 定义模型的前向传播方法,用于执行推断或训练过程中的正向计算
    def forward(
        self,
        input_ids: torch.LongTensor = None,  # 输入序列的token IDs,类型为长整型张量
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码,可选的张量类型
        decoder_input_ids: Optional[torch.LongTensor] = None,  # 解码器输入的token IDs,可选的长整型张量
        decoder_attention_mask: Optional[torch.LongTensor] = None,  # 解码器注意力掩码,可选的长整型张量
        head_mask: Optional[torch.Tensor] = None,  # 头部掩码,可选的张量
        decoder_head_mask: Optional[torch.Tensor] = None,  # 解码器头部掩码,可选的张量
        cross_attn_head_mask: Optional[torch.Tensor] = None,  # 跨注意力头部掩码,可选的张量
        encoder_outputs: Optional[List[torch.FloatTensor]] = None,  # 编码器输出的列表,包含浮点张量
        past_key_values: Optional[List[torch.FloatTensor]] = None,  # 过去的键值对列表,包含浮点张量
        inputs_embeds: Optional[torch.FloatTensor] = None,  # 输入嵌入,可选的浮点张量
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,  # 解码器输入嵌入,可选的浮点张量
        labels: Optional[torch.LongTensor] = None,  # 标签,可选的长整型张量
        use_cache: Optional[bool] = None,  # 是否使用缓存,可选的布尔值
        output_attentions: Optional[bool] = None,  # 是否输出注意力权重,可选的布尔值
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态,可选的布尔值
        return_dict: Optional[bool] = None,  # 是否以字典形式返回结果,可选的布尔值
    ) -> Union[Tuple, Seq2SeqLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:
            Return type annotation indicating the function returns either a tuple or `Seq2SeqLMOutput`.

        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # If labels are provided, adjust `use_cache` behavior and prepare `decoder_input_ids`
        if labels is not None:
            if use_cache:
                # Warn if `use_cache` is `True` because `labels` are provided; set `use_cache` to `False`
                logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
            use_cache = False
            # If `decoder_input_ids` are not provided, shift `labels` for decoder inputs
            if decoder_input_ids is None and decoder_inputs_embeds is None:
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
                )

        # Pass the input arguments to the model for processing
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # Generate logits for language modeling head and adjust with final bias
        lm_logits = self.lm_head(outputs[0])
        lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)

        masked_lm_loss = None
        # Compute masked language modeling loss if labels are provided
        if labels is not None:
            labels = labels.to(lm_logits.device)  # Ensure labels are on the same device as logits
            loss_fct = CrossEntropyLoss()  # Define the loss function
            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))

        # If `return_dict` is `False`, return outputs as a tuple
        if not return_dict:
            output = (lm_logits,) + outputs[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        # If `return_dict` is `True`, return structured `Seq2SeqLMOutput`
        return Seq2SeqLMOutput(
            loss=masked_lm_loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )
    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        past_key_values=None,
        attention_mask=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):
        # 如果使用了过去的键值(past_key_values),则根据其长度修剪 decoder_input_ids
        if past_key_values is not None:
            past_length = past_key_values[0][0].shape[2]

            # 有些生成方法已经只传递了最后一个输入 ID
            if decoder_input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                # 默认旧行为:仅保留最后一个 ID
                remove_prefix_length = decoder_input_ids.shape[1] - 1

            decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]

        # 返回准备好的输入字典,用于生成
        return {
            "input_ids": None,  # encoder_outputs 已定义,input_ids 不需要
            "encoder_outputs": encoder_outputs,
            "past_key_values": past_key_values,
            "decoder_input_ids": decoder_input_ids,
            "attention_mask": attention_mask,
            "decoder_attention_mask": decoder_attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,  # 更改此处以避免缓存(推测是为了调试)
        }

    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        # 将标签右移一个位置,以准备解码器的输入
        return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            # 缓存的交叉注意力状态无需重新排序 -> 它们始终相同
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
                + layer_past[2:],
            )
        # 返回重新排序后的过去键值
        return reordered_past
@add_start_docstrings(
    """
    BigBirdPegasus model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g.
    for GLUE tasks.
    """,
    BIGBIRD_PEGASUS_START_DOCSTRING,
)
class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

    def __init__(self, config: BigBirdPegasusConfig, **kwargs):
        super().__init__(config, **kwargs)
        self.model = BigBirdPegasusModel(config)
        self.classification_head = BigBirdPegasusClassificationHead(
            config.d_model,
            config.d_model,
            config.num_labels,
            config.classifier_dropout,
        )

        # Initialize weights and apply final processing
        self.post_init()

    @add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=Seq2SeqSequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    # Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        """
        Forward pass for the BigBirdPegasusForSequenceClassification model.

        Args:
            input_ids (torch.LongTensor, optional): Input token IDs. Defaults to None.
            attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
            decoder_input_ids (torch.LongTensor, optional): Decoder input token IDs. Defaults to None.
            decoder_attention_mask (torch.LongTensor, optional): Decoder attention mask. Defaults to None.
            head_mask (torch.Tensor, optional): Head mask. Defaults to None.
            decoder_head_mask (torch.Tensor, optional): Decoder head mask. Defaults to None.
            cross_attn_head_mask (torch.Tensor, optional): Cross-attention head mask. Defaults to None.
            encoder_outputs (List[torch.FloatTensor], optional): Encoder outputs. Defaults to None.
            inputs_embeds (torch.FloatTensor, optional): Embedded inputs. Defaults to None.
            decoder_inputs_embeds (torch.FloatTensor, optional): Embedded decoder inputs. Defaults to None.
            labels (torch.LongTensor, optional): Labels for classification. Defaults to None.
            use_cache (bool, optional): Whether to use cache. Defaults to None.
            output_attentions (bool, optional): Whether to output attentions. Defaults to None.
            output_hidden_states (bool, optional): Whether to output hidden states. Defaults to None.
            return_dict (bool, optional): Whether to return a dictionary. Defaults to None.

        Returns:
            Seq2SeqSequenceClassifierOutput or dict: Sequence classification output.
        """
        # Actual implementation of the forward pass follows in the code of the function.
        pass


@add_start_docstrings(
    """
    BigBirdPegasus Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
    linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
    BIGBIRD_PEGASUS_START_DOCSTRING,
)
class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

    def __init__(self, config):
        super().__init__(config)

        # Set number of output labels to 2 for question answering (start and end positions)
        config.num_labels = 2
        self.num_labels = config.num_labels

        self.model = BigBirdPegasusModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        # Initialize weights and apply final processing
        self.post_init()

    @add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=Seq2SeqQuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    # 使用装饰器添加代码示例的文档字符串,指定相关的检查点、输出类型和配置类

    # 以下内容是从 transformers.models.bart.modeling_bart.BartForQuestionAnswering.forward 复制而来

    def forward(
        self,
        input_ids: torch.Tensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
        start_positions: Optional[torch.LongTensor] = None,
        end_positions: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,



        # 前向传播函数,接受多个参数用于模型推断
        # input_ids: 输入序列的 token IDs
        # attention_mask: 注意力掩码,指定哪些位置是填充的
        # decoder_input_ids: 解码器输入的 token IDs
        # decoder_attention_mask: 解码器的注意力掩码
        # head_mask: 多头注意力机制的掩码
        # decoder_head_mask: 解码器多头注意力的掩码
        # cross_attn_head_mask: 跨注意力头的掩码
        # encoder_outputs: 编码器输出的列表
        # start_positions: 答案开始位置的 token IDs
        # end_positions: 答案结束位置的 token IDs
        # inputs_embeds: 嵌入式输入的张量
        # decoder_inputs_embeds: 解码器输入的嵌入式张量
        # use_cache: 是否使用缓存
        # output_attentions: 是否输出注意力权重
        # output_hidden_states: 是否输出隐藏状态
        # return_dict: 是否返回字典形式的输出
# 从transformers.models.pegasus.modeling_pegasus.PegasusDecoderWrapper复制代码,并将Pegasus更改为BigBirdPegasus
class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel):
    """
    This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
    used in combination with the [`EncoderDecoderModel`] framework.
    """

    def __init__(self, config):
        super().__init__(config)
        # 初始化BigBirdPegasusDecoder对象作为decoder
        self.decoder = BigBirdPegasusDecoder(config)

    def forward(self, *args, **kwargs):
        # 调用decoder的forward方法,并将参数传递下去
        return self.decoder(*args, **kwargs)


class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        # 深度拷贝config对象,设定为decoder模式,并关闭encoder-decoder模式
        config = copy.deepcopy(config)
        config.is_decoder = True
        config.is_encoder_decoder = False
        super().__init__(config)
        # 初始化BigBirdPegasusDecoderWrapper对象作为model
        self.model = BigBirdPegasusDecoderWrapper(config)

        # 初始化线性层,作为lm_head,用于生成输出
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # 初始化权重并进行最终处理
        self.post_init()

    def get_input_embeddings(self):
        # 返回decoder的embed_tokens作为输入的嵌入层
        return self.model.decoder.embed_tokens

    def set_input_embeddings(self, value):
        # 设置decoder的embed_tokens为新的值
        self.model.decoder.embed_tokens = value

    def get_output_embeddings(self):
        # 返回lm_head作为输出的嵌入层
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        # 设置lm_head为新的输出嵌入层
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        # 设置decoder模型为给定的decoder对象
        self.model.decoder = decoder

    def get_decoder(self):
        # 返回当前的decoder对象
        return self.model.decoder

    @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        """
        Forward pass for the BigBirdPegasusForCausalLM model.
        """
        # 实现模型的前向传播,接受多种输入参数,并返回输出结果
        ...

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
    ):
        """
        Prepare inputs for generation based on the BigBirdPegasusForCausalLM model.
        """
        # 准备生成模型输入的方法,接受多种参数,并返回适用于生成的输入
        ...
    ):
        # 如果模型被用作编码器-解码器模型中的解码器,注意力遮罩会即时创建
        if attention_mask is None:
            # 如果注意力遮罩为空,则创建一个与输入张量形状相同的全为1的张量作为注意力遮罩
            attention_mask = input_ids.new_ones(input_ids.shape)

        if past_key_values:
            # 如果有过去的键值状态,则只保留输入张量的最后一个位置作为当前输入
            input_ids = input_ids[:, -1:]
        # 返回一个包含各种输出和状态的字典
        return {
            "input_ids": input_ids,  # encoder_outputs 已经定义,不再需要 input_ids
            "attention_mask": attention_mask,
            "past_key_values": past_key_values,
            "use_cache": use_cache,
        }

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            # 对每一层的过去状态按照 beam_idx 重新排序,并转移到正确的设备上
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        # 返回重新排序后的过去状态
        return reordered_past
posted @ 2024-06-30 15:34  绝不原创的飞龙  阅读(13)  评论(0编辑  收藏  举报