Transformers-源码解析-六十四-

Transformers 源码解析(六十四)

.\models\led\tokenization_led.py

# coding=utf-8
# 版权所有 2021 Iz Beltagy,Matthew E. Peters,Arman Cohan 和 HuggingFace Inc. 团队。保留所有权利。
#
# 根据 Apache 许可证版本 2.0 进行许可;
# 除非符合许可证的条款,否则不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件根据“原样”分发,
# 没有任何形式的明示或暗示的保证或条件。
# 有关特定语言的权限,请参阅许可证。
"""LED 的分词类。"""

import json
import os
from functools import lru_cache
from typing import Dict, List, Optional, Tuple, Union

import regex as re

from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...tokenization_utils_base import BatchEncoding, EncodedInput
from ...utils import PaddingStrategy, logging

# 获取日志记录器实例
logger = logging.get_logger(__name__)

# 定义词汇文件名
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt"}

# 预训练模型的词汇文件映射
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/vocab.json",
    },
    "merges_file": {
        "allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/merges.txt",
    },
    "tokenizer_file": {
        "allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/tokenizer.json",
    },
}

# 预训练模型的位置编码嵌入尺寸
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "allenai/led-base-16384": 16384,
}


@lru_cache()
# 从 transformers.models.bart.tokenization_bart.bytes_to_unicode 复制的函数
def bytes_to_unicode():
    """
    返回 utf-8 字节列表和 Unicode 字符串的映射。避免映射到空白字符或控制字符,以免引起 bpe 代码错误。
    
    可逆的 bpe 代码适用于 Unicode 字符串。这意味着如果要避免 UNK(未知)符号,词汇表中需要大量的 Unicode 字符。
    当数据集达到约 100 亿个标记时,您需要大约 5000 个 Unicode 字符以获得良好的覆盖率。
    这在普通的 32K bpe 词汇表中占有相当大的比例。为了避免这种情况,我们需要 utf-8 字节和 Unicode 字符串之间的查找表。
    """
    bs = (
        list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
    )
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8 + n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))


# 从 transformers.models.bart.tokenization_bart.get_pairs 复制的函数
def get_pairs(word):
    """
    返回单词中的符号对集合。

    单词表示为符号元组(符号是长度可变的字符串)。
    """
    # 创建一个空的集合用于存储字符对
    pairs = set()
    # 取单词的第一个字符作为前一个字符
    prev_char = word[0]
    # 遍历单词中除第一个字符外的每个字符
    for char in word[1:]:
        # 将前一个字符和当前字符作为一个元组,添加到集合中
        pairs.add((prev_char, char))
        # 更新前一个字符为当前字符,为下一次循环做准备
        prev_char = char
    # 返回存储了单词中相邻字符对的集合
    return pairs
class LEDTokenizer(PreTrainedTokenizer):
    """
    Constructs a LED tokenizer, which is similar to the ROBERTa tokenizer, using byte-level Byte-Pair-Encoding.

    This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
    be encoded differently whether it is at the beginning of the sentence (without space) or not:

    ```
    >>> from transformers import LEDTokenizer

    >>> tokenizer = LEDTokenizer.from_pretrained("allenai/led-base-16384")
    >>> tokenizer("Hello world")["input_ids"]
    [0, 31414, 232, 2]

    >>> tokenizer(" Hello world")["input_ids"]
    [0, 20920, 232, 2]
    ```

    You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
    call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.

    <Tip>

    When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one).

    </Tip>

    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
    this superclass for more information regarding those methods.
    """
    # LEDTokenizer 类,继承自 PreTrainedTokenizer
    # 用于构建 LED 分词器,其使用字节级字节对编码(Byte-Pair-Encoding)
    
    def __init__(self, vocab_file, merges_file, **kwargs):
        """
        Initializes the LEDTokenizer with the provided vocabulary and merges files.

        Args:
            vocab_file (str): Path to the vocabulary file.
            merges_file (str): Path to the merges file.
            kwargs: Additional arguments passed to the tokenizer initialization.
        """
        # 调用父类的初始化方法,传入词汇表文件和合并文件路径,以及其他可选参数
        super().__init__(**kwargs)
        # 使用给定的词汇表文件和合并文件初始化 LEDTokenizer
        
        self.vocab_file = vocab_file
        self.merges_file = merges_file
        # 设置词汇表文件和合并文件的属性
        
        self.encoder = json.load(open(vocab_file))
        self.decoder = {v: k for k, v in self.encoder.items()}
        # 从词汇表文件中加载编码器和解码器,用于将词汇映射到整数和反向映射

        with open(merges_file, encoding="utf-8") as f:
            bpe_data = f.read().split("\n")[1:-1]
        # 打开合并文件,读取 BPE 数据
        
        merges = [(tuple(merge.split()[0:2]), int(merge.split()[2])) for merge in bpe_data]
        # 解析 BPE 数据并转换为元组的列表
        
        self.bpe_ranks = dict(zip(merges, range(len(merges))))
        # 创建 BPE 合并的排名字典,以便快速查找合并的顺序

    def _tokenize(self, text):
        """
        Tokenizes a given text into subwords.

        Args:
            text (str): The input text to tokenize.

        Returns:
            List[str]: A list of subwords representing the tokenized text.
        """
        # 将给定的文本分词成子词(subwords)
        merges = self._split_to_subwords(text)
        # 调用 _split_to_subwords 方法,将文本拆分为子词
        
        return merges
        # 返回分词后的子词列表

    def _split_to_subwords(self, text):
        """
        Splits the text into subwords based on the BPE merges.

        Args:
            text (str): The input text to split.

        Returns:
            List[str]: A list of subwords.
        """
        # 根据 BPE 合并将文本拆分为子词
        return []
        # 返回空的子词列表,实际应该是根据 BPE 算法进行子词拆分并返回结果
    # 词汇文件的名称映射,用于指定预训练模型的词汇文件
    vocab_files_names = VOCAB_FILES_NAMES
    # 预训练模型的词汇文件映射,指定每个预训练模型对应的词汇文件路径
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    # 预训练位置嵌入大小的映射,指定每个预训练模型的最大输入尺寸
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    # 模型输入的名称列表,指定输入数据的名称,通常为 "input_ids" 和 "attention_mask"
    model_input_names = ["input_ids", "attention_mask"]

    # 以下代码段是从 transformers.models.bart.tokenization_bart.BartTokenizer.__init__ 中复制的
    def __init__(
        self,
        vocab_file,
        merges_file,
        errors="replace",
        bos_token="<s>",
        eos_token="</s>",
        sep_token="</s>",
        cls_token="<s>",
        unk_token="<unk>",
        pad_token="<pad>",
        mask_token="<mask>",
        add_prefix_space=False,
        **kwargs,
    ):
        # 如果 `bos_token`, `eos_token`, `sep_token`, `cls_token`, `unk_token`, `pad_token` 是字符串类型,则将它们封装为 `AddedToken` 对象,否则保持原样
        bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
        eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
        sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
        cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
        unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
        pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token

        # 如果 `mask_token` 是字符串类型,则将它封装为 `AddedToken` 对象,并且在左侧去掉空格,保持右侧不变;否则保持原样
        mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token

        # 使用 UTF-8 编码打开 `vocab_file` 文件,并将其加载为 JSON 格式,存储到 `self.encoder` 中
        with open(vocab_file, encoding="utf-8") as vocab_handle:
            self.encoder = json.load(vocab_handle)
        
        # 创建 `self.decoder` 字典,将 `self.encoder` 中的键值对反转,以便根据索引获取词汇
        self.decoder = {v: k for k, v in self.encoder.items()}
        
        # 指定解码过程中处理错误的策略,存储到 `self.errors` 中
        self.errors = errors  # how to handle errors in decoding
        
        # 转换字节到 Unicode 编码的映射表,通过调用 `bytes_to_unicode()` 函数实现
        self.byte_encoder = bytes_to_unicode()
        
        # 创建 `self.byte_decoder` 字典,将 `self.byte_encoder` 中的键值对反转,以便根据编码获取原始字节
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        
        # 使用 UTF-8 编码打开 `merges_file` 文件,读取 BPE 合并操作,存储到 `bpe_merges` 中
        with open(merges_file, encoding="utf-8") as merges_handle:
            bpe_merges = merges_handle.read().split("\n")[1:-1]
        
        # 将每行 BPE 合并操作转换为元组形式,存储到 `bpe_merges` 列表中
        bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
        
        # 创建 `self.bpe_ranks` 字典,将 BPE 合并操作与其在列表中的索引关联起来
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
        
        # 初始化缓存 `self.cache`,用于存储临时数据
        self.cache = {}
        
        # 设置是否在词前添加空格的标志,存储到 `self.add_prefix_space` 中
        self.add_prefix_space = add_prefix_space

        # 使用正则表达式创建 `self.pat` 模式,用于分词时处理合并和大小写
        # 添加 `re.IGNORECASE` 标志以支持对大写版本的缩写进行 BPE 合并
        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

        # 调用父类的初始化方法,传递参数并完成初始化
        super().__init__(
            errors=errors,
            bos_token=bos_token,
            eos_token=eos_token,
            unk_token=unk_token,
            sep_token=sep_token,
            cls_token=cls_token,
            pad_token=pad_token,
            mask_token=mask_token,
            add_prefix_space=add_prefix_space,
            **kwargs,
        )

    @property
    # 返回 `self.encoder` 中的条目数量,即词汇表的大小
    # 从 `transformers.models.bart.tokenization_bart.BartTokenizer.vocab_size` 处复制
    def vocab_size(self):
        return len(self.encoder)

    # 返回包含 `self.encoder` 和 `self.added_tokens_encoder` 的词汇表字典
    # 从 `transformers.models.bart.tokenization_bart.BartTokenizer.get_vocab` 处复制
    def get_vocab(self):
        return dict(self.encoder, **self.added_tokens_encoder)

    # 复制自 `transformers.models.bart.tokenization_bart.BartTokenizer.bpe`
    # 未提供具体的实现
    def bpe
    def bpe(self, token):
        # 如果缓存中已经存在该 token 的处理结果,则直接返回缓存中的结果
        if token in self.cache:
            return self.cache[token]
        
        # 将 token 转换为字符元组
        word = tuple(token)
        # 获取 token 的所有可能的字符对
        pairs = get_pairs(word)

        # 如果没有字符对,则直接返回原始的 token
        if not pairs:
            return token
        
        # 进入循环,直到 token 不再有字符对为止
        while True:
            # 找到当前字符对中排名最低的字符对
            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
            # 如果找到的字符对不在字符对排名中,退出循环
            if bigram not in self.bpe_ranks:
                break
            # 分割字符对
            first, second = bigram
            new_word = []
            i = 0
            # 遍历 token 的字符
            while i < len(word):
                try:
                    j = word.index(first, i)
                except ValueError:
                    # 如果找不到字符 first,则将剩余部分添加到新的 token 中
                    new_word.extend(word[i:])
                    break
                else:
                    # 将 first 之前的字符添加到新的 token 中
                    new_word.extend(word[i:j])
                    i = j

                # 检查是否匹配字符对 first 和 second
                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
                    # 如果匹配,则将字符对添加到新的 token 中
                    new_word.append(first + second)
                    i += 2
                else:
                    # 否则将当前字符添加到新的 token 中
                    new_word.append(word[i])
                    i += 1
            # 更新 word 为新的字符元组
            new_word = tuple(new_word)
            word = new_word
            # 如果 token 的长度为 1,则退出循环
            if len(word) == 1:
                break
            else:
                # 获取更新后的字符对
                pairs = get_pairs(word)
        
        # 将字符元组转换为字符串
        word = " ".join(word)
        # 将处理结果添加到缓存中
        self.cache[token] = word
        # 返回处理后的字符串
        return word

    # Copied from transformers.models.bart.tokenization_bart.BartTokenizer._tokenize
    def _tokenize(self, text):
        """Tokenize a string."""
        bpe_tokens = []
        # 使用正则表达式找到文本中的所有 token
        for token in re.findall(self.pat, text):
            # 将每个 token 编码为字节,并映射为 unicode 字符串,避免 BPE 中的控制符号(在我们的情况下是空格)
            token = "".join(
                self.byte_encoder[b] for b in token.encode("utf-8")
            )  # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
            # 使用 BPE 算法处理 token,并将处理结果拆分为多个子 token
            bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
        # 返回处理后的所有子 token 列表
        return bpe_tokens

    # Copied from transformers.models.bart.tokenization_bart.BartTokenizer._convert_token_to_id
    def _convert_token_to_id(self, token):
        """Converts a token (str) in an id using the vocab."""
        # 使用词汇表将 token 转换为对应的 id,如果 token 不存在于词汇表中,则返回未知 token 的 id
        return self.encoder.get(token, self.encoder.get(self.unk_token))

    # Copied from transformers.models.bart.tokenization_bart.BartTokenizer._convert_id_to_token
    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        # 使用词汇表将 id 转换为对应的 token
        return self.decoder.get(index)

    # Copied from transformers.models.bart.tokenization_bart.BartTokenizer.convert_tokens_to_string
    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (string) in a single string."""
        # 将所有 token 合并为一个字符串
        text = "".join(tokens)
        # 将合并后的字符串解码为 utf-8 格式的文本
        text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
        # 返回解码后的文本
        return text

    # Copied from transformers.models.bart.tokenization_bart.BartTokenizer.save_vocabulary
    # 将词汇表保存到指定目录下的文件中
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # 检查保存目录是否存在,若不存在则记录错误并返回
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return
        
        # 构建词汇表文件名和合并文件名
        vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )
        merge_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
        )

        # 将编码器的内容以 JSON 格式写入词汇表文件
        with open(vocab_file, "w", encoding="utf-8") as f:
            f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")

        index = 0
        # 将 BPE(Byte Pair Encoding)标记和其索引写入合并文件
        with open(merge_file, "w", encoding="utf-8") as writer:
            writer.write("#version: 0.2\n")
            for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
                if index != token_index:
                    # 若 BPE 合并索引不连续,记录警告信息
                    logger.warning(
                        f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
                        " Please check that the tokenizer is not corrupted!"
                    )
                    index = token_index
                # 写入 BPE 标记
                writer.write(" ".join(bpe_tokens) + "\n")
                index += 1

        # 返回保存的词汇表文件名和合并文件名
        return vocab_file, merge_file

    # 从 BARTTokenizer.build_inputs_with_special_tokens 复制并修改为 LEDTokenizer 的特殊标记构建方法
    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        通过添加特殊标记,为序列分类任务构建模型输入。LED 序列有以下格式:

        - 单个序列:`<s> X </s>`
        - 序列对:`<s> A </s></s> B </s>`

        Args:
            token_ids_0 (`List[int]`):
                要添加特殊标记的 ID 列表。
            token_ids_1 (`List[int]`, *可选*):
                第二个序列的 ID 列表(用于序列对任务)。

        Returns:
            `List[int]`: 包含适当特殊标记的输入 ID 列表。
        """
        if token_ids_1 is None:
            return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
        cls = [self.cls_token_id]
        sep = [self.sep_token_id]
        return cls + token_ids_0 + sep + sep + token_ids_1 + sep

    # 从 BARTTokenizer.get_special_tokens_mask 复制并修改为 LEDTokenizer 的特殊标记掩码方法
    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
    ) -> List[int]:
        """
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` method.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        if token_ids_1 is None:
            return [1] + ([0] * len(token_ids_0)) + [1]
        return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]

    # Copied from transformers.models.bart.tokenization_bart.BartTokenizer.create_token_type_ids_from_sequences with BART->LED
    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. LED does not
        make use of token type ids, therefore a list of zeros is returned.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of zeros.
        """
        # Initialize separator and classification token IDs for sequence masking
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]

        if token_ids_1 is None:
            # Return a list of zeros of length equal to the combined length of cls, token_ids_0, and sep
            return len(cls + token_ids_0 + sep) * [0]
        # Return a list of zeros of length equal to the combined length of cls, token_ids_0, sep, sep, token_ids_1, and sep
        return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]

    # Copied from transformers.models.bart.tokenization_bart.BartTokenizer.prepare_for_tokenization
    def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
        """
        Prepare text for tokenization by adding a space prefix if specified and not already present.

        Args:
            text (str): The input text to be tokenized.
            is_split_into_words (bool, optional): Whether the text is already split into words.
            **kwargs: Additional keyword arguments.

        Returns:
            tuple: A tuple containing the modified text and remaining keyword arguments.
        """
        # Check if a space prefix should be added and modify text accordingly
        add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
        if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()):
            text = " " + text
        return (text, kwargs)

    def _pad(
        self,
        encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
        max_length: Optional[int] = None,
        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
        pad_to_multiple_of: Optional[int] = None,
        return_attention_mask: Optional[bool] = None,
    ) -> dict:
        # 调用父类的 `_pad` 方法,对编码输入进行填充
        encoded_inputs = super()._pad(
            encoded_inputs=encoded_inputs,
            max_length=max_length,
            padding_strategy=padding_strategy,
            pad_to_multiple_of=pad_to_multiple_of,
            return_attention_mask=return_attention_mask,
        )

        # 如果 `return_attention_mask` 为 None,则根据模型默认值确定是否返回注意力掩码
        if return_attention_mask is None:
            return_attention_mask = "attention_mask" in self.model_input_names

        # 如果需要返回注意力掩码,并且编码输入中存在 `global_attention_mask`
        if return_attention_mask and "global_attention_mask" in encoded_inputs:
            required_input = encoded_inputs[self.model_input_names[0]]
            # `global_attention_mask` 需要与其他(顺序)输入具有相同的长度
            needs_to_be_padded = len(encoded_inputs["global_attention_mask"]) != len(required_input)

            # 如果需要填充
            if needs_to_be_padded:
                difference = len(required_input) - len(encoded_inputs["global_attention_mask"])

                # 根据填充方向进行处理
                if self.padding_side == "right":
                    # 使用 `-1`,因为 `global_attention_mask` 中的 `0` 表示局部注意力而不是不需要注意
                    encoded_inputs["global_attention_mask"] = (
                        encoded_inputs["global_attention_mask"] + [-1] * difference
                    )
                elif self.padding_side == "left":
                    encoded_inputs["global_attention_mask"] = [-1] * difference + encoded_inputs[
                        "global_attention_mask"
                    ]
                else:
                    # 抛出异常,无效的填充策略
                    raise ValueError("Invalid padding strategy:" + str(self.padding_side))

        # 返回填充后的编码输入字典
        return encoded_inputs

.\models\led\tokenization_led_fast.py

# 定义 LEDTokenizerFast 类,继承自 PreTrainedTokenizerFast 类
class LEDTokenizerFast(PreTrainedTokenizerFast):
    r"""
    Construct a "fast" LED tokenizer (backed by HuggingFace's *tokenizers* library), derived from the GPT-2 tokenizer,
    using byte-level Byte-Pair-Encoding.

    This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
    be encoded differently whether it is at the beginning of the sentence (without space) or not:

    ```
    >>> from transformers import LEDTokenizerFast

    >>> tokenizer = LEDTokenizerFast.from_pretrained("allenai/led-base-16384")
    >>> tokenizer("Hello world")["input_ids"]
    [0, 31414, 232, 2]

    >>> tokenizer(" Hello world")["input_ids"]
    [0, 20920, 232, 2]
    ```

    You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
    call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.

    <Tip>

    When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.

    </Tip>

    This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
    """
    refer to this superclass for more information regarding those methods.

    Args:
        vocab_file (`str`):
            Path to the vocabulary file.
        merges_file (`str`):
            Path to the merges file.
        errors (`str`, *optional*, defaults to `"replace"`):
            Paradigm to follow when decoding bytes to UTF-8. See
            [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
        bos_token (`str`, *optional*, defaults to `"<s>"`):
            The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.

            <Tip>

            When building a sequence using special tokens, this is not the token that is used for the beginning of
            sequence. The token used is the `cls_token`.

            </Tip>

        eos_token (`str`, *optional*, defaults to `"</s>"`):
            The end of sequence token.

            <Tip>

            When building a sequence using special tokens, this is not the token that is used for the end of sequence.
            The token used is the `sep_token`.

            </Tip>

        sep_token (`str`, *optional*, defaults to `"</s>"`):
            The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
            sequence classification or for a text and a question for question answering. It is also used as the last
            token of a sequence built with special tokens.
        cls_token (`str`, *optional*, defaults to `"<s>"`):
            The classifier token which is used when doing sequence classification (classification of the whole sequence
            instead of per-token classification). It is the first token of the sequence when built with special tokens.
        unk_token (`str`, *optional*, defaults to `"<unk>"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        pad_token (`str`, *optional*, defaults to `"<pad>"`):
            The token used for padding, for example when batching sequences of different lengths.
        mask_token (`str`, *optional*, defaults to `"<mask>"`):
            The token used for masking values. This is the token used when training this model with masked language
            modeling. This is the token which the model will try to predict.
        add_prefix_space (`bool`, *optional*, defaults to `False`):
            Whether or not to add an initial space to the input. This allows to treat the leading word just as any
            other word. (LED tokenizer detect beginning of words by the preceding space).
        trim_offsets (`bool`, *optional*, defaults to `True`):
            Whether the post processing step should trim offsets to avoid including whitespaces.
    """

    # 设置两个常量变量
    vocab_files_names = VOCAB_FILES_NAMES
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    # 设置最大模型输入尺寸为预训练位置嵌入大小
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    # 指定慢速分词器的类为 LEDTokenizer
    slow_tokenizer_class = LEDTokenizer
    # 模型输入的名称列表,包括 input_ids 和 attention_mask
    model_input_names = ["input_ids", "attention_mask"]

    # 以下内容是从 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast.__init__ 中复制过来的
    # 初始化方法
    def __init__(
        self,
        vocab_file=None,
        merges_file=None,
        tokenizer_file=None,
        errors="replace",
        bos_token="<s>",
        eos_token="</s>",
        sep_token="</s>",
        cls_token="<s>",
        unk_token="<unk>",
        pad_token="<pad>",
        mask_token="<mask>",
        add_prefix_space=False,
        trim_offsets=True,
        **kwargs,
    ):
        # 如果 `mask_token` 是字符串,创建一个带有特殊标志的 AddedToken 对象,用于表示特殊的 MASK 标记
        mask_token = (
            AddedToken(mask_token, lstrip=True, normalized=True, special=True)
            if isinstance(mask_token, str)
            else mask_token
        )
        # 调用父类的初始化方法,初始化 LEDTokenizerFast 对象
        super().__init__(
            vocab_file,
            merges_file,
            tokenizer_file=tokenizer_file,
            errors=errors,
            bos_token=bos_token,
            eos_token=eos_token,
            sep_token=sep_token,
            cls_token=cls_token,
            unk_token=unk_token,
            pad_token=pad_token,
            mask_token=mask_token,
            add_prefix_space=add_prefix_space,
            trim_offsets=trim_offsets,
            **kwargs,
        )

        # 获取当前前置处理器的状态,并检查是否需要更新 `add_prefix_space` 属性
        pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
        if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
            # 如果前置处理器的 `add_prefix_space` 属性不匹配当前设定,更新前置处理器的状态
            pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
            pre_tok_state["add_prefix_space"] = add_prefix_space
            self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)

        self.add_prefix_space = add_prefix_space

        # 检查后处理器的状态,并更新 `sep` 和 `cls` 标记为元组,以便与 LED 的 `post_processor` 兼容
        tokenizer_component = "post_processor"
        tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
        if tokenizer_component_instance:
            state = json.loads(tokenizer_component_instance.__getstate__())

            if "sep" in state:
                state["sep"] = tuple(state["sep"])
            if "cls" in state:
                state["cls"] = tuple(state["cls"])

            changes_to_apply = False

            # 检查后处理器的状态是否需要更新 `add_prefix_space` 和 `trim_offsets` 属性
            if state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
                state["add_prefix_space"] = add_prefix_space
                changes_to_apply = True

            if state.get("trim_offsets", trim_offsets) != trim_offsets:
                state["trim_offsets"] = trim_offsets
                changes_to_apply = True

            # 如果有更改需要应用,则创建新的后处理器实例并更新到 LEDTokenizerFast 对象中
            if changes_to_apply:
                component_class = getattr(processors, state.pop("type"))
                new_value = component_class(**state)
                setattr(self.backend_tokenizer, tokenizer_component, new_value)
    def mask_token(self) -> str:
        """
        `str`: 获取掩码标记,用于训练掩码语言建模的模型。如果尚未设置,则记录错误信息。
        
        LED 分词器具有特殊的掩码标记,用于填充掩码管道中的空白。掩码标记将贪婪地包括在 *<mask>* 前的空格。
        """
        # 如果掩码标记未设置,则记录错误信息并返回 None
        if self._mask_token is None:
            if self.verbose:
                logger.error("Using mask_token, but it is not set yet.")
            return None
        # 返回掩码标记的字符串表示
        return str(self._mask_token)

    @mask_token.setter
    def mask_token(self, value):
        """
        设置掩码标记的默认行为,使其在之前包含空格。

        这是为了与所有先前使用的基于 LED 的模型保持向后兼容所必需的。
        """
        # 如果值是字符串类型,则创建 AddedToken 对象,并设置 lstrip=True,rstrip=False,使掩码标记行为类似普通词
        value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value
        self._mask_token = value

    # 从 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast._batch_encode_plus 复制而来
    def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
        is_split_into_words = kwargs.get("is_split_into_words", False)

        # 如果输入被预分词且没有添加前缀空格,则抛出 ValueError
        if is_split_into_words and not self.add_prefix_space:
            raise ValueError(
                f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
                "to use it with pretokenized inputs."
            )

        # 调用父类的 _batch_encode_plus 方法进行批处理编码
        return super()._batch_encode_plus(*args, **kwargs)

    # 从 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast._encode_plus 复制而来
    def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
        is_split_into_words = kwargs.get("is_split_into_words", False)

        # 如果输入被预分词且没有添加前缀空格,则抛出 ValueError
        if is_split_into_words and not self.add_prefix_space:
            raise ValueError(
                f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
                "to use it with pretokenized inputs."
            )

        # 调用父类的 _encode_plus 方法进行编码
        return super()._encode_plus(*args, **kwargs)

    # 从 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast.save_vocabulary 复制而来
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # 调用内部的 tokenizer.model.save 方法保存词汇表到指定目录下
        files = self._tokenizer.model.save(save_directory, name=filename_prefix)
        return tuple(files)

    # 从 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast.build_inputs_with_special_tokens 复制而来
    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        # 构建带有特殊标记的输入,包括起始标记、终止标记
        output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
        if token_ids_1 is None:
            return output

        # 如果存在第二个输入序列,添加终止标记,并连接第二个输入序列及其终止标记
        return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
    # 从 BART -> LED 的转换中复制的方法,用于根据输入序列创建token类型ID
    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        创建用于序列对分类任务的掩码。LED 不使用token类型ID,因此返回一个由零组成的列表。
    
        Args:
            token_ids_0 (`List[int]`):
                第一个序列的ID列表。
            token_ids_1 (`List[int]`, *optional*):
                第二个序列的ID列表,用于序列对。
    
        Returns:
            `List[int]`: 全零列表。
        """
        sep = [self.sep_token_id]  # 分隔符的token ID列表
        cls = [self.cls_token_id]  # 类别标记的token ID列表
    
        if token_ids_1 is None:
            # 如果只有一个输入序列,则返回一个由零填充的列表
            return len(cls + token_ids_0 + sep) * [0]
        # 如果有两个输入序列,则返回一个由零填充的列表,包括两个分隔符
        return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
    
    # 从 transformers.models.led.tokenization_led.LEDTokenizer._pad 复制的方法
    def _pad(
        self,
        encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
        max_length: Optional[int] = None,
        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
        pad_to_multiple_of: Optional[int] = None,
        return_attention_mask: Optional[bool] = None,
    ) -> dict:
        encoded_inputs = super()._pad(
            encoded_inputs=encoded_inputs,
            max_length=max_length,
            padding_strategy=padding_strategy,
            pad_to_multiple_of=pad_to_multiple_of,
            return_attention_mask=return_attention_mask,
        )

使用 `super()._pad` 方法对输入进行填充操作,返回填充后的编码输入字典。


        # Load from model defaults
        if return_attention_mask is None:
            return_attention_mask = "attention_mask" in self.model_input_names

如果 `return_attention_mask` 为 `None`,则检查模型输入名称中是否包含 `"attention_mask"`,将其赋值给 `return_attention_mask`。


        if return_attention_mask and "global_attention_mask" in encoded_inputs:
            required_input = encoded_inputs[self.model_input_names[0]]
            # `global_attention_mask` need to have the same length as other (sequential) inputs.
            needs_to_be_padded = len(encoded_inputs["global_attention_mask"]) != len(required_input)

如果 `return_attention_mask` 为真且 `encoded_inputs` 中包含 `"global_attention_mask"`:
- 获取第一个模型输入的名称,并检查 `"global_attention_mask"` 的长度是否与该输入的长度相同。


            if needs_to_be_padded:
                difference = len(required_input) - len(encoded_inputs["global_attention_mask"])

                if self.padding_side == "right":
                    # Use `-1` since `0` in `global_attention_mask` means `local attention` instead of `not to attend`
                    encoded_inputs["global_attention_mask"] = (
                        encoded_inputs["global_attention_mask"] + [-1] * difference
                    )
                elif self.padding_side == "left":
                    encoded_inputs["global_attention_mask"] = [-1] * difference + encoded_inputs[
                        "global_attention_mask"
                    ]
                else:
                    raise ValueError("Invalid padding strategy:" + str(self.padding_side))

如果需要进行填充:
- 计算差异,确定填充方向(右侧或左侧),将 `-1` 添加到 `global_attention_mask` 以保持与其他输入相同的长度。


        return encoded_inputs

返回填充后的编码输入字典 `encoded_inputs`。

.\models\led\__init__.py

# 导入类型检查模块
from typing import TYPE_CHECKING

# 导入必要的模块和函数
from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_tf_available,
    is_tokenizers_available,
    is_torch_available,
)

# 定义导入结构的字典,包含LED模型配置和标记化
_import_structure = {
    "configuration_led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig"],
    "tokenization_led": ["LEDTokenizer"],
}

# 检查是否可用标记器,若不可用则引发异常
try:
    if not is_tokenizers_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,将LED快速标记化添加到导入结构中
    _import_structure["tokenization_led_fast"] = ["LEDTokenizerFast"]

# 检查是否可用PyTorch,若不可用则引发异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,将LED模型相关类添加到导入结构中
    _import_structure["modeling_led"] = [
        "LED_PRETRAINED_MODEL_ARCHIVE_LIST",
        "LEDForConditionalGeneration",
        "LEDForQuestionAnswering",
        "LEDForSequenceClassification",
        "LEDModel",
        "LEDPreTrainedModel",
    ]

# 检查是否可用TensorFlow,若不可用则引发异常
try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,将TensorFlow版LED模型相关类添加到导入结构中
    _import_structure["modeling_tf_led"] = ["TFLEDForConditionalGeneration", "TFLEDModel", "TFLEDPreTrainedModel"]

# 如果正在进行类型检查
if TYPE_CHECKING:
    # 从相应模块导入LED模型配置和标记化类
    from .configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig
    from .tokenization_led import LEDTokenizer

    # 检查是否可用标记器,若不可用则忽略
    try:
        if not is_tokenizers_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果可用,从快速标记化模块导入LED快速标记化类
        from .tokenization_led_fast import LEDTokenizerFast

    # 检查是否可用PyTorch,若不可用则忽略
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果可用,从LED建模模块导入相关类
        from .modeling_led import (
            LED_PRETRAINED_MODEL_ARCHIVE_LIST,
            LEDForConditionalGeneration,
            LEDForQuestionAnswering,
            LEDForSequenceClassification,
            LEDModel,
            LEDPreTrainedModel,
        )

    # 检查是否可用TensorFlow,若不可用则忽略
    try:
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果可用,从TensorFlow版LED建模模块导入相关类
        from .modeling_tf_led import TFLEDForConditionalGeneration, TFLEDModel, TFLEDPreTrainedModel

# 如果不是类型检查阶段,导入sys模块
else:
    import sys
    # 将当前模块注册到 sys.modules 中,使用 _LazyModule 进行延迟加载
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\levit\configuration_levit.py

# coding=utf-8
# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
LeViT model configuration
"""

# 导入需要的库
from collections import OrderedDict  # 导入有序字典模块
from typing import Mapping  # 导入 Mapping 类型提示

from packaging import version  # 导入版本相关的模块

# 导入配置相关的工具函数和类
from ...configuration_utils import PretrainedConfig  # 导入预训练配置类
from ...onnx import OnnxConfig  # 导入 ONNX 配置类
from ...utils import logging  # 导入日志相关的工具函数

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 预训练模型名称与其配置文件的映射字典
LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "facebook/levit-128S": "https://huggingface.co/facebook/levit-128S/resolve/main/config.json",
    # 查看所有 LeViT 模型的列表:https://huggingface.co/models?filter=levit
}


class LevitConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`LevitModel`]. It is used to instantiate a LeViT
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the LeViT
    [facebook/levit-128S](https://huggingface.co/facebook/levit-128S) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    """
    # 设定模型类型为 "levit"
    model_type = "levit"

    # 初始化函数,设置模型的各项参数
    def __init__(
        self,
        image_size=224,  # 输入图像的尺寸,默认为224
        num_channels=3,  # 输入图像的通道数,默认为3
        kernel_size=3,   # 初始卷积层的卷积核大小,默认为3
        stride=2,        # 初始卷积层的步长大小,默认为2
        padding=1,       # 初始卷积层的填充大小,默认为1
        patch_size=16,   # 嵌入的补丁大小,默认为16
        hidden_sizes=[128, 256, 384],     # 每个编码器块的隐藏层维度,默认为[128, 256, 384]
        num_attention_heads=[4, 8, 12],   # 每个Transformer编码器块中注意力层的注意力头数,默认为[4, 8, 12]
        depths=[4, 4, 4],                 # 每个编码器块中的层的数量,默认为[4, 4, 4]
        key_dim=[16, 16, 16],             # 每个编码器块中键的大小,默认为[16, 16, 16]
        drop_path_rate=0,                 # 用于随机深度中的dropout概率,默认为0
        mlp_ratio=[2, 2, 2],              # Mix FFNs中隐藏层大小与输入层大小的比例,默认为[2, 2, 2]
        attention_ratio=[2, 2, 2],        # 注意力层输出维度与输入维度的比例,默认为[2, 2, 2]
        initializer_range=0.02,           # 初始化所有权重矩阵的截断正态分布标准差,默认为0.02
        **kwargs,                         # 其他参数,使用关键字参数方式接收
    ):
        ):
            # 调用父类的初始化方法,传入所有的关键字参数
            super().__init__(**kwargs)
            # 设置图像大小
            self.image_size = image_size
            # 设置通道数
            self.num_channels = num_channels
            # 设置卷积核大小
            self.kernel_size = kernel_size
            # 设置步长
            self.stride = stride
            # 设置填充
            self.padding = padding
            # 设置隐藏层大小
            self.hidden_sizes = hidden_sizes
            # 设置注意力头数目
            self.num_attention_heads = num_attention_heads
            # 设置深度
            self.depths = depths
            # 设置键的维度
            self.key_dim = key_dim
            # 设置丢弃路径的比率
            self.drop_path_rate = drop_path_rate
            # 设置补丁大小
            self.patch_size = patch_size
            # 设置注意力比率
            self.attention_ratio = attention_ratio
            # 设置MLP比率
            self.mlp_ratio = mlp_ratio
            # 设置初始化器范围
            self.initializer_range = initializer_range
            # 设置下采样操作列表
            self.down_ops = [
                # 第一个下采样操作
                ["Subsample", key_dim[0], hidden_sizes[0] // key_dim[0], 4, 2, 2],
                # 第二个下采样操作
                ["Subsample", key_dim[0], hidden_sizes[1] // key_dim[0], 4, 2, 2],
            ]
# 从transformers.models.vit.configuration_vit.ViTOnnxConfig复制而来的LevitOnnxConfig类
class LevitOnnxConfig(OnnxConfig):
    # 定义torch_onnx_minimum_version属性为1.11版本
    torch_onnx_minimum_version = version.parse("1.11")

    # 定义inputs属性为一个OrderedDict,包含映射关系
    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        return OrderedDict(
            [
                # 输入映射,将输入通道名称映射到索引位置
                ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
            ]
        )

    # 定义atol_for_validation属性为浮点数1e-4,用于验证时的绝对误差容忍度
    @property
    def atol_for_validation(self) -> float:
        return 1e-4

.\models\levit\convert_levit_timm_to_pytorch.py

# 设置编码格式为 UTF-8
# 版权声明,这段代码由 HuggingFace Inc. 团队版权所有,遵循 Apache License, Version 2.0 授权
#
# 根据许可证规定,除非符合许可证的条件,否则不得使用此文件
# 可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件按"原样"分发,
# 没有任何形式的担保或条件,无论是明示的还是默示的。
# 详见许可证了解更多信息。
"""从 timm 转换 LeViT 检查点。"""

# 导入必要的库和模块
import argparse  # 用于解析命令行参数
import json  # 用于处理 JSON 格式数据
from collections import OrderedDict  # 有序字典,用于按照插入顺序存储键值对
from functools import partial  # 用于创建偏函数
from pathlib import Path  # 处理文件路径的类库

import timm  # 提供预训练模型的创建和管理
import torch  # PyTorch 深度学习框架
from huggingface_hub import hf_hub_download  # 用于从 HuggingFace Hub 下载模型和文件

from transformers import LevitConfig, LevitForImageClassificationWithTeacher, LevitImageProcessor  # LeViT 模型相关类
from transformers.utils import logging  # 日志记录模块

# 设置日志输出为 info 级别
logging.set_verbosity_info()
logger = logging.get_logger()

# 定义函数:转换权重并推送到 Hub
def convert_weight_and_push(
    hidden_sizes: int, name: str, config: LevitConfig, save_directory: Path, push_to_hub: bool = True
):
    print(f"Converting {name}...")

    # 禁用梯度计算
    with torch.no_grad():
        # 根据不同的 hidden_sizes 加载不同的 LeViT 模型
        if hidden_sizes == 128:
            if name[-1] == "S":
                from_model = timm.create_model("levit_128s", pretrained=True)
            else:
                from_model = timm.create_model("levit_128", pretrained=True)
        elif hidden_sizes == 192:
            from_model = timm.create_model("levit_192", pretrained=True)
        elif hidden_sizes == 256:
            from_model = timm.create_model("levit_256", pretrained=True)
        elif hidden_sizes == 384:
            from_model = timm.create_model("levit_384", pretrained=True)

        # 设置模型为评估模式
        from_model.eval()
        our_model = LevitForImageClassificationWithTeacher(config).eval()
        huggingface_weights = OrderedDict()

        # 获取源模型的权重,并根据键的映射将其赋给新模型
        weights = from_model.state_dict()
        og_keys = list(from_model.state_dict().keys())
        new_keys = list(our_model.state_dict().keys())
        print(len(og_keys), len(new_keys))
        for i in range(len(og_keys)):
            huggingface_weights[new_keys[i]] = weights[og_keys[i]]
        our_model.load_state_dict(huggingface_weights)

        # 创建随机输入张量并计算两个模型的输出结果
        x = torch.randn((2, 3, 224, 224))
        out1 = from_model(x)
        out2 = our_model(x).logits

    # 检查两个模型输出是否相等
    assert torch.allclose(out1, out2), "The model logits don't match the original one."

    # 设置检查点名称为模型名称
    checkpoint_name = name
    print(checkpoint_name)

    # 如果指定推送到 Hub,则保存模型和相关处理器,并输出推送成功信息
    if push_to_hub:
        our_model.save_pretrained(save_directory / checkpoint_name)
        image_processor = LevitImageProcessor()
        image_processor.save_pretrained(save_directory / checkpoint_name)

        print(f"Pushed {checkpoint_name}")

# 定义函数:转换权重并推送到 Hub
def convert_weights_and_push(save_directory: Path, model_name: str = None, push_to_hub: bool = True):
    filename = "imagenet-1k-id2label.json"
    num_labels = 1000
    # 预期模型输出的形状为 (1, num_labels)
    expected_shape = (1, num_labels)

    # 定义用于下载模型配置的 Hugging Face 仓库 ID
    repo_id = "huggingface/label-files"
    # 将 num_labels 赋值给变量 num_labels
    num_labels = num_labels
    # 使用 Hugging Face Hub 下载指定文件名的数据集,并加载为 JSON 格式
    id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
    # 将 id2label 中的键转换为整数类型,并保留原始值
    id2label = {int(k): v for k, v in id2label.items()}

    # 将 id2label 赋值给变量 id2label
    id2label = id2label
    # 创建一个将标签映射到 ID 的字典
    label2id = {v: k for k, v in id2label.items()}

    # 定义一个部分应用的函数,使用 ImageNet 预训练配置创建 LevitConfig
    ImageNetPreTrainedConfig = partial(LevitConfig, num_labels=num_labels, id2label=id2label, label2id=label2id)

    # 定义不同 Levit 模型名称到隐藏层大小的映射关系
    names_to_hidden_sizes = {
        "levit-128S": 128,
        "levit-128": 128,
        "levit-192": 192,
        "levit-256": 256,
        "levit-384": 384,
    }

    # 定义不同 Levit 模型名称到其配置对象的映射关系
    names_to_config = {
        "levit-128S": ImageNetPreTrainedConfig(
            hidden_sizes=[128, 256, 384],
            num_attention_heads=[4, 6, 8],
            depths=[2, 3, 4],
            key_dim=[16, 16, 16],
            drop_path_rate=0,
        ),
        "levit-128": ImageNetPreTrainedConfig(
            hidden_sizes=[128, 256, 384],
            num_attention_heads=[4, 8, 12],
            depths=[4, 4, 4],
            key_dim=[16, 16, 16],
            drop_path_rate=0,
        ),
        "levit-192": ImageNetPreTrainedConfig(
            hidden_sizes=[192, 288, 384],
            num_attention_heads=[3, 5, 6],
            depths=[4, 4, 4],
            key_dim=[32, 32, 32],
            drop_path_rate=0,
        ),
        "levit-256": ImageNetPreTrainedConfig(
            hidden_sizes=[256, 384, 512],
            num_attention_heads=[4, 6, 8],
            depths=[4, 4, 4],
            key_dim=[32, 32, 32],
            drop_path_rate=0,
        ),
        "levit-384": ImageNetPreTrainedConfig(
            hidden_sizes=[384, 512, 768],
            num_attention_heads=[6, 9, 12],
            depths=[4, 4, 4],
            key_dim=[32, 32, 32],
            drop_path_rate=0.1,
        ),
    }

    # 如果给定了模型名称,则转换权重并推送到指定的 Hub
    if model_name:
        convert_weight_and_push(
            names_to_hidden_sizes[model_name], model_name, names_to_config[model_name], save_directory, push_to_hub
        )
    else:  # 否则对所有模型进行转换权重并推送操作
        for model_name, config in names_to_config.items():
            convert_weight_and_push(names_to_hidden_sizes[model_name], model_name, config, save_directory, push_to_hub)
    
    # 返回最终的配置对象和预期的输出形状
    return config, expected_shape
if __name__ == "__main__":
    # 如果当前脚本作为主程序执行,则执行以下代码块

    parser = argparse.ArgumentParser()
    # 创建命令行参数解析器对象

    # 必选参数
    parser.add_argument(
        "--model_name",
        default=None,
        type=str,
        help="The name of the model you wish to convert, it must be one of the supported Levit* architecture,",
    )
    # 添加模型名称参数,指定需要转换的模型名称,必须是支持的 Levit* 架构之一

    parser.add_argument(
        "--pytorch_dump_folder_path",
        default="levit-dump-folder/",
        type=Path,
        required=False,
        help="Path to the output PyTorch model directory.",
    )
    # 添加 PyTorch 模型输出文件夹路径参数,默认为 'levit-dump-folder/',指定输出 PyTorch 模型的目录路径

    parser.add_argument("--push_to_hub", action="store_true", help="Push model and image processor to the hub")
    # 添加推送到 Hub 的选项参数,如果设置该参数,则推送模型和图像处理器到 Hub

    parser.add_argument(
        "--no-push_to_hub",
        dest="push_to_hub",
        action="store_false",
        help="Do not push model and image processor to the hub",
    )
    # 添加不推送到 Hub 的选项参数,设置该参数则不将模型和图像处理器推送到 Hub

    args = parser.parse_args()
    # 解析命令行参数,将参数存储在 args 对象中

    pytorch_dump_folder_path: Path = args.pytorch_dump_folder_path
    # 获取 PyTorch 模型输出文件夹路径,并将其赋值给 pytorch_dump_folder_path 变量
    pytorch_dump_folder_path.mkdir(exist_ok=True, parents=True)
    # 创建 PyTorch 模型输出文件夹,如果不存在则创建,确保存在父文件夹路径

    convert_weights_and_push(pytorch_dump_folder_path, args.model_name, args.push_to_hub)
    # 调用函数,将权重转换并推送到指定的 PyTorch 模型文件夹路径,使用指定的模型名称和推送到 Hub 的标志

.\models\levit\feature_extraction_levit.py

# coding=utf-8
# 版权 2022 年 Meta Platforms, Inc. 和 HuggingFace Inc. 团队。保留所有权利。
#
# 根据 Apache 许可证 2.0 版本(“许可证”)许可;
# 除非符合许可证的规定,否则您不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则根据“许可证”分发的软件是基于“按原样提供”的基础上分发的,
# 没有任何明示或暗示的保证或条件。
# 有关具体语言的详细信息,请参阅许可证。
"""LeViT 的特征提取器类。"""

# 导入警告模块
import warnings

# 从当前包的 utils 模块中导入 logging 功能
from ...utils import logging
# 从本地模块中导入 LevitImageProcessor 类
from .image_processing_levit import LevitImageProcessor

# 获取 logger 对象
logger = logging.get_logger(__name__)

# 定义 LevitFeatureExtractor 类,继承自 LevitImageProcessor 类
class LevitFeatureExtractor(LevitImageProcessor):
    
    # 初始化方法,接受任意位置参数和关键字参数,并发出警告
    def __init__(self, *args, **kwargs) -> None:
        # 发出警告,提醒 LevitFeatureExtractor 类即将在 Transformers 的第五个版本中被移除,建议使用 LevitImageProcessor 替代
        warnings.warn(
            "The class LevitFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
            " use LevitImageProcessor instead.",
            FutureWarning,
        )
        # 调用父类(LevitImageProcessor)的初始化方法
        super().__init__(*args, **kwargs)

.\models\levit\image_processing_levit.py

# 引入必要的库和模块
from typing import Dict, Iterable, Optional, Union

import numpy as np  # 导入 NumPy 库

# 导入图像处理相关的工具函数和类
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
    get_resize_output_image_size,
    resize,
    to_channel_dimension_format,
)
from ...image_utils import (
    IMAGENET_DEFAULT_MEAN,  # 导入图像处理的常量,如默认均值和标准差
    IMAGENET_DEFAULT_STD,
    ChannelDimension,
    ImageInput,
    PILImageResampling,
    infer_channel_dimension_format,
    is_scaled_image,
    make_list_of_images,
    to_numpy_array,
    valid_images,
    validate_kwargs,
    validate_preprocess_arguments,
)
from ...utils import TensorType, logging  # 导入工具函数和日志记录器

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)


class LevitImageProcessor(BaseImageProcessor):
    r"""
    Constructs a LeViT image processor.
    """
    Args:
        do_resize (`bool`, *optional*, defaults to `True`):
            是否调整输入图像的最短边至 int(256/224 * size),可以在 `preprocess` 方法中的 `do_resize` 参数中覆盖。
        size (`Dict[str, int]`, *optional*, defaults to `{"shortest_edge": 224}`):
            调整后的输出图像尺寸。如果 `size` 是一个包含 "width" 和 "height" 键的字典,图像将被调整至 `(size["height"], size["width"])`。如果 `size` 是一个包含 "shortest_edge" 键的字典,最短边的值 `c` 将被重新缩放为 `int(c * (256/224))`。图像的较小边将被匹配到此值,例如,如果 height > width,则图像将被缩放至 `(size["shortest_edge"] * height / width, size["shortest_edge"])`。可以在 `preprocess` 方法中的 `size` 参数中覆盖。
        resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
            如果调整图像大小,使用的重采样滤波器。可以在 `preprocess` 方法中的 `resample` 参数中覆盖。
        do_center_crop (`bool`, *optional*, defaults to `True`):
            是否对输入图像进行中心裁剪至 `(crop_size["height"], crop_size["width"])`。可以在 `preprocess` 方法中的 `do_center_crop` 参数中覆盖。
        crop_size (`Dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
            `center_crop` 后的期望图像尺寸。可以在 `preprocess` 方法中的 `crop_size` 参数中覆盖。
        do_rescale (`bool`, *optional*, defaults to `True`):
            控制是否按指定的比例因子 `rescale_factor` 重新缩放图像。可以在 `preprocess` 方法中的 `do_rescale` 参数中覆盖。
        rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
            如果重新缩放图像,要使用的缩放因子。可以在 `preprocess` 方法中的 `rescale_factor` 参数中覆盖。
        do_normalize (`bool`, *optional*, defaults to `True`):
            控制是否对图像进行归一化。可以在 `preprocess` 方法中的 `do_normalize` 参数中覆盖。
        image_mean (`List[int]`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
            如果归一化图像,要使用的均值。这是一个浮点数或与图像通道数相同长度的浮点数列表。可以在 `preprocess` 方法中的 `image_mean` 参数中覆盖。
        image_std (`List[int]`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
            如果归一化图像,要使用的标准差。这是一个浮点数或与图像通道数相同长度的浮点数列表。可以在 `preprocess` 方法中的 `image_std` 参数中覆盖。
    """

    model_input_names = ["pixel_values"]
    def __init__(
        self,
        do_resize: bool = True,
        size: Dict[str, int] = None,
        resample: PILImageResampling = PILImageResampling.BICUBIC,
        do_center_crop: bool = True,
        crop_size: Dict[str, int] = None,
        do_rescale: bool = True,
        rescale_factor: Union[int, float] = 1 / 255,
        do_normalize: bool = True,
        image_mean: Optional[Union[float, Iterable[float]]] = IMAGENET_DEFAULT_MEAN,
        image_std: Optional[Union[float, Iterable[float]]] = IMAGENET_DEFAULT_STD,
        **kwargs,
    ) -> None:
        # 调用父类初始化方法
        super().__init__(**kwargs)
        # 如果 size 参数为 None,则设置默认的最短边为 224
        size = size if size is not None else {"shortest_edge": 224}
        # 根据给定的 size 参数获取大小的字典,确保不会默认为正方形
        size = get_size_dict(size, default_to_square=False)
        # 如果 crop_size 参数为 None,则设置默认的高度和宽度均为 224
        crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
        # 根据给定的 crop_size 参数获取裁剪大小的字典
        crop_size = get_size_dict(crop_size, param_name="crop_size")

        # 初始化类成员变量
        self.do_resize = do_resize
        self.size = size
        self.resample = resample
        self.do_center_crop = do_center_crop
        self.crop_size = crop_size
        self.do_rescale = do_rescale
        self.rescale_factor = rescale_factor
        self.do_normalize = do_normalize
        self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
        self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
        # 设置有效的处理器关键字列表,包括图像处理相关的参数和数据格式参数
        self._valid_processor_keys = [
            "images",
            "do_resize",
            "size",
            "resample",
            "do_center_crop",
            "crop_size",
            "do_rescale",
            "rescale_factor",
            "do_normalize",
            "image_mean",
            "image_std",
            "return_tensors",
            "data_format",
            "input_data_format",
        ]
    ) -> np.ndarray:
        """
        Resize an image.

        If size is a dict with keys "width" and "height", the image will be resized to `(size["height"],
        size["width"])`.

        If size is a dict with key "shortest_edge", the shortest edge value `c` is rescaled to `int(c * (256/224))`.
        The smaller edge of the image will be matched to this value i.e, if height > width, then image will be rescaled
        to `(size["shortest_egde"] * height / width, size["shortest_egde"])`.

        Args:
            image (`np.ndarray`):
                Image to resize.
            size (`Dict[str, int]`):
                Size of the output image after resizing. If size is a dict with keys "width" and "height", the image
                will be resized to (height, width). If size is a dict with key "shortest_edge", the shortest edge value
                `c` is rescaled to int(`c` * (256/224)). The smaller edge of the image will be matched to this value
                i.e, if height > width, then image will be rescaled to (size * height / width, size).
            resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
                Resampling filter to use when resizing the image.
            data_format (`str` or `ChannelDimension`, *optional*):
                The channel dimension format of the image. If not provided, it will be the same as the input image.
            input_data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format of the input image. If not provided, it will be inferred.
        """
        # Determine the target size dictionary based on input size parameters
        size_dict = get_size_dict(size, default_to_square=False)

        # Check if 'shortest_edge' is specified in size dictionary
        if "shortest_edge" in size:
            # Calculate the length of the shortest edge based on the scaling factor (256/224)
            shortest_edge = int((256 / 224) * size["shortest_edge"])
            # Determine the output size after resizing based on the calculated shortest edge
            output_size = get_resize_output_image_size(
                image, size=shortest_edge, default_to_square=False, input_data_format=input_data_format
            )
            # Update size_dict to reflect the height and width after resizing
            size_dict = {"height": output_size[0], "width": output_size[1]}

        # Ensure that the size_dict contains both 'height' and 'width' keys
        if "height" not in size_dict or "width" not in size_dict:
            # Raise an error if the size_dict does not have the required keys
            raise ValueError(
                f"Size dict must have keys 'height' and 'width' or 'shortest_edge'. Got {size_dict.keys()}"
            )

        # Resize the image to the specified dimensions using the resize function
        return resize(
            image,
            size=(size_dict["height"], size_dict["width"]),
            resample=resample,
            data_format=data_format,
            input_data_format=input_data_format,
            **kwargs,
        )
    # 定义一个预处理方法,用于处理图像数据
    def preprocess(
        self,
        images: ImageInput,  # 图像输入,可以是单张图像或图像列表
        do_resize: Optional[bool] = None,  # 是否调整大小的标志,默认为None
        size: Optional[Dict[str, int]] = None,  # 调整大小的目标尺寸,字典形式,包含宽和高
        resample: PILImageResampling = None,  # 调整大小时使用的重采样方法,默认为None
        do_center_crop: Optional[bool] = None,  # 是否进行中心裁剪的标志,默认为None
        crop_size: Optional[Dict[str, int]] = None,  # 中心裁剪的目标尺寸,字典形式,包含宽和高
        do_rescale: Optional[bool] = None,  # 是否进行重新缩放的标志,默认为None
        rescale_factor: Optional[float] = None,  # 重新缩放的因子,默认为None
        do_normalize: Optional[bool] = None,  # 是否进行归一化的标志,默认为None
        image_mean: Optional[Union[float, Iterable[float]]] = None,  # 图像归一化的均值,默认为None
        image_std: Optional[Union[float, Iterable[float]]] = None,  # 图像归一化的标准差,默认为None
        return_tensors: Optional[TensorType] = None,  # 返回的张量类型,默认为None
        data_format: ChannelDimension = ChannelDimension.FIRST,  # 数据的通道格式,默认为第一通道
        input_data_format: Optional[Union[str, ChannelDimension]] = None,  # 输入数据的通道格式,默认为None
        **kwargs,  # 其他可能的关键字参数,以字典形式接收

.\models\levit\modeling_levit.py

# coding=utf-8
# 声明文件编码为 UTF-8

# 版权声明及许可证信息

# 导入所需的库和模块
import itertools
from dataclasses import dataclass
from typing import Optional, Tuple, Union

# 导入 PyTorch 相关模块
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

# 导入模型输出类
from ...modeling_outputs import (
    BaseModelOutputWithNoAttention,
    BaseModelOutputWithPoolingAndNoAttention,
    ImageClassifierOutputWithNoAttention,
    ModelOutput,
)

# 导入预训练模型基类
from ...modeling_utils import PreTrainedModel

# 导入工具函数和日志记录
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging

# 导入 LevitConfig 配置类
from .configuration_levit import LevitConfig

# 获取 logger 实例
logger = logging.get_logger(__name__)

# 模型配置文档字符串
_CONFIG_FOR_DOC = "LevitConfig"

# 模型检查点文档字符串
_CHECKPOINT_FOR_DOC = "facebook/levit-128S"

# 预期输出形状
_EXPECTED_OUTPUT_SHAPE = [1, 16, 384]

# 图像分类模型检查点
_IMAGE_CLASS_CHECKPOINT = "facebook/levit-128S"

# 图像分类预期输出示例
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"

# 预训练模型存档列表
LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "facebook/levit-128S",
    # 更多 Levit 模型可在 https://huggingface.co/models?filter=levit 查看
]

# LevitForImageClassificationWithTeacherOutput 类,继承自 ModelOutput 类
@dataclass
class LevitForImageClassificationWithTeacherOutput(ModelOutput):
    """
    [`LevitForImageClassificationWithTeacher`] 的输出类型。
    """

# 此处为代码块结束
    # logits参数是一个形状为(batch_size, config.num_labels)的张量,包含了预测分数,
    # 这些分数是cls_logits和distillation_logits的平均值。
    # cls_logits是分类头部的预测分数,即在类标记的最终隐藏状态之上的线性层。
    # distillation_logits是蒸馏头部的预测分数,即在蒸馏标记的最终隐藏状态之上的线性层。
    # hidden_states参数是一个可选的元组,包含了模型每一层的隐藏状态张量,
    # 形状为(batch_size, sequence_length, hidden_size),包括初始嵌入输出。
    
    logits: torch.FloatTensor = None
    cls_logits: torch.FloatTensor = None
    distillation_logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class LevitConvEmbeddings(nn.Module):
    """
    LeViT Conv Embeddings with Batch Norm, used in the initial patch embedding layer.
    """

    def __init__(
        self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, groups=1, bn_weight_init=1
    ):
        super().__init__()
        # 定义卷积层,用于将输入的图像数据转换成特征图
        self.convolution = nn.Conv2d(
            in_channels, out_channels, kernel_size, stride, padding, dilation=dilation, groups=groups, bias=False
        )
        # 定义批归一化层,用于规范化卷积输出的特征图
        self.batch_norm = nn.BatchNorm2d(out_channels)

    def forward(self, embeddings):
        # 执行卷积操作,将输入的嵌入数据转换成特征图
        embeddings = self.convolution(embeddings)
        # 执行批归一化操作,规范化卷积输出的特征图
        embeddings = self.batch_norm(embeddings)
        return embeddings


class LevitPatchEmbeddings(nn.Module):
    """
    LeViT patch embeddings, for final embeddings to be passed to transformer blocks. It consists of multiple
    `LevitConvEmbeddings`.
    """

    def __init__(self, config):
        super().__init__()
        # 第一个卷积嵌入层及其激活函数
        self.embedding_layer_1 = LevitConvEmbeddings(
            config.num_channels, config.hidden_sizes[0] // 8, config.kernel_size, config.stride, config.padding
        )
        self.activation_layer_1 = nn.Hardswish()

        # 第二个卷积嵌入层及其激活函数
        self.embedding_layer_2 = LevitConvEmbeddings(
            config.hidden_sizes[0] // 8, config.hidden_sizes[0] // 4, config.kernel_size, config.stride, config.padding
        )
        self.activation_layer_2 = nn.Hardswish()

        # 第三个卷积嵌入层及其激活函数
        self.embedding_layer_3 = LevitConvEmbeddings(
            config.hidden_sizes[0] // 4, config.hidden_sizes[0] // 2, config.kernel_size, config.stride, config.padding
        )
        self.activation_layer_3 = nn.Hardswish()

        # 第四个卷积嵌入层,不带激活函数
        self.embedding_layer_4 = LevitConvEmbeddings(
            config.hidden_sizes[0] // 2, config.hidden_sizes[0], config.kernel_size, config.stride, config.padding
        )
        self.num_channels = config.num_channels

    def forward(self, pixel_values):
        # 检查输入的像素值张量是否与配置中设置的通道数匹配
        num_channels = pixel_values.shape[1]
        if num_channels != self.num_channels:
            raise ValueError(
                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
            )
        # 依次执行四个卷积嵌入层及其后的激活函数
        embeddings = self.embedding_layer_1(pixel_values)
        embeddings = self.activation_layer_1(embeddings)
        embeddings = self.embedding_layer_2(embeddings)
        embeddings = self.activation_layer_2(embeddings)
        embeddings = self.embedding_layer_3(embeddings)
        embeddings = self.activation_layer_3(embeddings)
        embeddings = self.embedding_layer_4(embeddings)
        # 将结果展平并转置,以便传递给变压器块
        return embeddings.flatten(2).transpose(1, 2)


class MLPLayerWithBN(nn.Module):
    """
    MLP layer with Batch Norm, used in the transformer blocks.
    """

    def __init__(self, input_dim, output_dim, bn_weight_init=1):
        super().__init__()
        # 定义线性层,用于进行全连接操作
        self.linear = nn.Linear(in_features=input_dim, out_features=output_dim, bias=False)
        # 定义批归一化层,用于规范化全连接层的输出
        self.batch_norm = nn.BatchNorm1d(output_dim)
    # 定义前向传播方法,接收隐藏状态作为输入参数
    def forward(self, hidden_state):
        # 将隐藏状态通过线性层进行变换
        hidden_state = self.linear(hidden_state)
        # 将变换后的隐藏状态展平并应用批归一化处理
        hidden_state = self.batch_norm(hidden_state.flatten(0, 1)).reshape_as(hidden_state)
        # 返回处理后的隐藏状态
        return hidden_state
# 定义一个名为 LevitSubsample 的自定义神经网络模块,继承自 nn.Module
class LevitSubsample(nn.Module):
    # 初始化函数,接受步长(stride)和分辨率(resolution)两个参数
    def __init__(self, stride, resolution):
        super().__init__()
        # 设置对象属性 stride 和 resolution
        self.stride = stride
        self.resolution = resolution

    # 前向传播函数,接受隐藏状态(hidden_state)作为输入
    def forward(self, hidden_state):
        # 获取输入张量的批量大小(batch_size)、通道数(channels)
        batch_size, _, channels = hidden_state.shape
        # 将隐藏状态重新视图化为指定分辨率的形状,并进行下采样
        hidden_state = hidden_state.view(batch_size, self.resolution, self.resolution, channels)[
            :, :: self.stride, :: self.stride
        ].reshape(batch_size, -1, channels)
        # 返回下采样后的隐藏状态张量
        return hidden_state


# 定义一个名为 LevitAttention 的自定义神经网络模块,继承自 nn.Module
class LevitAttention(nn.Module):
    # 初始化函数,接受隐藏层大小(hidden_sizes)、键维度(key_dim)、注意力头数(num_attention_heads)、
    # 注意力比率(attention_ratio)、分辨率(resolution)五个参数
    def __init__(self, hidden_sizes, key_dim, num_attention_heads, attention_ratio, resolution):
        super().__init__()
        # 设置对象属性
        self.num_attention_heads = num_attention_heads
        self.scale = key_dim**-0.5
        self.key_dim = key_dim
        self.attention_ratio = attention_ratio
        # 计算键-值对和投影输出的维度
        self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads * 2
        self.out_dim_projection = attention_ratio * key_dim * num_attention_heads

        # 创建查询、键和值的 MLP 层并进行批归一化
        self.queries_keys_values = MLPLayerWithBN(hidden_sizes, self.out_dim_keys_values)
        # 激活函数采用 Hardswish
        self.activation = nn.Hardswish()
        # 创建投影层的 MLP 层并进行批归一化
        self.projection = MLPLayerWithBN(self.out_dim_projection, hidden_sizes, bn_weight_init=0)

        # 生成所有可能点的笛卡尔积
        points = list(itertools.product(range(resolution), range(resolution)))
        len_points = len(points)
        attention_offsets, indices = {}, []

        # 计算所有点对之间的偏移量及其对应的索引
        for p1 in points:
            for p2 in points:
                offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
                if offset not in attention_offsets:
                    attention_offsets[offset] = len(attention_offsets)
                indices.append(attention_offsets[offset])

        # 初始化注意力偏置的缓存和参数
        self.attention_bias_cache = {}
        self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets)))
        self.register_buffer(
            "attention_bias_idxs", torch.LongTensor(indices).view(len_points, len_points), persistent=False
        )

    # 用于训练时无梯度更新的装饰器函数
    @torch.no_grad()
    def train(self, mode=True):
        super().train(mode)
        # 如果是训练模式且存在注意力偏置缓存,则清空缓存
        if mode and self.attention_bias_cache:
            self.attention_bias_cache = {}  # 清空注意力偏置缓存

    # 获取注意力偏置的函数,根据设备类型缓存不同的注意力偏置
    def get_attention_biases(self, device):
        if self.training:
            # 如果是训练模式,则直接返回计算得到的注意力偏置
            return self.attention_biases[:, self.attention_bias_idxs]
        else:
            # 如果是推断模式,则根据设备类型缓存注意力偏置
            device_key = str(device)
            if device_key not in self.attention_bias_cache:
                self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
            return self.attention_bias_cache[device_key]
    # 定义一个前向传播函数,接受隐藏状态作为输入
    def forward(self, hidden_state):
        # 获取输入隐藏状态的批大小、序列长度和特征维度
        batch_size, seq_length, _ = hidden_state.shape
        # 使用self.queries_keys_values方法计算查询、键和值
        queries_keys_values = self.queries_keys_values(hidden_state)
        # 将查询、键、值重新组织成指定形状,以便进行多头注意力计算
        query, key, value = queries_keys_values.view(batch_size, seq_length, self.num_attention_heads, -1).split(
            [self.key_dim, self.key_dim, self.attention_ratio * self.key_dim], dim=3
        )
        # 将查询张量转置,以适应多头注意力计算的形状要求
        query = query.permute(0, 2, 1, 3)
        key = key.permute(0, 2, 1, 3)
        value = value.permute(0, 2, 1, 3)

        # 计算注意力分数,包括缩放和注意力偏置
        attention = query @ key.transpose(-2, -1) * self.scale + self.get_attention_biases(hidden_state.device)
        # 对注意力分数进行 softmax 归一化
        attention = attention.softmax(dim=-1)
        # 计算加权后的值向量,然后重新排列以恢复原始形状
        hidden_state = (attention @ value).transpose(1, 2).reshape(batch_size, seq_length, self.out_dim_projection)
        # 应用激活函数、投影层和最终投影,得到最终的隐藏状态表示
        hidden_state = self.projection(self.activation(hidden_state))
        # 返回处理后的隐藏状态
        return hidden_state
class LevitAttentionSubsample(nn.Module):
    # LevitAttentionSubsample 类,继承自 nn.Module
    def __init__(
        self,
        input_dim,
        output_dim,
        key_dim,
        num_attention_heads,
        attention_ratio,
        stride,
        resolution_in,
        resolution_out,
    ):
        # 初始化函数,设置模块参数
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.scale = key_dim**-0.5  # 缩放因子,用于缩放注意力机制中的键值
        self.key_dim = key_dim  # 注意力键的维度
        self.attention_ratio = attention_ratio  # 注意力比率
        self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads
        self.out_dim_projection = attention_ratio * key_dim * num_attention_heads
        self.resolution_out = resolution_out
        # resolution_in 是初始分辨率,resolution_out 是下采样后的最终分辨率

        # 初始化模块:MLPLayerWithBN 是带批量归一化的 MLP 层
        self.keys_values = MLPLayerWithBN(input_dim, self.out_dim_keys_values)
        self.queries_subsample = LevitSubsample(stride, resolution_in)  # 对查询进行下采样
        self.queries = MLPLayerWithBN(input_dim, key_dim * num_attention_heads)  # 查询的 MLP 层
        self.activation = nn.Hardswish()  # 激活函数使用 Hardswish
        self.projection = MLPLayerWithBN(self.out_dim_projection, output_dim)  # 投影到最终输出维度

        self.attention_bias_cache = {}  # 初始化注意力偏置缓存

        # 计算注意力偏置的索引
        points = list(itertools.product(range(resolution_in), range(resolution_in)))
        points_ = list(itertools.product(range(resolution_out), range(resolution_out)))
        len_points, len_points_ = len(points), len(points_)
        attention_offsets, indices = {}, []
        for p1 in points_:
            for p2 in points:
                size = 1
                offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2), abs(p1[1] * stride - p2[1] + (size - 1) / 2))
                if offset not in attention_offsets:
                    attention_offsets[offset] = len(attention_offsets)
                indices.append(attention_offsets[offset])

        # 初始化注意力偏置参数和索引
        self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets)))
        self.register_buffer(
            "attention_bias_idxs", torch.LongTensor(indices).view(len_points_, len_points), persistent=False
        )

    @torch.no_grad()
    def train(self, mode=True):
        # 重写父类的 train 方法,并设置为不需要梯度
        super().train(mode)
        if mode and self.attention_bias_cache:
            self.attention_bias_cache = {}  # 如果是训练模式且注意力偏置缓存不为空,则清空缓存

    def get_attention_biases(self, device):
        # 获取注意力偏置方法
        if self.training:
            return self.attention_biases[:, self.attention_bias_idxs]  # 如果是训练模式,直接返回注意力偏置
        else:
            device_key = str(device)
            if device_key not in self.attention_bias_cache:
                self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
                # 如果设备键不在缓存中,则将注意力偏置缓存到设备键
            return self.attention_bias_cache[device_key]  # 返回设备键对应的注意力偏置
    # 定义前向传播方法,接收隐藏状态作为输入
    def forward(self, hidden_state):
        # 获取输入张量的批量大小、序列长度和特征维度
        batch_size, seq_length, _ = hidden_state.shape
        
        # 使用 self.keys_values 方法生成键和值,然后重新组织张量形状
        key, value = (
            self.keys_values(hidden_state)
            .view(batch_size, seq_length, self.num_attention_heads, -1)
            .split([self.key_dim, self.attention_ratio * self.key_dim], dim=3)
        )
        
        # 对键张量进行维度重排,以便后续计算注意力
        key = key.permute(0, 2, 1, 3)
        
        # 对值张量进行维度重排,以便后续计算注意力
        value = value.permute(0, 2, 1, 3)

        # 使用 self.queries_subsample 方法对隐藏状态进行查询抽样
        query = self.queries(self.queries_subsample(hidden_state))
        
        # 重新组织查询张量的形状,以便后续计算注意力
        query = query.view(batch_size, self.resolution_out**2, self.num_attention_heads, self.key_dim).permute(
            0, 2, 1, 3
        )

        # 计算注意力分数,包括缩放、添加注意力偏置
        attention = query @ key.transpose(-2, -1) * self.scale + self.get_attention_biases(hidden_state.device)
        
        # 对注意力分数进行 softmax 归一化
        attention = attention.softmax(dim=-1)
        
        # 计算加权后的值张量,然后进行维度重排和形状调整
        hidden_state = (attention @ value).transpose(1, 2).reshape(batch_size, -1, self.out_dim_projection)
        
        # 对加权后的值张量应用激活函数和投影层
        hidden_state = self.projection(self.activation(hidden_state))
        
        # 返回处理后的隐藏状态张量
        return hidden_state
# 定义一个 LevitMLPLayer 类,继承自 nn.Module,用于实现 MLP 层,相比 ViT 只扩展 2 倍。
class LevitMLPLayer(nn.Module):
    """
    MLP Layer with `2X` expansion in contrast to ViT with `4X`.
    """

    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        # 使用带有批归一化的 MLPLayerWithBN 类来定义线性变换(升维)
        self.linear_up = MLPLayerWithBN(input_dim, hidden_dim)
        # 激活函数使用 Hardswish
        self.activation = nn.Hardswish()
        # 使用带有批归一化的 MLPLayerWithBN 类定义线性变换(降维)
        self.linear_down = MLPLayerWithBN(hidden_dim, input_dim)

    def forward(self, hidden_state):
        # 执行升维操作
        hidden_state = self.linear_up(hidden_state)
        # 应用激活函数 Hardswish
        hidden_state = self.activation(hidden_state)
        # 执行降维操作
        hidden_state = self.linear_down(hidden_state)
        return hidden_state


# 定义一个 LevitResidualLayer 类,继承自 nn.Module,用于实现 LeViT 的残差块。
class LevitResidualLayer(nn.Module):
    """
    Residual Block for LeViT
    """

    def __init__(self, module, drop_rate):
        super().__init__()
        # 保存作为残差的模块
        self.module = module
        # 设定丢弃率(dropout rate)
        self.drop_rate = drop_rate

    def forward(self, hidden_state):
        # 如果处于训练模式并且设置了丢弃率
        if self.training and self.drop_rate > 0:
            # 随机生成与隐藏状态维度相同的随机数张量,用于丢弃
            rnd = torch.rand(hidden_state.size(0), 1, 1, device=hidden_state.device)
            # 将随机数张量转换为掩码,根据丢弃率进行归一化
            rnd = rnd.ge_(self.drop_rate).div(1 - self.drop_rate).detach()
            # 计算残差块的输出,同时应用丢弃掩码
            hidden_state = hidden_state + self.module(hidden_state) * rnd
            return hidden_state
        else:
            # 计算残差块的输出
            hidden_state = hidden_state + self.module(hidden_state)
            return hidden_state


# 定义一个 LevitStage 类,继承自 nn.Module,表示 LeViT 模型中的一个阶段,包括 LevitMLPLayer 和 LevitAttention 层。
class LevitStage(nn.Module):
    """
    LeViT Stage consisting of `LevitMLPLayer` and `LevitAttention` layers.
    """

    def __init__(
        self,
        config,
        idx,
        hidden_sizes,
        key_dim,
        depths,
        num_attention_heads,
        attention_ratio,
        mlp_ratio,
        down_ops,
        resolution_in,
        ):
        super().__init__()
        # 初始化 LeViT 阶段的参数和配置
        self.config = config
        self.idx = idx
        self.hidden_sizes = hidden_sizes
        self.key_dim = key_dim
        self.depths = depths
        self.num_attention_heads = num_attention_heads
        self.attention_ratio = attention_ratio
        self.mlp_ratio = mlp_ratio
        self.down_ops = down_ops
        self.resolution_in = resolution_in
    ):
        # 调用父类的构造函数初始化对象
        super().__init__()
        # 初始化图层列表
        self.layers = []
        # 设置配置参数
        self.config = config
        # 设置初始分辨率和最终分辨率
        self.resolution_in = resolution_in
        # resolution_in 是初始分辨率,resolution_out 是经过降采样后的最终分辨率

        # 根据深度循环构建层对象
        for _ in range(depths):
            # 添加注意力机制层到层列表
            self.layers.append(
                LevitResidualLayer(
                    LevitAttention(hidden_sizes, key_dim, num_attention_heads, attention_ratio, resolution_in),
                    self.config.drop_path_rate,
                )
            )
            # 如果 mlp_ratio 大于 0,则构建 MLP 层并添加到层列表
            if mlp_ratio > 0:
                hidden_dim = hidden_sizes * mlp_ratio
                self.layers.append(
                    LevitResidualLayer(LevitMLPLayer(hidden_sizes, hidden_dim), self.config.drop_path_rate)
                )

        # 如果第一个 down_ops 是 "Subsample",则执行下采样操作
        if down_ops[0] == "Subsample":
            # 计算经过下采样后的分辨率 resolution_out
            self.resolution_out = (self.resolution_in - 1) // down_ops[5] + 1
            # 添加下采样注意力机制层到层列表
            self.layers.append(
                LevitAttentionSubsample(
                    *self.config.hidden_sizes[idx : idx + 2],
                    key_dim=down_ops[1],
                    num_attention_heads=down_ops[2],
                    attention_ratio=down_ops[3],
                    stride=down_ops[5],
                    resolution_in=resolution_in,
                    resolution_out=self.resolution_out,
                )
            )
            # 更新当前分辨率为下采样后的分辨率
            self.resolution_in = self.resolution_out
            # 如果 down_ops[4] 大于 0,则构建 MLP 层并添加到层列表
            if down_ops[4] > 0:
                hidden_dim = self.config.hidden_sizes[idx + 1] * down_ops[4]
                self.layers.append(
                    LevitResidualLayer(
                        LevitMLPLayer(self.config.hidden_sizes[idx + 1], hidden_dim), self.config.drop_path_rate
                    )
                )

        # 将层列表转换为 nn.ModuleList 对象
        self.layers = nn.ModuleList(self.layers)

    # 获取当前模型的分辨率
    def get_resolution(self):
        return self.resolution_in

    # 前向传播函数
    def forward(self, hidden_state):
        # 对每一层进行前向传播计算
        for layer in self.layers:
            hidden_state = layer(hidden_state)
        # 返回最终的隐藏状态
        return hidden_state
class LevitEncoder(nn.Module):
    """
    LeViT Encoder consisting of multiple `LevitStage` stages.
    """

    def __init__(self, config):
        super().__init__()
        self.config = config  # 初始化模型配置参数
        resolution = self.config.image_size // self.config.patch_size  # 计算分辨率
        self.stages = []  # 初始化阶段列表
        self.config.down_ops.append([""])  # 将空字符串追加到下采样操作列表中(可能是个bug)

        for stage_idx in range(len(config.depths)):  # 遍历每个阶段的深度
            stage = LevitStage(  # 创建LevitStage阶段实例
                config,
                stage_idx,
                config.hidden_sizes[stage_idx],
                config.key_dim[stage_idx],
                config.depths[stage_idx],
                config.num_attention_heads[stage_idx],
                config.attention_ratio[stage_idx],
                config.mlp_ratio[stage_idx],
                config.down_ops[stage_idx],
                resolution,
            )
            resolution = stage.get_resolution()  # 获取当前阶段的分辨率
            self.stages.append(stage)  # 将当前阶段添加到阶段列表中

        self.stages = nn.ModuleList(self.stages)  # 转换阶段列表为PyTorch的模块列表

    def forward(self, hidden_state, output_hidden_states=False, return_dict=True):
        all_hidden_states = () if output_hidden_states else None  # 初始化所有隐藏状态的元组或空值

        for stage in self.stages:  # 遍历所有阶段
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_state,)  # 将当前隐藏状态添加到所有隐藏状态元组中
            hidden_state = stage(hidden_state)  # 将隐藏状态传递给当前阶段进行前向计算

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_state,)  # 将最终隐藏状态添加到所有隐藏状态元组中
        if not return_dict:
            return tuple(v for v in [hidden_state, all_hidden_states] if v is not None)  # 如果不返回字典,则返回所有非空的值元组

        return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=all_hidden_states)  # 返回基本模型输出对象


class LevitClassificationLayer(nn.Module):
    """
    LeViT Classification Layer
    """

    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.batch_norm = nn.BatchNorm1d(input_dim)  # 初始化批标准化层
        self.linear = nn.Linear(input_dim, output_dim)  # 初始化线性层

    def forward(self, hidden_state):
        hidden_state = self.batch_norm(hidden_state)  # 批标准化操作
        logits = self.linear(hidden_state)  # 计算输出logits
        return logits  # 返回logits


class LevitPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = LevitConfig  # 设置配置类为LevitConfig
    base_model_prefix = "levit"  # 基础模型前缀名
    main_input_name = "pixel_values"  # 主输入名称

    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Conv2d)):  # 如果是线性层或卷积层
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)  # 使用正态分布初始化权重
            if module.bias is not None:
                module.bias.data.zero_()  # 如果存在偏置,则初始化为零
        elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):  # 如果是批标准化层
            module.bias.data.zero_()  # 初始化偏置为零
            module.weight.data.fill_(1.0)  # 初始化权重为1.0
# 定义 LevitModel 类,继承自 LevitPreTrainedModel,用于构建 Levit 模型
@add_start_docstrings(
    "The bare Levit model outputting raw features without any specific head on top.",  # 添加关于 Levit 模型的文档说明
    LEVIT_START_DOCSTRING,  # 添加 Levit 模型的配置参数说明和初始化信息
)
class LevitModel(LevitPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)  # 调用父类 LevitPreTrainedModel 的初始化方法
        self.config = config  # 将传入的配置参数 config 存储为实例变量
        self.patch_embeddings = LevitPatchEmbeddings(config)  # 初始化图像的 patch embeddings
        self.encoder = LevitEncoder(config)  # 初始化 Levit 编码器
        # Initialize weights and apply final processing
        self.post_init()  # 调用自定义的 post_init 方法,用于初始化权重和应用最终处理

    @add_start_docstrings_to_model_forward(LEVIT_INPUTS_DOCSTRING)  # 添加前向传播函数的文档说明,包括输入参数
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,  # 添加代码示例的文档说明,显示如何使用模型
        output_type=BaseModelOutputWithPoolingAndNoAttention,  # 指定输出类型的文档说明
        config_class=_CONFIG_FOR_DOC,  # 指定模型配置类的文档说明
        modality="vision",  # 指明模型适用的领域为视觉
        expected_output=_EXPECTED_OUTPUT_SHAPE,  # 添加预期输出形状的文档说明
    )
    def forward(
        self,
        pixel_values: torch.FloatTensor = None,  # 输入参数 pixel_values,代表像素值的浮点张量
        output_hidden_states: Optional[bool] = None,  # 是否返回所有层的隐藏状态的布尔值参数
        return_dict: Optional[bool] = None,  # 是否返回 ModelOutput 对象而不是普通元组的布尔值参数
        ) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]:
        # 定义函数签名,指定返回类型为元组或BaseModelOutputWithPoolingAndNoAttention类型

        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 如果output_hidden_states不为None,则使用其值;否则使用self.config.output_hidden_states的值

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # 如果return_dict不为None,则使用其值;否则使用self.config.use_return_dict的值

        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")
        # 如果pixel_values为None,则抛出数值错误异常,要求指定pixel_values

        embeddings = self.patch_embeddings(pixel_values)
        # 将像素值转换为嵌入向量

        encoder_outputs = self.encoder(
            embeddings,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 使用编码器对嵌入向量进行编码,返回编码器的输出

        last_hidden_state = encoder_outputs[0]
        # 取编码器输出的第一个元素作为最终的隐藏状态表示

        # global average pooling, (batch_size, seq_length, hidden_sizes) -> (batch_size, hidden_sizes)
        pooled_output = last_hidden_state.mean(dim=1)
        # 对最终隐藏状态进行全局平均池化,将每个序列的隐藏状态平均到一个向量中

        if not return_dict:
            return (last_hidden_state, pooled_output) + encoder_outputs[1:]
        # 如果return_dict为False,则返回元组形式的输出:最终隐藏状态、池化输出以及其余的编码器输出

        return BaseModelOutputWithPoolingAndNoAttention(
            last_hidden_state=last_hidden_state,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
        )
        # 如果return_dict为True,则使用BaseModelOutputWithPoolingAndNoAttention类封装最终隐藏状态、池化输出和所有隐藏状态的列表,并返回该对象
# 定义一个 Levit 图像分类模型,基于 Levit 模型并添加一个分类器头部
@add_start_docstrings(
    """
    Levit Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
    ImageNet.
    """,
    LEVIT_START_DOCSTRING,
)
class LevitForImageClassification(LevitPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config  # 保存配置信息
        self.num_labels = config.num_labels  # 获取标签数量
        self.levit = LevitModel(config)  # 初始化基础的 Levit 模型

        # 分类器头部
        self.classifier = (
            # 如果标签数量大于 0,则创建 Levit 分类层;否则创建一个恒等映射
            LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels)
            if config.num_labels > 0
            else torch.nn.Identity()
        )

        # 初始化权重并进行最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(LEVIT_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_IMAGE_CLASS_CHECKPOINT,
        output_type=ImageClassifierOutputWithNoAttention,
        config_class=_CONFIG_FOR_DOC,
        expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
    )
    def forward(
        self,
        pixel_values: torch.FloatTensor = None,
        labels: Optional[torch.LongTensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        # 如果 return_dict 不为 None,则使用 return_dict;否则使用 self.config.use_return_dict
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用 self.levit 方法,传入像素值 pixel_values 和其他参数,获取模型输出
        outputs = self.levit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)

        # 从模型输出中获取序列输出(通常是最后一层隐藏状态的输出),并计算其在第1维度上的平均值
        sequence_output = outputs[0]
        sequence_output = sequence_output.mean(1)

        # 将平均后的序列输出输入分类器,得到 logits(未经 softmax 处理的分类器输出)
        logits = self.classifier(sequence_output)

        # 初始化损失为 None
        loss = None

        # 如果提供了 labels,则计算损失
        if labels is not None:
            # 如果未指定问题类型,则根据情况自动判断问题类型(回归、单标签分类、多标签分类)
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            # 根据问题类型计算相应的损失
            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        # 如果 return_dict 为 False,则返回一个元组,包含 logits 和可能的其他模型输出
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 如果 return_dict 为 True,则构建一个 ImageClassifierOutputWithNoAttention 对象,并返回
        return ImageClassifierOutputWithNoAttention(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
        )
@add_start_docstrings(
    """
    LeViT Model transformer with image classification heads on top (a linear layer on top of the final hidden state and
    a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet. .. warning::
           This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
           supported.
    """,
    LEVIT_START_DOCSTRING,
)
class LevitForImageClassificationWithTeacher(LevitPreTrainedModel):
    """
    构建一个基于 LeViT 模型的图像分类器,带有两个分类头部(一个用于最终隐藏状态的线性层,另一个用于蒸馏令牌最终隐藏状态的线性层),适用于 ImageNet 等数据集。
    注意:该模型仅支持推断,暂不支持使用蒸馏(即与教师模型进行微调)。

    Attributes:
        config (LevitConfig): 模型的配置对象,包含模型的各种参数设定。
        num_labels (int): 分类任务中的标签数量。
        levit (LevitModel): 底层的 LeViT 模型实例。

    """
    def __init__(self, config):
        """
        初始化方法,用于创建一个新的 LevitForImageClassificationWithTeacher 实例。

        Args:
            config (LevitConfig): 模型的配置对象,包含模型的各种参数设定。
        """
        super().__init__(config)
        self.config = config
        self.num_labels = config.num_labels
        self.levit = LevitModel(config)

        # Classifier head
        self.classifier = (
            LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels)
            if config.num_labels > 0
            else torch.nn.Identity()
        )
        self.classifier_distill = (
            LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels)
            if config.num_labels > 0
            else torch.nn.Identity()
        )

        # Initialize weights and apply final processing
        self.post_init()

    @add_start_docstrings_to_model_forward(LEVIT_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_IMAGE_CLASS_CHECKPOINT,
        output_type=LevitForImageClassificationWithTeacherOutput,
        config_class=_CONFIG_FOR_DOC,
        expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
    )
    def forward(
        self,
        pixel_values: torch.FloatTensor = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, LevitForImageClassificationWithTeacherOutput]:
        """
        前向传播方法,执行模型的推断过程。

        Args:
            pixel_values (torch.FloatTensor, optional): 输入的像素值张量。默认为 None。
            output_hidden_states (bool, optional): 是否返回隐藏状态。默认为 None。
            return_dict (bool, optional): 是否以字典形式返回输出。默认为 None。

        Returns:
            Union[Tuple, LevitForImageClassificationWithTeacherOutput]: 根据 return_dict 的设置,返回不同的输出形式。
                如果 return_dict 为 False,则返回一个元组。
                如果 return_dict 为 True,则返回一个 LevitForImageClassificationWithTeacherOutput 对象。

        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 使用 LeViT 模型进行前向传播
        outputs = self.levit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)

        # 获取序列输出并对其进行平均池化
        sequence_output = outputs[0]
        sequence_output = sequence_output.mean(1)

        # 分别通过分类头部和蒸馏头部计算 logits
        cls_logits, distill_logits = self.classifier(sequence_output), self.classifier_distill(sequence_output)
        logits = (cls_logits + distill_logits) / 2

        if not return_dict:
            # 如果 return_dict 为 False,则返回一个元组形式的输出
            output = (logits, cls_logits, distill_logits) + outputs[2:]
            return output

        # 如果 return_dict 为 True,则返回一个 LevitForImageClassificationWithTeacherOutput 对象
        return LevitForImageClassificationWithTeacherOutput(
            logits=logits,
            cls_logits=cls_logits,
            distillation_logits=distill_logits,
            hidden_states=outputs.hidden_states,
        )

.\models\levit\__init__.py

# 版权声明和许可信息,指明版权归属和使用许可
# 详情请查阅Apache License, Version 2.0: http://www.apache.org/licenses/LICENSE-2.0
#
# 如果依据许可法律要求或以书面形式同意,软件将按“原样”分发,不附任何明示或暗示的保证或条件
# 请参阅许可协议以了解特定的语言版本

from typing import TYPE_CHECKING

# 从自定义模块中导入所需函数和类,用以检查环境是否支持特定功能
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available

# 定义模块导入的结构字典,初始化一些模块路径
_import_structure = {"configuration_levit": ["LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LevitConfig", "LevitOnnxConfig"]}

# 检查视觉处理功能是否可用,若不可用则引发OptionalDependencyNotAvailable异常
try:
    if not is_vision_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,添加相关模块到导入结构字典中
    _import_structure["feature_extraction_levit"] = ["LevitFeatureExtractor"]
    _import_structure["image_processing_levit"] = ["LevitImageProcessor"]

# 检查是否支持PyTorch环境,若不支持则引发OptionalDependencyNotAvailable异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果支持PyTorch,添加相关模块到导入结构字典中
    _import_structure["modeling_levit"] = [
        "LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
        "LevitForImageClassification",
        "LevitForImageClassificationWithTeacher",
        "LevitModel",
        "LevitPreTrainedModel",
    ]

# 如果是类型检查模式,导入具体的模块和类以进行类型检查
if TYPE_CHECKING:
    from .configuration_levit import LEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, LevitConfig, LevitOnnxConfig

    # 检查视觉处理功能是否可用,若不可用则跳过
    try:
        if not is_vision_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入视觉特征提取和图像处理相关模块
        from .feature_extraction_levit import LevitFeatureExtractor
        from .image_processing_levit import LevitImageProcessor

    # 检查是否支持PyTorch环境,若不支持则跳过
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入模型相关的PyTorch模块
        from .modeling_levit import (
            LEVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
            LevitForImageClassification,
            LevitForImageClassificationWithTeacher,
            LevitModel,
            LevitPreTrainedModel,
        )

# 如果不是类型检查模式,将当前模块注册为_LazyModule的懒加载模块
else:
    import sys

    # 将当前模块重新指定为_LazyModule的实例,用于延迟导入
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\lilt\configuration_lilt.py

# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" LiLT configuration"""

from ...configuration_utils import PretrainedConfig  # 导入PretrainedConfig类,用于处理预训练模型配置
from ...utils import logging  # 导入logging模块,用于日志记录

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

LILT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "SCUT-DLVCLab/lilt-roberta-en-base": (
        "https://huggingface.co/SCUT-DLVCLab/lilt-roberta-en-base/resolve/main/config.json"
    ),
}

class LiltConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`LiltModel`]. It is used to instantiate a LiLT
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the LiLT
    [SCUT-DLVCLab/lilt-roberta-en-base](https://huggingface.co/SCUT-DLVCLab/lilt-roberta-en-base) architecture.
    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Examples:

    ```
    >>> from transformers import LiltConfig, LiltModel

    >>> # Initializing a LiLT SCUT-DLVCLab/lilt-roberta-en-base style configuration
    >>> configuration = LiltConfig()
    >>> # Randomly initializing a model from the SCUT-DLVCLab/lilt-roberta-en-base style configuration
    >>> model = LiltModel(configuration)
    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    """

    model_type = "lilt"  # 定义模型类型为"lilt"

    def __init__(
        self,
        vocab_size=30522,  # 词汇表大小,默认为30522
        hidden_size=768,  # 隐藏层大小,默认为768
        num_hidden_layers=12,  # 隐藏层数,默认为12
        num_attention_heads=12,  # 注意力头数,默认为12
        intermediate_size=3072,  # 中间层大小,默认为3072
        hidden_act="gelu",  # 隐藏层激活函数,默认为GELU
        hidden_dropout_prob=0.1,  # 隐藏层Dropout概率,默认为0.1
        attention_probs_dropout_prob=0.1,  # 注意力概率Dropout概率,默认为0.1
        max_position_embeddings=512,  # 最大位置嵌入长度,默认为512
        type_vocab_size=2,  # 类型词汇表大小,默认为2
        initializer_range=0.02,  # 初始化范围,默认为0.02
        layer_norm_eps=1e-12,  # LayerNorm的epsilon,默认为1e-12
        pad_token_id=0,  # 填充token的ID,默认为0
        position_embedding_type="absolute",  # 位置嵌入类型,默认为绝对位置编码
        classifier_dropout=None,  # 分类器的Dropout,默认为None
        channel_shrink_ratio=4,  # 通道缩小比例,默认为4
        max_2d_position_embeddings=1024,  # 最大二维位置嵌入长度,默认为1024
        **kwargs,  # 其他关键字参数
    ):
        """
        Initializes a new instance of LiltConfig with optional parameters to define the model architecture.

        Parameters:
        - vocab_size: The size of the vocabulary.
        - hidden_size: The size of the hidden layers.
        - num_hidden_layers: The number of hidden layers.
        - num_attention_heads: The number of attention heads in the multi-head attention setups.
        - intermediate_size: The size of the intermediate (i.e., feed-forward) layer in the transformer blocks.
        - hidden_act: The activation function (e.g., "gelu").
        - hidden_dropout_prob: The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
        - attention_probs_dropout_prob: The dropout ratio for the attention probabilities.
        - max_position_embeddings: The maximum length of the input sequences.
        - type_vocab_size: The size of the token type vocab.
        - initializer_range: The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        - layer_norm_eps: The epsilon used by LayerNorm layers.
        - pad_token_id: The ID of the padding token.
        - position_embedding_type: The type of position embeddings.
        - classifier_dropout: The dropout ratio for classifier.
        - channel_shrink_ratio: The shrink ratio of channel.
        - max_2d_position_embeddings: The maximum length of the 2D position embeddings.
        - **kwargs: Additional keyword arguments.

        """
        super().__init__(**kwargs)  # 调用父类的初始化方法,传入所有关键字参数
        ):
            # 调用父类的初始化方法,设定填充标记的 ID 和其他可选参数
            super().__init__(pad_token_id=pad_token_id, **kwargs)

            # 设置模型的词汇表大小
            self.vocab_size = vocab_size
            # 设置隐藏层的大小
            self.hidden_size = hidden_size
            # 设置隐藏层的数量
            self.num_hidden_layers = num_hidden_layers
            # 设置注意力头的数量
            self.num_attention_heads = num_attention_heads
            # 设置隐藏层激活函数的类型
            self.hidden_act = hidden_act
            # 设置中间层大小
            self.intermediate_size = intermediate_size
            # 设置隐藏层的 dropout 概率
            self.hidden_dropout_prob = hidden_dropout_prob
            # 设置注意力概率 dropout 概率
            self.attention_probs_dropout_prob = attention_probs_dropout_prob
            # 设置最大位置嵌入的大小
            self.max_position_embeddings = max_position_embeddings
            # 设置类型词汇表的大小
            self.type_vocab_size = type_vocab_size
            # 设置初始化范围
            self.initializer_range = initializer_range
            # 设置层归一化的 epsilon 值
            self.layer_norm_eps = layer_norm_eps
            # 设置位置嵌入的类型
            self.position_embedding_type = position_embedding_type
            # 设置分类器 dropout 概率
            self.classifier_dropout = classifier_dropout
            # 设置通道收缩比率
            self.channel_shrink_ratio = channel_shrink_ratio
            # 设置最大二维位置嵌入的大小
            self.max_2d_position_embeddings = max_2d_position_embeddings

.\models\lilt\modeling_lilt.py

# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch LiLT model."""

import math
from typing import Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPooling,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_lilt import LiltConfig


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "LiltConfig"

LILT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "SCUT-DLVCLab/lilt-roberta-en-base",
    # See all LiLT models at https://huggingface.co/models?filter=lilt
]


class LiltTextEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化词嵌入层,用于将输入词编号转换为向量表示
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        # 初始化位置嵌入层,用于表示词的位置信息
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        # 初始化类型嵌入层,用于表示输入的类型信息
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # self.LayerNorm 使用 nn.LayerNorm 进行层归一化,保持和 TensorFlow 模型变量名一致以便加载 TensorFlow 检查点文件
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 使用 dropout 进行随机失活,以防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        # 注册一个持久化的 buffer,用于存储位置 ID,这些位置 ID 在序列化时会被导出
        self.register_buffer(
            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
        )
        # 设置位置嵌入类型,默认为绝对位置编码
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")

        # End copy
        # 设置填充标记的索引
        self.padding_idx = config.pad_token_id
        # 初始化位置嵌入层,用于表示词的位置信息,带有填充标记的索引设置
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
        )

    def forward(
        self,
        input_ids=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
    ):
        if position_ids is None:
            if input_ids is not None:
                # 如果位置 ids 为空且输入 ids 不为空,则从输入 token ids 创建位置 ids。任何填充的 token 保持填充状态。
                position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx).to(
                    input_ids.device
                )
            else:
                # 如果位置 ids 为空且输入 ids 也为空,则从输入嵌入创建位置 ids。
                position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)

        if input_ids is not None:
            # 如果输入 ids 不为空,获取其形状
            input_shape = input_ids.size()
        else:
            # 否则,获取输入嵌入的形状,去掉最后一个维度(即序列长度)
            input_shape = inputs_embeds.size()[:-1]

        if token_type_ids is None:
            # 如果 token 类型 ids 为空,则创建全零的 token 类型 ids,与输入形状相同
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

        if inputs_embeds is None:
            # 如果输入嵌入为空,则通过输入 ids 获取单词嵌入
            inputs_embeds = self.word_embeddings(input_ids)
        # 获取 token 类型嵌入
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        # 将输入嵌入与 token 类型嵌入相加,得到整体嵌入
        embeddings = inputs_embeds + token_type_embeddings
        if self.position_embedding_type == "absolute":
            # 如果位置嵌入类型为 "absolute",则添加绝对位置嵌入
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings
        # 应用 LayerNorm 对 embeddings 进行归一化
        embeddings = self.LayerNorm(embeddings)
        # 对 embeddings 进行 dropout 处理
        embeddings = self.dropout(embeddings)
        # 返回 embeddings 和 position_ids
        return embeddings, position_ids

    def create_position_ids_from_input_ids(self, input_ids, padding_idx):
        """
        Args:
        非填充符号替换为它们的位置编号。位置编号从 padding_idx+1 开始。忽略填充符号。这是从 fairseq 的 `utils.make_positions` 修改而来。
            input_ids: torch.Tensor
            padding_idx: int
        Returns: torch.Tensor
        """
        # 创建一个 mask,标记出非填充符号位置
        mask = input_ids.ne(padding_idx).int()
        # 使用累加的方式生成位置 ids,确保在填充符号处保持填充状态
        incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask)) * mask
        return incremental_indices.long() + padding_idx

    def create_position_ids_from_inputs_embeds(self, inputs_embeds):
        """
        Args:
        我们直接提供嵌入。无法推断哪些是填充符号,因此只生成顺序的位置 ids。
            inputs_embeds: torch.Tensor
        Returns: torch.Tensor
        """
        # 获取输入嵌入的形状,去掉最后一个维度得到序列长度
        input_shape = inputs_embeds.size()[:-1]
        sequence_length = input_shape[1]

        # 生成顺序的位置 ids,从 padding_idx+1 开始到 sequence_length + padding_idx+1
        position_ids = torch.arange(
            self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
        )
        return position_ids.unsqueeze(0).expand(input_shape)
class LiltLayoutEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 将隐藏大小除以6,因为有6种不同的布局嵌入:
        # 左侧位置、上侧位置、右侧位置、下侧位置、高度、宽度
        self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6)
        self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6)
        self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6)
        self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6)

        self.padding_idx = config.pad_token_id
        # 使用config中的参数初始化嵌入层,设置padding_idx为padding标记的ID
        self.box_position_embeddings = nn.Embedding(
            config.max_position_embeddings,
            config.hidden_size // config.channel_shrink_ratio,
            padding_idx=self.padding_idx,
        )
        # 线性层,将隐藏大小映射到更小的尺寸,用于嵌入向量的线性变换
        self.box_linear_embeddings = nn.Linear(
            in_features=config.hidden_size, out_features=config.hidden_size // config.channel_shrink_ratio
        )
        # LayerNorm 层,用于归一化输入向量
        self.LayerNorm = nn.LayerNorm(config.hidden_size // config.channel_shrink_ratio, eps=config.layer_norm_eps)
        # Dropout 层,用于随机丢弃输入向量的一部分,防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, bbox=None, position_ids=None):
        try:
            # 从bbox中提取左侧、上侧、右侧、下侧位置的嵌入向量
            left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
            upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
            right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
            lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
        except IndexError as e:
            # 抛出异常,如果bbox的坐标值不在0-1000范围内
            raise IndexError("The `bbox` coordinate values should be within 0-1000 range.") from e

        # 计算高度和宽度的嵌入向量
        h_position_embeddings = self.h_position_embeddings(bbox[:, :, 3] - bbox[:, :, 1])
        w_position_embeddings = self.w_position_embeddings(bbox[:, :, 2] - bbox[:, :, 0])

        # 拼接左侧、上侧、右侧、下侧、高度、宽度的嵌入向量
        spatial_position_embeddings = torch.cat(
            [
                left_position_embeddings,
                upper_position_embeddings,
                right_position_embeddings,
                lower_position_embeddings,
                h_position_embeddings,
                w_position_embeddings,
            ],
            dim=-1,
        )
        # 对拼接的嵌入向量进行线性变换
        spatial_position_embeddings = self.box_linear_embeddings(spatial_position_embeddings)
        # 获取位置ID对应的位置嵌入向量
        box_position_embeddings = self.box_position_embeddings(position_ids)

        # 将位置嵌入向量加到拼接的嵌入向量上
        spatial_position_embeddings = spatial_position_embeddings + box_position_embeddings

        # 对加和后的嵌入向量进行LayerNorm归一化
        spatial_position_embeddings = self.LayerNorm(spatial_position_embeddings)
        # 对归一化后的嵌入向量进行Dropout操作
        spatial_position_embeddings = self.dropout(spatial_position_embeddings)

        # 返回最终的空间位置嵌入向量
        return spatial_position_embeddings
    # 初始化函数,接收配置和位置嵌入类型作为参数
    def __init__(self, config, position_embedding_type=None):
        # 调用父类初始化方法
        super().__init__()
        # 检查隐藏层大小是否能被注意力头数整除,同时隐藏层大小不具有嵌入大小属性
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            # 抛出数值错误,显示隐藏层大小不能被注意力头数整除
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        # 设置注意力头数和每个头的大小
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        # 计算所有头的总大小
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # 定义查询、键、值的线性层
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        # 定义布局查询、键、值的线性层
        self.layout_query = nn.Linear(
            config.hidden_size // config.channel_shrink_ratio, self.all_head_size // config.channel_shrink_ratio
        )
        self.layout_key = nn.Linear(
            config.hidden_size // config.channel_shrink_ratio, self.all_head_size // config.channel_shrink_ratio
        )
        self.layout_value = nn.Linear(
            config.hidden_size // config.channel_shrink_ratio, self.all_head_size // config.channel_shrink_ratio
        )

        # 定义 dropout 层
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        # 设置位置嵌入类型,如果未提供则默认为绝对位置
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        # 如果位置嵌入类型是相对键或相对键查询,则初始化距离嵌入层
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

        # 设置通道缩减比例
        self.channel_shrink_ratio = config.channel_shrink_ratio
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput
class LiltSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 定义一个全连接层,将输入的hidden_size维度映射到hidden_size维度
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # LayerNorm层,用于对输入进行归一化处理
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # Dropout层,用于在训练过程中随机将一部分输入置为0,以防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        # 全连接层操作,将hidden_states映射到相同维度
        hidden_states = self.dense(hidden_states)
        # Dropout操作,随机置0
        hidden_states = self.dropout(hidden_states)
        # LayerNorm操作,对映射后的结果进行归一化处理并与原始输入相加
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class LiltAttention(nn.Module):
    def __init__(self, config, position_embedding_type=None):
        super().__init__()
        # LiltSelfAttention模块,用于计算注意力机制
        self.self = LiltSelfAttention(config, position_embedding_type=position_embedding_type)
        # LiltSelfOutput模块,用于处理自注意力的输出
        self.output = LiltSelfOutput(config)
        # 用于存储被剪枝的注意力头索引
        self.pruned_heads = set()

        # 保存原始的hidden_size,并根据channel_shrink_ratio调整hidden_size大小
        ori_hidden_size = config.hidden_size
        config.hidden_size = config.hidden_size // config.channel_shrink_ratio
        # 用于处理布局输入的LiltSelfOutput模块
        self.layout_output = LiltSelfOutput(config)
        config.hidden_size = ori_hidden_size

    # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
    def prune_heads(self, heads):
        # 如果没有要剪枝的头部,则直接返回
        if len(heads) == 0:
            return
        # 调用帮助函数找到可剪枝的头部和对应索引
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )

        # 对线性层进行剪枝
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # 更新超参数并存储已剪枝的头部索引
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states: torch.Tensor,
        layout_inputs: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # 调用self模块的forward方法,计算自注意力机制
        self_outputs = self.self(
            hidden_states,
            layout_inputs,
            attention_mask,
            head_mask,
            output_attentions,
        )
        # 对自注意力的输出进行处理,传入self.output模块
        attention_output = self.output(self_outputs[0][0], hidden_states)
        # 对布局注意力的输出进行处理,传入self.layout_output模块
        layout_attention_output = self.layout_output(self_outputs[0][1], layout_inputs)
        # 如果有需要,则添加注意力输出到结果中
        outputs = ((attention_output, layout_attention_output),) + self_outputs[1:]
        return outputs
# 定义 LiltLayer 类,继承自 nn.Module,表示一个自定义的神经网络层
class LiltLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化 LiltLayer 类,设置一些基本属性
        self.chunk_size_feed_forward = config.chunk_size_feed_forward  # 设定前向传播的块大小
        self.seq_len_dim = 1  # 序列长度维度为 1
        self.attention = LiltAttention(config)  # 初始化注意力层对象
        self.intermediate = LiltIntermediate(config)  # 初始化中间层对象
        self.output = LiltOutput(config)  # 初始化输出层对象

        # 保存原始的隐藏大小和中间大小
        ori_hidden_size = config.hidden_size
        ori_intermediate_size = config.intermediate_size

        # 根据配置调整隐藏大小和中间大小
        config.hidden_size = config.hidden_size // config.channel_shrink_ratio
        config.intermediate_size = config.intermediate_size // config.channel_shrink_ratio

        # 创建新的中间层和输出层对象,用于布局处理
        self.layout_intermediate = LiltIntermediate(config)
        self.layout_output = LiltOutput(config)

        # 恢复原始的隐藏大小和中间大小
        config.hidden_size = ori_hidden_size
        config.intermediate_size = ori_intermediate_size

    def forward(
        self,
        hidden_states: torch.Tensor,
        layout_inputs: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
        ):
        # 前向传播方法定义,接收隐藏状态、布局输入以及可选的注意力掩码和头部掩码
    # 定义函数,接受多个参数并返回一个元组,包含一个 torch.Tensor 对象
    ) -> Tuple[torch.Tensor]:
        # 调用 self.attention 方法,传入多个参数
        # hidden_states: 隐藏状态
        # layout_inputs: 布局输入
        # attention_mask: 注意力掩码
        # head_mask: 头部掩码
        # output_attentions: 是否输出注意力权重,默认为 False
        self_attention_outputs = self.attention(
            hidden_states,
            layout_inputs,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
        )
        # 从 self_attention_outputs 中获取第一个元素的第一个元素,即 self attention 的输出
        attention_output = self_attention_outputs[0][0]
        # 从 self_attention_outputs 中获取第一个元素的第二个元素,即 layout attention 的输出
        layout_attention_output = self_attention_outputs[0][1]

        # 如果输出注意力权重,则将 self attentions 添加到输出中
        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights

        # 调用 apply_chunking_to_forward 函数,对 self.feed_forward_chunk 进行分块处理
        layer_output = apply_chunking_to_forward(
            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
        )
        # 调用 apply_chunking_to_forward 函数,对 self.layout_feed_forward_chunk 进行分块处理
        layout_layer_output = apply_chunking_to_forward(
            self.layout_feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, layout_attention_output
        )
        # 将处理后的输出添加到 outputs 元组中
        outputs = ((layer_output, layout_layer_output),) + outputs

        # 返回最终的输出元组
        return outputs

    # 从 transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk 复制过来的函数
    def feed_forward_chunk(self, attention_output):
        # 调用 self.intermediate 方法,对 attention_output 进行处理
        intermediate_output = self.intermediate(attention_output)
        # 调用 self.output 方法,对 intermediate_output 和 attention_output 进行处理,得到最终的输出
        layer_output = self.output(intermediate_output, attention_output)
        # 返回处理后的输出
        return layer_output

    # 定义的函数,处理 layout attention 的输出
    def layout_feed_forward_chunk(self, attention_output):
        # 调用 self.layout_intermediate 方法,对 attention_output 进行处理
        intermediate_output = self.layout_intermediate(attention_output)
        # 调用 self.layout_output 方法,对 intermediate_output 和 attention_output 进行处理,得到最终的输出
        layer_output = self.layout_output(intermediate_output, attention_output)
        # 返回处理后的输出
        return layer_output
# 声明一个名为 LiltEncoder 的类,继承自 nn.Module
class LiltEncoder(nn.Module):
    # 初始化函数,接受一个配置参数 config
    # 从 transformers.models.bert.modeling_bert.BertEncoder.__init__ 复制而来,将 Bert 替换为 Lilt
    def __init__(self, config):
        super().__init__()
        # 将配置参数保存到实例变量中
        self.config = config
        # 创建一个 nn.ModuleList,其中包含 config.num_hidden_layers 个 LiltLayer 实例
        self.layer = nn.ModuleList([LiltLayer(config) for _ in range(config.num_hidden_layers)])
        # 设置梯度检查点功能为 False
        self.gradient_checkpointing = False

    # 前向传播函数,接受多个输入参数和可选的返回类型注解
    def forward(
        self,
        hidden_states: torch.Tensor,  # 输入的隐藏状态张量
        layout_inputs: torch.Tensor,  # 布局输入张量
        attention_mask: Optional[torch.FloatTensor] = None,  # 可选的注意力掩码张量,默认为 None
        head_mask: Optional[torch.FloatTensor] = None,  # 可选的头部掩码张量,默认为 None
        output_attentions: Optional[bool] = False,  # 是否输出注意力张量的开关,默认为 False
        output_hidden_states: Optional[bool] = False,  # 是否输出所有隐藏状态的开关,默认为 False
        return_dict: Optional[bool] = True,  # 是否返回字典形式的结果,默认为 True
    ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:  # 返回类型为元组或 BaseModelOutput 类型

        # 如果需要输出隐藏状态,则初始化空的所有隐藏状态元组
        all_hidden_states = () if output_hidden_states else None
        # 如果需要输出注意力,则初始化空的所有自注意力元组
        all_self_attentions = () if output_attentions else None

        # 遍历 self.layer 中的每个层次模块
        for i, layer_module in enumerate(self.layer):
            # 如果需要输出隐藏状态,则将当前隐藏状态添加到 all_hidden_states 中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 获取当前层的头部掩码
            layer_head_mask = head_mask[i] if head_mask is not None else None

            # 如果梯度检查点开启且处于训练状态,则使用 _gradient_checkpointing_func 函数来计算层输出
            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    hidden_states,
                    layout_inputs,
                    attention_mask,
                    layer_head_mask,
                    output_attentions,
                )
            else:
                # 否则,直接调用层模块的 __call__ 方法计算层输出
                layer_outputs = layer_module(
                    hidden_states,
                    layout_inputs,
                    attention_mask,
                    layer_head_mask,
                    output_attentions,
                )

            # 更新隐藏状态和布局输入为当前层的输出结果的第一个元素和第二个元素
            hidden_states = layer_outputs[0][0]
            layout_inputs = layer_outputs[0][1]

            # 如果需要输出注意力,则将当前层的注意力张量添加到 all_self_attentions 中
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        # 如果需要输出隐藏状态,则将最终的隐藏状态添加到 all_hidden_states 中
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 如果不需要返回字典形式的结果,则返回非 None 的所有值的元组
        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    all_hidden_states,
                    all_self_attentions,
                ]
                if v is not None
            )
        # 否则,返回一个 BaseModelOutput 类型的对象,包含最终的隐藏状态、所有隐藏状态和所有自注意力
        return BaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )


# 从 transformers.models.bert.modeling_bert.BertPooler 复制而来的 LiltPooler 类
class LiltPooler(nn.Module):
    # 初始化函数,接受一个配置参数 config
    def __init__(self, config):
        super().__init__()
        # 创建一个线性层,输入维度为 config.hidden_size,输出维度为 config.hidden_size
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 创建一个 Tanh 激活函数实例
        self.activation = nn.Tanh()
    # 定义一个前向传播方法,接收隐藏状态作为输入,并返回转换后的张量作为输出
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 通过选择第一个标记对应的隐藏状态来“汇聚”模型
        first_token_tensor = hidden_states[:, 0]
        # 将第一个标记的隐藏状态输入全连接层,进行线性变换
        pooled_output = self.dense(first_token_tensor)
        # 对线性变换后的结果应用激活函数
        pooled_output = self.activation(pooled_output)
        # 返回经过汇聚和激活处理后的输出张量
        return pooled_output
class LiltPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 使用 LiltConfig 作为配置类
    config_class = LiltConfig
    # 模型的前缀名称
    base_model_prefix = "lilt"
    # 支持梯度检查点
    supports_gradient_checkpointing = True
    # 不进行模块分割的模块列表
    _no_split_modules = []

    # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, nn.Linear):
            # 如果是线性层,使用正态分布初始化权重,均值为0,标准差为配置文件中的初始化范围
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 如果存在偏置,将其初始化为零
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            # 如果是嵌入层,使用正态分布初始化权重,均值为0,标准差为配置文件中的初始化范围
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 如果定义了填充索引,将对应索引的权重初始化为零
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            # 如果是 LayerNorm 层,初始化偏置为零,初始化权重为1
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


LILT_START_DOCSTRING = r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`LiltConfig`]): Model configuration class with all the parameters of the
            model. Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

LILT_INPUTS_DOCSTRING = r"""
"""


@add_start_docstrings(
    "The bare LiLT Model transformer outputting raw hidden-states without any specific head on top.",
    LILT_START_DOCSTRING,
)
class LiltModel(LiltPreTrainedModel):
    def __init__(self, config, add_pooling_layer=True):
        super().__init__(config)
        # 设置模型配置
        self.config = config

        # 初始化嵌入层、布局嵌入和编码器
        self.embeddings = LiltTextEmbeddings(config)
        self.layout_embeddings = LiltLayoutEmbeddings(config)
        self.encoder = LiltEncoder(config)

        # 如果需要,添加池化层
        self.pooler = LiltPooler(config) if add_pooling_layer else None

        # 初始化权重并进行最终处理
        self.post_init()

    def get_input_embeddings(self):
        # 返回输入嵌入层的权重
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        # 设置输入嵌入层的权重
        self.embeddings.word_embeddings = value
    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        # 遍历 heads_to_prune 字典中的每个元素,其中 key 是层号,value 是需要剪枝的头部列表
        for layer, heads in heads_to_prune.items():
            # 在编码器的指定层中的注意力模型中执行剪枝操作
            self.encoder.layer[layer].attention.prune_heads(heads)

    @add_start_docstrings_to_model_forward(LILT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        bbox: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        ):
        """
        Performs forward pass of the model. Args:
            input_ids (Optional[torch.Tensor], optional): Input tensors for the model.
            bbox (Optional[torch.Tensor], optional): Bounding box tensors.
            attention_mask (Optional[torch.Tensor], optional): Attention mask tensors.
            token_type_ids (Optional[torch.Tensor], optional): Token type ID tensors.
            position_ids (Optional[torch.Tensor], optional): Position ID tensors.
            head_mask (Optional[torch.Tensor], optional): Head mask tensors.
            inputs_embeds (Optional[torch.Tensor], optional): Embedded input tensors.
            output_attentions (Optional[bool], optional): Whether to output attentions.
            output_hidden_states (Optional[bool], optional): Whether to output hidden states.
            return_dict (Optional[bool], optional): Whether to return as dictionary.
        """
        # 实现模型的前向传播,接收和处理各种输入张量
        # 具体参数作用见函数说明文档和相关注释
        pass
@add_start_docstrings(
    """
    LiLT Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    """,
    LILT_START_DOCSTRING,
)
class LiltForSequenceClassification(LiltPreTrainedModel):
    # 从transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification.__init__复制而来,将Roberta替换为Lilt,roberta替换为lilt
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels  # 初始化分类标签数
        self.config = config

        self.lilt = LiltModel(config, add_pooling_layer=False)  # 初始化Lilt模型,不添加池化层
        self.classifier = LiltClassificationHead(config)  # 初始化分类头部

        # 初始化权重并进行最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(LILT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        bbox: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,



@add_start_docstrings(
    """
    Lilt Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
    Named-Entity-Recognition (NER) tasks.
    """,
    LILT_START_DOCSTRING,
)
class LiltForTokenClassification(LiltPreTrainedModel):
    # 从transformers.models.roberta.modeling_roberta.RobertaForTokenClassification.__init__复制而来,将Roberta替换为Lilt,roberta替换为lilt
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels  # 初始化分类标签数

        self.lilt = LiltModel(config, add_pooling_layer=False)  # 初始化Lilt模型,不添加池化层
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)  # 初始化Dropout层
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)  # 初始化线性分类器层

        # 初始化权重并进行最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(LILT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,  # 输入的token ID序列,可以为空
        bbox: Optional[torch.LongTensor] = None,  # 包围框信息的张量,可以为空
        attention_mask: Optional[torch.FloatTensor] = None,  # 注意力掩码张量,可以为空
        token_type_ids: Optional[torch.LongTensor] = None,  # token类型ID张量,可以为空
        position_ids: Optional[torch.LongTensor] = None,  # 位置ID张量,可以为空
        head_mask: Optional[torch.FloatTensor] = None,  # 头部掩码张量,可以为空
        inputs_embeds: Optional[torch.FloatTensor] = None,  # 嵌入输入张量,可以为空
        labels: Optional[torch.LongTensor] = None,  # 用于计算标记分类损失的标签张量,可以为空
        output_attentions: Optional[bool] = None,  # 是否输出注意力张量的标志,可以为空
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态的标志,可以为空
        return_dict: Optional[bool] = None,  # 是否返回字典类型的输出,可以为空
    ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
        r"""
        返回Lilt模型的前向传播结果。

        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            用于计算标记分类损失的标签。索引应在 `[0, ..., config.num_labels - 1]` 范围内。

        Returns:
            如果 `return_dict=False`:
                返回一个包含 `(logits, hidden_states, attentions)` 的元组,其中 `logits` 是预测的分类结果张量。
                如果 `loss` 不为空,则还包含 `loss`。

            如果 `return_dict=True`:
                返回一个 `TokenClassifierOutput` 对象,包含 `loss`、`logits`、`hidden_states` 和 `attentions` 属性。

        Examples:

        ```
        >>> from transformers import AutoTokenizer, AutoModelForTokenClassification
        >>> from datasets import load_dataset

        >>> tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
        >>> model = AutoModelForTokenClassification.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")

        >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
        >>> example = dataset[0]
        >>> words = example["tokens"]
        >>> boxes = example["bboxes"]

        >>> encoding = tokenizer(words, boxes=boxes, return_tensors="pt")

        >>> outputs = model(**encoding)
        >>> predicted_class_indices = outputs.logits.argmax(-1)
        ```
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict  # 确定是否使用模型配置中的返回字典选项

        outputs = self.lilt(
            input_ids,
            bbox=bbox,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]  # 获取模型输出的序列输出

        sequence_output = self.dropout(sequence_output)  # 对序列输出应用dropout操作
        logits = self.classifier(sequence_output)  # 使用分类器对序列输出进行分类

        loss = None
        if labels is not None:
            # 将标签移动到正确的设备以启用模型并行计算
            labels = labels.to(logits.device)
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))  # 计算交叉熵损失

        if not return_dict:
            output = (logits,) + outputs[2:]  # 构建输出元组
            return ((loss,) + output) if loss is not None else output  # 如果有损失则返回损失和输出,否则只返回输出

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )  # 返回TokenClassifierOutput对象,包含损失、logits、隐藏状态和注意力
# 从 transformers.models.roberta.modeling_roberta.RobertaClassificationHead 复制代码,并将 Roberta 替换为 Lilt
class LiltClassificationHead(nn.Module):
    """用于句子级分类任务的头部模块。"""

    def __init__(self, config):
        super().__init__()
        # 全连接层,输入和输出维度都是 config.hidden_size
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 分类器的 dropout 概率,默认为 config.hidden_dropout_prob
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        # 最终的输出全连接层,输出维度是 config.num_labels
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, features, **kwargs):
        # 取序列中第一个 token 的隐藏状态作为特征
        x = features[:, 0, :]  # 相当于取 <s> token (等同于 [CLS])
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


@add_start_docstrings(
    """
    在 Lilt 模型顶部添加用于提取式问答任务的 span 分类头部(在隐藏状态输出之上的线性层,计算 `span start logits` 和 `span end logits`)。
    """,
    LILT_START_DOCSTRING,
)
class LiltForQuestionAnswering(LiltPreTrainedModel):
    # 从 transformers.models.roberta.modeling_roberta.RobertaForQuestionAnswering.__init__ 复制代码,并将 Roberta 替换为 Lilt, roberta 替换为 lilt
    def __init__(self, config):
        super().__init__(config)
        # 设置模型的标签数目
        self.num_labels = config.num_labels

        # 使用 LiltModel 初始化,不添加汇聚层
        self.lilt = LiltModel(config, add_pooling_layer=False)
        # 输出层,线性层,输入维度为 config.hidden_size,输出维度为 config.num_labels
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        # 初始化权重并应用最终处理
        self.post_init()

    @add_start_docstrings_to_model_forward(LILT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        bbox: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        start_positions: Optional[torch.LongTensor] = None,
        end_positions: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,

.\models\lilt\__init__.py

# 引入所需的模块和函数
from typing import TYPE_CHECKING

# 从自定义的工具包中引入异常处理类和延迟加载模块类
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available

# 定义模块的导入结构,包括配置和模型的名称
_import_structure = {
    "configuration_lilt": ["LILT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LiltConfig"],
}

# 尝试检查是否可用 Torch 库,如果不可用则引发自定义的异常类
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果 Torch 可用,则添加相关模型的导入结构
    _import_structure["modeling_lilt"] = [
        "LILT_PRETRAINED_MODEL_ARCHIVE_LIST",
        "LiltForQuestionAnswering",
        "LiltForSequenceClassification",
        "LiltForTokenClassification",
        "LiltModel",
        "LiltPreTrainedModel",
    ]

# 如果是类型检查阶段,则从配置和模型模块中导入特定的类和常量
if TYPE_CHECKING:
    from .configuration_lilt import LILT_PRETRAINED_CONFIG_ARCHIVE_MAP, LiltConfig

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        from .modeling_lilt import (
            LILT_PRETRAINED_MODEL_ARCHIVE_LIST,
            LiltForQuestionAnswering,
            LiltForSequenceClassification,
            LiltForTokenClassification,
            LiltModel,
            LiltPreTrainedModel,
        )

# 如果不是类型检查阶段,则设置模块为延迟加载模式
else:
    import sys

    # 将当前模块设置为 LazyModule,以便在需要时按需加载
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\llama\configuration_llama.py

# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" LLaMA model configuration"""

# Importing necessary classes from transformers library
from ...configuration_utils import PretrainedConfig
from ...utils import logging

# Getting the logger instance for logging messages related to this module
logger = logging.get_logger(__name__)

# Mapping dictionary to store pretrained configurations for LLaMA models
LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}

# Configuration class inheriting from PretrainedConfig to define LLaMA model configuration
class LlamaConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the LLaMA-7B.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    ```
    >>> from transformers import LlamaModel, LlamaConfig

    >>> # Initializing a LLaMA llama-7b style configuration
    >>> configuration = LlamaConfig()

    >>> # Initializing a model from the llama-7b style configuration
    >>> model = LlamaModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    """

    # Setting model_type attribute for LLaMA model identification
    model_type = "llama"
    
    # List of keys to ignore during inference
    keys_to_ignore_at_inference = ["past_key_values"]

    # Constructor method to initialize LLaMA configuration parameters
    def __init__(
        self,
        vocab_size=32000,
        hidden_size=4096,
        intermediate_size=11008,
        num_hidden_layers=32,
        num_attention_heads=32,
        num_key_value_heads=None,
        hidden_act="silu",
        max_position_embeddings=2048,
        initializer_range=0.02,
        rms_norm_eps=1e-6,
        use_cache=True,
        pad_token_id=None,
        bos_token_id=1,
        eos_token_id=2,
        pretraining_tp=1,
        tie_word_embeddings=False,
        rope_theta=10000.0,
        rope_scaling=None,
        attention_bias=False,
        attention_dropout=0.0,
        **kwargs,
        ):
        # 设置模型的参数:词汇表大小
        self.vocab_size = vocab_size
        # 设置模型的参数:最大位置编码长度
        self.max_position_embeddings = max_position_embeddings
        # 设置模型的参数:隐藏层大小
        self.hidden_size = hidden_size
        # 设置模型的参数:中间层大小
        self.intermediate_size = intermediate_size
        # 设置模型的参数:隐藏层的数量
        self.num_hidden_layers = num_hidden_layers
        # 设置模型的参数:注意力头的数量
        self.num_attention_heads = num_attention_heads

        # 兼容性处理:如果未指定键值头的数量,则默认与注意力头数量相同
        if num_key_value_heads is None:
            num_key_value_heads = num_attention_heads

        # 设置模型的参数:键值头的数量
        self.num_key_value_heads = num_key_value_heads
        # 设置模型的参数:隐藏层激活函数
        self.hidden_act = hidden_act
        # 设置模型的参数:初始化范围
        self.initializer_range = initializer_range
        # 设置模型的参数:RMS归一化的epsilon值
        self.rms_norm_eps = rms_norm_eps
        # 设置模型的参数:预训练类型
        self.pretraining_tp = pretraining_tp
        # 设置模型的参数:是否使用缓存
        self.use_cache = use_cache
        # 设置模型的参数:Rope模型的theta值
        self.rope_theta = rope_theta
        # 设置模型的参数:Rope模型的缩放参数
        self.rope_scaling = rope_scaling
        # 调用私有方法验证Rope模型的缩放参数是否合法
        self._rope_scaling_validation()
        # 设置模型的参数:注意力偏置
        self.attention_bias = attention_bias
        # 设置模型的参数:注意力dropout率
        self.attention_dropout = attention_dropout

        # 调用父类的初始化方法,设置模型的其他参数
        super().__init__(
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            tie_word_embeddings=tie_word_embeddings,
            **kwargs,
        )

    def _rope_scaling_validation(self):
        """
        Validate the `rope_scaling` configuration.
        """
        # 如果没有设置Rope模型的缩放参数,则直接返回
        if self.rope_scaling is None:
            return

        # 检查Rope模型的缩放参数是否为字典且包含两个字段
        if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
            raise ValueError(
                "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
                f"got {self.rope_scaling}"
            )
        # 获取Rope模型的缩放类型和缩放因子
        rope_scaling_type = self.rope_scaling.get("type", None)
        rope_scaling_factor = self.rope_scaling.get("factor", None)
        # 检查Rope模型的缩放类型是否合法
        if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
            raise ValueError(
                f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
            )
        # 检查Rope模型的缩放因子是否为浮点数且大于1
        if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
            raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")

.\models\llama\convert_llama_weights_to_hf.py

# 导入必要的库
import argparse  # 用于解析命令行参数
import gc  # Python 的垃圾回收模块
import json  # 用于 JSON 文件的读写操作
import os  # 提供了对操作系统的接口,用于文件和目录操作
import shutil  # 提供高级的文件操作功能
import warnings  # 用于处理警告信息

import torch  # 引入 PyTorch 库

# 从 transformers 库中导入所需的类
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer

try:
    # 尝试从 tokenizers 库中导入 LlamaTokenizerFast 类
    from transformers import LlamaTokenizerFast
except ImportError as e:
    # 如果导入失败,发出警告并提示使用慢速的 tokenizer
    warnings.warn(e)
    warnings.warn(
        "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
    )
    LlamaTokenizerFast = None

"""
样例用法:

python src/transformers/models/llama/convert_llama_weights_to_hf.py \
    --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
"""

# 各模型尺寸对应的分片数目
NUM_SHARDS = {
    "7B": 1,
    "7Bf": 1,
    "13B": 2,
    "13Bf": 2,
    "34B": 4,
    "30B": 4,
    "65B": 8,
    "70B": 8,
    "70Bf": 8,
}

def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
    # 计算中间层的大小,确保是给定倍数的整数
    return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)

def read_json(path):
    # 从指定路径读取 JSON 文件并返回其内容
    with open(path, "r") as f:
        return json.load(f)

def write_json(text, path):
    # 将文本内容写入指定路径的 JSON 文件
    with open(path, "w") as f:
        json.dump(text, f)

def write_model(
    model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True, llama_version=1
):
    # 为了向后兼容性,如果之前需要 repo 被称为 `my_repo/model_size`
    if not os.path.isfile(os.path.join(input_base_path, "params.json")):
        input_base_path = os.path.join(input_base_path, model_size)

    # 创建模型路径和临时模型路径
    os.makedirs(model_path, exist_ok=True)
    tmp_model_path = os.path.join(model_path, "tmp")
    os.makedirs(tmp_model_path, exist_ok=True)

    # 读取模型参数 JSON 文件
    params = read_json(os.path.join(input_base_path, "params.json"))
    num_shards = NUM_SHARDS[model_size]
    params = params.get("model", params)
    n_layers = params["n_layers"]
    n_heads = params["n_heads"]
    # 计算每个分片中的注意力头数
    n_heads_per_shard = n_heads // num_shards

    # 从参数字典中获取维度信息
    dim = params["dim"]

    # 计算每个头部的维度大小
    dims_per_head = dim // n_heads

    # 获取参数中的 "rope_theta",默认为 10000.0
    base = params.get("rope_theta", 10000.0)

    # 计算逆频率,用于位置编码
    inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))

    # 根据 base 的大小确定最大位置嵌入的值
    if base > 10000.0:
        max_position_embeddings = 16384
    else:
        # 根据 Llama 的版本确定默认的最大位置嵌入
        if llama_version == 1:
            max_position_embeddings = 2048
        elif llama_version == 2:
            max_position_embeddings = 4096
        else:
            # 抛出未实现错误,对于不支持的 Llama 版本
            raise NotImplementedError(
                f"Version {llama_version} of llama is not supported yet. "
                "Current supported versions of llama are [1, 2]."
            )

    # 根据 LlamaTokenizerFast 是否为 None 选择正确的 tokenizer 类
    tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast

    # 如果提供了 tokenizer_path,则初始化 tokenizer 并保存到 model_path
    if tokenizer_path is not None:
        tokenizer = tokenizer_class(tokenizer_path)
        tokenizer.save_pretrained(model_path)

    # 根据 tokenizer_path 是否为 None 决定词汇表大小
    vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000

    # 如果参数中提供了 n_kv_heads,则使用其定义的键值头数,否则使用默认值
    if params.get("n_kv_heads", None) is not None:
        num_key_value_heads = params["n_kv_heads"]  # for GQA / MQA
        num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
        key_value_dim = dim // num_key_value_heads
    else:
        # 兼容性处理,对于其他检查点使用默认值
        num_key_value_heads = n_heads
        num_local_key_value_heads = n_heads_per_shard
        key_value_dim = dim

    # 定义用于分片旋转的置换函数
    def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
        return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)

    # 打印加载检查点参数的信息
    print(f"Fetching all parameters from the checkpoint at {input_base_path}.")

    # 加载权重
    if num_shards == 1:
        # 如果不分片,则加载单个文件
        loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
    else:
        # 如果分片,则加载所有分片的文件
        loaded = [
            torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
            for i in range(num_shards)
        ]

    # 初始化参数计数器和索引字典
    param_count = 0
    index_dict = {"weight_map": {}}

    # 构建模型文件名
    filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"

    if num_shards == 1:
        # 如果不分片,则构建状态字典
        state_dict = {
            "model.embed_tokens.weight": loaded["tok_embeddings.weight"],
            "model.norm.weight": loaded["norm.weight"],
            "lm_head.weight": loaded["output.weight"],
        }
    else:
        # 如果分片,则合并各分片的权重
        state_dict = {
            "model.norm.weight": loaded[0]["norm.weight"],
            "model.embed_tokens.weight": torch.cat(
                [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
            ),
            "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
        }
    # 遍历状态字典中的键值对,将键(参数名称)映射到文件名
    for k, v in state_dict.items():
        index_dict["weight_map"][k] = filename
        # 累加参数张量中元素的数量,计算模型参数总数
        param_count += v.numel()
    
    # 使用PyTorch保存模型参数到文件系统中
    torch.save(state_dict, os.path.join(tmp_model_path, filename))

    # 写入配置信息到索引字典中
    index_dict["metadata"] = {"total_size": param_count * 2}
    # 将索引字典以JSON格式写入到文件系统中
    write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
    
    # 根据参数中的配置,确定FFN维度的倍增器和倍数
    ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1
    multiple_of = params["multiple_of"] if "multiple_of" in params else 256
    
    # 创建Llama模型的配置对象
    config = LlamaConfig(
        hidden_size=dim,
        intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
        num_attention_heads=params["n_heads"],
        num_hidden_layers=params["n_layers"],
        rms_norm_eps=params["norm_eps"],
        num_key_value_heads=num_key_value_heads,
        vocab_size=vocab_size,
        rope_theta=base,
        max_position_embeddings=max_position_embeddings,
    )
    # 将配置保存到临时模型路径
    config.save_pretrained(tmp_model_path)

    # 释放不再需要的对象,清理内存
    del state_dict
    del loaded
    gc.collect()

    # 打印加载Llama模型检查点的消息
    print("Loading the checkpoint in a Llama model.")
    # 从预训练模型路径加载Llama模型,指定张量数据类型和低CPU内存使用
    model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
    
    # 避免将此项设置保存为配置的一部分
    del model.config._name_or_path
    # 将模型配置的张量数据类型设置为float16
    model.config.torch_dtype = torch.float16
    
    # 打印保存为Transformers格式的消息
    print("Saving in the Transformers format.")
    # 将Llama模型保存到指定的模型路径,进行安全序列化
    model.save_pretrained(model_path, safe_serialization=safe_serialization)
    
    # 递归删除临时模型路径及其内容
    shutil.rmtree(tmp_model_path)
# 主函数,程序的入口点
def main():
    # 创建参数解析器
    parser = argparse.ArgumentParser()
    # 添加命令行参数:输入目录,包含 LLAMA 权重文件,包括 tokenizer.model 和 model 文件夹
    parser.add_argument(
        "--input_dir",
        help="Location of LLaMA weights, which contains tokenizer.model and model folders",
    )
    # 添加命令行参数:模型大小,可选项为不同大小的 Llama 模型或仅令牌化器
    parser.add_argument(
        "--model_size",
        choices=["7B", "7Bf", "13B", "13Bf", "30B", "34B", "65B", "70B", "70Bf", "tokenizer_only"],
        help="'f' models correspond to the finetuned versions, and are specific to the Llama2 official release. For more details on Llama2, checkout the original repo: https://huggingface.co/meta-llama",
    )
    # 添加命令行参数:输出目录,用于写入 HF 模型和令牌化器
    parser.add_argument(
        "--output_dir",
        help="Location to write HF model and tokenizer",
    )
    # 添加命令行参数:安全序列化选项,指示是否使用 `safetensors` 进行保存
    parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
    # 添加命令行参数:LLAMA 版本,选择 1 或 2,用于控制上下文大小
    parser.add_argument(
        "--llama_version",
        choices=[1, 2],
        default=1,
        type=int,
        help="Version of the Llama model to convert. Currently supports Llama1 and Llama2. Controls the context size",
    )
    # 解析命令行参数
    args = parser.parse_args()
    
    # 构造令牌化器模型文件路径
    spm_path = os.path.join(args.input_dir, "tokenizer.model")
    
    # 如果模型大小不是 "tokenizer_only",则调用写入模型函数
    if args.model_size != "tokenizer_only":
        write_model(
            model_path=args.output_dir,
            input_base_path=args.input_dir,
            model_size=args.model_size,
            safe_serialization=args.safe_serialization,
            tokenizer_path=spm_path,
            llama_version=args.llama_version,
        )
    else:
        # 否则,仅写入令牌化器
        write_tokenizer(args.output_dir, spm_path)


# 如果当前脚本作为主程序运行,则执行主函数
if __name__ == "__main__":
    main()

.\models\llama\modeling_flax_llama.py

# 引入必要的模块和库
from functools import partial  # 导入 functools 模块中的 partial 函数,用于创建带有部分参数的新函数
from typing import Optional, Tuple  # 导入 typing 模块中的 Optional 和 Tuple 类型,用于类型标注

import flax.linen as nn  # 导入 Flax 的 linen 模块,并用 nn 别名引用
import jax  # 导入 JAX 库,用于自动求导和并行计算
import jax.numpy as jnp  # 导入 JAX 库中的 numpy 模块,并用 jnp 别名引用
import numpy as np  # 导入 numpy 库,并用 np 别名引用
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze  # 导入 Flax 中的 FrozenDict 等相关函数
from flax.linen import combine_masks, make_causal_mask  # 导入 Flax 中的相关函数和类
from flax.linen.attention import dot_product_attention_weights  # 导入 Flax 中的 dot_product_attention_weights 函数
from flax.traverse_util import flatten_dict, unflatten_dict  # 导入 Flax 中的 flatten_dict 和 unflatten_dict 函数
from jax import lax  # 从 JAX 库中导入 lax 模块

# 导入模型相关的输出类和函数
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging

# 导入 LLaMA 模型的配置类
from .configuration_llama import LlamaConfig

# 获取 logger 对象,用于日志记录
logger = logging.get_logger(__name__)

# 用于文档的配置、检查点和真实检查点的字符串常量
_CONFIG_FOR_DOC = "LlamaConfig"
_CHECKPOINT_FOR_DOC = "afmck/testing-llama-tiny"
_REAL_CHECKPOINT_FOR_DOC = "openlm-research/open_llama_3b_v2"

# LLaMA 模型的起始文档字符串,包含了模型的继承信息和特性说明
LLAMA_START_DOCSTRING = r"""

    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a Flax Linen
    [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
    regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
    # 参数:
    # config ([`LlamaConfig`]): 模型配置类,包含模型的所有参数。
    #     用配置文件初始化不会加载与模型关联的权重,仅加载配置。
    #     可以查看 [`~FlaxPreTrainedModel.from_pretrained`] 方法来加载模型权重。
    # dtype (`jax.numpy.dtype`, *可选*, 默认为 `jax.numpy.float32`):
    #     计算的数据类型。可以是 `jax.numpy.float32`, `jax.numpy.float16`, 或 `jax.numpy.bfloat16` 中的一种。
    # 
    #     这可用于在 GPU 或 TPU 上启用混合精度训练或半精度推断。如果指定,则所有计算将使用给定的 `dtype` 进行。
    # 
    #     **请注意,这仅指定计算的数据类型,不影响模型参数的数据类型。**
    # 
    #     如果您希望更改模型参数的数据类型,请参阅 [`~FlaxPreTrainedModel.to_fp16`] 和 [`~FlaxPreTrainedModel.to_bf16`]。
# 创建正弦位置编码矩阵,用于将位置索引映射为正弦波形式的向量表示
def create_sinusoidal_positions(num_pos, dim):
    # 计算正弦编码的频率逆频率
    inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
    # 计算位置索引乘以频率得到的矩阵,每个维度都是浮点数
    freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32")
    # 按照最后一个维度将两个频率矩阵连接起来,形成最终的正弦位置编码矩阵
    emb = np.concatenate((freqs, freqs), axis=-1)
    # 将 emb 数组中的每个元素应用正弦函数,然后与对应元素应用余弦函数的结果拼接起来
    out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1)
    # 从拼接后的数组中取出前 num_pos 列,并转换为 JAX 数组格式返回
    return jnp.array(out[:, :, :num_pos])
# 定义一个函数,用于将输入张量的后一半隐藏维度旋转
def rotate_half(tensor):
    # 将张量按照其最后一个维度的一半进行拼接,实现旋转操作
    rotate_half_tensor = jnp.concatenate(
        (-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1
    )
    return rotate_half_tensor


# 定义一个函数,将旋转的位置嵌入应用到输入张量上
def apply_rotary_pos_emb(tensor, sin_pos, cos_pos):
    # 将输入张量乘以余弦位置编码,然后加上经过旋转半隐藏维度的正弦位置编码
    return (tensor * cos_pos) + (rotate_half(tensor) * sin_pos)


# 定义一个名为FlaxLlamaRMSNorm的类,继承自nn.Module
class FlaxLlamaRMSNorm(nn.Module):
    # 类的配置信息
    config: LlamaConfig
    dtype: jnp.dtype = jnp.float32

    # 设置方法,在类实例化时调用,用于初始化权重和其他参数
    def setup(self):
        # 设置 epsilon 参数为 RMS 归一化的小数值
        self.epsilon = self.config.rms_norm_eps
        # 初始化权重矩阵,形状为隐藏大小(hidden_size)
        self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.hidden_size)

    # 类的调用方法,对隐藏状态进行处理
    def __call__(self, hidden_states):
        # 将隐藏状态转换为 jax 数组,并将数据类型设置为 jnp.float32
        variance = jnp.asarray(hidden_states, dtype=jnp.float32)
        # 计算方差的平方
        variance = jnp.power(variance, 2)
        # 求取方差的平均值,保持最后一个维度
        variance = variance.mean(-1, keepdims=True)
        # 根据 RMS 归一化公式,将隐藏状态除以标准差加上一个小值 epsilon
        hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon)

        # 返回加权后的隐藏状态
        return self.weight * jnp.asarray(hidden_states, dtype=self.dtype)


# 定义一个名为FlaxLlamaRotaryEmbedding的类,继承自nn.Module
class FlaxLlamaRotaryEmbedding(nn.Module):
    # 类的配置信息
    config: LlamaConfig
    dtype: jnp.dtype = jnp.float32

    # 设置方法,在类实例化时调用,用于初始化位置编码
    def setup(self):
        # 计算每个注意力头的维度
        head_dim = self.config.hidden_size // self.config.num_attention_heads
        # 创建正弦余弦位置编码
        self.sincos = create_sinusoidal_positions(self.config.max_position_embeddings, head_dim)

    # 类的调用方法,将位置编码应用到键、查询和位置ID上
    def __call__(self, key, query, position_ids):
        # 获取指定位置ID的正弦余弦位置编码
        sincos = self.sincos[position_ids]
        sin_pos, cos_pos = jnp.split(sincos, 2, axis=-1)

        # 将正弦余弦位置编码应用到键和查询上
        key = apply_rotary_pos_emb(key, sin_pos, cos_pos)
        query = apply_rotary_pos_emb(query, sin_pos, cos_pos)

        # 将键和查询转换为 jax 数组,并将数据类型设置为 self.dtype
        key = jnp.asarray(key, dtype=self.dtype)
        query = jnp.asarray(query, dtype=self.dtype)

        # 返回处理后的键和查询
        return key, query


# 定义一个名为FlaxLlamaAttention的类,继承自nn.Module
class FlaxLlamaAttention(nn.Module):
    # 类的配置信息
    config: LlamaConfig
    dtype: jnp.dtype = jnp.float32
    causal: bool = True
    is_cross_attention: bool = False

    # 设置方法,在类实例化时调用,用于初始化注意力机制的参数
    def setup(self):
        # 从配置中获取参数
        config = self.config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        self.attention_softmax_in_fp32 = self.dtype is not jnp.float32

        # 创建偏置注意力层
        dense = partial(
            nn.Dense,
            self.embed_dim,
            use_bias=config.attention_bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )

        # 初始化查询、键、值和输出投影层
        self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
        self.o_proj = dense()

        # 创建因果遮罩
        self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
        # 创建旋转嵌入层
        self.rotary_emb = FlaxLlamaRotaryEmbedding(config, dtype=self.dtype)

    # 内部方法,用于将隐藏状态分割为多个注意力头
    def _split_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
    def _merge_heads(self, hidden_states):
        # 将输入的 hidden_states 重塑成形状为 (batch_size, sequence_length, self.embed_dim) 的张量,并返回
        return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))

    @nn.compact
    # 从 transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache 复制而来
    def _concatenate_to_cache(self, key, value, query, attention_mask):
        """
        This function takes projected key, value states from a single input token and concatenates the states to cached
        states from previous steps. This function is slighly adapted from the official Flax repository:
        https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
        """
        # 检测是否通过缺少现有缓存数据进行初始化
        is_initialized = self.has_variable("cache", "cached_key")
        # 如果未初始化,则创建形状和类型与 key 相同的零张量作为 cached_key
        cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
        # 如果未初始化,则创建形状和类型与 value 相同的零张量作为 cached_value
        cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
        # 如果未初始化,则创建初始值为 0 的 cache_index
        cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))

        if is_initialized:
            # 获取当前缓存张量的形状信息
            *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
            # 使用新的 1 维空间切片更新 key、value 缓存
            cur_index = cache_index.value
            indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
            key = lax.dynamic_update_slice(cached_key.value, key, indices)
            value = lax.dynamic_update_slice(cached_value.value, value, indices)
            cached_key.value = key
            cached_value.value = value
            # 更新 cache_index 值,增加已更新的缓存向量数量
            num_updated_cache_vectors = query.shape[1]
            cache_index.value = cache_index.value + num_updated_cache_vectors
            # 生成用于缓存的因果掩码:我们的单个查询位置只应关注已生成和缓存的键位置,而不是剩余的零元素
            pad_mask = jnp.broadcast_to(
                jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
                tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
            )
            # 将 pad_mask 与 attention_mask 结合
            attention_mask = combine_masks(pad_mask, attention_mask)
        return key, value, attention_mask

    def __call__(
        self,
        hidden_states,
        attention_mask,
        position_ids,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        ):
            # 使用投影函数计算查询向量
            query = self.q_proj(hidden_states)
            # 使用投影函数计算键向量
            key = self.k_proj(hidden_states)
            # 使用投影函数计算值向量
            value = self.v_proj(hidden_states)

            # 将查询向量分割成多个头
            query = self._split_heads(query)
            # 将键向量分割成多个头
            key = self._split_heads(key)
            # 将值向量分割成多个头
            value = self._split_heads(value)

            # 应用旋转位置编码器到键和查询向量
            key, query = self.rotary_emb(key, query, position_ids)

            # 获取查询向量和键向量的长度
            query_length, key_length = query.shape[1], key.shape[1]

            # 构建因果掩码
            if self.has_variable("cache", "cached_key"):
                # 如果有缓存的键,根据缓存索引和最大解码器长度动态切片因果掩码
                mask_shift = self.variables["cache"]["cache_index"]
                max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
                causal_mask = lax.dynamic_slice(
                    self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
                )
            else:
                # 否则,使用静态切片获取因果掩码
                causal_mask = self.causal_mask[:, :, :query_length, :key_length]

            # 获取批次大小
            batch_size = hidden_states.shape[0]
            # 将因果掩码广播到与查询向量和键向量匹配的形状
            causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])

            # 广播注意力掩码以匹配因果掩码的形状
            attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
            # 结合注意力掩码和因果掩码
            attention_mask = combine_masks(attention_mask, causal_mask)

            # 初始化 dropout RNG
            dropout_rng = None
            if not deterministic and self.config.attention_dropout > 0.0:
                dropout_rng = self.make_rng("dropout")

            # 在快速自回归解码期间,逐步一次性输入一个位置,逐步缓存键和值。
            if self.has_variable("cache", "cached_key") or init_cache:
                key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)

            # 将布尔掩码转换为浮点数掩码
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
            )

            # 标准点积注意力
            attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype
            attn_weights = dot_product_attention_weights(
                query,
                key,
                bias=attention_bias,
                dropout_rng=dropout_rng,
                dropout_rate=self.config.attention_dropout,
                deterministic=deterministic,
                dtype=attention_dtype,
            )

            # 如果需要,将注意力权重转换为指定的数据类型
            if self.attention_softmax_in_fp32:
                attn_weights = attn_weights.astype(self.dtype)

            # 使用注意力权重计算注意力输出
            attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
            # 合并多头得到的注意力输出
            attn_output = self._merge_heads(attn_output)
            # 应用输出投影层
            attn_output = self.o_proj(attn_output)

            # 准备输出,包括注意力输出和注意力权重(如果需要)
            outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
            return outputs
class FlaxLlamaMLP(nn.Module):
    config: LlamaConfig  # 类型注解:指定该类的配置信息来自于LlamaConfig类
    dtype: jnp.dtype = jnp.float32  # 类型注解:指定数据类型为jnp.float32,默认为浮点数类型

    def setup(self):
        embed_dim = self.config.hidden_size  # 从配置中获取隐藏层大小作为嵌入维度
        inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim
        # 计算内部层维度,如果配置中有中间大小定义则使用,否则使用默认值4倍的嵌入维度

        kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
        # 使用正态分布初始化器初始化核参数,范围由配置的initializer_range定义
        self.act = ACT2FN[self.config.hidden_act]
        # 从ACT2FN字典中获取激活函数,并存储在act属性中,其类型由配置的hidden_act指定

        self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
        # 创建具有inner_dim大小的全连接层,不使用偏置,使用上述初始化器初始化权重
        self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
        # 创建具有embed_dim大小的全连接层,不使用偏置,使用上述初始化器初始化权重
        self.up_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
        # 创建具有inner_dim大小的全连接层,不使用偏置,使用上述初始化器初始化权重

    def __call__(self, hidden_states):
        up_proj_states = self.up_proj(hidden_states)
        # 使用up_proj层处理输入的隐藏状态
        gate_states = self.act(self.gate_proj(hidden_states))
        # 使用激活函数act处理gate_proj层处理后的隐藏状态

        hidden_states = self.down_proj(up_proj_states * gate_states)
        # 使用down_proj层处理up_proj_states与gate_states的乘积,并将结果存储在隐藏状态中
        return hidden_states
        # 返回处理后的隐藏状态作为结果


class FlaxLlamaDecoderLayer(nn.Module):
    config: LlamaConfig  # 类型注解:指定该类的配置信息来自于LlamaConfig类
    dtype: jnp.dtype = jnp.float32  # 类型注解:指定数据类型为jnp.float32,默认为浮点数类型

    def setup(self):
        self.input_layernorm = FlaxLlamaRMSNorm(self.config, dtype=self.dtype)
        # 创建一个使用LlamaConfig和指定数据类型的FlaxLlamaRMSNorm实例,存储在input_layernorm属性中
        self.self_attn = FlaxLlamaAttention(self.config, dtype=self.dtype)
        # 创建一个使用LlamaConfig和指定数据类型的FlaxLlamaAttention实例,存储在self_attn属性中
        self.post_attention_layernorm = FlaxLlamaRMSNorm(self.config, dtype=self.dtype)
        # 创建一个使用LlamaConfig和指定数据类型的FlaxLlamaRMSNorm实例,存储在post_attention_layernorm属性中
        self.mlp = FlaxLlamaMLP(self.config, dtype=self.dtype)
        # 创建一个使用LlamaConfig和指定数据类型的FlaxLlamaMLP实例,存储在mlp属性中

    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        position_ids=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
    ):
        residual = hidden_states
        # 将输入的隐藏状态存储在变量residual中,用于残差连接
        hidden_states = self.input_layernorm(hidden_states)
        # 使用input_layernorm对隐藏状态进行规范化处理

        outputs = self.self_attn(
            hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
        )
        # 使用self_attn处理规范化后的隐藏状态,传递额外参数attention_mask、position_ids等,并将结果存储在outputs中

        attn_output = outputs[0]
        # 从outputs中获取注意力机制的输出
        hidden_states = residual + attn_output
        # 将residual与注意力输出相加得到新的隐藏状态

        residual = hidden_states
        # 将新的隐藏状态存储在变量residual中,用于下一步的残差连接
        hidden_states = self.post_attention_layernorm(hidden_states)
        # 使用post_attention_layernorm对新的隐藏状态进行规范化处理
        hidden_states = self.mlp(hidden_states)
        # 使用mlp处理规范化后的隐藏状态,得到最终的输出

        hidden_states = residual + hidden_states
        # 将残差连接的结果与MLP处理后的隐藏状态相加,作为最终的输出

        return (hidden_states,) + outputs[1:]
        # 返回包含最终输出和outputs中其他项的元组


# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel with GPTNeo->Llama, GPT_NEO->LLAMA, transformer->model
class FlaxLlamaPreTrainedModel(FlaxPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = LlamaConfig
    # 指定配置类为LlamaConfig
    base_model_prefix = "model"
    # 指定基础模型前缀为"model"
    module_class: nn.Module = None
    # 指定模块类为nn.Module,初始值为None
    def __init__(
        self,
        config: LlamaConfig,
        input_shape: Tuple = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        # 使用给定的配置和参数初始化模块对象
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        # 调用父类的初始化方法,传入配置、模块对象以及其他参数
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # 初始化输入张量
        input_ids = jnp.zeros(input_shape, dtype="i4")
        # 创建与 input_ids 相同形状的全 1 张量作为 attention_mask
        attention_mask = jnp.ones_like(input_ids)
        # 根据 input_ids 的维度生成位置编码
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
        # 利用输入的随机种子分割出两个随机数生成器
        params_rng, dropout_rng = jax.random.split(rng)
        # 将随机数生成器存入字典
        rngs = {"params": params_rng, "dropout": dropout_rng}

        # 利用模块的初始化方法初始化参数
        random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]

        # 如果提供了额外的参数,则将随机初始化的参数与提供的参数进行合并
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params

    def init_cache(self, batch_size, max_length):
        r"""
        Args:
            batch_size (`int`):
                fast auto-regressive decoding 使用的批大小。定义初始化缓存的批大小。
            max_length (`int`):
                auto-regressive decoding 的最大可能长度。定义初始化缓存的序列长度。
        """
        # 初始化输入变量以检索缓存
        input_ids = jnp.ones((batch_size, max_length))
        # 创建与 input_ids 相同形状的全 1 张量作为 attention_mask
        attention_mask = jnp.ones_like(input_ids)
        # 根据 input_ids 的形状生成位置编码
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        # 利用模块的初始化方法初始化变量,并指定初始化缓存
        init_variables = self.module.init(
            jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
        )
        return unfreeze(init_variables["cache"])

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        position_ids=None,
        params: dict = None,
        past_key_values: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 省略了 __call__ 方法的注释,因为该方法通过装饰器 @add_start_docstrings_to_model_forward 添加了文档字符串
        ):
        # 如果没有显式提供输出注意力的设置,则使用配置中的默认值
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 如果没有显式提供输出隐藏状态的设置,则使用配置中的默认值
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 如果没有显式提供返回字典的设置,则使用配置中的默认值
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # 获取输入张量的批量大小和序列长度
        batch_size, sequence_length = input_ids.shape

        # 如果未提供位置编码,则根据序列长度创建默认位置编码
        if position_ids is None:
            # 如果传入了过去的键值(past_key_values),则需要明确提供位置编码,否则抛出异常
            if past_key_values is not None:
                raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
            
            # 使用广播操作将序列长度范围内的数组扩展为指定批次大小的位置编码张量
            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

        # 如果未提供注意力遮罩,则创建全1的注意力遮罩张量
        if attention_mask is None:
            attention_mask = jnp.ones((batch_size, sequence_length))

        # 处理任何需要的伪随机数生成器
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        # 准备输入参数字典,包括模型参数或者传入的参数
        inputs = {"params": params or self.params}

        # 如果传入了过去的键值(past_key_values),则将其作为缓存传递给模型,确保缓存是可变的以便后续更新
        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        # 应用模型的正向传播,传递所有必要的输入张量和设置
        outputs = self.module.apply(
            inputs,
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            not train,
            False,
            output_attentions,
            output_hidden_states,
            return_dict,
            rngs=rngs,
            mutable=mutable,
        )

        # 如果传入了过去的键值(past_key_values)并且需要返回字典,则将更新后的缓存添加到模型输出中
        if past_key_values is not None and return_dict:
            outputs, past_key_values = outputs
            outputs["past_key_values"] = unfreeze(past_key_values["cache"])
            return outputs
        # 如果传入了过去的键值(past_key_values)但不需要返回字典,则将更新后的缓存添加到模型输出元组中
        elif past_key_values is not None and not return_dict:
            outputs, past_key_values = outputs
            outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]

        # 返回模型的输出结果
        return outputs
class FlaxLlamaLayerCollection(nn.Module):
    # LlamaConfig 类型的配置信息
    config: LlamaConfig
    # 默认数据类型为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 初始化方法
    def setup(self):
        # 创建一系列 FlaxLlamaDecoderLayer 对象并存储在 self.blocks 中
        self.blocks = [
            FlaxLlamaDecoderLayer(self.config, dtype=self.dtype, name=str(i))
            for i in range(self.config.num_hidden_layers)
        ]

    # 调用实例时执行的方法
    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        position_ids=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = False,
    ):
        # 如果输出注意力矩阵,则初始化空的元组 all_attentions
        all_attentions = () if output_attentions else None
        # 如果输出隐藏状态,则初始化空的元组 all_hidden_states
        all_hidden_states = () if output_hidden_states else None

        # 遍历 self.blocks 中的每个 FlaxLlamaDecoderLayer 对象
        for block in self.blocks:
            # 如果输出隐藏状态,则将当前隐藏状态 hidden_states 添加到 all_hidden_states 中
            if output_hidden_states:
                all_hidden_states += (hidden_states,)
            # 调用 block 对象,计算层的输出结果 layer_outputs
            layer_outputs = block(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                deterministic=deterministic,
                init_cache=init_cache,
                output_attentions=output_attentions,
            )
            # 更新 hidden_states 为当前层的输出结果中的第一个元素
            hidden_states = layer_outputs[0]

            # 如果输出注意力矩阵,则将当前层的注意力矩阵添加到 all_attentions 中
            if output_attentions:
                all_attentions += (layer_outputs[1],)

        # 输出结果包括 hidden_states, all_hidden_states, all_attentions
        # 注意:all_hidden_states 和 all_attentions 可能包含 None 值,由 FlaxLlamaModule 进行过滤处理
        outputs = (hidden_states, all_hidden_states, all_attentions)

        return outputs


class FlaxLlamaModule(nn.Module):
    # LlamaConfig 类型的配置信息
    config: LlamaConfig
    # 默认数据类型为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 初始化方法
    def setup(self):
        # 设置隐藏大小为 config 中的隐藏大小
        self.hidden_size = self.config.hidden_size
        # 使用正态分布初始化 embed_tokens 层,存储在 self.embed_tokens 中
        embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range)
        self.embed_tokens = nn.Embed(
            self.config.vocab_size,
            self.hidden_size,
            embedding_init=embedding_init,
            dtype=self.dtype,
        )
        # 创建 FlaxLlamaLayerCollection 对象并存储在 self.layers 中
        self.layers = FlaxLlamaLayerCollection(self.config, dtype=self.dtype)
        # 创建 FlaxLlamaRMSNorm 对象并存储在 self.norm 中
        self.norm = FlaxLlamaRMSNorm(self.config, dtype=self.dtype)

    # 调用实例时执行的方法
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        position_ids=None,
        deterministic=True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 省略部分代码,未提供完整内容
    # 使用给定的输入 ID 创建输入的嵌入表示,数据类型转换为32位整数
    input_embeds = self.embed_tokens(input_ids.astype("i4"))
    
    # 将输入的嵌入表示传递给模型的层进行处理,并返回处理后的输出结果
    outputs = self.layers(
        input_embeds,
        position_ids=position_ids,
        attention_mask=attention_mask,
        deterministic=deterministic,
        init_cache=init_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
    
    # 从模型输出中获取隐藏状态,索引为0的元素为模型的最后隐藏状态
    hidden_states = outputs[0]
    
    # 对隐藏状态进行归一化处理
    hidden_states = self.norm(hidden_states)
    
    # 如果需要输出所有隐藏状态,则将当前隐藏状态添加到所有隐藏状态列表中
    if output_hidden_states:
        all_hidden_states = outputs[1] + (hidden_states,)
        outputs = (hidden_states, all_hidden_states) + outputs[2:]
    else:
        outputs = (hidden_states,) + outputs[1:]
    
    # 如果不需要以字典形式返回结果,则返回所有非空的输出值的元组
    if not return_dict:
        return tuple(v for v in outputs if v is not None)
    
    # 如果需要以字典形式返回结果,则使用 FlaxBaseModelOutput 类封装最后的隐藏状态、所有隐藏状态和注意力值
    return FlaxBaseModelOutput(
        last_hidden_state=hidden_states,
        hidden_states=outputs[1],
        attentions=outputs[-1],
    )
# 添加起始文档字符串和元数据到 FlaxLlamaModel 类,说明它是一个裸 Llama 模型变换器,输出原始隐藏状态,没有特定的顶部头部。
# 使用 LLAMA_START_DOCSTRING 定义的起始文档字符串作为补充信息。
@add_start_docstrings(
    "The bare Llama Model transformer outputting raw hidden-states without any specific head on top.",
    LLAMA_START_DOCSTRING,
)
class FlaxLlamaModel(FlaxLlamaPreTrainedModel):
    module_class = FlaxLlamaModule


# 向 FlaxLlamaModel 类添加调用示例的文档字符串
append_call_sample_docstring(
    FlaxLlamaModel,
    _CHECKPOINT_FOR_DOC,
    FlaxBaseModelOutput,
    _CONFIG_FOR_DOC,
    real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
)


# 定义 FlaxLlamaForCausalLMModule 类,用于支持因果语言建模任务
class FlaxLlamaForCausalLMModule(nn.Module):
    # 模块配置参数
    config: LlamaConfig
    # 数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # 使用给定的配置参数和数据类型创建 Llama 模型
        self.model = FlaxLlamaModule(self.config, dtype=self.dtype)
        # 创建语言建模头部,一个全连接层,用于生成词汇表大小的输出
        self.lm_head = nn.Dense(
            self.config.vocab_size,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
        )

    # 定义模块的调用方法
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        position_ids=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 调用 Llama 模型来处理输入序列
        outputs = self.model(
            input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取模型的隐藏状态
        hidden_states = outputs[0]
        # 使用语言建模头部生成最终的语言建模输出
        lm_logits = self.lm_head(hidden_states)

        # 如果不要求返回字典格式的输出,则返回元组形式的输出
        if not return_dict:
            return (lm_logits,) + outputs[1:]

        # 返回格式化后的因果语言建模输出
        return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)


# 向 FlaxLlamaForCausalLM 类添加起始文档字符串,说明它是带有语言建模头部的 Llama 模型变换器
@add_start_docstrings(
    """
    The Llama Model transformer with a language modeling head (linear layer) on top.
    """,
    LLAMA_START_DOCSTRING,
)
# 从 transformers.models.gptj.modeling_flax_gptj.FlaxGPTJForCausalLM 复制到 FlaxLlamaForCausalLM,
# 并将其中的 GPTJ 替换为 Llama
class FlaxLlamaForCausalLM(FlaxLlamaPreTrainedModel):
    module_class = FlaxLlamaForCausalLMModule
    # 为生成准备输入数据的方法
    def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
        # 初始化缓存
        batch_size, seq_length = input_ids.shape

        # 使用模型的初始化方法创建缓存
        past_key_values = self.init_cache(batch_size, max_length)

        # 注意:通常需要在 attention_mask 中对超出 input_ids.shape[-1] 和 cache_length 之外的位置置为 0。
        # 但由于 Llama 使用因果注意力机制,这些位置已经被掩码处理。
        # 因此,在这里我们可以创建一个单一的静态 attention_mask,这样更高效地进行编译。
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
        
        # 如果有传入 attention_mask,则根据它计算 position_ids
        if attention_mask is not None:
            position_ids = attention_mask.cumsum(axis=-1) - 1
            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
        else:
            # 否则,根据序列长度广播创建 position_ids
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))

        # 返回生成所需的输入数据字典
        return {
            "past_key_values": past_key_values,
            "attention_mask": extended_attention_mask,
            "position_ids": position_ids,
        }

    # 更新生成过程中的输入数据的方法
    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        # 更新模型关键值缓存和 position_ids
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
        
        # 返回更新后的输入数据字典
        return model_kwargs
# 将样本文档字符串附加到指定类中的方法上
append_call_sample_docstring(
    # 目标类:FlaxLlamaForCausalLM,用于添加文档字符串
    FlaxLlamaForCausalLM,
    # 用于文档的检查点对象的名称或引用:_CHECKPOINT_FOR_DOC
    _CHECKPOINT_FOR_DOC,
    # 生成的文档字符串应描述的输出对象类型:FlaxCausalLMOutput
    FlaxCausalLMOutput,
    # 用于文档的配置对象的名称或引用:_CONFIG_FOR_DOC
    _CONFIG_FOR_DOC,
    # 实际使用的检查点对象的名称或引用:_REAL_CHECKPOINT_FOR_DOC
    real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
)

.\models\llama\modeling_llama.py

# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch LLaMA model."""

# 引入数学库和警告库
import math
import warnings
# 引入类型提示相关的模块
from typing import List, Optional, Tuple, Union

# 引入PyTorch相关的模块
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
# 引入PyTorch的神经网络模块
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

# 引入各种工具函数和模型输出相关的模块
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
    logging,
    replace_return_docstrings,
)
# 引入LLaMA配置模块
from .configuration_llama import LlamaConfig

# 检查是否可用新的注意力机制库
if is_flash_attn_2_available():
    # 如果可用,引入相关函数
    from flash_attn import flash_attn_func, flash_attn_varlen_func
    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa

# 获取日志记录器
logger = logging.get_logger(__name__)

# 文档使用的配置名称
_CONFIG_FOR_DOC = "LlamaConfig"

# 辅助函数:获取未填充数据
def _get_unpad_data(attention_mask):
    # 计算每个序列在批次中的长度
    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
    # 找出attention_mask中为1的位置
    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
    # 计算批次中最大序列长度
    max_seqlen_in_batch = seqlens_in_batch.max().item()
    # 计算累积的序列长度
    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
    # 返回结果
    return (
        indices,
        cu_seqlens,
        max_seqlen_in_batch,
    )

# LLaMA模型的RMS归一化层
class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        # 初始化权重参数
        self.weight = nn.Parameter(torch.ones(hidden_size))
        # 定义方差的小量值
        self.variance_epsilon = eps
    # 定义一个前向传播方法,接受隐藏状态作为输入
    def forward(self, hidden_states):
        # 获取输入张量的数据类型
        input_dtype = hidden_states.dtype
        # 将隐藏状态张量转换为 float32 数据类型
        hidden_states = hidden_states.to(torch.float32)
        # 计算隐藏状态张量每个元素的平方,并沿着最后一个维度求均值,保持维度
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        # 对隐藏状态张量进行归一化处理,使用倒数平方根公式
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        # 返回归一化后的隐藏状态张量乘以权重张量
        return self.weight * hidden_states
# 将 LlamaRMSNorm 类添加到 ALL_LAYERNORM_LAYERS 列表中
ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)

# 定义 LlamaRotaryEmbedding 类,继承自 nn.Module
class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        super().__init__()
        self.scaling_factor = scaling_factor
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        # 计算频率的倒数,用于位置编码
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        # 将 inv_freq 注册为不可训练的缓冲区
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        # 缓存最大序列长度
        self.max_seq_len_cached = max_position_embeddings
        # 创建位置编码的张量 t,并根据缩放因子调整
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
        t = t / self.scaling_factor
        # 计算频率矩阵 freqs,并进行拼接以生成位置嵌入 emb
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        # 将 cos 和 sin 值缓存起来,注册为不可训练的缓冲区
        self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
        self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)

    @property
    def sin_cached(self):
        # 警告:sin_cached 属性将在 4.39 版本中移除,建议使用 RoPE 的 forward 方法代替
        logger.warning_once(
            "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
            "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
        )
        return self._sin_cached

    @property
    def cos_cached(self):
        # 警告:cos_cached 属性将在 4.39 版本中移除,建议使用 RoPE 的 forward 方法代替
        logger.warning_once(
            "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
            "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
        )
        return self._cos_cached

    @torch.no_grad()
    def forward(self, x, position_ids):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # 扩展 inv_freq 和 position_ids 的维度,以便进行矩阵乘法
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # 在关闭自动混合精度的情况下,计算频率并计算 cos 和 sin
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        # 返回计算得到的 cos 和 sin,转换为与 x 相同的数据类型
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
    """LlamaRotaryEmbedding 扩展,添加了线性缩放。感谢 Reddit 用户 /u/kaiokendev 的贡献。"""
    # 定义一个方法 `forward`,接收输入 `x` 和位置标识 `position_ids`
    def forward(self, x, position_ids):
        # 将位置标识转换为浮点数,并应用缩放因子,以调整位置标识的范围
        position_ids = position_ids.float() / self.scaling_factor
        # 调用父类的 `forward` 方法,传入 `x` 和调整后的位置标识 `position_ids`
        cos, sin = super().forward(x, position_ids)
        # 返回计算得到的余弦和正弦值
        return cos, sin
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
    """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

    def forward(self, x, position_ids):
        # 计算序列的长度,找到最大的位置 ID 并加 1
        seq_len = torch.max(position_ids) + 1
        # 如果序列长度超过了最大位置嵌入的设定值
        if seq_len > self.max_position_embeddings:
            # 计算基础值,考虑动态的 NTK 缩放因子
            base = self.base * (
                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
            ) ** (self.dim / (self.dim - 2))
            # 计算新的频率倒数张量
            inv_freq = 1.0 / (
                base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
            )
            # 将频率倒数张量注册为缓冲区,以便在模型运行中使用,不会被视为模型的参数
            self.register_buffer("inv_freq", inv_freq, persistent=False)  # TODO joao: this may break with compilation

        # 调用父类的 forward 方法计算余弦和正弦部分
        cos, sin = super().forward(x, position_ids)
        # 返回余弦和正弦部分作为输出
        return cos, sin


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    # 将输入张量的一半维度旋转180度
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    # 拼接旋转后的两部分张量
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    # 在指定的维度上对余弦和正弦部分进行展开
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    # 对查询张量和键张量应用旋转位置嵌入
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    # 返回旋转后的查询张量和键张量作为结果
    return q_embed, k_embed
    # 初始化方法,接受一个配置对象作为参数
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__()
        # 将配置对象保存为实例变量
        self.config = config
        # 从配置对象中获取隐藏层大小并保存为实例变量
        self.hidden_size = config.hidden_size
        # 从配置对象中获取中间层大小并保存为实例变量
        self.intermediate_size = config.intermediate_size
        # 创建一个线性变换层,将隐藏层大小映射到中间层大小,没有偏置项
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        # 创建一个线性变换层,将隐藏层大小映射到中间层大小,没有偏置项
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        # 创建一个线性变换层,将中间层大小映射回隐藏层大小,没有偏置项
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        # 根据配置中的激活函数名称从预定义的映射中获取对应的激活函数
        self.act_fn = ACT2FN[config.hidden_act]

    # 前向传播方法,接受输入张量 x 作为参数
    def forward(self, x):
        # 如果预训练类型大于 1
        if self.config.pretraining_tp > 1:
            # 计算每个分片的大小
            slice = self.intermediate_size // self.config.pretraining_tp
            # 将gate_proj权重分片
            gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
            # 将up_proj权重分片
            up_proj_slices = self.up_proj.weight.split(slice, dim=0)
            # 将down_proj权重分片
            down_proj_slices = self.down_proj.weight.split(slice, dim=1)

            # 对输入张量 x 执行多个线性变换,然后拼接在一起,形成 gate_proj 的结果
            gate_proj = torch.cat(
                [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
            )
            # 对输入张量 x 执行多个线性变换,然后拼接在一起,形成 up_proj 的结果
            up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)

            # 对 gate_proj 的结果应用激活函数,并与 up_proj 相乘,然后按照 slice 进行分片
            intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
            # 对每个分片应用 down_proj 的线性变换,然后将结果相加
            down_proj = [
                F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
            ]
            down_proj = sum(down_proj)
        else:
            # 如果预训练类型不大于 1,直接计算 down_proj 的结果
            down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

        # 返回 down_proj 结果作为前向传播的输出
        return down_proj
# 定义一个函数 repeat_kv,用于复制输入张量的内容。这相当于 torch.repeat_interleave(x, dim=1, repeats=n_rep) 的功能。
# 输入参数 hidden_states 是一个四维张量,表示隐藏状态,维度为(batch, num_key_value_heads, seqlen, head_dim)。
# n_rep 是重复复制的次数。
# 函数返回一个张量,将隐藏状态从 (batch, num_key_value_heads, seqlen, head_dim) 转换为 (batch, num_attention_heads, seqlen, head_dim)。

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    # 获取输入张量的维度信息
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    
    # 如果 n_rep 等于 1,则直接返回原始的隐藏状态张量
    if n_rep == 1:
        return hidden_states
    
    # 将隐藏状态张量扩展为新的形状,以便复制内容
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    
    # 将扩展后的张量重新整形为所需的形状,即 (batch, num_attention_heads, seqlen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class LlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
        super().__init__()
        # 初始化 LlamaAttention 类的属性
        self.config = config
        self.layer_idx = layer_idx
        
        # 如果未提供 layer_idx,发出警告,因为在使用缓存时可能导致前向调用错误
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        # 设置注意力机制的相关参数
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True

        # 检查 hidden_size 是否可以被 num_heads 整除,否则抛出 ValueError
        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )

        # 初始化线性变换层,用于查询、键、值和输出的投影
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
        
        # 初始化相关的参数和设置
        self._init_rope()
    # 初始化 RoPE(Rotary Positional Embedding)模块
    def _init_rope(self):
        # 检查是否配置了 RoPE 的缩放参数
        if self.config.rope_scaling is None:
            # 如果未配置缩放参数,则使用默认的 LlamaRotaryEmbedding
            self.rotary_emb = LlamaRotaryEmbedding(
                self.head_dim,
                max_position_embeddings=self.max_position_embeddings,
                base=self.rope_theta,
            )
        else:
            # 如果配置了缩放参数,则根据类型选择相应的 RoPE 实现
            scaling_type = self.config.rope_scaling["type"]
            scaling_factor = self.config.rope_scaling["factor"]
            if scaling_type == "linear":
                # 使用线性缩放的 RoPE 实现
                self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    scaling_factor=scaling_factor,
                    base=self.rope_theta,
                )
            elif scaling_type == "dynamic":
                # 使用动态 NTK 缩放的 RoPE 实现
                self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    scaling_factor=scaling_factor,
                    base=self.rope_theta,
                )
            else:
                # 抛出异常,提示未知的 RoPE 缩放类型
                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

    # 前向传播函数定义,接受输入的张量和可选的参数
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
class LlamaFlashAttention2(LlamaAttention):
    """
    Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
    flash attention and deal with padding tokens in case the input contains any of them.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
        # 标志变量用于处理 Flash Attention 版本 2.1 以下的兼容性问题,此版本的 flash_attn 生成左上角对齐的因果蒙版,而本模块需要右下角对齐的默认行为。参考:https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0。
        # 注意,在 flash_attn<2.1 的情况下,如果 q_seqlen != k_seqlen(除了 q_seqlen == 1 的情况),会产生错误的蒙版(左上角)。
        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ):
        # 前向传播函数,用于计算注意力机制
        # hidden_states: 输入的隐藏状态张量
        # attention_mask: 可选的注意力蒙版张量,默认为 None
        # position_ids: 可选的位置 ID 张量,默认为 None
        # past_key_value: 可选的缓存键值对,默认为 None
        # output_attentions: 是否输出注意力权重,默认为 False
        # use_cache: 是否使用缓存,默认为 False
        # cache_position: 可选的缓存位置张量,默认为 None
        # **kwargs: 其他关键字参数

    def _flash_attention_forward(
        self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
    ):
        # 内部函数,执行 Flash Attention 的前向传播
        # query_states: 查询状态张量
        # key_states: 键状态张量
        # value_states: 值状态张量
        # attention_mask: 注意力蒙版张量
        # query_length: 查询长度
        # dropout: dropout 概率,默认为 0.0
        # softmax_scale: softmax 缩放参数,可选
    ):
        """
        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
        first unpad the input, then computes the attention scores and pad the final attention scores.

        Args:
            query_states (`torch.Tensor`):
                Input query states to be passed to Flash Attention API
            key_states (`torch.Tensor`):
                Input key states to be passed to Flash Attention API
            value_states (`torch.Tensor`):
                Input value states to be passed to Flash Attention API
            attention_mask (`torch.Tensor`):
                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
                position of padding tokens and 1 for the position of non-padding tokens.
            dropout (`float`):
                Attention dropout
            softmax_scale (`float`, *optional*):
                The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
        """
        # Determine if the attention mechanism should be causal based on configuration and query length
        if not self._flash_attn_uses_top_left_mask:
            causal = self.is_causal
        else:
            # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
            causal = self.is_causal and query_length != 1

        # Check if there are any padding tokens in the input sequence
        if attention_mask is not None:
            # Retrieve batch size from query states tensor
            batch_size = query_states.shape[0]
            # Unpad input states based on attention mask and query length
            query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
                query_states, key_states, value_states, attention_mask, query_length
            )

            # Extract lengths of effective sequences after unpadding
            cu_seqlens_q, cu_seqlens_k = cu_seq_lens
            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

            # Perform variable-length Flash Attention calculation
            attn_output_unpad = flash_attn_varlen_func(
                query_states,
                key_states,
                value_states,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=max_seqlen_in_batch_q,
                max_seqlen_k=max_seqlen_in_batch_k,
                dropout_p=dropout,
                softmax_scale=softmax_scale,
                causal=causal,
            )

            # Pad the attention output to match original sequence lengths
            attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
        else:
            # Perform standard Flash Attention calculation without padding
            attn_output = flash_attn_func(
                query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
            )

        # Return the final attention output
        return attn_output
    # 定义一个方法来处理无需填充的输入数据,根据输入的query_layer、key_layer、value_layer、attention_mask和query_length参数
    def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
        # 调用_get_unpad_data函数获取未填充数据的索引、当前序列长度和批次中的最大序列长度
        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
        # 获取key_layer的形状信息:批次大小、键值对序列长度、键值头的数量和头维度
        batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

        # 根据索引重新排列key_layer,以便按第一个轴索引重新组织
        key_layer = index_first_axis(
            key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
        )
        # 根据索引重新排列value_layer,以便按第一个轴索引重新组织
        value_layer = index_first_axis(
            value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
        )

        # 如果query_length等于kv_seq_len,则按索引重新排列query_layer,并更新相关变量
        if query_length == kv_seq_len:
            query_layer = index_first_axis(
                query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
            )
            cu_seqlens_q = cu_seqlens_k
            max_seqlen_in_batch_q = max_seqlen_in_batch_k
            indices_q = indices_k
        # 如果query_length等于1,则设置max_seqlen_in_batch_q为1,cu_seqlens_q为从0到batch_size+1的整数,indices_q为cu_seqlens_q的前一部分
        elif query_length == 1:
            max_seqlen_in_batch_q = 1
            cu_seqlens_q = torch.arange(
                batch_size + 1, dtype=torch.int32, device=query_layer.device
            )  # 这里有一个memcpy操作,效率很差。
            indices_q = cu_seqlens_q[:-1]
            query_layer = query_layer.squeeze(1)
        else:
            # 如果以上条件都不满足,则假设左填充,并截取attention_mask的后-query_length列
            attention_mask = attention_mask[:, -query_length:]
            # 调用unpad_input函数,根据query_layer和截取后的attention_mask获取unpad后的输入数据
            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

        # 返回更新后的query_layer、key_layer、value_layer、indices_q、cu_seqlens_q和max_seqlen_in_batch_q
        return (
            query_layer,
            key_layer,
            value_layer,
            indices_q,
            (cu_seqlens_q, cu_seqlens_k),
            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
        )
class LlamaSdpaAttention(LlamaAttention):
    """
    Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
    `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
    SDPA API.
    """

    # Adapted from LlamaAttention.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ):
        """
        Forward pass of the LlamaSdpaAttention module.

        Args:
            hidden_states (torch.Tensor): The input hidden states.
            attention_mask (Optional[torch.Tensor], optional): The attention mask. Defaults to None.
            position_ids (Optional[torch.LongTensor], optional): The position ids. Defaults to None.
            past_key_value (Optional[Cache], optional): The past key value cache. Defaults to None.
            output_attentions (bool, optional): Whether to output attentions. Defaults to False.
            use_cache (bool, optional): Whether to use caching. Defaults to False.
            cache_position (Optional[torch.LongTensor], optional): The position for caching. Defaults to None.
            **kwargs: Additional keyword arguments.

        Returns:
            torch.Tensor: The output tensor from the attention layer.
        """
        # Forward pass implementation goes here

LLAMA_ATTENTION_CLASSES = {
    "eager": LlamaAttention,
    "flash_attention_2": LlamaFlashAttention2,
    "sdpa": LlamaSdpaAttention,
}


class LlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size

        # Initialize self attention mechanism based on config's specified implementation
        self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)

        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ):
        """
        Forward pass of the LlamaDecoderLayer module.

        Args:
            hidden_states (torch.Tensor): The input hidden states.
            attention_mask (Optional[torch.Tensor], optional): The attention mask. Defaults to None.
            position_ids (Optional[torch.LongTensor], optional): The position ids. Defaults to None.
            past_key_value (Optional[Tuple[torch.Tensor]], optional): The past key value cache. Defaults to None.
            output_attentions (Optional[bool], optional): Whether to output attentions. Defaults to False.
            use_cache (Optional[bool], optional): Whether to use caching. Defaults to False.
            cache_position (Optional[torch.LongTensor], optional): The position for caching. Defaults to None.
            **kwargs: Additional keyword arguments.

        Returns:
            torch.Tensor: The output tensor from the decoder layer.
        """
        # Forward pass implementation goes here
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*):
                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
                query_sequence_length, key_sequence_length)` if default attention is used.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
        """
        # 如果kwargs中包含"padding_mask",则发出警告,该功能将在v4.37版本中移除
        if "padding_mask" in kwargs:
            warnings.warn(
                "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
            )

        # 保存输入的残差连接
        residual = hidden_states

        # 对输入的hidden_states进行LayerNorm处理
        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention部分
        # 调用self_attn方法进行自注意力计算,得到新的hidden_states、自注意力权重self_attn_weights以及新的缓存present_key_value
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            **kwargs,
        )

        # 加上残差连接
        hidden_states = residual + hidden_states

        # Fully Connected部分
        # 保存新的残差连接
        residual = hidden_states

        # 对新的hidden_states进行LayerNorm处理
        hidden_states = self.post_attention_layernorm(hidden_states)

        # 经过MLP层处理
        hidden_states = self.mlp(hidden_states)

        # 加上残差连接
        hidden_states = residual + hidden_states

        # 输出结果
        outputs = (hidden_states,)

        # 如果需要返回注意力权重,则添加到输出结果中
        if output_attentions:
            outputs += (self_attn_weights,)

        # 如果需要使用缓存,则添加present_key_value到输出结果中
        if use_cache:
            outputs += (present_key_value,)

        return outputs
    "Document the inputs the LLAMA model accepts (`model_input_ids`, `attention_mask`, etc.) See the superclass "
    "documentation for more details."
    LLAMA_INPUTS_DOCSTRING,
)
    # 创建一个包含字符串的元组,第一个元素是字符串描述模型的功能,第二个元素是模型文档字符串的起始部分
    (
        "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
        LLAMA_START_DOCSTRING,
    )
# 定义 LlamaForCausalLM 类,继承自 LlamaPreTrainedModel 类
class LlamaForCausalLM(LlamaPreTrainedModel):
    # 定义权重共享的键列表
    _tied_weights_keys = ["lm_head.weight"]

    # 初始化方法
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)
        # 创建一个 LlamaModel 实例,传入配置参数
        self.model = LlamaModel(config)
        # 设置词汇表大小
        self.vocab_size = config.vocab_size
        # 创建一个线性层 lm_head,用于预测词汇表中的词
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # 初始化权重并进行最终处理
        self.post_init()

    # 返回输入的嵌入层对象
    def get_input_embeddings(self):
        return self.model.embed_tokens

    # 设置输入的嵌入层对象
    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    # 返回输出的嵌入层对象
    def get_output_embeddings(self):
        return self.lm_head

    # 设置输出的嵌入层对象
    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    # 设置解码器对象
    def set_decoder(self, decoder):
        self.model = decoder

    # 获取解码器对象
    def get_decoder(self):
        return self.model
    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    # 添加模型前向传播方法的文档字符串,使用指定的输入文档字符串
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ):
    # 此方法定义了模型的前向传播过程,接受多个可选参数用于生成预测结果或计算损失

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
    ):
    # 准备生成过程的输入,接受多个参数用于生成新的模型输入

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        # 静态方法,用于重新排序缓存中的过去键值,以便与给定的beam索引匹配
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
                # 对每个层级的过去状态进行重新排序,使其与beam索引匹配
            )
        return reordered_past
    # 返回重新排序后的过去键值
"""
The Llama Model transformer with a sequence classification head on top (linear layer).

[`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.

Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
"""
@add_start_docstrings(
"""
The Llama Model transformer with a span classification head on top for extractive question-answering tasks like
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
LLAMA_START_DOCSTRING,
)
class LlamaForQuestionAnswering(LlamaPreTrainedModel):
    base_model_prefix = "transformer"

    # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
    def __init__(self, config):
        super().__init__(config)
        # Initialize Llama model with given configuration
        self.transformer = LlamaModel(config)
        # Linear layer for question-answering output (span start and end logits)
        self.qa_outputs = nn.Linear(config.hidden_size, 2)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        # Retrieve input embeddings from the Llama model
        return self.transformer.embed_tokens

    def set_input_embeddings(self, value):
        # Set input embeddings for the Llama model
        self.transformer.embed_tokens = value

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
"""
    # 定义一个方法 `forward`,用于模型的前向传播
    def forward(
        # 输入序列的 token IDs,类型为长整型张量,可选参数
        input_ids: Optional[torch.LongTensor] = None,
        # 注意力遮罩,类型为单精度浮点张量,可选参数
        attention_mask: Optional[torch.FloatTensor] = None,
        # 位置编码 ID,类型为长整型张量,可选参数
        position_ids: Optional[torch.LongTensor] = None,
        # 过去的键值对,类型为浮点张量列表,可选参数
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        # 输入的嵌入张量,类型为单精度浮点张量,可选参数
        inputs_embeds: Optional[torch.FloatTensor] = None,
        # 起始位置,类型为长整型张量,可选参数
        start_positions: Optional[torch.LongTensor] = None,
        # 结束位置,类型为长整型张量,可选参数
        end_positions: Optional[torch.LongTensor] = None,
        # 是否输出注意力张量,布尔类型,可选参数
        output_attentions: Optional[bool] = None,
        # 是否输出隐藏状态,布尔类型,可选参数
        output_hidden_states: Optional[bool] = None,
        # 是否返回字典格式的结果,布尔类型,可选参数
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, QuestionAnsweringModelOutput]:
        r"""
        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        """
        # 确保返回的字典存在,如果未提供则使用配置中的默认设置
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 使用 Transformer 模型处理输入,获取输出
        outputs = self.transformer(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从模型输出中提取序列输出
        sequence_output = outputs[0]

        # 将序列输出传递给问答模型输出层,获取开始和结束位置的 logits
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()

        total_loss = None
        if start_positions is not None and end_positions is not None:
            # 如果在多GPU环境中,需要添加一个维度
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1).to(start_logits.device)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1).to(end_logits.device)
            # 忽略超出模型输入长度的位置
            ignored_index = start_logits.size(1)
            start_positions = start_positions.clamp(0, ignored_index)
            end_positions = end_positions.clamp(0, ignored_index)

            # 定义交叉熵损失函数,忽略指定索引
            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

        # 如果不要求返回字典形式的输出,则按元组形式返回结果
        if not return_dict:
            output = (start_logits, end_logits) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        # 返回 QuestionAnsweringModelOutput 类型的对象,包含损失、开始和结束位置的 logits,以及隐藏状态和注意力权重
        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
posted @ 2024-06-29 16:59  绝不原创的飞龙  阅读(26)  评论(0编辑  收藏  举报