Transformers-源码解析-四十六-

Transformers 源码解析(四十六)

.\models\ernie_m\tokenization_ernie_m.py

# coding=utf-8
# 上面的注释声明了文件的编码格式为 UTF-8,并非代码实际操作,仅为信息说明

# 版权声明,指出该文件的版权归属于 Xuan Ouyang, Shuohuan Wang, Chao Pang, Yu Sun, Hao Tian, Hua Wu, Haifeng Wang 以及 HuggingFace Inc. 团队所有
# 版权声明的文本通常包括对软件使用的限制和许可,这里声明使用 Apache License, Version 2.0,详细信息可在指定网址查看
# http://www.apache.org/licenses/LICENSE-2.0

# 导入所需的模块和类
import io
import os
import unicodedata
from typing import Any, Dict, List, Optional, Tuple

# 导入 sentencepiece 模块,用于分词
import sentencepiece as spm

# 从 HuggingFace 库中导入必要的类和函数
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging

# 获取 logger 对象,用于记录日志
logger = logging.get_logger(__name__)

# 定义常量 SPIECE_UNDERLINE,用于表示子词之间的连接符
SPIECE_UNDERLINE = "▁"

# 定义词汇文件的默认名称映射,包括词汇文件和 sentencepiece 模型文件
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "sentencepiece_model_ckpt": "sentencepiece.bpe.model"}

# 定义资源文件的默认名称映射,包括 sentencepiece 模型文件和词汇文件
RESOURCE_FILES_NAMES = {
    "sentencepiece_model_file": "sentencepiece.bpe.model",
    "vocab_file": "vocab.txt",
}

# 预训练模型的词汇文件映射,根据模型名称指定对应的下载链接
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "ernie-m-base": "https://huggingface.co/susnato/ernie-m-base_pytorch/blob/main/vocab.txt",
        "ernie-m-large": "https://huggingface.co/susnato/ernie-m-base_pytorch/blob/main/vocab.txt",
    },
    "sentencepiece_model_file": {
        "ernie-m-base": "https://huggingface.co/susnato/ernie-m-base_pytorch/blob/main/sentencepiece.bpe.model",
        "ernie-m-large": "https://huggingface.co/susnato/ernie-m-base_pytorch/blob/main/sentencepiece.bpe.model",
    },
}

# 预训练模型的位置嵌入大小映射,根据模型名称指定对应的嵌入大小
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "ernie-m-base": 514,
    "ernie-m-large": 514,
}

# 预训练模型的初始化配置映射,根据模型名称指定对应的初始化参数
PRETRAINED_INIT_CONFIGURATION = {
    "ernie-m-base": {"do_lower_case": False},
    "ernie-m-large": {"do_lower_case": False},
}

# 以下是一个类的定义,用于构建 Ernie-M 的分词器
class ErnieMTokenizer(PreTrainedTokenizer):
    r"""
    Constructs a Ernie-M tokenizer. It uses the `sentencepiece` tools to cut the words to sub-words.
    """
    # 这里是类的初始化方法和构造函数,用于初始化一个 Ernie-M 分词器对象
    def __init__(
        self,
        # 指定分词器的初始化参数
        vocab_file: Optional[str] = None,
        sentencepiece_model_file: Optional[str] = None,
        do_lower_case=False,
        **kwargs
    ):
        """
        :param vocab_file: 词汇文件的路径(可选)
        :param sentencepiece_model_file: sentencepiece 模型文件的路径(可选)
        :param do_lower_case: 是否将所有输入转换为小写(默认为 False)
        """
        # 调用父类 PreTrainedTokenizer 的构造函数,初始化分词器
        super().__init__(
            # 指定初始化参数
            vocab_file=vocab_file,
            sentencepiece_model_file=sentencepiece_model_file,
            do_lower_case=do_lower_case,
            **kwargs
        )
    # 类描述了用于构建和初始化Ernie-M模型所需的参数与配置, 包括预训练语言模型的超参数和初始化工具.
    """
    Args:
        sentencepiece_model_ckpt (`str`):
            某句段语法模型检查点的路径, 用于序列到序列的编码和解码任务.
    
        vocab_file (`str`, *optional*):
            字典文件路径, 若未提供则继承默认词汇表.
    
        do_lower_case (`str`, *optional*, defaults to `True`):
            是否将输入文本转换为小写, 当在数据预处理阶段处理文本时启用.
    
        encoding (`str`, *optional*, defaults to `utf8`):
            编码方式, 默认使用UTF-8用于解析输入数据.
    
        unk_token (`str`, *optional*, defaults to `"[UNK]"`):
            未知词汇(外域词汇)的标记, 用于替换未在词汇表中的词汇.
    
        sep_token (`str`, *optional*, defaults to `"[SEP]"`):
            用于分隔不同句子在同一批文本序列中.
    
        pad_token (`str`, *optional*, defaults to `"[PAD]"`):
            用于填充序列, 使所有序列长度相等适用于批处理.
    
        cls_token (`str`, *optional*, defaults to `"[CLS]"`):
            分类器的标志符, 表示序列开始的典型符号.
    
        mask_token (`str`, *optional*, defaults to `"[MASK]"`):
            用于替换的标记符号, 该模型将其视为需要预测原始未掩码的令牌的例证.
    
        sp_model_kwargs: `Optional[Dict[str, Any]]` = None:
            用于初始化句段模型的可选参数字典.
    
        kwargs:
            其他可能的初始化参数, 用于扩展上述参数的功能.
    """
    
    # 定义用于Ernie-M模型关键输入名称的列表.
    model_input_names: List[str] = ["input_ids"]
    
    # 载入预训练配置集合与需要提供的初始配置工具
    vocab_files_names = VOCAB_FILES_NAMES
    pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    
    # 载入预先构建的词汇表文件映射表
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    
    # 载入额外资源文件的定义与配置
    resource_files_names = RESOURCE_FILES_NAMES
    
    # 构建模型实例以初始化, 包括指定的参数和可能的额外配置项.
    def __init__(
        self,
        sentencepiece_model_ckpt,
        vocab_file=None,
        do_lower_case=False,
        encoding="utf8",
        unk_token="[UNK]",
        sep_token="[SEP]",
        pad_token="[PAD]",
        cls_token="[CLS]",
        mask_token="[MASK]",
        sp_model_kwargs: Optional[Dict[str, Any]] = None,
        **kwargs,
    ):
        pass
    ) -> None:
        # 定义一个初始化方法,接受多个参数,用于初始化对象的各种属性和参数

        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
        # 如果 sp_model_kwargs 为 None,则将其设置为空字典,否则使用传入的 sp_model_kwargs

        self.do_lower_case = do_lower_case
        # 初始化一个属性 do_lower_case,表示是否进行小写处理

        self.sentencepiece_model_ckpt = sentencepiece_model_ckpt
        # 初始化 sentencepiece_model_ckpt 属性,表示 SentencePiece 模型的检查点路径

        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        # 使用 sp_model_kwargs 初始化 SentencePieceProcessor 对象,并赋值给 sp_model 属性

        self.sp_model.Load(sentencepiece_model_ckpt)
        # 载入 SentencePiece 模型的检查点文件

        # 模仿 paddlenlp.transformers.ernie_m.tokenizer.ErnieMTokenizer 的功能
        if vocab_file is not None:
            self.vocab = self.load_vocab(filepath=vocab_file)
            # 如果提供了 vocab_file 参数,则调用 load_vocab 方法加载词汇表
        else:
            self.vocab = {self.sp_model.id_to_piece(id): id for id in range(self.sp_model.get_piece_size())}
            # 否则,根据 SentencePiece 模型的大小构建词汇表,使用 id_to_piece 方法和 get_piece_size 方法

        self.reverse_vocab = {v: k for k, v in self.vocab.items()}
        # 创建反向词汇表,将 id 映射到词汇

        super().__init__(
            do_lower_case=do_lower_case,
            unk_token=unk_token,
            sep_token=sep_token,
            pad_token=pad_token,
            cls_token=cls_token,
            mask_token=mask_token,
            vocab_file=vocab_file,
            encoding=encoding,
            sp_model_kwargs=self.sp_model_kwargs,
            **kwargs,
        )
        # 调用父类的初始化方法,传递相关参数和关键字参数

    def get_offset_mapping(self, text):
        if text is None:
            return None
        # 如果文本为空,则返回 None

        split_tokens = self.tokenize(text)
        # 使用当前对象的 tokenize 方法对文本进行分词,得到分词后的列表 split_tokens

        normalized_text, char_mapping = "", []
        # 初始化 normalized_text 和 char_mapping

        for i, ch in enumerate(text):
            if ch in self.SP_CHAR_MAPPING:
                ch = self.SP_CHAR_MAPPING.get(ch)
                # 如果字符在 SP_CHAR_MAPPING 中,使用映射后的字符替换原字符
            else:
                ch = unicodedata.normalize("NFKC", ch)
                # 否则,使用 NFKC 规范化处理字符

            if self.is_whitespace(ch):
                continue
            # 如果字符是空白字符,则跳过

            normalized_text += ch
            # 将处理后的字符追加到 normalized_text 中
            char_mapping.extend([i] * len(ch))
            # 根据字符长度,将相应索引追加到 char_mapping 中

        text, token_mapping, offset = normalized_text, [], 0
        # 将处理后的文本赋值给 text,初始化 token_mapping 和 offset

        if self.do_lower_case:
            text = text.lower()
            # 如果需要进行小写处理,则将文本转换为小写

        for token in split_tokens:
            if token[:1] == "▁":
                token = token[1:]
                # 如果 token 以 "▁" 开头,去除 "▁"
            start = text[offset:].index(token) + offset
            # 找到 token 在 text 中的起始索引
            end = start + len(token)
            # 计算 token 在 text 中的结束索引

            token_mapping.append((char_mapping[start], char_mapping[end - 1] + 1))
            # 将 token 的字符映射加入 token_mapping 中
            offset = end
            # 更新 offset

        return token_mapping
        # 返回字符映射列表

    @property
    def vocab_size(self):
        return len(self.vocab)
        # 返回词汇表大小

    def get_vocab(self):
        return dict(self.vocab, **self.added_tokens_encoder)
        # 返回词汇表及额外的 token 编码器

    def __getstate__(self):
        state = self.__dict__.copy()
        # 复制对象的属性字典到 state 中
        state["sp_model"] = None
        # 将 sp_model 设置为 None
        return state
        # 返回对象的状态字典

    def __setstate__(self, d):
        self.__dict__ = d
        # 将状态字典 d 中的内容赋值给对象的属性字典

        # 用于向后兼容
        if not hasattr(self, "sp_model_kwargs"):
            self.sp_model_kwargs = {}
            # 如果对象中不存在 sp_model_kwargs 属性,则设置为空字典

        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        # 使用 sp_model_kwargs 初始化 SentencePieceProcessor 对象,并赋值给 sp_model 属性
        self.sp_model.Load(self.sentencepiece_model_ckpt)
        # 载入 SentencePiece 模型的检查点文件
    # 对文本进行清洗,去除无效字符并清理空白
    def clean_text(self, text):
        return "".join((self.SP_CHAR_MAPPING.get(c, c) for c in text))

    # 对字符串进行分词处理
    def _tokenize(self, text, enable_sampling=False, nbest_size=64, alpha=0.1):
        """Tokenize a string."""

        # 如果在参数中启用了采样,则将 enable_sampling 设置为 True
        if self.sp_model_kwargs.get("enable_sampling") is True:
            enable_sampling = True
        # 如果参数中指定了 alpha,则使用参数中的值
        if self.sp_model_kwargs.get("alpha") is not None:
            alpha = self.sp_model_kwargs.get("alpha")
        # 如果参数中指定了 nbest_size,则使用参数中的值
        if self.sp_model_kwargs.get("nbest_size") is not None:
            nbest_size = self.sp_model_kwargs.get("nbest_size")

        # 根据是否启用采样来选择使用 EncodeAsPieces 还是 SampleEncodeAsPieces 方法
        if not enable_sampling:
            pieces = self.sp_model.EncodeAsPieces(text)
        else:
            pieces = self.sp_model.SampleEncodeAsPieces(text, nbest_size, alpha)
        
        new_pieces = []
        # 遍历分词后的片段
        for pi, piece in enumerate(pieces):
            # 处理特殊标记 SPIECE_UNDERLINE
            if piece == SPIECE_UNDERLINE:
                # 如果当前标记是 SPIECE_UNDERLINE 且下一个标记不以 SPIECE_UNDERLINE 开头且不是第一个标记,则添加 SPIECE_UNDERLINE
                if not pieces[pi + 1].startswith(SPIECE_UNDERLINE) and pi != 0:
                    new_pieces.append(SPIECE_UNDERLINE)
                    continue
                else:
                    continue
            lst_i = 0
            # 遍历当前片段中的每个字符
            for i, chunk in enumerate(piece):
                # 跳过 SPIECE_UNDERLINE
                if chunk == SPIECE_UNDERLINE:
                    continue
                # 判断字符是否为中文字符或标点符号
                if self.is_ch_char(chunk) or self.is_punct(chunk):
                    # 如果当前字符不是第一个且前一个字符不是 SPIECE_UNDERLINE,则添加前面的部分到 new_pieces
                    if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE:
                        new_pieces.append(piece[lst_i:i])
                    # 添加当前字符到 new_pieces
                    new_pieces.append(chunk)
                    # 更新 lst_i 为当前索引加 1
                    lst_i = i + 1
                # 如果字符是数字且不是第一个,并且前一个字符不是数字,则添加前面的部分到 new_pieces
                elif chunk.isdigit() and i > 0 and not piece[i - 1].isdigit():
                    if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE:
                        new_pieces.append(piece[lst_i:i])
                    lst_i = i
                # 如果字符不是数字且不是第一个,并且前一个字符是数字,则添加前面的部分到 new_pieces
                elif not chunk.isdigit() and i > 0 and piece[i - 1].isdigit():
                    if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE:
                        new_pieces.append(piece[lst_i:i])
                    lst_i = i
            # 如果片段长度大于 lst_i,则添加剩余部分到 new_pieces
            if len(piece) > lst_i:
                new_pieces.append(piece[lst_i:])
        
        # 返回处理后的片段列表
        return new_pieces

    # 将分词后的 tokens 列表转换为单个字符串
    def convert_tokens_to_string(self, tokens):
        out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
        return out_string

    # 将 token ids 转换为单个字符串
    def convert_ids_to_string(self, ids):
        # 将 token ids 转换为 tokens
        tokens = self.convert_ids_to_tokens(ids)
        # 将 tokens 列表转换为单个字符串,并替换 SPIECE_UNDERLINE
        out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
        return out_string

    # 模仿 paddlenlp.transformers.ernie_m.tokenizer.ErnieMTokenizer 的功能,将 token 转换为其对应的 id
    def _convert_token_to_id(self, token):
        return self.vocab.get(token, self.vocab.get(self.unk_token))

    # 模仿 paddlenlp.transformers.ernie_m.tokenizer.ErnieMTokenizer 的功能
    def _convert_id_to_token(self, index):
        """
        Converts an index (integer) into a token (str) using the vocabulary.
        
        Args:
            index (int): Index to convert into a token.
        
        Returns:
            str: The corresponding token if found in the vocabulary, otherwise returns the unknown token (self.unk_token).
        """
        # 使用反向词汇表将索引转换为对应的标记,如果索引不存在则返回未知标记
        return self.reverse_vocab.get(index, self.unk_token)

    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        """
        构建用于序列分类任务的模型输入,通过连接和添加特殊标记。ErnieM 序列的格式如下:

        - 单个序列:`[CLS] X [SEP]`
        - 序列对:`[CLS] A [SEP] [SEP] B [SEP]`

        Args:
            token_ids_0 (List[int]): 要添加特殊标记的 ID 列表。
            token_ids_1 (List[int], optional): 第二个序列的 ID 列表(可选)。

        Returns:
            List[int]: 包含适当特殊标记的输入 ID 列表。
        """
        if token_ids_1 is None:
            return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
        _cls = [self.cls_token_id]
        _sep = [self.sep_token_id]
        return _cls + token_ids_0 + _sep + _sep + token_ids_1 + _sep

    def build_offset_mapping_with_special_tokens(self, offset_mapping_0, offset_mapping_1=None):
        """
        构建偏移映射,通过连接和添加特殊标记的偏移量。Ernie-M 偏移映射的格式如下:

        - 单个序列:`(0,0) X (0,0)`
        - 序列对:`(0,0) A (0,0) (0,0) B (0,0)`

        Args:
            offset_mapping_ids_0 (List[tuple]): 要添加特殊标记的字符偏移列表。
            offset_mapping_ids_1 (List[tuple], optional): 第二个序列的单词片段偏移列表(可选)。

        Returns:
            List[tuple]: 包含适当特殊标记偏移量的单词片段偏移列表。
        """
        if offset_mapping_1 is None:
            return [(0, 0)] + offset_mapping_0 + [(0, 0)]

        return [(0, 0)] + offset_mapping_0 + [(0, 0), (0, 0)] + offset_mapping_1 + [(0, 0)]
    # 检查给定的 token_ids_0 是否已经包含特殊标记,如果是则返回特殊标记掩码
    def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
        r"""
        Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `encode` method.

        Args:
            token_ids_0 (`List[int]`):
                List of ids of the first sequence.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`str`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.
        Returns:
            `List[int]`:
                The list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """

        if already_has_special_tokens:
            # 如果 token_ids_0 已经包含了特殊标记,并且 token_ids_1 也被提供了,则抛出 ValueError
            if token_ids_1 is not None:
                raise ValueError(
                    "You should not supply a second sequence if the provided sequence of "
                    "ids is already formatted with special tokens for the model."
                )
            # 返回 token_ids_0 中特殊标记的掩码
            return [1 if x in [self.sep_token_id, self.cls_token_id] else 0 for x in token_ids_0]

        if token_ids_1 is not None:
            # 如果 token_ids_1 存在,则创建包含特殊标记的掩码列表,用于处理序列对
            return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
        # 否则,仅处理 token_ids_0,返回包含特殊标记的掩码列表
        return [1] + ([0] * len(token_ids_0)) + [1]

    # 根据传入的 token_ids_0 和 token_ids_1 创建对应的 token 类型 ID 列表
    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Create the token type IDs corresponding to the sequences passed. [What are token type
        IDs?](../glossary#token-type-ids) Should be overridden in a subclass if the model has a special way of
        building: those.

        Args:
            token_ids_0 (`List[int]`):
                The first tokenized sequence.
            token_ids_1 (`List[int]`, *optional*):
                The second tokenized sequence.
        Returns:
            `List[int]`: The token type ids.
        """
        # 当 `add_special_tokens` 为 True 时调用,因此需要与 `build_inputs_with_special_tokens` 方法对齐
        if token_ids_1 is None:
            # [CLS] X [SEP] 的序列,对应的 token 类型 ID 全为 0
            return (len(token_ids_0) + 2) * [0]

        # [CLS] A [SEP] [SEP] B [SEP] 的序列,构建对应的 token 类型 ID 列表
        return [0] * (len(token_ids_0) + 1) + [1] * (len(token_ids_1) + 3)

    # 检查字符是否为中文字符
    def is_ch_char(self, char):
        """
        is_ch_char
        """
        if "\u4e00" <= char <= "\u9fff":
            return True
        return False

    # 检查字符是否为字母
    def is_alpha(self, char):
        """
        is_alpha
        """
        if ("a" <= char <= "z") or ("A" <= char <= "Z"):
            return True
        return False

    # 检查字符是否为标点符号
    def is_punct(self, char):
        """
        is_punct
        """
        if char in ",;:.?!~,;:。?!《》【】":
            return True
        return False
    # 判断字符是否为空白字符
    def is_whitespace(self, char):
        """
        is whitespace
        """
        # 检查字符是否为空格、制表符、换行符或回车符
        if char == " " or char == "\t" or char == "\n" or char == "\r":
            return True
        # 如果字符长度为1,则使用 unicodedata 模块检查其分类是否为 Zs(空格分隔符)
        if len(char) == 1:
            cat = unicodedata.category(char)
            if cat == "Zs":
                return True
        # 若以上条件都不满足,则返回 False,表示字符不是空白字符
        return False

    # 加载词汇表文件,并返回 token 到索引的映射字典
    def load_vocab(self, filepath):
        token_to_idx = {}
        # 使用 utf-8 编码打开文件,并逐行读取
        with io.open(filepath, "r", encoding="utf-8") as f:
            for index, line in enumerate(f):
                # 去除行尾的换行符,并将 token 作为键,索引作为值存入字典
                token = line.rstrip("\n")
                token_to_idx[token] = int(index)
        # 返回 token 到索引的映射字典
        return token_to_idx

    # 将词汇表保存到指定目录,返回保存的文件路径元组
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        index = 0
        # 检查保存目录是否存在,若不存在则直接使用保存文件名
        if os.path.isdir(save_directory):
            vocab_file = os.path.join(
                save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
            )
        else:
            vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
        # 使用 utf-8 编码打开 vocab_file 文件,并写入词汇表
        with open(vocab_file, "w", encoding="utf-8") as writer:
            # 按照词汇表中的索引顺序排序词汇,并写入文件
            for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
                if index != token_index:
                    # 如果发现索引不连续,记录警告信息
                    logger.warning(
                        f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
                        " Please check that the vocabulary is not corrupted!"
                    )
                    index = token_index
                # 将 token 写入文件,并在末尾添加换行符
                writer.write(token + "\n")
                index += 1

        # 将 tokenizer 模型保存为二进制文件
        tokenizer_model_file = os.path.join(save_directory, "sentencepiece.bpe.model")
        with open(tokenizer_model_file, "wb") as fi:
            # 获取序列化的 tokenizer 模型,并写入文件
            content_spiece_model = self.sp_model.serialized_model_proto()
            fi.write(content_spiece_model)

        # 返回保存的词汇表文件路径的元组
        return (vocab_file,)

.\models\ernie_m\__init__.py

# 版权声明和保留声明,声明了代码版权及许可协议
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

# 导入异常:当依赖项不可用时引发异常
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_sentencepiece_available, is_torch_available

# 定义导入结构的字典
_import_structure = {
    "configuration_ernie_m": ["ERNIE_M_PRETRAINED_CONFIG_ARCHIVE_MAP", "ErnieMConfig"],
}

# 检查是否有句子分割器可用,若不可用则引发异常
try:
    if not is_sentencepiece_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 将 tokenization_ernie_m 模块添加到导入结构中
    _import_structure["tokenization_ernie_m"] = ["ErnieMTokenizer"]

# 检查是否有 Torch 可用,若不可用则引发异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 将 modeling_ernie_m 模块添加到导入结构中
    _import_structure["modeling_ernie_m"] = [
        "ERNIE_M_PRETRAINED_MODEL_ARCHIVE_LIST",
        "ErnieMForMultipleChoice",
        "ErnieMForQuestionAnswering",
        "ErnieMForSequenceClassification",
        "ErnieMForTokenClassification",
        "ErnieMModel",
        "ErnieMPreTrainedModel",
        "ErnieMForInformationExtraction",
    ]

# 如果是类型检查模式,则进行额外的导入
if TYPE_CHECKING:
    # 导入配置和配置类
    from .configuration_ernie_m import ERNIE_M_PRETRAINED_CONFIG_ARCHIVE_MAP, ErnieMConfig

    try:
        # 再次检查是否有句子分割器可用,若不可用则跳过
        if not is_sentencepiece_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入 ErnieMTokenizer 类
        from .tokenization_ernie_m import ErnieMTokenizer

    try:
        # 再次检查是否有 Torch 可用,若不可用则跳过
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入 modeling_ernie_m 中的多个类和常量
        from .modeling_ernie_m import (
            ERNIE_M_PRETRAINED_MODEL_ARCHIVE_LIST,
            ErnieMForInformationExtraction,
            ErnieMForMultipleChoice,
            ErnieMForQuestionAnswering,
            ErnieMForSequenceClassification,
            ErnieMForTokenClassification,
            ErnieMModel,
            ErnieMPreTrainedModel,
        )

# 若非类型检查模式,则创建懒加载模块
else:
    import sys

    # 使用 _LazyModule 创建模块,提供导入结构和模块规范
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\esm\configuration_esm.py

"""
# coding=utf-8
# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" ESM model configuration"""

# Import necessary modules
from dataclasses import asdict, dataclass
from typing import Optional

# Import configuration utilities
from ...configuration_utils import PretrainedConfig
from ...utils import logging

# Get logger for this module
logger = logging.get_logger(__name__)

# TODO Update this
# Mapping of pretrained model names to their configuration URLs
ESM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "facebook/esm-1b": "https://huggingface.co/facebook/esm-1b/resolve/main/config.json",
    # See all ESM models at https://huggingface.co/models?filter=esm
}

# Configuration class for the ESM model, inheriting from PretrainedConfig
class EsmConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`ESMModel`]. It is used to instantiate a ESM model
    according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the ESM
    [facebook/esm-1b](https://huggingface.co/facebook/esm-1b) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Examples:

    ```
    >>> from transformers import EsmModel, EsmConfig

    >>> # Initializing a ESM facebook/esm-1b style configuration >>> configuration = EsmConfig()

    >>> # Initializing a model from the configuration >>> model = ESMModel(configuration)

    >>> # Accessing the model configuration >>> configuration = model.config
    ```
    """

    model_type = "esm"

    def __init__(
        self,
        vocab_size=None,
        mask_token_id=None,
        pad_token_id=None,
        hidden_size=768,
        num_hidden_layers=12,
        num_attention_heads=12,
        intermediate_size=3072,
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=1026,
        initializer_range=0.02,
        layer_norm_eps=1e-12,
        position_embedding_type="absolute",
        use_cache=True,
        emb_layer_norm_before=None,
        token_dropout=False,
        is_folding_model=False,
        esmfold_config=None,
        vocab_list=None,
        **kwargs,
    ):
        super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs)
        # 调用父类的初始化方法,传入特定的参数来初始化当前类

        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps
        self.position_embedding_type = position_embedding_type
        self.use_cache = use_cache
        self.emb_layer_norm_before = emb_layer_norm_before
        self.token_dropout = token_dropout
        self.is_folding_model = is_folding_model
        # 初始化多个模型配置的参数

        if is_folding_model:
            if esmfold_config is None:
                logger.info("No esmfold_config supplied for folding model, using default values.")
                # 如果没有提供 esmfold_config 参数,则使用默认配置并记录日志信息
                esmfold_config = EsmFoldConfig()
            elif isinstance(esmfold_config, dict):
                esmfold_config = EsmFoldConfig(**esmfold_config)
                # 如果 esmfold_config 是一个字典,则根据字典内容创建 EsmFoldConfig 对象
            self.esmfold_config = esmfold_config
            if vocab_list is None:
                logger.warning("No vocab_list supplied for folding model, assuming the ESM-2 vocabulary!")
                # 如果没有提供 vocab_list 参数,则假设使用 ESM-2 词汇表,并记录警告信息
                self.vocab_list = get_default_vocab_list()
            else:
                self.vocab_list = vocab_list
                # 否则,使用提供的 vocab_list 参数
        else:
            self.esmfold_config = None
            self.vocab_list = None
            # 如果不是折叠模型,则将 esmfold_config 和 vocab_list 设置为 None

        if self.esmfold_config is not None and getattr(self.esmfold_config, "use_esm_attn_map", False):
            raise ValueError("The HuggingFace port of ESMFold does not support use_esm_attn_map at this time!")
            # 如果 esmfold_config 不为 None,且其属性 use_esm_attn_map 为 True,则抛出值错误异常

    def to_dict(self):
        """
        Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].

        Returns:
            `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
        """
        output = super().to_dict()
        # 调用父类的 to_dict 方法,将父类的序列化结果添加到 output 字典中

        if isinstance(self.esmfold_config, EsmFoldConfig):
            output["esmfold_config"] = self.esmfold_config.to_dict()
            # 如果 esmfold_config 是 EsmFoldConfig 类型的对象,则将其序列化为字典并加入 output 中

        return output
        # 返回包含当前实例所有属性的字典作为序列化结果
# 数据类 EsmFoldConfig,用于配置 ESM 折叠模型的参数
@dataclass
class EsmFoldConfig:
    # ESM 类型,默认为 None
    esm_type: str = None
    # 是否使用 FP16 格式的 ESM
    fp16_esm: bool = True
    # 是否使用 ESM 注意力映射
    use_esm_attn_map: bool = False
    # 是否剔除 ESM 的成对序列
    esm_ablate_pairwise: bool = False
    # 是否剔除 ESM 的序列
    esm_ablate_sequence: bool = False
    # ESM 输入的 dropout 概率
    esm_input_dropout: float = 0

    # 是否嵌入氨基酸信息
    embed_aa: bool = True
    # 是否绕过语言模型
    bypass_lm: bool = False

    # LDDT 头部隐藏维度
    lddt_head_hid_dim: int = 128
    # EsmFoldConfig 的 trunk 配置,如果为 None 则使用默认配置
    trunk: "TrunkConfig" = None

    # 初始化方法,在对象创建后调用,处理 trunk 属性
    def __post_init__(self):
        # 如果 trunk 为 None,则使用默认的 TrunkConfig
        if self.trunk is None:
            self.trunk = TrunkConfig()
        # 如果 trunk 是 dict 类型,则将其转换为 TrunkConfig 对象
        elif isinstance(self.trunk, dict):
            self.trunk = TrunkConfig(**self.trunk)

    # 将当前实例序列化为 Python 字典的方法
    def to_dict(self):
        """
        Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].

        Returns:
            `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
        """
        # 将当前实例转换为字典
        output = asdict(self)
        # 将 trunk 属性也转换为字典
        output["trunk"] = self.trunk.to_dict()
        return output


# 数据类 TrunkConfig,用于配置 ESM 折叠模型的 trunk 参数
@dataclass
class TrunkConfig:
    # trunk 的块数
    num_blocks: int = 48
    # 序列状态维度
    sequence_state_dim: int = 1024
    # 成对状态维度
    pairwise_state_dim: int = 128
    # 序列头部宽度
    sequence_head_width: int = 32
    # 成对头部宽度
    pairwise_head_width: int = 32
    # 位置分箱数
    position_bins: int = 32
    # dropout 概率
    dropout: float = 0
    # 层丢弃概率
    layer_drop: float = 0
    # 是否使用 CPU 梯度检查点
    cpu_grad_checkpoint: bool = False
    # 最大循环次数
    max_recycles: int = 4
    # 分块大小
    chunk_size: Optional[int] = 128
    # 结构模块配置
    structure_module: "StructureModuleConfig" = None
    # 初始化方法,在对象实例化后自动调用。确保配置的正确性和一致性。
    def __post_init__(self):
        # 如果结构模块未指定,则使用默认的结构模块配置
        if self.structure_module is None:
            self.structure_module = StructureModuleConfig()
        # 如果结构模块是一个字典,则将其转换为结构模块配置对象
        elif isinstance(self.structure_module, dict):
            self.structure_module = StructureModuleConfig(**self.structure_module)

        # 检查最大循环次数是否大于零,否则抛出数值错误异常
        if self.max_recycles <= 0:
            raise ValueError(f"`max_recycles` should be positive, got {self.max_recycles}.")
        
        # 检查序列状态维度是否是其自身的倍数,否则抛出数值错误异常
        if self.sequence_state_dim % self.sequence_state_dim != 0:
            raise ValueError(
                "`sequence_state_dim` should be a round multiple of `sequence_state_dim`, got"
                f" {self.sequence_state_dim} and {self.sequence_state_dim}."
            )
        
        # 检查成对状态维度是否是其自身的倍数,否则抛出数值错误异常
        if self.pairwise_state_dim % self.pairwise_state_dim != 0:
            raise ValueError(
                "`pairwise_state_dim` should be a round multiple of `pairwise_state_dim`, got"
                f" {self.pairwise_state_dim} and {self.pairwise_state_dim}."
            )

        # 计算序列头的数量,确保序列状态维度与序列头宽度的乘积相等
        sequence_num_heads = self.sequence_state_dim // self.sequence_head_width
        if self.sequence_state_dim != sequence_num_heads * self.sequence_head_width:
            raise ValueError(
                "`sequence_state_dim` should be equal to `sequence_num_heads * sequence_head_width, got"
                f" {self.sequence_state_dim} != {sequence_num_heads} * {self.sequence_head_width}."
            )
        
        # 计算成对头的数量,确保成对状态维度与成对头宽度的乘积相等
        pairwise_num_heads = self.pairwise_state_dim // self.pairwise_head_width
        if self.pairwise_state_dim != pairwise_num_heads * self.pairwise_head_width:
            raise ValueError(
                "`pairwise_state_dim` should be equal to `pairwise_num_heads * pairwise_head_width, got"
                f" {self.pairwise_state_dim} != {pairwise_num_heads} * {self.pairwise_head_width}."
            )
        
        # 检查成对状态维度是否为偶数,否则抛出数值错误异常
        if self.pairwise_state_dim % 2 != 0:
            raise ValueError(f"`pairwise_state_dim` should be even, got {self.pairwise_state_dim}.")

        # 检查丢弃率是否小于0.4,否则抛出数值错误异常
        if self.dropout >= 0.4:
            raise ValueError(f"`dropout` should not be greater than 0.4, got {self.dropout}.")

    # 将当前实例序列化为Python字典的方法。覆盖默认的to_dict方法。
    def to_dict(self):
        """
        Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].

        Returns:
            `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
        """
        # 将对象的所有属性转换为字典
        output = asdict(self)
        # 将结构模块属性转换为其对应的字典表示
        output["structure_module"] = self.structure_module.to_dict()
        return output
@dataclass
class StructureModuleConfig:
    """
    定义了结构模块的配置参数的数据类。

    Args:
        sequence_dim:
            单一表示通道的维度
        pairwise_dim:
            成对表示通道的维度
        ipa_dim:
            IPA 隐藏通道的维度
        resnet_dim:
            Angle resnet(Alg. 23 lines 11-14)隐藏通道的维度
        num_heads_ipa:
            IPA 头的数量
        num_qk_points:
            在IPA期间生成的查询/键点的数量
        num_v_points:
            在IPA期间生成的值点的数量
        dropout_rate:
            层中使用的dropout率
        num_blocks:
            结构模块的块数量
        num_transition_layers:
            单一表示转换中的层数(Alg. 23 lines 8-9)
        num_resnet_blocks:
            Angle resnet 中的块数量
        num_angles:
            Angle resnet 中生成的角度数量
        trans_scale_factor:
            单一表示转换的隐藏维度的比例因子
        epsilon:
            Angle resnet 归一化中使用的小数值
        inf:
            用于注意力屏蔽的大数值
    """

    sequence_dim: int = 384
    pairwise_dim: int = 128
    ipa_dim: int = 16
    resnet_dim: int = 128
    num_heads_ipa: int = 12
    num_qk_points: int = 4
    num_v_points: int = 8
    dropout_rate: float = 0.1
    num_blocks: int = 8
    num_transition_layers: int = 1
    num_resnet_blocks: int = 2
    num_angles: int = 7
    trans_scale_factor: int = 10
    epsilon: float = 1e-8
    inf: float = 1e5

    def to_dict(self):
        """
        将数据类实例转换为字典的方法。
        """
        return asdict(self)


def get_default_vocab_list():
    """
    返回默认的词汇表列表。

    Returns:
        tuple: 包含默认词汇的元组
    """
    return (
        "<cls>",
        "<pad>",
        "<eos>",
        "<unk>",
        "L",
        "A",
        "G",
        "V",
        "S",
        "E",
        "R",
        "T",
        "I",
        "D",
        "P",
        "K",
        "Q",
        "N",
        "F",
        "Y",
        "M",
        "H",
        "W",
        "C",
        "X",
        "B",
        "U",
        "Z",
        "O",
        ".",
        "-",
        "<null_1>",
        "<mask>",
    )

.\models\esm\convert_esm.py

# coding=utf-8
# 版权 2022 年 HuggingFace Inc. 团队所有。
#
# 根据 Apache 许可证 2.0 版本("许可证")授权;
# 除非符合许可证的规定,否则您不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件根据"原样"的基础分发,
# 不提供任何形式的明示或暗示保证或条件。
# 有关详细信息,请参阅许可证。

"""Convert ESM checkpoint."""

# 导入必要的库和模块
import argparse  # 导入命令行参数解析模块
import pathlib   # 导入路径操作模块
from pathlib import Path  # 导入路径操作模块中的Path类
from tempfile import TemporaryDirectory  # 导入临时目录模块

# 导入ESM相关的模块和类
import esm as esm_module  # 导入ESM模块
import torch  # 导入PyTorch库
from esm.esmfold.v1.misc import batch_encode_sequences as esmfold_encode_sequences  # 导入序列编码函数
from esm.esmfold.v1.pretrained import esmfold_v1  # 导入ESM-Fold v1预训练模型

# 导入Transformers相关的类和函数
from transformers.models.esm.configuration_esm import EsmConfig, EsmFoldConfig  # 导入ESM和ESM-Fold的配置类
from transformers.models.esm.modeling_esm import (  # 导入ESM模型相关类
    EsmForMaskedLM,
    EsmForSequenceClassification,
    EsmIntermediate,
    EsmLayer,
    EsmOutput,
    EsmSelfAttention,
    EsmSelfOutput,
)
from transformers.models.esm.modeling_esmfold import EsmForProteinFolding  # 导入蛋白质折叠相关的ESM模型类
from transformers.models.esm.tokenization_esm import EsmTokenizer  # 导入ESM的分词器类
from transformers.utils import logging  # 导入日志记录模块

# 设置日志的详细级别为信息级别
logging.set_verbosity_info()
# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义样本数据,包含蛋白质序列和标识
SAMPLE_DATA = [
    (
        "protein1",
        "MNGTEGPNFYVPFSNATGVVRSPFEYPQYYLAEPWQFSMLAAYMFLLIVLGFPINFLTLYVTVQHKKLRTPLNYILLNLAVADLFMVLGGFTSTLYTSLHGYFVFGPTGCNLEGFFATLGGEIALWSLVVLAIERYVVVCKPMSNFRFGENHAIMGVAFTWVMALACAAPPLAGWSRYIPEGLQCSCGIDYYTLKPEVNNESFVIYMFVVHFTIPMIIIFFCYGQLVFTVKEAAAQQQESATTQKAEKEVTRMVIIMVIAFLICWVPYASVAFYIFTHQGSNFGPIFMTIPAFFAKSAAIYNPVIYIMMNKQFRNCMLTTICCGKNPLGDDEASATVSKTETSQVAPA",
    ),
    ("protein2", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA"),
    ("protein3", "MKTVRQERLKSI<mask>RILERSKEPVSGAQLAEELS<mask>SRQVIVQDIAYLRSLGYN<mask>VATPRGYVLAGG"),
    ("protein4", "MKTVRQERLKSI<mask>RILERSKEPVSGAQLAEELS<mask>SRQVIVQDIAYLRSLGYN<mask>VATPRGYVLA"),
]

# 定义ESM模型的名称与模型对象的映射关系
MODEL_MAPPING = {
    "esm1b_t33_650M_UR50S": esm_module.pretrained.esm1b_t33_650M_UR50S,
    "esm1v_t33_650M_UR90S_1": esm_module.pretrained.esm1v_t33_650M_UR90S_1,
    "esm1v_t33_650M_UR90S_2": esm_module.pretrained.esm1v_t33_650M_UR90S_2,
    "esm1v_t33_650M_UR90S_3": esm_module.pretrained.esm1v_t33_650M_UR90S_3,
    "esm1v_t33_650M_UR90S_4": esm_module.pretrained.esm1v_t33_650M_UR90S_4,
    "esm1v_t33_650M_UR90S_5": esm_module.pretrained.esm1v_t33_650M_UR90S_5,
    "esm2_t48_15B_UR50D": esm_module.pretrained.esm2_t48_15B_UR50D,
    "esm2_t36_3B_UR50D": esm_module.pretrained.esm2_t36_3B_UR50D,
    "esm2_t33_650M_UR50D": esm_module.pretrained.esm2_t33_650M_UR50D,
    "esm2_t30_150M_UR50D": esm_module.pretrained.esm2_t30_150M_UR50D,
    "esm2_t12_35M_UR50D": esm_module.pretrained.esm2_t12_35M_UR50D,
}
    # 将模型名称映射到预训练模型对象的引用:"esm2_t6_8M_UR50D"映射到esm_module.pretrained.esm2_t6_8M_UR50D
    "esm2_t6_8M_UR50D": esm_module.pretrained.esm2_t6_8M_UR50D,
    # 将模型名称映射到预训练模型对象的引用:"esmfold_v1"映射到esmfold_v1
    "esmfold_v1": esmfold_v1,
}

# 定义氨基酸类型列表
restypes = list("ARNDCQEGHILKMFPSTWYV")

# 在氨基酸类型列表中加入额外的字符 'X'
restypes_with_x = restypes + ["X"]

# 在带有 'X' 的氨基酸类型列表中再加入特殊的 token
restypes_with_extras = restypes_with_x + ["<pad>", "<mask>", "<cls>", "<sep>", "<eos>"]

# 返回一个 ESM 模型的 tokenizer 对象
def get_esmfold_tokenizer():
    # 使用临时目录创建词汇表文件并写入字符列表
    with TemporaryDirectory() as tempdir:
        vocab = "\n".join(restypes_with_extras)
        vocab_file = Path(tempdir) / "vocab.txt"
        vocab_file.write_text(vocab)
        # 使用词汇表文件创建 ESM tokenizer 对象
        hf_tokenizer = EsmTokenizer(vocab_file=str(vocab_file))
    # 设置 padding token 的 ID
    hf_tokenizer.pad_token_id = 0  # 与 'A' 重叠,但这似乎是他们想要的
    return hf_tokenizer

# 将原始模型的权重转移并检查到我们的模型中
def transfer_and_check_weights(original_module, our_module):
    status = our_module.load_state_dict(original_module.state_dict())
    # 如果有缺失的键,则引发 ValueError 异常
    if status.missing_keys:
        raise ValueError(f"Missing keys: {status.missing_keys}")
    # 如果有意外的键,则引发 ValueError 异常
    if status.unexpected_keys:
        raise ValueError(f"Unexpected keys: {status.unexpected_keys}")

# 将 ESM 模型检查点转换为 PyTorch 的格式
def convert_esm_checkpoint_to_pytorch(
    model: str, pytorch_dump_folder_path: str, classification_head: bool, push_to_repo: str, auth_token: str
):
    """
    复制/粘贴/调整 esm 的权重到我们的 BERT 结构中。
    """
    # 如果模型以 "esmfold" 开头,则创建相应的 ESM 模型实例
    if model.startswith("esmfold"):
        esm = MODEL_MAPPING[model]()
    else:
        esm, alphabet = MODEL_MAPPING[model]()
    
    # 将模型设为评估模式,禁用 dropout
    esm.eval()

    # 根据模型类型设置各种参数和配置
    if model.startswith("esmfold"):
        embed_dim = esm.esm.embed_dim
        num_layers = esm.esm.num_layers
        num_attention_heads = esm.esm.attention_heads
        intermediate_size = 4 * embed_dim
        token_dropout = esm.esm.token_dropout
        emb_layer_norm_before = False  # 这条代码路径在 ESM-2 中不存在
        position_embedding_type = "rotary"
        is_folding_model = True
        esmfold_config = EsmFoldConfig()
        # 更新 ESMFoldConfig 对象的配置项
        for key, val in esm.cfg.items():
            if hasattr(esmfold_config, key) and key != "trunk":
                setattr(esmfold_config, key, val)
        for key, val in esm.cfg.trunk.items():
            if hasattr(esmfold_config.trunk, key) and key != "structure_module":
                setattr(esmfold_config.trunk, key, val)
        for key, val in esm.cfg.trunk.structure_module.items():
            if hasattr(esmfold_config.trunk.structure_module, key):
                setattr(esmfold_config.trunk.structure_module, key, val)
    elif hasattr(esm, "args"):
        # 表明是 ESM-1b 或 ESM-1v 模型
        embed_dim = esm.args.embed_dim
        num_layers = esm.args.layers
        num_attention_heads = esm.args.attention_heads
        intermediate_size = esm.args.ffn_embed_dim
        token_dropout = esm.args.token_dropout
        emb_layer_norm_before = True if esm.emb_layer_norm_before else False
        position_embedding_type = "absolute"
        is_folding_model = False
        esmfold_config = None
    else:
        # 表示这是一个 ESM-2 模型
        embed_dim = esm.embed_dim
        num_layers = esm.num_layers
        num_attention_heads = esm.attention_heads
        intermediate_size = 4 * embed_dim  # 这个值在 ESM-2 中是硬编码的
        token_dropout = esm.token_dropout
        emb_layer_norm_before = False  # 这个代码路径在 ESM-2 中不存在
        position_embedding_type = "rotary"
        is_folding_model = False
        esmfold_config = None

    if is_folding_model:
        alphabet = esm.esm.alphabet
    vocab_list = tuple(alphabet.all_toks)
    mask_token_id = alphabet.mask_idx
    pad_token_id = alphabet.padding_idx

    if is_folding_model:
        original_esm_model = esm.esm
    else:
        original_esm_model = esm

    config = EsmConfig(
        vocab_size=original_esm_model.embed_tokens.num_embeddings,
        mask_token_id=mask_token_id,
        hidden_size=embed_dim,
        num_hidden_layers=num_layers,
        num_attention_heads=num_attention_heads,
        intermediate_size=intermediate_size,
        max_position_embeddings=1026,
        layer_norm_eps=1e-5,  # 在 fairseq 中使用的 PyTorch 默认值
        attention_probs_dropout_prob=0.0,
        hidden_dropout_prob=0.0,
        pad_token_id=pad_token_id,
        emb_layer_norm_before=emb_layer_norm_before,
        token_dropout=token_dropout,
        position_embedding_type=position_embedding_type,
        is_folding_model=is_folding_model,
        esmfold_config=esmfold_config,
        vocab_list=vocab_list,
    )
    if classification_head:
        config.num_labels = esm.classification_heads["mnli"].out_proj.weight.shape[0]
    print("Our ESM config:", config)

    if model.startswith("esmfold"):
        model_class = EsmForProteinFolding
    elif classification_head:
        model_class = EsmForSequenceClassification
    else:
        model_class = EsmForMaskedLM
    model = model_class(config)
    model.eval()

    # 现在我们来复制所有的权重。
    # Embeddings
    model.esm.embeddings.word_embeddings.weight = original_esm_model.embed_tokens.weight
    if position_embedding_type == "absolute":
        model.esm.embeddings.position_embeddings.weight = original_esm_model.embed_positions.weight

    if config.emb_layer_norm_before:
        model.esm.embeddings.layer_norm.weight = original_esm_model.emb_layer_norm_before.weight
        model.esm.embeddings.layer_norm.bias = original_esm_model.emb_layer_norm_before.bias

    model.esm.encoder.emb_layer_norm_after.weight = original_esm_model.emb_layer_norm_after.weight
    model.esm.encoder.emb_layer_norm_after.bias = original_esm_model.emb_layer_norm_after.bias
    # 如果是折叠模型(folding model),则执行以下操作
    if is_folding_model:
        # 将 ESM 模型的 esm_s_combine 数据传输到 model 的 esm_s_combine 中
        model.esm_s_combine.data = esm.esm_s_combine.data
        # 将 ESM 模型的 af2_to_esm 数据传输到 model 的 af2_to_esm 中
        model.af2_to_esm.data = esm.af2_to_esm.data
        # 将 ESM 模型的 embedding 数据传输到 model 的 embedding 中,并检查权重
        transfer_and_check_weights(esm.embedding, model.embedding)
        # 将 ESM 模型的 esm_s_mlp 数据传输到 model 的 esm_s_mlp 中,并检查权重
        transfer_and_check_weights(esm.esm_s_mlp, model.esm_s_mlp)
        # 将 ESM 模型的 trunk 数据传输到 model 的 trunk 中,并检查权重
        transfer_and_check_weights(esm.trunk, model.trunk)
        # 将 ESM 模型的 distogram_head 数据传输到 model 的 distogram_head 中,并检查权重
        transfer_and_check_weights(esm.distogram_head, model.distogram_head)
        # 将 ESM 模型的 ptm_head 数据传输到 model 的 ptm_head 中,并检查权重
        transfer_and_check_weights(esm.ptm_head, model.ptm_head)
        # 将 ESM 模型的 lm_head 数据传输到 model 的 lm_head 中,并检查权重
        transfer_and_check_weights(esm.lm_head, model.lm_head)
        # 将 ESM 模型的 lddt_head 数据传输到 model 的 lddt_head 中,并检查权重
        transfer_and_check_weights(esm.lddt_head, model.lddt_head)

    # 否则,如果是分类头(classification head),执行以下操作
    elif classification_head:
        # 将 ESM 模型的 "mnli" 分类头的权重传输到 model 的 classifier.dense.weight 中
        model.classifier.dense.weight = esm.esm.classification_heads["mnli"].dense.weight
        # 将 ESM 模型的 "mnli" 分类头的偏置传输到 model 的 classifier.dense.bias 中
        model.classifier.dense.bias = esm.classification_heads["mnli"].dense.bias
        # 将 ESM 模型的 "mnli" 分类头的输出投影权重传输到 model 的 classifier.out_proj.weight 中
        model.classifier.out_proj.weight = esm.classification_heads["mnli"].out_proj.weight
        # 将 ESM 模型的 "mnli" 分类头的输出投影偏置传输到 model 的 classifier.out_proj.bias 中
        model.classifier.out_proj.bias = esm.classification_heads["mnli"].out_proj.bias

    # 否则,执行以下操作(通常是语言模型头)
    else:
        # 将 ESM 模型的 lm_head 的 dense.weight 数据传输到 model 的 lm_head.dense.weight 中
        model.lm_head.dense.weight = esm.lm_head.dense.weight
        # 将 ESM 模型的 lm_head 的 dense.bias 数据传输到 model 的 lm_head.dense.bias 中
        model.lm_head.dense.bias = esm.lm_head.dense.bias
        # 将 ESM 模型的 lm_head 的 layer_norm.weight 数据传输到 model 的 lm_head.layer_norm.weight 中
        model.lm_head.layer_norm.weight = esm.lm_head.layer_norm.weight
        # 将 ESM 模型的 lm_head 的 layer_norm.bias 数据传输到 model 的 lm_head.layer_norm.bias 中
        model.lm_head.layer_norm.bias = esm.lm_head.layer_norm.bias
        # 将 ESM 模型的 lm_head 的 weight 数据传输到 model 的 lm_head.decoder.weight 中
        model.lm_head.decoder.weight = esm.lm_head.weight
        # 将 ESM 模型的 lm_head 的 bias 数据传输到 model 的 lm_head.bias 中
        model.lm_head.bias = esm.lm_head.bias

    # 将 ESM 模型的 contact_head 数据传输到 model 的 esm.contact_head 中,并检查权重
    transfer_and_check_weights(esm.contact_head, model.esm.contact_head)

    # 准备数据(来自 ESMStructuralSplitDataset 超家族的前两个序列 / 4)
    if is_folding_model:
        # 对于折叠模型,采样前两个数据样本,因为折叠模型不会使用掩码输入且不喜欢掩码令牌
        sample_data = SAMPLE_DATA[:2]
    else:
        # 对于其他模型,采样全部数据样本
        sample_data = SAMPLE_DATA

    if is_folding_model:
        # 获取 ESMFold 的 tokenizer
        hf_tokenizer = get_esmfold_tokenizer()
        # 使用 ESMFold tokenizer 处理样本数据,返回 PyTorch 张量格式,进行填充,并不添加特殊令牌
        hf_tokens = hf_tokenizer(
            [row[1] for row in sample_data], return_tensors="pt", padding=True, add_special_tokens=False
        )
        # 使用 ESMFold 编码函数处理样本数据,获取氨基酸序列、掩码等信息
        esmfold_aas, esmfold_mask, _, _, _ = esmfold_encode_sequences([row[1] for row in sample_data])
        # 检查是否成功匹配 ESMFold tokenizer 输出的 input_ids 和 attention_mask 与 hf_tokens 中的对应值
        success = torch.all(hf_tokens["input_ids"] == esmfold_aas) and torch.all(
            hf_tokens["attention_mask"] == esmfold_mask
        )
    else:
        # 否则,检查两个模型的 tokenizer 是否输出相同的 tokens
        batch_converter = alphabet.get_batch_converter()
        # 使用 batch_converter 处理样本数据,返回批次标签、字符串和 tokens
        batch_labels, batch_strs, batch_tokens = batch_converter(sample_data)
        # 准备 tokenizer,并确保其与 batch_tokens 匹配
        with TemporaryDirectory() as tempdir:
            # 创建临时目录,写入 alphabet 的全部 tokens 作为 vocab
            vocab = "\n".join(alphabet.all_toks)
            vocab_file = Path(tempdir) / "vocab.txt"
            vocab_file.write_text(vocab)
            # 使用 EsmTokenizer 初始化 hf_tokenizer
            hf_tokenizer = EsmTokenizer(vocab_file=str(vocab_file))

        # 使用 hf_tokenizer 处理样本数据,返回 PyTorch 张量格式,进行填充
        hf_tokens = hf_tokenizer([row[1] for row in sample_data], return_tensors="pt", padding=True)
        # 检查是否成功匹配 hf_tokens 的 input_ids 与 batch_tokens 中的对应值
        success = torch.all(hf_tokens["input_ids"] == batch_tokens)

    # 打印是否两个模型的 tokenizer 输出相同的 tokens,如果相同则输出 "🔥",否则输出 "💩"
    print("Do both models tokenizers output the same tokens?", "🔥" if success else "💩")
    # 如果成功标志为假,则引发异常并显示消息
    if not success:
        raise Exception("Tokenization does not match!")

    # 禁用梯度计算,因为这是推断阶段
    with torch.no_grad():
        # 如果是折叠模型
        if is_folding_model:
            # 分阶段测试模型
            # ESMFold 总是将 ESM stem 转换为 float16,需要在 GPU 上执行 float16 操作
            # 这在 CPU 上不支持。因此,我们需要在 GPU 上运行它。然而,
            # ESMFold 是社区中所谓的“大型模型”,因此我们强烈避免同时将原始模型和转换后的模型放在 GPU 上。
            their_output = esm.cuda().infer([row[1] for row in sample_data])
            # 使用模型在 GPU 上运行推理
            our_output = model.cuda()(
                input_ids=hf_tokens["input_ids"].cuda(), attention_mask=hf_tokens["attention_mask"].cuda()
            )
        else:
            # 在模型上运行输入以生成输出隐藏状态
            our_output = model(**hf_tokens, output_hidden_states=True)
            # 从输出中提取逻辑回归层结果
            our_output = our_output["logits"]
            if classification_head:
                # 如果是分类头,使用 ESM 模型的多功能自然语言推理分类
                their_output = esm.model.classification_heads["mnli"](esm.extract_features(batch_tokens))
            else:
                # 使用 ESM 模型对输入进行推理并返回逻辑回归结果
                their_output = esm(hf_tokens["input_ids"], repr_layers=list(range(999)))
                their_output = their_output["logits"]

        # 如果是折叠模型,则计算位置差的最大绝对值,并检查输出是否全部接近
        if is_folding_model:
            max_absolute_diff = torch.max(torch.abs(our_output["positions"] - their_output["positions"])).item()
            success = torch.allclose(our_output["positions"], their_output["positions"], atol=1e-5)
        else:
            # 否则计算输出差的最大绝对值,并检查输出是否全部接近
            max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
            success = torch.allclose(our_output, their_output, atol=1e-5)

        # 打印最大绝对差异的值
        print(f"max_absolute_diff = {max_absolute_diff}")  # 大约为 1e-5
        # 打印模型是否输出相同的张量
        print("Do both models output the same tensors?", "🔥" if success else "💩")

        # 如果没有成功匹配输出,则引发异常
        if not success:
            raise Exception("Something went wRoNg")

        # 如果不是折叠模型,进行接触预测测试
        if not is_folding_model:
            # 使用模型预测接触点
            our_output = model.predict_contacts(hf_tokens["input_ids"], hf_tokens["attention_mask"])
            # 使用 ESM 模型预测接触点
            their_output = esm.predict_contacts(hf_tokens["input_ids"])
            # 计算接触预测的最大绝对值差异,并检查输出是否全部接近
            max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
            success = torch.allclose(our_output, their_output, atol=1e-5)

            # 打印接触预测测试结果
            print("Contact prediction testing:")
            print(f"max_absolute_diff = {max_absolute_diff}")  # 大约为 1e-5
            print("Do both models output the same tensors?", "🔥" if success else "💩")

            # 如果没有成功匹配输出,则引发异常
            if not success:
                raise Exception("Something went wRoNg")

        # 创建目录以保存 PyTorch 模型
        pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
        # 打印正在保存模型的消息
        print(f"Saving model to {pytorch_dump_folder_path}")
        # 保存模型的预训练参数到指定路径
        model.save_pretrained(pytorch_dump_folder_path)

        # 在继续之前释放部分内存
        del esm

    # 打印正在保存分词器的消息
    print(f"Saving tokenizer to {pytorch_dump_folder_path}")
    # 保存分词器的预训练参数到指定路径
    hf_tokenizer.save_pretrained(pytorch_dump_folder_path)
    # 如果 push_to_repo 为真,则执行下面的操作
    if push_to_repo:
        # 调用 model 对象的 push_to_hub 方法,将模型推送到指定的仓库
        model.push_to_hub(repo_id=push_to_repo, token_token=auth_token)
        # 调用 hf_tokenizer 对象的 push_to_hub 方法,将 tokenizer 推送到指定的仓库
        hf_tokenizer.push_to_hub(repo_id=push_to_repo, token_token=auth_token)
if __name__ == "__main__":
    # 如果脚本作为主程序执行,则执行以下代码块

    parser = argparse.ArgumentParser()
    # 创建参数解析器对象

    # 必填参数
    parser.add_argument(
        "--pytorch_dump_folder_path", type=str, required=True, help="Path to the output PyTorch model."
    )
    # 参数:pytorch_dump_folder_path,类型为字符串,必填,用于指定输出 PyTorch 模型的路径

    parser.add_argument(
        "--classification_head", action="store_true", help="Whether to convert a final classification head."
    )
    # 参数:classification_head,如果存在则设置为 True,用于指定是否转换最终分类头部

    parser.add_argument("--model", default=None, type=str, required=True, help="Name of model to convert.")
    # 参数:model,类型为字符串,默认值为 None,必填,用于指定要转换的模型的名称

    parser.add_argument("--push_to_repo", type=str, help="Repo to upload to (including username!).")
    # 参数:push_to_repo,类型为字符串,用于指定要上传的仓库(包括用户名)

    parser.add_argument("--auth_token", type=str, help="HuggingFace auth token.")
    # 参数:auth_token,类型为字符串,用于指定 HuggingFace 的认证令牌

    args = parser.parse_args()
    # 解析命令行参数并存储到 args 对象中

    convert_esm_checkpoint_to_pytorch(
        args.model, args.pytorch_dump_folder_path, args.classification_head, args.push_to_repo, args.auth_token
    )
    # 调用函数 convert_esm_checkpoint_to_pytorch,传入解析后的参数来执行模型转换操作

.\models\esm\modeling_esm.py

# 设置文件编码为 UTF-8

# 版权声明和许可证信息

# 导入必要的库和模块
import math
from typing import List, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

# 导入各种辅助函数和类
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    BaseModelOutputWithPoolingAndCrossAttentions,
    MaskedLMOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
from .configuration_esm import EsmConfig

# 获取 logger 实例
logger = logging.get_logger(__name__)

# 文档中使用的模型检查点名称
_CHECKPOINT_FOR_DOC = "facebook/esm2_t6_8M_UR50D"

# 文档中使用的配置文件名称
_CONFIG_FOR_DOC = "EsmConfig"

# ESM 预训练模型存档列表
ESM_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "facebook/esm2_t6_8M_UR50D",
    "facebook/esm2_t12_35M_UR50D",
    # This is not a complete list of all ESM models!
    # See all ESM models at https://huggingface.co/models?filter=esm
]

# 定义一个函数,将输入张量沿着最后一个维度分成两部分,然后交换这两部分的顺序
def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

# 应用旋转位置嵌入到输入张量中
def apply_rotary_pos_emb(x, cos, sin):
    # 限制余弦和正弦嵌入的长度与输入张量的前两个维度一致
    cos = cos[:, :, : x.shape[-2], :]
    sin = sin[:, :, : x.shape[-2], :]

    # 返回应用旋转位置嵌入后的结果张量
    return (x * cos) + (rotate_half(x) * sin)

# 实现原始 ESM 仓库中的 GELU 激活函数
def gelu(x):
    """
    This is the gelu implementation from the original ESM repo. Using F.gelu yields subtly wrong results.
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

# 使张量在最后两个维度上对称化,用于接触预测
def symmetrize(x):
    "Make layer symmetric in final two dimensions, used for contact prediction."
    return x + x.transpose(-1, -2)

# 执行平均产品修正,用于接触预测
def average_product_correct(x):
    "Perform average product correct, used for contact prediction."
    a1 = x.sum(-1, keepdims=True)
    a2 = x.sum(-2, keepdims=True)
    a12 = x.sum((-1, -2), keepdims=True)

    avg = a1 * a2
    avg.div_(a12)  # in-place to reduce memory
    normalized = x - avg
    return normalized

# 定义旋转嵌入类,基于 RoFormer 中的旋转位置嵌入实现
class RotaryEmbedding(torch.nn.Module):
    """
    Rotary position embeddings based on those in
    [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
    matrices which depend on their relative positions.
    """
    def __init__(self, dim: int):
        super().__init__()
        # 生成并保存反频率缓冲区(非可训练)
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
        inv_freq = inv_freq
        self.register_buffer("inv_freq", inv_freq)

        self._seq_len_cached = None
        self._cos_cached = None
        self._sin_cached = None

    def _update_cos_sin_tables(self, x, seq_dimension=2):
        seq_len = x.shape[seq_dimension]

        # 如果序列长度发生变化,或者在新设备上(可能由于追踪等原因),则重置表格
        if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
            self._seq_len_cached = seq_len
            t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
            freqs = torch.outer(t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)

            self._cos_cached = emb.cos()[None, None, :, :]
            self._sin_cached = emb.sin()[None, None, :, :]

        return self._cos_cached, self._sin_cached

    def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # 更新余弦和正弦表格,使用 k 张量的序列维度作为参数
        self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)

        return (
            apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
            apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
        )
class EsmContactPredictionHead(nn.Module):
    """Performs symmetrization, apc, and computes a logistic regression on the output features"""

    def __init__(
        self,
        in_features: int,
        bias=True,
        eos_idx: int = 2,
    ):
        super().__init__()
        self.in_features = in_features
        self.eos_idx = eos_idx
        # 定义一个线性层,用于执行 logistic 回归
        self.regression = nn.Linear(in_features, 1, bias)
        # 定义激活函数为 Sigmoid
        self.activation = nn.Sigmoid()

    def forward(self, tokens, attentions):
        # 移除 EOS 标记的注意力
        eos_mask = tokens.ne(self.eos_idx).to(attentions)
        eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
        attentions = attentions * eos_mask[:, None, None, :, :]
        attentions = attentions[..., :-1, :-1]
        # 移除 CLS 标记的注意力
        attentions = attentions[..., 1:, 1:]
        batch_size, layers, heads, seqlen, _ = attentions.size()
        attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)

        # 特征:批次 x 通道 x 标记 x 标记(对称)
        attentions = attentions.to(
            self.regression.weight.device
        )  # 注意力始终是 float32,可能需要转换为 float16
        # 对注意力矩阵进行对称化处理和平均产品校正
        attentions = average_product_correct(symmetrize(attentions))
        # 将维度重新排列以匹配线性层的输入要求
        attentions = attentions.permute(0, 2, 3, 1)
        return self.activation(self.regression(attentions).squeeze(3))


class EsmEmbeddings(nn.Module):
    """
    Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
    """

    def __init__(self, config):
        super().__init__()
        # 词嵌入层,根据配置参数创建
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)

        if config.emb_layer_norm_before:
            # 如果配置中指定在嵌入之前进行层归一化,则初始化层归一化层
            self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        else:
            self.layer_norm = None
        # Dropout 层,用于随机置零输入张量的元素
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # 位置嵌入类型,绝对或相对
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
        # 在内存中创建位置 IDs 张量,用于序列中每个位置的索引
        self.register_buffer(
            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
        )

        self.padding_idx = config.pad_token_id
        # 位置嵌入层,根据配置参数创建
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
        )
        # 是否启用 token dropout
        self.token_dropout = config.token_dropout
        # Mask token 的 ID
        self.mask_token_id = config.mask_token_id

    def forward(
        self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
        ):
            # 如果未提供位置编码,但提供了输入的token ids,则从token ids创建位置编码,保留任何填充的token。
            if position_ids is None:
                if input_ids is not None:
                    position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
                else:
                    position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)

            # 如果未提供输入的嵌入表示,则使用word_embeddings来生成。
            if inputs_embeds is None:
                inputs_embeds = self.word_embeddings(input_ids)

            # 如果希望在未来支持ESM-1(而不是1b!),则可能需要在这里支持一个embedding_scale因子。
            embeddings = inputs_embeds

            # Matt: ESM在MLM中有一个处理masking的略微不同的选项。如果token_dropout标志为False,则处理方式与BERT/RoBERTa相同。
            # 如果设置为True,则掩码token被视为已选择进行输入的dropout并将其置零。
            # 当训练期间没有掩码的token时,通过将嵌入乘以 (训练期间未掩码token的分数) / (样本中未掩码token的分数),来补偿"mask-dropout"。
            # 这类似于dropout层在评估期间未实际丢弃值时缩减输出(或者在训练期间增加未丢弃的输出)的方式。
            if self.token_dropout:
                embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
                mask_ratio_train = 0.15 * 0.8  # 在所有ESM模型训练中硬编码的比率
                src_lengths = attention_mask.sum(-1)
                mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
                embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
                    embeddings.dtype
                )

            # 如果位置嵌入类型为"absolute",则添加绝对位置嵌入到嵌入表示中。
            if self.position_embedding_type == "absolute":
                position_embeddings = self.position_embeddings(position_ids)
                embeddings = embeddings + position_embeddings

            # 如果层归一化函数不为None,则对嵌入表示进行层归一化。
            if self.layer_norm is not None:
                embeddings = self.layer_norm(embeddings)

            # 如果存在注意力遮罩,则将其应用于嵌入表示。
            if attention_mask is not None:
                embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)

            # Matt: 我认为这行代码从BERT中复制过来时出现了错误,暂时禁用它。
            # embeddings = self.dropout(embeddings)

            # 返回最终的嵌入表示。
            return embeddings
    # 根据输入的嵌入张量生成位置 ID。由于我们直接提供了嵌入向量,无法推断哪些是填充的,因此生成连续的位置 ID。

    # 获取输入嵌入张量的形状,去除最后一个维度(通常是 batch 维度)
    input_shape = inputs_embeds.size()[:-1]
    # 获取序列长度,即嵌入张量的第二个维度大小
    sequence_length = input_shape[1]

    # 使用 torch.arange 生成从 self.padding_idx + 1 到 sequence_length + self.padding_idx + 1 的整数序列
    # 结果类型为 long 型,设备为 inputs_embeds 的设备
    position_ids = torch.arange(
        self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
    )

    # 将 position_ids 在第0维度上增加一个维度,并在各维度上重复 input_shape 次数,以便与 inputs_embeds 的形状匹配
    return position_ids.unsqueeze(0).expand(input_shape)
# 定义一个自注意力模块,继承自 nn.Module
class EsmSelfAttention(nn.Module):
    def __init__(self, config, position_embedding_type=None):
        super().__init__()
        # 检查隐藏大小是否能够被注意力头数整除,或者检查是否有嵌入大小属性
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        # 初始化注意力头数和每个注意力头的大小
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # 定义用于查询、键、值的线性变换层
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        # 定义用于dropout的层
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        
        # 设置位置嵌入的类型,默认为绝对位置编码
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        
        # 如果位置编码类型是相对键或者相对键查询,则初始化距离嵌入层
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
        
        # 如果位置编码类型是旋转类型,则初始化旋转嵌入层
        elif self.position_embedding_type == "rotary":
            self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)

        # 标记是否为解码器
        self.is_decoder = config.is_decoder

    # 将输入张量转换为分数矩阵
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    # 前向传播函数
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    def __init__(self, config):
        super().__init__()
        # 初始化 self 属性为 EsmSelfAttention 对象,使用给定的配置参数
        self.self = EsmSelfAttention(config)
        # 初始化 output 属性为 EsmSelfOutput 对象,使用给定的配置参数
        self.output = EsmSelfOutput(config)
        # 初始化 pruned_heads 为一个空集合,用于存储被剪枝的注意力头部索引
        self.pruned_heads = set()
        # 初始化 LayerNorm 属性为 nn.LayerNorm,使用给定的 hidden_size 和 layer_norm_eps 参数
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        # 调用 find_pruneable_heads_and_indices 函数找到可剪枝的头部及其索引
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )

        # 剪枝线性层
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # 更新超参数并存储被剪枝的头部索引
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        # 对输入的 hidden_states 进行 LayerNorm 处理
        hidden_states_ln = self.LayerNorm(hidden_states)
        # 调用 self.self 的 forward 方法进行自注意力计算
        self_outputs = self.self(
            hidden_states_ln,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
        )
        # 将自注意力的输出与原始 hidden_states 应用 self.output 进行最终的输出
        attention_output = self.output(self_outputs[0], hidden_states)
        # 如果需要输出注意力信息,则将其加入到输出中
        outputs = (attention_output,) + self_outputs[1:]  # 如果需要输出注意力信息,则加入
        return outputs
class EsmIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化一个全连接层,将输入维度为 config.hidden_size 转换为 config.intermediate_size
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 使用全连接层进行前向传播,将隐藏状态 hidden_states 映射到 intermediate_size 的维度
        hidden_states = self.dense(hidden_states)
        # 使用 GELU 激活函数处理 hidden_states
        hidden_states = gelu(hidden_states)
        return hidden_states


class EsmOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 初始化一个全连接层,将 intermediate_size 的输入转换为 hidden_size 的输出
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        # 初始化一个 dropout 层,以概率 config.hidden_dropout_prob 随机将输出置为 0
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        # 使用全连接层进行前向传播,将 hidden_states 映射回 hidden_size 的维度
        hidden_states = self.dense(hidden_states)
        # 对输出进行 dropout 处理
        hidden_states = self.dropout(hidden_states)
        # 将 dropout 处理后的输出与输入 input_tensor 相加,实现残差连接
        hidden_states = hidden_states + input_tensor
        return hidden_states


class EsmLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 设置用于分块前向传播的 chunk_size_feed_forward
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        # 设置序列长度的维度为 1
        self.seq_len_dim = 1
        # 初始化自注意力层
        self.attention = EsmAttention(config)
        # 设置是否为解码器的标志位
        self.is_decoder = config.is_decoder
        # 设置是否添加跨层注意力的标志位
        self.add_cross_attention = config.add_cross_attention
        # 如果添加了跨层注意力,确保模型被用作解码器模型
        if self.add_cross_attention:
            if not self.is_decoder:
                raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
            # 初始化跨层注意力层
            self.crossattention = EsmAttention(config)
        # 初始化中间层和输出层
        self.intermediate = EsmIntermediate(config)
        self.output = EsmOutput(config)
        # 初始化层归一化层
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
        ):
        # 对隐藏状态进行自注意力计算
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            )
        # 从自注意力输出中提取隐藏状态
        attention_output = self_attention_outputs[0]
        # 如果添加了跨层注意力,计算跨层注意力输出
        if self.add_cross_attention:
            cross_attention_outputs = self.crossattention(
                attention_output,
                encoder_hidden_states,
                encoder_attention_mask,
                past_key_value,
                output_attentions=output_attentions,
                )
            attention_output = cross_attention_outputs[0]
        # 应用中间层和输出层处理
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        # 应用层归一化层处理输出
        layer_output = self.LayerNorm(layer_output + attention_output)
        outputs = (layer_output,) + self_attention_outputs[1:]  # 添加注意力输出信息
        if output_attentions:
            outputs = outputs + cross_attention_outputs[1:]  # 添加跨层注意力输出信息
        return outputs
        # 如果过去的键/值对存在,只保留自注意力缓存的前两个位置的值
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
        # 执行自注意力计算,传入隐藏状态、注意力掩码、头掩码等参数
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            past_key_value=self_attn_past_key_value,
        )
        # 获取自注意力计算的输出
        attention_output = self_attention_outputs[0]

        # 如果当前实例为解码器,最后一个输出是自注意力缓存的元组
        if self.is_decoder:
            # 只保留自注意力计算输出中除了最后一个元组的部分
            outputs = self_attention_outputs[1:-1]
            # 提取自注意力的当前键/值对
            present_key_value = self_attention_outputs[-1]
        else:
            # 否则,包含所有自注意力计算输出(如果输出注意力权重的话)
            outputs = self_attention_outputs[1:]

        # 初始化交叉注意力的当前键/值对为 None
        cross_attn_present_key_value = None
        # 如果当前实例是解码器且存在编码器的隐藏状态
        if self.is_decoder and encoder_hidden_states is not None:
            # 检查是否存在交叉注意力层,如果没有则引发错误
            if not hasattr(self, "crossattention"):
                raise AttributeError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated"
                    " with cross-attention layers by setting `config.add_cross_attention=True`"
                )

            # 如果过去的键/值对存在,只保留交叉注意力缓存的最后两个位置的值
            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
            # 执行交叉注意力计算,传入自注意力的输出、注意力掩码、头掩码、编码器的隐藏状态等参数
            cross_attention_outputs = self.crossattention(
                attention_output,
                attention_mask,
                head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                cross_attn_past_key_value,
                output_attentions,
            )
            # 获取交叉注意力计算的输出
            attention_output = cross_attention_outputs[0]
            # 将交叉注意力计算输出中除了最后一个元组的部分添加到自注意力计算输出中
            outputs = outputs + cross_attention_outputs[1:-1]

            # 将交叉注意力的当前键/值对添加到自注意力的当前键/值对中的最后两个位置
            cross_attn_present_key_value = cross_attention_outputs[-1]
            present_key_value = present_key_value + cross_attn_present_key_value

        # 使用注意力输出执行前馈网络的计算
        layer_output = self.feed_forward_chunk(attention_output)

        # 将前馈网络的输出添加到输出元组中
        outputs = (layer_output,) + outputs

        # 如果当前实例是解码器,将注意力的键/值对作为输出元组的最后一个元素返回
        if self.is_decoder:
            outputs = outputs + (present_key_value,)
        return outputs
    ```
class EsmEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config  # 初始化模型配置
        self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])  # 创建多层ESM层
        self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # 对嵌入进行层归一化
        self.gradient_checkpointing = False  # 梯度检查点标记为False,表示不使用梯度检查点

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True,
        # 返回字典类型的结果



class EsmPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)  # 密集连接层
        self.activation = nn.Tanh()  # 激活函数为Tanh

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 通过获取第一个token对应的隐藏状态来“池化”模型
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)  # 使用线性层进行池化
        pooled_output = self.activation(pooled_output)  # 应用Tanh激活函数
        return pooled_output



class EsmPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = EsmConfig  # 配置类为EsmConfig
    base_model_prefix = "esm"  # 基础模型前缀为"esm"
    supports_gradient_checkpointing = True  # 支持梯度检查点

    _no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"]
    # 不拆分的模块列表

    # 初始化权重的函数,根据不同类型的模块进行初始化
    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, nn.Linear):
            # 使用正态分布初始化线性层的权重,偏置置零
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            # 使用正态分布初始化嵌入层的权重,偏置置零
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            # 层归一化的权重初始化,偏置置零,缩放参数置1
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


ESM_START_DOCSTRING = r"""

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    # 将其作为常规的 PyTorch 模块使用,并参考 PyTorch 文档以获取有关一般用法和行为的所有信息。

    Parameters:
        # config ([`EsmConfig`]): 模型配置类,包含模型的所有参数。
        # 初始化时使用配置文件不会加载与模型相关的权重,只加载配置信息。
        # 可以查看 [`~PreTrainedModel.from_pretrained`] 方法来加载模型权重。
# 定义了一个原始的 ESM 模型类,继承自 EsmPreTrainedModel
@add_start_docstrings(
    "The bare ESM Model transformer outputting raw hidden-states without any specific head on top.",
    ESM_START_DOCSTRING,
)
class EsmModel(EsmPreTrainedModel):
    """
    ESM 模型类,可以作为编码器(只包含自注意力)或解码器使用,后者则在自注意力层之间添加了一层交叉注意力,遵循了 Ashish Vaswani 等人在《Attention is all you need》中描述的架构。
    """
    """
    To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
    to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
    `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
    """

    # 根据给定的配置初始化模型,可选择添加一个池化层
    def __init__(self, config, add_pooling_layer=True):
        super().__init__(config)
        self.config = config

        # 初始化嵌入层
        self.embeddings = EsmEmbeddings(config)
        # 初始化编码器
        self.encoder = EsmEncoder(config)

        # 如果设置了添加池化层,则初始化池化层,否则为None
        self.pooler = EsmPooler(config) if add_pooling_layer else None

        # 初始化联系预测头部
        self.contact_head = EsmContactPredictionHead(
            in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
        )

        # 执行初始化后的权重和最终处理
        self.post_init()

    # 返回输入嵌入层
    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    # 设置输入嵌入层
    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

    # 剪枝模型中的注意力头部
    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    # 前向传播函数,接受多种输入和参数
    @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=BaseModelOutputWithPoolingAndCrossAttentions,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    # 预测联系人函数,接受 tokens 和 attention_mask 作为输入
    def predict_contacts(self, tokens, attention_mask):
        # 使用模型进行推断,返回注意力矩阵列表
        attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions
        # 将注意力矩阵堆叠起来,以匹配原始模型的布局
        attns = torch.stack(attns, dim=1)  # Matches the original model layout
        
        # 在原始模型中,对于填充的 token,其注意力被完全置零。
        # 大多数情况下这不会有影响,因为其他 token 不会关注它们,
        # 但对于需要将注意力作为输入的联系人预测任务而言,这一点非常重要,
        # 因此我们需要在这里模仿这种处理方式。
        
        # 将注意力矩阵乘以 attention_mask,以将填充的 token 的注意力置零
        attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
        attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
        
        # 使用联系人头部模型进行联系人预测,并返回结果
        return self.contact_head(tokens, attns)
# 定义一个 EsmForMaskedLM 类,继承自 EsmPreTrainedModel 类,并添加了语言建模头部
@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING)
class EsmForMaskedLM(EsmPreTrainedModel):
    # 定义了与 lm_head.decoder.weight 相关的权重绑定键
    _tied_weights_keys = ["lm_head.decoder.weight"]

    # 初始化方法,接收一个配置对象 config
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__(config)

        # 如果配置中 is_decoder 为 True,则发出警告信息
        if config.is_decoder:
            logger.warning(
                "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for "
                "bi-directional self-attention."
            )

        # 创建 EsmModel 对象,不添加池化层
        self.esm = EsmModel(config, add_pooling_layer=False)
        # 创建 EsmLMHead 对象
        self.lm_head = EsmLMHead(config)

        # 初始化模型权重
        self.init_weights()

    # 返回 lm_head.decoder 对象,用于输出嵌入
    def get_output_embeddings(self):
        return self.lm_head.decoder

    # 设置 lm_head.decoder 的新嵌入
    def set_output_embeddings(self, new_embeddings):
        self.lm_head.decoder = new_embeddings

    # 前向传播方法,接收多个输入参数并返回输出
    @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=MaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
        mask="<mask>",
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        # 以下为输入参数的详细说明
    ):
    ) -> Union[Tuple, MaskedLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        kwargs (`Dict[str, any]`, optional, defaults to *{}*):
            Used to hide legacy arguments that have been deprecated.
        """
        # 根据参数 `return_dict` 确定是否返回字典类型的输出
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用模型 `esm` 进行前向传播,传入各种输入参数
        outputs = self.esm(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 获取模型输出的序列输出
        sequence_output = outputs[0]
        # 对序列输出进行预测得到预测分数
        prediction_scores = self.lm_head(sequence_output)

        masked_lm_loss = None
        # 如果存在标签,则计算掩码语言建模损失
        if labels is not None:
            loss_fct = CrossEntropyLoss()

            # 将标签移动到与预测分数相同的设备上
            labels = labels.to(prediction_scores.device)
            # 计算掩码语言建模的损失
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

        # 如果不返回字典类型的输出,则组织最终输出格式
        if not return_dict:
            output = (prediction_scores,) + outputs[2:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        # 返回掩码语言建模任务的输出,包括损失、预测分数、隐藏状态和注意力权重
        return MaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def predict_contacts(self, tokens, attention_mask):
        # 调用模型 `esm` 的方法进行接触预测
        return self.esm.predict_contacts(tokens, attention_mask=attention_mask)
class EsmLMHead(nn.Module):
    """ESM Head for masked language modeling."""

    def __init__(self, config):
        super().__init__()
        # 定义一个全连接层,将输入特征空间映射到隐藏大小的空间
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # Layer normalization,对输入进行归一化处理
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        # 用于输出,将隐藏大小映射回词汇表大小的线性层,无偏置
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        # 定义一个偏置参数,长度为词汇表大小,用于模型输出的偏移
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))

    def forward(self, features, **kwargs):
        # 前向传播函数
        x = self.dense(features)  # 全连接层映射
        x = gelu(x)  # 使用 GELU 激活函数
        x = self.layer_norm(x)  # Layer normalization 归一化处理

        # 用线性层映射回词汇表大小,并加上偏置
        x = self.decoder(x) + self.bias
        return x


@add_start_docstrings(
    """
    ESM Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    """,
    ESM_START_DOCSTRING,
)
class EsmForSequenceClassification(EsmPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config

        # ESM 模型主体部分,不添加池化层
        self.esm = EsmModel(config, add_pooling_layer=False)
        # 分类头部,用于序列分类任务
        self.classifier = EsmClassificationHead(config)

        self.init_weights()  # 初始化模型权重

    @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=SequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, SequenceClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        # 如果 return_dict 为 None,则使用 self.config.use_return_dict 决定是否返回字典形式的输出
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 使用 ESM 模型进行前向传播,获取模型的输出
        outputs = self.esm(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 从模型输出中获取序列输出
        sequence_output = outputs[0]
        # 将序列输出输入分类器,得到预测 logits
        logits = self.classifier(sequence_output)

        # 初始化损失为 None
        loss = None
        # 如果存在 labels,则计算损失
        if labels is not None:
            # 将 labels 移动到 logits 所在的设备上
            labels = labels.to(logits.device)

            # 根据问题类型确定问题类型("regression", "single_label_classification", "multi_label_classification")
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            # 根据问题类型计算相应的损失
            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        # 如果不要求返回字典形式的输出,则按元组形式返回结果
        if not return_dict:
            output = (logits,) + outputs[2:]  # 将 logits 和其他输出组成元组
            return ((loss,) + output) if loss is not None else output  # 如果有损失,则将损失与输出一起返回,否则只返回输出

        # 返回 SequenceClassifierOutput 对象,包括损失、logits、隐藏状态和注意力权重
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
@add_start_docstrings(
    """
    ESM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
    Named-Entity-Recognition (NER) tasks.
    """,
    ESM_START_DOCSTRING,
)



class EsmForTokenClassification(EsmPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        # 初始化 ESM 模型,不添加池化层
        self.esm = EsmModel(config, add_pooling_layer=False)
        
        # Dropout 层,用于防止过拟合
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        
        # 分类器,将隐藏状态映射到标签数的线性层
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        # 初始化模型权重
        self.init_weights()



    @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )



    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, TokenClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        
        # 确定是否返回字典类型的输出
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 获取 ESM 模型的输出
        outputs = self.esm(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取序列输出
        sequence_output = outputs[0]

        # 应用 Dropout 层
        sequence_output = self.dropout(sequence_output)
        
        # 使用分类器将序列输出映射到标签空间
        logits = self.classifier(sequence_output)

        # 计算损失
        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()

            # 将标签移到与 logits 相同的设备上
            labels = labels.to(logits.device)
            # 计算交叉熵损失
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        # 如果不返回字典,则以元组形式返回输出
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 返回 TokenClassifierOutput 对象
        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )



class EsmClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""
    # 初始化函数,用于创建一个新的神经网络模型实例
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__()
        # 创建一个全连接层,输入和输出维度都是 config.hidden_size
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 创建一个 Dropout 层,使用 config.hidden_dropout_prob 作为丢弃概率
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # 创建一个全连接层,输入维度是 config.hidden_size,输出维度是 config.num_labels
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    # 前向传播函数,定义了数据从输入到输出的流程
    def forward(self, features, **kwargs):
        # 取 features 的第一个位置的数据,通常表示起始 token(如 [CLS])
        x = features[:, 0, :]  # take <s> token (equiv. to [CLS])
        # 对取出的数据应用 Dropout,随机部分神经元失活,防止过拟合
        x = self.dropout(x)
        # 将数据通过全连接层 self.dense 进行线性变换
        x = self.dense(x)
        # 对变换后的数据应用双曲正切函数进行非线性变换
        x = torch.tanh(x)
        # 再次应用 Dropout 层,进一步随机失活神经元
        x = self.dropout(x)
        # 将数据通过全连接层 self.out_proj 进行线性变换,得到最终的输出结果
        x = self.out_proj(x)
        # 返回神经网络模型的输出结果
        return x
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
    """
    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
    are ignored. This is modified from fairseq's `utils.make_positions`.

    Args:
        input_ids: torch.Tensor, input tensor containing token IDs
        padding_idx: int, the index of padding tokens in input_ids
        past_key_values_length: int, optional, length of past key values for incremental processing

    Returns:
        torch.Tensor, tensor of position IDs corresponding to input_ids
    """
    # 创建一个掩码,标记非填充符号的位置为1,填充符号为0
    mask = input_ids.ne(padding_idx).int()
    # 计算每个非填充符号的位置编号,位置编号从 padding_idx+1 开始,乘以掩码以忽略填充符号
    incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
    # 将位置编号转换为长整型并加上 padding_idx,得到最终的位置 ID
    return incremental_indices.long() + padding_idx

.\models\esm\modeling_esmfold.py

# 设置编码格式为 UTF-8
# 版权声明和许可证信息,表明此代码的使用和分发需要遵循 Apache License, Version 2.0
# 导入必要的库和模块
import math
import sys
from dataclasses import dataclass
from functools import partial
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np  # 导入 NumPy 库,用于数值计算
import torch  # 导入 PyTorch 深度学习框架
import torch.nn as nn  # 导入 PyTorch 的神经网络模块
from torch.nn import LayerNorm  # 导入 PyTorch 的 LayerNorm 模块

# 导入相关的模块和函数,用于 DeepSpeed 集成、模型输出、文档字符串处理等
from ...integrations.deepspeed import is_deepspeed_available
from ...modeling_outputs import ModelOutput
from ...utils import (
    ContextManagers,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_scipy_available,
    logging,
    replace_return_docstrings,
)
# 导入 ESM 模型的配置文件和模型定义
from .configuration_esm import EsmConfig
from .modeling_esm import ESM_START_DOCSTRING, EsmModel, EsmPreTrainedModel
# 导入与蛋白质折叠相关的工具函数和类
from .openfold_utils import (
    OFProtein,
    Rigid,
    Rotation,
    atom14_to_atom37,
    chunk_layer,
    compute_predicted_aligned_error,
    compute_tm,
    frames_and_literature_positions_to_atom14_pos,
    make_atom14_masks,
    residue_constants,
    to_pdb,
    torsion_angles_to_frames,
)

# 获取日志记录器对象
logger = logging.get_logger(__name__)
# 文档中的模型检查点和配置信息
_CHECKPOINT_FOR_DOC = "facebook/esmfold_v1"
_CONFIG_FOR_DOC = "EsmConfig"

@dataclass
class EsmForProteinFoldingOutput(ModelOutput):
    """
    [`EsmForProteinFoldingOutput`] 的输出类型。
    """
    Args:
        frames (`torch.FloatTensor`):
            输出帧。
            模型预测的帧输出。
        sidechain_frames (`torch.FloatTensor`):
            侧链帧。
            模型预测的侧链帧输出。
        unnormalized_angles (`torch.FloatTensor`):
            预测的未归一化主链和侧链扭转角度。
            模型预测的未归一化主链和侧链扭转角度。
        angles (`torch.FloatTensor`):
            预测的主链和侧链扭转角度。
            模型预测的主链和侧链扭转角度。
        positions (`torch.FloatTensor`):
            预测的主链和侧链原子的位置。
            模型预测的主链和侧链原子位置。
        states (`torch.FloatTensor`):
            蛋白质折叠主干的隐藏状态。
            来自蛋白质折叠主干的隐藏状态。
        s_s (`torch.FloatTensor`):
            每个残基嵌入。
            通过连接ESM-2 LM stem每层的隐藏状态得到的每个残基嵌入。
        s_z (`torch.FloatTensor`):
            成对残基嵌入。
            成对残基嵌入。
        distogram_logits (`torch.FloatTensor`):
            距离直方图的输入对数。
            用于计算残基距离的输入对数。
        lm_logits (`torch.FloatTensor`):
            ESM-2蛋白质语言模型主干的输出对数。
            ESM-2蛋白质语言模型主干的输出对数。
        aatype (`torch.FloatTensor`):
            输入的氨基酸(AlphaFold2索引)。
            输入的氨基酸(AlphaFold2索引)。
        atom14_atom_exists (`torch.FloatTensor`):
            每个原子在atom14表示中是否存在。
            每个原子在atom14表示中是否存在。
        residx_atom14_to_atom37 (`torch.FloatTensor`):
            atom14到atom37表示之间的映射。
            atom14到atom37表示之间的映射。
        residx_atom37_to_atom14 (`torch.FloatTensor`):
            atom37到atom14表示之间的映射。
            atom37到atom14表示之间的映射。
        atom37_atom_exists (`torch.FloatTensor`):
            每个原子在atom37表示中是否存在。
            每个原子在atom37表示中是否存在。
        residue_index (`torch.FloatTensor`):
            蛋白链中每个残基的索引。
            蛋白链中每个残基的索引。
        lddt_head (`torch.FloatTensor`):
            lddt头部的原始输出。
            用于计算plddt的lddt头部的原始输出。
        plddt (`torch.FloatTensor`):
            每个残基的置信度分数。
            模型预测结构可能不确定或蛋白结构无序的区域可能表明低置信度的区域。
        ptm_logits (`torch.FloatTensor`):
            用于计算ptm的原始logits。
            用于计算ptm的原始logits。
        ptm (`torch.FloatTensor`):
            TM-score输出,代表模型对整体结构的高级置信度。
            TM-score输出,代表模型对整体结构的高级置信度。
        aligned_confidence_probs (`torch.FloatTensor`):
            对齐结构的每个残基置信度分数。
            对齐结构的每个残基置信度分数。
        predicted_aligned_error (`torch.FloatTensor`):
            模型预测与真实值之间的预测误差。
            模型预测与真实值之间的预测误差。
        max_predicted_aligned_error (`torch.FloatTensor`):
            每个样本的最大预测误差。
            每个样本的最大预测误差。
    """

    frames: torch.FloatTensor = None
    sidechain_frames: torch.FloatTensor = None
    unnormalized_angles: torch.FloatTensor = None
    angles: torch.FloatTensor = None
    # 定义一系列变量,每个变量的类型均为 torch.FloatTensor,初始赋值为 None
    positions: torch.FloatTensor = None  # 用于存储位置信息的张量
    states: torch.FloatTensor = None  # 用于存储状态信息的张量
    s_s: torch.FloatTensor = None  # 用于存储 s_s 信息的张量
    s_z: torch.FloatTensor = None  # 用于存储 s_z 信息的张量
    distogram_logits: torch.FloatTensor = None  # 用于存储距离直方图 logits 的张量
    lm_logits: torch.FloatTensor = None  # 用于存储语言模型 logits 的张量
    aatype: torch.FloatTensor = None  # 用于存储氨基酸类型的张量
    atom14_atom_exists: torch.FloatTensor = None  # 用于存储 atom14 是否存在的张量
    residx_atom14_to_atom37: torch.FloatTensor = None  # 用于存储 residue index 到 atom37 的映射的张量
    residx_atom37_to_atom14: torch.FloatTensor = None  # 用于存储 residue index 到 atom14 的映射的张量
    atom37_atom_exists: torch.FloatTensor = None  # 用于存储 atom37 是否存在的张量
    residue_index: torch.FloatTensor = None  # 用于存储残基索引的张量
    lddt_head: torch.FloatTensor = None  # 用于存储 lddt 头信息的张量
    plddt: torch.FloatTensor = None  # 用于存储 plddt 信息的张量
    ptm_logits: torch.FloatTensor = None  # 用于存储 ptm logits 的张量
    ptm: torch.FloatTensor = None  # 用于存储 ptm 信息的张量
    aligned_confidence_probs: torch.FloatTensor = None  # 用于存储对齐置信度概率的张量
    predicted_aligned_error: torch.FloatTensor = None  # 用于存储预测的对齐误差的张量
    max_predicted_aligned_error: torch.FloatTensor = None  # 用于存储最大预测对齐误差的张量
# 定义一个多行文档字符串,描述了函数 `ESMFOLD_INPUTS_DOCSTRING` 的参数及其含义
ESMFOLD_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        masking_pattern (`torch.LongTensor` of shape `({0})`, *optional*):
            Locations of tokens to mask during training as a form of regularization. Mask values selected in `[0, 1]`.
        num_recycles (`int`, *optional*, defaults to `None`):
            Number of times to recycle the input sequence. If `None`, defaults to `config.num_recycles`. "Recycling"
            consists of passing the output of the folding trunk back in as input to the trunk. During training, the
            number of recycles should vary with each batch, to ensure that the model learns to output valid predictions
            after each recycle. During inference, num_recycles should be set to the highest value that the model was
            trained with for maximum accuracy. Accordingly, when this value is set to `None`, config.max_recycles is
            used.
"""


def is_fp16_enabled():
    # 检查当前是否启用了 FP16 自动转换
    fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
    fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
    return fp16_enabled


def is_deepspeed_initialized():
    # 检查是否初始化了 DeepSpeed,如果 DeepSpeed 可用但未初始化则返回 False
    if is_deepspeed_available():
        return False
    else:
        try:
            import deepspeed

            # 尝试调用 DeepSpeed 的初始化检查函数,部分版本可能不支持此功能
            return deepspeed.utils.is_initialized()
        except Exception:
            # 捕获所有异常,返回 False 表示未初始化
            return False


def collate_dense_tensors(samples: List[torch.Tensor], pad_v: float = 0) -> torch.Tensor:
    """
    将一个张量列表堆叠并填充成一个单一张量,所有张量的维度必须一致。
    参数:
        samples: 包含多个张量的列表,每个张量的维度必须相同。
        pad_v: 填充值,默认为 0。
    返回:
        堆叠并填充后的单一张量。
    异常:
        如果 samples 中张量的维度不一致,抛出 RuntimeError 异常。
    """
    if len(samples) == 0:
        return torch.Tensor()  # 如果 samples 列表为空,则返回空张量

    if len({x.dim() for x in samples}) != 1:
        # 检查 samples 中张量的维度是否一致,不一致则抛出异常
        raise RuntimeError(f"Samples has varying dimensions: {[x.dim() for x in samples]}")
    # 从 samples 中获取设备信息,假设所有样本都在同一设备上
    (device,) = tuple({x.device for x in samples})
    
    # 计算 samples 中每个样本的最大形状的每个维度的最大值
    max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
    
    # 使用 torch.empty 创建一个与最大形状匹配的张量 result,长度为 len(samples),数据类型与 samples[0] 相同,设备与 samples 相同
    result = torch.empty(len(samples), *max_shape, dtype=samples[0].dtype, device=device)
    
    # 用 pad_v 填充 result 张量
    result.fill_(pad_v)
    
    # 遍历每个样本并将其复制到 result 张量的适当位置
    for i in range(len(samples)):
        result_i = result[i]  # 获取 result 中的第 i 个子张量
        t = samples[i]         # 获取第 i 个样本张量 t
        # 将样本张量 t 复制到 result_i 的正确位置
        result_i[tuple(slice(0, k) for k in t.shape)] = t
    
    # 返回填充后的 result 张量,其中包含了所有样本的数据
    return result
# 定义函数,用于将张量的最后几个维度展平成一个维度
def flatten_final_dims(t: torch.Tensor, no_dims: int):
    return t.reshape(t.shape[:-no_dims] + (-1,))


# 定义函数,用于对张量的最后几个维度进行置换
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
    # 计算最后几个维度的起始索引
    zero_index = -1 * len(inds)
    # 获取前面的维度索引列表
    first_inds = list(range(len(tensor.shape[:zero_index])))
    # 对张量进行置换操作
    return tensor.permute(first_inds + [zero_index + i for i in inds])


# 定义函数,对多个字典中相同键的值应用指定的函数
def dict_multimap(fn, dicts):
    # 获取第一个字典
    first = dicts[0]
    new_dict = {}
    # 遍历第一个字典的键值对
    for k, v in first.items():
        # 收集所有字典中相同键的值列表
        all_v = [d[k] for d in dicts]
        # 如果第一个字典中的值是字典类型,则递归调用dict_multimap函数
        if isinstance(v, dict):
            new_dict[k] = dict_multimap(fn, all_v)
        else:
            # 否则,对所有值应用给定的函数fn
            new_dict[k] = fn(all_v)
    # 返回应用函数后的新字典
    return new_dict


# 定义函数,使用截断正态分布初始化权重张量
def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
    shape = weights.shape
    # 计算缩放系数
    scale = scale / max(1, shape[1])

    # 检查是否存在SciPy库,如果不存在,则给出警告
    if not is_scipy_available():
        logger.warning(
            "This init requires scipy, but scipy was not found, default to an approximation that might not be"
            " equivalent."
        )
        # 使用近似值初始化权重张量
        std = math.sqrt(scale)
        torch.nn.init.normal_(weights, std=std).clamp(min=0.0, max=2.0 * std)

    else:
        from scipy.stats import truncnorm

        # 使用SciPy的截断正态分布生成权重样本
        std = math.sqrt(scale) / truncnorm.std(a=-2, b=2, loc=0, scale=1)
        samples = truncnorm.rvs(a=-2, b=2, loc=0, scale=std, size=weights.numel())
        samples = np.reshape(samples, shape)
        # 将生成的样本复制到权重张量中
        weights.copy_(torch.tensor(samples, device=weights.device))


# 定义函数,使用指定值初始化权重张量
def ipa_point_weights_init_(weights):
    with torch.no_grad():
        softplus_inverse_1 = 0.541324854612918
        # 用给定值填充权重张量
        weights.fill_(softplus_inverse_1)


# 定义类,继承自torch.nn.Linear,实现了自定义的初始化方法
class EsmFoldLinear(nn.Linear):
    """
    A Linear layer with built-in nonstandard initializations. Called just like torch.nn.Linear.

    Implements the initializers in 1.11.4, plus some additional ones found in the code.
    """

    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        bias: bool = True,
        init: str = "default",
        init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
        # 继承父类构造方法,定义额外的初始化参数
        **kwargs
    ):
        super().__init__(in_dim, out_dim, bias=bias, **kwargs)
    ):
        """
        Args:
            in_dim:
                输入层的最终维度
            out_dim:
                层输出的最终维度
            bias:
                是否学习一个可加偏置,默认为True
            init:
                要使用的初始化器。可选项包括:

                "default": LeCun fan-in截断正态分布初始化
                "relu": 带截断正态分布的He初始化
                "glorot": Fan-average Glorot均匀分布初始化
                "gating": 权重=0,偏置=1
                "normal": 标准差为1/sqrt(fan_in)的正态分布初始化
                "final": 权重=0,偏置=0

                如果init_fn不为None,则被init_fn覆盖。
            init_fn:
                接受权重和偏置作为输入的自定义初始化器。如果不为None,则覆盖init。
        """
        # 调用父类构造函数,初始化输入维度、输出维度和是否有偏置
        super().__init__(in_dim, out_dim, bias=bias)

        # 如果有偏置,用0填充偏置项
        if bias:
            with torch.no_grad():
                self.bias.fill_(0)

        # 初始化器和自定义初始化器赋值
        self.init = init
        self.init_fn = init_fn

        # 检查init参数是否合法
        if init not in ["default", "relu", "glorot", "gating", "normal", "final"]:
            raise ValueError("Invalid init string.")
class EsmFoldLayerNorm(nn.Module):
    def __init__(self, c_in, eps=1e-5):
        super().__init__()

        self.c_in = (c_in,)  # 输入通道数的元组,用于后续操作
        self.eps = eps  # Layer normalization 中的 epsilon 参数

        self.weight = nn.Parameter(torch.ones(c_in))  # 可学习的权重参数,默认为全1
        self.bias = nn.Parameter(torch.zeros(c_in))  # 可学习的偏置参数,默认为全0

    def forward(self, x):
        d = x.dtype  # 获取输入张量 x 的数据类型
        if d is torch.bfloat16 and not is_deepspeed_initialized():  # 如果输入是 bfloat16 并且没有启用深度速度优化
            with torch.cuda.amp.autocast(enabled=False):  # 禁用自动混合精度
                out = nn.functional.layer_norm(x, self.c_in, self.weight.to(dtype=d), self.bias.to(dtype=d), self.eps)  # 使用 layer normalization 进行归一化
        else:
            out = nn.functional.layer_norm(x, self.c_in, self.weight, self.bias, self.eps)  # 使用 layer normalization 进行归一化

        return out


@torch.jit.ignore
def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
    """
    Softmax, but without automatic casting to fp32 when the input is of type bfloat16
    """
    d = t.dtype  # 获取输入张量 t 的数据类型
    if d is torch.bfloat16 and not is_deepspeed_initialized():  # 如果输入是 bfloat16 并且没有启用深度速度优化
        with torch.cuda.amp.autocast(enabled=False):  # 禁用自动混合精度
            s = torch.nn.functional.softmax(t, dim=dim)  # 使用 softmax 计算张量 t 在指定维度上的概率分布
    else:
        s = torch.nn.functional.softmax(t, dim=dim)  # 使用 softmax 计算张量 t 在指定维度上的概率分布

    return s


class EsmFoldAttention(nn.Module):
    """
    Standard multi-head attention using AlphaFold's default layer initialization. Allows multiple bias vectors.
    """

    def __init__(
        self,
        c_q: int,
        c_k: int,
        c_v: int,
        c_hidden: int,
        no_heads: int,
        gating: bool = True,
    ):
        """
        Args:
            c_q:
                Input dimension of query data
            c_k:
                Input dimension of key data
            c_v:
                Input dimension of value data
            c_hidden:
                Per-head hidden dimension
            no_heads:
                Number of attention heads
            gating:
                Whether the output should be gated using query data
        """
        super().__init__()

        self.c_q = c_q  # 查询数据的输入维度
        self.c_k = c_k  # 键数据的输入维度
        self.c_v = c_v  # 值数据的输入维度
        self.c_hidden = c_hidden  # 每个注意力头的隐藏层维度
        self.no_heads = no_heads  # 注意力头的数量
        self.gating = gating  # 是否使用查询数据对输出进行门控

        # DISCREPANCY: c_hidden is not the per-head channel dimension, as
        # stated in the supplement, but the overall channel dimension.

        self.linear_q = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot")  # 查询线性变换层
        self.linear_k = EsmFoldLinear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot")  # 键线性变换层
        self.linear_v = EsmFoldLinear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot")  # 值线性变换层
        self.linear_o = EsmFoldLinear(self.c_hidden * self.no_heads, self.c_q, init="final")  # 输出线性变换层

        self.linear_g = None
        if self.gating:
            self.linear_g = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, init="gating")  # 门控线性变换层

        self.sigmoid = nn.Sigmoid()  # Sigmoid 激活函数的实例化
    # 准备 Q/K/V 查询、键、值的线性变换
    def _prep_qkv(self, q_x: torch.Tensor, kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # 对查询向量 q_x 执行线性变换
        q = self.linear_q(q_x)
        # 对键向量 kv_x 执行线性变换
        k = self.linear_k(kv_x)
        # 对值向量 kv_x 执行线性变换
        v = self.linear_v(kv_x)

        # 重新塑形以适应多头注意力机制的输入格式
        # [*, Q/K/V, H, C_hidden]
        q = q.view(q.shape[:-1] + (self.no_heads, -1))
        k = k.view(k.shape[:-1] + (self.no_heads, -1))
        v = v.view(v.shape[:-1] + (self.no_heads, -1))

        # 将多头维度与注意力头部数交换位置,以便后续计算注意力权重
        # [*, H, Q/K, C_hidden]
        q = q.transpose(-2, -3)
        k = k.transpose(-2, -3)
        v = v.transpose(-2, -3)

        # 缩放 Q 向量,以便在计算注意力权重时更稳定
        q /= math.sqrt(self.c_hidden)

        return q, k, v

    # 处理输出结果 o,并应用可选的全局门控线性变换
    def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor:
        if self.linear_g is not None:
            # 计算全局门控线性变换的输出,并应用 Sigmoid 激活函数
            g = self.sigmoid(self.linear_g(q_x))

            # 重新塑形以适应多头注意力机制的输入格式
            # [*, Q, H, C_hidden]
            g = g.view(g.shape[:-1] + (self.no_heads, -1))
            o = o * g

        # 将多头注意力机制的输出展平最后两个维度
        # [*, Q, H * C_hidden]
        o = flatten_final_dims(o, 2)

        # 对最终的输出应用线性变换,将其映射到输出空间
        # [*, Q, C_q]
        o = self.linear_o(o)

        return o

    # 实现模型的前向传播
    def forward(
        self,
        q_x: torch.Tensor,
        kv_x: torch.Tensor,
        biases: Optional[List[torch.Tensor]] = None,
        use_memory_efficient_kernel: bool = False,
        use_lma: bool = False,
        lma_q_chunk_size: int = 1024,
        lma_kv_chunk_size: int = 4096,
        use_flash: bool = False,
        flash_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Args:
            q_x:
                [*, Q, C_q] query data  # 输入的查询数据,形状为 [*, Q, C_q]
            kv_x:
                [*, K, C_k] key data  # 输入的键数据,形状为 [*, K, C_k]
            biases:
                List of biases that broadcast to [*, H, Q, K]  # 广播到 [*, H, Q, K] 的偏置列表
            use_memory_efficient_kernel:
                Whether to use a custom memory-efficient attention kernel. This should be the default choice for most.
                If none of the "use_<...>" flags are True, a stock PyTorch implementation is used instead
                是否使用自定义的内存高效注意力核。对于大多数情况,这应该是默认选择。
                如果没有一个 "use_<...>" 标志为 True,则使用标准的 PyTorch 实现
            use_lma:
                Whether to use low-memory attention (Staats & Rabe 2021). If none of the "use_<...>" flags are True, a
                stock PyTorch implementation is used instead
                是否使用低内存注意力 (Staats & Rabe 2021)。
                如果没有一个 "use_<...>" 标志为 True,则使用标准的 PyTorch 实现
            lma_q_chunk_size:
                Query chunk size (for LMA)  # 查询分块大小(用于低内存注意力)
            lma_kv_chunk_size:
                Key/Value chunk size (for LMA)  # 键/值分块大小(用于低内存注意力)
        Returns
            [*, Q, C_q] attention update  # 注意力更新后的输出,形状为 [*, Q, C_q]
        """
        if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None):
            raise ValueError("If use_lma is specified, lma_q_chunk_size and lma_kv_chunk_size must be provided")
            # 如果使用低内存注意力,并且没有提供查询或键/值的分块大小,则抛出数值错误异常

        if use_flash and biases is not None:
            raise ValueError("use_flash is incompatible with the bias option. For masking, use flash_mask instead")
            # 如果同时使用闪存和偏置选项,则抛出数值错误异常。应使用 flash_mask 进行遮罩操作而非偏置。

        attn_options = [use_memory_efficient_kernel, use_lma, use_flash]
        if sum(attn_options) > 1:
            raise ValueError("Choose at most one alternative attention algorithm")
            # 如果选择了多个注意力算法选项,则抛出数值错误异常。只能选择最多一个备选注意力算法。

        if biases is None:
            biases = []

        # [*, H, Q/K, C_hidden]
        query, key, value = self._prep_qkv(q_x, kv_x)
        key = permute_final_dims(key, (1, 0))
        # 准备查询、键、值,形状为 [*, H, Q/K, C_hidden],并将键的最后两个维度进行置换

        # [*, H, Q, K]
        output = torch.matmul(query, key)
        # 执行矩阵乘法得到注意力分数矩阵 [*, H, Q, K]
        for b in biases:
            output += b
        # 添加偏置到输出
        output = softmax_no_cast(output, -1)
        # 在最后一个维度上执行 softmax 操作,得到注意力权重

        # [*, H, Q, C_hidden]
        output = torch.matmul(output, value)
        # 使用注意力权重加权值,得到加权后的值矩阵,形状为 [*, H, Q, C_hidden]
        output = output.transpose(-2, -3)
        # 对输出进行维度转置,将倒数第二个和倒数第三个维度进行交换
        output = self._wrap_up(output, q_x)
        # 调用 _wrap_up 方法对输出进行包装处理,根据查询数据 q_x

        return output
class EsmFoldTriangleAttention(nn.Module):
    # 定义 EsmFoldTriangleAttention 类,继承自 nn.Module
    def __init__(self, c_in, c_hidden, no_heads, starting=True, inf=1e9):
        """
        Args:
            c_in:
                输入通道维度
            c_hidden:
                总体隐藏通道维度(非每个注意力头)
            no_heads:
                注意力头的数量
        """
        super().__init__()
        
        # 初始化类的属性
        self.c_in = c_in
        self.c_hidden = c_hidden
        self.no_heads = no_heads
        self.starting = starting
        self.inf = inf
        
        # 初始化层归一化对象
        self.layer_norm = LayerNorm(self.c_in)
        
        # 初始化线性层对象
        self.linear = EsmFoldLinear(c_in, self.no_heads, bias=False, init="normal")
        
        # 初始化自定义的注意力对象
        self.mha = EsmFoldAttention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads)

    @torch.jit.ignore
    def _chunk(
        self,
        x: torch.Tensor,
        biases: List[torch.Tensor],
        chunk_size: int,
        use_memory_efficient_kernel: bool = False,
        use_lma: bool = False,
        inplace_safe: bool = False,
    ) -> torch.Tensor:
        "triangle! triangle!"
        # 准备输入参数字典给多头注意力的 chunk_layer 方法
        mha_inputs = {
            "q_x": x,
            "kv_x": x,
            "biases": biases,
        }

        # 使用 chunk_layer 函数对注意力进行分块处理
        return chunk_layer(
            partial(self.mha, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma),
            mha_inputs,
            chunk_size=chunk_size,
            no_batch_dims=len(x.shape[:-2]),
            _out=x if inplace_safe else None,
        )

    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        chunk_size: Optional[int] = None,
        use_memory_efficient_kernel: bool = False,
        use_lma: bool = False,
        inplace_safe: bool = False,
    ) -> torch.Tensor:
        # 正向传播函数,接收输入张量 x 和可选的掩码 mask
        pass  # 实际实现在此处省略
    ) -> torch.Tensor:
        """
        Args:
            x:
                [*, I, J, C_in] input tensor (e.g. the pair representation)
        Returns:
            [*, I, J, C_in] output tensor
        """
        # 如果没有提供掩码,则创建一个形状为 [*, I, J] 的新张量,所有元素为1
        if mask is None:
            mask = x.new_ones(
                x.shape[:-1],
            )

        # 如果不是起始状态,交换输入张量的倒数第二和倒数第三个维度
        if not self.starting:
            x = x.transpose(-2, -3)
            mask = mask.transpose(-1, -2)

        # 对输入张量进行 layer normalization,形状保持不变 [*, I, J, C_in]
        x = self.layer_norm(x)

        # 创建一个形状为 [*, I, 1, 1, J] 的张量,其中 mask_bias 的计算基于 mask 张量
        mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]

        # 对线性层的输出进行维度变换,形状为 [*, H, I, J]
        triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))

        # 在倒数第四个维度上扩展 triangle_bias,形状变为 [*, 1, H, I, J]
        triangle_bias = triangle_bias.unsqueeze(-4)

        # 将 mask_bias 和 triangle_bias 放入列表中作为偏置项
        biases = [mask_bias, triangle_bias]

        # 如果指定了 chunk_size,则调用 _chunk 方法处理输入 x 和 biases
        if chunk_size is not None:
            x = self._chunk(
                x,
                biases,
                chunk_size,
                use_memory_efficient_kernel=use_memory_efficient_kernel,
                use_lma=use_lma,
                inplace_safe=inplace_safe,
            )
        else:
            # 否则调用 self.mha 进行多头注意力计算,使用给定的 biases
            x = self.mha(
                q_x=x, kv_x=x, biases=biases, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma
            )

        # 如果不是起始状态,恢复 x 的倒数第二和倒数第三个维度的顺序
        if not self.starting:
            x = x.transpose(-2, -3)

        # 返回处理后的张量 x
        return x
    """
    Implements Algorithms 11 and 12.
    实现第 11 和第 12 算法。
    """

    def __init__(self, config, _outgoing=True):
        # 初始化函数,设置模型参数
        super().__init__()
        # 从配置中获取隐藏状态的维度
        c_hidden = config.pairwise_state_dim
        # 是否是外部输出
        self._outgoing = _outgoing

        # 定义线性层,用于算法中的计算
        self.linear_a_p = EsmFoldLinear(c_hidden, c_hidden)
        self.linear_a_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
        self.linear_b_p = EsmFoldLinear(c_hidden, c_hidden)
        self.linear_b_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
        self.linear_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
        self.linear_z = EsmFoldLinear(c_hidden, c_hidden, init="final")

        # 初始化输入和输出的 LayerNorm
        self.layer_norm_in = LayerNorm(c_hidden)
        self.layer_norm_out = LayerNorm(c_hidden)

        # 定义 Sigmoid 激活函数
        self.sigmoid = nn.Sigmoid()

    def _combine_projections(
        self, a: torch.Tensor, b: torch.Tensor, _inplace_chunk_size: Optional[int] = None
    ) -> torch.Tensor:
        # 组合投影函数,根据 _outgoing 参数确定维度顺序
        if self._outgoing:
            a = permute_final_dims(a, (2, 0, 1))
            b = permute_final_dims(b, (2, 1, 0))
        else:
            a = permute_final_dims(a, (2, 1, 0))
            b = permute_final_dims(b, (2, 0, 1))

        # 如果指定了 _inplace_chunk_size,使用循环方式批量处理
        if _inplace_chunk_size is not None:
            # 待替换为 torch vmap 的部分
            for i in range(0, a.shape[-3], _inplace_chunk_size):
                a_chunk = a[..., i : i + _inplace_chunk_size, :, :]
                b_chunk = b[..., i : i + _inplace_chunk_size, :, :]
                a[..., i : i + _inplace_chunk_size, :, :] = torch.matmul(
                    a_chunk,
                    b_chunk,
                )

            p = a
        else:
            # 否则直接进行矩阵乘法运算
            p = torch.matmul(a, b)

        return permute_final_dims(p, (1, 2, 0))

    def _inference_forward(
        self,
        z: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        inplace_chunk_size: Optional[int] = None,
        with_add: bool = True,
    ):
        # 推断过程的前向传播函数,包括处理 mask、是否进行 in-place 操作和是否添加额外计算
        ...

    def forward(
        self,
        z: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        inplace_safe: bool = False,
        _add_with_inplace: bool = False,
        _inplace_chunk_size: Optional[int] = 256,
    ):
        # 模型的前向传播函数,接受输入张量 z 和可选的 mask,执行模型计算
        ...
    ) -> torch.Tensor:
        """
        Args:
            x:
                [*, N_res, N_res, C_z] input tensor 输入张量,形状为 [*, N_res, N_res, C_z]
            mask:
                [*, N_res, N_res] input mask 输入的遮罩,形状为 [*, N_res, N_res]
        Returns:
            [*, N_res, N_res, C_z] output tensor 输出张量,形状为 [*, N_res, N_res, C_z]
        """
        if inplace_safe:
            x = self._inference_forward(
                z,
                mask,
                inplace_chunk_size=_inplace_chunk_size,  # 设置原地操作的块大小
                with_add=_add_with_inplace,  # 原地操作时是否进行加法
            )
            return x  # 返回处理后的张量

        if mask is None:
            mask = z.new_ones(z.shape[:-1])  # 使用输入 z 的形状创建全为 1 的遮罩

        mask = mask.unsqueeze(-1)  # 在最后一个维度上增加一个维度,形状变为 [*, N_res, N_res, 1]

        z = self.layer_norm_in(z)  # 输入 z 执行层归一化操作
        a = mask  # 将 mask 赋值给变量 a
        a = a * self.sigmoid(self.linear_a_g(z))  # a 乘以线性变换后经过 sigmoid 函数的结果
        a = a * self.linear_a_p(z)  # a 乘以另一个线性变换的结果
        b = mask  # 将 mask 赋值给变量 b
        b = b * self.sigmoid(self.linear_b_g(z))  # b 乘以线性变换后经过 sigmoid 函数的结果
        b = b * self.linear_b_p(z)  # b 乘以另一个线性变换的结果

        if is_fp16_enabled():  # 如果启用了 FP16 计算
            with torch.cuda.amp.autocast(enabled=False):  # 关闭自动混合精度计算
                x = self._combine_projections(a.float(), b.float())  # 使用浮点数进行投影组合
        else:
            x = self._combine_projections(a, b)  # 使用原始数据类型进行投影组合

        del a, b  # 删除变量 a 和 b
        x = self.layer_norm_out(x)  # 对输出 x 进行层归一化操作
        x = self.linear_z(x)  # 对归一化后的 x 进行线性变换
        g = self.sigmoid(self.linear_g(z))  # 对 z 执行线性变换后经过 sigmoid 函数的结果
        x = x * g  # 将 x 乘以 g

        return x  # 返回处理后的张量
class EsmFoldPreTrainedModel(EsmPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # Subclass `EsmPreTrainedModel` to handle special initialization of weights
    def _init_weights(self, module):
        """Initialize the weights of the given module."""
        # Check if the module is an instance of `EsmFoldLinear`
        if isinstance(module, EsmFoldLinear):
            # Apply weight initialization based on module's initialization method
            with torch.no_grad():
                # Initialize using custom function if specified
                if module.init_fn is not None:
                    module.init_fn(module.weight, module.bias)
                # Initialize using truncated normal distribution with scale 1.0
                elif module.init == "default":
                    trunc_normal_init_(module.weight, scale=1.0)
                # Initialize using truncated normal distribution with scale 2.0
                elif module.init == "relu":
                    trunc_normal_init_(module.weight, scale=2.0)
                # Initialize using Xavier uniform initialization
                elif module.init == "glorot":
                    nn.init.xavier_uniform_(module.weight, gain=1)
                # Initialize weights to zero for "gating" type
                elif module.init == "gating":
                    module.weight.fill_(0.0)
                    # Initialize bias to 1.0 if bias exists
                    if module.bias:
                        module.bias.fill_(1.0)
                # Initialize using Kaiming normal distribution for "normal" type
                elif module.init == "normal":
                    torch.nn.init.kaiming_normal_(module.weight, nonlinearity="linear")
                # Initialize weights to zero for "final" type
                elif module.init == "final":
                    module.weight.fill_(0.0)
        # Initialize weights for `EsmFoldInvariantPointAttention` module
        elif isinstance(module, EsmFoldInvariantPointAttention):
            ipa_point_weights_init_(module.head_weights)
        # Initialize weights for `EsmFoldTriangularSelfAttentionBlock` module
        elif isinstance(module, EsmFoldTriangularSelfAttentionBlock):
            # Initialize various linear layers' weights and biases to zero
            torch.nn.init.zeros_(module.tri_mul_in.linear_z.weight)
            torch.nn.init.zeros_(module.tri_mul_in.linear_z.bias)
            torch.nn.init.zeros_(module.tri_mul_out.linear_z.weight)
            torch.nn.init.zeros_(module.tri_mul_out.linear_z.bias)
            torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.weight)
            torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.bias)
            torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.weight)
            torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.bias)

            torch.nn.init.zeros_(module.sequence_to_pair.o_proj.weight)
            torch.nn.init.zeros_(module.sequence_to_pair.o_proj.bias)
            torch.nn.init.zeros_(module.pair_to_sequence.linear.weight)
            torch.nn.init.zeros_(module.seq_attention.o_proj.weight)
            torch.nn.init.zeros_(module.seq_attention.o_proj.bias)
            torch.nn.init.zeros_(module.mlp_seq.mlp[-2].weight)
            torch.nn.init.zeros_(module.mlp_seq.mlp[-2].bias)
            torch.nn.init.zeros_(module.mlp_pair.mlp[-2].weight)
            torch.nn.init.zeros_(module.mlp_pair.mlp[-2].bias)
        else:
            # Call superclass method to initialize weights
            super()._init_weights(module)


class EsmFoldSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, head_width, gated=False):
        super().__init__()
        assert embed_dim == num_heads * head_width

        self.embed_dim = embed_dim  # 设置嵌入维度
        self.num_heads = num_heads  # 设置头的数量
        self.head_width = head_width  # 设置每个头的宽度

        # 定义投影层,将输入映射到更高维度的空间,不使用偏置项
        self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        # 输出投影层,将多头注意力的结果映射回原始的嵌入维度,使用偏置项
        self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True)
        self.gated = gated  # 是否启用门控机制
        if gated:
            # 门控投影层,用于门控机制的加权输出
            self.g_proj = nn.Linear(embed_dim, embed_dim)
            torch.nn.init.zeros_(self.g_proj.weight)  # 初始化权重为零
            torch.nn.init.ones_(self.g_proj.bias)  # 初始化偏置为一

        self.rescale_factor = self.head_width**-0.5  # 缩放因子

        torch.nn.init.zeros_(self.o_proj.bias)  # 输出投影层偏置初始化为零

    def forward(self, x, mask=None, bias=None, indices=None):
        """
        基础的自注意力机制,可选带掩码和外部的注意力偏置。用于处理不同长度的序列,使用掩码。

        Inputs:
            x: 输入序列的批量 (.. x L x C) mask: 批量的布尔掩码,其中 1=有效,0=填充位置 (.. x L_k) bias: 批量的标量注意力偏置 (.. x Lq x Lk x num_heads)

        Outputs:
            序列投影 (B x L x embed_dim), 注意力映射 (B x L x L x num_heads)
        """

        t = self.proj(x).view(*x.shape[:2], self.num_heads, -1)  # 投影并重塑张量形状
        t = t.permute(0, 2, 1, 3)  # 转置张量的维度顺序
        q, k, v = t.chunk(3, dim=-1)  # 拆分成查询、键、值

        q = self.rescale_factor * q  # 缩放查询向量
        a = torch.einsum("...qc,...kc->...qk", q, k)  # 执行注意力计算

        # 添加外部注意力偏置
        if bias is not None:
            a = a + bias.permute(0, 3, 1, 2)

        # 不参与填充令牌的注意力
        if mask is not None:
            mask = mask[:, None, None]
            a = a.masked_fill(mask == False, -np.inf)  # noqa: E712

        a = nn.functional.softmax(a, dim=-1)  # 执行 softmax 操作得到注意力权重

        y = torch.einsum("...hqk,...hkc->...qhc", a, v)  # 应用注意力权重到值上
        y = y.reshape(*y.shape[:2], -1)  # 重塑输出形状

        if self.gated:
            y = self.g_proj(x).sigmoid() * y  # 使用门控机制调节输出
        y = self.o_proj(y)  # 最终的输出投影

        return y, a.permute(0, 3, 1, 2)  # 返回结果及注意力权重的转置
class EsmFoldDropout(nn.Module):
    """
    Implementation of dropout with the ability to share the dropout mask along a particular dimension.
    """

    def __init__(self, r: float, batch_dim: Union[int, List[int]]):
        super().__init__()

        self.r = r  # 设定 dropout 的概率 r
        if isinstance(batch_dim, int):
            batch_dim = [batch_dim]
        self.batch_dim = batch_dim  # 指定需要共享 dropout mask 的维度
        self.dropout = nn.Dropout(self.r)  # 初始化 Dropout 层

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shape = list(x.shape)  # 获取输入张量 x 的形状
        if self.batch_dim is not None:
            for bd in self.batch_dim:
                shape[bd] = 1  # 将指定维度的大小设为 1,用于共享 dropout mask
        return x * self.dropout(x.new_ones(shape))  # 对输入张量 x 应用 dropout 操作


class EsmFoldSequenceToPair(nn.Module):
    def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim):
        super().__init__()

        self.layernorm = nn.LayerNorm(sequence_state_dim)  # 序列归一化层
        self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True)  # 线性投影层
        self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True)  # 输出线性投影层

        torch.nn.init.zeros_(self.proj.bias)  # 将投影层的偏置项初始化为零
        torch.nn.init.zeros_(self.o_proj.bias)  # 将输出投影层的偏置项初始化为零

    def forward(self, sequence_state):
        """
        Inputs:
          sequence_state: B x L x sequence_state_dim

        Output:
          pairwise_state: B x L x L x pairwise_state_dim

        Intermediate state:
          B x L x L x 2*inner_dim
        """

        assert len(sequence_state.shape) == 3  # 断言输入张量的形状为 B x L x sequence_state_dim

        s = self.layernorm(sequence_state)  # 序列归一化
        s = self.proj(s)  # 应用线性投影
        q, k = s.chunk(2, dim=-1)  # 将投影后的结果切分为两部分,q 和 k

        prod = q[:, None, :, :] * k[:, :, None, :]  # 计算乘积部分
        diff = q[:, None, :, :] - k[:, :, None, :]  # 计算差异部分

        x = torch.cat([prod, diff], dim=-1)  # 拼接乘积和差异部分
        x = self.o_proj(x)  # 应用输出投影层

        return x  # 返回输出张量


class EsmFoldPairToSequence(nn.Module):
    def __init__(self, pairwise_state_dim, num_heads):
        super().__init__()

        self.layernorm = nn.LayerNorm(pairwise_state_dim)  # 对成对状态维度进行归一化
        self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False)  # 线性层,用于生成成对偏置

    def forward(self, pairwise_state):
        """
        Inputs:
          pairwise_state: B x L x L x pairwise_state_dim

        Output:
          pairwise_bias: B x L x L x num_heads
        """
        assert len(pairwise_state.shape) == 4  # 断言输入张量的形状为 B x L x L x pairwise_state_dim
        z = self.layernorm(pairwise_state)  # 应用归一化层
        pairwise_bias = self.linear(z)  # 应用线性层生成成对偏置
        return pairwise_bias  # 返回成对偏置张量


class EsmFoldResidueMLP(nn.Module):
    def __init__(self, embed_dim, inner_dim, dropout=0):
        super().__init__()

        self.mlp = nn.Sequential(
            nn.LayerNorm(embed_dim),  # 对嵌入维度进行归一化
            nn.Linear(embed_dim, inner_dim),  # 第一个线性层
            nn.ReLU(),  # ReLU 激活函数
            nn.Linear(inner_dim, embed_dim),  # 第二个线性层
            nn.Dropout(dropout),  # Dropout 层
        )

    def forward(self, x):
        return x + self.mlp(x)  # 返回输入张量加上 MLP 处理后的结果


class EsmFoldTriangularSelfAttentionBlock(nn.Module):
    """
    Placeholder for a module implementing a triangular self-attention block.
    This class is not fully implemented in the provided code snippet.
    """
    # 初始化函数,用于创建对象实例时的初始化操作,接受一个配置参数config
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__()
        # 将配置参数保存到实例的config属性中
        self.config = config

        # 从配置参数中获取序列状态维度和成对状态维度
        sequence_state_dim = config.sequence_state_dim
        pairwise_state_dim = config.pairwise_state_dim

        # 根据配置参数计算序列自注意力机制的头数
        sequence_num_heads = sequence_state_dim // config.sequence_head_width
        # 根据配置参数计算成对自注意力机制的头数
        pairwise_num_heads = pairwise_state_dim // config.pairwise_head_width

        # 创建一个序列层归一化模块,传入序列状态维度作为参数
        self.layernorm_1 = nn.LayerNorm(sequence_state_dim)

        # 创建一个将序列映射为成对表示的模块,传入序列状态维度和一半的成对状态维度作为参数
        self.sequence_to_pair = EsmFoldSequenceToPair(sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim)
        # 创建一个将成对表示映射回序列表示的模块,传入成对状态维度和序列自注意力机制头数作为参数
        self.pair_to_sequence = EsmFoldPairToSequence(pairwise_state_dim, sequence_num_heads)

        # 创建一个序列自注意力机制模块,传入序列状态维度、头数、头宽度和是否启用门控机制作为参数
        self.seq_attention = EsmFoldSelfAttention(
            sequence_state_dim, sequence_num_heads, config.sequence_head_width, gated=True
        )
        
        # 创建一个序列三角形形态更新模块(输出方向),传入配置参数和输出方向(True表示输出方向)
        self.tri_mul_out = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=True)
        # 创建一个序列三角形形态更新模块(输入方向),传入配置参数和输出方向(False表示输入方向)
        self.tri_mul_in = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=False)

        # 创建一个成对三角形注意力模块(起始方向),传入成对状态维度、头宽度、头数、无穷大值和是否起始方向为True作为参数
        self.tri_att_start = EsmFoldTriangleAttention(
            pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=True
        )
        # 创建一个成对三角形注意力模块(结束方向),传入成对状态维度、头宽度、头数、无穷大值和是否起始方向为False作为参数
        self.tri_att_end = EsmFoldTriangleAttention(
            pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=False
        )

        # 创建一个序列残差MLP模块,传入序列状态维度、4倍的序列状态维度和dropout概率作为参数
        self.mlp_seq = EsmFoldResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=config.dropout)
        # 创建一个成对残差MLP模块,传入成对状态维度、4倍的成对状态维度和dropout概率作为参数
        self.mlp_pair = EsmFoldResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=config.dropout)

        # 创建一个普通的dropout模块,传入dropout概率作为参数
        self.drop = nn.Dropout(config.dropout)
        # 创建一个行dropout模块,传入2倍的dropout概率和1作为参数
        self.row_drop = EsmFoldDropout(config.dropout * 2, 2)
        # 创建一个列dropout模块,传入2倍的dropout概率和1作为参数
        self.col_drop = EsmFoldDropout(config.dropout * 2, 1)
class EsmCategoricalMixture:
    # 定义一个混合分类分布的类
    def __init__(self, param, bins=50, start=0, end=1):
        # 初始化方法,接收参数和一些配置信息
        # 所有的张量都是形状为 ..., bins
        self.logits = param
        # 创建一个等间距的张量 bins,用于表示值的中心点
        bins = torch.linspace(start, end, bins + 1, device=self.logits.device, dtype=self.logits.dtype)
        # 计算每个 bin 的中心值
        self.v_bins = (bins[:-1] + bins[1:]) / 2

    def log_prob(self, true):
        # 计算给定值的对数概率
        # Shapes are:
        #     self.probs: ... x bins
        #     true      : ...
        # 找到最接近 true 的值在 v_bins 中的索引
        true_index = (true.unsqueeze(-1) - self.v_bins[[None] * true.ndim]).abs().argmin(-1)
        # 计算 logits 的对数 softmax,并计算负对数似然
        nll = self.logits.log_softmax(-1)
        # 返回 true_index 处的对数概率
        return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1)

    def mean(self):
        # 计算混合分布的均值
        return (self.logits.softmax(-1) @ self.v_bins.unsqueeze(1)).squeeze(-1)


def categorical_lddt(logits, bins=50):
    # 计算混合分类分布的均值
    # Logits are ..., 37, bins.
    return EsmCategoricalMixture(logits, bins=bins).mean()


def get_axial_mask(mask):
    """
    Helper to convert B x L mask of valid positions to axial mask used in row column attentions.

    Input:
      mask: B x L tensor of booleans

    Output:
      mask: B x L x L tensor of booleans
    """
    # 将 B x L 的有效位置掩码转换为用于行列注意力的轴向掩码的辅助函数

    if mask is None:
        return None

    if len(mask.shape) != 2:
        # 如果掩码的维度不是 2,则抛出异常
        raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.")
    batch_dim, seq_dim = mask.shape
    # 在第二个维度上扩展掩码,以便生成 B x L x L 的掩码
    m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim)
    m = m.reshape(batch_dim * seq_dim, seq_dim)
    return m


class EsmFoldRelativePosition(nn.Module):
    # 相对位置编码模块
    def __init__(self, config):
        super().__init__()
        self.bins = config.position_bins

        # Note an additional offset is used so that the 0th position
        # is reserved for masked pairs.
        # 使用额外的偏移量,确保第 0 位置留给掩码对

        self.embedding = torch.nn.Embedding(2 * self.bins + 2, config.pairwise_state_dim)

    def forward(self, residue_index, mask=None):
        """
        Input:
          residue_index: B x L tensor of indices (dytpe=torch.long) mask: B x L tensor of booleans

        Output:
          pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings
        """
        # 前向传播函数,接收残基索引和掩码,返回残基的嵌入向量

        if residue_index.dtype != torch.long:
            # 如果残基索引的数据类型不是 torch.long,则抛出异常
            raise ValueError(f"`residue_index` has dtype {residue_index.dtype}, it should be `torch.long`.")
        if mask is not None and residue_index.shape != mask.shape:
            # 如果掩码不为空且形状与残基索引不一致,则抛出异常
            raise ValueError(
                f"`residue_index` and `mask` have inconsistent shapes: {residue_index.shape} != {mask.shape}."
            )

        # 计算残基索引之间的距离,并进行截断
        diff = residue_index[:, None, :] - residue_index[:, :, None]
        diff = diff.clamp(-self.bins, self.bins)
        diff = diff + self.bins + 1  # Add 1 to adjust for padding index.

        if mask is not None:
            # 如果掩码不为空,则应用掩码
            mask = mask[:, None, :] * mask[:, :, None]
            diff[mask == False] = 0  # noqa: E712

        # 使用嵌入层将距离转换为嵌入向量
        output = self.embedding(diff)
        return output


class EsmFoldAngleResnetBlock(nn.Module):
    # ESM 折叠角度 ResNet 块,未完整提供代码,无需注释
    # 初始化函数,接受一个配置对象作为参数
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__()

        # 创建第一个线性层,输入和输出维度都为 config.resnet_dim,使用 ReLU 激活函数初始化
        self.linear_1 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="relu")
        
        # 创建第二个线性层,输入和输出维度也为 config.resnet_dim,使用 "final" 方法进行初始化
        self.linear_2 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="final")

        # 创建 ReLU 激活函数层
        self.relu = nn.ReLU()

    # 前向传播函数,接受一个 torch.Tensor 类型的输入 a,返回一个 torch.Tensor 类型的输出
    def forward(self, a: torch.Tensor) -> torch.Tensor:
        # 保存初始输入 a 到 s_initial 中
        s_initial = a

        # 对输入 a 应用 ReLU 激活函数
        a = self.relu(a)
        
        # 将经过 ReLU 激活函数后的输入 a 传入第一个线性层 self.linear_1
        a = self.linear_1(a)
        
        # 再次应用 ReLU 激活函数
        a = self.relu(a)
        
        # 将经过第一个线性层和 ReLU 后的输出 a 传入第二个线性层 self.linear_2
        a = self.linear_2(a)

        # 返回最终输出,它是第二个线性层的输出与初始输入的和
        return a + s_initial
class EsmFoldAngleResnet(nn.Module):
    """
    Implements Algorithm 20, lines 11-14
    """

    def __init__(self, config):
        super().__init__()
        self.config = config

        # 初始化输入线性层,将输入维度转换为ResNet维度
        self.linear_in = EsmFoldLinear(config.sequence_dim, config.resnet_dim)
        # 初始化初始线性层,将输入维度转换为ResNet维度
        self.linear_initial = EsmFoldLinear(config.sequence_dim, config.resnet_dim)

        # 初始化ResNet块的列表
        self.layers = nn.ModuleList()
        for _ in range(config.num_resnet_blocks):
            layer = EsmFoldAngleResnetBlock(config)
            self.layers.append(layer)

        # 初始化输出线性层,将ResNet维度转换为角度预测的维度(num_angles * 2)
        self.linear_out = EsmFoldLinear(config.resnet_dim, config.num_angles * 2)

        # 定义ReLU激活函数
        self.relu = nn.ReLU()

    def forward(self, s: torch.Tensor, s_initial: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            s:
                [*, C_hidden] 单个嵌入向量
            s_initial:
                [*, C_hidden] StructureModule 开始时的单个嵌入向量
        Returns:
            Tuple[torch.Tensor, torch.Tensor]:
                [*, no_angles, 2] 预测的角度
        """
        # 注意:补充资料中未提及对输入应用ReLU,但在源代码中存在。
        # 为了最大兼容性,保留源代码中的实现方式。

        # 对 s_initial 应用ReLU激活函数
        s_initial = self.relu(s_initial)
        # 经过初始线性层处理
        s_initial = self.linear_initial(s_initial)
        # 对 s 应用ReLU激活函数
        s = self.relu(s)
        # 经过输入线性层处理
        s = self.linear_in(s)
        # 加上初始嵌入向量处理后的结果
        s = s + s_initial

        # 遍历所有的ResNet块
        for l in self.layers:
            s = l(s)

        # 对结果应用ReLU激活函数
        s = self.relu(s)

        # 经过输出线性层处理,得到未归一化的预测值
        s = self.linear_out(s)

        # 将输出形状变换为 [*, no_angles, 2]
        s = s.view(s.shape[:-1] + (-1, 2))

        # 对 s 进行归一化处理
        unnormalized_s = s  # 保存未归一化的预测值
        norm_denom = torch.sqrt(
            torch.clamp(
                torch.sum(s**2, dim=-1, keepdim=True),
                min=self.config.epsilon,
            )
        )
        s = s / norm_denom  # 归一化处理

        return unnormalized_s, s


class EsmFoldInvariantPointAttention(nn.Module):
    """
    Implements Algorithm 22.
    """
    # 初始化函数,接受一个配置对象作为参数
    def __init__(self, config):
        # 调用父类的初始化函数
        super().__init__()
        # 将配置对象保存到实例属性中
        self.config = config

        # 从配置对象中获取各个维度的设定
        c_s = config.sequence_dim
        c_z = config.pairwise_dim
        self.hidden_dim = config.ipa_dim
        self.num_heads = config.num_heads_ipa
        self.num_qk_points = config.num_qk_points
        self.num_v_points = config.num_v_points

        # 下面的线性层与说明书中的规格不同。
        # 说明书中,它们没有偏置并使用Glorot初始化。
        # 在这里和官方源码中,它们带有偏置并使用默认的Lecun初始化。
        
        # 计算线性层q的输出维度
        hc = config.ipa_dim * config.num_heads_ipa
        # 创建线性层q,输入维度为c_s,输出维度为hc
        self.linear_q = EsmFoldLinear(c_s, hc)
        
        # 计算线性层kv的输出维度
        self.linear_kv = EsmFoldLinear(c_s, 2 * hc)

        # 计算线性层q_points的输出维度
        hpq = config.num_heads_ipa * config.num_qk_points * 3
        self.linear_q_points = EsmFoldLinear(c_s, hpq)

        # 计算线性层kv_points的输出维度
        hpkv = config.num_heads_ipa * (config.num_qk_points + config.num_v_points) * 3
        self.linear_kv_points = EsmFoldLinear(c_s, hpkv)

        # 创建线性层b,输入维度为c_z,输出维度为config.num_heads_ipa
        self.linear_b = EsmFoldLinear(c_z, config.num_heads_ipa)

        # 创建可学习的参数,用于存储头部权重
        self.head_weights = nn.Parameter(torch.zeros((config.num_heads_ipa)))

        # 计算拼接后的输出维度
        concat_out_dim = config.num_heads_ipa * (c_z + config.ipa_dim + config.num_v_points * 4)
        # 创建线性层out,输入维度为concat_out_dim,输出维度为c_s,使用"final"初始化方式
        self.linear_out = EsmFoldLinear(concat_out_dim, c_s, init="final")

        # 创建softmax激活函数,沿着最后一个维度进行softmax操作
        self.softmax = nn.Softmax(dim=-1)
        # 创建softplus激活函数
        self.softplus = nn.Softplus()

    # 前向传播函数定义,接受多个输入参数并返回一个输出
    def forward(
        self,
        s: torch.Tensor,
        z: Optional[torch.Tensor],
        r: Rigid,
        mask: torch.Tensor,
        _offload_inference: bool = False,
        _z_reference_list: Optional[Sequence[torch.Tensor]] = None,
class EsmFoldBackboneUpdate(nn.Module):
    """
    Implements part of Algorithm 23.
    """

    def __init__(self, config):
        super().__init__()

        # Initialize a linear layer for updating the backbone with 6 output features
        self.linear = EsmFoldLinear(config.sequence_dim, 6, init="final")

    def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            [*, N_res, C_s] single representation
        Returns:
            [*, N_res, 6] update vector
        """
        # Compute the update vector using the linear layer
        update = self.linear(s)

        return update


class EsmFoldStructureModuleTransitionLayer(nn.Module):
    def __init__(self, config):
        super().__init__()

        # Initialize three linear layers for transformation, using ReLU activation for the first two
        self.linear_1 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu")
        self.linear_2 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu")
        self.linear_3 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="final")

        self.relu = nn.ReLU()

    def forward(self, s):
        # Save the initial input for later residual connection
        s_initial = s

        # Pass through the three linear layers with ReLU activations in between
        s = self.linear_1(s)
        s = self.relu(s)
        s = self.linear_2(s)
        s = self.relu(s)
        s = self.linear_3(s)

        # Add the initial input to the transformed output (residual connection)
        s = s + s_initial

        return s


class EsmFoldStructureModuleTransition(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Initialize a series of transition layers based on the specified number in config
        self.layers = nn.ModuleList()
        for _ in range(config.num_transition_layers):
            l = EsmFoldStructureModuleTransitionLayer(config)
            self.layers.append(l)

        # Apply dropout and layer normalization
        self.dropout = nn.Dropout(config.dropout_rate)
        self.layer_norm = LayerNorm(config.sequence_dim)

    def forward(self, s):
        # Forward pass through each transition layer
        for l in self.layers:
            s = l(s)

        # Apply dropout and layer normalization to the final output
        s = self.dropout(s)
        s = self.layer_norm(s)

        return s


class EsmFoldStructureModule(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Buffers to be lazily initialized later
        # self.default_frames
        # self.group_idx
        # self.atom_mask
        # self.lit_positions

        # Initialize layer normalization for sequence and pairwise dimensions
        self.layer_norm_s = LayerNorm(config.sequence_dim)
        self.layer_norm_z = LayerNorm(config.pairwise_dim)

        # Linear layer for initial transformation of input sequence
        self.linear_in = EsmFoldLinear(config.sequence_dim, config.sequence_dim)

        # Initialize Invariant Point Attention and its associated dropout and layer normalization
        self.ipa = EsmFoldInvariantPointAttention(config)
        self.ipa_dropout = nn.Dropout(config.dropout_rate)
        self.layer_norm_ipa = LayerNorm(config.sequence_dim)

        # Initialize transition module, backbone update, and angle resnet modules
        self.transition = EsmFoldStructureModuleTransition(config)
        self.bb_update = EsmFoldBackboneUpdate(config)
        self.angle_resnet = EsmFoldAngleResnet(config)

    def forward(
        self,
        evoformer_output_dict,
        aatype,
        mask=None,
        _offload_inference=False,
    ):
        # Implementation of forward pass for the entire structure module is not provided here
        pass
    # 初始化残基常量,如果不存在默认帧,则注册为缓冲区张量
    def _init_residue_constants(self, float_dtype, device):
        if not hasattr(self, "default_frames"):
            self.register_buffer(
                "default_frames",
                torch.tensor(
                    residue_constants.restype_rigid_group_default_frame,
                    dtype=float_dtype,
                    device=device,
                    requires_grad=False,
                ),
                persistent=False,
            )
        # 如果不存在组索引,则注册为缓冲区张量
        if not hasattr(self, "group_idx"):
            self.register_buffer(
                "group_idx",
                torch.tensor(
                    residue_constants.restype_atom14_to_rigid_group,
                    device=device,
                    requires_grad=False,
                ),
                persistent=False,
            )
        # 如果不存在原子掩码,则注册为缓冲区张量
        if not hasattr(self, "atom_mask"):
            self.register_buffer(
                "atom_mask",
                torch.tensor(
                    residue_constants.restype_atom14_mask,
                    dtype=float_dtype,
                    device=device,
                    requires_grad=False,
                ),
                persistent=False,
            )
        # 如果不存在文献位置,则注册为缓冲区张量
        if not hasattr(self, "lit_positions"):
            self.register_buffer(
                "lit_positions",
                torch.tensor(
                    residue_constants.restype_atom14_rigid_group_positions,
                    dtype=float_dtype,
                    device=device,
                    requires_grad=False,
                ),
                persistent=False,
            )

    # 将扭转角转换为帧
    def torsion_angles_to_frames(self, r, alpha, f):
        # 懒惰地在正确的设备上初始化残基常量
        self._init_residue_constants(alpha.dtype, alpha.device)
        # 将扭转角转换为帧,使用默认帧作为参数之一
        return torsion_angles_to_frames(r, alpha, f, self.default_frames)

    # 将帧和文献位置转换为原子14位置
    def frames_and_literature_positions_to_atom14_pos(self, r, f):  # [*, N, 8]  # [*, N]
        # 懒惰地在正确的设备上初始化残基常量
        self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
        # 使用帧、组索引、原子掩码和文献位置将帧和文献位置转换为原子14位置
        return frames_and_literature_positions_to_atom14_pos(
            r,
            f,
            self.default_frames,
            self.group_idx,
            self.atom_mask,
            self.lit_positions,
        )
# 定义一个名为 EsmFoldingTrunk 的神经网络模块类,继承自 nn.Module
class EsmFoldingTrunk(nn.Module):
    # 初始化方法,接收一个 config 参数
    def __init__(self, config):
        super().__init__()
        # 将传入的 config 参数保存在实例变量 self.config 中
        self.config = config

        # 从 config 中获取序列状态维度和成对状态维度,并保存到本地变量 c_s 和 c_z 中
        c_s = config.sequence_state_dim
        c_z = config.pairwise_state_dim

        # 创建一个 EsmFoldRelativePosition 实例,用于生成成对位置嵌入
        self.pairwise_positional_embedding = EsmFoldRelativePosition(config)

        # 创建一个由多个 EsmFoldTriangularSelfAttentionBlock 实例组成的模块列表,
        # 列表的长度由 config.num_blocks 决定
        self.blocks = nn.ModuleList([EsmFoldTriangularSelfAttentionBlock(config) for _ in range(config.num_blocks)])

        # 设置循环使用的桶数为 15
        self.recycle_bins = 15
        # 创建一个用于序列状态归一化的 LayerNorm 实例,参数为 c_s
        self.recycle_s_norm = nn.LayerNorm(c_s)
        # 创建一个用于成对状态归一化的 LayerNorm 实例,参数为 c_z
        self.recycle_z_norm = nn.LayerNorm(c_z)
        # 创建一个嵌入层,用于存储循环分布信息,有 recycle_bins 个桶,每个桶长度为 c_z
        self.recycle_disto = nn.Embedding(self.recycle_bins, c_z)
        # 将嵌入层的第一个权重向量初始化为零
        self.recycle_disto.weight[0].detach().zero_()

        # 创建一个 EsmFoldStructureModule 实例,用于处理结构模块相关任务
        self.structure_module = EsmFoldStructureModule(config.structure_module)
        # 创建一个线性层,将序列状态映射到结构模块的序列维度大小
        self.trunk2sm_s = nn.Linear(c_s, config.structure_module.sequence_dim)
        # 创建一个线性层,将成对状态映射到结构模块的成对维度大小
        self.trunk2sm_z = nn.Linear(c_z, config.structure_module.pairwise_dim)

        # 初始化块的默认大小,用于分块处理注意力机制的输入
        self.chunk_size = config.chunk_size

    # 设置块的大小,用于分块处理注意力机制的输入
    def set_chunk_size(self, chunk_size):
        # 参数 chunk_size 指示将使用分块方式计算轴向注意力机制。
        # 这可以使得内存使用大致为 O(L) 而不是 O(L^2)。
        # 相当于在我们迭代的维度的块上运行一个 for 循环,
        # 其中 chunk_size 是块的大小,比如如果设置为 128,则意味着解析长度为 128 的块。
        self.chunk_size = chunk_size
    def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles):
        """
        Inputs:
          seq_feats: B x L x C tensor of sequence features pair_feats: B x L x L x C tensor of pair features residx: B
          x L long tensor giving the position in the sequence mask: B x L boolean tensor indicating valid residues

        Output:
          predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object
        """

        # 获取输入张量 seq_feats 的设备信息
        device = seq_feats.device
        # 初始化原始的序列特征和对特征
        s_s_0 = seq_feats
        s_z_0 = pair_feats

        # 如果未提供 no_recycles 参数,则使用配置中的最大循环次数
        if no_recycles is None:
            no_recycles = self.config.max_recycles
        else:
            # 如果提供了 no_recycles 参数,确保其不为负数
            if no_recycles < 0:
                raise ValueError("Number of recycles must not be negative.")
            # 将 no_recycles 值增加 1,因为第一个 'recycle' 是通过模型的标准前向传播
            no_recycles += 1

        def trunk_iter(s, z, residx, mask):
            # 为 z 添加位置编码嵌入
            z = z + self.pairwise_positional_embedding(residx, mask=mask)

            # 遍历所有的块(blocks),每个块执行一次
            for block in self.blocks:
                s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size)
            return s, z

        # 将初始的序列特征和对特征赋值给 s_s 和 s_z
        s_s = s_s_0
        s_z = s_z_0
        # 初始化用于循环的张量
        recycle_s = torch.zeros_like(s_s)
        recycle_z = torch.zeros_like(s_z)
        recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64)

        # 执行循环指定的次数(no_recycles)
        for recycle_idx in range(no_recycles):
            with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]):
                # === Recycling ===
                # 对 recycle_s 和 recycle_z 进行归一化处理,并转移到指定设备上
                recycle_s = self.recycle_s_norm(recycle_s.detach()).to(device)
                recycle_z = self.recycle_z_norm(recycle_z.detach()).to(device)
                # 添加距离约束到 recycle_z
                recycle_z += self.recycle_disto(recycle_bins.detach()).to(device)

                # 执行 trunk_iter 函数,更新 s_s 和 s_z
                s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask)

                # === Structure module ===
                # 使用结构模块生成结构预测,传入单体和对体的转换结果,真实的氨基酸序列和掩码
                structure = self.structure_module(
                    {"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)},
                    true_aa,
                    mask.float(),
                )

                # 更新 recycle_s 和 recycle_z 为当前的 s_s 和 s_z
                recycle_s = s_s
                recycle_z = s_z
                # 计算距离直方图所需的 bins,调用 distogram 方法
                recycle_bins = EsmFoldingTrunk.distogram(
                    structure["positions"][-1][:, :, :3],
                    3.375,
                    21.375,
                    self.recycle_bins,
                )

        # 将最终的 s_s 和 s_z 存储在结构对象中,并返回结构对象
        structure["s_s"] = s_s
        structure["s_z"] = s_z

        return structure

    @staticmethod
    def distogram(coords, min_bin, max_bin, num_bins):
        # 计算距离直方图,输入参数分别为坐标数组,最小bin值,最大bin值,bin的数量

        # 使用 torch.linspace 在设备上生成一组均匀间隔的边界值
        boundaries = torch.linspace(
            min_bin,
            max_bin,
            num_bins - 1,
            device=coords.device,
        )
        boundaries = boundaries**2  # 将边界值平方

        # 将输入的坐标数组按照特定维度切分成 N, CA, C 坐标数组
        N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)]

        # 推断出 CB 坐标
        b = CA - N
        c = C - CA
        a = b.cross(c, dim=-1)
        CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA

        # 计算 CB 坐标之间的距离的平方和,得到距离矩阵
        dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True)

        # 计算每对 CB 坐标之间的距离所属的 bin 编号
        bins = torch.sum(dists > boundaries, dim=-1)  # 得到距离直方图的矩阵

        return bins
# 导入函数用于添加文档字符串(docstring)信息到类
@add_start_docstrings(
    """
    ESMForProteinFolding is the HuggingFace port of the original ESMFold model. It consists of an ESM-2 "stem" followed
    by a protein folding "head", although unlike most other output heads, this "head" is similar in size and runtime to
    the rest of the model combined! It outputs a dictionary containing predicted structural information about the input
    protein(s).
    """,
    ESM_START_DOCSTRING,
)
# 定义 EsmForProteinFolding 类,继承自 EsmPreTrainedModel 类
class EsmForProteinFolding(EsmPreTrainedModel):
    # 不需要拆分的模块列表,用于模型训练和推理阶段的处理
    _no_split_modules = ["EsmFoldStructureModule", "EsmFoldTriangularSelfAttentionBlock"]
    # 初始化函数,接受一个配置参数,并调用父类的初始化方法
    def __init__(self, config):
        # 调用父类的初始化方法,传入配置参数
        super().__init__(config)

        # 将配置参数保存到当前对象的属性中
        self.config = config

        # 定义直方图分箱的数量为64
        self.distogram_bins = 64

        # 创建一个 EsmModel 对象,禁用添加池化层的选项
        self.esm = EsmModel(config, add_pooling_layer=False)

        # 将 EsmModel 的参数设置为不需要梯度
        self.esm.requires_grad_(False)
        
        # 如果配置中指定使用 fp16 模式,则将 EsmModel 切换为半精度
        if self.config.esmfold_config.fp16_esm:
            self.esm.half()

        # 设置 ESM 特征的维度为配置中指定的隐藏层大小
        self.esm_feats = self.config.hidden_size

        # 计算 ESM 注意力头的数量
        self.esm_attns = self.config.num_hidden_layers * self.config.num_attention_heads

        # 设置 ESM 层数为配置中指定的隐藏层数
        self.esm_layers = self.config.num_hidden_layers

        # 使用从词汇表中得到的映射创建一个缓冲区,用于将序列特征映射到 ESM 的表示
        self.register_buffer("af2_to_esm", self._af2_to_esm_from_vocab_list(config.vocab_list))

        # 创建一个可学习的参数,用于结合不同层的 ESM 输出
        self.esm_s_combine = nn.Parameter(torch.zeros(self.esm_layers + 1))

        # 从配置中获取 ESMFold 的 trunk 配置
        trunk_config = self.config.esmfold_config.trunk

        # 定义序列状态维度和配对状态维度
        c_s = trunk_config.sequence_state_dim
        c_z = trunk_config.pairwise_state_dim

        # 定义一个序列,包含一系列的层次归一化和线性变换,用于将 ESM 特征映射到序列状态维度
        self.esm_s_mlp = nn.Sequential(
            LayerNorm(self.esm_feats),
            nn.Linear(self.esm_feats, c_s),
            nn.ReLU(),
            nn.Linear(c_s, c_s),
        )

        # 定义序列的嵌入标记数量,包括填充标记、未知残基标记和掩码标记
        self.n_tokens_embed = residue_constants.restype_num + 3
        self.pad_idx = 0
        self.unk_idx = self.n_tokens_embed - 2
        self.mask_idx = self.n_tokens_embed - 1

        # 获取词汇表中特定标记的索引,如 "<cls>", "<mask>", "<eos>", "<pad>"
        self.esm_dict_cls_idx = self.config.vocab_list.index("<cls>")
        self.esm_dict_mask_idx = self.config.vocab_list.index("<mask>")
        self.esm_dict_eos_idx = self.config.vocab_list.index("<eos>")
        self.esm_dict_padding_idx = self.config.vocab_list.index("<pad>")

        # 如果配置指定要嵌入氨基酸标记,则创建一个嵌入层
        if self.config.esmfold_config.embed_aa:
            self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0)

        # 创建 ESMFold 的 trunk 部分
        self.trunk = EsmFoldingTrunk(trunk_config)

        # 定义直方图头部的线性层和蛋白质结构的头部线性层
        self.distogram_head = nn.Linear(c_z, self.distogram_bins)
        self.ptm_head = nn.Linear(c_z, self.distogram_bins)
        self.lm_head = nn.Linear(c_s, self.n_tokens_embed)

        # 定义 LDDT 预测的分箱数量
        self.lddt_bins = 50

        # 获取 trunk 配置中结构模块的配置信息
        structure_module_config = trunk_config.structure_module

        # 定义 LDDT 预测头部的线性层序列
        self.lddt_head = nn.Sequential(
            nn.LayerNorm(structure_module_config.sequence_dim),
            nn.Linear(structure_module_config.sequence_dim, self.config.esmfold_config.lddt_head_hid_dim),
            nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, self.config.esmfold_config.lddt_head_hid_dim),
            nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, 37 * self.lddt_bins),
        )
    # 定义模型的前向传播方法,接受多个输入张量作为参数
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        masking_pattern: Optional[torch.Tensor] = None,
        num_recycles: Optional[int] = None,
    ):
        # 省略的是前向传播的具体实现,根据输入参数计算模型输出
        pass

    # 将从AF2空间到ESM空间的索引映射转换为与输入设备相同的设备,以避免设备上的索引错误
    def af2_idx_to_esm_idx(self, aa, mask):
        if self.af2_to_esm.device != aa.device:
            self.af2_to_esm = self.af2_to_esm.to(aa.device)
        # 将aa中的每个元素加一,并将非1的位置用0填充
        aa = (aa + 1).masked_fill(mask != 1, 0)
        # 使用af2_to_esm映射aa中的每个元素,返回对应的索引
        return self.af2_to_esm[aa]

    # 计算语言模型的表示,接受ESM的张量作为输入,并返回处理后的张量
    def compute_language_model_representations(self, esmaa: torch.Tensor) -> torch.Tensor:
        device = next(self.parameters()).device
        B, L = esmaa.shape  # B为批次大小,L为序列长度

        # 如果配置要求绕过语言模型,则返回全零的张量作为输出
        if self.config.esmfold_config.bypass_lm:
            esm_s = torch.zeros(B, L, self.esm_s_combine.size[0], -1, self.esm_feats, device=device)
            return esm_s

        # 获取开始和结束的特殊标记索引
        bosi, eosi = self.esm_dict_cls_idx, self.esm_dict_eos_idx
        # 在序列的开头和结尾添加特殊标记索引
        bos = esmaa.new_full((B, 1), bosi)
        eos = esmaa.new_full((B, 1), self.esm_dict_padding_idx)
        esmaa = torch.cat([bos, esmaa, eos], dim=1)
        # 在推断过程中,使用第一个填充索引作为结束标记
        esmaa[range(B), (esmaa != 1).sum(1)] = eosi

        # 计算ESM模型的隐藏状态,返回多层隐藏状态的张量
        esm_hidden_states = self.esm(esmaa, attention_mask=esmaa != 1, output_hidden_states=True)["hidden_states"]
        esm_s = torch.stack(esm_hidden_states, dim=2)

        # 移除序列开头和结尾的特殊标记
        esm_s = esm_s[:, 1:-1]  # B, L, nLayers, C

        return esm_s

    # 对输入的aa和esmaa张量进行BERT掩码操作,并返回处理后的新张量
    def bert_mask(self, aa, esmaa, mask, pattern):
        new_aa = aa.clone()
        target = aa.clone()
        new_esmaa = esmaa.clone()
        # 将pattern为1的位置在new_aa中替换为mask_idx
        new_aa[pattern == 1] = self.mask_idx
        # 将pattern不为1的位置在target中替换为0
        target[pattern != 1] = 0
        # 将pattern为1的位置在new_esmaa中替换为esm_dict_mask_idx
        new_esmaa[pattern == 1] = self.esm_dict_mask_idx
        return new_aa, new_esmaa, target

    # 声明推断方法,接受序列文本或列表作为输入,不进行梯度计算
    @torch.no_grad()
    def infer(
        self,
        seqs: Union[str, List[str]],
        position_ids=None,
    ):
        if isinstance(seqs, str):
            # 如果输入的序列是字符串,则转换为单元素列表
            lst = [seqs]
        else:
            # 否则,直接使用输入的序列列表
            lst = seqs
        # 获取模型参数的设备信息
        device = next(self.parameters()).device
        # 使用自定义函数将输入序列转换为 one-hot 编码的张量
        aatype = collate_dense_tensors(
            [
                torch.from_numpy(
                    residue_constants.sequence_to_onehot(
                        sequence=seq,
                        mapping=residue_constants.restype_order_with_x,
                        map_unknown_to_x=True,
                    )
                )
                .to(device)
                .argmax(dim=1)
                for seq in lst
            ]
        )  # B=1 x L
        # 为每个序列生成掩码张量
        mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst])
        # 生成位置 ID 张量,如果未提供则创建一个新的
        position_ids = (
            torch.arange(aatype.shape[1], device=device).expand(len(lst), -1)
            if position_ids is None
            else position_ids.to(device)
        )
        # 如果位置 ID 张量的维度为 1,则扩展为二维张量
        if position_ids.ndim == 1:
            position_ids = position_ids.unsqueeze(0)
        # 调用模型的 forward 方法进行推断
        return self.forward(
            aatype,
            mask,
            position_ids=position_ids,
        )

    @staticmethod
    def output_to_pdb(output: Dict) -> List[str]:
        """Returns the pdb (file) string from the model given the model output."""
        # 将模型输出中的张量转移到 CPU 上,并转换为 numpy 数组
        output = {k: v.to("cpu").numpy() for k, v in output.items()}
        pdbs = []
        # 获取最终的原子位置和掩码信息
        final_atom_positions = atom14_to_atom37(output["positions"][-1], output)
        final_atom_mask = output["atom37_atom_exists"]
        # 遍历每个样本的预测结果,并生成相应的 PDB 对象
        for i in range(output["aatype"].shape[0]):
            aa = output["aatype"][i]
            pred_pos = final_atom_positions[i]
            mask = final_atom_mask[i]
            resid = output["residue_index"][i] + 1
            # 使用预测的信息创建 OFProtein 对象
            pred = OFProtein(
                aatype=aa,
                atom_positions=pred_pos,
                atom_mask=mask,
                residue_index=resid,
                b_factors=output["plddt"][i],
            )
            # 将生成的 PDB 对象转换为 PDB 文件格式字符串并添加到列表中
            pdbs.append(to_pdb(pred))
        return pdbs

    def infer_pdb(self, seqs, *args, **kwargs) -> str:
        """Returns the pdb (file) string from the model given an input sequence."""
        # 确保输入序列为字符串
        assert isinstance(seqs, str)
        # 调用 infer 方法进行推断
        output = self.infer(seqs, *args, **kwargs)
        # 将推断结果转换为 PDB 文件格式字符串并返回第一个结果
        return self.output_to_pdb(output)[0]

    def infer_pdbs(self, seqs: List[str], *args, **kwargs) -> List[str]:
        """Returns the pdb (file) string from the model given an input sequence."""
        # 调用 infer 方法进行推断
        output = self.infer(seqs, *args, **kwargs)
        # 将推断结果转换为 PDB 文件格式字符串列表并返回
        return self.output_to_pdb(output)

.\models\esm\modeling_tf_esm.py

# 设定编码格式为 UTF-8

# 版权声明和许可证信息

# 导入所需的库和模块
from __future__ import annotations  # 使用未来版本的 annotations 特性

import os  # 导入操作系统相关的功能
from typing import Optional, Tuple, Union  # 引入类型提示需要的数据结构

import numpy as np  # 导入 NumPy 库,用于科学计算
import tensorflow as tf  # 导入 TensorFlow 深度学习框架

# 导入 HuggingFace Transformers 相关的文件操作和模型输出等
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_tf_outputs import (
    TFBaseModelOutputWithPastAndCrossAttentions,
    TFBaseModelOutputWithPoolingAndCrossAttentions,
    TFMaskedLMOutput,
    TFSequenceClassifierOutput,
    TFTokenClassifierOutput,
)
from ...modeling_tf_utils import (
    TFMaskedLanguageModelingLoss,
    TFModelInputType,
    TFPreTrainedModel,
    TFSequenceClassificationLoss,
    TFTokenClassificationLoss,
    get_initializer,
    keras,
    shape_list,
    unpack_inputs,
)
from ...tf_utils import check_embeddings_within_bounds, stable_softmax
from ...utils import logging  # 导入日志记录工具
from .configuration_esm import EsmConfig  # 导入 ESM 模型的配置文件

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

# 用于文档的模型检查点和配置信息
_CHECKPOINT_FOR_DOC = "facebook/esm2_t6_8M_UR50D"
_CONFIG_FOR_DOC = "EsmConfig"

# 预训练模型存档列表
TF_ESM_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "facebook/esm2_t6_8M_UR50D",
    "facebook/esm2_t12_35M_UR50D",
    # 这里没有列出所有 ESM 模型,可以在 https://huggingface.co/models?filter=esm 查看完整列表
]


def rotate_half(x):
    """
    将张量沿最后一个维度分割成两半,然后进行旋转操作。
    Args:
        x: 输入的张量

    Returns:
        tf.Tensor: 旋转后的张量
    """
    x1, x2 = tf.split(x, 2, axis=-1)
    return tf.concat((-x2, x1), axis=-1)


def apply_rotary_pos_emb(x, cos, sin):
    """
    应用旋转位置嵌入到输入张量 x 中。
    Args:
        x: 输入的张量
        cos: 余弦值张量
        sin: 正弦值张量

    Returns:
        tf.Tensor: 应用旋转位置嵌入后的张量
    """
    cos = cos[:, :, : tf.shape(x)[-2], :]
    sin = sin[:, :, : tf.shape(x)[-2], :]

    return (x * cos) + (rotate_half(x) * sin)


def symmetrize(x):
    """
    对最后两个维度进行转置操作,使层对称化,用于接触预测。
    Args:
        x: 输入张量

    Returns:
        tf.Tensor: 对称化后的张量
    """
    return x + tf.linalg.matrix_transpose(x)  # 仅转置最后两个维度


def average_product_correct(x):
    """
    执行平均产品校正,用于接触预测。
    Args:
        x: 输入张量

    Returns:
        tf.Tensor: 校正后的张量
    """
    a1 = tf.reduce_sum(x, -1, keepdims=True)
    a2 = tf.reduce_sum(x, -2, keepdims=True)
    a12 = tf.reduce_sum(x, (-1, -2), keepdims=True)

    avg = a1 * a2
    avg = avg / a12
    normalized = x - avg
    return normalized


class TFRotaryEmbedding(keras.layers.Layer):
    """
    基于 RoFormer 中的旋转位置嵌入,对查询和键进行旋转矩阵变换,依赖它们的相对位置。
    """

    # 在此类中定义相关的方法和初始化操作
    def __init__(self, dim: int, name=None):
        super().__init__(name=name)
        # Matt: The PyTorch version of this layer does a lot of work to cache values, but we just rely on TF compilation
        # and/or XLA to sort out constants like that. It actually may not seem like this layer needs to be stateful at
        # all when we benefit from TF compilation, but it does. The reason is that self.inv_freq is a buffer in the
        # original implementation, but all the shared ESM checkpoints were trained with fp16 params. This means that
        # the inv_freq tensor was stored as a float16, and we need to replicate those lower-precision values or our
        # models give different outputs from the original.
        self.dim = dim

    def build(self, input_shape):
        super().build(input_shape)
        # 创建一个名为 "inv_freq" 的权重变量,其形状为 (self.dim // 2,),数据类型为 tf.float32,初始化为 1.0,不可训练
        self.inv_freq = self.add_weight(
            "inv_freq", shape=(self.dim // 2,), dtype=tf.float32, initializer=get_initializer(1.0), trainable=False
        )
        # 计算 inv_freq 的值,这是一个与序列长度相关的正弦余弦嵌入频率
        self.inv_freq.assign(
            1.0 / (10000 ** (tf.range(start=0, limit=self.dim, delta=2, dtype=tf.float32) / self.dim))
        )

    def _compute_cos_sin(self, x, seq_dimension=2):
        # 获取输入张量 x 的序列长度
        seq_len = tf.shape(x)[seq_dimension]

        # 创建一个序列 t,数据类型与 self.inv_freq 相同,长度为 seq_len
        t = tf.range(seq_len, dtype=self.inv_freq.dtype)
        # 计算频率矩阵 freqs,是 t 和 self.inv_freq 的外积
        freqs = tf.einsum("i, j -> ij", t, self.inv_freq)  # Outer multiplication
        # 创建正弦和余弦嵌入矩阵 emb,通过连接 freqs 和其自身的拷贝,axis=-1 表示在最后一个维度上连接
        emb = tf.concat((freqs, freqs), axis=-1)[None, None, :, :]

        return tf.cos(emb), tf.sin(emb)

    def call(self, q: tf.Tensor, k: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
        # 计算正弦和余弦嵌入矩阵 cos_emb 和 sin_emb,针对张量 k 在序列维度上进行计算
        cos_emb, sin_emb = self._compute_cos_sin(k, seq_dimension=-2)

        # 应用旋转位置嵌入到输入张量 q 和 k 上,并返回结果
        return (
            apply_rotary_pos_emb(q, cos_emb, sin_emb),
            apply_rotary_pos_emb(k, cos_emb, sin_emb),
        )
class TFEsmContactPredictionHead(keras.layers.Layer):
    """Performs symmetrization, apc, and computes a logistic regression on the output features"""

    def __init__(
        self,
        in_features: int,
        bias=True,
        eos_idx: int = 2,
        name=None,
    ):
        super().__init__(name=name)
        self.eos_idx = eos_idx  # 设置 eos 标记的索引值
        self.in_features = in_features  # 输入特征的维度
        self.regression = keras.layers.Dense(1, use_bias=bias, activation="sigmoid", name="regression")  # 定义逻辑回归层

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True  # 标记层已经构建
        if getattr(self, "regression", None) is not None:
            with tf.name_scope(self.regression.name):
                self.regression.build((None, self.in_features))  # 构建逻辑回归层的计算图

    def call(self, tokens, attentions):
        # remove eos token attentions
        eos_mask = tf.cast(tokens != self.eos_idx, attentions.dtype)  # 创建一个用于屏蔽 eos 标记的掩码
        eos_mask = tf.expand_dims(eos_mask, 1) * tf.expand_dims(eos_mask, 2)  # 将掩码扩展到适当的维度
        attentions = attentions * eos_mask[:, None, None, :, :]  # 使用掩码屏蔽 eos 标记的注意力值
        attentions = attentions[..., :-1, :-1]  # 移除最后一个维度中的 eos 标记的注意力值

        # remove cls token attentions
        attentions = attentions[..., 1:, 1:]  # 移除第一个维度中的 cls 标记的注意力值
        batch_size, layers, heads, seqlen, _ = shape_list(attentions)  # 获取注意力张量的形状信息
        attentions = tf.reshape(attentions, (batch_size, layers * heads, seqlen, seqlen))  # 重新整形注意力张量的维度

        # features: batch x channels x tokens x tokens (symmetric)
        attentions = average_product_correct(symmetrize(attentions))  # 对注意力张量进行对称化和平均产品校正
        attentions = tf.transpose(attentions, perm=(0, 2, 3, 1))  # 转置注意力张量的维度顺序
        return tf.squeeze(self.regression(attentions), 3)  # 使用逻辑回归层进行预测,并压缩维度以匹配输出形状


class TFEsmEmbeddings(keras.layers.Layer):
    """
    Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
    """
    # 初始化函数,用于创建一个新的对象实例
    def __init__(self, config, name=None):
        # 调用父类的初始化函数
        super().__init__(name=name)
        # 创建词嵌入层,用于将词汇索引映射到向量表示
        self.word_embeddings = keras.layers.Embedding(
            config.vocab_size,
            config.hidden_size,
            embeddings_initializer=get_initializer(config.initializer_range),
            name="word_embeddings",
        )
        # 创建位置嵌入层,用于表示输入序列中每个位置的信息
        self.position_embeddings = keras.layers.Embedding(
            config.max_position_embeddings,
            config.hidden_size,
            embeddings_initializer=get_initializer(config.initializer_range),
            name="position_embeddings",
        )

        # 根据配置选择是否添加层归一化操作
        if config.emb_layer_norm_before:
            self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
        else:
            self.layer_norm = None
        # 定义位置嵌入类型,默认为绝对位置编码
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")

        # 创建位置 ID,用于表示序列中每个位置的索引
        self.position_ids = tf.range(config.max_position_embeddings)[None, :]

        # 定义填充符的索引
        self.padding_idx = config.pad_token_id
        # 定义是否对 token 进行 dropout 的配置
        self.token_dropout = config.token_dropout
        # 定义 mask token 的索引
        self.mask_token_id = config.mask_token_id
        # 保存配置对象的引用
        self.config = config
        ):
            if position_ids is None:
                if input_ids is not None:
                    # 从输入的标记 IDs 创建位置 IDs。任何填充的标记保持填充状态。
                    position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
                else:
                    position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)

            if inputs_embeds is None:
                # 检查输入的标记 IDs 是否在词汇表大小范围内
                check_embeddings_within_bounds(input_ids, self.config.vocab_size)
                inputs_embeds = self.word_embeddings(input_ids)

            # 注意:如果未来要支持 ESM-1(而不是1b!),则需要在此处支持嵌入比例因子。
            embeddings = inputs_embeds

            # Matt: ESM 有处理 MLM 掩码的选项,稍微不同于通常。如果 token_dropout 标志为 False,
            # 则与 BERT/RoBERTa 处理方式相同。如果设置为 True,则屏蔽的标记被视为选择输入丢失并清零。
            # 当屏蔽的标记不存在时,通过缩放嵌入来补偿 (训练期间未屏蔽标记的比例) / (样本中未屏蔽标记的比例)。
            # 这类似于评估期间丢弃层缩小输出的方式(或者等价地,在训练期间缩放未丢弃的输出)。
            if self.token_dropout:
                # 将屏蔽标记的嵌入清零
                embeddings = tf.where((input_ids == self.mask_token_id)[:, :, None], 0.0, embeddings)
                # 训练时的屏蔽比率,硬编码为所有 ESM 模型训练运行中使用的比率
                mask_ratio_train = 0.15 * 0.8
                # 计算源长度
                src_lengths = tf.cast(tf.reduce_sum(attention_mask, axis=-1), tf.float32)
                # 检查是否有屏蔽的标记
                masked_tokens = input_ids == self.mask_token_id
                # 观察到的屏蔽比率
                mask_ratio_observed = tf.math.count_nonzero(masked_tokens, dtype=tf.float32, axis=-1) / src_lengths
                # 缩放嵌入以补偿 mask-dropout
                embeddings = embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]

            if self.position_embedding_type == "absolute":
                # 如果位置嵌入类型为绝对位置,则添加位置嵌入到嵌入中
                position_embeddings = self.position_embeddings(position_ids)
                embeddings += position_embeddings

            if self.layer_norm is not None:
                # 如果有层归一化,则对嵌入进行归一化
                embeddings = self.layer_norm(embeddings)
            if attention_mask is not None:
                # 如果存在注意力掩码,则将其应用于嵌入
                embeddings = embeddings * tf.cast(tf.expand_dims(attention_mask, -1), embeddings.dtype)
            # Matt: 我认为这行代码从 BERT 复制过来时出错了,暂时禁用它。
            # embeddings = self.dropout(embeddings)
            return embeddings
    # 如果已经构建过模型则直接返回,避免重复构建
    if self.built:
        return
    
    # 标记模型已经构建
    self.built = True
    
    # 如果存在词嵌入,则构建词嵌入层
    if getattr(self, "word_embeddings", None) is not None:
        # 在词嵌入的命名空间下,构建词嵌入层
        with tf.name_scope(self.word_embeddings.name):
            self.word_embeddings.build(None)
    
    # 如果存在位置嵌入,则构建位置嵌入层
    if getattr(self, "position_embeddings", None) is not None:
        # 在位置嵌入的命名空间下,构建位置嵌入层
        with tf.name_scope(self.position_embeddings.name):
            self.position_embeddings.build(None)
    
    # 如果存在层归一化,则构建层归一化层
    if getattr(self, "layer_norm", None) is not None:
        # 在层归一化的命名空间下,构建层归一化层,输入形状为 [None, None, self.config.hidden_size]
        with tf.name_scope(self.layer_norm.name):
            self.layer_norm.build([None, None, self.config.hidden_size])
class TFEsmSelfAttention(keras.layers.Layer):
    # 定义一个自注意力层的 TensorFlow 扩展类
    def __init__(self, config, position_embedding_type=None, name=None):
        # 初始化函数,设置参数并配置层的名称
        super().__init__(name=name)
        # 检查隐藏大小是否可以被注意力头数整除,若不能则抛出错误
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        # 设置注意力头数和每个头的大小
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # 创建查询、键、值的全连接层
        self.query = keras.layers.Dense(
            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
        )
        self.key = keras.layers.Dense(
            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
        )
        self.value = keras.layers.Dense(
            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
        )

        # 设置注意力概率的 dropout 层
        self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob)
        # 设置位置嵌入类型,默认为绝对位置
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        self.rotary_embeddings = None
        # 如果位置嵌入类型是相对键或者相对键-查询,则创建距离嵌入层
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = keras.layers.Embedding(
                2 * config.max_position_embeddings - 1,
                self.attention_head_size,
                embeddings_initializer=get_initializer(config.initializer_range),
            )
        # 如果位置嵌入类型是旋转,则创建旋转嵌入对象
        elif self.position_embedding_type == "rotary":
            self.rotary_embeddings = TFRotaryEmbedding(dim=self.attention_head_size, name="rotary_embeddings")

        # 设置是否为解码器的标志和存储配置信息
        self.is_decoder = config.is_decoder
        self.config = config

    def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor:
        # 重新排列张量的维度以便进行注意力计算
        new_x_shape = shape_list(x)[:-1] + [self.num_attention_heads, self.attention_head_size]
        x = tf.reshape(x, new_x_shape)
        return tf.transpose(x, perm=(0, 2, 1, 3))

    def call(
        self,
        hidden_states: tf.Tensor,
        attention_mask: tf.Tensor | None = None,
        head_mask: tf.Tensor | None = None,
        encoder_hidden_states: tf.Tensor | None = None,
        encoder_attention_mask: tf.Tensor | None = None,
        past_key_value: Tuple[Tuple[tf.Tensor]] | None = None,
        output_attentions: Optional[bool] = False,
        training: bool = False,
        **kwargs
    ) -> Tuple[tf.Tensor, Optional[Tuple[tf.Tensor]]]:
        # 定义自注意力层的调用方法,处理输入张量并返回处理后的张量和可选的注意力张量
    # 定义 build 方法,用于构建模型结构
    def build(self, input_shape=None):
        # 如果已经构建过,直接返回,避免重复构建
        if self.built:
            return
        # 标记模型为已构建状态
        self.built = True
        
        # 如果存在查询(query)属性,则构建查询的结构
        if getattr(self, "query", None) is not None:
            # 使用查询的名称作为命名空间
            with tf.name_scope(self.query.name):
                # 构建查询的形状为 [None, None, self.config.hidden_size]
                self.query.build([None, None, self.config.hidden_size])
        
        # 如果存在键(key)属性,则构建键的结构
        if getattr(self, "key", None) is not None:
            # 使用键的名称作为命名空间
            with tf.name_scope(self.key.name):
                # 构建键的形状为 [None, None, self.config.hidden_size]
                self.key.build([None, None, self.config.hidden_size])
        
        # 如果存在值(value)属性,则构建值的结构
        if getattr(self, "value", None) is not None:
            # 使用值的名称作为命名空间
            with tf.name_scope(self.value.name):
                # 构建值的形状为 [None, None, self.config.hidden_size]
                self.value.build([None, None, self.config.hidden_size])
        
        # 如果存在旋转嵌入(rotary_embeddings)属性,则构建其结构
        if getattr(self, "rotary_embeddings", None) is not None:
            # 使用旋转嵌入的名称作为命名空间
            with tf.name_scope(self.rotary_embeddings.name):
                # 构建旋转嵌入,传入 None 作为输入形状参数
                self.rotary_embeddings.build(None)
# 自定义 Keras 层,实现自注意力机制的输出层
class TFEsmSelfOutput(keras.layers.Layer):
    def __init__(self, config, name=None):
        super().__init__(name=name)
        # 创建一个全连接层,用于映射隐藏状态到指定大小的输出空间
        self.dense = keras.layers.Dense(
            config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        # 创建一个 Dropout 层,用于在训练时随机丢弃部分神经元,防止过拟合
        self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
        # 保存配置信息
        self.config = config

    def call(self, hidden_states, input_tensor, training=False):
        # 将隐藏状态通过全连接层映射,并应用激活函数
        hidden_states = self.dense(hidden_states)
        # 在训练模式下,对映射后的输出应用 Dropout
        hidden_states = self.dropout(hidden_states, training=training)
        # 将映射后的输出与输入张量相加,实现残差连接
        hidden_states += input_tensor
        return hidden_states

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果层已构建,则直接返回;否则,构建全连接层
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                # 构建全连接层,输入形状为 [None, None, hidden_size]
                self.dense.build([None, None, self.config.hidden_size])


# 自定义 Keras 层,实现注意力机制的中间层
class TFEsmAttention(keras.layers.Layer):
    def __init__(self, config, name=None):
        super().__init__(name=name)
        # 创建自注意力层
        self.self = TFEsmSelfAttention(config, name="self")
        # 创建自注意力层的输出层
        self.output_layer = TFEsmSelfOutput(config, name="output")
        # 初始化一个空集合,用于存储要剪枝的注意力头
        self.pruned_heads = set()
        # 创建 LayerNormalization 层,用于对输入进行归一化
        self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        # 保存配置信息
        self.config = config

    def prune_heads(self, heads):
        # 剪枝方法暂未实现
        raise NotImplementedError

    def call(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
        training=False,
    ):
        # 对输入的隐藏状态进行 LayerNormalization
        hidden_states_ln = self.LayerNorm(hidden_states)
        # 调用自注意力层进行计算,传入各种参数
        self_outputs = self.self(
            hidden_states_ln,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
            training,
        )
        # 将自注意力层的输出传递给输出层,同时传入原始的隐藏状态
        attention_output = self.output_layer(self_outputs[0], hidden_states)
        # 组装最终的输出,包括注意力输出和可能的其他信息
        outputs = (attention_output,) + self_outputs[1:]  # 如果需要输出注意力,将其添加到输出中
        return outputs

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果层已构建,则直接返回;否则,构建自注意力层和输出层
        if getattr(self, "self", None) is not None:
            with tf.name_scope(self.self.name):
                # 构建自注意力层
                self.self.build(None)
        if getattr(self, "output_layer", None) is not None:
            with tf.name_scope(self.output_layer.name):
                # 构建自注意力层的输出层
                self.output_layer.build(None)
        if getattr(self, "LayerNorm", None) is not None:
            with tf.name_scope(self.LayerNorm.name):
                # 构建 LayerNormalization 层,输入形状为 [None, None, hidden_size]
                self.LayerNorm.build([None, None, self.config.hidden_size])
    # 初始化函数,用于创建一个新的实例
    def __init__(self, config: EsmConfig, **kwargs):
        # 调用父类的初始化方法,传入额外的关键字参数
        super().__init__(**kwargs)

        # 创建一个全连接层,用于处理输入数据
        self.dense = keras.layers.Dense(
            units=config.intermediate_size,  # 设置全连接层的输出单元数
            kernel_initializer=get_initializer(config.initializer_range),  # 初始化权重的方式
            name="dense",  # 设置层的名称
        )
        self.config = config  # 保存配置信息到实例中

    # 调用函数,用于定义模型的前向传播逻辑
    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        hidden_states = self.dense(inputs=hidden_states)  # 将输入数据传入全连接层处理
        hidden_states = tf.nn.gelu(hidden_states)  # 使用GELU激活函数处理全连接层输出
        return hidden_states  # 返回处理后的数据

    # 构建函数,用于构建模型的层次结构
    def build(self, input_shape=None):
        if self.built:  # 如果模型已经构建过,直接返回
            return
        self.built = True  # 标记模型已构建

        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):  # 使用全连接层的名称作为命名空间
                self.dense.build([None, None, self.config.hidden_size])  # 构建全连接层的结构
# 自定义的 Transformer Encoder 层,继承自 keras.layers.Layer
class TFEsmLayer(keras.layers.Layer):
    # 初始化方法,接收配置参数 config 和可选的层名字 name
    def __init__(self, config, name=None):
        # 调用父类的初始化方法
        super().__init__(name=name)
        # 设定前馈传播时的块大小
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        # 序列长度维度,设定为1
        self.seq_len_dim = 1
        # 创建自注意力层 TFEsmAttention 对象
        self.attention = TFEsmAttention(config, name="attention")
        # 是否作为解码器使用的标志
        self.is_decoder = config.is_decoder
        # 是否添加交叉注意力机制的标志
        self.add_cross_attention = config.add_cross_attention
        # 如果添加了交叉注意力且不是解码器,则引发运行时错误
        if self.add_cross_attention:
            if not self.is_decoder:
                raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
            # 创建交叉注意力层 TFEsmAttention 对象
            self.crossattention = TFEsmAttention(config)
        # 创建中间层对象 TFEsmIntermediate
        self.intermediate = TFEsmIntermediate(config, name="intermediate")
        # 创建输出层对象 TFEsmOutput
        self.output_layer = TFEsmOutput(config, name="output")
        # 创建层归一化对象 LayerNorm
        self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        # 保存配置对象到 self.config
        self.config = config

    # 调用方法,实现层的前向传播逻辑
    def call(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
        training=False,
    ):
        # 如果过去的键/值对存在,则提取自注意力的缓存键/值对,位置在1和2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
        # 使用自注意力模块处理隐藏状态,应用注意力掩码和头掩码,输出注意力信息
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            past_key_value=self_attn_past_key_value,
            training=training,
        )
        # 提取自注意力模块的输出结果
        attention_output = self_attention_outputs[0]

        # 如果是解码器模型,则最后一个输出是自注意力缓存的元组
        if self.is_decoder:
            outputs = self_attention_outputs[1:-1]  # 提取除了最后一个元素外的所有元素
            present_key_value = self_attention_outputs[-1]  # 提取最后一个元素作为当前键/值对
        else:
            outputs = self_attention_outputs[1:]  # 如果输出注意力权重,则添加自注意力信息

        cross_attn_present_key_value = None
        # 如果是解码器且有编码器的隐藏状态
        if self.is_decoder and encoder_hidden_states is not None:
            # 如果模型没有交叉注意力层,则抛出异常
            if not hasattr(self, "crossattention"):
                raise AttributeError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated"
                    " with cross-attention layers by setting `config.add_cross_attention=True`"
                )

            # 提取交叉注意力模块的缓存键/值对,位置在过去键/值对元组的倒数第二和倒数第一位置
            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
            # 使用交叉注意力模块处理自注意力输出、注意力掩码、头掩码、编码器隐藏状态等信息
            cross_attention_outputs = self.crossattention(
                attention_output,
                attention_mask,
                head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                cross_attn_past_key_value,
                output_attentions,
                training=training,
            )
            # 提取交叉注意力模块的输出结果
            attention_output = cross_attention_outputs[0]
            # 添加交叉注意力信息到输出中
            outputs = outputs + cross_attention_outputs[1:-1]

            # 将交叉注意力的键/值对添加到当前键/值对元组的第三和第四位置
            cross_attn_present_key_value = cross_attention_outputs[-1]
            present_key_value = present_key_value + cross_attn_present_key_value

        # 对注意力输出进行 LayerNorm 处理
        layernorm_output = self.LayerNorm(attention_output)
        # 使用中间层处理 LayerNorm 后的输出
        intermediate_output = self.intermediate(hidden_states=layernorm_output)
        # 使用输出层处理中间层的输出和注意力输出
        layer_output = self.output_layer(
            hidden_states=intermediate_output, input_tensor=attention_output, training=training
        )
        # 将处理后的输出添加到总输出中
        outputs = (layer_output,) + outputs

        # 如果是解码器模型,将注意力的键/值对作为最后一个输出返回
        if self.is_decoder:
            outputs = outputs + (present_key_value,)

        return outputs
    # 构建函数,用于构建模型的各个部分
    def build(self, input_shape=None):
        # 如果已经构建过,则直接返回,避免重复构建
        if self.built:
            return
        # 标记模型已经构建
        self.built = True
        
        # 如果存在注意力层,构建注意力层并设置名称作用域
        if getattr(self, "attention", None) is not None:
            with tf.name_scope(self.attention.name):
                self.attention.build(None)
        
        # 如果存在中间层,构建中间层并设置名称作用域
        if getattr(self, "intermediate", None) is not None:
            with tf.name_scope(self.intermediate.name):
                self.intermediate.build(None)
        
        # 如果存在输出层,构建输出层并设置名称作用域
        if getattr(self, "output_layer", None) is not None:
            with tf.name_scope(self.output_layer.name):
                self.output_layer.build(None)
        
        # 如果存在 LayerNorm 层,构建 LayerNorm 层并设置名称作用域,
        # 输入形状为 [None, None, self.config.hidden_size]
        if getattr(self, "LayerNorm", None) is not None:
            with tf.name_scope(self.LayerNorm.name):
                self.LayerNorm.build([None, None, self.config.hidden_size])
# 定义自定义的 Transformer 编码器层,继承自 Keras 的 Layer 类
class TFEsmEncoder(keras.layers.Layer):
    # 初始化方法,接收配置参数和可选的名称
    def __init__(self, config, name=None):
        super().__init__(name=name)
        # 保存配置参数
        self.config = config
        # 创建多个 Transformer 编码层,根据配置中的隐藏层数量
        self.layer = [TFEsmLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
        # 创建一个 LayerNormalization 层,用于对嵌入层之后的结果进行归一化处理
        self.emb_layer_norm_after = keras.layers.LayerNormalization(
            epsilon=config.layer_norm_eps, name="emb_layer_norm_after"
        )

    # 定义调用方法,处理输入数据和各种选项
    def call(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True,
        training=False,
    ):
        # 初始化存储所有隐藏状态、自注意力和交叉注意力的元组,如果需要输出的话
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

        # 初始化存储下一个解码器缓存的元组,如果需要使用缓存的话
        next_decoder_cache = () if use_cache else None

        # 遍历所有 Transformer 编码层
        for i, layer_module in enumerate(self.layer):
            # 如果需要输出隐藏状态,则将当前层的隐藏状态添加到 all_hidden_states 中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 获取当前层的头部掩码
            layer_head_mask = head_mask[i] if head_mask is not None else None
            # 获取当前层的过去键值对(如果有的话)
            past_key_value = past_key_values[i] if past_key_values is not None else None

            # 调用当前层的处理方法,获取该层的输出结果
            layer_outputs = layer_module(
                hidden_states,
                attention_mask,
                layer_head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                past_key_value,
                output_attentions,
                training,
            )

            # 更新隐藏状态为当前层的输出的第一个元素
            hidden_states = layer_outputs[0]
            # 如果需要使用缓存,则将当前层的输出缓存加入到 next_decoder_cache 中
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)
            # 如果需要输出注意力分布,则将当前层的自注意力加入到 all_self_attentions 中
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)
                # 如果配置中包含交叉注意力,则将当前层的交叉注意力加入到 all_cross_attentions 中
                if self.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

        # 如果存在嵌入层之后的归一化层,则对最终的隐藏状态进行归一化处理
        if self.emb_layer_norm_after:
            hidden_states = self.emb_layer_norm_after(hidden_states)

        # 如果需要输出所有隐藏状态,则将最终的隐藏状态添加到 all_hidden_states 中
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 如果不需要以字典形式返回结果,则以元组形式返回多个结果
        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_decoder_cache,
                    all_hidden_states,
                    all_self_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )
        
        # 如果需要以字典形式返回结果,则创建 TFBaseModelOutputWithPastAndCrossAttentions 对象返回
        return TFBaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )
    # 如果已经构建过网络,则直接返回,避免重复构建
    if self.built:
        return

    # 标记网络已经构建
    self.built = True

    # 如果存在额外的嵌入层归一化操作,构建该层
    if getattr(self, "emb_layer_norm_after", None) is not None:
        # 在 TensorFlow 的命名空间下构建嵌入层归一化操作
        with tf.name_scope(self.emb_layer_norm_after.name):
            # 构建嵌入层归一化操作,指定输入形状为 [None, None, self.config.hidden_size]
            self.emb_layer_norm_after.build([None, None, self.config.hidden_size])

    # 如果存在多层网络结构,逐层构建网络
    if getattr(self, "layer", None) is not None:
        # 遍历每一层网络
        for layer in self.layer:
            # 在 TensorFlow 的命名空间下构建当前层网络
            with tf.name_scope(layer.name):
                # 构建当前层网络,输入形状暂时为 None,表示动态适配
                layer.build(None)
"""
定义一个自定义的 Keras 层 TFEsmPooler,用于 ESM 模型的池化操作。
从 transformers.models.bert.modeling_tf_bert.TFBertPooler 复制并修改为 ESM。

Parameters:
    config (EsmConfig): ESM 模型的配置对象,包含模型的各种参数。

Attributes:
    dense (Dense): 密集连接层,用于处理隐藏状态向量。
    config (EsmConfig): ESM 模型的配置对象。

Methods:
    call(hidden_states: tf.Tensor) -> tf.Tensor:
        对隐藏状态进行池化操作,只使用第一个 token 对应的隐藏状态。
    build(input_shape=None):
        构建层,初始化密集连接层。
"""

"""
ESM 模型的预训练模型基类 TFEsmPreTrainedModel。

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.

Attributes:
    config_class (EsmConfig): 模型配置类,指定为 EsmConfig。
    base_model_prefix (str): 基础模型名称前缀,设为 "esm"。

Notes:
    该类提供了预训练模型的通用方法,如初始化权重、下载和加载预训练模型等。
"""

"""
ESM 模型的输入文档字符串,描述模型的基本信息和使用方法。

This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)

This model is also a Keras [Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it as a
regular Keras model and refer to the TF/Keras documentation for all matters related to general usage and behavior.

Parameters:
    config ([`EsmConfig`]): Model configuration class with all the parameters of the
    model. Initializing with a config file does not load the weights associated with the model, only the
    configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""

"""
ESM 模型的输入文档字符串,描述输入参数的详细信息和用法示例。
"""
        Args:
            input_ids (`tf.Tensor` of shape `({0})`):
                # 输入序列中的词汇索引。可以使用 `AutoTokenizer` 获取这些索引。参见 `PreTrainedTokenizer.encode` 和 `PreTrainedTokenizer.__call__`。
                # [什么是输入 ID?](../glossary#input-ids)
                Indices of input sequence tokens in the vocabulary.

            attention_mask (`tf.Tensor` of shape `({0})`, *optional*):
                # 遮罩,用于避免在填充令牌的索引上执行注意力操作。
                # 遮罩值选在 `[0, 1]`:
                # - 1 表示 **不遮罩** 的标记,
                # - 0 表示 **遮罩** 的标记。
                Mask to avoid performing attention on padding token indices.

            position_ids (`tf.Tensor` of shape `({0})`, *optional*):
                # 输入序列标记在位置嵌入中的位置索引。选在 `[0, config.max_position_embeddings - 1]` 范围内。
                # [什么是位置 ID?](../glossary#position-ids)
                Indices of positions of each input sequence tokens in the position embeddings.

            head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
                # 用于置空自注意力模块中的选择头部的遮罩。
                # 遮罩值选在 `[0, 1]`:
                # - 1 表示头部 **不被遮罩**,
                # - 0 表示头部 **被遮罩**。
                Mask to nullify selected heads of the self-attention modules.

            inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
                # 可选,直接传递嵌入表示而不是 `input_ids`。如果想要更精确地控制如何将 `input_ids` 索引转换为相关联的向量,这很有用。
                # 比模型内部嵌入查找矩阵更有控制力。
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.

            output_attentions (`bool`, *optional*):
                # 是否返回所有注意力层的注意力张量。查看返回张量下的 `attentions` 获取更多细节。
                Whether or not to return the attentions tensors of all attention layers.

            output_hidden_states (`bool`, *optional*):
                # 是否返回所有层的隐藏状态。查看返回张量下的 `hidden_states` 获取更多细节。
                Whether or not to return the hidden states of all layers.

            return_dict (`bool`, *optional*):
                # 是否返回 [`~file_utils.ModelOutput`] 而不是普通的元组。
                Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""


@add_start_docstrings(
    "The bare ESM Model transformer outputting raw hidden-states without any specific head on top.",
    ESM_START_DOCSTRING,
)
class TFEsmMainLayer(keras.layers.Layer):
    """
    
    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
    cross-attention is added between the self-attention layers, following the architecture described in [Attention is
    all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.

    To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
    to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
    `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
    """

    _keys_to_ignore_on_load_missing = [r"position_ids"]

    def __init__(self, config, add_pooling_layer=True, name=None, **kwargs):
        super().__init__(name=name, **kwargs)

        self.config = config
        self.is_decoder = config.is_decoder  # 初始化解码器标志位

        self.embeddings = TFEsmEmbeddings(config, name="embeddings")  # 初始化嵌入层
        self.encoder = TFEsmEncoder(config, name="encoder")  # 初始化编码器
        self.pooler = TFEsmPooler(config, name="pooler") if add_pooling_layer else None  # 初始化池化层(如果需要)

        self.contact_head = TFEsmContactPredictionHead(
            in_features=self.config.num_hidden_layers * self.config.num_attention_heads, bias=True, name="contact_head"
        )  # 初始化接触预测头部

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "embeddings", None) is not None:
            with tf.name_scope(self.embeddings.name):
                self.embeddings.build(None)  # 构建嵌入层
        if getattr(self, "encoder", None) is not None:
            with tf.name_scope(self.encoder.name):
                self.encoder.build(None)  # 构建编码器
        if getattr(self, "pooler", None) is not None:
            with tf.name_scope(self.pooler.name):
                self.pooler.build(None)  # 构建池化层
        if getattr(self, "contact_head", None) is not None:
            with tf.name_scope(self.contact_head.name):
                self.contact_head.build(None)  # 构建接触预测头部

    def get_input_embeddings(self):
        return self.embeddings.word_embeddings  # 获取输入嵌入层的词嵌入

    def set_input_embeddings(self, value: tf.Variable):
        self.embeddings.word_embeddings.weight = value  # 设置输入嵌入层的权重
        self.embeddings.vocab_size = shape_list(value)[0]  # 设置词汇表大小

    def _prune_heads(self, heads_to_prune):
        raise NotImplementedError  # 剪枝头部的方法,未实现
    # 定义一个方法,用于调用模型,接收多种输入参数并返回预测结果
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
        encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
        past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: bool = False,
    ):
        # 定义一个方法,用于预测模型的接触点(contacts)
        def predict_contacts(self, tokens, attention_mask):
            # 调用当前对象(self)的call方法,传入tokens和attention_mask作为输入,
            # 并设定return_dict和output_attentions参数为True,以获取注意力权重信息。
            attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions
            # 将得到的注意力权重列表堆叠成一个张量,维度顺序与原始模型一致
            attns = tf.stack(attns, axis=1)
            
            # 在原始模型中,对于填充标记的注意力权重被完全置零。
            # 这通常不会有太大影响,因为其他标记不会关注它们,
            # 但是在接触点预测任务中,它们作为输入需要被模仿。
            # 因此,这里要做的是将填充标记对应位置的注意力权重置零。
            attention_mask = tf.cast(attention_mask, attns.dtype)
            attns *= attention_mask[:, None, None, None]  # 扩展维度匹配注意力权重张量
            attns *= attention_mask[:, None, None, :, None]  # 扩展维度匹配注意力权重张量
            
            # 调用模型的contact_head方法,传入tokens和处理后的注意力权重attns作为参数,
            # 返回接触点预测的结果。
            return self.contact_head(tokens, attns)
# 给 TFEsmModel 类添加文档字符串,描述其作为没有特定顶部头的原始隐藏状态输出的 ES 模型转换器
@add_start_docstrings(
    "The bare ESM Model transformer outputting raw hidden-states without any specific head on top.",
    ESM_START_DOCSTRING,
)
class TFEsmModel(TFEsmPreTrainedModel):
    def __init__(self, config: EsmConfig, add_pooling_layer=True, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        # 初始化 ES 模型的主层,根据给定的配置和是否添加池化层
        self.esm = TFEsmMainLayer(config, add_pooling_layer=add_pooling_layer, name="esm")

    # 对 call 方法进行装饰,添加文档字符串以描述模型前向传播的输入
    @unpack_inputs
    @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TFBaseModelOutputWithPoolingAndCrossAttentions,
        config_class=_CONFIG_FOR_DOC,
    )
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
        encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
        past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: Optional[bool] = False,
        # 这里继续列出所有的参数,描述它们的作用和可选性
    ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
        r"""
        encoder_hidden_states  (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

        past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
            contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
        use_cache (`bool`, *optional*, defaults to `True`):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`). Set to `False` during training, `True` during generation
        """
        outputs = self.esm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )
        return outputs

    def predict_contacts(self, tokens, attention_mask):
        # 调用模型的方法来预测接触点
        return self.esm.predict_contacts(tokens, attention_mask)

    def build(self, input_shape=None):
        if self.built:
            return
        # 标记模型已构建
        self.built = True
        if getattr(self, "esm", None) is not None:
            with tf.name_scope(self.esm.name):
                # 构建模型的子模块
                self.esm.build(None)
# 为模型添加文档字符串,描述其为带有顶部语言建模头的ESM模型
@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING)
class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss):
    # 在加载过程中忽略缺失的关键字列表
    _keys_to_ignore_on_load_missing = [r"position_ids"]
    # 在加载过程中忽略意外的关键字列表
    _keys_to_ignore_on_load_unexpected = [r"pooler"]

    def __init__(self, config):
        super().__init__(config)

        # 如果配置指示为decoder,则发出警告
        if config.is_decoder:
            logger.warning(
                "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for "
                "bi-directional self-attention."
            )

        # 初始化ESM主层,不添加池化层,并命名为"esm"
        self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm")
        # 初始化ESM语言建模头,并命名为"lm_head"
        self.lm_head = TFEsmLMHead(config, name="lm_head")
        
        # 如果需要绑定词嵌入
        if config.tie_word_embeddings:
            # 确保词嵌入已构建,以便进行绑定
            with tf.name_scope(os.path.join(self._name_scope(), "esm", "embeddings", "word_embeddings")):
                self.esm.embeddings.word_embeddings.build((None, None))
            # 将lm_head的解码器设置为与ESM的词嵌入权重相同
            self.lm_head.decoder = self.esm.embeddings.word_embeddings.weights[0]

    # 获取输出嵌入
    def get_output_embeddings(self):
        return self.lm_head.decoder

    # 设置输出嵌入
    def set_output_embeddings(self, new_embeddings):
        self.lm_head.decoder = new_embeddings

    # 获取语言建模头
    def get_lm_head(self):
        return self.lm_head

    # 模型调用函数,解包输入并添加模型前向传播的文档字符串
    @unpack_inputs
    @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TFMaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
        mask="<mask>",
    )
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
        encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
        labels: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: bool = False,
        ):
        # 模型前向传播逻辑在此实现
    ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
        r"""
        labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        kwargs (`Dict[str, any]`, optional, defaults to *{}*):
            Used to hide legacy arguments that have been deprecated.
        """
        # 设置是否返回字典格式的输出,如果未提供,则使用配置中的默认设置
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 使用 ESM 模型进行前向传播
        outputs = self.esm(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )
        # 获取模型输出的序列特征
        sequence_output = outputs[0]
        # 使用语言模型头部生成预测分数
        prediction_scores = self.lm_head(sequence_output)

        masked_lm_loss = None
        # 如果提供了标签,则计算掩码语言建模损失
        if labels is not None:
            masked_lm_loss = self.hf_compute_loss(labels=labels, logits=prediction_scores)

        # 如果不要求返回字典格式的输出
        if not return_dict:
            # 构造输出元组,包含预测分数及可能的额外输出
            output = (prediction_scores,) + outputs[2:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        # 返回 TFMaskedLMOutput 对象,包含损失、预测分数、隐藏状态和注意力权重
        return TFMaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def predict_contacts(self, tokens, attention_mask):
        # 调用 ESM 模型的预测接口,用于生成联系
        return self.esm.predict_contacts(tokens, attention_mask)

    def build(self, input_shape=None):
        # 如果模型已经构建,则直接返回
        if self.built:
            return
        # 设置模型为已构建状态
        self.built = True
        # 如果存在 ESM 模型,则在命名空间下构建它
        if getattr(self, "esm", None) is not None:
            with tf.name_scope(self.esm.name):
                self.esm.build(None)
        # 如果存在语言模型头部,则在命名空间下构建它
        if getattr(self, "lm_head", None) is not None:
            with tf.name_scope(self.lm_head.name):
                self.lm_head.build(None)
class TFEsmLMHead(keras.layers.Layer):
    """ESM Head for masked language modeling."""

    def __init__(self, config, name=None):
        super().__init__(name=name)
        # 创建一个全连接层,用于将输入特征映射到隐藏层大小的输出空间
        self.dense = keras.layers.Dense(
            config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )

        # 添加一个 LayerNormalization 层,用于标准化输入向量
        self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")

        # 如果设置了 tie_word_embeddings,decoder 为 None;否则创建一个全连接层,用于解码到词汇表大小
        if config.tie_word_embeddings:
            self.decoder = None
        else:
            self.decoder = keras.layers.Dense(
                config.vocab_size,
                kernel_initializer=get_initializer(config.initializer_range),
                name="decoder",
                use_bias=False,
            )
        self.config = config

    def build(self, input_shape=None):
        # 分离偏置项以匹配 PT 模型,并允许权重交叉加载工作
        # 将其放在 build 方法中,以便在将其添加为权重时获得正确的名称
        if self.built:
            return
        self.built = True
        # 添加一个名为 "bias" 的权重,形状为 (config.vocab_size,),并初始化为零,可训练
        self.bias = self.add_weight("bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True)
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                # 构建 dense 层,输入形状为 [None, None, config.hidden_size]
                self.dense.build([None, None, self.config.hidden_size])
        if getattr(self, "layer_norm", None) is not None:
            with tf.name_scope(self.layer_norm.name):
                # 构建 layer_norm 层,输入形状为 [None, None, config.hidden_size]
                self.layer_norm.build([None, None, self.config.hidden_size])
        if getattr(self, "decoder", None) is not None and not self.config.tie_word_embeddings:
            with tf.name_scope(self.decoder.name):
                # 构建 decoder 层,输入形状为 [None, None, config.hidden_size]
                self.decoder.build([None, None, self.config.hidden_size])

    def get_bias(self):
        return {"bias": self.bias}

    def call(self, features):
        # 经过 dense 层映射特征
        x = self.dense(features)
        # 使用 gelu 激活函数
        x = tf.nn.gelu(x)
        # 使用 layer_norm 层标准化输出
        x = self.layer_norm(x)

        # 根据 tie_word_embeddings 决定如何将 x 投影回词汇表大小,同时加上偏置
        if self.config.tie_word_embeddings:
            x = tf.matmul(x, self.decoder, transpose_b=True) + self.bias
        else:
            x = self.decoder(x) + self.bias
        return x


@add_start_docstrings(
    """
    ESM Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    """,
    ESM_START_DOCSTRING,
)
class TFEsmForSequenceClassification(TFEsmPreTrainedModel, TFSequenceClassificationLoss):
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    def __init__(self, config):
        super().__init__(config)
        # 设置分类或回归任务的标签数量
        self.num_labels = config.num_labels
        self.config = config

        # 创建 ESM 主层,不添加池化层,命名为 "esm"
        self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm")
        # 创建分类头部,命名为 "classifier"
        self.classifier = TFEsmClassificationHead(config, name="classifier")

    @unpack_inputs
    @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TFSequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    # 将当前函数用作代码示例的文档字符串,指定了一些参数和返回类型的信息
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        labels: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: bool = False,
    ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
        r"""
        labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        # 设置 return_dict 变量,若未提供则使用 self.config.use_return_dict 中的默认值
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用 self.esm 方法,执行序列编码模型的前向传播
        outputs = self.esm(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )
        # 从模型输出中获取序列输出
        sequence_output = outputs[0]
        # 将序列输出传递给分类器,生成分类任务的 logits
        logits = self.classifier(sequence_output)

        # 计算损失,如果 labels 不为 None,则使用 labels 和 logits 计算损失值
        loss = None if labels is None else self.hf_compute_loss(labels, logits)

        # 如果 return_dict 为 False,则构建输出元组
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 如果 return_dict 为 True,则构建 TFSequenceClassifierOutput 对象作为输出
        return TFSequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    # 构建模型,设置输入形状并初始化模型的各个组件
    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果存在 self.esm,则在命名空间 self.esm.name 下构建它
        if getattr(self, "esm", None) is not None:
            with tf.name_scope(self.esm.name):
                self.esm.build(None)
        # 如果存在 self.classifier,则在命名空间 self.classifier.name 下构建它
        if getattr(self, "classifier", None) is not None:
            with tf.name_scope(self.classifier.name):
                self.classifier.build(None)
@add_start_docstrings(
    """
    ESM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
    Named-Entity-Recognition (NER) tasks.
    """,
    ESM_START_DOCSTRING,
)
class TFEsmForTokenClassification(TFEsmPreTrainedModel, TFTokenClassificationLoss):
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    def __init__(self, config):
        super().__init__(config)
        # 初始化时设置分类标签数量
        self.num_labels = config.num_labels

        # 创建 ESM 主模型层,不包含池化层
        self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm")
        # Dropout 层,用于防止过拟合
        self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
        # 分类器层,将隐藏状态输出转化为分类预测
        self.classifier = keras.layers.Dense(config.num_labels, name="classifier")
        # 保存配置信息
        self.config = config

    @unpack_inputs
    @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TFTokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        labels: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: bool = False,
    ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
        r"""
        labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        # 确定是否返回字典格式的输出结果
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 调用 ESM 主模型进行前向传播
        outputs = self.esm(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        # 获取序列输出
        sequence_output = outputs[0]

        # 在训练时使用 Dropout 层防止过拟合
        sequence_output = self.dropout(sequence_output, training=training)
        # 使用分类器层生成分类预测 logits
        logits = self.classifier(sequence_output)

        # 如果没有提供标签,则不计算损失
        loss = None if labels is None else self.hf_compute_loss(labels, logits)

        # 根据是否返回字典格式来组织输出
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 返回 TFTokenClassifierOutput 格式的结果
        return TFTokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
    # 如果模型已经构建好,直接返回,不做重复构建
    if self.built:
        return
    # 将模型标记为已构建状态
    self.built = True
    
    # 如果存在名为"esm"的属性,并且不为None,执行以下操作
    if getattr(self, "esm", None) is not None:
        # 在命名空间下以"esm"的名称构建模型
        with tf.name_scope(self.esm.name):
            # 调用esm对象的build方法,传入None作为输入形状
            self.esm.build(None)
    
    # 如果存在名为"classifier"的属性,并且不为None,执行以下操作
    if getattr(self, "classifier", None) is not None:
        # 在命名空间下以"classifier"的名称构建模型
        with tf.name_scope(self.classifier.name):
            # 调用classifier对象的build方法,传入[None, None, self.config.hidden_size]作为输入形状
            self.classifier.build([None, None, self.config.hidden_size])
class TFEsmClassificationHead(keras.layers.Layer):
    """Head for sentence-level classification tasks."""

    def __init__(self, config, name=None):
        super().__init__(name=name)
        # 定义一个全连接层,用于生成隐藏层大小的输出,激活函数为tanh
        self.dense = keras.layers.Dense(
            config.hidden_size,
            kernel_initializer=get_initializer(config.initializer_range),
            activation="tanh",
            name="dense",
        )
        # 定义一个Dropout层,用于在训练时随机丢弃部分输入,以防止过拟合
        self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
        # 定义一个全连接层,用于生成类别数目大小的输出,激活函数为线性(即无激活函数)
        self.out_proj = keras.layers.Dense(
            config.num_labels,
            kernel_initializer=get_initializer(config.initializer_range),
            activation="linear",
            name="out_proj",
        )
        self.config = config

    def call(self, features, training=False):
        # 提取features中的第一个位置的向量(对应于<s> token,即[CLS]),作为输入x
        x = features[:, 0, :]  # take <s> token (equiv. to [CLS])
        # 在训练阶段使用dropout随机丢弃部分输入向量,防止过拟合
        x = self.dropout(x, training=training)
        # 将输入向量x通过全连接层dense进行线性变换,并应用tanh激活函数
        x = self.dense(x)
        # 再次在训练阶段使用dropout随机丢弃部分输出向量,防止过拟合
        x = self.dropout(x, training=training)
        # 将处理后的向量x通过全连接层out_proj进行线性变换,生成最终的分类输出
        x = self.out_proj(x)
        return x

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果dense层已定义,则建立其内部权重
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                # 建立dense层的权重,输入形状为[None, None, hidden_size]
                self.dense.build([None, None, self.config.hidden_size])
        # 如果out_proj层已定义,则建立其内部权重
        if getattr(self, "out_proj", None) is not None:
            with tf.name_scope(self.out_proj.name):
                # 建立out_proj层的权重,输入形状为[None, None, hidden_size]
                self.out_proj.build([None, None, self.config.hidden_size])


def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
    """
    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
    are ignored. This is modified from fairseq's `utils.make_positions`.

    Args:
        input_ids: 输入的整数张量,表示输入的符号序列
        padding_idx: 表示填充符号的索引
        past_key_values_length: 过去键值长度,用于计算增量索引

    Returns:
        tf.Tensor: 包含位置ID的张量,替换非填充符号为其位置数字
    """
    # 创建一个掩码,标记出不是填充符号的位置
    mask = tf.cast(input_ids != padding_idx, tf.int64)
    # 计算每个位置的增量索引,跳过填充符号,位置编号从padding_idx+1开始
    incremental_indices = (tf.cumsum(mask, axis=1) + past_key_values_length) * mask
    # 将增量索引加上padding_idx,得到最终的位置ID张量
    return incremental_indices + padding_idx
posted @ 2024-06-30 15:36  绝不原创的飞龙  阅读(11)  评论(0编辑  收藏  举报