Transformers-源码解析-四十七-

Transformers 源码解析(四十七)

.\models\esm\openfold_utils\chunk_utils.py

# 导入日志和数学库
import logging
import math
# 导入偏函数模块
from functools import partial
# 导入类型提示
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union

# 导入 PyTorch 库
import torch

# 导入自定义模块中的函数
from .tensor_utils import tensor_tree_map, tree_map

# 定义一个函数,根据树形结构获取张量的维度信息
def _fetch_dims(tree: Union[dict, list, tuple, torch.Tensor]) -> List[Tuple[int, ...]]:
    # 初始化空列表,用于存储所有维度信息
    shapes = []
    # 如果输入是字典,则递归获取每个值的维度信息
    if isinstance(tree, dict):
        for v in tree.values():
            shapes.extend(_fetch_dims(v))
    # 如果输入是列表或元组,则递归获取每个元素的维度信息
    elif isinstance(tree, (list, tuple)):
        for t in tree:
            shapes.extend(_fetch_dims(t))
    # 如果输入是 PyTorch 张量,则获取其维度信息并添加到列表中
    elif isinstance(tree, torch.Tensor):
        shapes.append(tree.shape)
    else:
        # 如果输入类型不支持,则抛出 ValueError
        raise ValueError("Not supported")

    # 返回所有获取到的维度信息列表
    return shapes


# 使用 Torch 的 JIT 功能忽略该函数,不进行 JIT 编译
@torch.jit.ignore
# 定义一个函数,将扁平索引转换为多维索引
def _flat_idx_to_idx(flat_idx: int, dims: Tuple[int, ...]) -> Tuple[int, ...]:
    # 初始化空列表,用于存储多维索引
    idx = []
    # 从后向前遍历维度元组
    for d in reversed(dims):
        # 将当前扁平索引对应的维度索引加入到列表中
        idx.append(flat_idx % d)
        # 更新扁平索引,准备处理下一个维度
        flat_idx = flat_idx // d

    # 返回反转后的多维索引元组
    return tuple(reversed(idx))


# 使用 Torch 的 JIT 功能忽略该函数,不进行 JIT 编译
@torch.jit.ignore
# 定义一个函数,获取最小的切片集合
def _get_minimal_slice_set(
    start: Sequence[int],
    end: Sequence[int],
    dims: Sequence[int],
    start_edges: Optional[Sequence[bool]] = None,
    end_edges: Optional[Sequence[bool]] = None,
) -> List[Tuple[slice, ...]]:
    """
    Produces an ordered sequence of tensor slices that, when used in sequence on a tensor with shape dims, yields
    tensors that contain every leaf in the contiguous range [start, end]. Care is taken to yield a short sequence of
    slices, and perhaps even the shortest possible (I'm pretty sure it's the latter).

    end is INCLUSIVE.
    """

    # 如果未提供起始边缘信息,则初始化为从顶部开始的边缘
    if start_edges is None:
        start_edges = [s == 0 for s in start]
        # 减少边缘列表,确保每个维度是否为顶部边缘
        reduce_edge_list(start_edges)
    
    # 如果未提供结束边缘信息,则初始化为从底部结束的边缘
    if end_edges is None:
        end_edges = [e == (d - 1) for e, d in zip(end, dims)]
        # 减少边缘列表,确保每个维度是否为底部边缘
        reduce_edge_list(end_edges)

    # 基本情况:如果起始索引为空,则返回空切片元组
    if len(start) == 0:
        return [()]
    # 如果起始和结束的维度长度为1,直接返回包含该范围的切片元组的列表
    elif len(start) == 1:
        return [(slice(start[0], end[0] + 1),)]

    # 初始化空列表用于存储切片元组
    slices: List[Tuple[slice, ...]] = []
    # 初始化空列表用于存储路径切片
    path_list: List[slice] = []

    # 遍历起始和结束的维度,找出可以直接选择的公共路径
    for s, e in zip(start, end):
        if s == e:
            path_list.append(slice(s, s + 1))  # 如果起始和结束相同,直接选择这一维度的切片
        else:
            break

    # 将路径切片转换为元组
    path: Tuple[slice, ...] = tuple(path_list)
    # 确定分歧点的索引
    divergence_idx = len(path)

    # 如果起始和结束完全相同,直接返回路径切片的列表
    if divergence_idx == len(dims):
        return [path]

    # 定义用于处理上界情况的函数
    def upper() -> Tuple[Tuple[slice, ...], ...]:
        assert start_edges is not None
        assert end_edges is not None

        sdi = start[divergence_idx]
        return tuple(
            path + (slice(sdi, sdi + 1),) + s
            for s in _get_minimal_slice_set(
                start[divergence_idx + 1 :],
                [d - 1 for d in dims[divergence_idx + 1 :]],
                dims[divergence_idx + 1 :],
                start_edges=start_edges[divergence_idx + 1 :],
                end_edges=[True for _ in end_edges[divergence_idx + 1 :]],
            )
        )

    # 定义用于处理下界情况的函数
    def lower() -> Tuple[Tuple[slice, ...], ...]:
        assert start_edges is not None
        assert end_edges is not None

        edi = end[divergence_idx]
        return tuple(
            path + (slice(edi, edi + 1),) + s
            for s in _get_minimal_slice_set(
                [0 for _ in start[divergence_idx + 1 :]],
                end[divergence_idx + 1 :],
                dims[divergence_idx + 1 :],
                start_edges=[True for _ in start_edges[divergence_idx + 1 :]],
                end_edges=end_edges[divergence_idx + 1 :],
            )
        )

    # 如果起始和结束都在分叉点的子树边缘上,直接选择整个子树的切片
    if start_edges[divergence_idx] and end_edges[divergence_idx]:
        slices.append(path + (slice(start[divergence_idx], end[divergence_idx] + 1),))
    # 如果只有起始在边缘上,选择几乎整个子树,最后一个边缘情况单独处理
    elif start_edges[divergence_idx]:
        slices.append(path + (slice(start[divergence_idx], end[divergence_idx]),))
        slices.extend(lower())
    # 如果只有结束在边缘上,选择上半部分子树,最后一个边缘情况单独处理
    elif end_edges[divergence_idx]:
        slices.extend(upper())
        slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),))
    # 如果起始和结束都不在边缘上,需要分别处理两边,中间部分可以一次性索引
    else:
        slices.extend(upper())
        middle_ground = end[divergence_idx] - start[divergence_idx]
        if middle_ground > 1:
            slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx]),))
        slices.extend(lower())

    return slices
@torch.jit.ignore
# 标记此函数在Torch的即时编译中应被忽略
def _chunk_slice(t: torch.Tensor, flat_start: int, flat_end: int, no_batch_dims: int) -> torch.Tensor:
    """
    Equivalent to

        t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]

    but without the need for the initial reshape call, which can be memory-intensive in certain situations. The only
    reshape operations in this function are performed on sub-tensors that scale with (flat_end - flat_start), the chunk
    size.
    """
    # 将输入张量的批处理维度保存到batch_dims中
    batch_dims = t.shape[:no_batch_dims]
    # 将flat_start转换为索引,_flat_idx_to_idx返回的是生成器,将其转换为列表
    start_idx = list(_flat_idx_to_idx(flat_start, batch_dims))
    # flat_end - 1转换为索引
    end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims))

    # 获取一个有序的切片列表以执行
    slices = _get_minimal_slice_set(
        start_idx,
        end_idx,
        batch_dims,
    )

    # 对切片后的张量列表进行操作
    sliced_tensors = [t[s] for s in slices]

    # 拼接切片后的张量,并重新调整形状,以匹配原始批处理维度之后的形状
    return torch.cat([s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors])


def chunk_layer(
    layer: Callable,
    inputs: Dict[str, Any],
    chunk_size: int,
    no_batch_dims: int,
    low_mem: bool = False,
    _out: Any = None,
    _add_into_out: bool = False,
) -> Any:
    """
    Implements the "chunking" procedure described in section 1.11.8.

    Layer outputs and inputs are assumed to be simple "pytrees," consisting only of (arbitrarily nested) lists, tuples,
    and dicts with torch.Tensor leaves.

    Args:
        layer:
            The layer to be applied chunk-wise
        inputs:
            A (non-nested) dictionary of keyworded inputs. All leaves must be tensors and must share the same batch
            dimensions.
        chunk_size:
            The number of sub-batches per chunk. If multiple batch dimensions are specified, a "sub-batch" is defined
            as a single indexing of all batch dimensions simultaneously (s.t. the number of sub-batches is the product
            of the batch dimensions).
        no_batch_dims:
            How many of the initial dimensions of each input tensor can be considered batch dimensions.
        low_mem:
            Avoids flattening potentially large input tensors. Unnecessary in most cases, and is ever so slightly
            slower than the default setting.
    Returns:
        The reassembled output of the layer on the inputs.
    """
    # 如果没有提供输入,则引发值错误
    if not (len(inputs) > 0):
        raise ValueError("Must provide at least one input")

    # 从输入中提取初始维度,并确定原始批处理维度
    initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
    orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])

    def _prep_inputs(t: torch.Tensor) -> torch.Tensor:
        # 如果low_mem为False,扩展输入张量的形状以匹配原始批处理维度,并重新调整形状
        if not low_mem:
            if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
                t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
            t = t.reshape(-1, *t.shape[no_batch_dims:])
        else:
            # 否则,仅扩展输入张量的形状以匹配原始批处理维度
            t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
        return t

    # 对输入进行准备处理,并应用于输入字典的所有叶子张量
    prepped_inputs: Dict[str, Any] = tensor_tree_map(_prep_inputs, inputs)
    # 初始化预处理输出为 None
    prepped_outputs = None
    # 如果 _out 不为 None,则对 _out 中的每个张量应用 lambda 函数,将其展平并保留其余维度
    if _out is not None:
        prepped_outputs = tensor_tree_map(lambda t: t.view([-1] + list(t.shape[no_batch_dims:])), _out)

    # 计算扁平化批次维度的初始值为 1
    flat_batch_dim = 1
    # 遍历原始批次维度列表,计算总的扁平化批次维度
    for d in orig_batch_dims:
        flat_batch_dim *= d

    # 计算需要的块数,即扁平化批次维度除以块大小,如果有余数则增加一个块
    no_chunks = flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0)

    # 定义一个函数,用于从张量中选择一个块
    def _select_chunk(t: torch.Tensor) -> torch.Tensor:
        return t[i : i + chunk_size] if t.shape[0] != 1 else t

    # 初始化块的起始索引 i 为 0,输出 out 为预处理输出
    i = 0
    out = prepped_outputs
    # 对于每个块的迭代
    for _ in range(no_chunks):
        # 如果不使用低内存选项,则选择的块为 _select_chunk 函数,否则为 _chunk_slice 函数的部分应用
        if not low_mem:
            select_chunk = _select_chunk
        else:
            select_chunk = partial(
                _chunk_slice,
                flat_start=i,
                flat_end=min(flat_batch_dim, i + chunk_size),
                no_batch_dims=len(orig_batch_dims),
            )

        # 对预处理输入的每个张量应用 select_chunk 函数,得到块的字典 chunks
        chunks: Dict[str, Any] = tensor_tree_map(select_chunk, prepped_inputs)

        # 在当前块上运行层操作,得到输出块 output_chunk
        output_chunk = layer(**chunks)

        # 如果输出 out 为 None,则根据 output_chunk 的形状创建全零张量作为 out
        if out is None:
            out = tensor_tree_map(lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]), output_chunk)

        # 将 output_chunk 中的数据放入预先分配的空间中
        if isinstance(output_chunk, dict):
            # 如果 output_chunk 是字典,则递归地将其值分配给 out 中对应的项
            def assign(d1: dict, d2: dict) -> None:
                for k, v in d1.items():
                    if isinstance(v, dict):
                        assign(v, d2[k])
                    else:
                        # 根据 _add_into_out 标志选择是否加法或赋值操作
                        if _add_into_out:
                            v[i : i + chunk_size] += d2[k]
                        else:
                            v[i : i + chunk_size] = d2[k]

            assign(out, output_chunk)
        elif isinstance(output_chunk, tuple):
            # 如果 output_chunk 是元组,则对应元素逐一处理
            for x1, x2 in zip(out, output_chunk):
                if _add_into_out:
                    x1[i : i + chunk_size] += x2
                else:
                    x1[i : i + chunk_size] = x2
        elif isinstance(output_chunk, torch.Tensor):
            # 如果 output_chunk 是张量,则根据 _add_into_out 标志选择是否加法或赋值操作
            if _add_into_out:
                out[i : i + chunk_size] += output_chunk
            else:
                out[i : i + chunk_size] = output_chunk
        else:
            # 如果 output_chunk 类型不支持,则引发错误
            raise ValueError("Not supported")

        # 更新块的起始索引 i
        i += chunk_size

    # 将 out 中的每个张量重新调整形状,恢复原始批次维度
    out = tensor_tree_map(lambda t: t.view(orig_batch_dims + t.shape[1:]), out)

    # 返回最终的输出结果 out
    return out
    # 定义一个用于调整块大小的类
    class ChunkSizeTuner:
        def __init__(
            self,
            # 最大块大小,默认为512,基于实验观察到大多数模型在所有GPU上的运行时会在此之前达到平台期。
            max_chunk_size: int = 512,
        ):
            self.max_chunk_size = max_chunk_size
            self.cached_chunk_size: Optional[int] = None  # 缓存的块大小,初始为None
            self.cached_arg_data: Optional[tuple] = None  # 缓存的参数数据,初始为None

        def _determine_favorable_chunk_size(self, fn: Callable, args: tuple, min_chunk_size: int) -> int:
            # 记录调整块大小的过程
            logging.info("Tuning chunk size...")

            # 如果最小块大小已经大于等于最大块大小,直接返回最小块大小
            if min_chunk_size >= self.max_chunk_size:
                return min_chunk_size

            # 创建候选块大小列表,从最小块大小开始到不超过最大块大小,以2的指数增长
            candidates: List[int] = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)]
            candidates = [c for c in candidates if c > min_chunk_size]
            candidates = [min_chunk_size] + candidates
            candidates[-1] += 4

            # 测试每个候选块大小是否可行
            def test_chunk_size(chunk_size: int) -> bool:
                try:
                    with torch.no_grad():
                        fn(*args, chunk_size=chunk_size)
                    return True
                except RuntimeError:
                    return False

            # 初始化最小可行块大小的索引
            min_viable_chunk_size_index = 0
            i = len(candidates) - 1
            # 二分搜索找到最小的可行块大小
            while i > min_viable_chunk_size_index:
                viable = test_chunk_size(candidates[i])
                if not viable:
                    i = (min_viable_chunk_size_index + i) // 2
                else:
                    min_viable_chunk_size_index = i
                    i = (i + len(candidates) - 1) // 2

            # 返回最小的可行块大小
            return candidates[min_viable_chunk_size_index]

        def _compare_arg_caches(self, ac1: Iterable, ac2: Iterable) -> bool:
            # 比较两个参数缓存是否一致
            consistent = True
            for a1, a2 in zip(ac1, ac2):
                assert type(ac1) == type(ac2)
                if isinstance(ac1, (list, tuple)):
                    consistent &= self._compare_arg_caches(a1, a2)
                elif isinstance(ac1, dict):
                    # 将字典按键排序后比较值是否一致
                    a1_items = [v for _, v in sorted(a1.items(), key=lambda x: x[0])]
                    a2_items = [v for _, v in sorted(a2.items(), key=lambda x: x[0])]
                    consistent &= self._compare_arg_caches(a1_items, a2_items)
                else:
                    consistent &= a1 == a2

            return consistent

        def tune_chunk_size(
            self,
            representative_fn: Callable,
            args: tuple,
            min_chunk_size: int,
        ) -> int:
        # 定义一个方法,其返回类型为整数
        consistent = True
        # 初始化一个布尔变量 consistent 为 True
        arg_data: tuple = tree_map(lambda a: a.shape if isinstance(a, torch.Tensor) else a, args, object)
        # 使用 tree_map 函数,将参数 args 中的每个元素转换成其形状(如果是 torch.Tensor 对象)或原始值,并存储在 arg_data 中
        if self.cached_arg_data is not None:
            # 如果已经有缓存的参数数据
            assert len(self.cached_arg_data) == len(arg_data)
            # 断言缓存的参数数据与当前参数数据的长度相等
            consistent = self._compare_arg_caches(self.cached_arg_data, arg_data)
            # 调用 _compare_arg_caches 方法比较缓存的参数数据和当前参数数据,更新 consistent 变量
        else:
            # 如果没有缓存的参数数据
            # 此时需要重新计算
            consistent = False

        if not consistent:
            # 如果参数数据不一致
            self.cached_chunk_size = self._determine_favorable_chunk_size(
                representative_fn,
                args,
                min_chunk_size,
            )
            # 调用 _determine_favorable_chunk_size 方法计算出合适的块大小,并存储在 cached_chunk_size 中
            self.cached_arg_data = arg_data
            # 更新缓存的参数数据为当前参数数据

        assert self.cached_chunk_size is not None
        # 断言 cached_chunk_size 不为 None

        return self.cached_chunk_size
        # 返回 cached_chunk_size

.\models\esm\openfold_utils\data_transforms.py

# 导入必要的库和模块
from typing import Dict  # 导入类型提示 Dict

import numpy as np  # 导入 numpy 库
import torch  # 导入 PyTorch 库

from . import residue_constants as rc  # 导入当前包中的 residue_constants 模块
from .tensor_utils import tensor_tree_map, tree_map  # 导入当前包中的 tensor_tree_map 和 tree_map 函数

def make_atom14_masks(protein: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    """构建更密集的原子位置掩码(14维而非37维)。"""
    # 初始化三个空列表来存储不同的映射和掩码
    restype_atom14_to_atom37_list = []
    restype_atom37_to_atom14_list = []
    restype_atom14_mask_list = []

    # 遍历所有氨基酸类型
    for rt in rc.restypes:
        # 获取对应氨基酸类型的14维原子名称列表
        atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
        
        # 创建从14维原子到37维原子的映射列表
        restype_atom14_to_atom37_list.append([(rc.atom_order[name] if name else 0) for name in atom_names])
        
        # 创建从37维原子到14维原子的映射列表
        atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
        restype_atom37_to_atom14_list.append(
            [(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) for name in rc.atom_types]
        )

        # 创建当前氨基酸类型的14维原子掩码列表
        restype_atom14_mask_list.append([(1.0 if name else 0.0) for name in atom_names])

    # 添加 'UNK' 类型的虚拟映射和掩码
    restype_atom14_to_atom37_list.append([0] * 14)
    restype_atom37_to_atom14_list.append([0] * 37)
    restype_atom14_mask_list.append([0.0] * 14)

    # 将映射列表转换为 PyTorch 张量
    restype_atom14_to_atom37 = torch.tensor(
        restype_atom14_to_atom37_list,
        dtype=torch.int32,
        device=protein["aatype"].device,
    )
    restype_atom37_to_atom14 = torch.tensor(
        restype_atom37_to_atom14_list,
        dtype=torch.int32,
        device=protein["aatype"].device,
    )
    restype_atom14_mask = torch.tensor(
        restype_atom14_mask_list,
        dtype=torch.float32,
        device=protein["aatype"].device,
    )
    
    # 将 protein 字典中的 "aatype" 键的值转换为长整型
    protein_aatype = protein["aatype"].to(torch.long)

    # 创建 (残基索引, 14维原子) --> 37维原子 的映射索引数组
    residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype]
    # 创建 14维原子掩码数组
    residx_atom14_mask = restype_atom14_mask[protein_aatype]

    # 将结果存储回 protein 字典中的相应键
    protein["atom14_atom_exists"] = residx_atom14_mask
    protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()

    # 创建用于反向映射的索引数组
    residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype]
    protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()

    # 创建相应的掩码
    # 创建一个形状为 [21, 37] 的全零张量,数据类型为 32 位浮点数,存储在指定设备上(由 protein["aatype"].device 决定)
    restype_atom37_mask = torch.zeros([21, 37], dtype=torch.float32, device=protein["aatype"].device)
    
    # 遍历 rc.restypes 列表,同时追踪其索引和对应的单字母表示 restype_letter
    for restype, restype_letter in enumerate(rc.restypes):
        # 使用 rc.restype_1to3 字典将单字母表示转换为三字母表示
        restype_name = rc.restype_1to3[restype_letter]
        # 获取当前氨基酸类型对应的原子名列表
        atom_names = rc.residue_atoms[restype_name]
        # 遍历当前氨基酸类型的原子名列表
        for atom_name in atom_names:
            # 使用 rc.atom_order 字典获取原子名对应的类型编号
            atom_type = rc.atom_order[atom_name]
            # 在 restype_atom37_mask 张量中,标记当前氨基酸类型的指定原子类型存在(设为 1)
            restype_atom37_mask[restype, atom_type] = 1
    
    # 根据 protein_aatype 中的索引,选择相应的原子存在掩码,并赋值给 protein 字典中的 "atom37_atom_exists" 键
    residx_atom37_mask = restype_atom37_mask[protein_aatype]
    protein["atom37_atom_exists"] = residx_atom37_mask
    
    # 返回更新后的 protein 字典
    return protein
# 定义函数,接受一个字典类型的参数 batch,值为 torch.Tensor 类型,返回值也是字典类型,其值为 np.ndarray 类型
def make_atom14_masks_np(batch: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:
    # 使用 tree_map 函数,将 batch 中的每个值转换为 torch.tensor 类型,设备为 batch["aatype"].device
    batch = tree_map(lambda n: torch.tensor(n, device=batch["aatype"].device), batch, np.ndarray)
    # 使用 tensor_tree_map 函数,对 make_atom14_masks(batch) 的结果进行处理,将其中每个 torch.Tensor 转换为 np.array
    out = tensor_tree_map(lambda t: np.array(t), make_atom14_masks(batch))
    # 返回处理后的结果 out,其中包含了每个键对应的 np.ndarray 数据
    return out

.\models\esm\openfold_utils\feats.py

# 导入必要的模块和类型声明
from typing import Dict, Tuple, overload
import torch
import torch.types
from torch import nn

# 导入自定义模块和函数
from . import residue_constants as rc
from .rigid_utils import Rigid, Rotation
from .tensor_utils import batched_gather

# 定义一个函数重载,接受 torch.Tensor 类型参数并返回 torch.Tensor 类型结果
@overload
def pseudo_beta_fn(aatype: torch.Tensor, all_atom_positions: torch.Tensor, all_atom_masks: None) -> torch.Tensor:
    ...

# 定义另一个函数重载,接受 torch.Tensor 类型参数并返回元组 (torch.Tensor, torch.Tensor)
@overload
def pseudo_beta_fn(
    aatype: torch.Tensor, all_atom_positions: torch.Tensor, all_atom_masks: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    ...

# 实现函数 pseudo_beta_fn,根据输入参数计算伪β原子位置
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
    # 检查是否为甘氨酸类型
    is_gly = aatype == rc.restype_order["G"]
    # 确定 CA 和 CB 的索引
    ca_idx = rc.atom_order["CA"]
    cb_idx = rc.atom_order["CB"]
    # 根据是否为甘氨酸选择 CA 或 CB 的坐标作为伪β原子的位置
    pseudo_beta = torch.where(
        is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3),
        all_atom_positions[..., ca_idx, :],
        all_atom_positions[..., cb_idx, :],
    )

    # 如果提供了原子掩码,则根据甘氨酸类型选择相应的掩码
    if all_atom_masks is not None:
        pseudo_beta_mask = torch.where(
            is_gly,
            all_atom_masks[..., ca_idx],
            all_atom_masks[..., cb_idx],
        )
        return pseudo_beta, pseudo_beta_mask
    else:
        return pseudo_beta

# 定义函数 atom14_to_atom37,将 14 个原子数据映射为 37 个原子数据
def atom14_to_atom37(atom14: torch.Tensor, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
    # 使用 batched_gather 函数将 atom14 数据转换为 atom37 数据
    atom37_data = batched_gather(
        atom14,
        batch["residx_atom37_to_atom14"],
        dim=-2,
        no_batch_dims=len(atom14.shape[:-2]),
    )

    # 将不存在的原子位置数据置零
    atom37_data = atom37_data * batch["atom37_atom_exists"][..., None]

    return atom37_data

# 定义函数 build_template_angle_feat,构建模板角度特征
def build_template_angle_feat(template_feats: Dict[str, torch.Tensor]) -> torch.Tensor:
    # 获取模板的氨基酸类型和角度正弦余弦值
    template_aatype = template_feats["template_aatype"]
    torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"]
    alt_torsion_angles_sin_cos = template_feats["template_alt_torsion_angles_sin_cos"]
    torsion_angles_mask = template_feats["template_torsion_angles_mask"]
    # 构建模板角度特征张量,包括氨基酸独热编码、主要和备选的角度正弦余弦值以及角度掩码
    template_angle_feat = torch.cat(
        [
            nn.functional.one_hot(template_aatype, 22),
            torsion_angles_sin_cos.reshape(*torsion_angles_sin_cos.shape[:-2], 14),
            alt_torsion_angles_sin_cos.reshape(*alt_torsion_angles_sin_cos.shape[:-2], 14),
            torsion_angles_mask,
        ],
        dim=-1,
    )

    return template_angle_feat

# 定义函数 build_template_pair_feat,构建模板对特征
def build_template_pair_feat(
    batch: Dict[str, torch.Tensor],
    # 继续下一个函数定义
    min_bin: torch.types.Number,
    # 定义变量 min_bin,用于存储最小 bin 的值,类型为 torch.types.Number
    max_bin: torch.types.Number,
    # 定义变量 max_bin,用于存储最大 bin 的值,类型为 torch.types.Number
    no_bins: int,
    # 定义变量 no_bins,用于存储 bin 的数量,类型为整数 int
    use_unit_vector: bool = False,
    # 定义变量 use_unit_vector,用于指示是否使用单位向量,默认为 False,类型为布尔值 bool
    eps: float = 1e-20,
    # 定义变量 eps,用于存储一个小的正数值,用作数值稳定性的参数,默认为 1e-20,类型为浮点数 float
    inf: float = 1e8,
    # 定义变量 inf,用于表示一个较大的数,通常用作无穷大的近似值,默认为 1e8,类型为浮点数 float
def torsion_angles_to_frames(
    r: Rigid,
    alpha: torch.Tensor,
    aatype: torch.Tensor,
    rrgdf: torch.Tensor,
) -> Rigid:
    # [*, N, 8, 4, 4]
    # 从 rrgdf 中根据氨基酸类型选择默认的 4x4 变换矩阵
    default_4x4 = rrgdf[aatype, ...]

    # [*, N, 8] transformations, i.e.
    #   One [*, N, 8, 3, 3] rotation matrix and
    #   One [*, N, 8, 3]    translation matrix
    # 从 default_4x4 创建 Rigid 对象,包括旋转矩阵和平移矩阵
    default_r = r.from_tensor_4x4(default_4x4)

    # 创建一个新的形状与 alpha 一致的零张量,最后两个维度为 2,表示二维旋转信息
    bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
    bb_rot[..., 1] = 1

    # [*, N, 8, 2]
    # 将 bb_rot 在第二维度扩展,与 alpha 连接,形成新的张量
    alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2)

    # 创建一个全零张量 all_rots,形状与 default_r 的旋转矩阵形状相同
    all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
    # 设置旋转矩阵的部分值,形成类似如下结构的旋转矩阵:
    # [
    #   [1, 0  , 0  ],
    #   [0, a_2,-a_1],
    #   [0, a_1, a_2]
    # ]
    # 这与原始代码保持一致,而不是附加的文档中所用的不同索引方式。

    all_rots[..., 0, 0] = 1
    all_rots[..., 1, 1] = alpha[..., 1]
    all_rots[..., 1, 2] = -alpha[..., 0]
    all_rots[..., 2, 1:] = alpha

    # 使用 default_r 构造所有帧的刚体变换
    all_frames = default_r.compose(Rigid(Rotation(rot_mats=all_rots), None))

    # 从所有帧中提取帧到帧的转移矩阵的特定部分
    chi2_frame_to_frame = all_frames[..., 5]
    chi3_frame_to_frame = all_frames[..., 6]
    chi4_frame_to_frame = all_frames[..., 7]

    # 从所有帧中提取帧到背骨坐标系的转移矩阵的特定部分
    chi1_frame_to_bb = all_frames[..., 4]
    chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
    chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
    chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)

    # 将所有帧的刚体变换连接成一个新的刚体变换序列 all_frames_to_bb
    all_frames_to_bb = Rigid.cat(
        [
            all_frames[..., :5],
            chi2_frame_to_bb.unsqueeze(-1),
            chi3_frame_to_bb.unsqueeze(-1),
            chi4_frame_to_bb.unsqueeze(-1),
        ],
        dim=-1,
    )

    # 将所有帧的刚体变换 all_frames_to_bb 与 r 的全局变换连接起来,形成最终的全局变换
    all_frames_to_global = r[..., None].compose(all_frames_to_bb)

    # 返回最终的全局变换结果
    return all_frames_to_global
# 将 group_idx 按照 aatype 中的索引值进行索引,得到形状为 [*, N, 14] 的掩码
group_mask = group_idx[aatype, ...]

# 使用 nn.functional.one_hot 函数将 group_mask 转换为 one-hot 编码,形状为 [*, N, 14, 8],
# 其中 8 是 default_frames.shape[-3] 的值,表示类别数量
group_mask_one_hot: torch.LongTensor = nn.functional.one_hot(
    group_mask,
    num_classes=default_frames.shape[-3],
)

# 将旋转矩阵 r 与 group_mask_one_hot 相乘,扩展维度以适应广播规则,得到形状为 [*, N, 14, 3] 的张量 t_atoms_to_global
t_atoms_to_global = r[..., None, :] * group_mask_one_hot

# 对 t_atoms_to_global 在最后一个维度上求和,得到形状为 [*, N, 14] 的张量
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1))

# 根据 aatype 中的索引值,获取对应的 atom_mask,然后在最后一个维度上添加一个维度,形状变为 [*, N, 14, 1]
atom_mask = atom_mask[aatype, ...].unsqueeze(-1)

# 根据 aatype 中的索引值,获取对应的 lit_positions,形状为 [*, N, 14, 3]
lit_positions = lit_positions[aatype, ...]

# 将 lit_positions 应用到 t_atoms_to_global 上,得到预测的位置 pred_positions,形状为 [*, N, 14, 3]
pred_positions = t_atoms_to_global.apply(lit_positions)

# 将预测的位置 pred_positions 与 atom_mask 相乘,使得未激活的原子位置为零,形状不变 [*, N, 14, 3]
pred_positions = pred_positions * atom_mask

# 返回预测的原子位置 pred_positions
return pred_positions

.\models\esm\openfold_utils\loss.py

# 引入必要的模块和类型定义
from typing import Dict, Optional, Tuple
import torch

# 计算直方图的中心点
def _calculate_bin_centers(boundaries: torch.Tensor) -> torch.Tensor:
    step = boundaries[1] - boundaries[0]  # 计算边界间隔
    bin_centers = boundaries + step / 2  # 计算直方图的中心点
    bin_centers = torch.cat([bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0)  # 添加最后一个中心点
    return bin_centers

# 计算期望的对齐误差
def _calculate_expected_aligned_error(
    alignment_confidence_breaks: torch.Tensor,
    aligned_distance_error_probs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    bin_centers = _calculate_bin_centers(alignment_confidence_breaks)  # 调用计算中心点的函数
    return (
        torch.sum(aligned_distance_error_probs * bin_centers, dim=-1),  # 计算期望的对齐距离误差
        bin_centers[-1],  # 返回最后一个中心点作为最大值
    )

# 计算预测的对齐误差
def compute_predicted_aligned_error(
    logits: torch.Tensor,
    max_bin: int = 31,
    no_bins: int = 64,
    **kwargs,
) -> Dict[str, torch.Tensor]:
    """从对数输出计算对齐信心度度量。

    Args:
      logits: [*, num_res, num_res, num_bins] PredictedAlignedErrorHead 输出的对数。
      max_bin: 最大 bin 值
      no_bins: bin 的数量
    Returns:
      aligned_confidence_probs: [*, num_res, num_res, num_bins] 每个残基对的预测对齐误差概率。
      predicted_aligned_error: [*, num_res, num_res] 每对残基的预期对齐距离误差。
      max_predicted_aligned_error: [*] 可能的最大预测误差。
    """
    boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device)  # 生成边界值

    aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1)  # 对 logits 进行 softmax 处理得到对齐信心概率
    predicted_aligned_error, max_predicted_aligned_error = _calculate_expected_aligned_error(
        alignment_confidence_breaks=boundaries,
        aligned_distance_error_probs=aligned_confidence_probs,
    )

    return {
        "aligned_confidence_probs": aligned_confidence_probs,  # 返回对齐信心概率
        "predicted_aligned_error": predicted_aligned_error,  # 返回预测的对齐误差
        "max_predicted_aligned_error": max_predicted_aligned_error,  # 返回最大预测误差
    }

# 计算 TM 分数
def compute_tm(
    logits: torch.Tensor,
    residue_weights: Optional[torch.Tensor] = None,
    max_bin: int = 31,
    no_bins: int = 64,
    eps: float = 1e-8,
    **kwargs,
) -> torch.Tensor:
    if residue_weights is None:
        residue_weights = logits.new_ones(logits.shape[-2])  # 如果残基权重为空,则初始化为全1张量
    # 在指定设备上生成一个包含从0到max_bin的等间距分割的张量边界
    boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device)

    # 根据边界计算分箱的中心点
    bin_centers = _calculate_bin_centers(boundaries)

    # 计算残差权重的总和,但是没有将结果赋给任何变量或者使用它
    torch.sum(residue_weights)

    # 获取logits张量的倒数第二维度的大小
    n = logits.shape[-2]

    # 将n与19比较取较大值,并赋给clipped_n
    clipped_n = max(n, 19)

    # 根据公式计算d0的值
    d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8

    # 对logits张量在最后一个维度上进行softmax操作,得到概率值
    probs = torch.nn.functional.softmax(logits, dim=-1)

    # 计算每个分箱的时间项值
    tm_per_bin = 1.0 / (1 + (bin_centers**2) / (d0**2))

    # 计算预测的时间项,即概率加权后每个分箱的加权平均
    predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)

    # 计算归一化的残差掩码,即残差权重除以其总和加上一个极小值eps
    normed_residue_mask = residue_weights / (eps + residue_weights.sum())

    # 计算每个对齐的时间项加权和
    per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)

    # 计算加权的对齐时间项乘以残差权重
    weighted = per_alignment * residue_weights

    # 找出加权项中值最大的索引
    argmax = (weighted == torch.max(weighted)).nonzero()[0]

    # 返回加权后对齐时间项中值最大的那个值
    return per_alignment[tuple(argmax)]

.\models\esm\openfold_utils\protein.py

# 导入必要的模块和库
import dataclasses  # 用于创建不可变数据类
import re  # 用于正则表达式操作
import string  # 包含字符串相关的常量和函数
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple  # 导入类型提示相关的声明

import numpy as np  # 数组操作库

from . import residue_constants  # 导入本地模块 residue_constants

FeatureDict = Mapping[str, np.ndarray]  # 定义 FeatureDict 类型别名,表示一个字符串到 NumPy 数组的映射
ModelOutput = Mapping[str, Any]  # 定义 ModelOutput 类型别名,表示一个字符串到任意类型的映射,通常是嵌套字典

PICO_TO_ANGSTROM = 0.01  # 定义常量 PICO_TO_ANGSTROM,用于从皮科米转换为埃

@dataclasses.dataclass(frozen=True)
class Protein:
    """蛋白质结构的表示类。"""

    # 原子的笛卡尔坐标,单位为埃,atom_types 对应 residue_constants.atom_types
    atom_positions: np.ndarray  # 形状为 [num_res, num_atom_type, 3]

    # 每个残基的氨基酸类型,表示为 0 到 20 之间的整数,其中 20 表示 'X'
    aatype: np.ndarray  # 形状为 [num_res]

    # 二进制浮点掩码,指示特定原子的存在性。如果原子存在则为 1.0,否则为 0.0,用于损失掩码
    atom_mask: np.ndarray  # 形状为 [num_res, num_atom_type]

    # 残基在 PDB 中的索引。不一定连续或从零开始索引
    residue_index: np.ndarray  # 形状为 [num_res]

    # 残基的 B 因子或温度因子(单位为平方埃),表示残基与其基本真实均值之间的偏移量
    b_factors: np.ndarray  # 形状为 [num_res, num_atom_type]

    # 多链预测中的链索引
    chain_index: Optional[np.ndarray] = None  # 可选的链索引数组,形状为 [num_res]

    # 关于蛋白质的可选备注,将包含在输出 PDB 文件的注释中
    remark: Optional[str] = None  # 可选的字符串类型的备注信息

    # 用于生成此蛋白质的模板(仅限预测)
    parents: Optional[Sequence[str]] = None  # 可选的字符串序列,表示用于生成蛋白质的模板列表

    # 每个父模板对应的链索引
    parents_chain_index: Optional[Sequence[int]] = None  # 可选的整数序列,表示每个父模板对应的链索引

def from_proteinnet_string(proteinnet_str: str) -> Protein:
    # 匹配标签的正则表达式,如 [XXXX]\n
    tag_re = r"(\[[A-Z]+\]\n)"
    # 使用正则表达式分割蛋白质字符串,得到标签列表
    tags: List[str] = [tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0]
    # 将标签分成组,每个组包含一个标签和相应的数据行列表
    groups: Iterator[Tuple[str, List[str]]] = zip(tags[0::2], [l.split("\n") for l in tags[1::2]])

    atoms: List[str] = ["N", "CA", "C"]  # 原子类型列表,包括 N、CA、C
    aatype = None  # 初始化氨基酸类型变量为 None
    atom_positions = None  # 初始化原子位置变量为 None
    atom_mask = None  # 初始化原子掩码变量为 None
    # 遍历给定的groups列表
    for g in groups:
        # 检查当前组是否为主要结构信息
        if "[PRIMARY]" == g[0]:
            # 提取序列信息并去除首尾空格
            seq = g[1][0].strip()
            # 对序列中每个字符进行检查,如果不在restypes中,则替换为"X"
            for i in range(len(seq)):
                if seq[i] not in residue_constants.restypes:
                    seq[i] = "X"  # FIXME: 字符串是不可变的
            # 根据序列中的氨基酸符号获取其对应的编号,形成NumPy数组
            aatype = np.array(
                [residue_constants.restype_order.get(res_symbol, residue_constants.restype_num) for res_symbol in seq]
            )
        # 检查当前组是否为三维结构信息
        elif "[TERTIARY]" == g[0]:
            # 初始化一个空的三维结构列表
            tertiary: List[List[float]] = []
            # 逐个轴解析三维结构信息并转换为浮点数列表
            for axis in range(3):
                tertiary.append(list(map(float, g[1][axis].split())))
            # 将解析后的三维结构信息转换为NumPy数组
            tertiary_np = np.array(tertiary)
            # 初始化原子位置数组,用于存储原子的坐标信息
            atom_positions = np.zeros((len(tertiary[0]) // 3, residue_constants.atom_type_num, 3)).astype(np.float32)
            # 根据原子顺序和三维结构信息填充原子位置数组
            for i, atom in enumerate(atoms):
                atom_positions[:, residue_constants.atom_order[atom], :] = np.transpose(tertiary_np[:, i::3])
            # 将位置从皮科秒转换为埃
            atom_positions *= PICO_TO_ANGSTROM
        # 检查当前组是否为掩码信息
        elif "[MASK]" == g[0]:
            # 解析掩码信息,将"-"映射为0,将"+"映射为1,存储为NumPy数组
            mask = np.array(list(map({"-": 0, "+": 1}.get, g[1][0].strip())))
            # 初始化原子掩码数组,用于表示原子是否被掩盖
            atom_mask = np.zeros(
                (
                    len(mask),
                    residue_constants.atom_type_num,
                )
            ).astype(np.float32)
            # 根据原子顺序填充原子掩码数组
            for i, atom in enumerate(atoms):
                atom_mask[:, residue_constants.atom_order[atom]] = 1
            # 将掩码数组应用到原子掩码数组上
            atom_mask *= mask[..., None]

    # 断言确保aatype不为空
    assert aatype is not None

    # 返回一个Protein对象,包括原子位置、原子掩码、氨基酸类型、残基索引和B因子信息
    return Protein(
        atom_positions=atom_positions,
        atom_mask=atom_mask,
        aatype=aatype,
        residue_index=np.arange(len(aatype)),
        b_factors=None,
    )
def get_pdb_headers(prot: Protein, chain_id: int = 0) -> List[str]:
    pdb_headers: List[str] = []  # 初始化一个空列表,用于存储 PDB 头部信息

    remark = prot.remark  # 获取蛋白质对象的备注信息
    if remark is not None:  # 如果存在备注信息
        pdb_headers.append(f"REMARK {remark}")  # 将 REMARK 记录添加到 pdb_headers 中

    parents = prot.parents  # 获取蛋白质对象的父对象列表
    parents_chain_index = prot.parents_chain_index  # 获取父对象对应的链索引列表
    if parents is not None and parents_chain_index is not None:  # 如果父对象列表和链索引列表都不为空
        parents = [p for i, p in zip(parents_chain_index, parents) if i == chain_id]  # 筛选出指定链索引的父对象列表

    if parents is None or len(parents) == 0:  # 如果父对象列表为空
        parents = ["N/A"]  # 使用字符串 "N/A" 作为默认父对象

    pdb_headers.append(f"PARENT {' '.join(parents)}")  # 将格式化的 PARENT 记录添加到 pdb_headers 中

    return pdb_headers  # 返回包含 PDB 头部信息的列表


def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
    """Add pdb headers to an existing PDB string. Useful during multi-chain
    recycling
    """
    out_pdb_lines: List[str] = []  # 初始化一个空列表,用于存储输出的 PDB 行

    lines = pdb_str.split("\n")  # 将输入的 PDB 字符串按行拆分成列表

    remark = prot.remark  # 获取蛋白质对象的备注信息
    if remark is not None:  # 如果存在备注信息
        out_pdb_lines.append(f"REMARK {remark}")  # 将 REMARK 记录添加到输出列表中

    parents_per_chain: List[List[str]]  # 声明一个二维列表,用于存储每条链的父对象列表
    if prot.parents is not None and len(prot.parents) > 0:  # 如果存在父对象列表且不为空
        parents_per_chain = []  # 初始化空的链列表
        if prot.parents_chain_index is not None:  # 如果存在链索引列表
            parent_dict: Dict[str, List[str]] = {}  # 创建一个字典,用于按链索引存储父对象
            for p, i in zip(prot.parents, prot.parents_chain_index):
                parent_dict.setdefault(str(i), [])  # 如果索引不存在则创建新列表,存在则不变
                parent_dict[str(i)].append(p)  # 将父对象添加到对应索引的列表中

            max_idx = max([int(chain_idx) for chain_idx in parent_dict])  # 获取最大的链索引
            for i in range(max_idx + 1):  # 遍历每个可能的链索引
                chain_parents = parent_dict.get(str(i), ["N/A"])  # 获取链索引对应的父对象列表或默认为 ["N/A"]
                parents_per_chain.append(chain_parents)  # 将该链的父对象列表添加到父对象列表中
        else:
            parents_per_chain.append(list(prot.parents))  # 如果没有链索引列表,则将整个父对象列表作为单链的父对象列表
    else:
        parents_per_chain = [["N/A"]]  # 如果不存在父对象列表,则将默认父对象列表作为单链的父对象列表

    def make_parent_line(p: Sequence[str]) -> str:  # 定义生成 PARENT 记录行的函数
        return f"PARENT {' '.join(p)}"  # 返回格式化的 PARENT 记录行

    out_pdb_lines.append(make_parent_line(parents_per_chain[0]))  # 将第一条链的 PARENT 记录行添加到输出列表中

    chain_counter = 0  # 初始化链计数器
    for i, l in enumerate(lines):  # 遍历输入的 PDB 行
        if "PARENT" not in l and "REMARK" not in l:  # 如果当前行不包含 PARENT 或 REMARK 记录
            out_pdb_lines.append(l)  # 将当前行添加到输出列表中
        if "TER" in l and "END" not in lines[i + 1]:  # 如果当前行包含 TER 记录且下一行不包含 END 记录
            chain_counter += 1  # 链计数器加一
            if not chain_counter >= len(parents_per_chain):  # 如果链计数器小于等于父对象列表的长度
                chain_parents = parents_per_chain[chain_counter]  # 获取下一条链的父对象列表
            else:
                chain_parents = ["N/A"]  # 否则使用默认的父对象列表

            out_pdb_lines.append(make_parent_line(chain_parents))  # 将下一条链的 PARENT 记录行添加到输出列表中

    return "\n".join(out_pdb_lines)  # 返回连接成字符串的输出 PDB 行


def to_pdb(prot: Protein) -> str:
    """Converts a `Protein` instance to a PDB string.

    Args:
      prot: The protein to convert to PDB.

    Returns:
      PDB string.
    """
    restypes = residue_constants.restypes + ["X"]  # 将氨基酸类型和额外的 "X" 添加到 restypes 中

    def res_1to3(r: int) -> str:  # 定义从氨基酸单字母码到三字母码的转换函数
        return residue_constants.restype_1to3.get(restypes[r], "UNK")  # 返回单字母码对应的三字母码或 "UNK"

    atom_types = residue_constants.atom_types  # 获取原子类型常量

    pdb_lines: List[str] = []  # 初始化一个空列表,用于存储 PDB 行

    atom_mask = prot.atom_mask  # 获取蛋白质对象的原子掩码
    aatype = prot.aatype  # 获取蛋白质对象的氨基酸类型
    atom_positions = prot.atom_positions  # 获取蛋白质对象的原子位置
    residue_index = prot.residue_index.astype(np.int32)  # 获取蛋白质对象的残基索引,并转换为整数类型
    b_factors = prot.b_factors  # 获取蛋白质对象的 B 因子
    chain_index = prot.chain_index  # 获取蛋白质对象的链索引
    # 检查 aatype 中是否存在大于 residue_constants.restype_num 的任何值
    if np.any(aatype > residue_constants.restype_num):
        # 如果存在,则抛出值错误异常
        raise ValueError("Invalid aatypes.")

    # 获取蛋白质结构的 PDB 文件头信息
    headers = get_pdb_headers(prot)
    # 如果存在头信息,则将其加入到 pdb_lines 中
    if len(headers) > 0:
        pdb_lines.extend(headers)

    # 获取 aatype 数组的长度
    n = aatype.shape[0]
    atom_index = 1  # 初始化原子索引为 1
    prev_chain_index = 0  # 初始化前一个链的索引为 0
    chain_tags = string.ascii_uppercase  # 获取大写字母序列作为链标识符
    chain_tag = None  # 初始化链标识符为 None

    # 添加所有原子位置信息
    # 遍历每个残基
    for i in range(n):
        # 获取残基的三字母缩写
        res_name_3 = res_1to3(aatype[i])
        # 遍历每个原子的类型、位置、掩码、B因子
        for atom_name, pos, mask, b_factor in zip(atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
            # 如果掩码小于 0.5,则跳过当前原子
            if mask < 0.5:
                continue

            record_type = "ATOM"  # 记录类型为 "ATOM"
            name = atom_name if len(atom_name) == 4 else f" {atom_name}"  # 原子名称
            alt_loc = ""  # 替代位置标识为空字符串
            insertion_code = ""  # 插入代码为空字符串
            occupancy = 1.00  # 占用率设置为 1.00
            element = atom_name[0]  # 元素类型,蛋白质仅支持 C, N, O, S
            charge = ""  # 电荷为空字符串

            chain_tag = "A"  # 默认链标识符为 "A"
            # 如果提供了链索引,则使用对应的大写字母作为链标识符
            if chain_index is not None:
                chain_tag = chain_tags[chain_index[i]]

            # 构建 PDB 文件中的原子行信息
            # 注意每个字段的空格分隔是必要的
            atom_line = (
                f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
                f"{res_name_3:>3} {chain_tag:>1}"
                f"{residue_index[i]:>4}{insertion_code:>1}   "
                f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
                f"{occupancy:>6.2f}{b_factor:>6.2f}          "
                f"{element:>2}{charge:>2}"
            )
            # 将原子行信息添加到 pdb_lines 列表中
            pdb_lines.append(atom_line)
            atom_index += 1

        should_terminate = i == n - 1  # 判断是否是最后一个残基
        if chain_index is not None:
            # 如果存在链索引,并且下一个残基的链索引与当前不同,则应终止当前链
            if i != n - 1 and chain_index[i + 1] != prev_chain_index:
                should_terminate = True
                prev_chain_index = chain_index[i + 1]

        if should_terminate:
            # 结束当前链的标识符为 "TER"
            chain_end = "TER"
            # 构建链终止行信息并添加到 pdb_lines 列表中
            chain_termination_line = (
                f"{chain_end:<6}{atom_index:>5}      {res_1to3(aatype[i]):>3} {chain_tag:>1}{residue_index[i]:>4}"
            )
            pdb_lines.append(chain_termination_line)
            atom_index += 1

            # 如果不是最后一个残基,则添加下一个链的头信息到 pdb_lines 列表中
            if i != n - 1:
                # 这里的名称“prev”有误导性,它在每个新链的开头发生。
                pdb_lines.extend(get_pdb_headers(prot, prev_chain_index))

    # 添加 PDB 文件的结束标记和空行
    pdb_lines.append("END")
    pdb_lines.append("")
    # 将 pdb_lines 列表中的所有行连接成一个字符串并返回
    return "\n".join(pdb_lines)
# 根据给定的蛋白质对象计算一个理想的原子掩码
def ideal_atom_mask(prot: Protein) -> np.ndarray:
    """Computes an ideal atom mask.

    `Protein.atom_mask` typically is defined according to the atoms that are reported in the PDB. This function
    computes a mask according to heavy atoms that should be present in the given sequence of amino acids.

    Args:
      prot: `Protein` whose fields are `numpy.ndarray` objects.

    Returns:
      An ideal atom mask.
    """
    # 返回与给定氨基酸序列中标准原子掩码对应的掩码
    return residue_constants.STANDARD_ATOM_MASK[prot.aatype]


# 从预测结果中组装一个蛋白质对象
def from_prediction(
    features: FeatureDict,
    result: ModelOutput,
    b_factors: Optional[np.ndarray] = None,
    chain_index: Optional[np.ndarray] = None,
    remark: Optional[str] = None,
    parents: Optional[Sequence[str]] = None,
    parents_chain_index: Optional[Sequence[int]] = None,
) -> Protein:
    """Assembles a protein from a prediction.

    Args:
      features: Dictionary holding model inputs.
      result: Dictionary holding model outputs.
      b_factors: (Optional) B-factors to use for the protein.
      chain_index: (Optional) Chain indices for multi-chain predictions
      remark: (Optional) Remark about the prediction
      parents: (Optional) List of template names
    Returns:
      A protein instance.
    """
    # 创建一个 Protein 对象并返回,使用给定的特征和模型输出来设置其属性
    return Protein(
        aatype=features["aatype"],
        atom_positions=result["final_atom_positions"],
        atom_mask=result["final_atom_mask"],
        residue_index=features["residue_index"] + 1,
        b_factors=b_factors if b_factors is not None else np.zeros_like(result["final_atom_mask"]),
        chain_index=chain_index,
        remark=remark,
        parents=parents,
        parents_chain_index=parents_chain_index,
    )

.\models\esm\openfold_utils\residue_constants.py

# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Constants used in AlphaFold."""

import collections
import copy
import functools
from importlib import resources
from typing import Dict, List, Mapping, Sequence, Tuple

import numpy as np


# Internal import (35fd).


# Distance from one CA to next CA [trans configuration: omega = 180].
ca_ca = 3.80209737096

# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
# chi angles so their chi angle lists are empty.
chi_angles_atoms: Dict[str, List[List[str]]] = {
    "ALA": [],
    # Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
    "ARG": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"], ["CB", "CG", "CD", "NE"], ["CG", "CD", "NE", "CZ"]],
    "ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
    "ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
    "CYS": [["N", "CA", "CB", "SG"]],
    "GLN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"], ["CB", "CG", "CD", "OE1"]],
    "GLU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"], ["CB", "CG", "CD", "OE1"]],
    "GLY": [],
    "HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]],
    "ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]],
    "LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
    "LYS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"], ["CB", "CG", "CD", "CE"], ["CG", "CD", "CE", "NZ"]],
    "MET": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "SD"], ["CB", "CG", "SD", "CE"]],
    "PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
    "PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]],
    "SER": [["N", "CA", "CB", "OG"]],
    "THR": [["N", "CA", "CB", "OG1"]],
    "TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
    "TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
    "VAL": [["N", "CA", "CB", "CG1"]],
}

# If chi angles given in fixed-length array, this matrix determines how to mask
# them for each AA type. The order is as per restype_order (see below).
chi_angles_mask: List[List[float]] = [
    [0.0, 0.0, 0.0, 0.0],  # ALA
    [1.0, 1.0, 1.0, 1.0],  # ARG
    [1.0, 1.0, 0.0, 0.0],  # ASN
    [1.0, 1.0, 0.0, 0.0],  # ASP
    [1.0, 0.0, 0.0, 0.0],  # CYS
    [1.0, 1.0, 1.0, 0.0],  # GLN
    [1.0, 1.0, 1.0, 0.0],  # GLU
    [0.0, 0.0, 0.0, 0.0],  # GLY
    # The following entries are truncated for brevity.
    # They follow the same pattern of defining chi angle masks for each amino acid type.
    # To maintain code block integrity, they are not fully commented here.
    # Please refer to the original source for detailed explanations.
    # Each sublist corresponds to an amino acid type and its associated chi angle masks.
    # 下面是一个二维列表,每行代表一个氨基酸的属性向量
    [
        [1.0, 1.0, 0.0, 0.0],  # HIS - 氨基酸组合:亮度、橙色、宽度、高度
        [1.0, 1.0, 0.0, 0.0],  # ILE - 氨基酸组合:亮度、橙色、宽度、高度
        [1.0, 1.0, 0.0, 0.0],  # LEU - 氨基酸组合:亮度、橙色、宽度、高度
        [1.0, 1.0, 1.0, 1.0],  # LYS - 氨基酸组合:亮度、橙色、宽度、高度
        [1.0, 1.0, 1.0, 0.0],  # MET - 氨基酸组合:亮度、橙色、宽度、高度
        [1.0, 1.0, 0.0, 0.0],  # PHE - 氨基酸组合:亮度、橙色、宽度、高度
        [1.0, 1.0, 0.0, 0.0],  # PRO - 氨基酸组合:亮度、橙色、宽度、高度
        [1.0, 0.0, 0.0, 0.0],  # SER - 氨基酸组合:亮度、橙色、宽度、高度
        [1.0, 0.0, 0.0, 0.0],  # THR - 氨基酸组合:亮度、橙色、宽度、高度
        [1.0, 1.0, 0.0, 0.0],  # TRP - 氨基酸组合:亮度、橙色、宽度、高度
        [1.0, 1.0, 0.0, 0.0],  # TYR - 氨基酸组合:亮度、橙色、宽度、高度
        [1.0, 0.0, 0.0, 0.0],  # VAL - 氨基酸组合:亮度、橙色、宽度、高度
    ]
# 下面的 chi 角度是 pi 周期性的:它们可以通过多个 pi 的旋转而不影响结构。
# 每一行对应一种氨基酸,列出了其四个 chi 角度的初始值。每个角度是以弧度表示的。
chi_pi_periodic: List[List[float]] = [
    [0.0, 0.0, 0.0, 0.0],  # ALA
    [0.0, 0.0, 0.0, 0.0],  # ARG
    [0.0, 0.0, 0.0, 0.0],  # ASN
    [0.0, 1.0, 0.0, 0.0],  # ASP
    [0.0, 0.0, 0.0, 0.0],  # CYS
    [0.0, 0.0, 0.0, 0.0],  # GLN
    [0.0, 0.0, 1.0, 0.0],  # GLU
    [0.0, 0.0, 0.0, 0.0],  # GLY
    [0.0, 0.0, 0.0, 0.0],  # HIS
    [0.0, 0.0, 0.0, 0.0],  # ILE
    [0.0, 0.0, 0.0, 0.0],  # LEU
    [0.0, 0.0, 0.0, 0.0],  # LYS
    [0.0, 0.0, 0.0, 0.0],  # MET
    [0.0, 1.0, 0.0, 0.0],  # PHE
    [0.0, 0.0, 0.0, 0.0],  # PRO
    [0.0, 0.0, 0.0, 0.0],  # SER
    [0.0, 0.0, 0.0, 0.0],  # THR
    [0.0, 0.0, 0.0, 0.0],  # TRP
    [0.0, 1.0, 0.0, 0.0],  # TYR
    [0.0, 0.0, 0.0, 0.0],  # VAL
    [0.0, 0.0, 0.0, 0.0],  # UNK
]

# 原子的位置相对于8个刚性组的轴端原子,由预omega、phi、psi和chi角度定义:
# 0: '骨架组',
# 1: '预omega组', (空)
# 2: 'phi组', (当前为空,因为它只定义了氢原子)
# 3: 'psi组',
# 4,5,6,7: 'chi1,2,3,4组'
# 原子位置是相对于相应旋转轴的轴端原子的坐标。x轴沿着旋转轴方向,y轴定义为使二面角定义原子(chi_angles_atoms中的最后一个条目)在xy平面上(y坐标为正)。
rigid_group_atom_positions: Dict[str, List[Tuple[str, int, Tuple[float, float, float]]]] = {
    "ALA": [
        ("N", 0, (-0.525, 1.363, 0.000)),
        ("CA", 0, (0.000, 0.000, 0.000)),
        ("C", 0, (1.526, -0.000, -0.000)),
        ("CB", 0, (-0.529, -0.774, -1.205)),
        ("O", 3, (0.627, 1.062, 0.000)),
    ],
    "ARG": [
        ("N", 0, (-0.524, 1.362, -0.000)),
        ("CA", 0, (0.000, 0.000, 0.000)),
        ("C", 0, (1.525, -0.000, -0.000)),
        ("CB", 0, (-0.524, -0.778, -1.209)),
        ("O", 3, (0.626, 1.062, 0.000)),
        ("CG", 4, (0.616, 1.390, -0.000)),
        ("CD", 5, (0.564, 1.414, 0.000)),
        ("NE", 6, (0.539, 1.357, -0.000)),
        ("NH1", 7, (0.206, 2.301, 0.000)),
        ("NH2", 7, (2.078, 0.978, -0.000)),
        ("CZ", 7, (0.758, 1.093, -0.000)),
    ],
    "ASN": [
        ("N", 0, (-0.536, 1.357, 0.000)),
        ("CA", 0, (0.000, 0.000, 0.000)),
        ("C", 0, (1.526, -0.000, -0.000)),
        ("CB", 0, (-0.531, -0.787, -1.200)),
        ("O", 3, (0.625, 1.062, 0.000)),
        ("CG", 4, (0.584, 1.399, 0.000)),
        ("ND2", 5, (0.593, -1.188, 0.001)),
        ("OD1", 5, (0.633, 1.059, 0.000)),
    ],
    # 氨基酸 ASP 的原子坐标信息,每个元素是一个元组,包含原子名称、类型、坐标
    "ASP": [
        ("N", 0, (-0.525, 1.362, -0.000)),  # 氮原子 N,类型 0,坐标 (-0.525, 1.362, -0.000)
        ("CA", 0, (0.000, 0.000, 0.000)),   # 碳α原子 CA,类型 0,坐标 (0.000, 0.000, 0.000)
        ("C", 0, (1.527, 0.000, -0.000)),   # 碳原子 C,类型 0,坐标 (1.527, 0.000, -0.000)
        ("CB", 0, (-0.526, -0.778, -1.208)),# 碳β原子 CB,类型 0,坐标 (-0.526, -0.778, -1.208)
        ("O", 3, (0.626, 1.062, -0.000)),   # 氧原子 O,类型 3,坐标 (0.626, 1.062, -0.000)
        ("CG", 4, (0.593, 1.398, -0.000)),  # 碳γ原子 CG,类型 4,坐标 (0.593, 1.398, -0.000)
        ("OD1", 5, (0.610, 1.091, 0.000)),  # 羟基原子 OD1,类型 5,坐标 (0.610, 1.091, 0.000)
        ("OD2", 5, (0.592, -1.101, -0.003)),# 羟基原子 OD2,类型 5,坐标 (0.592, -1.101, -0.003)
    ],
    # 氨基酸 CYS 的原子坐标信息,每个元素是一个元组,包含原子名称、类型、坐标
    "CYS": [
        ("N", 0, (-0.522, 1.362, -0.000)),
        ("CA", 0, (0.000, 0.000, 0.000)),
        ("C", 0, (1.524, 0.000, 0.000)),
        ("CB", 0, (-0.519, -0.773, -1.212)),
        ("O", 3, (0.625, 1.062, -0.000)),
        ("SG", 4, (0.728, 1.653, 0.000)),
    ],
    # 氨基酸 GLN 的原子坐标信息,每个元素是一个元组,包含原子名称、类型、坐标
    "GLN": [
        ("N", 0, (-0.526, 1.361, -0.000)),
        ("CA", 0, (0.000, 0.000, 0.000)),
        ("C", 0, (1.526, 0.000, 0.000)),
        ("CB", 0, (-0.525, -0.779, -1.207)),
        ("O", 3, (0.626, 1.062, -0.000)),
        ("CG", 4, (0.615, 1.393, 0.000)),
        ("CD", 5, (0.587, 1.399, -0.000)),
        ("NE2", 6, (0.593, -1.189, -0.001)),
        ("OE1", 6, (0.634, 1.060, 0.000)),
    ],
    # 氨基酸 GLU 的原子坐标信息,每个元素是一个元组,包含原子名称、类型、坐标
    "GLU": [
        ("N", 0, (-0.528, 1.361, 0.000)),
        ("CA", 0, (0.000, 0.000, 0.000)),
        ("C", 0, (1.526, -0.000, -0.000)),
        ("CB", 0, (-0.526, -0.781, -1.207)),
        ("O", 3, (0.626, 1.062, 0.000)),
        ("CG", 4, (0.615, 1.392, 0.000)),
        ("CD", 5, (0.600, 1.397, 0.000)),
        ("OE1", 6, (0.607, 1.095, -0.000)),
        ("OE2", 6, (0.589, -1.104, -0.001)),
    ],
    # 氨基酸 GLY 的原子坐标信息,每个元素是一个元组,包含原子名称、类型、坐标
    "GLY": [
        ("N", 0, (-0.572, 1.337, 0.000)),
        ("CA", 0, (0.000, 0.000, 0.000)),
        ("C", 0, (1.517, -0.000, -0.000)),
        ("O", 3, (0.626, 1.062, -0.000)),
    ],
    # 氨基酸 HIS 的原子坐标信息,每个元素是一个元组,包含原子名称、类型、坐标
    "HIS": [
        ("N", 0, (-0.527, 1.360, 0.000)),
        ("CA", 0, (0.000, 0.000, 0.000)),
        ("C", 0, (1.525, 0.000, 0.000)),
        ("CB", 0, (-0.525, -0.778, -1.208)),
        ("O", 3, (0.625, 1.063, 0.000)),
        ("CG", 4, (0.600, 1.370, -0.000)),
        ("CD2", 5, (0.889, -1.021, 0.003)),
        ("ND1", 5, (0.744, 1.160, -0.000)),
        ("CE1", 5, (2.030, 0.851, 0.002)),
        ("NE2", 5, (2.145, -0.466, 0.004)),
    ],
    # 氨基酸 ILE 的原子坐标信息,每个元素是一个元组,包含原子名称、类型、坐标
    "ILE": [
        ("N", 0, (-0.493, 1.373, -0.000)),
        ("CA", 0, (0.000, 0.000, 0.000)),
        ("C", 0, (1.527, -0.000, -0.000)),
        ("CB", 0, (-0.536, -0.793, -1.213)),
        ("O", 3, (0.627, 1.062, -0.000)),
        ("CG1", 4, (0.534, 1.437, -0.000)),
        ("CG2", 4, (0.540, -0.785, -1.199)),
        ("CD1", 5, (0.619, 1.391, 0.000)),
    ],
    # 氨基酸 LEU 的原子坐标信息,每个元素是一个元组,包含原子名称、类型、坐标
    "LEU": [
        ("N", 0, (-0.520, 1.363, 0.000)),
        ("CA", 0, (0.000, 0.000, 0.000)),
        ("C", 0, (1.525, -0.000, -0.000)),
        ("CB", 0, (-0.522, -0.773, -1.214)),
        ("O", 3, (0.625, 1.063, -0.000)),
        ("CG", 4, (0.678, 1.371, 0.000)),
        ("CD1", 5, (0.530, 1.430, -0.000)),
        ("CD2", 5, (0.535, -0.774, 1.200)),
    ],
    "LYS": [
        ("N", 0, (-0.526, 1.362, -0.000)),  # 原子名称"N",电荷状态0,坐标(-0.526, 1.362, -0.000)
        ("CA", 0, (0.000, 0.000, 0.000)),   # 原子名称"CA",电荷状态0,坐标(0.000, 0.000, 0.000)
        ("C", 0, (1.526, 0.000, 0.000)),    # 原子名称"C",电荷状态0,坐标(1.526, 0.000, 0.000)
        ("CB", 0, (-0.524, -0.778, -1.208)),# 原子名称"CB",电荷状态0,坐标(-0.524, -0.778, -1.208)
        ("O", 3, (0.626, 1.062, -0.000)),   # 原子名称"O",电荷状态3,坐标(0.626, 1.062, -0.000)
        ("CG", 4, (0.619, 1.390, 0.000)),   # 原子名称"CG",电荷状态4,坐标(0.619, 1.390, 0.000)
        ("CD", 5, (0.559, 1.417, 0.000)),   # 原子名称"CD",电荷状态5,坐标(0.559, 1.417, 0.000)
        ("CE", 6, (0.560, 1.416, 0.000)),   # 原子名称"CE",电荷状态6,坐标(0.560, 1.416, 0.000)
        ("NZ", 7, (0.554, 1.387, 0.000)),   # 原子名称"NZ",电荷状态7,坐标(0.554, 1.387, 0.000)
    ],
    "MET": [
        ("N", 0, (-0.521, 1.364, -0.000)),  # 原子名称"N",电荷状态0,坐标(-0.521, 1.364, -0.000)
        ("CA", 0, (0.000, 0.000, 0.000)),   # 原子名称"CA",电荷状态0,坐标(0.000, 0.000, 0.000)
        ("C", 0, (1.525, 0.000, 0.000)),    # 原子名称"C",电荷状态0,坐标(1.525, 0.000, 0.000)
        ("CB", 0, (-0.523, -0.776, -1.210)),# 原子名称"CB",电荷状态0,坐标(-0.523, -0.776, -1.210)
        ("O", 3, (0.625, 1.062, -0.000)),   # 原子名称"O",电荷状态3,坐标(0.625, 1.062, -0.000)
        ("CG", 4, (0.613, 1.391, -0.000)),   # 原子名称"CG",电荷状态4,坐标(0.613, 1.391, -0.000)
        ("SD", 5, (0.703, 1.695, 0.000)),    # 原子名称"SD",电荷状态5,坐标(0.703, 1.695, 0.000)
        ("CE", 6, (0.320, 1.786, -0.000)),   # 原子名称"CE",电荷状态6,坐标(0.320, 1.786, -0.000)
    ],
    "PHE": [
        ("N", 0, (-0.518, 1.363, 0.000)),   # 原子名称"N",电荷状态0,坐标(-0.518, 1.363, 0.000)
        ("CA", 0, (0.000, 0.000, 0.000)),   # 原子名称"CA",电荷状态0,坐标(0.000, 0.000, 0.000)
        ("C", 0, (1.524, 0.000, -0.000)),   # 原子名称"C",电荷状态0,坐标(1.524, 0.000, -0.000)
        ("CB", 0, (-0.525, -0.776, -1.212)),# 原子名称"CB",电荷状态0,坐标(-0.525, -0.776, -1.212)
        ("O", 3, (0.626, 1.062, -0.000)),   # 原子名称"O",电荷状态3,坐标(0.626, 1.062, -0.000)
        ("CG", 4, (0.607, 1.377, 0.000)),   # 原子名称"CG",电荷状态4,坐标(0.607, 1.377, 0.000)
        ("CD1", 5, (0.709, 1.195, -0.000)), # 原子名称"CD1",电荷状态5,坐标(0.709, 1.195, -0.000)
        ("CD2", 5, (0.706, -1.196, 0.000)),  # 原子名称"CD2",电荷状态5,坐标(0.706, -1.196, 0.000)
        ("CE1", 5, (2.102, 1.198, -0.000)),  # 原子名称"CE1",电荷状态5,坐标(2.102, 1.198, -0.000)
        ("CE2", 5, (2.098, -1.201, -0.000)), # 原子名称"CE2",电荷状态5,坐标(2.098, -1.201, -0.000)
        ("CZ", 5, (2.794, -0.003, -0.001)),  # 原子名称"CZ",电荷状态5,坐标(2.794, -0.003, -0.001)
    ],
    "PRO": [
        ("N", 0, (-0.566, 1.351, -0.000)),  # 原子名称"N",电荷状态0,坐标(-0.566, 1.351, -0.000)
        ("CA", 0, (0.000, 0.000, 0.000)),   # 原子名称"CA",电荷状态0,坐标(0.000, 0.000, 0.000)
        ("C", 0, (1.527, -0.000, 0.000)),   # 原子名称"C",电荷状态0,坐标(1.527, -0.000, 0.000)
        ("CB", 0, (-0.546, -0.611, -1.293)),# 原子名称"CB",电荷状态0,坐标(-0.546, -0.611, -1.293)
        ("O", 3, (0.621, 1.066, 0.000)),    # 原子名称"O",电荷状态3,坐标(0.621, 1.066, 0.000)
        ("CG", 4, (0.382, 1.445, 0.0)),     # 原子名称"CG",电荷状态4,坐标(0.382, 1.445, 0.0)
        ("CD", 5, (0.477, 1.424, 0.0)),     # 原子名称"CD",电荷状态5,坐标(0.477, 1.424, 0.0)
        # ('CD', 5, (0.427, 1.440, 0.0)),   # 注释
    "TYR": [  # TYR 残基的描述开始
        ("N", 0, (-0.522, 1.362, 0.000)),  # 残基的氮原子,索引为 0,坐标为 (-0.522, 1.362, 0.000)
        ("CA", 0, (0.000, 0.000, 0.000)),  # 残基的α-碳原子,索引为 0,坐标为 (0.000, 0.000, 0.000)
        ("C", 0, (1.524, -0.000, -0.000)),  # 残基的碳原子,索引为 0,坐标为 (1.524, -0.000, -0.000)
        ("CB", 0, (-0.522, -0.776, -1.213)),  # 残基的侧链碳原子,索引为 0,坐标为 (-0.522, -0.776, -1.213)
        ("O", 3, (0.627, 1.062, -0.000)),  # 残基的氧原子,索引为 3,坐标为 (0.627, 1.062, -0.000)
        ("CG", 4, (0.607, 1.382, -0.000)),  # 残基的芳香环的碳原子,索引为 4,坐标为 (0.607, 1.382, -0.000)
        ("CD1", 5, (0.716, 1.195, -0.000)),  # 残基的芳香环的第一个碳原子,索引为 5,坐标为 (0.716, 1.195, -0.000)
        ("CD2", 5, (0.713, -1.194, -0.001)),  # 残基的芳香环的第二个碳原子,索引为 5,坐标为 (0.713, -1.194, -0.001)
        ("CE1", 5, (2.107, 1.200, -0.002)),  # 残基的芳香环的第一个环氧基碳原子,索引为 5,坐标为 (2.107, 1.200, -0.002)
        ("CE2", 5, (2.104, -1.201, -0.003)),  # 残基的芳香环的第二个环氧基碳原子,索引为 5,坐标为 (2.104, -1.201, -0.003)
        ("OH", 5, (4.168, -0.002, -0.005)),  # 残基的酚羟基氧原子,索引为 5,坐标为 (4.168, -0.002, -0.005)
        ("CZ", 5, (2.791, -0.001, -0.003)),  # 残基的芳香环的环氧基碳原子,索引为 5,坐标为 (2.791, -0.001, -0.003)
    ],  # TYR 残基的描述结束

    "VAL": [  # VAL 残基的描述开始
        ("N", 0, (-0.494, 1.373, -0.000)),  # 残基的氮原子,索引为 0,坐标为 (-0.494, 1.373, -0.000)
        ("CA", 0, (0.000, 0.000, 0.000)),  # 残基的α-碳原子,索引为 0,坐标为 (0.000, 0.000, 0.000)
        ("C", 0, (1.527, -0.000, -0.000)),  # 残基的碳原子,索引为 0,坐标为 (1.527, -0.000, -0.000)
        ("CB", 0, (-0.533, -0.795, -1.213)),  # 残基的侧链碳原子,索引为 0,坐标为 (-0.533, -0.795, -1.213)
        ("O", 3, (0.627, 1.062, -0.000)),  # 残基的氧原子,索引为 3,坐标为 (0.627, 1.062, -0.000)
        ("CG1", 4, (0.540, 1.429, -0.000)),  # 残基的第一个侧链碳原子,索引为 4,坐标为 (0.540, 1.429, -0.000)
        ("CG2", 4, (0.533, -0.776, 1.203)),  # 残基的第二个侧链碳原子,索引为 4,坐标为 (0.533, -0.776, 1.203)
    ],  # VAL 残基的描述结束
# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
residue_atoms: Dict[str, List[str]] = {
    "ALA": ["C", "CA", "CB", "N", "O"],
    "ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"],
    "ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"],
    "ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"],
    "CYS": ["C", "CA", "CB", "N", "O", "SG"],
    "GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"],
    "GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"],
    "GLY": ["C", "CA", "N", "O"],
    "HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"],
    "ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"],
    "LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"],
    "LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"],
    "MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"],
    "PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"],
    "PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"],
    "SER": ["C", "CA", "CB", "N", "O", "OG"],
    "THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"],
    "TRP": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE2", "CE3", "CZ2", "CZ3", "CH2", "N", "NE1", "O"],
    "TYR": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O", "OH"],
    "VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"],
}

# Naming swaps for ambiguous atom names.
# Due to symmetries in the amino acids the naming of atoms is ambiguous in
# 4 of the 20 amino acids.
# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
# in LEU, VAL and ARG can be resolved by using the 3d constellations of
# the 'ambiguous' atoms and their neighbours)
# TODO: ^ interpret this
residue_atom_renaming_swaps: Dict[str, Dict[str, str]] = {
    "ASP": {"OD1": "OD2"},
    "GLU": {"OE1": "OE2"},
    "PHE": {"CD1": "CD2", "CE1": "CE2"},
    "TYR": {"CD1": "CD2", "CE1": "CE2"},
}

# Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
van_der_waals_radius: Dict[str, float] = {
    "C": 1.7,
    "N": 1.55,
    "O": 1.52,
    "S": 1.8,
}

Bond = collections.namedtuple("Bond", ["atom1_name", "atom2_name", "length", "stddev"])
BondAngle = collections.namedtuple(
    "BondAngle",
    ["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"],
)


def map_structure_with_atom_order(in_list: list, first_call: bool = True) -> list:
    # Maps strings in a nested list structure to their corresponding index in atom_order
    if first_call:
        in_list = copy.deepcopy(in_list)
    for i in range(len(in_list)):
        if isinstance(in_list[i], list):
            in_list[i] = map_structure_with_atom_order(in_list[i], first_call=False)
        elif isinstance(in_list[i], str):
            in_list[i] = atom_order[in_list[i]]
        else:
            raise ValueError("Unexpected type when mapping nested lists!")
    return in_list


@functools.lru_cache(maxsize=None)
def load_stereo_chemical_props() -> None:
    # Placeholder function, presumably to load stereo chemical properties.
    # No implementation details provided.
    pass
    # 定义一个类型注解,表示函数返回一个元组,包含三个映射结构:
    # 第一个映射结构的键是字符串,值是 Bond 对象的列表
    # 第二个映射结构的键是字符串,值是 Bond 对象的列表
    # 第三个映射结构的键是字符串,值是 BondAngle 对象的列表
    Tuple[
        Mapping[str, List[Bond]],
        Mapping[str, List[Bond]],
        Mapping[str, List[BondAngle]],
    ]
# 将 stereo_chemical_props.txt 文件加载到一个结构化的数据中。

# 从资源管理器中读取 stereo_chemical_props.txt 文件内容
stereo_chemical_props = resources.read_text("openfold.resources", "stereo_chemical_props.txt")

# 创建行迭代器,用于逐行处理文件内容
lines_iter = iter(stereo_chemical_props.splitlines())

# 初始化字典,用于存储残基键值对应的键长信息列表
residue_bonds: Dict[str, List[Bond]] = {}
next(lines_iter)  # 跳过头部信息行

# 遍历文件内容的每一行,处理键长信息
for line in lines_iter:
    if line.strip() == "-":
        break
    bond, resname, bond_length, stddev = line.split()
    atom1, atom2 = bond.split("-")
    # 如果 resname 不在 residue_bonds 字典中,则创建空列表
    if resname not in residue_bonds:
        residue_bonds[resname] = []
    # 向 residue_bonds[resname] 列表中添加 Bond 对象
    residue_bonds[resname].append(Bond(atom1, atom2, float(bond_length), float(stddev)))
residue_bonds["UNK"] = []  # 添加一个默认值

# 初始化字典,用于存储残基键值对应的键角信息列表
residue_bond_angles: Dict[str, List[BondAngle]] = {}
next(lines_iter)  # 跳过空行
next(lines_iter)  # 跳过头部信息行

# 遍历文件内容的每一行,处理键角信息
for line in lines_iter:
    if line.strip() == "-":
        break
    bond, resname, angle_degree, stddev_degree = line.split()
    atom1, atom2, atom3 = bond.split("-")
    # 如果 resname 不在 residue_bond_angles 字典中,则创建空列表
    if resname not in residue_bond_angles:
        residue_bond_angles[resname] = []
    # 向 residue_bond_angles[resname] 列表中添加 BondAngle 对象
    residue_bond_angles[resname].append(
        BondAngle(
            atom1,
            atom2,
            atom3,
            float(angle_degree) / 180.0 * np.pi,
            float(stddev_degree) / 180.0 * np.pi,
        )
    )
residue_bond_angles["UNK"] = []  # 添加一个默认值

def make_bond_key(atom1_name: str, atom2_name: str) -> str:
    """创建用于查找键长的唯一键值。"""
    return "-".join(sorted([atom1_name, atom2_name]))

# 初始化字典,用于存储残基键值对应的虚拟键长信息列表
residue_virtual_bonds: Dict[str, List[Bond]] = {}
    for resname, bond_angles in residue_bond_angles.items():
        # 为键值对(resname, bond_angles)中的每个残基名称(resname)和键角(bond_angles)执行以下操作

        # 创建用于快速查找键长的字典。
        bond_cache: Dict[str, Bond] = {}
        # 遍历给定残基(resname)对应的键的列表,将键对(atom1_name, atom2_name)和键对象存入字典bond_cache中。
        for b in residue_bonds[resname]:
            bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b

        # 将残基的虚拟键列表初始化为空列表。
        residue_virtual_bonds[resname] = []

        # 遍历键角(bond_angles)中的每个键角(ba)。
        for ba in bond_angles:
            # 从bond_cache字典中获取键角的第一个键对应的键对象(bond1)和第二个键对应的键对象(bond2)。
            bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
            bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]

            # 使用余弦定理计算atom1和atom3之间的距离长度。
            gamma = ba.angle_rad
            length = np.sqrt(bond1.length**2 + bond2.length**2 - 2 * bond1.length * bond2.length * np.cos(gamma))

            # 根据假设未关联错误,计算不确定性的传播。
            dl_outer = 0.5 / length
            dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer
            dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer
            dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer
            stddev = np.sqrt(
                (dl_dgamma * ba.stddev) ** 2 + (dl_db1 * bond1.stddev) ** 2 + (dl_db2 * bond2.stddev) ** 2
            )

            # 将计算得到的虚拟键信息添加到residue_virtual_bonds[resname]列表中。
            residue_virtual_bonds[resname].append(Bond(ba.atom1_name, ba.atom3name, length, stddev))

    # 返回包含三个项目的元组:原始键(bonds)、虚拟键(virtual_bonds)和键角(bond_angles)。
    return (residue_bonds, residue_virtual_bonds, residue_bond_angles)
# 一对元组,分别表示普通键合和脯氨酸键合的残基间距长度(以埃为单位)。
between_res_bond_length_c_n: Tuple[float, float] = (1.329, 1.341)
# 一对元组,分别表示普通键合和脯氨酸键合的残基间距长度的标准偏差(以埃为单位)。
between_res_bond_length_stddev_c_n: Tuple[float, float] = (0.014, 0.016)

# 一对元组,分别表示残基间的余弦角度。
between_res_cos_angles_c_n_ca: Tuple[float, float] = (-0.5203, 0.0353)  # degrees: 121.352 +- 2.315
between_res_cos_angles_ca_c_n: Tuple[float, float] = (-0.4473, 0.0311)  # degrees: 116.568 +- 1.995

# 这个列表用于存储原子数据,每个残基需要固定的原子数据大小(例如 numpy 数组)。
atom_types: List[str] = [
    "N", "CA", "C", "CB", "O", "CG", "CG1", "CG2", "OG", "OG1", "SG", "CD",
    "CD1", "CD2", "ND1", "ND2", "OD1", "OD2", "SD", "CE", "CE1", "CE2", "CE3",
    "NE", "NE1", "NE2", "OE1", "OE2", "CH2", "NH1", "NH2", "OH", "CZ", "CZ2",
    "CZ3", "NZ", "OXT",
]
# 字典,将原子类型映射到它们在列表中的索引位置。
atom_order: Dict[str, int] = {atom_type: i for i, atom_type in enumerate(atom_types)}
# 原子类型的数量,这里是固定的值 37。
atom_type_num = len(atom_types)  # := 37.

# 字典,将每种氨基酸的简称映射到一个包含14个元素的原子名列表,用于紧凑的原子编码。
restype_name_to_atom14_names: Dict[str, List[str]] = {
    "ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
    "ARG": ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2", "", "", ""],
    "ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2", "", "", "", "", "", ""],
    "ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2", "", "", "", "", "", ""],
    "CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
    "GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2", "", "", "", "", ""],
    "GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2", "", "", "", "", ""],
    "GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
    "HIS": ["N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2", "", "", "", ""],
    "ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1", "", "", "", "", "", ""],
    "LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "", "", "", "", "", ""],
    "LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "", "", "", "", ""],
    "MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE", "", "", "", "", "", ""],
    "PHE": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "", "", ""],
    "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
    "SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
    "THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2", "", "", "", "", "", "", ""],
    "TRP": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "NE1", "CE2", "CE3", "CZ2", "CZ3", "CH2"],
    "TYR": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH", "", ""],
    # 定义"VAL"键对应的列表,包含了特定的原子名称
    "VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "", "", "", "", "", "", ""],
    # 定义"UNK"键对应的列表,包含了空字符串,用于未知类型的占位符
    "UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
# pylint: enable=line-too-long
# pylint: enable=bad-whitespace

# This is the standard residue order when coding AA type as a number.
# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
restypes: List[str] = [
    "A",   # Alanine
    "R",   # Arginine
    "N",   # Asparagine
    "D",   # Aspartic acid
    "C",   # Cysteine
    "Q",   # Glutamine
    "E",   # Glutamic acid
    "G",   # Glycine
    "H",   # Histidine
    "I",   # Isoleucine
    "L",   # Leucine
    "K",   # Lysine
    "M",   # Methionine
    "F",   # Phenylalanine
    "P",   # Proline
    "S",   # Serine
    "T",   # Threonine
    "W",   # Tryptophan
    "Y",   # Tyrosine
    "V",   # Valine
]

restype_order: Dict[str, int] = {restype: i for i, restype in enumerate(restypes)}
restype_num = len(restypes)  # Number of standard amino acid types (:= 20).
unk_restype_index = restype_num  # Catch-all index for unknown amino acid types.

restypes_with_x: List[str] = restypes + ["X"]  # Include 'X' for unknown amino acids.
restype_order_with_x: Dict[str, int] = {restype: i for i, restype in enumerate(restypes_with_x)}


def sequence_to_onehot(sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False) -> np.ndarray:
    """Maps the given sequence into a one-hot encoded matrix.

    Args:
      sequence: An amino acid sequence.
      mapping: A dictionary mapping amino acids to integers.
      map_unknown_to_x: If True, any amino acid that is not in the mapping will be
        mapped to the unknown amino acid 'X'. If the mapping doesn't contain amino acid 'X', an error will be thrown.
        If False, any amino acid not in the mapping will throw an error.

    Returns:
      A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of the sequence.

    Raises:
      ValueError: If the mapping doesn't contain values from 0 to
        num_unique_aas - 1 without any gaps.
    """
    num_entries = max(mapping.values()) + 1

    if sorted(set(mapping.values())) != list(range(num_entries)):
        raise ValueError(
            "The mapping must have values from 0 to num_unique_aas-1 without any gaps. Got: %s"
            % sorted(mapping.values())
        )

    one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)

    for aa_index, aa_type in enumerate(sequence):
        if map_unknown_to_x:
            if aa_type.isalpha() and aa_type.isupper():
                aa_id = mapping.get(aa_type, mapping["X"])  # Map unknown AA to 'X' if allowed.
            else:
                raise ValueError(f"Invalid character in the sequence: {aa_type}")
        else:
            aa_id = mapping[aa_type]  # Map AA based on the provided mapping.
        one_hot_arr[aa_index, aa_id] = 1

    return one_hot_arr


restype_1to3: Dict[str, str] = {
    "A": "ALA",
    "R": "ARG",
    "N": "ASN",
    "D": "ASP",
    "C": "CYS",
    "Q": "GLN",
    "E": "GLU",
    "G": "GLY",
    "H": "HIS",
    "I": "ILE",
    "L": "LEU",
    "K": "LYS",
    "M": "MET",
    "F": "PHE",
    "P": "PRO",
    "S": "SER",
    "T": "THR",
    "W": "TRP",
    "Y": "TYR",
    "V": "VAL",
}

# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
# 1-to-1 mapping of 3 letter names to one letter names. The latter contains
# many more, and less common, three letter names as keys and maps many of these
# to the same one letter name (including 'X' and 'U' which we don't use here).
# 将 restype_1to3 字典中的键值对反转,创建一个新的字典 restype_3to1,用于将三字母缩写映射回单字母缩写。
restype_3to1: Dict[str, str] = {v: k for k, v in restype_1to3.items()}

# 为所有未知的残基定义一个默认的 restype 名称。
unk_restype = "UNK"

# 根据 restypes 中的单字母缩写列表,加上 unk_restype,创建一个包含所有残基名称的列表 resnames。
resnames: List[str] = [restype_1to3[r] for r in restypes] + [unk_restype]

# 创建一个将残基名称映射到索引的字典 resname_to_idx。
resname_to_idx: Dict[str, int] = {resname: i for i, resname in enumerate(resnames)}

# HHBLITS_AA_TO_ID 和 ID_TO_HHBLITS_AA 是根据 hhblits 约定定义的字母到整数编码的映射。
# HHBLITS_AA_TO_ID 将每个氨基酸字母映射到一个整数 ID。
HHBLITS_AA_TO_ID: Dict[str, int] = {
    "A": 0,
    "B": 2,  # B 被映射到 D 的 ID
    "C": 1,  # C 和 U 共用相同的 ID
    "D": 2,  # D 和 B 共用相同的 ID
    "E": 3,  # E 和 Z 共用相同的 ID
    "F": 4,
    "G": 5,
    "H": 6,
    "I": 7,
    "J": 20,  # J 被映射到 X 的 ID
    "K": 8,
    "L": 9,
    "M": 10,
    "N": 11,
    "O": 20,  # O 被映射到 X 的 ID
    "P": 12,
    "Q": 13,
    "R": 14,
    "S": 15,
    "T": 16,
    "U": 1,   # U 和 C 共用相同的 ID
    "V": 17,
    "W": 18,
    "X": 20,  # X 表示任何氨基酸
    "Y": 19,
    "Z": 3,   # Z 和 E 共用相同的 ID
    "-": 21,  # - 表示在序列比对中缺失的氨基酸
}

# ID_TO_HHBLITS_AA 是 HHBLITS_AA_TO_ID 的部分反转,将整数 ID 映射回对应的氨基酸字母。
ID_TO_HHBLITS_AA: Dict[int, str] = {
    0: "A",
    1: "C",    # 也对应 U
    2: "D",    # 也对应 B
    3: "E",    # 也对应 Z
    4: "F",
    5: "G",
    6: "H",
    7: "I",
    8: "K",
    9: "L",
    10: "M",
    11: "N",
    12: "P",
    13: "Q",
    14: "R",
    15: "S",
    16: "T",
    17: "V",
    18: "W",
    19: "Y",
    20: "X",   # 包括 J 和 O
    21: "-",   # 表示缺失的氨基酸
}

# 将 restypes 和 ["X", "-"] 合并,创建一个包含所有氨基酸类型的列表 restypes_with_x_and_gap。
restypes_with_x_and_gap: List[str] = restypes + ["X", "-"]

# 使用 ID_TO_HHBLITS_AA 字典映射 restypes_with_x_and_gap 中的每个氨基酸类型到其在 restypes_with_x_and_gap 中的索引,
# 创建一个元组 MAP_HHBLITS_AATYPE_TO_OUR_AATYPE。
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE: Tuple[int, ...] = tuple(
    restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i]) for i in range(len(restypes_with_x_and_gap))
)


def _make_standard_atom_mask() -> np.ndarray:
    """Returns [num_res_types, num_atom_types] mask array."""
    # 创建一个二维数组 mask,维度为 [restype_num + 1, atom_type_num],初始值都为 0。
    # +1 是为了包括未知类型的残基 (all 0s)。
    mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32)
    
    # 遍历 restypes 中的每个单字母缩写 restype_letter。
    for restype, restype_letter in enumerate(restypes):
        # 获取 restype_letter 对应的三字母残基名称。
        restype_name = restype_1to3[restype_letter]
        # 获取该残基的所有原子名称列表。
        atom_names = residue_atoms[restype_name]
        # 遍历该残基的每个原子名称,将对应的原子类型置为 1。
        for atom_name in atom_names:
            atom_type = atom_order[atom_name]
            mask[restype, atom_type] = 1
    
    return mask


# 调用 _make_standard_atom_mask 函数,生成标准原子掩码数组,并将其赋值给 STANDARD_ATOM_MASK 变量。
STANDARD_ATOM_MASK = _make_standard_atom_mask()


# 定义一个函数 chi_angle_atom,用于生成每个残基中每个 chi 角的轴定义原子的独热表示。
def chi_angle_atom(atom_index: int) -> np.ndarray:
    """Define chi-angle rigid groups via one-hot representations."""
    # 创建一个空字典 chi_angles_index 和一个空列表 one_hots。
    chi_angles_index = {}
    one_hots = []

    # 遍历 chi_angles_atoms 中的键值对 (k, v)。
    for k, v in chi_angles_atoms.items():
        # 对于每个 v 中的序列 s,将其第 atom_index 个原子类型的索引添加到 indices 列表中。
        indices = [atom_types.index(s[atom_index]) for s in v]
        # 如果 indices 的长度不足 4,则用 -1 填充到长度为 4。
        indices.extend([-1] * (4 - len(indices)))
        # 将 indices 列表赋值给 chi_angles_index 字典的键 k。
        chi_angles_index[k] = indices

    # 遍历 restypes 中的每个残基 r。
    for r in restypes:
        # 获取 r 对应的三字母残基名称 res3。
        res3 = restype_1to3[r]
        # 根据 chi_angles_index[res3] 中的每个索引,生成对应的独热表示,并添加到 one_hots 列表中。
        one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
        one_hots.append(one_hot)
    # 将一个全零的数组添加到 `one_hots` 列表中,数组形状为 [4, atom_type_num],用于表示残基 `X`。
    one_hots.append(np.zeros([4, atom_type_num]))
    
    # 将 `one_hots` 列表中的数组堆叠成一个新的数组 `one_hot`,沿着第一个轴堆叠。
    one_hot = np.stack(one_hots, axis=0)
    
    # 对 `one_hot` 数组进行转置操作,交换第二个和第三个维度,形状变为 [batch_size, atom_type_num, 4]。
    one_hot = np.transpose(one_hot, [0, 2, 1])
    
    # 返回经过处理的 one_hot 数组作为函数的输出结果。
    return one_hot
# 使用函数 chi_angle_atom(1) 计算第一个原子的氨基酸角度独热编码
chi_atom_1_one_hot = chi_angle_atom(1)
# 使用函数 chi_angle_atom(2) 计算第二个原子的氨基酸角度独热编码
chi_atom_2_one_hot = chi_angle_atom(2)

# 生成一个与 chi_angles_atoms 类似的数组,但使用索引而不是名称
chi_angles_atom_indices_list: List[List[List[str]]] = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
# 使用函数 map_structure_with_atom_order 处理 chi_angles_atom_indices_list 中的结构,并返回结果
chi_angles_atom_indices_ours: list = map_structure_with_atom_order(chi_angles_atom_indices_list)
# 创建一个 numpy 数组,存储每个氨基酸的角度原子索引,用于计算
chi_angles_atom_indices = np.array(
    [chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms))) for chi_atoms in chi_angles_atom_indices_list]
)

# 从 (氨基酸名, 原子名) 对映射到原子的 chi 组索引及其在该组内的原子索引
chi_groups_for_atom: Dict[Tuple[str, str], List[Tuple[int, int]]] = collections.defaultdict(list)
for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
    for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
        for atom_i, atom in enumerate(chi_group):
            chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
chi_groups_for_atom = dict(chi_groups_for_atom)

def _make_rigid_transformation_4x4(ex: np.ndarray, ey: np.ndarray, translation: np.ndarray) -> np.ndarray:
    """Create a rigid 4x4 transformation matrix from two axes and transl."""
    # 将 ex 向量归一化
    ex_normalized = ex / np.linalg.norm(ex)

    # 使 ey 向量垂直于 ex 向量
    ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
    ey_normalized /= np.linalg.norm(ey_normalized)

    # 计算 ez 向量作为 ex 和 ey 向量的叉乘
    eznorm = np.cross(ex_normalized, ey_normalized)
    # 创建 4x4 的刚体变换矩阵
    m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose()
    m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0)
    return m

# 创建数组,存储 (氨基酸类型, 原子类型) 到刚体组索引的映射
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int)
# 创建数组,存储 (氨基酸类型, 原子类型) 到刚体组位置的掩码
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
# 创建数组,存储 (氨基酸类型, 原子类型) 到刚体组位置的坐标
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
# 创建数组,存储 (氨基酸类型, 原子类型) 到刚体组索引的映射(14 个原子的版本)
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int)
# 创建数组,存储 (氨基酸类型, 原子类型) 到刚体组位置的掩码(14 个原子的版本)
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
# 创建数组,存储 (氨基酸类型, 原子类型) 到刚体组位置的坐标(14 个原子的版本)
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
# 创建数组,存储 (氨基酸类型) 到默认刚体组坐标系的转换矩阵
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)

def _make_rigid_group_constants() -> None:
    """Fill the arrays above."""
    # 遍历每个残基类型及其对应的字母表示
    for restype, restype_letter in enumerate(restypes):
        # 根据字母找到对应的残基名
        resname = restype_1to3[restype_letter]
        # 遍历该残基名对应的所有原子名称、原子组索引和原子位置信息
        for atomname, group_idx, atom_position in rigid_group_atom_positions[resname]:
            # 根据原子名找到对应的原子类型
            atomtype = atom_order[atomname]

            # 将残基类型和原子类型映射到刚性组索引
            restype_atom37_to_rigid_group[restype, atomtype] = group_idx
            # 设置残基类型和原子类型的掩码为1,表示存在关联
            restype_atom37_mask[restype, atomtype] = 1
            # 设置残基类型和原子类型的刚性组位置信息
            restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position

            # 在残基名到14原子名列表中找到当前原子名的索引
            atom14idx = restype_name_to_atom14_names[resname].index(atomname)
            # 将残基类型和14原子名索引映射到刚性组索引
            restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
            # 设置残基类型和14原子名索引的掩码为1
            restype_atom14_mask[restype, atom14idx] = 1
            # 设置残基类型和14原子名索引的刚性组位置信息
            restype_atom14_rigid_group_positions[restype, atom14idx, :] = atom_position
    # 遍历氨基酸类型及其对应的缩写
    for restype, restype_letter in enumerate(restypes):
        # 根据缩写获取氨基酸的全名
        resname = restype_1to3[restype_letter]
        # 从预定义的刚性群体原子位置中创建字典,键为原子名,值为位置数组
        atom_positions: Dict[str, np.ndarray] = {
            name: np.array(pos) for name, _, pos in rigid_group_atom_positions[resname]
        }

        # 将刚性群体的默认骨架到骨架的变换矩阵设为单位矩阵(身份变换)
        restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)

        # 将预Ω框架到骨架的变换矩阵设为单位矩阵(虚拟的身份变换)
        restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)

        # 将φ框架到骨架的变换矩阵计算为刚性变换矩阵
        mat = _make_rigid_transformation_4x4(
            ex=atom_positions["N"] - atom_positions["CA"],  # X轴方向的向量
            ey=np.array([1.0, 0.0, 0.0]),  # Y轴方向的向量
            translation=atom_positions["N"],  # 平移向量
        )
        restype_rigid_group_default_frame[restype, 2, :, :] = mat

        # 将ψ框架到骨架的变换矩阵计算为刚性变换矩阵
        mat = _make_rigid_transformation_4x4(
            ex=atom_positions["C"] - atom_positions["CA"],  # X轴方向的向量
            ey=atom_positions["CA"] - atom_positions["N"],  # Y轴方向的向量
            translation=atom_positions["C"],  # 平移向量
        )
        restype_rigid_group_default_frame[restype, 3, :, :] = mat

        # 如果存在χ1角度,则计算χ1框架到骨架的变换矩阵
        if chi_angles_mask[restype][0]:
            base_atom_names = chi_angles_atoms[resname][0]  # χ1角度的基础原子名列表
            base_atom_positions = [atom_positions[name] for name in base_atom_names]  # 基础原子的位置列表
            mat = _make_rigid_transformation_4x4(
                ex=base_atom_positions[2] - base_atom_positions[1],  # X轴方向的向量
                ey=base_atom_positions[0] - base_atom_positions[1],  # Y轴方向的向量
                translation=base_atom_positions[2],  # 平移向量
            )
            restype_rigid_group_default_frame[restype, 4, :, :] = mat

        # 依次计算χ2到χ4框架到前一框架的刚性变换矩阵
        # 由于所有下一个框架的旋转轴都从前一个框架的(0,0,0)开始,因此这里使用了固定的旋转轴
        for chi_idx in range(1, 4):
            if chi_angles_mask[restype][chi_idx]:
                axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]  # 当前角度的轴端原子名
                axis_end_atom_position = atom_positions[axis_end_atom_name]  # 轴端原子的位置
                mat = _make_rigid_transformation_4x4(
                    ex=axis_end_atom_position,  # X轴方向的向量
                    ey=np.array([-1.0, 0.0, 0.0]),  # Y轴方向的向量
                    translation=axis_end_atom_position,  # 平移向量
                )
                restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat
# 调用函数以初始化刚性群组的常量
_make_rigid_group_constants()

# 定义函数,计算原子间的上下界,以评估违规情况
def make_atom14_dists_bounds(
    overlap_tolerance: float = 1.5,  # 碰撞容忍度
    bond_length_tolerance_factor: int = 15,  # 键长容忍因子
) -> Dict[str, np.ndarray]:
    """compute upper and lower bounds for bonds to assess violations."""
    # 初始化数组以存储不同残基类型、原子间距离的下界、上界和标准差
    restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
    restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
    restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
    
    # 载入化学属性,包括原子间的键和虚拟键
    residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
    
    # 遍历每种残基类型
    for restype, restype_letter in enumerate(restypes):
        resname = restype_1to3[restype_letter]
        atom_list = restype_name_to_atom14_names[resname]

        # 创建碰撞的下界和上界
        for atom1_idx, atom1_name in enumerate(atom_list):
            if not atom1_name:
                continue
            atom1_radius = van_der_waals_radius[atom1_name[0]]
            for atom2_idx, atom2_name in enumerate(atom_list):
                if (not atom2_name) or atom1_idx == atom2_idx:
                    continue
                atom2_radius = van_der_waals_radius[atom2_name[0]]
                lower = atom1_radius + atom2_radius - overlap_tolerance
                upper = 1e10
                # 设置原子间距离的下界和上界
                restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
                restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
                restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
                restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper

        # 覆盖键和角度的下界和上界
        for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
            atom1_idx = atom_list.index(b.atom1_name)
            atom2_idx = atom_list.index(b.atom2_name)
            lower = b.length - bond_length_tolerance_factor * b.stddev
            upper = b.length + bond_length_tolerance_factor * b.stddev
            # 设置原子间距离的下界和上界,以及标准差
            restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
            restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
            restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
            restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
            restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
            restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
    
    # 返回包含键的下界、上界和标准差的字典
    return {
        "lower_bound": restype_atom14_bond_lower_bound,  # 形状为 (21,14,14)
        "upper_bound": restype_atom14_bond_upper_bound,  # 形状为 (21,14,14)
        "stddev": restype_atom14_bond_stddev,  # 形状为 (21,14,14)
    }

# 初始化数组以存储不同残基类型的模糊原子信息
restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32)
# 创建索引数组,用于记录不同残基类型的原子序号
restype_atom14_ambiguous_atoms_swap_idx: np.ndarray = np.tile(np.arange(14, dtype=int), (21, 1))

# 定义函数,生成原子的模糊特征
def _make_atom14_ambiguity_feats() -> None:
    # 遍历 residue_atom_renaming_swaps 字典的每一项,其中 res 是键,pairs 是对应的值(另一个字典)
    for res, pairs in residue_atom_renaming_swaps.items():
        # 使用 restype_3to1 字典将三字母氨基酸代码 res 转换为索引 res_idx
        res_idx = restype_order[restype_3to1[res]]
        # 遍历 pairs 字典中的每一对 atom1 和 atom2
        for atom1, atom2 in pairs.items():
            # 在 restype_name_to_atom14_names[res] 列表中找到 atom1 的索引 atom1_idx
            atom1_idx = restype_name_to_atom14_names[res].index(atom1)
            # 在 restype_name_to_atom14_names[res] 列表中找到 atom2 的索引 atom2_idx
            atom2_idx = restype_name_to_atom14_names[res].index(atom2)
            # 将 restype_atom14_ambiguous_atoms 中 (res_idx, atom1_idx) 处置为 1,表示 atom1 是模糊的
            restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1
            # 将 restype_atom14_ambiguous_atoms 中 (res_idx, atom2_idx) 处置为 1,表示 atom2 是模糊的
            restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1
            # 记录 atom1_idx 处应该交换的索引是 atom2_idx
            restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom1_idx] = atom2_idx
            # 记录 atom2_idx 处应该交换的索引是 atom1_idx
            restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom2_idx] = atom1_idx
# 调用名为 `_make_atom14_ambiguity_feats` 的函数,执行其内部逻辑
_make_atom14_ambiguity_feats()

# 将整数序列 `aatype` 转换为对应的字符串序列,并返回结果
def aatype_to_str_sequence(aatype: Sequence[int]) -> str:
    # 使用列表推导式遍历 `aatype` 序列的每个元素,根据其值从 `restypes_with_x` 字典中获取对应的字符串,并连接成一个字符串
    return "".join([restypes_with_x[aatype[i]] for i in range(len(aatype))])

.\models\esm\openfold_utils\rigid_utils.py

# 引入必要的模块和库
from functools import lru_cache  # 导入 functools 库中的 lru_cache 装饰器
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple  # 引入类型提示

import numpy as np  # 导入 numpy 库,用于数值计算
import torch  # 导入 PyTorch 库,用于张量操作


def rot_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """
    执行两个旋转矩阵张量的矩阵乘法。手动编写以避免 AMP 下转换。

    Args:
        a: [*, 3, 3] 左乘数
        b: [*, 3, 3] 右乘数
    Returns:
        乘积 ab
    """

    def row_mul(i: int) -> torch.Tensor:
        # 计算矩阵乘法的每行结果
        return torch.stack(
            [
                a[..., i, 0] * b[..., 0, 0] + a[..., i, 1] * b[..., 1, 0] + a[..., i, 2] * b[..., 2, 0],
                a[..., i, 0] * b[..., 0, 1] + a[..., i, 1] * b[..., 1, 1] + a[..., i, 2] * b[..., 2, 1],
                a[..., i, 0] * b[..., 0, 2] + a[..., i, 1] * b[..., 1, 2] + a[..., i, 2] * b[..., 2, 2],
            ],
            dim=-1,
        )

    # 按行堆叠计算结果,形成最终的矩阵乘法结果
    return torch.stack(
        [
            row_mul(0),
            row_mul(1),
            row_mul(2),
        ],
        dim=-2,
    )


def rot_vec_mul(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    """
    对向量施加旋转。手动编写以避免 AMP 下转换。

    Args:
        r: [*, 3, 3] 旋转矩阵
        t: [*, 3] 坐标张量
    Returns:
        [*, 3] 旋转后的坐标
    """
    x, y, z = torch.unbind(t, dim=-1)
    # 计算旋转后的坐标
    return torch.stack(
        [
            r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
            r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z,
            r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z,
        ],
        dim=-1,
    )


@lru_cache(maxsize=None)
def identity_rot_mats(
    batch_dims: Tuple[int, ...],
    dtype: Optional[torch.dtype] = None,
    device: Optional[torch.device] = None,
    requires_grad: bool = True,
) -> torch.Tensor:
    """
    返回指定批次维度下的单位旋转矩阵张量。

    Args:
        batch_dims: 批次维度的元组
        dtype: 张量数据类型,默认为 None
        device: 张量的设备,默认为 None
        requires_grad: 是否需要梯度,默认为 True
    Returns:
        torch.Tensor: 单位旋转矩阵张量
    """
    # 创建单位矩阵,并根据指定的批次维度进行形状调整和扩展
    rots = torch.eye(3, dtype=dtype, device=device, requires_grad=requires_grad)
    rots = rots.view(*((1,) * len(batch_dims)), 3, 3)
    rots = rots.expand(*batch_dims, -1, -1)
    rots = rots.contiguous()

    return rots


@lru_cache(maxsize=None)
def identity_trans(
    batch_dims: Tuple[int, ...],
    dtype: Optional[torch.dtype] = None,
    device: Optional[torch.device] = None,
    requires_grad: bool = True,
) -> torch.Tensor:
    """
    返回指定批次维度下的单位平移张量。

    Args:
        batch_dims: 批次维度的元组
        dtype: 张量数据类型,默认为 None
        device: 张量的设备,默认为 None
        requires_grad: 是否需要梯度,默认为 True
    Returns:
        torch.Tensor: 单位平移张量
    """
    # 创建单位平移张量,并根据指定的批次维度进行形状调整和扩展
    trans = torch.zeros(3, dtype=dtype, device=device, requires_grad=requires_grad)
    trans = trans.view(*((1,) * len(batch_dims)), 3)
    trans = trans.expand(*batch_dims, -1)
    trans = trans.contiguous()

    return trans
# 定义一个函数,返回一个零填充的三维张量
def identity_quats(
    batch_dims: Tuple[int, ...],  # 批量维度,指定生成张量的形状
    dtype: Optional[torch.dtype] = None,  # 数据类型,默认为None
    device: Optional[torch.device] = None,  # 设备,默认为None
    requires_grad: bool = True,  # 是否需要梯度,默认为True
) -> torch.Tensor:  # 返回一个torch.Tensor对象
    trans = torch.zeros((*batch_dims, 3), dtype=dtype, device=device, requires_grad=requires_grad)
    return trans


# 使用LRU缓存装饰器包装的函数,生成批量维度的单位四元数
@lru_cache(maxsize=None)
def identity_quats(
    batch_dims: Tuple[int, ...],  # 批量维度,指定生成张量的形状
    dtype: Optional[torch.dtype] = None,  # 数据类型,默认为None
    device: Optional[torch.device] = None,  # 设备,默认为None
    requires_grad: bool = True,  # 是否需要梯度,默认为True
) -> torch.Tensor:  # 返回一个torch.Tensor对象
    quat = torch.zeros((*batch_dims, 4), dtype=dtype, device=device, requires_grad=requires_grad)

    # 使用torch.no_grad()上下文管理器,设置四元数的第一维度为1
    with torch.no_grad():
        quat[..., 0] = 1

    return quat


# 定义四元数元素列表
_quat_elements: List[str] = ["a", "b", "c", "d"]
# 生成四元数键的列表
_qtr_keys: List[str] = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
# 生成四元数键到索引的字典
_qtr_ind_dict: Dict[str, int] = {key: ind for ind, key in enumerate(_qtr_keys)}


# 定义一个将键值对列表转换为numpy数组的函数
def _to_mat(pairs: List[Tuple[str, int]]) -> np.ndarray:
    mat = np.zeros((4, 4))
    for key, value in pairs:
        ind = _qtr_ind_dict[key]
        mat[ind // 4][ind % 4] = value

    return mat


# 初始化一个形状为(4, 4, 3, 3)的四元数转换矩阵数组
_QTR_MAT = np.zeros((4, 4, 3, 3))
# 填充_QTR_MAT数组中的每个元素,使用_to_mat函数处理键值对列表
_QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
_QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
_QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
_QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
_QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
_QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
_QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
_QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
_QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])


def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
    """
    Converts a quaternion to a rotation matrix.

    Args:
        quat: [*, 4] quaternions
    Returns:
        [*, 3, 3] rotation matrices
    """
    # [*, 4, 4] 扩展四元数的维度,用于矩阵乘法
    quat = quat[..., None] * quat[..., None, :]

    # [4, 4, 3, 3] 获取四元数转换矩阵
    mat = _get_quat("_QTR_MAT", dtype=quat.dtype, device=quat.device)

    # [*, 4, 4, 3, 3] 扩展_QTR_MAT数组的维度,用于矩阵乘法
    shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
    quat = quat[..., None, None] * shaped_qtr_mat

    # [*, 3, 3] 沿着指定维度求和,得到旋转矩阵
    return torch.sum(quat, dim=(-3, -4))


def rot_to_quat(rot: torch.Tensor) -> torch.Tensor:
    if rot.shape[-2:] != (3, 3):
        raise ValueError("Input rotation is incorrectly shaped")

    [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = [[rot[..., i, j] for j in range(3)] for i in range(3)]

    k = [
        [
            xx + yy + zz,
            zy - yz,
            xz - zx,
            yx - xy,
        ],
        [
            zy - yz,
            xx - yy - zz,
            xy + yx,
            xz + zx,
        ],
        [
            xz - zx,
            xy + yx,
            yy - xx - zz,
            yz + zy,
        ],
        [
            yx - xy,
            xz + zx,
            yz + zy,
            zz - xx - yy,
        ],
    ]

    # 计算特征值和特征向量,返回最后一个特征向量作为四元数
    _, vectors = torch.linalg.eigh((1.0 / 3.0) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2))
    return vectors[..., -1]


# 初始化一个形状为(4, 4, 4)的四元数乘法表
_QUAT_MULTIPLY = np.zeros((4, 4, 4))
# 定义一个4x4的数组,用于执行四元数相乘操作的矩阵表示
_QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, -1]]

# 定义一个4x4的数组,用于执行四元数相乘操作的矩阵表示
_QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, -1, 0]]

# 定义一个4x4的数组,用于执行四元数相乘操作的矩阵表示
_QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], [0, 0, 0, -1], [1, 0, 0, 0], [0, 1, 0, 0]]

# 定义一个4x4的数组,用于执行四元数相乘操作的矩阵表示
_QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0], [1, 0, 0, 0]]

# 从_QUAT_MULTIPLY中选取索引为1到末尾的切片,这是用于纯向量四元数相乘的子矩阵
_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]

# 初始化一个字典,包含缓存的四元数相关矩阵和矩阵切片
_CACHED_QUATS: Dict[str, np.ndarray] = {
    "_QTR_MAT": _QTR_MAT,
    "_QUAT_MULTIPLY": _QUAT_MULTIPLY,
    "_QUAT_MULTIPLY_BY_VEC": _QUAT_MULTIPLY_BY_VEC,
}

# 使用LRU缓存装饰器,缓存_get_quat函数的结果,以加快多次调用时的速度
@lru_cache(maxsize=None)
def _get_quat(quat_key: str, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
    return torch.tensor(_CACHED_QUATS[quat_key], dtype=dtype, device=device)

# 执行四元数相乘操作的函数,输入两个四元数,返回它们的乘积
def quat_multiply(quat1: torch.Tensor, quat2: torch.Tensor) -> torch.Tensor:
    """Multiply a quaternion by another quaternion."""
    # 获取用于四元数相乘的矩阵
    mat = _get_quat("_QUAT_MULTIPLY", dtype=quat1.dtype, device=quat1.device)
    # 将矩阵形状调整为与输入四元数匹配
    reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
    # 执行张量运算,计算四元数乘积
    return torch.sum(reshaped_mat * quat1[..., :, None, None] * quat2[..., None, :, None], dim=(-3, -2))

# 执行四元数与纯向量四元数相乘操作的函数
def quat_multiply_by_vec(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
    """Multiply a quaternion by a pure-vector quaternion."""
    # 获取用于纯向量四元数相乘的子矩阵
    mat = _get_quat("_QUAT_MULTIPLY_BY_VEC", dtype=quat.dtype, device=quat.device)
    # 将矩阵形状调整为与输入四元数匹配
    reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
    # 执行张量运算,计算四元数与纯向量四元数的乘积
    return torch.sum(reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None], dim=(-3, -2))

# 执行旋转矩阵转置操作的函数
def invert_rot_mat(rot_mat: torch.Tensor) -> torch.Tensor:
    return rot_mat.transpose(-1, -2)

# 执行四元数取逆操作的函数
def invert_quat(quat: torch.Tensor) -> torch.Tensor:
    # 创建四元数副本
    quat_prime = quat.clone()
    # 将四元数除了第一个元素外的其余元素取反
    quat_prime[..., 1:] *= -1
    # 计算四元数的逆
    inv = quat_prime / torch.sum(quat**2, dim=-1, keepdim=True)
    return inv

# 表示一个3D旋转的类,支持旋转矩阵和四元数两种格式
class Rotation:
    """
    A 3D rotation. Depending on how the object is initialized, the rotation is represented by either a rotation matrix
    or a quaternion, though both formats are made available by helper functions. To simplify gradient computation, the
    underlying format of the rotation cannot be changed in-place. Like Rigid, the class is designed to mimic the
    behavior of a torch Tensor, almost as if each Rotation object were a tensor of rotations, in one format or another.
    """

    def __init__(
        self,
        rot_mats: Optional[torch.Tensor] = None,
        quats: Optional[torch.Tensor] = None,
        normalize_quats: bool = True,
        """
        Args:
            rot_mats:
                A [*, 3, 3] rotation matrix tensor. Mutually exclusive with quats
            quats:
                A [*, 4] quaternion. Mutually exclusive with rot_mats. If normalize_quats is not True, must be a unit
                quaternion
            normalize_quats:
                If quats is specified, whether to normalize quats
        """
        # 检查参数的合法性,确保只有一个输入参数被指定
        if (rot_mats is None and quats is None) or (rot_mats is not None and quats is not None):
            raise ValueError("Exactly one input argument must be specified")

        # 检查旋转矩阵和四元数的形状是否正确
        if (rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or (quats is not None and quats.shape[-1] != 4):
            raise ValueError("Incorrectly shaped rotation matrix or quaternion")

        # 强制使用全精度(float32)
        if quats is not None:
            quats = quats.to(dtype=torch.float32)
        if rot_mats is not None:
            rot_mats = rot_mats.to(dtype=torch.float32)

        # 如果指定了四元数且需要归一化,则进行归一化处理
        if quats is not None and normalize_quats:
            quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)

        # 将旋转矩阵和四元数存储在对象的私有属性中
        self._rot_mats = rot_mats
        self._quats = quats

    @staticmethod
    def identity(
        shape,
        dtype: Optional[torch.dtype] = None,
        device: Optional[torch.device] = None,
        requires_grad: bool = True,
        fmt: str = "quat",
    ) -> Rotation:
        """
        Returns an identity Rotation.

        Args:
            shape:
                The "shape" of the resulting Rotation object. See documentation for the shape property
            dtype:
                The torch dtype for the rotation
            device:
                The torch device for the new rotation
            requires_grad:
                Whether the underlying tensors in the new rotation object should require gradient computation
            fmt:
                One of "quat" or "rot_mat". Determines the underlying format of the new object's rotation
        Returns:
            A new identity rotation
        """
        # 根据指定的 fmt 参数创建一个身份旋转对象
        if fmt == "rot_mat":
            rot_mats = identity_rot_mats(
                shape,
                dtype,
                device,
                requires_grad,
            )
            return Rotation(rot_mats=rot_mats, quats=None)
        elif fmt == "quat":
            quats = identity_quats(shape, dtype, device, requires_grad)
            return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
        else:
            raise ValueError(f"Invalid format: f{fmt}")

    # Magic methods
    def __getitem__(self, index: Any) -> Rotation:
        """
        Allows torch-style indexing over the virtual shape of the rotation object. See documentation for the shape
        property.

        Args:
            index:
                A torch index. E.g. (1, 3, 2), or (slice(None,))
        Returns:
            The indexed rotation
        """
        # 如果索引不是元组,则转换为元组形式
        if type(index) != tuple:
            index = (index,)

        # 如果存储旋转矩阵的属性不为空
        if self._rot_mats is not None:
            # 使用索引获取部分旋转矩阵,并返回一个新的 Rotation 对象
            rot_mats = self._rot_mats[index + (slice(None), slice(None))]
            return Rotation(rot_mats=rot_mats)
        # 如果存储四元数的属性不为空
        elif self._quats is not None:
            # 使用索引获取部分四元数,并返回一个新的 Rotation 对象
            quats = self._quats[index + (slice(None),)]
            return Rotation(quats=quats, normalize_quats=False)
        else:
            # 如果旋转矩阵和四元数都为空,则抛出异常
            raise ValueError("Both rotations are None")

    def __mul__(self, right: torch.Tensor) -> Rotation:
        """
        Pointwise left multiplication of the rotation with a tensor. Can be used to e.g. mask the Rotation.

        Args:
            right:
                The tensor multiplicand
        Returns:
            The product
        """
        # 确保右乘数是一个 Tensor
        if not (isinstance(right, torch.Tensor)):
            raise TypeError("The other multiplicand must be a Tensor")

        # 如果存储旋转矩阵的属性不为空
        if self._rot_mats is not None:
            # 对旋转矩阵逐点进行左乘操作,并返回一个新的 Rotation 对象
            rot_mats = self._rot_mats * right[..., None, None]
            return Rotation(rot_mats=rot_mats, quats=None)
        # 如果存储四元数的属性不为空
        elif self._quats is not None:
            # 对四元数逐点进行左乘操作,并返回一个新的 Rotation 对象
            quats = self._quats * right[..., None]
            return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
        else:
            # 如果旋转矩阵和四元数都为空,则抛出异常
            raise ValueError("Both rotations are None")

    def __rmul__(self, left: torch.Tensor) -> Rotation:
        """
        Reverse pointwise multiplication of the rotation with a tensor.

        Args:
            left:
                The left multiplicand
        Returns:
            The product
        """
        # 右乘的逆操作,调用 __mul__ 方法实现
        return self.__mul__(left)

    # Properties

    @property
    def shape(self) -> torch.Size:
        """
        Returns the virtual shape of the rotation object. This shape is defined as the batch dimensions of the
        underlying rotation matrix or quaternion. If the Rotation was initialized with a [10, 3, 3] rotation matrix
        tensor, for example, the resulting shape would be [10].

        Returns:
            The virtual shape of the rotation object
        """
        # 如果存储旋转矩阵的属性不为空,则返回其形状的前两个维度
        if self._rot_mats is not None:
            return self._rot_mats.shape[:-2]
        # 如果存储四元数的属性不为空,则返回其形状的前一个维度
        elif self._quats is not None:
            return self._quats.shape[:-1]
        else:
            # 如果旋转矩阵和四元数都为空,则抛出异常
            raise ValueError("Both rotations are None")

    @property
    # 返回基础旋转的数据类型(dtype)
    def dtype(self) -> torch.dtype:
        """
        Returns the dtype of the underlying rotation.

        Returns:
            The dtype of the underlying rotation
        """
        # 如果存储旋转矩阵不为空,则返回其数据类型
        if self._rot_mats is not None:
            return self._rot_mats.dtype
        # 如果存储四元数不为空,则返回其数据类型
        elif self._quats is not None:
            return self._quats.dtype
        # 如果旋转矩阵和四元数都为空,则抛出数值错误异常
        else:
            raise ValueError("Both rotations are None")

    @property
    # 返回基础旋转所在的设备(device)
    def device(self) -> torch.device:
        """
        The device of the underlying rotation

        Returns:
            The device of the underlying rotation
        """
        # 如果存储旋转矩阵不为空,则返回其所在设备
        if self._rot_mats is not None:
            return self._rot_mats.device
        # 如果存储四元数不为空,则返回其所在设备
        elif self._quats is not None:
            return self._quats.device
        # 如果旋转矩阵和四元数都为空,则抛出数值错误异常
        else:
            raise ValueError("Both rotations are None")

    @property
    # 返回基础旋转张量是否需要梯度计算(requires_grad)
    def requires_grad(self) -> bool:
        """
        Returns the requires_grad property of the underlying rotation

        Returns:
            The requires_grad property of the underlying tensor
        """
        # 如果存储旋转矩阵不为空,则返回其是否需要梯度计算的属性
        if self._rot_mats is not None:
            return self._rot_mats.requires_grad
        # 如果存储四元数不为空,则返回其是否需要梯度计算的属性
        elif self._quats is not None:
            return self._quats.requires_grad
        # 如果旋转矩阵和四元数都为空,则抛出数值错误异常
        else:
            raise ValueError("Both rotations are None")

    # 返回基础旋转矩阵张量
    def get_rot_mats(self) -> torch.Tensor:
        """
        Returns the underlying rotation as a rotation matrix tensor.

        Returns:
            The rotation as a rotation matrix tensor
        """
        # 如果存储旋转矩阵不为空,则直接返回其存储的旋转矩阵张量
        if self._rot_mats is not None:
            return self._rot_mats
        # 如果存储四元数不为空,则将四元数转换为旋转矩阵张量并返回
        elif self._quats is not None:
            return quat_to_rot(self._quats)
        # 如果旋转矩阵和四元数都为空,则抛出数值错误异常
        else:
            raise ValueError("Both rotations are None")

    # 返回基础四元数张量
    def get_quats(self) -> torch.Tensor:
        """
        Returns the underlying rotation as a quaternion tensor.

        Depending on whether the Rotation was initialized with a quaternion, this function may call torch.linalg.eigh.

        Returns:
            The rotation as a quaternion tensor.
        """
        # 如果存储旋转矩阵不为空,则将旋转矩阵转换为四元数张量并返回
        if self._rot_mats is not None:
            return rot_to_quat(self._rot_mats)
        # 如果存储四元数不为空,则直接返回其存储的四元数张量
        elif self._quats is not None:
            return self._quats
        # 如果旋转矩阵和四元数都为空,则抛出数值错误异常
        else:
            raise ValueError("Both rotations are None")

    # 返回当前存储的旋转数据
    def get_cur_rot(self) -> torch.Tensor:
        """
        Return the underlying rotation in its current form

        Returns:
            The stored rotation
        """
        # 如果存储旋转矩阵不为空,则返回其存储的旋转矩阵张量
        if self._rot_mats is not None:
            return self._rot_mats
        # 如果存储四元数不为空,则返回其存储的四元数张量
        elif self._quats is not None:
            return self._quats
        # 如果旋转矩阵和四元数都为空,则抛出数值错误异常
        else:
            raise ValueError("Both rotations are None")

    # 旋转函数
    # 定义一个方法,用于计算并返回一个新的四元数旋转对象,通过一个四元数更新向量更新当前对象的底层旋转。
    # 更新向量以 [*, 3] 张量格式表示,最后三列表示 x、y、z,使得 (1, x, y, z) 是期望的(不一定是单位)四元数更新。
    
    def compose_q_update_vec(self, q_update_vec: torch.Tensor, normalize_quats: bool = True) -> Rotation:
        """
        Returns a new quaternion Rotation after updating the current object's underlying rotation with a quaternion
        update, formatted as a [*, 3] tensor whose final three columns represent x, y, z such that (1, x, y, z) is the
        desired (not necessarily unit) quaternion update.
    
        Args:
            q_update_vec:
                A [*, 3] quaternion update tensor
            normalize_quats:
                Whether to normalize the output quaternion
        Returns:
            An updated Rotation
        """
        # 获取当前对象的四元数
        quats = self.get_quats()
        # 计算新的四元数,通过当前四元数和更新向量相乘
        new_quats = quats + quat_multiply_by_vec(quats, q_update_vec)
        # 返回一个新的 Rotation 对象,使用新的四元数,可以选择是否归一化
        return Rotation(
            rot_mats=None,
            quats=new_quats,
            normalize_quats=normalize_quats,
        )
    
    # 定义一个方法,用于将当前 Rotation 对象的旋转矩阵与另一个 Rotation 对象的旋转矩阵组合。
    # 返回一个包含组合后旋转矩阵的新 Rotation 对象。
    
    def compose_r(self, r: Rotation) -> Rotation:
        """
        Compose the rotation matrices of the current Rotation object with those of another.
    
        Args:
            r:
                An update rotation object
        Returns:
            An updated rotation object
        """
        # 获取当前对象和参数对象的旋转矩阵
        r1 = self.get_rot_mats()
        r2 = r.get_rot_mats()
        # 计算新的旋转矩阵,通过矩阵乘法将两个旋转矩阵相乘
        new_rot_mats = rot_matmul(r1, r2)
        # 返回一个新的 Rotation 对象,使用新的旋转矩阵
        return Rotation(rot_mats=new_rot_mats, quats=None)
    
    # 定义一个方法,用于将当前 Rotation 对象的四元数与另一个 Rotation 对象的四元数组合。
    # 返回一个包含组合后四元数的新 Rotation 对象。
    
    def compose_q(self, r: Rotation, normalize_quats: bool = True) -> Rotation:
        """
        Compose the quaternions of the current Rotation object with those of another.
    
        Depending on whether either Rotation was initialized with quaternions, this function may call
        torch.linalg.eigh.
    
        Args:
            r:
                An update rotation object
        Returns:
            An updated rotation object
        """
        # 获取当前对象和参数对象的四元数
        q1 = self.get_quats()
        q2 = r.get_quats()
        # 计算新的四元数,通过四元数乘法将两个四元数相乘
        new_quats = quat_multiply(q1, q2)
        # 返回一个新的 Rotation 对象,使用新的四元数,可以选择是否归一化
        return Rotation(rot_mats=None, quats=new_quats, normalize_quats=normalize_quats)
    
    # 定义一个方法,将当前 Rotation 对象的旋转矩阵作为旋转矩阵应用到一组 3D 坐标上。
    # 返回一个 [*, 3] 形状的旋转后的点坐标集合。
    
    def apply(self, pts: torch.Tensor) -> torch.Tensor:
        """
        Apply the current Rotation as a rotation matrix to a set of 3D coordinates.
    
        Args:
            pts:
                A [*, 3] set of points
        Returns:
            [*, 3] rotated points
        """
        # 获取当前对象的旋转矩阵
        rot_mats = self.get_rot_mats()
        # 将旋转矩阵应用到点集合上,返回旋转后的点坐标
        return rot_vec_mul(rot_mats, pts)
    
    # 定义一个方法,将当前 Rotation 对象的逆旋转矩阵作为旋转矩阵应用到一组 3D 坐标上。
    # 返回一个 [*, 3] 形状的逆旋转后的点坐标集合。
    
    def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
        """
        The inverse of the apply() method.
    
        Args:
            pts:
                A [*, 3] set of points
        Returns:
            [*, 3] inverse-rotated points
        """
        # 获取当前对象的旋转矩阵
        rot_mats = self.get_rot_mats()
        # 计算旋转矩阵的逆矩阵
        inv_rot_mats = invert_rot_mat(rot_mats)
        # 将逆旋转矩阵应用到点集合上,返回逆旋转后的点坐标
        return rot_vec_mul(inv_rot_mats, pts)
    def invert(self) -> Rotation:
        """
        Returns the inverse of the current Rotation.

        Returns:
            The inverse of the current Rotation
        """
        # 如果旋转矩阵不为 None,则返回其逆矩阵对应的 Rotation 对象
        if self._rot_mats is not None:
            return Rotation(rot_mats=invert_rot_mat(self._rot_mats), quats=None)
        # 如果四元数不为 None,则返回其逆四元数对应的 Rotation 对象
        elif self._quats is not None:
            return Rotation(
                rot_mats=None,
                quats=invert_quat(self._quats),
                normalize_quats=False,
            )
        else:
            # 如果旋转矩阵和四元数都为 None,则抛出数值错误异常
            raise ValueError("Both rotations are None")

    # "Tensor" stuff

    def unsqueeze(self, dim: int) -> Rotation:
        """
        Analogous to torch.unsqueeze. The dimension is relative to the shape of the Rotation object.

        Args:
            dim: A positive or negative dimension index.
        Returns:
            The unsqueezed Rotation.
        """
        # 如果指定的维度超出了 Rotation 对象的形状范围,则抛出数值错误异常
        if dim >= len(self.shape):
            raise ValueError("Invalid dimension")

        # 如果有旋转矩阵,则按照指定维度对旋转矩阵进行 unsqueeze 操作
        if self._rot_mats is not None:
            rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2)
            return Rotation(rot_mats=rot_mats, quats=None)
        # 如果有四元数,则按照指定维度对四元数进行 unsqueeze 操作
        elif self._quats is not None:
            quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1)
            return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
        else:
            # 如果旋转矩阵和四元数都为 None,则抛出数值错误异常
            raise ValueError("Both rotations are None")

    @staticmethod
    def cat(rs: Sequence[Rotation], dim: int) -> Rotation:
        """
        Concatenates rotations along one of the batch dimensions. Analogous to torch.cat().

        Note that the output of this operation is always a rotation matrix, regardless of the format of input
        rotations.

        Args:
            rs:
                A list of rotation objects
            dim:
                The dimension along which the rotations should be concatenated
        Returns:
            A concatenated Rotation object in rotation matrix format
        """
        # 将输入的 Rotation 对象列表中的旋转矩阵沿指定维度进行拼接
        rot_mats = torch.cat(
            [r.get_rot_mats() for r in rs],
            dim=dim if dim >= 0 else dim - 2,
        )

        return Rotation(rot_mats=rot_mats, quats=None)
    def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rotation:
        """
        Apply a Tensor -> Tensor function to underlying rotation tensors, mapping over the rotation dimension(s). Can
        be used e.g. to sum out a one-hot batch dimension.

        Args:
            fn:
                A Tensor -> Tensor function to be mapped over the Rotation
        Returns:
            The transformed Rotation object
        """
        # 如果存在旋转矩阵 _rot_mats
        if self._rot_mats is not None:
            # 将 _rot_mats 的形状调整为去掉最后两个维度后再加上一个长度为 9 的维度的形状
            rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))
            # 解绑 _rot_mats 的最后一个维度,并对每个解绑后的张量应用函数 fn,然后重新堆叠起来
            rot_mats = torch.stack(list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1)
            # 将 rot_mats 的形状调整回去,去掉最后一个维度并加上一个 3x3 的形状
            rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
            # 返回一个新的 Rotation 对象,传入新的旋转矩阵 rot_mats 和 None 的 quats
            return Rotation(rot_mats=rot_mats, quats=None)
        # 如果存在四元数 _quats
        elif self._quats is not None:
            # 对 _quats 解绑最后一个维度,并对每个解绑后的张量应用函数 fn,然后重新堆叠起来
            quats = torch.stack(list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1)
            # 返回一个新的 Rotation 对象,传入 None 的 rot_mats 和新的 quats,以及不需要归一化的标志
            return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
        else:
            # 如果 _rot_mats 和 _quats 都是 None,则抛出异常
            raise ValueError("Both rotations are None")

    def cuda(self) -> Rotation:
        """
        Analogous to the cuda() method of torch Tensors

        Returns:
            A copy of the Rotation in CUDA memory
        """
        # 如果存在旋转矩阵 _rot_mats
        if self._rot_mats is not None:
            # 将 _rot_mats 移动到 CUDA 内存,并返回一个新的 Rotation 对象
            return Rotation(rot_mats=self._rot_mats.cuda(), quats=None)
        # 如果存在四元数 _quats
        elif self._quats is not None:
            # 将 _quats 移动到 CUDA 内存,并返回一个新的 Rotation 对象,不需要归一化
            return Rotation(rot_mats=None, quats=self._quats.cuda(), normalize_quats=False)
        else:
            # 如果 _rot_mats 和 _quats 都是 None,则抛出异常
            raise ValueError("Both rotations are None")

    def to(self, device: Optional[torch.device], dtype: Optional[torch.dtype]) -> Rotation:
        """
        Analogous to the to() method of torch Tensors

        Args:
            device:
                A torch device
            dtype:
                A torch dtype
        Returns:
            A copy of the Rotation using the new device and dtype
        """
        # 如果存在旋转矩阵 _rot_mats
        if self._rot_mats is not None:
            # 将 _rot_mats 转移到指定的 device 和 dtype,并返回一个新的 Rotation 对象
            return Rotation(
                rot_mats=self._rot_mats.to(device=device, dtype=dtype),
                quats=None,
            )
        # 如果存在四元数 _quats
        elif self._quats is not None:
            # 将 _quats 转移到指定的 device 和 dtype,并返回一个新的 Rotation 对象,不需要归一化
            return Rotation(
                rot_mats=None,
                quats=self._quats.to(device=device, dtype=dtype),
                normalize_quats=False,
            )
        else:
            # 如果 _rot_mats 和 _quats 都是 None,则抛出异常
            raise ValueError("Both rotations are None")
    # 返回一个 Rotation 对象的副本,其中底层的 Tensor 已从其 torch 图中分离
    def detach(self) -> Rotation:
        """
        Returns a copy of the Rotation whose underlying Tensor has been detached from its torch graph.

        Returns:
            A copy of the Rotation whose underlying Tensor has been detached from its torch graph
        """
        # 如果 _rot_mats 不为 None,则返回一个新的 Rotation 对象,其 rot_mats 被分离
        if self._rot_mats is not None:
            return Rotation(rot_mats=self._rot_mats.detach(), quats=None)
        # 如果 _quats 不为 None,则返回一个新的 Rotation 对象,其 quats 被分离
        elif self._quats is not None:
            return Rotation(
                rot_mats=None,
                quats=self._quats.detach(),
                normalize_quats=False,
            )
        else:
            # 如果 _rot_mats 和 _quats 都是 None,则抛出数值错误异常
            raise ValueError("Both rotations are None")
class Rigid:
    """
    A class representing a rigid transformation. Little more than a wrapper around two objects: a Rotation object and a
    [*, 3] translation Designed to behave approximately like a single torch tensor with the shape of the shared batch
    dimensions of its component parts.
    """

    def __init__(self, rots: Optional[Rotation], trans: Optional[torch.Tensor]):
        """
        Args:
            rots: A [*, 3, 3] rotation tensor
            trans: A corresponding [*, 3] translation tensor
        """
        # 根据输入参数确定 batch 维度、数据类型、设备和梯度设置
        batch_dims, dtype, device, requires_grad = None, None, None, None
        if trans is not None:
            batch_dims = trans.shape[:-1]  # 获取除最后一维外的所有维度,即 batch 维度
            dtype = trans.dtype  # 获取数据类型
            device = trans.device  # 获取设备
            requires_grad = trans.requires_grad  # 获取梯度需求设置
        elif rots is not None:
            batch_dims = rots.shape  # 获取 rots 的形状作为 batch 维度
            dtype = rots.dtype  # 获取数据类型
            device = rots.device  # 获取设备
            requires_grad = rots.requires_grad  # 获取梯度需求设置
        else:
            raise ValueError("At least one input argument must be specified")  # 抛出数值错误,至少需要指定一个输入参数

        # 如果 rots 为 None,则使用 identity 方法创建默认的 Rotation 对象
        if rots is None:
            rots = Rotation.identity(
                batch_dims,
                dtype,
                device,
                requires_grad,
            )
        # 如果 trans 为 None,则使用 identity_trans 函数创建默认的 translation tensor
        elif trans is None:
            trans = identity_trans(
                batch_dims,
                dtype,
                device,
                requires_grad,
            )

        assert rots is not None
        assert trans is not None

        # 检查 rots 和 trans 的形状和设备是否兼容
        if (rots.shape != trans.shape[:-1]) or (rots.device != trans.device):
            raise ValueError("Rots and trans incompatible")  # 抛出数值错误,rots 和 trans 不兼容

        # 强制将 trans 转换为 torch.float32 数据类型
        trans = trans.to(dtype=torch.float32)

        self._rots = rots  # 将 rots 赋值给对象的 _rots 属性
        self._trans = trans  # 将 trans 赋值给对象的 _trans 属性

    @staticmethod
    def identity(
        shape: Tuple[int, ...],
        dtype: Optional[torch.dtype] = None,
        device: Optional[torch.device] = None,
        requires_grad: bool = True,
        fmt: str = "quat",
    ) -> Rigid:
        """
        Constructs an identity transformation.

        Args:
            shape:
                The desired shape
            dtype:
                The dtype of both internal tensors
            device:
                The device of both internal tensors
            requires_grad:
                Whether grad should be enabled for the internal tensors
        Returns:
            The identity transformation
        """
        # 使用 Rotation.identity 和 identity_trans 函数创建 identity transformation 对象并返回
        return Rigid(
            Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt),
            identity_trans(shape, dtype, device, requires_grad),
        )
    def __getitem__(self, index: Any) -> Rigid:
        """
        Indexes the affine transformation with PyTorch-style indices. The index is applied to the shared dimensions of
        both the rotation and the translation.

        E.g.::

            r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None)
            t = Rigid(r, torch.rand(10, 10, 3))
            indexed = t[3, 4:6]
            assert(indexed.shape == (2,))
            assert(indexed.get_rots().shape == (2,))
            assert(indexed.get_trans().shape == (2, 3))

        Args:
            index: A standard torch tensor index. E.g. 8, (10, None, 3),
            or (3, slice(0, 1, None))
        Returns:
            The indexed tensor
        """
        # 如果索引不是元组,则转换为元组形式
        if type(index) != tuple:
            index = (index,)

        # 返回一个新的 Rigid 对象,通过索引获取对应的旋转矩阵和平移向量
        return Rigid(
            self._rots[index],  # 使用索引获取旋转矩阵的子集
            self._trans[index + (slice(None),)],  # 使用索引获取平移向量的子集
        )

    def __mul__(self, right: torch.Tensor) -> Rigid:
        """
        Pointwise left multiplication of the transformation with a tensor. Can be used to e.g. mask the Rigid.

        Args:
            right:
                The tensor multiplicand
        Returns:
            The product
        """
        # 如果 right 不是 torch.Tensor 类型,则抛出类型错误异常
        if not (isinstance(right, torch.Tensor)):
            raise TypeError("The other multiplicand must be a Tensor")

        # 对旋转矩阵和平移向量分别进行点乘操作
        new_rots = self._rots * right  # 对旋转矩阵进行点乘
        new_trans = self._trans * right[..., None]  # 对平移向量进行点乘(在最后一个维度上扩展)

        # 返回一个新的 Rigid 对象,代表点乘后的结果
        return Rigid(new_rots, new_trans)

    def __rmul__(self, left: torch.Tensor) -> Rigid:
        """
        Reverse pointwise multiplication of the transformation with a tensor.

        Args:
            left:
                The left multiplicand
        Returns:
            The product
        """
        # 调用 __mul__ 方法进行反向点乘操作
        return self.__mul__(left)

    @property
    def shape(self) -> torch.Size:
        """
        Returns the shape of the shared dimensions of the rotation and the translation.

        Returns:
            The shape of the transformation
        """
        # 返回旋转矩阵和平移向量共享维度的形状(去掉最后一个维度)
        return self._trans.shape[:-1]

    @property
    def device(self) -> torch.device:
        """
        Returns the device on which the Rigid's tensors are located.

        Returns:
            The device on which the Rigid's tensors are located
        """
        # 返回平移向量所在的设备
        return self._trans.device

    def get_rots(self) -> Rotation:
        """
        Getter for the rotation.

        Returns:
            The rotation object
        """
        # 返回存储的旋转矩阵对象
        return self._rots

    def get_trans(self) -> torch.Tensor:
        """
        Getter for the translation.

        Returns:
            The stored translation
        """
        # 返回存储的平移向量
        return self._trans
    def compose_q_update_vec(self, q_update_vec: torch.Tensor) -> Rigid:
        """
        Composes the transformation with a quaternion update vector of shape [*, 6], where the final 6 columns
        represent the x, y, and z values of a quaternion of form (1, x, y, z) followed by a 3D translation.

        Args:
            q_vec: The quaternion update vector.
        Returns:
            The composed transformation.
        """
        # Extract quaternion update vector and translation vector
        q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:]
        # Compose rotations with quaternion update vector
        new_rots = self._rots.compose_q_update_vec(q_vec)

        # Apply rotations to translation vector
        trans_update = self._rots.apply(t_vec)
        # Calculate new translation by adding current translation with applied rotations
        new_translation = self._trans + trans_update

        # Return composed transformation
        return Rigid(new_rots, new_translation)

    def compose(self, r: Rigid) -> Rigid:
        """
        Composes the current rigid object with another.

        Args:
            r:
                Another Rigid object
        Returns:
            The composition of the two transformations
        """
        # Compose rotations of current object with rotations of another object
        new_rot = self._rots.compose_r(r._rots)
        # Apply rotations of current object to translation of another object and add current translation
        new_trans = self._rots.apply(r._trans) + self._trans
        # Return composed transformation
        return Rigid(new_rot, new_trans)

    def apply(self, pts: torch.Tensor) -> torch.Tensor:
        """
        Applies the transformation to a coordinate tensor.

        Args:
            pts: A [*, 3] coordinate tensor.
        Returns:
            The transformed points.
        """
        # Apply rotations to the input points
        rotated = self._rots.apply(pts)
        # Add translation to the rotated points
        return rotated + self._trans

    def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
        """
        Applies the inverse of the transformation to a coordinate tensor.

        Args:
            pts: A [*, 3] coordinate tensor
        Returns:
            The transformed points.
        """
        # Subtract translation from the input points
        pts = pts - self._trans
        # Apply inverse rotations to the translated points
        return self._rots.invert_apply(pts)

    def invert(self) -> Rigid:
        """
        Inverts the transformation.

        Returns:
            The inverse transformation.
        """
        # Invert rotations
        rot_inv = self._rots.invert()
        # Apply inverted rotations to current translation and negate it
        trn_inv = rot_inv.apply(self._trans)

        # Return inverted transformation
        return Rigid(rot_inv, -1 * trn_inv)

    def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rigid:
        """
        Apply a Tensor -> Tensor function to underlying translation and rotation tensors, mapping over the
        translation/rotation dimensions respectively.

        Args:
            fn:
                A Tensor -> Tensor function to be mapped over the Rigid
        Returns:
            The transformed Rigid object
        """
        # Apply function to rotation tensors
        new_rots = self._rots.map_tensor_fn(fn)
        # Map function over each dimension of translation tensor and stack results
        new_trans = torch.stack(list(map(fn, torch.unbind(self._trans, dim=-1))), dim=-1)

        # Return transformed Rigid object
        return Rigid(new_rots, new_trans)
    def to_tensor_4x4(self) -> torch.Tensor:
        """
        Converts a transformation to a homogenous transformation tensor.

        Returns:
            A [*, 4, 4] homogenous transformation tensor
        """
        # 创建一个与当前对象形状相同的全零张量,形状为 [*self.shape, 4, 4]
        tensor = self._trans.new_zeros((*self.shape, 4, 4))
        # 将旋转矩阵填充到张量的前三行前三列
        tensor[..., :3, :3] = self._rots.get_rot_mats()
        # 将平移矢量填充到张量的前三行最后一列
        tensor[..., :3, 3] = self._trans
        # 最后一个元素设为1,构成齐次变换张量
        tensor[..., 3, 3] = 1
        return tensor

    @staticmethod
    def from_tensor_4x4(t: torch.Tensor) -> Rigid:
        """
        Constructs a transformation from a homogenous transformation tensor.

        Args:
            t: [*, 4, 4] homogenous transformation tensor
        Returns:
            T object with shape [*]
        """
        # 检查输入张量的形状是否为 [*, 4, 4]
        if t.shape[-2:] != (4, 4):
            raise ValueError("Incorrectly shaped input tensor")

        # 从输入张量中提取旋转矩阵部分
        rots = Rotation(rot_mats=t[..., :3, :3], quats=None)
        # 从输入张量中提取平移矢量部分
        trans = t[..., :3, 3]

        # 返回一个 Rigid 类对象,其中包含旋转矩阵和平移矢量
        return Rigid(rots, trans)

    def to_tensor_7(self) -> torch.Tensor:
        """
        Converts a transformation to a tensor with 7 final columns, four for the quaternion followed by three for the
        translation.

        Returns:
            A [*, 7] tensor representation of the transformation
        """
        # 创建一个与当前对象形状相同的全零张量,形状为 [*self.shape, 7]
        tensor = self._trans.new_zeros((*self.shape, 7))
        # 将四元数填充到张量的前四列
        tensor[..., :4] = self._rots.get_quats()
        # 将平移矢量填充到张量的后三列
        tensor[..., 4:] = self._trans

        return tensor

    @staticmethod
    def from_tensor_7(t: torch.Tensor, normalize_quats: bool = False) -> Rigid:
        # 检查输入张量的形状是否为 [*..., 7]
        if t.shape[-1] != 7:
            raise ValueError("Incorrectly shaped input tensor")

        # 从输入张量中提取四元数部分和平移矢量部分
        quats, trans = t[..., :4], t[..., 4:]

        # 根据四元数和是否需要标准化创建 Rotation 对象
        rots = Rotation(rot_mats=None, quats=quats, normalize_quats=normalize_quats)

        # 返回一个 Rigid 类对象,其中包含旋转部分和平移部分
        return Rigid(rots, trans)

    @staticmethod
    def from_3_points(
        p_neg_x_axis: torch.Tensor, origin: torch.Tensor, p_xy_plane: torch.Tensor, eps: float = 1e-8
    ):
        # 该方法未提供完整的实现,需要进一步的代码来完成
        pass
    ) -> Rigid:
        """
        Implements algorithm 21. Constructs transformations from sets of 3 points using the Gram-Schmidt algorithm.

        Args:
            p_neg_x_axis: [*, 3] coordinates
                Coordinates of points defining the negative x-axis direction
            origin: [*, 3] coordinates used as frame origins
                Coordinates of points defining the origin of the frame
            p_xy_plane: [*, 3] coordinates
                Coordinates of points defining the xy-plane orientation
            eps: Small epsilon value
                Small value added to avoid division by zero
        Returns:
            A transformation object of shape [*]
        """
        # Unbind tensors along the last dimension
        p_neg_x_axis_unbound = torch.unbind(p_neg_x_axis, dim=-1)
        origin_unbound = torch.unbind(origin, dim=-1)
        p_xy_plane_unbound = torch.unbind(p_xy_plane, dim=-1)

        # Calculate the first two orthonormal vectors e0 and e1 using Gram-Schmidt process
        e0 = [c1 - c2 for c1, c2 in zip(origin_unbound, p_neg_x_axis_unbound)]
        e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane_unbound, origin_unbound)]

        # Normalize e0
        denom = torch.sqrt(sum(c * c for c in e0) + eps * torch.ones_like(e0[0]))
        e0 = [c / denom for c in e0]

        # Calculate e1 orthogonal to e0 and normalize it
        dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
        e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
        denom = torch.sqrt(sum((c * c for c in e1)) + eps * torch.ones_like(e1[0]))
        e1 = [c / denom for c in e1]

        # Calculate the third orthonormal vector e2
        e2 = [
            e0[1] * e1[2] - e0[2] * e1[1],
            e0[2] * e1[0] - e0[0] * e1[2],
            e0[0] * e1[1] - e0[1] * e1[0],
        ]

        # Stack e0, e1, and e2 to form rotation matrices and reshape into proper shape
        rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
        rots = rots.reshape(rots.shape[:-1] + (3, 3))

        # Create a Rotation object using the calculated rotation matrices
        rot_obj = Rotation(rot_mats=rots, quats=None)

        # Return a Rigid transformation object with the rotation and origin vectors stacked
        return Rigid(rot_obj, torch.stack(origin_unbound, dim=-1))

    def unsqueeze(self, dim: int) -> Rigid:
        """
        Analogous to torch.unsqueeze. The dimension is relative to the shared dimensions of the rotation/translation.

        Args:
            dim: A positive or negative dimension index.
        Returns:
            The unsqueezed transformation.
        """
        if dim >= len(self.shape):
            raise ValueError("Invalid dimension")
        # Unsqueeze rotation matrices and translations along the specified dimension
        rots = self._rots.unsqueeze(dim)
        trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)

        # Return a Rigid transformation object with unsqueezed rotation and translation tensors
        return Rigid(rots, trans)

    @staticmethod
    def cat(ts: Sequence[Rigid], dim: int) -> Rigid:
        """
        Concatenates transformations along a new dimension.

        Args:
            ts:
                A list of T objects
            dim:
                The dimension along which the transformations should be concatenated
        Returns:
            A concatenated transformation object
        """
        # Concatenate rotation matrices and translations along the specified dimension
        rots = Rotation.cat([t._rots for t in ts], dim)
        trans = torch.cat([t._trans for t in ts], dim=dim if dim >= 0 else dim - 1)

        # Return a Rigid transformation object with concatenated rotation and translation tensors
        return Rigid(rots, trans)
    def apply_rot_fn(self, fn: Callable[[Rotation], Rotation]) -> Rigid:
        """
        Applies a Rotation -> Rotation function to the stored rotation object.

        Args:
            fn: A function of type Rotation -> Rotation

        Returns:
            A transformation object with a transformed rotation.
        """
        return Rigid(fn(self._rots), self._trans)



    def apply_trans_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rigid:
        """
        Applies a Tensor -> Tensor function to the stored translation.

        Args:
            fn:
                A function of type Tensor -> Tensor to be applied to the translation

        Returns:
            A transformation object with a transformed translation.
        """
        return Rigid(self._rots, fn(self._trans))



    def scale_translation(self, trans_scale_factor: float) -> Rigid:
        """
        Scales the translation by a constant factor.

        Args:
            trans_scale_factor:
                The constant factor

        Returns:
            A transformation object with a scaled translation.
        """
        return self.apply_trans_fn(lambda t: t * trans_scale_factor)



    def stop_rot_gradient(self) -> Rigid:
        """
        Detaches the underlying rotation object

        Returns:
            A transformation object with detached rotations
        """
        return self.apply_rot_fn(lambda r: r.detach())



    @staticmethod
    def make_transform_from_reference(
        n_xyz: torch.Tensor, ca_xyz: torch.Tensor, c_xyz: torch.Tensor, eps: float = 1e-20
    ) -> Rigid:
        """
        Constructs a transformation object based on reference points.

        Args:
            n_xyz:
                Tensor representing N atom coordinates
            ca_xyz:
                Tensor representing C-alpha atom coordinates
            c_xyz:
                Tensor representing C atom coordinates
            eps:
                Small value to avoid division by zero (default: 1e-20)

        Returns:
            A transformation object initialized with the given reference points.
        """
    ) -> Rigid:
        """
        Returns a transformation object from reference coordinates.

        Note that this method does not take care of symmetries. If you provide the atom positions in the non-standard
        way, the N atom will end up not at [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
        need to take care of such cases in your code.

        Args:
            n_xyz: A [*, 3] tensor of nitrogen xyz coordinates.
            ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates.
            c_xyz: A [*, 3] tensor of carbon xyz coordinates.
        Returns:
            A transformation object. After applying the translation and rotation to the reference backbone, the
            coordinates will approximately equal to the input coordinates.
        """
        # Calculate translation vector by negating carbon alpha coordinates
        translation = -1 * ca_xyz
        # Translate nitrogen and carbon coordinates accordingly
        n_xyz = n_xyz + translation
        c_xyz = c_xyz + translation

        # Extract x, y, z components of carbon coordinates
        c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)]
        # Compute normalization factor with epsilon smoothing
        norm = torch.sqrt(eps + c_x**2 + c_y**2)
        # Calculate sine and cosine of the first rotation angle
        sin_c1 = -c_y / norm
        cos_c1 = c_x / norm

        # Initialize rotation matrices for the first rotation (around z-axis)
        c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3))
        c1_rots[..., 0, 0] = cos_c1
        c1_rots[..., 0, 1] = -1 * sin_c1
        c1_rots[..., 1, 0] = sin_c1
        c1_rots[..., 1, 1] = cos_c1
        c1_rots[..., 2, 2] = 1

        # Compute normalization factor with epsilon smoothing
        norm = torch.sqrt(eps + c_x**2 + c_y**2 + c_z**2)
        # Calculate sine and cosine of the second rotation angle
        sin_c2 = c_z / norm
        cos_c2 = torch.sqrt(c_x**2 + c_y**2) / norm

        # Initialize rotation matrices for the second rotation (around x-axis)
        c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
        c2_rots[..., 0, 0] = cos_c2
        c2_rots[..., 0, 2] = sin_c2
        c2_rots[..., 1, 1] = 1
        c2_rots[..., 2, 0] = -1 * sin_c2
        c2_rots[..., 2, 2] = cos_c2

        # Combine the two rotation matrices
        c_rots = rot_matmul(c2_rots, c1_rots)
        # Rotate nitrogen coordinates using the combined rotation matrix
        n_xyz = rot_vec_mul(c_rots, n_xyz)

        # Extract y, z components of rotated nitrogen coordinates
        _, n_y, n_z = [n_xyz[..., i] for i in range(3)]
        # Compute normalization factor with epsilon smoothing
        norm = torch.sqrt(eps + n_y**2 + n_z**2)
        # Calculate sine and cosine of the final rotation angle
        sin_n = -n_z / norm
        cos_n = n_y / norm

        # Initialize rotation matrices for the final rotation (around y-axis)
        n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
        n_rots[..., 0, 0] = 1
        n_rots[..., 1, 1] = cos_n
        n_rots[..., 1, 2] = -1 * sin_n
        n_rots[..., 2, 1] = sin_n
        n_rots[..., 2, 2] = cos_n

        # Combine all rotations to get the final rotation matrix
        rots = rot_matmul(n_rots, c_rots)

        # Transpose the rotation matrix
        rots = rots.transpose(-1, -2)
        # Negate translation vector
        translation = -1 * translation

        # Create a Rotation object using the computed rotation matrix
        rot_obj = Rotation(rot_mats=rots, quats=None)

        # Return a Rigid object encapsulating the rotation and translation
        return Rigid(rot_obj, translation)

    def cuda(self) -> Rigid:
        """
        Moves the transformation object to GPU memory

        Returns:
            A version of the transformation on GPU
        """
        # Move rotation and translation tensors to GPU
        return Rigid(self._rots.cuda(), self._trans.cuda())

.\models\esm\openfold_utils\tensor_utils.py

# 导入 functools 模块中的 partial 函数
from functools import partial
# 导入 typing 模块中的各种类型提示
from typing import Any, Callable, Dict, List, Type, TypeVar, Union, overload

# 导入 PyTorch 库
import torch
# 导入 torch.nn 模块
import torch.nn as nn
# 导入 torch.types 模块
import torch.types

# 定义一个函数 add,接受两个 torch.Tensor 类型的参数 m1 和 m2,以及一个布尔类型的 inplace 参数,返回一个 torch.Tensor
def add(m1: torch.Tensor, m2: torch.Tensor, inplace: bool) -> torch.Tensor:
    # 如果 inplace 参数为 False,则进行非就地操作
    # 第一个操作不能是就地操作,但在推理期间执行就地加法会更好。因此...
    if not inplace:
        m1 = m1 + m2  # 非就地加法
    else:
        m1 += m2  # 就地加法

    return m1  # 返回结果 m1

# 定义一个函数 permute_final_dims,接受一个 torch.Tensor 类型的 tensor 参数和一个 List[int] 类型的 inds 参数,返回一个 torch.Tensor
def permute_final_dims(tensor: torch.Tensor, inds: List[int]) -> torch.Tensor:
    zero_index = -1 * len(inds)
    first_inds = list(range(len(tensor.shape[:zero_index])))
    return tensor.permute(first_inds + [zero_index + i for i in inds])

# 定义一个函数 flatten_final_dims,接受一个 torch.Tensor 类型的 t 参数和一个整数类型的 no_dims 参数,返回一个 torch.Tensor
def flatten_final_dims(t: torch.Tensor, no_dims: int) -> torch.Tensor:
    return t.reshape(t.shape[:-no_dims] + (-1,))

# 定义一个函数 masked_mean,接受一个 torch.Tensor 类型的 mask 参数,一个 torch.Tensor 类型的 value 参数,一个整数类型的 dim 参数,以及一个浮点数类型的 eps 参数,默认值为 1e-4,返回一个 torch.Tensor
def masked_mean(mask: torch.Tensor, value: torch.Tensor, dim: int, eps: float = 1e-4) -> torch.Tensor:
    mask = mask.expand(*value.shape)
    return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))

# 定义一个函数 pts_to_distogram,接受一个 torch.Tensor 类型的 pts 参数,以及三个可选参数 min_bin、max_bin 和 no_bins,都是 torch.types.Number 类型,默认值分别为 2.3125、21.6875 和 64,返回一个 torch.Tensor
def pts_to_distogram(
    pts: torch.Tensor, min_bin: torch.types.Number = 2.3125, max_bin: torch.types.Number = 21.6875, no_bins: int = 64
) -> torch.Tensor:
    boundaries = torch.linspace(min_bin, max_bin, no_bins - 1, device=pts.device)
    dists = torch.sqrt(torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1))
    return torch.bucketize(dists, boundaries)

# 定义一个函数 dict_multimap,接受一个 Callable[[list], Any] 类型的 fn 参数和一个 List[dict] 类型的 dicts 参数,返回一个 dict
def dict_multimap(fn: Callable[[list], Any], dicts: List[dict]) -> dict:
    first = dicts[0]
    new_dict = {}
    for k, v in first.items():
        all_v = [d[k] for d in dicts]
        if isinstance(v, dict):
            new_dict[k] = dict_multimap(fn, all_v)
        else:
            new_dict[k] = fn(all_v)

    return new_dict

# 定义一个函数 one_hot,接受一个 torch.Tensor 类型的 x 参数和一个 torch.Tensor 类型的 v_bins 参数,返回一个 torch.Tensor
def one_hot(x: torch.Tensor, v_bins: torch.Tensor) -> torch.Tensor:
    reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
    diffs = x[..., None] - reshaped_bins
    am = torch.argmin(torch.abs(diffs), dim=-1)
    return nn.functional.one_hot(am, num_classes=len(v_bins)).float()

# 定义一个函数 batched_gather,接受一个 torch.Tensor 类型的 data 参数,一个 torch.Tensor 类型的 inds 参数,以及两个可选参数 dim 和 no_batch_dims,都是整数类型,默认值分别为 0 和 0,返回一个 torch.Tensor
def batched_gather(data: torch.Tensor, inds: torch.Tensor, dim: int = 0, no_batch_dims: int = 0) -> torch.Tensor:
    ranges: List[Union[slice, torch.Tensor]] = []
    # 遍历数据的形状的前几个维度(不包括批量维度)
    for i, s in enumerate(data.shape[:no_batch_dims]):
        # 创建一个包含从0到s-1的整数的张量
        r = torch.arange(s)
        # 根据当前维度i的索引,以及数据索引的形状,重新视图化张量r
        r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
        # 将r添加到ranges列表中
        ranges.append(r)

    # 创建一个包含slice或者张量的列表,用于处理剩余的维度
    remaining_dims: List[Union[slice, torch.Tensor]] = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
    # 将inds插入到对应的维度位置中,处理维度偏移问题
    remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
    # 将剩余的维度信息添加到ranges列表中
    ranges.extend(remaining_dims)

    # 返回根据ranges索引得到的数据
    # Matt 注意:修改此处以避免在最近的Numpy版本中使用列表作为数组索引的行为变化
    return data[tuple(ranges)]
# 使用 TypeVar 创建一个泛型变量 T,用于表示函数的参数类型
T = TypeVar("T")

# 定义 dict_map 函数,用于对字典及其嵌套结构中的各种类型进行映射操作
def dict_map(
    fn: Callable[[T], Any], dic: Dict[Any, Union[dict, list, tuple, T]], leaf_type: Type[T]
) -> Dict[Any, Union[dict, list, tuple, Any]]:
    # 创建一个新的空字典,用于存储映射后的结果
    new_dict: Dict[Any, Union[dict, list, tuple, Any]] = {}
    # 遍历输入的字典 dic 的键值对
    for k, v in dic.items():
        # 如果值 v 是字典类型,则递归调用 dict_map 对其进行映射
        if isinstance(v, dict):
            new_dict[k] = dict_map(fn, v, leaf_type)
        # 否则,调用 tree_map 函数对 v 进行映射(tree_map 函数在后面定义)
        else:
            new_dict[k] = tree_map(fn, v, leaf_type)

    return new_dict


# tree_map 函数的重载定义:处理输入 tree 为单个元素的情况
@overload
def tree_map(fn: Callable[[T], Any], tree: T, leaf_type: Type[T]) -> Any:
    ...


# tree_map 函数的重载定义:处理输入 tree 为字典的情况
@overload
def tree_map(fn: Callable[[T], Any], tree: dict, leaf_type: Type[T]) -> dict:
    ...


# tree_map 函数的重载定义:处理输入 tree 为列表的情况
@overload
def tree_map(fn: Callable[[T], Any], tree: list, leaf_type: Type[T]) -> list:
    ...


# tree_map 函数的重载定义:处理输入 tree 为元组的情况
@overload
def tree_map(fn: Callable[[T], Any], tree: tuple, leaf_type: Type[T]) -> tuple:
    ...


# 定义 tree_map 函数,用于对树状数据结构 tree 进行映射操作
def tree_map(fn, tree, leaf_type):
    # 如果 tree 是字典类型,则调用 dict_map 对其进行映射
    if isinstance(tree, dict):
        return dict_map(fn, tree, leaf_type)
    # 如果 tree 是列表类型,则递归调用 tree_map 对其内部每个元素进行映射
    elif isinstance(tree, list):
        return [tree_map(fn, x, leaf_type) for x in tree]
    # 如果 tree 是元组类型,则递归调用 tree_map 对其内部每个元素进行映射,并返回元组
    elif isinstance(tree, tuple):
        return tuple(tree_map(fn, x, leaf_type) for x in tree)
    # 如果 tree 是 leaf_type 类型,则直接调用 fn 对其进行映射
    elif isinstance(tree, leaf_type):
        return fn(tree)
    # 如果 tree 不属于以上任何类型,则抛出 ValueError 异常
    else:
        print(type(tree))
        raise ValueError("Not supported")


# 使用 partial 函数创建 tensor_tree_map 函数,固定 leaf_type 参数为 torch.Tensor 类型
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)

.\models\esm\openfold_utils\__init__.py

# 导入自定义模块中的函数和类

from .chunk_utils import chunk_layer
# 导入 chunk_layer 函数,用于处理数据的分块操作

from .data_transforms import make_atom14_masks
# 导入 make_atom14_masks 函数,用于生成 atom14 掩码

from .feats import atom14_to_atom37, frames_and_literature_positions_to_atom14_pos, torsion_angles_to_frames
# 导入三个函数:atom14_to_atom37、frames_and_literature_positions_to_atom14_pos、torsion_angles_to_frames,
# 用于特征转换和处理

from .loss import compute_predicted_aligned_error, compute_tm
# 导入 compute_predicted_aligned_error 和 compute_tm 函数,用于计算损失

from .protein import Protein as OFProtein
# 导入 Protein 类,并将其命名为 OFProtein,用于处理蛋白质数据

from .protein import to_pdb
# 导入 to_pdb 函数,用于将蛋白质数据输出为 PDB 文件

from .rigid_utils import Rigid, Rotation
# 导入 Rigid 和 Rotation 类,用于刚体和旋转操作的工具函数

from .tensor_utils import dict_multimap, flatten_final_dims, permute_final_dims
# 导入 dict_multimap、flatten_final_dims 和 permute_final_dims 函数,
# 用于处理张量的映射、维度展平和维度置换操作

.\models\esm\tokenization_esm.py

# coding=utf-8
# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for ESM."""
import os
from typing import List, Optional

from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging

# 获取名为 logging 的日志记录器
logger = logging.get_logger(__name__)

# 定义词汇文件的名称
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}

# 定义预训练模型对应的词汇文件映射
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "facebook/esm2_t6_8M_UR50D": "https://huggingface.co/facebook/esm2_t6_8M_UR50D/resolve/main/vocab.txt",
        "facebook/esm2_t12_35M_UR50D": "https://huggingface.co/facebook/esm2_t12_35M_UR50D/resolve/main/vocab.txt",
    },
}

# 定义预训练模型对应的位置嵌入大小映射
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "facebook/esm2_t6_8M_UR50D": 1024,
    "facebook/esm2_t12_35M_UR50D": 1024,
}


def load_vocab_file(vocab_file):
    # 打开给定路径的词汇文件,并将内容按行读取为列表
    with open(vocab_file, "r") as f:
        lines = f.read().splitlines()
        return [l.strip() for l in lines]


class EsmTokenizer(PreTrainedTokenizer):
    """
    Constructs an ESM tokenizer.
    """

    # 设置类属性:词汇文件名
    vocab_files_names = VOCAB_FILES_NAMES
    # 设置类属性:预训练模型对应的词汇文件映射
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    # 设置类属性:预训练模型对应的位置嵌入大小映射
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    # 设置类属性:模型输入名称列表
    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        vocab_file,
        unk_token="<unk>",
        cls_token="<cls>",
        pad_token="<pad>",
        mask_token="<mask>",
        eos_token="<eos>",
        **kwargs,
    ):
        # 加载词汇文件中的所有词汇,并构建词汇表
        self.all_tokens = load_vocab_file(vocab_file)
        self._id_to_token = dict(enumerate(self.all_tokens))
        self._token_to_id = {tok: ind for ind, tok in enumerate(self.all_tokens)}
        super().__init__(
            unk_token=unk_token,
            cls_token=cls_token,
            pad_token=pad_token,
            mask_token=mask_token,
            eos_token=eos_token,
            **kwargs,
        )

        # TODO, all the tokens are added? But they are also part of the vocab... bit strange.
        # none of them are special, but they all need special splitting.

        # 将所有词汇加入到不需要拆分的特殊标记列表中
        self.unique_no_split_tokens = self.all_tokens
        # 更新基于特殊标记列表的 Trie 数据结构
        self._update_trie(self.unique_no_split_tokens)

    def _convert_id_to_token(self, index: int) -> str:
        # 根据索引将其转换为对应的词汇,若索引不存在则返回未知标记
        return self._id_to_token.get(index, self.unk_token)

    def _convert_token_to_id(self, token: str) -> int:
        # 根据词汇将其转换为对应的索引,若词汇不存在则返回未知标记的索引
        return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
    # 将输入文本按空格分割,返回分割后的列表作为结果
    def _tokenize(self, text, **kwargs):
        return text.split()

    # 返回包含基础词汇的字典,包括_token_to_id和added_tokens_encoder的合并
    def get_vocab(self):
        base_vocab = self._token_to_id.copy()
        base_vocab.update(self.added_tokens_encoder)
        return base_vocab

    # 根据给定的token返回其对应的id,如果token不存在则返回unk_token对应的id
    def token_to_id(self, token: str) -> int:
        return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))

    # 根据给定的index返回对应的token,如果index不存在则返回unk_token
    def id_to_token(self, index: int) -> str:
        return self._id_to_token.get(index, self.unk_token)

    # 构建包含特殊token的输入列表,处理单个或两个序列的情况
    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        cls = [self.cls_token_id]
        sep = [self.eos_token_id]  # ESM词汇表中没有sep token
        if token_ids_1 is None:
            if self.eos_token_id is None:
                return cls + token_ids_0
            else:
                return cls + token_ids_0 + sep
        elif self.eos_token_id is None:
            raise ValueError("Cannot tokenize multiple sequences when EOS token is not set!")
        return cls + token_ids_0 + sep + token_ids_1 + sep  # 多个输入始终有一个EOS token

    # 获取不包含特殊token的token列表的特殊token掩码
    def get_special_tokens_mask(
        self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
        """
        检索没有添加特殊token的token列表的序列id。当使用tokenizer的`prepare_for_model`或`encode_plus`方法添加特殊token时调用此方法。
        
        Args:
            token_ids_0 (`List[int]`):
                第一个序列的id列表。
            token_ids_1 (`List[int]`, *可选*):
                第二个序列的id列表。
            already_has_special_tokens (`bool`, *可选*, 默认为 `False`):
                token列表是否已经格式化包含了模型的特殊token。

        Returns:
            一个整数列表,范围为[0, 1]:1表示特殊token,0表示序列token。
        """
        if already_has_special_tokens:
            if token_ids_1 is not None:
                raise ValueError(
                    "You should not supply a second sequence if the provided sequence of "
                    "ids is already formatted with special tokens for the model."
                )

            return [1 if token in self.all_special_ids else 0 for token in token_ids_0]
        
        # 创建一个mask列表,标识特殊token的位置
        mask = [1] + ([0] * len(token_ids_0)) + [1]
        if token_ids_1 is not None:
            mask += [0] * len(token_ids_1) + [1]
        return mask

    # 将词汇表保存到指定目录下的文件中,文件名由filename_prefix和vocab.txt组成
    def save_vocabulary(self, save_directory, filename_prefix):
        vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.txt")
        with open(vocab_file, "w") as f:
            f.write("\n".join(self.all_tokens))
        return (vocab_file,)

    # 返回词汇表的大小,即all_tokens的长度
    @property
    def vocab_size(self) -> int:
        return len(self.all_tokens)

.\models\esm\__init__.py

# 版权声明和许可证信息
#
# 版权所有 2022 年 Facebook 和 HuggingFace 团队。保留所有权利。
# 
# 根据 Apache 许可证 2.0 版本(“许可证”)许可;
# 除非符合许可证的规定,否则不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则按“原样”分发软件,
# 没有任何明示或暗示的保证或条件。
# 请查阅许可证获取具体语言的权限和限制。
from typing import TYPE_CHECKING

# 从相对路径导入必要的模块和类
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available

# 定义模块的导入结构
_import_structure = {
    "configuration_esm": ["ESM_PRETRAINED_CONFIG_ARCHIVE_MAP", "EsmConfig"],
    "tokenization_esm": ["EsmTokenizer"],
}

# 检查是否 Torch 可用,否则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果 Torch 可用,导入 Torch 版本的 ESM 模块
    _import_structure["modeling_esm"] = [
        "ESM_PRETRAINED_MODEL_ARCHIVE_LIST",
        "EsmForMaskedLM",
        "EsmForSequenceClassification",
        "EsmForTokenClassification",
        "EsmModel",
        "EsmPreTrainedModel",
    ]
    _import_structure["modeling_esmfold"] = ["EsmForProteinFolding", "EsmFoldPreTrainedModel"]

# 检查是否 TensorFlow 可用,否则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果 TensorFlow 可用,导入 TensorFlow 版本的 ESM 模块
    _import_structure["modeling_tf_esm"] = [
        "TF_ESM_PRETRAINED_MODEL_ARCHIVE_LIST",
        "TFEsmForMaskedLM",
        "TFEsmForSequenceClassification",
        "TFEsmForTokenClassification",
        "TFEsmModel",
        "TFEsmPreTrainedModel",
    ]

# 如果是类型检查阶段,导入必要的类型
if TYPE_CHECKING:
    from .configuration_esm import ESM_PRETRAINED_CONFIG_ARCHIVE_MAP, EsmConfig
    from .tokenization_esm import EsmTokenizer

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入 Torch 版本的 ESM 模块(仅用于类型检查)
        from .modeling_esm import (
            ESM_PRETRAINED_MODEL_ARCHIVE_LIST,
            EsmForMaskedLM,
            EsmForSequenceClassification,
            EsmForTokenClassification,
            EsmModel,
            EsmPreTrainedModel,
        )
        from .modeling_esmfold import EsmFoldPreTrainedModel, EsmForProteinFolding

    try:
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入 TensorFlow 版本的 ESM 模块(仅用于类型检查)
        from .modeling_tf_esm import (
            TF_ESM_PRETRAINED_MODEL_ARCHIVE_LIST,
            TFEsmForMaskedLM,
            TFEsmForSequenceClassification,
            TFEsmForTokenClassification,
            TFEsmModel,
            TFEsmPreTrainedModel,
        )

# 如果不是类型检查阶段,将当前模块设为懒加载模块
else:
    import sys

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)

.\models\falcon\configuration_falcon.py

# coding=utf-8
# Copyright 2023 the Falcon authors and HuggingFace Inc. team.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Falcon configuration
"""
# 从 Transformers 库导入基类 PretrainedConfig
from ...configuration_utils import PretrainedConfig
# 从 Transformers 库导入日志工具
from ...utils import logging

# 获取 logger 对象,用于记录日志
logger = logging.get_logger(__name__)

# Falcon 模型预训练配置文件的映射字典,指定了模型名称与配置文件的 URL 地址
FALCON_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "tiiuae/falcon-40b": "https://huggingface.co/tiiuae/falcon-40b/resolve/main/config.json",
    "tiiuae/falcon-7b": "https://huggingface.co/tiiuae/falcon-7b/resolve/main/config.json",
}

# FalconConfig 类,继承自 PretrainedConfig 类,用于存储 Falcon 模型的配置信息
class FalconConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`FalconModel`]. It is used to instantiate a Falcon
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the
    [tiiuae/falcon-7b](https://huggingface.co/tiiuae/falcon-7b) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Example:

    ```
    >>> from transformers import FalconModel, FalconConfig

    >>> # Initializing a small (2-layer) Falcon configuration
    >>> configuration = FalconConfig(num_hidden_layers=2)

    >>> # Initializing a model from the small configuration
    >>> model = FalconModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    """

    # 模型类型标识为 "falcon"
    model_type = "falcon"
    # 推断时忽略的关键字列表
    keys_to_ignore_at_inference = ["past_key_values"]

    # FalconConfig 类的初始化方法,定义了模型的各种配置参数
    def __init__(
        self,
        vocab_size=65024,
        hidden_size=4544,
        num_hidden_layers=32,
        num_attention_heads=71,
        layer_norm_epsilon=1e-5,
        initializer_range=0.02,
        use_cache=True,
        hidden_dropout=0.0,
        attention_dropout=0.0,
        num_kv_heads=None,
        alibi=False,
        new_decoder_architecture=False,
        multi_query=True,
        parallel_attn=True,
        bias=False,
        max_position_embeddings=2048,
        rope_theta=10000.0,
        rope_scaling=None,
        bos_token_id=11,
        eos_token_id=11,
        **kwargs,
    ):
        # 调用父类 PretrainedConfig 的初始化方法,设置模型配置的基本参数
        super().__init__(
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            num_hidden_layers=num_hidden_layers,
            num_attention_heads=num_attention_heads,
            layer_norm_epsilon=layer_norm_epsilon,
            initializer_range=initializer_range,
            use_cache=use_cache,
            hidden_dropout=hidden_dropout,
            attention_dropout=attention_dropout,
            num_kv_heads=num_kv_heads,
            alibi=alibi,
            new_decoder_architecture=new_decoder_architecture,
            multi_query=multi_query,
            parallel_attn=parallel_attn,
            bias=bias,
            max_position_embeddings=max_position_embeddings,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            **kwargs,
        )
        ):
        self.vocab_size = vocab_size
        # Backward compatibility with n_embed kwarg
        n_embed = kwargs.pop("n_embed", None)
        # 设置隐藏层大小,如果未指定 n_embed 则使用 hidden_size
        self.hidden_size = hidden_size if n_embed is None else n_embed
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.layer_norm_epsilon = layer_norm_epsilon
        self.initializer_range = initializer_range
        self.use_cache = use_cache
        self.hidden_dropout = hidden_dropout
        self.attention_dropout = attention_dropout

        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id
        # 如果 num_kv_heads 未指定,则使用 num_attention_heads
        self.num_kv_heads = num_attention_heads if num_kv_heads is None else num_kv_heads
        self.alibi = alibi
        self.new_decoder_architecture = new_decoder_architecture
        # 当 new_decoder_architecture 为 True 时,忽略 multi_query
        self.multi_query = multi_query  # Ignored when new_decoder_architecture is True
        self.parallel_attn = parallel_attn
        self.bias = bias
        self.max_position_embeddings = max_position_embeddings
        self.rope_theta = rope_theta
        self.rope_scaling = rope_scaling
        # 运行 _rope_scaling_validation 方法验证 rope_scaling 的设置
        self._rope_scaling_validation()

        super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

    @property
    def head_dim(self):
        # 返回每个注意力头的维度
        return self.hidden_size // self.num_attention_heads

    @property
    def rotary(self):
        # 如果 alibi 为 False,则返回 True,表示支持旋转注意力
        return not self.alibi

    def _rope_scaling_validation(self):
        """
        Validate the `rope_scaling` configuration.
        """
        if self.rope_scaling is None:
            return

        if self.alibi:
            # 当 alibi 为 True 时,不支持 rope_scaling,抛出异常
            raise ValueError("`rope_scaling` is not supported when `alibi` is `True`.")

        if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
            # rope_scaling 必须是包含 type 和 factor 两个字段的字典
            raise ValueError(
                "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
                f"got {self.rope_scaling}"
            )
        rope_scaling_type = self.rope_scaling.get("type", None)
        rope_scaling_factor = self.rope_scaling.get("factor", None)
        if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
            # type 字段必须是 ['linear', 'dynamic'] 中的一个
            raise ValueError(
                f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
            )
        if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
            # factor 字段必须是大于 1 的浮点数
            raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")

.\models\falcon\convert_custom_code_checkpoint.py

# 导入所需的模块
import json  # 导入用于处理 JSON 格式的模块
from argparse import ArgumentParser  # 导入用于解析命令行参数的模块中的 ArgumentParser 类
from pathlib import Path  # 导入用于处理文件路径的模块中的 Path 类

"""
This script converts Falcon custom code checkpoints to modern Falcon checkpoints that use code in the Transformers
library. After conversion, performance (especially for generation) should improve and the checkpoint can be loaded
without needing trust_remote_code=True.
"""

# 如果当前脚本作为主程序运行
if __name__ == "__main__":
    # 创建参数解析器
    parser = ArgumentParser()
    # 添加命令行参数:--checkpoint_dir,类型为路径,必需参数,用于指定包含自定义代码检查点的目录
    parser.add_argument(
        "--checkpoint_dir",
        type=Path,
        required=True,
        help="Directory containing a custom code checkpoint to convert to a modern Falcon checkpoint.",
    )
    # 解析命令行参数
    args = parser.parse_args()

    # 检查指定的目录是否存在
    if not args.checkpoint_dir.is_dir():
        # 如果不存在,抛出数值错误异常
        raise ValueError("--checkpoint_dir argument should be a directory!")

    # 检查模型目录是否包含 configuration_RW.py 和 modelling_RW.py 文件
    if (
        not (args.checkpoint_dir / "configuration_RW.py").is_file()
        or not (args.checkpoint_dir / "modelling_RW.py").is_file()
    ):
        # 如果不包含,抛出数值错误异常
        raise ValueError(
            "The model directory should contain configuration_RW.py and modelling_RW.py files! Are you sure this is a custom code checkpoint?"
        )
    
    # 删除模型目录下的 configuration_RW.py 和 modelling_RW.py 文件
    (args.checkpoint_dir / "configuration_RW.py").unlink()
    (args.checkpoint_dir / "modelling_RW.py").unlink()

    # 读取并修改配置文件 config.json
    config = args.checkpoint_dir / "config.json"
    text = config.read_text()
    # 替换 JSON 文本中的特定字符串
    text = text.replace("RWForCausalLM", "FalconForCausalLM")
    text = text.replace("RefinedWebModel", "falcon")
    text = text.replace("RefinedWeb", "falcon")
    # 解析 JSON 文本为 Python 字典
    json_config = json.loads(text)
    # 删除字典中的 auto_map 键值对
    del json_config["auto_map"]

    # 根据键名替换字典中的键值对
    if "n_head" in json_config:
        json_config["num_attention_heads"] = json_config.pop("n_head")
    if "n_layer" in json_config:
        json_config["num_hidden_layers"] = json_config.pop("n_layer")
    if "n_head_kv" in json_config:
        json_config["num_kv_heads"] = json_config.pop("n_head_kv")
        json_config["new_decoder_architecture"] = True
    else:
        json_config["new_decoder_architecture"] = False
    
    # 获取字典中的 bos_token_id 和 eos_token_id,如果不存在默认为 1 和 2
    bos_token_id = json_config.get("bos_token_id", 1)
    eos_token_id = json_config.get("eos_token_id", 2)
    
    # 删除并重新写入修改后的配置文件 config.json
    config.unlink()
    config.write_text(json.dumps(json_config, indent=2, sort_keys=True))

    # 处理 tokenizer_config.json 文件
    tokenizer_config = args.checkpoint_dir / "tokenizer_config.json"
    if tokenizer_config.is_file():
        text = tokenizer_config.read_text()
        json_config = json.loads(text)
        # 如果 tokenizer_class 是 PreTrainedTokenizerFast,则修改 model_input_names
        if json_config["tokenizer_class"] == "PreTrainedTokenizerFast":
            json_config["model_input_names"] = ["input_ids", "attention_mask"]
            # 删除并重新写入修改后的 tokenizer_config.json
            tokenizer_config.unlink()
            tokenizer_config.write_text(json.dumps(json_config, indent=2, sort_keys=True))

    # 处理 generation_config.json 文件
    generation_config_path = args.checkpoint_dir / "generation_config.json"
    # 创建要写入的字典
    generation_dict = {
        "_from_model_config": True,
        "bos_token_id": bos_token_id,
        "eos_token_id": eos_token_id,
        "transformers_version": "4.33.0.dev0",
    }
    # 将生成配置写入 generation_config.json
    generation_config_path.write_text(json.dumps(generation_dict, indent=2, sort_keys=True))
    # 打印消息到标准输出,提示操作完成并建议用户验证新的检查点是否符合预期。
    print("Done! Please double-check that the new checkpoint works as expected.")

.\models\falcon\modeling_falcon.py

# 指定编码格式为 UTF-8

# 版权声明和许可协议,此处声明代码版权归 Falcon 作者及 HuggingFace Inc. 团队所有,保留所有权利
# 根据 Apache 许可证 2.0 版本发布,除非符合许可协议,否则不得使用此文件
# 可以从以下网址获取许可协议的副本:http://www.apache.org/licenses/LICENSE-2.0
# 如果适用法律要求或书面同意,软件将按“原样”分发,没有任何形式的担保或条件
"""PyTorch Falcon model."""

import math  # 导入 math 库提供数学函数
import warnings  # 导入 warnings 库用于警告处理
from typing import TYPE_CHECKING, Optional, Tuple, Union  # 导入类型提示相关模块

import torch  # 导入 PyTorch 库
import torch.utils.checkpoint  # 导入 PyTorch 的 checkpoint 功能
from torch import nn  # 导入 PyTorch 的神经网络模块
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss  # 导入损失函数
from torch.nn import functional as F  # 导入 PyTorch 的函数式接口,并重命名为 F

# 导入工具函数和类
from ...modeling_attn_mask_utils import (
    AttentionMaskConverter,
    _prepare_4d_causal_attention_mask,
    _prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutputWithPast,
    TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel  # 导入预训练模型基类
from ...pytorch_utils import is_torch_greater_or_equal_than_2_0  # 导入 PyTorch 版本判断工具函数
from ...utils import (
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
    logging,
)

# 导入 FalconConfig 配置类
from .configuration_falcon import FalconConfig


if TYPE_CHECKING:
    from ...configuration_utils import PretrainedConfig

# 如果 flash_attn 2.x 可用,则导入相应的函数和模块
if is_flash_attn_2_available():
    from flash_attn import flash_attn_func, flash_attn_varlen_func
    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa

# 获取 logger 实例
logger = logging.get_logger(__name__)

# Falcon 模型的预训练模型存档列表
FALCON_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "tiiuae/falcon-40b",
    "tiiuae/falcon-40b-instruct",
    "tiiuae/falcon-7b",
    "tiiuae/falcon-7b-instruct",
    "tiiuae/falcon-rw-7b",
    "tiiuae/falcon-rw-1b",
]
# 用于文档的检查点名
_CHECKPOINT_FOR_DOC = "Rocketknight1/falcon-rw-1b"
# 用于文档的配置名
_CONFIG_FOR_DOC = "FalconConfig"


# 注意:在训练期间,我们未融合矩阵乘法和偏置项,这意味着操作之间需要一个额外的量化步骤到 bfloat16。
# 为了不降低 HF 代码的质量,我们在最终模型中保留了这些特征。
class FalconLinear(nn.Linear):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # 执行线性变换
        hidden_states = input @ self.weight.T
        # 如果没有偏置项,则直接返回变换后的结果
        if self.bias is None:
            return hidden_states
        # 否则,将偏置项加到变换后的结果中并返回
        return hidden_states + self.bias


# 从 transformers.models.llama.modeling_llama 中复制的函数,用于旋转输入张量一半隐藏维度的值
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    # 使用切片操作从张量 x 中取出从中间到末尾的所有维度,保持其他维度不变
    x2 = x[..., x.shape[-1] // 2 :]
    # 使用 torch.cat 函数沿着最后一个维度将 x2 和 x1 张量连接起来
    return torch.cat((-x2, x1), dim=-1)
# 从transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb复制而来
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    """对查询张量和键张量应用旋转位置嵌入。

    Args:
        q (`torch.Tensor`): 查询张量。
        k (`torch.Tensor`): 键张量。
        cos (`torch.Tensor`): 旋转嵌入的余弦部分。
        sin (`torch.Tensor`): 旋转嵌入的正弦部分。
        position_ids (`torch.Tensor`):
            对应于查询和键张量的标记位置索引。例如,当使用KV缓存时,可以传递偏移的位置ID。
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            'unsqueeze_dim' 参数指定沿其进行展开的维度,以便将 cos[position_ids] 和 sin[position_ids] 正确广播到 q 和 k 的维度。
            例如,注意 cos[position_ids] 和 sin[position_ids] 的形状为 [batch_size, seq_len, head_dim]。然后,
            如果 q 和 k 的形状为 [batch_size, heads, seq_len, head_dim],设置 unsqueeze_dim=1 使得 cos[position_ids] 和 sin[position_ids]
            可以广播到 q 和 k 的形状。类似地,如果 q 和 k 的形状为 [batch_size, seq_len, heads, head_dim],则设置 unsqueeze_dim=2。

    Returns:
        `tuple(torch.Tensor)`: 包含使用旋转位置嵌入旋转后的查询和键张量。
    """
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


# 从transformers.models.llama.modeling_llama._get_unpad_data复制而来
def _get_unpad_data(attention_mask):
    """获取未填充数据。

    Args:
        attention_mask (`torch.Tensor`): 注意力掩码张量。

    Returns:
        `tuple`: 包含以下三个元素的元组:
            - `torch.Tensor`: 指示非填充位置索引的张量。
            - `torch.Tensor`: 指示累积序列长度的张量,用于填充。
            - `int`: 批次中最大序列长度。
    """
    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
    max_seqlen_in_batch = seqlens_in_batch.max().item()
    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
    return (
        indices,
        cu_seqlens,
        max_seqlen_in_batch,
    )


# 从transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding复制而来,更改为FalconRotaryEmbedding
class FalconRotaryEmbedding(nn.Module):
    # 初始化函数,设置模型参数和缓存
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        # 调用父类初始化方法
        super().__init__()

        # 设置模型维度
        self.dim = dim
        # 设置最大位置编码长度,默认为2048
        self.max_position_embeddings = max_position_embeddings
        # 设置基础频率,默认为10000
        self.base = base

        # 计算频率的倒数,用于位置编码
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        # 将频率的倒数注册为模型的缓存,不持久化
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # 调用内部方法设置余弦和正弦缓存,以便 `torch.jit.trace` 方法能正常工作
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    # 内部方法,设置余弦和正弦缓存
    def _set_cos_sin_cache(self, seq_len, device, dtype):
        # 记录当前缓存的最大序列长度
        self.max_seq_len_cached = seq_len
        # 创建序列长度张量 t,并转换为与 inv_freq 相同类型
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

        # 计算频率张量与位置张量的外积
        freqs = torch.outer(t, self.inv_freq)
        # 按最后一个维度拼接频率张量,构成位置编码矩阵
        emb = torch.cat((freqs, freqs), dim=-1)
        # 将余弦值缓存注册为模型的缓存,不持久化,并转换为指定数据类型
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        # 将正弦值缓存注册为模型的缓存,不持久化,并转换为指定数据类型
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    # 前向传播方法
    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]

        # 如果指定了新的序列长度超过当前缓存的最大序列长度
        if seq_len > self.max_seq_len_cached:
            # 重新设置余弦和正弦缓存
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        # 返回当前缓存中的余弦和正弦值,截取到指定的序列长度,转换为输入张量 x 的数据类型
        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )
# 从transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding复制并修改为Falcon
# TODO @joao: 经过静态缓存后不再从LLama复制,修复我(复制 -> Copied)
class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
    """使用线性缩放扩展的FalconRotaryEmbedding。由Reddit用户/u/kaiokendev贡献"""

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_position_embeddings, base, device)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        # 创建一个整数序列t,其长度为max_seq_len_cached,使用给定的设备和数据类型
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
        # 将整数序列t除以scaling_factor,以得到频率
        t = t / self.scaling_factor

        # 计算频率矩阵,使用torch.outer计算外积
        freqs = torch.outer(t, self.inv_freq)
        # 与论文不同,但使用不同的排列方式以获得相同的计算结果
        emb = torch.cat((freqs, freqs), dim=-1)
        # 将计算得到的cosine和sine值注册为缓冲区,使用给定的数据类型,并标记为非持久化
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)


# 从transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding复制并修改为Falcon
# TODO @joao: 经过静态缓存后不再从LLama复制,修复我(复制 -> Copied)
class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
    """使用动态NTK缩放扩展的FalconRotaryEmbedding。由Reddit用户/u/bloc97和/u/emozilla贡献"""

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_position_embeddings, base, device)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len

        # 如果序列长度超过最大位置嵌入数,计算基础频率
        if seq_len > self.max_position_embeddings:
            base = self.base * (
                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
            ) ** (self.dim / (self.dim - 2))
            inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
            self.register_buffer("inv_freq", inv_freq, persistent=False)

        # 创建一个整数序列t,其长度为max_seq_len_cached,使用给定的设备和数据类型
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

        # 计算频率矩阵,使用torch.outer计算外积
        freqs = torch.outer(t, self.inv_freq)
        # 与论文不同,但使用不同的排列方式以获得相同的计算结果
        emb = torch.cat((freqs, freqs), dim=-1)
        # 将计算得到的cosine和sine值注册为缓冲区,使用给定的数据类型,并标记为非持久化
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)


def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
    # 从注意力掩码构建Alibi张量,指定批次大小和序列长度
    batch_size, seq_length = attention_mask.shape
    # 计算最接近的2的幂次方,小于等于给定数值num_heads的最大整数幂次方
    closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
    # 计算基数,用于生成注意力偏置(attention bias)
    base = torch.tensor(
        2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
    )
    # 生成幂次序列,从1开始到最接近的2的幂次方(包含)
    powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
    # 计算斜率,即基数的各幂次方
    slopes = torch.pow(base, powers)

    # 如果最接近的2的幂次方不等于num_heads,则需要额外计算
    if closest_power_of_2 != num_heads:
        # 计算额外基数
        extra_base = torch.tensor(
            2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
        )
        # 计算剩余头数
        num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
        # 计算额外幂次序列,步长为2
        extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
        # 将额外的斜率拼接到原始斜率序列中
        slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)

    # 创建alibi张量,用于相对位置偏置(relative position bias),其形状为(batch_size, num_heads, query_length, key_length)
    # 在此设置为(batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length),query_length维度将会正确广播
    # 这与T5模型中的相对位置偏置基本相同,参见:https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
    arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
    alibi = slopes[..., None].bfloat16() * arange_tensor
    # 重新形状化alibi张量,并转换为指定的dtype
    return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
# 从transformers.models.bloom.modeling_bloom.dropout_add复制而来,定义了一个dropout add函数
def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
    """
    Dropout add function

    Args:
        x (`torch.tensor`, *required*):
            input tensor 输入张量
        residual (`torch.tensor`, *required*):
            residual tensor 剩余张量
        prob (`float`, *required*):
            dropout probability dropout概率
        training (`bool`, *required*):
            training mode 训练模式
    """
    # 对输入张量x应用dropout操作,根据training参数决定是否使用
    out = F.dropout(x, p=prob, training=training)
    # 将dropout后的结果与剩余张量residual相加
    out = residual + out
    # 返回结果张量out
    return out


class FalconAttention(nn.Module):
    def __init__(self, config: FalconConfig):
        super().__init__()

        # 初始化FalconAttention类的配置参数
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.split_size = self.hidden_size
        self.hidden_dropout = config.hidden_dropout
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True
        self._use_sdpa = config._attn_implementation == "sdpa"

        # 检查hidden_size是否能被num_heads整除,若不能则抛出错误
        if self.head_dim * self.num_heads != self.hidden_size:
            raise ValueError(
                f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
                f" {self.num_heads})."
            )

        # 如果配置为使用rotary,则初始化rope
        if config.rotary:
            self._init_rope()

        # Layer-wise attention scaling,初始化注意力权重的缩放因子
        self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
        self.beta = self.inv_norm_factor
        
        # 根据配置选择不同的输出维度qkv_out_dim
        if config.new_decoder_architecture:
            qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
        elif config.multi_query:
            qkv_out_dim = self.hidden_size + 2 * self.head_dim
        else:
            qkv_out_dim = 3 * self.hidden_size
        
        # 初始化query_key_value线性层,用于计算查询、键、值
        self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
        self.new_decoder_architecture = config.new_decoder_architecture
        self.multi_query = config.multi_query
        
        # 初始化输出的线性层dense
        self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
        
        # 初始化注意力的dropout层
        self.attention_dropout = nn.Dropout(config.attention_dropout)
        
        # 根据配置初始化num_kv_heads
        self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1

    # 从transformers.models.llama.modeling_llama.LlamaAttention._init_rope复制而来,用于初始化rope
    # 初始化 RoPE(Rotary Positional Embedding),根据配置设置不同的缩放方式
    def _init_rope(self):
        # 如果配置中没有指定 RoPE 缩放方式,则使用默认的 FalconRotaryEmbedding
        if self.config.rope_scaling is None:
            self.rotary_emb = FalconRotaryEmbedding(
                self.head_dim,
                max_position_embeddings=self.max_position_embeddings,
                base=self.rope_theta,
            )
        else:
            # 否则根据配置选择不同的 RoPE 缩放类型
            scaling_type = self.config.rope_scaling["type"]
            scaling_factor = self.config.rope_scaling["factor"]
            if scaling_type == "linear":
                # 使用线性缩放方式的 RoPE
                self.rotary_emb = FalconLinearScalingRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    scaling_factor=scaling_factor,
                    base=self.rope_theta,
                )
            elif scaling_type == "dynamic":
                # 使用动态 NTK 缩放方式的 RoPE
                self.rotary_emb = FalconDynamicNTKScalingRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    scaling_factor=scaling_factor,
                    base=self.rope_theta,
                )
            else:
                # 如果缩放类型未知,则抛出 ValueError 异常
                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

    # 将融合后的查询/键/值张量拆分为多个头部,根据模型的不同架构选择不同的拆分方式
    def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv`

        Args:
            fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]

        Returns:
            query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
            value: [batch_size, seq_length, num_heads, head_dim]
        """
        if self.new_decoder_architecture:
            # 如果是新的解码器架构,则按照指定方式拆分
            batch, seq_len, _ = fused_qkv.shape
            qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv_heads + 2, self.head_dim)
            query = qkv[:, :, :, :-2]
            key = qkv[:, :, :, [-2]]
            value = qkv[:, :, :, [-1]]
            key = torch.broadcast_to(key, query.shape)
            value = torch.broadcast_to(value, query.shape)

            query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
            return query, key, value
        elif not self.multi_query:
            # 如果不是多查询模式,则按照普通的拆分方式
            batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
            fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
            return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
        else:
            # 否则按照另一种特定的拆分方式处理
            batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
            fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
            return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]

    # 从 transformers 库中复制的函数,用于合并注意力机制中的多个头部
    # Copied from transformers.models.bloom.modeling_bloom.BloomAttention._merge_heads
    def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        """
        Merge heads together over the last dimension

        Args:
            x (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]

        Returns:
            torch.tensor: [batch_size, seq_length, num_heads * head_dim]
        """
        # 获取输入张量的形状信息:batch_size * num_heads, seq_length, head_dim
        batch_size_and_num_heads, seq_length, _ = x.shape
        # 计算真实的 batch_size,即 batch_size * num_heads 的商,表示真正的 batch 数量
        batch_size = batch_size_and_num_heads // self.num_heads

        # 将输入张量重新视图化以分解批次大小
        # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
        x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)

        # 通过维度置换将头部维度与序列长度维度交换位置
        # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
        x = x.permute(0, 2, 1, 3)

        # 将 num_heads 和 head_dim 合并到一个维度中
        # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
        return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
# 定义了一个名为 FalconFlashAttention2 的类,继承自 FalconAttention 类。
# Falcon flash attention 模块,其权重未被修改。唯一需要更改的是前向传播,在这里需要正确调用 flash attention 的公共 API,并处理输入中可能存在的填充标记。

# 从 transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ 复制而来
# FalconFlashAttention2 类的构造函数,调用父类构造函数初始化。
def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)

    # TODO: 一旦 RoCm 的 Flash Attention 升级到 2.1 版本,此处应该移除。
    # flash_attn<2.1 生成左上角对齐的因果掩码,而此处需要的是右下角对齐,默认 flash_attn>=2.1 才支持此特性。该属性用于处理这两个版本之间的差异。
    # 参考:https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0
    # 注意,在 flash_attn<2.1 版本中,除非 q_seqlen == 1,否则使用 q_seqlen != k_seqlen 会产生错误的掩码(左上角对齐)。
    self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
# FalconFlashAttention2 类的私有方法 _flash_attention_forward,执行 flash attention 的前向传播。
        """
        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
        first unpad the input, then computes the attention scores and pad the final attention scores.

        Args:
            query_states (`torch.Tensor`):
                Input query states to be passed to Flash Attention API
            key_states (`torch.Tensor`):
                Input key states to be passed to Flash Attention API
            value_states (`torch.Tensor`):
                Input value states to be passed to Flash Attention API
            attention_mask (`torch.Tensor`):
                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
                position of padding tokens and 1 for the position of non-padding tokens.
            dropout (`float`):
                Attention dropout
            softmax_scale (`float`, *optional*):
                The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
        """
        # Determine if causal masking is required based on the current configuration
        if not self._flash_attn_uses_top_left_mask:
            causal = self.is_causal
        else:
            # Temporary condition until Flash Attention for RoCm is updated
            causal = self.is_causal and query_length != 1

        # Check if there are padding tokens in the input sequences
        if attention_mask is not None:
            # Get the batch size
            batch_size = query_states.shape[0]
            # Unpad the input sequences based on the attention mask
            query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
                query_states, key_states, value_states, attention_mask, query_length
            )

            # Extract sequence lengths after unpadding
            cu_seqlens_q, cu_seqlens_k = cu_seq_lens
            # Extract maximum sequence lengths in the current batch
            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

            # Compute attention scores for unpad input using variable-length Flash Attention
            attn_output_unpad = flash_attn_varlen_func(
                query_states,
                key_states,
                value_states,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=max_seqlen_in_batch_q,
                max_seqlen_k=max_seqlen_in_batch_k,
                dropout_p=dropout,
                softmax_scale=softmax_scale,
                causal=causal,
            )

            # Pad the attention scores to match the original input length
            attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
        else:
            # Compute attention scores using standard Flash Attention
            attn_output = flash_attn_func(
                query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
            )

        # Return the computed attention scores
        return attn_output
    # 在 `_upad_input` 方法中,根据给定的注意力掩码和查询长度处理输入数据
    def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
        # 获取未填充数据的索引、当前序列长度和批次中的最大序列长度
        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
        
        # 获取批次大小、键值序列长度、头数以及头维度
        batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

        # 通过索引重组键层数据,以适应未填充的序列
        key_layer = index_first_axis(
            key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
        )
        
        # 通过索引重组值层数据,以适应未填充的序列
        value_layer = index_first_axis(
            value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
        )
        
        # 根据查询长度处理查询层数据
        if query_length == kv_seq_len:
            # 若查询长度等于键值序列长度,直接重组查询层数据
            query_layer = index_first_axis(
                query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
            )
            cu_seqlens_q = cu_seqlens_k
            max_seqlen_in_batch_q = max_seqlen_in_batch_k
            indices_q = indices_k
        elif query_length == 1:
            # 若查询长度为1,进行相关处理
            max_seqlen_in_batch_q = 1
            cu_seqlens_q = torch.arange(
                batch_size + 1, dtype=torch.int32, device=query_layer.device
            )  # 这里有一个 memcpy,这是非常糟糕的。
            indices_q = cu_seqlens_q[:-1]
            query_layer = query_layer.squeeze(1)
        else:
            # 处理左填充的情况,使用 -query_length: 切片
            attention_mask = attention_mask[:, -query_length:]
            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

        # 返回处理后的查询层、键层、值层、查询索引、当前序列长度元组和最大序列长度元组
        return (
            query_layer,
            key_layer,
            value_layer,
            indices_q,
            (cu_seqlens_q, cu_seqlens_k),
            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
        )
# 定义一个名为 FalconMLP 的神经网络模块
class FalconMLP(nn.Module):
    # 初始化方法,接受一个 FalconConfig 类型的参数 config
    def __init__(self, config: FalconConfig):
        super().__init__()
        # 从配置中获取隐藏层大小
        hidden_size = config.hidden_size

        # 定义全连接层,将隐藏层映射到4倍隐藏层大小,带有可选的偏置
        self.dense_h_to_4h = FalconLinear(hidden_size, 4 * hidden_size, bias=config.bias)
        # 使用 GELU 激活函数
        self.act = nn.GELU()
        # 定义全连接层,将4倍隐藏层大小映射回隐藏层大小,带有可选的偏置
        self.dense_4h_to_h = FalconLinear(4 * hidden_size, hidden_size, bias=config.bias)
        # 获取隐藏层的 dropout 概率
        self.hidden_dropout = config.hidden_dropout

    # 前向传播方法,接受输入张量 x,返回输出张量
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 使用 GELU 激活函数的全连接层操作
        x = self.act(self.dense_h_to_4h(x))
        # 第二个全连接层操作
        x = self.dense_4h_to_h(x)
        return x


# FalconAttention 类的字典映射,根据配置不同选择不同的注意力机制类
FALCON_ATTENTION_CLASSES = {
    "eager": FalconAttention,
    "sdpa": FalconAttention,  # FalconAttention 原始实现同时包含有/无 SDPA 的前向传播
    "flash_attention_2": FalconFlashAttention2,
}


# FalconDecoderLayer 类定义
class FalconDecoderLayer(nn.Module):
    # 初始化方法,接受一个 FalconConfig 类型的参数 config
    def __init__(self, config: FalconConfig):
        super().__init__()
        # 从配置中获取隐藏层大小
        hidden_size = config.hidden_size
        # 获取注意力头的数量
        self.num_heads = config.num_attention_heads

        # 根据配置选择适当的注意力机制类初始化 self_attention 属性
        self.self_attention = FALCON_ATTENTION_CLASSES[config._attn_implementation](config)
        # 初始化 MLP 层
        self.mlp = FalconMLP(config)
        # 获取隐藏层的 dropout 概率
        self.hidden_dropout = config.hidden_dropout
        # 保存配置
        self.config = config

        # 根据配置选择不同的解码器架构
        if config.new_decoder_architecture:
            # 在 self-attention 前进行层归一化
            self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
            # 在 MLP 前进行层归一化
            self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        else:
            # 输入层的层归一化
            self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
            # 如果不使用并行注意力,则在注意力后进行层归一化
            if not config.parallel_attn:
                self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

    # 前向传播方法
    def forward(
        self,
        hidden_states: torch.Tensor,
        alibi: Optional[torch.Tensor],
        attention_mask: torch.Tensor,
        position_ids: Optional[torch.LongTensor] = None,
        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        head_mask: Optional[torch.Tensor] = None,
        use_cache: bool = False,
        output_attentions: bool = False,
        **kwargs,
        ):
            # 检查是否传入了 "padding_mask" 参数,如果是则发出警告,因为在 v4.37 版本中将移除,请使用 `attention_mask` 替代。
            if "padding_mask" in kwargs:
                warnings.warn(
                    "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
                )

        # 保存输入的隐藏状态,以备后续计算中使用
        residual = hidden_states

        # 根据配置选择不同的层归一化方法
        if self.config.new_decoder_architecture:
            # 使用新的解码器架构,应用注意力层和MLP层的归一化
            attention_layernorm_out = self.ln_attn(hidden_states)
            mlp_layernorm_out = self.ln_mlp(hidden_states)
        else:
            # 使用旧的解码器架构,只应用输入层的归一化
            attention_layernorm_out = self.input_layernorm(hidden_states)

        # 自注意力机制
        attn_outputs = self.self_attention(
            attention_layernorm_out,
            layer_past=layer_past,
            attention_mask=attention_mask,
            position_ids=position_ids,
            alibi=alibi,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            **kwargs,
        )

        attention_output = attn_outputs[0]

        # 如果不使用新的解码器架构,则进行残差连接和Dropout操作
        if not self.config.new_decoder_architecture:
            if self.config.parallel_attn:
                # 并行注意力模式,直接使用注意力层的归一化输出
                mlp_layernorm_out = attention_layernorm_out
            else:
                # 非并行注意力模式,进行残差连接和Dropout操作
                residual = dropout_add(
                    attention_output, residual, self.config.attention_dropout, training=self.training
                )
                mlp_layernorm_out = self.post_attention_layernorm(residual)

        # 提取自注意力机制的输出
        outputs = attn_outputs[1:]

        # MLP层的前向传播
        mlp_output = self.mlp(mlp_layernorm_out)

        # 如果使用新的解码器架构或并行注意力模式,则将自注意力输出与MLP输出相加
        if self.config.new_decoder_architecture or self.config.parallel_attn:
            mlp_output += attention_output

        # 最终的输出,进行Dropout和残差连接
        output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)

        # 如果使用缓存,输出包括隐藏状态、present以及注意力信息;否则,不包括隐藏状态
        if use_cache:
            outputs = (output,) + outputs
        else:
            outputs = (output,) + outputs[1:]

        # 返回模型的输出,包括隐藏状态、present以及注意力信息
        return outputs  # hidden_states, present, attentions
# FalconPreTrainedModel 类的文档字符串,描述了该类继承自 PreTrainedModel,介绍了其方法和通用功能。
# 包含了关于如何使用该模型以及其参数配置的信息。
FALCON_START_DOCSTRING = r"""

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`FalconConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

# FalconPreTrainedModel 类的输入文档字符串,目前为空。
FALCON_INPUTS_DOCSTRING = r"""
"""


class FalconPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # FalconPreTrainedModel 类的配置类,指定为 FalconConfig。
    config_class = FalconConfig
    # 模型基础前缀,用于模型的标识。
    base_model_prefix = "transformer"
    # 是否支持梯度检查点。
    supports_gradient_checkpointing = True
    # 不进行分割的模块列表。
    _no_split_modules = ["FalconDecoderLayer"]
    # 是否支持闪光注意力版本2。
    _supports_flash_attn_2 = True
    # 是否支持自发对齐。
    _supports_sdpa = True

    def __init__(self, *inputs, **kwargs):
        super().__init__(*inputs, **kwargs)

    def _init_weights(self, module: nn.Module):
        """Initialize the weights."""
        # 初始化模型权重的函数。
        if isinstance(module, nn.Linear) or isinstance(module, FalconLinear):
            # 对于线性层和特定的自定义线性层 FalconLinear,使用正态分布初始化权重。
            # 与 TensorFlow 版本稍有不同,后者使用截断正态分布进行初始化。
            # 参考 https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 如果存在偏置,则将其初始化为零。
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            # 对于嵌入层,使用正态分布初始化权重。
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # 如果定义了填充索引,则将该索引处的权重初始化为零。
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, LayerNorm):
            # 对于 LayerNorm 层,将偏置初始化为零,将权重初始化为1.0。
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    # 从 transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa 适配而来的方法。
    @classmethod
    # 检查并启用 SDPA(Scaled Dot-Product Attention)的设置,可能会修改配置并返回更新后的配置对象。
    def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "PretrainedConfig":
        # 注意:自 PyTorch 2.0 起,Falcon 支持 SDPA。为了向后兼容性,保持这样的设定(torch>=2.0 自动使用 SDPA)。
        # 如果只进行严格检查,且当前的 torch 版本不符合要求,则抛出 ImportError 异常。
        if hard_check_only:
            if not is_torch_greater_or_equal_than_2_0:
                raise ImportError("PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.0.")

        # 如果当前 torch 版本不符合要求,则直接返回原始配置对象。
        if not is_torch_greater_or_equal_than_2_0:
            return config

        # 检查是否使用了 BetterTransformer,如果是,则直接返回原始配置对象。
        _is_bettertransformer = getattr(cls, "use_bettertransformer", False)
        if _is_bettertransformer:
            return config

        # 如果不是严格检查模式,将注意力机制实现设为 "sdpa"。
        if not hard_check_only:
            config._attn_implementation = "sdpa"
        # 返回更新后的配置对象。
        return config
# 使用装饰器添加文档字符串,描述这是一个不带特定头部的 Falcon 模型变换器的类
@add_start_docstrings(
    "The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.",
    FALCON_START_DOCSTRING,
)
# FalconModel 类,继承自 FalconPreTrainedModel
class FalconModel(FalconPreTrainedModel):
    # 初始化方法,接受一个 FalconConfig 类型的参数 config
    def __init__(self, config: FalconConfig):
        # 调用父类 FalconPreTrainedModel 的初始化方法
        super().__init__(config)

        # 设置模型的嵌入维度和注意力头数
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.use_alibi = config.alibi

        # 嵌入层 + LayerNorm 嵌入层
        self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)

        # Transformer 块
        self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)])
        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
        self._use_sdpa = config._attn_implementation == "sdpa"

        # 最终的 Layer Norm
        self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

        # 梯度检查点设置为 False
        self.gradient_checkpointing = False

        # 初始化权重并应用最终处理
        self.post_init()

    # 获取输入嵌入的方法
    def get_input_embeddings(self):
        return self.word_embeddings

    # 设置输入嵌入的方法,接受一个新的嵌入张量作为参数
    def set_input_embeddings(self, new_embeddings: torch.Tensor):
        self.word_embeddings = new_embeddings

    # 使用装饰器添加文档字符串,描述这是一个 Falcon 模型变换器的前向方法,包含输入参数的详细说明
    @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=BaseModelOutputWithPastAndCrossAttentions,
        config_class=_CONFIG_FOR_DOC,
    )
    # 前向方法,接受多个输入参数并返回多个输出
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,



# 使用装饰器添加文档字符串,描述这是一个带有语言建模头的 Falcon 模型变换器的类
@add_start_docstrings(
    "The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).",
    FALCON_START_DOCSTRING,
)
# FalconForCausalLM 类,继承自 FalconPreTrainedModel
class FalconForCausalLM(FalconPreTrainedModel):
    # 静态变量,指示与输入嵌入权重相关联的键名列表
    _tied_weights_keys = ["lm_head.weight"]

    # 初始化方法,接受一个 FalconConfig 类型的参数 config
    def __init__(self, config: FalconConfig):
        # 调用父类 FalconPreTrainedModel 的初始化方法
        super().__init__(config)
        
        # 创建一个 FalconModel 实例,并传入配置 config
        self.transformer = FalconModel(config)
        
        # 创建一个线性层用于语言建模的头部,输入维度为 config.hidden_size,输出维度为 config.vocab_size,没有偏置
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # 初始化权重并应用最终处理
        self.post_init()

    # 获取输出嵌入的方法
    def get_output_embeddings(self):
        return self.lm_head

    # 设置输出嵌入的方法,接受一个新的嵌入张量作为参数
    def set_output_embeddings(self, new_embeddings: torch.Tensor):
        self.lm_head = new_embeddings
    # 准备生成的输入参数,返回一个字典
    def prepare_inputs_for_generation(
        self,
        input_ids: torch.LongTensor,
        past_key_values: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> dict:
        if past_key_values is not None:
            # 获取过去键值张量的长度
            past_length = past_key_values[0][0].shape[2]

            # 一些生成方法已经只传递了最后一个输入 ID
            if input_ids.shape[1] > past_length:
                # 如果输入的长度大于过去的长度,则移除前缀长度为过去的长度
                remove_prefix_length = past_length
            else:
                # 否则默认保留最后一个 ID
                remove_prefix_length = input_ids.shape[1] - 1

            # 更新输入的 ID,移除前缀部分
            input_ids = input_ids[:, remove_prefix_length:]

        # 注意:Falcon 版本中带有 alibi 的情况下不使用 position_ids。它在 RoPE 中使用。
        if not self.transformer.use_alibi and attention_mask is not None and position_ids is None:
            # 为批量生成创建即时的 position_ids
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                # 如果有过去的键值,只保留与输入长度相匹配的 position_ids
                position_ids = position_ids[:, -input_ids.shape[1] :]

        # 返回包含所有生成输入的字典
        return {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "past_key_values": past_key_values,
            "use_cache": kwargs.get("use_cache"),
            "attention_mask": attention_mask,
        }

    # 将模型前向传播方法装饰为添加文档字符串和代码示例文档字符串
    @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=CausalLMOutputWithCrossAttentions,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        """
        
        # 如果未指定返回字典,则使用配置中的默认设置
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        # 使用 Transformer 模型处理输入数据,获取模型的输出
        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 获取 Transformer 输出中的隐藏状态
        hidden_states = transformer_outputs[0]

        # 使用语言模型头部生成预测的 logits
        lm_logits = self.lm_head(hidden_states)

        # 初始化损失为 None
        loss = None
        if labels is not None:
            # 将 logits 向左移动一个位置,以便预测下一个 token
            shift_logits = lm_logits[..., :-1, :].contiguous()
            # 将 labels 向右移动一个位置,与 shift_logits 对齐
            shift_labels = labels[..., 1:].contiguous()
            # 获取 batch_size, seq_length 和 vocab_size 的大小
            batch_size, seq_length, vocab_size = shift_logits.shape
            # 使用交叉熵损失函数计算损失
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
            )

        # 如果不要求返回字典,则返回模型输出的元组形式
        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output
        
        # 如果需要返回字典形式的输出,则创建 CausalLMOutputWithCrossAttentions 对象
        return CausalLMOutputWithCrossAttentions(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )
    ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
        """
        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
        beam_idx at every generation step.

        Output shares the same memory storage as `past`.
        """

        # 获取在所有需要索引的设备上的 `beam_idx` 的副本。
        device_to_beam_idx = {
            past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
        }
        # 对 `past` 进行重新排序,以便与每个生成步骤中正确的 `beam_idx` 匹配。
        reordered_past = tuple(
            (
                layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
                layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
            )
            for layer_past in past
        )
        # 返回重新排序后的 `past`。
        return reordered_past
@add_start_docstrings(
    """
    The Falcon Model transformer with a sequence classification head on top (linear layer).

    [`FalconForSequenceClassification`] uses the last token in order to do the classification, as other causal models
    (e.g. GPT-1) do.

    Since it does classification on the last token, it requires to know the position of the last token. If a
    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
    each row of the batch).
    """,
    FALCON_START_DOCSTRING,
)
class FalconForSequenceClassification(FalconPreTrainedModel):
    def __init__(self, config: FalconConfig):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.transformer = FalconModel(config)  # 初始化 FalconModel,并传入配置信息
        self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)  # 生成线性层,用于分类

        # Initialize weights and apply final processing
        self.post_init()  # 执行初始化权重和最终处理操作


@add_start_docstrings(
    """
    Falcon Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
    Named-Entity-Recognition (NER) tasks.
    """,
    FALCON_START_DOCSTRING,
)
class FalconForTokenClassification(FalconPreTrainedModel):
    def __init__(self, config: FalconConfig):
        super().__init__(config)
        self.num_labels = config.num_labels  # 根据配置设置标签数量

        self.transformer = FalconModel(config)  # 初始化 FalconModel,并传入配置信息

        # 设置分类器的 dropout 概率
        if getattr(config, "classifier_dropout", None) is not None:
            classifier_dropout = config.classifier_dropout
        elif getattr(config, "hidden_dropout", None) is not None:
            classifier_dropout = config.hidden_dropout
        else:
            classifier_dropout = 0.1
        self.dropout = nn.Dropout(classifier_dropout)  # 创建 dropout 层
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)  # 创建线性层,用于分类

        # Initialize weights and apply final processing
        self.post_init()  # 执行初始化权重和最终处理操作
    # 将模型前向传播方法添加文档字符串,用于文档化模型输入参数和示例代码
    @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    # 定义模型的前向传播方法,接受多个输入参数并返回分类器输出或损失
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """

        # 如果没有显式指定 return_dict,则使用模型配置中的默认值
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 使用 transformer 处理输入数据,得到变换器的输出结果
        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 从变换器的输出中获取隐藏状态并应用 dropout
        hidden_states = transformer_outputs[0]
        hidden_states = self.dropout(hidden_states)
        
        # 使用分类器模型对隐藏状态进行分类预测
        logits = self.classifier(hidden_states)

        # 初始化损失为 None
        loss = None
        # 如果有提供标签信息,则计算损失值
        if labels is not None:
            batch_size, seq_length = labels.shape
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(
                logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
            )

        # 如果 return_dict 为 False,则组装输出为元组
        if not return_dict:
            output = (logits,) + transformer_outputs[2:]
            return ((loss,) + output) if loss is not None else output

        # 如果 return_dict 为 True,则返回 TokenClassifierOutput 对象
        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )
@add_start_docstrings(
    """
    The Falcon Model transformer with a span classification head on top for extractive question-answering tasks like
    SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
    FALCON_START_DOCSTRING,
)
class FalconForQuestionAnswering(FalconPreTrainedModel):
    """
    Falcon model for question answering tasks, extending FalconPreTrainedModel.

    Inherits from FalconPreTrainedModel and implements a transformer with a span classification head
    for tasks such as SQuAD. It includes linear layers to compute `span start logits` and `span end logits`.
    """

    def __init__(self, config):
        """
        Initializes the FalconForQuestionAnswering model.

        Args:
            config (FalconConfig): Configuration object specifying the model architecture and parameters.
        """
        super().__init__(config)
        # Initialize the FalconModel with the provided configuration
        self.transformer = FalconModel(config)
        # Linear layer for predicting start and end positions in the span
        self.qa_outputs = nn.Linear(config.hidden_size, 2)

        # Initialize weights and perform additional post-initialization processing
        self.post_init()

    @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        start_positions: Optional[torch.LongTensor] = None,
        end_positions: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        """
        Defines the forward pass for FalconForQuestionAnswering.

        Args:
            input_ids (torch.LongTensor, optional): Input token IDs.
            attention_mask (torch.FloatTensor, optional): Mask to avoid performing attention on padding tokens.
            head_mask (torch.FloatTensor, optional): Mask to nullify selected heads of the self-attention modules.
            inputs_embeds (torch.FloatTensor, optional): Embedded input tokens.
            start_positions (torch.LongTensor, optional): Index of the start position for the answer span.
            end_positions (torch.LongTensor, optional): Index of the end position for the answer span.
            output_attentions (bool, optional): Whether to output attentions weights.
            output_hidden_states (bool, optional): Whether to output hidden states.
            return_dict (bool, optional): Whether to return a dictionary as the output.

        Returns:
            FalconForQuestionAnsweringOutput: Output object containing the logits for start and end positions of the span.
        """
        # Forward pass through the FalconModel transformer
        outputs = self.transformer(
            input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # Compute logits for start and end positions using the qa_outputs linear layer
        logits = self.qa_outputs(outputs[0])

        # Return FalconForQuestionAnsweringOutput containing the logits
        return FalconForQuestionAnsweringOutput(
            loss=None if start_positions is None or end_positions is None else self.loss(logits, start_positions, end_positions),
            start_logits=logits[:, :, 0] if logits.shape[:2] == 3 else None,
            end_logits=logits[:, :, 1] if logits.shape[:2] == 3 else None,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
    ) -> Union[Tuple, QuestionAnsweringModelOutput]:
        r"""
        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        """
        # Decide whether to use the return_dict based on input or default configuration
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Pass inputs to the transformer model and retrieve outputs
        outputs = self.transformer(
            input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # Extract the sequence output from the transformer outputs
        sequence_output = outputs[0]

        # Get logits for question answering from the sequence output
        logits = self.qa_outputs(sequence_output)

        # Split logits into start and end logits for the answer span
        start_logits, end_logits = logits.split(1, dim=-1)

        # Squeeze the logits tensors to remove unnecessary dimensions
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()

        total_loss = None
        if start_positions is not None and end_positions is not None:
            # Handle multi-GPU training by squeezing additional dimensions if present
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)

            # Clamp positions to prevent them from exceeding sequence length
            ignored_index = start_logits.size(1)
            start_positions = start_positions.clamp(0, ignored_index)
            end_positions = end_positions.clamp(0, ignored_index)

            # Compute the CrossEntropyLoss for start and end positions
            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)

            # Calculate total loss as the average of start and end losses
            total_loss = (start_loss + end_loss) / 2

        if not return_dict:
            # Return outputs without loss if return_dict is False
            output = (start_logits, end_logits) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        # Return structured output using QuestionAnsweringModelOutput if return_dict is True
        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
posted @ 2024-06-30 15:36  绝不原创的飞龙  阅读(15)  评论(0编辑  收藏  举报