Lucidrains-系列项目源码解析-四十一-

Lucidrains 系列项目源码解析(四十一)

.\lucidrains\soundstorm-pytorch\soundstorm_pytorch\trainer.py

# 导入必要的模块
from pathlib import Path
import re
from shutil import rmtree

# 导入 beartype 模块及相关类型
from beartype import beartype
from beartype.typing import Optional

# 导入 PyTorch 相关模块
import torch
from torch import nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, random_split

# 导入自定义模块
from audiolm_pytorch.data import get_dataloader
from audiolm_pytorch.optimizer import get_optimizer

from soundstorm_pytorch.soundstorm import SoundStorm

# 导入加速器模块及分布式类型
from accelerate import Accelerator, DistributedType

# 定义一些辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 空操作函数
def noop(*args, **kwargs):
    pass

# 生成数据循环
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 将输入转换为元组
def cast_tuple(t):
    return t if isinstance(t, (tuple, list)) else (t,)

# 询问用户是或否
def yes_or_no(question):
    answer = input(f'{question} (y/n) ')
    return answer.lower() in ('yes', 'y')

# 累积日志信息
def accum_log(log, new_logs):
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

# 从检查点文件名中获取训练步数
def checkpoint_num_steps(checkpoint_path):
    """Returns the number of steps trained from a checkpoint based on the filename.

    Filename format assumed to be something like "/path/to/soundstorm.20000.pt" which is
    for 20k train steps. Returns 20000 in that case.
    """
    results = re.findall(r'\d+', str(checkpoint_path)

    if len(results) == 0:
        return 0

    return int(results[-1])

# 定义 SoundStormTrainer 类
class SoundStormTrainer(nn.Module):
    @beartype
    def __init__(
        self,
        model: SoundStorm,
        *,
        num_train_steps,
        num_warmup_steps,
        batch_size,
        dataset: Optional[Dataset] = None,
        only_train_generator = False,
        only_train_critic = False,
        lr = 3e-4,
        initial_lr = 1e-5,
        grad_accum_every = 1,
        wd = 0.,
        max_grad_norm = 0.5,
        valid_frac = 0.05,
        random_split_seed = 42,
        save_results_every = 100,
        save_model_every = 1000,
        results_folder = './results',
        accelerate_kwargs: dict = dict(),
        split_batches = False,
        drop_last = False,
        force_clear_prev_results = None
    # 初始化函数,继承父类的初始化方法
    ):
        super().__init__()

        # 初始化加速器对象
        self.accelerator = Accelerator(
            split_batches = split_batches,
            **accelerate_kwargs
        )

        # 设置模型
        self.model = model

        # 注册缓冲区,存储训练步数
        self.register_buffer('steps', torch.Tensor([0]))

        # 设置训练步数、预热步数、批量大小、梯度累积步数等参数
        self.num_train_steps = num_train_steps
        self.num_warmup_steps = num_warmup_steps
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every
        
        self.only_train_generator = only_train_generator
        self.only_train_critic = only_train_critic

        # 初始化优化器
        self.optim = get_optimizer(
            model.parameters(),
            lr = lr,
            wd = wd
        )

        self.lr = lr
        self.initial_lr = initial_lr
        # 设置学习率调度器为余弦退火调度器
        self.scheduler = CosineAnnealingLR(self.optim, T_max = num_train_steps)

        # 设置梯度裁剪阈值
        self.max_grad_norm = max_grad_norm

        # 创建数据集
        self.ds = dataset

        # 划分验证集
        if valid_frac > 0:
            train_size = int((1 - valid_frac) * len(self.ds))
            valid_size = len(self.ds) - train_size
            self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
            self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
        else:
            self.valid_ds = self.ds
            self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')

        # 断言确保数据集和验证集的样本数足够
        assert len(self.ds) >= batch_size, 'dataset must have sufficient samples for training'
        assert len(self.valid_ds) >= batch_size, f'validation dataset must have sufficient number of samples (currently {len(self.valid_ds)}) for training'

        # 创建数据加载器
        self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True, drop_last = drop_last)
        self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True, drop_last = drop_last)

        # 使用加速器准备模型、优化器、调度器、数据加载器
        (
            self.model,
            self.optim,
            self.scheduler,
            self.dl,
            self.valid_dl
        ) = self.accelerator.prepare(
            self.model,
            self.optim,
            self.scheduler,
            self.dl,
            self.valid_dl
        )

        # 创建数据加载器迭代器
        self.dl_iter = cycle(self.dl)
        self.valid_dl_iter = cycle(self.valid_dl)

        # 设置保存模型和结果的频率
        self.save_model_every = save_model_every
        self.save_results_every = save_results_every

        # 设置结果文件夹路径
        self.results_folder = Path(results_folder)

        # 如果是主进程且需要清除之前的结果,则清除结果文件夹
        if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
            rmtree(str(self.results_folder))

        # 创建结果文件夹
        self.results_folder.mkdir(parents = True, exist_ok = True)
        
        # 初始化超参数追踪器
        hps = {"num_train_steps": num_train_steps, "num_warmup_steps": num_warmup_steps, "learning_rate": lr, "initial_learning_rate": lr}
        self.accelerator.init_trackers("soundstorm", config=hps)

    # 保存模型方法
    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.model),
            optim = self.optim.state_dict(),
            scheduler = self.scheduler.state_dict()
        )
        torch.save(pkg, path)

    # 加载模型方法
    def load(self, path, restore_optimizer = True):
        model = self.accelerator.unwrap_model(self.model)
        pkg = model.load(path)

        # 如果需要恢复优化器状态,则加载优化器和调度器状态
        if restore_optimizer:
            self.optim.load_state_dict(pkg['optim'])
            self.scheduler.load_state_dict(pkg['scheduler'])

            # + 1 to start from the next step and avoid overwriting the last checkpoint
            self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)
    # 打印消息,调用加速器对象的打印方法
    def print(self, msg):
        self.accelerator.print(msg)

    # 生成结果,调用模型对象的生成方法
    def generate(self, *args, **kwargs):
        return self.model.generate(*args, **kwargs)

    # 返回设备信息,调用加速器对象的设备属性
    @property
    def device(self):
        return self.accelerator.device

    # 返回是否分布式训练,判断加速器对象的分布式类型和进程数是否为1
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    # 返回是否为主进程,判断加速器对象是否为主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 返回是否为本地主进程,判断加速器对象是否为本地主进程
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    # 预热方法,根据步数计算学习率
    def warmup(self, step):
        if step < self.num_warmup_steps:
            return self.initial_lr + (self.lr - self.initial_lr) * step / self.num_warmup_steps
        else:
            return self.lr
    # 定义训练步骤函数
    def train_step(self):
        # 获取当前步数
        steps = int(self.steps.item())

        # 将模型设置为训练模式
        self.model.train()
        
        # 根据训练步数调整学习率
        if steps < self.num_warmup_steps:
            # 如果步数小于预热步数,应用预热学习率
            lr = self.warmup(steps)
            for param_group in self.optim.param_groups:
                param_group['lr'] = lr
        else:
            # 预热期后,开始应用余弦退火学习率
            self.scheduler.step()

        # 初始化日志
        logs = {}

        # 更新生成器
        for _ in range(self.grad_accum_every):
            # 获取下一个数据批次
            semantic_token_ids, acoustic_token_ids = next(self.dl_iter)

            # 计算损失和损失细分
            loss, loss_breakdown = self.model(
                acoustic_token_ids,
                cond_semantic_token_ids = semantic_token_ids,
                only_train_generator = self.only_train_generator,
                only_train_critic = self.only_train_critic
            )

            generator_loss, critic_loss = loss_breakdown
            generator_loss = 0. if generator_loss is None else generator_loss
            critic_loss = 0. if critic_loss is None else critic_loss
            
            # 反向传播
            self.accelerator.backward(loss / self.grad_accum_every)

            # 累积日志
            accum_log(logs, {'loss': loss.item() / self.grad_accum_every, 'generator_loss': generator_loss / self.grad_accum_every, 'critic_loss': critic_loss / self.grad_accum_every})

        # 如果存在最大梯度范数,则进行梯度裁剪
        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

        # 更新优化器
        self.optim.step()
        self.optim.zero_grad()

        # 记录日志
        self.print(f"{steps}: loss: {logs['loss']:0.3f}, generator loss: {logs['generator_loss']:0.3f}, critic loss: {logs['critic_loss']:0.3f}")
        self.accelerator.log({"train_loss": logs['loss']}, step=steps)

        # 定期采样结果
        self.accelerator.wait_for_everyone()

        if self.is_main and not (steps % self.save_results_every):
            # 获取验证数据批次
            semantic_token_ids, acoustic_token_ids = next(self.valid_dl_iter)

            with torch.inference_mode():
                self.model.eval()
                # 计算验证损失和损失细分
                valid_loss, valid_loss_breakdown = self.model(acoustic_token_ids, cond_semantic_token_ids = semantic_token_ids)
                
                valid_generator_loss, valid_critic_loss = valid_loss_breakdown
                valid_generator_loss = 0. if valid_generator_loss is None else valid_generator_loss
                valid_critic_loss = 0. if valid_critic_loss is None else valid_critic_loss

            # 记录验证日志
            self.print(f'{steps}: valid loss {valid_loss:0.3f}, valid generator loss {valid_generator_loss:0.3f}, valid critic loss {valid_critic_loss:0.3f}')
            self.accelerator.log({"valid_loss": valid_loss, "valid_generator_loss": valid_generator_loss, "valid_critic_loss": valid_critic_loss}, step=steps)

        # 定期保存模型
        if self.is_main and not (steps % self.save_model_every):
            model_path = str(self.results_folder / f'soundstorm.{steps}.pt')
            self.save(model_path)

            self.print(f'{steps}: saving model to {str(self.results_folder)}')

        # 更新步数并返回日志
        self.steps += 1
        return logs

    # 训练函数
    def train(self, log_fn = noop):
        # 循环直到达到训练步数上限
        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        self.print('training complete')

.\lucidrains\soundstorm-pytorch\soundstorm_pytorch\__init__.py

# 从soundstorm_pytorch包中导入SoundStorm、SoundStream、ConformerWrapper和Conformer类
from soundstorm_pytorch.soundstorm import (
    SoundStorm,
    SoundStream,
    ConformerWrapper,
    Conformer
)
# 从soundstorm_pytorch包中导入SoundStormTrainer类
from soundstorm_pytorch.trainer import (
    SoundStormTrainer
)

Spear-TTS - Pytorch

Implementation of Spear-TTS - multi-speaker text-to-speech attention network, in Pytorch

The text-to-semantic module built here will be used for SoundStorm for conditioning.

Appreciation

  • Stability for their generous sponsorships to work on and open source cutting edge artificial intelligence research

  • Lucas Newman for completing the backtranslation portion, as well as beam search decoding!

  • Lucas Newman for completing the final text to semantic transformer training code!

Install

$ pip install spear-tts-pytorch

Usage

import torch

from audiolm_pytorch import HubertWithKmeans

from spear_tts_pytorch import (
    TextToSemantic,
    SemanticToTextDatasetGenerator,
    GeneratedAudioTextDataset,
    MockDataset
)

wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert_base_ls960.pt',
    kmeans_path = './hubert_base_ls960_L9_km500.bin'
)

model = TextToSemantic(
    wav2vec = wav2vec,
    dim = 512,
    num_text_token_ids = 256,
    heads = 8,
    target_kv_heads = 2, # grouped query attention, for memory efficient decoding
    source_depth = 1,
    target_depth = 1
)

ds = MockDataset(10)

dataset_generator = SemanticToTextDatasetGenerator(
    model = model,
    dataset = ds,
    folder = './output_folder'
)

dataset_generator(max_length = 2)

generated_dataset = GeneratedAudioTextDataset(
    folder = './output_folder'
)

assert len(generated_dataset) == 10

Todo

Citations

@misc{kharitonov2023speak,
    title   = {Speak, Read and Prompt: High-Fidelity Text-to-Speech with Minimal Supervision}, 
    author  = {Eugene Kharitonov and Damien Vincent and Zalán Borsos and Raphaël Marinier and Sertan Girgin and Olivier Pietquin and Matt Sharifi and Marco Tagliasacchi and Neil Zeghidour},
    year    = {2023},
    eprint  = {2302.03540},
    archivePrefix = {arXiv},
    primaryClass = {cs.SD}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@misc{shi2023enhance,
    title   = {Enhance audio generation controllability through representation similarity regularization}, 
    author  = {Yangyang Shi and Gael Le Lan and Varun Nagaraja and Zhaoheng Ni and Xinhao Mei and Ernie Chang and Forrest Iandola and Yang Liu and Vikas Chandra},
    year    = {2023},
    eprint  = {2309.08773},
    archivePrefix = {arXiv},
    primaryClass = {cs.SD}
}
@article{Ainslie2023GQATG,
    title   = {GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints},
    author  = {Joshua Ainslie and James Lee-Thorp and Michiel de Jong and Yury Zemlyanskiy and Federico Lebr'on and Sumit K. Sanghai},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2305.13245},
    url     = {https://api.semanticscholar.org/CorpusID:258833177}
}
@inproceedings{Leviathan2022FastIF,
    title   = {Fast Inference from Transformers via Speculative Decoding},
    author  = {Yaniv Leviathan and Matan Kalman and Y. Matias},
    booktitle = {International Conference on Machine Learning},
    year    = {2022},
    url     = {https://api.semanticscholar.org/CorpusID:254096365}
}

.\lucidrains\spear-tts-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的信息
setup(
  # 包名
  name = 'spear-tts-pytorch',
  # 查找包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.4.8',
  # 许可证
  license='MIT',
  # 描述
  description = 'Spear-TTS - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/spear-tts-pytorch',
  # 关键词
  keywords = [
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'text-to-speech'
  ],
  # 安装依赖
  install_requires=[
    'audiolm-pytorch>=1.2.8',
    'beartype',
    'einops>=0.6.1',
    'rotary-embedding-torch>=0.3.0',
    'torch>=1.6',
    'tqdm',
    'x-clip>=0.12.2'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\spear-tts-pytorch\spear_tts_pytorch\attend.py

import torch
from torch import nn, einsum
import torch.nn.functional as F

from collections import namedtuple
from functools import wraps
from packaging import version

from einops import rearrange, repeat

# 定义一个命名元组 Config,包含三个布尔类型的参数
Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义一个辅助函数,用于检查变量是否存在
def exists(val):
    return val is not None

# 定义一个装饰器函数,用于确保被装饰的函数只执行一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 用装饰器 once 包装 print 函数,确保只打印一次
print_once = once(print)

# 主要类 Attend
class Attend(nn.Module):
    def __init__(
        self,
        dropout = 0.,
        causal = False,
        flash = False
    ):
        super().__init__()
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.causal = causal
        self.register_buffer("mask", None, persistent=False)

        self.flash = flash
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定用于 cuda 和 cpu 的高效注意力配置

        self.cpu_config = Config(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = Config(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = Config(False, True, True)

    # 获取掩码
    def get_mask(self, i, j, device):
        n = max(i, j)

        if exists(self.mask) and self.mask.shape[-1] >= n:
            mask = self.mask[:n, :n]
        else:
            mask = torch.ones((n, n), device = device, dtype = torch.bool).triu(1)
            self.register_buffer("mask", mask, persistent = False)

        return mask[-i:, :]

    # Flash Attention 函数
    def flash_attn(self, q, k, v, mask = None):
        _, heads, q_len, _, k_len, causal, is_cuda, device = *q.shape, k.shape[-2], self.causal, q.is_cuda, q.device

        # 检查掩码是否存在并扩展到兼容的形状
        if exists(mask):
            mask = rearrange(mask, 'b j -> b 1 1 j')
            mask = mask.expand(-1, heads, q_len, -1)

        # 检查是否有兼容的设备用于 Flash Attention
        config = self.cuda_config if is_cuda else self.cpu_config

        # 如果 q 和 k 的长度不同(缓存键/值),并且是因果的,手动构造因果注意力掩码作为浮点数,因为不支持(Flash Attention 2 最终会支持这一点)
        row_is_entirely_masked = None
        if causal and q_len != k_len:
            causal_mask = self.get_mask(q_len, k_len, device = device)

            if exists(mask):
                mask = mask & ~causal_mask
            else:
                mask = ~causal_mask

            row_is_entirely_masked = ~mask.any(dim = -1)
            mask[..., 0] = mask[..., 0] | row_is_entirely_masked

            causal = False

        # 使用 torch.backends.cuda.sdp_kernel 函数应用 PyTorch 2.0 Flash Attention
        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask = mask,
                dropout_p = self.dropout if self.training else 0., 
                is_causal = causal
            )

        if exists(row_is_entirely_masked):
            out = out.masked_fill(row_is_entirely_masked[..., None], 0.)

        return out
    # 定义一个前向传播函数,接受查询(q)、键(k)、值(v)和掩码(mask)作为输入参数
    """
    einstein notation
    b - batch
    h - heads
    n, i, j - sequence length (base sequence length, source, target)
    d - feature dimension
    """

    # 获取查询(q)的序列长度和设备信息
    n, device = q.shape[-2], q.device
    # 获取头数和键值对应的头数
    heads, kv_heads = q.shape[1], k.shape[1]

    # 如果键值对应的头数小于总头数,则对键(k)和值(v)进行重复以匹配总头数
    if kv_heads < heads:
        k, v = map(lambda t: repeat(t, 'b h ... -> b (g h) ...', g = heads // kv_heads), (k, v))

    # 缩放因子
    scale = q.shape[-1] ** -0.5

    # 如果启用了flash注意力机制,则调用flash_attn函数
    if self.flash:
        return self.flash_attn(q, k, v, mask = mask)

    # 相似度计算

    sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale

    # 键填充掩码

    # 如果存在掩码,则重新排列掩码并用极小值替换相似度矩阵中的无效位置
    if exists(mask):
        mask = rearrange(mask, 'b j -> b 1 1 j')
        sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

    # 因果掩码

    # 如果启用了因果掩码,则生成因果掩码并用极小值替换相似度矩阵中的无效位置
    if self.causal:
        i, j = sim.shape[-2:]
        causal_mask = self.get_mask(i, j, device)
        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

    # 注意力权重计算

    # 对相似度矩阵进行softmax操作,得到注意力权重
    attn = sim.softmax(dim = -1)
    # 对注意力权重进行dropout操作
    attn = self.attn_dropout(attn)

    # 聚合值

    # 根据注意力权重对值(v)进行加权求和,得到输出结果
    out = einsum("b h i j, b h j d -> b h i d", attn, v)

    return out

.\lucidrains\spear-tts-pytorch\spear_tts_pytorch\data.py

# 导入必要的模块
from pathlib import Path
import torch
from torch.utils.data import Dataset
from beartype import beartype

# 模拟数据集类
class MockDataset(Dataset):
    # 初始化方法,接受数据集长度参数
    def __init__(self, length: int):
        self.length = length

    # 返回数据集长度
    def __len__(self):
        return self.length

    # 获取数据集中指定索引的数据
    def __getitem__(self, ind):
        return torch.randn(1024)

# 生成音频文本数据集类
class GeneratedAudioTextDataset(Dataset):
    # 初始化方法,接受文件夹路径和分隔符ID参数
    @beartype
    def __init__(
        self,
        folder: str,
        delimiter_id: int = -1
    ):
        # 将文件夹路径转换为Path对象
        self.folder = Path(folder)
        # 断言文件夹存在且是一个目录
        assert self.folder.exists() and self.folder.is_dir()
        # 获取文件夹中所有以'.pt'结尾的文件路径列表
        self.paths = list(self.folder.glob('*.pt'))
        # 设置分隔符ID
        self.delimiter_id = delimiter_id

    # 返回数据集的长度
    def __len__(self):
        return len(self.paths)

    # 获取数据集中指定索引的数据
    def __getitem__(self, ind):
        # 获取指定索引的文件路径
        path = self.paths[ind]
        # 加载文件中的数据为张量
        tensor = torch.load(str(path))

        # 创建一个布尔张量,标记分隔符ID的位置
        delimiter_mask = tensor == self.delimiter_id
        # 断言至少存在一个分隔符,否则抛出异常
        assert delimiter_mask.any(), f'delimeter (<audio> <delimeter> <text>) not found'

        # 找到第一个分隔符的位置
        ind = (delimiter_mask.cumsum(dim=-1) == 0).sum().item()

        # 返回分隔符之前的部分和分隔符之后的部分作为数据
        return tensor[:ind], tensor[(ind + 1):]

.\lucidrains\spear-tts-pytorch\spear_tts_pytorch\distributed.py

# 导入 torch 库
import torch
# 从 torch.autograd 模块中导入 Function 类
from torch.autograd import Function
# 导入 torch.distributed 模块
import torch.distributed as distributed
# 从 einops 库中导入 rearrange 函数

from einops import rearrange

# distributed helpers

# 定义一个函数用于在所有进程中收集具有可变维度的张量
def all_gather_variable_dim(t, dim = 0, sizes = None):
    # 获取当前设备、进程的排名和总进程数
    device, rank, world_size = t.device, distributed.get_rank(), distributed.get_world_size()

    # 如果 sizes 不存在
    if not exists(sizes):
        # 创建一个张量表示 t 在指定维度上的大小
        size = torch.tensor(t.shape[dim], device = device, dtype = torch.long)
        # 创建一个列表,用于存储各个进程的大小信息
        sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)]
        # 在所有进程中收集各个进程的大小信息
        distributed.all_gather(sizes, size)
        # 将收集到的大小信息堆叠成一个张量
        sizes = torch.stack(sizes)

    # 获取所有进程中最大的大小
    max_size = sizes.amax().item()
    # 将 t 在指定维度上填充到最大大小
    padded_t = pad_dim_to(t, max_size, dim = dim)

    # 创建一个列表,用于存储各个进程收集到的张量
    gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)]
    # 在所有进程中收集填充后的张量
    distributed.all_gather(gathered_tensors, padded_t)

    # 将所有进程收集到的张量在指定维度上拼接
    gathered_tensor = torch.cat(gathered_tensors, dim = dim)
    # 创建一个序列张量
    seq = torch.arange(max_size, device = device)

    # 创建一个掩码,用于选择有效的数据
    mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
    mask = rearrange(mask, 'i j -> (i j)')
    seq = torch.arange(mask.shape[-1], device = device)
    indices = seq[mask]

    # 根据掩码选择有效的数据
    gathered_tensor = gathered_tensor.index_select(dim, indices)

    return gathered_tensor, sizes

# 定义一个继承自 Function 的类 AllGather
class AllGather(Function):
    @staticmethod
    def forward(ctx, x, dim, sizes):
        # 检查是否处于分布式环境中且进程数大于 1
        is_dist = distributed.is_initialized() and distributed.get_world_size() > 1
        ctx.is_dist = is_dist

        # 如果不处于分布式环境中,直接返回输入张量和空值
        if not is_dist:
            return x, None

        # 在所有进程中收集具有可变维度的张量
        x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
        ctx.batch_sizes = batch_sizes.tolist()
        ctx.dim = dim
        return x, batch_sizes

    @staticmethod
    def backward(ctx, grads, _):
        # 如果不处于分布式环境中,直接返回梯度和空值
        if not ctx.is_dist:
            return grads, None, None

        # 获取各个进程的大小信息和当前进程的排名
        batch_sizes, rank = ctx.batch_sizes, distributed.get_rank()
        # 根据各个进程的大小信息拆分梯度
        grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
        return grads_by_rank[rank], None, None

# 将 AllGather 类应用为一个函数
all_gather = AllGather.apply

.\lucidrains\spear-tts-pytorch\spear_tts_pytorch\spear_tts_pytorch.py

# 导入数学库
import math
# 从路径库中导入路径类
from pathlib import Path
# 从 functools 库中导入 partial 函数
from functools import partial
# 从 random 库中导入 random 函数
from random import random

# 导入 torch 库
import torch
# 从 torch.nn.functional 中导入 F
import torch.nn.functional as F
# 从 torch.nn.utils.rnn 中导入 pad_sequence
from torch.nn.utils.rnn import pad_sequence
# 从 torch 中导入 Tensor, nn, einsum, IntTensor, LongTensor
from torch import Tensor, nn, einsum, IntTensor, LongTensor

# 从 torch.nn 中导入 Module, ModuleList
from torch.nn import Module, ModuleList

# 从 torch.utils.data 中导入 Dataset
from torch.utils.data import Dataset

# 从 einops 中导入 rearrange, repeat, pack, reduce
from einops import rearrange, repeat, pack, reduce
# 从 einops.layers.torch 中导入 Rearrange
from einops.layers.torch import Rearrange

# 从 audiolm_pytorch 中导入 FairseqVQWav2Vec, HubertWithKmeans
from audiolm_pytorch import FairseqVQWav2Vec, HubertWithKmeans
# 从 audiolm_pytorch.data 中导入 get_dataloader
from audiolm_pytorch.data import get_dataloader

# 从 rotary_embedding_torch 中导入 RotaryEmbedding
from rotary_embedding_torch import RotaryEmbedding

# 从 beartype 中导入 beartype
from beartype import beartype
# 从 beartype.door 中导入 is_bearable
from beartype.door import is_bearable
# 从 beartype.typing 中导入 Optional, Union, Callable, Literal, Tuple, List
from beartype.typing import Optional, Union, Callable, Literal, Tuple, List

# 从 x_clip.tokenizer 中导入 tokenizer
from x_clip.tokenizer import tokenizer

# 从 spear_tts_pytorch 中导入 Attend, all_gather
from spear_tts_pytorch.attend import Attend
from spear_tts_pytorch.distributed import all_gather

# 从 tqdm 中导入 tqdm
from tqdm import tqdm

# 定义 FloatTensor 类型为 Union 类型,包含 torch.FloatTensor 和 torch.cuda.FloatTensor
FloatTensor = Union[
    torch.FloatTensor,
    torch.cuda.FloatTensor
]

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 判断张量是否为空
def empty(t: Tensor):
    return t.numel() == 0

# 对张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# 设置 EOS 标识符的位置
def set_eos_id(t: Tensor, eos_id: int, pad_id: int):
    eos_indices = ((t == pad_id).cumsum(dim = -1) == 0).sum(dim = -1, keepdim = True).long()

    batch_range = torch.arange(t.shape[0], device = t.device, dtype = torch.long)
    batch_range = rearrange(batch_range, '... -> ... 1')

    t = F.pad(t, (0, 1), value = pad_id)
    t[batch_range, eos_indices] = eos_id
    return t

# 对批次中的唯一连续值进行填充
def batch_unique_consecutive(t, pad_value = 0.):
    unique_arr = [torch.unique_consecutive(el) for el in t.unbind(dim = 0)]
    return pad_sequence(unique_arr, batch_first = True, padding_value = pad_value)

# 在 EOS 之后进行掩码处理
def mask_after_eos(target, eos_id, pad_id):
    mask = (target == eos_id).cumsum(dim = -1) > 0
    mask = F.pad(mask, (1, -1), value = False)
    return target.masked_fill(mask, pad_id)

# 安全除法
def safe_div(num, den, eps = 1e-10):
    return num / max(den, eps)

# 查找第一个为真的索引
def find_first_true_index(bool_tensor, dim = -1):
    return (bool_tensor.cumsum(dim = dim) == 0).sum(dim = dim)

# 冻结和解冻辅助函数

# 设置模块参数是否需要梯度
def set_requires_grad_(module: Module, requires_grad: bool):
    for p in module.parameters():
        p.requires_grad = requires_grad

# 冻结模块参数
def freeze(module: Module):
    set_requires_grad_(module, False)

# 解冻模块参数
def unfreeze(module: Module):
    set_requires_grad_(module, True)

# 采样辅助函数

# 评估装饰器
def eval_decorator(fn):
    def inner(self, *args, **kwargs):
        was_training = self.training
        self.eval()
        out = fn(self, *args, **kwargs)
        self.train(was_training)
        return out
    return inner

# 对数函数
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# Gumbel 采样
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)

# Top-p 采样
def top_p(logits, thres = 0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    sorted_indices_to_remove = F.pad(cum_probs > thres, (1, -1), value = 0)
    sorted_logits[sorted_indices_to_remove] = float('-inf')
    sorted_logits = sorted_logits.scatter(-1, sorted_indices, sorted_logits)
    return sorted_logits

# Top-k 采样
def top_k(logits, thres = 0.1, k = None):
    if not exists(k):
        k = math.ceil(thres * logits.shape[-1])
    val, ind = torch.topk(logits, k, dim = -1)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(-1, ind, val)
    return probs

# 残差包装器

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

# RMSNorm

class RMSNorm(nn.Module):
    # 初始化函数,接受一个维度参数
    def __init__(self, dim):
        # 调用父类的初始化函数
        super().__init__()
        # 计算缩放因子为维度的平方根
        self.scale = dim ** 0.5
        # 创建一个可学习的参数 gamma,维度为输入维度
        self.gamma = nn.Parameter(torch.ones(dim))
    
    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 对输入 x 进行归一化操作,dim=-1 表示对最后一个维度进行归一化
        return F.normalize(x, dim=-1) * self.scale * self.gamma
# 定义 GEGLU 类,用于实现 GEGLU 激活函数
class GEGLU(nn.Module):
    # GEGLU 类的前向传播函数
    def forward(self, x):
        # 将输入张量 x 按照最后一个维度分成两部分
        x, gate = x.chunk(2, dim = -1)
        # 对 gate 部分应用 GELU 激活函数,并与 x 相乘
        return F.gelu(gate) * x

# 定义 FeedForward 函数,用于创建前馈神经网络层
def FeedForward(dim, mult = 4, dropout = 0.):
    # 计算内部维度
    dim_inner = int(dim * mult * 2 / 3)
    # 返回一个包含多个层的神经网络模型
    return nn.Sequential(
        RMSNorm(dim),  # 使用 RMSNorm 进行归一化
        nn.Linear(dim, dim_inner * 2),  # 线性变换层
        GEGLU(),  # 使用 GEGLU 激活函数
        nn.Dropout(dropout),  # Dropout 层
        nn.Linear(dim_inner, dim)  # 线性变换层
    )

# 定义 Attention 类,用于实现注意力机制
class Attention(nn.Module):
    # Attention 类的初始化函数
    def __init__(
        self,
        dim,
        *,
        dim_head = 64,
        heads = 8,
        kv_heads = None,
        causal = False,
        dim_context = None,
        dropout = 0.,
        rotary_emb: Optional[RotaryEmbedding] = None,
        flash = False,
        add_null_kv = False
    ):
        super().__init__()
        dim_context = default(dim_context, dim)

        self.heads = heads
        self.kv_heads = default(kv_heads, heads)
        assert (self.heads % self.kv_heads) == 0, 'number of key value heads must be divisible by query heads'

        self.scale = dim_head ** -0.5
        dim_query_inner = heads * dim_head
        dim_kv_inner = self.kv_heads * dim_head

        self.rotary_emb = rotary_emb

        self.attend = Attend(
            causal = causal,
            flash = flash,
            dropout = dropout
        )

        self.norm = RMSNorm(dim)
        self.attn_dropout = nn.Dropout(dropout)

        # 将输入转换为查询向量
        self.to_q = nn.Sequential(
            nn.Linear(dim, dim_query_inner, bias = False),
            Rearrange('b n (h d) -> b h n d', h = self.heads)
        )

        # 将上下文转换为键值对
        self.to_kv = nn.Sequential(
            nn.Linear(dim_context, dim_kv_inner * 2, bias = False),
            Rearrange('b n (kv h d) -> kv b h n d', kv = 2, h = self.kv_heads)
        )

        # 将输出转换为指定维度
        self.to_out = nn.Linear(dim_query_inner, dim, bias = False)

        self.add_null_kv = add_null_kv
        if add_null_kv:
            self.null_kv = nn.Parameter(torch.randn(2, self.kv_heads, 1, dim_head))

    # Attention 类的前向传播函数
    def forward(
        self,
        x,
        context = None,
        mask = None,
        cache = None,
        return_cached_key_values = False
    ):
        has_context = exists(context)
        b = x.shape[0]

        x = self.norm(x)

        context = default(context, x)

        q, k, v = (self.to_q(x), *self.to_kv(context))

        if exists(cache):
            ck, cv = cache.unbind(dim = 1)
            k = torch.cat((ck, k), dim = -2)
            v = torch.cat((cv, v), dim = -2)

        new_cache = torch.stack((k, v), dim = 1)

        if exists(self.rotary_emb):
            assert not has_context
            q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)

        if self.add_null_kv:
            assert not exists(self.rotary_emb)
            nk, nv = map(lambda t: repeat(t, 'h 1 d -> b h 1 d', b = b), self.null_kv)
            k = torch.cat((nk, k), dim = -2)
            v = torch.cat((nv, v), dim = -2)

            if exists(mask):
                mask = F.pad(mask, (1, 0), value = True)

        out = self.attend(q, k, v, mask = mask)

        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)

        if not return_cached_key_values:
            return out

        return out, new_cache

# 定义 Transformer 类,用于实现 Transformer 模型
class Transformer(nn.Module):
    # Transformer 类的初始化函数
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        kv_heads = None,
        causal = False,
        attn_dropout = 0.,
        ff_mult = 4,
        ff_dropout = 0.,
        cross_attend = False,
        attn_flash = False
    ):
        # 调用父类的构造函数
        super().__init__()

        # 创建旋转嵌入对象
        rotary_emb = RotaryEmbedding(dim_head)

        # 初始化神经网络层列表
        self.layers = nn.ModuleList([])

        # 循环创建指定数量的层
        for _ in range(depth):
            # 每一层包含注意力机制、交叉注意力机制(可选)、前馈神经网络
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, causal = causal, dim_head = dim_head, heads = heads, kv_heads = kv_heads, dropout = attn_dropout, rotary_emb = rotary_emb, flash = attn_flash),
                Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = attn_flash, add_null_kv = True) if cross_attend else None,
                FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
            ]))

        # 创建最终的归一化层
        self.final_norm = RMSNorm(dim)

    def forward(
        self,
        x,
        mask = None,
        context = None,
        context_mask = None,
        cache = None,
        return_cache = False,
        return_hiddens = False,
        early_exit_at_layer = None,
        seq_start_pos = None
    ):
        # 检查是否存在上下文信息
        has_context = exists(context)

        # 如果存在序列起始位置信息,则生成对应的掩码
        if exists(seq_start_pos):
            assert not exists(mask)
            seq_len = x.shape[-2]
            seq_arange = torch.arange(seq_len, device = x.device, dtype = torch.long)
            mask = seq_arange >= seq_start_pos[..., None]

        # 如果存在缓存信息,则截取输入序列
        if exists(cache):
            cached_length, seq_len = cache.shape[-2], x.shape[-2]
            assert seq_len > cached_length
            x = x[:, cached_length:]

        # 初始化新的缓存列表和隐藏层列表
        new_cache = []
        hiddens = []

        # 如果存在缓存信息,则创建迭代器
        if exists(cache):
            iter_cache = iter(cache.unbind(dim = 1))
        else:
            iter_cache = iter([])

        # 遍历每一层
        for ind, (self_attn, maybe_cross_attn, ff) in enumerate(self.layers):
            layer = ind + 1

            # 计算自注意力机制输出,并更新缓存
            residual = x
            attn_out, key_values = self_attn(x, mask = mask, cache = next(iter_cache, None), return_cached_key_values = True)
            x = attn_out + residual
            new_cache.append(key_values)

            # ��果存在交叉注意力机制,则应用
            if exists(maybe_cross_attn):
                assert has_context
                x = maybe_cross_attn(x, context = context, mask = context_mask) + x

            # 应用前馈神经网络
            x = ff(x) + x
            hiddens.append(x)

            # 如果设置了提前退出层,则在该层结束循环
            if exists(early_exit_at_layer) and early_exit_at_layer == layer:
                break

        # 如果设置了提前退出层,则返回结果或缓存
        if exists(early_exit_at_layer):
            if return_cache:
                return x, torch.stack(new_cache, dim = 1)
            return x

        # 对最终输出进行归一化
        out = self.final_norm(x)

        # 如果需要返回隐藏层信息,则返回结果和隐藏层列表
        if return_hiddens:
            assert not return_cache
            return out, torch.stack(hiddens)

        # 如果不需要返回缓存信息,则返回结果
        if not return_cache:
            return out

        # 返回结果和缓存信息
        return out, torch.stack(new_cache, dim = 1)
# 定义 SpeechOrTextLiteral 类型,可以是'speech'或'text'中的一个
SpeechOrTextLiteral = Union[
    Literal['speech'],
    Literal['text']
]

# 定义 SemanticModelType 类型,可以是 FairseqVQWav2Vec 或 HubertWithKmeans 中的一个
SemanticModelType = Union[
    FairseqVQWav2Vec,
    HubertWithKmeans
]

# 定义 TextToSemantic 类,继承自 Module 类
class TextToSemantic(Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        dim,
        *,
        source_depth,
        target_depth,
        num_text_token_ids = None,
        tokenizer_encode: Optional[Callable] = None,
        use_openai_tokenizer = False,
        wav2vec: Optional[SemanticModelType] = None,
        num_semantic_token_ids = None,
        dim_head = 64,
        heads = 8,
        target_kv_heads = None,  # for grouped query attention, saving memory on decoder inference
        attn_dropout = 0.,
        ff_mult = 4,
        ff_dropout = 0.,
        semantic_pad_id = -1,
        text_pad_id = 0,
        autoset_semantic_eos_id = True,
        autoset_text_eos_id = True,
        attn_flash = False,
        cond_drop_prob = 0.,
        target_early_exit_layer = None,
        detach_early_exit_embed = False,
        align_reg_loss_weight = 0.1,
        align_reg_use_logsumexp_pool = True,
        align_reg_logsumexp_pool_temp = 0.1
    @property
    def device(self):
        # 返回第一个参数的设备
        return next(self.parameters()).device

    # 加载函数
    def load(self, path, strict = True):
        # 返回 pkg,以便如果此函数从 Trainer 函数调用中调用,则 Trainer 也可以访问从检查点加载的包
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path), map_location = 'cpu')
        self.load_state_dict(pkg['model'], strict = strict)
        return pkg

    # 一组冻结/解冻工具
    # 然后依赖 get_optimizer 来过滤不需要梯度的参数,使其暴露给优化器

    # 解冻所有参数
    def unfreeze_all(self):
        unfreeze(self)

    # 冻结编码器
    def freeze_encoder(self):
        freeze(self.source_transformer)

    # 冻结编码器到某一层
    def freeze_encoder_below_layer(self, layer: int):
        """
        用于在伪标记数据集上对文本到语义的最终训练
        他们将编码器部分冻结到某一层
        """
        unfreeze(self.source_transformer)

        for ind, module in enumerate(self.source_transformer.layers):
            current_layer = ind + 1

            if current_layer <= layer:
                freeze(module)

    # 冻结解码器
    def freeze_decoder(self):
        freeze(self.target_transformer)

    # 冻结语音嵌入
    def freeze_speech_emb(self):
        freeze(self.token_emb['speech'])
        self.start_token['speech'].requires_grad = False

    # 冻结文本嵌入
    def freeze_text_emb(self):
        freeze(self.token_emb['text'])
        self.start_token['text'].requires_grad = False

    # 采样函数

    @torch.no_grad()
    @eval_decorator
    @beartype
    def generate(
        self,
        source: Union[List[str], Tensor],
        *,
        source_type: SpeechOrTextLiteral,
        target_type: SpeechOrTextLiteral,
        temperature = 1.,
        filter_logits_fn = top_k,
        filter_fn_kwargs: dict = dict(),
        source_mask: Optional[Tensor] = None,
        max_length = 2048,
        beam_search_decode = False,
        spec_decode = False,
        spec_decode_gamma = 5,
        spec_decode_lenience = 1.,
        beam_size = 4,
        return_source = False,
        return_target_mask = False,
        cond_scale = 1.
    @beartype
    def forward(
        self,
        source: Union[List[str], Tensor],
        target: Union[List[str], Tensor],
        *,
        source_type: SpeechOrTextLiteral,
        target_type: SpeechOrTextLiteral,
        source_mask: Optional[Tensor] = None,
        target_mask: Optional[Tensor] = None,
        return_loss = False,
        return_logits = False,
        cond_drop_prob: Optional[float] = None,
        should_sim_regularize = True,
        return_early_exit_loss = False
# 预训练模块

# 获取掩码子集概率函数
def get_mask_subset_prob(mask, prob, min_mask = 0):
    batch, seq, device = *mask.shape, mask.device
    # 计算每个位置需要mask的数量,根据mask的和与概率相乘,并限制最小值为min_mask
    num_to_mask = (mask.sum(dim=-1, keepdim=True) * prob).clamp(min=min_mask)
    # 生成一个指定大小的随机张量,用于存储logits
    logits = torch.rand((batch, seq), device=device)
    # 根据mask将logits中的非mask位置填充为-1
    logits = logits.masked_fill(~mask, -1)

    # 对logits进行排序,返回排序后的索引
    randperm = logits.argsort(dim=-1).float()

    # 计算每个样本中需要填充的数量
    num_padding = (~mask).sum(dim=-1, keepdim=True)
    # 将randperm中的索引减去需要填充的数量,以保证填充的位置不会被选中
    randperm -= num_padding

    # 生成一个布尔张量,表示哪些位置需要被选中
    subset_mask = randperm < num_to_mask
    # 将subset_mask中非mask位置填充为False
    subset_mask.masked_fill_(~mask, False)
    # 返回subset_mask
    return subset_mask
# 定义一个包装器类,用于语音到语义预训练任务
class SpeechSpeechPretrainWrapper(nn.Module):
    # 初始化方法
    @beartype
    def __init__(
        self,
        model: TextToSemantic,  # 语义模型
        wav2vec: Optional[SemanticModelType] = None,  # 可选的语音模型
        deletion_prob: float = 0.6,  # 删除概率
        reconstruct_seq: bool = False,  # 是否重构序列
        mask_id = None  # 掩码 ID
    ):
        super().__init__()

        self.model = model  # 保存语义模型
        self.wav2vec = default(wav2vec, model.wav2vec)  # 保存语音模型,默认为语义模型的 wav2vec

        self.deletion_prob = deletion_prob  # 保存删除概率
        self.reconstruct_seq = reconstruct_seq  # 是否重构序列
        self.mask_id = mask_id  # 掩码 ID

    # 前向传播方法
    def forward(
        self,
        x,  # 输入数据
        return_early_exit_loss = False  # 是否返回早期退出损失
    ):
        is_raw_audio = x.dtype == torch.float  # 判断输入数据是否为原始音频

        if is_raw_audio:
            assert exists(self.wav2vec)  # 断言语音模型存在
            
            with torch.no_grad():
                self.wav2vec.eval()  # 设置语音模型为评估模式
                x = self.wav2vec(x, flatten = False)  # 对输入数据进行处理

        batch = x.shape[0]  # 获取批次大小

        mask = torch.ones_like(x, dtype = torch.bool, device = self.model.device)  # 创建与输入数据相同形状的掩码

        if exists(self.mask_id):
            assert self.reconstruct_seq, 'reconstruct_seq must be true if mask id is provided'  # 如果提供了掩码 ID,则重构序列必须为真
            
            mask = mask.masked_fill(x == self.model.semantic_pad_id, False)  # 根据语义填充 ID 进行掩码
            delete_mask = get_mask_subset_prob(mask, self.deletion_prob)  # 获取删除掩码

            source = x.masked_fill(delete_mask, self.mask_id)  # 根据删除掩码和掩码 ID 生成源数据
        else:
            delete_mask = get_mask_subset_prob(mask, self.deletion_prob)  # 获取删除掩码

            source = rearrange(x[~delete_mask], '(b n) -> b n', b = batch)  # 重新排列数据

        if self.reconstruct_seq:
            target = x  # 目标数据为输入数据
        else:
            target = rearrange(x[delete_mask], '(b n) -> b n', b = batch)  # 目标数据为删除后的数据

        loss, logits = self.model(
            source, target,  # 输入源数据和目标数据
            source_type = 'speech',  # 源数据类型为语音
            target_type = 'speech',  # 目标数据类型为语音
            return_loss = True,  # 返回损失
            return_logits = True,  # 返回 logits
            return_early_exit_loss = return_early_exit_loss,  # 是否返回早期退出损失
        )

        return loss, logits

# 包装器类,用于反向翻译任务
class SemanticToTextWrapper(nn.Module):
    # 初始化方法
    @beartype
    def __init__(
        self,
        model: TextToSemantic  # 语义模型
    ):
        super().__init__()

        self.model = model  # 保存语义模型

    # 前向传播方法
    def forward(
        self,
        semantic_token_ids,  # 语义标记 ID
        grapheme_token_ids,  # 字形标记 ID
    ):
        source = semantic_token_ids  # 源数据为语义标记 ID
        target = grapheme_token_ids  # 目标数据为字形标记 ID

        loss, logits = self.model(
            source, target,  # 输入源数据和目标数据
            source_type = 'speech',  # 源数据类型为语音
            target_type = 'text',  # 目标数据类型为文本
            return_loss = True,  # 返回损失
            return_logits = True  # 返回 logits
        )

        return loss, logits

# 包装器类,用于文本到语义任务
class TextToSemanticWrapper(nn.Module):
    # 初始化方法
    @beartype
    def __init__(
        self,
        model: TextToSemantic  # 语义模型
    ):
        super().__init__()

        self.model = model  # 保存语义模型

    # 前向传播方法
    def forward(
        self,
        grapheme_token_ids,  # 字形标记 ID
        semantic_token_ids,  # 语义标记 ID
        return_early_exit_loss = True  # 是否返回早期退出损失
    ):
        source = grapheme_token_ids  # 源数据为字形标记 ID
        target = semantic_token_ids  # 目标数据为语义标记 ID

        loss, logits = self.model(
            source, target,  # 输入源数据和目标数据
            source_type = 'text',  # 源数据类型为文本
            target_type = 'speech',  # 目标数据类型为语音
            return_loss = True,  # 返回损失
            return_logits = True,  # 返回 logits
            return_early_exit_loss = return_early_exit_loss  # 是否返回早期退出损失
        )

        return loss, logits

# 包装器类,用于生成伪标记的音频到文本数据集
class SemanticToTextDatasetGenerator(nn.Module):
    # 初始化方法
    @beartype
    def __init__(
        self,
        model,  # 模型
        *,
        dataset: Dataset,  # 数据集
        folder = './generated-audio-text-pairs',  # 文件夹路径
        batch_size = 4,  # 批次大小
        delimiter_id: int = -1,  # 分隔符 ID
        audio_pad_id = None,  # 音频填充 ID
        text_pad_id = 0  # 文本填充 ID
    # 初始化函数,设置模型、数据集、数据加载器等参数
    def __init__(
        self,
        model,
        dataset,
        batch_size,
        delimiter_id,
        audio_pad_id,
        text_pad_id,
        folder
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 设置模型
        self.model = model

        # 设置数据集
        self.dataset = dataset
        # 根据数据集和批量大小创建数据加载器
        self.dl = get_dataloader(dataset, batch_size=batch_size)
        # 设置分隔符的 ID
        self.delimiter_id = delimiter_id

        # 设置音频填充符的 ID
        self.audio_pad_id = audio_pad_id
        # 设置文本填充符的 ID
        self.text_pad_id = text_pad_id

        # 将文件夹路径转换为 Path 对象,并创建文件夹(如果不存在)
        self.folder = Path(folder)
        self.folder.mkdir(exist_ok=True, parents=True)

    # 前向传播函数,生成文本数据
    def forward(
        self,
        max_length=2048,
        beam_search_decode=True,
        **generate_kwargs
    ):
        # 创建包含分隔符 ID 的张量
        delimiter = torch.tensor([self.delimiter_id], device=self.model.device)

        # 计数器,用于生成文件名
        counter = 0

        # 遍历数据加载器中的音频数据
        for audio, in self.dl:
            # 生成音频语义 ID 和文本 ID
            audio_semantic_ids, text_ids = self.model.generate(
                source=audio,
                source_type='speech',
                target_type='text',
                return_source=True,
                max_length=max_length,
                beam_search_decode=beam_search_decode,
                **generate_kwargs
            )

            # 遍历音频语义 ID 和文本 ID
            for audio_semantic_id, text_id in zip(audio_semantic_ids, text_ids):

                # 如果音频填充符存在,则创建音频填充掩码并去除填充符
                if exists(self.audio_pad_id):
                    audio_pad_mask = audio_semantic_id == self.audio_pad_id
                    audio_semantic_id = audio_semantic_id[~audio_pad_mask]

                # 如果文本填充符存在,则创建文本填充掩码并去除填充符
                if exists(self.text_pad_id):
                    text_pad_mask = text_id == self.text_pad_id
                    text_id = text_id[~text_pad_mask]

                # 将音频语义 ID、分隔符和文本 ID 打包成一行数据
                row, _ = pack([audio_semantic_id, delimiter, text_id], '*')
                # 构建保存路径
                path = str(self.folder / f'{counter}.pt')

                # 保存数据到指定路径
                torch.save(row, path)
                # 更新计数器
                counter += 1

.\lucidrains\spear-tts-pytorch\spear_tts_pytorch\trainer.py

# 导入必要的库
import re
from pathlib import Path
from shutil import rmtree

# 导入 beartype 库中的函数和类型
from beartype import beartype
from beartype.door import is_bearable
from beartype.typing import Union, Optional, Tuple

# 导入 PyTorch 库
import torch
from torch import nn, LongTensor, IntTensor
from torch.utils.data import ConcatDataset
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, random_split

# 导入 audiolm_pytorch 库中的模型和函数
from audiolm_pytorch import FairseqVQWav2Vec, HubertWithKmeans
from audiolm_pytorch.data import get_dataloader
from audiolm_pytorch.optimizer import get_optimizer

# 导入 spear_tts_pytorch 库中的模型和数据集
from spear_tts_pytorch.spear_tts_pytorch import SpeechSpeechPretrainWrapper, TextToSemantic, SemanticToTextWrapper, TextToSemanticWrapper
from spear_tts_pytorch.data import GeneratedAudioTextDataset

# 导入 accelerate 库中的加速器和分布式类型
from accelerate import Accelerator, DistributedType

# 定义类型别名
IndicesTensor = Union[LongTensor, IntTensor]

# 确保只有一个 Trainer 实例化
ONE_TRAINER_INSTANTIATED = False

def check_one_trainer():
    global ONE_TRAINER_INSTANTIATED
    assert not ONE_TRAINER_INSTANTIATED, 'only one Trainer can be instantiated at a time for training'
    ONE_TRAINER_INSTANTIATED = True

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 空操作函数
def noop(*args, **kwargs):
    pass

# 无限循环生成数据集
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 将输入转换为元组
def cast_tuple(t):
    return t if isinstance(t, (tuple, list)) else (t,)

# 询问用户是或否
def yes_or_no(question):
    answer = input(f'{question} (y/n) ')
    return answer.lower() in ('yes', 'y')

# 累积日志信息
def accum_log(log, new_logs):
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

# 从检查点文件名中获取训练步数
def checkpoint_num_steps(checkpoint_path):
    """Returns the number of steps trained from a checkpoint based on the filename.

    Filename format assumed to be something like "/path/to/speech.speech.20000.pt" which is
    for 20k train steps. Returns 20000 in that case.
    """
    results = re.findall(r'\d+', str(checkpoint_path)

    if len(results) == 0:
        return 0

    return int(results[-1])

# 定义 SpeechSpeechPretrainer 类
class SpeechSpeechPretrainer(nn.Module):
    @beartype
    def __init__(
        self,
        model: TextToSemantic,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]],
        *,
        num_train_steps,
        num_warmup_steps,
        batch_size,
        dataset: Optional[Dataset] = None,
        deletion_prob: float = 0.6,
        reconstruct_seq: bool = False,
        mask_id = None,
        lr = 3e-4,
        initial_lr = 1e-5,
        grad_accum_every = 1,
        wd = 0.,
        max_grad_norm = 0.5,
        valid_frac = 0.05,
        random_split_seed = 42,
        log_every = 10,
        save_results_every = 100,
        save_model_every = 1000,
        results_folder = './results',
        accelerate_kwargs: dict = dict(),
        split_batches = False,
        drop_last = False,
        force_clear_prev_results = None
        ):
        # 调用父类的构造函数
        super().__init__()
        # 检查是否只有一个训练器
        check_one_trainer()

        # 初始化加速器
        self.accelerator = Accelerator(
            split_batches = split_batches,
            **accelerate_kwargs
        )

        # 设置模型和wav2vec
        self.model = model
        self.wav2vec = wav2vec

        # 初始化训练包装器
        self.train_wrapper = SpeechSpeechPretrainWrapper(
            model = model,
            wav2vec = wav2vec,
            deletion_prob = deletion_prob,
            reconstruct_seq = reconstruct_seq,
            mask_id = mask_id
        )

        # 注册缓冲区
        self.register_buffer('steps', torch.Tensor([0]))

        # 设置训练步数、热身步数、批量大小、梯度累积频率
        self.num_train_steps = num_train_steps
        self.num_warmup_steps = num_warmup_steps
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every

        # 优化器
        self.lr = lr
        self.initial_lr = initial_lr
        self.optim = get_optimizer(model.parameters(), lr = lr, wd = wd)
        self.scheduler = CosineAnnealingLR(self.optim, T_max = num_train_steps)

        # 最大梯度范数
        self.max_grad_norm = max_grad_norm

        # 创建数据集
        self.ds = dataset

        # 划分验证集
        if valid_frac > 0:
            train_size = int((1 - valid_frac) * len(self.ds))
            valid_size = len(self.ds) - train_size
            self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
            self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
        else:
            self.valid_ds = self.ds
            self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')

        # 断言确保数据集和验证集的样本数足够
        assert len(self.ds) >= batch_size, 'dataset must have sufficient samples for training'
        assert len(self.valid_ds) >= batch_size, f'validation dataset must have sufficient number of samples (currently {len(self.valid_ds)}) for training'

        # 数据加载器
        self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True, drop_last = drop_last)
        self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True, drop_last = drop_last)

        # 使用加速器准备训练所需的对象
        (
            self.train_wrapper,
            self.optim,
            self.scheduler,
            self.dl,
            self.valid_dl
        ) = self.accelerator.prepare(
            self.train_wrapper,
            self.optim,
            self.scheduler,
            self.dl,
            self.valid_dl
        )

        # 数据加载器迭代器
        self.dl_iter = cycle(self.dl)
        self.valid_dl_iter = cycle(self.valid_dl)

        # 设置日志、保存模型和保存结果的频率
        self.log_every = log_every
        self.save_model_every = save_model_every
        self.save_results_every = save_results_every

        # 设置结果文件夹路径
        self.results_folder = Path(results_folder)

        # 如果是主进程且需要清除之前的结果,则清除结果文件夹
        if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
            rmtree(str(self.results_folder))

        # 创建结果文件夹
        self.results_folder.mkdir(parents = True, exist_ok = True)
        
        # 初始化超参数跟踪器
        hps = {"num_train_steps": num_train_steps, "num_warmup_steps": num_warmup_steps, "learning_rate": lr, "initial_learning_rate": lr}
        self.accelerator.init_trackers("speechspeech", config=hps)

    # 保存模型
    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.model),
            optim = self.optim.state_dict(),
            scheduler = self.scheduler.state_dict()
        )
        torch.save(pkg, path)
    # 加载模型参数和优化器状态
    def load(self, path):
        # 获取未封装的模型
        model = self.accelerator.unwrap_model(self.model)
        # 加载模型
        pkg = model.load(path)

        # 加载优化器状态
        self.optim.load_state_dict(pkg['optim'])
        # 加载调度器状态
        self.scheduler.load_state_dict(pkg['scheduler'])

        # 从下一个步骤开始,避免覆盖最后一个检查点
        self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)

    # 打印消息
    def print(self, msg):
        self.accelerator.print(msg)

    # 生成结果
    def generate(self, *args, **kwargs):
        return self.train_wrapper.generate(*args, **kwargs)

    # 获取设备
    @property
    def device(self):
        return self.accelerator.device

    # 判断是否分布式训练
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    # 判断是否为主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 判断是否为本地主进程
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    # 热身训练
    def warmup(self, step):
        if step < self.num_warmup_steps:
            return self.initial_lr + (self.lr - self.initial_lr) * step / self.num_warmup_steps
        else:
            return self.lr
    
    # 训练步骤
    def train_step(self):
        steps = int(self.steps.item())

        self.model.train()
        
        # 根据调度器调整学习率
        
        if steps < self.num_warmup_steps:
            # 应用热身训练
            lr = self.warmup(steps)
            for param_group in self.optim.param_groups:
                param_group['lr'] = lr
        else:
            # 热身训练后,开始应用余弦退火学习率调度器
            self.scheduler.step()

        # 日志

        logs = {}

        # 更新 VAE(生成器)

        for _ in range(self.grad_accum_every):
            x, = next(self.dl_iter)

            loss, _ = self.train_wrapper(x)

            self.accelerator.backward(loss / self.grad_accum_every)

            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

        self.optim.step()
        self.optim.zero_grad()

        # 日志

        if not (steps % self.log_every):
            self.print(f"{steps}: loss: {logs['loss']:0.3f}")

        self.accelerator.log({"train_loss": logs['loss']}, step=steps)

        # 定期采样结果

        self.accelerator.wait_for_everyone()

        if self.is_main and not (steps % self.save_results_every):
            x, = next(self.valid_dl_iter)

            with torch.inference_mode():
                self.train_wrapper.eval()
                valid_loss, _ = self.train_wrapper(x)

            self.print(f'{steps}: valid loss {valid_loss:0.3f}')
            self.accelerator.log({"valid_loss": valid_loss}, step=steps)

        # 定期保存模型

        if self.is_main and not (steps % self.save_model_every):
            model_path = str(self.results_folder / f'speech.speech.{steps}.pt')
            self.save(model_path)

            self.print(f'{steps}: saving model to {str(self.results_folder)}')

        self.steps += 1
        return logs

    # 训练模型
    def train(self, log_fn = noop):
        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        self.print('training complete')
# 定义一个用于将语义转换为文本的训练器类
class SemanticToTextTrainer(nn.Module):
    # 初始化方法,接受多个参数
    @beartype
    def __init__(
        self,
        model: TextToSemantic,  # 模型参数,用于将文本转换为语义
        *,
        num_train_steps,  # 训练步数
        num_warmup_steps,  # 热身步数
        batch_size,  # 批量大小
        dataset: Optional[Dataset] = None,  # 数据集,默认为None
        lr = 3e-4,  # 学习率,默认为3e-4
        initial_lr = 1e-5,  # 初始学习率,默认为1e-5
        grad_accum_every = 1,  # 梯度累积频率,默认为1
        wd = 0.,  # 权重衰减,默认为0
        max_grad_norm = 0.5,  # 最大梯度范数,默认为0.5
        valid_frac = 0.05,  # 验证集比例,默认为0.05
        random_split_seed = 42,  # 随机拆分种子,默认为42
        log_every = 10,  # 每隔多少步记录日志,默认为10
        save_results_every = 100,  # 每隔多少步保存结果,默认为100
        save_model_every = 1000,  # 每隔多少步保存模型,默认为1000
        results_folder = './results',  # 结果保存文件夹,默认为'./results'
        accelerate_kwargs: dict = dict(),  # 加速参数,默认为空字典
        split_batches = False,  # 是否拆分批次,默认为False
        drop_last = False,  # 是否丢弃最后一批数据,默认为False
        force_clear_prev_results = None  # 强制清除之前的结果,默认为None
        ):
        # 调用父类的构造函数
        super().__init__()
        # 检查是否只有一个训练器
        check_one_trainer()

        # 初始化加速器
        self.accelerator = Accelerator(
            split_batches = split_batches,
            **accelerate_kwargs
        )

        # 设置模型
        self.model = model

        # 创建训练包装器
        self.train_wrapper = SemanticToTextWrapper(model = model)

        # 注册缓冲区
        self.register_buffer('steps', torch.Tensor([0]))

        # 设置训练步数、预热步数、批量大小、梯度累积频率
        self.num_train_steps = num_train_steps
        self.num_warmup_steps = num_warmup_steps
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every

        # 在进行反向翻译时,冻结编码器和语音嵌入
        model.unfreeze_all()
        model.freeze_speech_emb()
        model.freeze_encoder()

        # 优化器
        # get_optimizer应该过滤掉冻结的参数(requires_grad设置为False的参数)
        self.optim = get_optimizer(
            model.parameters(),
            lr = lr,
            wd = wd,
            filter_by_requires_grad = True
        )

        self.lr = lr
        self.initial_lr = initial_lr
        self.scheduler = CosineAnnealingLR(self.optim, T_max = num_train_steps)

        # 最大梯度范数
        self.max_grad_norm = max_grad_norm

        # 创建数据集
        self.ds = dataset

        # 划分验证集
        if valid_frac > 0:
            train_size = int((1 - valid_frac) * len(self.ds))
            valid_size = len(self.ds) - train_size
            self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
            self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
        else:
            self.valid_ds = self.ds
            self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')

        assert len(self.ds) >= batch_size, 'dataset must have sufficient samples for training'
        assert len(self.valid_ds) >= batch_size, f'validation dataset must have sufficient number of samples (currently {len(self.valid_ds)}) for training'

        # 数据加载器
        self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True, drop_last = drop_last)

        self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True, drop_last = drop_last)

        # 使用加速器准备
        (
            self.train_wrapper,
            self.optim,
            self.scheduler,
            self.dl,
            self.valid_dl
        ) = self.accelerator.prepare(
            self.train_wrapper,
            self.optim,
            self.scheduler,
            self.dl,
            self.valid_dl
        )

        # 数据加载器迭代器
        self.dl_iter = cycle(self.dl)
        self.valid_dl_iter = cycle(self.valid_dl)

        self.log_every = log_every
        self.save_model_every = save_model_every
        self.save_results_every = save_results_every

        self.results_folder = Path(results_folder)

        # 如果是主进程并且强制清除之前的结果或者(force_clear_prev_results不存在且结果文件夹中有文件且用户确认清除)
        if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
            rmtree(str(self.results_folder))

        # 创建结果文件夹
        self.results_folder.mkdir(parents = True, exist_ok = True)
        
        # 初始化超参数跟踪器
        hps = {"num_train_steps": num_train_steps, "num_warmup_steps": num_warmup_steps, "learning_rate": lr, "initial_learning_rate": lr}
        self.accelerator.init_trackers("semantictext", config=hps)

    # 保存模型
    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.model),
            optim = self.optim.state_dict(),
            scheduler = self.scheduler.state_dict()
        )
        torch.save(pkg, path)
    # 加载模型参数和优化器状态
    def load(self, path, restore_optimizer = True):
        # 获取未封装的模型对象
        model = self.accelerator.unwrap_model(self.model)
        # 加载模型参数
        pkg = model.load(path)

        # 如果需要恢复优化器状态
        if restore_optimizer:
            # 加载优化器状态
            self.optim.load_state_dict(pkg['optim'])
            # 加载学习率调度器状态
            self.scheduler.load_state_dict(pkg['scheduler'])

            # 从下一个步骤开始,避免覆盖最后一个检查点
            self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)

    # 打印消息
    def print(self, msg):
        self.accelerator.print(msg)

    # 生成结果
    def generate(self, *args, **kwargs):
        return self.train_wrapper.generate(*args, **kwargs)

    # 获取设备
    @property
    def device(self):
        return self.accelerator.device

    # 判断是否分布式训练
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    # 判断是否为主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 判断是否为本地主进程
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    # 热身训练
    def warmup(self, step):
        if step < self.num_warmup_steps:
            return self.initial_lr + (self.lr - self.initial_lr) * step / self.num_warmup_steps
        else:
            return self.lr
    
    # 训练步骤
    def train_step(self):
        steps = int(self.steps.item())

        # 设置模型为训练模式
        self.model.train()
        
        # 根据调度器调整学习率

        if steps < self.num_warmup_steps:
            # 应用热身训练
            lr = self.warmup(steps)
            for param_group in self.optim.param_groups:
                param_group['lr'] = lr
        else:
            # 热身训练后,开始应用余弦退火学习率调度器
            self.scheduler.step()

        # 日志

        logs = {}

        # 更新 VAE(生成器)

        for _ in range(self.grad_accum_every):
            semantic_token_ids, grapheme_token_ids = next(self.dl_iter)

            loss, _ = self.train_wrapper(semantic_token_ids = semantic_token_ids, grapheme_token_ids = grapheme_token_ids)

            self.accelerator.backward(loss / self.grad_accum_every)

            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

        # 如果存在最大梯度范数,则进行梯度裁剪
        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

        self.optim.step()
        self.optim.zero_grad()

        # 记录日志

        if not (steps % self.log_every):
            self.print(f"{steps}: loss: {logs['loss']:0.3f}")
        self.accelerator.log({"train_loss": logs['loss']}, step=steps)

        # 定期采样结果

        self.accelerator.wait_for_everyone()

        if self.is_main and not (steps % self.save_results_every):
            semantic_token_ids, grapheme_token_ids = next(self.valid_dl_iter)

            with torch.inference_mode():
                self.train_wrapper.eval()
                valid_loss, _ = self.train_wrapper(semantic_token_ids = semantic_token_ids, grapheme_token_ids = grapheme_token_ids)

            self.print(f'{steps}: valid loss {valid_loss:0.3f}')
            self.accelerator.log({"valid_loss": valid_loss}, step=steps)

        # 定期保存模型

        if self.is_main and not (steps % self.save_model_every):
            model_path = str(self.results_folder / f'semantic.text.{steps}.pt')
            self.save(model_path)

            self.print(f'{steps}: saving model to {str(self.results_folder)}')

        self.steps += 1
        return logs

    # 训练模型
    def train(self, log_fn = noop):
        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        self.print('training complete')
# 定义一个用于训练文本到语义模型的类
class TextToSemanticTrainer(nn.Module):
    # 初始化函数,接受模型、训练步数、预热步数等参数
    @beartype
    def __init__(
        self,
        model: TextToSemantic,
        *,
        num_train_steps,
        num_warmup_steps,
        batch_size,
        dataset: Optional[Dataset] = None,
        generated_audio_text_dataset_folder = None,
        dataset_delimiter_id = -1,
        lr = 3e-4,
        initial_lr = 1e-5,
        grad_accum_every = 1,
        wd = 0.,
        max_grad_norm = 0.5,
        valid_frac = 0.05,
        random_split_seed = 42,
        log_every = 10,
        save_results_every = 100,
        save_model_every = 1000,
        results_folder = './results',
        accelerate_kwargs: dict = dict(),
        split_batches = False,
        drop_last = False,
        force_clear_prev_results = None,
        freeze_encoder_layers_below = 2,
        should_train_early_exit_layer_if_available = True
    # 保存模型参数到指定路径
    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.model),
            optim = self.optim.state_dict(),
            scheduler = self.scheduler.state_dict()
        )
        torch.save(pkg, path)

    # 从指定路径加载模型参数,可选择是否还原优化器状态
    def load(self, path, restore_optimizer = True):
        model = self.accelerator.unwrap_model(self.model)
        pkg = model.load(path)

        if restore_optimizer:
            self.optim.load_state_dict(pkg['optim'])
            self.scheduler.load_state_dict(pkg['scheduler'])

            # + 1 to start from the next step and avoid overwriting the last checkpoint
            self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device)

    # 打印消息
    def print(self, msg):
        self.accelerator.print(msg)

    # 生成结果
    def generate(self, *args, **kwargs):
        return self.train_wrapper.generate(*args, **kwargs)

    # 返回设备信息
    @property
    def device(self):
        return self.accelerator.device

    # 判断是否为分布式训练
    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    # 判断是否为主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 判断是否为本地主进程
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    # 根据当前步数计算学习率
    def warmup(self, step):
        if step < self.num_warmup_steps:
            return self.initial_lr + (self.lr - self.initial_lr) * step / self.num_warmup_steps
        else:
            return self.lr
    # 定义训练步骤函数
    def train_step(self):
        # 获取当前步数
        steps = int(self.steps.item())

        # 设置模型为训练模式
        self.model.train()
        
        # 根据训练步数调整学习率
        
        if steps < self.num_warmup_steps:
            # 如果步数小于预热步数,应用预热
            lr = self.warmup(steps)
            for param_group in self.optim.param_groups:
                param_group['lr'] = lr
        else:
            # 预热期后,开始应用余弦退火学习率调度器
            self.scheduler.step()

        # 日志

        logs = {}

        # 更新 VAE(生成器)

        for _ in range(self.grad_accum_every):
            semantic_token_ids, grapheme_token_ids = next(self.dl_iter)

            # 计算损失并进行训练
            loss, _ = self.train_wrapper(semantic_token_ids=semantic_token_ids, grapheme_token_ids=grapheme_token_ids, return_early_exit_loss=self.train_early_exit)

            self.accelerator.backward(loss / self.grad_accum_every)

            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

        # 如果存在最大梯度范数,对梯度进行裁剪
        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

        # 更新优化器
        self.optim.step()
        self.optim.zero_grad()

        # 记录日志

        if not (steps % self.log_every):
            self.print(f"{steps}: loss: {logs['loss']:0.3f}")
        
        self.accelerator.log({"train_loss": logs['loss']}, step=steps)

        # 定期采样结果

        self.accelerator.wait_for_everyone()

        if self.is_main and not (steps % self.save_results_every):
            semantic_token_ids, grapheme_token_ids = next(self.valid_dl_iter)

            with torch.inference_mode():
                self.train_wrapper.eval()
                valid_loss, _ = self.train_wrapper(semantic_token_ids=semantic_token_ids, grapheme_token_ids=grapheme_token_ids, return_early_exit_loss=self.train_early_exit)

            self.print(f'{steps}: valid loss {valid_loss:0.3f}')
            self.accelerator.log({"valid_loss": valid_loss}, step=steps)

        # 定期保存模型

        if self.is_main and not (steps % self.save_model_every):
            model_path = str(self.results_folder / f'text.semantic.{steps}.pt')
            self.save(model_path)

            self.print(f'{steps}: saving model to {str(self.results_folder)}')

        # 更新步数并返回日志
        self.steps += 1
        return logs

    # 训练函数
    def train(self, log_fn=noop):
        # 在未达到训练步数前循环执行训练步骤
        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        self.print('training complete')

.\lucidrains\spear-tts-pytorch\spear_tts_pytorch\__init__.py

# 导入spear_tts_pytorch包中的TextToSemantic、SpeechSpeechPretrainWrapper、SemanticToTextWrapper、TextToSemanticWrapper、SemanticToTextDatasetGenerator类
# 导入spear_tts_pytorch包中的trainer模块中的SpeechSpeechPretrainer、SemanticToTextTrainer、TextToSemanticTrainer类
# 导入spear_tts_pytorch包中的data模块中的GeneratedAudioTextDataset、MockDataset类
from spear_tts_pytorch.spear_tts_pytorch import (
    TextToSemantic,
    SpeechSpeechPretrainWrapper,
    SemanticToTextWrapper,
    TextToSemanticWrapper,
    SemanticToTextDatasetGenerator
)

from spear_tts_pytorch.trainer import (
    SpeechSpeechPretrainer,
    SemanticToTextTrainer,
    TextToSemanticTrainer
)

from spear_tts_pytorch.data import (
    GeneratedAudioTextDataset,
    MockDataset
)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

Speculative Decoding

Explorations into some recent techniques surrounding speculative decoding

Also have a few ideas of my own that I will try and share in this repository, if they work. The goal is to initially use it to speed up the text-to-semantic decoder in Spear-TTS

Appreciation

  • StabilityAI and 🤗 Huggingface for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source current artificial intelligence techniques.

Todo

Citations

@inproceedings{Leviathan2022FastIF,
    title   = {Fast Inference from Transformers via Speculative Decoding},
    author  = {Yaniv Leviathan and Matan Kalman and Y. Matias},
    booktitle = {International Conference on Machine Learning},
    year    = {2022},
    url     = {https://api.semanticscholar.org/CorpusID:254096365}
}
@inproceedings{sun2023spectr,
    title     = {SpecTr: Fast Speculative Decoding via Optimal Transport},
    author    = {Ziteng Sun and Ananda Theertha Suresh and Jae Hun Ro and Ahmad Beirami and Himanshu Jain and Felix Yu and Michael Riley and Sanjiv Kumar},
    booktitle = {Workshop on Efficient Systems for Foundation Models @ ICML2023},
    year      = {2023},
    url       = {https://openreview.net/forum?id=d0mGsaheuT}
}
@article{Chen2023AcceleratingLL,
    title     = {Accelerating Large Language Model Decoding with Speculative Sampling},
    author    = {Charlie Chen and Sebastian Borgeaud and Geoffrey Irving and Jean-Baptiste Lespiau and L. Sifre and John M. Jumper},
    journal   = {ArXiv},
    year      = {2023},
    volume    = {abs/2302.01318},
    url       = {https://api.semanticscholar.org/CorpusID:256503945}
}
@article{Yan2020ProphetNetPF,
    title   = {ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training},
    author  = {Yu Yan and Weizhen Qi and Yeyun Gong and Dayiheng Liu and Nan Duan and Jiusheng Chen and Ruofei Zhang and Ming Zhou},
    journal = {ArXiv},
    year    = {2020},
    volume  = {abs/2001.04063},
    url     = {https://api.semanticscholar.org/CorpusID:210164665}
}
@article{Zhang2023DraftV,
    title     = {Draft \& Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding},
    author    = {Jinchao Zhang and Jue Wang and Huan Li and Lidan Shou and Ke Chen and Gang Chen and Sharad Mehrotra},
    journal   = {ArXiv},
    year      = {2023},
    volume    = {abs/2309.08168},
    url       = {https://api.semanticscholar.org/CorpusID:262013673}
}
@misc{medusa,
    author     = {Tianle Cai and Yuhong Li and Zhengyang Geng and Hongwu Peng and Tri Dao},
    title      = {Medusa: Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads},
    year       = {2023},
    publisher  = {GitHub},
    journal    = {GitHub repository},
    howpublished = {\url{https://github.com/FasterDecoding/Medusa}},
}

.\lucidrains\speculative-decoding\setup.py

# 导入设置安装和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'speculative-decoding', # 包的名称
  packages = find_packages(exclude=[]), # 查找所有包,不排除任何包
  version = '0.1.2', # 版本号
  license='MIT', # 许可证
  description = 'Speculative Decoding', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  long_description_content_type = 'text/markdown', # 长描述内容类型
  url = 'https://github.com/lucidrains/speculative-decoding', # 项目链接
  keywords = [ # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'efficient decoding'
  ],
  install_requires=[ # 安装依赖
    'beartype',
    'einops>=0.6.1',
    'torch>=1.12',
  ],
  classifiers=[ # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\speculative-decoding\speculative_decoding\speculative_decoding.py

import math
# 导入数学库

import torch
# 导入 PyTorch 库
from torch.nn import Module, ModuleList
# 从 PyTorch 中导入 Module 和 ModuleList
from torch import nn, einsum, Tensor
# 从 PyTorch 中导入 nn、einsum 和 Tensor
import torch.nn.functional as F
# 从 PyTorch 中导入 nn.functional,并简称为 F

from rotary_embedding_torch import RotaryEmbedding
# 导入自定义的 RotaryEmbedding 模块
from beartype import beartype
# 导入 beartype 模块,用于类型检查

from collections import namedtuple
# 导入 namedtuple 模块

from einops import rearrange
# 导入 einops 中的 rearrange 函数

# constants

Cache = namedtuple('Cache', ['cached_kvs', 'embeds'])
# 定义一个命名元组 Cache,包含 cached_kvs 和 embeds 两个字段

# helper functions

def exists(val):
    return val is not None
# 定义函数 exists,用于判断值是否存在

def default(val, d):
    return val if exists(val) else d
# 定义函数 default,用于返回值或默认值

# sampling helpers

def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))
# 定义函数 log,用于计算对数并进行截断处理

def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))
# 定义函数 gumbel_noise,生成 Gumbel 噪声

def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)
# 定义函数 gumbel_sample,用于根据温度参数进行采样

def top_k(logits, thres = 0.9):
    k = math.ceil((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(-1, ind, val)
    return probs
# 定义函数 top_k,用于获取前 k 个最大值并进行处理

# rotary embeddings

class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
    # 定义 RotaryEmbedding 类,用于生成旋转嵌入

    def forward(self, seq_len):
        t = torch.arange(seq_len, device = self.inv_freq.device).type_as(self.inv_freq)
        freqs = einsum('i, j -> i j', t, self.inv_freq)
        freqs = torch.cat((freqs, freqs), dim = -1)
        return freqs
    # 前向传播函数,生成旋转嵌入

def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)
# 定义函数 rotate_half,用于旋转张量的一半

def apply_rotary_pos_emb(pos, t):
    seq_len = t.shape[-2]
    pos = pos[-seq_len:, :]
    return t * pos.cos() + rotate_half(t) * pos.sin()
# 定义函数 apply_rotary_pos_emb,应用旋转位置嵌入到张量中

# different decoding strategies

@torch.no_grad()
def base_decoding(
    net: Module,
    prompt: Tensor,
    seq_len: int,
    temperature = 1.,
    filter_thres = 0.9,
):
    prompt_seq_len, out = prompt.shape[-1], prompt.clone()
    sample_num_times = max(0, seq_len - prompt_seq_len)

    cache = None

    for _ in range(sample_num_times):
        logits, cache = net(out, cache = cache, return_cache = True)
        logits = logits[:, -1]

        logits = top_k(logits, thres = filter_thres)
        sample = gumbel_sample(logits, temperature = temperature, dim = -1)

        out = torch.cat((out, sample[..., None]), dim = -1)

    return out[..., prompt_seq_len:]
# 定义函数 base_decoding,基础解码策略

# speculative decoding functions

def safe_div(num, den, eps = 1e-10):
    return num / max(den, eps)
# 定义函数 safe_div,安全除法

def find_first_true_index(bool_tensor, dim = -1):
    return (bool_tensor.cumsum(dim = dim) == 0).sum(dim = dim)
# 定义函数 find_first_true_index,查找第一个为真的索引

@torch.no_grad()
def speculative_decoding(
    net: Module,
    small_net: Module,
    prompt: Tensor,
    seq_len: int,
    gamma: int = 5,
    temperature = 1.,
    filter_thres = 0.9,
    lenience = 1.,
    pad_id = 0
):
    """
    eq. algorithm 1 in paper https://arxiv.org/abs/2211.17192
    """
    # 假设性解码函数,参考论文中的算法1

    batch, prompt_seq_len, out, device = *prompt.shape, prompt.clone(), prompt.device
    sample_num_times = max(0, seq_len - prompt_seq_len)

    cache = None
    small_cache = None

    num_steps = 0
    total_accepted = 0

    batch_range = torch.arange(batch, device = device, dtype = torch.long)[..., None]
    seq_lens = torch.full((batch,), prompt_seq_len, device = device, dtype = torch.long)

    # now left align

    num_pad_left = out.shape[-1] - seq_lens
    max_pad_left = num_pad_left.amax()
    out = F.pad(out, (0, max_pad_left), value = pad_id)

    seq_len_range = torch.arange(seq_len, device = device, dtype = torch.long)
    out = out[batch_range, seq_len_range + num_pad_left[..., None]]

    return out[..., prompt_seq_len:], total_accepted / num_steps
# 定义函数 speculative_decoding,假设性解码函数

@torch.no_grad()
def speculative_decoding_with_same_model(
    net: Module,
    prompt: Tensor,
    seq_len: int,
    gamma: int = 5,
    temperature = 1.,
    filter_thres = 0.9,
    lenience = 1.,
    pad_id = 0
):
    """
    eq. algorithm 1 in paper https://arxiv.org/abs/2211.17192
    """
    # 假设性解码函数,参考论文中的算法1
    # 将 prompt 的形状解包为 batch, prompt_seq_len, out, device
    batch, prompt_seq_len, out, device = *prompt.shape, prompt.clone(), prompt.device
    # 计算需要采样的次数
    sample_num_times = max(0, seq_len - prompt_seq_len)

    # 初始化缓存变量
    cache = None
    small_cache = None

    # 初始化步数和接受总数
    num_steps = 0
    total_accepted = 0

    # 创建 batch_range 和 seq_lens 张量
    batch_range = torch.arange(batch, device=device, dtype=torch.long)[..., None]
    seq_lens = torch.full((batch,), prompt_seq_len, device=device, dtype=torch.long)

    # 对输出进行左对齐填充
    num_pad_left = out.shape[-1] - seq_lens
    max_pad_left = num_pad_left.amax()
    out = F.pad(out, (0, max_pad_left), value=pad_id)

    # 选择左对齐后的输出
    seq_len_range = torch.arange(seq_len, device=device, dtype=torch.long)
    out = out[batch_range, seq_len_range + num_pad_left[..., None]]

    # 返回处理后的输出和接受率
    return out[..., prompt_seq_len:], total_accepted / num_steps
# 定义一个模块,用于对输入进行 RMS 归一化处理
class RMSNorm(Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return F.normalize(x, dim = -1) * self.scale * self.gamma

# 定义一个模块,实现自注意力机制
class CausalAttention(Module):
    def __init__(
        self,
        dim,
        *,
        dim_head = 64,
        heads = 8,
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        dim_inner = dim_head * heads

        self.norm = RMSNorm(dim)

        self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
        self.to_out = nn.Linear(dim_inner, dim, bias = False)

    def forward(
        self,
        x,
        cache = None,
        context_mask = None,
        rotary_emb = None
    ):
        h, device = self.heads, x.device

        x = self.norm(x)

        q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)

        if exists(cache):
            ck, cv = cache.unbind(dim = 1)
            k = torch.cat((ck, k), dim = -2)
            v = torch.cat((cv, v), dim = -2)

        cached_kv = torch.stack((k, v), dim = 1)

        if exists(rotary_emb):
            q = apply_rotary_pos_emb(rotary_emb, q)
            k = apply_rotary_pos_emb(rotary_emb, k)

        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        i, j = sim.shape[-2:]
        causal_mask = torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)

        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        if exists(context_mask):
            context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~context_mask, -torch.finfo(sim.dtype).max)

        attn = sim.softmax(dim = -1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)

        return out, cached_kv

# 定义一个前馈神经网络模块
def FeedForward(dim, mult = 4):
    dim_inner = dim * mult
    return nn.Sequential(
        RMSNorm(dim),
        nn.Linear(dim, dim_inner),
        nn.GELU(),
        nn.Linear(dim_inner, dim)
    )

# 主要的解码器类
class Decoder(Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        depth,
        heads = 8,
        dim_head = 64,
        ff_mult = 4,
        ignore_index = -1,
        early_exit_layer = None,
        early_exit_extra_transformer_blocks = 0,
        detach_early_exit_hiddens = False
    ):
        super().__init__()
        self.token_emb = nn.Embedding(num_tokens, dim)

        self.layers = ModuleList([])

        self.rotary_emb = RotaryEmbedding(dim = dim_head)

        # 创建多个解码器层,每个层包含自注意力和前馈神经网络模块
        for _ in range(depth):
            self.layers.append(ModuleList([
                CausalAttention(dim = dim, dim_head = dim_head, heads = heads),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

        # 输出层,将解码器输出映射到标记空间
        self.to_logits = nn.Sequential(
            RMSNorm(dim),
            nn.Linear(dim, num_tokens, bias = False)
        )

        self.detach_early_exit_hiddens = detach_early_exit_hiddens
        self.early_exit_layer = early_exit_layer
        self.to_early_exit_logits = None
        self.early_exit_transformer_blocks = ModuleList([])

        # 如果存在提前退出层,则创建额外的解码器层
        if exists(early_exit_layer):
            for _ in range(early_exit_extra_transformer_blocks):
                self.early_exit_transformer_blocks.append(ModuleList([
                    CausalAttention(dim = dim, dim_head = dim_head, heads = heads, rotary_emb = rotary_emb),
                    FeedForward(dim = dim, mult = ff_mult)
                ]))

            # 提前退出层的输出层
            self.to_early_exit_logits = nn.Sequential(
                RMSNorm(dim),
                nn.Linear(dim, num_tokens, bias = False)
            )

        self.ignore_index = ignore_index
    # 定义一个方法用于前向传播
    def forward(
        self,
        x,
        return_loss = False,  # 是否返回损失,默认为False
        return_cache = False,  # 是否返回缓存,默认为False
        seq_start_pos = None,  # 序列起始位置,默认为None
        cache = None,  # 缓存,默认为None
        early_exit_cache = None,  # 提前退出缓存,默认为None
        return_early_exit_only = False,  # 是否仅返回提前退出,默认为False
        start_from_early_exit_hiddens = False  # 是否从提前退出隐藏状态开始,默认为False

.\lucidrains\speculative-decoding\speculative_decoding\speculative_decoding_with_prophet.py

import math
import torch
from torch.nn import Module, ModuleList
from torch import nn, einsum, Tensor
import torch.nn.functional as F
from rotary_embedding_torch import RotaryEmbedding
from beartype import beartype
from collections import namedtuple
from einops import rearrange

# 定义一个命名元组Cache,包含cached_kvs和embeds两个字段
Cache = namedtuple('Cache', ['cached_kvs', 'embeds'])

# 定义一些辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果val存在则返回val,否则返回默认值d
def default(val, d):
    return val if exists(val) else d

# 采样辅助函数

# 计算输入张量的对数,避免出现负无穷
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 生成Gumbel噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 从Gumbel分布中采样
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)

# 保留top-k的概率值,其余设置为负无穷
def top_k(logits, thres = 0.9):
    k = math.ceil((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(-1, ind, val)
    return probs

# 旋转嵌入

# 定义旋转嵌入类
class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, seq_len):
        t = torch.arange(seq_len, device = self.inv_freq.device).type_as(self.inv_freq)
        freqs = einsum('i, j -> i j', t, self.inv_freq)
        freqs = torch.cat((freqs, freqs), dim = -1)
        return freqs

# 将输入张量的一半旋转
def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

# 应用旋转位置嵌入
def apply_rotary_pos_emb(pos, t):
    seq_len = t.shape[-2]
    pos = pos[-seq_len:, :]
    return t * pos.cos() + rotate_half(t) * pos.sin()

# 不同的解码策略

# 基础解码函数,用于生成序列
@torch.no_grad()
def base_decoding(
    net: Module,
    prompt: Tensor,
    seq_len: int,
    temperature = 1.,
    filter_thres = 0.9,
):
    prompt_seq_len, out = prompt.shape[-1], prompt.clone()
    sample_num_times = max(0, seq_len - prompt_seq_len)

    cache = None

    for _ in range(sample_num_times):
        logits, cache = net(out, cache = cache, return_cache = True)
        logits = logits[:, -1]

        logits = top_k(logits, thres = filter_thres)
        sample = gumbel_sample(logits, temperature = temperature, dim = -1)

        out = torch.cat((out, sample[..., None]), dim = -1)

    return out[..., prompt_seq_len:]

# 归一化

# 均方根归一化类
class RMSNorm(Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return F.normalize(x, dim = -1) * self.scale * self.gamma

# 注意力和前馈

# 因果注意力类
class CausalAttention(Module):
    def __init__(
        self,
        dim,
        *,
        dim_head = 64,
        heads = 8,
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        dim_inner = dim_head * heads

        self.norm = RMSNorm(dim)

        self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
        self.to_out = nn.Linear(dim_inner, dim, bias = False)

    def forward(
        self,
        x,
        cache = None,
        context_mask = None,
        rotary_emb = None
        ):
        # 获取头数和输入张量的设备信息
        h, device = self.heads, x.device

        # 对输入张量进行归一化处理
        x = self.norm(x)

        # 将输入张量转换为查询、键、值,并重新排列维度
        q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)

        # 如果存在缓存,则将缓存的键值与当前计算的键值拼接
        if exists(cache):
            ck, cv = cache.unbind(dim = 1)
            k = torch.cat((ck, k), dim = -2)
            v = torch.cat((cv, v), dim = -2)

        # 将键值对堆叠在一起
        cached_kv = torch.stack((k, v), dim = 1)

        # 如果存在旋转位置编码,则应用旋转位置编码到查询和键
        if exists(rotary_emb):
            q = apply_rotary_pos_emb(rotary_emb, q)
            k = apply_rotary_pos_emb(rotary_emb, k)

        # 计算注意力矩阵
        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        i, j = sim.shape[-2:]
        # 创建因果掩码
        causal_mask = torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)

        # 使用因果掩码填充注意力矩阵
        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # 如果存在上下文掩码,则使用上下文掩码填充注意力矩阵
        if exists(context_mask):
            context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~context_mask, -torch.finfo(sim.dtype).max)

        # 计算注意力权重
        attn = sim.softmax(dim = -1)

        # 计算输出张量
        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        # 重新排列输出张量的维度
        out = rearrange(out, 'b h n d -> b n (h d)')
        # 将输出张量转换为输出
        out = self.to_out(out)

        # 返回输出张量和缓存的键值对
        return out, cached_kv
# 定义一个前馈神经网络模块,包含 RMSNorm 层、线性层、GELU 激活函数和另一个线性层
def FeedForward(dim, mult = 4):
    # 计算内部维度
    dim_inner = dim * mult
    return nn.Sequential(
        RMSNorm(dim),  # 使用 RMSNorm 对输入进行归一化
        nn.Linear(dim, dim_inner),  # 线性变换,将输入维度转换为内部维度
        nn.GELU(),  # GELU 激活函数
        nn.Linear(dim_inner, dim)  # 线性变换,将内部维度转换为输出维度
    )

# 主要类

class Decoder(Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        depth,
        heads = 8,
        dim_head = 64,
        ff_mult = 4,
        ignore_index = -1
    ):
        super().__init__()
        self.dim = dim
        self.token_emb = nn.Embedding(num_tokens, dim)  # 创建一个嵌入层,将标记映射到指定维度的向量

        self.layers = ModuleList([])  # 创建一个空的模块列表

        self.rotary_emb = RotaryEmbedding(dim = dim_head)  # 创建一个旋转嵌入层,用于相对位置编码

        for _ in range(depth):
            self.layers.append(ModuleList([
                CausalAttention(dim = dim, dim_head = dim_head, heads = heads),  # 创建一个因果注意力层
                FeedForward(dim = dim, mult = ff_mult)  # 创建一个前馈神经网络模块
            ]))

        self.to_logits = nn.Sequential(
            RMSNorm(dim),  # 使用 RMSNorm 对输入进行归一化
            nn.Linear(dim, num_tokens, bias = False)  # 线性变换,将维度转换为标记数量,不使用偏置
        )

        self.ignore_index = ignore_index  # 设置忽略的索引值

    def forward(
        self,
        x,
        start_tokens = None,
        return_loss = False,
        return_cache = False,
        seq_start_pos = None,
        cache = None
    ):
        has_start_tokens = exists(start_tokens)  # 检查是否存在起始标记

        start_token_len = 0
        if exists(start_tokens):
            if start_tokens.ndim == 2:
                start_tokens = rearrange(start_tokens, 'b d -> b 1 d')  # 重新排列起始标记的维度

            start_token_len = start_tokens.shape[-2]  # 获取起始标记的长度

        if return_loss:
            x, labels = x[:, start_token_len:-1], x[:, 1:]  # 如果需要返回损失,则截取输入和标签序列

        x = self.token_emb(x)  # 将输入序列映射为嵌入向量

        if exists(start_tokens):
            x = torch.cat((start_tokens, x), dim = 1)  # 如果存在起始标记,则将其与输入序列连接起来

        # 处理序列起始位置偏移

        self_attn_kv_mask = None  # 初始化自注意力键值掩码为 None
        if exists(seq_start_pos):
            batch, seq_len = x.shape[:2]
            seq_range = torch.arange(seq_len, device = x.device, dtype = torch.long)
            self_attn_kv_mask = seq_range >= seq_start_pos[..., None]  # 生成自注意力键值掩码

        # 相对位置编码

        rotary_emb = self.rotary_emb(x.shape[-2])  # 获取相对位置编码

        # 设置缓存

        new_cached_kvs = []  # 创建一个新的缓存键值对列表

        cache_kvs = cache_embeds = None  # 初始化缓存键值对和嵌入向量为 None

        if exists(cache):
            cache_kvs, cache_embeds = cache  # 如果存在缓存,则获取缓存键值对和嵌入向量

        if exists(cache_kvs):
            iter_cache_kvs = iter(cache_kvs.unbind(dim = 1))  # 迭代缓存键值对
        else:
            iter_cache_kvs = iter([])  # 否则创建一个空迭代器

        # 如果传入了缓存,则只使用最后一个标记

        if exists(cache):
            num_tokens_keep = x.shape[-2] - cache_kvs.shape[-2]  # 计算保留的标记数量
            x = x[:, -num_tokens_keep:]  # 截取保留的标记

        # 主要的变换器体

        for ind, (attn, ff) in enumerate(self.layers):
            layer = ind + 1  # 获取当前层索引

            residual = x  # 保存残差连接
            attn_out, cached_kv = attn(x, rotary_emb = rotary_emb, cache = next(iter_cache_kvs, None))  # 执行注意力计算
            x = residual + attn_out  # 添加残差连接

            new_cached_kvs.append(cached_kv)  # 将缓存键值对添加到列表中

        new_cached_kvs = torch.stack(new_cached_kvs, dim = 1)  # 将新的缓存键值对堆叠在一起

        logits = self.to_logits(x)  # 获取输出 logits

        if not return_loss:
            if not return_cache:
                return logits  # 如果不需要返回损失和缓存,则直接返回 logits

            return logits, Cache(new_cached_kvs, x)  # 否则返回 logits 和缓存

        loss = F.cross_entropy(
            rearrange(logits, 'b n c -> b c n'),  # 重新排列 logits 的维度
            labels,  # 标签
            ignore_index = self.ignore_index  # 忽略的索引值
        )

        return loss, Cache(new_cached_kvs, x)  # 返回损失和缓存

class ModelWithProphetWrapper(Module):
    def __init__(
        self,
        model: Decoder,
        prophet: Decoder,
        prophet_train_length = 8,  # 先知训练长度,应大于主模型解码伽马,因为主模型缓存嵌入是滞后一步的
        detach_model_embed_for_prophet = False,
        num_leading_start_tokens = 1
    # 初始化函数,继承父类的初始化方法
    def __init__(
        super().__init__()
        # 初始化模型和prophet
        self.model = model
        self.prophet = prophet

        # 判断模型和prophet的维度是否相同
        model_prophet_same_dim = model.dim == prophet.dim
        # 如果维度相同,则使用nn.Identity(),否则使用nn.Linear()进行维度转换
        self.to_prophet_start_token = nn.Identity() if model_prophet_same_dim else nn.Linear(model.dim, prophet.dim, bias = False)

        # 确保num_leading_start_tokens大于等于1
        assert num_leading_start_tokens >= 1
        self.num_leading_start_tokens = num_leading_start_tokens

        # 设置prophet的训练长度和是否在模型嵌入中分离prophet
        self.prophet_train_length = prophet_train_length
        self.detach_model_embed_for_prophet = detach_model_embed_for_prophet

    # 前向传播函数
    def forward(self, x):
        # 获取num_start_tokens、batch、seq_len、device
        num_start_tokens = self.num_leading_start_tokens
        batch, seq_len, device = *x.shape, x.device
        prophet_seq_len = self.prophet_train_length
        # 确保序列长度大于等于prophet训练长度
        assert seq_len >= prophet_seq_len

        total_loss = 0.

        # 调用模型的前向传播函数,返回主要损失和缓存的键值对以及嵌入
        main_loss, (cached_kvs, embeds) = self.model(x, return_loss = True)

        # 累加主要损失
        total_loss = total_loss + main_loss

        # 如果需要分离模型嵌入用于prophet
        if self.detach_model_embed_for_prophet:
            embeds = embeds.detach()

        # 将嵌入转换为prophet的起始标记
        prophet_start_tokens = self.to_prophet_start_token(embeds)

        # 创建batch索引和prophet序列长度索引
        batch_arange = torch.arange(batch, device = device, dtype = torch.long)
        prophet_seq_arange = torch.arange(prophet_seq_len, device = device, dtype = torch.long)

        # 计算用于prophet训练的序列数量
        num_seq_train_prophet = seq_len - prophet_seq_len - (num_start_tokens - 1)

        # 创建偏移量
        offsets = torch.arange(num_seq_train_prophet, device = device, dtype = torch.long)

        # 获取prophet的输入序列
        prophet_input = x[
            batch_arange[:, None, None],
            offsets[..., None] + prophet_seq_arange
        ]

        # 重新排列prophet的输入序列
        prophet_input = rearrange(prophet_input, '... n -> (...) n')

        # 创建起始标记索引
        start_tokens_arange = torch.arange(num_start_tokens, device = device, dtype = torch.long)

        # 获取prophet的起始标记
        prophet_start_tokens = prophet_start_tokens[
            batch_arange[:, None, None],
            offsets[..., None] + start_tokens_arange
        ]

        # 重新排列prophet的起始标记
        prophet_start_tokens = rearrange(prophet_start_tokens[:, :num_seq_train_prophet], 'b n l d -> (b n) l d')

        # 调用prophet的前向传播函数,返回prophet损失
        prophet_loss, _ = self.prophet(prophet_input, start_tokens = prophet_start_tokens, return_loss = True)

        # 累加prophet损失
        total_loss = total_loss + prophet_loss

        # 返回总损失和主要损失、prophet损失
        return total_loss, (main_loss, prophet_loss)
# 安全除法函数,避免分母为零的情况
def safe_div(num, den, eps = 1e-10):
    return num / max(den, eps)

# 在布尔张量中查找第一个为True的索引
def find_first_true_index(bool_tensor, dim = -1):
    return (bool_tensor.cumsum(dim = dim) == 0).sum(dim = dim)

# 使用Prophet模型进行推测解码
@torch.no_grad()
def speculative_decoding_with_prophet_model(
    net: ModelWithProphetWrapper,
    prompt: Tensor,
    seq_len: int,
    gamma: int = 5,
    temperature = 1.,
    filter_thres = 0.9,
    lenience = 1.,
    pad_id = 0
):
    """
    eq. algorithm 1 in paper https://arxiv.org/abs/2211.17192
    """

    # 提取模型、Prophet模型和模型到Prophet模型的转换(如果它们的模型维度不同)

    model = net.model
    to_prophet_start_token = net.to_prophet_start_token
    prophet = net.prophet
    num_start_tokens = net.num_leading_start_tokens

    batch, prompt_seq_len, out, device = *prompt.shape, prompt.clone(), prompt.device

    if (seq_len - prompt_seq_len) <= 0:
        return prompt, None

    cache = None
    small_cache = None

    num_steps = 0
    total_accepted = 0

    batch_range = torch.arange(batch, device = device, dtype = torch.long)[..., None]
    seq_lens = torch.full((batch,), prompt_seq_len, device = device, dtype = torch.long)

    # 从主模型中随机抽样第一个标记

    for _ in range(max(1, num_start_tokens - prompt_seq_len)):
        logits, cache = model(out, cache = cache, return_cache = True)
        logits = logits[:, -1:]
        logits = top_k(logits, thres = filter_thres)
        sample = gumbel_sample(logits, temperature = temperature, dim = -1)
        out = torch.cat((out, sample), dim = -1)
        seq_lens += 1

    # 现在我们有第一个缓存的嵌入,用作推测抽样的Prophet网络的起始标记

    _, embeds = cache
    next_prophet_start_tokens = to_prophet_start_token(embeds[:, -num_start_tokens:])
    # 当序列长度小于给定的序列长度时,执行循环
    while (seq_lens < seq_len).any():

        # 使用较小的网络进行预测

        # 存储所有较小网络的logits和采样输出
        all_small_logits = []
        q_sampled_out = []

        small_cache = None
        num_tokens = 2  # 主模型的嵌入比主序列滞后1步

        # 运行gamma次循环
        for _ in range(gamma):
            # 使用prophet函数进行预测
            small_logits, small_cache = prophet(
                out[..., -num_tokens:],
                start_tokens = next_prophet_start_tokens,
                cache = small_cache,
                return_cache = True
            )

            small_logits = small_logits[:, -1:]

            # 对logits进行top-k筛选
            small_logits = top_k(small_logits, thres = filter_thres)
            all_small_logits.append(small_logits)

            # 使用gumbel采样得到样本
            sample = gumbel_sample(small_logits, temperature = temperature, dim = -1)
            out = torch.cat((out, sample), dim = -1)

            seq_lens += 1
            num_tokens += 1

            q_sampled_out.append(rearrange(sample, '... -> ... 1'))

        q_sampled_out = torch.cat(q_sampled_out, dim = -2)
        small_logits = torch.cat(all_small_logits, dim = -2)

        # 使用较大的网络进行验证

        logits, cache = model(
            out,
            cache = cache,
            return_cache = True,
            seq_start_pos = out.shape[-1] - seq_lens
        )

        logits = logits[..., -(gamma + 1):, :]
        logits = top_k(logits, thres = filter_thres)

        # 计算较大网络和较小网络的概率(算法1中的p(x)和q(x))

        prob = safe_div(logits, temperature).softmax(dim = -1)
        small_prob = safe_div(small_logits, temperature).softmax(dim = -1)

        p, prob_next = prob[:, :-1], prob[:, -1]

        p = p.gather(-1, q_sampled_out)
        q = small_prob.gather(-1, q_sampled_out) * lenience

        p, q = [rearrange(t, 'b n 1 -> b n') for t in (p, q)]

        r = random_uniform = torch.zeros_like(q).float().uniform_(0, 1)

        accepted = find_first_true_index(r > (p / q))

        total_accepted += accepted.float().mean()
        num_steps += 1

        num_rejected = gamma - accepted
        has_rejected = num_rejected > 0

        accepted = rearrange(accepted, 'b -> b 1')
        accepted.clamp_(max = gamma - 1)

        adjusted_prob = F.relu(prob[batch_range, accepted] - small_prob[batch_range, accepted])
        adjusted_prob = adjusted_prob / adjusted_prob.sum(dim = -1, keepdim = True)
        adjusted_prob = rearrange(adjusted_prob, 'b 1 d -> b d')

        prob_next = torch.where(
            rearrange(has_rejected, '... -> ... 1'),
            adjusted_prob,
            prob_next
        )

        # 进行一系列切片操作,将所有内容对齐到右侧,包括kv缓存

        max_num_rejected = num_rejected.amax()
        seq_arange = torch.arange(out.shape[-1], device = device, dtype = torch.long)
        seq_offset_indices = seq_arange + (max_num_rejected - num_rejected)[..., None]

        seq_lens -= num_rejected
        max_seq_len = seq_lens.amax()

        if batch > 1:
            out = F.pad(out, (0, max_num_rejected), value = pad_id)
            out = out[batch_range, seq_offset_indices]

            cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in cache)
            cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in cache)
            cache = tuple(t[batch_range, seq_offset_indices] for t in cache)
            cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in cache)

            if out.shape[-1] > max_seq_len:
                left_index = out.shape[-1] - max_seq_len
                out = out[:, left_index:]
                cache = tuple(t[..., left_index:, :] for t in cache)

        # 采样额外的token,这是论文中的一个技巧,用于更好地限制最坏情况

        next_token = torch.multinomial(prob_next, 1)

        out = torch.cat((out, next_token), dim = -1)
        seq_lens += 1

        _, embeds = cache
        next_prophet_start_tokens = to_prophet_start_token(embeds[:, -num_start_tokens:])
    # 将输出向左对齐

    # 计算需要左侧填充的数量
    num_pad_left = out.shape[-1] - seq_lens
    # 计算最大的左侧填充数量
    max_pad_left = num_pad_left.amax()
    # 在输出张量的最后一个维度上进行填充,左侧填充0,右侧填充最大填充数量,填充值为pad_id
    out = F.pad(out, (0, max_pad_left), value=pad_id)

    # 创建一个序列长度范围的张量
    seq_len_range = torch.arange(seq_len, device=device, dtype=torch.long)
    # 从out张量中选择出需要的部分,根据batch_range和seq_len_range进行索引
    out = out[batch_range, seq_len_range + num_pad_left[..., None]]

    # 返回去除prompt_seq_len长度后的out张量和total_accepted除以num_steps的结果
    return out[..., prompt_seq_len:], total_accepted / num_steps

.\lucidrains\speculative-decoding\speculative_decoding\__init__.py

# 从 speculative_decoding.speculative_decoding 模块中导入 Decoder、base_decoding、speculative_decoding、speculative_decoding_with_same_model 函数
from speculative_decoding.speculative_decoding import (
    Decoder,
    base_decoding,
    speculative_decoding,
    speculative_decoding_with_same_model
)

.\lucidrains\speculative-decoding\train.py

# 导入所需的库
import gzip
import random
import tqdm
import numpy as np
import time
from functools import wraps, partial
import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.cuda import synchronize, Event
from torch.utils.data import DataLoader, Dataset
from speculative_decoding import (
    Decoder,
    base_decoding,
    speculative_decoding
)

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRAD_ACCUM_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
PRIME_LENGTH = 128
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512
GAMMA = 5
DEVICE_STR = 'cuda' if torch.cuda.is_available() else 'cpu'

# 定义辅助函数
def cycle(loader):
    while True:
        for data in loader:
            yield data

def decode_token(token):
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

def benchmark(fn):
    @wraps(fn)
    def inner(*args, **kwargs):
        start_event = timer()
        end_event = timer()
        start_event.record()

        out = fn(*args, **kwargs)

        end_event.record()
        torch.cuda.synchronize()
        elapsed_time_ms = start_event.elapsed_time(end_event)
        return out, elapsed_time_ms
    return inner

# 实例化 Transformer 模型
device = torch.device(DEVICE_STR)
model = Decoder(
    num_tokens = 256,
    dim = 512,
    depth = 10
).to(device)

# 实例化小型 Transformer 模型
small_model = Decoder(
    num_tokens = 256,
    dim = 512,
    depth = 2
).to(device)

# 准备 enwik8 数据
with gzip.open("./data/enwik8.gz") as file:
    data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
    np_train, np_valid = np.split(data, [int(90e6)])
    data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.to(device)

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))

# 定义优化器
optim = Adam(model.parameters(), lr = LEARNING_RATE)
small_optim = Adam(small_model.parameters(), lr = LEARNING_RATE)

# 训练过程
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
    model.train()
    small_model.train()

    for _ in range(GRAD_ACCUM_EVERY):
        data = next(train_loader)

        loss = model(data, return_loss = True)
        small_loss = small_model(data, return_loss = True)

        (loss / GRAD_ACCUM_EVERY).backward()
        (small_loss / GRAD_ACCUM_EVERY).backward()

    print(f"training loss: {loss.item():.3f}")
    print(f"training small loss: {small_loss.item():.3f}")

    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    torch.nn.utils.clip_grad_norm_(small_model.parameters(), 0.5)

    optim.step()
    optim.zero_grad()

    small_optim.step()
    small_optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            valid_data = next(val_loader)

            loss = model(valid_data, return_loss = True)
            print(f"validation loss: {loss.item():.3f}")

            small_loss = small_model(valid_data, return_loss = True)
            print(f"validation small loss: {small_loss.item():.3f}")
    # 检查是否达到生成频率
    if i % GENERATE_EVERY == 0:
        # 将模型设置为评估模式
        model.eval()
        small_model.eval()

        # 从验证数据集中随机选择一个样本作为输入
        inp = random.choice(val_dataset)[:PRIME_LENGTH]
        # 将输入解码为文本
        prime = decode_tokens(inp)
        # 打印输入的提示信息
        print(f"%s \n\n %s", (prime, "*" * 100))

        # 将输入转换为张量
        prompt = inp[None, ...]

        # 使用基本解码函数生成文本序列,并记录基本解码时间
        sampled, base_decode_elapsed = benchmark(base_decoding)(model, prompt, GENERATE_LENGTH)

        # 使用推测解码函数生成文本序列,并记录推测解码时间以及接受的标记数量
        (spec_decode_sampled, num_accepted), spec_decode_elapsed = benchmark(speculative_decoding)(model, small_model, prompt, GENERATE_LENGTH, GAMMA)

        # 将基本解码和推测解码的输出解码为文本
        base_decode_output = decode_tokens(sampled[0])
        spec_decode_output = decode_tokens(spec_decode_sampled[0])

        # 打印基本解码的输出
        print("\nbase decoding:\n\n", base_decode_output, "\n")
        # 打印推测解码的输出
        print("\nspec decoding:\n\n", spec_decode_output, "\n")

        # 打印基本解码的时间
        print(f'base decoding in: {base_decode_elapsed:.3f}ms\n')
        # 打印推测解码的时间
        print(f'spec decoding in: {spec_decode_elapsed:.3f}ms\n')
        # 打印平均接受的标记数量
        print(f'average num accepted: {num_accepted:.1f} / {GAMMA}\n')

.\lucidrains\speculative-decoding\train_early_exit.py

# 导入必要的库
import gzip
import random
import tqdm
import numpy as np
import time
from functools import wraps, partial

import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.cuda import synchronize, Event
from torch.utils.data import DataLoader, Dataset

# 创建计时器
timer = partial(Event, enable_timing = True)

# 导入自定义模块
from speculative_decoding import (
    Decoder,
    base_decoding,
    speculative_decoding_with_same_model
)

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRAD_ACCUM_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
PRIME_LENGTH = 128
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512
GAMMA = 5
EARLY_EXIT_LOSS_WEIGHT = 1.

DEVICE_STR = 'cuda' if torch.cuda.is_available() else 'cpu'

# 定义循环函数
def cycle(loader):
    while True:
        for data in loader:
            yield data

# 解码单个 token
def decode_token(token):
    return str(chr(max(32, token)))

# 解码一组 tokens
def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

# 计时装饰器
def benchmark(fn):
    @wraps(fn)
    def inner(*args, **kwargs):
        start_event = timer()
        end_event = timer()
        start_event.record()

        out = fn(*args, **kwargs)

        end_event.record()
        torch.cuda.synchronize()
        elapsed_time_ms = start_event.elapsed_time(end_event)
        return out, elapsed_time_ms
    return inner

# 实例化 Transformer 模型
device = torch.device(DEVICE_STR)

model = Decoder(
    num_tokens = 256,
    dim = 512,
    depth = 10,
    early_exit_layer = 2   # 使用与小近似模型相同的模型,稍后考虑缓存层隐藏状态
).to(device)

# 准备 enwik8 数据
with gzip.open("./data/enwik8.gz") as file:
    data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
    np_train, np_valid = np.split(data, [int(90e6)])
    data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.to(device)

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))

# 定义优化器
optim = Adam(model.parameters(), lr = LEARNING_RATE)

# 训练过程
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
    model.train()

    for _ in range(GRAD_ACCUM_EVERY):
        data = next(train_loader)

        loss, small_loss = model(data, return_loss = True)

        ((loss + small_loss * EARLY_EXIT_LOSS_WEIGHT) / GRAD_ACCUM_EVERY).backward()

    print(f"training loss: {loss.item():.3f}")
    print(f"training small loss: {small_loss.item():.3f}")

    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            valid_data = next(val_loader)

            loss, small_loss = model(valid_data, return_loss = True)
            print(f"validation loss: {loss.item():.3f}")
            print(f"validation small loss: {small_loss.item():.3f}")
    # 检查是否达到生成的次数
    if i % GENERATE_EVERY == 0:
        # 将模型设置为评估模式
        model.eval()

        # 从验证数据集中随机选择一个样本作为输入
        inp = random.choice(val_dataset)[:PRIME_LENGTH]
        # 将输入解码为文本
        prime = decode_tokens(inp)
        # 打印输入的文本和分隔符
        print(f"%s \n\n %s", (prime, "*" * 100))

        # 将输入转换为张量
        prompt = inp[None, ...]

        # 使用基本解码函数对模型进行基本解码,并记录时间
        sampled, base_decode_elapsed = benchmark(base_decoding)(model, prompt, GENERATE_LENGTH)

        # 使用具有相同模型的推测解码函数对模型进行推测解码,并记录时间
        (spec_decode_sampled, num_accepted), spec_decode_elapsed = benchmark(speculative_decoding_with_same_model)(model, prompt, GENERATE_LENGTH, GAMMA)

        # 将基本解码的输出解码为文本
        base_decode_output = decode_tokens(sampled[0])
        # 将推测解码的输出解码为文本
        spec_decode_output = decode_tokens(spec_decode_sampled[0])

        # 打印基本解码的输出
        print("\nbase decoding:\n\n", base_decode_output, "\n")
        # 打印推测解码的输出
        print("\nspec decoding:\n\n", spec_decode_output, "\n")

        # 打印基本解码的时间
        print(f'base decoding in: {base_decode_elapsed:.3f}ms\n')
        # 打印推测解码的时间
        print(f'spec decoding in: {spec_decode_elapsed:.3f}ms\n')
        # 打印平均接受的数量
        print(f'average num accepted: {num_accepted:.1f} / {GAMMA}\n')

.\lucidrains\speculative-decoding\train_prophet.py

# 导入必要的库
import gzip
import random
import tqdm
import numpy as np
import time
from functools import wraps, partial
import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.cuda import synchronize, Event
from torch.utils.data import DataLoader, Dataset

# 创建计时器
timer = partial(Event, enable_timing = True)

# 导入自定义模块
from speculative_decoding.speculative_decoding_with_prophet import (
    Decoder,
    ModelWithProphetWrapper,
    base_decoding,
    speculative_decoding_with_prophet_model
)

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRAD_ACCUM_EVERY = 4
LEARNING_RATE = 1e-4
PRIME_LENGTH = 128
GENERATE_EVERY = 100
GENERATE_LENGTH = 512
SEQ_LEN = 512
GAMMA = 5
TRAIN_PROPHET = True

DEVICE_STR = 'cuda' if torch.cuda.is_available() else 'cpu'

# 定义辅助函数

# 生成数据循环
def cycle(loader):
    while True:
        for data in loader:
            yield data

# 解码单个 token
def decode_token(token):
    return str(chr(max(32, token)))

# 解码一组 tokens
def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

# 计时装饰器
def benchmark(fn):
    @wraps(fn)
    def inner(*args, **kwargs):
        start_event = timer()
        end_event = timer()
        start_event.record()

        out = fn(*args, **kwargs)

        end_event.record()
        torch.cuda.synchronize()
        elapsed_time_ms = start_event.elapsed_time(end_event)
        return out, elapsed_time_ms
    return inner

# 实例化 Transformer 模型

device = torch.device(DEVICE_STR)

model = Decoder(
    num_tokens = 256,
    dim = 512,
    depth = 10
)

prophet = Decoder(
    num_tokens = 256,
    dim = 512,
    depth = 2
)

model_and_prophet = ModelWithProphetWrapper(
    model,
    prophet,
    prophet_train_length = GAMMA + 2,
    num_leading_start_tokens = 2,
    detach_model_embed_for_prophet = False
).to(device)

# 准备 enwik8 数据

with gzip.open("./data/enwik8.gz") as file:
    data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
    np_train, np_valid = np.split(data, [int(90e6)])
    data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)

# 创建数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.to(device)

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练和验证数据集
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))

# 选择优化器参数
params = model_and_prophet.parameters() if TRAIN_PROPHET else model.parameters()

# 创建优化器
optim = Adam(params, lr = LEARNING_RATE)

# 训练循环
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
    model_and_prophet.train()

    for _ in range(GRAD_ACCUM_EVERY):
        data = next(train_loader)

        total_loss, (loss, prophet_loss) = model_and_prophet(data)

        (total_loss / GRAD_ACCUM_EVERY).backward()

    print(f"training loss: {loss.item():.3f}")
    print(f"training prophet loss: {prophet_loss.item():.3f}")

    torch.nn.utils.clip_grad_norm_(model_and_prophet.parameters(), 0.5)

    optim.step()
    optim.zero_grad()
    # 检查是否达到生成频率
    if i % GENERATE_EVERY == 0:
        # 将模型和prophet评估为当前状态
        model_and_prophet.eval()

        # 从验证数据集中随机选择一个样本作为输入
        inp = random.choice(val_dataset)[:PRIME_LENGTH]
        # 将输入解码为文本
        prime = decode_tokens(inp)
        # 打印输入的prime文本和分隔符
        print(f"%s \n\n %s", (prime, "*" * 100))

        # 将输入转换为张量
        prompt = inp[None, ...]

        # 使用基本解码函数对模型进行基本解码
        sampled, base_decode_elapsed = benchmark(base_decoding)(model, prompt, GENERATE_LENGTH)

        # 使用带有prophet模型的推测解码函数对模型进行推测解码
        (spec_decode_sampled, num_accepted), spec_decode_elapsed = benchmark(speculative_decoding_with_prophet_model)(model_and_prophet, prompt, GENERATE_LENGTH, GAMMA)

        # 将基本解码和推测解码的输出解码为文本
        base_decode_output = decode_tokens(sampled[0])
        spec_decode_output = decode_tokens(spec_decode_sampled[0])

        # 打印基本解码的输出
        print("\nbase decoding:\n\n", base_decode_output, "\n")
        # 打印推测解码的输出
        print("\nspec decoding:\n\n", spec_decode_output, "\n")

        # 打印基本解码的时间
        print(f'base decoding in: {base_decode_elapsed:.3f}ms\n')
        # 打印推测解码的时间
        print(f'spec decoding in: {spec_decode_elapsed:.3f}ms\n')
        # 打印平均接受的数量
        print(f'average num accepted: {num_accepted:.1f} / {GAMMA}\n')

.\lucidrains\st-moe-pytorch\assert.py

# 导入必要的库
import os
from copy import deepcopy
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from st_moe_pytorch.st_moe_pytorch import Experts, Expert
from st_moe_pytorch.distributed import all_gather_variable_dim

# 设置初始化函数,用于初始化分布式训练环境
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

# 清理函数,用于销毁进程组
def cleanup():
    dist.destroy_process_group()

# 主函数,用于启动分布式训练
def start(
    rank,
    world_size,
    batch_size,
    batch_size_var_len,
    num_experts,
    tokens_per_expert,
    dim,
    use_cuda
):
    # 初始化分布式训练环境
    setup(rank, world_size)

    # 创建专家网络
    net = Experts([Expert(dim) for _ in range(num_experts)])

    # 根据是否变长批次设置批次大小
    if batch_size_var_len:
        batch_size = batch_size + rank

    # 生成随机输入序列
    seq = torch.randn(batch_size, num_experts, tokens_per_expert, dim)

    # 本地计算

    # 深拷贝专家网络
    local_net = deepcopy(net)

    # 聚合所有进程的输入数据
    local_inputs, _ = all_gather_variable_dim(seq)

    # 在本地网络上进行前向传播
    local_out = local_net(
        local_inputs,
        is_distributed=False
    )

    # 计算本地输出的均值并进行反向传播
    local_out.mean().backward()

    # 分布式计算

    # 使用分布式数据并行模型
    model = DDP(net)
    ddp_inputs = seq

    # 如果使用CUDA,则将模型和输入数据移动到对应设备
    if use_cuda:
        model.cuda(rank)
        ddp_inputs = seq.cuda(rank)

    # 在分布式模型上进行前向传播
    out = model(ddp_inputs)
    out.mean().backward()

    # 聚合所有进程的输出数据
    ddp_all_out, _ = all_gather_variable_dim(out)

    if rank == 0:
        # 验证本地和分布式输出是否一致

        # 将模型和输出数据移回CPU
        model.cpu()
        ddp_all_out.cpu()

        # 使用assert检查本地和分布式输出是否一致
        assert torch.allclose(local_out, ddp_all_out.cpu(), atol=1e-3), 'output is not the same'

        # 验证本地和分布式第一个专家的梯度是否一致

        # 定义获取第一个专家梯度的函数
        get_first_expert_grad = lambda t: t.experts[0].net[0].weight.grad

        # 使用assert检查本地和分布式第一个专家的梯度是否一致
        assert torch.allclose(
            get_first_expert_grad(net).cpu(),
            get_first_expert_grad(local_net),
            atol=1e-2
        ), 'grad is not the same'

        # 输出验证结果
        print('✅ outputs and gradients are same between local and ddp')

    # 清理环境
    cleanup()

# 主程序入口
if __name__ == '__main__':
    # 设置参数
    world_size = 8
    num_experts = 3
    batch_size = 2
    batch_size_var_len = True
    use_cuda = False

    # 检查是否使用CUDA并且设备数量小于等于进程数量
    assert not use_cuda or torch.cuda.device_count() <= world_size

    seq_len = 32
    dim = 8

    # 使用多进程启动分布式训练
    mp.spawn(
        start,
        args=(
            world_size,
            batch_size,
            batch_size_var_len,
            num_experts,
            seq_len,
            dim,
            use_cuda
        ),
        nprocs=world_size,
        join=True
    )

ST-MoE - Pytorch

Implementation of ST-MoE, the latest incarnation of mixture of experts after years of research at Brain, in Pytorch. Will be largely a transcription of the official Mesh Tensorflow implementation. If you have any papers you think should be added, while I have my attention on mixture of experts, please open an issue.

This should be SOTA for mixture-of-experts for autoregressive transformers. It is rumored that GPT4 is using 16 experts with top2 gating.

For non-autoregressive, would recommend going with the simpler and better Soft MoE.

Install

$ pip install st-moe-pytorch

Appreciation

  • StabilityAI for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.

  • Aran Komatsuzaki for consultation on mixture-of-experts, for removal of 2-level MoE and simplifications to code

Usage

import torch
from st_moe_pytorch import MoE

moe = MoE(
    dim = 512,
    num_experts = 16,               # increase the experts (# parameters) of your model without increasing computation
    gating_top_n = 2,               # default to top 2 gating, but can also be more (3 was tested in the paper with a lower threshold)
    threshold_train = 0.2,          # at what threshold to accept a token to be routed to second expert and beyond - 0.2 was optimal for 2 expert routing, and apparently should be lower for 3
    threshold_eval = 0.2,
    capacity_factor_train = 1.25,   # experts have fixed capacity per batch. we need some extra capacity in case gating is not perfectly balanced.
    capacity_factor_eval = 2.,      # capacity_factor_* should be set to a value >=1
    balance_loss_coef = 1e-2,       # multiplier on the auxiliary expert balancing auxiliary loss
    router_z_loss_coef = 1e-3,      # loss weight for router z-loss
)

inputs = torch.randn(4, 1024, 512)
out, total_aux_loss, balance_loss, router_z_loss = moe(inputs) # (4, 1024, 512), (1,), (1,), (1,)

# for the entire mixture of experts block, in context of transformer

from st_moe_pytorch import SparseMoEBlock

moe_block = SparseMoEBlock(
    moe,
    add_ff_before = True,
    add_ff_after = True
)

out, total_aux_loss, balance_loss, router_z_loss = moe_block(inputs) # (4, 1024, 512), (1,) (1,), (1,)

# the total auxiliary loss will need to be summed and then added to the main loss

# the other two losses are the unweighted breakdown for logging purposes

Todo

Citations

@inproceedings{Zoph2022STMoEDS,
    title   = {ST-MoE: Designing Stable and Transferable Sparse Expert Models},
    author  = {Barret Zoph and Irwan Bello and Sameer Kumar and Nan Du and Yanping Huang and Jeff Dean and Noam M. Shazeer and William Fedus},
    year    = {2022}
}

.\lucidrains\st-moe-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的信息
setup(
  name = 'st-moe-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.1.7',  # 版本号
  license='MIT',  # 许可证
  description = 'ST - Mixture of Experts - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/st-moe-pytorch',  # URL
  keywords = [
    'artificial intelligence',  # 关键词
    'deep learning',  # 关键词
    'mixture of experts'  # 关键词
  ],
  install_requires=[
    'beartype',  # 安装所需的包
    'CoLT5-attention>=0.10.15',  # 安装所需的包
    'einops>=0.6',  # 安装所需的包
    'torch>=2.0',  # 安装所需的包
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类器
    'Intended Audience :: Developers',  # 分类器
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类器
    'License :: OSI Approved :: MIT License',  # 分类器
    'Programming Language :: Python :: 3.6',  # 分类器
  ],
)

.\lucidrains\st-moe-pytorch\st_moe_pytorch\distributed.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 torch.autograd 模块中导入 Function 类
from torch.autograd import Function

# 从 torch.distributed 模块中导入 dist 对象
import torch.distributed as dist

# 从 einops 库中导入 rearrange, pack, unpack 函数
from einops import rearrange, pack, unpack

# 定义函数,判断变量是否存在
def exists(val):
    return val is not None

# 定义函数,返回默认值
def default(val, d):
    return val if exists(val) else d

# 定义函数,判断两个数是否整除
def divisible_by(num, den):
    return (num % den) == 0

# 定义函数,将张量在指定维度上进行填充
def pad_dim_to(t, length, dim = 0):
    pad_length = length - t.shape[dim]
    zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
    return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length)

# 定义函数,对所有进程进行相同维度的全局收集
def all_gather_same_dim(t):
    t = t.contiguous()
    world_size = dist.get_world_size()
    gathered_tensors = [torch.empty_like(t, device = t.device, dtype = t.dtype) for i in range(world_size)]
    dist.all_gather(gathered_tensors, t)
    return gathered_tensors

# 定义函数,收集指定维度的大小信息
def gather_sizes(t, *, dim):
    size = torch.tensor(t.shape[dim], device = t.device, dtype = torch.long)
    sizes = all_gather_same_dim(size)
    return torch.stack(sizes)

# 定义函数,判断张量是否只有一个值
def has_only_one_value(t):
    return (t == t[0]).all()

# 定义函数,对所有进程进行变量维度的全局收集
def all_gather_variable_dim(t, dim = 0, sizes = None):
    device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()

    if not exists(sizes):
        sizes = gather_sizes(t, dim = dim)

    if has_only_one_value(sizes):
        gathered_tensors = all_gather_same_dim(t)
        gathered_tensors = torch.cat(gathered_tensors, dim = dim)
        return gathered_tensors, sizes

    max_size = sizes.amax().item()

    padded_t = pad_dim_to(t, max_size, dim = dim)
    gathered_tensors = all_gather_same_dim(padded_t)

    gathered_tensors = torch.cat(gathered_tensors, dim = dim)
    seq = torch.arange(max_size, device = device)

    mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
    mask = rearrange(mask, 'i j -> (i j)')
    seq = torch.arange(mask.shape[-1], device = device)
    indices = seq[mask]

    gathered_tensors = gathered_tensors.index_select(dim, indices)

    return gathered_tensors, sizes

# 定义 AllGatherFunction 类,继承自 Function 类
class AllGatherFunction(Function):
    @staticmethod
    def forward(ctx, x, dim, sizes):
        x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
        ctx.batch_sizes = batch_sizes.tolist()
        ctx.dim = dim
        return x, batch_sizes

    @staticmethod
    def backward(ctx, grads, _):
        batch_sizes, rank = ctx.batch_sizes, dist.get_rank()
        grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
        return grads_by_rank[rank], None, None

# 定义 AllGather 类,继承自 nn.Module 类
class AllGather(nn.Module):
    def __init__(self, *, dim = 0):
        super().__init__()
        self.dim = dim

    def forward(self, x, sizes = None):
        return AllGatherFunction.apply(x, self.dim, sizes)

# 定义函数,根据进程排名拆分张量
def split_by_rank(x):
    rank = dist.get_rank()
    out = x[rank]

    if isinstance(x, tuple):
        sizes = tuple(map(lambda t: t.shape[0], x))
    else:
        sizes = (x.shape[1],) * x.shape[0]

    sizes = torch.tensor(sizes, device = out.device, dtype = torch.long)
    return out, sizes

.\lucidrains\st-moe-pytorch\st_moe_pytorch\st_moe_pytorch.py

# 导入必要的库
from functools import partial
from collections import namedtuple
from typing import Optional, Tuple, Union

import torch
from torch.nn import Module, ModuleList
from torch import nn, einsum
import torch.nn.functional as F

# 导入额外的库
from beartype import beartype
from einops import rearrange, repeat, reduce, pack, unpack
from colt5_attention import topk as maybe_differentiable_topk
import torch.distributed as dist
from st_moe_pytorch.distributed import (
    AllGather,
    split_by_rank,
    gather_sizes,
    pad_dim_to,
    has_only_one_value
)

# 常量定义
MIN_EXPERT_CAPACITY = 4
MixtureOfExpertsReturn = namedtuple('MixtureOfExpertsReturn', [
    'outputs',
    'total_aux_loss',
    'balance_loss',
    'router_z_loss'
])

# 辅助函数定义

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, default):
    if exists(val):
        return val
    return default() if callable(default) else default

# 判断一个数是否可以被另一个数整除
def divisible_by(num, den):
    return (num % den) == 0

# 将一个数均匀分成多个部分
def chunk_num(num, chunks):
    num_per_chunk, remainder = divmod(num, chunks)
    out = []
    for i in range(chunks):
        n = num_per_chunk
        out.append(n + int(i < remainder))
    return out

# 将一个张量按照指定模式打包
def pack_one(t, pattern):
    return pack([t], pattern)

# 将一个打包的张量按照指定模式解包
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 将元素转换为元组
def cast_tuple(el, len = 1):
    return el if isinstance(el, tuple) else ((el,) * len)

# 创建一个序列模块
def Sequential(*modules):
    return nn.Sequential(*filter(exists, modules))

# 与张量相关的辅助函数

# 计算张量的累积和(不包括当前元素)
def cumsum_exclusive(t, dim = -3):
    assert dim < 0
    num_pad_dims = -dim - 1
    pre_padding = (0, 0) * num_pad_dims
    return F.pad(t, (*pre_padding, 1, -1)).cumsum(dim = dim)

# 计算张量的对数
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 生成古贝尔噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 安全的独热编码函数,避免索引超出范围
def safe_one_hot(indexes, max_length):
    max_index = indexes.max() + 1
    one_hot_classes = max(max_index + 1, max_length)
    return F.one_hot(indexes, one_hot_classes)[..., :max_length]

# RMS归一化

class RMSNorm(Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return F.normalize(x, dim = -1) * self.gamma * self.scale

# 专家类
# 最佳表现是在门控后使用乘法偏置的ff geglu

class GEGLU(Module):
    def __init__(
        self,
        dim,
        mult_bias = True
    ):
        super().__init__()
        self.mult_bias = nn.Parameter(torch.ones(dim)) if mult_bias else 1.

    def forward(self, x):
        x, gate = x.chunk(2, dim = -1)
        return F.gelu(gate) * x * self.mult_bias

class Expert(Module):
    def __init__(
        self,
        dim,
        hidden_mult = 4,
        mult_bias = True,
        prenorm = False
    ):
        super().__init__()
        dim_hidden = int(dim * hidden_mult * 2 / 3)

        self.net = Sequential(
            RMSNorm(dim) if prenorm else None,
            nn.Linear(dim, dim_hidden * 2),
            GEGLU(dim_hidden, mult_bias = mult_bias),
            nn.Linear(dim_hidden, dim)
        )

        self.apply(self.init_)

    def init_(self, module):
        if isinstance(module, nn.Linear):
            dim = module.weight.shape[0]
            std = dim ** -0.5

            module.weight.data.uniform_(-std, std)
            module.bias.data.uniform_(-std, std)

    def forward(self, x):
        return self.net(x)

class Experts(nn.Module):
    def __init__(
        self,
        experts,
        is_distributed = None,
        allow_var_seq_len = False # 是否处理可变序列长度
    # 初始化函数,设置专家数量和专家模块列表
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 获取专家数量并初始化专家模块列表
        self.num_experts = len(experts)
        self.experts = nn.ModuleList(experts)

        # 分布式相关设置

        # 是否处于分布式环境
        self.is_distributed = is_distributed
        # 如果未指定是否分布式,则根据当前环境判断
        if not exists(self.is_distributed):
            self.is_distributed = dist.is_initialized() and dist.get_world_size() > 1

        # 创建 AllGather 对象
        self.all_gather = AllGather()

        # 是否允许变长序列长度
        self.allow_var_seq_len = allow_var_seq_len

        # 设备跟踪器,需要手动将未使用的专家移动到 CPU 上

        # 注册缓冲区,用于跟踪设备
        self.register_buffer('dummy', torch.ones(1), persistent = False)

    # 设备属性,返回 dummy 的设备
    @property
    def device(self):
        return self.dummy.device

    # 将除了指定专家之外的所有专家移动到 CPU
    def all_experts_to_cpu_besides(self, selection):
        # 根据选择的专家索引或切片获取专家列表
        if isinstance(selection, int):
            experts = [self.experts[selection]]
        if isinstance(selection, slice):
            experts = self.experts[selection]
        else:
            experts = selection

        # 将专家列表转换为集合
        experts_set = set(experts)

        # 遍历所有专家,根据是否在选择的专家列表中决定设备
        for expert in self.experts:
            device = self.device if expert in experts_set else 'cpu'
            expert.to(device)

    # 前向传播函数
    def forward(
        self,
        x,
        is_distributed = None
# 定义一个名为 TopNGating 的类,继承自 Module 类
class TopNGating(Module):

    # 初始化方法,接受多个参数
    @beartype
    def __init__(
        self,
        dim,  # 维度
        num_gates,  # 门的数量
        eps = 1e-9,  # 微小值
        top_n = 2,  # 顶部 N 个
        threshold_train: Union[float, Tuple[float, ...]] = 0.2,  # 训练阈值
        threshold_eval: Union[float, Tuple[float, ...]] = 0.2,  # 评估阈值
        capacity_factor_train = 1.25,  # 训练容量因子
        capacity_factor_eval = 2.,  # 评估容量因子
        straight_through_dispatch_tensor = True,  # 直通分发张量
        differentiable_topk = False,  # 可微分的 topk
        differentiable_topk_fused = True  # 融合的可微分 topk
    ):
        super().__init__()  # 调用父类的初始化方法
        self.eps = eps  # 将 eps 赋值给实例变量
        self.num_gates = num_gates  # 将 num_gates 赋值给实例变量
        self.to_gates = nn.Linear(dim, num_gates, bias = False)  # 创建一个线性层

        self.differentiable_topk = differentiable_topk  # 将 differentiable_topk 赋值给实例变量

        # 部分函数应用,使用 maybe_differentiable_topk 函数
        self.topk = partial(
            maybe_differentiable_topk,
            non_differentiable = not differentiable_topk,
            fused = differentiable_topk_fused  # 默认情况下使用 Triton 融合坐标下降
        )

        assert top_n >= 2, 'must be 2 or more experts'  # 断言,确保 top_n 大于等于 2
        self.top_n = top_n  # 将 top_n 赋值给实例变量
        top_n_minus_1 = top_n - 1  # 计算 top_n 减 1

        threshold_train = cast_tuple(threshold_train, top_n_minus_1)  # 将 threshold_train 转换为元组
        threshold_eval = cast_tuple(threshold_eval, top_n_minus_1)  # 将 threshold_eval 转换为元组

        assert len(threshold_train) == len(threshold_eval) == top_n_minus_1  # 断言,确保长度相等

        # 将 threshold_train 和 threshold_eval 转换为张量,并注册为缓冲区
        self.register_buffer('threshold_train', torch.tensor([eps, *threshold_train]))
        self.register_buffer('threshold_eval', torch.tensor([eps, *threshold_eval]))

        self.capacity_factor_train = capacity_factor_train  # 将 capacity_factor_train 赋值给实例变量
        self.capacity_factor_eval = capacity_factor_eval  # 将 capacity_factor_eval 赋值给实例变量

        self.straight_through_dispatch_tensor = straight_through_dispatch_tensor  # 将 straight_through_dispatch_tensor 赋值给实例变量
        # 将零值注册为缓冲区
        self.register_buffer('zero', torch.zeros((1,)), persistent = False)

    # 前向传播方法
    def forward(
        self,
        x,  # 输入张量
        noise_gates = False,  # 是否添加噪音到门
        noise_mult = 1.  # 噪音倍数



# 定义一个名为 MoE 的类,继承自 Module 类
class MoE(Module):

    # 初始化方法,接受多个参数
    @beartype
    def __init__(self,
        dim,  # 维度
        num_experts = 16,  # 专家数量
        expert_hidden_mult = 4,  # 专家隐藏倍数
        threshold_train = 0.2,  # 训练阈值
        threshold_eval = 0.2,  # 评估阈值
        capacity_factor_train = 1.25,  # 训练容量因子
        capacity_factor_eval = 2.,  # 评估容量因子
        gating_top_n = 2,  # 门的顶部 N 个
        balance_loss_coef = 1e-2,  # 平衡损失系数
        router_z_loss_coef = 1e-3,  # 路由器 z 损失系数
        experts: Optional[Module] = None,  # 专家模块
        straight_through_dispatch_tensor = True,  # 直通分发张量
        differentiable_topk = False,  # 可微分的 topk
        differentiable_topk_fused = True,  # 融合的可微分 topk
        is_distributed = None,  # 是否分布式
        allow_var_seq_len = False  # 是否允许可变序列长度
    ):
        super().__init__()  # 调用父类的初始化方法
        self.dim = dim  # 将 dim 赋值给实例变量
        self.num_experts = num_experts  # 将 num_experts 赋值给实例变量

        # 创建一个 TopNGating 实例
        self.gate = TopNGating(
            dim,
            top_n = gating_top_n,
            num_gates = num_experts,
            straight_through_dispatch_tensor = straight_through_dispatch_tensor,
            differentiable_topk = differentiable_topk,
            threshold_train = threshold_train,
            threshold_eval = threshold_eval,
            capacity_factor_train = capacity_factor_train,
            capacity_factor_eval = capacity_factor_eval
        )

        # 如果 experts 为 None,则创建一个专家列表
        experts = default(experts, lambda: [Expert(dim = dim, hidden_mult = expert_hidden_mult) for _ in range(num_experts)])

        # 创建一个 Experts 实例
        self.experts = Experts(
            experts,
            is_distributed = is_distributed,
            allow_var_seq_len = allow_var_seq_len
        )

        self.balance_loss_coef = balance_loss_coef  # 将 balance_loss_coef 赋值给实例变量
        self.router_z_loss_coef = router_z_loss_coef  # 将 router_z_loss_coef 赋值给实例变量

    # 前向传播方法
    def forward(
        self,
        x,  # 输入张量
        noise_gates = False,  # 是否添加噪音到门
        noise_mult = 1.  # 噪音倍数
        dispatch_tensor, combine_tensor, balance_loss, router_z_loss = self.gate(x, noise_gates = noise_gates, noise_mult = noise_mult)
        # 调用gate方法,获取dispatch_tensor、combine_tensor、balance_loss和router_z_loss

        # dispatch
        expert_inputs = einsum('b n d, b n e c -> b e c d', x, dispatch_tensor)
        # 使用einsum函数将输入x和dispatch_tensor进行张量乘法,得到expert_inputs

        # feed the expert inputs through the experts.
        expert_outputs = self.experts(expert_inputs)
        # 将expert_inputs传递给experts方法,得到expert_outputs

        # combine
        output = einsum('b e c d, b n e c -> b n d', expert_outputs, combine_tensor)
        # 使用einsum函数将expert_outputs和combine_tensor进行张量乘法,得到output

        # losses
        weighted_balance_loss = balance_loss * self.balance_loss_coef
        weighted_router_z_loss = router_z_loss * self.router_z_loss_coef
        # 计算加权的balance_loss和router_z_loss

        # combine the losses
        total_aux_loss = weighted_balance_loss + weighted_router_z_loss
        # 将加权的balance_loss和router_z_loss相加得到总的辅助损失

        return MixtureOfExpertsReturn(output, total_aux_loss, balance_loss, router_z_loss)
        # 返回MixtureOfExpertsReturn对象,包含output、total_aux_loss、balance_loss和router_z_loss
# 定义一个稀疏的 Mixture of Experts(MoE)块
# 特别是,他们发现在前后添加一个前馈网络可以极大地稳定训练并改善结果

class SparseMoEBlock(Module):

    @beartype
    def __init__(
        self,
        moe: MoE,
        *,
        add_ff_before = False,
        add_ff_after = True
    ):
        super().__init__()
        dim = moe.dim

        # 初始化 MoE 模块和 RMSNorm 模块
        self.moe = moe
        self.moe_prenorm = RMSNorm(dim)

        # 根据参数决定是否添加前馈网络
        self.ff_before = Expert(dim, prenorm = True) if add_ff_before else None
        self.ff_after = Expert(dim, prenorm = True) if add_ff_after else None

    def forward(
        self,
        x,
        noise_gates = False,
        noise_mult = 1.
    ):

        # 前馈网络之前的处理

        if exists(self.ff_before):
            x = self.ff_before(x) + x

        # 专家混合层

        residual = x

        # 调用 MoE 模块进行前向传播
        moe_out, total_aux_loss, balance_loss, router_z_loss = self.moe(self.moe_prenorm(x), noise_gates = noise_gates, noise_mult = noise_mult)

        x = moe_out + residual

        # 前馈网络之后的处理

        if exists(self.ff_after):
            x = self.ff_after(x) + x

        # 返回 MoE 模块的输出结果和相关损失
        return MixtureOfExpertsReturn(x, total_aux_loss, balance_loss, router_z_loss)
posted @ 2024-06-28 14:14  绝不原创的飞龙  阅读(4)  评论(0编辑  收藏  举报