Lucidrains-系列项目源码解析-二-

Lucidrains 系列项目源码解析(二)

.\lucidrains\alphafold2\scripts\refinement.py

# 导入所需的库和模块
import os
import json
import warnings
# 科学计算库
import numpy as np
# 尝试导入 pyrosetta 模块,如果导入失败则发出警告
try: 
    import pyrosetta
except ModuleNotFoundError:
    msg = "Unable to find an existing installation of the PyRosetta module. " +\
          "Functions involving this module such as the FastRelax pipeline " +\
          "will not work."
    warnings.warn(msg) # no pyRosetta was found


#####################
### ROSETTA STUFF ###
#####################


def pdb2rosetta(route):
    """ Takes pdb file route(s) as input and returns rosetta pose(s). 
        Input:
        * route: list or string.
        Output: list of 1 or many according to input
       """
    # 如果输入是字符串,则返回包含单个 rosetta pose 的列表
    if isinstance(route, str):
        return [pyrosetta.io.pose_from_pdb(route)]
    else:
        return list(pyrosetta.io.poses_from_files(route))

def rosetta2pdb(pose, route, verbose=True):
    """ Takes pose(s) as input and saves pdb(s) to disk.
        Input:
        * pose: list or string. rosetta poses object(s).
        * route: list or string. destin filenames to be written.
        * verbose: bool. warns if lengths dont match and @ every write.
        Inspo:
        * https://www.rosettacommons.org/demos/latest/tutorials/input_and_output/input_and_output#controlling-output_common-structure-output-files_pdb-file
        * https://graylab.jhu.edu/PyRosetta.documentation/pyrosetta.rosetta.core.io.pdb.html#pyrosetta.rosetta.core.io.pdb.dump_pdb
    """
    # 将输入转换为列表
    pose  = [pose] if isinstance(pose, str) else pose
    route = [route] if isinstance(route, str) else route
    # 检查长度是否匹配,如果不匹配则发出警告
    if verbose and ( len(pose) != len(route) ):
        print("Length of pose and route are not the same. Will stop at the minimum.")
    # 转换并保存
    for i,pos in enumerate(pose):
        pyrosetta.rosetta.core.io.pdb.dump_pdb(pos, route[i])
        if verbose:
            print("Saved structure @ "+route)
    return

def run_fast_relax(config_route, pdb_route=None, pose=None):
    """ Runs the Fast-Relax pipeline.
        * config_route: route to json file with config
        * pose: rosetta pose to run the pipeline on
        Output: rosetta pose
    """
    # 加载 rosetta pose - 如果传入字符串或列表,则转换为 pose + 重新调用
    if isinstance(pdb_route, str):
        pose = pdb2rosetta(pdb_route)
        return run_fast_relax(config, pose=pose)
    elif isinstance(pdb_route, list):
        return [run_fast_relax(config, pdb_route=pdb) for pdb in pdb_route]
    # 加载配置文件
    config = json.load(config_route)
    # 运行 Fast-Relax pipeline - 示例:
    # https://colab.research.google.com/github/RosettaCommons/PyRosetta.notebooks/blob/master/notebooks/06.02-Packing-design-and-regional-relax.ipynb#scrollTo=PYr025Rn1Q8i
    # https://nbviewer.jupyter.org/github/RosettaCommons/PyRosetta.notebooks/blob/master/notebooks/06.03-Design-with-a-resfile-and-relax.ipynb
    # https://faculty.washington.edu/dimaio/files/demo2.py
    raise NotImplementedError("Last step. Not implemented yet.")

.\lucidrains\alphafold2\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'alphafold2-pytorch',
  # 查找并包含所有包
  packages = find_packages(),
  # 版本号
  version = '0.4.32',
  # 许可证
  license='MIT',
  # 描述
  description = 'AlphaFold2 - Pytorch',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 作者
  author = 'Phil Wang, Eric Alcaide',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com, ericalcaide1@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/alphafold2',
  # 关键词
  keywords = [
    'artificial intelligence',
    'attention mechanism',
    'protein folding'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.3',
    'En-transformer>=0.2.3',
    'invariant-point-attention',
    'mdtraj>=1.8',
    'numpy',
    'proDy',
    'pytorch3d',
    'requests',
    'sidechainnet',
    'torch>=1.6',
    'transformers',
    'tqdm',
    'biopython',
    'mp-nerf>=0.1.5'
  ],
  # 设置需要的依赖
  setup_requires=[
    'pytest-runner',
  ],
  # 测试需要的依赖
  tests_require=[
    'pytest'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.7',
  ],
)

.\lucidrains\alphafold2\tests\test_attention.py

import torch
from torch import nn
from einops import repeat

from alphafold2_pytorch.alphafold2 import Alphafold2
from alphafold2_pytorch.utils import *

# 定义测试函数 test_main
def test_main():
    # 创建 Alphafold2 模型对象
    model = Alphafold2(
        dim = 32,
        depth = 2,
        heads = 2,
        dim_head = 32
    )

    # 生成随机序列数据和多序列比对数据
    seq = torch.randint(0, 21, (2, 128))
    msa = torch.randint(0, 21, (2, 5, 128))
    mask = torch.ones_like(seq).bool()
    msa_mask = torch.ones_like(msa).bool()

    # 使用模型进行预测
    distogram = model(
        seq,
        msa,
        mask = mask,
        msa_mask = msa_mask
    )
    # 断言测试结果为真
    assert True

# 定义测试函数 test_no_msa
def test_no_msa():
    # 创建 Alphafold2 模型对象
    model = Alphafold2(
        dim = 32,
        depth = 2,
        heads = 2,
        dim_head = 32
    )

    # 生成随机序列数据和掩码
    seq = torch.randint(0, 21, (2, 128))
    mask = torch.ones_like(seq).bool()

    # 使用模型进行预测
    distogram = model(
        seq,
        mask = mask
    )
    # 断言测试结果为真
    assert True

# 定义测试函数 test_anglegrams
def test_anglegrams():
    # 创建 Alphafold2 模型对象
    model = Alphafold2(
        dim = 32,
        depth = 2,
        heads = 2,
        dim_head = 32,
        predict_angles = True
    )

    # 生成随机序列数据、多序列比对数据和掩码
    seq = torch.randint(0, 21, (2, 128))
    msa = torch.randint(0, 21, (2, 5, 128))
    mask = torch.ones_like(seq).bool()
    msa_mask = torch.ones_like(msa).bool()

    # 使用模型进行预测
    ret = model(
        seq,
        msa,
        mask = mask,
        msa_mask = msa_mask
    )
    # 断言测试结果为真
    assert True

# 定义测试函数 test_templates
def test_templates():
    # 创建 Alphafold2 模型对象
    model = Alphafold2(
        dim = 32,
        depth = 2,
        heads = 2,
        dim_head = 32,
        templates_dim = 32,
        templates_angles_feats_dim = 32
    )

    # 生成随机序列数据、多序列比对数据和掩码
    seq = torch.randint(0, 21, (2, 16))
    mask = torch.ones_like(seq).bool()
    msa = torch.randint(0, 21, (2, 5, 16))
    msa_mask = torch.ones_like(msa).bool()

    # 生成随机模板特征数据、模板角度数据和模板掩码
    templates_feats = torch.randn(2, 3, 16, 16, 32)
    templates_angles = torch.randn(2, 3, 16, 32)
    templates_mask = torch.ones(2, 3, 16).bool()

    # 使用模型进行预测
    distogram = model(
        seq,
        msa,
        mask = mask,
        msa_mask = msa_mask,
        templates_feats = templates_feats,
        templates_angles = templates_angles,
        templates_mask = templates_mask
    )
    # 断言测试结果为真
    assert True

# 定义测试函数 test_extra_msa
def test_extra_msa():
    # 创建 Alphafold2 模型对象
    model = Alphafold2(
        dim = 128,
        depth = 2,
        heads = 2,
        dim_head = 32,
        predict_coords = True
    )

    # 生成随机序列数据、多序列比对数据和掩码
    seq = torch.randint(0, 21, (2, 4))
    mask = torch.ones_like(seq).bool()
    msa = torch.randint(0, 21, (2, 5, 4))
    msa_mask = torch.ones_like(msa).bool()

    # 生成额外的多序列比对数据和掩码
    extra_msa = torch.randint(0, 21, (2, 5, 4))
    extra_msa_mask = torch.ones_like(extra_msa).bool()

    # 使用模型进行预测
    coords = model(
        seq,
        msa,
        mask = mask,
        msa_mask = msa_mask,
        extra_msa = extra_msa,
        extra_msa_mask = extra_msa_mask
    )
    # 断言测试结果为真
    assert True

# 定义测试函数 test_embeddings
def test_embeddings():
    # 创建 Alphafold2 模型对象
    model = Alphafold2(
        dim = 32,
        depth = 2,
        heads = 2,
        dim_head = 32
    )

    # 生成随机序列数据、掩码和嵌入数据
    seq = torch.randint(0, 21, (2, 16))
    mask = torch.ones_like(seq).bool()
    embedds = torch.randn(2, 1, 16, 1280)

    # 使用模型进行预测(不带掩码)
    distogram = model(
        seq,
        mask = mask,
        embedds = embedds,
        msa_mask = None
    )
    
    # 生成嵌入数据的掩码
    embedds_mask = torch.ones_like(embedds[..., -1]).bool()

    # 使用模型进行预测(带掩码)
    distogram = model(
        seq,
        mask = mask,
        embedds = embedds,
        msa_mask = embedds_mask
    )
    # 断言测试结果为真
    assert True

# 定义测试函数 test_coords
def test_coords():
    # 创建 Alphafold2 模型对象
    model = Alphafold2(
        dim = 32,
        depth = 2,
        heads = 2,
        dim_head = 32,
        predict_coords = True,
        structure_module_depth = 1,
        structure_module_heads = 1,
        structure_module_dim_head = 1,
    )

    # 生成随机序列数据、多序列比对数据和掩码
    seq = torch.randint(0, 21, (2, 16))
    mask = torch.ones_like(seq).bool()
    msa = torch.randint(0, 21, (2, 5, 16))
    msa_mask = torch.ones_like(msa).bool()

    # 使用模型进行预测
    coords = model(
        seq,
        msa,
        mask = mask,
        msa_mask = msa_mask
    )

    # 断言输出坐标的形状为 (2, 16, 3)
    assert coords.shape == (2, 16, 3), 'must output coordinates'

# 定义测试函数 test_coords_backbone_with_cbeta
def test_coords_backbone_with_cbeta():
    # 创建 Alphafold2 模型对象
    model = Alphafold2(
        dim = 32,
        depth = 2,
        heads = 2,
        dim_head = 32,
        predict_coords = True,
        structure_module_depth = 1,
        structure_module_heads = 1,
        structure_module_dim_head = 1,
    )

    # 生成随机序列数据、多序列比对数据和掩码
    seq = torch.randint(0, 21, (2, 16))
    mask = torch.ones_like(seq).bool()
    msa = torch.randint(0, 21, (2, 5, 16))
    msa_mask = torch.ones_like(msa).bool()

    # 使用模型进行预测
    coords = model(
        seq,
        msa,
        mask = mask,
        msa_mask = msa_mask
    )

    # 断言输出坐标的形状为 (2, 16, 3)
    assert coords.shape == (2, 16, 3), 'must output coordinates'

# 定义测试函数 test_coords_all_atoms
def test_coords_all_atoms():
    # 创建 Alphafold2 模型对象
    model = Alphafold2(
        dim = 32,
        depth = 2,
        heads = 2,
        dim_head = 32,
        predict_coords = True,
        structure_module_depth = 1,
        structure_module_heads = 1,
        structure_module_dim_head = 1,
    )

    # 生成随机序列数据、多序列比对数据和掩码
    seq = torch.randint(0, 21, (2, 16))
    mask = torch.ones_like(seq).bool()
    msa = torch.randint(0, 21, (2, 5, 16))
    msa_mask = torch.ones_like(msa).bool()

    # 使用模型进行预测
    coords = model(
        seq,
        msa,
        mask = mask,
        msa_mask = msa_mask
    )

    # 断言输出坐标的形状为 (2, 16, 3)
    assert coords.shape == (2, 16, 3), 'must output coordinates'

# 定义测试函数 test_mds
def test_mds():
    # 创建 Alphafold2 模型对象
    model = Alphafold2(
        dim = 32,
        depth = 2,
        heads = 2,
        dim_head = 32,
        predict_coords = True,
        structure_module_depth = 1,
        structure_module_heads = 1,
        structure_module_dim_head = 1,
    )

    # 生成随机序列数据、多序列比对数据和掩码
    seq = torch.randint(0, 21, (2, 16))
    mask = torch.ones_like(seq).bool()
    msa = torch.randint(0, 21, (2, 5, 16))
    msa_mask = torch.ones_like(msa).bool()

    # 使用模型进行预测
    coords = model(
        seq,
        msa,
        mask = mask,
        msa_mask = msa_mask
    )

    # 断言输出坐标的形状为 (2, 16, 3)
    assert coords.shape == (2, 16, 3), 'must output coordinates'

# 定义测试函数 test_edges_to_equivariant_network
def test_edges_to_equivariant_network():
    # 创建 Alphafold2 模型对象
    model = Alphafold2(
        dim = 32,
        depth = 1,
        heads = 2,
        dim_head = 32,
        predict_coords = True,
        predict_angles = True
    )

    # 生成随机序列数据、多序列比对数据和掩码
    seq = torch.randint(0, 21, (2, 32))
    mask = torch.ones_like(seq).bool()
    msa = torch.randint(0, 21, (2, 5, 32))
    msa_mask = torch.ones_like(msa).bool()

    # 使用模型进行预测
    coords, confidences = model(
        seq,
        msa,
        mask = mask,
        msa_mask = msa_mask,
        return_confidence = True
    )
    # 断言测试结果为真
    assert True, 'should run without errors'

# 定义测试函数 test_coords_backwards
def test_coords_backwards():
    # 创建 Alphafold2 模型对象
    model = Alphafold2(
        dim = 256,
        depth = 2,
        heads = 2,
        dim_head = 32,
        predict_coords = True,
        structure_module_depth = 1,
        structure_module_heads = 1,
        structure_module_dim_head = 1,
    )

    # 生成随机序列数据、多序列比对数据和掩码
    seq = torch.randint(0, 21, (2, 16))
    mask = torch.ones_like(seq).bool()
    msa = torch.randint(0, 21, (2, 5, 16))
    msa_mask = torch.ones_like(msa).bool()

    # 使用模型进行预测
    coords = model(
        seq,
        msa,
        mask = mask,
        msa_mask = msa_mask
    )

    # 反向传播
    coords.sum().backward()
    assert True, 'must be able to go backwards through MDS and center distogram'

# 定义测试函数 test_confidence
def test_confidence():
    # 创建 Alphafold2 模型对象
    model = Alphafold2(
        dim = 256,
        depth = 1,
        heads = 2,
        dim_head = 32,
        predict_coords = True
    )

    # 生成随机序列数据、多序列比对数据和掩码
    seq = torch.randint(0, 21, (2, 16))
    mask = torch.ones_like(seq).bool()
    msa = torch.randint(0, 21, (2, 5, 16))
    msa_mask = torch.ones_like(msa).bool()

    # 使用模型进行预测
    coords, confidences = model(
        seq,
        msa,
        mask = mask,
        msa_mask = msa_mask,
        return_confidence = True
    )
    
    # 断言坐标和置信度的形状相同
    assert coords.shape[:-1] == confidences.shape[:-1]

# 定义测试函数 test_recycling
def test_recycling():
    # 创建 Alphafold2 模型对象
    model = Alphafold2(
        dim = 128,
        depth = 2,
        heads = 2,
        dim_head = 32,
        predict_coords = True,
    )

    # 生成随机序列数据、多序列比对数据和掩码
    seq = torch.randint(0, 21, (2, 4))
    mask = torch.ones_like(seq).bool()
    msa = torch.randint(0, 21, (2, 5, 4))
    msa_mask = torch.ones_like(msa).bool()

    # 生成额外的多序列比对数据和掩码
    extra_msa = torch.randint(0, 21, (2, 5, 4))
    extra_msa_mask = torch.ones_like(extra_msa).bool()
    # 调用模型,传入序列、多序列比对、掩码、多序列比对掩码、额外多序列比对、额外多序列比对掩码等参数,并返回坐标和结果
    coords, ret = model(
        seq,
        msa,
        mask = mask,
        msa_mask = msa_mask,
        extra_msa = extra_msa,
        extra_msa_mask = extra_msa_mask,
        return_aux_logits = True,  # 返回辅助日志
        return_recyclables = True  # 返回可回收的数据
    )

    # 调用模型,传入序列、多序列比对、掩码、多序列比对掩码、额外多序列比对、额外多序列比对掩码、可回收的数据等参数,并返回坐标和结果
    coords, ret = model(
        seq,
        msa,
        mask = mask,
        msa_mask = msa_mask,
        extra_msa = extra_msa,
        extra_msa_mask = extra_msa_mask,
        recyclables = ret.recyclables,  # 使用上一个调用返回的可回收数据
        return_aux_logits = True,  # 返回辅助日志
        return_recyclables = True  # 返回可回收的数据
    )

    # 断言,确保条件为真,否则会引发异常
    assert True

.\lucidrains\alphafold2\tests\test_utils.py

import torch
import numpy as np
from alphafold2_pytorch.utils import *

# 测试 mat_input_to_masked 函数
def test_mat_to_masked():
    # nodes
    x = torch.ones(19, 3)
    x_mask = torch.randn(19) > -0.3
    # edges
    edges_mat = torch.randn(19, 19) < 1
    edges = torch.nonzero(edges_mat, as_tuple=False).t()

    # 测试正常的边缘/节点
    cleaned = mat_input_to_masked(x, x_mask, edges=edges)
    cleaned_2 = mat_input_to_masked(x, x_mask, edges_mat=edges_mat)

    # 测试批处理维度
    x_ = torch.stack([x]*2, dim=0)
    x_mask_ = torch.stack([x_mask]*2, dim=0)
    edges_mat_ = torch.stack([edges_mat]*2, dim=0)

    cleaned_3 = mat_input_to_masked(x_, x_mask_, edges_mat=edges_mat_)
    assert True

# 测试 center_distogram_torch 函数
def test_center_distogram_median():
    distogram = torch.randn(1, 128, 128, 37)
    distances, weights = center_distogram_torch(distogram, center='median')
    assert True

# 测试 scn_backbone_mask 函数
def test_masks():
    seqs = torch.randint(20, size=(2, 50))
    N_mask, CA_mask, C_mask = scn_backbone_mask(seqs, boolean=True)
    assert True

# 测试 MDScaling 函数
def test_mds_and_mirrors():
    distogram = torch.randn(2, 32*3, 32*3, 37)

    distances, weights = center_distogram_torch(distogram)
    paddings = [7, 0]
    for i, pad in enumerate(paddings):
        if pad > 0:
            weights[i, -pad:, -pad:] = 0.

    masker = torch.arange(distogram.shape[1]) % 3
    N_mask = (masker == 0).bool()
    CA_mask = (masker == 1).bool()
    coords_3d, _ = MDScaling(distances, weights=weights, iters=5, fix_mirror=2, N_mask=N_mask, CA_mask=CA_mask, C_mask=None)
    assert list(coords_3d.shape) == [2, 3, 32*3], 'coordinates must be of the right shape after MDS'

# 测试 sidechain_container 函数
def test_sidechain_container():
    seqs = torch.tensor([[0]*137, [3]*137]).long()
    bb = torch.randn(2, 137*4, 3)
    atom_mask = torch.tensor([1]*4 + [0]*(14-4))
    proto_3d = sidechain_container(seqs, bb, atom_mask=atom_mask)
    assert list(proto_3d.shape) == [2, 137, 14, 3]

# 测试 distmat_loss_torch 函数
def test_distmat_loss():
    a = torch.randn(2, 137, 14, 3)
    b = torch.randn(2, 137, 14, 3)
    loss = distmat_loss_torch(a, b, p=2, q=2)  # mse on distmat
    assert True

# 测试 lddt_ca_torch 函数
def test_lddt():
    a = torch.randn(2, 137, 14, 3)
    b = torch.randn(2, 137, 14, 3)
    cloud_mask = torch.ones(a.shape[:-1]).bool()
    lddt_result = lddt_ca_torch(a, b, cloud_mask)
    assert list(lddt_result.shape) == [2, 137]

# 测试 Kabsch 函数
def test_kabsch():
    a = torch.randn(3, 8)
    b = torch.randn(3, 8) 
    a_, b_ = Kabsch(a, b)
    assert a.shape == a_.shape

# 测试 TMscore 函数
def test_tmscore():
    a = torch.randn(2, 3, 8)
    b = torch.randn(2, 3, 8)
    out = TMscore(a, b)
    assert True

# 测试 GDT 函数
def test_gdt():
    a = torch.randn(1, 3, 8)
    b = torch.randn(1, 3, 8)
    GDT(a, b, weights=1)
    assert True

.\lucidrains\alphafold2\training_scripts\datasets\trrosetta.py

import pickle
import string
from argparse import ArgumentParser
from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union

import numpy as np
import numpy.linalg as LA
import prody
import torch
from Bio import SeqIO
from einops import repeat
from sidechainnet.utils.measure import get_seq_coords_and_angles
from sidechainnet.utils.sequence import ProteinVocabulary
from torch.utils.data import DataLoader, Dataset
from alphafold2_pytorch.constants import DISTOGRAM_BUCKETS
from tqdm import tqdm

try:
    import pytorch_lightning as pl

    LightningDataModule = pl.LightningDataModule
except ImportError:
    LightningDataModule = object

CACHE_PATH = Path("~/.cache/alphafold2_pytorch").expanduser()
DATA_DIR = CACHE_PATH / "trrosetta" / "trrosetta"
URL = "http://s3.amazonaws.com/proteindata/data_pytorch/trrosetta.tar.gz"

REMOVE_KEYS = dict.fromkeys(string.ascii_lowercase)
REMOVE_KEYS["."] = None
REMOVE_KEYS["*"] = None
translation = str.maketrans(REMOVE_KEYS)

DEFAULT_VOCAB = ProteinVocabulary()


def default_tokenize(seq: str) -> List[int]:
    return [DEFAULT_VOCAB[ch] for ch in seq]


def read_fasta(filename: str) -> List[Tuple[str, str]]:
    def remove_insertions(sequence: str) -> str:
        return sequence.translate(translation)

    return [
        (record.description, remove_insertions(str(record.seq)))
        for record in SeqIO.parse(filename, "fasta")
    ]


def read_pdb(pdb: str):
    ag = prody.parsePDB(pdb)
    for chain in ag.iterChains():
        angles, coords, seq = get_seq_coords_and_angles(chain)
        return angles, coords, seq


def download_file(url, filename=None, root=CACHE_PATH):
    import os
    import urllib

    root.mkdir(exist_ok=True, parents=True)
    filename = filename or os.path.basename(url)

    download_target = root / filename
    download_target_tmp = root / f"tmp.{filename}"

    if download_target.exists() and not download_target.is_file():
        raise RuntimeError(f"{download_target} exists and is not a regular file")

    if download_target.is_file():
        return download_target

    with urllib.request.urlopen(url) as source, open(
        download_target_tmp, "wb"
    ) as output:
        with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))

    download_target_tmp.rename(download_target)
    return download_target


def get_or_download(url: str = URL):
    """
    download and extract trrosetta data
    """
    import tarfile

    file = CACHE_PATH / "trrosetta.tar.gz"
    dir = CACHE_PATH / "trrosetta"
    dir_temp = CACHE_PATH / "trrosetta_tmp"
    if dir.is_dir():
        print(f"Load cached data from {dir}")
        return dir

    if not file.is_file():
        print(f"Cache not found, download from {url} to {file}")
        download_file(url)

    print(f"Extract data from {file} to {dir}")
    with tarfile.open(file, "r:gz") as tar:
        tar.extractall(dir_temp)

    dir_temp.rename(dir)

    return dir


def pad_sequences(sequences, constant_value=0, dtype=None) -> np.ndarray:
    batch_size = len(sequences)
    shape = [batch_size] + np.max([seq.shape for seq in sequences], 0).tolist()

    if dtype is None:
        dtype = sequences[0].dtype

    if isinstance(sequences[0], np.ndarray):
        array = np.full(shape, constant_value, dtype=dtype)
    elif isinstance(sequences[0], torch.Tensor):
        array = torch.full(shape, constant_value, dtype=dtype)

    for arr, seq in zip(array, sequences):
        arrslice = tuple(slice(dim) for dim in seq.shape)
        arr[arrslice] = seq

    return array


class TrRosettaDataset(Dataset):
    def __init__(
        self,
        data_dir: Path,
        list_path: Path,
        tokenize: Callable[[str], List[int]],
        seq_pad_value: int = 20,
        random_sample_msa: bool = False,
        max_seq_len: int = 300,
        max_msa_num: int = 300,
        overwrite: bool = False,
    ):
        self.data_dir = data_dir
        self.file_list: List[Path] = self.read_file_list(data_dir, list_path)

        self.tokenize = tokenize
        self.seq_pad_value = seq_pad_value

        self.random_sample_msa = random_sample_msa
        self.max_seq_len = max_seq_len
        self.max_msa_num = max_msa_num

        self.overwrite = overwrite

    def __len__(self) -> int:
        return len(self.file_list)

    def read_file_list(self, data_dir: Path, list_path: Path):
        file_glob = (data_dir / "npz").glob("*.npz")
        files = set(list_path.read_text().split())
        if len(files) == 0:
            raise ValueError("Passed an empty split file set")

        file_list = [f for f in file_glob if f.name in files]
        if len(file_list) != len(files):
            num_missing = len(files) - len(file_list)
            raise FileNotFoundError(
                f"{num_missing} specified split files not found in directory"
            )

        return file_list

    def has_cache(self, index):
        if self.overwrite:
            return False

        path = (self.data_dir / "cache" / self.file_list[index].stem).with_suffix(
            ".pkl"
        )
        return path.is_file()

    def write_cache(self, index, data):
        path = (self.data_dir / "cache" / self.file_list[index].stem).with_suffix(
            ".pkl"
        )
        path.parent.mkdir(exist_ok=True, parents=True)
        with open(path, "wb") as file:
            pickle.dump(data, file)

    def read_cache(self, index):
        path = (self.data_dir / "cache" / self.file_list[index].stem).with_suffix(
            ".pkl"
        )
        with open(path, "rb") as file:
            return pickle.load(file)

    def __getitem__(self, index):
        if self.has_cache(index):
            item = self.read_cache(index)
        else:
            id = self.file_list[index].stem
            pdb_path = self.data_dir / "pdb" / f"{id}.pdb"
            msa_path = self.data_dir / "a3m" / f"{id}.a3m"
            _, msa = zip(*read_fasta(str(msa_path)))
            msa = np.array([np.array(list(seq)) for seq in msa])
            angles, coords, seq = read_pdb(str(pdb_path))
            seq = np.array(list(seq))
            coords = coords.reshape((coords.shape[0] // 14, 14, 3))
            dist = self.get_bucketed_distance(seq, coords, subset="ca")
            item = {
                "id": id,
                "seq": seq,
                "msa": msa,
                "coords": coords,
                "angles": angles,
                "dist": dist
            }
            self.write_cache(index, item)

        item["msa"] = self.sample(item["msa"], self.max_msa_num, self.random_sample_msa)
        item = self.crop(item, self.max_seq_len)
        return item

    def calc_cb(self, coord):
        N = coord[0]
        CA = coord[1]
        C = coord[2]

        b = CA - N
        c = C - CA
        a = np.cross(b, c)
        CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
        return CB

    def get_bucketed_distance(
        self, seq, coords, subset="ca", start=2, bins=DISTOGRAM_BUCKETS-1, step=0.5
        assert subset in ("ca", "cb")
        # 检查 subset 是否为 "ca" 或 "cb"
        
        if subset == "ca":
            coords = coords[:, 1, :]
            # 如果 subset 为 "ca",则只保留坐标的第二列数据
        
        elif subset == "cb":
            cb_coords = []
            # 创建空列表用于存储 cb 坐标数据
            for res, coord in zip(seq, coords):
                # 遍历序列和坐标数据
                if res == "G":
                    # 如果氨基酸为 "G"
                    cb = self.calc_cb(coord)
                    # 计算 cb 坐标
                    cb_coords.append(cb)
                    # 将计算得到的 cb 坐标添加到列表中
                else:
                    cb_coords.append(coord[4, :])
                    # 如果氨基酸不是 "G",则将坐标的第五行数据添加到列表中
            coords = np.array(cb_coords)
            # 将列表转换为 NumPy 数组,更新坐标数据
        
        vcs = coords + np.zeros([coords.shape[0]] + list(coords.shape))
        # 创建与 coords 形状相同的全零数组,并与 coords 相加,得到 vcs
        vcs = vcs - np.swapaxes(vcs, 0, 1)
        # 将 vcs 与其转置矩阵相减,更新 vcs
        distance_map = LA.norm(vcs, axis=2)
        # 计算 vcs 的二范数,得到距离矩阵
        mask = np.ones(distance_map.shape) - np.eye(distance_map.shape[0])
        # 创建与距离矩阵形状相同的全一数组,减去单位矩阵,得到 mask
        low_pos = np.where(distance_map < start)
        # 找出距离矩阵中小于 start 的位置
        high_pos = np.where(distance_map >= start + step * bins)
        # 找出距离矩阵中大于等于 start + step * bins 的位置

        mask[low_pos] = 0
        # 将低于 start 的位置在 mask 中置为 0
        distance_map = (distance_map - start) // step
        # 对距离矩阵进行归一化处理
        distance_map[high_pos] = bins
        # 将高于 start + step * bins 的位置在距离矩阵中置为 bins
        dist = (distance_map * mask).astype(int)
        # 将归一化后的距离矩阵乘以 mask,并转换为整数类型,得到最终距离矩阵
        return dist
        # 返回距离矩阵

    def crop(self, item, max_seq_len: int):
        # 截取序列数据,使其长度不超过 max_seq_len
        seq_len = len(item["seq"])

        if seq_len <= max_seq_len or max_seq_len <= 0:
            return item
            # 如果序列长度小于等于 max_seq_len 或 max_seq_len 小于等于 0,则直接返回原始数据

        start = 0
        end = start + max_seq_len
        # 计算截取的起始位置和结束位置

        item["seq"] = item["seq"][start:end]
        item["msa"] = item["msa"][:, start:end]
        item["coords"] = item["coords"][start:end]
        item["angles"] = item["angles"][start:end]
        item["dist"] = item["dist"][start:end, start:end]
        # 对 item 中的各项数据进行截取操作
        return item
        # 返回截取后的数据

    def sample(self, msa, max_msa_num: int, random: bool):
        # 对多序列进行采样,使其数量不超过 max_msa_num
        num_msa, seq_len = len(msa), len(msa[0])

        if num_msa <= max_msa_num or max_msa_num <= 0:
            return msa
            # 如果多序列数量小于等于 max_msa_num 或 max_msa_num 小于等于 0,则直接返回原始数据

        if random:
            # 如果需要随机采样
            num_sample = max_msa_num - 1
            # 计算需要采样的数量
            indices = np.random.choice(num_msa - 1, size=num_sample, replace=False) + 1
            # 随机选择索引进行采样
            indices = np.pad(indices, [1, 0], "constant")
            # 在索引数组前面添加一个元素
            return msa[indices]
            # 返回采样后的多序列数据
        else:
            return msa[:max_msa_num]
            # 如果不需要随机采样,则直接返回前 max_msa_num 个多序列数据

    def collate_fn(self, batch):
        # 对批量数据进行整理
        b = len(batch)
        # 获取批量数据的长度
        batch = {k: [item[k] for item in batch] for k in batch[0]}
        # 将批量数据转换为字典形式,按照键值进行整理

        id = batch["id"]
        seq = batch["seq"]
        msa = batch["msa"]
        coords = batch["coords"]
        angles = batch["angles"]
        dist = batch["dist"]
        # 获取批量数据中的各项内容

        lengths = torch.LongTensor([len(x[0]) for x in msa])
        depths = torch.LongTensor([len(x) for x in msa])
        max_len = lengths.max()
        max_depth = depths.max()
        # 计算多序列数据的长度和深度信息

        seq = pad_sequences(
            [torch.LongTensor(self.tokenize(seq_)) for seq_ in seq], self.seq_pad_value,
        )
        # 对序列数据进行填充处理

        msa = pad_sequences(
            [torch.LongTensor([self.tokenize(seq_) for seq_ in msa_]) for msa_ in msa],
            self.seq_pad_value,
        )
        # 对多序列数据进行填充处理

        coords = pad_sequences([torch.FloatTensor(x) for x in coords], 0.0)
        # 对坐标数据进行填充处理

        angles = pad_sequences([torch.FloatTensor(x) for x in angles], 0.0)
        # 对角度数据进行填充处理

        dist = pad_sequences([torch.LongTensor(x) for x in dist], -100)
        # 对距离数据进行填充处理

        mask = repeat(torch.arange(max_len), "l -> b l", b=b) < repeat(
            lengths, "b -> b l", l=max_len
        )
        # 生成序列数据的掩码

        msa_seq_mask = repeat(
            torch.arange(max_len), "l -> b s l", b=b, s=max_depth
        ) < repeat(lengths, "b -> b s l", s=max_depth, l=max_len)
        # 生成多序列数据的序列掩码

        msa_depth_mask = repeat(
            torch.arange(max_depth), "s -> b s l", b=b, l=max_len
        ) < repeat(depths, "b -> b s l", s=max_depth, l=max_len)
        # 生成多序列数据的深度掩码

        msa_mask = msa_seq_mask & msa_depth_mask
        # 组合多序列数据的掩码

        return {
            "id": id,
            "seq": seq,
            "msa": msa,
            "coords": coords,
            "angles": angles,
            "mask": mask,
            "msa_mask": msa_mask,
            "dist": dist,
        }
        # 返回整理后的批量��据
class TrRosettaDataModule(LightningDataModule):
    @staticmethod
    def add_data_specific_args(parent_parser):
        # 创建参数解析器
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        # 添加数据目录参数
        parser.add_argument("--data_dir", type=str, default=str(DATA_DIR))
        # 添加训练批量大小参数
        parser.add_argument("--train_batch_size", type=int, default=1)
        # 添加评估批量大小参数
        parser.add_argument("--eval_batch_size", type=int, default=1)
        # 添加测试批量大小参数
        parser.add_argument("--test_batch_size", type=int, default=1)
        # 添加工作进程数参数
        parser.add_argument("--num_workers", type=int, default=0)
        # 添加训练最大序列长度参数
        parser.add_argument("--train_max_seq_len", type=int, default=256)
        # 添加评估最大序列长度参数
        parser.add_argument("--eval_max_seq_len", type=int, default=256)
        # 添加测试最大序列长度参数
        parser.add_argument("--test_max_seq_len", type=int, default=-1)
        # 添加训练最大 MSA 数量参数
        parser.add_argument("--train_max_msa_num", type=int, default=256)
        # 添加评估最大 MSA 数量参数
        parser.add_argument("--eval_max_msa_num", type=int, default=256)
        # 添加测试最大 MSA 数量参数
        parser.add_argument("--test_max_msa_num", type=int, default=1000)
        # 添加覆盖参数
        parser.add_argument("--overwrite", dest="overwrite", action="store_true")
        # 返回参数解析器
        return parser

    def __init__(
        self,
        data_dir: str = DATA_DIR,
        train_batch_size: int = 1,
        eval_batch_size: int = 1,
        test_batch_size: int = 1,
        num_workers: int = 0,
        train_max_seq_len: int = 256,
        eval_max_seq_len: int = 256,
        test_max_seq_len: int = -1,
        train_max_msa_num: int = 32,
        eval_max_msa_num: int = 32,
        test_max_msa_num: int = 64,
        tokenize: Callable[[str], List[int]] = default_tokenize,
        seq_pad_value: int = 20,
        overwrite: bool = False,
        **kwargs,
    ):
        # 调用父类构造函数
        super(TrRosettaDataModule, self).__init__()
        # 解析数据目录
        self.data_dir = Path(data_dir).expanduser().resolve()
        # 初始化各参数
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.test_batch_size = test_batch_size
        self.num_workers = num_workers
        self.train_max_seq_len = train_max_seq_len
        self.eval_max_seq_len = eval_max_seq_len
        self.test_max_seq_len = test_max_seq_len
        self.train_max_msa_num = train_max_msa_num
        self.eval_max_msa_num = eval_max_msa_num
        self.test_max_msa_num = test_max_msa_num
        self.tokenize = tokenize
        self.seq_pad_value = seq_pad_value
        self.overwrite = overwrite
        # 获取或下载数据
        get_or_download()

    def setup(self, stage: Optional[str] = None):
        # 设置训练数据集
        self.train = TrRosettaDataset(
            self.data_dir,
            self.data_dir / "train_files.txt",
            self.tokenize,
            self.seq_pad_value,
            random_sample_msa=True,
            max_seq_len=self.train_max_seq_len,
            max_msa_num=self.train_max_msa_num,
            overwrite=self.overwrite,
        )
        # 设置验证数据集
        self.val = TrRosettaDataset(
            self.data_dir,
            self.data_dir / "valid_files.txt",
            self.tokenize,
            self.seq_pad_value,
            random_sample_msa=False,
            max_seq_len=self.eval_max_seq_len,
            max_msa_num=self.eval_max_msa_num,
            overwrite=self.overwrite,
        )
        # 设置测试数据集
        self.test = TrRosettaDataset(
            self.data_dir,
            self.data_dir / "valid_files.txt",
            self.tokenize,
            self.seq_pad_value,
            random_sample_msa=False,
            max_seq_len=self.test_max_seq_len,
            max_msa_num=self.test_max_msa_num,
            overwrite=self.overwrite,
        )

    def train_dataloader(self, *args, **kwargs) -> DataLoader:
        # 返回训练数据加载器
        return DataLoader(
            self.train,
            batch_size=self.train_batch_size,
            shuffle=True,
            collate_fn=self.train.collate_fn,
            num_workers=self.num_workers,
        )

    def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
        # 返回验证数据加载器
        return DataLoader(
            self.val,
            batch_size=self.eval_batch_size,
            shuffle=False,
            collate_fn=self.val.collate_fn,
            num_workers=self.num_workers,
        )

    def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
        # 返回测试数据加载器
        return DataLoader(
            self.test,
            batch_size=self.test_batch_size,
            shuffle=False,
            collate_fn=self.test.collate_fn,
            num_workers=self.num_workers,
        )


def test():
    # 创建数据模块实例
    dm = TrRosettaDataModule(train_batch_size=1, num_workers=4)
    # 设置数据
    dm.setup()

    # 遍历训练数据加载器
    for batch in dm.train_dataloader():
        print("id", batch["id"])
        print("seq", batch["seq"].shape, batch["seq"])
        print("msa", batch["msa"].shape, batch["msa"][..., :20])
        print("msa", batch["msa"].shape, batch["msa"][..., -20:])
        print("coords", batch["coords"].shape)
        print("angles", batch["angles"].shape)
        print("mask", batch["mask"].shape)
        print("msa_mask", batch["msa_mask"].shape)
        print("dist", batch["dist"].shape, batch["dist"])
        break


if __name__ == "__main__":
    test()

.\lucidrains\alphafold2\training_scripts\datasets\__init__.py

# 定义一个名为calculate_area的函数,用于计算矩形的面积
def calculate_area(length, width):
    # 计算矩形的面积
    area = length * width
    # 返回计算得到的面积
    return area

.\lucidrains\alphafold2\training_scripts\deepspeed.py

# 定义一个名为calculate_area的函数,用于计算矩形的面积
def calculate_area(length, width):
    # 计算矩形的面积
    area = length * width
    # 返回计算得到的面积
    return area

.\lucidrains\alphafold2\training_scripts\lightning.py

# 定义一个名为calculate_area的函数,用于计算矩形的面积
def calculate_area(length, width):
    # 计算矩形的面积
    area = length * width
    # 返回计算得到的面积
    return area

.\lucidrains\alphafold2\train_end2end.py

# 导入所需的库
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
import torch.nn.functional as F
from einops import rearrange

# 导入数据处理相关的库
import sidechainnet as scn
from sidechainnet.sequence.utils import VOCAB
from sidechainnet.structure.build_info import NUM_COORDS_PER_RES

# 导入模型相关的库
from alphafold2_pytorch import Alphafold2
import alphafold2_pytorch.constants as constants

from se3_transformer_pytorch import SE3Transformer
from alphafold2_pytorch.utils import *

# 定义常量
FEATURES = "esm" # 特征类型
DEVICE = None # 设备类型,默认为cuda,如果不可用则为cpu
NUM_BATCHES = int(1e5) # 批次数量
GRADIENT_ACCUMULATE_EVERY = 16 # 梯度累积次数
LEARNING_RATE = 3e-4 # 学习率
IGNORE_INDEX = -100 # 忽略索引
THRESHOLD_LENGTH = 250 # 阈值长度
TO_PDB = False # 是否保存为pdb文件
SAVE_DIR = "" # 保存目录

# 设置设备
DEVICE = constants.DEVICE
DISTOGRAM_BUCKETS = constants.DISTOGRAM_BUCKETS

# 根据特征类型选择嵌入模型
if FEATURES == "esm":
    # 从pytorch hub加载ESM-1b模型
    embedd_model, alphabet = torch.hub.load("facebookresearch/esm", "esm1b_t33_650M_UR50S")
    batch_converter = alphabet.get_batch_converter()

# 定义循环函数
def cycle(loader, cond = lambda x: True):
    while True:
        for data in loader:
            if not cond(data):
                continue
            yield data

# 获取数据
data = scn.load(
    casp_version = 12,
    thinning = 30,
    with_pytorch = 'dataloaders',
    batch_size = 1,
    dynamic_batching = False
)

data = iter(data['train'])
data_cond = lambda t: t[1].shape[1] < THRESHOLD_LENGTH
dl = cycle(data, data_cond)

# 定义模型
model = Alphafold2(
    dim = 256,
    depth = 1,
    heads = 8,
    dim_head = 64,
    predict_coords = True,
    structure_module_dim = 8,
    structure_module_depth = 2,
    structure_module_heads = 4,
    structure_module_dim_head = 16,
    structure_module_refinement_iters = 2
).to(DEVICE)

# 定义优化器
dispersion_weight = 0.1
criterion = nn.MSELoss()
optim = Adam(model.parameters(), lr = LEARNING_RATE)

# 训练循环
for _ in range(NUM_BATCHES):
    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        batch = next(dl)
        seq, coords, mask = batch.seqs, batch.crds, batch.msks

        b, l, _ = seq.shape

        # 准备数据和掩码标签
        seq, coords, mask = seq.argmax(dim = -1).to(DEVICE), coords.to(DEVICE), mask.to(DEVICE)

        # 序列嵌入
        msa, embedds = None

        # 获取嵌入
        if FEATURES == "esm":
            embedds = get_esm_embedd(seq, embedd_model, batch_converter)
        elif FEATURES == "msa":
            pass 
        else:
            pass

        # 预测
        refined = model(
            seq,
            msa = msa,
            embedds = embedds,
            mask = mask
        )

        # 构建侧链容器
        proto_sidechain = sidechain_container(coords_3d, n_aa=batch,
                                              cloud_mask=cloud_mask, place_oxygen=False)

        # 旋转/对齐
        coords_aligned, labels_aligned = Kabsch(refined, coords[flat_cloud_mask])

        # 原子掩码
        cloud_mask = scn_cloud_mask(seq, boolean = False)
        flat_cloud_mask = rearrange(cloud_mask, 'b l c -> b (l c)')

        # 链掩码
        chain_mask = (mask * cloud_mask)[cloud_mask]
        flat_chain_mask = rearrange(chain_mask, 'b l c -> b (l c)')

        # 保存pdb文件
        if TO_PDB: 
            idx = 0
            coords2pdb(seq[idx, :, 0], coords_aligned[idx], cloud_mask, prefix=SAVE_DIR, name="pred.pdb")
            coords2pdb(seq[idx, :, 0], labels_aligned[idx], cloud_mask, prefix=SAVE_DIR, name="label.pdb")

        # 计算损失
        loss = torch.sqrt(criterion(coords_aligned[flat_chain_mask], labels_aligned[flat_chain_mask])) + \
                          dispersion_weight * torch.norm( (1/weights)-1 )

        loss.backward()

    print('loss:', loss.item())

    optim.step()
    optim.zero_grad()

.\lucidrains\alphafold2\train_pre.py

# 导入所需的库
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
import torch.nn.functional as F
from einops import rearrange

# 导入自定义库
import sidechainnet as scn
from alphafold2_pytorch import Alphafold2
import alphafold2_pytorch.constants as constants
from alphafold2_pytorch.utils import get_bucketed_distance_matrix

# 常量定义

DEVICE = None # 默认为 cuda(如果可用),否则为 cpu
NUM_BATCHES = int(1e5)
GRADIENT_ACCUMULATE_EVERY = 16
LEARNING_RATE = 3e-4
IGNORE_INDEX = -100
THRESHOLD_LENGTH = 250

# 设置设备

DISTOGRAM_BUCKETS = constants.DISTOGRAM_BUCKETS
DEVICE = constants.DEVICE

# 辅助函数

def cycle(loader, cond = lambda x: True):
    # 无限循环遍历数据加载器
    while True:
        for data in loader:
            if not cond(data):
                continue
            yield data

# 获取数据

# 加载数据集
data = scn.load(
    casp_version = 12,
    thinning = 30,
    with_pytorch = 'dataloaders',
    batch_size = 1,
    dynamic_batching = False
)

# 获取训练数据集的迭代器
data = iter(data['train'])
data_cond = lambda t: t[1].shape[1] < THRESHOLD_LENGTH
dl = cycle(data, data_cond)

# 模型

# 初始化 Alphafold2 模型
model = Alphafold2(
    dim = 256,
    depth = 1,
    heads = 8,
    dim_head = 64
).to(DEVICE)

# 优化器

# 初始化 Adam 优化器
optim = Adam(model.parameters(), lr = LEARNING_RATE)

# 训练循环

# 循环执行指定次数的训练批次
for _ in range(NUM_BATCHES):
    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        # 获取下一个数据批次
        batch = next(dl)
        seq, coords, mask = batch.seqs, batch.crds, batch.msks

        b, l, _ = seq.shape

        # 准备 mask 和 labels

        # 将序列、坐标和 mask 转换为指定设备上的张量
        seq, coords, mask = seq.argmax(dim = -1).to(DEVICE), coords.to(DEVICE), mask.to(DEVICE).bool()
        coords = rearrange(coords, 'b (l c) d -> b l c d', l = l)

        # 获取离散化的距离矩阵
        discretized_distances = get_bucketed_distance_matrix(coords[:, :, 1], mask, DISTOGRAM_BUCKETS, IGNORE_INDEX)

        # 预测

        distogram = model(seq, mask = mask)
        distogram = rearrange(distogram, 'b i j c -> b c i j')

        # 计算损失

        loss = F.cross_entropy(
            distogram,
            discretized_distances,
            ignore_index = IGNORE_INDEX
        )

        # 反向传播
        loss.backward()

    # 打印损失值
    print('loss:', loss.item())

    # 更新优化器参数
    optim.step()
    optim.zero_grad()

.\lucidrains\AMIE-pytorch\AMIE_pytorch\AMIE_pytorch.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum 模块
from torch import nn, einsum
# 从 torch.nn 模块中导入 Module, ModuleList 类
from torch.nn import Module, ModuleList

# 导入 einops 库中的 rearrange 函数
from einops import rearrange

# 定义函数

# 判断变量是否存在
def exists(v):
    return v is not None

# 如果变量存在则返回该变量,否则返回默认值
def default(v, d):
    return v if exists(v) else d

# 自我评论提示
# 论文中的图 A.15

PROMPT_EVALUATE_EXPLANATION = """
I have a doctor-patient dialogue and the corresponding rating that quantifies its quality according to
the following criterion: <criterion> (e.g., maintaining patient welfare). The rating of the dialogue is
on a scale of 1 to 5 where:

5: <definition> e.g., “Treats patient respectfully, and ensures comfort, safety and dignity”
1: <definition> e.g., “Causes patient physical or emotional discomfort AND jeopardises patient safety”

First, describe which parts of the dialogue are good with respect to the criterion. Then, describe which parts are bad with respect to the criterion. Lastly, summarise the above to explain the
provided rating, using the following format:

Good: ...
Bad: ...
Summary: ...

DIALOGUE: <dialogue>
Rating: <human rating>
EVALUATION:
"""

# 图 A.16

PROMPT_EVALUATE_QUALITATIVE = """
I have a doctor-patient dialogue which I would like you to evaluate on the following criterion:
<criterion> (e.g., maintaining patient welfare). The dialogue should be rated on a scale of 1-5 with
respect to the criterion where:

5: <definition> e.g., “Treats patient respectfully, and ensures comfort, safety and dignity”
1: <definition> e.g., “Causes patient physical or emotional discomfort AND jeopardises patient safety”

Here are some example dialogues and their ratings:
DIALOGUE: <example dialog>
EVALUATION: <example self-generated explanation>
Rating: <example rating>
...

Now, please rate the following dialogue as instructed below. First, describe which parts of the dialogue
are good with respect to the criterion. Then, describe which parts are bad with respect to the criterion.
Third, summarise the above findings. Lastly, rate the dialogue on a scale of 1-5 with respect to the
criterion, according to this schema:

Good: ...
Bad: ...
Summary: ...
Rating: ...

DIALOGUE: <dialogue>
EVALUATION:
"""

# 自我对弈模块

class OuterSelfPlay(Module):
    def __init__(self):
        super().__init__()
        raise NotImplementedError

class InnerSelfPlay(Module):
    def __init__(self):
        super().__init__()
        raise NotImplementedError

class PatientAgent(Module):
    def __init__(self):
        super().__init__()
        raise NotImplementedError

class ClinicalVignetteGenerator(Module):
    def __init__(self):
        super().__init__()
        raise NotImplementedError

class Moderator(Module):
    def __init__(self):
        super().__init__()
        raise NotImplementedError

class DoctorAgent(Module):
    def __init__(self):
        super().__init__()
        raise NotImplementedError

class SimulatedDialogue(Module):
    def __init__(self):
        super().__init__()
        raise NotImplementedError

class Critic(Module):
    def __init__(self):
        super().__init__()
        raise NotImplementedError

# 主类

class AMIE(Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x

.\lucidrains\AMIE-pytorch\AMIE_pytorch\__init__.py

# 从AMIE_pytorch模块中导入AMIE类
from AMIE_pytorch.AMIE_pytorch import AMIE

AMIE - Pytorch (wip)

Implementation of the general framework for AMIE, from the paper Towards Conversational Diagnostic AI, out of Google Deepmind

Reach out to me if you are at least a 3rd year medical student, have kept up with the current state of deep learning, and interested in this project.

Todo

Citations

@inproceedings{Tu2024TowardsCD,
    title   = {Towards Conversational Diagnostic AI},
    author  = {Tao Tu and Anil Palepu and Mike Schaekermann and Khaled Saab and Jan Freyberg and Ryutaro Tanno and Amy Wang and Brenna Li and Mohamed Amin and Nenad Toma{\vs}ev and Shekoofeh Azizi and Karan Singhal and Yong Cheng and Le Hou and Albert Webson and Kavita Kulkarni and S Sara Mahdavi and Christopher Semturs and Juraj Gottweis and Joelle Barral and Katherine Chou and Greg S. Corrado and Yossi Matias and Alan Karthikesalingam and Vivek Natarajan},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:266933212}
}

.\lucidrains\AMIE-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'AMIE-pytorch',  # 包名
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.0.1',  # 版本号
  license='MIT',  # 许可证
  description = 'AMIE',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/AMIE-pytorch',  # URL
  keywords = [
    'artificial intelligence',  # 关键词
    'deep learning',  # 关键词
    'differential diagnosis'  # 关键词
  ],
  install_requires=[  # 安装依赖
    'accelerate',  # 加速库
    'beartype',  # 类型检查库
    'einops>=0.7.0',  # 数据操作库
    'einx>=0.1.2',  # 扩展库
    'torch>=2.0',  # PyTorch
    'tqdm'  # 进度条库
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',  # 开发状态
    'Intended Audience :: Developers',  # 目标受众
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 主题
    'License :: OSI Approved :: MIT License',  # 许可证
    'Programming Language :: Python :: 3.6',  # 编程语言
  ],
)

.\lucidrains\anymal-belief-state-encoder-decoder-pytorch\anymal_belief_state_encoder_decoder_pytorch\networks.py

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import GRUCell
from torch.distributions import Categorical
from torch.optim import Adam

from einops import rearrange
from einops_exts import check_shape
from einops.layers.torch import Rearrange

from anymal_belief_state_encoder_decoder_pytorch.running import RunningStats

# helper functions

# 检查值是否存在
def exists(val):
    return val is not None

# 冻结神经网络的函数(老师需要被冻结)

# 设置模块是否需要梯度
def set_module_requires_grad_(module, requires_grad):
    for param in module.parameters():
        param.requires_grad = requires_grad

# 冻结所有层
def freeze_all_layers_(module):
    set_module_requires_grad_(module, False)

# 解冻所有层
def unfreeze_all_layers_(module):
    set_module_requires_grad_(module, True)

# 在论文中
# 网络的注意力门控制外部感知,然后将其与信念状态相加
# todo: 确保填充在正确的一侧

# 使用零填充对两个张量进行相加
def sum_with_zeropad(x, y):
    x_dim, y_dim = x.shape[-1], y.shape[-1]

    if x_dim == y_dim:
        return x + y

    if x_dim < y_dim:
        x = F.pad(x, (y_dim - x_dim, 0))

    if y_dim < x_dim:
        y = F.pad(y, (x_dim - y_dim, 0))

    return x + y

# 添加基本的多层感知机(MLP)

class MLP(nn.Module):
    def __init__(
        self,
        dims,
        activation = nn.LeakyReLU,
        final_activation = False
    ):
        super().__init__()
        assert isinstance(dims, (list, tuple))
        assert len(dims) > 2, 'must have at least 3 dimensions (input, *hiddens, output)'

        dim_pairs = list(zip(dims[:-1], dims[1:]))
        *dim_pairs, dim_out_pair = dim_pairs

        layers = []
        for dim_in, dim_out in dim_pairs:
            layers.extend([
                nn.Linear(dim_in, dim_out),
                activation()
            ])

        layers.append(nn.Linear(*dim_out_pair))

        if final_activation:
            layers.append(activation())

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        if isinstance(x, (tuple, list)):
            x = torch.cat(x, dim = -1)

        return self.net(x)

# 学生模型
class Student(nn.Module):
    def __init__(
        self,
        num_actions,
        proprio_dim = 133,
        extero_dim = 52,  # in paper, height samples was marked as 208, but wasn't sure if that was per leg, or (4 legs x 52) = 208
        latent_extero_dim = 24,
        extero_encoder_hidden = (80, 60),
        belief_state_encoder_hiddens = (64, 64),
        extero_gate_encoder_hiddens = (64, 64),
        belief_state_dim = 120,  # should be equal to teacher's extero_dim + privileged_dim (part of the GRU's responsibility is to maintain a hidden state that forms an opinion on the privileged information)
        gru_num_layers = 2,
        gru_hidden_size = 50,
        mlp_hidden = (256, 160, 128),
        num_legs = 4,
        privileged_dim = 50,
        privileged_decoder_hiddens = (64, 64),
        extero_decoder_hiddens = (64, 64),
    ):
        super().__init__()
        assert belief_state_dim > (num_legs * latent_extero_dim)
        self.num_legs = num_legs
        self.proprio_dim = proprio_dim
        self.extero_dim = extero_dim        

        # encoding of exteroception
        # 外部感知的编码
        self.extero_encoder = MLP((extero_dim, *extero_encoder_hidden, latent_extero_dim))

        # GRU related parameters
        # GRU 相关参数
        gru_input_dim = (latent_extero_dim * num_legs) + proprio_dim
        gru_input_dims = (gru_input_dim, *((gru_hidden_size,) * (gru_num_layers - 1)))
        self.gru_cells = nn.ModuleList([GRUCell(input_dim, gru_hidden_size) for input_dim in gru_input_dims])
        self.gru_hidden_size = gru_hidden_size

        # belief state encoding
        # 信念状态编码
        self.belief_state_encoder = MLP((gru_hidden_size, *belief_state_encoder_hiddens, belief_state_dim))

        # attention gating of exteroception
        # 外部感知的注意力门控制
        self.to_latent_extero_attn_gate = MLP((gru_hidden_size, *extero_gate_encoder_hiddens, latent_extero_dim * num_legs))

        # belief state decoder
        # 信念状态解码器
        self.privileged_decoder = MLP((gru_hidden_size, *privileged_decoder_hiddens, privileged_dim))
        self.extero_decoder = MLP((gru_hidden_size, *extero_decoder_hiddens, extero_dim * num_legs))

        self.to_extero_attn_gate = MLP((gru_hidden_size, *extero_gate_encoder_hiddens, extero_dim * num_legs))

        # final MLP to action logits
        # 最终的 MLP 转换为动作的逻辑
        self.to_logits = MLP((
            belief_state_dim + proprio_dim,
            *mlp_hidden
        ))

        self.to_action_head = nn.Sequential(
            nn.LeakyReLU(),
            nn.Linear(mlp_hidden[-1], num_actions)
        )

    def get_gru_hiddens(self):
        device = next(self.parameters()).device
        return torch.zeros((len(self.gru_cells), self.gru_hidden_size))

    def forward(
        self,
        proprio,
        extero,
        hiddens = None,
        return_estimated_info = False,  # for returning estimated privileged info + exterceptive info, for reconstruction loss
        return_action_categorical_dist = False
    ):
        check_shape(proprio, 'b d', d = self.proprio_dim)
        check_shape(extero, 'b n d', n = self.num_legs, d = self.extero_dim)

        latent_extero = self.extero_encoder(extero)
        latent_extero = rearrange(latent_extero, 'b ... -> b (...)')

        # RNN
        # 循环神经网络

        if not exists(hiddens):
            prev_hiddens = (None,) * len(self.gru_cells)
        else:
            prev_hiddens = hiddens.unbind(dim = -2)

        gru_input = torch.cat((proprio, latent_extero), dim = -1)

        next_hiddens = []
        for gru_cell, prev_hidden in zip(self.gru_cells, prev_hiddens):
            gru_input = gru_cell(gru_input, prev_hidden)
            next_hiddens.append(gru_input)

        gru_output = gru_input

        next_hiddens = torch.stack(next_hiddens, dim = -2)

        # attention gating of exteroception
        # 外部感知的注意力门控制

        latent_extero_attn_gate = self.to_latent_extero_attn_gate(gru_output)
        gated_latent_extero = latent_extero * latent_extero_attn_gate.sigmoid()

        # belief state and add gated exteroception
        # 信念状态和添加门控外部感知

        belief_state = self.belief_state_encoder(gru_output)
        belief_state = sum_with_zeropad(belief_state, gated_latent_extero)

        # to action logits
        # 转换为动作的逻辑

        belief_state_with_proprio = torch.cat((
            proprio,
            belief_state,
        ), dim = 1)

        logits = self.to_logits(belief_state_with_proprio)

        pi_logits = self.to_action_head(logits)

        return_action = Categorical(pi_logits.softmax(dim = -1)) if return_action_categorical_dist else pi_logits

        if not return_estimated_info:
            return return_action, next_hiddens

        # belief state decoding
        # for reconstructing privileged and exteroception information from hidden belief states
        # 用于从隐藏的信念状态中重建特权和外部感知信息

        recon_privileged = self.privileged_decoder(gru_output)
        recon_extero = self.extero_decoder(gru_output)
        extero_attn_gate = self.to_extero_attn_gate(gru_output)

        gated_extero = rearrange(extero, 'b ... -> b (...)') * extero_attn_gate.sigmoid()
        recon_extero = recon_extero + gated_extero
        recon_extero = rearrange(recon_extero, 'b (n d) -> b n d', n = self.num_legs)

        # whether to return raw policy logits or action probs wrapped with Categorical
        # 是否返回原始策略逻辑或用 Categorical 包装的动作概率

        return return_action, next_hiddens, (recon_privileged, recon_extero)

# 教师模型
class Teacher(nn.Module):
    def __init__(
        self,
        num_actions,
        proprio_dim = 133,
        extero_dim = 52,  # in paper, height samples was marked as 208, but wasn't sure if that was per leg, or (4 legs x 52) = 208
        latent_extero_dim = 24,
        extero_encoder_hidden = (80, 60),
        privileged_dim = 50,
        latent_privileged_dim = 24,
        privileged_encoder_hidden = (64, 32),
        mlp_hidden = (256, 160, 128),
        num_legs = 4
        ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化属性:腿的数量
        self.num_legs = num_legs
        # 初始化属性:本体维度
        self.proprio_dim = proprio_dim
        # 初始化属性:外部维度
        self.extero_dim = extero_dim
        # 初始化属性:特权维度
        self.privileged_dim = privileged_dim

        # 初始化属性:外部编码器
        self.extero_encoder = MLP((extero_dim, *extero_encoder_hidden, latent_extero_dim))
        # 初始化属性:特权编码器
        self.privileged_encoder = MLP((privileged_dim, *privileged_encoder_hidden, latent_privileged_dim))

        # 初始化属性:转换为逻辑
        self.to_logits = MLP((
            latent_extero_dim * num_legs + latent_privileged_dim + proprio_dim,
            *mlp_hidden
        ))

        # 初始化属性:转换为动作头
        self.to_action_head = nn.Sequential(
            nn.LeakyReLU(),
            nn.Linear(mlp_hidden[-1], num_actions)
        )

        # 初始化属性:转换为价值头
        self.to_value_head = nn.Sequential(
            nn.LeakyReLU(),
            nn.Linear(mlp_hidden[-1], 1),
            Rearrange('... 1 -> ...')
        )

    def forward(
        self,
        proprio,
        extero,
        privileged,
        return_value_head = False,
        return_action_categorical_dist = False
    ):
        # 检查本体形状
        check_shape(proprio, 'b d', d = self.proprio_dim)
        # 检查外部形状
        check_shape(extero, 'b n d', n = self.num_legs, d = self.extero_dim)
        # 检查特权形状
        check_shape(privileged, 'b d', d = self.privileged_dim)

        # 计算外部潜在表示
        latent_extero = self.extero_encoder(extero)
        # 重新排列外部潜在表示
        latent_extero = rearrange(latent_extero, 'b ... -> b (...)')

        # 计算特权潜在表示
        latent_privileged = self.privileged_encoder(privileged)

        # 拼接本体、外部潜在表示和特权潜在表示
        latent = torch.cat((
            proprio,
            latent_extero,
            latent_privileged,
        ), dim = -1)

        # 计算逻辑
        logits = self.to_logits(latent)

        # 计算动作头
        pi_logits = self.to_action_head(logits)

        # 如果不返回价值头,则返回动作头
        if not return_value_head:
            return pi_logits

        # 计算价值头
        value_logits = self.to_value_head(logits)

        # 如果需要返回动作的分类分布,则返回分类分布,否则返回动作头
        return_action = Categorical(pi_logits.softmax(dim = -1)) if return_action_categorical_dist else pi_logits
        return return_action, value_logits
# 定义一个同时管理教师和学生的模块
class Anymal(nn.Module):
    def __init__(
        self,
        num_actions,
        proprio_dim = 133,
        extero_dim = 52,
        privileged_dim = 50,
        num_legs = 4,
        latent_extero_dim = 24,
        latent_privileged_dim = 24,
        teacher_extero_encoder_hidden = (80, 60),
        teacher_privileged_encoder_hidden = (64, 32),
        student_extero_gate_encoder_hiddens = (64, 64),
        student_belief_state_encoder_hiddens = (64, 64),
        student_belief_state_dim = 120,
        student_gru_num_layers = 2,
        student_gru_hidden_size = 50,
        student_privileged_decoder_hiddens = (64, 64),
        student_extero_decoder_hiddens = (64, 64),
        student_extero_encoder_hidden = (80, 60),
        mlp_hidden = (256, 160, 128),
        recon_loss_weight = 0.5
    ):
        super().__init__()
        # 初始化模块的属性
        self.proprio_dim = proprio_dim
        self.num_legs = num_legs
        self.extero_dim = extero_dim

        # 创建学生对象
        self.student = Student(
            num_actions = num_actions,
            proprio_dim = proprio_dim,
            extero_dim = extero_dim,
            latent_extero_dim = latent_extero_dim,
            extero_encoder_hidden = student_extero_encoder_hidden,
            belief_state_encoder_hiddens = student_belief_state_encoder_hiddens,
            extero_gate_encoder_hiddens = student_extero_gate_encoder_hiddens,
            belief_state_dim = student_belief_state_dim,
            gru_num_layers = student_gru_num_layers,
            gru_hidden_size = student_gru_hidden_size,
            mlp_hidden = mlp_hidden,
            num_legs = num_legs,
            privileged_dim = privileged_dim,
            privileged_decoder_hiddens = student_privileged_decoder_hiddens,
            extero_decoder_hiddens = student_extero_decoder_hiddens,
        )

        # 创建教师对象
        self.teacher = Teacher(
            num_actions = num_actions,
            proprio_dim = proprio_dim,
            extero_dim = extero_dim,
            latent_extero_dim = latent_extero_dim,
            extero_encoder_hidden = teacher_extero_encoder_hidden,
            privileged_dim = privileged_dim,
            latent_privileged_dim = latent_privileged_dim,
            privileged_encoder_hidden = teacher_privileged_encoder_hidden,
            mlp_hidden = mlp_hidden,
            num_legs = num_legs
        )

        self.recon_loss_weight = recon_loss_weight

    # 获取观察的运行统计信息
    def get_observation_running_stats(self):
        return RunningStats(self.proprio_dim), RunningStats((self.num_legs, self.extero_dim))

    # 使用教师初始化学生
    def init_student_with_teacher(self):
        self.student.extero_encoder.load_state_dict(self.teacher.extero_encoder.state_dict())
        self.student.to_logits.load_state_dict(self.teacher.to_logits.state_dict())
        self.student.to_action_head.load_state_dict(self.teacher.to_action_head.state_dict())

    # 为教师前向传播定义方法
    def forward_teacher(self, *args, return_value_head = False, **kwargs):
        return self.teacher(*args, return_value_head = return_value_head, **kwargs)

    # 为学生前向传播定义方法
    def forward_student(self, *args, **kwargs):
        return self.student(*args, **kwargs)

    # 用教师作为指导训练学生的主要前向传播
    def forward(
        self,
        proprio,
        extero,
        privileged,
        teacher_states = None,
        hiddens = None,
        noise_strength = 0.1
    ):
        # 将教师设置为评估模式
        self.teacher.eval()
        # 冻结教师的所有层
        freeze_all_layers_(self.teacher)

        with torch.no_grad():
            teacher_proprio, teacher_extero = teacher_states if exists(teacher_states) else (proprio, extero)
            teacher_action_logits = self.forward_teacher(teacher_proprio, teacher_extero, privileged)

        # 添加噪声到外部感知
        noised_extero = extero + torch.rand_like(extero) * noise_strength

        # 学生前向传播
        student_action_logits, hiddens, recons = self.student(proprio, noised_extero, hiddens = hiddens, return_estimated_info = True)

        # 计算特权和去噪外部感知的重建损失
        (recon_privileged, recon_extero) = recons
        recon_loss = F.mse_loss(recon_privileged, privileged) + F.mse_loss(recon_extero, extero)

        # 计算行为损失,也是平方距离?
        behavior_loss = F.mse_loss(teacher_action_logits, student_action_logits)

        # 计算总损失
        loss = behavior_loss + recon_loss * self.recon_loss_weight
        return loss, hiddens

.\lucidrains\anymal-belief-state-encoder-decoder-pytorch\anymal_belief_state_encoder_decoder_pytorch\ppo.py

# 导入必要的库
from collections import namedtuple, deque
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from anymal_belief_state_encoder_decoder_pytorch import Anymal
from anymal_belief_state_encoder_decoder_pytorch.networks import unfreeze_all_layers_
from einops import rearrange

# 定义一个命名元组Memory,用于存储经验数据
Memory = namedtuple('Memory', ['state', 'action', 'action_log_prob', 'reward', 'done', 'value'])

# 定义一个数据集类ExperienceDataset,用于存储经验数据
class ExperienceDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data

    def __len__(self):
        return len(self.data[0])

    def __getitem__(self, ind):
        return tuple(map(lambda t: t[ind], self.data))

# 创建一个混洗数据加载器函数
def create_shuffled_dataloader(data, batch_size):
    ds = ExperienceDataset(data)
    return DataLoader(ds, batch_size = batch_size, shuffle = True)

# 定义一个归一化函数,用于对张量进行归一化处理
def normalize(t, eps = 1e-5):
    return (t - t.mean()) / (t.std() + eps)

# 定义一个裁剪值损失函数,用于计算值函数的损失
def clipped_value_loss(values, rewards, old_values, clip):
    value_clipped = old_values + (values - old_values).clamp(-clip, clip)
    value_loss_1 = (value_clipped.flatten() - rewards) ** 2
    value_loss_2 = (values.flatten() - rewards) ** 2
    return torch.mean(torch.max(value_loss_1, value_loss_2))

# 定义一个模拟环境类MockEnv,用于模拟环境状态和动作
class MockEnv(object):
    def __init__(
        self,
        proprio_dim,
        extero_dim,
        privileged_dim,
        num_legs = 4
    ):
        self.proprio_dim = proprio_dim
        self.extero_dim = extero_dim
        self.privileged_dim = privileged_dim
        self.num_legs = num_legs

    def rand_state(self):
        return (
            torch.randn((self.proprio_dim,)),
            torch.randn((self.num_legs, self.extero_dim,)),
            torch.randn((self.privileged_dim,))
        )

    def reset(self):
        return self.rand_state()

    def step(self, action):
        reward = torch.randn((1,))
        done = torch.tensor([False])
        return self.rand_state(), reward, done, None

# 定义一个PPO类,用于执行PPO算法
class PPO(nn.Module):
    def __init__(
        self,
        *,
        env,
        anymal,
        epochs = 2,
        lr = 5e-4,
        betas = (0.9, 0.999),
        eps_clip = 0.2,
        beta_s = 0.005,
        value_clip = 0.4,
        max_timesteps = 10000,
        update_timesteps = 5000,
        lam = 0.95,
        gamma = 0.99,
        minibatch_size = 8300
    ):
        super().__init__()
        assert isinstance(anymal, Anymal)
        self.env = env
        self.anymal = anymal

        self.minibatch_size = minibatch_size
        self.optimizer = Adam(anymal.teacher.parameters(), lr = lr, betas = betas)
        self.epochs = epochs

        self.max_timesteps = max_timesteps
        self.update_timesteps = update_timesteps

        self.beta_s = beta_s
        self.eps_clip = eps_clip
        self.value_clip = value_clip

        self.lam = lam
        self.gamma = gamma

        # 在论文中,他们说传递给teacher的观察值是通过运行均值进行归一化的

        self.running_proprio, self.running_extero = anymal.get_observation_running_stats()

    def learn_from_memories(
        self,
        memories,
        next_states
    ):
        device = next(self.parameters()).device

        # 从内存中检索和准备数据进行训练
        states = []
        actions = []
        old_log_probs = []
        rewards = []
        masks = []
        values = []

        for mem in memories:
            states.append(mem.state)
            actions.append(torch.tensor(mem.action))
            old_log_probs.append(mem.action_log_prob)
            rewards.append(mem.reward)
            masks.append(1 - float(mem.done))
            values.append(mem.value)

        states = tuple(zip(*states))

        # 计算广义优势估计

        next_states = map(lambda t: t.to(device), next_states)
        next_states = map(lambda t: rearrange(t, '... -> 1 ...'), next_states)

        _, next_value = self.anymal.forward_teacher(*next_states, return_value_head = True)
        next_value = next_value.detach()

        values = values + [next_value]

        returns = []
        gae = 0
        for i in reversed(range(len(rewards))):
            delta = rewards[i] + self.gamma * values[i + 1] * masks[i] - values[i]
            gae = delta + self.gamma * self.lam * masks[i] * gae
            returns.insert(0, gae + values[i])

        # 将值转换为torch张量

        to_torch_tensor = lambda t: torch.stack(t).to(device).detach()

        states = map(to_torch_tensor, states)
        actions = to_torch_tensor(actions)
        old_log_probs = to_torch_tensor(old_log_probs)

        old_values = to_torch_tensor(values[:-1])
        old_values = rearrange(old_values, '... 1 -> ...')

        rewards = torch.tensor(returns).float().to(device)

        # 为策略阶段训练准备数据加载器

        dl = create_shuffled_dataloader([*states, actions, old_log_probs, rewards, old_values], self.minibatch_size)

        # 策略阶段训练,类似于原始的PPO

        for _ in range(self.epochs):
            for proprio, extero, privileged, actions, old_log_probs, rewards, old_values in dl:

                dist, values = self.anymal.forward_teacher(
                    proprio, extero, privileged,
                    return_value_head = True,
                    return_action_categorical_dist = True
                )

                action_log_probs = dist.log_prob(actions)

                entropy = dist.entropy()
                ratios = (action_log_probs - old_log_probs).exp()
                advantages = normalize(rewards - old_values.detach())
                surr1 = ratios * advantages
                surr2 = ratios.clamp(1 - self.eps_clip, 1 + self.eps_clip) * advantages

                policy_loss = - torch.min(surr1, surr2) - self.beta_s * entropy

                value_loss = clipped_value_loss(values, rewards, old_values, self.value_clip)

                (policy_loss.mean() + value_loss.mean()).backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

    # 执行一个episode的学习
    # 定义一个前向传播函数,用于执行模型的前向传播操作
    def forward(self):
        # 获取模型参数中的设备信息
        device = next(self.parameters()).device
        # 解冻所有层的参数
        unfreeze_all_layers_(self.anymal)

        # 初始化时间步数和状态信息
        time = 0
        states = self.env.reset() # 状态假设为(本体感知,外部感知,特权信息)
        memories = deque([])

        # 清空本体感知和外部感知的运行均值
        self.running_proprio.clear()
        self.running_extero.clear()

        # 循环执行最大时间步数次
        for timestep in range(self.max_timesteps):
            time += 1

            # 将状态信息转移到指定设备上
            states = list(map(lambda t: t.to(device), states))
            proprio, extero, privileged = states

            # 更新用于教师的观测运行均值
            self.running_proprio.push(proprio)
            self.running_extero.push(extero)

            # 对教师的观测状态进行归一化处理(本体感知和外部感知)
            states = (
                self.running_proprio.norm(proprio),
                self.running_extero.norm(extero),
                privileged
            )

            # 将状态信息重新排列为适合模型输入的形式
            anymal_states = list(map(lambda t: rearrange(t, '... -> 1 ...'), states))

            # 执行模型的前向传播操作,获取动作分布和值
            dist, values = self.anymal.forward_teacher(
                *anymal_states,
                return_value_head = True,
                return_action_categorical_dist = True
            )

            # 从动作分布中采样动作
            action = dist.sample()
            action_log_prob = dist.log_prob(action)
            action = action.item()

            # 执行动作,获取下一个状态、奖励、是否结束标志和额外信息
            next_states, reward, done, _ = self.env.step(action)

            # 创建记忆对象,存储状态、动作、动作对数概率、奖励、是否结束标志和值
            memory = Memory(states, action, action_log_prob, reward, done, values)
            memories.append(memory)

            # 更新状态信息为下一个状态
            states = next_states

            # 每隔一定时间步数执行一次经验回放和学习
            if time % self.update_timesteps == 0:
                self.learn_from_memories(memories, next_states)
                memories.clear()

            # 如果环境结束,则跳出循环
            if done:
                break

        # 打印训练完成一���的信息
        print('trained for 1 episode')

.\lucidrains\anymal-belief-state-encoder-decoder-pytorch\anymal_belief_state_encoder_decoder_pytorch\running.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn

# 定义 RunningStats 类,继承自 nn.Module 类
class RunningStats(nn.Module):
    # 初始化方法,接受 shape 和 eps 两个参数
    def __init__(self, shape, eps = 1e-5):
        super().__init__()
        # 如果 shape 不是元组,则转换为元组
        shape = shape if isinstance(shape, tuple) else (shape,)

        # 初始化对象的 shape、eps 和 n 属性
        self.shape = shape
        self.eps = eps
        self.n = 0

        # 注册缓冲区 old_mean、new_mean、old_std、new_std,并设置为非持久化
        self.register_buffer('old_mean', torch.zeros(shape), persistent = False)
        self.register_buffer('new_mean', torch.zeros(shape), persistent = False)
        self.register_buffer('old_std', torch.zeros(shape), persistent = False)
        self.register_buffer('new_std', torch.zeros(shape), persistent = False)

    # 清空方法,将 n 属性重置为 0
    def clear(self):
        self.n = 0

    # 推送方法,接受输入 x,并更新均值和标准差
    def push(self, x):
        self.n += 1

        # 如果 n 为 1,则将 old_mean 和 new_mean 设置为 x 的数据,old_std 和 new_std 设置为 0
        if self.n == 1:
            self.old_mean.copy_(x.data)
            self.new_mean.copy_(x.data)
            self.old_std.zero_()
            self.new_std.zero_()
            return

        # 更新均值和标准差
        self.new_mean.copy_(self.old_mean + (x - self.old_mean) / self.n)
        self.new_std.copy_(self.old_std + (x - self.old_mean) * (x - self.new_mean))

        self.old_mean.copy_(self.new_mean)
        self.old_std.copy_(self.new_std)

    # 返回均值的方法
    def mean(self):
        return self.new_mean if self.n else torch.zeros_like(self.new_mean)

    # 返回方差的方法
    def variance(self):
        return (self.new_std / (self.n - 1)) if self.n > 1 else torch.zeros_like(self.new_std)

    # 返回标准差的倒数的方法
    def rstd(self):
        return torch.rsqrt(self.variance() + self.eps)

    # 归一化方法,接受输入 x,返回归一化后的结果
    def norm(self, x):
        return (x - self.mean()) * self.rstd()

.\lucidrains\anymal-belief-state-encoder-decoder-pytorch\anymal_belief_state_encoder_decoder_pytorch\trainer.py

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

from collections import deque
from einops import rearrange

from anymal_belief_state_encoder_decoder_pytorch import Anymal

# 定义一个继承自Dataset的类,用于存储经验数据
class ExperienceDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data

    def __len__(self):
        return len(self.data[0])

    def __getitem__(self, ind):
        return tuple(map(lambda t: t[ind], self.data))

# 创建一个DataLoader对象,用于批量加载数据
def create_dataloader(data, batch_size):
    ds = ExperienceDataset(data)
    return DataLoader(ds, batch_size = batch_size, drop_last = True)

# 定义一个继承自nn.Module的类,用于训练学生模型
class StudentTrainer(nn.Module):
    def __init__(
        self,
        *,
        anymal,
        env,
        epochs = 2,
        lr = 5e-4,
        max_timesteps = 10000,
        update_timesteps = 5000,
        minibatch_size = 16,
        truncate_tpbtt = 10
    ):
        super().__init__()
        self.env = env
        self.anymal = anymal
        self.optimizer = Adam(anymal.student.parameters(), lr = lr)
        self.epochs = epochs

        self.max_timesteps = max_timesteps
        self.update_timesteps = update_timesteps
        self.minibatch_size = minibatch_size
        self.truncate_tpbtt = truncate_tpbtt

        self.running_proprio, self.running_extero = anymal.get_observation_running_stats()

    # 从记忆中学习
    def learn_from_memories(
        self,
        memories,
        next_states,
        noise_strength = 0.
    ):
        device = next(self.parameters()).device

        # 从记忆中检索和准备数据进行训练

        states = []
        teacher_states = []
        hiddens = []
        dones = []

        for (state, teacher_state, hidden, done) in memories:
            states.append(state)
            teacher_states.append(teacher_state)
            hiddens.append(hidden)
            dones.append(torch.Tensor([done]))

        states = tuple(zip(*states))
        teacher_states = tuple(zip(*teacher_states))

        # 将值转换为torch张量

        to_torch_tensor = lambda t: torch.stack(t).to(device).detach()

        states = map(to_torch_tensor, states)
        teacher_states = map(to_torch_tensor, teacher_states)
        hiddens = to_torch_tensor(hiddens)
        dones = to_torch_tensor(dones)

        # 为策略阶段训练准备数据加载器

        dl = create_dataloader([*states, *teacher_states, hiddens, dones], self.minibatch_size)

        current_hiddens = self.anymal.student.get_gru_hiddens()
        current_hiddens = rearrange(current_hiddens, 'l d -> 1 l d')

        for _ in range(self.epochs):
            for ind, (proprio, extero, privileged, teacher_proprio, teacher_extero, episode_hiddens, done) in enumerate(dl):

                straight_through_hiddens = current_hiddens - current_hiddens.detach() + episode_hiddens

                loss, current_hiddens = self.anymal(
                    proprio,
                    extero,
                    privileged,
                    teacher_states = (teacher_proprio, teacher_extero),
                    hiddens = straight_through_hiddens,
                    noise_strength = noise_strength
                )

                loss.backward(retain_graph = True)

                tbptt_limit = not ((ind + 1) % self.truncate_tpbtt)
                if tbptt_limit: # 控制梯度回传的时间跨度
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    current_hiddens = current_hiddens.detach()

                # 根据是否是新的一集,分离隐藏状态
                # 待办事项:重新构建数据加载器以每批行加载一个集

                maybe_detached_hiddens = []
                for current_hidden, done in zip(current_hiddens.unbind(dim = 0), dones.unbind(dim = 0)):
                    maybe_detached_hiddens.append(current_hidden.detach() if done else current_hidden)

                current_hiddens = torch.stack(maybe_detached_hiddens)

    # 前向传播函数
    def forward(
        self,
        noise_strength = 0.
    ):
        device = next(self.parameters()).device

        time = 0
        done = False
        states = self.env.reset()
        memories = deque([])

        hidden = self.anymal.student.get_gru_hiddens()
        hidden = rearrange(hidden, 'l d -> 1 l d')

        self.running_proprio.clear()
        self.running_extero.clear()

        for timestep in range(self.max_timesteps):
            time += 1

            states = list(map(lambda t: t.to(device), states))
            anymal_states = list(map(lambda t: rearrange(t, '... -> 1 ...'), states))

            # 教师需要有归一化的观测值

            (proprio, extero, privileged) = states

            self.running_proprio.push(proprio)
            self.running_extero.push(extero)

            teacher_states = (
                self.running_proprio.norm(proprio),
                self.running_extero.norm(extero)
            )

            teacher_anymal_states = list(map(lambda t: rearrange(t, '... -> 1 ...'), teacher_states))

            # 将状态添加到记忆中

            memories.append((
                states,
                teacher_states,
                rearrange(hidden, '1 ... -> ...'),
                done
            ))

            dist, hidden = self.anymal.forward_student(
                *anymal_states[:-1],
                hiddens = hidden,
                return_action_categorical_dist = True
            )

            action = dist.sample()
            action_log_prob = dist.log_prob(action)
            action = action.item()

            next_states, _, done, _ = self.env.step(action)

            states = next_states

            if time % self.update_timesteps == 0:
                self.learn_from_memories(memories, next_states, noise_strength = noise_strength)
                memories.clear()

            if done:
                break

.\lucidrains\anymal-belief-state-encoder-decoder-pytorch\anymal_belief_state_encoder_decoder_pytorch\__init__.py

# 从anymal_belief_state_encoder_decoder_pytorch.networks模块中导入Student, Teacher, MLP, Anymal类
from anymal_belief_state_encoder_decoder_pytorch.networks import Student, Teacher, MLP, Anymal
# 从anymal_belief_state_encoder_decoder_pytorch.ppo模块中导入PPO, MockEnv类
from anymal_belief_state_encoder_decoder_pytorch.ppo import PPO, MockEnv

Belief State Encoder / Decoder (Anymal) - Pytorch

Implementation of the Belief State Encoder / Decoder in the new breakthrough robotics paper from ETH Zürich.

This paper is important as it seems their learned approach produced a policy that rivals Boston Dynamic's handcrafted algorithms (quadripedal Spot).

The results speak for itself in their video demonstration

Install

$ pip install anymal-belief-state-encoder-decoder-pytorch

Usage

Teacher

import torch
from anymal_belief_state_encoder_decoder_pytorch import Teacher

teacher = Teacher(
    num_actions = 10,
    num_legs = 4,
    extero_dim = 52,
    proprio_dim = 133,
    privileged_dim = 50
)

proprio = torch.randn(1, 133)
extero = torch.randn(1, 4, 52)
privileged = torch.randn(1, 50)

action_logits, values = teacher(proprio, extero, privileged, return_values = True) # (1, 10)

Student

import torch
from anymal_belief_state_encoder_decoder_pytorch import Student

student = Student(
    num_actions = 10,
    num_legs = 4,
    extero_dim = 52,
    proprio_dim = 133,
    gru_num_layers = 2,
    gru_hidden_size = 50
)

proprio = torch.randn(1, 133)
extero = torch.randn(1, 4, 52)

action_logits, hiddens = student(proprio, extero) # (1, 10), (2, 1, 50)
action_logits, hiddens = student(proprio, extero, hiddens) # (1, 10), (2, 1, 50)
action_logits, hiddens = student(proprio, extero, hiddens) # (1, 10), (2, 1, 50)

# hiddens are in the shape (num gru layers, batch size, gru hidden dimension)
# train with truncated bptt

Full Anymal (which contains both Teacher and Student)

import torch
from anymal_belief_state_encoder_decoder_pytorch import Anymal

anymal = Anymal(
    num_actions = 10,
    num_legs = 4,
    extero_dim = 52,
    proprio_dim = 133,
    privileged_dim = 50,
    recon_loss_weight = 0.5
)

# mock data

proprio = torch.randn(1, 133)
extero = torch.randn(1, 4, 52)
privileged = torch.randn(1, 50)

# first train teacher

teacher_action_logits = anymal.forward_teacher(proprio, extero, privileged)

# teacher is trained with privileged information in simulation with domain randomization

# after teacher has satisfactory performance, init the student with the teacher weights, excluding the privilege information encoder from the teacher (which student does not have)

anymal.init_student_with_teacher()

# then train the student on the proprioception and noised exteroception, forcing it to reconstruct the privileged information that the teacher had access to (as well as learning to denoise the exterception) - there is also a behavior loss between the policy logits of the teacher with those of the student

loss, hiddens = anymal(proprio, extero, privileged)
loss.backward()

# finally, you can deploy the student to the real world, zero-shot

anymal.eval()
dist, hiddens = anymal.forward_student(proprio, extero, return_action_categorical_dist = True)
action = dist.sample()

PPO training of the Teacher (using a mock environment, this needs to be substituted with a environment wrapper around simulator)

import torch
from anymal_belief_state_encoder_decoder_pytorch import Anymal, PPO
from anymal_belief_state_encoder_decoder_pytorch.ppo import MockEnv

anymal = Anymal(
    num_actions = 10,
    num_legs = 4,
    extero_dim = 52,
    proprio_dim = 133,
    privileged_dim = 50,
    recon_loss_weight = 0.5
)

mock_env = MockEnv(
    proprio_dim = 133,
    extero_dim = 52,
    privileged_dim = 50
)

ppo = PPO(
    env = mock_env,
    anymal = anymal,
    epochs = 10,
    lr = 3e-4,
    eps_clip = 0.2,
    beta_s = 0.01,
    value_clip = 0.4,
    max_timesteps = 10000,
    update_timesteps = 5000,
)

# train for 10 episodes

for _ in range(10):
    ppo()

# save the weights of the teacher for student training

torch.save(anymal.state_dict(), './anymal-with-trained-teacher.pt')

To train the student

import torch
from anymal_belief_state_encoder_decoder_pytorch import Anymal
from anymal_belief_state_encoder_decoder_pytorch.trainer import StudentTrainer
from anymal_belief_state_encoder_decoder_pytorch.ppo import MockEnv

anymal = Anymal(
    num_actions = 10,
    num_legs = 4,
    extero_dim = 52,
    proprio_dim = 133,
    privileged_dim = 50,
    recon_loss_weight = 0.5
)

# first init student with teacher weights, at the very beginning
# if not resuming training

mock_env = MockEnv(
    proprio_dim = 133,
    extero_dim = 52,
    privileged_dim = 50
)

trainer = StudentTrainer(
    anymal = anymal,
    env = mock_env
)

# for 100 episodes

for _ in range(100):
    trainer()

... You've beaten Boston Dynamics and its team of highly paid control engineers!

But you probably haven't beaten a real quadripedal "anymal" just yet 😃

Todo

Diagrams

Citations

@article{2022,
  title     = {Learning robust perceptive locomotion for quadrupedal robots in the wild},
  url       = {http://dx.doi.org/10.1126/scirobotics.abk2822},
  journal   = {Science Robotics},
  publisher = {American Association for the Advancement of Science (AAAS)},
  author    = {Miki, Takahiro and Lee, Joonho and Hwangbo, Jemin and Wellhausen, Lorenz and Koltun, Vladlen and Hutter, Marco},
  year      = {2022},
  month     = {Jan}
}

.\lucidrains\anymal-belief-state-encoder-decoder-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'anymal-belief-state-encoder-decoder-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.0.20',  # 版本号
  license='MIT',  # 许可证
  description = 'Anymal Belief-state Encoder Decoder - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/anymal-belief-state-encoder-decoder-pytorch',  # 项目链接
  keywords = [
    'artificial intelligence',  # 关键词
    'deep learning',  # 关键词
    'attention gating',  # 关键词
    'belief state',  # 关键词
    'robotics'  # 关键词
  ],
  install_requires=[
    'einops>=0.4',  # 安装所需的依赖包
    'einops-exts',  # 安装所需的依赖包
    'torch>=1.6',  # 安装所需的依赖包
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类器
    'Intended Audience :: Developers',  # 分类器
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类器
    'License :: OSI Approved :: MIT License',  # 分类器
    'Programming Language :: Python :: 3.6',  # 分类器
  ],
)

.\lucidrains\AoA-pytorch\aoa_pytorch\aoa_pytorch.py

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange

# 定义一个函数,用于检查变量是否存在
def exists(val):
    return val is not None

# 定义一个函数,如果变量存在则返回该变量,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 定义一个名为AttentionOnAttention的类,继承自nn.Module
class AttentionOnAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        dropout = 0.,
        aoa_dropout = 0.
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        # 定义线性层,用于将输入转换为查询向量
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        # 定义线性层,用于将输入转换为键值对
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        # 定义dropout层
        self.dropout = nn.Dropout(dropout)

        # 定义Attention on Attention模块
        self.aoa = nn.Sequential(
            nn.Linear(2 * inner_dim, 2 * dim),
            nn.GLU(),
            nn.Dropout(aoa_dropout)
        )

    # 前向传播函数
    def forward(self, x, context = None):
        h = self.heads

        # 将输入x转换为查询向量
        q_ = self.to_q(x)

        # 如果存在上下文信息,则使用上下文信息作为键值对,否则使用输入x作为键值对
        context = default(context, x)
        kv = self.to_kv(context).chunk(2, dim = -1)

        # 将查询向量、键向量和值向量按照头数拆分
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q_, *kv))
        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        # 计算注意力权重
        attn = dots.softmax(dim = -1)
        attn = self.dropout(attn)

        # 加权平均值
        attn_out = einsum('b h i j, b h j d -> b h i d', attn, v)

        # 合并头部
        out = rearrange(attn_out, 'b h n d -> b n (h d)', h = h)

        # Attention on Attention模块
        out = self.aoa(torch.cat((out, q_), dim = -1))
        return out

.\lucidrains\AoA-pytorch\aoa_pytorch\__init__.py

# 从 aoa_pytorch 模块中导入 AttentionOnAttention 类
from aoa_pytorch.aoa_pytorch import AttentionOnAttention
# 将 AttentionOnAttention 类赋值给 AoA 变量
AoA = AttentionOnAttention

Attention on Attention - Pytorch

A Pytorch implementation of the Attention on Attention module, from the paper An Improved Attention for Visual Question Answering. The repository will include both the Self and Guided (cross-attention) variants.

Install

$ pip install aoa-pytorch

Usage

Self Attention on Attention

import torch
from aoa_pytorch import AoA

attn = AoA(
    dim = 512,
    heads = 8
)

x = torch.randn(1, 1024, 512)
attn(x) + x # (1, 1024, 512)

Guided Attention on Attention

```python
import torch
from aoa_pytorch import AoA

attn = AoA(
    dim = 512,
    heads = 8
)

x = torch.randn(1, 1024, 512)
context = torch.randn(1, 1024, 512)

attn(x, context = context) + x # (1, 1024, 512)
```py

## Citations

```py
@misc{rahman2020improved,
    title   = {An Improved Attention for Visual Question Answering}, 
    author  = {Tanzila Rahman and Shih-Han Chou and Leonid Sigal and Giuseppe Carenini},
    year    = {2020},
    eprint  = {2011.02164},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
```py

```py
@misc{huang2019attention,
    title   = {Attention on Attention for Image Captioning}, 
    author  = {Lun Huang and Wenmin Wang and Jie Chen and Xiao-Yong Wei},
    year    = {2019},
    eprint  = {1908.06954},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

.\lucidrains\AoA-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'aoa_pytorch', # 包的名称
  packages = find_packages(exclude=['examples']), # 查找并包含除了 examples 之外的所有包
  version = '0.0.2', # 版本号
  license='MIT', # 许可证信息
  description = 'Attention on Attention - Pytorch', # 包的描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者的邮箱
  url = 'https://github.com/lucidrains/SAoA-pytorch', # 项目的链接
  keywords = [
    'artificial intelligence', # 关键词:人工智能
    'attention mechanism', # 关键词:注意力机制
    'visual question answering' # 关键词:视觉问题回答
  ],
  install_requires=[
    'torch>=1.6', # 安装所需的 torch 版本
    'einops>=0.3' # 安装所需的 einops 版本
  ],
  classifiers=[
    'Development Status :: 4 - Beta', # 开发状态
    'Intended Audience :: Developers', # 预期的受众
    'Topic :: Scientific/Engineering :: Artificial Intelligence', # 主题
    'License :: OSI Approved :: MIT License', # 许可证
    'Programming Language :: Python :: 3.6', # 使用的编程语言版本
  ],
)

.\lucidrains\attention-tensorflow-mesh\attention_tensorflow_mesh\attention_tensorflow_mesh.py

# 导入必要的库
import math
import mesh_tensorflow as mtf
import tensorflow.compat.v1 as tf

# 辅助函数

# 如果值为None,则返回默认值
def default(val, d):
    return val if val is not None else d

# 简单的线性层

def linear(x, dim_out, scope = 'linear', bias = True):
    with tf.variable_scope(scope):
        *_, dim_in = x.shape
        w_init_stdev = 1 / math.sqrt(dim_in.size)

        return  mtf.layers.dense(x, new_dims=[dim_out], reduced_dims=[dim_in], name=scope, use_bias=bias,
                                 kernel_initializer=tf.random_normal_initializer(stddev=w_init_stdev, dtype=tf.float32))

# 归一化

def norm(x, axis = None, epsilon=1e-5):
    axis = default(axis, x.shape[-1])

    u = mtf.reduce_mean(x, reduced_dim=axis)
    s = mtf.reduce_mean(mtf.square(x - u), reduced_dim=axis)

    u = mtf.broadcast(u, x.shape)
    s = mtf.broadcast(s, x.shape)

    return (x - u) * mtf.rsqrt(s + epsilon)

# 缩放归一化
def scale_norm(x, scope, *, axis=None, epsilon=1e-5, params=None):
    if axis is None:
        axis = x.shape[-1]

    with tf.variable_scope(scope):
        n_state = x.shape[-1]

        dt = tf.float32

        g = mtf.get_variable(x.mesh, 'g', [], initializer=tf.constant_initializer(1, dtype=dt), dtype=dt)

        x = norm(x, axis, epsilon)
        x = x * g
        return x

# 预归一化
def prenorm(fn, scope):
    def inner(x, *args, **kwargs):
        return fn(scale_norm(x, scope), *args, **kwargs)
    return inner

# 残差连接
def residual(fn):
    def inner(x, *args, **kwargs):
        return fn(x, *args, **kwargs) + x
    return inner

# 完整的多头注意力机制
def attention(x, dim_head, dim_features_head, scope = 'attn', causal = False):
    with tf.variable_scope(scope):
        mesh, batch, seq, dim = x.mesh, *x.shape

        dim_heads = mtf.Dimension('dim_heads', dim_head.size * dim_features_head.size)
        dim_intermediate = mtf.Dimension('qkv_dimension', dim_heads.size * 3)
        qkv = linear(x, dim_intermediate, bias = False, scope='to_qkv')

        q, k, v = mtf.split(qkv, dim_intermediate, 3)
        q, k, v = map(lambda t: mtf.reshape(t, [batch, seq, dim_head, dim_features_head]), (q, k, v))
        q, k, v = map(lambda t: mtf.transpose(t, [batch, dim_head, seq, dim_features_head]), (q, k, v))

        k, v = map(lambda t: mtf.rename_dimension(t, seq.name, 'memory_length'), (k, v))
        mem_len_dim = v.shape[-2]

        dots = mtf.layers.us_einsum([q, k], [batch, dim_head, seq, mem_len_dim])

        if causal:
            i = mtf.range(mesh, seq, tf.int32)
            j = mtf.range(mesh, mem_len_dim, tf.int32)
            i, j = map(lambda t: mtf.broadcast(t, [seq, mem_len_dim]), (i, j))
            mask = mtf.less(i + mem_len_dim.size - seq.size, j)
            mask = mtf.cast(mask, tf.float32) * -1e10
            dots += mask

        attn = mtf.softmax(dots, mem_len_dim)
        out = mtf.einsum([attn, v], [batch, dim_head, seq, dim_features_head])

        out = mtf.transpose(out, [batch, seq, dim_head, dim_features_head])
        out = mtf.reshape(out, [batch, seq, dim_heads])

        combined_out = linear(out, dim, scope='combine_output')
        return combined_out

# 前馈神经网络
def ff(x, mult = 4, scope = 'ff'):
    *_, dim = x.shape

    with tf.variable_scope(scope):
        dim_intermediate = mtf.Dimension('ff_intermediate', dim.size * mult)
        h = linear(x, dim_intermediate, scope='w1')
        h = mtf.gelu(h)
        h = linear(h, dim, scope='w2')
        return h

# 块
def transformer(x, *, depth, dim_head, dim_features_head, causal = False):
    attn_fn = residual(prenorm(attention, 'norm1'))
    ff_fn = residual(prenorm(ff, 'norm2'))

    for i in range(depth):
        with tf.variable_scope(f'layer_{i}'):
            x = attn_fn(x, dim_head, dim_features_head, causal = causal)
            x = ff_fn(x)
    return x

# 语言模型
def transformer_lm(x, *, dim, num_tokens, depth, max_seq_len, dim_head, dim_features_head, causal = False):
    mesh, batch, seq_dim = x.mesh, *x.shape

    dim = mtf.Dimension('dim', dim)
    dim_head = mtf.Dimension('dim_head', dim_head)
    dim_features_head = mtf.Dimension('dim_features_head', dim_features_head)
    dim_num_tokens = mtf.Dimension('vocab_size', num_tokens)
    dim_max_seq_len = mtf.Dimension('max_seq_len', max_seq_len)

    wte = mtf.get_variable(mesh, name='wte', shape=mtf.Shape([dim_num_tokens, dim]), dtype=tf.float32)
    wpe = mtf.get_variable(mesh, name='wpe', shape=mtf.Shape([seq_dim, dim]), dtype=tf.float32)

    x = mtf.gather(wte, x, dim_num_tokens)
    p = mtf.gather(wpe, mtf.range(mesh, seq_dim, dtype=tf.int32), dim_max_seq_len)
    x = x + p

    x = transformer(x, depth = depth, dim_head = dim_head, dim_features_head = dim_features_head, causal = causal)

    logits = linear(x, dim_num_tokens, scope='to_logits')
    return logits

.\lucidrains\attention-tensorflow-mesh\attention_tensorflow_mesh\__init__.py

# 从 attention_tensorflow_mesh 模块中导入 transformer_lm, transformer, attention 函数
from attention_tensorflow_mesh.attention_tensorflow_mesh import transformer_lm, transformer, attention

Attention for Tensorflow Mesh

A collection of attention related functions, for building and scaling large attention neural networks.

Install

$ pip install attention-tensorflow-mesh

Usage

from attention_tensorflow_mesh import transformer_lm

import tensorflow as tf
tf.compat.v1.enable_eager_execution()
import mesh_tensorflow as mtf
from mesh_tensorflow import placement_mesh_impl

graph = mtf.Graph()
mesh = mtf.Mesh(graph, "my_mesh")

# setup dimensions

batch 		= mtf.Dimension('batch', 1)
seq_len 	= mtf.Dimension('sequence', 1024)
dim 		= mtf.Dimension('dim', 512)
dim_head 	= mtf.Dimension('dim_head', 12)
dim_features_head = mtf.Dimension('dim_features_head', 64)

# input

input = mtf.ones(mesh, mtf.Shape([batch, seq_len]), dtype=tf.int32)

# transformer

logits = transformer_lm(
	input,
	dim = 512,
	num_tokens = 20000,
	depth = 1,
	max_seq_len = 1024,
	dim_head = 12,
	dim_features_head = 75,
	causal = True
)

# placement

mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
lowering = mtf.Lowering(graph, {mesh: mesh_impl})

# export

logits = lowering.export_to_tf_tensor(logits)
print(logits)

More tools to come

.\lucidrains\attention-tensorflow-mesh\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的信息
setup(
  # 包的名称
  name = 'attention-tensorflow-mesh',
  # 查找所有包
  packages = find_packages(),
  # 版本号
  version = '0.0.2',
  # 许可证
  license='MIT',
  # 描述
  description = 'A bunch of attention related functions, for constructing transformers in tensorflow mesh',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/attention-tensorflow-mesh',
  # 关键词
  keywords = ['transformers', 'artificial intelligence'],
  # 安装依赖
  install_requires=[
      'mesh-tensorflow',
      'tensorflow-gpu>=1.15'
  ],
  # 分类
  classifiers=[
      'Development Status :: 4 - Beta',
      'Intended Audience :: Developers',
      'Topic :: Scientific/Engineering :: Artificial Intelligence',
      'License :: OSI Approved :: MIT License',
      'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\audiolm-pytorch\audiolm_pytorch\attend.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F

# 导入 namedtuple、wraps 函数
from collections import namedtuple
from functools import wraps
# 从 packaging 库中导入 version 模块
from packaging import version

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat

# 定义 Config 命名元组,包含三个布尔类型的参数
Config = namedtuple('Config', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义 exists 函数,用于判断变量是否存在
def exists(val):
    return val is not None

# 定义 once 装饰器函数,确保函数只被调用一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 使用 once 装饰器包装 print 函数,确保只打印一次
print_once = once(print)

# 主要类 Attend
class Attend(nn.Module):
    def __init__(
        self,
        dropout = 0.,
        causal = False,
        flash = False
    ):
        super().__init__()
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.causal = causal
        # 注册缓冲区 mask,初始值为 None
        self.register_buffer("mask", None, persistent=False)

        self.flash = flash
        # 断言条件,如果 flash 为 True 且 torch 版本小于 2.0.0,则抛出异常
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定 cuda 和 cpu 的高效注意力配置

        self.cpu_config = Config(True, True, True)
        self.cuda_config = None

        # 如果没有可用的 CUDA 或不使用 flash,则直接返回
        if not torch.cuda.is_available() or not flash:
            return

        # 获取当前 CUDA 设备的属性
        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        # 如果 CUDA 设备为 A100,则打印信息并设置 cuda_config
        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = Config(True, False, False)
        else:
            # 如果 CUDA 设备不是 A100,则打印信息并设置 cuda_config
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = Config(False, True, True)

    # flash_attn 函数,实现闪存注意力机制
    def flash_attn(self, q, k, v, mask = None):
        _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda

        # 将 k 和 v 重复 heads 次
        k = repeat(k, 'b ... -> b h ...', h = heads)
        v = repeat(v, 'b ... -> b h ...', h = heads)

        causal = self.causal

        # 如果 mask 存在,则根据 mask 设置 mask 和 causal
        if exists(mask):
            mask = rearrange(mask, 'b j -> b 1 1 j')
            mask = mask.expand(-1, heads, q_len, -1)

            if causal:
                causal_mask = torch.ones((q_len, k_len), device = q.device, dtype = torch.bool).triu(k_len - q_len + 1)
                mask = mask & ~causal_mask                
                causal = False

        # 根据是否在 CUDA 上运行选择配置,使用 torch.backends.cuda.sdp_kernel 函数
        config = self.cuda_config if is_cuda else self.cpu_config

        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask = mask,
                dropout_p = self.dropout if self.training else 0., 
                is_causal = causal
            )

        return out

    # 前向传播函数
    def forward(self, q, k, v, mask = None, attn_bias = None):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        n, device = q.shape[-2], q.device

        scale = q.shape[-1] ** -0.5

        # 如果使用 flash 注意力机制,则调用 flash_attn 函数
        if self.flash:
            assert not exists(attn_bias), 'attention bias not supported for flash attention'
            return self.flash_attn(q, k, v, mask = mask)

        # 相似度计算
        sim = einsum("b h i d, b j d -> b h i j", q, k) * scale

        # 如果存在 attn_bias,则加到相似度上
        if exists(attn_bias):
            sim = sim + attn_bias

        # 如果存在 mask,则根据 mask 设置 sim
        if exists(mask):
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        # 如果是因果关系,则设置因果 mask
        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones((i, j), device = sim.device, dtype = torch.bool).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # 注意力计算
        attn = sim.softmax(dim=-1)
        attn = self.attn_dropout(attn)

        # 聚合值
        out = einsum("b h i j, b j d -> b h i d", attn, v)

        return out

.\lucidrains\audiolm-pytorch\audiolm_pytorch\audiolm_pytorch.py

# 导入数学库
import math
# 导入 functools 模块中的 partial 和 wraps 函数
from functools import partial, wraps

# 导入 beartype 库中的 Optional, Union, List 类型
from beartype.typing import Optional, Union, List
# 导入 beartype 库中的 beartype 装饰器
from beartype import beartype

# 导入 torch 库
import torch
# 导入 torch 库中的 nn, einsum, Tensor 模块
from torch import nn, einsum, Tensor
# 导入 torch 库中的 grad 函数,并重命名为 torch_grad
from torch.autograd import grad as torch_grad
# 导入 torch.nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 导入 torch.nn.utils.rnn 中的 pad_sequence 函数
from torch.nn.utils.rnn import pad_sequence

# 导入 torchaudio 库
import torchaudio

# 导入 einops 库中的 rearrange, repeat, reduce 函数
from einops import rearrange, repeat, reduce
# 导入 einops.layers.torch 中的 Rearrange 类
from einops.layers.torch import Rearrange

# 导入 audiolm_pytorch 库中的 FairseqVQWav2Vec 类
from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec
# 导入 audiolm_pytorch 库中的 HubertWithKmeans 类
from audiolm_pytorch.hubert_kmeans import HubertWithKmeans

# 导入 audiolm_pytorch 库中的 t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME 函数
from audiolm_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME

# 导入 torchaudio.functional 中的 resample 函数
from torchaudio.functional import resample

# 导入 audiolm_pytorch 库中的 SoundStream 类
from audiolm_pytorch.soundstream import SoundStream
# 导入 audiolm_pytorch 库中的 EncodecWrapper 类
from audiolm_pytorch.encodec import EncodecWrapper
# 导入 audiolm_pytorch 库中的 AudioConditionerBase 类
from audiolm_pytorch.utils import AudioConditionerBase
# 导入 audiolm_pytorch 库中的 Attend 类
from audiolm_pytorch.attend import Attend

# 导入 tqdm 库中的 tqdm 函数
from tqdm import tqdm
# 导入 pathlib 库中的 Path 类
from pathlib import Path
# 导入 audiolm_pytorch.version 中的 __version__ 变量
from audiolm_pytorch.version import __version__
# 导入 packaging 库中的 version 模块
from packaging import version

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 如果值存在,则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 返回一个始终返回指定值的函数
def always(val):
    def inner(*args, **kwargs):
        return val
    return inner

# 如果函数存在,则返回该函数,否则返回一个始终返回 None 的函数
def maybe(fn):
    if not exists(fn):
        return always(None)

    @wraps(fn)
    def inner(x, *args, **kwargs):
        if not exists(x):
            return x
        return fn(x, *args, **kwargs)
    return inner

# 对两个数进行向上取整除法运算
def ceil_div(numer, denom):
    return (numer + denom - 1) // denom

# 计算使得 n 成为 mult 的倍数所需的余数
def remainder_needed_until_multiple(n, mult):
    return (ceil_div(n, mult) * mult) - n

# 将值向下舍入到最接近的倍数
def round_down_nearest_multiple(val, mult):
    return (val // mult) * mult

# 评估装饰器
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 张量辅助函数

# 生成一个与给定形状相同的掩码张量,其中一定比例的值被置为 0
def generate_mask_with_prob(shape, mask_prob, device):
    seq = shape[-1]
    rand = torch.randn(shape, device = device)
    rand[:, 0] = -torch.finfo(rand.dtype).max
    num_mask = min(int(seq * mask_prob), seq - 1)
    indices = rand.topk(num_mask, dim = -1).indices
    mask = ~torch.zeros(shape, device = device).scatter(1, indices, 1.).bool()
    return mask

# 注意力相关工具函数

# 缩小梯度的函数
def grad_shrink(t, alpha = 0.1):
    return t * alpha + t.detach() * (1 - alpha)

# 采样辅助函数

# 计算张量的自然对数
def log(t, eps = 1e-20):
    return torch.log(t + eps)

# 对张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# 生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 从 Gumbel 分布中采样
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)

# 保留前 k 个最大值,其余值置为负无穷
def top_k(logits, thres = 0.5):
    num_logits = logits.shape[-1]
    k = max(int((1 - thres) * num_logits), 1)
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 在遇到特定值后的位置进行掩码
def mask_out_after_eos_id(t, eos_id, mask_value = -1, keep_eos = True):
    eos_mask = (t == eos_id).float()

    if keep_eos:
        eos_mask = F.pad(eos_mask, (1, -1))

    after_eos_mask = eos_mask.cumsum(dim = -1) > 0
    return t.masked_fill(after_eos_mask, mask_value)

# 检查所有行是否都包含特定值
def all_rows_have_eos_id(t, eos_id):
    eos_mask = (t == eos_id)
    return torch.any(eos_mask, dim = -1).all()

# 安全地拼接张量
def safe_cat(*tensors, dim = -2):
    args = [*filter(exists, tensors)]

    if len(args) == 0:
        return None
    elif len(args) == 1:
        return args[0]
    else:
        return torch.cat(args, dim = dim)

# 无监督分类器指导函数

# 生成与给定形状相同的概率掩码张量
def prob_mask_like(shape, prob, device):
    if prob == 1:
        return torch.ones(shape, device = device, dtype = torch.bool)
    elif prob == 0:
        return torch.zeros(shape, device = device, dtype = torch.bool)
    else:
        return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

# 移除语义标记 id 中的唯一连续值
# 定义一个函数,用于在输入的 ids 张量末尾添加一个特定的 eos_id
def append_eos_id(ids, eos_id):
    # 获取 ids 张量的形状和设备信息
    b, device = ids.shape[0], ids.device
    # 创建一个只包含 eos_id 的张量,形状为 (1, ),设备与 ids 相同
    eos_ids = torch.ones(1, device=device).long() * eos_id
    # 将 eos_ids 重复 b 次,形状变为 (b, 1)
    eos_ids = repeat(eos_ids, '1 -> b 1', b=b)
    # 在 ids 张量的末尾拼接 eos_ids,dim=-1 表示在最后一个维度上拼接
    ids = torch.cat((ids, eos_ids), dim=-1)
    return ids

# 批量处理输入张量 t 中每个元素,使每个元素的值连续且唯一,用 pad_value 进行填充
def batch_unique_consecutive(t, pad_value=0.):
    # 对 t 沿着第 0 维度进行拆分,并对每个元素进行唯一连续化处理
    unique_arr = [torch.unique_consecutive(el) for el in t.unbind(dim=0)]
    # 对处理后的结果进行填充,batch_first=True 表示第一个维度为 batch 维度
    return pad_sequence(unique_arr, batch_first=True, padding_value=pad_value)

# 从 nn.Embedding 中获取嵌入向量,对于超出嵌入表范围的填充值使用 pad_id
def get_embeds(
    embeddings: nn.Embedding,
    codes: torch.Tensor,
    pad_id=-1,
    return_mask=False,
    mask_pad_pos_to=0
):
    # 创建一个与 codes 相同形状的布尔掩码,用于标记 pad_id 的位置
    pad_mask = codes == pad_id
    # 将 codes 中的 pad_id 替换为 0,作为嵌入表的索引
    codes_without_pad = codes.masked_fill(pad_mask, 0)
    # 从嵌入表中获取嵌入向量
    embeds = embeddings(codes_without_pad)

    # 如果指定了 mask_pad_pos_to,则将 pad_id 的位置替换为指定值
    if exists(mask_pad_pos_to):
        embeds = embeds.masked_fill(rearrange(pad_mask, '... -> ... 1'), mask_pad_pos_to)

    # 如果需要返回掩码,则返回嵌入向量和掩码的逻辑非
    if return_mask:
        return embeds, ~pad_mask

    return embeds

# 无偏置的 Layernorm,用于提高稳定性
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# 相对位置偏置
class RelativePositionBias(nn.Module):
    """ from https://arxiv.org/abs/2111.09883 """

    def __init__(
        self,
        *,
        dim,
        heads,
        layers=3
    ):
        super().__init__()
        self.net = nn.ModuleList([])
        self.net.append(nn.Sequential(nn.Linear(1, dim), nn.SiLU()))

        for _ in range(layers - 1):
            self.net.append(nn.Sequential(nn.Linear(dim, dim), nn.SiLU()))

        self.net.append(nn.Linear(dim, heads)

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, i, j):
        assert j >= i
        device = self.device

        i_pos = torch.arange(i, device=device) + (j - i)
        j_pos = torch.arange(j, device=device)

        rel_pos = (rearrange(i_pos, 'i -> i 1') - rearrange(j_pos, 'j -> 1 j'))
        rel_pos += (j - 1)

        x = torch.arange(-j + 1, j, device=device).float()
        x = rearrange(x, '... -> ... 1')

        for layer in self.net:
            x = layer(x)

        x = x[rel_pos]
        return rearrange(x, 'i j h -> h i j')

# GEGLU 激活函数
class GEGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.gelu(gate) * x

# FeedForward 层
def FeedForward(dim, mult=4, dropout=0.1):
    inner_dim = int(dim * 2 * mult / 3)
    return nn.Sequential(
        LayerNorm(dim),
        nn.Linear(dim, inner_dim * 2, bias=False),
        GEGLU(),
        LayerNorm(inner_dim),
        nn.Dropout(dropout),
        nn.Linear(inner_dim, dim, bias=False)
    )

# 注意力机制
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        causal=False,
        dim_head=64,
        dim_context=None,
        heads=8,
        norm_context=False,
        num_null_kv=0,
        dropout=0.1,
        scale=8,
        flash=False
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化头数和是否使用因果关系
        self.heads = heads
        self.causal = causal
        # 计算内部维度
        inner_dim = dim_head * heads

        # 设置上下文维度,默认为输入维度
        dim_context = default(dim_context, dim)

        # 初始化 LayerNorm 层
        self.norm = LayerNorm(dim)
        self.context_norm = LayerNorm(dim_context) if norm_context else nn.Identity()

        # 初始化 Dropout 层
        self.attn_dropout = nn.Dropout(dropout)

        # 初始化空键值对数量和空键值对参数
        self.num_null_kv = num_null_kv
        self.null_kv = nn.Parameter(torch.randn(2, num_null_kv, dim_head)) if num_null_kv > 0 else None

        # 初始化线性变换层
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim_context, dim_head * 2, bias = False)

        # 初始化 Attend 层
        self.attend = Attend(
            flash = flash,
            dropout = dropout,
            causal = causal
        )

        # 初始化输出层
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            nn.Dropout(dropout)
        )

    def forward(
        self,
        x,
        context = None,
        mask = None,
        attn_bias = None,
        prefix_context = None,
        prefix_context_mask = None,
        return_kv_cache = False,
        kv_cache = None
    ):
        # 获取输入张量的形状和设备信息
        b, n, _, device = *x.shape, x.device

        # 如果存在上下文信息,则进行归一化处理
        if exists(context):
            context = self.context_norm(context)

        # 如果存在前缀上下文信息,则进行处理
        kv_input = default(context, x)

        # 处理基于前缀的自注意力条件
        if exists(prefix_context):
            kv_input = torch.cat((prefix_context, kv_input), dim = -2)
            prefix_seq_len = prefix_context.shape[-2]

            if not exists(mask):
                mask = torch.ones((b, n), device = device, dtype = torch.bool)

            if exists(prefix_context_mask):
                mask = torch.cat((prefix_context_mask, mask), dim = -1)
            else:
                mask = F.pad(mask, (prefix_seq_len, 0), value = True)

            if exists(attn_bias):
                attn_bias = F.pad(attn_bias, (prefix_seq_len, 0), value = 0.)

        # 预处理
        x = self.norm(x)

        # 为查询、键、值进行投影
        q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)

        # 处理键值缓存
        if exists(kv_cache):
            ck, cv = kv_cache
            k = torch.cat((ck, k), dim = -2)
            v = torch.cat((cv, v), dim = -2)

        # 存储键值缓存
        if return_kv_cache:
            kv_cache = torch.stack((k, v))

        # 处理空键/值对
        if self.num_null_kv > 0:
            null_k, null_v = repeat(self.null_kv, 'kv n d -> kv b n d', b = b).unbind(dim = 0)
            k = torch.cat((null_k, k), dim = -2)
            v = torch.cat((null_v, v), dim = -2)

        # 分割为多头注意力
        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)

        # 处理掩码和空键/值对
        if exists(mask):
            mask = F.pad(mask, (self.num_null_kv, 0), value = True)

        # 注意力计算
        out = self.attend(q, k, v, attn_bias = attn_bias, mask = mask)

        # 合并多头
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)

        # 如果不需要返回键值缓存,则直接返回输出
        if not return_kv_cache:
            return out

        # 返回输出和键值缓存
        return out, kv_cache
# 定义 Transformer 类,用于实现 Transformer 模型
class Transformer(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        dim,  # 输入维度
        depth,  # Transformer 层数
        heads,  # 多头注意力头数
        dim_context = None,  # 上下文维度,默认为 None
        cross_attend = False,  # 是否进行跨注意力
        attn_dropout = 0.,  # 注意力层的 dropout 概率
        ff_dropout = 0.,  # FeedForward 层的 dropout 概率
        grad_shrink_alpha = 0.1,  # 梯度缩放参数
        cond_as_self_attn_prefix = False,  # 是否将条件作为自注意力前缀
        rel_pos_bias = True,  # 是否使用相对位置偏置
        flash_attn = False,  # 是否使用 Flash Attention
        **kwargs  # 其他参数
    ):
        super().__init__()
        rel_pos_bias = rel_pos_bias and not flash_attn

        assert not (cross_attend and cond_as_self_attn_prefix)

        self.dim_context = default(dim_context, dim)

        self.cond_as_self_attn_prefix = cond_as_self_attn_prefix

        self.grad_shrink = partial(grad_shrink, alpha = grad_shrink_alpha)

        self.layers = nn.ModuleList([])

        self.rel_pos_bias = RelativePositionBias(dim = dim // 2, heads = heads) if rel_pos_bias else None

        # 构建 Transformer 层
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, heads = heads, dropout = attn_dropout, flash = flash_attn, causal = True, **kwargs),
                Attention(dim = dim, heads = heads, dropout = attn_dropout, dim_context = dim_context, flash = flash_attn, num_null_kv = 1, norm_context = True, **kwargs) if cross_attend else None,
                FeedForward(dim = dim, dropout = ff_dropout)
            ]))

        self.norm = LayerNorm(dim)

    # 前向传播函数
    def forward(
        self,
        x,  # 输入张量
        self_attn_mask = None,  # 自注意力掩码
        context = None,  # 上下文张量
        context_mask = None,  # 上下文掩码
        attn_bias = None,  # 注意力偏置
        return_kv_cache = False,  # 是否返回键值缓存
        kv_cache = None  # 键值缓存
    ):
        assert not (self.cond_as_self_attn_prefix and not exists(context))
        assert not (exists(context) and context.shape[-1] != self.dim_context), f'you had specified a conditioning dimension of {self.dim_context}, yet what was received by the transformer has dimension of {context.shape[-1]}'

        n, device = x.shape[1], x.device

        # 从 cogview 论文中采用,GLM 130B LLM 采用,减少注意力网络不稳定性的可能性

        x = self.grad_shrink(x)

        # ���果使用条件作为自注意力前缀,则关闭键值缓存
        if self.cond_as_self_attn_prefix:
            kv_cache = None

        # 处理键值缓存
        new_kv_cache = []

        if exists(kv_cache):
            cache_len = kv_cache.shape[-2]
            kv_cache = iter(kv_cache)
        else:
            cache_len = 0
            kv_cache = iter([])

        x = x[:, cache_len:]

        # 相对位置偏置
        if exists(attn_bias):
            rel_pos_bias = attn_bias
        else:
            rel_pos_bias = maybe(self.rel_pos_bias)(n, n)

        if exists(rel_pos_bias):
            rel_pos_bias = rel_pos_bias[..., cache_len:, :]

        # 自注意力关键字参数
        self_attn_kwargs = dict()
        if self.cond_as_self_attn_prefix:
            self_attn_kwargs = dict(
                prefix_context = context,
                prefix_context_mask = context_mask
            )

        # Transformer 层
        for attn, cross_attn, ff in self.layers:

            residual = x

            x, layer_kv_cache = attn(x, attn_bias = rel_pos_bias, mask = self_attn_mask, kv_cache = next(kv_cache, None), return_kv_cache = True, **self_attn_kwargs)
            new_kv_cache.append(layer_kv_cache)

            x = x + residual

            if exists(cross_attn):
                assert exists(context)

                x = cross_attn(x, context = context, mask = context_mask) + x

            x = ff(x) + x

        x = self.norm(x)

        if not return_kv_cache:
            return x

        return x, torch.stack(new_kv_cache)

# 定义 SemanticTransformer 类,用于实现语义 Transformer
class SemanticTransformer(nn.Module):
    @beartype
    # 初始化函数,设置模型参数
    def __init__(
        self,
        *,
        dim,  # 维度
        depth,  # 深度
        num_semantic_tokens,  # 语义标记数量
        heads = 8,  # 头数
        attn_dropout = 0.,  # 注意力丢弃率
        ff_dropout = 0.,  # 前馈网络丢弃率
        t5_name = DEFAULT_T5_NAME,  # T5模型名称
        cond_dim = None,  # 条件维度
        has_condition = False,  # 是否有条件
        audio_text_condition = False,  # 音频文本条件
        cond_as_self_attn_prefix = False,  # 条件作为自注意力前缀
        cond_drop_prob = 0.5,  # 条件丢弃概率
        grad_shrink_alpha = 0.1,  # 梯度缩减系数
        rel_pos_bias = True,  # 相对位置偏置
        flash_attn = False,  # 闪电注意力
        **kwargs  # 其他参数
    ):
        super().__init__()
        # 根据条件设置相对位置偏置
        rel_pos_bias = rel_pos_bias and not flash_attn

        self.num_semantic_tokens = num_semantic_tokens

        if audio_text_condition:
            has_condition = True
            cond_dim = default(cond_dim, dim)

        self.has_condition = has_condition
        # 文本嵌入函数
        self.embed_text = partial(t5_encode_text, name = t5_name)
        self.cond_drop_prob = cond_drop_prob

        self.start_token = nn.Parameter(torch.randn(dim))

        # 语义嵌入
        self.semantic_embedding = nn.Embedding(num_semantic_tokens + 1, dim)
        self.eos_id = num_semantic_tokens

        text_dim = default(cond_dim, get_encoded_dim(t5_name))
        # 文本嵌入投影
        self.proj_text_embed = nn.Linear(text_dim, dim, bias = False) if text_dim != dim else nn.Identity()

        # Transformer模型
        self.transformer = Transformer(
            dim = dim,
            depth = depth,
            heads = heads,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            cross_attend = has_condition and not cond_as_self_attn_prefix,
            cond_as_self_attn_prefix = cond_as_self_attn_prefix,
            grad_shrink_alpha = grad_shrink_alpha,
            rel_pos_bias = rel_pos_bias,
            flash_attn = flash_attn,
            **kwargs
        )

        # 输出层
        self.to_logits = nn.Linear(dim, num_semantic_tokens + 1)

    # 设备属性
    @property
    def device(self):
        return next(self.parameters()).device

    # 加载模型
    def load(self, path):
        # 返回 pkg,以便如果此函数从 Trainer 函数调用中调用,则 Trainer 也可以访问从检查点加载的包
        device = self.device
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path), map_location = device)
        # 检查版本
        if 'version' in pkg and version.parse(pkg['version']) < version.parse(__version__):
            print(f'model was trained on older version {pkg["version"]} of audiolm-pytorch')
        self.load_state_dict(pkg['model'])
        return pkg

    # 带条件缩放的前向传播
    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 3,  # 条件缩放
        kv_cache = None,
        return_kv_cache = False,
        **kwargs
    ):
        kv_cache = iter(default(kv_cache, []))
        new_kv_caches = []

        logits, new_kv_cache = self.forward(*args, cond_drop_prob = 0., kv_cache = next(kv_cache, None), return_kv_cache = True, **kwargs)
        new_kv_caches.append(new_kv_cache)

        if cond_scale == 1 or not self.has_condition:
            if not return_kv_cache:
                return logits

            return logits, torch.stack(new_kv_caches)

        null_logits, null_new_kv_cache = self.forward(*args, cond_drop_prob = 1., kv_cache = next(kv_cache, None), return_kv_cache = True, **kwargs)
        new_kv_caches.append(null_new_kv_cache)

        scaled_logits = null_logits + (logits - null_logits) * cond_scale

        if not return_kv_cache:
            return scaled_logits

        return scaled_logits, torch.stack(new_kv_caches)

    # 前向传播
    @beartype
    def forward(
        self,
        *,
        ids = None,
        return_loss = False,
        text: Optional[List[str]] = None,
        text_embeds = None,
        self_attn_mask = None,
        cond_drop_prob = None,
        unique_consecutive = None,
        kv_cache = None,
        return_kv_cache = False
        ):
            # 获取当前设备
            device = self.device

            # 获取输入张量的批量大小
            b = ids.shape[0]

            # 检查是否存在文本或文本嵌入
            has_text = exists(text) or exists(text_embeds)
            # 断言条件:self.has_condition 与 has_text 不应该同时为真
            assert not (self.has_condition ^ has_text)

            # 初始化文本掩码为 None
            text_mask = None
            # 如果不存在文本嵌入且存在文本
            if not exists(text_embeds) and exists(text):
                # 在推理模式下
                with torch.inference_mode():
                    # 通过调用 self.embed_text 方法获取文本嵌入,输出设备为 device
                    text_embeds = self.embed_text(text, output_device = device)
                    # 生成文本掩码,标记非零元素
                    text_mask = torch.any(text_embeds != 0, dim = -1)

            # 如果存在文本嵌入
            if exists(text_embeds):
                # 通过 self.proj_text_embed 方法处理文本嵌入
                text_embeds = self.proj_text_embed(text_embeds)

            # 获取条件丢弃概率,默认为 self.cond_drop_prob
            cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

            # 如果存在文本掩码且条件丢弃概率大于 0
            if exists(text_mask) and cond_drop_prob > 0:
                # 生成保留掩码,概率为 1 - cond_drop_prob,设备为 device
                keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device)
                # 更新文本掩码,保留掩码与文本掩码按位与
                text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask

            # 如果需要返回损失
            if return_loss:
                # 复制 ids 到 labels,并截取最后一个元素
                labels, ids = ids.clone(), ids[:, :-1]

            # 获取 tokens,通过 self.semantic_embedding 获取嵌入
            tokens = get_embeds(self.semantic_embedding, ids)

            # 生成起始 tokens,重复 self.start_token,维度变换为 'd -> b 1 d',批量大小为 ids.shape[0]
            start_tokens = repeat(self.start_token, 'd -> b 1 d', b = ids.shape[0])

            # 拼接起始 tokens 和 tokens,沿着第二维度拼接
            tokens = torch.cat((start_tokens, tokens), dim = 1)

            # 如果存在 self_attn_mask
            if exists(self_attn_mask):
                # 在第二维度前面填充一个元素,值为 True
                self_attn_mask = F.pad(self_attn_mask, (1, 0), value = True)

            # 使用 transformer 处理 tokens,传入文本嵌入、自注意力掩码、文本掩码、kv_cache,并返回 kv_cache
            tokens, kv_cache = self.transformer(tokens, context = text_embeds, self_attn_mask = self_attn_mask, context_mask = text_mask, kv_cache = kv_cache, return_kv_cache = True)
            # 将 tokens 转换为 logits
            logits = self.to_logits(tokens)

            # 如果不需要返回 kv_cache,则返回 logits
            if not return_kv_cache:
                return logits

            # 返回 logits 和 kv_cache
            return logits, kv_cache
class CoarseTransformer(nn.Module):
    # 定义一个名为CoarseTransformer的类,继承自nn.Module

    @beartype
    def __init__(
        self,
        *,
        codebook_size,
        num_coarse_quantizers,
        dim,
        depth,
        num_semantic_tokens,
        heads = 8,
        attn_dropout = 0.,
        ff_dropout = 0.,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        cond_dim = None,
        audio_text_condition = False,
        cond_as_self_attn_prefix = False,
        cond_drop_prob = 0.5,
        grad_shrink_alpha = 0.1,
        project_semantic_logits = True,
        rel_pos_bias = True,
        flash_attn = False,
        **kwargs
    ):
        # 初始化函数,接受一系列参数

        super().__init__()
        # 调用父类的初始化函数

        rel_pos_bias = rel_pos_bias and not flash_attn
        # 更新rel_pos_bias的值

        self.num_semantic_tokens = num_semantic_tokens
        # 设置类属性num_semantic_tokens为传入的num_semantic_tokens的值

        if audio_text_condition:
            # 如果audio_text_condition为True
            has_condition = True
            # 将has_condition设置为True
            cond_dim = default(cond_dim, dim)
            # 如果cond_dim为None,则将其设置为dim

        self.has_condition = has_condition
        # 设置类属性has_condition为传入的has_condition的值
        self.embed_text = partial(t5_encode_text, name = t5_name)
        # 设置类属性embed_text为t5_encode_text函数的partial函数,name参数为t5_name
        self.cond_drop_prob = cond_drop_prob
        # 设置类属性cond_drop_prob为传入的cond_drop_prob的值

        self.semantic_start_token = nn.Parameter(torch.randn(dim))
        # 设置类属性semantic_start_token为一个dim维的随机张量
        self.coarse_start_token = nn.Parameter(torch.randn(dim))
        # 设置类属性coarse_start_token为一个dim维的随机张量

        self.semantic_eos_id = num_semantic_tokens
        # 设置类属性semantic_eos_id为num_semantic_tokens
        self.semantic_embedding = nn.Embedding(num_semantic_tokens + 1, dim)
        # 设置类属性semantic_embedding为一个Embedding层,词汇表大小为num_semantic_tokens + 1,embedding维度为dim

        self.coarse_eos_id = codebook_size
        # 设置类属性coarse_eos_id为codebook_size
        codebook_size_with_eos = codebook_size + 1
        # 计算codebook_size_with_eos为codebook_size + 1

        self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size_with_eos, dim)
        # 设置类属性coarse_embedding为一个Embedding层,词汇表大小为num_coarse_quantizers * codebook_size_with_eos,embedding维度为dim
        self.coarse_quantize_embedding = nn.Embedding(num_coarse_quantizers, dim)
        # 设置类属性coarse_quantize_embedding为一个Embedding层,词汇表大小为num_coarse_quantizers,embedding维度为dim

        text_dim = default(cond_dim, get_encoded_dim(t5_name))
        # 计算text_dim为cond_dim或者get_encoded_dim(t5_name)的值
        self.proj_text_embed = nn.Linear(text_dim, dim, bias = False) if text_dim != dim else nn.Identity()
        # 设置类属性proj_text_embed为一个线性层,输入维度为text_dim,输出维度为dim,不使用偏置项

        self.cross_attn_bias = nn.Parameter(torch.zeros(heads, 1, 1)) if rel_pos_bias else None
        # 设置类属性cross_attn_bias为一个形状为(heads, 1, 1)的参数张量,如果rel_pos_bias为True,否则为None

        self.transformer = Transformer(
            dim = dim,
            depth = depth,
            heads = heads,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            cross_attend = has_condition and not cond_as_self_attn_prefix,
            cond_as_self_attn_prefix = cond_as_self_attn_prefix,
            grad_shrink_alpha = grad_shrink_alpha,
            rel_pos_bias = rel_pos_bias,
            flash_attn = flash_attn,
            **kwargs
        )
        # 设置类属性transformer为一个Transformer模型,传入各种参数

        self.codebook_size = codebook_size
        # 设置类属性codebook_size为传入的codebook_size的值
        self.num_coarse_quantizers = num_coarse_quantizers
        # 设置类属性num_coarse_quantizers为传入的num_coarse_quantizers的值

        self.to_semantic_logits = nn.Linear(dim, num_semantic_tokens + 1) if project_semantic_logits else None
        # 设置类属性to_semantic_logits为一个线性层,输入维度为dim,输出维度为num_semantic_tokens + 1,如果project_semantic_logits为True,否则为None
        self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size_with_eos, dim))
        # 设置类属性coarse_logit_weights为一个形状为(num_coarse_quantizers, codebook_size_with_eos, dim)的参数张量

    @property
    def device(self):
        # 定义一个device属性,返回第一个参数的设备
        return next(self.parameters()).device

    def load(self, path):
        # 定义一个load方法,加载模型参数

        device = self.device
        # 获取设备信息
        path = Path(path)
        # 将path转换为Path对象
        assert path.exists()
        # 断言path存在
        pkg = torch.load(str(path), map_location = device)
        # 加载模型参数
        if 'version' in pkg and version.parse(pkg['version']) < version.parse(__version__):
            # 如果版本信息在pkg中且小于当前版本
            print(f'model was trained on older version {pkg["version"]} of audiolm-pytorch')
            # 打印模型训练的旧版本信息
        self.load_state_dict(pkg['model'])
        # 加载模型参数
        return pkg
        # 返回加载的模型参数

    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 3,
        return_kv_cache = False,
        kv_cache = None,
        embed_cache = None,
        **kwargs
        # 定义一个前向传播方法,接受一系列参数
        ):
        # 从缓存中获取键值对缓存的迭代器
        iter_kv_cache = iter(default(kv_cache, []))
        # 从缓存中获取嵌入缓存的迭代器
        iter_embed_cache = iter(default(embed_cache, []))
        # 创建新的键值对缓存列表
        new_kv_caches = []
        # 创建新的嵌入缓存列表
        new_embed_caches = []

        # 调用 forward 方法进行前向传播,获取语义和粗糙logits,以及新的键值对缓存和嵌入缓存
        (semantic_logits, coarse_logits), (new_kv_cache, new_embed_cache) = self.forward(*args, cond_drop_prob = 0., return_cache = True, kv_cache = next(iter_kv_cache, None), embed_cache = next(iter_embed_cache, None), **kwargs)
        # 将新的键值对缓存添加到列表中
        new_kv_caches.append(new_kv_cache)
        # 将新的嵌入缓存添加到列表中
        new_embed_caches.append(new_embed_cache)

        # 如果条件缩放为1或者没有条件
        if cond_scale == 1 or not self.has_condition:
            # 如果不需要返回键值对缓存,则返回语义logits和粗糙logits
            if not return_kv_cache:
                return semantic_logits, coarse_logits

            # 否则返回语义logits、粗糙logits以及新的键值对缓存和嵌入缓存
            return (semantic_logits, coarse_logits), (torch.stack(new_kv_caches), torch.stack(new_embed_caches))

        # 调用 forward 方法进行前向传播,获取空的语义和粗糙logits,以及新的空的键值对缓存和嵌入缓存
        (null_semantic_logits, null_coarse_logits), (null_new_kv_cache, null_new_embed_cache) = self.forward(*args, cond_drop_prob = 1., return_cache = True, kv_cache = next(iter_kv_cache, None), embed_cache = next(iter_embed_cache, None), **kwargs)
        # 将新的空的键值对缓存添加到列表中
        new_kv_caches.append(null_new_kv_cache)
        # 将新的空的嵌入缓存添加到列表中
        new_embed_caches.append(null_new_embed_cache)

        # 初始化缩放后的语义logits为None
        scaled_semantic_logits = None
        # 如果空的语义logits存在
        if exists(null_semantic_logits):
            # 计算缩放后的语义logits
            scaled_semantic_logits = null_semantic_logits + (semantic_logits - null_semantic_logits) * cond_scale

        # 计算缩放后的粗糙logits
        scaled_coarse_logits = null_coarse_logits + (coarse_logits - null_coarse_logits) * cond_scale

        # 如果不需要返回键值对缓存,则返回缩放后的语义logits和粗糙logits
        if not return_kv_cache:
            return scaled_semantic_logits, scaled_coarse_logits

        # 否则返回缩放后的语义logits、粗糙logits以及新的键值对缓存和嵌入缓存
        return (scaled_semantic_logits, scaled_coarse_logits), (torch.stack(new_kv_caches), torch.stack(new_embed_caches))

    @beartype
    def forward(
        self,
        *,
        semantic_token_ids,
        coarse_token_ids,
        self_attn_mask = None,
        text: Optional[List[str]] = None,
        text_embeds = None,
        cond_drop_prob = None,
        return_only_coarse_logits = False,
        return_cache = False,
        kv_cache = None,
        embed_cache = None
class FineTransformer(nn.Module):
    # 定义 FineTransformer 类,继承自 nn.Module
    def __init__(
        self,
        *,
        num_coarse_quantizers,
        num_fine_quantizers,
        codebook_size,
        dim,
        depth,
        heads = 8,
        attn_dropout = 0.,
        ff_dropout = 0.,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        cond_dim = None,
        audio_text_condition = False,
        cond_as_self_attn_prefix = False,
        cond_drop_prob = 0.5,
        grad_shrink_alpha = 0.1,
        project_coarse_logits = True,
        pad_id = -1,
        rel_pos_bias = True,
        flash_attn = False,
        **kwargs
    ):
        # 初始化函数,接受多个参数
        super().__init__()
        # 调用父类的初始化函数

        rel_pos_bias = rel_pos_bias and not flash_attn
        # 更新 rel_pos_bias 变量的值

        if audio_text_condition:
            # 如果 audio_text_condition 为真
            has_condition = True
            # 将 has_condition 设置为 True
            cond_dim = default(cond_dim, dim)
            # 如果 cond_dim 为 None,则将其设置为 dim

        self.has_condition = has_condition
        # 设置类属性 has_condition
        self.embed_text = partial(t5_encode_text, name = t5_name)
        # 设置类属性 embed_text,使用 t5_encode_text 函数的部分应用
        self.cond_drop_prob = cond_drop_prob
        # 设置类属性 cond_drop_prob

        self.num_coarse_quantizers = num_coarse_quantizers
        # 设置类属性 num_coarse_quantizers

        self.coarse_start_token = nn.Parameter(torch.randn(dim))
        self.fine_start_token = nn.Parameter(torch.randn(dim))
        # 创建 nn.Parameter 类型的 coarse_start_token 和 fine_start_token

        self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size, dim)
        self.fine_embedding = nn.Embedding(num_fine_quantizers * codebook_size, dim)
        # 创建 nn.Embedding 类型的 coarse_embedding 和 fine_embedding

        self.coarse_quantize_embedding = nn.Embedding(num_coarse_quantizers, dim)
        self.fine_quantize_embedding = nn.Embedding(num_fine_quantizers, dim)
        # 创建 nn.Embedding 类型的 coarse_quantize_embedding 和 fine_quantize_embedding

        self.pad_id = pad_id
        self.eos_id = codebook_size
        # 设置类属性 pad_id 和 eos_id

        text_dim = default(cond_dim, get_encoded_dim(t5_name))
        self.proj_text_embed = nn.Linear(text_dim, dim, bias = False) if text_dim != dim else nn.Identity()
        # 根据条件设置类属性 proj_text_embed

        self.transformer = Transformer(
            dim = dim,
            depth = depth,
            heads = heads,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            cross_attend = has_condition and not cond_as_self_attn_prefix,
            cond_as_self_attn_prefix = cond_as_self_attn_prefix,
            rel_pos_bias = False,
            grad_shrink_alpha = grad_shrink_alpha,
            flash_attn = flash_attn,
            **kwargs
        )
        # 创建 Transformer 类型的 transformer

        self.null_pos_bias = nn.Parameter(torch.randn(heads, 1, 1)) if rel_pos_bias else None
        # 创建 nn.Parameter 类型的 null_pos_bias

        pos_bias_mlp_dim = dim // 2
        self.pos_bias_mlp = nn.Sequential(
            nn.Linear(2, pos_bias_mlp_dim),
            nn.SiLU(),
            nn.Linear(pos_bias_mlp_dim, pos_bias_mlp_dim),
            nn.SiLU(),
            nn.Linear(pos_bias_mlp_dim, heads)
        ) if rel_pos_bias else None
        # 创建 nn.Sequential 类型的 pos_bias_mlp

        self.codebook_size = codebook_size
        self.num_coarse_quantizers = num_coarse_quantizers
        self.num_fine_quantizers = num_fine_quantizers
        # 设置类属性 codebook_size, num_coarse_quantizers, num_fine_quantizers

        self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size, dim)) if project_coarse_logits else None
        self.fine_logit_weights = nn.Parameter(torch.randn(num_fine_quantizers, codebook_size, dim))
        # 创建 nn.Parameter 类型的 coarse_logit_weights 和 fine_logit_weights

    @property
    def device(self):
        return next(self.parameters()).device
    # 定义 device 属性,返回第一个参数的设备信息

    def load(self, path):
        # 加载模型参数
        device = self.device
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path), map_location = device)
        # 加载模型参数
        if 'version' in pkg and version.parse(pkg['version']) < version.parse(__version__):
            print(f'model was trained on older version {pkg["version"]} of audiolm-pytorch')
        self.load_state_dict(pkg['model'])
        # 加载模型参数
        return pkg
        # 返回加载的模型参数
    # 定义一个带有条件缩放的前向传播函数
    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 3,  # 设置默认的条件缩放比例为3
        return_kv_cache = False,  # 设置默认不返回kv缓存
        kv_cache = None,  # 初始化kv缓存为None
        embed_cache = None,  # 初始化嵌入缓存为None
        **kwargs
    ):
        # 生成kv缓存的迭代器
        iter_kv_cache = iter(default(kv_cache, []))
        # 生成嵌入缓存的迭代器
        iter_embed_cache = iter(default(embed_cache, []))
        # 初始化新的kv缓存列表
        new_kv_caches = []
        # 初始化新的嵌入缓存列表
        new_embed_caches = []

        # 调用self.forward函数进行前向传播,并返回新的kv缓存和嵌入缓存
        (semantic_logits, coarse_logits), (new_kv_cache, new_embed_cache) = self.forward(*args, cond_drop_prob = 0., return_cache = True, kv_cache = next(iter_kv_cache, None), embed_cache = next(iter_embed_cache, None), **kwargs)
        # 将新的kv缓存添加到列表中
        new_kv_caches.append(new_kv_cache)
        # 将新的嵌入缓存添加到列表中
        new_embed_caches.append(new_embed_cache)

        # 如果条件缩放为1或者没有条件,则直接返回结果
        if cond_scale == 1 or not self.has_condition:
            if not return_kv_cache:
                return semantic_logits, coarse_logits

            return (semantic_logits, coarse_logits), (torch.stack(new_kv_caches), torch.stack(new_embed_caches))

        # 调用self.forward函数进行前向传播,条件概率为1,返回新的kv缓存和嵌入缓存
        (null_semantic_logits, null_coarse_logits), (null_new_kv_cache, null_new_embed_cache) = self.forward(*args, cond_drop_prob = 1., return_cache = True, kv_cache = next(iter_kv_cache, None), embed_cache = next(iter_embed_cache, None), **kwargs)
        # 将新的kv缓存添加到列表中
        new_kv_caches.append(null_new_kv_cache)
        # 将新的嵌入缓存添加到列表中
        new_embed_caches.append(null_new_embed_cache)

        # 计算缩放后的语义logits
        scaled_semantic_logits = None
        if exists(null_semantic_logits):
            scaled_semantic_logits = null_semantic_logits + (semantic_logits - null_semantic_logits) * cond_scale

        # 计算缩放后的粗糙logits
        scaled_coarse_logits = null_coarse_logits + (coarse_logits - null_coarse_logits) * cond_scale

        # 如果不返回kv缓存,则直接返回缩放后的结果
        if not return_kv_cache:
            return scaled_semantic_logits, scaled_coarse_logits

        return (scaled_semantic_logits, scaled_coarse_logits), (torch.stack(new_kv_caches), torch.stack(new_embed_caches))

    # 定义一个前向传播函数
    def forward(
        self,
        coarse_token_ids,
        fine_token_ids,
        text: Optional[List[str]] = None,
        text_embeds = None,
        cond_drop_prob = None,
        self_attn_mask = None,
        kv_cache = None,
        embed_cache = None,
        return_cache = False,
        return_only_fine_logits = False
# 定义一个语义转换器包装类
class SemanticTransformerWrapper(nn.Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        *,
        transformer: SemanticTransformer,  # 语义转换器对象
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,  # 可选的音频编码器对象
        audio_conditioner: Optional[AudioConditionerBase] = None,  # 可选的音频调节器对象
        pad_id = -1,  # 填充标识符,默认为-1
        unique_consecutive = True,  # 是否唯一连续,默认为True
        mask_prob = 0.15  # 掩码概率,默认为0.15
    ):
        super().__init__()  # 调用父类的初始化函数
        self.wav2vec = wav2vec  # 设置音频编码器对象
        self.transformer = transformer  # 设置语义转换器对象
        self.to(transformer.device)  # 将模型移动到语义转换器所在的设备
        self.audio_conditioner = audio_conditioner  # 设置音频调节器对象

        # 断言条件,如果音频调节器存在且语义转换器没有条件,则抛出异常
        assert not (exists(audio_conditioner) and not transformer.has_condition), 'if conditioning on audio embeddings from mulan, transformer has_condition must be set to True'

        # 断言条件,如果音频编码器存在且音频编码器的码书大小与语义转换器的语义标记数相同,则通过,否则抛出异常
        assert not exists(self.wav2vec) or self.wav2vec.codebook_size == transformer.num_semantic_tokens, f'num_semantic_tokens on SemanticTransformer must be set to {self.wav2vec.codebook_size}'

        self.unique_consecutive = unique_consecutive  # 设置是否唯一连续
        self.pad_id = pad_id  # 设置填充标识符
        self.eos_id = transformer.eos_id  # 设置结束标识符
        self.mask_prob = mask_prob  # 设置掩码概率

    # 返回模型所在设备
    @property
    def device(self):
        return next(self.parameters()).device

    # 嵌入文本
    def embed_text(self, text):
        return self.transformer.embed_text(text, output_device = self.device)

    # 生成函数
    @eval_decorator
    @torch.inference_mode()
    @beartype
    def generate(
        self,
        *,
        max_length,  # 最大长度
        text: Optional[List[str]] = None,  # 文本列表
        text_embeds = None,  # 文本嵌入
        prime_wave = None,  # 主要波形
        prime_wave_input_sample_hz = None,  # 主要波形输入采样频率
        prime_ids = None,  # 主要标识符
        batch_size = 1,  # 批大小
        cond_scale = 3,  # 条件规模
        filter_thres = 0.9,  # 过滤阈值
        temperature = 1.,  # 温度
        use_kv_cache = True,  # 是否使用键值缓存
        include_eos_in_output = True,  # 输出中是否包含结束标识符,如果进行分层采样,必须保留结束标识符以便操作
        **kwargs  # 其他参数
    ):
        # 获取当前对象的设备
        device = self.device

        # 从输入波形派生 wav2vec ids

        # 如果存在 prime_wave
        if exists(prime_wave):
            # 确保 prime_ids 不存在
            assert not exists(prime_ids)
            # 确保 self.wav2vec 存在
            assert exists(self.wav2vec)
            # 使用 self.wav2vec 从 prime_wave 中获取 ids
            ids = self.wav2vec(
                prime_wave,
                flatten = False,
                input_sample_hz = prime_wave_input_sample_hz
            )
        # 如果存在 prime_ids
        elif exists(prime_ids):
            ids = prime_ids
        else:
            # 创建一个空的张量作为 ids
            ids = torch.empty((batch_size, 0), dtype = torch.long, device = device)

        # 如果需要唯一连续的 ids
        if self.unique_consecutive:
            # 对 ids 进行唯一连续处理
            ids = batch_unique_consecutive(ids, pad_value = self.pad_id)

        # 如果需要派生联合音频文本嵌入
        if exists(self.audio_conditioner) and exists(prime_wave):
            # 确保 text 和 text_embeds 不存在
            assert not exists(text) and not exists(text_embeds)
            # 使用 self.audio_conditioner 从 prime_wave 中获取文本嵌入
            text_embeds = self.audio_conditioner(wavs = prime_wave, namespace = 'semantic')

        # 如果需要派生文本嵌入
        has_text = exists(text) or exists(text_embeds)
        assert not (self.transformer.has_condition ^ has_text)

        if not exists(text_embeds) and exists(text):
            # 使用 transformer.embed_text 从文本中获取文本嵌入
            with torch.inference_mode():
                text_embeds = self.transformer.embed_text(text, output_device = device)

        # 初始化变量
        batch = ids.shape[0]
        start_length = ids.shape[-1]
        sample_semantic_ids = ids.clone()
        last_logit_indices = (ids != self.pad_id).sum(dim = -1).long()
        kv_cache = None
        logits = None

        # 从 transformer 中采样
        for ind in tqdm(range(start_length, max_length), desc = 'generating semantic'):

            new_logits, new_kv_cache = self.transformer.forward_with_cond_scale(
                ids = sample_semantic_ids,
                text_embeds = text_embeds,
                cond_scale = cond_scale,
                kv_cache = kv_cache,
                return_kv_cache = True,
                **kwargs
            )

            if use_kv_cache:
                kv_cache = new_kv_cache
                logits = safe_cat(logits, new_logits, dim = -2)
            else:
                logits = new_logits

            last_logit_indices_expanded = repeat(last_logit_indices, 'b -> b 1 c', b = batch, c = logits.shape[-1])
            last_logits = logits.gather(1, last_logit_indices_expanded)
            last_logits = rearrange(last_logits, 'b 1 c -> b c')

            filtered_logits = top_k(last_logits, thres = filter_thres)
            sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)

            sampled = rearrange(sampled, 'b -> b 1')
            sample_semantic_ids = torch.cat((sample_semantic_ids, sampled), dim = -1)

            if all_rows_have_eos_id(sample_semantic_ids, self.eos_id):
                break

            last_logit_indices += 1

        sample_semantic_ids = mask_out_after_eos_id(sample_semantic_ids, self.eos_id, keep_eos = False)

        return sample_semantic_ids

    # 前向传播函数
    def forward(
        self,
        *,
        semantic_token_ids = None,
        raw_wave = None,
        text = None,
        text_embeds = None,
        return_loss = False,
        **kwargs
        ):
            # 断言要么给定原始波形(raw_wave),要么给定语义标记(semantic_token_ids)
            assert exists(raw_wave) or exists(semantic_token_ids), 'either raw waveform (raw_wave) is given or semantic token ids are given (semantic_token_ids)'

            if exists(self.audio_conditioner):
                # 断言存在原始波形
                assert exists(raw_wave)
                # 断言不存在文本和文本嵌入
                assert not exists(text) and not exists(text_embeds)
                # 使用音频调节器处理原始波形,生成语义嵌入
                text_embeds = self.audio_conditioner(wavs = raw_wave, namespace = 'semantic')

            if not exists(semantic_token_ids):
                # 断言存在 VQWav2Vec 模型
                assert exists(self.wav2vec), 'VQWav2Vec must be be provided if given raw wave for training'
                # 使用 VQWav2Vec 模型处理原始波形,生成语义标记
                semantic_token_ids = self.wav2vec(raw_wave, flatten = False)

            # 重新排列语义标记的维度
            semantic_token_ids = rearrange(semantic_token_ids, 'b ... -> b (...)')

            if self.training:
                # 如果是训练模式,为语义标记添加结束标记
                semantic_token_ids = append_eos_id(semantic_token_ids, self.transformer.eos_id)

            if self.unique_consecutive:
                # 如果需要唯一连续的语义标记,进行处理
                semantic_token_ids = batch_unique_consecutive(semantic_token_ids, pad_value = self.pad_id)

            # 输入标记为语义标记
            input_ids = semantic_token_ids
            if return_loss:
                # 如果需要返回损失,将输入标记截断最后一个标记
                input_ids = semantic_token_ids[:, :-1]

            self_attn_mask = None
            if self.mask_prob > 0. and self.training:
                # 如果需要进行掩码处理,生成掩码
                self_attn_mask = generate_mask_with_prob(input_ids.shape, self.mask_prob, input_ids.device)

            # 使用 Transformer 模型进行前向传播
            logits = self.transformer(
                ids = input_ids,
                text = text,
                text_embeds = text_embeds,
                self_attn_mask = self_attn_mask,
                **kwargs
            )

            if not return_loss:
                # 如果不需要返回损失,直接返回预测结果
                return logits

            # 计算交叉熵损失
            loss = F.cross_entropy(
                rearrange(logits, 'b n c -> b c n'),
                semantic_token_ids,
                ignore_index = self.pad_id
            )

            return loss
class CoarseTransformerWrapper(nn.Module):
    # 定义一个名为CoarseTransformerWrapper的类,继承自nn.Module
    @beartype
    def __init__(
        self,
        *,
        transformer: CoarseTransformer,
        codec: Optional[Union[SoundStream, EncodecWrapper]]  = None,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
        audio_conditioner: Optional[AudioConditionerBase] = None,
        pad_id = -1,
        unique_consecutive = True,
        semantic_cross_entropy_loss_weight = 1.,
        mask_prob = 0.15
    ):
        # 初始化函数,接受一系列参数
        super().__init__()
        # 调用父类的初始化函数
        self.codec = codec
        # 将参数codec赋值给实例变量self.codec
        self.wav2vec = wav2vec
        # 将参数wav2vec赋值给实例变量self.wav2vec

        self.transformer = transformer
        # 将参数transformer赋值给实例变量self.transformer
        self.to(transformer.device)
        # 将transformer的设备信息赋值给当前实例
        self.audio_conditioner = audio_conditioner
        # 将参数audio_conditioner赋值给实例变量self.audio_conditioner

        assert not (exists(audio_conditioner) and not transformer.has_condition), 'if conditioning on audio embeddings from mulan, transformer has_condition must be set to True'
        # 断言条件,如果条件不成立则抛出异常

        self.unique_consecutive = unique_consecutive
        # 将参数unique_consecutive赋值给实例变量self.unique_consecutive
        self.pad_id = pad_id
        # 将参数pad_id赋值给实例变量self.pad_id

        self.semantic_cross_entropy_loss_weight = semantic_cross_entropy_loss_weight
        # 将参数semantic_cross_entropy_loss_weight赋值给实例变量self.semantic_cross_entropy_loss_weight

        self.num_coarse_quantizers = transformer.num_coarse_quantizers * codec.rq_groups
        # 计算粗粒度量化器的数量
        self.semantic_eos_id = transformer.semantic_eos_id
        # 将transformer的语义结束符ID赋值给实例变量self.semantic_eos_id
        self.coarse_eos_id = transformer.coarse_eos_id
        # 将transformer的粗粒度结束符ID赋值给实例变量self.coarse_eos_id

        self.mask_prob = mask_prob
        # 将参数mask_prob赋值给实例变量self.mask_prob

    @property
    def device(self):
        # 定义一个device属性,返回参数的设备信息
        return next(self.parameters()).device

    @eval_decorator
    @torch.inference_mode()
    @beartype
    def generate(
        self,
        *,
        semantic_token_ids,
        prime_wave: Optional[Tensor] = None,
        prime_wave_input_sample_hz = None,
        prime_coarse_token_ids: Optional[Tensor] = None,
        text: Optional[List[str]] = None,
        text_embeds = None,
        max_time_steps = 512,
        cond_scale = 3.,
        filter_thres = 0.9,
        temperature = 1.,
        reconstruct_wave = False,
        use_kv_cache = True,
        **kwargs
    ):
        # 定义一个生成函数,接受一系列参数
        pass
        # 占位符,暂时不做任何操作

    def forward(
        self,
        *,
        semantic_token_ids = None,
        raw_wave = None,
        raw_wave_for_codec = None,
        text = None,
        text_embeds = None,
        coarse_token_ids = None,
        return_loss = False,
        **kwargs
    ):
        # 定义一个前向传播函数,接受一系列参数
        pass
        # 占位符,暂时不做任何操作

class FineTransformerWrapper(nn.Module):
    # 定义一个名为FineTransformerWrapper的类,继承自nn.Module
    @beartype
    def __init__(
        self,
        *,
        transformer: FineTransformer,
        codec: Optional[Union[SoundStream, EncodecWrapper]] = None,
        audio_conditioner: Optional[AudioConditionerBase] = None,
        coarse_cross_entropy_loss_weight = 1.,
        pad_id = -1,
        mask_prob = 0.15
    ):
        # 初始化函数,接受一系列参数
        super().__init__()
        # 调用父类的初始化函数
        self.codec = codec
        # 将参数codec赋值给实例变量self.codec

        self.transformer = transformer
        # 将参数transformer赋值给实例变量self.transformer
        self.to(transformer.device)
        # 将transformer的设备信息赋值给当前实例
        self.audio_conditioner = audio_conditioner
        # 将参数audio_conditioner赋值给实例变量self.audio_conditioner

        assert not (exists(audio_conditioner) and not transformer.has_condition), 'if conditioning on audio embeddings from mulan, transformer has_condition must be set to True'
        # 断言条件,如果条件不成立则抛出异常

        self.num_fine_quantizers = transformer.num_fine_quantizers * codec.rq_groups
        # 计算细粒度量化器的数量
        self.num_coarse_quantizers = transformer.num_coarse_quantizers * codec.rq_groups
        # 计算粗粒度量化器的数量

        if exists(codec):
            assert (self.num_fine_quantizers + self.num_coarse_quantizers) == (codec.num_quantizers * codec.rq_groups), 'number of fine and coarse quantizers on fine transformer must add up to total number of quantizers on codec'
        # 断言条件,如果条件不成立则抛出异常

        self.eos_id = transformer.eos_id
        # 将transformer的结束符ID赋值给实例变量self.eos_id

        assert self.num_coarse_quantizers > 0
        # 断言条件,如果条件不成立则抛出异常

        self.pad_id = pad_id
        # 将参数pad_id赋值给实例变量self.pad_id
        self.coarse_cross_entropy_loss_weight = coarse_cross_entropy_loss_weight
        # 将参数coarse_cross_entropy_loss_weight赋值给实例变量self.coarse_cross_entropy_loss_weight

        self.mask_prob = mask_prob
        # 将参数mask_prob赋值给实例变量self.mask_prob

    @property
    def device(self):
        # 定义一个device属性,返回参数的设备信息
        return next(self.parameters()).device

    @eval_decorator
    @torch.inference_mode()
    @beartype
    # 装饰器,用于评估和推断模式
    # 定义一个生成函数,用于生成音频波形
    def generate(
        self,
        *,
        coarse_token_ids,  # 粗粒度音频标记的张量
        prime_wave: Optional[Tensor] = None,  # 初始波形张量,默认为None
        prime_wave_input_sample_hz = None,  # 初始波形输入采样率,默认为None
        prime_fine_token_ids: Optional[Tensor] = None,  # 初始细粒度音频标记的张量,默认为None
        text: Optional[List[str]] = None,  # 文本列表,默认为None
        text_embeds = None,  # 文本嵌入,默认为None
        cond_scale = 3.,  # 条件缩放,默认为3.0
        filter_thres = 0.9,  # 过滤阈值,默认为0.9
        temperature = 1.,  # 温度,默认为1.0
        reconstruct_wave = False,  # 是否重建波形,默认为False
        use_kv_cache = True,  # 是否使用键值缓存,默认为True
        mask_out_generated_fine_tokens = False,  # 是否屏蔽生成的细粒度标记,默认为False
        **kwargs  # 其他关键字参数
    # 定义一个前向传播函数,用于模型的前向传播计算
    def forward(
        self,
        *,
        raw_wave = None,  # 原始波形,默认为None
        text = None,  # 文本,默认为None
        text_embeds = None,  # 文本嵌入,默认为None
        token_ids = None,  # 标记ID,默认为None
        coarse_token_ids = None,  # 粗粒度音频标记的张量,默认为None
        fine_token_ids = None,  # 细粒度音频标记的张量,默认为None
        return_loss = False,  # 是否返回损失,默认为False
        **kwargs  # 其他关键字参数
        ):
            # 断言条件:要么存在原始波形数据,要么存在粗糙和细粒度的令牌ID,但不能同时存在
            assert exists(raw_wave) ^ (exists(token_ids) ^ (exists(coarse_token_ids) and exists(fine_token_ids))), 'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'

            if exists(self.audio_conditioner):
                # 断言条件:存在原始波形数据
                assert exists(raw_wave)
                # 断言条件:不存在文本和文本嵌入
                assert not exists(text) and not exists(text_embeds)
                # 使用音频调节器处理原始波形数据,生成细粒度的文本嵌入
                text_embeds = self.audio_conditioner(wavs = raw_wave, namespace = 'fine') # technically audio embeds, but shared text-audio joint embedding space for mulan

            if exists(raw_wave):
                # 断言条件:存在编解码器
                assert exists(self.codec), 'Codec must be provided if given raw wave for training'

                with torch.inference_mode():
                    # 设置编解码器为评估模式
                    self.codec.eval()
                    # 使用编解码器处理原始波形数据,返回编码后的令牌ID
                    _, token_ids, _ = self.codec(raw_wave, return_encoded = True)

                    batch, num_timesteps = raw_wave.shape
                    num_frames = int(num_timesteps / self.codec.seq_len_multiple_of)

                    # 断言条件:令牌ID的形状应为(batch, num_frames, num_coarse_quantizers + num_fine_quantizers)
                    assert token_ids.shape == torch.Size((batch, num_frames, self.num_coarse_quantizers + self.num_fine_quantizers)), \
                        f'Expected token ids to have shape (batch, num_frames, num_coarse_quantizers + num_fine_quantizers), but got {token_ids.shape}'

            if exists(token_ids):
                # 将令牌ID分为粗糙和细粒度的令牌ID
                coarse_token_ids, fine_token_ids = token_ids[..., :self.num_coarse_quantizers], token_ids[..., self.num_coarse_quantizers:]

            # 重新排列粗糙和细粒度的令牌ID
            coarse_token_ids = rearrange(coarse_token_ids, 'b ... -> b (...)')
            fine_token_ids = rearrange(fine_token_ids, 'b ... -> b (...)')

            # 如果是训练阶段,确定标签,应从细粒度的令牌ID中删除一个
            if return_loss:
                coarse_labels = coarse_token_ids
                fine_labels = fine_token_ids
                fine_token_ids = fine_token_ids[:, :-1]

            # 忘记性因果掩码 - 结构化丢失
            self_attn_mask = None

            if self.mask_prob > 0 and self.training:
                mask_shape = (
                    coarse_token_ids.shape[0],
                    coarse_token_ids.shape[-1] + fine_token_ids.shape[-1] + 2
                )

                # 生成具有概率的掩码
                self_attn_mask = generate_mask_with_prob(mask_shape, self.mask_prob, device = self.device)

            # 获取粗糙和细粒度的逻辑值
            coarse_logits, fine_logits = self.transformer(
                coarse_token_ids = coarse_token_ids,
                fine_token_ids = fine_token_ids,
                self_attn_mask = self_attn_mask,
                text = text,
                text_embeds = text_embeds,
                **kwargs
            )

            # 提前返回逻辑值
            if not return_loss:
                return coarse_logits, fine_logits

            # 重新排列逻辑值的维度
            coarse_logits, fine_logits = map(lambda t: maybe(rearrange)(t, 'b n c -> b c n'), (coarse_logits, fine_logits))

            num_fine_logits = fine_logits.shape[-1]

            num_coarse_logits = 0
            coarse_loss = 0.

            if self.coarse_cross_entropy_loss_weight > 0 and exists(coarse_logits):
                num_coarse_logits = coarse_logits.shape[-1]

                # 计算粗糙损失
                coarse_loss = F.cross_entropy(
                    coarse_logits,
                    coarse_labels,
                    ignore_index = self.pad_id
                )

            # 计算细粒度损失
            fine_loss = F.cross_entropy(
                fine_logits,
                fine_labels,
                ignore_index = self.pad_id
            )

            # 返回损失值
            return (
                coarse_loss * num_coarse_logits * self.coarse_cross_entropy_loss_weight +
                fine_loss * num_fine_logits
            ) / (num_coarse_logits + num_fine_logits)
# 定义一个名为 AudioLM 的类,用于处理音频语言模型相关任务
class AudioLM(nn.Module):
    # 初始化函数,接受多个参数
    @beartype
    def __init__(
        self,
        *,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]], 
        codec: Union[SoundStream, EncodecWrapper],
        semantic_transformer: SemanticTransformer,
        coarse_transformer: CoarseTransformer,
        fine_transformer: FineTransformer,
        audio_conditioner: Optional[AudioConditionerBase] = None,
        unique_consecutive = True
    ):
        # 调用父类的初始化函数
        super().__init__()

        # 将传入的音频条件器参数赋值给对象属性
        self.audio_conditioner = audio_conditioner

        # 断言语义变换器的语义标记数与粗糙变换器的语义标记数相等
        assert semantic_transformer.num_semantic_tokens == coarse_transformer.num_semantic_tokens
        # 断言粗糙变换器的码书大小与细化变换器的码书大小相等
        assert coarse_transformer.codebook_size == fine_transformer.codebook_size
        # 断言粗糙变换器的粗糙量化器数量与细化变换器的粗糙量化器数量相等
        assert coarse_transformer.num_coarse_quantizers == fine_transformer.num_coarse_quantizers
        # 断言细化变换器的粗糙量化器数量与细化量化器数量之和等于编解码器的量化器数量
        assert (fine_transformer.num_coarse_quantizers + fine_transformer.num_fine_quantizers) == codec.num_quantizers

        # 检查是否需要文本输入
        self.semantic_has_condition = semantic_transformer.has_condition
        self.coarse_has_condition = coarse_transformer.has_condition
        self.fine_has_condition = fine_transformer.has_condition
        self.needs_text = any([self.semantic_has_condition, self.coarse_has_condition, self.fine_has_condition])

        # 创建语义变换器包装器对象
        self.semantic = SemanticTransformerWrapper(
            wav2vec = wav2vec,
            transformer = semantic_transformer,
            audio_conditioner = audio_conditioner,
            unique_consecutive = unique_consecutive
        )

        # 创建粗糙变换器包装器对象
        self.coarse = CoarseTransformerWrapper(
            wav2vec = wav2vec,
            codec = codec,
            transformer = coarse_transformer,
            audio_conditioner = audio_conditioner,
            unique_consecutive = unique_consecutive
        )

        # 创建细化变换器包装器对象
        self.fine = FineTransformerWrapper(
            codec= codec,
            transformer = fine_transformer,
            audio_conditioner = audio_conditioner
        )

    # 定义 device 属性,返回模型参数所在的设备
    @property
    def device(self):
        return next(self.parameters()).device

    # 定义前向传播函数,接受多个参数
    @eval_decorator
    @torch.inference_mode()
    def forward(
        self,
        *,
        batch_size = 1,
        text: Optional[List[str]] = None,
        text_embeds: Optional[Tensor] = None,
        prime_wave = None,
        prime_wave_input_sample_hz = None,
        prime_wave_path = None,
        max_length = 2048,
        return_coarse_generated_wave = False,
        mask_out_generated_fine_tokens = False
    ):
        # 断言条件:如果需要文本信息,但文本信息和文本嵌入都不存在,则抛出异常
        assert not (self.needs_text and (not exists(text) and not exists(text_embeds))), 'text needs to be passed in if one of the transformer requires conditioning'

        # 如果需要文本信息
        if self.needs_text:
            # 如果文本信息存在,则使用语义模型将文本嵌入
            if exists(text):
                text_embeds = self.semantic.embed_text(text)

        # 断言条件:如果提示音频既存在`prime_wave`又存在`prime_wave_path`,则抛出异常
        assert not (exists(prime_wave) and exists(prime_wave_path)), 'prompt audio must be given as either `prime_wave: Tensor` or `prime_wave_path: str`'

        # 如果`prime_wave`存在
        if exists(prime_wave):
            # 断言条件:必须提供提示音频的输入采样频率`prime_wave_input_sample_hz`
            assert exists(prime_wave_input_sample_hz), 'the input sample frequency for the prompt audio must be given as `prime_wave_input_sample_hz: int`'
            # 将`prime_wave`转移到指定设备
            prime_wave = prime_wave.to(self.device)
        # 如果`prime_wave_path`存在
        elif exists(prime_wave_path):
            # 将`prime_wave_path`转换为路径对象
            prime_wave_path = Path(prime_wave_path)
            # 断言条件:确保文件存在于指定路径
            assert exists(prime_wave_path), f'file does not exist at {str(prime_wave_path)}'

            # 加载提示音频和其输入采样频率
            prime_wave, prime_wave_input_sample_hz = torchaudio.load(str(prime_wave_path))
            prime_wave = prime_wave.to(self.device)

        # 使用语义模型生成语义标记
        semantic_token_ids = self.semantic.generate(
            text_embeds = text_embeds if self.semantic_has_condition else None,
            batch_size = batch_size,
            prime_wave = prime_wave,
            prime_wave_input_sample_hz = prime_wave_input_sample_hz,
            max_length = max_length
        )

        # 使用粗糙模型生成粗糙标记或重构音频波形
        coarse_token_ids_or_recon_wave = self.coarse.generate(
            text_embeds = text_embeds if self.coarse_has_condition else None,
            semantic_token_ids = semantic_token_ids,
            prime_wave = prime_wave,
            prime_wave_input_sample_hz = prime_wave_input_sample_hz,
            reconstruct_wave = return_coarse_generated_wave
        )

        # 如果需要返回生成的粗糙音频波形
        if return_coarse_generated_wave:
            return coarse_token_ids_or_recon_wave

        # 使用精细模型生成细化标记或重构音频波形
        generated_wave = self.fine.generate(
            text_embeds = text_embeds if self.fine_has_condition else None,
            coarse_token_ids = coarse_token_ids_or_recon_wave,
            prime_wave = prime_wave,
            prime_wave_input_sample_hz = prime_wave_input_sample_hz,
            reconstruct_wave = True,
            mask_out_generated_fine_tokens = mask_out_generated_fine_tokens
        )

        # 返回生成的音频波形
        return generated_wave
posted @ 2024-06-28 14:01  绝不原创的飞龙  阅读(5)  评论(0编辑  收藏  举报