Transformers-源码解析-五十四-

Transformers 源码解析(五十四)

.\models\gpt2\tokenization_gpt2.py

# 设置脚本的编码格式为UTF-8

# 引入必要的模块和函数
import json  # 导入用于 JSON 操作的模块
import os  # 导入用于操作系统功能的模块
from functools import lru_cache  # 导入 lru_cache 装饰器,用于缓存函数调用结果
from typing import List, Optional, Tuple  # 导入类型提示相关的类和函数

import regex as re  # 导入 regex 模块,用于正则表达式操作

# 导入 tokenization_utils 模块中的 AddedToken 和 PreTrainedTokenizer 类
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
# 导入 utils 模块中的 logging 对象
from ...utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义词汇文件的名称常量
VOCAB_FILES_NAMES = {
    "vocab_file": "vocab.json",  # 词汇表 JSON 文件的名称
    "merges_file": "merges.txt",  # 合并文件的名称
}

# 预训练模型的词汇文件映射
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "openai-community/gpt2": "https://huggingface.co/openai-community/gpt2/resolve/main/vocab.json",
        "openai-community/gpt2-medium": "https://huggingface.co/openai-community/gpt2-medium/resolve/main/vocab.json",
        "openai-community/gpt2-large": "https://huggingface.co/openai-community/gpt2-large/resolve/main/vocab.json",
        "openai-community/gpt2-xl": "https://huggingface.co/openai-community/gpt2-xl/resolve/main/vocab.json",
        "distilbert/distilgpt2": "https://huggingface.co/distilbert/distilgpt2/resolve/main/vocab.json",
    },
    "merges_file": {
        "openai-community/gpt2": "https://huggingface.co/openai-community/gpt2/resolve/main/merges.txt",
        "openai-community/gpt2-medium": "https://huggingface.co/openai-community/gpt2-medium/resolve/main/merges.txt",
        "openai-community/gpt2-large": "https://huggingface.co/openai-community/gpt2-large/resolve/main/merges.txt",
        "openai-community/gpt2-xl": "https://huggingface.co/openai-community/gpt2-xl/resolve/main/merges.txt",
        "distilbert/distilgpt2": "https://huggingface.co/distilbert/distilgpt2/resolve/main/merges.txt",
    },
}

# 预训练位置嵌入的尺寸映射
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "openai-community/gpt2": 1024,
    "openai-community/gpt2-medium": 1024,
    "openai-community/gpt2-large": 1024,
    "openai-community/gpt2-xl": 1024,
    "distilbert/distilgpt2": 1024,
}

# 使用 lru_cache 装饰器缓存结果的函数,将字节转换为 Unicode 字符
@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
    characters the bpe code barfs on.

    The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
    if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
    """
    # 此函数返回一个 utf-8 字节列表和映射到 Unicode 字符串的字典,避免映射到空白或控制字符,以免 BPE 算法出错
    # 生成一个字典,将 UTF-8 字节与 Unicode 字符之间建立映射关系
    bs = (
        # ASCII 可见字符的 Unicode 码点范围
        list(range(ord("!"), ord("~") + 1)) +
        # 特殊符号的 Unicode 码点范围
        list(range(ord("¡"), ord("¬") + 1)) +
        # 特殊符号的 Unicode 码点范围
        list(range(ord("®"), ord("ÿ") + 1))
    )
    cs = bs[:]  # 复制 bs 列表到 cs 列表
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)  # 将不在 bs 中的字节添加到 bs 列表中
            cs.append(2**8 + n)  # 同时将其映射到 cs 列表中,通过添加 256 + n 的方式
            n += 1
    cs = [chr(n) for n in cs]  # 将 cs 列表中的整数转换为对应的 Unicode 字符
    return dict(zip(bs, cs))  # 返回由 bs 和 cs 列表组成的字典,表示 UTF-8 字节到 Unicode 字符的映射关系
def get_pairs(word):
    """
    Return set of symbol pairs in a word.

    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    # Initialize an empty set to store symbol pairs
    pairs = set()
    # Initialize prev_char with the first symbol of the word
    prev_char = word[0]
    # Iterate through each character in the word starting from the second character
    for char in word[1:]:
        # Add a tuple representing the pair (prev_char, char) to the pairs set
        pairs.add((prev_char, char))
        # Update prev_char to the current character for the next iteration
        prev_char = char
    # Return the set of symbol pairs
    return pairs


class GPT2Tokenizer(PreTrainedTokenizer):
    """
    Construct a GPT-2 tokenizer. Based on byte-level Byte-Pair-Encoding.

    This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
    be encoded differently whether it is at the beginning of the sentence (without space) or not:

    ```
    >>> from transformers import GPT2Tokenizer

    >>> tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
    >>> tokenizer("Hello world")["input_ids"]
    [15496, 995]

    >>> tokenizer(" Hello world")["input_ids"]
    [18435, 995]
    ```

    You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
    call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.

    <Tip>

    When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one).

    </Tip>

    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
    this superclass for more information regarding those methods.

    Args:
        vocab_file (`str`):
            Path to the vocabulary file.
        merges_file (`str`):
            Path to the merges file.
        errors (`str`, *optional*, defaults to `"replace"`):
            Paradigm to follow when decoding bytes to UTF-8. See
            [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
        unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
            The beginning of sequence token.
        eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
            The end of sequence token.
        pad_token (`str`, *optional*):
            The token used for padding, for example when batching sequences of different lengths.
        add_prefix_space (`bool`, *optional*, defaults to `False`):
            Whether or not to add an initial space to the input. This allows to treat the leading word just as any
            other word. (GPT2 tokenizer detect beginning of words by the preceding space).
        add_bos_token (`bool`, *optional*, defaults to `False`):
            Whether or not to add an initial beginning of sentence token to the input. This allows to treat the leading
            word just as any other word.
    """
    # The GPT2Tokenizer class provides methods to tokenize text based on GPT-2 model's byte-level BPE approach.
    # It handles various tokenization scenarios including handling spaces and special tokens like BOS and EOS.
    
    def __init__(
        self,
        vocab_file: str,
        merges_file: str,
        errors: str = "replace",
        unk_token: str = "<|endoftext|>",
        bos_token: str = "<|endoftext|>",
        eos_token: str = "<|endoftext|>",
        pad_token: Optional[str] = None,
        add_prefix_space: bool = False,
        add_bos_token: bool = False,
    ):
        # Initialize the tokenizer by loading vocabulary and merges information
        pass

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # Save the tokenizer's vocabulary and merges information to the specified directory
        pass

    def encode(
        self,
        text: Union[str, List[str], List[int]],
        text_pair: Optional[Union[str, List[str], List[int]]] = None,
        add_special_tokens: bool = True,
        is_split_into_words: bool = False,
        padding: Union[bool, str, PaddingStrategy] = False,
        truncation: Union[bool, str, TruncationStrategy] = False,
        max_length: Optional[int] = None,
        stride: int = 0,
        return_tensors: Optional[Union[str, TensorType]] = None,
        **kwargs,
    ) -> BatchEncoding:
        # Tokenize a single sequence or a pair of sequences and return the token IDs
        pass

    def decode(self, token_ids: Union[int, List[int]], skip_special_tokens: bool = False) -> str:
        # Convert token IDs back to text
        pass

    def decode_batch(self, token_ids_batch: Union[List[int], List[List[int]]], **kwargs) -> List[str]:
        # Convert batches of token IDs back to text
        pass

    def convert_tokens_to_string(self, tokens: Union[int, List[int]]) -> str:
        # Convert token IDs or list of token IDs to a single string
        pass

    def convert_ids_to_tokens(self, ids: Union[int, List[int]]) -> Union[str, List[str]]:
        # Convert token IDs or list of token IDs to token strings
        pass

    def prepare_for_model(
        self,
        ids: Union[int, List[int]],
        pair_ids: Optional[Union[int, List[int]]] = None,
        max_length: Optional[int] = None,
        add_special_tokens: bool = True,
        stride: int = 0,
        truncation_strategy: Union[TruncationStrategy, str] = "longest_first",
        return_tensors: Optional[Union[str, TensorType]] = None,
        **kwargs,
    ) -> BatchEncoding:
        # Preprocess IDs and pair IDs for input to the model
        pass
    # 初始化变量 vocab_files_names,使用预定义的常量 VOCAB_FILES_NAMES
    vocab_files_names = VOCAB_FILES_NAMES
    # 初始化变量 pretrained_vocab_files_map,使用预定义的常量 PRETRAINED_VOCAB_FILES_MAP
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    # 初始化变量 max_model_input_sizes,使用预定义的常量 PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    # 初始化变量 model_input_names,包含固定的字符串列表
    model_input_names = ["input_ids", "attention_mask"]

    # 定义类的初始化方法
    def __init__(
        self,
        vocab_file,
        merges_file,
        errors="replace",
        unk_token="<|endoftext|>",
        bos_token="<|endoftext|>",
        eos_token="<|endoftext|>",
        pad_token=None,
        add_prefix_space=False,
        add_bos_token=False,
        **kwargs,
    ):
        # 如果 bos_token 是字符串,则将其转换为 AddedToken 对象
        bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
        # 如果 eos_token 是字符串,则将其转换为 AddedToken 对象
        eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
        # 如果 unk_token 是字符串,则将其转换为 AddedToken 对象
        unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
        # 如果 pad_token 是字符串,则将其转换为 AddedToken 对象
        pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token

        # 设置类属性 add_bos_token
        self.add_bos_token = add_bos_token

        # 使用 UTF-8 编码打开 vocab_file 文件,并加载其中的 JSON 数据到 self.encoder 中
        with open(vocab_file, encoding="utf-8") as vocab_handle:
            self.encoder = json.load(vocab_handle)
        # 根据 self.encoder 创建反向映射 self.decoder
        self.decoder = {v: k for k, v in self.encoder.items()}
        # 设置错误处理策略
        self.errors = errors  # how to handle errors in decoding
        # 初始化 bytes_to_unicode 函数,并将其赋值给 self.byte_encoder
        self.byte_encoder = bytes_to_unicode()
        # 根据 self.byte_encoder 创建反向映射 self.byte_decoder
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        # 使用 UTF-8 编码打开 merges_file 文件,并读取其中的 BPE 合并规则
        with open(merges_file, encoding="utf-8") as merges_handle:
            bpe_merges = merges_handle.read().split("\n")[1:-1]
        # 将 BPE 合并规则列表转换为元组,并创建 BPE 合并规则到索引的映射 self.bpe_ranks
        bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
        # 初始化缓存字典 self.cache
        self.cache = {}
        # 设置是否在前缀后添加空格的属性 self.add_prefix_space
        self.add_prefix_space = add_prefix_space

        # 编译正则表达式模式 self.pat,用于识别文本中的各种标记和空格
        # 添加 re.IGNORECASE 选项,以便处理大小写版本的缩略词合并
        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", re.IGNORECASE)

        # 调用父类的初始化方法,传递额外的参数和关键字参数
        super().__init__(
            errors=errors,
            unk_token=unk_token,
            bos_token=bos_token,
            eos_token=eos_token,
            pad_token=pad_token,
            add_prefix_space=add_prefix_space,
            add_bos_token=add_bos_token,
            **kwargs,
        )

    # 定义属性方法 vocab_size,返回 self.encoder 的长度
    @property
    def vocab_size(self):
        return len(self.encoder)

    # 定义方法 get_vocab,返回包含 self.encoder 和 self.added_tokens_encoder 的字典
    def get_vocab(self):
        return dict(self.encoder, **self.added_tokens_encoder)
    # 根据给定的 token 查询缓存,如果已经缓存则直接返回结果
    if token in self.cache:
        return self.cache[token]

    # 将 token 转换为元组形式,每个字符作为一个元素
    word = tuple(token)

    # 获取 token 中所有可能的字符对
    pairs = get_pairs(word)

    # 如果不存在字符对,则直接返回原始 token
    if not pairs:
        return token

    # 循环处理直到无法继续拆分
    while True:
        # 找到权重最小的字符对
        bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))

        # 如果找到的字符对不在权重表中,则跳出循环
        if bigram not in self.bpe_ranks:
            break

        # 将 word 按照找到的字符对进行拆分和合并
        first, second = bigram
        new_word = []
        i = 0
        while i < len(word):
            try:
                j = word.index(first, i)
            except ValueError:
                new_word.extend(word[i:])
                break
            else:
                new_word.extend(word[i:j])
                i = j

            if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
                new_word.append(first + second)
                i += 2
            else:
                new_word.append(word[i])
                i += 1

        # 更新 word 为新的拆分后的元组形式
        new_word = tuple(new_word)
        word = new_word

        # 如果最终只剩一个元素,结束循环
        if len(word) == 1:
            break
        else:
            # 否则继续获取新的字符对
            pairs = get_pairs(word)

    # 将最终处理后的 word 转换为字符串形式
    word = " ".join(word)

    # 将处理后的结果加入缓存中
    self.cache[token] = word

    # 返回最终处理后的字符串形式的 token
    return word

# 根据给定的 token_ids_0 和 token_ids_1 (可选)构建带有特殊 token 的输入
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
    # 如果需要添加 bos_token,则设置开始的特殊 token_ids
    if self.add_bos_token:
        bos_token_ids = [self.bos_token_id]
    else:
        bos_token_ids = []

    # 将 token_ids_0 添加到输出列表中
    output = bos_token_ids + token_ids_0

    # 如果 token_ids_1 为空,则直接返回构建好的输出
    if token_ids_1 is None:
        return output

    # 否则将 token_ids_1 也添加到输出列表中,并在两个句子之间添加 bos_token_ids
    return output + bos_token_ids + token_ids_1

# 返回一个特殊 token 的掩码,用于标识输入中哪些位置已经包含了特殊 token
def get_special_tokens_mask(
    self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
        """
        从没有添加特殊标记的令牌列表中提取序列 ID。当使用分词器的 `prepare_for_model` 或 `encode_plus` 方法添加特殊标记时调用此方法。

        Args:
            token_ids_0 (`List[int]`):
                ID 列表。
            token_ids_1 (`List[int]`, *optional*):
                可选的第二个 ID 列表,用于序列对。
            already_has_special_tokens (`bool`, *optional*, 默认为 `False`):
                标识令牌列表是否已经包含模型的特殊标记。

        Returns:
            `List[int]`: 一个整数列表,取值为 [0, 1]:1 表示特殊标记,0 表示序列标记。
        """
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        if not self.add_bos_token:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=False
            )

        if token_ids_1 is None:
            return [1] + ([0] * len(token_ids_0))
        return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))

    def _tokenize(self, text):
        """对字符串进行分词。"""
        bpe_tokens = []
        for token in re.findall(self.pat, text):
            token = "".join(
                self.byte_encoder[b] for b in token.encode("utf-8")
            )  # 将所有字节映射为 Unicode 字符串,避免 BPE 的控制令牌(在本例中为空格)
            bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
        return bpe_tokens

    def _convert_token_to_id(self, token):
        """使用词汇表将令牌(str)转换为 ID。"""
        return self.encoder.get(token, self.encoder.get(self.unk_token))

    def _convert_id_to_token(self, index):
        """使用词汇表将索引(整数)转换为令牌(str)。"""
        return self.decoder.get(index)

    def convert_tokens_to_string(self, tokens):
        """将令牌序列(字符串)转换为单个字符串。"""
        text = "".join(tokens)
        text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
        return text
    # 将词汇表保存到指定目录
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # 检查保存目录是否存在,如果不存在则记录错误并返回
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return
        
        # 构建词汇表文件路径
        vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )
        
        # 构建合并文件路径
        merge_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
        )

        # 写入词汇表到vocab_file中
        with open(vocab_file, "w", encoding="utf-8") as f:
            f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")

        index = 0
        # 写入合并信息到merge_file中
        with open(merge_file, "w", encoding="utf-8") as writer:
            writer.write("#version: 0.2\n")
            # 遍历并按照token_index排序写入BPE merges信息
            for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
                if index != token_index:
                    # 如果BPE合并索引不连续,记录警告
                    logger.warning(
                        f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
                        " Please check that the tokenizer is not corrupted!"
                    )
                    index = token_index
                writer.write(" ".join(bpe_tokens) + "\n")
                index += 1

        # 返回保存的文件路径
        return vocab_file, merge_file

    # 为进行分词准备文本
    def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
        add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
        # 如果文本已经被分成单词或者需要在文本前添加前缀空格,则在文本前添加空格
        if is_split_into_words or add_prefix_space:
            text = " " + text
        # 返回处理后的文本和参数
        return (text, kwargs)

    # 默认的聊天模板
    @property
    def default_chat_template(self):
        """
        A simple chat template that ignores role information and just concatenates messages with EOS tokens.
        """
        # 如果没有为这个分词器定义聊天模板,记录警告并返回默认模板
        logger.warning_once(
            "\nNo chat template is defined for this tokenizer - using the default template "
            f"for the {self.__class__.__name__} class. If the default is not appropriate for "
            "your model, please set `tokenizer.chat_template` to an appropriate template. "
            "See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n"
        )
        return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"

.\models\gpt2\tokenization_gpt2_fast.py

# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""

# 导入必要的模块
import json  # 导入处理 JSON 格式的模块
from typing import Optional, Tuple  # 导入类型提示相关的模块

from tokenizers import pre_tokenizers  # 从 tokenizers 包中导入预处理器相关功能

from ...tokenization_utils_base import BatchEncoding  # 导入批量编码类
from ...tokenization_utils_fast import PreTrainedTokenizerFast  # 导入快速预训练分词器类
from ...utils import logging  # 导入日志记录工具
from .tokenization_gpt2 import GPT2Tokenizer  # 从当前目录中导入 GPT2Tokenizer 类

# 获取日志记录器对象
logger = logging.get_logger(__name__)

# 定义用于存储文件名的常量字典
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}

# 预训练模型的词汇文件映射字典
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "openai-community/gpt2": "https://huggingface.co/openai-community/gpt2/resolve/main/vocab.json",
        "openai-community/gpt2-medium": "https://huggingface.co/openai-community/gpt2-medium/resolve/main/vocab.json",
        "openai-community/gpt2-large": "https://huggingface.co/openai-community/gpt2-large/resolve/main/vocab.json",
        "openai-community/gpt2-xl": "https://huggingface.co/openai-community/gpt2-xl/resolve/main/vocab.json",
        "distilbert/distilgpt2": "https://huggingface.co/distilbert/distilgpt2/resolve/main/vocab.json",
    },
    "merges_file": {
        "openai-community/gpt2": "https://huggingface.co/openai-community/gpt2/resolve/main/merges.txt",
        "openai-community/gpt2-medium": "https://huggingface.co/openai-community/gpt2-medium/resolve/main/merges.txt",
        "openai-community/gpt2-large": "https://huggingface.co/openai-community/gpt2-large/resolve/main/merges.txt",
        "openai-community/gpt2-xl": "https://huggingface.co/openai-community/gpt2-xl/resolve/main/merges.txt",
        "distilbert/distilgpt2": "https://huggingface.co/distilbert/distilgpt2/resolve/main/merges.txt",
    },
    "tokenizer_file": {
        "openai-community/gpt2": "https://huggingface.co/openai-community/gpt2/resolve/main/tokenizer.json",
        "openai-community/gpt2-medium": "https://huggingface.co/openai-community/gpt2-medium/resolve/main/tokenizer.json",
        "openai-community/gpt2-large": "https://huggingface.co/openai-community/gpt2-large/resolve/main/tokenizer.json",
        "openai-community/gpt2-xl": "https://huggingface.co/openai-community/gpt2-xl/resolve/main/tokenizer.json",
        "distilbert/distilgpt2": "https://huggingface.co/distilbert/distilgpt2/resolve/main/tokenizer.json",
    },
}

# 预训练位置嵌入大小映射字典
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    # 创建一个字典,包含不同的模型名称作为键,每个模型的默认大小(1024)作为值
    "openai-community/gpt2": 1024,
    "openai-community/gpt2-medium": 1024,
    "openai-community/gpt2-large": 1024,
    "openai-community/gpt2-xl": 1024,
    "distilbert/distilgpt2": 1024,
}

class GPT2TokenizerFast(PreTrainedTokenizerFast):
    """
    Construct a "fast" GPT-2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
    Byte-Pair-Encoding.

    This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
    be encoded differently whether it is at the beginning of the sentence (without space) or not:

    ```
    >>> from transformers import GPT2TokenizerFast

    >>> tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
    >>> tokenizer("Hello world")["input_ids"]
    [15496, 995]

    >>> tokenizer(" Hello world")["input_ids"]
    [18435, 995]
    ```

    You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer, but since
    the model was not pretrained this way, it might yield a decrease in performance.

    <Tip>

    When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.

    </Tip>

    This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
    refer to this superclass for more information regarding those methods.

    Args:
        vocab_file (`str`, *optional*):
            Path to the vocabulary file.
        merges_file (`str`, *optional*):
            Path to the merges file.
        tokenizer_file (`str`, *optional*):
            Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
            contains everything needed to load the tokenizer.
        unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
            The beginning of sequence token.
        eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
            The end of sequence token.
        add_prefix_space (`bool`, *optional*, defaults to `False`):
            Whether or not to add an initial space to the input. This allows to treat the leading word just as any
            other word. (GPT2 tokenizer detect beginning of words by the preceding space).
    """

    vocab_files_names = VOCAB_FILES_NAMES
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    model_input_names = ["input_ids", "attention_mask"]
    slow_tokenizer_class = GPT2Tokenizer

    def __init__(
        self,
        vocab_file=None,
        merges_file=None,
        tokenizer_file=None,
        unk_token="<|endoftext|>",
        bos_token="<|endoftext|>",
        eos_token="<|endoftext|>",
        add_prefix_space=False,
        **kwargs,
    ):
        """
        Initialize the GPT2TokenizerFast class.

        Args:
            vocab_file (`str`, *optional*):
                Path to the vocabulary file.
            merges_file (`str`, *optional*):
                Path to the merges file.
            tokenizer_file (`str`, *optional*):
                Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
                contains everything needed to load the tokenizer.
            unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
                The unknown token.
            bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
                The beginning of sequence token.
            eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
                The end of sequence token.
            add_prefix_space (`bool`, *optional*, defaults to `False`):
                Whether or not to add an initial space to the input. This allows to treat the leading word just as any
                other word. (GPT2 tokenizer detect beginning of words by the preceding space).
            **kwargs:
                Additional arguments passed to the superclass.
        """
        super().__init__(
            vocab_file=vocab_file,
            merges_file=merges_file,
            tokenizer_file=tokenizer_file,
            unk_token=unk_token,
            bos_token=bos_token,
            eos_token=eos_token,
            add_prefix_space=add_prefix_space,
            **kwargs,
        )
    ):
        super().__init__(
            vocab_file,
            merges_file,
            tokenizer_file=tokenizer_file,
            unk_token=unk_token,
            bos_token=bos_token,
            eos_token=eos_token,
            add_prefix_space=add_prefix_space,
            **kwargs,
        )

        self.add_bos_token = kwargs.pop("add_bos_token", False)

        # 获取当前的预处理器状态
        pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
        # 检查预处理器是否需要添加前缀空格,如果需要则更新其状态
        if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
            pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
            pre_tok_state["add_prefix_space"] = add_prefix_space
            self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)

        self.add_prefix_space = add_prefix_space

    def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
        # 检查是否已经将输入分割为单词
        is_split_into_words = kwargs.get("is_split_into_words", False)
        # 如果需要使用预分词的输入,确保实例化时设置了 add_prefix_space=True
        assert self.add_prefix_space or not is_split_into_words, (
            f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
            "to use it with pretokenized inputs."
        )

        return super()._batch_encode_plus(*args, **kwargs)

    def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
        # 检查是否已经将输入分割为单词
        is_split_into_words = kwargs.get("is_split_into_words", False)

        # 如果需要使用预分词的输入,确保实例化时设置了 add_prefix_space=True
        assert self.add_prefix_space or not is_split_into_words, (
            f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
            "to use it with pretokenized inputs."
        )

        return super()._encode_plus(*args, **kwargs)

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # 保存词汇表到指定的目录,返回保存的文件名
        files = self._tokenizer.model.save(save_directory, name=filename_prefix)
        return tuple(files)

    @property
    # 从 transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.default_chat_template 复制而来
    def default_chat_template(self):
        """
        A simple chat template that ignores role information and just concatenates messages with EOS tokens.
        """
        # 若没有为这个分词器定义聊天模板,则使用默认的模板,并记录警告信息
        logger.warning_once(
            "\nNo chat template is defined for this tokenizer - using the default template "
            f"for the {self.__class__.__name__} class. If the default is not appropriate for "
            "your model, please set `tokenizer.chat_template` to an appropriate template. "
            "See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n"
        )
        return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"

.\models\gpt2\tokenization_gpt2_tf.py

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs):
        """Creates TFGPT2Tokenizer from pretrained GPT2Tokenizer

        Args:
            pretrained_model_name_or_path (Union[str, os.PathLike]): Path to pretrained model

        Examples:

        ```
        from transformers import TFGPT2Tokenizer

        tf_tokenizer = TFGPT2Tokenizer.from_pretrained("openai-community/gpt2")
        ```
        """
        # 使用给定的模型名或路径加载预训练的GPT2Tokenizer对象
        tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
        # 使用加载的GPT2Tokenizer对象创建TFGPT2Tokenizer对象
        return cls.from_tokenizer(tokenizer, *init_inputs, **kwargs)
    # 从配置信息创建 TFGPT2Tokenizer 的类方法
    def from_config(cls, config):
        """Creates TFGPT2Tokenizer from configurations

        Args:
            config (Dict): Dictionary with keys such as stated in `get_config`.
        """
        # 使用传入的配置参数创建 TFGPT2Tokenizer 实例并返回
        return cls(**config)

    # 返回当前实例的配置信息字典
    def get_config(self):
        return {
            "vocab": self.vocab,                # 返回词汇表
            "merges": self.merges,              # 返回合并信息
            "max_length": self.max_length,      # 返回最大长度
            "pad_token_id": self.pad_token_id,  # 返回填充标记的ID
        }

    # 对输入的文本进行处理,生成模型的输入
    def call(self, x, max_length: int = None):
        # 使用 TensorFlow Tokenizer 处理输入文本得到输入的ID
        input_ids = self.tf_tokenizer(x)
        # 创建一个全为1的注意力掩码
        attention_mask = tf.ones_like(input_ids)

        if self.pad_token_id is not None:
            # 如果存在填充标记ID,则将输入ID填充至最大长度
            max_length = max_length if max_length is not None else self.max_length

            if max_length is not None:
                # 使用 pad_model_inputs 函数填充输入ID和注意力掩码
                input_ids, attention_mask = pad_model_inputs(
                    input_ids, max_seq_length=max_length, pad_value=self.pad_token_id
                )

        # 返回注意力掩码和填充后的输入ID作为字典形式的结果
        return {"attention_mask": attention_mask, "input_ids": input_ids}

.\models\gpt2\__init__.py

# 引入类型检查工具,用于类型检查
from typing import TYPE_CHECKING

# 引入必要的依赖模块和函数
# 从 utils 模块中导入必要的异常类、延迟加载模块、可用性检查函数等
from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_flax_available,
    is_keras_nlp_available,
    is_tensorflow_text_available,
    is_tf_available,
    is_tokenizers_available,
    is_torch_available,
)

# 定义模块导入结构字典,包含各模块所需导入的类或函数列表
_import_structure = {
    "configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"],
    "tokenization_gpt2": ["GPT2Tokenizer"],
}

# 尝试导入 tokenizers 模块,如果不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_tokenizers_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,将 tokenization_gpt2_fast 模块添加到导入结构中
    _import_structure["tokenization_gpt2_fast"] = ["GPT2TokenizerFast"]

# 尝试导入 torch 模块,如果不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,将 modeling_gpt2 模块添加到导入结构中
    _import_structure["modeling_gpt2"] = [
        "GPT2_PRETRAINED_MODEL_ARCHIVE_LIST",
        "GPT2DoubleHeadsModel",
        "GPT2ForQuestionAnswering",
        "GPT2ForSequenceClassification",
        "GPT2ForTokenClassification",
        "GPT2LMHeadModel",
        "GPT2Model",
        "GPT2PreTrainedModel",
        "load_tf_weights_in_gpt2",
    ]

# 尝试导入 tensorflow 模块,如果不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,将 modeling_tf_gpt2 模块添加到导入结构中
    _import_structure["modeling_tf_gpt2"] = [
        "TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST",
        "TFGPT2DoubleHeadsModel",
        "TFGPT2ForSequenceClassification",
        "TFGPT2LMHeadModel",
        "TFGPT2MainLayer",
        "TFGPT2Model",
        "TFGPT2PreTrainedModel",
    ]

# 尝试导入 keras_nlp 模块,如果不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_keras_nlp_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,将 tokenization_gpt2_tf 模块添加到导入结构中
    _import_structure["tokenization_gpt2_tf"] = ["TFGPT2Tokenizer"]

# 尝试导入 flax 模块,如果不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_flax_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,将 modeling_flax_gpt2 模块添加到导入结构中
    _import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"]

# 如果在类型检查模式下,导入所需的类型定义
if TYPE_CHECKING:
    from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig
    from .tokenization_gpt2 import GPT2Tokenizer

    try:
        # 在类型检查模式下,检查 tokenizers 模块的可用性
        if not is_tokenizers_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    # 如果前面导入失败,则尝试导入 GPT2TokenizerFast
    else:
        from .tokenization_gpt2_fast import GPT2TokenizerFast

    try:
        # 检查是否存在 torch 库,如果不存在则抛出 OptionalDependencyNotAvailable 异常
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        # 如果依赖不可用,则忽略此部分代码块
        pass
    else:
        # 导入相关的 GPT-2 模型和相关类
        from .modeling_gpt2 import (
            GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
            GPT2DoubleHeadsModel,
            GPT2ForQuestionAnswering,
            GPT2ForSequenceClassification,
            GPT2ForTokenClassification,
            GPT2LMHeadModel,
            GPT2Model,
            GPT2PreTrainedModel,
            load_tf_weights_in_gpt2,
        )

    try:
        # 检查是否存在 TensorFlow 库,如果不存在则抛出 OptionalDependencyNotAvailable 异常
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        # 如果依赖不可用,则忽略此部分代码块
        pass
    else:
        # 导入相关的 TensorFlow 版本的 GPT-2 模型和相关类
        from .modeling_tf_gpt2 import (
            TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
            TFGPT2DoubleHeadsModel,
            TFGPT2ForSequenceClassification,
            TFGPT2LMHeadModel,
            TFGPT2MainLayer,
            TFGPT2Model,
            TFGPT2PreTrainedModel,
        )

    try:
        # 检查是否存在 keras_nlp 库,如果不存在则抛出 OptionalDependencyNotAvailable 异常
        if not is_keras_nlp_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        # 如果依赖不可用,则忽略此部分代码块
        pass
    else:
        # 导入 TensorFlow 版本的 GPT-2 的 tokenizer
        from .tokenization_gpt2_tf import TFGPT2Tokenizer

    try:
        # 检查是否存在 flax 库,如果不存在则抛出 OptionalDependencyNotAvailable 异常
        if not is_flax_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        # 如果依赖不可用,则忽略此部分代码块
        pass
    else:
        # 导入 Flax 版本的 GPT-2 模型和相关类
        from .modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
else:
    # 如果不处于前述条件分支,则执行以下操作

    import sys
    # 导入系统模块,用于操作 Python 运行时环境

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
    # 将当前模块名注册到 sys.modules 中,以 LazyModule 的形式,支持按需导入的模块加载策略

.\models\gptj\configuration_gptj.py

# coding=utf-8
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" GPT-J model configuration"""

# 引入 OrderedDict 类用于创建有序字典,以及其他必要的类型和模块
from collections import OrderedDict
from typing import Any, List, Mapping, Optional

# 引入 Hugging Face 库中的一些模块和函数
from ... import PreTrainedTokenizer, TensorType, is_torch_available
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast, PatchingSpec
from ...utils import logging

# 获取日志记录器对象,用于记录和输出日志信息
logger = logging.get_logger(__name__)

# 定义 GPT-J 预训练模型配置文件的存档映射字典
GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "EleutherAI/gpt-j-6B": "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/main/config.json",
    # 可在 https://huggingface.co/models?filter=gpt_j 查看所有 GPT-J 模型
}

# 定义 GPTJConfig 类,用于存储 GPT-J 模型的配置信息
class GPTJConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`GPTJModel`]. It is used to instantiate a GPT-J
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the GPT-J
    [EleutherAI/gpt-j-6B](https://huggingface.co/EleutherAI/gpt-j-6B) architecture. Configuration objects inherit from
    [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`]
    for more information.
    """
    # 定义模型类型为 GPT-J
    model_type = "gptj"
    
    # 创建属性映射字典,将配置参数名映射到对应的 GPT-J 模型配置参数名
    attribute_map = {
        "max_position_embeddings": "n_positions",     # 最大序列长度映射到 n_positions
        "hidden_size": "n_embd",                      # 隐藏大小映射到 n_embd
        "num_attention_heads": "n_head",              # 注意力头数映射到 n_head
        "num_hidden_layers": "n_layer",               # 隐藏层数映射到 n_layer
    }
    # 定义一个初始化函数,用于初始化一个Transformer模型的参数和设置
    def __init__(
        self,
        vocab_size=50400,                        # 词汇表大小,默认为50400
        n_positions=2048,                        # 最大位置编码数,默认为2048
        n_embd=4096,                             # 嵌入层维度,默认为4096
        n_layer=28,                              # Transformer层数,默认为28层
        n_head=16,                               # 自注意力机制中头数,默认为16
        rotary_dim=64,                           # 旋转注意力机制的维度,默认为64
        n_inner=None,                            # Transformer内部层的维度,默认为None
        activation_function="gelu_new",          # 激活函数类型,默认为"gelu_new"
        resid_pdrop=0.0,                          # 残差连接的dropout概率,默认为0.0
        embd_pdrop=0.0,                           # 嵌入层的dropout概率,默认为0.0
        attn_pdrop=0.0,                           # 注意力层的dropout概率,默认为0.0
        layer_norm_epsilon=1e-5,                  # Layer Norm层的epsilon,默认为1e-5
        initializer_range=0.02,                   # 参数初始化的范围,默认为0.02
        use_cache=True,                           # 是否使用缓存,默认为True
        bos_token_id=50256,                       # 开始词的token id,默认为50256
        eos_token_id=50256,                       # 结束词的token id,默认为50256
        tie_word_embeddings=False,                # 是否绑定词嵌入,默认为False
        **kwargs,                                 # 其他关键字参数
    ):
        self.vocab_size = vocab_size               # 初始化词汇表大小属性
        self.n_positions = n_positions             # 初始化最大位置编码数属性
        self.n_embd = n_embd                       # 初始化嵌入层维度属性
        self.n_layer = n_layer                     # 初始化Transformer层数属性
        self.n_head = n_head                       # 初始化自注意力机制头数属性
        self.n_inner = n_inner                     # 初始化Transformer内部层维度属性
        self.rotary_dim = rotary_dim               # 初始化旋转注意力机制维度属性
        self.activation_function = activation_function  # 初始化激活函数类型属性
        self.resid_pdrop = resid_pdrop             # 初始化残差连接的dropout概率属性
        self.embd_pdrop = embd_pdrop               # 初始化嵌入层的dropout概率属性
        self.attn_pdrop = attn_pdrop               # 初始化注意力层的dropout概率属性
        self.layer_norm_epsilon = layer_norm_epsilon  # 初始化Layer Norm层的epsilon属性
        self.initializer_range = initializer_range  # 初始化参数初始化范围属性
        self.use_cache = use_cache                 # 初始化是否使用缓存属性
    
        self.bos_token_id = bos_token_id           # 初始化开始词的token id属性
        self.eos_token_id = eos_token_id           # 初始化结束词的token id属性
    
        # 调用父类的初始化方法,传递开始词的token id、结束词的token id和是否绑定词嵌入的参数
        super().__init__(
            bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
        )
# 从transformers.models.gpt2.configuration_gpt2.GPT2OnnxConfig复制而来的配置类GPTJOnnxConfig,
# 继承自OnnxConfigWithPast。
class GPTJOnnxConfig(OnnxConfigWithPast):
    
    # 初始化方法,接受以下参数:
    # - config: 预训练配置对象
    # - task: 任务名称,默认为"default"
    # - patching_specs: 补丁规格列表,可选参数,默认为None
    # - use_past: 是否使用过去键值,布尔类型,默认为False
    def __init__(
        self,
        config: PretrainedConfig,
        task: str = "default",
        patching_specs: List[PatchingSpec] = None,
        use_past: bool = False,
    ):
        # 调用父类的初始化方法,传递配置对象、任务名称、补丁规格列表和是否使用过去键值
        super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
        
        # 如果配置对象的pad_token_id属性不存在
        if not getattr(self._config, "pad_token_id", None):
            # 设置pad_token_id为0(默认值)
            self._config.pad_token_id = 0

    # 输入属性,返回一个字典,表示常见的输入格式
    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        # 创建一个有序字典,包含输入ids的批次和序列索引
        common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
        
        # 如果使用过去键值
        if self.use_past:
            # 填充输入字典,包括过去键值的方向
            self.fill_with_past_key_values_(common_inputs, direction="inputs")
            # 添加注意力遮罩,考虑过去序列和当前序列
            common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
        else:
            # 添加默认的注意力遮罩,仅考虑当前序列
            common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}

        # 返回构建好的输入字典
        return common_inputs

    # 层数属性,返回配置对象的层数
    @property
    def num_layers(self) -> int:
        return self._config.n_layer

    # 注意力头数属性,返回配置对象的注意力头数
    @property
    def num_attention_heads(self) -> int:
        return self._config.n_head

    # 生成虚拟输入方法,接受以下参数:
    # - tokenizer: 预训练分词器对象
    # - batch_size: 批次大小,整数,默认为-1
    # - seq_length: 序列长度,整数,默认为-1
    # - is_pair: 是否是成对输入,布尔类型,默认为False
    # - framework: 框架类型,可选参数,默认为None
    ) -> Mapping[str, Any]:
        # 调用父类方法生成通用的虚拟输入数据
        common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
            tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
        )

        # 根据模型前向方法的输入顺序,重新排序输入数据
        ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})

        # 如果需要使用过去的键值(past_keys)
        if self.use_past:
            # 检查是否有安装 PyTorch,否则抛出错误
            if not is_torch_available():
                raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
            else:
                import torch

                batch, seqlen = common_inputs["input_ids"].shape
                # 计算过去键值的长度,比序列长度多两个
                past_key_values_length = seqlen + 2
                past_shape = (
                    batch,
                    self.num_attention_heads,
                    past_key_values_length,
                    self._config.hidden_size // self.num_attention_heads,
                )
                # 为每个层生成空的过去键值对
                ordered_inputs["past_key_values"] = [
                    (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
                ]

        # 将通用的注意力掩码添加到排序后的输入中
        ordered_inputs["attention_mask"] = common_inputs["attention_mask"]

        # 如果需要使用过去的键值(past_keys)
        if self.use_past:
            # 获取掩码的数据类型并为过去的键值对添加新的掩码
            mask_dtype = ordered_inputs["attention_mask"].dtype
            ordered_inputs["attention_mask"] = torch.cat(
                [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
            )

        # 返回最终排序后的输入字典
        return ordered_inputs

    @property
    def default_onnx_opset(self) -> int:
        # 返回默认的 ONNX 操作集版本号
        return 13

.\models\gptj\modeling_flax_gptj.py

from functools import partial
from typing import Optional, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax

from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_gptj import GPTJConfig
    # 定义函数 `ultimate`,接收配置参数和数据类型参数
    Parameters:
        config ([`GPTJConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
            The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
            `jax.numpy.bfloat16` (on TPUs).

            This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
            specified all the computation will be performed with the given `dtype`.

            **Note that this only specifies the dtype of the computation and does not influence the dtype of model
            parameters.**

            If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
            [`~FlaxPreTrainedModel.to_bf16`].
"""

GPTJ_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.
        past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
            Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
            auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""


def create_sinusoidal_positions(num_pos, dim):
    # 计算频率因子
    inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
    # 计算正弦和余弦位置编码
    sinusoid_inp = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32")
    sin, cos = np.sin(sinusoid_inp), np.cos(sinusoid_inp)

    # 计算需要填充的数量
    sentinel = dim // 2 + dim % 2
    # 创建输出数组
    out = np.zeros((num_pos, dim))
    # 填充正弦和余弦值
    out[:, 0:sentinel] = sin
    out[:, sentinel:] = cos

    return jnp.array(out)


def rotate_every_two(tensor):
    # 旋转张量中的每两个元素
    rotate_half_tensor = jnp.stack((-tensor[:, :, :, 1::2], tensor[:, :, :, ::2]), axis=-1)
    rotate_half_tensor = rotate_half_tensor.reshape(rotate_half_tensor.shape[:-2] + (-1,))
    return rotate_half_tensor


def apply_rotary_pos_emb(tensor, sincos):
    sin_pos, cos_pos = sincos
    # 扩展正弦和余弦位置编码
    sin_pos = sin_pos[:, :, None, :].repeat(2, 3)
    cos_pos = cos_pos[:, :, None, :].repeat(2, 3)
    # 应用旋转位置编码
    return (tensor * cos_pos) + (rotate_every_two(tensor) * sin_pos)


class FlaxGPTJAttention(nn.Module):
    config: GPTJConfig
    dtype: jnp.dtype = jnp.float32
    # 定义一个布尔类型的实例变量 causal,默认为 True,表示是否使用因果注意力
    causal: bool = True
    # 定义一个布尔类型的实例变量 is_cross_attention,默认为 False,表示是否进行交叉注意力

    # 初始化方法,在创建类实例时被调用,用于配置模型参数和初始化一些变量
    def setup(self):
        # 获取配置信息
        config = self.config
        # 设置嵌入维度为隐藏大小
        self.embed_dim = config.hidden_size
        # 设置注意力头数为配置中指定的注意力头数
        self.num_heads = config.num_attention_heads
        # 计算每个注意力头的维度
        self.head_dim = self.embed_dim // self.num_heads

        # 设置旋转维度为配置中指定的旋转维度,如果未指定则使用嵌入维度
        self.rotary_dim = config.rotary_dim

        # 创建偏函数 dense,用于创建全连接层
        dense = partial(
            nn.Dense,
            self.embed_dim,
            use_bias=False,
            dtype=self.dtype,
            # 使用正态分布初始化权重,范围由配置中的 initializer_range 指定
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )

        # 创建查询、键、值投影层
        self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
        # 创建输出投影层
        self.out_proj = dense()

        # 设置残差连接的 dropout 层,丢弃率由配置中的 resid_pdrop 指定
        self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)

        # 创建因果注意力掩码,形状为 (1, max_position_embeddings),用于屏蔽未来信息
        self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")

        # 计算位置编码维度,如果未指定旋转维度,则使用嵌入维度
        pos_embd_dim = self.rotary_dim or self.embed_dim
        # 创建正弦位置编码
        self.embed_positions = create_sinusoidal_positions(config.max_position_embeddings, pos_embd_dim)

    # 将隐藏状态按注意力头分割
    def _split_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))

    # 将分割的注意力头合并
    def _merge_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))

    # 使用 JAX 提供的 @nn.compact 装饰器,表示这是一个使用了 JAX 的紧凑型层定义
    @nn.compact
    def _concatenate_to_cache(self, key, value, query, attention_mask):
        """
        This function takes projected key, value states from a single input token and concatenates the states to cached
        states from previous steps. This function is slightly adapted from the official Flax repository:
        https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
        """
        # 检测是否通过缺少现有缓存数据来初始化。
        is_initialized = self.has_variable("cache", "cached_key")
        # 获取缓存的键,如果不存在则初始化为形状和类型与输入相同的零数组
        cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
        # 获取缓存的值,如果不存在则初始化为形状和类型与输入相同的零数组
        cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
        # 获取缓存索引,如果不存在则初始化为值为0的整数数组
        cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))

        if is_initialized:
            # 获取批次维度、最大长度、注意力头数和每个头部深度的缓存键形状
            *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
            # 使用新的一维空间切片更新键和值缓存
            cur_index = cache_index.value
            indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
            key = lax.dynamic_update_slice(cached_key.value, key, indices)
            value = lax.dynamic_update_slice(cached_value.value, value, indices)
            cached_key.value = key
            cached_value.value = value
            # 更新缓存索引,增加已更新的缓存向量数量
            num_updated_cache_vectors = query.shape[1]
            cache_index.value = cache_index.value + num_updated_cache_vectors
            # 对于缓存的解码器自注意力,生成因果掩码:我们的单个查询位置只应注意到已生成和缓存的键位置,而不是剩余的零元素。
            pad_mask = jnp.broadcast_to(
                jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
                tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
            )
            # 合并填充掩码和输入的注意力掩码
            attention_mask = combine_masks(pad_mask, attention_mask)
        # 返回更新后的键、值和注意力掩码
        return key, value, attention_mask
class FlaxGPTJMLP(nn.Module):
    # GPTJConfig 类型的配置对象
    config: GPTJConfig
    # 中间层大小
    intermediate_size: int
    # 数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 初始化模块
    def setup(self):
        # 嵌入维度为隐藏大小
        embed_dim = self.config.hidden_size
        # 使用正态分布初始化核
        kernel_init = jax.nn.initializers.normal(self.config.initializer_range)

        # 输入全连接层
        self.fc_in = nn.Dense(self.intermediate_size, dtype=self.dtype, kernel_init=kernel_init)
        # 输出全连接层
        self.fc_out = nn.Dense(embed_dim, dtype=self.dtype, kernel_init=kernel_init)

        # 激活函数选择
        self.act = ACT2FN[self.config.activation_function]
        # 丢弃率为 config 中的 resid_pdrop
        self.dropout = nn.Dropout(rate=self.config.resid_pdrop)

    # 调用模块
    def __call__(self, hidden_states, deterministic: bool = True):
        # 输入经过输入全连接层
        hidden_states = self.fc_in(hidden_states)
        # 应用激活函数
        hidden_states = self.act(hidden_states)
        # 输入经过输出全连接层
        hidden_states = self.fc_out(hidden_states)
        # 使用 dropout 进行正则化
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        return hidden_states


class FlaxGPTJBlock(nn.Module):
    # GPTJConfig 类型的配置对象
    config: GPTJConfig
    # 数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 初始化模块
    def setup(self):
        # 隐藏大小
        hidden_size = self.config.hidden_size
        # 内部维度为 n_inner 或者 4 * hidden_size
        inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size

        # 层归一化
        self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
        # 自注意力机制
        self.attn = FlaxGPTJAttention(self.config, dtype=self.dtype)

        # 多层感知机
        self.mlp = FlaxGPTJMLP(self.config, inner_dim, dtype=self.dtype)

    # 调用模块
    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        position_ids=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
    ):
        # 残差连接
        residual = hidden_states
        # 层归一化
        hidden_states = self.ln_1(hidden_states)
        # 自注意力机制输出
        attn_outputs = self.attn(
            hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
        )
        attn_output = attn_outputs[0]

        # 前馈网络的隐藏状态
        feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic)
        # 残差连接
        hidden_states = attn_output + feed_forward_hidden_states + residual

        return (hidden_states,) + attn_outputs[1:]


class FlaxGPTJPreTrainedModel(FlaxPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # GPTJConfig 类型的配置类
    config_class = GPTJConfig
    # 基础模型前缀为 "transformer"
    base_model_prefix = "transformer"
    # 模块类
    module_class: nn.Module = None

    # 初始化方法
    def __init__(
        self,
        config: GPTJConfig,
        input_shape: Tuple = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        # 使用给定的配置和类型初始化模块
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # 初始化输入张量
        input_ids = jnp.zeros(input_shape, dtype="i4")
        # 创建与input_ids相同形状的全1注意力掩码
        attention_mask = jnp.ones_like(input_ids)
        # 创建位置编码,广播以匹配input_ids的形状
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
        # 划分随机数生成器rng为参数rngs和dropout_rng
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        if self.config.add_cross_attention:
            # 如果配置中包含跨注意力机制,初始化编码器隐藏状态和注意力掩码
            encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
            encoder_attention_mask = attention_mask
            # 使用模型模块初始化
            module_init_outputs = self.module.init(
                rngs,
                input_ids,
                attention_mask,
                position_ids,
                encoder_hidden_states,
                encoder_attention_mask,
                return_dict=False,
            )
        else:
            # 否则仅使用输入初始化模型模块
            module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)

        # 获取随机初始化的模型参数
        random_params = module_init_outputs["params"]

        if params is not None:
            # 如果提供了预先定义的参数,将随机参数和预定义参数扁平化处理并合并
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            # 返回冻结后的合并参数
            return freeze(unflatten_dict(params))
        else:
            # 否则直接返回随机初始化的参数
            return random_params

    def init_cache(self, batch_size, max_length):
        r"""
        Args:
            batch_size (`int`):
                用于快速自回归解码的批量大小,定义了初始化缓存的批处理大小。
            max_length (`int`):
                自回归解码的最大可能长度,定义了初始化缓存的序列长度。
        """
        # 初始化用于检索缓存的输入变量
        input_ids = jnp.ones((batch_size, max_length))
        # 创建与input_ids相同形状的全1注意力掩码
        attention_mask = jnp.ones_like(input_ids)
        # 创建位置编码,广播以匹配input_ids的形状
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        # 使用模型模块初始化,设置init_cache标志以初始化缓存
        init_variables = self.module.init(
            jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
        )
        # 返回初始化的缓存
        return init_variables["cache"]

    @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING)
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        position_ids=None,
        params: dict = None,
        past_key_values: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        ):
            # 检查是否需要输出注意力权重,若未指定则使用配置中的默认设置
            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
            # 检查是否需要输出隐藏状态,若未指定则使用配置中的默认设置
            output_hidden_states = (
                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
            )
            # 检查是否需要返回字典形式的输出,若未指定则使用配置中的默认设置
            return_dict = return_dict if return_dict is not None else self.config.return_dict

            # 获取输入张量的批量大小和序列长度
            batch_size, sequence_length = input_ids.shape

            # 如果未提供位置编码(position_ids),且已传递过去的键值(past_key_values)不为空,则抛出错误
            if position_ids is None:
                if past_key_values is not None:
                    raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")

                # 使用序列长度创建广播位置编码(position_ids)
                position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

            # 如果未提供注意力遮罩(attention_mask),则创建全为1的注意力遮罩
            if attention_mask is None:
                attention_mask = jnp.ones((batch_size, sequence_length))

            # 处理可能存在的伪随机数发生器
            rngs = {}
            if dropout_rng is not None:
                rngs["dropout"] = dropout_rng

            # 构建输入参数字典
            inputs = {"params": params or self.params}

            # 如果已传递过去的键值(past_key_values)不为空,则初始化缓存
            if past_key_values:
                inputs["cache"] = past_key_values
                mutable = ["cache"]
            else:
                mutable = False

            # 应用模型的前向传播
            outputs = self.module.apply(
                inputs,
                jnp.array(input_ids, dtype="i4"),
                jnp.array(attention_mask, dtype="i4"),
                jnp.array(position_ids, dtype="i4"),
                not train,
                False,
                output_attentions,
                output_hidden_states,
                return_dict,
                rngs=rngs,
                mutable=mutable,
            )

            # 如果已传递过去的键值(past_key_values)不为空且需要返回字典形式的输出(return_dict),则添加更新后的缓存到模型输出中
            if past_key_values is not None and return_dict:
                outputs, past_key_values = outputs
                outputs["past_key_values"] = unfreeze(past_key_values["cache"])
                return outputs
            # 如果已传递过去的键值(past_key_values)不为空且不需要返回字典形式的输出(return_dict),则更新输出元组中的缓存信息
            elif past_key_values is not None and not return_dict:
                outputs, past_key_values = outputs
                outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]

            # 返回模型的输出
            return outputs
# 定义一个名为 FlaxGPTJBlockCollection 的类,继承自 nn.Module
class FlaxGPTJBlockCollection(nn.Module):
    # 类属性 config,类型为 GPTJConfig,dtype 默认为 jnp.float32
    config: GPTJConfig
    dtype: jnp.dtype = jnp.float32

    # 定义 setup 方法,用于初始化模块
    def setup(self):
        # 创建一个包含多个 FlaxGPTJBlock 实例的列表 self.blocks
        self.blocks = [
            FlaxGPTJBlock(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
        ]

    # 定义 __call__ 方法,使对象可以像函数一样调用
    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        position_ids=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 如果输出注意力矩阵,则初始化一个空元组 all_attentions
        all_attentions = () if output_attentions else None
        # 如果输出隐藏状态,则初始化一个空元组 all_hidden_states
        all_hidden_states = () if output_hidden_states else None

        # 遍历 self.blocks 中的每个 FlaxGPTJBlock 实例
        for block in self.blocks:
            # 如果需要输出隐藏状态,则将当前隐藏状态添加到 all_hidden_states
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            # 调用 block 对象处理当前的 hidden_states 等参数,返回 layer_outputs
            layer_outputs = block(
                hidden_states,
                attention_mask,
                position_ids=position_ids,
                deterministic=deterministic,
                init_cache=init_cache,
                output_attentions=output_attentions,
            )
            # 更新 hidden_states 为当前层处理后的输出
            hidden_states = layer_outputs[0]

            # 如果需要输出注意力矩阵,则将当前层的注意力矩阵添加到 all_attentions
            if output_attentions:
                all_attentions += (layer_outputs[1],)

        # 构建输出元组 outputs,包含最终的 hidden_states、all_hidden_states 和 all_attentions
        # 其中 all_attentions 可能包含 None 值,会在 FlaxGPTJModule 中进行过滤处理
        outputs = (hidden_states, all_hidden_states, all_attentions)

        # 返回最终的输出结果
        return outputs


# 定义一个名为 FlaxGPTJModule 的类,继承自 nn.Module
class FlaxGPTJModule(nn.Module):
    # 类属性 config,类型为 GPTJConfig,dtype 默认为 jnp.float32
    config: GPTJConfig
    dtype: jnp.dtype = jnp.float32

    # 定义 setup 方法,用于初始化模块
    def setup(self):
        # 初始化 self.embed_dim 为 config.hidden_size
        self.embed_dim = self.config.hidden_size

        # 创建一个 nn.Embed 实例 self.wte,用于词嵌入
        self.wte = nn.Embed(
            self.config.vocab_size,
            self.config.hidden_size,
            # 使用正态分布初始化词嵌入权重
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
        )
        # 创建一个 nn.Dropout 实例 self.dropout,用于词嵌入后的 dropout
        self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
        # 创建一个 FlaxGPTJBlockCollection 实例 self.h,用于处理隐藏层
        self.h = FlaxGPTJBlockCollection(self.config, dtype=self.dtype)
        # 创建一个 nn.LayerNorm 实例 self.ln_f,用于最终的 Layer Normalization
        self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)

    # 定义 __call__ 方法,使对象可以像函数一样调用
    def __call__(
        self,
        input_ids,
        attention_mask,
        position_ids,
        deterministic=True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 实际执行的内容在 FlaxGPTJBlockCollection 的 __call__ 方法中完成,此处不作进一步解释
        pass
        ):
        # 使用 self.wte 将输入的整数数组转换为嵌入向量,数据类型为 'i4'
        input_embeds = self.wte(input_ids.astype("i4"))

        # 对输入的嵌入向量应用 dropout,根据 deterministic 参数决定是否确定性地应用
        hidden_states = self.dropout(input_embeds, deterministic=deterministic)

        # 将处理后的隐藏状态传入 self.h 进行处理,接收多个命名参数
        outputs = self.h(
            hidden_states,
            attention_mask,
            position_ids=position_ids,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从输出中获取第一个元素作为新的隐藏状态
        hidden_states = outputs[0]

        # 对新的隐藏状态应用 LayerNormalization,self.ln_f 是一个层标准化层
        hidden_states = self.ln_f(hidden_states)

        # 如果设置了 output_hidden_states 标志,将所有隐藏状态存储到 all_hidden_states 中
        if output_hidden_states:
            all_hidden_states = outputs[1] + (hidden_states,)
            outputs = (hidden_states, all_hidden_states) + outputs[2:]
        else:
            outputs = (hidden_states,) + outputs[1:]

        # 如果 return_dict 是 False,则返回 outputs 中不为 None 的所有元素
        if not return_dict:
            return tuple(v for v in outputs if v is not None)

        # 如果 return_dict 是 True,则返回 FlaxBaseModelOutput 对象,包含隐藏状态、所有隐藏状态和注意力
        return FlaxBaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=outputs[1],
            attentions=outputs[-1],
        )
# 用于添加起始文档字符串的装饰器函数,描述了FlaxGPTJModel类的基本功能和输出信息
@add_start_docstrings(
    "The bare GPTJ Model transformer outputting raw hidden-states without any specific head on top.",
    GPTJ_START_DOCSTRING,
)
# 将示例调用文档字符串添加到FlaxGPTJModel类中,包含了模型检查点、输出配置信息和配置输出的样本
append_call_sample_docstring(
    FlaxGPTJModel,
    _CHECKPOINT_FOR_DOC,
    FlaxCausalLMOutput,
    _CONFIG_FOR_DOC,
)

# 定义一个用于语言建模的Flax模块类,依赖于GPTJConfig配置,并设置了数据类型为32位浮点数
class FlaxGPTJForCausalLMModule(nn.Module):
    config: GPTJConfig
    dtype: jnp.dtype = jnp.float32

    # 模块设置函数,初始化transformer和lm_head
    def setup(self):
        self.transformer = FlaxGPTJModule(self.config, dtype=self.dtype)
        # 使用配置中的词汇表大小初始化lm_head的全连接层
        self.lm_head = nn.Dense(
            self.config.vocab_size,
            dtype=self.dtype,
            # 使用正态分布初始化全连接层的权重矩阵
            kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
        )

    # 模块的调用函数,接收多个参数和关键字参数,并返回语言建模的输出
    def __call__(
        self,
        input_ids,
        attention_mask,
        position_ids,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 使用transformer处理输入数据,返回各种输出
        outputs = self.transformer(
            input_ids,
            attention_mask,
            position_ids,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]  # 提取transformer的隐藏状态作为下一步处理的输入

        if self.config.tie_word_embeddings:
            # 如果配置要求共享词嵌入矩阵,则从transformer的参数中提取共享的权重矩阵,并应用到lm_head
            shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
            lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
        else:
            lm_logits = self.lm_head(hidden_states)  # 否则直接使用lm_head进行预测

        if not return_dict:
            # 如果不返回字典形式的输出,则返回一个元组,包括lm_logits和outputs的其余部分
            return (lm_logits,) + outputs[1:]

        # 返回格式化的语言建模输出对象,包括logits、隐藏状态和注意力权重(如果有的话)
        return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)


# 为FlaxGPTJForCausalLM类添加起始文档字符串,描述其为一个带有语言建模头的GPTJ模型变体
@add_start_docstrings(
    """
    The GPTJ Model transformer with a language modeling head on top.
    """,
    GPTJ_START_DOCSTRING,
)
class FlaxGPTJForCausalLM(FlaxGPTJPreTrainedModel):
    module_class = FlaxGPTJForCausalLMModule  # 指定模块类为FlaxGPTJForCausalLMModule
    def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
        # initializing the cache
        # 获取输入张量的批大小和序列长度
        batch_size, seq_length = input_ids.shape

        # 使用初始化方法初始化过去键值(用于缓存)
        past_key_values = self.init_cache(batch_size, max_length)
        
        # 注意:通常需要在 attention_mask 中的位置(大于 input_ids.shape[-1] 和小于 cache_length)放置0。
        # 但由于 GPTJ 使用因果掩码,这些位置已经被掩盖了。
        # 因此,我们可以在这里创建一个静态的 attention_mask,这对编译来说更有效率。
        
        # 创建一个扩展的 attention_mask,全部初始化为1
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
        
        # 如果传入了 attention_mask,则根据它的累积和位置更新 extended_attention_mask
        if attention_mask is not None:
            position_ids = attention_mask.cumsum(axis=-1) - 1
            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
        else:
            # 否则,使用广播方法创建位置索引,范围是 [0, seq_length)
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))

        # 返回包含过去键值、扩展后的 attention_mask 和位置索引的字典
        return {
            "past_key_values": past_key_values,
            "attention_mask": extended_attention_mask,
            "position_ids": position_ids,
        }

    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        # 更新用于生成的模型输入参数
        
        # 将模型输出的过去键值更新到模型参数中
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        
        # 更新位置索引,仅保留最后一个位置
        model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
        
        # 返回更新后的模型参数
        return model_kwargs

调用函数append_call_sample_docstring,添加样例文档字符串到指定类和对象的关联属性中

append_call_sample_docstring(
FlaxGPTJForCausalLM, # 目标类:FlaxGPTJForCausalLM
_CHECKPOINT_FOR_DOC, # 样例文档字符串的检查点
FlaxCausalLMOutput, # 输出类:FlaxCausalLMOutput
_CONFIG_FOR_DOC, # 样例文档字符串的配置
)


# `.\models\gptj\modeling_gptj.py`

```py
# coding=utf-8
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch GPT-J model."""

import warnings
from typing import Optional, Tuple, Union

import torch
import torch.fx
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
    is_torch_fx_proxy,
    logging,
)
from ...utils.model_parallel_utils import assert_device_map, get_device_map
from .configuration_gptj import GPTJConfig

# Check if flash attention v2 is available
if is_flash_attn_2_available():
    from flash_attn import flash_attn_func, flash_attn_varlen_func
    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa

# Get logger instance for this module
logger = logging.get_logger(__name__)

# Constants used for documentation and testing
_CHECKPOINT_FOR_DOC = "hf-internal-testing/tiny-random-gptj"
_REAL_CHECKPOINT_FOR_DOC = "EleutherAI/gpt-j-6B"
_CONFIG_FOR_DOC = "GPTJConfig"

# List of pretrained model archives for GPT-J
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "EleutherAI/gpt-j-6B",
    # See all GPT-J models at https://huggingface.co/models?filter=gptj
]

# Function to get unpad data based on attention mask
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
    max_seqlen_in_batch = seqlens_in_batch.max().item()
    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
    return (
        indices,
        cu_seqlens,
        max_seqlen_in_batch,
    )

# Function to create sinusoidal positions based on number of positions and dimension
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
    # Calculate inverse frequency for sinusoidal function
    inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
    # Generate sinusoidal inputs
    sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
    # Concatenate sine and cosine of sinusoidal inputs along dimension 1
    return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)

# Wrap the following function in TorchFX framework for symbolic tracing
@torch.fx.wrap
def get_embed_positions(embed_positions, position_ids):
    # 将嵌入位置张量转移到与位置 ID 张量相同的设备上,并重复多次以匹配位置 ID 张量的形状
    return embed_positions.to(position_ids.device).repeat(position_ids.shape[0], 1, 1)
# 定义一个函数,用于将输入张量的每个位置的偶数索引位置的数据提取出来
def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
    x1 = x[:, :, :, ::2]  # 提取偶数索引位置的数据
    x2 = x[:, :, :, 1::2]  # 提取奇数索引位置的数据
    x = torch.stack((-x2, x1), dim=-1)  # 将奇偶索引位置的数据组成新的张量,并进行堆叠
    return x.flatten(-2)  # 将最后两个维度展平,即将每对奇偶索引位置的数据合并成单个维度


# 定义一个函数,用于在给定的张量上应用旋转位置编码
def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
    sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)  # 在第3维上重复插入sin值,用于奇数索引位置
    cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)  # 在第3维上重复插入cos值,用于偶数索引位置
    return (tensor * cos) + (rotate_every_two(tensor) * sin)  # 应用旋转位置编码公式:tensor * cos + rotate_every_two(tensor) * sin


# 定义一个自注意力模块类,用于处理注意力机制相关的操作
class GPTJAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        max_positions = config.max_position_embeddings
        # 创建一个下三角矩阵作为偏置,用于掩码操作
        self.register_buffer(
            "bias",
            torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
                1, 1, max_positions, max_positions
            ),
            persistent=False,
        )
        self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)  # 创建一个掩码偏置

        self.attn_dropout = nn.Dropout(config.attn_pdrop)  # 注意力权重的dropout
        self.resid_dropout = nn.Dropout(config.resid_pdrop)  # 残差连接的dropout

        self.is_causal = True  # 是否是因果关系(用于自回归模型)

        self.embed_dim = config.hidden_size  # 嵌入维度大小
        self.num_attention_heads = config.num_attention_heads  # 注意力头的数量
        self.head_dim = self.embed_dim // self.num_attention_heads  # 每个注意力头的维度
        if self.head_dim * self.num_attention_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
                f" `num_attention_heads`: {self.num_attention_heads})."
            )
        self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())

        # 定义四个线性映射层,分别用于计算查询、键、值和输出
        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)

        self.rotary_dim = config.rotary_dim  # 旋转位置编码的维度
        pos_embd_dim = self.rotary_dim or self.embed_dim  # 位置编码的维度,默认为嵌入维度
        self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)  # 创建正弦位置编码
       
    # 将输入张量进行分头处理
    def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary):
        """
        Splits hidden dim into attn_head_size and num_attention_heads
        """
        new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)  # 重塑张量的形状,分成多个头
        tensor = tensor.view(new_shape)
        if rotary:
            return tensor  # 如果使用旋转位置编码,直接返回分头后的张量
        if len(tensor.shape) == 5:
            return tensor.permute(0, 1, 3, 2, 4)  # 调整维度顺序,适用于5维张量
        elif len(tensor.shape) == 4:
            return tensor.permute(0, 2, 1, 3)  # 调整维度顺序,适用于4维张量
        else:
            raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")  # 抛出维度错误异常
    def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
        """
        Merges attn_head_size dim and num_attn_heads dim into hidden dim
        """
        # 如果输入张量维度为5,则交换维度顺序使得 attn_head_size 和 num_attention_heads 维度合并到隐藏层维度
        if len(tensor.shape) == 5:
            tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
        # 如果输入张量维度为4,则交换维度顺序使得 attn_head_size 和 num_attention_heads 维度合并到隐藏层维度
        elif len(tensor.shape) == 4:
            tensor = tensor.permute(0, 2, 1, 3).contiguous()
        else:
            # 抛出异常,如果张量维度既不是4也不是5
            raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
        # 计算新的张量形状,将 attn_head_size 和 num_attention_heads 合并到最后两个维度
        new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
        return tensor.view(new_shape)

    def _attn(
        self,
        query,
        key,
        value,
        attention_mask=None,
        head_mask=None,
    ):
        # 从 causal_mask buffer 计算 causal mask
        query_length, key_length = query.size(-2), key.size(-2)
        causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]

        # 将 query 和 key 张量类型转换为 float32,以避免溢出问题
        query = query.to(torch.float32)
        key = key.to(torch.float32)

        # 计算注意力权重
        attn_weights = torch.matmul(query, key.transpose(-1, -2))

        # 设置 mask_value 为最小的浮点数,与 attn_weights 张量相同类型
        mask_value = torch.finfo(attn_weights.dtype).min
        mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
        # 根据 causal_mask,将不需要的位置设置为 mask_value
        attn_weights = torch.where(causal_mask, attn_weights, mask_value)

        # 根据缩放因子缩放注意力权重
        attn_weights = attn_weights / self.scale_attn

        if attention_mask is not None:
            # 应用额外的注意力掩码
            attn_weights = attn_weights + attention_mask

        # 使用 softmax 计算最终的注意力权重
        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
        # 将注意力权重转换为与 value 张量相同的数据类型
        attn_weights = attn_weights.to(value.dtype)
        # 应用注意力 dropout
        attn_weights = self.attn_dropout(attn_weights)

        if head_mask is not None:
            # 如果需要,对注意力权重进行头部掩码操作
            attn_weights = attn_weights * head_mask

        # 计算最终的注意力输出
        attn_output = torch.matmul(attn_weights, value)

        return attn_output, attn_weights

    def _get_embed_positions(self, position_ids):
        embed_positions = self.embed_positions
        # 如果 embed_positions 的设备与 position_ids 不同,则将其移到 position_ids 的设备上
        if embed_positions.device != position_ids.device:
            embed_positions = embed_positions.to(position_ids.device)
            self.embed_positions = embed_positions
        # 将 embed_positions 扩展到与 position_ids 的第一个维度相同
        return embed_positions.repeat(position_ids.shape[0], 1, 1)

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Union[
        Tuple[torch.Tensor, Tuple[torch.Tensor]],
        Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
    ]:
        # 对隐藏状态进行投影以获得查询、键和值
        query = self.q_proj(hidden_states)
        key = self.k_proj(hidden_states)
        value = self.v_proj(hidden_states)

        # 将查询、键、值按注意力头数和头维度拆分
        query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
        key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
        value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)

        if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing():
            # 在 torch.fx 框架中或正在追踪时,无法跟踪条件复制到 GPU 的逻辑,因此每次在 torch.fx 框架中执行此操作
            embed_positions = get_embed_positions(self.embed_positions, position_ids)
        else:
            # 获取嵌入位置
            embed_positions = self._get_embed_positions(position_ids)

        # 重复位置ID以匹配嵌入位置的形状
        repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
        sincos = torch.gather(embed_positions, 1, repeated_position_ids)
        sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)

        if self.rotary_dim is not None:
            # 如果存在旋转维度,则将键和查询分为旋转部分和传递部分,并应用旋转位置编码
            k_rot = key[:, :, :, : self.rotary_dim]
            k_pass = key[:, :, :, self.rotary_dim :]

            q_rot = query[:, :, :, : self.rotary_dim]
            q_pass = query[:, :, :, self.rotary_dim :]

            k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
            q_rot = apply_rotary_pos_emb(q_rot, sin, cos)

            key = torch.cat([k_rot, k_pass], dim=-1)
            query = torch.cat([q_rot, q_pass], dim=-1)
        else:
            # 否则,直接应用旋转位置编码到键和查询
            key = apply_rotary_pos_emb(key, sin, cos)
            query = apply_rotary_pos_emb(query, sin, cos)

        # 将键和查询的维度进行转置
        key = key.permute(0, 2, 1, 3)
        query = query.permute(0, 2, 1, 3)

        if layer_past is not None:
            # 如果存在过去的层状态,则将过去的键和值与当前的键和值连接起来
            past_key = layer_past[0]
            past_value = layer_past[1]
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)

        if use_cache is True:
            # 如果使用缓存,则返回带有浮点数类型的键和值,参考自 GitHub 上的实现
            present = (key.to(hidden_states.dtype), value)
        else:
            # 否则,不返回任何状态
            present = None

        # 计算自注意力输出和注意力权重
        attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

        # 合并注意力头并进行输出投影
        attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
        attn_output = self.out_proj(attn_output)
        attn_output = self.resid_dropout(attn_output)

        outputs = (attn_output, present)
        if output_attentions:
            # 如果需要输出注意力权重,则添加到输出中
            outputs += (attn_weights,)

        return outputs  # 返回注意力输出、状态以及(如果需要)注意力权重
class GPTJFlashAttention2(GPTJAttention):
    """
    GPTJ flash attention module. This module inherits from `GPTJAttention` as the weights of the module stays
    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
    flash attention and deal with padding tokens in case the input contains any of them.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
        # Flag to determine if flash attention uses top-left aligned mask
        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Union[
        Tuple[torch.Tensor, Tuple[torch.Tensor]],
        Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
    ]:
        """
        Forward pass of the GPTJFlashAttention2 module.
        
        Args:
        - hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size).
        - layer_past: Tuple of past key-value states.
        - attention_mask: Optional tensor with attention mask of shape (batch_size, seq_length).
        - position_ids: Optional tensor with position ids of shape (batch_size, seq_length).
        - head_mask: Optional tensor with mask for attention heads of shape (num_heads,).
        - use_cache: Optional boolean flag indicating whether to use caching.
        - output_attentions: Optional boolean flag indicating whether to output attention weights.
        
        Returns:
        - Tuple of output tensor and updated layer past.
        """
        # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
        def _flash_attention_forward(
            self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
        ):
            """
            Internal function to perform forward pass of flash attention.

            Args:
            - query_states: Query tensor of shape (batch_size, query_length, hidden_size).
            - key_states: Key tensor of shape (batch_size, key_length, hidden_size).
            - value_states: Value tensor of shape (batch_size, key_length, hidden_size).
            - attention_mask: Attention mask tensor of shape (batch_size, query_length, key_length).
            - query_length: Length of the query sequence.
            - dropout: Optional dropout rate.
            - softmax_scale: Optional scaling factor for softmax.

            Returns:
            - Tuple of output tensor and updated attention weights.
            """
            # Implementation of flash attention forward pass
            pass
    ):
        """
        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
        first unpad the input, then computes the attention scores and pad the final attention scores.

        Args:
            query_states (`torch.Tensor`):
                Input query states to be passed to Flash Attention API
            key_states (`torch.Tensor`):
                Input key states to be passed to Flash Attention API
            value_states (`torch.Tensor`):
                Input value states to be passed to Flash Attention API
            attention_mask (`torch.Tensor`):
                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
                position of padding tokens and 1 for the position of non-padding tokens.
            dropout (`float`):
                Attention dropout
            softmax_scale (`float`, *optional*):
                The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
        """
        # Determine if causal masking is needed based on configuration
        if not self._flash_attn_uses_top_left_mask:
            causal = self.is_causal  # Set causal to self.is_causal if not using top-left mask
        else:
            # Special case for RoCm compatibility: adjust causal based on query length
            # TODO: Remove this check after upgrading Flash Attention to version 2.1
            causal = self.is_causal and query_length != 1

        # Check if there are padding tokens in the input sequence
        if attention_mask is not None:
            batch_size = query_states.shape[0]  # Get the batch size from query states
            # Unpad the input sequences based on the attention mask
            query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
                query_states, key_states, value_states, attention_mask, query_length
            )

            # Extract sequence lengths after unpadding
            cu_seqlens_q, cu_seqlens_k = cu_seq_lens
            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

            # Perform variable-length Flash Attention computation
            attn_output_unpad = flash_attn_varlen_func(
                query_states,
                key_states,
                value_states,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=max_seqlen_in_batch_q,
                max_seqlen_k=max_seqlen_in_batch_k,
                dropout_p=dropout,
                softmax_scale=softmax_scale,
                causal=causal,
            )

            # Pad the attention output back to the original sequence length
            attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
        else:
            # If no padding mask is provided, perform standard Flash Attention
            attn_output = flash_attn_func(
                query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
            )

        return attn_output

    # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input with num_heads->num_attention_heads
    # 定义一个方法 `_upad_input`,接收以下参数:query_layer, key_layer, value_layer, attention_mask, query_length
    def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
        # 调用 `_get_unpad_data` 方法获取解压后的数据的索引、cu_seqlens_k 和 max_seqlen_in_batch_k
        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
        
        # 获取 key_layer 的形状信息
        batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

        # 将 key_layer 重塑成适合索引操作的形状,并按照 indices_k 进行索引
        key_layer = index_first_axis(
            key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
        )
        
        # 将 value_layer 重塑成适合索引操作的形状,并按照 indices_k 进行索引
        value_layer = index_first_axis(
            value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
        )
        
        # 如果 query_length 等于 kv_seq_len,则将 query_layer 重塑成适合索引操作的形状,并按照 indices_k 进行索引
        if query_length == kv_seq_len:
            query_layer = index_first_axis(
                query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads, head_dim), indices_k
            )
            cu_seqlens_q = cu_seqlens_k
            max_seqlen_in_batch_q = max_seqlen_in_batch_k
            indices_q = indices_k
        # 如果 query_length 等于 1,则设置 max_seqlen_in_batch_q 为 1,cu_seqlens_q 为一个序列,然后将 query_layer 压缩
        elif query_length == 1:
            max_seqlen_in_batch_q = 1
            cu_seqlens_q = torch.arange(
                batch_size + 1, dtype=torch.int32, device=query_layer.device
            )  # There is a memcpy here, that is very bad.
            indices_q = cu_seqlens_q[:-1]
            query_layer = query_layer.squeeze(1)
        else:
            # 对 attention_mask 进行切片,保留最后 query_length 列,然后调用 unpad_input 函数解压输入
            attention_mask = attention_mask[:, -query_length:]
            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

        # 返回更新后的 query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k)
        return (
            query_layer,
            key_layer,
            value_layer,
            indices_q,
            (cu_seqlens_q, cu_seqlens_k),
            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
        )
GPTJ_ATTENTION_CLASSES = {
    "eager": GPTJAttention,
    "flash_attention_2": GPTJFlashAttention2,
}

class GPTJMLP(nn.Module):
    def __init__(self, intermediate_size, config):  # in MLP: intermediate_size= 4 * embed_dim
        super().__init__()
        embed_dim = config.n_embd

        # 初始化输入层和输出层的线性变换
        self.fc_in = nn.Linear(embed_dim, intermediate_size)
        self.fc_out = nn.Linear(intermediate_size, embed_dim)

        # 选择激活函数
        self.act = ACT2FN[config.activation_function]
        # 添加dropout层,以减少过拟合风险
        self.dropout = nn.Dropout(config.resid_pdrop)

    def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTensor:
        # 输入数据进行线性变换和激活函数处理
        hidden_states = self.fc_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.fc_out(hidden_states)
        # 对输出结果进行dropout处理
        hidden_states = self.dropout(hidden_states)
        return hidden_states


class GPTJBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化层归一化和注意力机制类
        inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.attn = GPTJ_ATTENTION_CLASSES[config._attn_implementation](config)
        self.mlp = GPTJMLP(inner_dim, config)

    def forward(
        self,
        hidden_states: Optional[torch.FloatTensor],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
        # 残差连接
        residual = hidden_states
        # 应用层归一化
        hidden_states = self.ln_1(hidden_states)
        # 执行注意力机制
        attn_outputs = self.attn(
            hidden_states=hidden_states,
            layer_past=layer_past,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        # 获取注意力机制的输出
        attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)
        outputs = attn_outputs[1:]

        # 执行MLP前向传播
        feed_forward_hidden_states = self.mlp(hidden_states)
        # 结合注意力机制输出、MLP输出和残差连接结果
        hidden_states = attn_output + feed_forward_hidden_states + residual

        # 根据使用缓存选项决定输出
        if use_cache:
            outputs = (hidden_states,) + outputs
        else:
            outputs = (hidden_states,) + outputs[1:]

        return outputs  # hidden_states, present, (attentions)


class GPTJPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 配置模型的类
    config_class = GPTJConfig
    # 基础模型的前缀
    base_model_prefix = "transformer"
    # 模型支持并行处理
    is_parallelizable = True
    # 支持梯度检查点
    supports_gradient_checkpointing = True
    # 不需要分割的模块列表
    _no_split_modules = ["GPTJBlock"]
    # 定义类变量,用于指定在设备放置时跳过的键名
    _skip_keys_device_placement = "past_key_values"
    # 定义类变量,表示是否支持闪存注意力机制2
    _supports_flash_attn_2 = True
    
    # 初始化方法,接受任意位置参数和关键字参数,并调用父类的初始化方法
    def __init__(self, *inputs, **kwargs):
        super().__init__(*inputs, **kwargs)
    
    # 初始化神经网络模块的权重
    def _init_weights(self, module):
        """Initialize the weights."""
        # 如果模块是线性层
        if isinstance(module, (nn.Linear,)):
            # 使用正态分布初始化权重,均值为0,标准差为配置中的初始化范围
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 如果存在偏置项,则将其初始化为零
            if module.bias is not None:
                module.bias.data.zero_()
        # 如果模块是嵌入层
        elif isinstance(module, nn.Embedding):
            # 使用正态分布初始化权重,均值为0,标准差为配置中的初始化范围
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 如果指定了填充索引,则将填充索引对应的权重初始化为零
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        # 如果模块是层归一化层
        elif isinstance(module, nn.LayerNorm):
            # 将偏置项初始化为零
            module.bias.data.zero_()
            # 将权重初始化为1
            module.weight.data.fill_(1.0)
"""
    This string defines the documentation (docstring) for describing the model class `GPTJ`, which is a subclass of
    `torch.nn.Module` in PyTorch. Users should consult the PyTorch documentation for general usage and behavior of
    `torch.nn.Module`.

    Parameters:
        config (`GPTJConfig`): This parameter is expected to be an instance of `GPTJConfig`, which holds all the
            configuration parameters for the model. It does not load the model weights, only the configuration. To
            load the weights, users should refer to the `from_pretrained` method of `PreTrainedModel`.

    Note:
        - The docstring explains the purpose and usage of the `GPTJ` model class.
        - It provides guidance on the `config` parameter and where to load model weights from.
"""
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            # 输入序列标记在词汇表中的索引。
            # 可以使用 [`AutoTokenizer`] 获取这些索引。详见 [`PreTrainedTokenizer.encode`] 和 [`PreTrainedTokenizer.__call__`]。
            # [什么是输入 ID?](../glossary#input-ids)
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            # 遮盖掩码,用于在填充的标记索引上避免进行注意力计算。
            # 掩码值在 `[0, 1]` 范围内:
            # - 1 表示**未被遮盖**的标记,
            # - 0 表示**被遮盖**的标记。
            # [什么是注意力遮盖?](../glossary#attention-mask)
        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            # 分段标记索引,用于指示输入的第一部分和第二部分。
            # 索引在 `[0, 1]` 范围内:
            # - 0 对应 *句子 A* 的标记,
            # - 1 对应 *句子 B* 的标记。
            # [什么是标记类型 ID?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            # 输入序列中每个标记的位置索引,用于位置嵌入。
            # 索引选择范围是 `[0, config.n_positions - 1]`。
            # [什么是位置 ID?](../glossary#position-ids)
        head_mask (`torch.FloatTensor` of shape `(num_attention_heads,)` or `(n_layer, num_attention_heads)`, *optional*):
            # 自注意力模块中用于屏蔽选定头部的掩码。
            # 掩码值在 `[0, 1]` 范围内:
            # - 1 表示头部**未被屏蔽**,
            # - 0 表示头部**被屏蔽**。
        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_dim)`, *optional*):
            # 可选参数,可以直接传递嵌入表示而不是 `input_ids`。
            # 如果需要更多控制如何将 *input_ids* 索引转换为相关联的向量,而不是使用模型的内部嵌入查找矩阵,这将非常有用。
        output_attentions (`bool`, *optional*):
            # 是否返回所有注意力层的注意力张量。
            # 查看返回张量中的 `attentions` 以获取更多细节。
        output_hidden_states (`bool`, *optional*):
            # 是否返回所有层的隐藏状态。
            # 查看返回张量中的 `hidden_states` 以获取更多细节。
        return_dict (`bool`, *optional*):
            # 是否返回一个 [`~utils.ModelOutput`] 而不是一个简单的元组。
# 并行化功能的文档字符串,描述了该功能的实验性质以及如何使用设备映射来分配模型的注意力模块到多个设备上
PARALLELIZE_DOCSTRING = r"""
    This is an experimental feature and is a subject to change at a moment's notice. Uses a device map to distribute
    attention modules of the model across several devices. If no device map is given, it will evenly distribute blocks
    across all devices.

    Args:
        device_map (`Dict[int, list]`, optional, defaults to None):
            A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
            automatically mapped to the first device (for esoteric reasons). That means that the first device should
            have fewer attention modules mapped to it than other devices. For reference, the GPT-J models have the
            following number of attention modules:

                - gpt-j-6B: 28

    Example:

    ```
    # Here is an example of a device map on a machine with 4 GPUs using gpt-j-6B, which has a total of 28 attention modules:
    model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
    device_map = {
        0: [0, 1, 2, 3, 4, 5, 6],
        1: [7, 8, 9, 10, 11, 12, 13],
        2: [14, 15, 16, 17, 18, 19, 20],
        3: [21, 22, 23, 24, 25, 26, 27],
    }
    model.parallelize(device_map)
    ```
"""

# 反并行化功能的文档字符串,描述了将模型从模型并行状态移回 CPU 的过程
DEPARALLELIZE_DOCSTRING = r"""
    Moves the model to CPU from a model parallel state.

    Example:

    ```
    # On a 4 GPU machine with gpt-j-6B:
    model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
    device_map = {
        0: [0, 1, 2, 3, 4, 5, 6],
        1: [7, 8, 9, 10, 11, 12, 13],
        2: [14, 15, 16, 17, 18, 19, 20],
        3: [21, 22, 23, 24, 25, 26, 27],
    }
    model.parallelize(device_map)  # Splits the model across several devices
    model.deparallelize()  # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
    ```
"""

# GPT-J 模型的类定义,继承自 GPTJPreTrainedModel
@add_start_docstrings(
    "The bare GPT-J Model transformer outputting raw hidden-states without any specific head on top.",
    GPTJ_START_DOCSTRING,
)
class GPTJModel(GPTJPreTrainedModel):
    
    # 初始化方法,接受一个配置参数 config
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)

        # 初始化模型的一些基本属性
        self.embed_dim = config.n_embd  # 嵌入维度
        self.vocab_size = config.vocab_size  # 词汇表大小
        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)  # 词嵌入模块
        self.drop = nn.Dropout(config.embd_pdrop)  # Dropout 模块
        self.h = nn.ModuleList([GPTJBlock(config) for _ in range(config.n_layer)])  # 多层 GPTJBlock 模块列表
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)  # 最后一层的 LayerNorm 模块

        # 模型并行化相关的属性
        self.model_parallel = False  # 是否启用模型并行化,默认为 False
        self.device_map = None  # 设备映射,默认为 None
        self.gradient_checkpointing = False  # 是否启用梯度检查点,默认为 False

        # 初始化权重并应用最终处理
        self.post_init()

        # 根据配置决定是否使用 flash_attention_2 实现
        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"

    # 使用 PARALLELIZE_DOCSTRING 文档字符串装饰该方法
    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
    # 使用装饰器添加文档字符串,文档字符串内容来自 DEPARALLELIZE_DOCSTRING
    def deparallelize(self):
        # 发出警告,指出 `deparallelize` 方法即将在 Transformers v5 中移除
        warnings.warn(
            "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
            FutureWarning,
        )
        # 禁用模型并设置相关属性
        self.model_parallel = False
        self.device_map = None
        self.first_device = "cpu"
        self.last_device = "cpu"
        # 将输入嵌入层(self.wte)移动到 CPU
        self.wte = self.wte.to("cpu")
        # 将所有隐藏层(self.h)移动到 CPU
        for index in range(len(self.h)):
            self.h[index] = self.h[index].to("cpu")
        # 将最后一层归一化层(self.ln_f)移动到 CPU
        self.ln_f = self.ln_f.to("cpu")
        # 清空 CUDA 缓存
        torch.cuda.empty_cache()

    def get_input_embeddings(self):
        # 返回输入嵌入层(self.wte)
        return self.wte

    def set_input_embeddings(self, new_embeddings):
        # 设置新的输入嵌入层(self.wte)
        self.wte = new_embeddings

    @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=BaseModelOutputWithPast,
        config_class=_CONFIG_FOR_DOC,
        real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
# 使用装饰器添加文档字符串描述 GPT-J 模型,这是一个在语言建模头部之上的变压器模型
@add_start_docstrings(
    """
    The GPT-J Model transformer with a language modeling head on top.
    """,
    GPTJ_START_DOCSTRING,
)
# 定义 GPTJForCausalLM 类,继承自 GPTJPreTrainedModel
class GPTJForCausalLM(GPTJPreTrainedModel):
    # 定义一个列表,指定需要共享权重的键
    _tied_weights_keys = ["lm_head.weight"]

    # 初始化方法,接收一个配置参数 config
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)
        # 创建一个 GPTJModel 实例,使用给定的配置
        self.transformer = GPTJModel(config)
        # 创建一个线性层,将 GPTJ 模型的隐藏状态映射到词汇表大小的输出
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size)

        # Model parallel
        # 是否开启模型并行计算,默认为 False
        self.model_parallel = False
        # 设备映射,默认为 None
        self.device_map = None

        # 调用 post_init 方法,初始化权重并进行最终处理
        self.post_init()

    # 使用装饰器添加并行化方法的文档字符串描述
    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
        # 发出警告,表明此方法即将在后续版本中移除
        warnings.warn(
            "`GPTJForCausalLM.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
            " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
            " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
            " 0, 'transformer.h.1': 1, ...}",
            FutureWarning,
        )
        # 如果 device_map 为 None,则使用 get_device_map 方法生成一个设备映射
        # 该映射将模型层分配到不同的 GPU 设备上
        self.device_map = (
            get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        # 断言设备映射的正确性,确保每个模型层都被正确映射到对应设备
        assert_device_map(self.device_map, len(self.transformer.h))
        # 调用 GPTJModel 类的 parallelize 方法,根据设备映射进行模型并行化
        self.transformer.parallelize(self.device_map)
        # 将 lm_head 层移到 transformer 的第一个设备上
        self.lm_head = self.lm_head.to(self.transformer.first_device)
        # 标记模型已经进行了模型并行化
        self.model_parallel = True

    # 使用装饰器添加取消并行化方法的文档字符串描述
    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
    def deparallelize(self):
        # 发出警告,表明此方法即将在后续版本中移除
        warnings.warn(
            "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
            FutureWarning,
        )
        # 调用 GPTJModel 类的 deparallelize 方法,取消模型的并行化
        self.transformer.deparallelize()
        # 将 transformer 和 lm_head 层移回 CPU 上
        self.transformer = self.transformer.to("cpu")
        self.lm_head = self.lm_head.to("cpu")
        # 标记模型未进行模型并行化
        self.model_parallel = False
        # 清空 CUDA 缓存
        torch.cuda.empty_cache()

    # 返回 lm_head 层,用于获取输出的词嵌入
    def get_output_embeddings(self):
        return self.lm_head

    # 设置 lm_head 层的新词嵌入
    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings
    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
        token_type_ids = kwargs.get("token_type_ids", None)
        # 如果存在 past_key_values 参数,则忽略已经被包含在其中的输入 ID
        if past_key_values:
            # 计算 past_key_values 的长度,即历史输入的数量
            past_length = past_key_values[0][0].shape[2]

            # 检查输入 ID 的长度是否大于历史输入长度
            if input_ids.shape[1] > past_length:
                remove_prefix_length = past_length  # 移除的前缀长度为历史输入长度
            else:
                # 默认行为:仅保留最后一个输入 ID
                remove_prefix_length = input_ids.shape[1] - 1

            # 移除前缀长度对应的部分输入 ID
            input_ids = input_ids[:, remove_prefix_length:]
            # 如果存在 token_type_ids,则也相应地截取
            if token_type_ids is not None:
                token_type_ids = token_type_ids[:, -input_ids.shape[1] :]

        attention_mask = kwargs.get("attention_mask", None)
        position_ids = kwargs.get("position_ids", None)

        if attention_mask is not None and position_ids is None:
            # 在批量生成时动态创建 position_ids
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                # 如果存在 past_key_values,则仅保留对应的 position_ids
                position_ids = position_ids[:, -input_ids.shape[1] :]

        # 如果传入了 inputs_embeds,则仅在第一个生成步骤中使用它们
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        # 更新 model_inputs 字典,包含所有可能的模型输入参数
        model_inputs.update(
            {
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "position_ids": position_ids,
                "attention_mask": attention_mask,
                "token_type_ids": token_type_ids,
            }
        )

        return model_inputs

    @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=CausalLMOutputWithPast,
        config_class=_CONFIG_FOR_DOC,
        real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
    )
    # 此方法用于模型的前向传播,接受多个可能的输入参数
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        """
        # 根据需要决定是否返回字典类型的结果
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用Transformer模型进行前向传播
        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]

        # 如果使用模型并行化,则设置隐藏状态所在的设备
        if self.model_parallel:
            torch.cuda.set_device(self.transformer.first_device)
            hidden_states = hidden_states.to(self.lm_head.weight.device)

        # 确保在fp16下采样工作正常,并使用fp32计算损失以匹配mesh-tf版本
        # 参考链接: https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
        lm_logits = self.lm_head(hidden_states).to(torch.float32)

        loss = None
        if labels is not None:
            # 将标签移动到正确的设备以启用模型并行化
            labels = labels.to(lm_logits.device)
            # 将logits向左移动一位,以便对比预测下一个token
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # 展平tokens
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            loss = loss.to(hidden_states.dtype)

        # 如果不返回字典类型的结果,则组装输出
        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        # 返回包含过去键值的CausalLMOutputWithPast对象
        return CausalLMOutputWithPast(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

    @staticmethod
    def _reorder_cache(
        past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
        """
        This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
        [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
        beam_idx at every generation step.
        """

        # 返回一个元组,其中每个元素也是一个元组,表示重新排序后的 past_key_values
        return tuple(
            # 对于 past_key_values 中的每个 layer_past,进行如下操作
            tuple(
                # 对于 layer_past 中的每个 past_state,通过 index_select 方法按照 beam_idx 的索引重新排序,
                # 并将结果移到 past_state 的设备上
                past_state.index_select(0, beam_idx.to(past_state.device))
                for past_state in layer_past
            )
            # 遍历整个 past_key_values,对每个 layer_past 执行上述操作
            for layer_past in past_key_values
        )
"""
定义一个 GPT-J 模型,其顶部有一个序列分类头(线性层)。

[`GPTJForSequenceClassification`] 使用最后一个 token 来进行分类,与其他因果模型(如 GPT、GPT-2、GPT-Neo)类似。

由于它在最后一个 token 上进行分类,因此需要知道最后一个 token 的位置。如果配置中定义了 `pad_token_id`,则在每行中找到不是填充 token 的最后一个 token。如果没有定义 `pad_token_id`,则简单地取每个批次行的最后一个值。当传递 `inputs_embeds` 而不是 `input_ids` 时,由于无法猜测填充 token,它也采用相同的方式(取每行批次的最后一个值)。
"""
@add_start_docstrings(
    """
    The GPT-J Model transformer with a sequence classification head on top (linear layer).

    [`GPTJForSequenceClassification`] uses the last token in order to do the classification, as other causal models
    (e.g. GPT, GPT-2, GPT-Neo) do.

    Since it does classification on the last token, it requires to know the position of the last token. If a
    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
    each row of the batch).
    """,
    GPTJ_START_DOCSTRING,
)
class GPTJForSequenceClassification(GPTJPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.transformer = GPTJModel(config)
        self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)

        # Model parallel
        self.model_parallel = False
        self.device_map = None

        # Initialize weights and apply final processing
        self.post_init()

    @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint="ydshieh/tiny-random-gptj-for-sequence-classification",
        output_type=SequenceClassifierOutputWithPast,
        config_class=_CONFIG_FOR_DOC,
        real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        # Forward 方法,接受多种输入参数,用于执行模型的前向推理。
@add_start_docstrings(
    """
    The GPT-J Model transformer with a span classification head on top for extractive question-answering tasks like
    SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
    GPTJ_START_DOCSTRING,
)
class GPTJForQuestionAnswering(GPTJPreTrainedModel):
"""
    # 初始化函数,接受一个配置对象作为参数
    def __init__(self, config):
        # 调用父类的初始化方法,传入配置对象
        super().__init__(config)
        # 设置类属性 num_labels,从配置对象中获取标签数量
        self.num_labels = config.num_labels
        # 创建 GPTJModel 实例并将其赋给类属性 transformer,传入配置对象作为参数
        self.transformer = GPTJModel(config)
        # 创建一个线性层 nn.Linear,将隐藏层大小调整为标签数量,赋给类属性 qa_outputs
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        # Model parallel
        # 设置模型并行为 False
        self.model_parallel = False
        # 设备映射设为 None
        self.device_map = None

        # 调用自定义的初始化方法 post_init
        # 初始化权重并进行最终处理
        self.post_init()

    # 前向传播函数,接受多个可选的张量作为输入
    @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=QuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
        real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        start_positions: Optional[torch.LongTensor] = None,
        end_positions: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        ) -> Union[Tuple, QuestionAnsweringModelOutput]:
        r"""
        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        """
        # 初始化返回字典,若未提供则使用配置中的默认设置
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 使用transformer模型处理输入的各种参数,得到输出
        outputs = self.transformer(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从transformer模型的输出中取得序列输出
        sequence_output = outputs[0]

        # 使用qa_outputs模型处理序列输出,得到开始和结束的logits
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()

        total_loss = None
        if start_positions is not None and end_positions is not None:
            # 如果使用多GPU,则扩展维度
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1).to(start_logits.device)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1).to(end_logits.device)
            # 忽略超出模型输入范围的起始/结束位置
            ignored_index = start_logits.size(1)
            start_positions = start_positions.clamp(0, ignored_index)
            end_positions = end_positions.clamp(0, ignored_index)

            # 定义交叉熵损失函数,忽略指定的索引
            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

        # 如果不需要返回字典形式的结果,则返回元组
        if not return_dict:
            output = (start_logits, end_logits) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        # 如果需要返回字典形式的结果,则返回QuestionAnsweringModelOutput对象
        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

.\models\gptj\modeling_tf_gptj.py

# 设置文件编码为UTF-8,确保可以正确处理中文和其他特殊字符
# 版权声明,声明代码版权归EleutherAI和HuggingFace团队所有
#
# 根据Apache许可证2.0版,除非符合许可证规定,否则不得使用此文件
# 您可以在以下网址获取许可证的副本:http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则依据此许可证分发的软件是基于“按原样”分发的,
# 不附带任何明示或暗示的保证或条件。请参阅许可证获取更多详情。
""" TF 2.0 GPT-J模型 """

from __future__ import annotations

from typing import Optional, Tuple, Union

import numpy as np
import tensorflow as tf

# 导入自定义模块和函数
from ...activations_tf import get_tf_activation
from ...file_utils import (
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
)
from ...modeling_tf_outputs import (
    TFBaseModelOutputWithPast,
    TFCausalLMOutputWithPast,
    TFQuestionAnsweringModelOutput,
    TFSequenceClassifierOutputWithPast,
)
from ...modeling_tf_utils import (
    TFCausalLanguageModelingLoss,
    TFModelInputType,
    TFPreTrainedModel,
    TFQuestionAnsweringLoss,
    TFSequenceClassificationLoss,
    TFSharedEmbeddings,
    get_initializer,
    keras,
    keras_serializable,
    unpack_inputs,
)
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import logging
from .configuration_gptj import GPTJConfig

# 获取日志记录器
logger = logging.get_logger(__name__)

# 用于文档的模型检查点和配置
_CHECKPOINT_FOR_DOC = "EleutherAI/gpt-j-6B"
_CONFIG_FOR_DOC = "GPTJConfig"

# 预训练模型存档列表
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "EleutherAI/gpt-j-6B",
    # 更多GPT-J模型详见 https://huggingface.co/models?filter=gptj
]
    # 初始化方法,接受一个GPTJConfig对象和其他关键字参数
    def __init__(self, config: GPTJConfig, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)

        # 设置嵌入维度为隐藏大小
        self.embed_dim = config.hidden_size
        # 设置注意力头的数量
        self.num_attention_heads = config.num_attention_heads
        # 计算每个注意力头的维度
        self.head_dim = self.embed_dim // self.num_attention_heads
        # 检查embed_dim是否能被num_attention_heads整除
        if self.head_dim * self.num_attention_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
                f" `num_attention_heads`: {self.num_attention_heads})."
            )
        # 设置注意力的缩放因子
        self.scale_attn = self.head_dim**0.5
        # 设置旋转维度
        self.rotary_dim = config.rotary_dim

        # 设置注意力的dropout层
        self.attn_dropout = keras.layers.Dropout(config.attn_pdrop)
        # 设置残差连接的dropout层
        self.resid_dropout = keras.layers.Dropout(config.resid_pdrop)

        # 初始化查询投影层
        self.q_proj = keras.layers.Dense(
            self.embed_dim,
            use_bias=False,
            kernel_initializer=get_initializer(config.initializer_range),
            name="q_proj",
        )
        # 初始化键投影层
        self.k_proj = keras.layers.Dense(
            self.embed_dim,
            use_bias=False,
            kernel_initializer=get_initializer(config.initializer_range),
            name="k_proj",
        )
        # 初始化值投影层
        self.v_proj = keras.layers.Dense(
            self.embed_dim,
            use_bias=False,
            kernel_initializer=get_initializer(config.initializer_range),
            name="v_proj",
        )
        # 初始化输出投影层
        self.out_proj = keras.layers.Dense(
            self.embed_dim,
            use_bias=False,
            kernel_initializer=get_initializer(config.initializer_range),
            name="out_proj",
        )

        # 设置最大位置编码
        self.max_positions = config.max_position_embeddings
        # 创建一个下三角形的掩码矩阵
        self.lower_triangle_mask = tf.reshape(
            tf.cast(tf.experimental.numpy.tril(tf.ones((self.max_positions, self.max_positions))), tf.int8),
            (1, 1, self.max_positions, self.max_positions),
        )
        # 确定位置编码的维度
        pos_embd_dim = self.rotary_dim or self.embed_dim
        # 创建正弦位置编码
        self.embed_positions = create_sinusoidal_positions(self.max_positions, pos_embd_dim)

    # 获取因果掩码,用于自注意力机制
    def get_causal_mask(self, key_length, query_length) -> tf.Tensor:
        return tf.cast(self.lower_triangle_mask[:, :, key_length - query_length : key_length, :key_length], tf.bool)

    # 静态方法,返回一个用于掩码的偏置
    @staticmethod
    def get_masked_bias(dtype: tf.DType) -> tf.Tensor:
        return tf.cast(tf.constant(-1e9), dtype)
    def _split_heads(self, hidden_states: tf.Tensor, rotary: bool) -> tf.Tensor:
        """
        Splits hidden dim into attn_head_size and num_attention_heads
        """
        # Compute the new shape for splitting heads
        new_shape = shape_list(hidden_states)[:-1] + [self.num_attention_heads, self.head_dim]
        # Reshape the tensor to split heads
        hidden_states = tf.reshape(hidden_states, new_shape)
        if rotary:
            return hidden_states
        # Transpose tensor dimensions based on its rank
        if len(shape_list(hidden_states)) == 4:
            return tf.transpose(hidden_states, (0, 2, 1, 3))  # (batch, head, seq_length, head_features)
        elif len(shape_list(hidden_states)) == 5:
            return tf.transpose(hidden_states, (0, 1, 3, 2, 4))  # (batch, blocks, head, block_length, head_features)
        else:
            raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(shape_list(hidden_states))}")

    def _merge_heads(self, hidden_states: tf.Tensor) -> tf.Tensor:
        """
        Merges attn_head_size dim and num_attn_heads dim into hidden dim
        """
        # Transpose tensor dimensions to merge heads back
        if len(shape_list(hidden_states)) == 4:
            hidden_states = tf.transpose(hidden_states, (0, 2, 1, 3))
        elif len(shape_list(hidden_states)) == 5:
            hidden_states = tf.transpose(hidden_states, (0, 1, 3, 2, 4))
        else:
            raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(shape_list(hidden_states))}")
        # Compute the new shape after merging heads
        new_shape = shape_list(hidden_states)[:-2] + [self.num_attention_heads * self.head_dim]
        return tf.reshape(hidden_states, new_shape)

    def _attn(
        self,
        query: tf.Tensor,
        key: tf.Tensor,
        value: tf.Tensor,
        attention_mask: tf.Tensor | None = None,
        head_mask: tf.Tensor | None = None,
    ) -> Tuple[tf.Tensor, tf.Tensor]:
        # compute causal mask from causal mask buffer
        query_length, key_length = shape_list(query)[-2], shape_list(key)[-2]
        # Generate a causal mask for self-attention
        causal_mask = self.get_causal_mask(key_length, query_length)

        # Keep the attention weights computation in fp32 to avoid overflow issues
        query = tf.cast(query, tf.float32)
        key = tf.cast(key, tf.float32)

        # Compute attention weights
        attn_weights = tf.matmul(query, key, transpose_b=True)
        # Apply causal mask to attention weights
        attn_weights = tf.where(causal_mask, attn_weights, self.get_masked_bias(attn_weights.dtype))

        # Scale attention weights
        attn_weights = attn_weights / self.scale_attn

        if attention_mask is not None:
            # Apply additional attention mask
            attn_weights = attn_weights + attention_mask

        # Apply stable softmax to compute attention probabilities
        attn_weights = stable_softmax(attn_weights, axis=-1)
        attn_weights = tf.cast(attn_weights, value.dtype)
        # Apply dropout to attention weights
        attn_weights = self.attn_dropout(attn_weights)

        # Mask heads if specified
        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        # Compute attention output by weighted sum of values
        attn_output = tf.matmul(attn_weights, value)

        return attn_output, attn_weights
    # 定义一个方法,用于处理自注意力机制的计算,输入包括隐藏状态、过去的键值对、注意力掩码、位置编码、头掩码、缓存使用标志和是否输出注意力权重
    def call(
        self,
        hidden_states: tf.Tensor,  # 输入的隐藏状态张量
        layer_past: Optional[Tuple[tf.Tensor, tf.Tensor]] = None,  # 可选的过去层的键值对
        attention_mask: tf.Tensor | None = None,  # 注意力掩码张量,可为None
        position_ids: tf.Tensor | None = None,  # 位置编码张量,可为None
        head_mask: tf.Tensor | None = None,  # 头掩码张量,可为None
        use_cache: bool = False,  # 是否使用缓存,默认为False
        output_attentions: bool = False,  # 是否输出注意力权重,默认为False
    ):
        # 使用三个不同的线性投影来生成查询、键和值
        query = self.q_proj(hidden_states)
        key = self.k_proj(hidden_states)
        value = self.v_proj(hidden_states)

        # 将查询、键和值张量分割成多头
        query = self._split_heads(query, True)
        key = self._split_heads(key, True)
        value = self._split_heads(value, False)

        # 根据位置编码应用旋转位置嵌入(如果提供了旋转维度)
        sincos = tf.cast(tf.gather(self.embed_positions, position_ids, axis=0), hidden_states.dtype)
        sincos = tf.split(sincos, 2, axis=-1)
        if self.rotary_dim is not None:
            k_rot = key[:, :, :, : self.rotary_dim]
            k_pass = key[:, :, :, self.rotary_dim :]

            q_rot = query[:, :, :, : self.rotary_dim]
            q_pass = query[:, :, :, self.rotary_dim :]

            k_rot = apply_rotary_pos_emb(k_rot, sincos)
            q_rot = apply_rotary_pos_emb(q_rot, sincos)

            # 合并旋转后的部分和传递部分的键和查询
            key = tf.concat((k_rot, k_pass), axis=-1)
            query = tf.concat((q_rot, q_pass), axis=-1)
        else:
            key = apply_rotary_pos_emb(key, sincos)
            query = apply_rotary_pos_emb(query, sincos)

        # 转置键和查询张量的维度
        key = tf.transpose(key, (0, 2, 1, 3))
        query = tf.transpose(query, (0, 2, 1, 3))

        # 如果提供了过去的键值对,则将当前键和值与过去的键值对连接起来
        if layer_past is not None:
            past_key = layer_past[0]
            past_value = layer_past[1]
            key = tf.concat((past_key, key), axis=-2)
            value = tf.concat((past_value, value), axis=-2)

        # 如果设置了使用缓存,则将当前的键值对作为“present”返回,否则返回None
        if use_cache is True:
            present = (key, value)
        else:
            present = None

        # 计算自注意力机制的输出和注意力权重
        attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

        # 合并多头注意力的输出
        attn_output = self._merge_heads(attn_output)
        
        # 通过输出投影层处理注意力输出
        attn_output = self.out_proj(attn_output)
        
        # 应用残差连接和dropout到注意力输出
        attn_output = self.resid_dropout(attn_output)

        # 构造最终输出元组,包括注意力输出和可能的“present”和注意力权重(如果设置输出注意力权重)
        outputs = (attn_output, present)
        if output_attentions:
            outputs += (attn_weights,)

        return outputs  # 返回最终的输出元组:注意力输出、可能的“present”和(如果设置输出注意力权重)注意力权重
    # 如果模型已经建立,则直接返回,不进行重复建立
    if self.built:
        return
    
    # 标记模型为已建立状态
    self.built = True
    
    # 如果存在查询投影层对象,建立查询投影层,并指定输入形状为 [None, None, self.embed_dim]
    if getattr(self, "q_proj", None) is not None:
        with tf.name_scope(self.q_proj.name):
            self.q_proj.build([None, None, self.embed_dim])
    
    # 如果存在键投影层对象,建立键投影层,并指定输入形状为 [None, None, self.embed_dim]
    if getattr(self, "k_proj", None) is not None:
        with tf.name_scope(self.k_proj.name):
            self.k_proj.build([None, None, self.embed_dim])
    
    # 如果存在值投影层对象,建立值投影层,并指定输入形状为 [None, None, self.embed_dim]
    if getattr(self, "v_proj", None) is not None:
        with tf.name_scope(self.v_proj.name):
            self.v_proj.build([None, None, self.embed_dim])
    
    # 如果存在输出投影层对象,建立输出投影层,并指定输入形状为 [None, None, self.embed_dim]
    if getattr(self, "out_proj", None) is not None:
        with tf.name_scope(self.out_proj.name):
            self.out_proj.build([None, None, self.embed_dim])
class TFGPTJMLP(keras.layers.Layer):
    # 初始化函数,定义了模型层的各个组件和参数
    def __init__(self, intermediate_size: int, config: GPTJConfig, **kwargs):
        super().__init__(**kwargs)
        embed_dim = config.n_embd

        # 输入层全连接层,用于将输入向量映射到中间维度
        self.fc_in = keras.layers.Dense(
            intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="fc_in"
        )
        # 输出层全连接层,将中间维度映射回原始嵌入维度
        self.fc_out = keras.layers.Dense(
            embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="fc_out"
        )

        # 激活函数,根据配置选择合适的激活函数
        self.act = get_tf_activation(config.activation_function)
        # Dropout 层,用于防止过拟合
        self.dropout = keras.layers.Dropout(config.embd_pdrop)
        self.embed_dim = config.n_embd
        self.intermediate_size = intermediate_size

    # 前向传播函数,定义了层的计算流程
    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        # 输入向量经过输入层全连接层
        hidden_states = self.fc_in(hidden_states)
        # 应用激活函数
        hidden_states = self.act(hidden_states)
        # 经过输出层全连接层
        hidden_states = self.fc_out(hidden_states)
        # 应用 Dropout
        hidden_states = self.dropout(hidden_states)
        return hidden_states

    # 构建函数,用于构建层的参数
    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果输入层存在,构建输入层
        if getattr(self, "fc_in", None) is not None:
            with tf.name_scope(self.fc_in.name):
                self.fc_in.build([None, None, self.embed_dim])
        # 如果输出层存在,构建输出层
        if getattr(self, "fc_out", None) is not None:
            with tf.name_scope(self.fc_out.name):
                self.fc_out.build([None, None, self.intermediate_size])


class TFGPTJBlock(keras.layers.Layer):
    # 初始化函数,定义了模型层的各个组件和参数
    def __init__(self, config: GPTJConfig, **kwargs):
        super().__init__(**kwargs)
        # 内部维度,用于确定 MLP 层的中间维度大小
        inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
        # 第一层的 LayerNormalization 层
        self.ln_1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1")
        # 自注意力层
        self.attn = TFGPTJAttention(config, name="attn")
        # MLP 层,用于处理经过自注意力层后的隐藏状态
        self.mlp = TFGPTJMLP(inner_dim, config, name="mlp")
        self.config = config

    # 前向传播函数,定义了层的计算流程
    def call(
        self,
        hidden_states: tf.Tensor,
        layer_past: tf.Tensor | None = None,
        attention_mask: tf.Tensor | None = None,
        position_ids: tf.Tensor | None = None,
        head_mask: tf.Tensor | None = None,
        use_cache: bool = False,
        output_attentions: bool = False,
        # 使用自注意力的输出和输入的张量、位置的 ID 的张量和头部面具的张量的张量,以及缓存的缓存的布尔值使用的缓存的注意
        ):
            residual = hidden_states
            # 将隐藏状态进行 LayerNormalization
            hidden_states = self.ln_1(hidden_states)
            # 使用注意力机制进行计算
            attn_outputs = self.attn(
                hidden_states=hidden_states,
                layer_past=layer_past,
                attention_mask=attention_mask,
                position_ids=position_ids,
                head_mask=head_mask,
                use_cache=use_cache,
                output_attentions=output_attentions,
            )  # attn_outputs: attn_output, present, (attentions)
            # 获取注意力输出
            attn_output = attn_outputs[0]
            # 剩余连接和前馈神经网络
            feed_forward_hidden_states = self.mlp(hidden_states)
            hidden_states = attn_output + feed_forward_hidden_states + residual

            # 如果使用缓存,则输出包含隐藏状态
            if use_cache:
                outputs = (hidden_states,) + outputs
            else:
                # 否则,输出中去除第一个元素
                outputs = (hidden_states,) + outputs[1:]
            return outputs  # hidden_states, present, (attentions)

        def build(self, input_shape=None):
            if self.built:
                return
            self.built = True
            # 构建 LayerNormalization 层
            if getattr(self, "ln_1", None) is not None:
                with tf.name_scope(self.ln_1.name):
                    self.ln_1.build([None, None, self.config.n_embd])
            # 构建注意力机制
            if getattr(self, "attn", None) is not None:
                with tf.name_scope(self.attn.name):
                    self.attn.build(None)
            # 构建前馈神经网络
            if getattr(self, "mlp", None) is not None:
                with tf.name_scope(self.mlp.name):
                    self.mlp.build(None)
@keras_serializable
class TFGPTJMainLayer(keras.layers.Layer):
    # 使用 keras_serializable 装饰器,表明这是一个可以序列化的 Keras 层
    config_class = GPTJConfig
    # 设置类属性 config_class 为 GPTJConfig,这是用于配置模型的类

    def __init__(self, config: GPTJConfig, *inputs, **kwargs):
        super().__init__(*inputs, **kwargs)
        # 调用父类的初始化方法

        self.config = config
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
        self.use_cache = config.use_cache
        self.return_dict = config.use_return_dict
        # 初始化一些配置参数和控制输出的标志

        self.num_hidden_layers = config.n_layer
        self.n_embd = config.n_embd
        self.n_positions = config.n_positions
        self.initializer_range = config.initializer_range
        # 从配置中获取模型的隐藏层数、嵌入维度、位置编码数以及初始化范围等属性

        self.wte = TFSharedEmbeddings(
            config.vocab_size, config.hidden_size, initializer_range=config.initializer_range, name="wte"
        )
        # 初始化共享嵌入层对象,用于将输入的 token 序列转换为向量表示
        self.drop = keras.layers.Dropout(config.embd_pdrop)
        # 初始化 Dropout 层,用于在训练过程中随机丢弃部分嵌入层输出
        self.h = [TFGPTJBlock(config, name=f"h_._{i}") for i in range(config.n_layer)]
        # 初始化 GPTJBlock 的列表,每个 block 是 GPTJ 模型中的一个处理块,用于构建完整的模型
        self.ln_f = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f")
        # 初始化 LayerNormalization 层,用于对最终输出进行归一化处理
        self.embed_dim = config.n_embd
        # 设置嵌入维度属性

    def get_input_embeddings(self):
        return self.wte
        # 返回输入嵌入层对象

    def set_input_embeddings(self, value: tf.Tensor):
        self.wte.weight = value
        self.wte.vocab_size = shape_list(value)[0]
        # 设置输入嵌入层的权重和词汇大小属性

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        """
        raise NotImplementedError
        # 剪枝模型中的注意力头部,heads_to_prune 是一个字典,表示每个层需要剪枝的注意力头部列表

    @unpack_inputs
    def call(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        training=False,
    ):
        # 模型的前向传播函数,接收多个输入参数,并返回模型的输出

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果模型已经构建过,则直接返回

        if getattr(self, "wte", None) is not None:
            with tf.name_scope(self.wte.name):
                self.wte.build(None)
        # 如果存在 wte 属性,则调用其 build 方法构建嵌入层

        if getattr(self, "ln_f", None) is not None:
            with tf.name_scope(self.ln_f.name):
                self.ln_f.build([None, None, self.embed_dim])
        # 如果存在 ln_f 属性,则调用其 build 方法构建 LayerNormalization 层

        if getattr(self, "h", None) is not None:
            for layer in self.h:
                with tf.name_scope(layer.name):
                    layer.build(None)
        # 如果存在 h 属性,则遍历每个 GPTJBlock 层并调用其 build 方法构建模型中的处理块
    # 此模型继承自 `TFPreTrainedModel`。查看超类文档以了解库实现的通用方法,如下载或保存模型、调整输入嵌入大小、修剪头等。
    
    # 此模型也是一个 `keras.Model` 的子类。可以像普通的 TF 2.0 Keras 模型一样使用它,并参考 TF 2.0 文档了解一般用法和行为。
    
    # <Tip> 标签中的内容是关于 `transformers` 中 TensorFlow 模型和层接受输入的两种格式的说明:
    # - 使用关键字参数作为所有输入(类似于 PyTorch 模型)
    # - 将所有输入作为列表、元组或字典的第一个位置参数
    # 第二种格式的支持是因为 Keras 方法在传递输入给模型和层时更喜欢这种格式。因此,在使用 `model.fit()` 等方法时,只需以 `model.fit()` 支持的任何格式传递输入和标签即可正常工作!然而,如果想在 Keras 方法之外使用第二种格式,比如在使用 Keras `Functional` API 创建自己的层或模型时,有三种可能性可以用来收集所有输入张量到第一个位置参数中:
    # - 仅使用 `input_ids` 作为单个张量:`model(input_ids)`
    # - 使用变长列表,包含按照文档字符串中给定的顺序的一个或多个输入张量:`model([input_ids, attention_mask])` 或 `model([input_ids, attention_mask, token_type_ids])`
    # - 使用字典,将一个或多个输入张量与文档字符串中给定的输入名称相关联:`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
    # 注意,当使用子类化创建模型和层时,无需担心这些,可以像对待任何其他 Python 函数一样传递输入!
    
    # Parameters: 部分描述了模型的参数:
    # - config (`GPTJConfig` 类型):包含模型所有参数的模型配置类。
    #   初始化配置文件时不会加载与模型关联的权重,仅加载配置。查看 `~TFPreTrainedModel.from_pretrained` 方法以加载模型权重。
"""

GPTJ_INPUTS_DOCSTRING = r"""
"""


@add_start_docstrings(
    "The bare GPT-J Model transformer outputting raw hidden-states without any specific head on top.",
    GPTJ_START_DOCSTRING,
)
class TFGPTJModel(TFGPTJPreTrainedModel):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        # 初始化时设置 GPT-J 主层
        self.transformer = TFGPTJMainLayer(config, name="transformer")

    @unpack_inputs
    @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TFBaseModelOutputWithPast,
        config_class=_CONFIG_FOR_DOC,
    )
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        token_type_ids: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: Optional[bool] = False,
    ) -> Union[TFBaseModelOutputWithPast, Tuple[tf.Tensor]]:
        r"""
        use_cache (`bool`, *optional*, defaults to `True`):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past`). Set to `False` during training, `True` during generation
        """
        # 调用 GPT-J 主层的前向传播函数,返回输出结果
        outputs = self.transformer(
            input_ids=input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        return outputs

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "transformer", None) is not None:
            with tf.name_scope(self.transformer.name):
                # 构建 GPT-J 主层
                self.transformer.build(None)


@add_start_docstrings(
    """
    The GPT-J Model transformer with a language modeling head on top.
    """,
    GPTJ_START_DOCSTRING,
)
class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
    # 这里会定义带有语言建模头部的 GPT-J 模型
    # 初始化方法,接收配置和其他输入参数,并调用父类的初始化方法
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        # 创建一个名为transformer的TFGPTJMainLayer对象,使用给定的配置
        self.transformer = TFGPTJMainLayer(config, name="transformer")
        # 创建一个全连接层,称为lm_head,用于语言模型的输出预测
        self.lm_head = keras.layers.Dense(
            config.vocab_size, kernel_initializer=get_initializer(config.initializer_range), name="lm_head"
        )
        # 将配置信息保存在实例变量中
        self.config = config

    # 返回lm_head,用于获取输出的嵌入表示
    def get_output_embeddings(self):
        return self.lm_head

    # 设置lm_head的值为new_embeddings,用于更新输出的嵌入表示
    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    # 准备用于生成的输入数据,根据传入的参数设置不同的输入
    def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs):
        token_type_ids = kwargs.get("token_type_ids", None)
        # 如果past_key_values存在,只使用输入的最后一个token
        if past_key_values:
            inputs = tf.expand_dims(inputs[:, -1], -1)
            # 如果存在token_type_ids,则也只使用最后一个token的类型
            if token_type_ids is not None:
                token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1)

        position_ids = kwargs.get("position_ids", None)
        attention_mask = kwargs.get("attention_mask", None)

        # 如果attention_mask存在而position_ids不存在,则根据attention_mask计算position_ids
        if attention_mask is not None and position_ids is None:
            position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
            # 如果past_key_values存在,只使用计算后的position_ids的最后一个值

        # 返回一个包含所有生成输入的字典
        return {
            "input_ids": inputs,
            "attention_mask": attention_mask,
            "position_ids": position_ids,
            "past_key_values": past_key_values,
            "use_cache": use_cache,
            "token_type_ids": token_type_ids,
        }

    # 调用模型的方法,接收多种输入参数,并进行前向传播计算
    @unpack_inputs
    @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TFCausalLMOutputWithPast,
        config_class=_CONFIG_FOR_DOC,
    )
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        token_type_ids: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        labels: np.ndarray | tf.Tensor | None = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: Optional[bool] = False,
        # 可选的训练参数,控制是否返回字典形式的输出
    ) -> Union[TFCausalLMOutputWithPast, Tuple[tf.Tensor]]:
        r"""
        labels (`np.ndarray` or `tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        """

        # 获取transformer模型的输出,根据传入的参数进行调用
        transformer_outputs = self.transformer(
            input_ids=input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )
        # 从transformer的输出中获取隐藏状态
        hidden_states = transformer_outputs[0]
        # 使用语言模型头部生成logits
        lm_logits = self.lm_head(hidden_states)

        # 初始化损失为None
        loss = None
        # 如果提供了标签(labels)
        if labels is not None:
            # 将logits向左移动一位并截断最后一个logit token
            shifted_logits = lm_logits[:, :-1]
            # 将标签向右移动一位以匹配shifted_logits的长度
            labels = labels[:, 1:]
            # 计算损失函数
            loss = self.hf_compute_loss(labels, shifted_logits)

        # 如果不需要返回字典,则按照非字典格式返回输出
        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        # 如果需要返回字典格式的输出,创建TFCausalLMOutputWithPast对象
        return TFCausalLMOutputWithPast(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

    def build(self, input_shape=None):
        # 如果模型已经构建,则直接返回
        if self.built:
            return
        # 标记模型为已构建状态
        self.built = True
        # 如果存在transformer模型,则构建transformer模型
        if getattr(self, "transformer", None) is not None:
            with tf.name_scope(self.transformer.name):
                self.transformer.build(None)
        # 如果存在lm_head模型,则构建lm_head模型
        if getattr(self, "lm_head", None) is not None:
            with tf.name_scope(self.lm_head.name):
                self.lm_head.build([None, None, self.config.n_embd])
    """
    The `build` method for constructing the TFGPTJForSequenceClassification model architecture.

    Ensures the model is built correctly by initializing necessary layers based on input shape and configuration.
    """
    def build(self, input_shape=None):
        # 如果已经构建过,直接返回,避免重复构建
        if self.built:
            return
        # 标记模型已经构建
        self.built = True
        
        # 如果存在 transformer 层,则构建 transformer 层
        if getattr(self, "transformer", None) is not None:
            with tf.name_scope(self.transformer.name):
                self.transformer.build(None)
        
        # 如果存在 score 层,则构建 score 层,其中输出的形状为 [None, None, self.config.n_embd]
        if getattr(self, "score", None) is not None:
            with tf.name_scope(self.score.name):
                self.score.build([None, None, self.config.n_embd])
    ```
    The GPT-J Model transformer with a span classification head on top for extractive question-answering tasks like
    SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
    GPTJ_START_DOCSTRING,
    )
    # TFGPTJForQuestionAnswering 类的定义,继承自 TFGPTJPreTrainedModel 和 TFQuestionAnsweringLoss
    class TFGPTJForQuestionAnswering(TFGPTJPreTrainedModel, TFQuestionAnsweringLoss):
        # 加载时忽略的键列表,用于处理缺失情况
        _keys_to_ignore_on_load_missing = [r"h.\d+.attn.masked_bias", r"h.\d+.attn.bias", r"lm_head.weight"]

        def __init__(self, config, *inputs, **kwargs):
            super().__init__(config, *inputs, **kwargs)
            # 初始化时设置类别数量
            self.num_labels = config.num_labels
            # 创建 GPT-J 主层实例,并命名为 "transformer"
            self.transformer = TFGPTJMainLayer(config, name="transformer")
            # 初始化问答输出层,使用 Dense 层,内核初始化方式根据配置的初始化范围确定
            self.qa_outputs = keras.layers.Dense(
                self.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
            )
            # 保存配置对象的引用
            self.config = config

        @unpack_inputs
        # 将输入解包并添加到模型前向传播的文档字符串中,描述了输入的格式
        @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
        # 添加代码示例的文档字符串,描述了如何使用模型进行问答任务
        @add_code_sample_docstrings(
            checkpoint=_CHECKPOINT_FOR_DOC,
            output_type=TFQuestionAnsweringModelOutput,
            config_class=_CONFIG_FOR_DOC,
        )
        # 模型的前向传播函数,接收多个输入参数,包括输入的特征、位置编码、注意力掩码等
        def call(
            self,
            input_ids: TFModelInputType | None = None,
            past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
            attention_mask: np.ndarray | tf.Tensor | None = None,
            token_type_ids: np.ndarray | tf.Tensor | None = None,
            position_ids: np.ndarray | tf.Tensor | None = None,
            head_mask: np.ndarray | tf.Tensor | None = None,
            inputs_embeds: np.ndarray | tf.Tensor | None = None,
            start_positions: np.ndarray | tf.Tensor | None = None,
            end_positions: np.ndarray | tf.Tensor | None = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            training: Optional[bool] = False,
    ) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
        r"""
        定义函数的签名和返回类型注解,指定函数返回的类型是 TFQuestionAnsweringModelOutput 或者 (tf.Tensor, tf.Tensor) 的元组。
        """

        transformer_outputs = self.transformer(
            input_ids=input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )
        # 获取transformer模型的输出
        sequence_output = transformer_outputs[0]

        # 通过qa_outputs模型计算起始位置和结束位置的logits
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = tf.split(logits, 2, axis=-1)
        # 去除最后一个维度为1的维度
        start_logits = tf.squeeze(start_logits, axis=-1)
        end_logits = tf.squeeze(end_logits, axis=-1)

        loss = None
        # 如果提供了起始和结束位置的标签,则计算损失
        if start_positions is not None and end_positions is not None:
            labels = {"start_position": start_positions}
            labels["end_position"] = end_positions
            loss = self.hf_compute_loss(labels, (start_logits, end_logits))

        # 如果不返回字典,则按照元组的方式返回输出
        if not return_dict:
            output = (start_logits, end_logits) + transformer_outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 如果返回字典,则构造 TFQuestionAnsweringModelOutput 对象
        return TFQuestionAnsweringModelOutput(
            loss=loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果模型已经构建,则直接返回
        if getattr(self, "transformer", None) is not None:
            with tf.name_scope(self.transformer.name):
                # 构建transformer模型
                self.transformer.build(None)
        # 如果qa_outputs存在,则构建qa_outputs模型
        if getattr(self, "qa_outputs", None) is not None:
            with tf.name_scope(self.qa_outputs.name):
                self.qa_outputs.build([None, None, self.config.hidden_size])

.\models\gptj\__init__.py

# 版权声明和许可信息
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 引入类型检查
from typing import TYPE_CHECKING

# 引入必要的模块和函数
from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_flax_available,
    is_tf_available,
    is_torch_available,
)

# 定义模块的导入结构,用于延迟加载模块
_import_structure = {"configuration_gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig", "GPTJOnnxConfig"]}

# 检查是否支持 torch 库,如果不支持则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果支持 torch 库,则添加 torch 下相关模块到导入结构
    _import_structure["modeling_gptj"] = [
        "GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST",
        "GPTJForCausalLM",
        "GPTJForQuestionAnswering",
        "GPTJForSequenceClassification",
        "GPTJModel",
        "GPTJPreTrainedModel",
    ]

# 检查是否支持 tensorflow 库,如果不支持则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果支持 tensorflow 库,则添加 tensorflow 下相关模块到导入结构
    _import_structure["modeling_tf_gptj"] = [
        "TFGPTJForCausalLM",
        "TFGPTJForQuestionAnswering",
        "TFGPTJForSequenceClassification",
        "TFGPTJModel",
        "TFGPTJPreTrainedModel",
    ]

# 检查是否支持 flax 库,如果不支持则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_flax_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果支持 flax 库,则添加 flax 下相关模块到导入结构
    _import_structure["modeling_flax_gptj"] = [
        "FlaxGPTJForCausalLM",
        "FlaxGPTJModel",
        "FlaxGPTJPreTrainedModel",
    ]

# 如果是类型检查模式,执行以下导入
if TYPE_CHECKING:
    # 从相应模块导入配置和模型类
    from .configuration_gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig, GPTJOnnxConfig

    try:
        # 检查是否支持 torch 库,如果不支持则抛出 OptionalDependencyNotAvailable 异常
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果支持 torch 库,则从 modeling_gptj 模块中导入相关类
        from .modeling_gptj import (
            GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST,
            GPTJForCausalLM,
            GPTJForQuestionAnswering,
            GPTJForSequenceClassification,
            GPTJModel,
            GPTJPreTrainedModel,
        )

    try:
        # 检查是否支持 tensorflow 库,如果不支持则抛出 OptionalDependencyNotAvailable 异常
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果支持 tensorflow 库,则从 modeling_tf_gptj 模块中导入相关类
        from .modeling_tf_gptj import (
            TFGPTJForCausalLM,
            TFGPTJForQuestionAnswering,
            TFGPTJForSequenceClassification,
            TFGPTJModel,
            TFGPTJPreTrainedModel,
        )

    try:
        # 检查是否支持 flax 库,如果不支持则抛出 OptionalDependencyNotAvailable 异常
        if not is_flax_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果支持 flax 库,则从 modeling_flax_gptj 模块中导入相关类
        from .modeling_flax_gptj import (
            FlaxGPTJForCausalLM,
            FlaxGPTJModel,
            FlaxGPTJPreTrainedModel,
        )
    # 捕获 OptionalDependencyNotAvailable 异常,如果发生则不做任何操作
    except OptionalDependencyNotAvailable:
        pass
    # 如果未发生异常,则导入以下模块
    else:
        from .modeling_flax_gptj import FlaxGPTJForCausalLM, FlaxGPTJModel, FlaxGPTJPreTrainedModel
else:
    # 导入 sys 模块,用于动态设置当前模块为懒加载模块
    import sys
    
    # 使用 sys.modules[__name__] 将当前模块注册为 _LazyModule 的实例,
    # __name__ 是当前模块的名称,__file__ 是当前模块的文件名,
    # _import_structure 是导入结构,module_spec=__spec__ 是模块规范
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\gptsan_japanese\configuration_gptsan_japanese.py

# coding=utf-8
# 指定代码文件的编码格式为UTF-8

# Copyright 2023, HuggingFace Inc.
# 版权声明,版权归HuggingFace Inc.所有,日期为2023年

# Licensed under the Apache License, Version 2.0 (the "License");
# 根据 Apache License, Version 2.0 许可证授权使用本文件

# you may not use this file except in compliance with the License.
# 除非遵守许可证的规定,否则不得使用本文件

# You may obtain a copy of the License at
# 可以在以下网址获取许可证的副本

#     http://www.apache.org/licenses/LICENSE-2.0
#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 除非法律要求或书面同意,否则按"原样"分发软件,无论是明示还是隐含的,不提供任何担保或条件

# See the License for the specific language governing permissions and
# limitations under the License.
# 请查阅许可证了解特定语言的授权内容及限制

"""  GPTSAN-japanese model configuration"""
# 模型配置的文档字符串说明,这是GPTSAN-japanese模型的配置

from ...configuration_utils import PretrainedConfig
# 导入PretrainedConfig类,用于存储预训练配置信息

from ...utils import logging
# 导入logging工具类,用于记录日志

logger = logging.get_logger(__name__)
# 获取当前模块的日志记录器对象

GPTSAN_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "tanreinama/GPTSAN-2.8B-spout_is_uniform": (
        "https://huggingface.co/tanreinama/GPTSAN-2.8B-spout_is_uniform/resolve/main/config.json"
    ),
}
# 预训练配置映射表,将模型名称映射到其配置文件的URL

class GPTSanJapaneseConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`GPTSanJapaneseModel`]. It is used to instantiate
    a GPTSANJapanese model according to the specified arguments, defining the model architecture. Instantiating a
    configuration with the defaults will yield a similar configuration to that of the GPTSANJapanese
    [Tanrei/GPTSAN-japanese](https://huggingface.co/Tanrei/GPTSAN-japanese) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    """
    # GPTSanJapaneseConfig类的文档字符串,用于存储GPTSanJapaneseModel的配置信息

    model_type = "gptsan-japanese"
    # 模型类型定义为"gptsan-japanese"

    keys_to_ignore_at_inference = [
        "past_key_values",
    ]
    # 推理过程中忽略的键列表,在推理时不使用"past_key_values"

    attribute_map = {
        "hidden_size": "d_model",
        "num_attention_heads": "num_heads",
        "num_hidden_layers": "num_layers",
    }
    # 属性映射字典,将配置中的部分属性名映射为其他名称,如"hidden_size"映射为"d_model"

    def __init__(
        self,
        vocab_size=36000,
        max_position_embeddings=1280,
        d_model=1024,
        d_ff=8192,
        d_ext=4096,
        d_spout=128,
        num_switch_layers=10,
        num_ext_layers=0,
        num_heads=16,
        num_experts=16,
        expert_capacity=128,
        dropout_rate=0.0,
        layer_norm_epsilon=1e-5,
        router_bias=False,
        router_jitter_noise=0.0,
        router_dtype="float32",
        router_ignore_padding_tokens=False,
        output_hidden_states=False,
        output_attentions=False,
        initializer_factor=0.002,
        output_router_logits=False,
        use_cache=True,
        separator_token_id=35998,
        pad_token_id=35995,
        eos_token_id=35999,
        **kwargs,
    ):
        # 初始化方法,用于创建一个新的GPTSanJapaneseConfig对象,设置模型的各种配置参数及其默认值
        ):
        # 初始化 TransformerXLConfig 类的实例,设定模型的各种超参数
        self.vocab_size = vocab_size  # 词汇表大小
        self.max_position_embeddings = max_position_embeddings  # 最大位置嵌入数
        self.d_model = d_model  # 模型的隐藏层大小
        self.d_ff = d_ff  # 前向传播神经网络中间层的大小
        self.d_ext = d_ext  # 扩展层的大小
        self.d_spout = d_spout  # 接口层的大小
        self.num_switch_layers = num_switch_layers  # 切换层的数量
        self.num_ext_layers = num_ext_layers  # 扩展层的数量
        self.num_layers = num_switch_layers + num_ext_layers  # 总层数
        self.num_heads = num_heads  # 注意力头的数量
        self.num_experts = num_experts  # 专家的数量
        self.expert_capacity = expert_capacity  # 专家的容量
        self.dropout_rate = dropout_rate  # 丢弃率
        self.layer_norm_epsilon = layer_norm_epsilon  # 层归一化的 epsilon 参数
        self.router_bias = router_bias  # 路由器的偏置
        self.router_jitter_noise = router_jitter_noise  # 路由器的抖动噪声
        self.router_dtype = router_dtype  # 路由器的数据类型
        self.router_ignore_padding_tokens = router_ignore_padding_tokens  # 是否忽略填充标记的路由
        self.output_hidden_states = output_hidden_states  # 是否输出隐藏状态
        self.output_attentions = output_attentions  # 是否输出注意力权重
        self.initializer_factor = initializer_factor  # 初始化因子
        self.output_router_logits = output_router_logits  # 是否输出路由器的对数
        self.use_cache = use_cache  # 是否使用缓存

        # 调用父类 TransformerXLConfig 的初始化方法,设置分隔符、填充符、终止符等参数
        super().__init__(
            separator_token_id=separator_token_id,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            **kwargs,
        )

.\models\gptsan_japanese\convert_gptsan_tf_checkpoint_to_pytorch.py

# 导入 argparse 模块,用于处理命令行参数解析
import argparse
# 导入 json 模块,用于读取和解析 JSON 格式的文件
import json
# 导入 os 模块,提供与操作系统交互的功能
import os
# 从 collections 模块中导入 OrderedDict 类,用于创建有序字典
from collections import OrderedDict

# 导入 numpy 库,一般用于科学计算,这里可能在后续的代码中使用
import numpy as np
# 导入 tensorflow 库,用于与 TensorFlow 模型相关的操作
import tensorflow as tf
# 导入 torch 库,用于与 PyTorch 模型相关的操作
import torch


# 定义函数 convert_tf_gptsan_to_pt,用于将 TensorFlow 模型转换为 PyTorch 模型
def convert_tf_gptsan_to_pt(args):
    # 构建参数文件的完整路径
    parameter_file = os.path.join(args.tf_model_dir, "parameters.json")
    # 读取并解析 JSON 格式的参数文件
    params = json.loads(open(parameter_file).read())
    # 如果参数文件为空,则抛出 ValueError 异常
    if not params:
        raise ValueError(
            f"It seems that the json file at {parameter_file} is empty. Make sure you have a correct json file."
        )
    # 如果输出路径不以 ".pt" 结尾,则自动添加 ".pt" 后缀
    if not args.output.endswith(".pt"):
        args.output = args.output + ".pt"
    # 创建一个空的有序字典 new_state
    new_state = OrderedDict()
    # 使用 PyTorch 的 torch.save 方法将 new_state 保存到指定的输出路径 args.output
    torch.save(new_state, args.output)


# 如果当前脚本作为主程序执行
if __name__ == "__main__":
    # 创建 ArgumentParser 对象 parser,用于解析命令行参数
    parser = argparse.ArgumentParser(
        description="model converter.", formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    # 添加命令行参数 --tf_model_dir,指定 TensorFlow 模型的路径,类型为字符串,必需参数
    parser.add_argument("--tf_model_dir", metavar="PATH", type=str, required=True, help="import model")
    # 添加命令行参数 --output,指定输出 PyTorch 模型的路径,类型为字符串,必需参数
    parser.add_argument("--output", metavar="PATH", type=str, required=True, help="output model")
    # 解析命令行参数,将结果存储在 args 对象中
    args = parser.parse_args()
    # 调用 convert_tf_gptsan_to_pt 函数,传入 args 对象,执行 TensorFlow 到 PyTorch 模型的转换
    convert_tf_gptsan_to_pt(args)

.\models\gptsan_japanese\modeling_gptsan_japanese.py

# 编码声明,指定文件编码为UTF-8
# Copyright声明及许可信息,指明版权归属及许可协议
# 引入模块:copy模块用于复制对象;List, Optional, Tuple, Union用于类型提示

import copy  # 引入copy模块,用于对象复制
from typing import List, Optional, Tuple, Union  # 引入类型提示

import torch  # 引入PyTorch库
import torch.nn as nn  # 引入PyTorch的神经网络模块

# 引入相关模块和函数
from ...activations import ACT2FN  # 从本地模块引入ACT2FN函数
from ...modeling_outputs import MoECausalLMOutputWithPast, MoEModelOutputWithPastAndCrossAttentions  # 从本地模块引入模型输出相关类
from ...modeling_utils import PreTrainedModel  # 从本地模块引入PreTrainedModel类
from ...utils import (  # 从本地模块引入多个函数和常量
    DUMMY_INPUTS,  # 引入DUMMY_INPUTS常量
    DUMMY_MASK,  # 引入DUMMY_MASK常量
    add_start_docstrings,  # 引入add_start_docstrings函数
    add_start_docstrings_to_model_forward,  # 引入add_start_docstrings_to_model_forward函数
    is_torch_fx_proxy,  # 引入is_torch_fx_proxy函数
    logging,  # 引入logging模块
)
from .configuration_gptsan_japanese import GPTSanJapaneseConfig  # 从本地模块引入GPTSanJapaneseConfig类

logger = logging.get_logger(__name__)  # 获取当前模块的logger对象

_CONFIG_FOR_DOC = "GPTSanJapaneseConfig"  # 模型配置文档的名称
_CHECKPOINT_FOR_DOC = "Tanrei/GPTSAN-japanese"  # 预训练模型的检查点名称

####################################################
# This dict contains ids and associated url
# for the pretrained weights provided with the models
####################################################
# 预训练模型的ID和关联的URL列表
GPTSAN_JAPANESE_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "Tanrei/GPTSAN-japanese",  # 预训练模型的名称和路径
    # 更多预训练模型详见https://huggingface.co/models?filter=gptsan-japanese
]


# 从transformers.models.switch_transformers.modeling_switch_transformers.router_z_loss_func中复制的函数
def router_z_loss_func(router_logits: torch.Tensor) -> float:
    r"""
    计算PyTorch中实现的路由器z-loss。

    Args:
        router_logits (`float`):
            输入的logits张量,形状为 [batch_size, sequence_length, num_experts]

    Returns:
        标量路由器z-loss。
    """
    num_groups, tokens_per_group, _ = router_logits.shape  # 获取logits张量的形状信息
    log_z = torch.logsumexp(router_logits, dim=-1)  # 计算log-sum-exp并存储在log_z中
    z_loss = log_z**2  # 计算z-loss
    return torch.sum(z_loss) / (num_groups * tokens_per_group)  # 返回平均z-loss


# 从transformers.models.switch_transformers.modeling_switch_transformers.load_balancing_loss_func中复制的函数
def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float:
    r"""
    计算PyTorch中实现的辅助负载平衡损失。

    Args:
        router_probs (`torch.Tensor`):
            路由概率张量,形状为 [batch_size, sequence_length, num_experts]
        expert_indices (`torch.Tensor`):
            专家索引张量,形状为 [batch_size, sequence_length]

    Returns:
        辅助负载平衡损失的标量。
    """
    num_groups, tokens_per_group, _ = router_probs.shape  # 获取路由概率张量的形状信息
    log_prob = torch.log(router_probs)  # 计算路由概率的对数并存储在log_prob中
    lb_loss = -torch.sum(log_prob * expert_indices) / (num_groups * tokens_per_group)  # 计算负载平衡损失
    return lb_loss  # 返回负载平衡损失值
    # 获取路由概率张量的最后一个维度大小,即专家数量
    num_experts = router_probs.shape[-1]

    # 将专家索引张量转换为 int64 类型,以便进行 one-hot 编码
    if expert_indices.dtype != torch.int64:
        expert_indices = expert_indices.to(torch.int64)

    # 如果专家索引张量的维度为 2,则添加一个维度来适应 one-hot 编码的需求
    if len(expert_indices.shape) == 2:
        expert_indices = expert_indices.unsqueeze(2)

    # 使用 one-hot 编码创建专家掩码张量
    expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts)

    # 对于每个 token,确定其是否被路由到特定的专家
    expert_mask = torch.max(expert_mask, axis=-2).values

    # 将专家掩码张量转换为 float32 类型,以便计算平均值
    expert_mask = expert_mask.to(torch.float32)

    # 计算每个组和专家的平均 token 数量
    tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2)

    # 计算每个组和专家的路由概率的平均值
    router_prob_per_group_and_expert = torch.mean(router_probs, axis=-2)

    # 计算辅助损失,这是平均 token 数量和路由概率的乘积的平均值,乘以专家数量的平方
    return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2)
class GPTSanJapaneseDenseActDense(nn.Module):
    """
    FFN Layer for Switch Transformer and Extra layers

    GPTSAN can mix Switch Transformer layers and normal Transformer layers This class is used as Expert in Switch
    Transformer layers and as FFN in regular Transformer layers. RELU is used in the Switch Transformer layer, and
    Swish is used in the normal Transformer layer, so there is a choice of which is used in the argument.

    """

    def __init__(self, config: GPTSanJapaneseConfig, ext_layer=False):
        super().__init__()
        # 根据是否是额外层选择不同的中间维度
        d_inter = config.d_ext if ext_layer else config.d_ff
        # 输入层到中间层的线性变换,不带偏置项
        self.wi = nn.Linear(config.d_model, d_inter, bias=ext_layer)
        # 中间层到输出层的线性变换,不带偏置项
        self.wo = nn.Linear(d_inter, config.d_model, bias=ext_layer)
        # 如果是额外层,使用恒等映射作为dropout,否则使用配置中的dropout率
        self.dropout = nn.Identity() if ext_layer else nn.Dropout(config.dropout_rate)
        # 根据是否是额外层选择激活函数
        self.act = ACT2FN["swish" if ext_layer else "relu"]

    def forward(self, hidden_states):
        r"""
        Args:
            hidden_states (`torch.Tensor`) :
                [num_groups, tokens_per_group, hidden_dim] inputs to send to experts.
        Returns:
            torch.Tensor[num_groups, tokens_per_group, hidden_dim]

        """
        # 输入经过输入层到中间层的线性变换
        hidden_states = self.wi(hidden_states)
        # 应用选择的激活函数
        hidden_states = self.act(hidden_states)
        # 应用dropout或者恒等映射
        hidden_states = self.dropout(hidden_states)
        # 中间层到输出层的线性变换
        hidden_states = self.wo(hidden_states)
        return hidden_states


# Copied from transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router with SwitchTransformers->GPTSanJapanese
class GPTSanJapaneseTop1Router(nn.Module):
    """
    Router using tokens choose top-1 experts assignment.

    This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE
    (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then
    routed to their choice of expert until the expert's expert_capacity is reached. **There is no guarantee that each
    token is processed by an expert**, or that each expert receives at least one token.

    """

    def __init__(self, config: GPTSanJapaneseConfig):
        super().__init__()
        # 专家数量
        self.num_experts = config.num_experts
        # 每个专家的容量
        self.expert_capacity = config.expert_capacity
        # 分类器层,将隐藏状态映射到专家数目的输出,带有偏置项
        self.classifier = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias)
        # 路由噪声
        self.jitter_noise = config.router_jitter_noise
        # 是否忽略填充标记
        self.ignore_padding_tokens = config.router_ignore_padding_tokens
        # 路由数据类型
        self.dtype = getattr(torch, config.router_dtype)
    def _compute_router_probabilities(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""
        Computes router probabilities from input hidden states.

        Args:
            hidden_states (`torch.Tensor`):
                (batch_size, sequence_length, hidden_dim) from which router probabilities are computed.
        Returns:
            router_probabilities (`torch.Tensor`):
                Tensor of shape (batch_size, sequence_length, num_experts) corresponding to the probabilities for each
                token and expert. Used for routing tokens to experts.
            router_logits (`torch.Tensor`):
                Logits tensor of shape (batch_size, sequence_length, num_experts) corresponding to raw router logits.
                This is used later for computing router z-loss.
        """
        # float32 is used to ensure stability. See the discussion of "selective precision" in
        # https://arxiv.org/abs/2101.03961.
        # We also store the previous dtype to cast back the output to the previous dtype
        # 存储当前输入 hidden_states 的数据类型,以备将输出重新转换回该数据类型
        self.input_dtype = hidden_states.dtype
        # 将 hidden_states 转换为 self.dtype,以确保稳定性和一致性
        hidden_states = hidden_states.to(self.dtype)

        if self.training and self.jitter_noise > 0:
            # 如果在训练过程中,并且设置了 jitter_noise,则将输入的 token 乘以均匀分布的值,以添加一些噪音
            hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)

        # Shape: [num_groups, tokens_per_group, num_experts]
        # 调用 _cast_classifier 方法,确保分类器的数据类型与 self.dtype 一致
        self._cast_classifier()
        # 计算 router_logits,即路由器的原始逻辑回归结果
        router_logits = self.classifier(hidden_states)

        # 应用 Softmax 函数,并将结果转换回原始的数据类型 self.input_dtype
        router_probabilities = nn.functional.softmax(router_logits, dim=-1, dtype=self.dtype).to(self.input_dtype)
        return router_probabilities, router_logits

    def _cast_classifier(self):
        r"""
        `bitsandbytes` `Linear8bitLt` layers does not support manual casting Therefore we need to check if they are an
        instance of the `Linear8bitLt` class by checking special attributes.
        """
        # 如果分类器不是 `Linear8bitLt` 类的实例(通过检查特定的属性),则将其转换为 self.dtype 类型
        if not (hasattr(self.classifier, "SCB") or hasattr(self.classifier, "CB")):
            self.classifier = self.classifier.to(self.dtype)
    def forward(self, hidden_states: torch.Tensor) -> Tuple:
        r"""
        Generic forward function for every Router class. Each Router expects to have the same input hidden states
        (`hidden_states`) corresponding to the hidden states for each token, the `expert_capacity` corresponding to the
        number of tokens the Router will send to each expert, some Routers can send up to few tokens to each expert.

        Each Router works as the following: it expects the hidden states for each token, gets the `router_probs` and
        `router_logits` from the `router_weights`. This will assign for each token, the raw probability to be assigned
        to an expert. Then each Router class will have to define its own `_compute_routing_instructions`.

        Args:
            hidden_states (`torch.Tensor`) :
                [num_groups, tokens_per_group, hidden_dim] inputs to send to experts.
        Returns:
            Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`] Tuple containing the expert index, the router probs
            and the router logits. The router probabilities and logits are required to compute the loss.
        """
        # 计算路由概率和路由 logits
        router_probs, router_logits = self._compute_router_probabilities(hidden_states)

        # 根据概率选择每个 token 被分配到的专家索引
        expert_index = torch.argmax(router_probs, dim=-1)
        # 将专家索引转换成 one-hot 格式
        expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts)

        # 计算每个专家接收的 token 数量的累积和
        token_priority = torch.cumsum(expert_index, dim=-2)
        # 创建专家接收容量的掩码,以限制 token 数量不超过 expert_capacity
        expert_capacity_mask = token_priority <= self.expert_capacity
        expert_index = expert_index * expert_capacity_mask

        # 取每个 token 的最大路由概率作为输出
        router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1)
        return expert_index, router_probs, router_logits
# Copied from transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersSparseMLP with SwitchTransformers->GPTSanJapanese
class GPTSanJapaneseSparseMLP(nn.Module):
    r"""
    Implementation of the Switch Transformers Sparse MLP module.
    """

    def __init__(self, config: GPTSanJapaneseConfig, expert_class: nn.Module = GPTSanJapaneseDenseActDense):
        super().__init__()
        # Step 1: Get the correct router according to its class
        self.router = GPTSanJapaneseTop1Router(config)

        # Step 2: Get the experts
        self.experts = nn.ModuleDict()
        for idx in range(config.num_experts):
            self.experts[f"expert_{idx}"] = expert_class(config)

    def forward(self, hidden_states):
        r"""
        Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following:

        1- Gets the `router_mask` from the router. The shape of the mask is `(batch_size, sequence_length, num_expert)`
        and corresponds to the argmax of the `router_probs`. The probabilities are needed in the computation of the
        hidden states : they are broadcasted to the hidden states values (can be interpreted as a scaling factor).

        2- Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each
        expert the corresponding hidden states.

        """
        # Step 1: Get the router_mask from the router as wel as the probabilities
        router_mask, router_probs, router_logits = self.router(hidden_states)
        expert_index = torch.argmax(router_mask, dim=-1)

        # The routers introduced might not always map all the tokens, to a router, which means that some hidden states
        # can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.

        next_states = hidden_states.clone()
        for idx, expert in enumerate(self.experts.values()):
            token_indices = router_mask[:, :, idx].bool()
            next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype)

        hidden_states = router_probs * next_states
        return hidden_states, (router_logits, expert_index)


class GPTSanJapaneseLayerSparseFF(nn.Module):
    r"""
    Switch Transformers Feed Forward layer module. This is a wrapper around the Mixture of Experts module.

    Parameters:
        config : ([`GPTSanJapaneseConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
    """
    def __init__(self, config: GPTSanJapaneseConfig):
        super().__init__()
        self.mlp = GPTSanJapaneseSparseMLP(config)  # 初始化稀疏多层感知机(MLP)模型
        self.soft_bypass_mlp = nn.Linear(config.d_model, config.d_model, bias=False)  # 创建线性层,用于软绕过MLP
        self.norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)  # 初始化层归一化,使用给定的epsilon值

    def forward(self, hidden_states, output_router_logits):
        r"""
        Args:
            hidden_states (`torch.Tensor`) :
                [num_groups, tokens_per_group, hidden_dim] inputs to send to experts.  # 输入隐藏状态,形状为[num_groups, tokens_per_group, hidden_dim],发送给专家
            output_router_logits (`bool`) :
                output experts router output.  # 输出专家的路由器输出
        Returns:
            torch.Tensor[num_groups, tokens_per_group, hidden_dim]  # 返回形状为[num_groups, tokens_per_group, hidden_dim]的张量

        """
        forwarded_states, router_tuple = self.mlp(hidden_states)  # 使用MLP处理隐藏状态,获得前向状态和路由元组
        forwarded_states += torch.tanh(self.soft_bypass_mlp(hidden_states))  # 添加软绕过MLP的操作结果到前向状态
        output = hidden_states + self.norm(forwarded_states)  # 使用层归一化将前向状态与隐藏状态相加,得到最终输出

        if output_router_logits and router_tuple is not None:
            return output, router_tuple  # 如果需要输出路由器的输出且路由元组不为空,则返回输出和路由元组
        else:
            return output  # 否则只返回输出
class GPTSanJapaneseLayerDenseFF(nn.Module):
    r"""
    Extra Transformers Feed Forward layer module.

    Parameters:
        config : ([`GPTSanJapaneseConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
    """

    def __init__(self, config: GPTSanJapaneseConfig):
        super().__init__()
        # 检查是否是稀疏层,如果不是则是密集层
        self.mlp = GPTSanJapaneseDenseActDense(config, ext_layer=True)
        # 使用 LayerNorm 对象进行归一化处理
        self.norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)

    def forward(self, hidden_states):
        r"""
        Args:
            hidden_states (`torch.Tensor`) :
                [num_groups, tokens_per_group, hidden_dim] inputs to send to experts.
        Returns:
            torch.Tensor[num_groups, tokens_per_group, hidden_dim]

        """
        # 通过 MLP 层处理隐藏状态
        forwarded_states = self.mlp(hidden_states)
        # 将处理后的状态与归一化后的隐藏状态相加作为输出
        output = hidden_states + self.norm(forwarded_states)
        return output


# 从 transformers.models.bart.modeling_bart.BartAttention 复制并修改为 GPTSanJapanese
class GPTSanJapaneseAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        bias: bool = True,
        is_causal: bool = False,
        config: Optional[GPTSanJapaneseConfig] = None,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        self.config = config

        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {num_heads})."
            )
        # 缩放因子设定为 head_dim 的负半径
        self.scaling = self.head_dim**-0.5
        self.is_decoder = is_decoder
        self.is_causal = is_causal

        # 线性变换层,用于计算 Q、K、V、输出
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        # 将张量重塑为 [bsz, num_heads, seq_len, head_dim] 的形状
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
    # 定义模型的前向传播方法,接受多个输入参数
    def forward(
        self,
        hidden_states: torch.Tensor,
        key_value_states: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    """
    Self Attention and Normalization Unit
    """

    def __init__(self, config, has_relative_attention_bias=False):
        super().__init__()
        # 初始化自注意力层,使用配置中的模型维度和头数,同时作为解码器层
        self.self_attn = GPTSanJapaneseAttention(
            embed_dim=config.d_model,
            num_heads=config.num_heads,
            is_decoder=True,
            bias=has_relative_attention_bias,
        )
        # 初始化层归一化,使用配置中的模型维度和层归一化的 epsilon 参数
        self.norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)

    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,



    """
    Self Attention and FFN Unit
    """

    def __init__(self, config, ext_layer=False):
        super().__init__()
        # 初始化自注意力和前馈网络单元,根据 ext_layer 参数决定使用稠密或稀疏的前馈网络层
        self.self_attn = GPTSanJapaneseLayerSelfAttention(config)
        self.feed_forward = GPTSanJapaneseLayerDenseFF(config) if ext_layer else GPTSanJapaneseLayerSparseFF(config)

    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
        output_router_tuple: Optional[bool] = False,



    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = GPTSanJapaneseConfig
    base_model_prefix = "gptsan_japanese"
    supports_gradient_checkpointing = False
    _no_split_modules = ["GPTSanJapaneseBlock"]
    _skip_keys_device_placement = "past_key_values"

    @property
    def dummy_inputs(self):
        # 创建一个包含虚拟输入的字典,包括输入标识符和注意力掩码
        input_ids = torch.tensor(DUMMY_INPUTS)
        input_mask = torch.tensor(DUMMY_MASK)
        dummy_inputs = {
            "input_ids": input_ids,
            "attention_mask": input_mask,
        }
        return dummy_inputs

    # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right
    # 定义一个私有方法 `_shift_right`,接收参数 `input_ids`
    def _shift_right(self, input_ids):
        # 从模型配置中获取解码器起始标记的 ID
        decoder_start_token_id = self.config.decoder_start_token_id
        # 从模型配置中获取填充标记的 ID
        pad_token_id = self.config.pad_token_id

        # 如果解码器起始标记 ID 未定义,则引发数值错误
        if decoder_start_token_id is None:
            raise ValueError(
                "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. "
                "See T5 docs for more information."
            )

        # 将输入向右移动一位,以便为解码器准备输入
        if is_torch_fx_proxy(input_ids):
            # 对于 Torch FX 代理对象,不支持原生的项目赋值操作
            shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
            shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
        else:
            # 创建一个与输入形状相同的零张量
            shifted_input_ids = input_ids.new_zeros(input_ids.shape)
            # 将输入向右移动一位,并在最左侧插入解码器起始标记 ID
            shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
            shifted_input_ids[..., 0] = decoder_start_token_id

        # 如果填充标记 ID 未定义,则引发数值错误
        if pad_token_id is None:
            raise ValueError("self.model.config.pad_token_id has to be defined.")
        
        # 将标签中可能存在的 -100 值替换为填充标记 ID
        shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

        # 返回向右移位后的输入张量
        return shifted_input_ids
# 使用原始字符串字面值定义文档字符串,介绍了 GPTSAN-japanese 模型的概述和用途链接
GPTSAN_JAPANESE_START_DOCSTRING = r"""

    The [GPTSAN-japanese](https://github.com/tanreinama/GPTSAN) model was proposed in General-purpose Swich transformer
    based Japanese language model

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`GPTSanJapaneseConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

# 空字符串,准备用于描述模型输入的文档字符串
GPTSAN_JAPANESE_INPUTS_DOCSTRING = r"""
"""

# 添加描述信息到 GPTSanJapaneseModel 类的文档字符串中,指明它是不带任何特定头部的 GPTSAN-japanese 模型变压器,输出原始隐藏状态
@add_start_docstrings(
    "The bare GPTSAN-japanese Model transformer outputting raw hidden-states without any specific head on top.",
    GPTSAN_JAPANESE_START_DOCSTRING,
)
class GPTSanJapaneseModel(GPTSanJapanesePreTrainedModel):
    def __init__(self, config: GPTSanJapaneseConfig):
        super().__init__(config)
        # 初始化位置嵌入,用于模型的位置编码
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)
        # 深拷贝模型配置
        self.config = copy.deepcopy(config)
        # 初始化词嵌入
        self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
        # 初始化最后一个投影层,用于将模型输出映射回特定维度
        self.last_project = nn.Linear(config.d_model, config.d_model, bias=True)
        # 设置激活函数为 Swish
        self.act = ACT2FN["swish"]

        # 初始化模型块列表
        self.blocks = torch.nn.ModuleList([])
        # 添加 switch 层
        for _ in range(config.num_switch_layers):
            self.blocks.append(GPTSanJapaneseBlock(config))
        # 添加 ext 层
        for _ in range(config.num_ext_layers):
            self.blocks.append(GPTSanJapaneseBlock(config, ext_layer=True))

        # 如果存在额外的 ext 层,初始化额外的位置嵌入
        if config.num_ext_layers > 0:
            self.extra_position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)

        # 如果存在 d_spout,初始化 spout 层
        if config.d_spout:
            spouts = []
            for _ in range(8):
                spouts.append(nn.Linear(config.d_spout, config.d_spout, bias=False))
                spouts.append(nn.Tanh())
            spouts.append(nn.Linear(config.d_spout, config.num_layers * 2 * config.d_model, bias=False))
            self.spout = nn.Sequential(*spouts)

        # 执行初始化后的操作
        self.post_init()

    # 获取输入嵌入
    def get_input_embeddings(self):
        return self.embed_tokens

    # 设置输入嵌入
    def set_input_embeddings(self, new_embeddings):
        self.embed_tokens = new_embeddings

    # 添加描述信息到模型前向方法的文档字符串中,描述模型的输入
    @add_start_docstrings_to_model_forward(GPTSAN_JAPANESE_INPUTS_DOCSTRING)
    # 定义模型的前向传播方法,接受多个可选的输入参数
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,  # 输入的token IDs,可选的长整型张量
        attention_mask: Optional[torch.FloatTensor] = None,  # 注意力掩码,可选的浮点数张量
        token_type_ids: Optional[torch.FloatTensor] = None,  # token 类型 IDs,可选的浮点数张量
        spout: Optional[torch.FloatTensor] = None,  # 特定应用的张量,可选的浮点数张量
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,  # 过去的键值对,可选的张量元组
        head_mask: Optional[torch.FloatTensor] = None,  # 头部掩码,可选的浮点数张量
        use_cache: Optional[bool] = False,  # 是否使用缓存,默认为False
        inputs_embeds: Optional[torch.FloatTensor] = None,  # 输入的嵌入张量,可选的浮点数张量
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,  # 解码器输入的嵌入张量,可选的浮点数张量
        output_attentions: Optional[bool] = None,  # 是否输出注意力,默认为None
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态,默认为None
        return_dict: Optional[bool] = None,  # 是否返回字典,默认为None
        output_router_logits: Optional[bool] = None,  # 是否输出路由器日志,默认为None
        num_precontext: Optional[torch.LongTensor] = None,  # 前文上下文数目,可选的长整型张量
@add_start_docstrings(
    "The bare GPTSAN-japanese Model with a language modeling head.",
    GPTSAN_JAPANESE_START_DOCSTRING,
)
class GPTSanJapaneseForConditionalGeneration(GPTSanJapanesePreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config: GPTSanJapaneseConfig):
        super().__init__(config)
        self.model = GPTSanJapaneseModel(config)  # 初始化一个GPTSanJapaneseModel模型
        self.register_buffer("final_logits_bias", torch.zeros([1, config.vocab_size]))  # 注册一个大小为1xvocab_size的零张量作为偏置项
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)  # 初始化一个线性层作为语言模型头部,无偏置项
        if not self.config.torchscript:
            self.lm_head.weight = self.model.embed_tokens.weight  # 如果不是torchscript模式,则将语言模型头部的权重与嵌入词嵌入权重绑定

    @add_start_docstrings_to_model_forward(GPTSAN_JAPANESE_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.FloatTensor] = None,
        spout: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        output_router_logits: Optional[bool] = None,
        labels: Optional[torch.LongTensor] = None,
    ):
        # 模型的前向传播方法,接受多种输入,返回生成的结果

    def prepare_inputs_for_generation(
        self,
        input_ids: torch.LongTensor,
        attention_mask: torch.FloatTensor,
        token_type_ids: Optional[torch.FloatTensor] = None,
        spout: Optional[Union[List, torch.FloatTensor]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        **kwargs,
    ):
        if isinstance(spout, list):  # 如果spout是列表类型
            spout = torch.tensor(spout).float()  # 将其转换为torch张量并转换为float类型
            if input_ids is not None:  # 如果存在input_ids
                spout = spout.to(input_ids.device)  # 将spout移动到与input_ids相同的设备上
        if past_key_values is not None:  # 如果存在过去的键值
            return {
                "input_ids": input_ids[:, -1:] if input_ids is not None else None,  # 返回最后一个位置的input_ids
                "attention_mask": attention_mask,  # 返回attention_mask
                "token_type_ids": token_type_ids[:, -1:] if token_type_ids is not None else None,  # 返回最后一个位置的token_type_ids
                "spout": spout,  # 返回spout
                "past_key_values": past_key_values,  # 返回过去的键值
            }
        return {
            "input_ids": input_ids,  # 返回input_ids
            "attention_mask": attention_mask,  # 返回attention_mask
            "token_type_ids": token_type_ids,  # 返回token_type_ids
            "spout": spout,  # 返回spout
            "past_key_values": None,  # 返回空的过去的键值
        }

    # 从transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersForConditionalGeneration.prepare_decoder_input_ids_from_labels复制而来,改为使用GPTSanJapanese
    # 根据标签张量生成解码器的输入序列(右移一位)
    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        return self._shift_right(labels)

    # 从父类中继承的方法,用于调整词嵌入矩阵的大小,支持可选的多重填充
    def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
        # 调用父类方法调整词嵌入矩阵大小
        new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
        # 调用本类方法调整最终输出偏置的大小
        self._resize_final_logits_bias(new_embeddings.weight.shape[0])
        # 返回调整后的词嵌入矩阵
        return new_embeddings

    # 调整最终输出偏置的大小,以适应新的标记数量
    def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
        old_num_tokens = self.final_logits_bias.shape[-1]
        # 如果新的标记数量小于等于旧的数量,只截取现有偏置的一部分
        if new_num_tokens <= old_num_tokens:
            new_bias = self.final_logits_bias[:, :new_num_tokens]
        # 如果新的标记数量大于旧的数量,则在偏置末尾填充零向量
        else:
            extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
            new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
        # 注册更新后的偏置为缓冲区
        self.register_buffer("final_logits_bias", new_bias)

    # 获取模型的输入词嵌入
    def get_input_embeddings(self):
        return self.model.get_input_embeddings()

    # 设置模型的输入词嵌入
    def set_input_embeddings(self, new_embeddings):
        self.model.set_input_embeddings(new_embeddings)

    # 设置模型的输出词嵌入,用新的词嵌入替换语言模型头部
    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    # 获取模型的输出词嵌入,即语言模型头部
    def get_output_embeddings(self):
        return self.lm_head

    # 将路由器的输出解压,返回总的路由器 logits 和专家索引
    def _unpack_router_logits(self, router_outputs):
        total_router_logits = []
        total_expert_indexes = []
        for router_output in router_outputs:
            # 如果路由器输出的第一个张量维度大于1,表明有有效的路由器 logits 和专家索引
            if len(router_output[0].shape) > 1:
                router_logits, expert_indexes = router_output
                total_router_logits.append(router_logits)
                total_expert_indexes.append(expert_indexes)
        # 沿着第一个维度拼接所有路由器 logits 和专家索引
        return torch.cat(total_router_logits, dim=1), torch.cat(total_expert_indexes, dim=1)
posted @ 2024-06-30 15:39  绝不原创的飞龙  阅读(15)  评论(0编辑  收藏  举报