Lucidrains-系列项目源码解析-三十七-

Lucidrains 系列项目源码解析(三十七)

.\lucidrains\RETRO-pytorch\retro_pytorch\training.py

import numpy as np
from functools import partial
import json
from pathlib import Path

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from retro_pytorch import RETRO, RETRODataset
from retro_pytorch.data import knn_to_retrieved_chunks
from retro_pytorch.optimizer import get_optimizer
from retro_pytorch.retrieval import text_folder_to_chunks_, chunks_to_precalculated_knn_, bert_embed, SOS_ID, EOS_ID
from retro_pytorch.utils import memmap, is_true_env_flag

from einops import rearrange

# helpers

# 检查值是否存在
def exists(val):
    return val is not None

# 评估装饰器,用于在评估时切换模型状态
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 安全拼接张量
def safe_cat(accum, t, dim = -1):
    if not exists(accum):
        return t
    return torch.cat((accum, t), dim = dim)

# sampling helpers

# 对数函数
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps)

# 生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 从 Gumbel 噪声中采样
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)

# Top-K 采样
def top_k(logits, thres = 0.9):
    num_logits = logits.shape[-1]
    k = max(int((1 - thres) * num_logits), 1)
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# Top-P 采样
def top_p(logits, thres = 0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    sorted_indices_to_remove = cum_probs > (1 - thres)
    sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
    sorted_indices_to_remove[:, 0] = 0

    sorted_logits[sorted_indices_to_remove] = float('-inf')
    return sorted_logits.scatter(1, sorted_indices, sorted_logits)

# 从序列块获取 KNN 块的函数
def knn_chunks_from_seq_chunks(
    seq_chunks,
    *,
    knn,
    faiss_index,
    num_chunks,
    chunk_size,
    chunks_memmap_path,
):
    b, device = seq_chunks.shape[0], seq_chunks.device

    # 为 BERT 嵌入准备带有 SOS 和 EOS 标记的最后一个块

    ones = torch.ones((b, 1), dtype = torch.bool, device = device)
    sos = ones * SOS_ID
    eos = ones * EOS_ID

    seq_chunks = torch.cat((sos, seq_chunks, eos), dim = 1)

    # 使用冻结的 BERT 进行嵌入

    embeds = bert_embed(seq_chunks.cpu()) # 暂时在 CPU 上获取嵌入

    # 使用 faiss 检索 KNN

    _, knn_indices = faiss_index.search(embeds.cpu().numpy(), k = knn)

    # numpy 转换为 torch

    with memmap(chunks_memmap_path, dtype = np.int32, shape = (num_chunks + 1, chunk_size + 1)) as chunk_memmap:
        knn_chunks = knn_to_retrieved_chunks(
            knn_indices,
            chunk_memmap,
            add_continuations = True,
            num_chunks = num_chunks
        )

        knn_chunks_torch = torch.from_numpy(knn_chunks).to(device)

    return knn_chunks_torch

# 训练包装类
class TrainingWrapper(nn.Module):
    def __init__(
        self,
        *,
        retro,
        chunk_size,
        documents_path,
        knn,
        glob = '**/*.txt',
        chunks_memmap_path = './train.chunks.dat',
        seqs_memmap_path = './train.seq.dat',
        doc_ids_memmap_path = './train.doc_ids.dat',
        max_chunks = 1_000_000,
        max_seqs = 100_000,
        knn_extra_neighbors = 100,
        processed_stats_json_path = './processed-stats.json',
        faiss_index_filename = 'knn.index',
        **index_kwargs
    # 初始化 RETROGenerator 类
    def __init__(
        self,
        retro: RETRO,
        processed_stats_json_path: str,
        documents_path: str,
        chunks_memmap_path: str,
        seqs_memmap_path: str,
        doc_ids_memmap_path: str,
        chunk_size: int,
        max_chunks: int,
        max_seqs: int,
        knn: int,
        knn_extra_neighbors: int,
        faiss_index_filename: str,
        **index_kwargs
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 断言 retro 必须是 RETRO 类的实例
        assert isinstance(retro, RETRO), 'retro must be instance of RETRO'
        # 将 retro 赋值给 self.retro
        self.retro = retro

        # 检查是否需要强制重新处理数据
        force_reprocess = is_true_env_flag('REPROCESS')

        # 存储处理后的训练数据统计信息,如块数和序列数
        stats_path = Path(processed_stats_json_path)

        # 如果统计文件不存在或需要强制重新处理,则处理文本文件夹
        if not stats_path.exists() or force_reprocess:
            # 调用函数处理文本文件夹,返回统计信息
            self.stats = text_folder_to_chunks_(
                folder = documents_path,
                glob = glob,
                chunks_memmap_path = chunks_memmap_path,
                seqs_memmap_path = seqs_memmap_path,
                doc_ids_memmap_path = doc_ids_memmap_path,
                chunk_size = chunk_size,
                seq_len = retro.seq_len,
                max_chunks = max_chunks,
                max_seqs = max_seqs
            )
            # 将统计信息写入 JSON 文件
            with open(processed_stats_json_path, 'w') as f:
                json.dump(self.stats, f)
        else:
            # 如果统计文件已经存在,则加载已处理的统计信息
            print(f'found to be previously processed at {str(stats_path)}')
            self.stats = json.loads(stats_path.read_text())

        # 获取块数和序列数
        num_chunks = self.stats['chunks']
        num_seqs = self.stats['seqs']

        # 计算 knn 的内存映射路径并获取 faiss 索引
        knn_memmap_path, faiss_index = chunks_to_precalculated_knn_(
            num_chunks = num_chunks,
            chunk_size = chunk_size,
            chunk_memmap_path = chunks_memmap_path,
            doc_ids_memmap_path = doc_ids_memmap_path,
            num_nearest_neighbors = knn,
            num_extra_neighbors = knn_extra_neighbors,
            index_file = faiss_index_filename,
            force_reprocess = force_reprocess,
            **index_kwargs
        )

        # 初始化 RETRODataset 类
        self.ds = RETRODataset(
            num_sequences = num_seqs,
            num_chunks = num_chunks,
            num_neighbors = knn,
            chunk_size = chunk_size,
            seq_len = retro.seq_len,
            chunk_memmap_path = chunks_memmap_path,
            chunk_nn_memmap_path = knn_memmap_path,
            seq_memmap_path = seqs_memmap_path
        )

        # 生成所需的参数
        self.chunk_size = chunk_size
        self.max_seq_len = self.retro.seq_len

        # 部分函数,用于从序列块中获取 knn 块
        self.fetch_knn_chunks_fn = partial(
            knn_chunks_from_seq_chunks,
            knn = knn,
            chunk_size = chunk_size,
            num_chunks = num_chunks,
            chunks_memmap_path = chunks_memmap_path,
            faiss_index = faiss_index
        )

    # 生成文本的方法,使用装饰器进行评估
    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        start = None,
        retrieved = None,
        filter_fn = top_k,
        filter_thres = 0.9,
        temperature = 1.0,
    ):
        # 断言过滤函数必须是top-k或nucleus
        assert filter_fn in {top_k, top_p}, 'filter function must be either top-k or nucleus'

        # 获取设备信息
        device = next(self.retro.parameters()).device

        # 如果没有给定起始标记,则假设从SOS标记开始,批量大小为1
        if not exists(start):
            start = torch.full((1, 1), SOS_ID, device=device).long()

        b, start_seq_len = start.shape

        # 将起始标记移动到与RETRO相同的设备上
        start = start.to(device)

        # 准备检索相关变量
        if start_seq_len >= self.chunk_size:
            seq_index = (start_seq_len // self.chunk_size) * self.chunk_size
            past_seq_chunks = rearrange(start[:, :seq_index], 'b (n c) -> (b n) c', c=self.chunk_size)

            # 获取KNN块
            retrieved = self.fetch_knn_chunks_fn(past_seq_chunks)
            retrieved = rearrange(retrieved, '(b n) k c -> b n k c', b=b)

        # 获取起始序列索引
        out = start

        # 采样循环
        for i in range(start_seq_len - 1, self.max_seq_len):

            logits = self.retro(out, retrieved=retrieved)
            logits = logits[:, i]

            logits = filter_fn(logits, thres=filter_thres)
            sampled = gumbel_sample(logits, temperature=temperature, dim=-1)
            sampled = rearrange(sampled, 'b -> b 1')

            out = torch.cat((out, sampled), dim=1)

            # 如果全部是EOS标记,则提前终止
            is_eos_tokens = (out == EOS_ID)

            if is_eos_tokens.any(dim=-1).all():

                # 在EOS标记后屏蔽所有内容
                shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
                mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1
                out = out.masked_fill(mask, self.retro.pad_id)
                break

            # 当序列长度是块大小的倍数时,检索下一组KNN
            curr_seq_len = out.shape[-1]

            if (curr_seq_len % self.chunk_size) == 0:
                last_chunk = rearrange(out, 'b (c n) -> b c n', n=self.chunk_size)[:, -1]

                knn_chunks = self.fetch_knn_chunks_fn(last_chunk)

                # 将检索到的KNN块连接到所有检索到的内容中
                # 以便在下一次迭代中发送到Retro进行块交叉注意力
                knn_chunks = rearrange(knn_chunks, 'b k r -> b 1 k r')
                retrieved = safe_cat(retrieved, knn_chunks, dim=1)

                print(f'retrieved at {curr_seq_len} / {self.max_seq_len}')

        return out

    # 获取数据加载器
    def get_dataloader(self, **kwargs):
        return DataLoader(self.ds, **kwargs)

    # 获取优化器
    def get_optimizer(self, **kwargs):
        return get_optimizer(self.retro.parameters(), **kwargs)

    # 前向传播函数
    def forward(self):
        raise NotImplemented

.\lucidrains\RETRO-pytorch\retro_pytorch\utils.py

# 导入 os 模块
import os
# 导入 numpy 模块并重命名为 np
import numpy as np

# 从 pathlib 模块中导入 Path 类
from pathlib import Path
# 从 shutil 模块中导入 rmtree 函数
from shutil import rmtree
# 从 contextlib 模块中导入 contextmanager 装饰器
from contextlib import contextmanager

# 检查环境变量是否为真
def is_true_env_flag(env_flag):
    return os.getenv(env_flag, 'false').lower() in ('true', '1', 't')

# 重置文件夹
def reset_folder_(p):
    # 创建 Path 对象
    path = Path(p)
    # 删除文件夹及其内容,如果文件夹不存在则忽略错误
    rmtree(path, ignore_errors = True)
    # 创建文件夹,如果文件夹已存在则忽略
    path.mkdir(exist_ok = True, parents = True)

# 创建内存映射对象的上下文管理器
@contextmanager
def memmap(*args, **kwargs):
    # 创建内存映射对象
    pointer = np.memmap(*args, **kwargs)
    # 通过 yield 将指针传递给调用者
    yield pointer
    # 在退出上下文管理器时删除内存映射对象
    del pointer

.\lucidrains\RETRO-pytorch\retro_pytorch\__init__.py

# 从 retro_pytorch.retro_pytorch 模块中导入 RETRO 类
# 从 retro_pytorch.data 模块中导入 RETRODataset 类
# 从 retro_pytorch.training 模块中导入 TrainingWrapper 类
from retro_pytorch.retro_pytorch import RETRO
from retro_pytorch.data import RETRODataset
from retro_pytorch.training import TrainingWrapper

.\lucidrains\RETRO-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'retro-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.3.9',  # 版本号
  license='MIT',  # 许可证
  description = 'RETRO - Retrieval Enhanced Transformer - Pytorch',  # 描述
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/RETRO-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention-mechanism',
    'retrieval',
  ],
  install_requires=[  # 安装依赖
    'autofaiss',
    'einops>=0.3',
    'numpy',
    'sentencepiece',
    'torch>=1.6',
    'tqdm'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\ring-attention-pytorch\assert.py

# 导入必要的库
import os
from math import ceil
from copy import deepcopy
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from ring_attention_pytorch.ring_attention import RingTransformer
from ring_attention_pytorch.distributed import all_gather_variable_dim

# 设置分布式训练环境
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

# 清理分布式训练环境
def cleanup():
    dist.destroy_process_group()

# 启动训练
def start(
    rank,
    world_size,
    batch_size,
    batch_size_var_len,
    seq_len,
    num_sharded_batches,
    causal,
    striped_ring_attn,
    dim,
    use_cuda
):
    # 设置分布式训练环境
    setup(rank, world_size)

    # 计算环形注意力网络的参数
    ring_seq_size = ceil(seq_len / world_size) * num_sharded_batches
    bucket_size = ring_seq_size // 2

    # 创建环形注意力网络和普通注意力网络
    ring_attention_net = RingTransformer(
        num_tokens=256,
        dim=dim,
        causal=causal,
        depth=2,
        dim_head=8,
        ring_attn=True,
        striped_ring_attn=striped_ring_attn,
        ring_seq_size=ring_seq_size,
        bucket_size=bucket_size
    )

    flash_attention_net = RingTransformer(
        num_tokens=256,
        dim=dim,
        causal=causal,
        depth=2,
        dim_head=8,
        ring_attn=False,
        bucket_size=bucket_size
    )

    # 加载环形注意力网络的参数到普通注意力网络
    flash_attention_net.load_state_dict(ring_attention_net.state_dict())

    # 根据是否变长批次更新批次大小
    if batch_size_var_len:
        batch_size = batch_size + rank

    # 生成随机序列
    seq = torch.randint(0, 256, (batch_size, seq_len))

    # 封装成分布式数据并行模型
    ddp_ring_attention_net = DDP(ring_attention_net)
    ddp_flash_attention_net = DDP(flash_attention_net)

    # 如果使用 GPU,将数据和模型移动到对应 GPU
    if use_cuda:
        seq = inputs.cuda(rank)
        flash_attention_net.cuda(rank)
        ring_attention_net.cuda(rank)

    # 在普通注意力网络上进行前向传播和反向传播
    flash_out = ddp_flash_attention_net(seq)
    flash_out.mean().backward()

    # 在环形注意力网络上进行前向传播和反向传播
    ring_out = ddp_ring_attention_net(seq)
    ring_out.mean().backward()

    # 验��序列跨多台机器和不跨机器时输出是否相同
    if rank == 0:
        ring_attention_net = ring_attention_net.cpu()
        flash_attention_net = flash_attention_net.cpu()
        ring_out = ring_out.cpu()
        flash_out = flash_out.cpu()

        assert torch.allclose(ring_out, flash_out, atol=1e-6), 'output is not the same'

        # 验证环形和非环形注意力网络的 token embedding 梯度是否相同
        get_embed_grad = lambda model: model.token_emb.weight.grad
        ring_embed_grad = get_embed_grad(ring_attention_net)
        flash_embed_grad = get_embed_grad(flash_attention_net)

        assert torch.allclose(
            ring_embed_grad,
            flash_embed_grad,
            atol=1e-2
        ), 'grad is not the same'

        print('✅ outputs and gradients are same between ring attention and non-ring attention')

    # 清理分布式训练环境
    cleanup()

# 主函数入口
if __name__ == '__main__':
    # 设置参数
    world_size = 8
    batch_size = 2
    num_sharded_batches = 1
    batch_size_var_len = False
    use_cuda = False
    causal = True
    striped_ring_attn = True

    # 断言检查是否使用 GPU 数量小于等于机器数量
    assert not use_cuda or torch.cuda.device_count() <= world_size

    seq_len = 31
    dim = 8

    # 多进程启动训练
    mp.spawn(
        start,
        args=(
            world_size,
            batch_size,
            batch_size_var_len,
            seq_len,
            num_sharded_batches,
            causal,
            striped_ring_attn,
            dim,
            use_cuda
        ),
        nprocs=world_size,
        join=True
    )

.\lucidrains\ring-attention-pytorch\assert_flash.py

# 导入 torch 库
import torch

# 从 ring_attention_pytorch 模块中导入 default_attention 和 ring_flash_attn 函数
from ring_attention_pytorch import (
    default_attention,
    ring_flash_attn
)

# 定义变量

# 是否使用因果关系
causal = True
# 序列长度
seq_len = 62
# 桶大小
bucket_size = 4

# 基础的 qkv

# 随机生成 q 张量,形状为 (2, seq_len, 2, 16)
q = torch.randn(2, seq_len, 2, 16)
# 随机生成 k 张量,形状为 (2, seq_len, 2, 16)
k = torch.randn(2, seq_len, 2, 16)
# 随机生成 v 张量,形状为 (2, seq_len, 2, 16)
v = torch.randn(2, seq_len, 2, 16)

# flash 和 regular qkv

# 克隆 q 张量,并设置 requires_grad 为 True
fq = q.clone().requires_grad_()
# 克隆 k 张量,并设置 requires_grad 为 True
fk = k.clone().requires_grad_()
# 克隆 v 张量,并设置 requires_grad 为 True
fv = v.clone().requires_grad_()

# 克隆 q 张量,并设置 requires_grad 为 True
rq = q.clone().requires_grad_()
# 克隆 k 张量,并设置 requires_grad 为 True
rk = k.clone().requires_grad_()
# 克隆 v 张量,并设置 requires_grad 为 True
rv = v.clone().requires_grad_()

# 前向传播

# 使用 default_attention 函数计算输出 o
o = default_attention(rq, rk, rv, causal=causal)
# 使用 ring_flash_attn 函数计算输出 fo
fo = ring_flash_attn(fq, fk, fv, bucket_size=bucket_size, causal=causal)

# 断言 o 和 fo 的值在给定的容差范围内相等
assert torch.allclose(o, fo, atol=1e-6)

# 反向传播

# 对 o 求和并进行反向传播
o.sum().backward()
# 对 fo 求和并进行反向传播
fo.sum().backward()

# 断言 rq.grad 和 fq.grad 的值在给定的容差范围内相等
assert torch.allclose(rq.grad, fq.grad, atol=1e-6)
# 断言 rk.grad 和 fk.grad 的值在给定的容差范围内相等
assert torch.allclose(rk.grad, fk.grad, atol=1e-6)
# 断言 rv.grad 和 fv.grad 的值在给定的容差范围内相等
assert torch.allclose(rv.grad, fv.grad, atol=1e-6)

Ring Attention - Pytorch

Explorations into Ring Attention, from Liu et al. at Berkeley AI.

It basically splits the data across the sequence dimension (instead of batch) and applies ring reduce to the processing of the tiles of the attention matrix, flash attention style.

I believe this is being used for the 1-10 million tokens for the latest Gemini. At least some form of it; the other possibility would be unpublished improvements on top of RMT.

In addition, the repository also contains the logic for Striped Attention, a follow up paper that permutes the sequence for better workload balancing for autoregressive transformers.

Appreciation

  • A16Z Open Source AI Grant Program for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research

Install

$ pip install ring-attention-pytorch

Usage

import torch
from ring_attention_pytorch import RingAttention

attn = RingAttention(
    dim = 512,
    dim_head = 64,
    heads = 8,
    causal = True,
    auto_shard_seq = True,
    ring_attn = True,
    ring_seq_size = 512
)

tokens = torch.randn(1, 1024, 512)
attended = attn(tokens)

assert attended.shape == tokens.shape

Test

$ python assert.py

Todo

Citations

@article{Liu2023RingAW,
    title    = {Ring Attention with Blockwise Transformers for Near-Infinite Context},
    author   = {Hao Liu and Matei Zaharia and Pieter Abbeel},
    journal  = {ArXiv},
    year     = {2023},
    volume   = {abs/2310.01889},
    url      = {https://api.semanticscholar.org/CorpusID:263608461}
}
@article{Brandon2023StripedAF,
    title   = {Striped Attention: Faster Ring Attention for Causal Transformers},
    author  = {William Brandon and Aniruddha Nrusimha and Kevin Qian and Zachary Ankner and Tian Jin and Zhiye Song and Jonathan Ragan-Kelley},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2311.09431},
    url     = {https://api.semanticscholar.org/CorpusID:265220849}
}
@article{Dao2022FlashAttentionFA,
    title   = {FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
    author  = {Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2205.14135}
}
@article{dao2023flashattention2,
    title   = {Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning,
    author  = {Dao, Tri},
    year    = {2023}
}
@article{Tillet2019TritonAI,
    title   = {Triton: an intermediate language and compiler for tiled neural network computations},
    author  = {Philippe Tillet and H. Kung and D. Cox},
    journal = {Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages},
    year    = {2019}
}

.\lucidrains\ring-attention-pytorch\ring_attention_pytorch\distributed.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch.nn 模块中导入 Module 类
from torch.nn import Module
# 从 torch.nn.functional 模块中导入 F 函数
import torch.nn.functional as F
# 从 torch.autograd 模块中导入 Function 类
from torch.autograd import Function

# 导入 torch.distributed 模块
import torch.distributed as dist

# 定义函数,判断变量是否存在
def exists(val):
    return val is not None

# 定义函数,如果变量存在则返回该变量,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 定义函数,判断两个数是否整除
def divisible_by(num, den):
    return (num % den) == 0

# 定义函数,将张量在指定维度上填充到指定长度
def pad_dim_to(t, length, dim = 0):
    pad_length = length - t.shape[dim]
    zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
    return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length)

# 定义函数,将所有进程中的张量在相同维度上聚合
def all_gather_same_dim(t):
    t = t.contiguous()
    world_size = dist.get_world_size()
    gathered_tensors = [torch.empty_like(t, device = t.device, dtype = t.dtype) for i in range(world_size)]
    dist.all_gather(gathered_tensors, t)
    return gathered_tensors

# 定义函数,收集张量在指定维度上的大小信息
def gather_sizes(t, *, dim):
    size = torch.tensor(t.shape[dim], device = t.device, dtype = torch.long)
    sizes = all_gather_same_dim(size)
    return torch.stack(sizes)

# 定义函数,判断张量是否只有一个值
def has_only_one_value(t):
    return (t == t[0]).all()

# 定义函数,将所有进程中的张量在指定维度上聚合,并处理变长情况
def all_gather_variable_dim(t, dim = 0, sizes = None):
    device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()

    if not exists(sizes):
        sizes = gather_sizes(t, dim = dim)

    if has_only_one_value(sizes):
        gathered_tensors = all_gather_same_dim(t)
        gathered_tensors = torch.cat(gathered_tensors, dim = dim)
        return gathered_tensors, sizes

    max_size = sizes.amax().item()

    padded_t = pad_dim_to(t, max_size, dim = dim)
    gathered_tensors = all_gather_same_dim(padded_t)

    gathered_tensors = torch.cat(gathered_tensors, dim = dim)
    seq = torch.arange(max_size, device = device)

    mask = einx.less('j, i -> (i j)', seq, sizes)
    seq = torch.arange(mask.shape[-1], device = device)
    indices = seq[mask]

    gathered_tensors = gathered_tensors.index_select(dim, indices)

    return gathered_tensors, sizes

# 定义自定义函数类,用于实现分布式全局聚合
class AllGatherFunction(Function):
    @staticmethod
    def forward(ctx, x, dim, sizes):
        x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
        ctx.batch_sizes = batch_sizes.tolist()
        ctx.dim = dim
        return x, batch_sizes

    @staticmethod
    def backward(ctx, grads, _):
        batch_sizes, rank = ctx.batch_sizes, dist.get_rank()
        grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
        return grads_by_rank[rank], None, None

# 定义类,实现分布式全局聚合
class AllGather(Module):
    def __init__(self, *, dim = 0):
        super().__init__()
        self.dim = dim

    def forward(self, x, sizes = None):
        return AllGatherFunction.apply(x, self.dim, sizes)

# 定义函数,根据进程编号拆分张量
def split_by_rank(x):
    rank = dist.get_rank()
    out = x[rank]

    if isinstance(x, tuple):
        sizes = tuple(map(lambda t: t.shape[0], x))
    else:
        sizes = (x.shape[1],) * x.shape[0]

    sizes = torch.tensor(sizes, device = out.device, dtype = torch.long)
    return out, sizes

.\lucidrains\ring-attention-pytorch\ring_attention_pytorch\ring.py

# 导入必要的模块
from typing import Optional
from functools import lru_cache, partial, wraps
from collections import namedtuple

import torch
from torch import nn, Tensor
from torch.nn import Module, ModuleList
from torch.autograd import Function

import torch.distributed as dist

# 辅助函数

# 检查变量是否存在
def exists(v):
    return v is not None

# 如果变量存在则返回该变量,否则返回默认值
def default(v, d):
    return v if exists(v) else d

# 将输入转换为元组,如果不是元组则重复多次
def cast_tuple(t, length = 1):
    return t if isinstance(t, tuple) else ((t,) * length)

# 缓存装饰器,用于缓存函数的结果
cache = partial(lru_cache, maxsize = None)

# 分布式全局变量

# 获取当前进程的排名
@cache()
def get_rank():
    return dist.get_rank() if dist.is_initialized() else 0

# 获取世界中进程的数量
@cache()
def get_world_size():
    return dist.get_world_size() if dist.is_initialized() else 1

# 判断是否处于分布式环境
@cache()
def is_distributed():
    return dist.is_initialized() and dist.get_world_size() > 1

# 环函数

# 左循环索引
def circular_index_left(pos, ring_size, num = 1):
    return ((pos - num) + ring_size) % ring_size

# 右循环索引
def circular_index_right(pos, ring_size, num = 1):
    return (pos + num) % ring_size

# 分布式环

# 左循环排名
def circular_rank_left(rank = None, ring_size = None, num = 1):
    rank = default(rank, get_rank())
    ring_size = default(ring_size, get_world_size())
    ring_set_num = rank // ring_size
    offset = ring_set_num * ring_size
    return circular_index_left(rank, ring_size, num) + offset

# 右循环排名
def circular_rank_right(rank = None, ring_size = None, num = 1):
    rank = default(rank, get_rank())
    ring_size = default(ring_size, get_world_size())
    ring_set_num = rank // ring_size
    offset = ring_set_num * ring_size
    return circular_index_right(rank, ring_size, num) + offset

# 单次环传递

# 发送和接收数据
def send_and_receive_(x, receive_buffer, send_to_rank, receive_from_rank):
    send_op = dist.P2POp(dist.isend, x, send_to_rank)
    recv_op = dist.P2POp(dist.irecv, receive_buffer, receive_from_rank)

    reqs = dist.batch_isend_irecv([send_op, recv_op])

    for req in reqs:
        req.wait()

    dist.barrier()

# 环传递
def ring_pass(
    num_ring_passes: int,
    x: Tensor,
    receive_buffer: Optional[Tensor] = None,
    ring_size: Optional[int] = None
):
    ring_size = default(ring_size, get_world_size())
    x = x.contiguous()

    if not exists(receive_buffer):
        receive_buffer = torch.zeros_like(x)
    else:
        receive_buffer = receive_buffer.contiguous()

    send_and_receive_(x, receive_buffer, circular_rank_right(ring_size = ring_size), circular_rank_left(ring_size = ring_size))
    return receive_buffer, x

# 一次环传递
one_ring_pass = partial(ring_pass, 1)

# 迭代器,用于所有张量的所有环传递

# 环信息命名元组
RingInfo = namedtuple('RingInfo', ['ring_rank', 'iter_info'])

# 空环传递
def null_ring_pass(*tensors, max_iters = None, receive_buffers = None, ring_size = None):
    yield RingInfo(0, (True, True)), (tensors, receive_buffers)

# 所有环传递
def all_ring_pass(*tensors, max_iters = None, receive_buffers = None, ring_size = None):
    ring_size = default(ring_size, get_world_size())
    max_iters = default(max_iters, ring_size)

    receive_buffers = cast_tuple(receive_buffers, len(tensors))

    # 确保迭代次数在1和世界大小之间
    total_iters = max(1, min(ring_size, max_iters))

    curr_ring_pos = get_rank()

    for ind in range(total_iters):
        is_first = ind == 0
        is_last = ind == (total_iters - 1)

        yield RingInfo(curr_ring_pos, (is_first,  is_last)), (tensors, receive_buffers)

        curr_ring_pos = circular_index_left(curr_ring_pos, ring_size)

        if is_last:
            continue

        new_tensors = []
        new_receive_buffers = []

        for tensor, receive_buffer in zip(tensors, receive_buffers):
            if exists(tensor):
                new_tensor, new_receive_buffer = one_ring_pass(tensor, receive_buffer, ring_size)
            else:
                new_tensor, new_receive_buffer = None, None

            new_tensors.append(new_tensor)
            new_receive_buffers.append(new_receive_buffer)

        tensors = new_tensors
        receive_buffers = new_receive_buffers

.\lucidrains\ring-attention-pytorch\ring_attention_pytorch\ring_attention.py

# 导入必要的库
from typing import Optional, Tuple, Union

import torch
from torch import nn, einsum, Tensor
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch.nn import Module, ModuleList

import einx
from einx import rearrange

from beartype import beartype

# 导入自定义模块和函数
from ring_attention_pytorch.ring import (
    all_ring_pass,
    is_distributed,
    get_rank,
    get_world_size
)

from ring_attention_pytorch.ring_flash_attention import (
    ring_flash_attn
)

from ring_attention_pytorch.distributed import (
    split_by_rank,
    AllGather
)

# 辅助函数

# 检查变量是否存在
def exists(v):
    return v is not None

# 如果变量存在则返回变量,否则返回默认值
def default(v, d):
    return v if exists(v) else d

# 将输入转换为元组,如果输入已经是元组则返回,否则返回包含输入的元组
def cast_tuple(t, length = 1):
    return t if isinstance(t, tuple) else ((t,) * length)

# 检查一个数是否可以被另一个数整除
def divisible_by(num, den):
    return (num % den) == 0

# 默认的注意力函数
def default_attention(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    mask: Optional[Tensor] = None,
    causal: bool = False
):
    q = q * (q.shape[-1] ** -0.5)

    mask_value = -torch.finfo(q.dtype).max

    # 相似度计算

    sim = einsum('b i h d, b j h d -> b h i j', q, k)

    # 掩码处理

    if causal:
        i, j = sim.shape[-2:]
        causal_mask = torch.ones((i, j), dtype = torch.bool).triu(j - i + 1)
        sim = torch.where(causal_mask, mask_value, sim)

    elif exists(mask):
        sim = einx.where('b j, b h i j, -> b h i j', mask, sim, mask_value)

    # 注意力计算

    attn = einx.softmax('b h i [j]', sim)

    # 聚合

    out = einsum('b h i j, b j h d -> b i h d', attn, v)

    return out

# 旋转嵌入,支持条纹注意力的修改
class RingRotaryEmbedding(Module):
    def __init__(
        self,
        dim,
        ring: bool = False,
        striped: bool = False,
        buckets: int = 1,        # 在带有 flash buckets > 1 的条纹注意力中,需要指定每台机器的桶数
        theta = 10000
    ):
        super().__init__()
        self.ring = ring
        self.striped = striped
        self.buckets = buckets

        inv_freq = theta ** -(torch.arange(0, dim, 2).float() / dim)
        self.register_buffer('inv_freq', inv_freq)

    @property
    def device(self):
        return self.inv_freq.device

    @autocast(enabled = False)
    def forward(
        self,
        seq_len: int,
        offset = 0
    ):
        device = self.device
        pos = None

        if self.ring:
            if self.striped:
                buckets = self.buckets
                ring_stride = get_world_size() * buckets
                ring_offset = buckets

                pos = torch.arange(seq_len // buckets, device = device)
                pos = rearrange('n -> n b', pos, b = buckets)

                pos = pos * ring_stride
                pos += torch.arange(buckets, device = device) + (get_rank() * buckets)
                pos = rearrange('n b -> (b n)', pos)

            else:
                pos = torch.arange(seq_len, device = device)
                pos += seq_len * get_rank()
        else:
            pos = torch.arange(seq_len, device = device)

        pos = pos.type_as(self.inv_freq)
        freqs = torch.einsum('i , j -> i j', pos, self.inv_freq)
        return torch.cat((freqs, freqs), dim = -1)

# 旋转半部分
def rotate_half(x):
    x1, x2 = x.chunk(2, dim = -1)
    return torch.cat((-x2, x1), dim=-1)

@autocast(enabled = False)
def apply_rotary_pos_emb(pos, t):
    pos = rearrange('n d -> n 1 d', pos)
    return t * pos.cos() + rotate_half(t) * pos.sin()

# 批量到序列分片和反向操作

# 将张量填充到指定长度的倍数
def pad_to_multiple(
    x: Tensor,
    length: int,
    pad_value = 0
):
    seq_len = x.shape[-1]
    remainder = seq_len % length

    if remainder == 0:
        return x, 0

    pad_length = length - remainder
    return F.pad(x, (0, pad_length), value = pad_value), pad_length

# 可能填充序列和掩码
def maybe_pad_seq_and_mask(
    x: Tensor,
    mask: Optional[Tensor],
    seq_size: int
):
    orig_x, seq_len = x, x.shape[-1]
    # 自动填充序列和掩码,因为环传递假设张量的形状都相同

    # 调用函数将输入张量 x 填充到 seq_size 的倍数,并返回填充后的张量和填充长度
    x, pad_length = pad_to_multiple(x, seq_size)

    # 如果填充长度为 0,则直接返回填充后的张量 x 和掩码 mask
    if pad_length == 0:
        return x, mask

    # 如果掩码 mask 不存在,则创建一个与原始输入 orig_x 相同形状的全为 True 的掩码
    if not exists(mask):
        mask = torch.ones_like(orig_x).bool()

    # 调用函数将掩码 mask 填充到 seq_size 的倍数,并使用 False 值进行填充
    mask, _ = pad_to_multiple(mask, seq_size, pad_value = False)

    # 返回填充后的张量 x 和掩码 mask
    return x, mask
def sharded_batch_to_sharded_seq(
    x: Tensor,
    mask: Optional[Tensor],
    seq_size: int
):
    assert is_distributed()

    # 创建 AllGather 对象,用于在批次维度上进行全局收集
    all_gather = AllGather(dim = 0)

    # 在批次维度上对输入张量 x 进行全局收集
    x, sizes = all_gather(x)

    if exists(mask):
        # 如果存在 mask,则在批次维度上对 mask 进行全局收集
        mask, _ = all_gather(mask)

    # 确保世界大小可以被序列大小整除
    world_size = get_world_size()
    total_split_seq = x.shape[-1] // seq_size
    assert divisible_by(world_size, total_split_seq)

    num_sharded_batches = world_size // total_split_seq

    # 重新排列输入张量 x,以便在序列维度上进行分片
    x = rearrange('(b s) n -> b (s n)', x, s = num_sharded_batches)

    # 在序列维度上对 x 进行分片
    x = x.split(seq_size, dim = -1)

    # 根据排名对 x 进行分割
    x, _ = split_by_rank(x)

    if exists(mask):
        # 如果存在 mask,则重新排列 mask,并在序列维度上对其进行分片
        mask = rearrange('(b s) n -> b (s n)', mask, s = num_sharded_batches)
        mask = mask.split(seq_size, dim = -1)
        mask, _ = split_by_rank(mask)

    return (x, mask), sizes, num_sharded_batches

def sharded_seq_to_sharded_batch(
    logits: Tensor,
    sizes,
    num_sharded_batches = 1
):
    all_gather = AllGather(dim = -2) # 在序列维度上进行全局收集

    # 在序列维度上对 logits 进行全局收集
    logits, _ = all_gather(logits)

    # 重新排列 logits,以便在批次维度上进行分片
    logits = rearrange('b (s n) c -> (b s) n c', logits, s = num_sharded_batches)

    # 在批次维度上对 logits 进行分片
    logits = logits.split(sizes.tolist(), dim = 0)

    # 根据排名对 logits 进行分割
    logits, _ = split_by_rank(logits)

    return logits

# 主类 RingAttention
class RingAttention(Module):
    @beartype
    def __init__(
        self,
        dim: int,
        *,
        dim_head: int = 64,
        heads: int = 8,
        causal: bool = False,
        eps: float = 1e-10,
        bucket_size: int = 512,
        ring_attn: bool = False,
        ring_seq_size: int = 512,
        max_lookback_seq_len: Optional[int] = None,
        striped_ring_attn: bool = False,
        auto_shard_seq: Optional[bool] = None,
        prenorm: bool = True,
        force_regular_attn: bool = False,
        rotary_embed: bool = False,
        rotary_embed_theta: int = 10000,
        use_cuda_kernel: bool = None
    ):
        super().__init__()
        self.eps = eps
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.causal = causal

        assert divisible_by(ring_seq_size, bucket_size)

        self.ring_attn = ring_attn
        self.max_lookback_seq_len = max_lookback_seq_len
        self.striped_ring_attn = striped_ring_attn

        self.force_regular_attn = force_regular_attn
        self.auto_shard_seq = default(auto_shard_seq, ring_attn) # 这应该在 token ids 的转换器级别上完成,但出于测试目的

        assert not (not self.ring_attn and self.auto_shard_seq)

        self.ring_seq_size = ring_seq_size
        self.bucket_size = bucket_size

        # 初始化旋转嵌入
        self.rotary_embed = None
        if rotary_embed:
            self.rotary_embed = RingRotaryEmbedding(
                dim = dim_head,
                ring = ring_attn,
                striped = striped_ring_attn,
                theta = rotary_embed_theta,
                buckets = ring_seq_size // bucket_size
            )

        # 投影层
        dim_inner = dim_head * heads

        self.to_qkv = nn.Sequential(
            RMSNorm(dim) if prenorm else nn.Identity(),
            nn.Linear(dim, dim_inner * 3, bias = False)
        )

        self.to_out = nn.Linear(dim_inner, dim, bias = False)

        # 是否使用 flash attention cuda kernel
        self.use_cuda_kernel = default(use_cuda_kernel, torch.cuda.is_available())
        assert not (use_cuda_kernel and not torch.cuda.is_available())

    def forward(
        self,
        x,
        mask = None,
        rotary_emb = None,
        force_ring_reduce_off = False,
        ring_size = None,
        ):
        """
        einstein notation

        b - batch
        h - heads
        d - feature dimension
        n, i, j - sequence
        """

        # 设置环的大小为默认值或者获取当前环的大小
        ring_size = default(ring_size, get_world_size())
        # 判断是否使用环形注意力,并且当前环是否分布式
        ring_attn = self.ring_attn & is_distributed()
        # 判断是否自动分片序列,并且当前环是否分布式
        auto_shard_seq = self.auto_shard_seq & is_distributed()

        # 获取序列的长度
        seq_len = x.shape[-1]

        # 如果自动分片序列为真
        if auto_shard_seq:
            # 可能填充序列和掩码,使其长度符合环形序列的大小
            x, mask = maybe_pad_seq_and_mask(x, mask, self.ring_seq_size)

            # 如果使用条纹环形注意力
            if self.striped_ring_attn:
                # 重新排列张量维度,以适应条纹环形注意力
                x = rearrange('b (i j) d -> b (j i) d', x, i = self.bucket_size)

                # 如果存在掩码
                if exists(mask):
                    # 重新排列掩码张量维度,以适应条纹环形注意力
                    mask = rearrange('b (i j) -> b (j i)', mask, i = self.bucket_size)

            # 将批次转换为序列,并返回批次大小
            (x, mask), batch_sizes = sharded_batch_to_sharded_seq(x, mask, self.ring_seq_size)

        # 获取设备信息
        device = x.device

        # 将输入张量转换为查询、键、值
        qkv = self.to_qkv(x)
        q, k, v = rearrange('b n (qkv h d) -> qkv b n h d', qkv, qkv = 3, h = self.heads)

        # 旋转相对位置

        # 如果旋转嵌入不存在且存在旋转嵌入
        if not exists(rotary_emb) and exists(self.rotary_embed):
            # 生成旋转嵌入
            rotary_emb = self.rotary_embed(q.shape[-2])

        # 如果存在旋转嵌入
        if exists(rotary_emb):
            # 应用旋转位置嵌入到查询和键
            q = apply_rotary_pos_emb(rotary_emb, q)
            k = apply_rotary_pos_emb(rotary_emb, k)

        # 常规注意力 vs 闪存注意力(带或不带 kv 环减少)

        # 判断是否有任何 CUDA 输入
        any_cuda_inputs = any([t.is_cuda for t in (q, k, v)])

        # 如果强制使用常规注意力
        if self.force_regular_attn:
            # 使用默认的注意力机制
            out = default_attention(q, k, v, mask = mask, causal = self.causal)

        # 如果有任何 CUDA 输入并且使用 CUDA 内核
        elif any_cuda_inputs and self.use_cuda_kernel:
            # 导入 CUDA 实现的闪存注意力
            from ring_attention_pytorch.ring_flash_attention_cuda import ring_flash_attn_cuda

            # 使用 CUDA 实现的闪存注意力
            out = ring_flash_attn_cuda(
                q, k, v,
                mask,
                self.causal,
                self.bucket_size,
                ring_attn and not force_ring_reduce_off,
                self.striped_ring_attn and not force_ring_reduce_off,
                self.max_lookback_seq_len,
                ring_size
            )

        else:
            # 使用 Python 实现的闪存注意力
            out = ring_flash_attn(
                q, k, v,
                mask,
                self.causal,
                self.bucket_size,
                ring_attn and not force_ring_reduce_off,
                self.striped_ring_attn and not force_ring_reduce_off,
                self.max_lookback_seq_len,
                ring_size
            )

        # 合并头部
        out = rearrange('b n h d -> b n (h d)', out)
        out = self.to_out(out)

        # 如果自动分片序列为真
        if auto_shard_seq:
            # 将序列转换为批次,并截取到原始序列长度
            out, _ = sharded_seq_to_sharded_batch(out, batch_sizes)
            out = out[:, :seq_len]

        # 返回结果
        return out
# 定义一个简单的端到端测试的转换器

class RMSNorm(Module):
    # 初始化函数,接受一个维度参数
    def __init__(self, dim):
        super().__init__()
        # 计算缩放因子
        self.scale = dim ** 0.5
        # 初始化可学习参数 gamma
        self.gamma = nn.Parameter(torch.ones(dim))

    # 前向传播函数
    def forward(self, x):
        # 对输入进行归一化处理,乘以缩放因子和 gamma
        return F.normalize(x, dim = -1) * self.scale * self.gamma

# 定义一个前馈神经网络模块
def FeedForward(dim, mult = 4):
    # 计算内部维度
    dim_inner = int(dim * mult)
    return nn.Sequential(
        RMSNorm(dim),  # 使用 RMSNorm 进行归一化
        nn.Linear(dim, dim_inner),  # 线性变换
        nn.GELU(),  # GELU 激活函数
        nn.Linear(dim_inner, dim)  # 线性变换
    )

# 定义一个环形注意力机制模块
class RingTransformer(Module):
    # 初始化函数,接受多个参数
    @beartype
    def __init__(
        self,
        *,
        num_tokens: int,
        dim: int,
        depth: int,
        causal: bool = False,
        dim_head: int = 64,
        heads: int = 8,
        ff_mult: int = 4,
        bucket_size: int = 512,
        ring_attn: bool = False,
        striped_ring_attn: bool = False,
        ring_seq_size: int = 512,
        auto_shard_seq: Optional[bool] = None,
        max_lookback_seq_len: Optional[Union[Tuple[int, ...], int]] = None,
        rotary_embed_theta: int = 10000,    # 需要根据上下文中的百万标记进行更改
        ignore_index: int = -1
    ):
        super().__init__()
        # 初始化环形注意力机制相关参数
        self.ring_attn = ring_attn
        self.striped_ring_attn = striped_ring_attn

        self.ring_seq_size = ring_seq_size
        self.bucket_size = bucket_size
        assert divisible_by(ring_seq_size, bucket_size)

        self.auto_shard_seq = default(auto_shard_seq, ring_attn) # 如果环形注意力机制打开,则自动在序列维度上进行分片。这也可以关闭,在数据加载的其他地方手动完成

        assert not (not self.ring_attn and self.auto_shard_seq)
        assert not (not self.ring_attn and self.striped_ring_attn)
        assert not (self.striped_ring_attn and not causal), 'striped ring attention only applies to autoregressive models'

        # 初始化标记嵌入层
        self.token_emb = nn.Embedding(num_tokens, dim)

        # 初始化旋转嵌入层
        self.rotary_emb = RingRotaryEmbedding(
            dim = dim_head,
            ring = ring_attn,
            striped = striped_ring_attn,
            theta = rotary_embed_theta,
            buckets = ring_seq_size // bucket_size
        )

        # 初始化层列表
        self.layers = ModuleList([])

        max_lookback_seq_len = cast_tuple(max_lookback_seq_len, depth)
        assert len(max_lookback_seq_len) == depth

        for layer_max_lookback_seq_len in max_lookback_seq_len:

            self.layers.append(ModuleList([
                RingAttention(
                    dim = dim,
                    causal = causal,
                    dim_head = dim_head,
                    heads = heads,
                    bucket_size = bucket_size,
                    ring_attn = ring_attn,
                    ring_seq_size = ring_seq_size,
                    max_lookback_seq_len = layer_max_lookback_seq_len,
                    striped_ring_attn = striped_ring_attn,
                    auto_shard_seq = False,
                ),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

        # 输出层
        self.to_logits = nn.Sequential(
            RMSNorm(dim),
            nn.Linear(dim, num_tokens, bias = False)
        )

        # 训练相关

        self.ignore_index = ignore_index

    # 前向传播函数
    def forward(
        self,
        x,
        mask = None,
        labels = None,
        return_loss = False,
        force_ring_reduce_off = False,
        ring_size = None
        ):
        # 获取序列长度和设备信息
        seq_len, device = x.shape[-1], x.device

        # 是否自动分片序列,如果不强制关闭环形归约且自动分片序列且处于分布式环境下
        auto_shard_seq = not force_ring_reduce_off and self.auto_shard_seq and is_distributed()

        # 如果没有传入标签,则获取标签
        return_loss |= exists(labels)

        # 如果需要返回损失值且没有传入标签,则将输入数据切片为输入和标签
        if return_loss and not exists(labels):
            x, labels = x[:, :-1], x[:, 1:]

        # 处理填充以便将序列分割到不同机器上
        ring_size = default(ring_size, get_world_size())

        # 如果自动分片序列
        if auto_shard_seq:
            # 首先填充到右侧的倍数
            x, mask = maybe_pad_seq_and_mask(x, mask, self.ring_seq_size)

            # 处理标签
            if exists(labels):
                labels, label_mask = maybe_pad_seq_and_mask(labels, mask[:, 1:], self.ring_seq_size)
                labels.masked_fill_(~label_mask, self.ignore_index)

            # 考虑条纹注意力以进行工作负载平衡
            if self.striped_ring_attn:
                x = rearrange('b (i j) -> b (j i)', x, i = self.bucket_size)

                if exists(labels):
                    labels = rearrange('b (i j) -> b (j i)', labels, i = self.bucket_size)

                if exists(mask):
                    mask = rearrange('b (i j) -> b (j i)', mask, i = self.bucket_size)

            # 在批次之间收集并在世界中分割
            (x, mask), batch_sizes, num_sharded_batches = sharded_batch_to_sharded_seq(x, mask, self.ring_seq_size)

            if exists(labels):
                (labels, _), *_ = sharded_batch_to_sharded_seq(labels, None, self.ring_seq_size)

            # 根据分片批次数计算环大小
            ring_size = get_world_size() // num_sharded_batches

        # 旋转位置,考虑环和条纹
        rotary_emb = self.rotary_emb(x.shape[-1])

        # 主要的Transformer逻辑
        x = self.token_emb(x)

        for attn, ff in self.layers:
            x = attn(
                x,
                mask = mask,
                rotary_emb = rotary_emb,
                force_ring_reduce_off = force_ring_reduce_off,
                ring_size = ring_size
            ) + x

            x = ff(x) + x

        logits = self.to_logits(x)

        # 处理返回损失值
        if return_loss:
            logits = rearrange('b n c -> b c n', logits)

            ce_loss = F.cross_entropy(
                logits,
                labels,
                ignore_index = self.ignore_index
            )

            return ce_loss

        # 否则收集所有机器上的序列块以获取logits并分片批次维度
        if not auto_shard_seq:
            return logits

        logits = sharded_seq_to_sharded_batch(logits, batch_sizes, num_sharded_batches)

        if self.striped_ring_attn:
            logits = rearrange('b (i j) d -> b (j i) d', logits, j = self.bucket_size)

        return logits[:, :seq_len]

.\lucidrains\ring-attention-pytorch\ring_attention_pytorch\ring_flash_attention.py

# 导入数学库
import math
# 导入 functools 库中的 partial 函数
from functools import partial
# 导入 typing 库中的 Optional 类型
from typing import Optional

# 导入 torch 库
import torch
# 从 torch 库中导入 nn、einsum、Tensor 类
from torch import nn, einsum, Tensor
# 从 torch.autograd.function 中导入 Function 类
from torch.autograd.function import Function

# 导入 einx 库
import einx
# 从 einx 库中导入 rearrange 函数
from einx import rearrange

# 导入 ring_attention_pytorch.ring 模块中的函数
from ring_attention_pytorch.ring import (
    ring_pass,
    all_ring_pass,
    null_ring_pass,
    one_ring_pass,
    get_rank,
    get_world_size
)

# 导入 beartype 库中的 beartype 装饰器
from beartype import beartype

# 常量定义
EPSILON = 1e-10

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 判断一个数是否可以被另一个数整除
def divisible_by(num, den):
    return (num % den) == 0

# 生成一个无限循环产生 None 的迭代器
def none_iterator():
    while True:
        yield None

# 根据条件切分张量
def maybe_split(t, size, dim = -2):
    if not exists(t):
        return none_iterator()

    return t.split(size, dim = dim)

# ring + (flash) attention 前向和后向

# flash attention v1 - https://arxiv.org/abs/2205.14135
# flash attention v2 - https://tridao.me/publications/flash2/flash2.pdf
# ring attention - https://arxiv.org/abs/2310.01889

# 定义 RingFlashAttentionFunction 类
class RingFlashAttentionFunction(Function):

    # 静态方法,用于前向传播
    @staticmethod
    @torch.no_grad()
    def forward(
        ctx,
        q: Tensor,
        k: Tensor,
        v: Tensor,
        mask: Optional[Tensor],
        causal: bool,
        bucket_size: int,
        ring_reduce_col: bool,
        striped_ring_attn: bool,
        max_lookback_seq_len: Optional[int],
        ring_size: Optional[int]
    @staticmethod
    @torch.no_grad()
# 调用 RingFlashAttentionFunction 类的 apply 方法
ring_flash_attn_ = RingFlashAttentionFunction.apply

# 使用 beartype 装饰器定义 ring_flash_attn 函数
@beartype
def ring_flash_attn(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    mask: Optional[Tensor] = None,
    causal: bool = False,
    bucket_size: int = 1024,
    ring_reduce_col: bool = False,
    striped_ring_attn: bool = False,
    max_lookback_seq_len: Optional[int] = None,
    ring_size: Optional[int] = None
):
    # 调用 ring_flash_attn_ 函数
    return ring_flash_attn_(q, k, v, mask, causal, bucket_size, ring_reduce_col, striped_ring_attn, max_lookback_seq_len, ring_size)

.\lucidrains\ring-attention-pytorch\ring_attention_pytorch\ring_flash_attention_cuda.py

# 导入数学库
import math
# 导入 functools 库中的 partial 函数
from functools import partial
# 导入 typing 库中的 Optional 和 Tuple 类型
from typing import Optional, Tuple
# 导入 packaging 库中的 version 模块
import packaging.version as pkg_version

# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum, Tensor 模块
from torch import nn, einsum, Tensor
# 从 torch 库中导入 F 模块
import torch.nn.functional as F
# 从 torch.autograd.function 中导入 Function 类
from torch.autograd.function import Function

# 从 ring_attention_pytorch.ring 模块中导入相关函数
from ring_attention_pytorch.ring import (
    ring_pass,
    all_ring_pass,
    null_ring_pass,
    one_ring_pass,
    get_rank,
    get_world_size
)

# 从 beartype 库中导入 beartype 函数
from beartype import beartype

# 从 einops 库中导入 repeat, rearrange 函数
from einops import repeat, rearrange

# 定义函数 exists,判断变量是否存在
def exists(v):
    return v is not None

# 定义函数 pad_at_dim,对张量在指定维度进行填充
def pad_at_dim(t, pad: Tuple[int, int], *, dim = -1, value = 0.):
    dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
    zeros = ((0, 0) * dims_from_right)
    return F.pad(t, (*zeros, *pad), value = value)

# 定义函数 is_contiguous,判断张量是否是连续的
def is_contiguous(x):
    return x.stride(-1) == 1

# 确保 flash attention 已安装用于反向传播
import importlib
from importlib.metadata import version

# 断言 flash-attn 必须已安装
assert exists(importlib.util.find_spec('flash_attn')), 'flash-attn must be installed. `pip install flash-attn --no-build-isolation` first'

# 获取 flash-attn 版本信息
flash_attn_version = version('flash_attn')
# 断言 flash-attn 版本大于等于 2.5.1
assert pkg_version.parse(flash_attn_version) >= pkg_version.parse('2.5.1')

# 从 flash_attn.flash_attn_interface 模块中导入相关函数
from flash_attn.flash_attn_interface import (
    _flash_attn_varlen_backward,
    _flash_attn_backward
)

# 确保 triton 已安装用于前向传播
assert exists(importlib.util.find_spec('triton')), 'latest triton must be installed. `pip install triton -U` first'

# 获取 triton 版本信息
triton_version = version('triton')
# 断言 triton 版本大于等于 2.1
assert pkg_version.parse(triton_version) >= pkg_version.parse('2.1')

# 导入 triton 库
import triton
# 从 triton.language 中导入 tl 模块

import triton.language as tl

# 从 Tri 的 flash_attn 仓库中获取 flash attention 前向传播代码,并进行修改以返回未归一化的累积值、行最大值和行 lse - 减少通过环传递

@triton.heuristics(
    {
        "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
        "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
        "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
    }
)
@triton.jit
def _fwd_kernel(
    Q,
    K,
    V,
    Bias,
    Out,
    M,
    Lse,
    softmax_scale,
    stride_qb,
    stride_qh,
    stride_qm,
    stride_kb,
    stride_kh,
    stride_kn,
    stride_vb,
    stride_vh,
    stride_vn,
    stride_bb,
    stride_bh,
    stride_bm,
    stride_ob,
    stride_oh,
    stride_om,
    nheads,
    seqlen_q,
    seqlen_k,
    seqlen_q_rounded,
    headdim,
    CACHE_KEY_SEQLEN_Q,
    CACHE_KEY_SEQLEN_K,
    HAS_BIAS: tl.constexpr,
    IS_CAUSAL: tl.constexpr,
    CAUSAL_MASK_DIAGONAL: tl.constexpr,
    LOAD_ACCUMULATED: tl.constexpr,
    RETURN_NORMALIZED_OUTPUT: tl.constexpr,
    BLOCK_HEADDIM: tl.constexpr,
    EVEN_M: tl.constexpr,
    EVEN_N: tl.constexpr,
    EVEN_HEADDIM: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hb = tl.program_id(1)
    off_b = off_hb // nheads
    off_h = off_hb % nheads

    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_HEADDIM)

    q_ptrs = (
        Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
    )
    k_ptrs = (
        K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
    )
    v_ptrs = (
        V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
    )

    if HAS_BIAS:
        b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n

    # 最大值

    m_ptrs = M + off_hb * seqlen_q_rounded + offs_m

    if LOAD_ACCUMULATED:
        m_i = tl.load(m_ptrs)
    else:
        m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")

    # 加载 lse

    lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m

    if LOAD_ACCUMULATED:
        lse_i = tl.load(lse_ptrs)
    else:
        # 如果条件不成立,创建一个形状为 [BLOCK_M],数据类型为 float32 的张量,并填充为负无穷大
        lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")

    # 加载累积输出的偏移量
    offs_d = tl.arange(0, BLOCK_HEADDIM)

    # 计算输出指针的位置
    out_ptrs = (
        Out
        + off_b * stride_ob
        + off_h * stride_oh
        + (offs_m[:, None] * stride_om + offs_d[None, :])
    )

    # 如果需要加载累积值
    if LOAD_ACCUMULATED:
        # 如果 BLOCK_M 是偶数
        if EVEN_M:
            # 如果 BLOCK_HEADDIM 是偶数
            if EVEN_HEADDIM:
                acc_o = tl.load(out_ptrs)
            else:
                acc_o = tl.load(out_ptrs, mask=offs_d[None, :] < headdim)
        else:
            # 如果 BLOCK_HEADDIM 是偶数
            if EVEN_HEADDIM:
                acc_o = tl.load(out_ptrs, mask=offs_m[:, None] < seqlen_q)
            else:
                acc_o = tl.load(
                    out_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
                )

        acc_o = acc_o.to(tl.float32)
    else:
        # 创建一个形状为 [BLOCK_M, BLOCK_HEADDIM],数据类型为 float32 的零张量
        acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)

    # 加载查询、键、值
    if EVEN_M & EVEN_N:
        # 如果 BLOCK_M 和 BLOCK_N 都是偶数
        if EVEN_HEADDIM:
            q = tl.load(q_ptrs)
        else:
            q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
    else:
        # 如果 BLOCK_M 和 BLOCK_N 不都是偶数
        if EVEN_HEADDIM:
            q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
        else:
            q = tl.load(
                q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
            )

    # 计算结束位置
    end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
    # 循环遍历起始位置,每次增加 BLOCK_N
    for start_n in range(0, end_n, BLOCK_N):
        # 将 start_n 调整为 BLOCK_N 的倍数
        start_n = tl.multiple_of(start_n, BLOCK_N)

        # 根据条件判断是否加载 k
        if EVEN_N & EVEN_M:
            # 根据条件加载 k
            if EVEN_HEADDIM:
                k = tl.load(k_ptrs + start_n * stride_kn)
            else:
                k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
        else:
            if EVEN_HEADDIM:
                k = tl.load(
                    k_ptrs + start_n * stride_kn,
                    mask=(start_n + offs_n)[:, None] < seqlen_k,
                    other=0.0,
                )
            else:
                k = tl.load(
                    k_ptrs + start_n * stride_kn,
                    mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                    other=0.0,
                )
        # 初始化 qk
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        # 计算 qk
        qk += tl.dot(q, tl.trans(k))

        # 根据条件判断是否添加特定值到 qk
        if not EVEN_N:
            qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))

        # 根据条件判断是否添加特定值到 qk
        if IS_CAUSAL:
            if CAUSAL_MASK_DIAGONAL:
                # 为 stripe attention 需要的操作
                qk += tl.where(offs_m[:, None] > (start_n + offs_n)[None, :], 0, float("-inf"))
            else:
                qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))

        # 根据条件判断是否添加偏置到 qk
        if HAS_BIAS:
            if EVEN_N:
                bias = tl.load(b_ptrs + start_n)
            else:
                bias = tl.load(
                    b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0
                )
            bias = bias[None, :]

            bias = bias.to(tl.float32)
            qk = qk * softmax_scale + bias
            m_ij = tl.maximum(tl.max(qk, 1), lse_i)
            p = tl.exp(qk - m_ij[:, None])
        else:
            m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
            p = tl.exp(qk * softmax_scale - m_ij[:, None])

        # 计算 l_ij
        l_ij = tl.sum(p, 1)

        # 计算 acc_o_scale
        acc_o_scale = tl.exp(m_i - m_ij)
        acc_o = acc_o * acc_o_scale[:, None]

        # 根据条件判断是否加载 v
        if EVEN_N & EVEN_M:
            if EVEN_HEADDIM:
                v = tl.load(v_ptrs + start_n * stride_vn)
            else:
                v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
        else:
            if EVEN_HEADDIM:
                v = tl.load(
                    v_ptrs + start_n * stride_vn,
                    mask=(start_n + offs_n)[:, None] < seqlen_k,
                    other=0.0,
                )
            else:
                v = tl.load(
                    v_ptrs + start_n * stride_vn,
                    mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                    other=0.0,
                )

        # 将 p 转换为与 v 相同的数据类型
        p = p.to(v.dtype)
        acc_o += tl.dot(p, v)

        # -- 更新统计信息

        m_i = m_ij
        l_i_new = tl.exp(lse_i - m_ij) + l_ij
        lse_i = m_ij + tl.log(l_i_new)

    # 如果需要返回归一化的输出
    if RETURN_NORMALIZED_OUTPUT:
        acc_o_scale = tl.exp(m_i - lse_i)
        acc_o = acc_o * acc_o_scale[:, None]

    # 计算 m 和 lse 的偏移量

    start_m = tl.program_id(0)
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)

    # 写回 lse 和 m

    tl.store(lse_ptrs, lse_i)

    if not RETURN_NORMALIZED_OUTPUT:
        tl.store(m_ptrs, m_i)

    # 写入输出

    if EVEN_M:
        if EVEN_HEADDIM:
            tl.store(out_ptrs, acc_o)
        else:
            tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
    else:
        if EVEN_HEADDIM:
            tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
        else:
            tl.store(
                out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
            )
# 定义 flash attention 的前向传播函数
def flash_attn_forward(
    q,
    k,
    v,
    bias = None,
    causal = False,
    o = None,
    m = None,
    lse = None,
    softmax_scale = None,
    causal_mask_diagonal = False,
    return_normalized_output = False,
    load_accumulated = True
):
    # 如果输入的张量不是连续的,则将其转换为连续的张量
    q, k, v = [x if is_contiguous(x) else x.contiguous() for x in (q, k, v)]

    # 获取输入张量的形状信息
    batch, seqlen_q, nheads, d = q.shape
    _, seqlen_k, _, _ = k.shape

    # 断言输入张量的形状符合要求
    assert k.shape == (batch, seqlen_k, nheads, d)
    assert v.shape == (batch, seqlen_k, nheads, d)
    assert d <= 128, "FlashAttention only support head dimensions up to 128"
    assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
    assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
    assert q.is_cuda and k.is_cuda and v.is_cuda

    # 设置 softmax 的缩放因子
    softmax_scale = default(softmax_scale, d ** -0.5)

    # 检查是否存在偏置项
    has_bias = exists(bias)

    if has_bias:
        assert bias.dtype in [q.dtype, torch.float]
        assert bias.is_cuda

        # 如果偏置项是二维的,则进行扩展
        if bias.ndim == 2:
            bias = repeat(bias, 'b j -> b h i j', h = nheads, i = seqlen_q)

        # 如果偏置项不是连续的,则转换为连续的张量
        if not is_contiguous(bias):
            bias = bias.contiguous()

        assert bias.shape[-2:] == (1, seqlen_k)
        bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)

    # 记录偏置项的步长信息
    bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)

    # 对序列长度进行向上取整,使其能够被 128 整除
    seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128

    # 初始化 lse 张量
    if not exists(lse):
        max_neg_value = -torch.finfo(torch.float32).max
        init_fn = partial(torch.full, fill_value = max_neg_value) if load_accumulated else torch.empty
        lse = init_fn((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)

    # 初始化 m 张量
    if not exists(m):
        max_neg_value = -torch.finfo(torch.float32).max
        init_fn = partial(torch.full, fill_value = max_neg_value) if load_accumulated else torch.empty
        m = init_fn((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)

    # 初始化输出张量 o
    if not exists(o):
        init_fn = torch.zeros_like if load_accumulated else torch.empty_like
        o = init_fn(q)

    # 设置 BLOCK_HEADDIM 和 BLOCK 的值
    BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
    BLOCK = 128
    num_warps = 4 if d <= 64 else 8
    grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)

    # 调用 _fwd_kernel 函数进行前向传播计算
    _fwd_kernel[grid](
        q,
        k,
        v,
        bias,
        o,
        m,
        lse,
        softmax_scale,
        q.stride(0),
        q.stride(2),
        q.stride(1),
        k.stride(0),
        k.stride(2),
        k.stride(1),
        v.stride(0),
        v.stride(2),
        v.stride(1),
        *bias_strides,
        o.stride(0),
        o.stride(2),
        o.stride(1),
        nheads,
        seqlen_q,
        seqlen_k,
        seqlen_q_rounded,
        d,
        seqlen_q // 32,
        seqlen_k // 32,
        has_bias,
        causal,
        causal_mask_diagonal,
        load_accumulated,
        return_normalized_output,
        BLOCK_HEADDIM,
        BLOCK_M = BLOCK,
        BLOCK_N = BLOCK,
        num_warps = num_warps,
        num_stages = 1,
    )

    # 返回输出张量 o, m, lse
    return o, m, lse

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 判断一个数是否能被另一个数整除
def divisible_by(num, den):
    return (num % den) == 0

# ring + (flash) attention forwards and backwards

# flash attention v1 - https://arxiv.org/abs/2205.14135
# flash attention v2 - https://tridao.me/publications/flash2/flash2.pdf
# ring attention - https://arxiv.org/abs/2310.01889

# 定义 RingFlashAttentionCUDAFunction 类
class RingFlashAttentionCUDAFunction(Function):

    # 前向传播函数
    @staticmethod
    @torch.no_grad()
    def forward(
        ctx,
        q: Tensor,
        k: Tensor,
        v: Tensor,
        mask: Optional[Tensor],
        causal: bool,
        bucket_size: int,
        ring_reduce_col: bool,
        striped_ring_attn: bool,
        max_lookback_seq_len: Optional[int],
        ring_size: Optional[int]
    @staticmethod
    @torch.no_grad()
# 将自定义的 CUDA 函数应用到环形闪光注意力机制上
ring_flash_attn_cuda_ = RingFlashAttentionCUDAFunction.apply

# 定义环形闪光注意力机制的 CUDA 函数
@beartype
def ring_flash_attn_cuda(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    mask: Optional[Tensor] = None,
    causal: bool = False,
    bucket_size: int = 1024,
    ring_reduce_col: bool = False,
    striped_ring_attn: bool = False,
    max_lookback_seq_len: Optional[int] = None,
    ring_size: Optional[int] = None
):
    # 调用环形闪光注意力机制的 CUDA 函数,传入参数并返回结果
    return ring_flash_attn_cuda_(q, k, v, mask, causal, bucket_size, ring_reduce_col, striped_ring_attn, max_lookback_seq_len, ring_size)

.\lucidrains\ring-attention-pytorch\ring_attention_pytorch\__init__.py

# 从ring_attention_pytorch.ring_attention模块中导入RingAttention、RingTransformer、RingRotaryEmbedding、apply_rotary_pos_emb、default_attention等类或函数
from ring_attention_pytorch.ring_attention import (
    RingAttention,
    RingTransformer,
    RingRotaryEmbedding,
    apply_rotary_pos_emb,
    default_attention
)

# 从ring_attention_pytorch.ring_flash_attention模块中导入ring_flash_attn、ring_flash_attn_等函数
from ring_attention_pytorch.ring_flash_attention import (
    ring_flash_attn,
    ring_flash_attn_
)

.\lucidrains\ring-attention-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'ring-attention-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.2.14',  # 版本号
  license='MIT',  # 许可证
  description = 'Ring Attention - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/ring-attention-pytorch',  # 项目链接
  keywords = [
    'artificial intelligence',  # 关键词
    'deep learning',  # 关键词
    'distributed attention'  # 关键词
  ],
  install_requires=[
    'beartype',  # 安装所需的依赖
    'einx[torch]>=0.1.3',  # 安装所需的依赖
    'torch>=2.0'  # 安装所需的依赖
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类器
    'Intended Audience :: Developers',  # 分类器
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类器
    'License :: OSI Approved :: MIT License',  # 分类器
    'Programming Language :: Python :: 3.6',  # 分类器
  ],
)

Robotic Transformer - Pytorch

Implementation of RT1 (Robotic Transformer), from the Robotics at Google team, in Pytorch

Install

$ pip install robotic-transformer-pytorch

Usage

import torch
from robotic_transformer_pytorch import MaxViT, RT1

vit = MaxViT(
    num_classes = 1000,
    dim_conv_stem = 64,
    dim = 96,
    dim_head = 32,
    depth = (2, 2, 5, 2),
    window_size = 7,
    mbconv_expansion_rate = 4,
    mbconv_shrinkage_rate = 0.25,
    dropout = 0.1
)

model = RT1(
    vit = vit,
    num_actions = 11,
    depth = 6,
    heads = 8,
    dim_head = 64,
    cond_drop_prob = 0.2
)

video = torch.randn(2, 3, 6, 224, 224)

instructions = [
    'bring me that apple sitting on the table',
    'please pass the butter'
]

train_logits = model(video, instructions) # (2, 6, 11, 256) # (batch, frames, actions, bins)

# after much training

model.eval()
eval_logits = model(video, instructions, cond_scale = 3.) # classifier free guidance with conditional scale of 3

Appreciation

  • Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research

Todo

Citations

@inproceedings{rt12022arxiv,
    title    = {RT-1: Robotics Transformer for Real-World Control at Scale},
    author   = {Anthony Brohan and Noah Brown and Justice Carbajal and  Yevgen Chebotar and Joseph Dabis and Chelsea Finn and Keerthana Gopalakrishnan and Karol Hausman and Alex Herzog and Jasmine Hsu and Julian Ibarz and Brian Ichter and Alex Irpan and Tomas Jackson and  Sally Jesmonth and Nikhil Joshi and Ryan Julian and Dmitry Kalashnikov and Yuheng Kuang and Isabel Leal and Kuang-Huei Lee and  Sergey Levine and Yao Lu and Utsav Malla and Deeksha Manjunath and  Igor Mordatch and Ofir Nachum and Carolina Parada and Jodilyn Peralta and Emily Perez and Karl Pertsch and Jornell Quiambao and  Kanishka Rao and Michael Ryoo and Grecia Salazar and Pannag Sanketi and Kevin Sayed and Jaspiar Singh and Sumedh Sontakke and Austin Stone and Clayton Tan and Huong Tran and Vincent Vanhoucke and Steve Vega and Quan Vuong and Fei Xia and Ted Xiao and Peng Xu and Sichun Xu and Tianhe Yu and Brianna Zitkovich},
    booktitle = {arXiv preprint arXiv:2204.01691},
    year      = {2022}
}
@inproceedings{Tu2022MaxViTMV,
    title   = {MaxViT: Multi-Axis Vision Transformer},
    author  = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
    year    = {2022}
}
@misc{peebles2022scalable,
    title   = {Scalable Diffusion Models with Transformers},
    author  = {William Peebles and Saining Xie},
    year    = {2022},
    eprint  = {2212.09748},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

.\lucidrains\robotic-transformer-pytorch\robotic_transformer_pytorch\robotic_transformer_pytorch.py

# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 torch 中导入 nn, einsum, Tensor
from torch import nn, einsum, Tensor
# 从 typing 中导入 List, Optional, Callable, Tuple
from typing import List, Optional, Callable, Tuple
# 从 beartype 中导入 beartype
from beartype import beartype
# 从 einops 中导入 pack, unpack, repeat, reduce, rearrange
from einops import pack, unpack, repeat, reduce, rearrange
# 从 einops.layers.torch 中导入 Rearrange, Reduce
from einops.layers.torch import Rearrange, Reduce
# 从 functools 中导入 partial
from functools import partial
# 从 classifier_free_guidance_pytorch 中导入 TextConditioner, AttentionTextConditioner, classifier_free_guidance

# helpers

# 定义函数 exists,判断值是否存在
def exists(val):
    return val is not None

# 定义函数 default,返回值或默认值
def default(val, d):
    return val if exists(val) else d

# 定义函数 cast_tuple,将值转换为元组
def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# 定义函数 pack_one,将值按照指定模式打包
def pack_one(x, pattern):
    return pack([x], pattern)

# 定义函数 unpack_one,将值按照指定模式解包
def unpack_one(x, ps, pattern):
    return unpack(x, ps, pattern)[0]

# sinusoidal positions

# 定义函数 posemb_sincos_1d,生成一维正弦余弦位置编码
def posemb_sincos_1d(seq, dim, temperature = 10000, device = None, dtype = torch.float32):
    n = torch.arange(seq, device = device)
    omega = torch.arange(dim // 2, device = device) / (dim // 2 - 1)
    omega = 1. / (temperature ** omega)

    n = n[:, None] * omega[None, :]
    pos_emb = torch.cat((n.sin(), n.cos()), dim = 1)
    return pos_emb.type(dtype)

# helper classes

# 定义类 Residual,实现残差连接
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x

# 定义类 LayerNorm,实现层归一化
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# 定义类 FeedForward,实现前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        inner_dim = int(dim * mult)
        self.norm = LayerNorm(dim)

        self.net = nn.Sequential(
            nn.Linear(dim, inner_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x, cond_fn = None):
        x = self.norm(x)

        if exists(cond_fn):
            # adaptive layernorm
            x = cond_fn(x)

        return self.net(x)

# MBConv

# 定义类 SqueezeExcitation,实现 MBConv 中的 Squeeze-and-Excitation 模块
class SqueezeExcitation(nn.Module):
    def __init__(self, dim, shrinkage_rate = 0.25):
        super().__init__()
        hidden_dim = int(dim * shrinkage_rate)

        self.gate = nn.Sequential(
            Reduce('b c h w -> b c', 'mean'),
            nn.Linear(dim, hidden_dim, bias = False),
            nn.SiLU(),
            nn.Linear(hidden_dim, dim, bias = False),
            nn.Sigmoid(),
            Rearrange('b c -> b c 1 1')
        )

    def forward(self, x):
        return x * self.gate(x)

# 定义类 MBConvResidual,实现 MBConv 中的残差连接
class MBConvResidual(nn.Module):
    def __init__(self, fn, dropout = 0.):
        super().__init__()
        self.fn = fn
        self.dropsample = Dropsample(dropout)

    def forward(self, x):
        out = self.fn(x)
        out = self.dropsample(out)
        return out + x

# 定义类 Dropsample,实现随机丢弃采样
class Dropsample(nn.Module):
    def __init__(self, prob = 0):
        super().__init__()
        self.prob = prob
  
    def forward(self, x):
        device = x.device

        if self.prob == 0. or (not self.training):
            return x

        keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
        return x * keep_mask / (1 - self.prob)

# 定义函数 MBConv,实现 MBConv 模块
def MBConv(
    dim_in,
    dim_out,
    *,
    downsample,
    expansion_rate = 4,
    shrinkage_rate = 0.25,
    dropout = 0.
):
    hidden_dim = int(expansion_rate * dim_out)
    stride = 2 if downsample else 1
    # 定义一个神经网络模型,包括卷积层、批量归一化层、GELU激活函数等
    net = nn.Sequential(
        nn.Conv2d(dim_in, hidden_dim, 1),  # 输入通道数为dim_in,输出通道数为hidden_dim的1x1卷积层
        nn.BatchNorm2d(hidden_dim),  # 对隐藏层进行批量归一化
        nn.GELU(),  # GELU激活函数
        nn.Conv2d(hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim),  # 3x3卷积层,带有步长、填充和分组参数
        nn.BatchNorm2d(hidden_dim),  # 对隐藏层进行批量归一化
        nn.GELU(),  # GELU激活函数
        SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate),  # Squeeze-and-Excitation模块
        nn.Conv2d(hidden_dim, dim_out, 1),  # 输入通道数为hidden_dim,输出通道数为dim_out的1x1卷积层
        nn.BatchNorm2d(dim_out)  # 对输出层进行批量归一化
    )

    # 如果输入通道数等于输出通道数且不需要下采样,则添加MBConvResidual模块
    if dim_in == dim_out and not downsample:
        net = MBConvResidual(net, dropout=dropout)

    # 返回构建好的神经网络模型
    return net
# 定义注意力机制类
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 32,
        dropout = 0.,
        window_size = 7,
        num_mem_kv = 4
    ):
        super().__init__()
        # 确保维度可以被头部维度整除
        assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'

        # 归一化层
        self.norm = LayerNorm(dim)

        # 头部数量
        self.heads = dim // dim_head
        # 缩放因子
        self.scale = dim_head ** -0.5

        # 查询、键、值的线性变换
        self.to_qkv = nn.Linear(dim, dim * 3, bias = False)

        # 记忆键值对
        self.mem_kv = nn.Parameter(torch.randn(2, self.heads, num_mem_kv, dim_head))

        # 注意力机制
        self.attend = nn.Sequential(
            nn.Softmax(dim = -1),
            nn.Dropout(dropout)
        )

        # 输出层
        self.to_out = nn.Sequential(
            nn.Linear(dim, dim, bias = False),
            nn.Dropout(dropout)
        )

        # 相对位置偏置
        self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)

        # 相对位置索引计算
        pos = torch.arange(window_size)
        grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
        grid = rearrange(grid, 'c i j -> (i j) c')
        rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
        rel_pos += window_size - 1
        rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)

        # 注册相对位置索引
        self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)

    def forward(self, x):
        # 获取输入张量的形状信息
        batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads

        # 归一化输入张量
        x = self.norm(x)

        # 展平张量
        x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')

        # 查询、键、值的投影
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)

        # 分割头部
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 缩放
        q = q * self.scale

        # 空值/记忆/注册键值对
        mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = q.shape[0]),  self.mem_kv)
        num_mem = mk.shape[-2]

        k = torch.cat((mk, k), dim = -2)
        v = torch.cat((mv, v), dim = -2)

        # 相似度计算
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        # 添加位置偏置
        bias = self.rel_pos_bias(self.rel_pos_indices)
        bias = F.pad(bias, (0, 0, num_mem, 0), value = 0.)
        sim = sim + rearrange(bias, 'i j h -> h i j')

        # 注意力计算
        attn = self.attend(sim)

        # 聚合
        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        # 合并头部
        out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width)

        # 合并头部输出
        out = self.to_out(out)
        return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width)

# 定义 MaxViT 模型类
class MaxViT(nn.Module):
    def __init__(
        self,
        *,
        num_classes,
        dim,
        depth,
        dim_head = 32,
        dim_conv_stem = None,
        window_size = 7,
        mbconv_expansion_rate = 4,
        mbconv_shrinkage_rate = 0.25,
        dropout = 0.1,
        channels = 3
        ):
        # 调用父类的构造函数
        super().__init__()
        # 断言 depth 是元组类型,如果不是则抛出异常
        assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'

        # 卷积干部

        # 设置卷积干部的维度
        dim_conv_stem = default(dim_conv_stem, dim)

        # 创建卷积干部的序列
        self.conv_stem = nn.Sequential(
            nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
            nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
        )

        # 变量

        # 获取深度的阶段数
        num_stages = len(depth)

        # 计算每个阶段的维度
        dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
        dims = (dim_conv_stem, *dims)
        dim_pairs = tuple(zip(dims[:-1], dims[1:]))

        # 创建模块列表
        self.layers = nn.ModuleList([])

        # 用于高效块-网格式注意力的窗口大小的简写

        w = window_size

        # 遍历各个阶段

        cond_hidden_dims = []

        for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
            for stage_ind in range(layer_depth):
                is_first = stage_ind == 0
                stage_dim_in = layer_dim_in if is_first else layer_dim

                cond_hidden_dims.append(stage_dim_in)

                block = nn.Sequential(
                    MBConv(
                        stage_dim_in,
                        layer_dim,
                        downsample = is_first,
                        expansion_rate = mbconv_expansion_rate,
                        shrinkage_rate = mbconv_shrinkage_rate
                    ),
                    Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w),  # 块状注意力
                    Residual(Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
                    Residual(FeedForward(dim = layer_dim, dropout = dropout)),
                    Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),

                    Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w),  # 网格式注意力
                    Residual(Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
                    Residual(FeedForward(dim = layer_dim, dropout = dropout)),
                    Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
                )

                self.layers.append(block)

        embed_dim = dims[-1]
        self.embed_dim = dims[-1]

        self.cond_hidden_dims = cond_hidden_dims

        # mlp 头部输出

        self.mlp_head = nn.Sequential(
            Reduce('b d h w -> b d', 'mean'),
            LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )

    @beartype
    def forward(
        self,
        x,
        texts: Optional[List[str]] = None,
        cond_fns: Optional[Tuple[Callable, ...]] = None,
        cond_drop_prob = 0.,
        return_embeddings = False
    ):
        # 对输入进行卷积干部处理
        x = self.conv_stem(x)

        # 初始化条件函数
        cond_fns = iter(default(cond_fns, []))

        # 遍历每个阶段
        for stage in self.layers:
            # 获取下一个条件函数
            cond_fn = next(cond_fns, None)

            # 如果条件函数存在,则应用条件函数
            if exists(cond_fn):
                x = cond_fn(x)

            # 应用当前阶段的模块
            x = stage(x)

        # 如果需要返回嵌入向量,则返回嵌入向量
        if return_embeddings:
            return x

        # 返回经过 MLP 头部处理后的结果
        return self.mlp_head(x)
# 定义 TransformerAttention 类,用于实现 Transformer 中的注意力机制
class TransformerAttention(nn.Module):
    def __init__(
        self,
        dim,
        causal = False,
        dim_head = 64,
        dim_context = None,
        heads = 8,
        norm_context = False,
        dropout = 0.1
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.causal = causal
        inner_dim = dim_head * heads

        dim_context = default(dim_context, dim)

        self.norm = LayerNorm(dim)
        self.context_norm = LayerNorm(dim_context) if norm_context else nn.Identity()

        self.attn_dropout = nn.Dropout(dropout)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim_context, dim_head * 2, bias = False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            nn.Dropout(dropout)
        )

    def forward(
        self,
        x,
        context = None,
        mask = None,
        attn_bias = None,
        attn_mask = None,
        cond_fn: Optional[Callable] = None
    ):
        b = x.shape[0]

        if exists(context):
            context = self.context_norm(context)

        kv_input = default(context, x)

        x = self.norm(x)

        if exists(cond_fn):
            # adaptive layer-norm
            x = cond_fn(x)

        q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)

        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)

        q = q * self.scale

        sim = einsum('b h i d, b j d -> b h i j', q, k)

        if exists(attn_bias):
            sim = sim + attn_bias

        if exists(attn_mask):
            sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)

        if exists(mask):
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        attn = sim.softmax(dim = -1)
        attn = self.attn_dropout(attn)

        out = einsum('b h i j, b j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义 Transformer 类,用于实现 Transformer 模型
@beartype
class Transformer(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        depth = 6,
        attn_dropout = 0.,
        ff_dropout = 0.
    ):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                TransformerAttention(dim = dim, heads =  heads, dropout = attn_dropout),
                FeedForward(dim = dim, dropout = ff_dropout)
            ]))

    def forward(
        self,
        x,
        cond_fns: Optional[Tuple[Callable, ...]] = None,
        attn_mask = None
    ):
        cond_fns = iter(default(cond_fns, []))

        for attn, ff in self.layers:
             x = attn(x, attn_mask = attn_mask, cond_fn = next(cond_fns, None)) + x
             x = ff(x, cond_fn = next(cond_fns, None)) + x
        return x

# 定义 TokenLearner 类,用于实现 TokenLearner 模块
class TokenLearner(nn.Module):
    """
    https://arxiv.org/abs/2106.11297
    using the 1.1 version with the MLP (2 dense layers with gelu) for generating attention map
    """

    def __init__(
        self,
        *,
        dim,
        ff_mult = 2,
        num_output_tokens = 8,
        num_layers = 2
    ):
        super().__init__()
        inner_dim = dim * ff_mult * num_output_tokens

        self.num_output_tokens = num_output_tokens
        self.net = nn.Sequential(
            nn.Conv2d(dim * num_output_tokens, inner_dim, 1, groups = num_output_tokens),
            nn.GELU(),
            nn.Conv2d(inner_dim, num_output_tokens, 1, groups = num_output_tokens),
        )
    # 前向传播函数,接收输入 x
    def forward(self, x):
        # 将输入 x 打包成指定格式,并返回打包后的数据和打包参数
        x, ps = pack_one(x, '* c h w')
        # 将输入 x 重复多次,改变维度,以适应网络输入要求
        x = repeat(x, 'b c h w -> b (g c) h w', g = self.num_output_tokens)
        # 使用网络进行处理
        attn = self.net(x)

        # 重新排列注意力矩阵的维度
        attn = rearrange(attn, 'b g h w -> b 1 g h w')
        # 重新排列输入 x 的维度
        x = rearrange(x, 'b (g c) h w -> b c g h w', g = self.num_output_tokens)

        # 对输入 x 和注意力矩阵进行元素级乘法,并对结果进行降维求均值
        x = reduce(x * attn, 'b c g h w -> b c g', 'mean')
        # 解包 x,恢复原始维度
        x = unpack_one(x, ps, '* c n')
        # 返回处理后的结果 x
        return x
# Robotic Transformer

# 使用 beartype 装饰器对 RT1 类进行类型检查
@beartype
class RT1(nn.Module):
    # 初始化函数,接收多个参数
    def __init__(
        self,
        *,
        vit: MaxViT,  # 接收一个 MaxViT 类型的参数 vit
        num_actions = 11,  # 默认参数,表示动作的数量
        action_bins = 256,  # 默认参数,表示动作的分组数量
        depth = 6,  # 默认参数,表示 Transformer 的深度
        heads = 8,  # 默认参数,表示 Transformer 的头数
        dim_head = 64,  # 默认参数,表示每个头的维度
        token_learner_ff_mult = 2,  # 默认参数,表示 TokenLearner 的前馈倍数
        token_learner_num_layers = 2,  # 默认参数,表示 TokenLearner 的层数
        token_learner_num_output_tokens = 8,  # 默认参数,表示 TokenLearner 的输出 token 数量
        cond_drop_prob = 0.2,  # 默认参数,表示条件丢弃的概率
        use_attn_conditioner = False,  # 默认参数,表示是否使用 AttentionTextConditioner
        conditioner_kwargs: dict = dict()  # 默认参数,表示条件器的其他参数
    ):
        super().__init__()
        self.vit = vit  # 初始化 vit

        self.num_vit_stages = len(vit.cond_hidden_dims)  # 计算 vit 的隐藏维度数量

        # 根据是否使用 AttentionTextConditioner 选择条件器类
        conditioner_klass = AttentionTextConditioner if use_attn_conditioner else TextConditioner

        # 初始化条件器
        self.conditioner = conditioner_klass(
            hidden_dims = (*tuple(vit.cond_hidden_dims), *((vit.embed_dim,) * depth * 2)),
            hiddens_channel_first = (*((True,) * self.num_vit_stages), *((False,) * depth * 2)),
            cond_drop_prob = cond_drop_prob,
            **conditioner_kwargs
        )

        # 初始化 TokenLearner
        self.token_learner = TokenLearner(
            dim = vit.embed_dim,
            ff_mult = token_learner_ff_mult,
            num_output_tokens = token_learner_num_output_tokens,
            num_layers = token_learner_num_layers
        )

        self.num_learned_tokens = token_learner_num_output_tokens  # 记录 TokenLearner 的输出 token 数量

        self.transformer_depth = depth  # 记录 Transformer 的深度

        # 初始化 Transformer
        self.transformer = Transformer(
            dim = vit.embed_dim,
            dim_head = dim_head,
            heads = heads,
            depth = depth
        )

        self.cond_drop_prob = cond_drop_prob  # 记录条件丢弃的概率

        # 初始化输出层
        self.to_logits = nn.Sequential(
            LayerNorm(vit.embed_dim),
            nn.Linear(vit.embed_dim, num_actions * action_bins),
            Rearrange('... (a b) -> ... a b', b = action_bins)
        )

    # 嵌入文本信息
    def embed_texts(self, texts: List[str]):
        return self.conditioner.embed_texts(texts)

    # 前向传播函数
    @classifier_free_guidance
    def forward(
        self,
        video,
        texts: Optional[List[str]] = None,
        text_embeds: Optional[Tensor] = None,
        cond_drop_prob = 0.
        ):
        # 断言只有 texts 或者 text_embeds 其中一个存在
        assert exists(texts) ^ exists(text_embeds)
        # 根据传入的参数创建条件参数字典
        cond_kwargs = dict(texts = texts, text_embeds = text_embeds)

        # 获取 transformer 的深度和条件丢弃概率
        depth = self.transformer_depth
        cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

        # 获取视频帧数和设备信息
        frames, device = video.shape[2], video.device

        # 调用 conditioner 方法生成条件函数
        cond_fns, _ = self.conditioner(
            **cond_kwargs,
            cond_drop_prob = cond_drop_prob,
            repeat_batch = (*((frames,) * self.num_vit_stages), *((1,) * self.transformer_depth * 2))
        )

        # 将条件函数分为 vit_cond_fns 和 transformer_cond_fns
        vit_cond_fns, transformer_cond_fns = cond_fns[:-(depth * 2)], cond_fns[-(depth * 2):]

        # 重新排列视频数据的维度
        video = rearrange(video, 'b c f h w -> b f c h w')
        # 打包视频数据
        images, packed_shape = pack_one(video, '* c h w')

        # 使用 vit 模型处理图像数据
        tokens = self.vit(
            images,
            texts = texts,
            cond_fns = vit_cond_fns,
            cond_drop_prob = cond_drop_prob,
            return_embeddings = True
        )

        # 解包 tokens 数据
        tokens = unpack_one(tokens, packed_shape, '* c h w')
        # 使用 token_learner 处理 tokens 数据
        learned_tokens = self.token_learner(tokens)

        # 重新排列 learned_tokens 的维度
        learned_tokens = rearrange(learned_tokens, 'b f c n -> b (f n) c')

        # 生成 causal attention mask
        attn_mask = torch.ones((frames, frames), dtype = torch.bool, device = device).triu(1)
        attn_mask = repeat(attn_mask, 'i j -> (i r1) (j r2)', r1 = self.num_learned_tokens, r2 = self.num_learned_tokens)

        # 生成 sinusoidal positional embedding
        pos_emb = posemb_sincos_1d(frames, learned_tokens.shape[-1], dtype = learned_tokens.dtype, device = learned_tokens.device)
        learned_tokens = learned_tokens + repeat(pos_emb, 'n d -> (n r) d', r = self.num_learned_tokens)

        # 进行 attention 操作
        attended_tokens = self.transformer(learned_tokens, cond_fns = transformer_cond_fns, attn_mask = ~attn_mask)

        # 对 attended_tokens 进行池化操作
        pooled = reduce(attended_tokens, 'b (f n) d -> b f d', 'mean', f = frames)

        # 将池化后的结果传入到 logits 模型中
        logits = self.to_logits(pooled)
        return logits

.\lucidrains\robotic-transformer-pytorch\robotic_transformer_pytorch\__init__.py

# 从 robotic_transformer_pytorch.robotic_transformer_pytorch 模块中导入 RT1, TokenLearner, MaxViT 类
from robotic_transformer_pytorch.robotic_transformer_pytorch import RT1, TokenLearner, MaxViT

.\lucidrains\robotic-transformer-pytorch\setup.py

# 导入设置工具和查找包函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'robotic-transformer-pytorch', # 包名
  packages = find_packages(exclude=[]), # 查找所有包
  version = '0.2.1', # 版本号
  license='MIT', # 许可证
  description = 'Robotic Transformer - Pytorch', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  long_description_content_type = 'text/markdown', # 长描述内容类型
  url = 'https://github.com/lucidrains/robotic-transformer-pytorch', # 项目链接
  keywords = [ # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'robotics'
  ],
  install_requires=[ # 安装依赖
    'classifier-free-guidance-pytorch>=0.4.0',
    'einops>=0.7',
    'torch>=2.0',
  ],
  classifiers=[ # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Rotary Embeddings - Pytorch

A standalone library for adding rotary embeddings to transformers in Pytorch, following its success as relative positional encoding. Specifically it will make rotating information into any axis of a tensor easy and efficient, whether they be fixed positional or learned. This library will give you state of the art results for positional embedding, at little costs.

My gut also tells me there is something more to rotations that can be exploited in artificial neural networks.

Install

$ pip install rotary-embedding-torch

Usage

import torch
from rotary_embedding_torch import RotaryEmbedding

# instantiate the positional embedding in your transformer and pass to all your attention layers

rotary_emb = RotaryEmbedding(dim = 32)

# mock queries and keys - dimensions should end with (seq_len, feature dimension), and any number of preceding dimensions (batch, heads, etc)

q = torch.randn(1, 8, 1024, 64) # queries - (batch, heads, seq len, dimension of head)
k = torch.randn(1, 8, 1024, 64) # keys

# apply the rotations to your queries and keys after the heads have been split out, but prior to the dot product and subsequent softmax (attention)

q = rotary_emb.rotate_queries_or_keys(q)
k = rotary_emb.rotate_queries_or_keys(k)

# then do your attention with your queries (q) and keys (k) as usual

If you do all the steps above correctly, you should see a dramatic improvement during training

Axial Rotary Embeddings

For easy use of n-dimensional axial relative positional embedding, ie. video transformers

import torch

from rotary_embedding_torch import (
    RotaryEmbedding,
    apply_rotary_emb
)

pos_emb = RotaryEmbedding(
    dim = 16,
    freqs_for = 'pixel',
    max_freq = 256
)

# queries and keys for frequencies to be rotated into
# say for a video with 8 frames, and rectangular image (feature dimension comes last)

q = torch.randn(1, 8, 64, 32, 64)
k = torch.randn(1, 8, 64, 32, 64)

# get axial frequencies - (8, 64, 32, 16 * 3 = 48)
# will automatically do partial rotary

freqs = pos_emb.get_axial_freqs(8, 64, 32)

# rotate in frequencies

q = apply_rotary_emb(freqs, q)
k = apply_rotary_emb(freqs, k)

Length Extrapolatable Rotary Embeddings

In this paper, they were able to fix length extrapolation issue with rotary embeddings by giving it a decay similar to ALiBi. They named this technique XPos, and you can use it by setting use_xpos = True on initialization.

This can only be used for autoregressive transformers

import torch
from rotary_embedding_torch import RotaryEmbedding

# instantiate the positional embedding in your transformer and pass to all your attention layers

rotary_emb = RotaryEmbedding(
    dim = 32,
    use_xpos = True   # set this to True to make rotary embeddings extrapolate better to sequence lengths greater than the one used at training time
)

# mock queries and keys - dimensions should end with (seq_len, feature dimension), and any number of preceding dimensions (batch, heads, etc)

q = torch.randn(1, 8, 1024, 64) # queries - (batch, heads, seq len, dimension of head)
k = torch.randn(1, 8, 1024, 64) # keys

# apply the rotations to your queries and keys after the heads have been split out, but prior to the dot product and subsequent softmax (attention)

# instead of using `rotate_queries_or_keys`, you will use `rotate_queries_and_keys`, the rest is taken care of

q, k = rotary_emb.rotate_queries_and_keys(q, k)

Interpolating Sequence Positions

This MetaAI paper proposes simply fine-tuning on interpolations of the sequence positions for extending to longer context length for pretrained models. They show this performs much better than simply fine-tuning on the same sequence positions but extended further.

You can use this by setting the interpolate_factor on initialization to a value greater than 1. (ex. if pretrained model was trained on 2048, setting interpolate_factor = 2. would allow fine-tuning to 2048 x 2. = 4096)

Update: someone in the community has reported that it does not work well. please email me if you see either a positive or negative result

import torch
from rotary_embedding_torch import RotaryEmbedding

rotary_emb = RotaryEmbedding(
    dim = 32,
    interpolate_factor = 2.    # add this line of code to pretrained model and fine-tune for ~1000 steps, as shown in paper
)

Citations

@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding}, 
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@inproceedings{Sun2022ALT,
    title     = {A Length-Extrapolatable Transformer},
    author    = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei},
    year      = {2022}
}
@inproceedings{Chen2023ExtendingCW,
    title   = {Extending Context Window of Large Language Models via Positional Interpolation},
    author  = {Shouyuan Chen and Sherman Wong and Liangjian Chen and Yuandong Tian},
    year    = {2023}
}
@misc{bloc97-2023
    title   = {NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.},
    author  = {/u/bloc97},
    url     = {https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/}
}

.\lucidrains\rotary-embedding-torch\rotary_embedding_torch\rotary_embedding_torch.py

# 从 math 模块中导入 pi 和 log 函数
from math import pi, log

# 导入 torch 模块
import torch
# 从 torch.nn 模块中导入 Module 和 ModuleList 类
from torch.nn import Module, ModuleList
# 从 torch.cuda.amp 模块中导入 autocast 函数
from torch.cuda.amp import autocast
# 从 torch 模块中导入 nn, einsum, broadcast_tensors, Tensor 类
from torch import nn, einsum, broadcast_tensors, Tensor

# 从 einops 模块中导入 rearrange, repeat 函数
from einops import rearrange, repeat

# 从 beartype 模块中导入 beartype 函数和相关类型
from beartype import beartype
from beartype.typing import Literal, Union, Optional

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 广播函数,用于 tortoise-tts

def broadcat(tensors, dim = -1):
    # 广播输入的张量
    broadcasted_tensors = broadcast_tensors(*tensors)
    # 沿指定维度拼接张量
    return torch.cat(broadcasted_tensors, dim = dim)

# 旋转嵌入的辅助函数

# 将输入张量沿最后两个维度旋转一半
def rotate_half(x):
    x = rearrange(x, '... (d r) -> ... d r', r = 2)
    x1, x2 = x.unbind(dim = -1)
    x = torch.stack((-x2, x1), dim = -1)
    return rearrange(x, '... d r -> ... (d r)')

# 应用旋转嵌入
@autocast(enabled = False)
def apply_rotary_emb(freqs, t, start_index = 0, scale = 1., seq_dim = -2):
    if t.ndim == 3:
        seq_len = t.shape[seq_dim]
        freqs = freqs[-seq_len:].to(t)

    rot_dim = freqs.shape[-1]
    end_index = start_index + rot_dim

    assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'

    t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
    t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
    return torch.cat((t_left, t, t_right), dim = -1)

# 应用学习到的旋转

def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None):
    if exists(freq_ranges):
        rotations = einsum('..., f -> ... f', rotations, freq_ranges)
        rotations = rearrange(rotations, '... r f -> ... (r f)')

    rotations = repeat(rotations, '... n -> ... (n r)', r = 2)
    return apply_rotary_emb(rotations, t, start_index = start_index)

# 类

# 旋转嵌入类
class RotaryEmbedding(Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        dim,
        custom_freqs: Optional[Tensor] = None,
        freqs_for: Union[
            Literal['lang'],
            Literal['pixel'],
            Literal['constant']
        ] = 'lang',
        theta = 10000,
        max_freq = 10,
        num_freqs = 1,
        learned_freq = False,
        use_xpos = False,
        xpos_scale_base = 512,
        interpolate_factor = 1.,
        theta_rescale_factor = 1.,
        seq_before_head_dim = False,
        cache_if_possible = True
    ):
        # 调用父类的构造函数
        super().__init__()
        # 提议由 Reddit 用户 bloc97 提出,将旋转嵌入重新缩放到更长的序列长度,无需微调
        # 与 NTK 文献有一定联系
        # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/

        # 根据维度调整旋转角度
        theta *= theta_rescale_factor ** (dim / (dim - 2))

        # 为频率设置参数
        self.freqs_for = freqs_for

        # 如果存在自定义频率,则使用自定义频率;否则根据不同的频率类型生成频率
        if exists(custom_freqs):
            freqs = custom_freqs
        elif freqs_for == 'lang':
            freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
        elif freqs_for == 'pixel':
            freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
        elif freqs_for == 'constant':
            freqs = torch.ones(num_freqs).float()

        # 设置是否缓存频率
        self.cache_if_possible = cache_if_possible

        # 初始化缓存频率和缩放
        self.tmp_store('cached_freqs', None)
        self.tmp_store('cached_scales', None)

        # 将频率设置为可学习参数
        self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)

        # 设置是否学习频率
        self.learned_freq = learned_freq

        # 为设备设置虚拟值
        self.tmp_store('dummy', torch.tensor(0))

        # 默认序列维度
        self.seq_before_head_dim = seq_before_head_dim
        self.default_seq_dim = -3 if seq_before_head_dim else -2

        # 插值因子
        assert interpolate_factor >= 1.
        self.interpolate_factor = interpolate_factor

        # 是否使用 x 位置编码
        self.use_xpos = use_xpos
        if not use_xpos:
            self.tmp_store('scale', None)
            return

        # 计算 x 位置编码的缩放
        scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
        self.scale_base = xpos_scale_base
        self.tmp_store('scale', scale)

    @property
    def device(self):
        # 返回虚拟值的设备
        return self.dummy.device

    def tmp_store(self, key, value):
        # 临时存储函数
        self.register_buffer(key, value, persistent = False)

    def get_seq_pos(self, seq_len, device, dtype, offset = 0):
        # 获取序列位置
        return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor

    def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0, freq_seq_len = None):
        # 旋转查询或键
        seq_dim = default(seq_dim, self.default_seq_dim)

        assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'

        device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]

        if exists(freq_seq_len):
            assert freq_seq_len >= seq_len
            seq_len = freq_seq_len

        freqs = self.forward(self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset), seq_len = seq_len, offset = offset)

        if seq_dim == -3:
            freqs = rearrange(freqs, 'n d -> n 1 d')

        return apply_rotary_emb(freqs, t, seq_dim = seq_dim)

    def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0):
        # 旋转查询并使用缓存的键
        seq_dim = default(seq_dim, self.default_seq_dim)

        q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
        assert q_len <= k_len
        rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, freq_seq_len = k_len)
        rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim)

        rotated_q = rotated_q.type(q.dtype)
        rotated_k = rotated_k.type(k.dtype)

        return rotated_q, rotated_k
    # 旋转查询和键,用于生成旋转后的查询和键
    def rotate_queries_and_keys(self, q, k, seq_dim = None):
        # 设置默认的序列维度
        seq_dim = default(seq_dim, self.default_seq_dim)

        # 断言是否使用了 xpos
        assert self.use_xpos
        # 获取设备、数据类型和序列长度
        device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]

        # 获取序列位置信息
        seq = self.get_seq_pos(seq_len, dtype = dtype, device = device)

        # 计算频率
        freqs = self.forward(seq, seq_len = seq_len)
        # 获取缩放比例
        scale = self.get_scale(seq, seq_len = seq_len).to(dtype)

        # 如果序列维度为 -3,则重新排列频率和缩放
        if seq_dim == -3:
            freqs = rearrange(freqs, 'n d -> n 1 d')
            scale = rearrange(scale, 'n d -> n 1 d')

        # 应用旋转嵌入到查询和键上
        rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim)
        rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim)

        # 转换旋转后的查询和键的数据类型
        rotated_q = rotated_q.type(q.dtype)
        rotated_k = rotated_k.type(k.dtype)

        # 返回旋转后的查询和键
        return rotated_q, rotated_k

    # 获取缩放比例
    @beartype
    def get_scale(
        self,
        t: Tensor,
        seq_len: Optional[int] = None,
        offset = 0
    ):
        # 断言是否使用了 xpos
        assert self.use_xpos

        # 判断是否应该缓存
        should_cache = (
            self.cache_if_possible and
            exists(seq_len)
        )

        # 如果应该缓存且缓存存在,则返回缓存的缩放比例
        if (
            should_cache and \
            exists(self.cached_scales) and \
            (seq_len + offset) <= self.cached_scales.shape[0]
        ):
            return self.cached_scales[offset:(offset + seq_len)]

        # 初始化缩放比例为 1
        scale = 1.
        # 如果使用了 xpos,则计算缩放比例
        if self.use_xpos:
            power = (t - len(t) // 2) / self.scale_base
            scale = self.scale ** rearrange(power, 'n -> n 1')
            scale = torch.cat((scale, scale), dim = -1)

        # 如果应该缓存,则缓存缩放比例
        if should_cache:
            self.tmp_store('cached_scales', scale)

        # 返回缩放比例
        return scale

    # 获取轴向频率
    def get_axial_freqs(self, *dims):
        # 定义切片
        Colon = slice(None)
        all_freqs = []

        # 遍历维度
        for ind, dim in enumerate(dims):
            # 根据频率类型生成位置信息
            if self.freqs_for == 'pixel':
                pos = torch.linspace(-1, 1, steps = dim, device = self.device)
            else:
                pos = torch.arange(dim, device = self.device)

            # 计算频率
            freqs = self.forward(pos, seq_len = dim)

            # 构建新的轴向切片
            all_axis = [None] * len(dims)
            all_axis[ind] = Colon

            new_axis_slice = (Ellipsis, *all_axis, Colon)
            all_freqs.append(freqs[new_axis_slice])

        # 广播所有频率并拼接
        all_freqs = broadcast_tensors(*all_freqs)
        return torch.cat(all_freqs, dim = -1)

    # 前向传播函数
    @autocast(enabled = False)
    def forward(
        self,
        t: Tensor,
        seq_len = None,
        offset = 0
    ):
        # 判断是否应该缓存频率
        should_cache = (
            self.cache_if_possible and \
            not self.learned_freq and \
            exists(seq_len) and \
            self.freqs_for != 'pixel'
        )

        # 如果应该缓存且缓存存在,则返回缓存的频率
        if (
            should_cache and \
            exists(self.cached_freqs) and \
            (offset + seq_len) <= self.cached_freqs.shape[0]
        ):
            return self.cached_freqs[offset:(offset + seq_len)].detach()

        # 获取频率
        freqs = self.freqs

        # 计算频率
        freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
        freqs = repeat(freqs, '... n -> ... (n r)', r = 2)

        # 如果应该缓存,则缓存频率
        if should_cache:
            self.tmp_store('cached_freqs', freqs.detach())

        # 返回频率
        return freqs

.\lucidrains\rotary-embedding-torch\rotary_embedding_torch\__init__.py

# 从 rotary_embedding_torch.rotary_embedding_torch 模块中导入以下函数和类
from rotary_embedding_torch.rotary_embedding_torch import (
    apply_rotary_emb,  # 应用旋转嵌入的函数
    RotaryEmbedding,   # 旋转嵌入类
    apply_learned_rotations,  # 应用学习到的旋转的函数
    broadcat  # 广播函数
)

.\lucidrains\rotary-embedding-torch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'rotary-embedding-torch',
  # 查找并包含所有包
  packages = find_packages(),
  # 版本号
  version = '0.5.3',
  # 许可证
  license='MIT',
  # 描述
  description = 'Rotary Embedding - Pytorch',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/rotary-embedding-torch',
  # 关键词
  keywords = [
    'artificial intelligence',
    'deep learning',
    'positional embedding'    
  ],
  # 安装依赖
  install_requires=[
    'beartype',
    'einops>=0.7',
    'torch>=2.0'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

Routing Transformer with Deepspeed for Enwik8

Deepspeed is the framework Microsoft used to train the world's largest Attention model (17GB) to date. They have open sourced it, and it works with Routing Transformers!

  1. First install Deepspeed following instructions from their official repository https://github.com/microsoft/DeepSpeed

  2. Run the following command in this folder

$ deepspeed train.py --deepspeed --deepspeed_config ds_config.json

.\lucidrains\routing-transformer\examples\enwik8_deepspeed\train.py

import deepspeed  # 导入deepspeed库

from routing_transformer import RoutingTransformerLM  # 从routing_transformer库中导入RoutingTransformerLM类
from routing_transformer.autoregressive_wrapper import AutoregressiveWrapper  # 从routing_transformer.autoregressive_wrapper库中导入AutoregressiveWrapper类

import argparse  # 导入argparse库,用于解析命令行参数
import random  # 导入random库,用于生成随机数
import tqdm  # 导入tqdm库,用于显示进度条
import gzip  # 导入gzip库,用于处理gzip文件
import numpy as np  # 导入numpy库,用于数值计算
import torch  # 导入torch库,用于构建神经网络
import torch.optim as optim  # 导入torch.optim库,用于优化器
from torch.nn import functional as F  # 从torch.nn库中导入functional模块
from torch.utils.data import DataLoader, Dataset  # 从torch.utils.data库中导入DataLoader和Dataset类

def add_argument():  # 定义函数add_argument,用于添加命令行参数
    parser=argparse.ArgumentParser(description='enwik8')  # 创建ArgumentParser对象,设置描述信息'enwik8'

    parser.add_argument('--with_cuda', default=False, action='store_true',  # 添加--with_cuda参数,默认为False,设置为True时执行store_true操作
                        help='use CPU in case there\'s no GPU support')  # 添加参数帮助信息
    parser.add_argument('--use_ema', default=False, action='store_true',  # 添加--use_ema参数,默认为False,设置为True时执行store_true操作
                        help='whether use exponential moving average')  # 添加参数帮助信息
    parser.add_argument('-b', '--batch_size', default=32, type=int,  # 添加-b或--batch_size参数,默认为32,类型为整数
                        help='mini-batch size (default: 32)')  # 添加参数帮助信息
    parser.add_argument('-e', '--epochs', default=30, type=int,  # 添加-e或--epochs参数,默认为30,类型为整数
                        help='number of total epochs (default: 30)')  # 添加参数帮助信息
    parser.add_argument('--local_rank', type=int, default=-1,  # 添加--local_rank参数,类型为整数,默认值为-1
                       help='local rank passed from distributed launcher')  # 添加参数帮助信息

    parser = deepspeed.add_config_arguments(parser)  # 调用deepspeed库的add_config_arguments函数
    args = parser.parse_args()  # 解析命令行参数并返回
    return args  # 返回参数值

# constants

VALIDATE_EVERY  = 100  # 定义常量VALIDATE_EVERY为100
GENERATE_EVERY  = 500  # 定义常量GENERATE_EVERY为500
GENERATE_LENGTH = 1024  # 定义常量GENERATE_LENGTH为1024
SEQ_LEN = 4096  # 定义常量SEQ_LEN为4096

# helpers

def decode_token(token):  # 定义函数decode_token,用于解码token
    return str(chr(max(32, token)))  # 返回ASCII码对应的字符,如果小于32则返回空格

def decode_tokens(tokens):  # 定义函数decode_tokens,用于解码tokens
    return ''.join(list(map(decode_token, tokens)))  # 将解码后的tokens拼接成字符串

# instantiate model

model = RoutingTransformerLM(  # 创建RoutingTransformerLM模型对象
    num_tokens = 256,  # 设置模型参数num_tokens为256
    dim = 512,  # 设置模型参数dim为512
    depth = 8,  # 设置模型参数depth为8
    max_seq_len = SEQ_LEN,  # 设置模型参数max_seq_len为SEQ_LEN
    heads = 8,  # 设置模型参数heads为8
    causal = True,  # 设置模型参数causal为True
    window_size = 128,  # 设置模型参数window_size为128
    reversible = True,  # 设置模型参数reversible为True
    ff_chunks = 2,  # 设置模型参数ff_chunks为2
    attn_dropout = 0.1,  # 设置模型参数attn_dropout为0.1
    rel_pos_emb = False,  # 设置模型参数rel_pos_emb为False
    n_local_attn_heads = (8, 8, 8, 8, 4, 4, 2, 2)  # 设置模型参数n_local_attn_heads为元组
)

model = AutoregressiveWrapper(model)  # 创建AutoregressiveWrapper对象,包装RoutingTransformerLM模型
model.cuda()  # 将模型移动到GPU上

# prepare enwik8 data

with gzip.open('./data/enwik8.gz') as file:  # 使用gzip打开enwik8.gz文件
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)  # 从文件中读取数据并转换为numpy数组
    trX, vaX = np.split(X, [int(90e6)])  # 将数据分割为训练集和验证集
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)  # 将数据转换为PyTorch张量

class TextSamplerDataset(Dataset):  # 定义TextSamplerDataset类,继承自Dataset类
    def __init__(self, data, seq_len):  # 定义初始化方法,接受数据和序列长度作为参数
        super().__init__()  # 调用父类的初始化方法
        self.data = data  # 设置数据属性
        self.seq_len = seq_len  # 设置序列长度属性

    def __getitem__(self, index):  # 定义获取数据项方法,接受索引作为参数
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))  # 生成随机起始位置
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()  # 获取完整序列
        return full_seq, torch.ones_like(full_seq).bool()  # 返回完整序列和掩码

    def __len__(self):  # 定义长度方法
        return self.data.size(0) // self.seq_len  # 返回数据长度除以序列长度的整数部分

train_dataset = TextSamplerDataset(data_train, SEQ_LEN)  # 创建训练集数据集对象
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)  # 创建验证集数据集对象

# setup deepspeed

cmd_args = add_argument()  # 调用add_argument函数,获取命令行参数
model_engine, optimizer, trainloader, _ = deepspeed.initialize(args=cmd_args, model=model, model_parameters=model.parameters(),  training_data=train_dataset)  # 初始化deepspeed

# training

for i, (data, mask) in enumerate(trainloader):  # 遍历训练数据加载器
    model_engine.train()  # 设置模型为训练模式

    data = data.to(model_engine.local_rank)  # 将数据移动到指定设备
    loss = model_engine(data, return_loss = True, randomly_truncate_sequence = True)  # 计算损失
    model_engine.backward(loss)  # 反向传播
    model_engine.step()  # 更新模型参数
    print(loss.item())  # 打印损失值

    if i % VALIDATE_EVERY == 0:  # 每隔VALIDATE_EVERY次迭代进行验证
        model.eval()  # 设置模型为评估模式
        with torch.no_grad():  # 禁用梯度计算
            inp, _ = random.choice(val_dataset)  # 从验证集中随机选择一个样本
            loss = model(inp[None, :].cuda(), return_loss = True)  # 计算验证集上的损失
            print(f'validation loss: {loss.item()}')  # 打印验证集损失值

    if i != 0 and model_engine.local_rank == 0 and i % GENERATE_EVERY == 0:  # 每隔GENERATE_EVERY次迭代生成文本
        model.eval()  # 设置模型为评估模式
        inp, _ = random.choice(val_dataset)  # 从验证集中随机选择一个样本
        print(inp.shape, inp)  # 打印输入数据的形状和内容
        prime = decode_tokens(inp)  # 解码输入数据
        print(f'%s \n\n %s', (prime, '*' * 100))  # 打印解码后的输入数据和分隔符

        sample = model.generate(inp.cuda(), GENERATE_LENGTH)  # 生成文本
        output_str = decode_tokens(sample)  # 解码生成的文本
        print(output_str)  # 打印生成的文本

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

.\lucidrains\routing-transformer\examples\enwik8_simple\train.py

# 导入所需的库和模块
from routing_transformer import RoutingTransformerLM
from routing_transformer.autoregressive_wrapper import AutoregressiveWrapper

import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512
SEQ_LEN = 4096

# 定义辅助函数

# 从 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 从 tokens 解码为字符串
def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# 实例化模型

model = RoutingTransformerLM(
    num_tokens = 256,
    dim = 512,
    depth = 6,
    max_seq_len = SEQ_LEN,
    heads = 8,
    causal = True,
    window_size = 128,
    n_local_attn_heads = (8, 8, 8, 4, 4, 4)
)

model = AutoregressiveWrapper(model)
model.cuda()

# 准备 enwik8 数据

with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的 DataLoader
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 训练过程
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader), return_loss = True)
        loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader), return_loss = True)
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp, GENERATE_LENGTH)
        output_str = decode_tokens(sample)
        print(output_str)

.\lucidrains\routing-transformer\examples\toy_tasks\enc_dec_copy_task.py

# 导入必要的库
import tqdm
import torch
import torch.optim as optim

# 导入自定义的模型类RoutingTransformerEncDec

from routing_transformer import RoutingTransformerEncDec

# 定义常量

NUM_BATCHES = int(1e5)  # 总批次数
BATCH_SIZE = 32  # 每批次的样本数量
LEARNING_RATE = 1e-4  # 学习率
GENERATE_EVERY  = 100  # 每隔多少批次生成一次输出
NUM_TOKENS = 256 + 2  # 标记的数量
ENC_SEQ_LEN = 128  # 编码器序列长度
DEC_SEQ_LEN = 256  # 解码器序列长度

# 定义辅助函数

def cycle():
    # 生成器函数,无限循环生成数据
    while True:
        prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
        src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
        tgt = torch.cat((prefix, src, src), 1)
        src_mask = torch.ones(BATCH_SIZE, ENC_SEQ_LEN).bool().cuda()
        tgt_mask = torch.ones(BATCH_SIZE, tgt.shape[1]).bool().cuda()
        yield (src, tgt, src_mask, tgt_mask)

# 实例化模型

model = RoutingTransformerEncDec(
    dim=512,
    enc_num_tokens=NUM_TOKENS,
    enc_depth=3,
    enc_heads=8,
    enc_max_seq_len=ENC_SEQ_LEN,
    enc_window_size=32,
    dec_num_tokens = NUM_TOKENS,
    dec_depth = 3,
    dec_heads = 8,
    dec_max_seq_len=DEC_SEQ_LEN,
    dec_window_size=32,
).cuda()

# 定义优化器

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 训练过程

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    # 获取下一个数据批次
    src, tgt, src_mask, tgt_mask = next(cycle())
    # 计算损失
    loss, _ = model(src, tgt, enc_input_mask=src_mask, dec_input_mask=tgt_mask, return_loss = True, randomly_truncate_sequence = True)
    # 反向传播
    loss.backward()

    # 梯度裁剪
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    # 更新参数
    optim.step()
    # 梯度清零
    optim.zero_grad()

    # 每GENERATE_EVERY个批次生成一次输出
    if i != 0 and i % GENERATE_EVERY == 0:
        model.eval()
        src, _, src_mask, _ = next(cycle())
        src, src_mask = src[0:1], src_mask[0:1]
        start_tokens = (torch.ones((1, 1)) * 1).long().cuda()

        # 生成输出
        sample = model.generate(src, start_tokens, ENC_SEQ_LEN, enc_input_mask=src_mask)
        # 计算错误数量
        incorrects = (src != sample).abs().sum()

        print(f"input:  ", src)
        print(f"predicted output:  ", sample)
        print(f"incorrects: {incorrects}")

.\lucidrains\routing-transformer\examples\toy_tasks\increment.py

# 导入所需的库
import torch
import numpy as np
import math
import time
import random
from torch.optim import Adam
from routing_transformer.routing_transformer import RoutingTransformerLM
from routing_transformer.autoregressive_wrapper import AutoregressiveWrapper

# 创建 RoutingTransformerLM 模型实例
s = RoutingTransformerLM(
    num_tokens = 256 + 4,
    dim = 1024,
    depth = 2,
    heads = 8,
    max_seq_len = 256,
    causal = True,
    window_size = 128
).cuda()

# 使用 AutoregressiveWrapper 对模型进行包装
s = AutoregressiveWrapper(s, ignore_index = 0, pad_value = 0)
# 使用 Adam 优化器对模型参数进行优化
opt = Adam(s.parameters(), lr=1e-4)

# 定义批次大小、源序列长度和目标序列长度
N_BATCH = 32
SRC_SEQ_LEN = 128
TGT_SEQ_LEN = 128

# 定义起始符、结束符和位置编码
bos = 1*torch.ones(N_BATCH, 1).long()
eos = 2*torch.ones(N_BATCH, 1).long()
pos = 3*torch.ones(N_BATCH, 1).long()

# 进行训练循环
for i in range(10000):
    # 生成随机的训练输入序列
    train_seq_in = torch.randint(4, 6, (N_BATCH, SRC_SEQ_LEN - 2)).long()
    # 训练输出序列为输入序列加一
    train_seq_out = train_seq_in + 1

    # 构建完整的训练序列,包括起始符、位置编码、输入序列、输出序列和结束符
    train_seq = torch.cat([bos, train_seq_in, pos, pos, pos, train_seq_out, eos], dim=1).cuda()

    # 计算模型的损失
    loss = s(train_seq, return_loss = True)
    # 反向传播计算梯度
    loss.backward()
    # 根据梯度更新模型参数
    opt.step()
    # 清空梯度
    opt.zero_grad()
    # 打印当前迭代次数和损失值
    print(i, loss.item())

Routing Transformer

PyPI version

A fully featured implementation of Routing Transformer. The paper proposes using k-means to route similar queries / keys into the same cluster for attention.

Open In Colab 131k tokens

Install

$ pip install routing_transformer

Usage

A simple language model

import torch
from routing_transformer import RoutingTransformerLM

model = RoutingTransformerLM(
    num_tokens = 20000,
    dim = 512,
    heads = 8,
    depth = 12,
    max_seq_len = 8192,
    causal = True,           # auto-regressive or not
    emb_dim = 128,           # embedding factorization, from Albert
    weight_tie = False,      # weight tie layers, from Albert
    tie_embedding = False,   # multiply final embeddings with token weights for logits
    dim_head = 64,           # be able to fix the dimension of each head, making it independent of the embedding dimension and the number of heads
    attn_dropout = 0.1,      # dropout after attention
    attn_layer_dropout = 0., # dropout after self attention layer
    ff_dropout = 0.1,        # feedforward dropout
    layer_dropout = 0.,      # layer dropout
    window_size = 128,       # target window size of each cluster
    n_local_attn_heads = 4,  # number of local attention heads
    reversible = True,       # reversible networks for memory savings, from Reformer paper
    ff_chunks = 10,          # feed forward chunking, from Reformer paper
    ff_glu = True,           # use GLU variant in feedforward
    pkm_layers = (4, 7),     # specify layers to use product key memory. paper shows 1 or 2 modules near the middle of the transformer is best
    pkm_num_keys = 128,      # defaults to 128, but can be increased to 256 or 512 as memory allows
    moe_layers = (3, 6),     # specify which layers to use mixture of experts
    moe_num_experts = 4,     # number of experts in the mixture of experts layer, defaults to 4. increase for adding more parameters to model
    moe_loss_coef = 1e-2,    # the weight for the auxiliary loss in mixture of experts to keep expert usage balanced
    num_mem_kv = 8,          # number of memory key/values to append to each cluster of each head, from the 'All-Attention' paper. defaults to 1 in the causal case for unshared QK to work
    use_scale_norm = False,  # use scale norm, simplified normalization from 'Transformers without Tears' paper
    use_rezero = False,      # use Rezero with no normalization
    shift_tokens = True      # shift tokens by one along sequence dimension, for a slight improvement in convergence
).cuda()

x = torch.randint(0, 20000, (1, 8192)).long().cuda()
input_mask = torch.ones_like(x).bool().cuda()

y, aux_loss = model(x, input_mask = input_mask) # (1, 8192, 20000)
aux_loss.backward() # add auxiliary loss to main loss before backprop

A simple transformer

import torch
from routing_transformer import RoutingTransformer

model = RoutingTransformer(
    dim = 512,
    heads = 8,
    depth = 12,
    max_seq_len = 8192,
    window_size = 128,
    n_local_attn_heads = 4
).cuda()

x = torch.randn(1, 8192, 512).cuda()
input_mask = torch.ones(1, 8192).bool().cuda()

y, aux_loss = model(x, input_mask = input_mask) # (1, 8192, 512)
aux_loss.backward() # add auxiliary loss to main loss before backprop

Encoder Decoder

To use a full encoder, decoder, simply import the RoutingTransformerEncDec class. Save for the dim keyword, all other keywords will be either prepended with enc_ or dec_ for the encoder and decoder RoutingTransformerLM class respectively.

import torch
from routing_transformer import RoutingTransformerEncDec

model = RoutingTransformerEncDec(
    dim=512,
    enc_num_tokens = 20000,
    enc_depth = 4,
    enc_heads = 8,
    enc_max_seq_len = 4096,
    enc_window_size = 128,
    dec_num_tokens = 20000,
    dec_depth = 4,
    dec_heads = 8,
    dec_max_seq_len = 4096,
    dec_window_size = 128,
    dec_reversible = True
).cuda()

src = torch.randint(0, 20000, (1, 4096)).cuda()
tgt = torch.randint(0, 20000, (1, 4096)).cuda()
src_mask = torch.ones_like(src).bool().cuda()
tgt_mask = torch.ones_like(tgt).bool().cuda()

loss, aux_loss = model(src, tgt, enc_input_mask = src_mask, dec_input_mask = tgt_mask, return_loss = True, randomly_truncate_sequence = True)
loss.backward()
aux_loss.backward()

# do your training, then to sample up to 2048 tokens based on the source sequence
src = torch.randint(0, 20000, (1, 4096)).cuda()
start_tokens = torch.ones(1, 1).long().cuda() # assume starting token is 1

sample = model.generate(src, start_tokens, seq_len = 2048, eos_token = 2) # (1, <= 2048, 20000)

Product Key Memory

To see the benefits of using PKM, the learning rate of the values must be set higher than the rest of the parameters. (Recommended to be 1e-2)

You can follow the instructions here to set it correctly https://github.com/lucidrains/product-key-memory#learning-rates

Kmeans Hyperparameters

  1. kmeans_ema_decay = {defaults to 0.999}

This is the exponential moving average decay for updating the k-means. The lower this is, the faster the means will adjust, but at the cost of stability.

  1. commitment_factor = {defaults to 1e-4}

The weight of the auxiliary loss that encourages tokens to get closer (commit) to the k-mean centroids that were chosen for them.

Updating kmeans manually

The following instructions will allow you to update the kmeans manually. By default the kmeans are updated automatically on every backward pass.

import torch
from routing_transformer import RoutingTransformerLM, AutoregressiveWrapper

model = RoutingTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 6,
    window_size = 256,
    max_seq_len = 8192,
    causal = True,
    _register_kmeans_update = False # set to False to disable auto-updating
)

model = AutoregressiveWrapper(model)

x = torch.randint(0, 20000, (1, 8192))
loss = model(x, return_loss = True)
loss.backward()

# update kmeans with this call
model.update_kmeans()

Issues

This architecture has trouble generalizing to shorter sequence lengths when decoding tokens from 1 -> maximum sequence length. The simplest and surest solution is to randomly truncate the sequence during training. This helps the network and the kmeans generalize to variable number of tokens, at the cost of prolonged training.

If you are priming the network with the full sequence length at start, then you will not face this problem, and you can skip this training procedure.

import torch
from routing_transformer import RoutingTransformerLM, AutoregressiveWrapper

model = RoutingTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    window_size = 256,
    max_seq_len = 8192,
    causal = True
)

model = AutoregressiveWrapper(model)

x = torch.randint(0, 20000, (1, 8192))
loss = model(x, return_loss = True, randomly_truncate_sequence = True) # (1, 8192, 20000)

Appreciation

Special thanks to Aran Komatsuzaki for bootstrapping the initial implementation in Pytorch that evolved into this library.

Citation

@misc{roy*2020efficient,
    title   = {Efficient Content-Based Sparse Attention with Routing Transformers},
    author  = {Aurko Roy* and Mohammad Taghi Saffar* and David Grangier and Ashish Vaswani},
    year    = {2020},
    url     = {https://arxiv.org/pdf/2003.05997.pdf}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}    
}
@inproceedings{kitaev2020reformer,
    title       = {Reformer: The Efficient Transformer},
    author      = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
    booktitle   = {International Conference on Learning Representations},
    year        = {2020},
    url         = {https://openreview.net/forum?id=rkgNKkHtvB}
}
@inproceedings{fan2020reducing,
    title     ={Reducing Transformer Depth on Demand with Structured Dropout},
    author    ={Angela Fan and Edouard Grave and Armand Joulin},
    booktitle ={International Conference on Learning Representations},
    year      ={2020},
    url       ={https://openreview.net/forum?id=SylO2yStDr}
}
@misc{lan2019albert,
    title       = {ALBERT: A Lite BERT for Self-supervised Learning of Language Representations},
    author      = {Zhenzhong Lan and Mingda Chen and Sebastian Goodman and Kevin Gimpel and Piyush Sharma and Radu Soricut},
    year        = {2019},
    url         = {https://arxiv.org/abs/1909.11942}
}
@misc{lample2019large,
    title   = {Large Memory Layers with Product Keys},
    author  = {Guillaume Lample and Alexandre Sablayrolles and Marc'Aurelio Ranzato and Ludovic Denoyer and Hervé Jégou},
    year    = {2019},
    eprint  = {1907.05242},
    archivePrefix = {arXiv}
}
@article{DBLP:journals/corr/abs-1907-01470,
    author    = {Sainbayar Sukhbaatar and
               Edouard Grave and
               Guillaume Lample and
               Herv{\'{e}} J{\'{e}}gou and
               Armand Joulin},
    title     = {Augmenting Self-attention with Persistent Memory},
    journal   = {CoRR},
    volume    = {abs/1907.01470},
    year      = {2019},
    url       = {http://arxiv.org/abs/1907.01470}
}
@misc{bhojanapalli2020lowrank,
    title   = {Low-Rank Bottleneck in Multi-head Attention Models},
    author  = {Srinadh Bhojanapalli and Chulhee Yun and Ankit Singh Rawat and Sashank J. Reddi and Sanjiv Kumar},
    year    = {2020},
    eprint  = {2002.07028}
}
@article{1910.05895,
    author  = {Toan Q. Nguyen and Julian Salazar},
    title   = {Transformers without Tears: Improving the Normalization of Self-Attention},
    year    = {2019},
    eprint  = {arXiv:1910.05895},
    doi     = {10.5281/zenodo.3525484},
}
@misc{bachlechner2020rezero,
    title   = {ReZero is All You Need: Fast Convergence at Large Depth},
    author  = {Thomas Bachlechner and Bodhisattwa Prasad Majumder and Huanru Henry Mao and Garrison W. Cottrell and Julian McAuley},
    year    = {2020},
    url     = {https://arxiv.org/abs/2003.04887}
}
@misc{vaswani2017attention,
    title   = {Attention Is All You Need},
    author  = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
    year    = {2017},
    eprint  = {1706.03762},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@software{peng_bo_2021_5196578,
    author       = {PENG Bo},
    title        = {BlinkDL/RWKV-LM: 0.01},
    month        = {aug},
    year         = {2021},
    publisher    = {Zenodo},
    version      = {0.01},
    doi          = {10.5281/zenodo.5196578},
    url          = {https://doi.org/10.5281/zenodo.5196578}
}

.\lucidrains\routing-transformer\routing_transformer\autopadder.py

# 导入数学库和 PyTorch 库
import math
import torch
# 从 torch 模块中导入 nn 模块
from torch import nn
# 从 routing_transformer 模块中导入 RoutingTransformer 类
from routing_transformer.routing_transformer import RoutingTransformer
# 从 torch.nn.functional 模块中导入 F 别名
import torch.nn.functional as F

# 定义一个函数,用于查找指定类型的模块
def find_module(nn_module, type):
    # 遍历 nn_module 中的所有模块
    for module in nn_module.modules():
        # 如果模块是指定类型的实例,则返回该模块
        if isinstance(module, type):
            return module
    # 如果未找到指定类型的模块,则返回 None
    return None

# 定义一个函数,用于将张量填充到指定的倍数
def pad_to_multiple(tensor, multiple, dim=-1, value=0):
    # 获取张量在指定维度上的长度
    seqlen = tensor.shape[dim]
    # 计算需要填充的长度
    m = seqlen / multiple
    # 如果 m 是整数,则无需填充,直接返回原张量
    if m.is_integer():
        return tensor

    # 计算填充前的偏移量和填充长度
    pre_pad_offset = (0,) * (-1 - dim) * 2
    padding = math.ceil(m) * multiple - seqlen
    # 对张量进行填充操作
    padded_tensor = F.pad(tensor, (*pre_pad_offset, *(0, padding)), value=value)
    return padded_tensor

# 定义一个自动填充器类,继承自 nn.Module
class Autopadder(nn.Module):
    def __init__(self, net):
        super().__init__()
        # 查找 RoutingTransformer 类型的模块
        transformer = find_module(net, RoutingTransformer)
        self.net = net
        # 获取 RoutingTransformer 模块的 pad_to_multiple 属性
        self.pad_multiple = transformer.pad_to_multiple

    def forward(self, x, **kwargs):
        # 如果 pad_multiple 小于等于 0,则直接调用网络的 forward 方法
        if self.pad_multiple <= 0:
            return self.net(x, **kwargs)

        # 获取输入张量 x 的形状和设备信息
        b, t, device = *x.shape, x.device

        # 获取输入参数中的 input_mask,如果不存在则创建全为 True 的 mask 张量
        input_mask = kwargs.get('input_mask')
        if input_mask is None:
            input_mask = torch.full((b, t), True, device=device, dtype=torch.bool)

        # 对输入张量和 mask 张量进行填充操作
        x = pad_to_multiple(x, self.pad_multiple, dim=1)
        new_mask = pad_to_multiple(input_mask, self.pad_multiple, dim=1, value=False)
        kwargs.update(input_mask=new_mask)

        # 调用网络的 forward 方法,���返回结果
        out, loss = self.net(x, **kwargs)
        return out[:, 0:t], loss

.\lucidrains\routing-transformer\routing_transformer\autoregressive_wrapper.py

# 导入所需的库
from functools import partial
import torch
import random
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from routing_transformer.routing_transformer import RoutingTransformerLM
from routing_transformer.autopadder import Autopadder

# 定义一个函数,返回参数值或默认值
def default(value, default):
    return value if value is not None else default

# 根据给定的阈值返回概率最高的logits
def top_p(logits, thres = 0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    sorted_indices_to_remove = cum_probs > 1.0 - thres
    sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
    sorted_indices_to_remove[:, 0] = 0

    sorted_logits[sorted_indices_to_remove] = float('-inf')
    return sorted_logits.scatter(1, sorted_indices, sorted_logits)

# 根据给定的阈值返回概率最高的k个logits
def top_k(logits, thres = 0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 对序列进行右侧填充
def pad_sequence_right(seqs, value):
    m = max([len(s) for s in seqs])
    return torch.stack([F.pad(s, (0, m - len(s))) for s in seqs])

# 截断序列
def truncate_sequence(inputs, mask = None, pad_value=0):
    b, t, device, dtype = *inputs.shape, inputs.device, inputs.dtype
    mask = default(mask, torch.ones_like(inputs).bool())
    rand_length = random.randint(2, t)
    return inputs[:, :rand_length], mask[:, :rand_length]

# 自回归包装器类
class AutoregressiveWrapper(nn.Module):
    def __init__(self, net, ignore_index = None, pad_value = 0):
        super().__init__()
        assert isinstance(net, RoutingTransformerLM), 'generative trainer wrapper can only accept RoutingTransformerLM class'
        self.pad_value = pad_value
        self.ignore_index = default(ignore_index, pad_value)

        self.net = Autopadder(net)
        self.max_seq_len = net.max_seq_len
        self.base_net = net

    # 更新kmeans
    def update_kmeans(self):
        self.base_net.update_kmeans()

    # 生成序列
    @torch.no_grad()
    def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
        was_training = self.net.training
        num_dims = len(start_tokens.shape)

        if num_dims == 1:
            start_tokens = start_tokens[None, :]

        b, t = start_tokens.shape

        self.net.eval()
        out = start_tokens
        input_mask = kwargs.pop('input_mask', None)

        if input_mask is None:
            input_mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)

        for _ in range(seq_len):
            x = out[:, -self.max_seq_len:]
            input_mask = input_mask[:, -self.max_seq_len:]
            logits, _ = self.net(x, input_mask=input_mask, **kwargs)
            logits = logits[:, -1, :]
            filtered_logits = filter_logits_fn(logits, thres = filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim=-1)
            sample = torch.multinomial(probs, 1)

            out = torch.cat((out, sample), dim=-1)
            input_mask = F.pad(input_mask, (1, 0), value=True)
            if eos_token is not None and (sample == eos_token).all():
                break

        out = out[:, t:]

        if num_dims == 1:
            out = out.squeeze(0)

        self.net.train(was_training)
        return out
    # 定义一个前向传播函数,接受输入 x,是否返回损失值,是否随机截断序列等参数
    def forward(self, x, return_loss = False, randomly_truncate_sequence = False, **kwargs):
        # 定义一个填充函数,将输入序列填充到相同长度
        pad = partial(pad_sequence, batch_first = True, padding_value = self.pad_value)

        # 如果不需要返回损失值
        if not return_loss:
            # 如果输入不是张量,则进行填充
            if not isinstance(x, torch.Tensor):
                x = pad(x)
            # 返回网络的输出结果
            return self.net(x, **kwargs)

        # 获取输入的掩码
        m = kwargs.get('input_mask', None)

        # 如果需要随机截断序列
        if randomly_truncate_sequence:
            # 对输入序列进行截断
            x, m = truncate_sequence(x, m, pad_value = self.pad_value)

        # 如果输入是张量
        if isinstance(x, torch.Tensor):
            # 将输入序列分为输入和输出部分
            xi, xo = x[:, :-1], x[:, 1:]
        else:
            # 对输入序列进行填充和截断
            xi = pad(list(map(lambda t: t[:-1], x)))
            xo = pad(list(map(lambda t: t[1:], x)))

        # 如果存在输入掩码
        if m is not None:
            # 断言输入掩码的形状与输入序列的形状相同
            assert m.shape == x.shape[0:2], 'input mask must be the same shape as the input of the auto-regressive wrapper to automatically handle'
            # 更新关键字参数中的输入掩码
            kwargs['input_mask'] = m[:, :-1]

        # 获取网络的输出和辅助损失
        out, aux_loss = self.net(xi, **kwargs)

        # 计算交叉熵损失
        loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
        # 将辅助损失加到主要损失上
        loss = loss + aux_loss
        # 返回损失值
        return loss

.\lucidrains\routing-transformer\routing_transformer\encoder_decoder.py

# 导入 re 模块,用于正则表达式操作
# 导入 isfunction 函数,用于检查对象是否为函数
# 导入 torch 模块
# 从 torch 模块中导入 nn 模块
# 从 routing_transformer.routing_transformer 模块中导入 RoutingTransformerLM 类和 update_kmeans_on_backwards 函数
# 从 routing_transformer.autoregressive_wrapper 模块中导入 AutoregressiveWrapper 类

# 定义编码器前缀
ENC_PREFIX = 'enc_'
# 定义解码器前缀
DEC_PREFIX = 'dec_'

# 定义默认函数,如果 x 为 None,则返回 d,如果 d 是函数,则调用函数返回结果
def default(x, d):
    if x is None:
        return d if not isfunction(d) else d()
    return x

# 根据条件 cond 对字典 d 进行分组,返回两个字典
def group_dict_by_key(cond, d):
    return_val = [dict(),dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)

# 判断字符串是否以指定前缀开头
def string_begins_with(prefix, str):
    return bool(re.match(f'^{prefix}', str))

# 根据前缀对字典 d 进行分组
def group_by_key_prefix(prefix, d):
    return group_dict_by_key(lambda x: string_begins_with(prefix, x), d)

# 根据前缀对字典 d 进行分组,并移除前缀
def group_by_key_prefix_and_remove_prefix(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: string_begins_with(prefix, x), d)
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
    return kwargs_without_prefix, kwargs

# 提取编码器和解码器的关键字参数
def extract_enc_dec_kwargs(kwargs):
    enc_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(ENC_PREFIX, kwargs)
    dec_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(DEC_PREFIX, kwargs)
    return enc_kwargs, dec_kwargs, kwargs

# 提取并设置编码器和解码器的关键字参数
def extract_and_set_enc_dec_kwargs(kwargs):
    enc_kwargs, dec_kwargs, kwargs = extract_enc_dec_kwargs(kwargs)
    if 'input_mask' in enc_kwargs:
        dec_kwargs.setdefault('context_mask', enc_kwargs['input_mask'])
    return enc_kwargs, dec_kwargs, kwargs

# 定义 RoutingTransformerEncDec 类,继承自 nn.Module
class RoutingTransformerEncDec(nn.Module):
    # 初始化方法
    def __init__(self, dim, ignore_index = None, pad_value = 0, **kwargs):
        super().__init__()
        ignore_index = default(ignore_index, pad_value)
        enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs)
        
        # 断言编码器关键字参数中不包含 'return_embedding',否则抛出异常
        assert 'return_embedding' not in enc_kwargs, 'you cannot manually set the return embeddings flag for the encoder'
        # 断言解码器和编码器关键字参数中均不包含 'dim',否则抛出异常
        assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder'

        # 设置编码器和解码器的维度
        enc_kwargs['dim'] = dec_kwargs['dim'] = dim
        enc_kwargs['return_embeddings'] = True
        dec_kwargs['causal'] = True
        dec_kwargs['receives_context'] = True
        enc_kwargs['_register_kmeans_update'] = dec_kwargs['_register_kmeans_update'] = False

        # 设置默认的窗口大小
        enc_kwargs.setdefault('window_size', 256)
        dec_kwargs.setdefault('window_size', 256)

        # 创建编码器和解码器对象
        enc = RoutingTransformerLM(**enc_kwargs)
        dec = RoutingTransformerLM(**dec_kwargs)

        self.enc = enc
        self.dec = AutoregressiveWrapper(dec, ignore_index = ignore_index, pad_value = pad_value)

        # 如果解码器可逆,则用户必须手动调用编码器辅助损失的反向传播
        # 应该在此处设置一个 bug 赏金
        self.dec_reversible = dec_kwargs.pop('reversible', False)

        # 显示警告消息
        if self.dec_reversible:
            print('Warning! Due to an issue with reversible nets and encoder auxiliary losses, you must explicitly call backwards on the encoder auxiliary loss, which is supplied as the second element of the returned tuple on forward')

        self._handle = None
        self.register_kmeans_update()

    # 取消 K-means 更新
    def cancel_kmeans_update(self):
        if self._handle is None:
            return
        self._handle.remove()
        self._handle = None

    # 注册 K-means 更新
    def register_kmeans_update(self):
        self.cancel_kmeans_update()
        return update_kmeans_on_backwards(self)

    # 使用 torch.no_grad() 修饰的方法
    # 生成序列,根据输入序列和起始输出序列生成目标序列
    def generate(self, seq_in, seq_out_start, max_seq_len = None, **kwargs):
        # 如果未指定最大序列长度,则使用解码器的最大序列长度
        max_seq_len = default(max_seq_len, self.dec.max_seq_len)
        # 提取并设置编码器和解码器的关键字参数
        enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs)
        # 使用编码器处理输入序列,获取上下文信息
        context, _ = self.enc(seq_in, **enc_kwargs)
        # 调用解码器生成目标序列
        return self.dec.generate(seq_out_start, max_seq_len, context = context, **{**dec_kwargs, **kwargs})

    # 前向传播,处理输入序列和目标序列,计算损失
    def forward(self, seq_in, seq_out, return_loss = False, randomly_truncate_sequence = False, **kwargs):
        # 提取并设置编码器和解码器的关键字参数
        enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs)
        # 使用编码器处理输入序列,获取上下文信息和编码器的辅助损失
        context, enc_aux_loss = self.enc(seq_in, **enc_kwargs)
        # 调用解码器计算损失
        loss = self.dec(seq_out, return_loss = return_loss, randomly_truncate_sequence = randomly_truncate_sequence, context = context, aux_loss = enc_aux_loss, **dec_kwargs)

        # 如果解码器可逆性开启,用户必须手动调用编码器辅助损失的反向传播
        if self.dec_reversible:
            return loss, enc_aux_loss

        # 初始化辅助损失为可求导的零张量
        aux_loss = torch.tensor(0., requires_grad = True)
        # 总损失为解码器损失加上编码器辅助损失
        loss = loss + enc_aux_loss
        return loss, aux_loss

.\lucidrains\routing-transformer\routing_transformer\reversible.py

import torch
import torch.nn as nn
from operator import itemgetter
from torch.autograd.function import Function
from torch.utils.checkpoint import get_device_states, set_device_states

# 用于将参数路由到可逆层函数中的函数

def route_args(router, args, depth):
    # 初始化路由后的参数列表
    routed_args = [(dict(), dict()) for _ in range(depth)]
    # 获取参数中与路由器匹配的键
    matched_keys = [key for key in args.keys() if key in router]

    for key in matched_keys:
        val = args[key]
        for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
            # 根据路由将参数分配到对应的函数中
            new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
            routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
    return routed_args

def layer_drop(layers, prob):
    # 根据概率丢弃层
    to_drop = torch.empty(len(layers)).uniform_(0, 1) < prob
    blocks = [block for block, drop in zip(layers, to_drop) if not drop]
    blocks = layers[:1] if len(blocks) == 0 else blocks
    return blocks

def cast_return(ret, requires_grad = True):
    # 将返回值转换为元组形式,用于梯度计算
    if type(ret) is not tuple:
        loss = torch.tensor(0., device=ret.device, dtype=ret.dtype, requires_grad=requires_grad)
        return (ret, loss)
    return ret

# 参考示例 https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 进行保存和设置随机数生成器
class Deterministic(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.cpu_state = None
        self.cuda_in_fwd = None
        self.gpu_devices = None
        self.gpu_states = None

    def record_rng(self, *args):
        # 记录随机数生成器状态
        self.cpu_state = torch.get_rng_state()
        if torch.cuda._initialized:
            self.cuda_in_fwd = True
            self.gpu_devices, self.gpu_states = get_device_states(*args)

    def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
        if record_rng:
            self.record_rng(*args)

        if not set_rng:
            return self.net(*args, **kwargs)

        rng_devices = []
        if self.cuda_in_fwd:
            rng_devices = self.gpu_devices

        with torch.random.fork_rng(devices=rng_devices, enabled=True):
            torch.set_rng_state(self.cpu_state)
            if self.cuda_in_fwd:
                set_device_states(self.gpu_devices, self.gpu_states)
            return self.net(*args, **kwargs)

# 受 https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 启发
# 一旦多GPU工作正常,重构并将PR发送回源代码
class ReversibleBlock(nn.Module):
    def __init__(self, f, g):
        super().__init__()
        self.f = Deterministic(f)
        self.g = Deterministic(g)

    def forward(self, x, f_args = {}, g_args = {}):
        x1, x2 = torch.chunk(x, 2, dim=2)
        y1, y2 = None, None

        f_args['_reverse'] = g_args['_reverse'] = False

        with torch.no_grad():
            f_out, f_loss = cast_return(self.f(x2, record_rng=self.training, **f_args), requires_grad = False)
            y1 = x1 + f_out

            g_out, g_loss = cast_return(self.g(y1, record_rng=self.training, **g_args), requires_grad = False)
            y2 = x2 + g_out

        return torch.cat([y1, y2], dim=2), f_loss, g_loss
    # 定义反向传播函数,接收输入 y、梯度 dy、损失函数 dl_f 和 dl_g,以及额外参数 f_args 和 g_args
    def backward_pass(self, y, dy, dl_f, dl_g, f_args = {}, g_args = {}):
        # 将 y 沿着第二维度分成两部分 y1 和 y2
        y1, y2 = torch.chunk(y, 2, dim=2)
        # 释放 y 变量的内存
        del y

        # 将 dy 沿着第二维度分成两部分 dy1 和 dy2
        dy1, dy2 = torch.chunk(dy, 2, dim=2)
        # 释放 dy 变量的内存
        del dy

        # 设置 f_args 和 g_args 中的 '_reverse' 参数为 True
        f_args['_reverse'] = g_args['_reverse'] = True

        # 启用梯度计算环境
        with torch.enable_grad():
            # 设置 y1 可以计算梯度
            y1.requires_grad = True
            # 调用 self.g 函数计算 gy1 和 g_loss
            gy1, g_loss = cast_return(self.g(y1, set_rng=True, **g_args))
            # 反向传播计算梯度
            torch.autograd.backward((gy1, g_loss), (dy2, dl_g))

        # 禁用梯度计算环境
        with torch.no_grad():
            # 计算 x2
            x2 = y2 - gy1
            # 释放 y2 和 gy1 变量的内存
            del y2, gy1

            # 计算 dx1
            dx1 = dy1 + y1.grad
            # 释放 dy1 变量的内存
            del dy1
            # 清空 y1 的梯度
            y1.grad = None

        # 再次启用梯度计算环境
        with torch.enable_grad():
            # 设置 x2 可以计算梯度
            x2.requires_grad = True
            # 调用 self.f 函数计算 fx2 和 f_loss
            fx2, f_loss = cast_return(self.f(x2, set_rng=True, **f_args))
            # 反向传播计算梯度,保留计算图
            torch.autograd.backward((fx2, f_loss), (dx1, dl_f), retain_graph=True)

        # 禁用梯度计算环境
        with torch.no_grad():
            # 计算 x1
            x1 = y1 - fx2
            # 释放 y1 和 fx2 变量的内存
            del y1, fx2

            # 计算 dx2
            dx2 = dy2 + x2.grad
            # 释放 dy2 变量的内存
            del dy2
            # 清空 x2 的梯度
            x2.grad = None

            # 拼接 x1 和去除梯度的 x2,沿着第二维度
            x = torch.cat([x1, x2.detach()], dim=2)
            # 拼接 dx1 和 dx2,沿着第二维度
            dx = torch.cat([dx1, dx2], dim=2)

        # 返回拼接后的 x 和 dx
        return x, dx
class _ReversibleFunction(Function):
    # 静态方法,定义前向传播逻辑
    @staticmethod
    def forward(ctx, x, blocks, args):
        # 保存参数
        ctx.args = args

        # 初始化辅助损失列表
        f_aux_loss = []
        g_aux_loss = []

        # 遍历每个块并执行前向传播
        for block, kwarg in zip(blocks, args):
            x, f_loss, g_loss = block(x, **kwarg)
            f_aux_loss.append(f_loss)
            g_aux_loss.append(g_loss)

        # 保存中间结果和块信息
        ctx.y = x.detach()
        ctx.blocks = blocks
        return x, torch.stack(f_aux_loss), torch.stack(g_aux_loss)

    # 静态方法,定义反向传播逻辑
    @staticmethod
    def backward(ctx, dy, dl_f, dl_g):
        # 获取保存的中间结果和参数
        y = ctx.y
        args = ctx.args
        # 反向遍历每个块并执行反向传播
        for block, kwargs, ind in zip(ctx.blocks[::-1], args[::-1], range(len(ctx.blocks))[::-1]):
            y, dy = block.backward_pass(y, dy, dl_f[ind], dl_g[ind], **kwargs)
        return dy, None, None

class SequentialSequence(nn.Module):
    # 初始化顺序序列模块
    def __init__(self, layers, args_route = {}, layer_dropout = 0.):
        super().__init__()
        # 断言每个参数路由映射的深度与顺序层的数量相同
        assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
        self.layers = layers
        self.args_route = args_route
        self.layer_dropout = layer_dropout

    # 前向传播逻辑
    def forward(self, x, **kwargs):
        # 根据参数路由获取参数
        args = route_args(self.args_route, kwargs, len(self.layers))
        layers_and_args = list(zip(self.layers, args))

        # 如果处于训练状态且存在层丢弃率,则执行层丢弃
        if self.training and self.layer_dropout > 0:
            layers_and_args = layer_drop(layers_and_args, self.layer_dropout)

        # 初始化辅助损失
        aux_loss = torch.zeros(1, device=x.device, dtype=x.dtype)

        # 遍历每个层并执行前向传播
        for (f, g), (f_args, g_args) in layers_and_args:
            res, loss = cast_return(f(x, **f_args))
            aux_loss += loss
            x = x + res

            res, loss = cast_return(g(x, **g_args))
            aux_loss += loss
            x = x + res
        return x, aux_loss

class ReversibleSequence(nn.Module):
    # 初始化可逆序列模块
    def __init__(self, blocks, args_route = {}, layer_dropout = 0.):
        super().__init__()
        self.args_route = args_route
        self.layer_dropout = layer_dropout
        # 创建可逆块模块列表
        self.blocks = nn.ModuleList([ReversibleBlock(f, g) for f, g in blocks])

    # 前向传播逻辑
    def forward(self, x, **kwargs):
        # 将输入张量在最后一个维度上进行拼接
        x = torch.cat([x, x], dim=-1)

        blocks = self.blocks
        # 根据参数路由获取参数
        args = route_args(self.args_route, kwargs, len(blocks))
        args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))

        layers_and_args = list(zip(blocks, args))

        # 如果处于训练状态且存在层丢弃率,则执行层丢弃
        if self.training and self.layer_dropout > 0:
            layers_and_args = layer_drop(layers_and_args, self.layer_dropout)
            blocks, args = map(lambda ind: list(map(itemgetter(ind), layers_and_args)), (0, 1))

        # 调用_ReversibleFunction的apply方法执行前向传播
        out, f_loss, g_loss =  _ReversibleFunction.apply(x, blocks, args)
        # 将输出张量在最后一个维度上分割成两部分并取平均
        out = torch.stack(out.chunk(2, dim=-1)).mean(dim=0)
        # 计算辅助损失
        aux_loss = f_loss.sum() + g_loss.sum()
        return out, aux_loss
posted @ 2024-06-28 14:03  绝不原创的飞龙  阅读(22)  评论(0编辑  收藏  举报