Lucidrains-系列项目源码解析-四十三-

Lucidrains 系列项目源码解析(四十三)

.\lucidrains\tf-bind-transformer\scripts\fetch_factor_fastas.py

# 导入所需的库
import requests
from pathlib import Path
import click
import polars as pl
from tqdm import tqdm
from tf_bind_transformer.gene_utils import parse_gene_name
from tf_bind_transformer.data import read_bed

# 常量定义

# Uniprot 数据库的 URL
UNIPROT_URL = 'http://www.uniprot.org'

# 默认的 remap 文件路径
DEFAULT_REMAP_PATH = dict(
    HUMAN = './remap2022_crm_macs2_hg38_v1_0.bed',
    MOUSE = './remap2022_crm_macs2_mm10_v1_0.bed',
)

# 用于覆盖基因名到 Uniprot ID 的映射
GENE_NAME_TO_ID_OVERRIDE = {
    'SS18-SSX': ['Q8IZH1'],
    'TFIIIC': ['A6ZV34']        # 待办事项: 找出人类条目在 Uniprot 中的位置
}

# 辅助函数

# 根据给定的类型和标识符,获取 Uniprot 映射
def uniprot_mapping(fromtype, totype, identifier):
    params = {
        'from': fromtype,
        'to': totype,
        'format': 'tab',
        'query': identifier,
    }

    response = requests.get(f'{UNIPROT_URL}/mapping', params = params)
    return response.text

# 主要函数

# 命令行入口函数
@click.command()
@click.option('--species', help = 'Species', default = 'human', type = click.Choice(['human', 'mouse']))
@click.option('--remap-bed-path', help = 'Path to species specific remap file')
@click.option('--fasta-folder', help = 'Path to factor fastas', default = './tfactor.fastas')
def fetch_factors(
    species,
    remap_bed_path,
    fasta_folder
):
    species = species.upper()

    # 如果未提供 remap-bed-path,则使用默认路径
    if remap_bed_path is None:
        remap_bed_path = DEFAULT_REMAP_PATH[species]

    remap_bed_path = Path(remap_bed_path)

    # 检查 remap 文件是否存在
    assert remap_bed_path.exists(), f'remap file does not exist at {str(remap_bed_path)}'

    # 加载 bed 文件并从第三列获取所有唯一的目标
    df = read_bed(remap_bed_path)
    genes = set([target for targets in df[:, 3] for target in targets.split(',')])

    print(f'{len(genes)} factors found')

    # 加载所有保存的 fasta 文件,以便可以优雅地恢复
    fasta_files = [str(path) for path in Path('./').glob('*.fasta')]
    processed_genes = set([*map(lambda t: str(t).split('.')[0], fasta_files)])

    results_folder = Path(fasta_folder)
    results_folder.mkdir(exist_ok = True, parents = True)

    # 遍历基因并处理
    for unparsed_gene_name in tqdm(genes):
        for gene_name in parse_gene_name(unparsed_gene_name):

            if gene_name in processed_genes:
                continue

            # 根据基因名获取 Uniprot ID
            if gene_name not in GENE_NAME_TO_ID_OVERRIDE:
                uniprot_resp = uniprot_mapping('GENENAME', 'ID', gene_name)

                # 仅获取人类的条目(待办事项: 使其与物种无关)
                entries = list(filter(lambda t: f'_{species}' in t, uniprot_resp.split('\n')))
                entries = list(map(lambda t: t.split('\t')[1], entries))
            else:
                entries = GENE_NAME_TO_ID_OVERRIDE[gene_name]

            if len(entries) == 0:
                print(f'no entries found for {gene_name}')
                continue

            # 保存所有结果
            for entry in entries:
                response = requests.get(f'{UNIPROT_URL}/uniprot/{entry}.fasta')

                if response.status_code != 200:
                    print(f'<{response.status_code}> error fetching fasta file from gene {gene_name} {entry}')
                    continue

                fasta_path = str(results_folder / f'{gene_name}.{entry}.fasta')

                with open(fasta_path, 'w') as f:
                    f.write(response.text)

            print(f'gene {gene_name} written')

# 执行主函数
if __name__ == '__main__':
    fetch_factors()

.\lucidrains\tf-bind-transformer\scripts\negative_peak_to_bool_npy.py

#/usr/bin/python

# 导入必要的库
import polars as pl
import numpy as np
from pathlib import Path
import sys

# 从命令行参数中获取负峰文件路径和行数
NEGATIVE_PEAK_PATH = sys.argv[1]
NUMROWS = int(sys.argv[2])
ID_COLUMN = 'column_6'

# 读取以制表符分隔的无标题负峰文件
df = pl.read_csv(NEGATIVE_PEAK_PATH, sep = '\t', has_headers = False)

# 获取指定列的数据并转换为 NumPy 数组
np_array = df.get_column(ID_COLUMN).to_numpy()

# 创建一个布尔数组,用于标记需要保存的行
to_save = np.full((NUMROWS,), False)
to_save[np_array - 1] = True

# 获取文件路径的 stem 部分,并创建保存布尔数组的文件名
p = Path(NEGATIVE_PEAK_PATH)
filename = f'{p.stem}.bool'

# 将布尔数组保存为 NumPy 文件
np.save(filename, to_save)

# 打印保存文件的信息
print(f'{filename} saved')

.\lucidrains\tf-bind-transformer\scripts\remap_to_separate_exp_target_cell_beds.py

# 导入必要的库
import polars as pl
from pathlib import Path
from tf_bind_transformer.data import read_bed, save_bed

# 定义函数,用于生成分离的实验目标细胞类型的 BED 文件
def generate_separate_exp_target_cell_beds(
    remap_file,
    *,
    output_folder = './negative-peaks-per-target',
    exp_target_cell_type_col = 'column_4'
):
    # 将输出文件夹路径转换为 Path 对象,并确保文件夹存在
    output_folder = Path(output_folder)
    output_folder.mkdir(exist_ok = True, parents = True)

    # 读取 remap 文件内容到 DataFrame
    df = read_bed(remap_file)
    # 获取目标实验的唯一值列表
    target_experiments = df.get_column(exp_target_cell_type_col).unique().to_list()

    # 遍历每个目标实验
    for target_experiment in target_experiments:
        # 根据目标实验筛选 DataFrame
        filtered_df = df.filter(pl.col(exp_target_cell_type_col) == target_experiment)

        # 构建目标实验的 BED 文件路径
        target_bed_path = str(output_folder / f'{target_experiment}.bed')
        # 保存筛选后的 DataFrame 到 BED 文件
        save_bed(filtered_df, target_bed_path)

    # 打印成功信息
    print('success')

.\lucidrains\tf-bind-transformer\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'tf-bind-transformer',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.0.118',  # 版本号
  license='MIT',  # 许可证
  description = 'Transformer for Transcription Factor Binding',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/tf-bind-transformer',  # 项目链接
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'attention mechanism',
    'transformers',
    'transcription factors',
    'gene expression'
  ],
  install_requires=[  # 安装依赖列表
    'bidirectional-cross-attention',
    'biopython',
    'click',
    'einops>=0.3',
    'enformer-pytorch>=0.5',
    'fair-esm',
    'logavgexp-pytorch',
    'polars',
    'python-dotenv',
    'sentencepiece',
    'torch>=1.6',
    'transformers>=4.0',
    'tqdm'
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\tf-bind-transformer\tf_bind_transformer\attention.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 从 torch 库中导入 einsum 函数
from torch import einsum
# 从 bidirectional_cross_attention 模块中导入 BidirectionalCrossAttention 类

# 定义函数,判断变量是否存在
def exists(val):
    return val is not None

# 定义函数,返回默认值
def default(val, d):
    return val if exists(val) else d

# 定义前馈神经网络类
def FeedForward(dim, mult = 4, dropout = 0.):
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, dim * mult),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(dim * mult, dim)
    )

# 自注意力机制类
class SelfAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 64,
        dropout = 0.
    ):
        super().__init__()
        self.norm = nn.LayerNorm(dim)

        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x,
        mask = None,
    ):
        h = self.heads
        x = self.norm(x)

        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q = q * self.scale

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        if exists(mask):
            mask_value = -torch.finfo(sim.dtype).max
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, mask_value)

        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 自注意力块类
class SelfAttentionBlock(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dropout = 0.,
        ff_mult = 4,
        **kwargs
    ):
        super().__init__()
        self.attn = SelfAttention(dim = dim, dropout = dropout, **kwargs)
        self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout)

    def forward(self, x, mask = None):
        x = self.attn(x, mask = mask) + x
        x = self.ff(x) + x
        return x

# 双向交叉注意力类
class CrossAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 64,
        context_dim = None,
        dropout = 0.
    ):
        super().__init__()
        context_dim = default(context_dim, dim)
        self.norm = nn.LayerNorm(dim)
        self.context_norm = nn.LayerNorm(context_dim)

        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x,
        context,
        mask = None,
        context_mask = None
    ):
        h = self.heads

        x = self.norm(x)
        context = self.context_norm(context)

        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        q = q * self.scale

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        if exists(context_mask):
            mask_value = -torch.finfo(sim.dtype).max
            context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~context_mask, mask_value)

        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class JointCrossAttentionBlock(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(
        self,
        *,
        dim,  # 维度
        context_dim = None,  # 上下文维度,默认为None
        ff_mult = 4,  # FeedForward模块的倍数,默认为4
        dropout = 0.,  # dropout概率,默认为0
        **kwargs  # 其他参数
    ):
        super().__init__()  # 调用父类的初始化函数
        context_dim = default(context_dim, dim)  # 如果上下文维度为None,则设置为维度值

        # 创建双向交叉注意力模块
        self.attn = BidirectionalCrossAttention(dim = dim, context_dim = context_dim, dropout = dropout, prenorm = True, **kwargs)
        # 创建FeedForward模块
        self.ff = FeedForward(dim, mult = ff_mult, dropout = dropout)
        # 创建上下文的FeedForward模块
        self.context_ff = FeedForward(context_dim, mult = ff_mult, dropout = dropout)

    # 前向传播函数
    def forward(
        self,
        x,  # 输入数据
        context,  # 上下文数据
        mask = None,  # 掩码,默认为None
        context_mask = None  # 上下文掩码,默认为None
    ):
        # 使用注意力模块处理输入数据和上下文数据
        attn_out, context_attn_out = self.attn(x, context, mask = mask, context_mask = context_mask)

        # 更新输入数据
        x = x + attn_out
        # 更新上下文数据
        context = context + context_attn_out

        # 使用FeedForward模块处理输入数据
        x = self.ff(x) + x
        # 使用上下文的FeedForward模块处理上下文数据
        context = self.context_ff(context) + context

        # 返回更新后的输入数据和上下文数据
        return x, context

.\lucidrains\tf-bind-transformer\tf_bind_transformer\cache_utils.py

# 导入必要的库
import os
from shutil import rmtree
import torch
import hashlib
from functools import wraps
from pathlib import Path

# 检查值是否存在的辅助函数
def exists(val):
    return val is not None

# 常量定义

# 设置缓存路径,默认为用户主目录下的.cache.tf.bind.transformer文件夹
CACHE_PATH = Path(os.getenv('TF_BIND_CACHE_PATH', os.path.expanduser('~/.cache.tf.bind.transformer')))
# 如果缓存路径不存在,则创建
CACHE_PATH.mkdir(exist_ok=True, parents=True)

# 检查是否需要清除缓存
CLEAR_CACHE = exists(os.getenv('CLEAR_CACHE', None))
# 检查是否需要输出详细信息
VERBOSE = exists(os.getenv('VERBOSE', None))

# 日志输出函数
def log(s):
    if not VERBOSE:
        return
    print(s)

# 计算字符串的 MD5 哈希值
def md5_hash_fn(s):
    encoded = s.encode('utf-8')
    return hashlib.md5(encoded).hexdigest()

# 仅运行一次的函数

# 全局运行记录字典
GLOBAL_RUN_RECORDS = dict()

# 仅运行一次的装饰器函数
def run_once(global_id=None):
    def outer(fn):
        has_ran_local = False
        output = None

        @wraps(fn)
        def inner(*args, **kwargs):
            nonlocal has_ran_local
            nonlocal output

            has_ran = GLOBAL_RUN_RECORDS.get(global_id, False) if exists(global_id) else has_ran_local

            if has_ran:
                return output

            output = fn(*args, **kwargs)

            if exists(global_id):
                GLOBAL_RUN_RECORDS[global_id] = True

            has_ran = True
            return output

        return inner
    return outer

# 缓存函数

# 缓存函数的装饰器
def cache_fn(
    fn,
    path='',
    hash_fn=md5_hash_fn,
    clear=False or CLEAR_CACHE,
    should_cache=True
):
    if not should_cache:
        return fn

    # 创建缓存路径
    (CACHE_PATH / path).mkdir(parents=True, exist_ok=True)

    # 清除缓存文件夹的函数
    @run_once(path)
    def clear_cache_folder_():
        cache_path = rmtree(str(CACHE_PATH / path))
        (CACHE_PATH / path).mkdir(parents=True, exist_ok=True)

    @wraps(fn)
    def inner(t, *args, __cache_key=None, **kwargs):
        if clear:
            clear_cache_folder_()

        cache_str = __cache_key if exists(__cache_key) else t
        key = hash_fn(cache_str)

        entry_path = CACHE_PATH / path / f'{key}.pt'

        if entry_path.exists():
            log(f'cache hit: fetching {t} from {str(entry_path)}')
            return torch.load(str(entry_path))

        out = fn(t, *args, **kwargs)

        log(f'saving: {t} to {str(entry_path)}')
        torch.save(out, str(entry_path))
        return out
    return inner

.\lucidrains\tf-bind-transformer\tf_bind_transformer\context_utils.py

# 导入所需的库
import torch
import os
import logging
from transformers import AutoTokenizer, AutoModelForMaskedLM, logging
from tf_bind_transformer.cache_utils import cache_fn, run_once

# 设置日志级别为错误
logging.set_verbosity_error()

# 检查值是否存在
def exists(val):
    return val is not None

# 对字典中的值应用函数
def map_values(fn, dictionary):
    return {k: fn(v) for k, v in dictionary.items()}

# 检查是否在环境变量中设置了使用 CPU 进行上下文嵌入
CONTEXT_EMBED_USE_CPU = os.getenv('CONTEXT_EMBED_USE_CPU', None) is not None

# 如果设置了使用 CPU 进行上下文嵌入,则打印提示信息
if CONTEXT_EMBED_USE_CPU:
    print('calculating context embed only on cpu')

# 预定义模型的维度和路径
MODELS = dict(
    pubmed = dict(
        dim = 768,
        path = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract',
    )
)

# 全局变量,用于存储模型和分词器
GLOBAL_VARIABLES = dict(model = None, tokenizer = None)

# 获取指定模型的上下文维度
def get_contextual_dim(model_name):
    assert model_name in MODELS
    return MODELS[model_name]['dim']

# 初始化模型和分词器,只运行一次
@run_once('init_transformer')
def init_transformer(model_name):
    path = MODELS[model_name]['path']
    GLOBAL_VARIABLES['tokenizer'] = AutoTokenizer.from_pretrained(path)

    model = AutoModelForMaskedLM.from_pretrained(path)

    # 如果未设置使用 CPU 进行上下文嵌入,则将模型移至 GPU
    if not CONTEXT_EMBED_USE_CPU:
        model = model.cuda()

    GLOBAL_VARIABLES['model'] = model

# 对文本进行分词和编码
@torch.no_grad()
def tokenize_text(
    text,
    max_length = 256,
    model_name = 'pubmed',
    hidden_state_index = -1,
    return_cls_token = True
):
    init_transformer(model_name)

    model = GLOBAL_VARIABLES['model']
    tokenizer = GLOBAL_VARIABLES['tokenizer']

    encoding = tokenizer.batch_encode_plus(
        [text],
        add_special_tokens = True,
        padding = True,
        truncation = True,
        max_length = max_length,
        return_attention_mask = True,
        return_tensors = 'pt'
    )

    # 如果未设置使用 CPU 进行上下文嵌入,则将编码移至 GPU
    if not CONTEXT_EMBED_USE_CPU:
        encoding = map_values(lambda t: t.cuda(), encoding)

    model.eval()
    with torch.no_grad():
        outputs = model(**encoding, output_hidden_states = True)

    hidden_state = outputs.hidden_states[hidden_state_index][0]

    if return_cls_token:
        return hidden_state[0]

    return hidden_state.mean(dim = 0)

# 获取文本表示
def get_text_repr(
    texts,
    *,
    device,
    max_length = 256,
    model_name = 'pubmed',
    hidden_state_index = -1,
    return_cls_token = True,
):
    assert model_name in MODELS, f'{model_name} not found in available text transformers to use'

    # 如果输入为字符串,则转换为列表
    if isinstance(texts, str):
        texts = [texts]

    # 缓存文本表示函数
    get_context_repr_fn = cache_fn(tokenize_text, path = f'contexts/{model_name}')

    # 获取文本的表示
    representations = [get_context_repr_fn(text, max_length = max_length, model_name = model_name, hidden_state_index = hidden_state_index, return_cls_token = return_cls_token) for text in texts]

    return torch.stack(representations).to(device)

.\lucidrains\tf-bind-transformer\tf_bind_transformer\data.py

# 导入所需的模块
from Bio import SeqIO
from random import choice, randrange
from pathlib import Path
import functools
import polars as pl
from collections import defaultdict

import os
import json
import shutil
import numpy as np

import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from tf_bind_transformer.gene_utils import parse_gene_name
from enformer_pytorch import FastaInterval

from pyfaidx import Fasta
import pybedtools

# 定义函数判断值是否存在
def exists(val):
    return val is not None

# 定义函数返回默认值
def default(val, d):
    return val if exists(val) else d

# 定义函数查找满足条件的第一个元素的索引
def find_first_index(cond, arr):
    for ind, el in enumerate(arr):
        if cond(el):
            return ind
    return -1

# 定义函数将值转换为列表
def cast_list(val = None):
    if not exists(val):
        return []
    return [val] if not isinstance(val, (tuple, list)) else val

# 读取 BED 文件并返回 Polars 数据框
def read_bed(path):
    return pl.read_csv(path, sep = '\t', has_headers = False)

# 将 Polars 数据框保存为 BED 文件
def save_bed(df, path):
    df.to_csv(path, sep = '\t', has_header = False)

# 解析实验、目标和细胞类型
def parse_exp_target_cell(exp_target_cell):
    experiment, target, *cell_type = exp_target_cell.split('.')
    cell_type = '.'.join(cell_type) # 处理细胞类型包含句点的情况
    return experiment, target, cell_type

# 获取数据集的索引,用于提供辅助读取值预测的测序 reads
def fetch_experiments_index(path):
    if not exists(path):
        return dict()

    exp_path = Path(path)
    assert exp_path.exists(), 'path to experiments json must exist'

    root_json = json.loads(exp_path.read_text())
    experiments = root_json['experiments']

    index = {}
    for experiment in experiments:
        exp_id = experiment['accession']

        if 'details' not in experiment:
            continue

        details = experiment['details']

        if 'datasets' not in details:
            continue

        datasets = details['datasets']

        for dataset in datasets:
            dataset_name = dataset['dataset_name']
            index[dataset_name] = dataset['peaks_NR']

    return index

# 根据基因名和 Uniprot ID 获取蛋白质序列
class FactorProteinDatasetByUniprotID(Dataset):
    def __init__(
        self,
        folder,
        species_priority = ['human', 'mouse']
    ):
        super().__init__()
        fasta_paths = [*Path(folder).glob('*.fasta')]
        assert len(fasta_paths) > 0, f'no fasta files found at {folder}'
        self.paths = fasta_paths
        self.index_by_id = dict()

        for path in fasta_paths:
            gene, uniprotid, *_ = path.stem.split('.')
            self.index_by_id[uniprotid] = path

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, uid):
        index = self.index_by_id

        if uid not in index:
            return None

        entry = index[uid]
        fasta = SeqIO.read(entry, 'fasta')
        return str(fasta.seq)

# 获取蛋白质数据集
class FactorProteinDataset(Dataset):
    def __init__(
        self,
        folder,
        species_priority = ['human', 'mouse', 'unknown'],
        return_tuple_only = False
    # 初始化函数,接受一个文件夹路径作为参数
    def __init__(
        super().__init__()
        # 获取文件夹中所有以 .fasta 结尾的文件路径
        fasta_paths = [*Path(folder).glob('*.fasta')]
        # 断言至少找到一个 .fasta 文件,否则抛出异常
        assert len(fasta_paths) > 0, f'no fasta files found at {folder}'
        # 将找到的文件路径列表保存在 self.paths 中

        self.paths = fasta_paths

        # 使用 defaultdict 创建一个以基因名为键,文件路径列表为值的字典
        index_by_gene = defaultdict(list)
        # 是否只返回元组,即使只有一个亚单位
        self.return_tuple_only = return_tuple_only 

        # 遍历每个 .fasta 文件路径
        for path in fasta_paths:
            # 从文件名中提取基因名和 Uniprot ID
            gene, uniprotid, *_ = path.stem.split('.')
            # 将文件路径添加到对应基因名的列表中
            index_by_gene[gene].append(path)

        # 用于从文件路径中提取物种信息的 lambda 函数
        get_species_from_path = lambda p: p.stem.split('_')[-1].lower() if '_' in p.stem else 'unknown'

        # 使用 defaultdict 创建一个以基因名为键,经过物种筛选后的文件路径列表为值的字典
        filtered_index_by_gene = defaultdict(list)

        # 遍历每个基因及其对应的文件路径列表
        for gene, gene_paths in index_by_gene.items():
            # 计算每个物种在该基因下的文件数量
            species_count = list(map(lambda specie: len(list(filter(lambda p: get_species_from_path(p) == specie, gene_paths))), species_priority))
            # 找到第一个文件数量不为零的物种索引
            species_ind_non_zero = find_first_index(lambda t: t > 0, species_count)

            # 如果没有找到文件数量不为零的物种索引,则跳过该基因
            if species_ind_non_zero == -1:
                continue

            # 获取该基因下文件数量不为零的物种
            species = species_priority[species_ind_non_zero]
            # 将该基因下属于指定物种的文件路径添加到筛选后的字典中
            filtered_index_by_gene[gene] = list(filter(lambda p: get_species_from_path(p) == species, gene_paths))

        # 将筛选后的字典保存在 self.index_by_gene 中

        self.index_by_gene = filtered_index_by_gene

    # 返回文件路径列表的长度
    def __len__(self):
        return len(self.paths)

    # 根据未解析的基因名获取对应的序列
    def __getitem__(self, unparsed_gene_name):
        # 获取基因名对应的文件路径列表
        index = self.index_by_gene

        # 解析基因名
        genes = parse_gene_name(unparsed_gene_name)
        seqs = []

        # 遍历每个基因
        for gene in genes:
            entry = index[gene]

            # 如果该基因没有对应的文件路径,则打印提示信息并继续下一个基因
            if len(entry) == 0:
                print(f'no entries for {gene}')
                continue

            # 从文件路径列表中随机选择一个文件路径
            path = choice(entry) if isinstance(entry, list) else entry

            # 读取 fasta 文件中的序列
            fasta = SeqIO.read(path, 'fasta')
            seqs.append(str(fasta.seq))

        # 将序列列表转换为元组
        seqs = tuple(seqs)

        # 如果只有一个序列且不要求返回元组,则返回该序列
        if len(seqs) == 1 and not self.return_tuple_only:
            return seqs[0]

        # 否则返回序列元组
        return seqs
# 重新映射数据框函数

# 获取染色体名称集合
def get_chr_names(ids):
    return set(map(lambda t: f'chr{t}', ids))

# 定义染色体编号集合和染色体名称集合
CHR_IDS = set([*range(1, 23), 'X'])
CHR_NAMES = get_chr_names(CHR_IDS)

# 重新映射数据框并添加实验、目标和细胞类型信息
def remap_df_add_experiment_target_cell(df, col = 'column_4'):
    df = df.clone()

    # 提取实验信息
    exp_id = df.select([pl.col(col).str.extract(r"^([\w\-]+)\.*")])
    exp_id = exp_id.rename({col: 'experiment'}).to_series(0)
    df.insert_at_idx(3, exp_id)

    # 提取目标信息
    targets = df.select([pl.col(col).str.extract(r"[\w\-]+\.([\w\-]+)\.[\w\-]+")])
    targets = targets.rename({col: 'target'}).to_series(0)
    df.insert_at_idx(3, targets)

    # 提取细胞类型信息
    cell_type = df.select([pl.col(col).str.extract(r"^.*\.([\w\-]+)$")])
    cell_type = cell_type.rename({col: 'cell_type'}).to_series(0)
    df.insert_at_idx(3, cell_type)

    return df

# 判断列中元素是否在数组中
def pl_isin(col, arr):
    equalities = list(map(lambda t: pl.col(col) == t, arr))
    return functools.reduce(lambda a, b: a | b, equalities)

# 判断列中元素是否不在数组中
def pl_notin(col, arr):
    equalities = list(map(lambda t: pl.col(col) != t, arr))
    return functools.reduce(lambda a, b: a & b, equalities)

# 根据列中元素是否在数组中进行过滤数据框
def filter_by_col_isin(df, col, arr, chunk_size = 25):
    """
    polars 似乎存在一个 bug
    当 OR 条件超过 25 个时会冻结(对于 pl_isin)
    拆分成 25 个一组进行处理,然后合并
    """
    dataframes = []
    for i in range(0, len(arr), chunk_size):
        sub_arr = arr[i:(i + chunk_size)]
        filtered_df = df.filter(pl_isin(col, sub_arr))
        dataframes.append(filtered_df)
    return pl.concat(dataframes)

# 根据 BED 文件进行过滤
def filter_bed_file_by_(bed_file_1, bed_file_2, output_file):
    # 由 OpenAI Codex 生成

    bed_file_1_bedtool = pybedtools.BedTool(bed_file_1)
    bed_file_2_bedtool = pybedtools.BedTool(bed_file_2)
    bed_file_1_bedtool_intersect_bed_file_2_bedtool = bed_file_1_bedtool.intersect(bed_file_2_bedtool, v = True)
    bed_file_1_bedtool_intersect_bed_file_2_bedtool.saveas(output_file)

# 根据 TF 蛋白质序列文件进行过滤数据框
def filter_df_by_tfactor_fastas(df, folder):
    files = [*Path(folder).glob('**/*.fasta')]
    present_target_names = set([f.stem.split('.')[0] for f in files])
    all_df_targets = df.get_column('target').unique().to_list()

    all_df_targets_with_parsed_name = [(target, parse_gene_name(target)) for target in all_df_targets]
    unknown_targets = [target for target, parsed_target_name in all_df_targets_with_parsed_name for parsed_target_name_sub_el in parsed_target_name if parsed_target_name_sub_el not in present_target_names]

    if len(unknown_targets) > 0:
        df = df.filter(pl_notin('target', unknown_targets))
    return df

# 从 FASTA 文件生成随机范围
def generate_random_ranges_from_fasta(
    fasta_file,
    *,
    output_filename = 'random-ranges.bed',
    context_length,
    filter_bed_files = [],
    num_entries_per_key = 10,
    keys = None,
):
    fasta = Fasta(fasta_file)
    tmp_file = f'/tmp/{output_filename}'

    with open(tmp_file, 'w') as f:
        for chr_name in sorted(CHR_NAMES):
            print(f'generating ranges for {chr_name}')

            if chr_name not in fasta:
                print(f'{chr_name} not found in fasta file')
                continue

            chromosome = fasta[chr_name]
            chromosome_length = len(chromosome)

            start = np.random.randint(0, chromosome_length - context_length, (num_entries_per_key,))
            end = start + context_length
            start_and_end = np.stack((start, end), axis = -1)

            for row in start_and_end.tolist():
                start, end = row
                f.write('\t'.join((chr_name, str(start), str(end))) + '\n')

    for file in filter_bed_files:
        filter_bed_file_by_(tmp_file, file, tmp_file)

    shutil.move(tmp_file, f'./{output_filename}')

    print('success')

# 上下文字符串创建类

class ContextDataset(Dataset):
    def __init__(
        self,
        biotypes_metadata_path = None,
        include_biotypes_metadata_in_context = False,
        include_biotypes_metadata_columns = [],
        biotypes_metadata_delimiter = ' | ',
    # 初始化类的属性,设置是否在上下文中包含生物类型元数据,以及相关的列和分隔符
    def __init__(
        self, include_biotypes_metadata_in_context, include_biotypes_metadata_columns, biotypes_metadata_delimiter
    ):
        self.include_biotypes_metadata_in_context = include_biotypes_metadata_in_context
        self.include_biotypes_metadata_columns = include_biotypes_metadata_columns
        self.biotypes_metadata_delimiter = biotypes_metadata_delimiter

        # 如果要在上下文中包含生物类型元数据
        if include_biotypes_metadata_in_context:
            # 确保要包含的生物类型元数据列数大于0
            assert len(self.include_biotypes_metadata_columns) > 0, 'must have more than one biotype metadata column to include'
            # 确保生物类型元数据路径存在
            assert exists(biotypes_metadata_path), 'biotypes metadata path must be supplied if to be included in context string'

            # 创建路径对象
            p = Path(biotypes_metadata_path)

            # 根据文件后缀选择分隔符
            if p.suffix == '.csv':
                sep = ','
            elif p.suffix == '.tsv':
                sep = '\t'
            else:
                raise ValueError(f'invalid suffix {p.suffix} for biotypes')

            # 读取CSV或TSV文件并存储为DataFrame
            self.df = pl.read_csv(str(p), sep = sep)

    # 返回DataFrame的长度或-1(如果不包含生物类型元数据)
    def __len__():
        return len(self.df) if self.include_biotypes_metadata_in_context else -1

    # 根据生物类型获取上下文字符串
    def __getitem__(self, biotype):
        # 如果不包含生物类型元数据,直接返回生物类型
        if not self.include_biotypes_metadata_in_context:
            return biotype

        # 获取要包含的生物类型元数据列的索引
        col_indices = list(map(self.df.columns.index, self.include_biotypes_metadata_columns))
        # 根据生物类型筛选DataFrame
        filtered = self.df.filter(pl.col('biotype') == biotype)

        # 如果没有找到匹配的行,打印消息并返回生物类型
        if len(filtered) == 0:
            print(f'no rows found for {biotype} in biotype metadata file')
            return biotype

        # 获取匹配行的数据
        row = filtered.row(0)
        # 获取要包含的列的值
        columns = list(map(lambda t: row[t], col_indices))

        # 组合上下文字符串
        context_string = self.biotypes_metadata_delimiter.join([biotype, *columns])
        return context_string
# 定义一个用于重新映射数据的数据集类 RemapAllPeakDataset,继承自 Dataset 类
class RemapAllPeakDataset(Dataset):
    # 初始化函数,接收多个参数
    def __init__(
        self,
        *,
        factor_fasta_folder,  # 因子 fasta 文件夹
        bed_file = None,  # bed 文件,默认为 None
        remap_df = None,  # 重新映射数据框,默认为 None
        filter_chromosome_ids = None,  # 过滤染色体 ID,默认为 None
        exclude_targets = None,  # 排除目标,默认为 None
        include_targets = None,  # 包含目标,默认为 None
        exclude_cell_types = None,  # 排除细胞类型,默认为 None
        include_cell_types = None,  # 包含细胞类型,默认为 None
        remap_df_frac = 1.,  # 重新映射数据框比例,默认为 1
        experiments_json_path = None,  # 实验 JSON 路径,默认为 None
        include_biotypes_metadata_in_context = False,  # 在上下文中包含生物类型元数据,默认为 False
        biotypes_metadata_path = None,  # 生物类型元数据路径,默认为 None
        include_biotypes_metadata_columns = [],  # 包含生物类型元数据列,默认为空列表
        biotypes_metadata_delimiter = ' | ',  # 生物类型元数据分隔符,默认为 ' | '
        balance_sampling_by_target = False,  # 按目标平衡采样,默认为 False
        **kwargs  # 其他关键字参数
    ):
        super().__init__()  # 调用父类的初始化函数
        assert exists(remap_df) ^ exists(bed_file), 'either remap bed file or remap dataframe must be passed in'

        if not exists(remap_df):
            remap_df = read_bed(bed_file)  # 读取 bed 文件并赋值给 remap_df

        if remap_df_frac < 1:
            remap_df = remap_df.sample(frac = remap_df_frac)  # 如果 remap_df_frac 小于 1,则对 remap_df 进行采样

        dataset_chr_ids = CHR_IDS  # 数据集染色体 ID

        if exists(filter_chromosome_ids):
            dataset_chr_ids = dataset_chr_ids.intersection(set(filter_chromosome_ids))  # 如果存在过滤染色体 ID,则取交集

        remap_df = remap_df.filter(pl_isin('column_1', get_chr_names(dataset_chr_ids)))  # 过滤 remap_df 中染色体名称
        remap_df = filter_df_by_tfactor_fastas(remap_df, factor_fasta_folder)  # 根据因子 fasta 文件夹过滤 remap_df

        self.factor_ds = FactorProteinDataset(factor_fasta_folder)  # 初始化因子蛋白数据集

        # 根据包含和排除目标列表过滤数据集
        # (<所有可用目标> 交集 <包含目标>) 减去 <排除目标>

        include_targets = cast_list(include_targets)  # 将包含目标转换为列表
        exclude_targets = cast_list(exclude_targets)  # 将排除目标转换为列表

        if include_targets:
            remap_df = remap_df.filter(pl_isin('target', include_targets))  # 如果包含目标非空,则过滤 remap_df

        if exclude_targets:
            remap_df = remap_df.filter(pl_notin('target', exclude_targets))  # 如果排除目标非空,则过滤 remap_df

        # 根据包含和排除细胞类型列表过滤数据集
        # 与目标相同的逻辑

        include_cell_types = cast_list(include_cell_types)  # 将包含细胞类型转换为列表
        exclude_cell_types = cast_list(exclude_cell_types)  # 将排除细胞类型转换为列表

        if include_cell_types:
            remap_df = remap_df.filter(pl_isin('cell_type', include_cell_types))  # 如果包含细胞类型非空,则过滤 remap_df

        if exclude_cell_types:
            remap_df = remap_df.filter(pl_notin('cell_type', exclude_cell_types))  # 如果排除细胞类型非空,则过滤 remap_df

        assert len(remap_df) > 0, 'dataset is empty by filter criteria'  # 断言数据集不为空

        self.df = remap_df  # 将过滤后的数据集赋值给 self.df
        self.fasta = FastaInterval(**kwargs)  # 初始化 FastaInterval 对象

        self.experiments_index = fetch_experiments_index(experiments_json_path)  # 获取实验索引

        # 平衡目标采样逻辑

        self.balance_sampling_by_target = balance_sampling_by_target  # 平衡目标采样标志

        if self.balance_sampling_by_target:
            self.df_indexed_by_target = []  # 初始化按目标索引的数据集列表

            for target in self.df.get_column('target').unique().to_list():
                df_by_target = self.df.filter(pl.col('target') == target)  # 根据目标过滤数据集
                self.df_indexed_by_target.append(df_by_target)  # 将按目标过滤后的数据集添加到列表中

        # 上下文字符串创建器

        self.context_ds = ContextDataset(
            include_biotypes_metadata_in_context = include_biotypes_metadata_in_context,  # 是否在上下文中包含生物类型元数据
            biotypes_metadata_path = biotypes_metadata_path,  # 生物类型元数据路径
            include_biotypes_metadata_columns = include_biotypes_metadata_columns,  # 包含生物类型元数据列
            biotypes_metadata_delimiter = biotypes_metadata_delimiter  # 生物类型元数据分隔符
        )

    # 返回数据集长度
    def __len__(self):
        if self.balance_sampling_by_target:
            return len(self.df_indexed_by_target)  # 如果按目标平衡采样,则返回按目标索引的数据集长度
        else:
            return len(self.df)  # 否则返回数据集长度
    # 定义特殊方法,用于通过索引获取数据样本
    def __getitem__(self, ind):
        # 如果按目标平衡采样,则从索引数据帧中随机抽取样本
        if self.balance_sampling_by_target:
            # 从按目标索引的数据帧中筛选数据
            filtered_df = self.df_indexed_by_target[ind]
            # 随机选择索引
            rand_ind = randrange(0, len(filtered_df))
            # 获取随机样本
            sample = filtered_df.row(rand_ind)
        else:
            # 否则直接从数据帧中获取样本
            sample = self.df.row(ind)

        # 解包样本数据
        chr_name, begin, end, _, _, _, experiment_target_cell_type, reading, *_ = sample

        # 解析实验、目标和细胞类型
        experiment, target, cell_type = parse_exp_target_cell(experiment_target_cell_type)

        # 获取序列数据
        seq = self.fasta(chr_name, begin, end)
        # 获取氨基酸序列数据
        aa_seq = self.factor_ds[target]
        # 获取上下文字符串数据
        context_str = self.context_ds[cell_type]

        # 将读数转换为张量
        read_value = torch.Tensor([reading])

        # 获取峰值数量
        peaks_nr = self.experiments_index.get(experiment_target_cell_type, 0.)
        # 将峰值数量转换为张量
        peaks_nr = torch.Tensor([peaks_nr])

        # 创建标签张量
        label = torch.Tensor([1.])

        # 返回序列数据、氨基酸序列数据、上下文字符串数据、峰值数量、读数值和标签
        return seq, aa_seq, context_str, peaks_nr, read_value, label
# 为基于保留值的 exp-target-cells 过滤函数

def filter_exp_target_cell(
    arr,
    *,
    exclude_targets = None,  # 排除的目标列表
    include_targets = None,  # 包含的目标列表
    exclude_cell_types = None,  # 排除的细胞类型列表
    include_cell_types = None,  # 包含的细胞类型列表
):
    out = []  # 输出列表

    for el in arr:  # 遍历输入数组
        experiment, target, cell_type = parse_exp_target_cell(el)  # 解析实验、目标和细胞类型

        # 如果包含的目标列表存在且不为空,并且目标不在包含的目标列表中,则跳过
        if exists(include_targets) and len(include_targets) > 0 and target not in include_targets:
            continue

        # 如果排除的目标列表存在且目标在排除的目标列表中,则跳过
        if exists(exclude_targets) and target in exclude_targets:
            continue

        # 如果包含的细胞类型列表存在且不为空,并且细胞类型不在包含的细胞类型列表中,则跳过
        if exists(include_cell_types) and len(include_cell_types) > 0 and cell_type not in include_cell_types:
            continue

        # 如果排除的细胞类型列表存在且细胞类型在排除的细胞类型列表中,则跳过
        if exists(exclude_cell_types) and cell_type in exclude_cell_types:
            continue

        out.append(el)  # 将符合条件的元素添加到输出列表中

    return out  # 返回输出列表


# 为特定 exp-target-celltype 范围的负样本数据集

class ScopedNegativePeakDataset(Dataset):
    def __init__(
        self,
        *,
        fasta_file,
        factor_fasta_folder,
        numpy_folder_with_scoped_negatives,
        exts = '.bed.bool.npy',
        remap_bed_file = None,
        remap_df = None,
        filter_chromosome_ids = None,
        experiments_json_path = None,
        exclude_targets = None,  # 排除的目标列表
        include_targets = None,  # 包含的目标列表
        exclude_cell_types = None,  # 排除的细胞类型列表
        include_cell_types = None,  # 包含的细胞类型列表
        include_biotypes_metadata_in_context = False,
        biotypes_metadata_path = None,
        include_biotypes_metadata_columns = [],
        biotypes_metadata_delimiter = ' | ',
        balance_sampling_by_target = False,
        **kwargs
    # 初始化函数,接受 remap_df 或 remap_bed_file 作为参数
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 断言只能传入 remap_df 或 remap_bed_file 中的一个
        assert exists(remap_df) ^ exists(remap_bed_file), 'either remap bed file or remap dataframe must be passed in'

        # 如果 remap_df 不存在,则从 remap_bed_file 中读取数据
        if not exists(remap_df):
            remap_df = read_bed(remap_bed_file)

        # 初始化 dataset_chr_ids 为全局变量 CHR_IDS
        dataset_chr_ids = CHR_IDS

        # 如果存在 filter_chromosome_ids,则更新 dataset_chr_ids
        if exists(filter_chromosome_ids):
            dataset_chr_ids = dataset_chr_ids.intersection(set(filter_chromosome_ids))

        # 根据 dataset_chr_ids 过滤 remap_df,生成 mask
        filter_map_df = remap_df.with_column(pl.when(pl_isin('column_1', get_chr_names(dataset_chr_ids))).then(True).otherwise(False).alias('mask'))
        mask = filter_map_df.get_column('mask').to_numpy()

        # 统计 mask 中为 True 的数量
        num_scoped_negs = mask.sum()

        # 打印找到的 scoped negative 行数
        print(f'{num_scoped_negs} scoped negative rows found for training')

        # 断言找到的 scoped negative 行数大于 0
        assert num_scoped_negs > 0, 'all remap rows filtered out for scoped negative peak dataset'

        # 设置 self.df 和 self.chromosome_mask
        self.df = remap_df
        self.chromosome_mask = mask

        # 获取 exp-target-cell 到布尔值 numpy 的字典,指示哪些是负样本

        # 获取所有 numpy 文件的路径
        npys_paths = [*Path(numpy_folder_with_scoped_negatives).glob('**/*.npy')]
        exp_target_cell_negatives = [(path.name.rstrip(exts), path) for path in npys_paths]

        # 获取所有 exp_target_cells
        exp_target_cells = [el[0] for el in exp_target_cell_negatives]

        # 根据条件过滤 exp_target_cells
        exp_target_cells = filter_exp_target_cell(
            exp_target_cells,
            include_targets = include_targets,
            exclude_targets = exclude_targets,
            include_cell_types = include_cell_types,
            exclude_cell_types = exclude_cell_types
        )

        # 根据过滤后的 exp_target_cells 过滤 exp_target_cell_negatives
        filtered_exp_target_cell_negatives = list(filter(lambda el: el[0] in exp_target_cells, exp_target_cell_negatives))

        # 设置 self.exp_target_cell_negatives
        self.exp_target_cell_negatives = filtered_exp_target_cell_negatives
        # 断言筛选后的 exp_target_cell_negatives 数量大于 0
        assert len(self.exp_target_cell_negatives) > 0, 'no experiment-target-cell scoped negatives to select from after filtering'

        # 平衡目标采样

        self.balance_sampling_by_target = balance_sampling_by_target

        # 如果需要平衡采样
        if balance_sampling_by_target:
            # 初始化 exp_target_cell_by_target 字典
            self.exp_target_cell_by_target = defaultdict(list)

            # 根据 target 对 exp_target_cell_negatives 进行分组
            for exp_target_cell, filepath in self.exp_target_cell_negatives:
                _, target, *_ = parse_exp_target_cell(exp_target_cell)
                self.exp_target_cell_by_target[target].append((exp_target_cell, filepath))

        # tfactor 数据集

        self.factor_ds = FactorProteinDataset(factor_fasta_folder)

        # 初始化 fasta 文件和 experiments_index
        self.fasta = FastaInterval(fasta_file = fasta_file, **kwargs)
        self.experiments_index = fetch_experiments_index(experiments_json_path)

        # 上下文字符串创建器

        self.context_ds = ContextDataset(
            include_biotypes_metadata_in_context = include_biotypes_metadata_in_context,
            biotypes_metadata_path = biotypes_metadata_path,
            include_biotypes_metadata_columns = include_biotypes_metadata_columns,
            biotypes_metadata_delimiter = biotypes_metadata_delimiter
        )

    # 返回数据集长度
    def __len__(self):
        # 如果需要按目标平衡采样,则返回 exp_target_cell_by_target 的长度
        if self.balance_sampling_by_target:
            return len(self.exp_target_cell_by_target)
        # 否则返回 exp_target_cell_negatives 的长度
        else:
            return len(self.exp_target_cell_negatives)
    # 通过索引获取样本数据
    def __getitem__(self, idx):
        # 如果按目标进行平衡采样
        if self.balance_sampling_by_target:
            # 获取指定索引下的负样本列表
            negatives = list(self.exp_target_cell_by_target.values())[idx]
            # 从负样本列表中随机选择一个样本
            sample = choice(negatives)
        else:
            # 获取指定索引下的负样本
            sample = self.exp_target_cell_negatives[idx]

        # 解析实验、目标和细胞类型
        exp_target_cell, bool_numpy_path = sample
        experiment, target, cell_type = parse_exp_target_cell(exp_target_cell)

        # 加载布尔类型的 numpy 数组,并添加随机噪声
        np_arr = np.load(str(bool_numpy_path))
        np_arr_noised = np_arr.astype(np.float32) + np.random.uniform(low=-1e-1, high=1e-1, size=np_arr.shape[0])

        # 使用染色体掩码进行掩盖
        np_arr_noised *= self.chromosome_mask.astype(np.float32)

        # 选择随机的负峰值
        random_neg_peak_index = np_arr_noised.argmax()

        # 获取染色体名称、起始位置、结束位置和序列
        chr_name, begin, end, *_ = self.df.row(random_neg_peak_index)
        seq = self.fasta(chr_name, begin, end)

        # 获取目标对应的氨基酸序列和细胞类型对应的上下文字符串
        aa_seq = self.factor_ds[target]
        context_str = self.context_ds[cell_type]

        # 获取实验目标细胞对应的峰值数量,并转换为张量
        peaks_nr = self.experiments_index.get(exp_target_cell, 0.)
        peaks_nr = torch.Tensor([peaks_nr])

        # 初始化读取值和标签,并转换为张量
        read_value = torch.Tensor([0.])
        label = torch.Tensor([0.])

        # 返回序列、氨基酸序列、上下文字符串、峰值数量、读取值和标签
        return seq, aa_seq, context_str, peaks_nr, read_value, label
# 定义一个负样本数据集类 NegativePeakDataset,继承自 Dataset 类
class NegativePeakDataset(Dataset):
    # 初始化函数,接收多个参数
    def __init__(
        self,
        *,
        factor_fasta_folder,  # 因子 fasta 文件夹
        negative_bed_file = None,  # 负样本 bed 文件,默认为 None
        remap_bed_file = None,  # 重映射 bed 文件,默认为 None
        remap_df = None,  # 重映射数据框,默认为 None
        negative_df = None,  # 负样本数据框,默认为 None
        filter_chromosome_ids = None,  # 过滤染色体 ID 列表,默认为 None
        exclude_targets = None,  # 排除目标列表,默认为 None
        include_targets = None,  # 包含目标列表,默认为 None
        exclude_cell_types = None,  # 排除细胞类型列表,默认为 None
        include_cell_types = None,  # 包含细胞类型列表,默认为 None
        exp_target_cell_column = 'column_4',  # 实验-目标-细胞列,默认为 'column_4'
        experiments_json_path = None,  # 实验 JSON 路径,默认为 None
        include_biotypes_metadata_in_context = False,  # 在上下文中包含生物类型元数据,默认为 False
        biotypes_metadata_path = None,  # 生物类型元数据路径,默认为 None
        include_biotypes_metadata_columns = [],  # 包含生物类型元数据列,默认为空列表
        biotypes_metadata_delimiter = ' | ',  # 生物类型元数据分隔符,默认为 ' | '
        balance_sampling_by_target = False,  # 按目标平衡采样,默认为 False
        **kwargs  # 其他关键字参数
    ):
        super().__init__()  # 调用父类的初始化函数
        # 断言语句,判断 remap_df 和 remap_bed_file 必须有一个存在
        assert exists(remap_df) ^ exists(remap_bed_file), 'either remap bed file or remap dataframe must be passed in'
        # 断言语句,判断 negative_df 和 negative_bed_file 必须有一个存在
        assert exists(negative_df) ^ exists(negative_bed_file), 'either negative bed file or negative dataframe must be passed in'

        # 如果 remap_df 不存在,则从 remap_bed_file 读取数据框
        if not exists(remap_df):
            remap_df = read_bed(remap_bed_file)

        # 如果 negative_df 不存在,则从 negative_bed_file 读取数据框
        neg_df = negative_df
        if not exists(negative_df):
            neg_df = read_bed(negative_bed_file)

        # 过滤 remap 数据框
        remap_df = filter_df_by_tfactor_fastas(remap_df, factor_fasta_folder)

        # 设置数据集的染色体 ID
        dataset_chr_ids = CHR_IDS

        # 如果存在过滤染色体 ID,则更新数据集的染色体 ID
        if exists(filter_chromosome_ids):
            dataset_chr_ids = dataset_chr_ids.intersection(set(filter_chromosome_ids))

        # 根据染色体名过滤负样本数据框
        neg_df = neg_df.filter(pl_isin('column_1', get_chr_names(dataset_chr_ids)))

        # 断言语句,确保负样本数据框不为空
        assert len(neg_df) > 0, 'dataset is empty by filter criteria'

        self.neg_df = neg_df  # 设置负样本数据框

        # 获取所有实验-目标-细胞,并根据条件过滤
        exp_target_cells = remap_df.get_column(exp_target_cell_column).unique().to_list()

        self.filtered_exp_target_cells = filter_exp_target_cell(
            exp_target_cells,
            include_targets = include_targets,
            exclude_targets = exclude_targets,
            include_cell_types = include_cell_types,
            exclude_cell_types = exclude_cell_types
        )

        # 断言语句,确保还有实验-目标-细胞用于硬负样本集
        assert len(self.filtered_exp_target_cells), 'no experiment-target-cell left for hard negative set'

        # 如果需要按目标平衡采样
        self.balance_sampling_by_target = balance_sampling_by_target

        if balance_sampling_by_target:
            self.exp_target_cell_by_target = defaultdict(list)

            # 根据目标将实验-目标-细胞分组
            for exp_target_cell in self.filtered_exp_target_cells:
                _, target, *_ = parse_exp_target_cell(exp_target_cell)
                self.exp_target_cell_by_target[target].append(exp_target_cell)

        # 因子数据集
        self.factor_ds = FactorProteinDataset(factor_fasta_folder)
        self.fasta = FastaInterval(**kwargs)

        # 获取实验索引
        self.experiments_index = fetch_experiments_index(experiments_json_path)

        # 上下文字符串创建器
        self.context_ds = ContextDataset(
            include_biotypes_metadata_in_context = include_biotypes_metadata_in_context,
            biotypes_metadata_path = biotypes_metadata_path,
            include_biotypes_metadata_columns = include_biotypes_metadata_columns,
            biotypes_metadata_delimiter = biotypes_metadata_delimiter
        )

    # 返回负样本数据集的长度
    def __len__(self):
        return len(self.neg_df)
    # 重载 __getitem__ 方法,用于获取指定索引处的数据
    def __getitem__(self, ind):
        # 从 neg_df 数据框中获取指定索引处的染色体名称、起始位置和终止位置
        chr_name, begin, end = self.neg_df.row(ind)

        # 如果按目标平衡采样
        if self.balance_sampling_by_target:
            # 从 exp_target_cell_by_target 字典中随机选择一个目标细胞类型
            rand_ind = randrange(0, len(self.exp_target_cell_by_target))
            exp_target_cell_by_target_list = list(self.exp_target_cell_by_target.values())
            random_exp_target_cell_type = choice(exp_target_cell_by_target_list[rand_ind])
        else:
            # 从 filtered_exp_target_cells 列表中随机选择一个目标细胞类型
            random_exp_target_cell_type = choice(self.filtered_exp_target_cells)

        # 解析实验、目标和细胞类型
        experiment, target, cell_type = parse_exp_target_cell(random_exp_target_cell_type)

        # 获取指定染色体区间的序列
        seq = self.fasta(chr_name, begin, end)
        # 获取目标对应的氨基酸序列
        aa_seq = self.factor_ds[target]
        # 获取细胞类型对应的上下文字符串
        context_str = self.context_ds[cell_type]

        # 初始化读取值为 0 的张量
        read_value = torch.Tensor([0.])

        # 获取指定目标细胞类型的峰值数
        peaks_nr = self.experiments_index.get(random_exp_target_cell_type, 0.)
        # 将峰值数转换为张量
        peaks_nr = torch.Tensor([peaks_nr])

        # 初始化标签为 0 的张量
        label = torch.Tensor([0.])

        # 返回获取的序列、氨基酸序列、上下文字符串、峰值数、读取值和标签
        return seq, aa_seq, context_str, peaks_nr, read_value, label
# dataloader相关函数

# 将数据集中的数据按照不同的类型解压缩
def collate_fn(data):
    seq, aa_seq, context_str, peaks_nr, read_values, labels = list(zip(*data))
    return torch.stack(seq), tuple(aa_seq), tuple(context_str), torch.stack(peaks_nr, dim=0), torch.stack(read_values, dim=0), torch.cat(labels, dim=0)

# 将多个dataloader的输出合并为一个元组
def collate_dl_outputs(*dl_outputs):
    outputs = list(zip(*dl_outputs))
    ret = []
    for entry in outputs:
        if isinstance(entry[0], torch.Tensor):
            entry = torch.cat(entry, dim=0)
        else:
            entry = (sub_el for el in entry for sub_el in el)
        ret.append(entry)
    return tuple(ret)

# 无限循环生成dataloader中的数据
def cycle(loader):
    while True:
        for data in loader:
            yield data

# 获取dataloader对象
def get_dataloader(ds, cycle_iter=False, **kwargs):
    dataset_len = len(ds)
    batch_size = kwargs.get('batch_size')
    drop_last = dataset_len > batch_size

    # 创建DataLoader对象
    dl = DataLoader(ds, collate_fn=collate_fn, drop_last=drop_last, **kwargs)
    wrapper = cycle if cycle_iter else iter
    return wrapper(dl)

.\lucidrains\tf-bind-transformer\tf_bind_transformer\data_bigwig.py

# 导入必要的库
from pathlib import Path
import polars as pl
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader

# 导入自定义的数据集和函数
from tf_bind_transformer.data import FactorProteinDataset, ContextDataset, cast_list, filter_df_by_tfactor_fastas
from tf_bind_transformer.data import pl_isin, pl_notin, fetch_experiments_index, parse_exp_target_cell, read_bed, cycle, filter_by_col_isin
from tf_bind_transformer.data import CHR_IDS, CHR_NAMES, get_chr_names
from enformer_pytorch import FastaInterval

# 尝试导入 pyBigWig 库,如果导入失败则打印提示信息并退出程序
try:
    import pyBigWig
except ImportError:
    print('pyBigWig needs to be installed - conda install pyBigWig')
    exit()

# 定义一个函数,用于检查变量是否存在
def exists(val):
    return val is not None

# 定义一个函数,用于处理 CHIP ATLAS 数据集中的实验、目标和细胞类型信息
def chip_atlas_add_experiment_target_cell(
    df,
    col_target = 'column_4',
    col_cell_type = 'column_5'
):
    df = df.clone()

    # 提取目标信息并转换为大写格式
    targets = df.select(col_target)
    targets = targets.to_series(0).str.to_uppercase().rename('target')
    df.insert_at_idx(2, targets)

    # 提取细胞类型信息
    cell_type = df.select(col_cell_type)
    cell_type = cell_type.rename({col_cell_type: 'cell_type'}).to_series(0)
    df.insert_at_idx(2, cell_type)

    return df

# 定义一个数据集类,用于处理 BigWig 数据
class BigWigDataset(Dataset):
    def __init__(
        self,
        *,
        factor_fasta_folder,
        bigwig_folder,
        enformer_loci_path,
        fasta_file,
        annot_file = None,
        filter_chromosome_ids = None,
        exclude_targets = None,
        include_targets = None,
        exclude_cell_types = None,
        include_cell_types = None,
        df_frac = 1.,
        experiments_json_path = None,
        include_biotypes_metadata_in_context = False,
        biotypes_metadata_path = None,
        filter_sequences_by = None,
        include_biotypes_metadata_columns = [],
        biotypes_metadata_delimiter = ' | ',
        only_ref = ['mm10', 'hg38'],
        factor_species_priority = ['human', 'mouse'],
        downsample_factor = 128,
        target_length = 896,
        bigwig_reduction_type = 'sum',
        **kwargs
    # 初始化函数,继承父类的初始化方法
    def __init__(
        super().__init__()
        # 断言注释文件存在
        assert exists(annot_file) 

        # 如果 bigwig 文件夹不存在,则设置为无效,目标数量为 0
        if not exists(bigwig_folder):
            self.invalid = True
            self.ntargets = 0
            return

        # 将 bigwig 文件夹路径转换为 Path 对象
        bigwig_folder = Path(bigwig_folder)
        # 断言 bigwig 文件夹存在
        assert bigwig_folder.exists(), 'bigwig folder does not exist'

        # 获取 bigwig 文件夹下所有的 .bw 文件名列表
        bw_experiments = [p.stem for p in bigwig_folder.glob('*.bw')]
        # 断言至少有一个 bigwig 文件存在
        assert len(bw_experiments) > 0, 'no bigwig files found in bigwig folder'

        # 读取 enformer_loci_path 中的 loci 数据
        loci = read_bed(enformer_loci_path)
        # 读取 annot_file 中的注释数据
        annot_df = pl.read_csv(annot_file, sep = "\t", has_headers = False, columns = list(map(lambda i: f'column_{i + 1}', range(17))))

        # 根据 only_ref 列表中的值筛选 annot_df
        annot_df = annot_df.filter(pl_isin('column_2', only_ref))
        # 根据 bw_experiments 列表中的值筛选 annot_df
        annot_df = filter_by_col_isin(annot_df, 'column_1', bw_experiments)

        # 如果 df_frac 小于 1,则对 annot_df 进行采样
        if df_frac < 1:
            annot_df = annot_df.sample(frac = df_frac)

        # 初始化 dataset_chr_ids 为 CHR_IDS
        dataset_chr_ids = CHR_IDS

        # 如果 filter_chromosome_ids 存在,则更新 dataset_chr_ids
        if exists(filter_chromosome_ids):
            dataset_chr_ids = dataset_chr_ids.intersection(set(filter_chromosome_ids))

        # 根据 dataset_chr_ids 中的值筛选 loci
        loci = loci.filter(pl_isin('column_1', get_chr_names(dataset_chr_ids)))

        # 如果 filter_sequences_by 存在,则根据其值筛选 loci
        if exists(filter_sequences_by):
            col_name, col_val = filter_sequences_by
            loci = loci.filter(pl.col(col_name) == col_val)

        # 初始化 FactorProteinDataset 对象
        self.factor_ds = FactorProteinDataset(factor_fasta_folder, species_priority = factor_species_priority)

        # 获取 annot_df 中 column_1 列的唯一值集合
        exp_ids = set(annot_df.get_column('column_1').to_list())

        # 添加实验目标细胞到 annot_df
        annot_df = chip_atlas_add_experiment_target_cell(annot_df)
        # 根据 factor_fasta_folder 筛选 annot_df
        annot_df = filter_df_by_tfactor_fastas(annot_df, factor_fasta_folder)

        # 获取筛选后的 exp_ids
        filtered_exp_ids = set(annot_df.get_column('column_1').to_list())

        # 计算被筛选掉的 exp_ids
        filtered_out_exp_ids = exp_ids - filtered_exp_ids
        print(f'{", ".join(only_ref)} - {len(filtered_out_exp_ids)} experiments filtered out by lack of transcription factor fastas', filtered_out_exp_ids)

        # 根据 include_targets 和 exclude_targets 筛选 annot_df
        include_targets = cast_list(include_targets)
        exclude_targets = cast_list(exclude_targets)

        if include_targets:
            annot_df = annot_df.filter(pl_isin('target', include_targets))

        if exclude_targets:
            annot_df = annot_df.filter(pl_notin('target', exclude_targets))

        # 根据 include_cell_types 和 exclude_cell_types 筛选 annot_df
        include_cell_types = cast_list(include_cell_types)
        exclude_cell_types = cast_list(exclude_cell_types)

        if include_cell_types:
            annot_df = annot_df.filter(pl_isin('cell_type', include_cell_types))

        if exclude_cell_types:
            annot_df = annot_df.filter(pl_notin('cell_type', exclude_cell_types))

        # 初始化 FastaInterval 对象
        self.fasta = FastaInterval(fasta_file = fasta_file, **kwargs)

        # 设置 self.df 和 self.annot
        self.df = loci
        self.annot = annot_df
        self.ntargets = self.annot.shape[0]

        # 初始化 bigwigs 列表
        self.bigwigs = [pyBigWig.open(str(bigwig_folder / f'{str(i)}.bw')) for i in self.annot.get_column("column_1")]

        # 设置 downsample_factor 和 target_length
        self.downsample_factor = downsample_factor
        self.target_length = target_length

        # 设置 bigwig_reduction_type 和 invalid
        self.bigwig_reduction_type = bigwig_reduction_type
        self.invalid = False

    # 返回数据集的长度
    def __len__(self):
        # 如果数据集无效,则长度为 0
        if self.invalid:
            return 0

        # 返回数据集的长度
        return len(self.df) * self.ntargets
    # 从自定义类中获取指定索引的元素
    def __getitem__(self, ind):
        # TODO 返回一个个体的所有目标
        # 从数据框中获取指定索引的染色体名称、起始位置、终止位置和其他信息
        chr_name, begin, end, _ = self.df.row(ind % self.df.shape[0])

        # 从注释中选择目标和细胞类型,并转换为 Series 对象
        targets = self.annot.select('target').to_series(0)
        cell_types = self.annot.select('cell_type').to_series(0)

        # 计算目标在列表中的索引
        ix_target = ind // self.df.shape[0]
    
        # 从列表中获取目标、细胞类型和 bigwig 对象
        target = targets[ix_target]
        context_str = cell_types[ix_target]
        exp_bw = self.bigwigs[ix_target]

        # 获取目标对应的氨基酸序列和基因组序列
        aa_seq = self.factor_ds[target]
        seq = self.fasta(chr_name, begin, end)

        # 计算 bigwig 数据
        output = np.array(exp_bw.values(chr_name, begin, end))
        output = output.reshape((-1, self.downsample_factor))

        # 根据指定的 bigwig 缩减类型进行处理
        if self.bigwig_reduction_type == 'mean':
            om = np.nanmean(output, axis = 1)
        elif self.bigwig_reduction_type == 'sum':
            om = np.nansum(output, axis = 1)
        else:
            raise ValueError(f'unknown reduction type {self.bigwig_reduction_type}')

        # 获取输出数据的长度
        output_length = output.shape[0]

        # 检查输出长度是否小于目标长度
        if output_length < self.target_length:
            assert f'target length {self.target_length} cannot be less than the {output_length}'

        # 计算需要裁剪的部分
        trim = (output.shape[0] - self.target_length) // 2
        om = om[trim:-trim]

        # 将 NaN 值替换为 0
        np.nan_to_num(om, copy = False)

        # 创建 PyTorch 张量作为标签
        label = torch.Tensor(om)
        return seq, aa_seq, context_str, label
# BigWig 数据集,仅包含轨迹

class BigWigTracksOnlyDataset(Dataset):
    def __init__(
        self,
        *,
        bigwig_folder,  # BigWig 文件夹路径
        enformer_loci_path,  # Enformer loci 路径
        fasta_file,  # FASTA 文件路径
        ref,  # 参考
        annot_file = None,  # 注释文件,默认为空
        filter_chromosome_ids = None,  # 过滤染色体 ID,默认为空
        downsample_factor = 128,  # 下采样因子,默认为 128
        target_length = 896,  # 目标长度,默认为 896
        bigwig_reduction_type = 'sum',  # BigWig 减少类型,默认为 'sum'
        filter_sequences_by = None,  # 过滤序列,默认为空
        **kwargs
    ):
        super().__init__()
        assert exists(annot_file)

        if not exists(bigwig_folder):
            self.invalid = True
            self.ntargets = 0
            return

        bigwig_folder = Path(bigwig_folder)
        assert bigwig_folder.exists(), 'bigwig folder does not exist'

        bw_experiments = [p.stem for p in bigwig_folder.glob('*.bw')]
        assert len(bw_experiments) > 0, 'no bigwig files found in bigwig folder'

        loci = read_bed(enformer_loci_path)

        annot_df = pl.read_csv(annot_file, sep = "\t", has_headers = False, columns = list(map(lambda i: f'column_{i + 1}', range(17))))

        annot_df = annot_df.filter(pl.col('column_2') == ref)
        annot_df = filter_by_col_isin(annot_df, 'column_1', bw_experiments)

        dataset_chr_ids = CHR_IDS

        if exists(filter_chromosome_ids):
            dataset_chr_ids = dataset_chr_ids.intersection(set(filter_chromosome_ids))

        # filtering loci by chromosomes
        # as well as training or validation

        loci = loci.filter(pl_isin('column_1', get_chr_names(dataset_chr_ids)))

        if exists(filter_sequences_by):
            col_name, col_val = filter_sequences_by
            loci = loci.filter(pl.col(col_name) == col_val)

        self.fasta = FastaInterval(fasta_file = fasta_file, **kwargs)

        self.df = loci
        self.annot = annot_df
        self.ntargets = self.annot.shape[0]

        # bigwigs

        self.bigwigs = [(str(i), pyBigWig.open(str(bigwig_folder / f'{str(i)}.bw'))) for i in self.annot.get_column("column_1")]
        
        self.downsample_factor = downsample_factor
        self.target_length = target_length

        self.bigwig_reduction_type = bigwig_reduction_type
        self.invalid = False

    def __len__(self):
        if self.invalid:
            return 0

        return len(self.df) * int(self.ntargets > 0)

    def __getitem__(self, ind):
        chr_name, begin, end, _ = self.df.row(ind)

        # figure out ref and fetch appropriate sequence

        seq = self.fasta(chr_name, begin, end)

        # calculate bigwig
        # properly downsample and then crop

        all_bw_values = []

        for bw_path, bw in self.bigwigs:
            try:
                bw_values = bw.values(chr_name, begin, end)
                all_bw_values.append(bw_values)
            except:
                print(f'hitting invalid range for {bw_path} - ({chr_name}, {begin}, {end})')
                exit()

        output = np.stack(all_bw_values, axis = -1)
        output = output.reshape((-1, self.downsample_factor, self.ntargets))

        if self.bigwig_reduction_type == 'mean':
            om = np.nanmean(output, axis = 1)
        elif self.bigwig_reduction_type == 'sum':
            om = np.nansum(output, axis = 1)
        else:
            raise ValueError(f'unknown reduction type {self.bigwig_reduction_type}')

        output_length = output.shape[0]

        if output_length < self.target_length:
            assert f'target length {self.target_length} cannot be less than the {output_length}'

        trim = (output.shape[0] - self.target_length) // 2
        om = om[trim:-trim]

        np.nan_to_num(om, copy = False)

        label = torch.Tensor(om)
        return seq, label

# 数据加载器

def bigwig_collate_fn(data):
    seq, aa_seq, context_str, labels = list(zip(*data))
    return torch.stack(seq), tuple(aa_seq), tuple(context_str), torch.stack(labels)

def get_bigwig_dataloader(ds, cycle_iter = False, **kwargs):
    dataset_len = len(ds)
    # 从参数中获取批量大小
    batch_size = kwargs.get('batch_size')
    # 检查数据集长度是否大于批量大小,以确定是否丢弃最后一批数据
    drop_last = dataset_len > batch_size

    # 使用DataLoader加载数据集,指定数据集、数据处理函数、是否丢弃最后一批数据以及其他参数
    dl = DataLoader(ds, collate_fn = bigwig_collate_fn, drop_last = drop_last, **kwargs)
    # 根据是否循环迭代选择返回迭代器或循环迭代器
    wrapper = cycle if cycle_iter else iter
    # 返回包装后的数据加载器
    return wrapper(dl)
# 定义一个函数,用于获取包含大WIG轨迹数据的数据加载器
def get_bigwig_tracks_dataloader(ds, cycle_iter = False, **kwargs):
    # 获取数据集的长度
    dataset_len = len(ds)
    # 获取批处理大小
    batch_size = kwargs.get('batch_size')
    # 如果数据集长度大于批处理大小,则设置为True,否则为False
    drop_last = dataset_len > batch_size

    # 创建一个数据加载器,根据是否丢弃最后一批数据进行设置
    dl = DataLoader(ds, drop_last = drop_last, **kwargs)
    # 根据cycle_iter参数选择返回数据加载器的迭代器类型
    wrapper = cycle if cycle_iter else iter
    # 返回迭代器类型的数据加载器
    return wrapper(dl)

.\lucidrains\tf-bind-transformer\tf_bind_transformer\gene_utils.py

# 用于获取转录因子序列的代码

# 定义基因标识映射,将'RXR'映射为'RXRA'
GENE_IDENTIFIER_MAP = {
    'RXR': 'RXRA'
}

# 包含连字符的基因名称集合
NAMES_WITH_HYPHENS = {
    'NKX3-1',
    'NKX2-1',
    'NKX2-5',
    'SS18-SSX'
}

# 解析基因名称的函数
def parse_gene_name(name):
    # 如果名称中不包含连字符或者名称在NAMES_WITH_HYPHENS中,则直接返回名称
    if '-' not in name or name in NAMES_WITH_HYPHENS:
        name = GENE_IDENTIFIER_MAP.get(name, name)

        # 如果名称中包含下划线,则只搜索下划线左侧的目标因子名称
        if '_' in name:
            name, *_ = name.split('_')

        return (name,)

    # 如果名称中包含连字符,则按照一定规则解析名称
    first, *rest = name.split('-')

    parsed_rest = []

    for name in rest:
        if len(name) == 1:
            name = f'{first[:-1]}{name}'
        parsed_rest.append(name)

    return tuple([first, *parsed_rest])

.\lucidrains\tf-bind-transformer\tf_bind_transformer\optimizer.py

# 从 torch.optim 模块中导入 AdamW 优化器
from torch.optim import AdamW

# 将参数分为可进行权重衰减和不可进行权重衰减的参数
def separate_weight_decayable_params(params):
    # 找出参数中维度小于 2 的参数,即不可进行权重衰减的参数
    no_wd_params = set([param for param in params if param.ndim < 2])
    # 可进行权重衰减的参数为所有参数减去不可进行权重衰减的参数
    wd_params = set(params) - no_wd_params
    return wd_params, no_wd_params

# 根据参数和超参数创建 AdamW 优化器
def get_optimizer(params, lr = 3e-4, wd = 1e-1, filter_by_requires_grad = False):
    # 如果需要根据 requires_grad 过滤参数,则只保留 requires_grad 为 True 的参数
    if filter_by_requires_grad:
        params = list(filter(lambda t: t.requires_grad, params))

    # 将参数转换为集合
    params = set(params)
    # 将参数分为可进行权重衰减和不可进行权重衰减的参数
    wd_params, no_wd_params = separate_weight_decayable_params(params)

    # 构建参数组,其中可进行权重衰减的参数使用默认权重衰减,不可进行权重衰减的参数不使用权重衰减
    param_groups = [
        {'params': list(wd_params)},
        {'params': list(no_wd_params), 'weight_decay': 0},
    ]

    # 返回使用 AdamW 优化器的参数组和超参数 lr 和 wd 的优化器
    return AdamW(param_groups, lr = lr, weight_decay = wd)

.\lucidrains\tf-bind-transformer\tf_bind_transformer\protein_utils.py

# 导入所需的库
import torch
import os
import re
from pathlib import Path
from functools import partial
import esm
from torch.nn.utils.rnn import pad_sequence
from transformers import AlbertTokenizer, AutoModelForMaskedLM, logging
from tf_bind_transformer.cache_utils import cache_fn, run_once, md5_hash_fn

# 定义函数,判断值是否存在
def exists(val):
    return val is not None

# 定义函数,对字典中的值应用给定函数
def map_values(fn, dictionary):
    return {k: fn(v) for k, v in dictionary.items()}

# 定义函数,将张量移动到指定设备
def to_device(t, *, device):
    return t.to(device)

# 定义函数,将输入转换为元组
def cast_tuple(t):
    return (t,) if not isinstance(t, tuple) else t

# 检查是否设置了环境变量 PROTEIN_EMBED_USE_CPU
PROTEIN_EMBED_USE_CPU = os.getenv('PROTEIN_EMBED_USE_CPU', None) is not None

# 如果设置了 PROTEIN_EMBED_USE_CPU,则打印提示信息
if PROTEIN_EMBED_USE_CPU:
    print('calculating protein embed only on cpu')

# 全局变量
GLOBAL_VARIABLES = {
    'model': None,
    'tokenizer': None
}

# 计算蛋白质表示与亚单位
def calc_protein_representations_with_subunits(proteins, get_repr_fn, *, device):
    representations = []

    for subunits in proteins:
        subunits = cast_tuple(subunits)
        subunits_representations = list(map(get_repr_fn, subunits))
        subunits_representations = list(map(partial(to_device, device=device), subunits_representations))
        subunits_representations = torch.cat(subunits_representations, dim=0)
        representations.append(subunits_representations)

    lengths = [seq_repr.shape[0] for seq_repr in representations]
    masks = torch.arange(max(lengths), device=device)[None, :] < torch.tensor(lengths, device=device)[:, None]
    padded_representations = pad_sequence(representations, batch_first=True)

    return padded_representations.to(device), masks.to(device)

# ESM 相关函数
ESM_MAX_LENGTH = 1024
ESM_EMBED_DIM = 1280

# 映射整数到氨基酸字符串的字典
INT_TO_AA_STR_MAP = {
    0: 'A',
    1: 'C',
    2: 'D',
    3: 'E',
    4: 'F',
    5: 'G',
    6: 'H',
    7: 'I',
    8: 'K',
    9: 'L',
    10: 'M',
    11: 'N',
    12: 'P',
    13: 'Q',
    14: 'R',
    15: 'S',
    16: 'T',
    17: 'V',
    18: 'W',
    19: 'Y',
    20: '_'
}

# 将张量转换为氨基酸字符串
def tensor_to_aa_str(t):
    str_seqs = []
    for int_seq in t.unbind(dim=0):
        str_seq = list(map(lambda t: INT_TO_AA_STR_MAP[t] if t != 20 else '', int_seq.tolist()))
        str_seqs.append(''.join(str_seq))
    return str_seqs

# 初始化 ESM 模型
@run_once('init_esm')
def init_esm():
    model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
    batch_converter = alphabet.get_batch_converter()

    if not PROTEIN_EMBED_USE_CPU:
        model = model.cuda()

    GLOBAL_VARIABLES['model'] = (model, batch_converter)

# 获取单个蛋白质的 ESM 表示
def get_single_esm_repr(protein_str):
    init_esm()
    model, batch_converter = GLOBAL_VARIABLES['model']

    data = [('protein', protein_str)]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)

    if batch_tokens.shape[1] > ESM_MAX_LENGTH:
        print(f'warning max length protein esm: {protein_str}')

    batch_tokens = batch_tokens[:, :ESM_MAX_LENGTH]

    if not PROTEIN_EMBED_USE_CPU:
        batch_tokens = batch_tokens.cuda()

    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33])

    token_representations = results['representations'][33]
    representation = token_representations[0][1: len(protein_str) + 1]
    return representation

# 获取多个蛋白质的 ESM 表示
def get_esm_repr(proteins, device):
    if isinstance(proteins, torch.Tensor):
        proteins = tensor_to_aa_str(proteins)

    get_protein_repr_fn = cache_fn(get_single_esm_repr, path='esm/proteins')

    return calc_protein_representations_with_subunits(proteins, get_protein_repr_fn, device=device)

# PROT-ALBERT 2048 上下文长度
PROT_ALBERT_PATH = 'Rostlab/prot_albert'
PROT_ALBERT_DIM = 4096
PROT_ALBERT_MAX_LENGTH = 2048

# 将蛋白质字符串中的特殊字符替换为空格
def protein_str_with_spaces(protein_str):
    protein_str = re.sub(r"[UZOB]", 'X', protein_str)
    return ' '.join([*protein_str])

# 初始化 PROT-ALBERT 模型
@run_once('init_prot_albert')
def init_prot_albert():
    GLOBAL_VARIABLES['tokenizer'] = AlbertTokenizer.from_pretrained(PROT_ALBERT_PATH, do_lower_case=False)
    # 从预训练的 ALBERT 模型中加载用于 Masked Language Modeling 的模型
    model = AutoModelForMaskedLM.from_pretrained(PROT_ALBERT_PATH)
    
    # 如果不使用 CPU 运行蛋白质嵌入模型,则将模型移动到 GPU 上
    if not PROTEIN_EMBED_USE_CPU:
        model = model.cuda()
    
    # 将加载的模型存储在全局变量中
    GLOBAL_VARIABLES['model'] = model
# 获取单个蛋白质的 ALBERT 表示
def get_single_prot_albert_repr(
    protein_str,
    max_length = PROT_ALBERT_MAX_LENGTH,
    hidden_state_index = -1
):
    # 初始化 ALBERT 模型
    init_prot_albert()
    # 获取全局变量中的模型和分词器
    model = GLOBAL_VARIABLES['model']
    tokenizer = GLOBAL_VARIABLES['tokenizer']

    # 对蛋白质字符串进行编码
    encoding = tokenizer.batch_encode_plus(
        [protein_str_with_spaces(protein_str)],
        add_special_tokens = True,
        padding = True,
        truncation = True,
        max_length = max_length,
        return_attention_mask = True,
        return_tensors = 'pt'
    )

    # 如果不使用 CPU 进行蛋白质嵌入
    if not PROTEIN_EMBED_USE_CPU:
        encoding = map_values(lambda t: t.cuda(), encoding)

    # 将模型设置为评估模式
    model.eval()
    # 禁用梯度计算
    with torch.no_grad():
        # 获取模型输出
        outputs = model(**encoding, output_hidden_states = True)

    # 获取隐藏状态
    hidden_state = outputs.hidden_states[hidden_state_index][0]
    return hidden_state

# 获取蛋白质 ALBERT 表示
def get_prot_albert_repr(
    proteins,
    device,
    max_length = PROT_ALBERT_MAX_LENGTH,
    hidden_state_index = -1
):
    # 如果输入为字符串,则转换为列表
    if isinstance(proteins, str):
        proteins = [proteins]

    # 如果输入为张量,则转换为氨基酸字符串
    if isinstance(proteins, torch.Tensor):
        proteins = tensor_to_aa_str(proteins)

    # 缓存单个蛋白质 ALBERT 表示的函数
    get_protein_repr_fn = cache_fn(get_single_prot_albert_repr, path = f'proteins/prot_albert')

    # 计算蛋白质表示
    return calc_protein_representations_with_subunits(proteins, get_protein_repr_fn, device = device)

# alphafold2 函数

# 定义最大长度和嵌入维度
AF2_MAX_LENGTH = 2500
AF2_EMBEDDING_DIM = 384

# 设置 AF2_DIRECTORY 路径
AF2_DIRECTORY = os.getenv('TF_BIND_AF2_DIRECTORY', os.path.expanduser('~/.cache.tf.bind.transformer/.af2_embeddings'))
AF2_DIRECTORY_PATH = Path(AF2_DIRECTORY)

# 获取单个 alphafold2 表示
def get_single_alphafold2_repr(
    protein_str,
    max_length = AF2_MAX_LENGTH,
):
    # 计算蛋白质字符串的 MD5 哈希值
    md5 = md5_hash_fn(protein_str)
    embedding_path = AF2_DIRECTORY_PATH / f'{md5}.pt'
    assert embedding_path.exists(), f'af2 embedding not found for {protein_str}'

    # 加载嵌入张量
    tensor = torch.load(str(embedding_path))
    return tensor[:max_length]

# 获取 alphafold2 表示
def get_alphafold2_repr(
    proteins,
    device,
    max_length = AF2_MAX_LENGTH,
    **kwargs
):
    representations = []

    for subunits in proteins:
        subunits = cast_tuple(subunits)
        subunits = list(map(lambda t: get_single_alphafold2_repr(t, max_length = max_length), subunits))
        subunits = torch.cat(subunits, dim = 0)
        representations.append(subunits)

    lengths = [seq_repr.shape[0] for seq_repr in representations]
    masks = torch.arange(max(lengths), device = device)[None, :] <  torch.tensor(lengths, device = device)[:, None]
    padded_representations = pad_sequence(representations, batch_first = True)

    return padded_representations.to(device), masks.to(device)

# 工厂函数

# 定义蛋白质表示配置
PROTEIN_REPR_CONFIG = {
    'esm': {
        'dim': ESM_EMBED_DIM,
        'fn': get_esm_repr
    },
    'protalbert': {
        'dim': PROT_ALBERT_DIM,
        'fn': get_prot_albert_repr
    },
    'alphafold2': {
        'dim': AF2_EMBEDDING_DIM,
        'fn': get_alphafold2_repr
    }
}

# 获取蛋白质嵌入器
def get_protein_embedder(name):
    allowed_protein_embedders = list(PROTEIN_REPR_CONFIG.keys())
    assert name in allowed_protein_embedders, f"must be one of {', '.join(allowed_protein_embedders)}"

    config = PROTEIN_REPR_CONFIG[name]
    return config

.\lucidrains\tf-bind-transformer\tf_bind_transformer\tf_bind_transformer.py

# 导入必要的库
import copy
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from functools import wraps

# 导入 einops 库中的函数
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce

# 导入 contextlib 库中的 contextmanager 函数
from contextlib import contextmanager

# 导入自定义的 Enformer 模型和相关函数
from enformer_pytorch import Enformer
from enformer_pytorch.modeling_enformer import poisson_loss, pearson_corr_coef
from enformer_pytorch.finetune import freeze_batchnorms_, freeze_all_but_layernorms_, unfreeze_last_n_layers_, unfreeze_all_layers_

# 导入 logavgexp 库中的函数
from logavgexp_pytorch import logavgexp

# 导入自定义的缓存函数和一些工具函数
from tf_bind_transformer.cache_utils import cache_fn
from tf_bind_transformer.protein_utils import get_protein_embedder
from tf_bind_transformer.context_utils import get_text_repr, get_contextual_dim

# 导入自定义的注意力机制相关类
from tf_bind_transformer.attention import FeedForward, JointCrossAttentionBlock, CrossAttention, SelfAttentionBlock

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 返回函数本身
def identity(fn, *args, **kwargs):
    return fn

# 空上下文管理器
@contextmanager
def null_context():
    yield

# 张量操作函数

# 对张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# 根据概率生成掩码
def prob_mask_like(t, prob):
    return torch.zeros_like(t).float().uniform_(0, 1) < prob

# 对输入进行傅立叶编码
def fourier_encode(x, dims, theta = 20000):
    device, dtype = x.device, x.dtype
    emb = math.log(theta) / (dims // 2)
    emb = torch.exp(torch.arange(dims // 2, device = device) * -emb)
    emb = rearrange(x, 'n -> n 1') * rearrange(emb, 'd -> 1 d')
    emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
    return emb

# 计算相关系数损失
def corr_coef_loss(pred, target):
    return 1 - pearson_corr_coef(pred, target).mean()

# 缓存 Enformer 前向传播结果的装饰器

def cache_enformer_forward(fn):
    cached_forward = cache_fn(fn, clear = True, path = 'genetic')

    @wraps(fn)
    def inner(seqs, *args, **kwargs):
        if seqs.ndim == 3:
            seqs = seqs.argmax(dim = -1)

        seq_list = seqs.unbind(dim = 0)
        seq_cache_keys = [''.join(list(map(str, one_seq.tolist()))) for one_seq in seq_list]
        outputs = [cached_forward(one_seq, *args, __cache_key = seq_cache_key, **kwargs) for one_seq, seq_cache_key in zip(seq_list, seq_cache_keys)]
        return torch.stack(outputs)

    return inner

# 模型

# FiLM 模块
class FiLM(nn.Module):
    def __init__(
        self,
        dim,
        conditioned_dim
    ):
        super().__init__()
        self.to_gamma = nn.Linear(dim, conditioned_dim)
        self.to_bias = nn.Linear(dim, conditioned_dim)

    def forward(self, x, condition, mask = None):
        gamma = self.to_gamma(condition)
        bias = self.to_bias(condition)

        x = x * rearrange(gamma, 'b d -> b 1 d')
        x = x + rearrange(bias, 'b d -> b 1 d')
        return x

# SqueezeExcitation 模块
class SqueezeExcitation(nn.Module):
    def __init__(
        self,
        dim,
        conditioned_dim,
        eps = 1e-8
    ):
        super().__init__()
        self.eps = eps
        self.to_gate = nn.Linear(dim + conditioned_dim, conditioned_dim)

    def forward(self, x, condition, mask = None):
        if exists(mask):
            numer = x.masked_fill(mask[..., None], 0.).sum(dim = 1)
            denom = mask.sum(dim = 1)[..., None].clamp(min = self.eps)
            mean_x = numer / denom
        else:
            mean_x = x.mean(dim = 1)

        condition = torch.cat((condition, mean_x), dim = -1)
        gate = self.to_gate(condition)

        x = x * rearrange(gate, 'b d -> b 1 d').sigmoid()
        return x

# 用于计算辅助损失的 ReadValueMLP 类
class ReadValueMLP(nn.Module):
    def __init__(
        self,
        dim,
        *,
        fourier_dims = 256,
        norm_factor_fourier = 50,
        norm_factor_linear = 8000,
        eps = 1e-20
    # 初始化函数,设置模型参数
    def __init__(
        self,
        eps,
        fourier_dims,
        norm_factor_fourier,
        norm_factor_linear
    ):
        # 调用父类初始化函数
        super().__init__()
        # 设置模型参数
        self.eps = eps
        self.fourier_dims = fourier_dims
        self.norm_factor_fourier = norm_factor_fourier
        self.norm_factor_linear = norm_factor_linear

        # 定义 logits 的归一化层
        self.logits_norm = nn.Sequential(
            Reduce('b n d -> b d', 'mean'),  # 对 logits 进行平均池化
            nn.LayerNorm(dim)  # 对结果进行 LayerNorm
        )

        # 定义 MLP 网络
        self.mlp = nn.Sequential(
            nn.Linear(dim + fourier_dims + 2, dim * 2),  # 线性层
            nn.GELU(),  # GELU 激活函数
            nn.Linear(dim * 2, 1),  # 线性层
            Rearrange('... 1 -> ...')  # 重新排列维度
        )

    # 前向传播函数
    def forward(self, logits, peaks_nr, read_value):
        # 对 logits 进行归一化
        logits = self.logits_norm(logits)

        # 对 peaks_nr 进行对数变换
        peaks_nr_log_space = torch.log(peaks_nr + self.eps)

        # 重新排列 peaks_nr 的维度
        peaks_nr = rearrange(peaks_nr, '... -> (...)')
        # 对 peaks_nr 进行傅立叶编码
        peaks_nr_encoded = fourier_encode(peaks_nr / self.norm_factor_fourier, self.fourier_dims)
        # 对 peaks_nr 进行归一化
        peaks_nr_normed = rearrange(peaks_nr, '... -> ... 1') / self.norm_factor_linear

        # 将 peaks_nr_normed、peaks_nr_log_space、peaks_nr_encoded 拼接在一起
        peaks_nr_encoded_with_self = torch.cat((peaks_nr_normed, peaks_nr_log_space, peaks_nr_encoded), dim = -1)

        # 将 logits 和 peaks_nr_encoded_with_self 拼接在一起
        logits_with_peaks = torch.cat((logits, peaks_nr_encoded_with_self), dim = -1)

        # 通过 MLP 网络得到预测值
        pred = self.mlp(logits_with_peaks)
        # 重新排列 read_value 的维度
        read_value = rearrange(read_value, '... -> (...)')

        # 返回 Smooth L1 损失
        return F.smooth_l1_loss(pred, read_value)
# 定义一个名为 HypergridLinear 的类,继承自 nn.Module
class HypergridLinear(nn.Module):
    # 初始化函数,接受输入维度 dim、输出维度 dim_out 和上下文维度 context_dim
    def __init__(
        self,
        dim,
        dim_out,
        *,
        context_dim
    ):
        super().__init__()
        # 定义权重参数,使用随机初始化
        self.weights = nn.Parameter(torch.randn(dim, dim_out))
        # 定义上下文投影层,使用线性变换
        self.contextual_projection = nn.Linear(context_dim, dim * dim_out)

    # 前向传播函数,接受输入 x 和上下文 context
    def forward(self, x, context):
        # 推导上下文门控,参考超网格论文
        gating = self.contextual_projection(context).sigmoid()
        gating = rearrange(gating, 'b (i o) -> b i o', i = int(math.sqrt(gating.shape[-1])))
        
        # 门控交互投影与上下文
        to_logits_w = rearrange(self.weights, 'i o -> 1 i o') * gating
        return einsum('b n d, b d e -> b n e', x, to_logits_w)

# 定义一个名为 FILIP 的类,继承自 nn.Module
class FILIP(nn.Module):
    # 初始化函数,接受输入维度 dim、上下文维度 context_dim、头数 heads、头维度 dim_head、dropout 概率
    def __init__(
        self,
        dim,
        context_dim,
        heads,
        dim_head = 64,
        dropout = 0.
    ):
        super().__init__()
        self.heads = heads
        inner_latent_dim = heads * dim_head

        # 定义转换到潜在空间的权重和偏置
        self.to_latent_w = nn.Parameter(torch.randn(dim, inner_latent_dim))
        self.to_latent_b = nn.Parameter(torch.randn(inner_latent_dim))

        self.pre_attn_dropout = dropout

        # 定义空上下文和上下文到潜在空间的权重和偏置
        self.null_context = nn.Parameter(torch.randn(heads, dim_head))
        self.context_to_latent_w = nn.Parameter(torch.randn(context_dim, inner_latent_dim))
        self.context_to_latent_b = nn.Parameter(torch.randn(inner_latent_dim))

    # 前向传播函数,接受输入 x、上下文 context 和上下文掩码 context_mask
    def forward(
        self,
        x,
        context,
        context_mask = None
    ):
        b, heads, device = x.shape[0], self.heads, x.device

        x = einsum('b n d, d e -> b n e', x, self.to_latent_w)
        x = x + self.to_latent_b

        x = rearrange(x, 'b n (h d) -> b h n d', h = heads)

        context = einsum('b n d, d e -> b n e', context, self.context_to_latent_w)
        context = context + self.context_to_latent_b

        context = rearrange(context, 'b n (h d) -> b h n d', h = heads)

        context, x = map(l2norm, (context, x))

        # DNA 和蛋白质序列之间的细粒度交互,参考 FILIP 论文
        if x.shape[0] == 1:
            x = rearrange(x, '1 ... -> ...')
            einsum_eq = 'h i d, b h j d -> b h i j'
        else:
            einsum_eq = 'b h i d, b h j d -> b h i j'

        # 如果上下文掩码不存在,则创建一个全为 True 的掩码
        if not exists(context_mask):
            context_mask = torch.ones((b, context.shape[-1]), device = device).bool()

        # 根据 dropout 概率生成掩码
        if self.training:
            keep_mask = prob_mask_like(context_mask, 1 - self.pre_attn_dropout)
            context_mask = context_mask & keep_mask

        # 添加空上下文并修改掩码
        context_mask = F.pad(context_mask, (1, 0), value = True)
        context_mask = rearrange(context_mask, 'b j -> b 1 1 j')

        null_context = repeat(self.null_context, 'h d -> b h 1 d', b = b)
        context = torch.cat((null_context, context), dim = -2)

        # 可微分最大化,参考 FILIP 论文
        interactions = einsum(einsum_eq, x, context)
        interactions = logavgexp(interactions, mask = context_mask, dim = -1, temp = 0.05)
        interactions = rearrange(interactions, 'b h i -> b i h')
        return interactions

# 定义一个名为 AdapterModel 的类,继承自 nn.Module
class AdapterModel(nn.Module):
    # 初始化函数,设置模型的各种参数
    def __init__(
        self,
        *,
        enformer,  # enformer 模型
        latent_dim = 64,  # 潜在维度,默认为 64
        latent_heads = 32,  # 潜在头数,默认为 32
        aa_embed_dim = None,  # 氨基酸嵌入维度,默认为 None
        aa_embed_encoder = 'esm',  # 氨基酸嵌入编码器,默认为 'esm'
        contextual_embed_dim = None,  # 上下文嵌入维度,默认为 None
        use_aa_embeds = False,  # 是否使用氨基酸嵌入,默认为 False
        use_free_text_context = False,  # 是否使用自由文本上下文,默认为 False
        free_text_context_encoder = 'pubmed',  # 自由文本上下文编码器,默认为 'pubmed'
        free_text_embed_method = 'cls',  # 自由文本嵌入方法,默认为 'cls'
        dropout = 0.,  # 丢弃率,默认为 0
        binary_target = False,  # 是否为二进制目标,默认为 False
        target_mse_loss = False,  # 是否使用均方误差损失,默认为 False
        aux_read_value_loss = False,  # 是否使用辅助读值损失,默认为 False
        read_value_aux_loss_weight = 0.05,  # 读值辅助损失权重,默认为 0.05
        joint_cross_attn_depth = 1,  # 联合交叉注意力深度,默认为 1
        genome_self_attn_depth = 0,  # 基因组自注意力深度,默认为 0
        fourier_dims = 256,  # 傅立叶维度,默认为 256
        condition_squeeze_excite = False,  # 是否条件挤压激活,默认为 False
        condition_film = False,  # 是否条件 FILM,默认为 False
        condition_hypergrid = True,  # 是否条件超网格,默认为 True
        use_corr_coef_loss = False,  # 是否使用相关系数损失,默认为 False
        finetune_output_heads = None,  # 微调输出头,默认为 None
        **kwargs  # 其他参数
        ):
            # 调用父类的构造函数
            super().__init__()
            # 断言 enformer 是 Enformer 的实例
            assert isinstance(enformer, Enformer), 'enformer must be an instance of Enformer'
            # 设置 self.enformer 为传入的 enformer
            self.enformer = enformer
            # 计算 enformer_dim 为 enformer.dim 的两倍
            enformer_dim = enformer.dim * 2

            # 如果 finetune_output_heads 存在,则为 enformer 添加头部
            if exists(finetune_output_heads):
                self.enformer.add_heads(**finetune_output_heads)

            # 初始化 norm_seq_embed 为 LayerNorm 层,输入维度为 enformer_dim
            self.norm_seq_embed = nn.LayerNorm(enformer_dim)

            # 上下文嵌入相关变量

            # 断言 free_text_embed_method 只能是 'cls' 或 'mean_pool'
            assert free_text_embed_method in {'cls', 'mean_pool'}, 'must be either cls or mean_pool'
            # 设置 self.free_text_embed_method 为传入的 free_text_embed_method
            self.free_text_embed_method = free_text_embed_method
            # 设置 self.use_free_text_context 为传入的 use_free_text_context

            if use_free_text_context:
                # 如果使用自由文本上下文,则计算上下文嵌入维度
                contextual_embed_dim = get_contextual_dim(free_text_context_encoder)
            else:
                # 否则,断言必须给出上下文嵌入维度
                assert exists(contextual_embed_dim), 'contextual embedding dimension must be given if not using transformer encoder'

            # 蛋白质嵌入相关变量

            # 设置 self.use_aa_embeds 为传入的 use_aa_embeds
            self.use_aa_embeds = use_aa_embeds
            # 获取蛋白质嵌入器的配置
            self.aa_embed_config = get_protein_embedder(aa_embed_encoder)
            # 获取蛋白质嵌入函数
            self.get_aa_embed = self.aa_embed_config['fn']

            if use_aa_embeds:
                # 如果使用蛋白质嵌入,则设置 aa_embed_dim 为蛋白质嵌入维度
                aa_embed_dim = self.aa_embed_config['dim']
            else:
                # 否则,断言必须设置 AA 嵌入维度
                assert exists(aa_embed_dim), 'AA embedding dimensions must be set if not using ESM'

            # 条件

            self.cond_genetic = None
            self.cond_protein = None

            if condition_squeeze_excite or condition_film:
                # 根据条件选择 SqueezeExcitation 或 FiLM 类
                condition_klass = SqueezeExcitation if condition_squeeze_excite else FiLM

                # 如果需要条件激活,则为 genetic 和 protein 设置条件
                self.cond_genetic  = condition_klass(contextual_embed_dim, enformer_dim)
                self.cond_protein  = condition_klass(contextual_embed_dim, aa_embed_dim)

            # 基因组自注意力

            # 初始化 genome_self_attns 为空的 ModuleList

            for _ in range(genome_self_attn_depth):
                # 循环创建 SelfAttentionBlock,并添加到 genome_self_attns 中
                attn = SelfAttentionBlock(
                    dim = enformer_dim,
                    dropout = dropout
                )
                self.genome_self_attns.append(attn)

            # 联合注意力

            # 初始化 joint_cross_attns 为空的 ModuleList

            for _ in range(joint_cross_attn_depth):
                # 循环创建 JointCrossAttentionBlock,并添加到 joint_cross_attns 中
                attn = JointCrossAttentionBlock(
                    dim = enformer_dim,
                    context_dim = aa_embed_dim,
                    dropout = dropout
                )

                self.joint_cross_attns.append(attn)

            # 潜变量

            # 初始化 filip 为 FILIP 模块
            self.filip = FILIP(
                dim = enformer_dim,
                context_dim = aa_embed_dim,
                dim_head = latent_dim,
                heads = latent_heads,
                dropout = dropout
            )

            # 超网格条件

            if condition_hypergrid:
                # 如果需要超网格条件,则初始化 linear_with_hypergrid 为 HypergridLinear
                self.linear_with_hypergrid = HypergridLinear(latent_heads, latent_heads, context_dim = contextual_embed_dim)
            else:
                # 否则,初始化 linear_to_logits 为 Linear 层
                self.linear_to_logits = nn.Linear(latent_heads, latent_heads)

            # 到预测

            # 设置 binary_target 和 aux_read_value_loss 为传入的值
            self.binary_target = binary_target
            self.aux_read_value_loss = aux_read_value_loss
            self.read_value_aux_loss_weight = read_value_aux_loss_weight

            if binary_target:
                # 如果是二进制目标,则设置损失函数为二进制交叉熵或均方误差
                self.loss_fn = F.binary_cross_entropy_with_logits if not target_mse_loss else F.mse_loss

                # 设置 to_pred 为 Sequential 模块,用于预测
                self.to_pred = nn.Sequential(
                    Reduce('... n d -> ... d', 'mean'),
                    nn.LayerNorm(latent_heads),
                    nn.Linear(latent_heads, 1),
                    Rearrange('... 1 -> ...')
                )

                # 设置 to_read_value_aux_loss 为 ReadValueMLP 模块
                self.to_read_value_aux_loss = ReadValueMLP(
                    dim = latent_heads,
                    fourier_dims = fourier_dims
                )

            else:
                # 如果不是二进制目标,则设置损失函数为泊松损失或相关系数损失
                self.loss_fn = poisson_loss if not use_corr_coef_loss else corr_coef_loss

                # 设置 to_pred 为 Sequential 模块,用于预测
                self.to_pred = nn.Sequential(
                    nn.Linear(latent_heads, 1),
                    Rearrange('... 1 -> ...'),
                    nn.Softplus()
                )
    # 合并主要损失和辅助损失,如果不需要辅助损失则返回主要损失
    def combine_losses(self, loss, aux_loss):
        if not self.aux_read_value_loss:
            return loss

        return loss + self.read_value_aux_loss_weight * aux_loss

    # 前向传播函数,用于处理 Enformer 模型的头部
    def forward_enformer_head(
        self,
        seq_embed,
        *,
        head,
        target = None,
        return_corr_coef = False
    ):
        # 检查是否开启二进制目标训练,如果是则无法在轨道上微调
        assert not self.binary_target, 'cannot finetune on tracks if binary_target training is turned on'

        # 解冻 Enformer 模型的所有层
        unfreeze_all_layers_(self.enformer._heads)

        # 检查指定的头部是否存在于 Enformer 模型中
        assert head in self.enformer._heads, f'{head} head not found in enformer'

        # 使用指定的头部对序列嵌入进行预测
        pred = self.enformer._heads[head](seq_embed)

        # 如果没有提供目标数据,则直接返回预测结果
        if not exists(target):
            return pred

        # 检查预测结果和目标数据的维度是否匹配
        assert pred.shape[-1] == target.shape[-1], f'{head} head on enformer produced {pred.shape[-1]} tracks, but the supplied target only has {target.shape[-1]}'

        # 如果提供了目标数据并且需要返回相关系数,则计算并返回相关系数
        if exists(target) and return_corr_coef:
            return pearson_corr_coef(pred, target)

        # 计算并返回损失函数的结果
        return self.loss_fn(pred, target)

    # 前向传播函数,用于处理多个输入和参数的情况
    def forward(
        self,
        seq,
        *,
        aa = None,
        aa_embed = None,
        contextual_embed = None,
        contextual_free_text = None,
        aa_mask = None,
        target = None,
        read_value = None,
        peaks_nr = None,
        return_corr_coef = False,
        finetune_enformer = False,
        finetune_enformer_ln_only = False,
        unfreeze_enformer_last_n_layers = 0,
        head = None

.\lucidrains\tf-bind-transformer\tf_bind_transformer\training_utils.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 tf_bind_transformer.optimizer 模块中导入 get_optimizer 函数
from tf_bind_transformer.optimizer import get_optimizer
# 从 tf_bind_transformer.data 模块中导入 read_bed, collate_dl_outputs, get_dataloader, remap_df_add_experiment_target_cell 函数
from tf_bind_transformer.data import read_bed, collate_dl_outputs, get_dataloader, remap_df_add_experiment_target_cell
# 从 tf_bind_transformer.data 模块中导入 RemapAllPeakDataset, NegativePeakDataset, ScopedNegativePeakDataset 类

# 定义 exists 函数,用于判断值是否存在
def exists(val):
    return val is not None

# 定义 default 函数,用于返回默认值
def default(val, d):
    return val if exists(val) else d

# 定义 accum_log 函数,用于记录和累积梯度步骤中的值
def accum_log(log, new_logs):
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

# 定义简单的 Trainer 类
class Trainer(nn.Module):
    def __init__(
        self,
        model,
        *,
        remap_bed_file,
        negative_bed_file,
        factor_fasta_folder,
        fasta_file,
        train_chromosome_ids,
        valid_chromosome_ids,
        batch_size,
        context_length,
        lr = 3e-4,
        wd = 0.1,
        validate_every = 250,
        grad_clip_norm = None,
        grad_accum_every = 1,
        held_out_targets = [],
        held_out_cell_types = [],
        exclude_targets = [],
        exclude_cell_types = [],
        shuffle = False,
        train_sample_frac = 1.,
        valid_sample_frac = 1.,
        remap_sample_frac = 1.,
        shift_aug_range = (-2, 2),
        rc_aug = False,
        experiments_json_path = None,
        read_value_aux_loss = False,
        checkpoint_filename = './checkpoint.pt',
        include_scoped_negs = False,
        scoped_negs_remap_bed_path = None,
        scoped_negs_path = None,
        scoped_negs_exts = '.bed.bool.npy',
        include_biotypes_metadata_in_context = False,
        biotypes_metadata_path = None,
        include_biotypes_metadata_columns = ['germ_layer', 'cellline_cat'],
        biotypes_metadata_delimiter = ' | ',
        balance_sampling_by_target = True,
        valid_balance_sampling_by_target = None,
    # 定义 forward 方法,用于前向传播
    def forward(
        self,
        finetune_enformer_ln_only = True,
        **kwargs
        ):
            # 获取当前的梯度累积步数
            grad_accum_every = self.grad_accum_every
            # 获取当前步数
            curr_step = int(self.steps.item())
            # 设置模型为训练模式
            self.model.train()

            # 初始化日志字典
            log = {}

            # 循环执行梯度累积步数次
            for _ in range(self.grad_accum_every):
                # 从数据加载器中获取数据
                dl_outputs = [next(self.dl), next(self.neg_dl)]

                # 如果包含了作用域负样本,则继续获取数据
                if self.include_scoped_negs:
                    dl_outputs.append(next(self.scoped_neg_dl))

                # 将数据整理成模型所需的格式
                seq, tf_aa, contextual_texts, peaks_nr, read_value, binary_target = collate_dl_outputs(*dl_outputs)
                seq, binary_target, read_value, peaks_nr = seq.cuda(), binary_target.cuda(), read_value.cuda(), peaks_nr.cuda()

                # 计算模型的损失
                loss, aux_loss = self.model(
                    seq,
                    target = binary_target,
                    aa = tf_aa,
                    contextual_free_text = contextual_texts,
                    finetune_enformer_ln_only = finetune_enformer_ln_only,
                    read_value = read_value,
                    peaks_nr = peaks_nr,
                    **kwargs
                )

                # 计算总损失
                total_loss = self.model.combine_losses(loss, aux_loss)

                # 更新日志
                log = accum_log(log, {
                    'loss': loss.item() / grad_accum_every,
                    'aux_loss': aux_loss.item() / grad_accum_every,
                    'total_loss': total_loss.item() / grad_accum_every
                })

                # 反向传播
                (total_loss / self.grad_accum_every).backward()

            # 打印当前步数的总损失
            print(f'{curr_step} loss: {log["total_loss"]}')

            # 如果设置了梯度裁剪阈值,则进行梯度裁剪
            if exists(self.grad_clip_norm):
                nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_norm)

            # 更新优化器
            self.optim.step()
            self.optim.zero_grad()

            # 每隔一定步数进行验证
            if (curr_step % self.validate_every) == 0:
                # 设置模型为评估模式
                self.model.eval()

                # 循环执行梯度累积步数次验证
                for _ in range(self.grad_accum_every):
                    # 从验证数据加载器中获取数据
                    seq, tf_aa, contextual_texts, peaks_nr, read_value, binary_target = collate_dl_outputs(next(self.valid_dl), next(self.valid_neg_dl))
                    seq, binary_target = seq.cuda(), binary_target.cuda()

                    # 获取验证集的预测结果
                    valid_logits = self.model(
                        seq,
                        aa = tf_aa,
                        contextual_free_text = contextual_texts,
                    )

                    # 计算验证集的损失和准确率
                    valid_loss = self.model.loss_fn(valid_logits, binary_target.float())
                    valid_accuracy = ((valid_logits.sigmoid() > 0.5).int() == binary_target).sum() / (binary_target.numel())

                    # 更新日志
                    log = accum_log(log, {
                        'valid_loss': valid_loss.item() / grad_accum_every,
                        'valid_accuracy': valid_accuracy.item() / grad_accum_every
                    })

                # 打印验证集的损失和准确率
                print(f'{curr_step} valid loss: {log["valid_loss"]}')
                print(f'{curr_step} valid accuracy: {log["valid_accuracy"]}')

                # 如果当前步数大于0,则保存模型参数
                if curr_step > 0:
                    torch.save(self.model.state_dict(), self.checkpoint_filename)

            # 更新步数
            self.steps += 1
            # 返回日志
            return log

.\lucidrains\tf-bind-transformer\tf_bind_transformer\training_utils_bigwig.py

import torch
from torch import nn
from tf_bind_transformer.optimizer import get_optimizer
from tf_bind_transformer.data_bigwig import BigWigDataset, BigWigTracksOnlyDataset, get_bigwig_dataloader, get_bigwig_tracks_dataloader
from enformer_pytorch.modeling_enformer import poisson_loss, pearson_corr_coef

def exists(val):
    # 检查值是否存在
    return val is not None

def default(val, d):
    # 如果值存在则返回该值,否则返回默认值
    return val if exists(val) else d

# helpers for logging and accumulating values across gradient steps

def accum_log(log, new_logs):
    # 累积日志中的值
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

# simple Trainer class

class BigWigTrainer(nn.Module):
    def __init__(
        self,
        model,
        *,
        human_factor_fasta_folder,
        annot_file_path,
        human_loci_path,
        mouse_loci_path,
        human_fasta_file,
        mouse_fasta_file,
        batch_size,
        bigwig_tracks_only_folder_path = None,
        bigwig_folder_path = None,
        train_chromosome_ids = None,
        valid_chromosome_ids = None,
        mouse_factor_fasta_folder = None,
        downsample_factor = 128,
        target_length = 896,
        lr = 3e-4,
        wd = 0.1,
        validate_every = 250,
        grad_clip_norm = None,
        grad_accum_every = 1,
        held_out_targets_human = [],
        held_out_targets_mouse = [],
        held_out_cell_types_human = [],
        held_out_cell_types_mouse = [],
        context_length = 4096,
        shuffle = False,
        shift_aug_range = (-2, 2),
        rc_aug = False,
        checkpoint_filename = './checkpoint.pt',
        include_biotypes_metadata_in_context = False,
        biotypes_metadata_path = None,
        include_biotypes_metadata_columns = ['germ_layer', 'cellline_cat'],
        biotypes_metadata_delimiter = ' | ',
        bigwig_reduction_type = 'sum',
        enformer_train_valid_split = True
    def forward(
        self,
        finetune_enformer_ln_only = True,
        **kwargs

.\lucidrains\tf-bind-transformer\tf_bind_transformer\__init__.py

# 从 tf_bind_transformer 库中导入 AdapterModel 类
from tf_bind_transformer.tf_bind_transformer import AdapterModel
# 从 tf_bind_transformer 库中导入 Trainer 类
from tf_bind_transformer.training_utils import Trainer
# 从 tf_bind_transformer 库中导入 BigWigTrainer 类
from tf_bind_transformer.training_utils_bigwig import BigWigTrainer

TimeSformer - Pytorch

Implementation of TimeSformer, from Facebook AI. A pure and simple attention-based solution for reaching SOTA on video classification. This repository will only house the best performing variant, 'Divided Space-Time Attention', which is nothing more than attention along the time axis before the spatial.

Press release

Install

$ pip install timesformer-pytorch

Usage

import torch
from timesformer_pytorch import TimeSformer

model = TimeSformer(
    dim = 512,
    image_size = 224,
    patch_size = 16,
    num_frames = 8,
    num_classes = 10,
    depth = 12,
    heads = 8,
    dim_head =  64,
    attn_dropout = 0.1,
    ff_dropout = 0.1
)

video = torch.randn(2, 8, 3, 224, 224) # (batch x frames x channels x height x width)
mask = torch.ones(2, 8).bool() # (batch x frame) - use a mask if there are variable length videos in the same batch

pred = model(video, mask = mask) # (2, 10)

Citations

@misc{bertasius2021spacetime,
    title   = {Is Space-Time Attention All You Need for Video Understanding?}, 
    author  = {Gedas Bertasius and Heng Wang and Lorenzo Torresani},
    year    = {2021},
    eprint  = {2102.05095},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@article{tokshift2021,
    title   = {Token Shift Transformer for Video Classification},
    author  = {Hao Zhang, Yanbin Hao, Chong-Wah Ngo},
    journal = {ACM Multimedia 2021},
}

.\lucidrains\TimeSformer-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'timesformer-pytorch',  # 包的名称
  packages = find_packages(),  # 查找并包含所有包
  version = '0.4.1',  # 版本号
  license='MIT',  # 许可证
  description = 'TimeSformer - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/TimeSformer-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'attention mechanism',
    'transformers',
    'video classification',
  ],
  install_requires=[  # 安装依赖
    'einops>=0.3',
    'torch>=1.6'
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\TimeSformer-pytorch\timesformer_pytorch\rotary.py

# 从 math 模块中导入 log 和 pi 函数
# 从 torch 模块中导入 nn, einsum 和 F
# 从 einops 模块中导入 rearrange 和 repeat 函数
from math import log, pi
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat

# 定义函数,用于将输入张量中的每两个元素进行旋转
def rotate_every_two(x):
    # 重新排列输入张量的维度,将每两个元素组成一组
    x = rearrange(x, '... (d j) -> ... d j', j = 2)
    # 将每组中的两个元素拆分为两个张量
    x1, x2 = x.unbind(dim = -1)
    # 对每组中的两个元素进行旋转操作
    x = torch.stack((-x2, x1), dim = -1)
    # 重新排列张量的维度,恢复原始形状
    return rearrange(x, '... d j -> ... (d j)')

# 定义函数,应用旋转嵌入到查询和键中
def apply_rot_emb(q, k, rot_emb):
    # 解包旋转嵌入
    sin, cos = rot_emb
    # 获取旋转维度的大小
    rot_dim = sin.shape[-1]
    # 将查询和键张量分为旋转部分和非旋转部分
    (q, q_pass), (k, k_pass) = map(lambda t: (t[..., :rot_dim], t[..., rot_dim:]), (q, k))
    # 对查询和键张量的旋转部分进行旋转操作
    q, k = map(lambda t: t * cos + rotate_every_two(t) * sin, (q, k))
    # 将旋转后的查询和键张量与非旋转部分拼接
    q, k = map(lambda t: torch.cat(t, dim = -1), ((q, q_pass), (k, k_pass)))
    return q, k

# 定义类,实现轴向旋转嵌入
class AxialRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_freq = 10):
        super().__init__()
        self.dim = dim
        # 计算频率范围
        scales = torch.logspace(0., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
        # 将频率范围作为缓冲区存储
        self.register_buffer('scales', scales)

    def forward(self, h, w, device):
        # 重新排列频率范围的维度
        scales = rearrange(self.scales, '... -> () ...')
        # 将频率范围移动到指定设备
        scales = scales.to(device)

        # 生成高度序列
        h_seq = torch.linspace(-1., 1., steps = h, device = device)
        h_seq = h_seq.unsqueeze(-1)

        # 生成宽度序列
        w_seq = torch.linspace(-1., 1., steps = w, device = device)
        w_seq = w_seq.unsqueeze(-1)

        # 对高度和宽度序列应用频率范围和 pi
        h_seq = h_seq * scales * pi
        w_seq = w_seq * scales * pi

        # 生成正弦序列
        x_sinu = repeat(h_seq, 'i d -> i j d', j = w)
        y_sinu = repeat(w_seq, 'j d -> i j d', i = h)

        # 拼接正弦和余弦序列
        sin = torch.cat((x_sinu.sin(), y_sinu.sin()), dim = -1)
        cos = torch.cat((x_sinu.cos(), y_sinu.cos()), dim = -1)

        # 重新排列正弦和余弦序列的维度
        sin, cos = map(lambda t: rearrange(t, 'i j d -> (i j) d'), (sin, cos))
        sin, cos = map(lambda t: repeat(t, 'n d -> () n (d j)', j = 2), (sin, cos))
        return sin, cos

# 定义类,实现旋转嵌入
class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        # 计算频率的倒数
        inv_freqs = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        # 将频率的倒数作为缓冲区存储
        self.register_buffer('inv_freqs', inv_freqs)

    def forward(self, n, device):
        # 生成序列
        seq = torch.arange(n, device = device)
        # 计算频率
        freqs = einsum('i, j -> i j', seq, self.inv_freqs)
        freqs = torch.cat((freqs, freqs), dim = -1)
        freqs = rearrange(freqs, 'n d -> () n d')
        return freqs.sin(), freqs.cos()

.\lucidrains\TimeSformer-pytorch\timesformer_pytorch\timesformer_pytorch.py

import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat

from timesformer_pytorch.rotary import apply_rot_emb, AxialRotaryEmbedding, RotaryEmbedding

# 导入所需的库

# helpers

def exists(val):
    return val is not None

# 定义一个辅助函数,用于检查变量是否存在

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, *args, **kwargs):
        x = self.norm(x)
        return self.fn(x, *args, **kwargs)

# 定义一个预正则化层,包含一个 LayerNorm 层和一个传入的函数

# time token shift

def shift(t, amt):
    if amt is 0:
        return t
    return F.pad(t, (0, 0, 0, 0, amt, -amt))

# 定义一个函数,用于在时间维度上进行平移

class PreTokenShift(nn.Module):
    def __init__(self, frames, fn):
        super().__init__()
        self.frames = frames
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        f, dim = self.frames, x.shape[-1]
        cls_x, x = x[:, :1], x[:, 1:]
        x = rearrange(x, 'b (f n) d -> b f n d', f = f)

        # shift along time frame before and after

        dim_chunk = (dim // 3)
        chunks = x.split(dim_chunk, dim = -1)
        chunks_to_shift, rest = chunks[:3], chunks[3:]
        shifted_chunks = tuple(map(lambda args: shift(*args), zip(chunks_to_shift, (-1, 0, 1))))
        x = torch.cat((*shifted_chunks, *rest), dim = -1)

        x = rearrange(x, 'b f n d -> b (f n) d')
        x = torch.cat((cls_x, x), dim = 1)
        return self.fn(x, *args, **kwargs)

# 定义一个预 Token 平移层,用于在时间维度上进行平移操作

# feedforward

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

# 定义一个 GEGLU 激活函数

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        return self.net(x)

# 定义一个前馈神经网络层,包含线性层、GEGLU激活函数和线性层

# attention

def attn(q, k, v, mask = None):
    sim = einsum('b i d, b j d -> b i j', q, k)

    if exists(mask):
        max_neg_value = -torch.finfo(sim.dtype).max
        sim.masked_fill_(~mask, max_neg_value)

    attn = sim.softmax(dim = -1)
    out = einsum('b i j, b j d -> b i d', attn, v)
    return out

# 定义一个注意力机制函数,计算注意力权重并应用到值上

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        dropout = 0.
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

# 定义一个注意力层,包含线性层用于计算查询、键、值,以及输出线性层和 Dropout
    # 定义一个前向传播函数,接受输入 x,从 einops_from 重排到 einops_to,可选参数 mask 用于掩码,cls_mask 用于分类掩码,rot_emb 用于旋转嵌入,**einops_dims 用于指定维度
    def forward(self, x, einops_from, einops_to, mask = None, cls_mask = None, rot_emb = None, **einops_dims):
        # 获取头数
        h = self.heads
        # 将输入 x 分解为查询、键、值
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        # 将查询、键、值重排为 (b h) n d 的形式
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))

        # 对查询进行缩放
        q = q * self.scale

        # 分离出索引为 1 的分类令牌
        (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k, v))

        # 让分类令牌关注所有时间和空间的补丁的键/值
        cls_out = attn(cls_q, k, v, mask = cls_mask)

        # 根据给定的 einops_from 和 einops_to 重排时间或空间
        q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q_, k_, v_))

        # 如果存在旋转嵌入,则应用旋转嵌入
        if exists(rot_emb):
            q_, k_ = apply_rot_emb(q_, k_, rot_emb)

        # 将分类令牌的键和值在时间或空间上扩展并连接
        r = q_.shape[0] // cls_k.shape[0]
        cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r = r), (cls_k, cls_v))

        k_ = torch.cat((cls_k, k_), dim = 1)
        v_ = torch.cat((cls_v, v_), dim = 1)

        # 注意力机制
        out = attn(q_, k_, v_, mask = mask)

        # 将时间或空间合并回原始形状
        out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims)

        # 将分类令牌连接回输出
        out = torch.cat((cls_out, out), dim = 1)

        # 将头部合并回输出
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)

        # 合并头部输出
        return self.to_out(out)
# 主要类

class TimeSformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        num_frames,
        num_classes,
        image_size = 224,
        patch_size = 16,
        channels = 3,
        depth = 12,
        heads = 8,
        dim_head = 64,
        attn_dropout = 0.,
        ff_dropout = 0.,
        rotary_emb = True,
        shift_tokens = False
    ):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_size // patch_size) ** 2
        num_positions = num_frames * num_patches
        patch_dim = channels * patch_size ** 2

        self.heads = heads
        self.patch_size = patch_size
        self.to_patch_embedding = nn.Linear(patch_dim, dim)
        self.cls_token = nn.Parameter(torch.randn(1, dim))

        self.use_rotary_emb = rotary_emb
        if rotary_emb:
            self.frame_rot_emb = RotaryEmbedding(dim_head)
            self.image_rot_emb = AxialRotaryEmbedding(dim_head)
        else:
            self.pos_emb = nn.Embedding(num_positions + 1, dim)

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            ff = FeedForward(dim, dropout = ff_dropout)
            time_attn = Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)
            spatial_attn = Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)

            if shift_tokens:
                time_attn, spatial_attn, ff = map(lambda t: PreTokenShift(num_frames, t), (time_attn, spatial_attn, ff))

            time_attn, spatial_attn, ff = map(lambda t: PreNorm(dim, t), (time_attn, spatial_attn, ff))

            self.layers.append(nn.ModuleList([time_attn, spatial_attn, ff]))

        self.to_out = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, video, mask = None):
        b, f, _, h, w, *_, device, p = *video.shape, video.device, self.patch_size
        assert h % p == 0 and w % p == 0, f'height {h} and width {w} of video must be divisible by the patch size {p}'

        # 计算高度和宽度维度中的补丁数量,以及总补丁数(n)

        hp, wp = (h // p), (w // p)
        n = hp * wp

        # 视频转换为补丁嵌入

        video = rearrange(video, 'b f c (h p1) (w p2) -> b (f h w) (p1 p2 c)', p1 = p, p2 = p)
        tokens = self.to_patch_embedding(video)

        # 添加类别标记

        cls_token = repeat(self.cls_token, 'n d -> b n d', b = b)
        x =  torch.cat((cls_token, tokens), dim = 1)

        # 位置嵌入

        frame_pos_emb = None
        image_pos_emb = None
        if not self.use_rotary_emb:
            x += self.pos_emb(torch.arange(x.shape[1], device = device))
        else:
            frame_pos_emb = self.frame_rot_emb(f, device = device)
            image_pos_emb = self.image_rot_emb(hp, wp, device = device)

        # 计算不同帧数的掩码

        frame_mask = None
        cls_attn_mask = None
        if exists(mask):
            mask_with_cls = F.pad(mask, (1, 0), value = True)

            frame_mask = repeat(mask_with_cls, 'b f -> (b h n) () f', n = n, h = self.heads)

            cls_attn_mask = repeat(mask, 'b f -> (b h) () (f n)', n = n, h = self.heads)
            cls_attn_mask = F.pad(cls_attn_mask, (1, 0), value = True)

        # 时间和空间注意力

        for (time_attn, spatial_attn, ff) in self.layers:
            x = time_attn(x, 'b (f n) d', '(b n) f d', n = n, mask = frame_mask, cls_mask = cls_attn_mask, rot_emb = frame_pos_emb) + x
            x = spatial_attn(x, 'b (f n) d', '(b f) n d', f = f, cls_mask = cls_attn_mask, rot_emb = image_pos_emb) + x
            x = ff(x) + x

        cls_token = x[:, 0]
        return self.to_out(cls_token)

.\lucidrains\TimeSformer-pytorch\timesformer_pytorch\__init__.py

# 从 timesformer_pytorch.timesformer_pytorch 模块中导入 TimeSformer 类
from timesformer_pytorch.timesformer_pytorch import TimeSformer

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

Token Shift GPT

Implementation of Token Shift GPT - An autoregressive model that relies solely on shifting along the sequence dimension and feedforwards.

Update: Inexplicably, it actually works quite well. The feedforward module follows the same design as gMLP, except the feature dimension of the gate tensor is divided up into log2(seq_len) chunks, and the mean pool of the past consecutive segments (length 1, 2, 4, 8, etc. into the past) are shifted into each chunk before a projection along the feature dimension.

Install

$ pip install token-shift-gpt

Usage

import torch
from token_shift_gpt import TokenShiftGPT

model = TokenShiftGPT(
    num_tokens = 256,
    dim = 512,
    max_seq_len = 1024,
    depth = 12,
    ff_mult = 8   # when working with small model dimensions, you may want to increase the intermediate feedforward dimension (here, 8x instead of the usual 4x), so the learning is not bottlenecked by the dimensions of the shifted chunk
)

x = torch.randint(0, 256, (1, 1024))
logits = model(x) # (1, 1024, 256)

To use the discounted cumulative sum approach (which only uses one chunk and seems to be just as effective as the above), just set use_discounted_cumsum = True

First install an additional library

$ pip install torch-discounted-cumsum

Then

import torch
from token_shift_gpt import TokenShiftGPT

model = TokenShiftGPT(
    num_tokens = 256,
    dim = 512,
    max_seq_len = 1024,
    depth = 12,
    ff_mult = 8,
    use_discounted_cumsum = True,
    discounted_gamma = 0.9              # gamma factor for discount
)

x = torch.randint(0, 256, (1, 1024))
logits = model(x) # (1, 1024, 256)

Citations

@misc{yu2021s2mlp,
    title   = {S$^2$-MLP: Spatial-Shift MLP Architecture for Vision}, 
    author  = {Tan Yu and Xu Li and Yunfeng Cai and Mingming Sun and Ping Li},
    year    = {2021},
    eprint  = {2106.07477},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{liu2021pay,
    title   = {Pay Attention to MLPs}, 
    author  = {Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
    year    = {2021},
    eprint  = {2105.08050},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@software{peng_bo_2021_5196578,
    author       = {PENG Bo},
    title        = {BlinkDL/RWKV-LM: 0.01},
    month        = {aug},
    year         = {2021},
    publisher    = {Zenodo},
    version      = {0.01},
    doi          = {10.5281/zenodo.5196578},
    url          = {https://doi.org/10.5281/zenodo.5196578}
}

.\lucidrains\token-shift-gpt\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'token-shift-gpt',  # 包的名称
  packages = find_packages(),  # 查找所有包
  version = '0.0.3',  # 版本号
  license='MIT',  # 许可证
  description = 'Token Shift GPT - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/token-shift-gpt',  # 项目链接
  keywords = [
    'artificial intelligence',  # 关键词
    'deep learning',  # 关键词
    'autoregressive language modeling'  # 关键词
  ],
  install_requires=[
    'einops>=0.3',  # 安装所需的依赖
    'torch>=1.6'  # 安装所需的依赖
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类器
    'Intended Audience :: Developers',  # 分类器
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类器
    'License :: OSI Approved :: MIT License',  # 分类器
    'Programming Language :: Python :: 3.6',  # 分类器
  ],
)

.\lucidrains\token-shift-gpt\token_shift_gpt\autoregressive_wrapper.py

import torch
from torch import nn
import torch.nn.functional as F

# 定义一个装饰器函数,用于在模型评估时切换为eval模式
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 定义一个函数用于对logits进行top k过滤
def top_k(logits, thres = 0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 定义一个包装类,用于自回归模型
class AutoregressiveWrapper(nn.Module):
    def __init__(self, net, ignore_index = -100, pad_value = 0):
        super().__init__()
        self.pad_value = pad_value
        self.ignore_index = ignore_index

        self.net = net
        self.max_seq_len = net.seq_len

    # 生成函数,用于生成序列
    @torch.no_grad()
    @eval_decorator
    def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
        device = start_tokens.device
        num_dims = len(start_tokens.shape)

        if num_dims == 1:
            start_tokens = start_tokens[None, :]

        b, t = start_tokens.shape

        out = start_tokens

        for _ in range(seq_len):
            x = out[:, -self.max_seq_len:]

            logits = self.net(x, **kwargs)[:, -1, :]

            filtered_logits = top_k(logits, thres = filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim=-1)

            sample = torch.multinomial(probs, 1)

            out = torch.cat((out, sample), dim=-1)

            if eos_token is not None and (sample == eos_token).all():
                break

        out = out[:, t:]

        if num_dims == 1:
            out = out.squeeze(0)

        return out

    # 前向传播函数,用于计算损失
    def forward(self, x, **kwargs):
        xi, xo = x[:, :-1], x[:, 1:]
        out = self.net(xi, **kwargs)
        loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
        return loss

.\lucidrains\token-shift-gpt\token_shift_gpt\token_shift_gpt.py

# 从 math 模块中导入 log2 和 ceil 函数
# 从 torch 模块中导入 nn, einsum 和 nn.functional 模块
from math import log2, ceil
import torch
from torch import nn, einsum
import torch.nn.functional as F

# 从 einops 模块中导入 rearrange 函数
from einops import rearrange

# 定义一个辅助函数,用于检查值是否存在
def exists(val):
    return val is not None

# 定义一个函数,用于在指定维度上对输入进行平移
def shift(x, amt, dim = -1):
    return F.pad(x, (*((0, 0) * (-dim - 1)), amt, -amt), value = 0.)

# 定义一个函数,用于在 tokens 上进行平移
def shift_tokens(x, amt, eps = 1e-5):
    n, device = x.shape[1], x.device

    # 计算累积和
    cumsum = x.cumsum(dim = 1)
    *x, x_pass = x.chunk(amt + 1, dim = -1)
    *x_cumsum, _ = cumsum.chunk(amt + 1, dim = -1)

    # 计算平移量
    amts = 2 ** torch.arange(amt)
    amts = amts.tolist()

    shifts = []
    denom = torch.arange(n, device = device)

    for x_chunk, x_cumsum_chunk, amt in zip(x, x_cumsum, amts):
        # 计算平移后的值
        shifted_chunk = shift(x_cumsum_chunk, amt, dim = -2) - shift(x_cumsum_chunk, 2 * amt, dim = -2)
        shifted_denom = shift(denom, amt, dim = -1) - shift(denom, 2 * amt, dim = -1)
        shifted_denom = rearrange(shifted_denom, 'n -> () n ()')
        normed_shifted_x = shifted_chunk /  (shifted_denom + eps)
        shifts.append(normed_shifted_x)

    return torch.cat((*shifts, x_pass), dim = -1)

# 定义一个函数,用于计算折扣累积和
def discounted_cumsum(t, gamma):
    try:
        from torch_discounted_cumsum import discounted_cumsum_left
    except ImportError:
        print('unable to import torch_discounted_cumsum - please run `pip install torch-discounted-cumsum`')

    b, n, d = t.shape
    t = rearrange(t, 'b n d -> (b d) n')
    t = discounted_cumsum_left(t, gamma)
    t = rearrange(t, '(b d) n -> b n d', b = b)
    return t

# 定义一个残差模块
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x

# 定义一个前馈神经网络模块
class FeedForward(nn.Module):
    def __init__(
        self,
        *,
        dim,
        max_seq_len,
        num_shifts,
        mult = 4,
        eps = 1e-3,
        use_discounted_cumsum = False,
        discount_gamma = 0.9
    ):
        super().__init__()
        self.norm = nn.LayerNorm(dim)

        self.project_in = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU()
        )

        self.num_shifts = num_shifts
        hidden_dim = dim * mult // 2

        self.gate_norm = nn.LayerNorm(hidden_dim)
        self.to_gate = nn.Linear(hidden_dim, hidden_dim)

        nn.init.constant_(self.to_gate.weight, eps)
        nn.init.constant_(self.to_gate.bias, 1.)

        self.project_out = nn.Linear(hidden_dim, dim)

        # 用于使用折扣累积和方法

        self.use_discounted_cumsum = use_discounted_cumsum
        self.discount_gamma = discount_gamma

    def forward(self, x):
        x = self.norm(x)

        x = self.project_in(x)

        x, gate = x.chunk(2, dim = -1)

        gate = self.gate_norm(gate)

        if self.use_discounted_cumsum:
            gate = shift(gate, 1, dim = -2)
            gate = discounted_cumsum(gate, self.discount_gamma)
        else:
            gate = shift_tokens(gate, self.num_shifts)

        x = x * self.to_gate(gate)
        return self.project_out(x)

# 定义一个 TokenShiftGPT 模块
class TokenShiftGPT(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        max_seq_len,
        depth,
        ff_mult = 4,
        use_discounted_cumsum = False,
        discount_gamma = 0.9
    ):
        super().__init__()
        self.seq_len = max_seq_len
        num_shifts = ceil(log2(max_seq_len)) - 1

        self.token_emb = nn.Embedding(num_tokens, dim)
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        self.net = nn.Sequential(
            *[Residual(FeedForward(dim = dim, num_shifts = num_shifts, mult = ff_mult, max_seq_len = max_seq_len, use_discounted_cumsum = use_discounted_cumsum, discount_gamma = discount_gamma)) for _ in range(depth)],
            nn.LayerNorm(dim),
            nn.Linear(dim, num_tokens)
        )
    # 定义一个前向传播函数,接收输入 x
    def forward(self, x):
        # 对输入 x 进行 token embedding
        x = self.token_emb(x)
        # 生成位置编码,长度为 x 的第二维度,设备为 x 所在的设备
        pos_emb = self.pos_emb(torch.arange(x.shape[1], device = x.device))
        # 将位置编码与 token embedding 相加,并重新排列维度
        x = x + rearrange(pos_emb, 'n d -> () n d')
        # 将处理后的输入 x 输入到神经网络中进行计算
        return self.net(x)

.\lucidrains\token-shift-gpt\token_shift_gpt\__init__.py

# 从 token_shift_gpt 包中导入 TokenShiftGPT 类
from token_shift_gpt.token_shift_gpt import TokenShiftGPT

.\lucidrains\token-shift-gpt\train.py

# 导入所需的模块
from token_shift_gpt import TokenShiftGPT
from token_shift_gpt.autoregressive_wrapper import AutoregressiveWrapper

import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 768
SEQ_LEN = 768

# 定义辅助函数

# 从 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 从 tokens 解码为字符串
def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# 实例化类似 GPT 的解码器模型
model = TokenShiftGPT(
    num_tokens = 256,
    max_seq_len = SEQ_LEN,
    dim = 512,
    depth = 8,
    ff_mult = 8
)

model = AutoregressiveWrapper(model)
model.cuda()

# 准备 enwik8 数据

with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的 DataLoader
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f'validation loss: {loss.item()}')

    if i is not 0 and i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp, GENERATE_LENGTH)
        output_str = decode_tokens(sample)
        print(output_str)

Toolformer - Pytorch (wip)

Implementation of Toolformer, Language Models That Can Use Tools, by MetaAI

Appreciation

  • Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research

  • Enrico for getting the ball rolling with the initial commit of different tools!

  • Thanks goes out to ChatGPT for doing all the regular expressions in this repository for parsing the functions and parameters for the API calls. I am terrible at regular expressions, so this was enormous help from the AI (with no hitches, it was perfect).

Install

$ pip install toolformer-pytorch

Usage

Example usage with giving language models awareness of current date and time.

import torch
from toolformer_pytorch import Toolformer, PaLM

# simple calendar api call - function that returns a string

def Calendar():
    import datetime
    from calendar import day_name, month_name
    now = datetime.datetime.now()
    return f'Today is {day_name[now.weekday()]}, {month_name[now.month]} {now.day}, {now.year}.'

# prompt for teaching it to use the Calendar function from above

prompt = f"""
Your task is to add calls to a Calendar API to a piece of text.
The API calls should help you get information required to complete the text.
You can call the API by writing "[Calendar()]"
Here are some examples of API calls:
Input: Today is the first Friday of the year.
Output: Today is the first [Calendar()] Friday of the year.
Input: The president of the United States is Joe Biden.
Output: The president of the United States is [Calendar()] Joe Biden.
Input: [input]
Output: 
"""

data = [
    "The store is never open on the weekend, so today it is closed.",
    "The number of days from now until Christmas is 30",
    "The current day of the week is Wednesday."
]

# model - here using PaLM, but any nn.Module that returns logits in the shape (batch, seq, num_tokens) is fine

model = PaLM(
    dim = 512,
    depth = 2,
    heads = 8,
    dim_head = 64
).cuda()

# toolformer

toolformer = Toolformer(
    model = model,
    model_seq_len = 256,
    teach_tool_prompt = prompt,
    tool_id = 'Calendar',
    tool = Calendar,
    finetune = True
)

# invoking this will
# (1) prompt the model with your inputs (data), inserted into [input] tag
# (2) with the sampled outputs, filter out the ones that made proper API calls
# (3) execute the API calls with the `tool` given
# (4) filter with the specialized filter function (which can be used independently as shown in the next section)
# (5) fine-tune on the filtered results

filtered_stats = toolformer(data)

# then, once you see the 'finetune complete' message

response = toolformer.sample_model_with_api_calls("How many days until the next new years?")

# hopefully you see it invoke the calendar and utilize the response of the api call...

The main novelty of the paper is defining a fitness score for the outputs from a transformer instructed to insert API calls. The score is used to filter the sampled outputs for finetuning the transformer to make API calls that decreases perplexity of the text that follows it.

import torch

from toolformer_pytorch import (
    Toolformer,
    PaLM,
    filter_tokens_with_api_response
)

# model

palm = PaLM(
    dim = 512,
    num_tokens = 20000,
    depth = 2,
    heads = 8,
    dim_head = 64
).cuda()

# mock some tokens

mock_start_pos = 512
mock_api_call_length = 10
mock_api_start_id = 19998
mock_api_stop_id = 19999

tokens = torch.randint(0, 20000, (10, 1024)).cuda()
tokens_with_api_response = torch.randint(0, 20000, (10, 1024)).cuda()
tokens_without_api_response = torch.randint(0, 20000, (10, 1024)).cuda()

tokens_with_api_response[:, mock_start_pos] = mock_api_start_id
tokens_with_api_response[:, mock_start_pos + mock_api_call_length] = mock_api_stop_id

tokens_without_api_response[:, mock_start_pos] = mock_api_start_id
tokens_without_api_response[:, mock_start_pos + mock_api_call_length] = mock_api_stop_id

# filter

filtered_results = filter_tokens_with_api_response(
    model = palm,
    tokens = tokens,
    tokens_with_api_response = tokens_with_api_response,
    tokens_without_api_response = tokens_without_api_response,
    filter_threshold = 1.,
    api_start_token_id = mock_api_start_id,
    api_end_token_id = mock_api_stop_id
)

To invoke the tools on a string generated by the language model, use invoke_tools

from toolformer_pytorch import invoke_tools

def inc(i):
    return i + 1

def dec(i):
    return i - 1

function_registry = dict(
    inc = inc,
    dec = dec
)

text = 'make the following api calls: [inc(1)] and [dec(2)] and [ignored(3)]'

invoke_tools(function_registry, text)

# make the following api calls: [inc(1) → 2] and [dec(2) → 1] and [ignored(3)]

Todo

Citations

@inproceedings{Schick2023ToolformerLM,
    title   = {Toolformer: Language Models Can Teach Themselves to Use Tools},
    author  = {Timo Schick and Jane Dwivedi-Yu and Roberto Dessi and Roberta Raileanu and Maria Lomeli and Luke Zettlemoyer and Nicola Cancedda and Thomas Scialom},
    year    = {2023}
}
@article{Gao2022PALPL,
    title   = {PAL: Program-aided Language Models},
    author  = {Luyu Gao and Aman Madaan and Shuyan Zhou and Uri Alon and Pengfei Liu and Yiming Yang and Jamie Callan and Graham Neubig},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2211.10435}
}

Reality is that which, when you stop believing it, doesn't go away. – Philip K. Dick.

.\lucidrains\toolformer-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'toolformer-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.0.30',  # 版本号
  license='MIT',  # 许可证
  description = 'Toolformer - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/toolformer-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'automated-tool-use'
  ],
  install_requires=[  # 安装依赖
    'beartype',
    'einops>=0.4',
    'torch>=1.6',
    'tqdm',
    'x-clip>=0.14.3'
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\toolformer-pytorch\toolformer_pytorch\optimizer.py

# 从 torch.optim 模块中导入 AdamW 和 Adam 优化器
from torch.optim import AdamW, Adam

# 将参数分为需要权重衰减和不需要权重衰减的两个列表
def separate_weight_decayable_params(params):
    wd_params, no_wd_params = [], []
    for param in params:
        # 根据参数的维度判断是否需要权重衰减
        param_list = no_wd_params if param.ndim < 2 else wd_params
        param_list.append(param)
    return wd_params, no_wd_params

# 获取优化器
def get_optimizer(
    params,
    lr = 1e-4,
    wd = 1e-2,
    betas = (0.9, 0.99),
    eps = 1e-8,
    filter_by_requires_grad = False,
    group_wd_params = True,
    **kwargs
):
    # 判断是否有权重衰减
    has_weight_decay = wd > 0

    # 根据 filter_by_requires_grad 参数过滤出需要梯度的参数
    if filter_by_requires_grad:
        params = list(filter(lambda t: t.requires_grad, params))

    # 如果需要对参数进行分组并应用权重衰减
    if group_wd_params and has_weight_decay:
        wd_params, no_wd_params = separate_weight_decayable_params(params)

        # 将参数分为需要权重衰减和不需要权重衰减的两组
        params = [
            {'params': wd_params},
            {'params': no_wd_params, 'weight_decay': 0},
        ]

    # 设置 Adam 优化器的参数
    adam_kwargs = dict(lr = lr, betas = betas, eps = eps)

    # 如果不需要权重衰减,则返回 Adam 优化器
    if not has_weight_decay:
        return Adam(params, **adam_kwargs)

    # 如果需要权重衰减,则返回 AdamW 优化器
    return AdamW(params, weight_decay = wd, **adam_kwargs)

.\lucidrains\toolformer-pytorch\toolformer_pytorch\palm.py

import torch
from torch import nn, einsum
from einops import rearrange

from x_clip.tokenizer import tokenizer

# 导入所需的库

# helpers

# 定义一个辅助函数,用于检查值是否存在
def exists(val):
    return val is not None

# normalization

# 定义一个 RMS 归一化层
class RMSNorm(nn.Module):
    def __init__(self, dim, eps = 1e-8):
        super().__init__()
        self.scale = dim ** -0.5
        self.eps = eps
        self.g = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
        return x / norm.clamp(min = self.eps) * self.g

# rotary positional embedding
# https://arxiv.org/abs/2104.09864

# 定义一个旋转位置嵌入层
class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, max_seq_len, *, device):
        seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
        freqs = einsum("i , j -> i j", seq, self.inv_freq)
        return torch.cat((freqs, freqs), dim=-1)

# 旋转半个位置
def rotate_half(x):
    x = rearrange(x, "... (j d) -> ... j d", j=2)
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)

# 应用旋转位置嵌入
def apply_rotary_pos_emb(pos, t):
    return (t * pos.cos()) + (rotate_half(t) * pos.sin())

# all we need

# 定义并行 Transformer 块
class ParallelTransformerBlock(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
        super().__init__()
        self.norm = RMSNorm(dim)

        attn_inner_dim = dim_head * heads
        ff_inner_dim = dim * ff_mult
        self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim))

        self.heads = heads
        self.scale = dim_head**-0.5
        self.rotary_emb = RotaryEmbedding(dim_head)

        self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
        self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)

        self.ff_out = nn.Sequential(
            nn.GELU(),
            nn.Linear(ff_inner_dim, dim, bias=False)
        )

        # for caching causal mask and rotary embeddings

        # 注册缓存的因果掩码和旋转嵌入
        self.register_buffer("mask", None, persistent=False)
        self.register_buffer("pos_emb", None, persistent=False)

    # 获取因果掩码
    def get_mask(self, n, device):
        if self.mask is not None and self.mask.shape[-1] >= n:
            return self.mask[:n, :n]

        mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
        self.register_buffer("mask", mask, persistent=False)
        return mask

    # 获取旋转嵌入
    def get_rotary_embedding(self, n, device):
        if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
            return self.pos_emb[:n]

        pos_emb = self.rotary_emb(n, device=device)
        self.register_buffer("pos_emb", pos_emb, persistent=False)
        return pos_emb
    # 定义前向传播函数,接受输入张量 x
    def forward(self, x):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        # 获取输入张量 x 的形状信息
        n, device, h = x.shape[1], x.device, self.heads

        # 对输入张量 x 进行 LayerNorm 处理
        x = self.norm(x)

        # 使用融合的注意力和前馈神经网络投影层对输入张量 x 进行投影
        q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)

        # 将投影后的张量按照指定维度进行分割,用于多头注意力
        q = rearrange(q, "b n (h d) -> b h n d", h=h)

        # 获取旋转位置嵌入
        positions = self.get_rotary_embedding(n, device)
        q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))

        # 缩放
        q = q * self.scale

        # 计算相似度
        sim = einsum("b h i d, b j d -> b h i j", q, k)

        # 获取因果掩码
        causal_mask = self.get_mask(n, device)
        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # 注意力权重计算
        attn = sim.softmax(dim=-1)

        # 聚合值
        out = einsum("b h i j, b j d -> b h i d", attn, v)

        # 合并多头
        out = rearrange(out, "b h n d -> b n (h d)")
        # 返回注意力输出和前馈网络输出的和
        return self.attn_out(out) + self.ff_out(ff)
# Transformer 类定义
class Transformer(nn.Module):
    # 初始化函数,接受维度、深度、头数、头维度和前馈网络倍数作为参数
    def __init__(
        self, 
        dim, 
        depth, 
        heads, 
        dim_head, 
        ff_mult = 4,
    ):
        super().__init__()
        # 初始化一个空的模块列表
        self.layers = nn.ModuleList([])

        # 循环创建指定深度的 ParallelTransformerBlock,并添加到模块列表中
        for _ in range(depth):
            self.layers.append(
                ParallelTransformerBlock(dim, dim_head, heads, ff_mult), 
            )

    # 前向传播函数
    def forward(self, x):
        # 遍历模块列表中的每个块,对输入进行变换并加上原始输入
        for block in self.layers:
            x = block(x) + x
        return x


# PaLM 类定义
class PaLM(nn.Module):
    # 初始化函数,接受维度、深度、标记数、头维度、头数和前馈网络倍数作为参数
    def __init__(
        self, 
        dim, 
        depth, 
        num_tokens=tokenizer.vocab_size,
        dim_head=64, 
        heads=8, 
        ff_mult=4,
    ):
        super().__init__()
        # 创建一个嵌入层,将标记映射到指定维度
        self.emb = nn.Embedding(num_tokens, dim)

        # 创建一个 Transformer 模型
        self.transformer = Transformer(dim, depth, heads, dim_head, ff_mult)

        # 创建一个输出层,包括 RMSNorm 层和线性层
        self.to_logits = nn.Sequential(
            RMSNorm(dim),
            nn.Linear(dim, num_tokens)
        )

    # 前向传播函数
    def forward(self, x):
        # 将输入通过嵌入层映射到指定维度
        x = self.emb(x)
        # 将映射后的输入通过 Transformer 模型进行变换
        x = self.transformer(x)
        # 将变换后的结果通过输出层得到最终的 logits
        return self.to_logits(x)

# 主函数入口
if __name__ == "__main__":
    # 创建一个 PaLM 模型实例
    palm = PaLM(
        num_tokens = 20000,
        dim = 512,
        depth = 6,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
    )

    # 生成一个随机的标记序列
    tokens = torch.randint(0, 20000, (1, 512))
    # 将标记序列输入到 PaLM 模型中得到 logits
    logits = palm(tokens)
    # 打印 logits 的形状
    print(logits.shape)

.\lucidrains\toolformer-pytorch\toolformer_pytorch\prompts.py

# 默认提示输入标签
DEFAULT_PROMPT_INPUT_TAG = '[input]'

# 计算器提示信息
calculator_prompt = f"""
Your task is to add calls to a Calculator API to a piece of text.
The calls should help you get information required to complete the text. 
You can call the API by writing "[Calculator(expression)]" where "expression" is the expression to be computed. 
Here are some examples of API calls:
Input: The number in the next term is 18 + 12 x 3 = 54.
Output: The number in the next term is 18 + 12 x 3 = [Calculator(18 + 12 * 3)] 54.
Input: The population is 658,893 people. This is 11.4% of the national average of 5,763,868 people.
Output: The population is 658,893 people. This is 11.4% of the national average of [Calculator(658,893 / 11.4%)] 5,763,868 people.
Input: A total of 252 qualifying matches were played, and 723 goals were scored (an average of 2.87 per match). This is three times less than the 2169 goals last year.
Output: A total of 252 qualifying matches were played, and 723 goals were scored (an average of [Calculator(723 / 252)] 2.87 per match). This is twenty goals more than the [Calculator(723 - 20)] 703 goals last year.
Input: I went to Paris in 1994 and stayed there until 2011, so in total, it was 17 years.
Output: I went to Paris in 1994 and stayed there until 2011, so in total, it was [Calculator(2011 - 1994)] 17 years.
Input: From this, we have 4 * 30 minutes = 120 minutes.
Output: From this, we have 4 * 30 minutes = [Calculator(4 * 30)] 120 minutes.
Input: {DEFAULT_PROMPT_INPUT_TAG}
Output:
"""

# 维基百科搜索提示信息
wikipedia_search_prompt = f"""
Your task is to complete a given piece of text. 
You can use a Wikipedia Search API to look up information. 
You can do so by writing "[WikiSearch(term)]" where "term" is the search term you want to look up. 
Here are some examples of API calls:
Input: The colors on the flag of Ghana have the following meanings: red is for the blood of martyrs, green for forests, and gold for mineral wealth.
Output: The colors on the flag of Ghana have the following meanings: red is for [WikiSearch("Ghana flag red meaning")] the blood of martyrs, green for forests, and gold for mineral wealth.
Input: But what are the risks during production of nanomaterials? Some nanomaterials may give rise to various kinds of lung damage.
Output: But what are the risks during production of nanomaterials? [WikiSearch("nanomaterial production risks")] Some nanomaterials may give rise to various kinds of lung damage.
Input: Metformin is the first-line drug for patients with type 2 diabetes and obesity.
Output: Metformin is the first-line drug for [WikiSearch("Metformin first-line drug")] patients with type 2 diabetes and obesity.
Input: {DEFAULT_PROMPT_INPUT_TAG}
Output:
"""

# 机器翻译提示信息
machine_translation_prompt = f"""
Your task is to complete a given piece of text by using a Machine Translation API.
You can do so by writing "[MT(text)]" where text is the text to be translated into English.
Here are some examples:
Input: He has published one book: O homem suprimido (“The Supressed Man”)
Output: He has published one book: O homem suprimido [MT(O homem suprimido)] (“The Supressed Man”)
Input: In Morris de Jonge’s Jeschuah, der klassische jüdische Mann, there is a description of a Jewish writer
Output: In Morris de Jonge’s Jeschuah, der klassische jüdische Mann [MT(der klassische jüdische Mann)], there is a description of a Jewish writer
Input: 南 京 高 淳 县 住 房 和 城 乡 建 设 局 城 市 新 区 设 计 a plane of reference Gaochun is one of seven districts of the provincial capital Nanjing
Output: [MT(南京高淳县住房和城乡建设局 城市新 区 设 计)] a plane of reference Gaochun is one of seven districts of the provincial capital Nanjing
Input: {DEFAULT_PROMPT_INPUT_TAG}
Output:
"""

# 日历提示信息
calendar_prompt = f"""
Your task is to add calls to a Calendar API to a piece of text. 
The API calls should help you get information required to complete the text. 
You can call the API by writing "[Calendar()]" 
Here are some examples of API calls:
Input: Today is the first Friday of the year.
# 输出今天是今年的第一个星期五。
Input: The president of the United States is Joe Biden.
# 输出美国总统是乔·拜登。
Output: The president of the United States is [Calendar()] Joe Biden.
Input: The current day of the week is Wednesday.
# 输出本周的当前日期是星期三。
Output: The current day of the week is [Calendar()] Wednesday.
Input: The number of days from now until Christmas is 30.
# 输出从现在到圣诞节还有30天。
Output: The number of days from now until Christmas is [Calendar()] 30.
Input: The store is never open on the weekend, so today it is closed.
# 输出这家商店周末从不营业,所以今天它关门了。
Output: The store is never open on the weekend, so today [Calendar()] it is closed.
Input: {DEFAULT_PROMPT_INPUT_TAG}
# 输入默认提示输入标签
Output:

.\lucidrains\toolformer-pytorch\toolformer_pytorch\toolformer_pytorch.py

# 导入所需的库
import re

from functools import partial, wraps
from collections import namedtuple

import torch
from torch import nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

from einops import rearrange, reduce

# 导入自定义模块
from toolformer_pytorch.palm import PaLM
from toolformer_pytorch.optimizer import get_optimizer
from toolformer_pytorch.prompts import DEFAULT_PROMPT_INPUT_TAG

# 导入类型提示相关库
from beartype import beartype
from beartype.typing import Callable, Optional, Union, List, Tuple

from tqdm import tqdm
from x_clip.tokenizer import tokenizer

# 设置 pad_sequence 函数的 batch_first 参数为 True
pad_sequence = partial(pad_sequence, batch_first = True)

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 返回输入值
def identity(t):
    return t

# 返回固定值的函数
def always(val):
    def inner(*args, **kwargs):
        return val
    return inner

# 尝试执行函数,捕获异常并执行回调函数
def try_except(fn, callback = identity):
    @wraps(fn)
    def inner(*args):
        try:
            return fn(*args)
        except Exception as e:
            return callback(e)
    return inner

# 张量操作函数

# 对数函数
def log(t, eps = 1e-20):
    return t.clamp(min = eps).log()

# 生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 生成 Gumbel 分布采样
def gumbel_sample(t, temperature = 1., dim = -1, eps = 1e-10):
    if temperature == 0:
        return t.argmax(dim = dim)

    return ((t / max(temperature, eps)) + gumbel_noise(t)).argmax(dim = dim)

# 保留前 k 个最大值,其余设为负无穷
def top_k(logits, thres = 0.9):
    k = math.ceil((1 - thres) * logits.shape[-1])
    val, indices = torch.topk(logits, k)
    probs = torch.full_like(logits, -torch.finfo(logits.dtype).max)
    probs.scatter_(1, indices, val)
    return probs

# 检查张量是否包含指定值
def all_contains_id(t: torch.Tensor, token_id: int):
    mask = t == token_id
    return mask.any(dim = -1).all()

# 查找指定值在张量中的索引
def find_indices_of(t: torch.Tensor, token_id: int, occurrence = 1):
    assert occurrence > 0
    mask = (t == token_id)

    has_occurred = mask.cumsum(dim = -1)
    has_occurred = F.pad(has_occurred, (1, 0), value = 0.)

    return (has_occurred < occurrence).sum(dim = -1).long()

# 调用 API 调用函数

# 检查字符串是否为有效格式
def is_valid_string(s):
    return exists(re.fullmatch(r"'[^']*'|\"[^\"]*\"", s))

# 检查整数是否为有效格式
def is_valid_integer(s):
    return exists(re.fullmatch(r"[+-]?\d+", s))

# 检查浮点数是否为有效格式
def is_valid_float(s):
    return exists(re.fullmatch(r"[+-]?\d+(\.\d+)?", s))

# 解析参数字符串为整数、浮点数或字符串
def parse_param(s: str) -> Optional[Union[int, float, str]]:
    if is_valid_string(s):
        return str(s)
    elif is_valid_integer(s):
        return int(s)
    elif is_valid_float(s):
        return float(s)

    return None

# 替换函数,根据注册的函数执行相应的 API 调用
@beartype
def replace_fn(
    registry: dict[str, Callable],
    matches,
    delimiter = '→'
):
    orig_text = matches.group(0)

    text_without_end_api_token = matches.group(1)
    end_api_token = matches.group(4)
    function_name = matches.group(2)

    # 如果注册表中找不到函数,则返回原始文本
    if function_name not in registry:
        return orig_text

    fn = registry[function_name]

    params = matches.group(3).split(',')
    params = list(map(lambda s: s.strip(), params))
    params = list(filter(len, params))
    params = list(map(parse_param, params))

    # 如果参数中有无法解析的部分,则返回原始文本
    if any([(not exists(p)) for p in params]):
        return orig_text

    # 尝试执行函数,如果出现错误则返回 None
    out = try_except(fn, always(None))(*params)

    # 如果输出为 None,则返回原始文本
    if not exists(out):
        return orig_text

    # 返回带有输出分隔符和字符串化输出的原始文本
    return f'{text_without_end_api_token} {delimiter} {str(out)} {end_api_token}'

# 主函数,接受函数注册表、文本和进行 API 调用并附加输出
def create_function_regex(
    api_start = ' [',
    api_stop = ']'
):
    # 将 api_start 和 api_stop 进行转义,得到转义后的正则表达式字符串
    api_start_regex, api_stop_regex = map(re.escape, (api_start, api_stop))
    # 返回一个包含转义后的 api_start 和 api_stop 的正则表达式字符串
    return rf'({api_start_regex}(\w+)\(([^)]*)\))({api_stop_regex})'
# 计算子字符串在文本中出现的次数
def num_matches(substr: str, text: str):
    return len(re.findall(re.escape(substr), text))

# 检查文本中是否存在 API 调用
def has_api_calls(
    text,
    api_start = ' [',
    api_stop = ']'
):
    # 创建 API 调用的正则表达式
    regex = create_function_regex(api_start, api_stop)
    # 查找匹配的 API 调用
    matches = re.findall(regex, text)
    return len(matches) > 0

# 替换除第一个外的所有 API 调用
def replace_all_but_first(
    text: str,
    api_start = ' [',
    api_stop = ']'
) -> str:
    # 创建 API 调用的正则表达式
    regex = create_function_regex(api_start, api_stop)

    count = 0

    def replace_(matches):
        orig_text = matches.group(0)
        nonlocal count
        count += 1
        if count > 1:
            return ''
        return orig_text

    return re.sub(regex, replace_, text)

# 在文本中调用工具函数
def invoke_tools(
    registry: dict[str, Callable],
    text: str,
    delimiter: str = '→',
    api_start = ' [',
    api_stop = ' ]'
) -> str:
    # 创建 API 调用的正则表达式
    regex = create_function_regex(api_start, api_stop)
    replace_ = partial(replace_fn, registry, delimiter = delimiter)
    return re.sub(regex, replace_, text)

# 在批量序列上调用工具函数
def invoke_tools_on_batch_sequences(
    registry: dict[str, Callable],
    token_ids: torch.Tensor,
    *,
    encode: Callable,
    decode: Callable,
    delimiter: str = '→',
    api_start = ' [',
    api_stop = ']'
) -> torch.Tensor:
    regex = create_function_regex(api_start_regex, api_stop_regex)
    all_texts = [decode(one_seq_token_ids) for one_seq_token_ids in token_ids]

    invoke_tools_ = partial(invoke_tools, api_start = api_start, api_stop = api_stop)
    all_texts_with_api_calls = [invoke_tools_(registry, text, delimiter) for text in all_texts]

    return encode(all_texts_with_api_calls)

# 采样 API 相关函数
# 它们进行贪婪采样,但通过在前 k = 10 中自动选择 <api> 标记来鼓励采样 API 调用

@beartype
@torch.no_grad()
def sample(
    model: nn.Module,
    *,
    seq_len,
    prime: Optional[torch.Tensor] = None,
    positions: Optional[torch.Tensor] = None,
    batch_size = 1,
    eos_token_id = None,
    sos_token_id = 1,
    temperature = 0.,
    pad_id = 0,
    call_api_only_once = False,
    api_start_token_id = None,
    auto_select_api_start_token_when_topk = False,
    select_api_start_id_top_k = 10,
):
    device = next(model.parameters()).device
    max_seq_len = seq_len + 1

    # 验证

    if call_api_only_once:
        assert exists(api_start_token_id)

    # 初始化

    if exists(prime):
        batch_size, prime_length = prime.shape
    else:
        prime_length = 1
        prime = torch.full((batch_size, 1), sos_token_id, device = device, dtype = torch.long)

    prime = prime.to(device)

    # 采样位置 - 不同序列有不同的游标

    if exists(positions):
        positions = positions.clone()
    else:
        positions = torch.zeros((batch_size,), device = device, dtype = torch.long)

    assert (positions <= (prime_length + 1)).all() and (positions <= max_seq_len).all(), '所有位置必须小于初始主长度以及总序列长度 + 1(如果一个序列在另一个序列之前完成采样,则加一)'

    # 评估模型

    model.eval()

    # 将主长度延长到整个序列长度

    remain_iterations = seq_len - prime_length
    output = F.pad(prime, (0, max_seq_len - prime_length), value = 0.)

    batch_indices = torch.arange(batch_size, device = device)
    batch_indices = rearrange(batch_indices, 'b -> b 1')
    position_indices = rearrange(positions, 'b -> b 1')

    # 确定 <api> 标记掩码,以确保只调用一次 API,屏蔽对数以防止它被选择为那些已经包含 <api> 标记的行

    api_token_mask = None # 懒惰创建,因为不知道对数维度

    def create_api_token_mask(num_tokens, api_start_token_id):
        mask = torch.zeros((1, 1, num_tokens), dtype = torch.bool)
        assert api_start_token_id < num_tokens
        mask[..., api_start_token_id] = True
        return mask

    # 开始迭代
    # 对于剩余的迭代次数,循环执行以下操作
    for iteration in tqdm(range(remain_iterations):
        # 使用模型预测输出
        logits = model(output)
        # 获取最后一个位置的logits
        last_logits = logits[batch_indices, position_indices]

        # 确保每个批次的令牌序列最多只有一个<api>令牌
        if call_api_only_once:
            # 如果api_token_mask不存在,则创建一个
            if not exists(api_token_mask):
                num_tokens = last_logits.shape[-1]
                api_token_mask = create_api_token_mask(num_tokens, api_start_token_id)
                api_token_mask = api_token_mask.to(device)

            # 检查是否调用了api
            api_called = (output == api_start_token_id).any(dim=-1)

            # 创建logit_mask,用于标记需要被替换的位置
            logit_mask = api_token_mask & rearrange(api_called, 'b -> b 1 1')
            last_logits = last_logits.masked_fill(logit_mask, -torch.finfo(last_logits.dtype).max)

        # 使用贪婪采样(也可以选择非贪婪)
        sampled = gumbel_sample(last_logits, temperature=temperature)

        # 对于那些没有api调用的序列,如果api_start_token_id在logits的前k个(设置为10)中,则自动选择
        if auto_select_api_start_token_when_topk:
            top_token_ids = last_logits.topk(select_api_start_id_top_k, dim=-1).indices
            has_api_token_in_topk = (top_token_ids == api_start_token_id).any(dim=-1)
            should_auto_select_api_token = has_api_token_in_topk & ~rearrange(api_called, 'b -> b 1')

            sampled = sampled.masked_fill(should_auto_select_api_token, api_start_token_id)

        # 将采样的令牌放置在正确的光标位置
        output[batch_indices, position_indices] = sampled

        # 增加位置索引
        position_indices += 1
        position_indices.clamp_(max=seq_len)  # 如果一个序列更靠后且接近结尾,则不执行任何操作

        # 如果使用<eos>令牌,查找所有包含它的序列并终止,<eos>之后的内容将被填充
        if exists(eos_token_id):
            eos_mask = (output == eos_token_id)
            all_rows_have_eos = eos_mask.any(dim=-1).all()

            if all_rows_have_eos:
                keep_mask = eos_mask.cumsum(dim=-1) == 0
                keep_mask = F.pad(keep_mask, (1, 0), value=True)
                output = output.masked_fill(~keep_mask, pad_id)
                break

    # 移除输出中的最后一个令牌(作为无操作占位符)
    output = output[:, :-1]
    return output
# 使用 beartype 装饰器对函数进行类型检查
# 使用 torch.no_grad() 上下文管理器,禁用梯度计算
@beartype
@torch.no_grad()
# 从模型中生成序列,调用 API 并返回结果
def sample_with_api_call(
    model: nn.Module,
    *,
    seq_len,  # 序列长度
    call_apis: Callable,  # 调用 API 的函数
    prime: torch.Tensor,  # 初始张量
    api_end_token_id: int,  # API 结束标记的 ID
    occurrence = 1,  # API 出现次数
    **kwargs  # 其他关键字参数
):
    # 生成初始序列
    sampled = sample(
        model = model,
        prime = prime,
        seq_len = seq_len,
        **kwargs
    )

    # 调用 API 处理生成的序列
    sampled = call_apis(sampled)

    # 获取处理后序列的长度
    sampled_seq_len = sampled.shape[-1]
    null_positions = sampled_seq_len  # 处理不包含 API 调用的序列

    # 查找 API 结束标记的位置
    pos_starting_at_end_of_api = find_indices_of(
        sampled,
        api_end_token_id,
        occurrence = occurrence
    )

    # 重新生成序列,从 API 结束位置开始
    resample_after_api_calls = sample(
        model = model,
        prime = sampled,
        seq_len = sampled_seq_len,
        positions = (pos_starting_at_end_of_api + 1).clamp(max = null_positions),  # 从 </api> 后的位置开始
        **kwargs
    )

    return resample_after_api_calls

# 论文的主要贡献在于第 2 节中提出的过滤方程

# 默认的权重函数
def default_weight_fn(t):
    # 根据第 4.1 节中的公式计算权重,不确定分母中的 w_s 是什么
    # 如果 t 代表每个时间步,则在 5 个标记内会减少到 0?
    return (1. - t * 0.2).clamp(min = 0.)

# 获取预测概率
def get_pred_prob(token_ids, logits):
    logits = logits[:, :-1]  # 每个标记的 logits(省略最后一个 logits)
    token_ids = token_ids[:, 1:]  # 预测下一个标记的 ID(省略第一个标记的 ID)

    token_ids = rearrange(token_ids, 'b n -> b n 1')
    probs = logits.softmax(dim = -1)
    correct_token_id_pred_prob = probs.gather(-1, token_ids)
    return rearrange(correct_token_id_pred_prob, 'b n 1 -> b n')

# 获取从特定标记开始的索引
def get_arange_start_at_token_id(
    token_ids: torch.Tensor,
    token_id: int,
    pad_id = -1
):
    is_token_id_mask = token_ids == token_id
    arange = (is_token_id_mask.cumsum(dim = -1) > 0).cumsum(dim = -1)
    before_token_mask = arange == 0
    arange = arange - 1
    arange = arange.masked_fill(before_token_mask, pad_id)
    return arange

# 计算权重和掩码
def weight_and_mask(
    token_ids: torch.Tensor,
    token_id: int,
    pad_id = -1,
    weighting_fn: Callable = default_weight_fn
):
    t = get_arange_start_at_token_id(token_ids, token_id, pad_id)
    weights = weighting_fn(t)
    return weights.masked_fill(t == pad_id, 0.)

# 定义过滤结果的命名元组
FilteredResults = namedtuple('FilteredResults', [
    'num_passed',
    'num_failed',
    'selected_indices',
    'selected_mask',
    'filtered_tokens',
    'filtered_tokens_without_api_response',
    'filtered_tokens_with_api_response'
])

# 过滤带有 API 响应的标记
@beartype
def filter_tokens_with_api_response(
    model: nn.Module,  # 语言模型应接受下面的标记并返回形状为 (batch, seq, num tokens) 的 logits
    *,
    tokens: torch.Tensor,  # 原始段落的标记 ID(不包含 API 调用)
    tokens_without_api_response: torch.Tensor,  # 包含 API 调用但没有填充响应的段落的标记 ID - <api>tool1(x, y)</api>
    tokens_with_api_response: torch.Tensor,  # 包含 API 调用和响应的段落的标记 ID - <api>tool1(x, y) → {response}</api>
    api_start_token_id: int,  # <api> 标记的 ID
    api_end_token_id: int,  # </api> 标记的 ID
    filter_threshold: float = 1.,  # 接受采样的 API 调用的阈值(tokens_with_api_response)用于微调
    weighting_fn: Callable = default_weight_fn  # 权重函数
) -> FilteredResults:

    # 验证

    assert all([*map(lambda t: t.dtype == torch.long, (tokens, tokens_with_api_response, tokens_without_api_response))])

    assert all_contains_id(tokens_without_api_response, api_start_token_id)
    assert all_contains_id(tokens_without_api_response, api_end_token_id)
    # 确保所有的 tokens_with_api_response 中包含 api_start_token_id
    assert all_contains_id(tokens_with_api_response, api_start_token_id)
    # 确保所有的 tokens_with_api_response 中包含 api_end_token_id
    assert all_contains_id(tokens_with_api_response, api_end_token_id)

    # 自动设置设备

    # 获取模型参数的设备
    device = next(model.parameters()).device
    # 将 tokens, tokens_without_api_response, tokens_with_api_response 移动到指定设备上
    tokens, tokens_without_api_response, tokens_with_api_response = map(lambda t: t.to(device), (tokens, tokens_without_api_response, tokens_with_api_response))

    # 获取所有的 logits

    with torch.no_grad():
        # 设置模型为评估模式
        model.eval()
        # 获取 logits, logits_without_api_response, logits_with_api_response
        logits, logits_without_api_response, logits_with_api_response = map(model, (tokens, tokens_without_api_response, tokens_with_api_response))

    # 推导出序列中实际下一个 token id 的所有预测概率

    probs                       = get_pred_prob(tokens, logits)
    probs_without_api_response  = get_pred_prob(tokens_without_api_response, logits_without_api_response)
    probs_with_api_response     = get_pred_prob(tokens_with_api_response, logits_with_api_response)

    weight_and_mask_fn = partial(weight_and_mask, weighting_fn = weighting_fn)

    # 推导权重

    weight_without_api_response = weight_and_mask_fn(tokens_without_api_response[:, :-1], api_end_token_id)
    weight_with_api_response = weight_and_mask_fn(tokens_with_api_response[:, :-1], api_end_token_id)

    # 推导原始 passage 的权重更加复杂
    # 需要从 <api> 开始标记的位置开始计数
    # 这也假设语言模型完美地复制了 passage,并且两个 token id 对齐,除了插入的 API 调用 - 但最终可以通过自定义过滤函数完成

    weight = weight_and_mask_fn(tokens_without_api_response[:, 1:], api_start_token_id) # 左移一个位置,因为原始序列中不存在 <api>
    weight = weight[:, :probs.shape[-1]]

    # 获取所有三种序列的损失 L

    def loss_fn(weight, probs):
        return (weight * -log(probs)).sum(dim = -1)

    loss = loss_fn(weight, probs)
    loss_without_api_response = loss_fn(weight_without_api_response, probs_without_api_response)
    loss_with_api_response = loss_fn(weight_with_api_response, probs_with_api_response)

    # 计算论文中的主要公式

    # loss+ = 带有 api 响应的损失
    # loss- = 最小值(没有 api 响应的损失, 没有 api 的损失)

    loss_plus = loss_with_api_response
    loss_minus = torch.minimum(loss_without_api_response, loss)

    selected_mask = (loss_minus - loss_plus) >= filter_threshold

    # 现在我们可以选择并返回经过过滤阶段幸存下来的条目
    # 同时返回正在处理的批次的选定索引
    # 用于将模型微调为 toolformer

    batch = tokens.shape[0]
    indices = torch.arange(batch, device = tokens.device)

    selected_indices = indices[selected_mask]

    ret = FilteredResults(
        selected_mask.sum().item(),
        (~selected_mask).sum().item(),
        selected_indices,
        selected_mask,
        tokens[selected_mask],
        tokens_without_api_response[selected_mask],
        tokens_with_api_response[selected_mask]
    )

    return ret
# datasets and dataloaders

# 用于通过 API 调用引导初始数据集以及最终微调

# 定义 PromptDataset 类,继承自 Dataset 类
@beartype
class PromptDataset(Dataset):
    # 初始化方法
    def __init__(
        self,
        prompt: str,
        prompt_input_tag: str,
        data: List[str],
        tokenizer_encode: Callable
    ):
        # 初始化数据集、提示、提示输入标签的正则表达式、编码器
        self.data = data
        self.prompt = prompt
        self.prompt_input_tag_regex = re.escape(prompt_input_tag)
        self.tokenizer_encode = tokenizer_encode

    # 返回数据集长度
    def __len__(self):
        return len(self.data)

    # 获取指定索引的数据
    def __getitem__(self, idx):
        data_string = self.data[idx]
        data_with_prompt = re.sub(self.prompt_input_tag_regex, data_string, self.prompt)
        token_ids = self.tokenizer_encode(data_with_prompt)
        return torch.tensor(token_ids).long(), torch.tensor(len(token_ids)).long()

# 定义 prompt_collate_fn 函数,用于数据集的填充
def prompt_collate_fn(data, padding_value = 0):
    prompts, prompt_lengths = zip(*data)
    prompts = pad_sequence(prompts, padding_value = padding_value)
    return prompts, torch.stack(prompt_lengths)

# 定义 PromptDataloader 函数,用于创建数据加载器
def PromptDataloader(ds: Dataset, *args, padding_value = 0, **kwargs):
    collate_fn = partial(prompt_collate_fn, padding_value = padding_value)
    return DataLoader(ds, *args, collate_fn = collate_fn, **kwargs)

# 定义 FinetuneDataset 类,继承自 Dataset 类
class FinetuneDataset(Dataset):
    # 初始化方法
    def __init__(
        self,
        tokens: torch.Tensor
    ):
        # 初始化 tokens
        self.tokens = tokens

    # 返回数据集长度
    def __len__(self):
        return len(self.tokens)

    # 获取指定索引的数据
    def __getitem__(self, idx):
        return self.tokens[idx]

# 定义 FinetuneDataloader 函数,用于创建微调数据加载器
def FinetuneDataloader(ds: Dataset, *args, padding_value = 0, **kwargs):
    return DataLoader(ds, *args, collate_fn = partial(pad_sequence, padding_value = padding_value), **kwargs)

# classes

# 定义 Toolformer 类,继承自 nn.Module 类
@beartype
class Toolformer(nn.Module):
    # 初始化方法
    def __init__(
        self,
        model: nn.Module,
        *,
        tool_id: str,
        tool: Callable,
        api_start_str = ' [',
        api_stop_str = ']',
        api_response_delimiter = '→',
        api_start_id = None,
        api_stop_id = None,
        teach_tool_prompt: str,
        filter_threshold = 1.,
        pad_id = 0,
        prompt_batch_size = 4,
        model_seq_len = 2048,
        tokenizer_encode: Callable = tokenizer.encode,
        tokenizer_decode: Callable = tokenizer.decode,
        post_prompt_callback: Callable = identity,
        prompt_input_tag: str = DEFAULT_PROMPT_INPUT_TAG,
        exclude_filters: dict[str, Callable[[str], bool]] = dict(),
        finetune = False,
        finetune_lr = 1e-4,
        finetune_wd = 1e-2,
        finetune_betas = (0.9, 0.99),
        finetune_eps = 1e-8,
        finetune_epochs = 3,
        finetune_batch_size = 16
    # 初始化函数,设置模型、模型序列长度、教学工具提示、提示批量大小、提示输入标签等参数
    ):
        super().__init__()
        self.model = model
        self.model_seq_len = model_seq_len

        self.teach_tool_prompt = teach_tool_prompt
        self.prompt_batch_size = prompt_batch_size
        self.prompt_input_tag = prompt_input_tag

        self.post_prompt_callback = post_prompt_callback # for easy mocking

        self.tokenizer_encode = tokenizer_encode
        self.tokenizer_decode = tokenizer_decode
        self.tokenizer_encode_to_tensor = lambda s: torch.tensor(tokenizer_encode(s)).long()

        self.filter_threshold = filter_threshold

        self.api_start_str = api_start_str
        self.api_stop_str = api_stop_str
        self.api_response_delimiter = api_response_delimiter

        # 如果不存在api_start_id,则根据api_start_str进行编码
        if not exists(api_start_id):
            api_start_id = tokenizer_encode(api_start_str)
            assert len(api_start_id) == 1
            api_start_id = api_start_id[0]

        self.api_start_id = api_start_id

        # 如果不存在api_stop_id,则根据api_stop_str进行编码
        if not exists(api_stop_id):
            api_stop_id = tokenizer_encode(api_stop_str)
            assert len(api_stop_id) == 1
            api_stop_id = api_stop_id[0]

        self.api_stop_id = api_stop_id

        self.pad_id = pad_id

        self.tool_id = tool_id
        self.tool = tool
        self.registry = {tool_id: tool}

        # 确保在提示中只有一个指定的提示输入标签
        assert num_matches(prompt_input_tag, teach_tool_prompt) == 1, f'there must be exactly one prompt input tag `{prompt_input_tag}` in your prompt to encourage the language model to use the designated tool'

        self.teach_tool_prompt = teach_tool_prompt
        self.exclude_filters = exclude_filters

        self.should_finetune = finetune

        # 如果不需要微调,则直接返回
        if not finetune:
            return

        self.finetune_batch_size = finetune_batch_size
        self.finetune_epochs = finetune_epochs

        # 获取优化器
        self.optimizer = get_optimizer(
            model.parameters(),
            lr = finetune_lr,
            wd = finetune_wd,
            betas = finetune_betas,
            eps = finetune_eps
        )

    # 生成带有API调用的数据
    def generate_data_with_api_calls(
        self,
        data: List[str],
        temperature: float = 0.9
    ) -> List[str]:

        # 创建PromptDataset对象
        dataset = PromptDataset(
            data = data,
            prompt_input_tag = self.prompt_input_tag,
            prompt = self.teach_tool_prompt,
            tokenizer_encode = self.tokenizer_encode
        )

        # 创建PromptDataloader对象
        dl = PromptDataloader(
            dataset,
            batch_size = self.prompt_batch_size
        )

        prompted_outputs = []

        # 遍历数据加载器
        for prime, positions in dl:

            # 对模型进行采样
            sampled_outputs = sample(
                model = self.model,
                prime = prime,
                positions = positions,
                seq_len = self.model_seq_len,
                pad_id = self.pad_id,
                temperature = temperature
            )

            # 解码采样输出并添加到结果列表中
            for sample_output, position in zip(sampled_outputs, positions):
                start_position = position.item()

                prompted_output = self.tokenizer_decode(sample_output[start_position:])
                prompted_outputs.append(prompted_output)

        # 调用后处理回调函数
        return self.post_prompt_callback(prompted_outputs)

    # 过滤并仅保留第一个API调用
    def filter_and_keep_only_first_api_call(
        self,
        data,
        data_with_api_calls: List[str],
        return_excluded = False
    # 初始化包含数据和包含 API 调用数据的空列表
    included_data = []
    included_data_with_api_calls = []

    # 将包含数据和包含 API 调用数据组成元组
    included = (included_data, included_data_with_api_calls)

    # 初始化排除数据和排除 API 调用数据的空列表
    excluded_data = []
    excluded_data_with_api_calls = []

    # 将排除数据和排除 API 调用数据组成元组
    excluded = (excluded_data, excluded_data_with_api_calls)

    # 设置 API 调用开始和结束参数
    api_start_stop_kwargs = dict(api_start=self.api_start_str, api_stop=self.api_stop_str)

    # 创建部分函数,用于检查是否存在 API 调用和替换除第一个外的所有 API 调用
    has_api_calls_ = partial(has_api_calls, **api_start_stop_kwargs)
    replace_all_but_first_ = partial(replace_all_but_first, **api_start_stop_kwargs)

    # 遍历数据和数据中包含 API 调用的元组
    for datum, data_with_api_call in zip(data, data_with_api_calls):
        # 如果数据中包含 API 调用
        if has_api_calls_(data_with_api_call):
            # 替换除第一个外的所有 API 调用
            data_with_api_call = replace_all_but_first_(data_with_api_call)

            # 将数据和数据中包含 API 调用添加到包含列表中
            included_data.append(datum)
            included_data_with_api_calls.append(data_with_api_call)
        else:
            # 将数据和数据中包含 API 调用添加到排除列表中
            excluded_data.append(datum)
            excluded_data_with_api_calls.append(data_with_api_call)

    # 如果不返回排除数据,则返回包含数据
    if not return_excluded:
        return included

    # 返回包含数据和排除数据
    return included, excluded

@torch.no_grad()
def sample_model_with_api_calls(
    self,
    prime: Union[torch.Tensor, str],
    occurrence=1,
    **kwargs
):
    # 将模型设置为评估模式
    self.model.eval()

    # 检查 prime 是否为字符串
    prime_is_str = isinstance(prime, str)

    # 如果 prime 是字符串
    if prime_is_str:
        # 对 prime 进行编码和转换为张量
        prime = self.tokenizer_encode(prime)
        prime = torch.tensor(prime).long()
        prime = rearrange(prime, 'n -> 1 n')

    # 断言 prime 的形状为 (1, n)
    assert prime.shape[0] == 1, 'only one at a time for now'

    # 创建部分函数,用于调用工具函数
    invoke_tools_ = partial(invoke_tools, self.registry)

    # 定义调用 API 函数
    def call_apis(t: torch.Tensor):
        t = self.tokenizer_decode(t[0])
        t = invoke_tools_(t)
        t = self.tokenizer_encode_to_tensor(t)
        return rearrange(t, 'n -> 1 n')

    # 使用带有 API 调用的模型进行采样
    output = sample_with_api_call(
        model=self.model,
        prime=prime,
        seq_len=self.model_seq_len,
        call_apis=call_apis,
        api_end_token_id=self.api_stop_id,
        occurrence=occurrence,
        **kwargs
    )

    # 如果 prime 不是字符串,则返回输出
    if not prime_is_str:
        return output

    # 将输出解码为字符串并返回
    return self.tokenizer_decode(output[0])

# 执行 API 调用
def make_api_calls(
    self,
    filtered_data_with_api_calls: List[str]
):
    # 创建部分函数,用于调用工具函数
    invoke_tools_ = partial(
        invoke_tools,
        self.registry,
        api_start=self.api_start_str,
        api_stop=self.api_stop_str,
        delimiter=self.api_response_delimiter
    )

    # 对过滤后的数据进行 API 调用
    data_with_api_responses = []
    for data in filtered_data_with_api_calls:
        output = invoke_tools_(data)
        data_with_api_responses.append(output)

    # 返回包含 API 响应的数据
    return data_with_api_responses

# 根据 API 响应过滤数据
def filter_by_api_responses(
    self,
    data: List[str],
    data_with_api_calls: List[str],
    data_with_api_responses: List[str]
) -> FilteredResults:

    # 定义将列表转换为张量的函数
    to_token_ids = lambda l: pad_sequence([*map(self.tokenizer_encode_to_tensor, l)], padding_value=self.pad_id)

    # 将数据转换为张量
    tokens, tokens_without_api_response, tokens_with_api_response = map(to_token_ids, (data, data_with_api_calls, data_with_api_responses))

    # 过滤带有 API 响应的结果
    filtered_results = filter_tokens_with_api_response(
        model=self.model,
        tokens=tokens,
        tokens_with_api_response=tokens_with_api_response,
        tokens_without_api_response=tokens_without_api_response,
        filter_threshold=self.filter_threshold,
        api_start_token_id=self.api_start_id,
        api_end_token_id=self.api_stop_id
    )

    # 返回过滤后的结果
    return filtered_results

# 微调模型
def finetune(
    self,
    filtered_results: Union[FilteredResults, torch.Tensor]
    # 设置模型为训练模式
    ):
        self.model.train()

        # 如果filtered_results是FilteredResults类型,则将其转换为没有API响应的过滤后结果
        if isinstance(filtered_results, FilteredResults):
            filtered_results = filtered_results.filtered_tokens_without_api_response

        # 创建用于微调的数据集
        dataset = FinetuneDataset(tokens = filtered_results)
        # 创建用于微调的数据加载器
        dl = FinetuneDataloader(dataset, batch_size = self.finetune_batch_size, shuffle = True)

        # 循环微调epochs次数
        for epoch in tqdm(range(self.finetune_epochs), desc = 'finetune epochs'):
            # 遍历数据加载器中的每个批次
            for batch in dl:
                # 将输入和标签分别赋值为批次中的前n-1列和最后一列
                inp, labels = batch[:, :-1], batch[:, 1:]

                # 使用模型进行前向传播
                logits = self.model(inp)
                # 重新排列logits的维度
                logits = rearrange(logits, 'b n c -> b c n')

                # 计算交叉熵损失
                loss = F.cross_entropy(logits, labels, ignore_index = self.pad_id)
                # 反向传播计算梯度
                loss.backward()

                # 打印损失值
                print(f'loss: {loss.item()}')
                # 更新优化器参数
                self.optimizer.step()
                # 梯度清零
                self.optimizer.zero_grad()

        # 打印微调结束信息
        print(f'finished finetuning on {len(dataset)} filtered samples')

    # 前向传播函数
    def forward(
        self,
        data: List[str],
        return_after_generating_api_calls = False,
        return_after_making_api_calls = False,
        return_after_filtering_api_calls = False,
        return_after_filtering_by_api_response = False
    ):
        # 生成带有API调用的数据
        data_with_api_calls = self.generate_data_with_api_calls(data)

        # 如果需要在生成API调用后返回数据,则直接返回
        if return_after_generating_api_calls:
            return data_with_api_calls

        # 过滤数据并保留第一个API调用
        filtered_data, filtered_data_with_api_calls = self.filter_and_keep_only_first_api_call(data, data_with_api_calls)

        # 如果需要在过滤API调用后返回数据,则直接返回
        if return_after_filtering_api_calls:
            return filtered_data, filtered_data_with_api_calls

        # 断言过滤后的数据中至少有一个API调用
        assert len(filtered_data_with_api_calls) > 0, 'your model failed to follow instructions and make API calls. please try a better model or do some better prompt engineering'

        # 进行API调用
        data_with_responses = self.make_api_calls(filtered_data_with_api_calls)

        # 如果需要在进行API调用后返回数据,则直接返回
        if return_after_making_api_calls:
            return filtered_data, filtered_data_with_api_calls, data_with_responses

        # 根据API响应过滤数据
        filtered_results = self.filter_by_api_responses(filtered_data, filtered_data_with_api_calls, data_with_responses)

        # 如果需要在根据API响应过滤数据后返回数据,则直接返回
        if return_after_filtering_by_api_response:
            return filtered_results

        # 如果需要微调模型
        if self.should_finetune:
            # 断言通过API调用的数据数量大于0
            assert filtered_results.num_passed > 0, f'none of the sequences with API calls passed the filtering criteria with threshold {self.filter_threshold}'

            # 进行��调
            self.finetune(filtered_results)

        # 返回过滤后的结果
        return filtered_results
posted @ 2024-06-28 14:12  绝不原创的飞龙  阅读(0)  评论(0编辑  收藏  举报