Lucidrains-系列项目源码解析-二十六-

Lucidrains 系列项目源码解析(二十六)

.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\memory_efficient_cosine_sim_attention.py

import math
import torch
import torch.nn.functional as F
from functools import partial
from torch import nn, einsum
from torch.utils.checkpoint import checkpoint

from einops import rearrange

# helper functions

# 检查变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 对输入张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# regular attention

# 普通的注意力机制
def attention(
    q, k, v,
    mask = None,
    causal = False,
    attn_bias = None,
    **kwargs
):
    # 计算查询、键之间的相似度
    sim = einsum('b h i d, b h j d -> b h i j', q, k)

    # 添加注意力偏置
    if exists(attn_bias):
        sim = sim + attn_bias

    mask_value = -torch.finfo(sim.dtype).max

    # 处理掩码
    if exists(mask):
        mask = rearrange(mask, 'b j -> b 1 1 j')
        sim = sim.masked_fill(~mask, mask_value)

    # 处理因果关系
    if causal:
        i, j = sim.shape[-2:]
        mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1)
        sim = sim.masked_fill(mask, mask_value)

    # 计算注意力权重
    attn = sim.softmax(dim = -1)

    # 计算输出
    out = einsum('b h i j, b h j d -> b h i d', attn, v)
    return out

# memory efficient attention

# 汇总查询、键、值的函数
def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices):
    q_start_index, k_start_index, q_chunk_size, k_chunk_size, device = *qk_start_indices, q.shape[-2], k.shape[-2], q.device

    weight = einsum('b h i d, b h j d -> b h i j', q, k)

    if exists(attn_bias_chunk):
        weight = weight + attn_bias_chunk

    mask_value = -torch.finfo(weight.dtype).max

    if exists(mask):
        mask = rearrange(mask, 'b j -> b 1 1 j')
        weight = weight.masked_fill(~mask, mask_value)

    if causal and q_start_index < (k_start_index + k_chunk_size - 1):
        causal_mask = torch.ones((q_chunk_size, k_chunk_size), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
        weight = weight.masked_fill(causal_mask, mask_value)

    exp_weight = weight.exp()
    weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v)

    return exp_weight.sum(dim = -1), weighted_value

# 使用 checkpoint 优化的汇总查询、键、值的函数
checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk)

# 数值不稳定的内存高效注意力机制
def numerically_unstable_memory_efficient_attention(
    q, k, v,
    mask = None,
    causal = False,
    attn_bias = None,
    q_bucket_size = 512,
    k_bucket_size = 1024,
    eps = 1e-8
):
    needs_backwards = q.requires_grad or k.requires_grad or v.requires_grad
    summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk

    # 将所有输入分块

    q_chunks = q.split(q_bucket_size, dim = -2)
    k_chunks = k.split(k_bucket_size, dim = -2)
    v_chunks = v.split(k_bucket_size, dim = -2)
    mask_chunks = mask.split(k_bucket_size, dim = -1) if exists(mask) else ((None,) * len(k_chunks))

    if exists(attn_bias):
        i, j = attn_bias.shape[-2:]
        attn_bias_chunks = attn_bias.split(q_bucket_size, dim = -2)
        attn_bias_chunks = list(map(lambda t: t.split(k_bucket_size, dim = -1), attn_bias_chunks))

    # 循环遍历所有块并累积

    out = []
    # 遍历查询块列表,获取索引和查询块
    for q_index, q_chunk in enumerate(q_chunks):
        # 计算查询块的起始索引
        q_start_index = q_index * q_bucket_size
        # 初始化期望权重列表和加权值列表
        exp_weights = []
        weighted_values = []

        # 遍历键值块、值块和掩码块的元组列表,获取索引和对应的块
        for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)):
            # 计算键块的起始索引
            k_start_index = k_index * k_bucket_size

            # 如果是因果的且键块的起始索引大于查询块的起始索引加上查询块的长度减1,则跳过当前循环
            if causal and k_start_index > (q_start_index + q_chunk.shape[-2] - 1):
                continue

            # 如果存在注意力偏置,则获取当前查询块和键块对应的注意力偏置
            attn_bias_chunk = attn_bias_chunks[q_index][k_index] if exists(attn_bias) else None

            # 调用summarize_qkv_fn函数,计算期望权重和加权值
            exp_weight_chunk, weighted_value_chunk = summarize_qkv_fn(
                q_chunk,
                k_chunk,
                v_chunk,
                mask_chunk,
                attn_bias_chunk,
                causal,
                (q_start_index, k_start_index)
            )

            # 将计算得到的期望权重和加权值添加到对应列表中
            exp_weights.append(exp_weight_chunk)
            weighted_values.append(weighted_value_chunk)

        # 计算所有加权值的总和
        all_values = sum(weighted_values)
        # 计算所有期望权重的总和
        all_weights = sum(exp_weights)

        # 对所有加权值进行归一化处理
        normalized_values = all_values / (rearrange(all_weights, '... -> ... 1') + eps)
        # 将归一化后的值添加到输出列表中
        out.append(normalized_values)

    # 沿着指定维度连接输出列表中的张量,形成最终输出结果
    return torch.cat(out, dim=-2)
# 主要类定义

class CosineSimAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        seq_len,
        heads = 8,
        dim_head = 64,
        dropout = 0.,
        causal = False,
        memory_efficient = False,
        q_bucket_size = 512,
        k_bucket_size = 1024
    ):
        super().__init__()
        self.heads = heads
        self.causal = causal

        inner_dim = heads * dim_head

        # 初始化缩放参数
        scale_init_value = -math.log(math.log2(seq_len ** 2 - seq_len))
        self.scale = nn.Parameter(torch.full((1, heads, 1, 1), scale_init_value))

        # 线性变换层,将输入维度映射到内部维度
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

        # 内存高效注意力相关参数
        # 可在前向传播中覆盖
        self.memory_efficient = memory_efficient
        self.q_bucket_size = q_bucket_size
        self.k_bucket_size = k_bucket_size

    def forward(
        self,
        x,
        context = None,
        mask = None,
        attn_bias = None,
        memory_efficient = None,
        q_bucket_size = None,
        k_bucket_size = None,
    ):
        memory_efficient = default(memory_efficient, self.memory_efficient)
        q_bucket_size = default(q_bucket_size, self.q_bucket_size)
        k_bucket_size = default(k_bucket_size, self.k_bucket_size)

        h = self.heads
        context = default(context, x)

        # 对输入进行线性变换得到查询、键、值
        q = self.to_q(x)
        k, v = self.to_kv(context).chunk(2, dim = -1)

        # 重排维度以适应多头注意力计算
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 对查询、键进行 L2 归一化
        q, k = map(l2norm, (q, k))

        # 缩放查询
        q = q * self.scale.exp()

        # 根据内存高效标志选择注意力函数
        attn_fn = attention if not memory_efficient else numerically_unstable_memory_efficient_attention

        # 计算注意力得到输出
        out = attn_fn(q, k, v, mask = mask, attn_bias = attn_bias, causal = self.causal, q_bucket_size = q_bucket_size, k_bucket_size = k_bucket_size)

        # 重排维度以还原原始形状
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\reversible.py

# 导入 torch 库
import torch
# 导入 torch 中的神经网络模块
import torch.nn as nn
# 从 operator 模块中导入 itemgetter 函数
from operator import itemgetter
# 从 torch.autograd.function 模块中导入 Function 类
from torch.autograd.function import Function
# 从 torch.utils.checkpoint 模块中导入 get_device_states 和 set_device_states 函数

# 用于将参数路由到可逆层函数中的函数
def route_args(router, args, depth):
    # 初始化路由后的参数列表
    routed_args = [(dict(), dict()) for _ in range(depth)]
    # 获取参数中与路由器匹配的键
    matched_keys = [key for key in args.keys() if key in router]

    # 遍历匹配的键
    for key in matched_keys:
        val = args[key]
        # 遍历路由后的参数列表和路由器中的路由
        for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
            # 根据路由将参数添加到对应的函数参数中
            new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
            routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
    return routed_args

# 参考示例 https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 中的保存和设置随机数生成器
class Deterministic(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.cpu_state = None
        self.cuda_in_fwd = None
        self.gpu_devices = None
        self.gpu_states = None

    def record_rng(self, *args):
        self.cpu_state = torch.get_rng_state()
        if torch.cuda._initialized:
            self.cuda_in_fwd = True
            self.gpu_devices, self.gpu_states = get_device_states(*args)

    def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
        if record_rng:
            self.record_rng(*args)

        if not set_rng:
            return self.net(*args, **kwargs)

        rng_devices = []
        if self.cuda_in_fwd:
            rng_devices = self.gpu_devices

        with torch.random.fork_rng(devices=rng_devices, enabled=True):
            torch.set_rng_state(self.cpu_state)
            if self.cuda_in_fwd:
                set_device_states(self.gpu_devices, self.gpu_states)
            return self.net(*args, **kwargs)

# 受 https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 启发
# 一旦多 GPU 确认工作正常,重构并将 PR 发回源代码
class ReversibleBlock(nn.Module):
    def __init__(self, f, g):
        super().__init__()
        self.f = Deterministic(f)
        self.g = Deterministic(g)

    def forward(self, x, f_args = {}, g_args = {}):
        x1, x2 = torch.chunk(x, 2, dim=2)
        y1, y2 = None, None

        with torch.no_grad():
            y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
            y2 = x2 + self.g(y1, record_rng=self.training, **g_args)

        return torch.cat([y1, y2], dim=2)

    def backward_pass(self, y, dy, f_args = {}, g_args = {}):
        y1, y2 = torch.chunk(y, 2, dim=2)
        del y

        dy1, dy2 = torch.chunk(dy, 2, dim=2)
        del dy

        with torch.enable_grad():
            y1.requires_grad = True
            gy1 = self.g(y1, set_rng=True, **g_args)
            torch.autograd.backward(gy1, dy2)

        with torch.no_grad():
            x2 = y2 - gy1
            del y2, gy1

            dx1 = dy1 + y1.grad
            del dy1
            y1.grad = None

        with torch.enable_grad():
            x2.requires_grad = True
            fx2 = self.f(x2, set_rng=True, **f_args)
            torch.autograd.backward(fx2, dx1, retain_graph=True)

        with torch.no_grad():
            x1 = y1 - fx2
            del y1, fx2

            dx2 = dy2 + x2.grad
            del dy2
            x2.grad = None

            x = torch.cat([x1, x2.detach()], dim=2)
            dx = torch.cat([dx1, dx2], dim=2)

        return x, dx

class _ReversibleFunction(Function):
    @staticmethod
    def forward(ctx, x, blocks, args):
        ctx.args = args
        for block, kwarg in zip(blocks, args):
            x = block(x, **kwarg)
        ctx.y = x.detach()
        ctx.blocks = blocks
        return x

    @staticmethod
    # 定义反向传播函数,接收上下文和梯度作为参数
    def backward(ctx, dy):
        # 获取上下文中的 y 和 args
        y = ctx.y
        args = ctx.args
        # 反向遍历上下文中的 blocks 和 args
        for block, kwargs in zip(ctx.blocks[::-1], args[::-1]):
            # 调用每个 block 的反向传播函数,更新 y 和 dy
            y, dy = block.backward_pass(y, dy, **kwargs)
        # 返回更新后的梯度
        return dy, None, None
# 定义一个可逆序列的神经网络模块
class ReversibleSequence(nn.Module):
    # 初始化函数,接受一组块和参数路由作为输入
    def __init__(self, blocks, args_route = {}):
        super().__init__()
        # 将参数路由保存在对象中
        self.args_route = args_route
        # 创建一个包含多个可逆块的模块列表
        self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks])

    # 前向传播函数
    def forward(self, x, **kwargs):
        # 在最后一个维度上将输入张量 x 进行拼接
        x = torch.cat([x, x], dim=-1)

        # 获取模块列表和参数路由
        blocks = self.blocks
        args = route_args(self.args_route, kwargs, len(blocks))
        # 将参数转换为字典形式
        args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))

        # 将块和参数组成元组列表
        layers_and_args = list(zip(blocks, args))

        # 调用自定义的可逆函数 _ReversibleFunction 的前向传播方法
        out =  _ReversibleFunction.apply(x, blocks, args)
        # 在最后一个维度上将输出张量拆分成两部分,然后对它们进行求和
        return torch.stack(out.chunk(2, dim=-1)).sum(dim=0)

.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\transformer.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 functools 库中导入 partial 函数
from functools import partial
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 从 memory_efficient_attention_pytorch 库中导入 FlashAttention 和 Attention 类
from memory_efficient_attention_pytorch import FlashAttention, Attention
# 从 memory_efficient_attention_pytorch.reversible 库中导入 ReversibleSequence 类
from memory_efficient_attention_pytorch.reversible import ReversibleSequence

# 定义一个函数,用于检查变量是否存在
def exists(val):
    return val is not None

# 定义一个继承自 nn.Module 的类 PreNorm
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, **kwargs):
        # 对输入数据进行 LayerNorm 归一化
        x = self.norm(x)
        # 调用传入的函数处理归一化后的数据
        return self.fn(x, **kwargs)

# 定义一个继承自 nn.Module 的类 FeedForward
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, chunks = 1):
        super().__init__()
        self.chunks = chunks

        # 定义一个包含线性层和 GELU 激活函数的神经网络
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        # 如果 chunks 小于等于 1,则直接对输入数据进行处理
        if self.chunks <= 1:
            return self.net(x)

        # 将输入数据按照指定维度进行切分
        chunks = x.chunk(self.chunks, dim = 1)
        # 对每个切分后的数据块进行处理
        out = [self.net(chunk) for chunk in chunks]
        # 将处理后的数据块拼接在一起
        return torch.cat(out, dim = 1)

# 定义一个继承自 nn.Module 的类 Transformer
class Transformer(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,
        max_seq_len,
        dim,
        depth,
        causal = False,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        ff_chunks = 1,
        use_flash_attn = True,
        **kwargs
    ):
        super().__init__()
        self.max_seq_len = max_seq_len

        # 定义一个 token 的 Embedding 层
        self.token_emb = nn.Embedding(num_tokens, dim)
        # ���义一个位置编码的 Embedding 层
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        # 根据 use_flash_attn 参数选择不同的注意力机制类
        attn_klass = FlashAttention if use_flash_attn else partial(Attention, memory_efficient = True)

        # 初始化一个空的神经网络层列表
        self.layers = nn.ModuleList([])
        # 根据深度循环创建多个层
        for _ in range(depth):
            # 每个层包含一个注意力机制和一个前馈神经网络
            self.layers.append(nn.ModuleList([
                PreNorm(dim, attn_klass(dim = dim, dim_head = dim_head, heads = heads, causal = causal, **kwargs)),
                PreNorm(dim, FeedForward(dim = dim, mult = ff_mult, chunks = ff_chunks)),
            ]))

        # 创建一个可逆序列
        self.net = ReversibleSequence(self.layers)

        # 定义一个输出层,用于将模型输出转换为预测标签
        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_tokens)
        )

    def forward(self, x, labels = None):
        device = x.device
        # 对输入数据进行 token embedding
        x = self.token_emb(x)

        # 生成位置编码
        pos_emb = self.pos_emb(torch.arange(x.shape[-2], device = device))
        x = x + pos_emb

        # 通过网络层进行前向传播
        x = self.net(x)

        # 将输出数据转换为预测标签
        logits = self.to_logits(x)

        # 如果不存在标签,则直接返回预测结果
        if not exists(labels):
            return logits

        # 计算交叉熵损失
        return F.cross_entropy(rearrange(logits, 'b n d -> b d n'), labels)

.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\__init__.py

# 从 memory_efficient_attention_pytorch.memory_efficient_attention 模块中导入 Attention 类和 memory_efficient_attention 函数
from memory_efficient_attention_pytorch.memory_efficient_attention import Attention, memory_efficient_attention
# 从 memory_efficient_attention_pytorch.memory_efficient_cosine_sim_attention 模块中导入 CosineSimAttention 类和 numerically_unstable_memory_efficient_attention 函数
from memory_efficient_attention_pytorch.memory_efficient_cosine_sim_attention import CosineSimAttention, numerically_unstable_memory_efficient_attention
# 从 memory_efficient_attention_pytorch.flash_attention 模块中导入 FlashAttention 类
from memory_efficient_attention_pytorch.flash_attention import FlashAttention

Memory Efficient Attention Pytorch (obsolete)

Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(n²) Memory. In addition, the module will take care of masking, causal masking, as well as cross attention.

This repository also contains a naive non-CUDA implementation of the improvements made by Tri Dao with his Flash Attention 2 paper, for educational purposes. It is a game changer for attention and building long-context transformers.

Update: from now on, you should just be using the F.scaled_dot_product_attention function in Pytorch 2.0 for built-in Flash Attention v1 support - or use Flash Attention v2 at the official repository

Install

$ pip install memory-efficient-attention-pytorch

Usage

For autoregressive language model

import torch
from memory_efficient_attention_pytorch import Attention

attn = Attention(
    dim = 512,
    dim_head = 64,                # dimension per head
    heads = 8,                    # number of attention heads
    causal = True,                # autoregressive or not
    memory_efficient = True,      # whether to use memory efficient attention (can be turned off to test against normal attention)
    q_bucket_size = 1024,         # bucket size along queries dimension
    k_bucket_size = 2048          # bucket size along key / values dimension
).cuda()

x = torch.randn(1, 65536, 512).cuda()
out = attn(x) # (1, 65536, 512)

Cross attention

import torch
from memory_efficient_attention_pytorch import Attention

cross_attn = Attention(
    dim = 512,
    dim_head = 64,
    heads = 8,
    memory_efficient = True,
    q_bucket_size = 1024,
    k_bucket_size = 2048
).cuda()

x = torch.randn(1, 65536, 512).cuda()
context = torch.randn(1, 65536, 512).cuda()
mask = torch.ones(1, 65536).bool().cuda()

out = cross_attn(x, context = context, mask = mask) # (1, 65536, 512)

Citations

@misc{rabe2021selfattention,
    title   = {Self-attention Does Not Need $O(n^2)$ Memory}, 
    author  = {Markus N. Rabe and Charles Staats},
    year    = {2021},
    eprint  = {2112.05682},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@article{Dao2022FlashAttentionFA,
    title   = {FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
    author  = {Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2205.14135}
}
@article{dao2023flashattention2,
  title     = {Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning,
  author    = {Dao, Tri},
  year      = {2023}
}

.\lucidrains\memory-efficient-attention-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'memory-efficient-attention-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.1.6',  # 版本号
  license='MIT',  # 许可证
  description = 'Memory Efficient Attention - Pytorch',  # 描述
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/memory-efficient-attention-pytorch',  # 项目链接
  keywords = [
    'artificial intelligence',  # 关键词
    'deep learning',  # 关键词
    'attention-mechanism'  # 关键词
  ],
  install_requires=[
    'einops>=0.4.1',  # 安装所需的依赖项
    'torch>=1.6'    # 安装所需的依赖项
  ],
  setup_requires=[
    'pytest-runner',  # 安装设置所需的依赖项
  ],
  tests_require=[
    'pytest'  # 安装测试所需的依赖项
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类器
    'Intended Audience :: Developers',  # 分类器
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类器
    'License :: OSI Approved :: MIT License',  # 分类器
    'Programming Language :: Python :: 3.8',  # 分类器
  ],
)

.\lucidrains\memory-efficient-attention-pytorch\tests\test.py

# 导入 torch 库
import torch
# 从 memory_efficient_attention_pytorch 中导入 Attention 类
from memory_efficient_attention_pytorch import Attention

# 从 memory_efficient_attention_pytorch.memory_efficient_attention 中导入 attention 函数
from memory_efficient_attention_pytorch.memory_efficient_attention import attention
# 从 memory_efficient_attention_pytorch.flash_attention 中导入 FlashAttention 和 FlashAttentionFunction 类

# 定义常量

# 判断两个张量是否接近
def isclose(a, b, atol = 1e-6):
    # 计算两个张量的最大差值
    diff = (a - b).abs().amax()
    # 返回是否小于给定的阈值
    return diff < atol

# 测试输出是否相等

def test_output_equal():
    # 创建 Attention 对象
    attn = Attention(
        dim = 512,
        dim_head = 64,
        heads = 8,
        q_bucket_size = 64,
        k_bucket_size = 64,
        causal = True
    )

    # 创建随机张量 x 和掩码 mask
    x = torch.randn(2, 2048, 512)
    mask = torch.ones(2, 2048).bool()

    # 使用 Attention 对象计算输出
    out = attn(x, mask = mask)
    # 使用 Attention 对象计算输出(启用内存效率模式)
    mem_efficient_out = attn(x, mask = mask, memory_efficient = True)

    # 断言内存效率输出与普通输出是否接近
    assert isclose(mem_efficient_out, out, atol = 1e-6)

# 测试梯度是否相等

def test_gradients_equal():
    # 创建 Attention 对象
    attn = Attention(
        dim = 512,
        dim_head = 64,
        heads = 8,
        q_bucket_size = 64,
        k_bucket_size = 64,
        causal = True
    )

    # 定义损失函数
    def loss_fn(inp, **kwargs):
        return attn(inp, **kwargs).sum()

    # 创建随机张量 x 和掩码 mask
    x = torch.randn(2, 2048, 512).requires_grad_()
    mask = torch.ones(2, 2048).bool()

    # 计算损失并反向传播
    loss_fn(x, mask = mask).backward()
    out_grad = x.grad.clone()

    x.grad.zero_()
    # 计算损失并反向传播(启用内存效率模式)
    loss_fn(x, mask = mask, memory_efficient = True).backward()
    mem_efficient_out_grad = x.grad.clone()

    # 断言内存效率梯度与普通梯度是否接近
    assert isclose(out_grad, mem_efficient_out_grad, atol = 1e-5)

# 测试 Flash Attention

def test_flash_attn_output_equal():
    attn_kwargs = dict(
        dim = 512,
        dim_head = 64,
        heads = 8,
        q_bucket_size = 64,
        k_bucket_size = 64,
        causal = True
    )

    # 创建 Attention 和 FlashAttention 对象
    attn = Attention(**attn_kwargs)
    flash_attn = FlashAttention(**attn_kwargs)

    # 将 Attention 对象的权重赋值给 FlashAttention 对象
    flash_attn.to_q = attn.to_q
    flash_attn.to_kv = attn.to_kv
    flash_attn.to_out = attn.to_out

    # 创建随机张量 x 和掩码 mask
    x = torch.randn(2, 2048, 512)
    mask = torch.ones(2, 2048).bool()

    # 使用 Attention 和 FlashAttention 对象计算输出
    out = attn(x, mask = mask)
    mem_efficient_out = flash_attn(x, mask = mask)

    # 断言内存效率输出与普通输出是否接近
    assert isclose(mem_efficient_out, out, atol = 1e-6)

# 测试 Flash Attention 梯度是否相等

def test_flash_attn_gradients_equal():
    q = torch.randn(1, 8, 1024, 512).requires_grad_()
    k = torch.randn(1, 8, 1024, 512).requires_grad_()
    v = torch.randn(1, 8, 1024, 512).requires_grad_()

    mask = torch.ones(1, 1024).bool()

    # 使用 attention 函数计算输出并反向传播
    o = attention(q, k, v, mask = mask, causal = True)
    o.sum().backward()

    dq_grad = q.grad.clone()
    dk_grad = k.grad.clone()
    dv_grad = v.grad.clone()

    q.grad.zero_()
    k.grad.zero_()
    v.grad.zero_()

    # 使用 FlashAttentionFunction 计算输出并反向传播
    flash_o = FlashAttentionFunction.apply(q, k, v, mask, True, 64, 64)
    flash_o.sum().backward()

    flash_dq_grad = q.grad.clone()
    flash_dk_grad = k.grad.clone()
    flash_dv_grad = v.grad.clone()

    # 断言 FlashAttention 梯度与 attention 函数梯度是否接近
    assert isclose(flash_dq_grad, dq_grad, atol = 1e-5)
    assert isclose(flash_dk_grad, dk_grad, atol = 1e-5)
    assert isclose(flash_dv_grad, dv_grad, atol = 1e-5)

# 测试 Flash Attention - 完全注意力掩码

def test_flash_attn_full_attn_mask_output_equal():
    attn_kwargs = dict(
        dim = 512,
        dim_head = 64,
        heads = 8,
        q_bucket_size = 64,
        k_bucket_size = 64,
        causal = True
    )

    # 创建 Attention 和 FlashAttention 对象
    attn = Attention(**attn_kwargs)
    flash_attn = FlashAttention(**attn_kwargs)

    # 将 Attention 对象的权重赋值给 FlashAttention 对象
    flash_attn.to_q = attn.to_q
    flash_attn.to_kv = attn.to_kv
    flash_attn.to_out = attn.to_out

    # 创建随机张量 x 和完全注意力掩码 mask
    x = torch.randn(2, 2048, 512)
    mask = torch.ones(2, 1, 2048, 2048).bool()

    # 使用 Attention 和 FlashAttention 对象计算输出
    out = attn(x, mask = mask)
    mem_efficient_out = flash_attn(x, mask = mask)

    # 断言内存效率输出与普通输出是否接近
    assert isclose(mem_efficient_out, out, atol = 1e-6)

# 测试梯度是否相等 - 完全注意力掩码

def test_flash_attn_full_attn_mask_gradients_equal():
    q = torch.randn(1, 8, 1024, 512).requires_grad_()
    k = torch.randn(1, 8, 1024, 512).requires_grad_()
    v = torch.randn(1, 8, 1024, 512).requires_grad_()

    mask = torch.ones(1, 1, 1024, 1024).bool()

    # 使用 attention 函数计算输出
    o = attention(q, k, v, mask = mask, causal = True)
    # 对输出进行求和并计算反向传播
    o.sum().backward()

    # 克隆梯度信息
    dq_grad = q.grad.clone()
    dk_grad = k.grad.clone()
    dv_grad = v.grad.clone()

    # 将梯度信息清零
    q.grad.zero_()
    k.grad.zero_()
    v.grad.zero_()

    # 使用自定义的FlashAttentionFunction进行注意力计算,并进行反向传播
    flash_o = FlashAttentionFunction.apply(q, k, v, mask, True, 64, 64)
    flash_o.sum().backward()

    # 克隆FlashAttentionFunction计算后的梯度信息
    flash_dq_grad = q.grad.clone()
    flash_dk_grad = k.grad.clone()
    flash_dv_grad = v.grad.clone()

    # 断言FlashAttentionFunction计算后的梯度信息与原始梯度信息在一定误差范围内相等
    assert isclose(flash_dq_grad, dq_grad, atol = 1e-5)
    assert isclose(flash_dk_grad, dk_grad, atol = 1e-5)
    assert isclose(flash_dv_grad, dv_grad, atol = 1e-5)

.\lucidrains\memory-efficient-attention-pytorch\train.py

# 从 memory_efficient_attention_pytorch 库中导入 Transformer 类
# 从 memory_efficient_attention_pytorch 库中导入 AutoregressiveWrapper 类
from memory_efficient_attention_pytorch.transformer import Transformer
from memory_efficient_attention_pytorch.autoregressive_wrapper import AutoregressiveWrapper

# 导入必要的库
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 常量定义
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 1024
SEQ_LEN = 4096

# 辅助函数

# 从 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 从 tokens 解码为字符串
def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# 实例化类似 GPT 的解码器模型

# 创建 Transformer 模型实例
model = Transformer(
    num_tokens = 256,
    dim = 512,
    max_seq_len = SEQ_LEN,
    depth = 6,
    heads = 8,
    causal = True,
    q_bucket_size = 256,
    k_bucket_size = 256,
    ff_chunks = 5,
    use_flash_attn = True
)

# 使用 AutoregressiveWrapper 对模型进行包装
model = AutoregressiveWrapper(model)
# 将模型移动到 GPU 上
model.cuda()

# 准备 enwik8 数据

# 读取 enwik8 数据集
with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 创建自定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的数据集实例
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
# 创建训练集和验证集的 DataLoader 实例
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 优化器

# 创建 Adam 优化器实例
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 训练

# 循环训练 NUM_BATCHES 次
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    # 梯度累积
    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward()

    print(f'training loss: {loss.item()}')
    # 梯度裁剪
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f'validation loss: {loss.item()}')

    if i != 0 and i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp[None, ...], GENERATE_LENGTH)
        output_str = decode_tokens(sample[0])
        print(output_str)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

.\lucidrains\memory-transformer-xl\examples\enwik8_simple\train.py

# 导入所需的库
from memory_transformer_xl import MemoryTransformerXL
from memory_transformer_xl.autoregressive_wrapper import AutoregressiveWrapper

import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 16
MAX_BATCH_SIZE = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY  = 100

GENERATE_EVERY  = 500
PRIME_LENGTH    = 512
GENERATE_LENGTH = 1024

SEQ_LEN = 512
NUM_SEGMENTS = 4

# 定义辅助函数
def cycle(loader):
    # 无限循环生成数据
    while True:
        for data in loader:
            yield data

def decode_token(token):
    # 将 token 解码为字符
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    # 将 tokens 解码为字符串
    return ''.join(list(map(decode_token, tokens)))

# 实例化模型
model = MemoryTransformerXL(
    num_tokens = 256,
    dim = 512,
    depth = 8,
    seq_len = SEQ_LEN,
    mem_len = SEQ_LEN,
    lmem_len = SEQ_LEN // 4,
    heads = 8,
    memory_layers = [6,7,8]
)

model = AutoregressiveWrapper(model)
model.cuda()

# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len, segments):
        super().__init__()
        self.data = data
        self.seq_len = seq_len
        self.segments = segments
        self.total_len = seq_len * segments

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.total_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.total_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.total_len

# 创建训练集和验证集
train_dataset = TextSamplerDataset(data_train, SEQ_LEN, NUM_SEGMENTS)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN, NUM_SEGMENTS)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    grad_accum_every = BATCH_SIZE / MAX_BATCH_SIZE

    for loss, is_last in model(next(train_loader), max_batch_size = MAX_BATCH_SIZE, return_loss = True):
        (loss / grad_accum_every).backward(retain_graph = True)

        print(f'training loss: {loss.item():.4f}')

        if is_last:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optim.step()
            optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            for loss, _ in model(next(val_loader), return_loss = True):
                print(f'validation loss: {loss.item():.4f}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        inp = inp[:PRIME_LENGTH]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp, GENERATE_LENGTH)
        output_str = decode_tokens(sample)
        print(output_str)

.\lucidrains\memory-transformer-xl\memory_transformer_xl\autoregressive_wrapper.py

# 导入数学库
import math
# 导入partial函数
from functools import partial
# 导入namedtuple
from collections import namedtuple

# 导入torch库
import torch
# 导入torch的nn模块
from torch import nn
# 导入torch的functional模块
import torch.nn.functional as F
# 导入torch的rnn模块
from torch.nn.utils.rnn import pad_sequence

# 定义一个命名元组Return,包含loss和is_last_batch两个字段
Return = namedtuple('Return', ['loss', 'is_last_batch'])

# 定义top_p函数,用于根据概率阈值过滤logits
def top_p(logits, thres = 0.9):
    # 对logits进行降序排序
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    # 计算累积概率
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    # 根据阈值确定需要移除的索引
    sorted_indices_to_remove = cum_probs > (1 - thres)
    sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
    sorted_indices_to_remove[:, 0] = 0

    # 将需要移除的logits设置为负无穷
    sorted_logits[sorted_indices_to_remove] = float('-inf')
    return sorted_logits.scatter(1, sorted_indices, sorted_logits)

# 定义top_k函数,用于根据概率阈值过滤logits
def top_k(logits, thres = 0.9):
    # 计算需要保留的top k个元素
    k = int((1 - thres) * logits.shape[-1])
    # 获取top k的值和索引
    val, ind = torch.topk(logits, k)
    # 创建与logits相同形状的tensor,并填充为负无穷
    probs = torch.full_like(logits, float('-inf'))
    # 将top k的值填充到对应位置
    probs.scatter_(1, ind, val)
    return probs

# 定义AutoregressiveWrapper类,继承自nn.Module
class AutoregressiveWrapper(nn.Module):
    # 初始化函数
    def __init__(self, net, ignore_index = -100, pad_value = 0):
        super().__init__()
        self.pad_value = pad_value
        self.ignore_index = ignore_index

        self.net = net
        self.seq_len = net.seq_len

    # 生成函数,用于生成序列
    @torch.no_grad()
    def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
        # 保存模型当前是否为训练状态
        was_training = self.net.training
        num_dims = len(start_tokens.shape)

        if num_dims == 1:
            start_tokens = start_tokens[None, :]

        b, t = start_tokens.shape

        self.net.eval()

        out = start_tokens

        # 处理默认的masking

        full_mask_like = lambda x: torch.full_like(x, True, dtype=torch.bool, device=x.device)

        mask = kwargs.pop('mask', None)
        if mask is None:
            mask = full_mask_like(out)

        # 处理任意长度的primed序列

        mem = None
        *primes, out = out.split(self.seq_len, dim=1)
        *prime_masks, mask = mask.split(self.seq_len, dim=1)

        for prime, prime_mask in zip(primes, prime_masks):
            _, mem = self.net(prime, memories = mem, mask = prime_mask, **kwargs)

        # 生成直到达到序列长度

        input_len = out.shape[1]

        for _ in range(seq_len):
            logits, mem = self.net(out[:, -input_len:], memories = mem, mask = mask[:, -input_len:], **kwargs)
            logits = logits[:, -1, :]
            filtered_logits = filter_logits_fn(logits, thres = filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim=-1)
            sample = torch.multinomial(probs, 1)

            # 不同于大多数模型,一旦填满完整序列长度,输入从序列长度为1开始

            out = torch.cat((out, sample), dim=-1)
            mask = F.pad(mask, (0, 1), value=True)

            # 将样本追加到累��输出

            input_len = input_len % self.seq_len
            input_len += 1

            if eos_token is not None and (sample == eos_token).all():
                break

        out = out[:, t:]

        if num_dims == 1:
            out = out.squeeze(0)

        self.net.train(was_training)
        return out
    # 定义前向传播函数,接受输入 x,最大批处理大小 max_batch_size,默认不返回损失,截断长度 truncate_every,以及其他关键字参数
    def forward(self, x, max_batch_size = None, return_loss = False, truncate_every = None, **kwargs):
        # 定义一个填充函数,用于在序列维度上进行填充,保证批处理时序列长度一致
        pad = partial(pad_sequence, batch_first = True, padding_value = self.pad_value)

        # 如果不需要返回损失
        if not return_loss:
            # 如果输入不是张量,则进行填充
            if not isinstance(x, torch.Tensor):
                x = pad(x)
            # 调用网络进行前向传播
            return self.net(x, **kwargs)

        # 如果输入是张量
        if isinstance(x, torch.Tensor):
            # 将输入序列切片为输入和输出序列
            xi = x[:, :-1]
            xo = x[:, 1:]
        else:
            # 对输入序列进行填充和切片,得到输入和输出序列
            xi = pad(list(map(lambda t: t[:-1], x)))
            xo = pad(list(map(lambda t: t[1:], x)))

        # 处理输入掩码,解决自回归模型中输入掩码与源序列长度不一致的问题
        mask = kwargs.pop('mask', None)
        if mask is not None and mask.shape[1] == x.shape[1]:
            mask = mask[:, :-1]

        # 定义分段函数,用于将序列分段
        segment_fn = lambda x: x.split(self.seq_len, dim=1)
        # 将输入和输出序列分段
        (xi, xo) = map(segment_fn, (xi, xo))

        # 获取分段数量
        num_segments = len(xi)
        # 如果存在掩码,则对掩码进行分段
        mask = segment_fn(mask) if mask is not None else ((None,) * num_segments)

        # 如果未指定最大批处理大小,则使用输入序列的大小
        max_batch_size = x.shape[0] if max_batch_size is None else max_batch_size
        # 定义分批函数,用于将序列按照最大批处理大小进行分批
        split_batch_fn = lambda x: x.split(max_batch_size, dim=0)

        # 计算梯度累积次数
        grad_accumulate_every = math.ceil(x.shape[0] / max_batch_size)
        # 初始化记忆列表
        mems = [None] * grad_accumulate_every

        # 遍历每个分段
        for ind, (xi_seg, xo_seg, mask_seg) in enumerate(zip(xi, xo, mask)):
            # 将输入和输出序列按照最大批处理大小进行分批
            xi_seg, xo_seg = map(split_batch_fn, (xi_seg, xo_seg))
            mask_seg = split_batch_fn(mask_seg) if mask_seg is not None else ((None,) * grad_accumulate_every)
            # 判断是否需要截断
            truncate = truncate_every is not None and ((ind + 1) % truncate_every) == 0

            new_mems = []
            # 遍历每个分批
            for ind, (xi_seg_b, xo_seg_b, mask_seg_b, mem) in enumerate(zip(xi_seg, xo_seg, mask_seg, mems)):
                is_last = ind == (grad_accumulate_every - 1)

                # 调用网络进行前向传播,获取输出和新的记忆
                logits, new_mem = self.net(xi_seg_b, mask = mask_seg_b, memories = mem, detach_lmem = truncate, **kwargs)
                new_mems.append(new_mem)

                # 计算交叉熵损失
                loss = F.cross_entropy(logits.transpose(1, 2), xo_seg_b, ignore_index = self.ignore_index)
                # 返回损失和是否为最后一个分批
                yield Return(loss, is_last)

            mems = new_mems

.\lucidrains\memory-transformer-xl\memory_transformer_xl\memory_transformer_xl.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F

# 从 mogrifier 模块中导入 Mogrifier 类
from mogrifier import Mogrifier

# 导入 math 库
import math
# 从 collections 模块中导入 namedtuple 类
from collections import namedtuple
# 从 functools 模块中导入 partial 函数
from functools import partial
# 从 inspect 模块中导入 isfunction 函数
from inspect import isfunction

# 定义一个名为 Memory 的命名元组,包含 short 和 long 两个字段
Memory = namedtuple('Memory', ['short', 'long'])

# 定义辅助函数

# 返回一个字典,包含输入张量的数据类型和设备信息
def to(t):
    return {'dtype': t.dtype, 'device': t.device}

# 如果输入元素 el 不是元组,则将其转换为元组
def cast_tuple(el):
    return el if isinstance(el, tuple) else (el,)

# 如果输入值 x 不为 None,则返回 x,否则返回 val 或 val() 的结果(如果 val 是函数)
def default(x, val):
    if x is not None:
        return x
    return val if not isfunction(val) else val()

# 返回输入张量的最小负值
def max_neg_value(tensor):
    return -torch.finfo(tensor.dtype).max

# 在指定维度上重新塑形张量
def reshape_dim(t, dim, split_dims):
    shape = list(t.shape)
    num_dims = len(shape)
    dim = (dim + num_dims) % num_dims
    shape[dim:dim+1] = split_dims
    return t.reshape(shape)

# 在指定维度上将张量拆分为两部分
def split_at_index(dim, index, t):
    pre_slices = (slice(None),) * dim
    l = (*pre_slices, slice(None, index))
    r = (*pre_slices, slice(index, None))
    return t[l], t[r]

# 在指定维度上创建一个先进先出队列
def queue_fifo(*args, length, dim=-2):
    queue = torch.cat(args, dim=dim)
    if length > 0:
        return split_at_index(dim, -length, queue)

    device = queue.device
    shape = list(queue.shape)
    shape[dim] = 0
    return queue, torch.empty(shape, device=device)

# 将输入张量在最后一个维度上进行循环移位
def shift(x):
    *_, i, j = x.shape
    zero_pad = torch.zeros((*_, i, i), **to(x))
    x = torch.cat([x, zero_pad], -1)
    l = i + j - 1
    x = x.view(*_, -1)
    zero_pad = torch.zeros(*_, -x.size(-1) % l, **to(x))
    shifted = torch.cat([x, zero_pad], -1).view(*_, -1, l)
    return shifted[..., :i, i - 1:]

# 迭代张量的第一个维度
def iterate_tensor(t):
    length = t.shape[0]
    for ind in range(length):
        yield t[ind]

# 初始化具有指定形状和维度的参数张量
def init_parameter(shape, dim):
    t = torch.zeros(shape)
    std = 1 / math.sqrt(dim)
    t.uniform_(-std, std)
    return nn.Parameter(t)

# 定义辅助类

# 定义一个具有残差连接的模块
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

# 定义一个具有预层归一化的模块
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

# 定义神经调制的双稳态循环单元和其他门控类

# 定义一个神经调制的双稳态循环单元类
class nBRC(nn.Module):
    def __init__(self, dims, hidden_dims):
        super().__init__()
        self.Ua = nn.Linear(dims, hidden_dims)
        self.Wa = nn.Linear(dims, hidden_dims)
        self.Uc = nn.Linear(dims, hidden_dims)
        self.Wc = nn.Linear(dims, hidden_dims)
        self.U  = nn.Linear(dims, hidden_dims)

    def forward(self, x, h):
        l = lambda linear, tensor: F.linear(tensor, linear.weight.clone(), linear.bias.clone())

        a = 1 + torch.tanh(l(self.Ua, x) + l(self.Wa, h))
        c = torch.sigmoid(l(self.Uc, x) + l(self.Wc, h))
        return c * h + (1 - c) * torch.tanh(l(self.U, x) + a * h)

# 定义一个门控类,使用 GRU 作为门控单元
class GRUGating(nn.Module):
    def __init__(self, dim, fn, mogrify=False):
        super().__init__()
        self.dim = dim
        self.fn = fn
        self.gru = nBRC(dim, dim)
        self.mogrify = Mogrifier(dim, factorize_k=dim // 4) if mogrify else None

    def forward(self, x, **kwargs):
        shape = x.shape
        dim = self.dim

        y = self.fn(x, **kwargs)

        if self.mogrify is not None:
            y, x = self.mogrify(y, x)

        gated_output = self.gru(
            y.reshape(-1, dim),
            x.reshape(-1, dim)
        )

        return gated_output.reshape(shape)

# feedforward

# 定义 GELU 激活函数类
class GELU_(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))

# 如果 nn 模块中存在 GELU 函数,则使用 nn.GELU,否则使用自定义的 GELU_ 函数
GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_

# 定义一个前馈神经网络类
class FeedForward(nn.Module):
    # 初始化神经网络模块,设置输入维度、倍数、dropout率、激活函数和是否使用GLU
    def __init__(self, dim, mult = 4, dropout = 0., activation = None, glu = False):
        # 调用父类的初始化方法
        super().__init__()
        # 设置默认激活函数为GELU
        activation = default(activation, GELU)

        # 是否使用GLU
        self.glu = glu
        # 第一层线性变换,输入维度为dim,输出维度为dim * mult * (2 if glu else 1)
        self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1))
        # 激活函数
        self.act = activation()
        # dropout层
        self.dropout = nn.Dropout(dropout)
        # 第二层线性变换,输入维度为dim * mult,输出维度为dim
        self.w2 = nn.Linear(dim * mult, dim)

    # 前向传播函数
    def forward(self, x, **kwargs):
        # 如果不使用GLU
        if not self.glu:
            # 第一层线性变换
            x = self.w1(x)
            # 激活函数
            x = self.act(x)
        else:
            # 使用GLU
            # 将第一层线性变换的输出分成两部分
            x, v = self.w1(x).chunk(2, dim=-1)
            # 激活函数作用在其中一部分上,另一部分保持不变
            x = self.act(x) * v

        # dropout层
        x = self.dropout(x)
        # 第二层线性变换
        x = self.w2(x)
        # 返回结果
        return x
# 定义自注意力机制类
class SelfAttention(nn.Module):
    def __init__(self, dim, seq_len, mem_len, lmem_len, heads = 8, attn_dropout = 0., dropout = 0., memory_attn_dropout = 0., one_kv_head = False, num_mem_kv = 4):
        super().__init__()
        assert (dim % heads) == 0, 'dimension must be divisible by the number of heads'

        self.heads = heads
        self.dim_head = dim // heads
        self.seq_len = seq_len
        self.mem_len = mem_len
        self.lmem_len = lmem_len
        self.scale = self.dim_head ** (-0.5)

        self.to_q = nn.Linear(dim, dim, bias = False)

        kv_dim = self.dim_head if one_kv_head else dim
        self.to_kv = nn.Linear(dim, kv_dim * 2, bias = False)
        self.to_out = nn.Linear(dim, dim)

        self.mem_kv = init_parameter((1, num_mem_kv, dim), dim)

        self.attn_dropout = nn.Dropout(attn_dropout)
        self.dropout = nn.Dropout(dropout)

        self.memory_attn_dropout = nn.Dropout(memory_attn_dropout)

    def forward(self, x, memories = None, pos_emb = None, input_mask = None, calc_memory = True, **kwargs):
        b, t, e, h, dim_h = *x.shape, self.heads, self.dim_head

        memories = default(memories, (None, None))
        mem, lmem = memories

        init_mem = lambda: torch.empty(b, 0, e, **to(x))
        mem = default(mem, init_mem)
        lmem = default(lmem, init_mem)
        mem_kv = self.mem_kv.expand(b, -1, -1)

        mem_len, lmem_len, mem_kv_len = map(lambda t: t.shape[1], (mem, lmem, mem_kv))

        q = self.to_q(x)

        kv_input = torch.cat((mem_kv, lmem, mem, x), dim=1)
        kv_len = kv_input.shape[1]
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)

        merge_heads = lambda x: reshape_dim(x, -1, (-1, dim_h)).transpose(1, 2)
        q, k, v = map(merge_heads, (q, k, v))

        k, v = map(lambda x: x.expand(-1, h, -1, -1), (k, v))

        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        mask_value = max_neg_value(dots)

        if pos_emb is not None:
            pos_emb = pos_emb[:, -kv_len:].type(q.dtype)
            pos_dots = torch.einsum('bhid,hjd->bhij', q, pos_emb) * self.scale
            pos_dots = shift(pos_dots)
            pos_dots = F.pad(pos_dots, (dots.shape[-1] - pos_dots.shape[-1], 0), value = 0.)
            dots = dots + pos_dots

        if input_mask is not None:
            mask = input_mask[:, None, :, None] * input_mask[:, None, None, :]
            mask = F.pad(mask, (mem_len + lmem_len + mem_kv_len, 0), value = True)
            dots.masked_fill_(~mask, mask_value)

        total_mem_len = mem_len + lmem_len + mem_kv_len
        mask = torch.ones(t, t + total_mem_len, **to(x)).triu_(diagonal = 1 + total_mem_len).bool()
        dots.masked_fill_(mask[None, None, ...], mask_value)

        attn = dots.softmax(dim=-1)
        attn = self.attn_dropout(attn)

        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = out.transpose(1, 2).reshape(b, t, -1)
        out = self.to_out(out)

        return self.dropout(out)

# 定义线性注意力机制函数
def linear_attn(q, k, v):
    q, k = q.softmax(dim=-1), k.softmax(dim=-2)
    context = torch.einsum('bhnd,bhne->bhde', k, v)
    out = torch.einsum('bhnd,bhde->bhne', q, context)
    return out

# 定义全连接注意力机制函数
def full_attn(q, k, v):
    dots = torch.einsum('bhid,bhjd->bhij', q, k) * q.shape[-1] ** -0.5
    dots = dots.softmax(dim=-1)
    out = torch.einsum('bhij,bhjd->bhid', dots, v)
    return out

# 定义线性自注意力类
class LinearSelfAttention(nn.Module):
    def __init__(self, dim, depth, heads = 8):
        super().__init__()
        self.dim_head = dim // heads
        self.norm = nn.LayerNorm(dim, elementwise_affine = False)

        self.to_q = init_parameter((dim, dim), dim)
        self.to_kv = init_parameter((dim, 2 * dim), dim)
        self.to_out = init_parameter((dim, dim), dim)
    # 定义一个前向传播函数,接受输入 x 和隐藏状态 hiddens,默认为 None
    def forward(self, x, hiddens = None):
        # 获取头部维度
        dim_head = self.dim_head
        # 复制权重矩阵 w_q, w_kv, w_out
        w_q, w_kv, w_out = map(torch.clone, (self.to_q, self.to_kv, self.to_out))
        
        # 对输入 x 进行归一化处理
        normed_lmem = self.norm(x)
        # 计算查询向量 q
        q = torch.einsum('bnd,de->bne', normed_lmem, w_q)

        # 将输入 x 和隐藏状态 hiddens 拼接在一起作为键值对输入
        kv_input = torch.cat((normed_lmem, hiddens), dim=1)
        # 计算键 k 和值 v
        k, v = torch.einsum('bnd,de->bne', kv_input, w_kv).chunk(2, dim=-1)

        # 将查询 q、键 k、值 v 进行维度重塑和转置
        q, k, v = map(lambda t: reshape_dim(t, -1, (-1, dim_head)).transpose(-2, -3), (q, k, v))

        # 使用线性注意力函数计算输出
        out = linear_attn(q, k, v)

        # 将输出进行维度转置和重塑,使其形状与输入 x 相同
        out = out.transpose(2, 3).reshape_as(x)
        # 使用权重矩阵 w_out 对输出进行线性变换
        out = torch.einsum('bnd,de->bne', out, w_out)
        # 返回处理后的输出
        return out
# 定义一个内存注意力网络的类,继承自 nn.Module
class MemoryAttentionNetwork(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(self, dim, num_memory_depth, mem_len, lmem_len, heads = 4, num_attn_steps = 2, num_mem_kv = 4, mem_write_iters = 2):
        super().__init__()
        # 初始化内存深度、内存长度和长期内存长度等属性
        self.num_memory_depth = num_memory_depth
        self.mem_len = mem_len
        self.lmem_len = lmem_len

        self.dim = dim
        dim_head = dim // heads
        self.dim_head = dim_head

        # 初始化深度嵌入、初始长期内存和长期内存位置嵌入等参数
        self.depth_emb = init_parameter((num_memory_depth, 1, 1, 1), dim)
        self.init_lmem = init_parameter((1, 1, dim), dim)
        self.lmem_pos_emb = init_parameter((1, lmem_len, dim), dim)

        self.mem_kv = init_parameter((1, num_mem_kv, dim), dim)

        # 初始化自注意力层和门控循环单元
        self.attn = LinearSelfAttention(dim, num_memory_depth, heads = heads)
        self.gate = nBRC(dim, dim)
        self.mem_write_iters = mem_write_iters

    # 前向传播函数,接受多个参数
    def forward(self, lmem, smem, hiddens, detach_lmem = False):
        batch, dim, dim_head, mem_depth, lmem_len = lmem.shape[0], self.dim, self.dim_head, self.num_memory_depth, self.lmem_len

        # 适当地分离隐藏状态,并在给定截断信号时分离长期内存
        hiddens = hiddens.detach()

        if detach_lmem:
            lmem = lmem.detach()

        # 如果没有提供长期内存状态,则初始化长期内存状态
        if lmem is None or lmem.shape[1] == 0:
            lmem = self.init_lmem.clone().expand(batch, lmem_len, -1)

        # 使用高效的线性注意力更新长期内存
        next_lmem = lmem + self.lmem_pos_emb

        hiddens_and_smem = torch.cat((smem, hiddens), dim=-2)
        all_hiddens = (hiddens_and_smem + self.depth_emb).transpose(0, 1).reshape(batch, -1, dim)
        all_hiddens = torch.cat((all_hiddens, self.mem_kv.expand(batch, -1, -1)), dim=1)

        # 迭代执行内存写入操作
        for _ in range(self.mem_write_iters):
            attn_out = self.attn(next_lmem, hiddens = all_hiddens)
            next_lmem = self.gate(attn_out, next_lmem)

        # FIFO队列短期内存
        _, next_mem = queue_fifo(smem, hiddens, length = self.mem_len, dim = 2)

        # 返回更新后的短期内存和长期内存
        return Memory(short = next_mem.detach(), long = next_lmem)

# transformer

class MemoryTransformerXL(nn.Module):
    # 初始化模型参数
    def __init__(self, num_tokens, dim, seq_len, depth, emb_dim = None, memory_layers = None, mem_len = None, lmem_len = None, heads = 8, gru_gated_residual = True, mogrify_gru = False, attn_dropout = 0., ff_glu = False, ff_dropout = 0., attn_layer_dropout = 0., one_kv_head = False, num_mem_kv = 0, mem_write_iters = 2):
        super().__init__()
        # 设置默认的嵌入维度
        emb_dim = default(emb_dim, dim)
        # 设置默认的短期记忆长度
        mem_len = default(mem_len, seq_len)
        # 设置默认的长期记忆长度
        lmem_len = default(lmem_len, mem_len)

        # 设置默认的记忆层
        memory_layers = default(memory_layers, list(range(1, depth + 1)))

        # 检查所有指定的记忆层是否有效
        assert all([layer > 0 and layer <= depth for layer in memory_layers]), 'one of the indicated memory layers is invalid'

        # 初始化模型参数
        self.mem_len = mem_len
        self.seq_len = seq_len

        self.depth = depth
        self.memory_layers = list(memory_layers)

        # 创建 token 的嵌入层
        self.token_emb = nn.Embedding(num_tokens, emb_dim)
        # 将嵌入维度转换为模型维度
        self.to_model_dim = nn.Identity() if emb_dim == dim else nn.Linear(emb_dim, dim)

        seq_and_mem_len = seq_len + mem_len + lmem_len
        # 创建位置编码参数
        self.pos_emb = nn.Parameter(torch.zeros(heads, seq_and_mem_len, dim // heads))
        
        # 创建输出层
        self.to_logits = nn.Sequential(
            nn.Identity() if emb_dim == dim else nn.Linear(dim, emb_dim),
            nn.Linear(emb_dim, num_tokens)
        )

        # 根据是否使用 GRU 门控残差来选择包装器
        wrapper = partial(GRUGating, dim, mogrify = mogrify_gru) if gru_gated_residual else Residual

        # 创建注意力层和前馈层
        self.attn_layers = nn.ModuleList([wrapper(PreNorm(dim, SelfAttention(dim, seq_len, mem_len, lmem_len, heads, dropout = attn_layer_dropout, attn_dropout = attn_dropout, one_kv_head = one_kv_head, num_mem_kv = num_mem_kv))) for _ in range(depth)])
        self.ff_layers = nn.ModuleList([wrapper(PreNorm(dim, FeedForward(dim, dropout = ff_dropout, glu = ff_glu))) for _ in range(depth)])

        # 创建记忆网络
        self.memory_network = MemoryAttentionNetwork(dim, len(self.memory_layers), mem_len, lmem_len, num_mem_kv = num_mem_kv, mem_write_iters = mem_write_iters)

    # 前向传播函数
    def forward(self, x, memories = None, mask = None, detach_lmem = False):
        # 对输入进行 token 嵌入
        x = self.token_emb(x)
        x = self.to_model_dim(x)
        b, t, d = x.shape

        # 检查输入序列长度是否超过最大序列长度
        assert t <= self.seq_len, f'input contains a sequence length {t} that is greater than the designated maximum sequence length {self.seq_len}'

        memories = default(memories, (None, None))
        mem, lmem = memories

        num_memory_layers = len(self.memory_layers)

        # 初始化记忆
        mem = default(mem, lambda: torch.empty(num_memory_layers, b, 0, d, **to(x)))
        lmem = default(lmem, lambda: torch.empty(b, 0, d, **to(x)))

        mem_len, lmem_len = map(lambda t: t.shape[2], (mem, lmem))
        total_len = mem_len + lmem_len + self.seq_len

        # 获取位置编码
        pos_emb = self.pos_emb[:, (self.seq_len - t):total_len]

        mem_iter = iterate_tensor(mem)

        hiddens = []

        # 遍历注意力层和前馈层
        for ind, (attn, ff) in enumerate(zip(self.attn_layers, self.ff_layers)):
            layer_num = ind + 1
            use_memory = layer_num in self.memory_layers
            memories = (next(mem_iter), lmem) if use_memory else None

            if use_memory:
                hiddens.append(x)

            x = attn(x, memories = memories, input_mask = mask, pos_emb = pos_emb)
            x = ff(x)

        hiddens = torch.stack(hiddens)
        out = self.to_logits(x)

        # 计算下一个记忆状态
        # 只有在输入序列长度达到最大时才将隐藏状态推送到短期记忆中

        if t < self.mem_len:
            return out, Memory(short = mem, long = lmem)

        next_memory = self.memory_network(lmem, mem, hiddens, detach_lmem = detach_lmem)
        return out, next_memory

.\lucidrains\memory-transformer-xl\memory_transformer_xl\__init__.py

# 从 memory_transformer_xl.memory_transformer_xl 模块中导入 MemoryTransformerXL 类
from memory_transformer_xl.memory_transformer_xl import MemoryTransformerXL

Memory Transformer-XL

A combination of Transformer-XL with ideas from Memory Transformers. While in Transformer-XL the memory is just a FIFO queue, this repository will attempt to update the memory (queries) against the incoming hidden states (keys / values) with a memory attention network. The memory attention network will utilize linear attention to be performant, followed by GRU gating, and will be backpropagated through time to learn how to properly store and discard new/old memory.

Install

$ pip install memory-transformer-xl

Usage

import torch
from memory_transformer_xl import MemoryTransformerXL

model = MemoryTransformerXL(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 8,
    seq_len = 512,
    mem_len = 256,            # short term memory (the memory from transformer-xl)
    lmem_len = 256,           # long term memory (memory attention network attending to short term memory and hidden activations)
    mem_write_iters = 2,      # number of iterations of attention for writing to memory
    memory_layers = [6,7,8],  # which layers to use memory, only the later layers are actually needed
    num_mem_kv = 128,         # number of memory key/values, from All-attention paper

).cuda()

x1 = torch.randint(0, 20000, (1, 512)).cuda()
logits1, mem1 = model(x1)

x2 = torch.randint(0, 20000, (1, 512)).cuda()
logits2, mem2 = model(x2, memories = mem1)

# and so on with carrying over memories...

Citations

@article{Dai_2019,
   title  = {Transformer-XL: Attentive Language Models beyond a Fixed-Length Context},
   url    = {http://dx.doi.org/10.18653/v1/P19-1285},
   DOI    = {10.18653/v1/p19-1285},
   journal={Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics},
   publisher = {Association for Computational Linguistics},
   author = {Dai, Zihang and Yang, Zhilin and Yang, Yiming and Carbonell, Jaime and Le, Quoc and Salakhutdinov, Ruslan},
   year = {2019}
}
@misc{burtsev2020memory,
    title   = {Memory Transformer},
    author  = {Mikhail S. Burtsev and Grigory V. Sapunov},
    year    = {2020},
    eprint  = {2006.11527},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{parisotto2019stabilizing,
    title     = {Stabilizing Transformers for Reinforcement Learning},
    author    = {Emilio Parisotto and H. Francis Song and Jack W. Rae and Razvan Pascanu and Caglar Gulcehre and Siddhant M. Jayakumar and Max Jaderberg and Raphael Lopez Kaufman and Aidan Clark and Seb Noury and Matthew M. Botvinick and Nicolas Heess and Raia Hadsell},
    year      = {2019},
    eprint    = {1910.06764},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@article{shen2019efficient,
  author    = {Zhuoran Shen and Mingyuan Zhang and Haiyu Zhao and Shuai Yi and Hongsheng Li},
  title     = {Efficient Attention: Attention with Linear Complexities},
  journal   = {CoRR},
  volume    = {abs/1812.01243},
  year      = {2018},
  url       = {http://arxiv.org/abs/1812.01243}
}
@article{DBLP:journals/corr/abs-1907-01470,
    author    = {Sainbayar Sukhbaatar and
               Edouard Grave and
               Guillaume Lample and
               Herv{\'{e}} J{\'{e}}gou and
               Armand Joulin},
    title     = {Augmenting Self-attention with Persistent Memory},
    journal   = {CoRR},
    volume    = {abs/1907.01470},
    year      = {2019},
    url       = {http://arxiv.org/abs/1907.01470}
}
@misc{vecoven2020bioinspired,
    title   = {A bio-inspired bistable recurrent cell allows for long-lasting memory},
    author  = {Nicolas Vecoven and Damien Ernst and Guillaume Drion},
    year    = {2020},
    eprint  = {2006.05252},
    archivePrefix = {arXiv},
    primaryClass = {cs.NE}
}

Memory is attention through time - Alex Graves

.\lucidrains\memory-transformer-xl\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的信息
setup(
  # 包的名称
  name = 'memory-transformer-xl',
  # 查找包,排除 examples 文件夹
  packages = find_packages(exclude=['examples']),
  # 版本号
  version = '0.1.0',
  # 许可证
  license='MIT',
  # 描述
  description = 'Memory Transformer-XL, a variant of Transformer-XL that uses linear attention update long term memory',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/memory-transformer-xl',
  # 关键词
  keywords = ['attention mechanism', 'artificial intelligence', 'transformer', 'deep learning'],
  # 安装依赖
  install_requires=[
      'torch',
      'mogrifier'
  ],
  # 分类
  classifiers=[
      'Development Status :: 4 - Beta',
      'Intended Audience :: Developers',
      'Topic :: Scientific/Engineering :: Artificial Intelligence',
      'License :: OSI Approved :: MIT License',
      'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\meshgpt-pytorch\meshgpt_pytorch\data.py

# 导入必要的库
from pathlib import Path
from functools import partial
import torch
from torch import Tensor
from torch import is_tensor
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence

import numpy as np
from numpy.lib.format import open_memmap

from einops import rearrange, reduce

from beartype import beartype
from beartype.typing import Tuple, List, Union, Optional, Callable, Dict, Callable

from torchtyping import TensorType

from pytorch_custom_utils.utils import pad_or_slice_to

# 定义辅助函数

def exists(v):
    return v is not None

def identity(t):
    return t

# 定义常量

Vertices = TensorType['nv', 3, float]   # 3个坐标
Faces = TensorType['nf', 3, int]        # 3个顶点

# 用于自动缓存文本到文本嵌入的装饰器

# 你可以用这个装饰器装饰你的 Dataset 类
# 然后改变你的 `data_kwargs = ["text_embeds", "vertices", "faces"]`

@beartype
def cache_text_embeds_for_dataset(
    embed_texts_fn: Callable[[List[str]], Tensor],
    max_text_len: int,
    cache_path: str = './text_embed_cache'
):
    # 创建缓存文件夹路径

    path = Path(cache_path)
    path.mkdir(exist_ok = True, parents = True)
    assert path.is_dir()

    # 全局内存映射句柄

    text_embed_cache = None
    is_cached = None

    # 缓存函数

    def get_maybe_cached_text_embed(
        idx: int,
        dataset_len: int,
        text: str,
        memmap_file_mode = 'w+'
    ):
        nonlocal text_embed_cache
        nonlocal is_cached

        # 在第一次调用时初始化缓存

        if not exists(text_embed_cache):
            test_embed = embed_texts_fn(['test'])
            feat_dim = test_embed.shape[-1]
            shape = (dataset_len, max_text_len, feat_dim)

            text_embed_cache = open_memmap(str(path / 'cache.text_embed.memmap.npy'), mode = memmap_file_mode, dtype = 'float32', shape = shape)
            is_cached = open_memmap(str(path / 'cache.is_cached.memmap.npy'), mode = memmap_file_mode, dtype = 'bool', shape = (dataset_len,))

        # 确定是从缓存中获取还是调用文本模型

        if is_cached[idx]:
            text_embed = torch.from_numpy(text_embed_cache[idx])
        else:
            # 缓存

            text_embed = get_text_embed(text)
            text_embed = pad_or_slice_to(text_embed, max_text_len, dim = 0, pad_value = 0.)

            is_cached[idx] = True
            text_embed_cache[idx] = text_embed.cpu().numpy()

        mask = ~reduce(text_embed == 0, 'n d -> n', 'all')
        return text_embed[mask]

    # 获取文本嵌入

    def get_text_embed(text: str):
        text_embeds = embed_texts_fn([text])
        return text_embeds[0]

    # 内部函数
    # 定义一个装饰器函数,接受一个数据集类作为参数
    def inner(dataset_klass):
        # 断言数据集类是 Dataset 类的子类
        assert issubclass(dataset_klass, Dataset)

        # 保存原始的 __init__ 和 __getitem__ 方法
        orig_init = dataset_klass.__init__
        orig_get_item = dataset_klass.__getitem__

        # 定义新的 __init__ 方法
        def __init__(
            self,
            *args,
            cache_memmap_file_mode = 'w+',
            **kwargs
        ):
            # 调用原始的 __init__ 方法
            orig_init(self, *args, **kwargs)

            # 设置缓存内存映射文件的模式
            self._cache_memmap_file_mode = cache_memmap_file_mode

            # 如果数据集类有 data_kwargs 属性,则将其中的 'texts' 替换为 'text_embeds'
            if hasattr(self, 'data_kwargs'):
                self.data_kwargs = [('text_embeds' if data_kwarg == 'texts' else data_kwarg) for data_kwarg in self.data_kwargs]

        # 定义新的 __getitem__ 方法
        def __getitem__(self, idx):
            # 调用原始的 __getitem__ 方法
            items = orig_get_item(self, idx)

            # 定义局部函数 get_text_embed_,用于获取可能缓存的文本嵌入
            get_text_embed_ = partial(get_maybe_cached_text_embed, idx, len(self), memmap_file_mode = self._cache_memmap_file_mode)

            # 如果 items 是字典
            if isinstance(items, dict):
                # 如果字典中包含 'texts' 键
                if 'texts' in items:
                    # 获取文本嵌入并替换 'texts' 键为 'text_embeds'
                    text_embed = get_text_embed_(items['texts'])
                    items['text_embeds'] = text_embed
                    del items['texts']

            # 如果 items 是元组
            elif isinstance(items, tuple):
                new_items = []

                # 遍历元组中的每个元素
                for maybe_text in items:
                    # 如果元素不是字符串,则直接添加到新列表中
                    if not isinstance(maybe_text, str):
                        new_items.append(maybe_text)
                        continue

                    # 如果元素是字符串,则获取文本嵌入并添加到新列表中
                    new_items.append(get_text_embed_(maybe_text))

                # 更新 items 为新的元组
                items = tuple(new_items)

            # 返回处理后的 items
            return items

        # 替换数据集类的 __init__ 和 __getitem__ 方法为新定义的方法
        dataset_klass.__init__ = __init__
        dataset_klass.__getitem__ = __getitem__

        # 返回修改后的数据集类
        return dataset_klass

    # 返回装饰器函数 inner
    return inner
# 用于自动缓存面边缘的装饰器

# 你可以用这个函数装饰你的 Dataset 类
# 然后改变你的 `data_kwargs = ["vertices", "faces", "face_edges"]`

@beartype
def cache_face_edges_for_dataset(
    max_edges_len: int,
    cache_path: str = './face_edges_cache',
    assert_edge_len_lt_max: bool = True,
    pad_id = -1
):
    # 创建缓存文件夹路径

    path = Path(cache_path)
    path.mkdir(exist_ok = True, parents = True)
    assert path.is_dir()

    # 全局 memmap 句柄

    face_edges_cache = None
    is_cached = None

    # 缓存函数

    def get_maybe_cached_face_edges(
        idx: int,
        dataset_len: int,
        faces: Tensor,
        memmap_file_mode = 'w+'
    ):
        nonlocal face_edges_cache
        nonlocal is_cached

        if not exists(face_edges_cache):
            # 在第一次调用时初始化缓存

            shape = (dataset_len, max_edges_len, 2)
            face_edges_cache = open_memmap(str(path / 'cache.face_edges_embed.memmap.npy'), mode = memmap_file_mode, dtype = 'float32', shape = shape)
            is_cached = open_memmap(str(path / 'cache.is_cached.memmap.npy'), mode = memmap_file_mode, dtype = 'bool', shape = (dataset_len,))

        # 确定是从缓存中获取还是调用派生面边缘函数

        if is_cached[idx]:
            face_edges = torch.from_numpy(face_edges_cache[idx])
        else:
            # 缓存

            face_edges = derive_face_edges_from_faces(faces, pad_id = pad_id)

            edge_len = face_edges.shape[0]
            assert not assert_edge_len_lt_max or (edge_len <= max_edges_len), f'mesh #{idx} has {edge_len} which exceeds the cache length of {max_edges_len}'

            face_edges = pad_or_slice_to(face_edges, max_edges_len, dim = 0, pad_value = pad_id)

            is_cached[idx] = True
            face_edges_cache[idx] = face_edges.cpu().numpy()

        mask = reduce(face_edges != pad_id, 'n d -> n', 'all')
        return face_edges[mask]

    # 内部函数

    def inner(dataset_klass):
        assert issubclass(dataset_klass, Dataset)

        orig_init = dataset_klass.__init__
        orig_get_item = dataset_klass.__getitem__

        def __init__(
            self,
            *args,
            cache_memmap_file_mode = 'w+',
            **kwargs
        ):
            orig_init(self, *args, **kwargs)

            self._cache_memmap_file_mode = cache_memmap_file_mode

            if hasattr(self, 'data_kwargs'):
                self.data_kwargs.append('face_edges')

        def __getitem__(self, idx):
            items = orig_get_item(self, idx)

            get_face_edges_ = partial(get_maybe_cached_face_edges, idx, len(self), memmap_file_mode = self._cache_memmap_file_mode)

            if isinstance(items, dict):
                face_edges = get_face_edges_(items['faces'])
                items['face_edges'] = face_edges

            elif isinstance(items, tuple):
                _, faces, *_ = items
                face_edges = get_face_edges_(faces)
                items = (*items, face_edges)

            return items

        dataset_klass.__init__ = __init__
        dataset_klass.__getitem__ = __getitem__

        return dataset_klass

    return inner

# 数据集

class DatasetFromTransforms(Dataset):
    @beartype
    def __init__(
        self,
        folder: str,
        transforms: Dict[str, Callable[[Path], Tuple[Vertices, Faces]]],
        data_kwargs: Optional[List[str]] = None,
        augment_fn: Callable = identity
    ):
        folder = Path(folder)
        assert folder.exists and folder.is_dir()
        self.folder = folder

        exts = transforms.keys()
        self.paths = [p for ext in exts for p in folder.glob(f'**/*.{ext}')]

        print(f'{len(self.paths)} training samples found at {folder}')
        assert len(self.paths) > 0

        self.transforms = transforms
        self.data_kwargs = data_kwargs
        self.augment_fn = augment_fn
    # 返回路径列表的长度,即数据集中的样本数量
    def __len__(self):
        return len(self.paths)

    # 根据索引获取数据集中指定位置的样本
    def __getitem__(self, idx):
        # 获取指定索引位置的路径
        path = self.paths[idx]
        # 获取路径的文件扩展名
        ext = path.suffix[1:]
        # 根据文件扩展名获取对应的转换函数
        fn = self.transforms[ext]

        # 使用转换函数处理路径对应的数据
        out = fn(path)
        # 对处理后的数据进行增强处理
        return self.augment_fn(out)
# tensor helper functions

# 从面数据中推导出面的边
def derive_face_edges_from_faces(
    faces: TensorType['b', 'nf', 3, int],  # 输入的面数据,形状为 [batch_size, num_faces, 3, int]
    pad_id = -1,  # 填充值,默认为 -1
    neighbor_if_share_one_vertex = False,  # 如果共享一个顶点则为邻居,默认为 False
    include_self = True  # 是否包括自身,默认为 True
) -> TensorType['b', 'e', 2, int]:  # 返回的面边数据,形状为 [batch_size, num_edges, 2, int]

    is_one_face, device = faces.ndim == 2, faces.device  # 判断是否只有一个面,获取设备信息

    if is_one_face:
        faces = rearrange(faces, 'nf c -> 1 nf c')  # 如果只有一个面,则重排维度为 [1, num_faces, 3, int]

    max_num_faces = faces.shape[1]  # 获取最大面数
    face_edges_vertices_threshold = 1 if neighbor_if_share_one_vertex else 2  # 根据是否共享一个顶点设置阈值

    all_edges = torch.stack(torch.meshgrid(
        torch.arange(max_num_faces, device = device),
        torch.arange(max_num_faces, device = device),
    indexing = 'ij'), dim = -1)  # 创建所有可能的边的组合

    face_masks = reduce(faces != pad_id, 'b nf c -> b nf', 'all')  # 根据填充值生成面的掩码
    face_edges_masks = rearrange(face_masks, 'b i -> b i 1') & rearrange(face_masks, 'b j -> b 1 j')  # 生成面边的掩码

    face_edges = []  # 存储面边数据的列表

    for face, face_edge_mask in zip(faces, face_edges_masks):

        shared_vertices = rearrange(face, 'i c -> i 1 c 1') == rearrange(face, 'j c -> 1 j 1 c')  # 判断是否共享顶点
        num_shared_vertices = shared_vertices.any(dim = -1).sum(dim = -1)  # 统计共享顶点的数量

        is_neighbor_face = (num_shared_vertices >= face_edges_vertices_threshold) & face_edge_mask  # 判断是否为邻居面

        if not include_self:
            is_neighbor_face &= num_shared_vertices != 3  # 排除自身面

        face_edge = all_edges[is_neighbor_face]  # 获取邻居面的边
        face_edges.append(face_edge)  # 添加到面边列表中

    face_edges = pad_sequence(face_edges, padding_value = pad_id, batch_first = True)  # 对面边进行填充

    if is_one_face:
        face_edges = rearrange(face_edges, '1 e ij -> e ij')  # 如果只有一个面,则重排维度

    return face_edges  # 返回面边数据

# custom collater

# 获取列表中的第一个元素
def first(it):
    return it[0]

# 自定义数据集拼接函数
def custom_collate(data, pad_id = -1):
    is_dict = isinstance(first(data), dict)  # 判断数据是否为字典类型

    if is_dict:
        keys = first(data).keys()  # 获取字典的键
        data = [d.values() for d in data]  # 获取字典的值

    output = []  # 存储输出数据的列表

    for datum in zip(*data):
        if is_tensor(first(datum)):
            datum = pad_sequence(datum, batch_first = True, padding_value = pad_id)  # 如果是张量,则进行填充
        else:
            datum = list(datum)  # 否则转换为列表

        output.append(datum)  # 添加到输出列表中

    output = tuple(output)  # 转换为元组

    if is_dict:
        output = dict(zip(keys, output))  # 如果是字典类型,则重新组合为字典

    return output  # 返回拼接后的数据

.\lucidrains\meshgpt-pytorch\meshgpt_pytorch\meshgpt_pytorch.py

# 导入所需的模块
from pathlib import Path
from functools import partial
from math import ceil, pi, sqrt

import torch
from torch import nn, Tensor, einsum
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from torch.cuda.amp import autocast

# 导入自定义的类型注解
from torchtyping import TensorType

# 导入自定义的工具函数
from pytorch_custom_utils import save_load

# 导入类型注解相关的模块
from beartype import beartype
from beartype.typing import Union, Tuple, Callable, Optional, List, Dict, Any

# 导入 einops 库中的函数
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange

# 导入 einx 库中的函数
from einx import get_at

# 导入 x_transformers 库中的模块和函数
from x_transformers import Decoder
from x_transformers.attend import Attend
from x_transformers.x_transformers import RMSNorm, FeedForward, LayerIntermediates

# 导入自动回归包装器相关的函数
from x_transformers.autoregressive_wrapper import (
    eval_decorator,
    top_k,
    top_p,
)

# 导入本地注意力相关的函数
from local_attention import LocalMHA

# 导入向量量化相关的函数
from vector_quantize_pytorch import (
    ResidualVQ,
    ResidualLFQ
)

# 导入 meshgpt_pytorch 库中的函数
from meshgpt_pytorch.data import derive_face_edges_from_faces
from meshgpt_pytorch.version import __version__

# 导入 Taylor 级数线性注意力相关的函数
from taylor_series_linear_attention import TaylorSeriesLinearAttn

# 导入无分类器引导相关的函数
from classifier_free_guidance_pytorch import (
    classifier_free_guidance,
    TextEmbeddingReturner
)

# 导入 torch_geometric 库中的函数
from torch_geometric.nn.conv import SAGEConv

# 导入 gateloop_transformer 库中的函数
from gateloop_transformer import SimpleGateLoopLayer

# 导入 tqdm 库中的函数
from tqdm import tqdm

# 定义一些辅助函数

# 检查变量是否存在
def exists(v):
    return v is not None

# 返回默认值
def default(v, d):
    return v if exists(v) else d

# 返回迭代器的第一个元素
def first(it):
    return it[0]

# 检查一个数是否可以被另一个数整除
def divisible_by(num, den):
    return (num % den) == 0

# 检查一个数是否为奇数
def is_odd(n):
    return not divisible_by(n, 2)

# 检查列表是否为空
def is_empty(l):
    return len(l) == 0

# 检查张量是否为空
def is_tensor_empty(t: Tensor):
    return t.numel() == 0

# 设置模块的 requires_grad 属性
def set_module_requires_grad_(
    module: Module,
    requires_grad: bool
):
    for param in module.parameters():
        param.requires_grad = requires_grad

# 计算张量的 L1 范数
def l1norm(t):
    return F.normalize(t, dim = -1, p = 1)

# 计算张量的 L2 范数
def l2norm(t):
    return F.normalize(t, dim = -1, p = 2)

# 安全地拼接张量
def safe_cat(tensors, dim):
    tensors = [*filter(exists, tensors)

    if len(tensors) == 0:
        return None
    elif len(tensors) == 1:
        return first(tensors)

    return torch.cat(tensors, dim = dim)

# 在指定维度上填充张量
def pad_at_dim(t, padding, dim = -1, value = 0):
    ndim = t.ndim
    right_dims = (ndim - dim - 1) if dim >= 0 else (-dim - 1)
    zeros = (0, 0) * right_dims
    return F.pad(t, (*zeros, *padding), value = value)

# 将张量填充到指定长度
def pad_to_length(t, length, dim = -1, value = 0, right = True):
    curr_length = t.shape[dim]
    remainder = length - curr_length

    if remainder <= 0:
        return t

    padding = (0, remainder) if right else (remainder, 0)
    return pad_at_dim(t, padding, dim = dim, value = value)

# 连续嵌入

def ContinuousEmbed(dim_cont):
    return nn.Sequential(
        Rearrange('... -> ... 1'),
        nn.Linear(1, dim_cont),
        nn.SiLU(),
        nn.Linear(dim_cont, dim_cont),
        nn.LayerNorm(dim_cont)
    )

# 获取派生的面特征
# 1. 角度 (3), 2. 面积 (1), 3. 法线 (3)

# 计算两个向量之间的夹角
def derive_angle(x, y, eps = 1e-5):
    z = einsum('... d, ... d -> ...', l2norm(x), l2norm(y))
    return z.clip(-1 + eps, 1 - eps).arccos()

# 获取派生的面特征
@torch.no_grad()
def get_derived_face_features(
    face_coords: TensorType['b', 'nf', 'nvf', 3, float]  # 3 or 4 vertices with 3 coordinates
):
    shifted_face_coords = torch.cat((face_coords[:, :, -1:], face_coords[:, :, :-1]), dim = 2)

    angles  = derive_angle(face_coords, shifted_face_coords)

    edge1, edge2, *_ = (face_coords - shifted_face_coords).unbind(dim = 2)

    normals = l2norm(torch.cross(edge1, edge2, dim = -1))
    area = normals.norm(dim = -1, keepdim = True) * 0.5

    return dict(
        angles = angles,
        area = area,
        normals = normals
    )   

# 张量辅助函数

# 将连续值离散化
@beartype
def discretize(
    t: Tensor,
    *,
    continuous_range: Tuple[float, float],
    num_discrete: int = 128
) -> Tensor:
    lo, hi = continuous_range
    # 断言高值大于低值,确保输入的范围是有效的
    assert hi > lo
    
    # 将输入值 t 根据给定的范围进行归一化处理
    t = (t - lo) / (hi - lo)
    # 将归一化后的值映射到离散值范围内
    t *= num_discrete
    # 将映射后的值进行偏移,使得离散值范围从0开始
    t -= 0.5
    
    # 将处理后的值四舍五入取整,并转换为长整型,然后限制在离散值范围内
    return t.round().long().clamp(min=0, max=num_discrete - 1)
# 使用 beartype 装饰器对 undiscretize 函数进行类型检查
@beartype
# 将连续值转换为离散值
def undiscretize(
    t: Tensor,  # 输入张量
    *,
    continuous_range = Tuple[float, float],  # 连续值范围
    num_discrete: int = 128  # 离散值数量
) -> Tensor:  # 返回张量
    lo, hi = continuous_range  # 解包连续值范围
    assert hi > lo  # 断言确保上限大于下限

    t = t.float()  # 将输入张量转换为浮点型

    t += 0.5  # 加上0.5
    t /= num_discrete  # 除以离散值数量
    return t * (hi - lo) + lo  # 返回转换后的张量

# 使用 beartype 装饰器对 gaussian_blur_1d 函数进行类型检查
@beartype
# 一维高斯模糊
def gaussian_blur_1d(
    t: Tensor,  # 输入张量
    *,
    sigma: float = 1.  # 高斯模糊的标准差
) -> Tensor:  # 返回张量

    _, _, channels, device, dtype = *t.shape, t.device, t.dtype  # 解包张量的形状、设备和数据类型

    width = int(ceil(sigma * 5))  # 计算模糊核的宽度
    width += (width + 1) % 2  # 确保宽度为奇数
    half_width = width // 2  # 计算宽度的一半

    distance = torch.arange(-half_width, half_width + 1, dtype = dtype, device = device)  # 生成距离张量

    gaussian = torch.exp(-(distance ** 2) / (2 * sigma ** 2))  # 计算高斯权重
    gaussian = l1norm(gaussian)  # 对高斯权重进行 L1 归一化

    kernel = repeat(gaussian, 'n -> c 1 n', c = channels)  # 重复高斯权重以匹配通道数

    t = rearrange(t, 'b n c -> b c n')  # 重新排列输入张量的维度
    out = F.conv1d(t, kernel, padding = half_width, groups = channels)  # 一维卷积操作
    return rearrange(out, 'b c n -> b n c')  # 重新排列输出张量的维度

# 使用 beartype 装饰器对 scatter_mean 函数进行类型检查
@beartype
# 对张量进行均值散点
def scatter_mean(
    tgt: Tensor,  # 目标张量
    indices: Tensor,  # 索引张量
    src = Tensor,  # 源张量
    *,
    dim: int = -1,  # 维度
    eps: float = 1e-5  # 防止除零的小值
):
    """
    todo: update to pytorch 2.1 and try https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_reduce_.html#torch.Tensor.scatter_reduce_
    """
    num = tgt.scatter_add(dim, indices, src)  # 使用索引张量将源张量的值加到目标张量上
    den = torch.zeros_like(tgt).scatter_add(dim, indices, torch.ones_like(src))  # 计算分母
    return num / den.clamp(min = eps)  # 返回均值

# resnet block

# 像素归一化模块
class PixelNorm(Module):
    def __init__(self, dim, eps = 1e-4):  # 初始化函数
        super().__init__()  # 调用父类初始化函数
        self.dim = dim  # 维度
        self.eps = eps  # 小值

    def forward(self, x):  # 前向传播函数
        dim = self.dim  # 获取维度
        return F.normalize(x, dim = dim, eps = self.eps) * sqrt(x.shape[dim])  # 返回归一化后的张量

# Squeeze-and-Excitation 模块
class SqueezeExcite(Module):
    def __init__(
        self,
        dim,
        reduction_factor = 4,  # 缩减因子
        min_dim = 16  # 最小维度
    ):
        super().__init__()  # 调用父类初始化函数
        dim_inner = max(dim // reduction_factor, min_dim)  # 计算内部维度

        self.net = nn.Sequential(  # 定义神经网络
            nn.Linear(dim, dim_inner),  # 线性层
            nn.SiLU(),  # SiLU 激活函数
            nn.Linear(dim_inner, dim),  # 线性层
            nn.Sigmoid(),  # Sigmoid 激活函数
            Rearrange('b c -> b c 1')  # 重新排列维度
        )

    def forward(self, x, mask = None):  # 前向传播函数
        if exists(mask):  # 如果存在掩码
            x = x.masked_fill(~mask, 0.)  # 使用掩码填充张量

            num = reduce(x, 'b c n -> b c', 'sum')  # 沿指定维度求和
            den = reduce(mask.float(), 'b 1 n -> b 1', 'sum')  # 沿指定维度求和
            avg = num / den.clamp(min = 1e-5)  # 计算均值
        else:
            avg = reduce(x, 'b c n -> b c', 'mean')  # 沿指定维度求均值

        return x * self.net(avg)  # 返回加权后的张量

# 基本块
class Block(Module):
    def __init__(
        self,
        dim,
        dim_out = None,
        dropout = 0.
    ):
        super().__init__()  # 调用父类初始化函数
        dim_out = default(dim_out, dim)  # 设置输出维度为输入维度

        self.proj = nn.Conv1d(dim, dim_out, 3, padding = 1)  # 一维卷积层
        self.norm = PixelNorm(dim = 1)  # 像素归一化
        self.dropout = nn.Dropout(dropout)  # 随机失活层
        self.act = nn.SiLU()  # SiLU 激活函数

    def forward(self, x, mask = None):  # 前向传播函数
        if exists(mask):  # 如果存在掩码
            x = x.masked_fill(~mask, 0.)  # 使用掩码填充张量

        x = self.proj(x)  # 卷积操作

        if exists(mask):  # 如果存在掩码
            x = x.masked_fill(~mask, 0.)  # 使用掩码填充张量

        x = self.norm(x)  # 像素归一化
        x = self.act(x)  # 激活函数
        x = self.dropout(x)  # 随机失活

        return x  # 返回处理后的张量

# ResNet 块
class ResnetBlock(Module):
    def __init__(
        self,
        dim,
        dim_out = None,
        *,
        dropout = 0.
    ):
        super().__init__()  # 调用父类初始化函数
        dim_out = default(dim_out, dim)  # 设置输出维度为输入维度
        self.block1 = Block(dim, dim_out, dropout = dropout)  # 基本块1
        self.block2 = Block(dim_out, dim_out, dropout = dropout)  # 基本块2
        self.excite = SqueezeExcite(dim_out)  # Squeeze-and-Excitation 模块
        self.residual_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity()  # 残差卷积层

    def forward(
        self,
        x,
        mask = None
    ):
        res = self.residual_conv(x)  # 残差连接
        h = self.block1(x, mask = mask)  # 第一个基本块
        h = self.block2(h, mask = mask)  # 第二个基本块
        h = self.excite(h, mask = mask)  # Squeeze-and-Excitation
        return h + res  # 返回残差连接结果

# gateloop 层

# 门循环块
class GateLoopBlock(Module):
    def __init__(
        self,
        dim,
        *,
        depth,
        use_heinsen = True
    # 初始化函数,继承父类的初始化方法
    ):
        # 初始化一个空的模块列表
        super().__init__()
        self.gateloops = ModuleList([])

        # 根据深度循环创建 SimpleGateLoopLayer 层,并添加到模块列表中
        for _ in range(depth):
            gateloop = SimpleGateLoopLayer(dim = dim, use_heinsen = use_heinsen)
            self.gateloops.append(gateloop)

    # 前向传播函数
    def forward(
        self,
        x,
        cache = None
    ):
        # 检查是否接收到缓存
        received_cache = exists(cache)

        # 检查输入张量是否为空
        if is_tensor_empty(x):
            return x, None

        # 如果接收到缓存,则将输入张量分为前面部分和最后一个元素
        if received_cache:
            prev, x = x[:, :-1], x[:, -1:]

        # 如果缓存为空,则初始化为空列表
        cache = default(cache, [])
        # 将缓存转换为迭代器
        cache = iter(cache)

        # 存储新的缓存
        new_caches = []
        # 遍历每个 SimpleGateLoopLayer 层
        for gateloop in self.gateloops:
            # 从缓存中获取当前层的缓存
            layer_cache = next(cache, None)
            # 调用当前层的前向传播方法,返回输出和新的缓存
            out, new_cache = gateloop(x, cache = layer_cache, return_cache = True)
            new_caches.append(new_cache)
            # 更新输入张量
            x = x + out

        # 如果接收到缓存,则将之前分离的部分和当前输出拼接在一起
        if received_cache:
            x = torch.cat((prev, x), dim = -2)

        # 返回更新后的输入张量和新的缓存列表
        return x, new_caches
# 主要类

# 使用装饰器 @save_load(version = __version__),保存和加载模型版本信息
class MeshAutoencoder(Module):
    # 初始化方法
    @beartype
    def __init__(
        self,
        num_discrete_coors = 128,  # 离散坐标数量
        coor_continuous_range: Tuple[float, float] = (-1., 1.),  # 连续坐标范围
        dim_coor_embed = 64,  # 坐标嵌入维度
        num_discrete_area = 128,  # 离散区域数量
        dim_area_embed = 16,  # 区域嵌入维度
        num_discrete_normals = 128,  # 离散法线数量
        dim_normal_embed = 64,  # 法线嵌入维度
        num_discrete_angle = 128,  # 离散角度数量
        dim_angle_embed = 16,  # 角度嵌入维度
        encoder_dims_through_depth: Tuple[int, ...] = (  # 编码器深度维度
            64, 128, 256, 256, 576
        ),
        init_decoder_conv_kernel = 7,  # 初始化解码器卷积核大小
        decoder_dims_through_depth: Tuple[int, ...] = (  # 解码器深度维度
            128, 128, 128, 128,
            192, 192, 192, 192,
            256, 256, 256, 256, 256, 256,
            384, 384, 384
        ),
        dim_codebook = 192,  # 代码簿维度
        num_quantizers = 2,  # 量化器数量
        codebook_size = 16384,  # 代码簿大小
        use_residual_lfq = True,  # 是否使用最新的无查找量化
        rq_kwargs: dict = dict(  # 量化器关键字参数
            quantize_dropout = True,
            quantize_dropout_cutoff_index = 1,
            quantize_dropout_multiple_of = 1,
        ),
        rvq_kwargs: dict = dict(  # RVQ关键字参数
            kmeans_init = True,
            threshold_ema_dead_code = 2,
        ),
        rlfq_kwargs: dict = dict(  # RLFQ关键字参数
            frac_per_sample_entropy = 1.
        ),
        rvq_stochastic_sample_codes = True,  # RVQ是否随机采样代码
        sageconv_kwargs: dict = dict(  # SageConv关键字参数
            normalize = True,
            project = True
        ),
        commit_loss_weight = 0.1,  # 提交损失权重
        bin_smooth_blur_sigma = 0.4,  # 模糊离散坐标位置的独热编码
        attn_encoder_depth = 0,  # 注意力编码器深度
        attn_decoder_depth = 0,  # 注意力解码器深度
        local_attn_kwargs: dict = dict(  # 本地注意力关键字参数
            dim_head = 32,
            heads = 8
        ),
        local_attn_window_size = 64,  # 本地注意力窗口大小
        linear_attn_kwargs: dict = dict(  # 线性注意力关键字参数
            dim_head = 8,
            heads = 16
        ),
        use_linear_attn = True,  # 是否使用线性注意力
        pad_id = -1,  # 填充ID
        flash_attn = True,  # 闪光注意力
        attn_dropout = 0.,  # 注意力丢弃率
        ff_dropout = 0.,  # 前馈丢弃率
        resnet_dropout = 0,  # ResNet丢弃率
        checkpoint_quantizer = False,  # 检查点量化器
        quads = False  # 四边形
    @beartype
    def encode(
        self,
        *,
        vertices:         TensorType['b', 'nv', 3, float],  # 顶点
        faces:            TensorType['b', 'nf', 'nvf', int],  # 面
        face_edges:       TensorType['b', 'e', 2, int],  # 面边
        face_mask:        TensorType['b', 'nf', bool],  # 面掩码
        face_edges_mask:  TensorType['b', 'e', bool],  # 面边掩码
        return_face_coordinates = False  # 返回面坐标
        """
        einops:
        b - batch
        nf - number of faces
        nv - number of vertices (3)
        nvf - number of vertices per face (3 or 4) - triangles vs quads
        c - coordinates (3)
        d - embed dim
        """

        # 获取顶点的批次、数量、坐标和设备信息
        batch, num_vertices, num_coors, device = *vertices.shape, vertices.device
        # 获取面的批次、数量和每个面的顶点数
        _, num_faces, num_vertices_per_face = faces.shape

        # 断言每个面的顶点数与预设的相同
        assert self.num_vertices_per_face == num_vertices_per_face

        # 根据 face_mask 对 faces 进行填充,将非有效面的值设为 0
        face_without_pad = faces.masked_fill(~rearrange(face_mask, 'b nf -> b nf 1'), 0)

        # 获取连续的面坐标
        face_coords = get_at('b [nv] c, b nf mv -> b nf mv c', vertices, face_without_pad)

        # 计算派生特征并嵌入
        derived_features = get_derived_face_features(face_coords)

        # 将角度离散化并嵌入
        discrete_angle = self.discretize_angle(derived_features['angles'])
        angle_embed = self.angle_embed(discrete_angle)

        # 将面积离散化并嵌入
        discrete_area = self.discretize_area(derived_features['area'])
        area_embed = self.area_embed(discrete_area)

        # 将法线离散化并嵌入
        discrete_normal = self.discretize_normals(derived_features['normals'])
        normal_embed = self.normal_embed(discrete_normal)

        # 为面坐标嵌入离散化顶点
        discrete_face_coords = self.discretize_face_coords(face_coords)
        discrete_face_coords = rearrange(discrete_face_coords, 'b nf nv c -> b nf (nv c)')

        # 对所有特征进行组合并投影到模型维度
        face_embed, _ = pack([face_coor_embed, angle_embed, area_embed, normal_embed], 'b nf *')
        face_embed = self.project_in(face_embed)

        # 处理变长的特征,使用 masked_select 和 masked_scatter
        face_index_offsets = reduce(face_mask.long(), 'b nf -> b', 'sum')
        face_index_offsets = F.pad(face_index_offsets.cumsum(dim = 0), (1, -1), value = 0)
        face_index_offsets = rearrange(face_index_offsets, 'b -> b 1 1')

        face_edges = face_edges + face_index_offsets
        face_edges = face_edges[face_edges_mask]
        face_edges = rearrange(face_edges, 'be ij -> ij be')

        orig_face_embed_shape = face_embed.shape[:2]

        face_embed = face_embed[face_mask]

        # 初始 sage conv 后跟激活和规范化
        face_embed = self.init_sage_conv(face_embed, face_edges)
        face_embed = self.init_encoder_act_and_norm(face_embed)

        # 对每个编码器进行操作
        for conv in self.encoders:
            face_embed = conv(face_embed, face_edges)

        shape = (*orig_face_embed_shape, face_embed.shape[-1])

        face_embed = face_embed.new_zeros(shape).masked_scatter(rearrange(face_mask, '... -> ... 1'), face_embed)

        # 对每个编码器的注意力块进行操作
        for linear_attn, attn, ff in self.encoder_attn_blocks:
            if exists(linear_attn):
                face_embed = linear_attn(face_embed, mask = face_mask) + face_embed

            face_embed = attn(face_embed, mask = face_mask) + face_embed
            face_embed = ff(face_embed) + face_embed

        # 如果不需要返回面坐标,则返回 face_embed
        if not return_face_coordinates:
            return face_embed

        # 否则返回 face_embed 和离散面坐标
        return face_embed, discrete_face_coords

    @beartype
    def quantize(
        self,
        *,
        faces: TensorType['b', 'nf', 'nvf', int],
        face_mask: TensorType['b', 'n', bool],
        face_embed: TensorType['b', 'nf', 'd', float],
        pad_id = None,
        rvq_sample_codebook_temp = 1.
    ):
        # 设置 pad_id 为默认值或者 self.pad_id
        pad_id = default(pad_id, self.pad_id)
        # 获取 batch, num_faces, device
        batch, num_faces, device = *faces.shape[:2], faces.device

        # 获取 faces 中最大的顶点索引
        max_vertex_index = faces.amax()
        # 计算顶点数量
        num_vertices = int(max_vertex_index.item() + 1)

        # 对 face_embed 进行维度投影
        face_embed = self.project_dim_codebook(face_embed)
        # 重新排列 face_embed 的维度
        face_embed = rearrange(face_embed, 'b nf (nvf d) -> b nf nvf d', nvf = self.num_vertices_per_face)

        # 获取顶点维度
        vertex_dim = face_embed.shape[-1]
        # 创建全零的顶点张量
        vertices = torch.zeros((batch, num_vertices, vertex_dim), device = device)

        # 创建 pad 顶点,用于变长的面
        pad_vertex_id = num_vertices
        vertices = pad_at_dim(vertices, (0, 1), dim = -2, value = 0.)

        # 根据 face_mask 对 faces 进行填充
        faces = faces.masked_fill(~rearrange(face_mask, 'b n -> b n 1'), pad_vertex_id)

        # 准备用于 scatter mean 的 faces_with_dim
        faces_with_dim = repeat(faces, 'b nf nvf -> b (nf nvf) d', d = vertex_dim)

        # 重新排列 face_embed 的维度
        face_embed = rearrange(face_embed, 'b ... d -> b (...) d')

        # scatter mean
        averaged_vertices = scatter_mean(vertices, faces_with_dim, face_embed, dim = -2)

        # 掩码掉空顶点令牌
        mask = torch.ones((batch, num_vertices + 1), device = device, dtype = torch.bool)
        mask[:, -1] = False

        # rvq 特定的参数
        quantize_kwargs = dict(mask = mask)

        if isinstance(self.quantizer, ResidualVQ):
            quantize_kwargs.update(sample_codebook_temp = rvq_sample_codebook_temp)

        # 一个使其可内存检查点的量化函数
        def quantize_wrapper_fn(inp):
            unquantized, quantize_kwargs = inp
            return self.quantizer(unquantized, **quantize_kwargs)

        # 可能检查点量化函数
        if self.checkpoint_quantizer:
            quantize_wrapper_fn = partial(checkpoint, quantize_wrapper_fn, use_reentrant = False)

        # 剩余 VQ
        quantized, codes, commit_loss = quantize_wrapper_fn((averaged_vertices, quantize_kwargs))

        # 将量化后的顶点收集回 faces 进行解码
        face_embed_output = get_at('b [n] d, b nf nvf -> b nf (nvf d)', quantized, faces)

        # 顶点代码也需要被收集以便按面序组织,用于自回归学习
        codes_output = get_at('b [n] q, b nf nvf -> b (nf nvf) q', codes, faces)

        # 确保输出的代码具有此填��
        face_mask = repeat(face_mask, 'b nf -> b (nf nvf) 1', nvf = self.num_vertices_per_face)
        codes_output = codes_output.masked_fill(~face_mask, self.pad_id)

        # 输出量化、代码以及承诺损失
        return face_embed_output, codes_output, commit_loss

    @beartype
    def decode(
        self,
        quantized: TensorType['b', 'n', 'd', float],
        face_mask:  TensorType['b', 'n', bool]
    ):
        # 重新排列 face_mask 的维度
        conv_face_mask = rearrange(face_mask, 'b n -> b 1 n')

        x = quantized

        for linear_attn, attn, ff in self.decoder_attn_blocks:
            if exists(linear_attn):
                x = linear_attn(x, mask = face_mask) + x

            x = attn(x, mask = face_mask) + x
            x = ff(x) + x

        # 重新排列 x 的维度
        x = rearrange(x, 'b n d -> b d n')
        x = x.masked_fill(~conv_face_mask, 0.)
        x = self.init_decoder_conv(x)

        for resnet_block in self.decoders:
            x = resnet_block(x, mask = conv_face_mask)

        return rearrange(x, 'b d n -> b n d')

    @beartype
    @torch.no_grad()
    def decode_from_codes_to_faces(
        self,
        codes: Tensor,
        face_mask: Optional[TensorType['b', 'n', bool]] = None,
        return_discrete_codes = False
    ):
        # 重新排列代码,将 'b ...' 转换为 'b (...)'
        codes = rearrange(codes, 'b ... -> b (...)')

        # 如果 face_mask 不存在,则将其设为不等于 self.pad_id 的代码
        if not exists(face_mask):
            face_mask = reduce(codes != self.pad_id, 'b (nf nvf q) -> b nf', 'all', nvf = self.num_vertices_per_face, q = self.num_quantizers)

        # 处理不同的代码形状

        # 重新排列代码,将 'b (n q)' 转换为 'b n q'
        codes = rearrange(codes, 'b (n q) -> b n q', q = self.num_quantizers)

        # 解码

        # 从索引获取量化值
        quantized = self.quantizer.get_output_from_indices(codes)
        # 重新排列量化值,将 'b (nf nvf) d' 转换为 'b nf (nvf d)'
        quantized = rearrange(quantized, 'b (nf nvf) d -> b nf (nvf d)', nvf = self.num_vertices_per_face)

        # 解码
        decoded = self.decode(
            quantized,
            face_mask = face_mask
        )

        # 将未被 face_mask 遮罩的部分填充为 0
        decoded = decoded.masked_fill(~face_mask[..., None], 0.)
        # 将 decoded 转换为坐标 logits
        pred_face_coords = self.to_coor_logits(decoded)

        # 取最大值的索引
        pred_face_coords = pred_face_coords.argmax(dim = -1)

        # 重新排列 pred_face_coords,将 '... (v c)' 转换为 '... v c', v 为 self.num_vertices_per_face
        pred_face_coords = rearrange(pred_face_coords, '... (v c) -> ... v c', v = self.num_vertices_per_face)

        # 转换回连续空间

        continuous_coors = undiscretize(
            pred_face_coords,
            num_discrete = self.num_discrete_coors,
            continuous_range = self.coor_continuous_range
        )

        # 使用 nan 进行遮罩处理

        continuous_coors = continuous_coors.masked_fill(~rearrange(face_mask, 'b nf -> b nf 1 1'), float('nan'))

        # 如果不返回离散代码,则返回 continuous_coors 和 face_mask
        if not return_discrete_codes:
            return continuous_coors, face_mask

        # 返回 continuous_coors、pred_face_coords 和 face_mask

        return continuous_coors, pred_face_coords, face_mask

    @torch.no_grad()
    def tokenize(self, vertices, faces, face_edges = None, **kwargs):
        # 确保 kwargs 中不存在 'return_codes'
        assert 'return_codes' not in kwargs

        inputs = [vertices, faces, face_edges]
        inputs = [*filter(exists, inputs)]
        ndims = {i.ndim for i in inputs}

        # 确保输入的张量维度相同
        assert len(ndims) == 1
        batch_less = first(list(ndims)) == 2

        # 如果 batch_less 为 True,则将输入转换为批量大小为 1 的形式
        if batch_less:
            inputs = [rearrange(i, '... -> 1 ...') for i in inputs]

        input_kwargs = dict(zip(['vertices', 'faces', 'face_edges'], inputs))

        self.eval()

        # 调用 forward 方法,返回代码

        codes = self.forward(
            **input_kwargs,
            return_codes = True,
            **kwargs
        )

        # 如果 batch_less 为 True,则重新排列代码
        if batch_less:
            codes = rearrange(codes, '1 ... -> ...')

        return codes

    @beartype
    def forward(
        self,
        *,
        vertices:       TensorType['b', 'nv', 3, float],
        faces:          TensorType['b', 'nf', 'nvf', int],
        face_edges:     Optional[TensorType['b', 'e', 2, int]] = None,
        return_codes = False,
        return_loss_breakdown = False,
        return_recon_faces = False,
        only_return_recon_faces = False,
        rvq_sample_codebook_temp = 1.
        ):
        # 如果未提供面边缘数据,则从面数据中推导出面边缘数据
        if not exists(face_edges):
            face_edges = derive_face_edges_from_faces(faces, pad_id = self.pad_id)

        # 获取面的数量、面边缘的数量以及设备信息
        num_faces, num_face_edges, device = faces.shape[1], face_edges.shape[1], faces.device

        # 创建面的掩码,标记哪些位置是有效的面
        face_mask = reduce(faces != self.pad_id, 'b nf c -> b nf', 'all')
        # 创建面边缘的掩码,标记哪些位置是有效的面边缘
        face_edges_mask = reduce(face_edges != self.pad_id, 'b e ij -> b e', 'all')

        # 编码输入数据,获取编码结果和面坐标
        encoded, face_coordinates = self.encode(
            vertices = vertices,
            faces = faces,
            face_edges = face_edges,
            face_edges_mask = face_edges_mask,
            face_mask = face_mask,
            return_face_coordinates = True
        )

        # 量化编码结果,获取量化后的数据、编码和损失
        quantized, codes, commit_loss = self.quantize(
            face_embed = encoded,
            faces = faces,
            face_mask = face_mask,
            rvq_sample_codebook_temp = rvq_sample_codebook_temp
        )

        # 如果需要返回编码结果
        if return_codes:
            assert not return_recon_faces, 'cannot return reconstructed faces when just returning raw codes'

            # 将编码结果填充到掩码之外的位置
            codes = codes.masked_fill(~repeat(face_mask, 'b nf -> b (nf nvf) 1', nvf = self.num_vertices_per_face), self.pad_id)
            return codes

        # 解码量化后的数据,获取解码结果
        decode = self.decode(
            quantized,
            face_mask = face_mask
        )

        # 将解码结果转换为坐标概率
        pred_face_coords = self.to_coor_logits(decode)

        # 如果需要计算重构的面
        if return_recon_faces or only_return_recon_faces:

            # 将坐标概率反离散化为坐标
            recon_faces = undiscretize(
                pred_face_coords.argmax(dim = -1),
                num_discrete = self.num_discrete_coors,
                continuous_range = self.coor_continuous_range,
            )

            # 重排重构的面数据
            recon_faces = rearrange(recon_faces, 'b nf (nvf c) -> b nf nvf c', nvf = self.num_vertices_per_face)
            face_mask = rearrange(face_mask, 'b nf -> b nf 1 1')
            recon_faces = recon_faces.masked_fill(~face_mask, float('nan'))
            face_mask = rearrange(face_mask, 'b nf 1 1 -> b nf')

        # 如果只需要返回重构的面数据
        if only_return_recon_faces:
            return recon_faces

        # 准备重构损失
        pred_face_coords = rearrange(pred_face_coords, 'b ... c -> b c (...)')
        face_coordinates = rearrange(face_coordinates, 'b ... -> b 1 (...)')

        # 计算重构损失,使用局部平滑
        with autocast(enabled = False):
            pred_log_prob = pred_face_coords.log_softmax(dim = 1)

            target_one_hot = torch.zeros_like(pred_log_prob).scatter(1, face_coordinates, 1.)

            if self.bin_smooth_blur_sigma >= 0.:
                target_one_hot = gaussian_blur_1d(target_one_hot, sigma = self.bin_smooth_blur_sigma)

            # 使用局部平滑的交叉熵损失
            recon_losses = (-target_one_hot * pred_log_prob).sum(dim = 1)

            face_mask = repeat(face_mask, 'b nf -> b (nf r)', r = self.num_vertices_per_face * 3)
            recon_loss = recon_losses[face_mask].mean()

        # 计算总损失
        total_loss = recon_loss + \
                     commit_loss.sum() * self.commit_loss_weight

        # 如果需要计算损失细分
        loss_breakdown = (recon_loss, commit_loss)

        # 返回逻辑
        if not return_loss_breakdown:
            if not return_recon_faces:
                return total_loss

            return recon_faces, total_loss

        if not return_recon_faces:
            return total_loss, loss_breakdown

        return recon_faces, total_loss, loss_breakdown
# 定义一个 MeshTransformer 类,用于处理网格数据的转换
@save_load(version = __version__)
class MeshTransformer(Module):
    # 初始化 MeshTransformer 类
    @beartype
    def __init__(
        self,
        autoencoder: MeshAutoencoder, # 接收一个 MeshAutoencoder 对象作为参数
        *,
        dim: Union[int, Tuple[int, int]] = 512, # 维度参数,默认为 512
        max_seq_len = 8192, # 最大序列长度,默认为 8192
        flash_attn = True, # 是否使用 flash attention,默认为 True
        attn_depth = 12, # 注意力层的深度,默认为 12
        attn_dim_head = 64, # 注意力头的维度,默认为 64
        attn_heads = 16, # 注意力头的数量,默认为 16
        attn_kwargs: dict = dict( # 注意力参数字典,默认包含 ff_glu 和 num_mem_kv 两个键值对
            ff_glu = True,
            num_mem_kv = 4
        ),
        cross_attn_num_mem_kv = 4, # 用于防止在丢弃文本条件时出现 NaN 的交叉注意力数目
        dropout = 0., # 丢弃概率,默认为 0
        coarse_pre_gateloop_depth = 2, # 粗粒度预门循环的深度,默认为 2
        fine_pre_gateloop_depth = 2, # 细粒度预门循环的深度,默认为 2
        gateloop_use_heinsen = False, # 是否使用 Heinsen 门循环,默认为 False
        fine_attn_depth = 2, # 细粒度注意力层的深度,默认为 2
        fine_attn_dim_head = 32, # 细粒度注意力头的维度,默认为 32
        fine_attn_heads = 8, # 细粒度注意力头的数量,默认为 8
        pad_id = -1, # 填充 ID,默认为 -1
        condition_on_text = False, # 是否基于文本条件,默认为 False
        text_condition_model_types = ('t5',), # 文本条件模型类型,默认为 ('t5',)
        text_condition_cond_drop_prob = 0.25, # 文本条件丢弃概率,默认为 0.25
        quads = False # 是否使用四元组,默认为 False
    ):
        # 调用父类的构造函数
        super().__init__()
        # 如果不是四边形,则每个面的顶点数为3,否则为4
        self.num_vertices_per_face = 3 if not quads else 4

        # 断言自动编码器和转换器必须都支持相同类型的网格(全三角形或全四边形)
        assert autoencoder.num_vertices_per_face == self.num_vertices_per_face, 'autoencoder and transformer must both support the same type of mesh (either all triangles, or all quads)'

        # 如果维度是整数,则将其转换为元组
        dim, dim_fine = (dim, dim) if isinstance(dim, int) else dim

        # 设置自动编码器,并将其梯度设置为False
        self.autoencoder = autoencoder
        set_module_requires_grad_(autoencoder, False)

        # 获取自动编码器的代码本大小和量化器数量
        self.codebook_size = autoencoder.codebook_size
        self.num_quantizers = autoencoder.num_quantizers

        # 初始化起始标记和结束标记
        self.sos_token = nn.Parameter(torch.randn(dim_fine))
        self.eos_token_id = self.codebook_size

        # 断言最大序列长度必须能够被(3 x self.num_quantizers)整除
        assert divisible_by(max_seq_len, self.num_vertices_per_face * self.num_quantizers), f'max_seq_len ({max_seq_len}) must be divisible by (3 x {self.num_quantizers}) = {3 * self.num_quantizers}' # 3 or 4 vertices per face, with D codes per vertex

        # 初始化标记嵌入
        self.token_embed = nn.Embedding(self.codebook_size + 1, dim)

        # 初始化量化级别嵌入和顶点嵌入
        self.quantize_level_embed = nn.Parameter(torch.randn(self.num_quantizers, dim))
        self.vertex_embed = nn.Parameter(torch.randn(self.num_vertices_per_face, dim))

        # 初始化绝对位置嵌入
        self.abs_pos_emb = nn.Embedding(max_seq_len, dim)

        # 设置最大序列长度
        self.max_seq_len = max_seq_len

        # 文本条件
        self.condition_on_text = condition_on_text
        self.conditioner = None

        cross_attn_dim_context = None

        # 如果有文本条件,则初始化文本嵌入返回器
        if condition_on_text:
            self.conditioner = TextEmbeddingReturner(
                model_types = text_condition_model_types,
                cond_drop_prob = text_condition_cond_drop_prob
            )
            cross_attn_dim_context = self.conditioner.dim_latent

        # 用于总结每个面的顶点
        self.to_face_tokens = nn.Sequential(
            nn.Linear(self.num_quantizers * self.num_vertices_per_face * dim, dim),
            nn.LayerNorm(dim)
        )

        # 初始化粗粒度门环块
        self.coarse_gateloop_block = GateLoopBlock(dim, depth = coarse_pre_gateloop_depth, use_heinsen = gateloop_use_heinsen) if coarse_pre_gateloop_depth > 0 else None

        # 主自回归注意力网络,关注面标记
        self.decoder = Decoder(
            dim = dim,
            depth = attn_depth,
            dim_head = attn_dim_head,
            heads = attn_heads,
            attn_flash = flash_attn,
            attn_dropout = dropout,
            ff_dropout = dropout,
            cross_attend = condition_on_text,
            cross_attn_dim_context = cross_attn_dim_context,
            cross_attn_num_mem_kv = cross_attn_num_mem_kv,
            **attn_kwargs
        )

        # 如果需要,从粗到细的投影
        self.maybe_project_coarse_to_fine = nn.Linear(dim, dim_fine) if dim != dim_fine else nn.Identity()

        # 解决注意力中的一个弱点
        self.fine_gateloop_block = GateLoopBlock(dim, depth = fine_pre_gateloop_depth) if fine_pre_gateloop_depth > 0 else None

        # 解码顶点,两阶层次
        self.fine_decoder = Decoder(
            dim = dim_fine,
            depth = fine_attn_depth,
            dim_head = attn_dim_head,
            heads = attn_heads,
            attn_flash = flash_attn,
            attn_dropout = dropout,
            ff_dropout = dropout,
            **attn_kwargs
        )

        # 转换为逻辑值
        self.to_logits = nn.Linear(dim_fine, self.codebook_size + 1)

        # 填充ID,强制自动编码器使用转换器中给定的相同填充ID
        self.pad_id = pad_id
        autoencoder.pad_id = pad_id

    @property
    def device(self):
        # 返回参数的设备
        return next(self.parameters()).device

    @beartype
    @torch.no_grad()
    # 将文本嵌入到向量空间中
    def embed_texts(self, texts: Union[str, List[str]]):
        # 检查是否为单个文本
        single_text = not isinstance(texts, list)
        # 如果是单个文本,则转换为列表
        if single_text:
            texts = [texts]

        # 断言条件器存在
        assert exists(self.conditioner)
        # 嵌入文本到向量空间中并分离计算图
        text_embeds = self.conditioner.embed_texts(texts).detach()

        # 如果是单个文本,则取第一个文本的嵌入向量
        if single_text:
            text_embeds = text_embeds[0]

        # 返回文本的嵌入向量
        return text_embeds

    # 生成文本
    @eval_decorator
    @torch.no_grad()
    @beartype
    def generate(
        self,
        prompt: Optional[Tensor] = None,
        batch_size: Optional[int] = None,
        filter_logits_fn: Callable = top_k,
        filter_kwargs: dict = dict(),
        temperature = 1.,
        return_codes = False,
        texts: Optional[List[str]] = None,
        text_embeds: Optional[Tensor] = None,
        cond_scale = 1.,
        cache_kv = True,
        max_seq_len = None,
        face_coords_to_file: Optional[Callable[[Tensor], Any]] = None
    ):
        # 设置最大序列长度为默认值或者给定的值
        max_seq_len = default(max_seq_len, self.max_seq_len)

        # 如果存在提示信息
        if exists(prompt):
            # 确保批处理大小不存在
            assert not exists(batch_size)

            # 重新排列提示信息的维度
            prompt = rearrange(prompt, 'b ... -> b (...)')
            # 确保提示信息的长度不超过最大序列长度
            assert prompt.shape[-1] <= self.max_seq_len

            # 设置批处理大小为提示信息的批次大小
            batch_size = prompt.shape[0]

        # 如果需要根据文本条件生成
        if self.condition_on_text:
            # 确保文本或文本嵌入存在,且二者只能存在一个
            assert exists(texts) ^ exists(text_embeds), '`text` or `text_embeds` must be passed in if `condition_on_text` is set to True'
            # 如果文本存在,则生成文本嵌入
            if exists(texts):
                text_embeds = self.embed_texts(texts)

            # 设置批处理大小为文本嵌入的批次大小
            batch_size = default(batch_size, text_embeds.shape[0])

        # 设置批处理大小为默认值或者1
        batch_size = default(batch_size, 1)

        # 初始化代码张量
        codes = default(prompt, torch.empty((batch_size, 0), dtype = torch.long, device = self.device))

        # 获取当前代码长度
        curr_length = codes.shape[-1]

        # 初始化缓存
        cache = (None, None)

        # 循环生成序列
        for i in tqdm(range(curr_length, max_seq_len)):

            # 判断是否可以生成结束符
            can_eos = i != 0 and divisible_by(i, self.num_quantizers * self.num_vertices_per_face)  # 只允许在每个面的末尾生成结束符,定义为具有 D 个残差 VQ 代码的 3 或 4 个顶点的面

            # 在代码上进行前向传播
            output = self.forward_on_codes(
                codes,
                text_embeds = text_embeds,
                return_loss = False,
                return_cache = cache_kv,
                append_eos = False,
                cond_scale = cond_scale,
                cfg_routed_kwargs = dict(
                    cache = cache
                )
            )

            # 如果使用缓存
            if cache_kv:
                logits, cache = output

                if cond_scale == 1.:
                    cache = (cache, None)
            else:
                logits = output

            logits = logits[:, -1]

            # ���果不能生成结束符,则将结束符位置的概率设为最小值
            if not can_eos:
                logits[:, -1] = -torch.finfo(logits.dtype).max

            # 过滤logits
            filtered_logits = filter_logits_fn(logits, **filter_kwargs)

            # 根据温度参数进行采样
            if temperature == 0.:
                sample = filtered_logits.argmax(dim = -1)
            else:
                probs = F.softmax(filtered_logits / temperature, dim = -1)
                sample = torch.multinomial(probs, 1)

            # 将采样结果添加到代码中
            codes, _ = pack([codes, sample], 'b *')

            # 检查是否所有行都有结束符以终止
            is_eos_codes = (codes == self.eos_token_id)

            if is_eos_codes.any(dim = -1).all():
                break

        # 掩盖第一个结束符后的所有内容
        mask = is_eos_codes.float().cumsum(dim = -1) >= 1
        codes = codes.masked_fill(mask, self.pad_id)

        # 移除可能存在的额外结束符
        code_len = codes.shape[-1]
        round_down_code_len = code_len // self.num_quantizers * self.num_quantizers
        codes = codes[:, :round_down_code_len]

        # 如果需要返回代码,则返回原始残差量化器代码
        if return_codes:
            codes = rearrange(codes, 'b (n q) -> b n q', q = self.num_quantizers)
            return codes

        # 将自动编码器设置为评估模式
        self.autoencoder.eval()
        # 从代码解码到面的坐标
        face_coords, face_mask = self.autoencoder.decode_from_codes_to_faces(codes)

        # 如果不存在面坐标到文件的映射,则返回面坐标和面掩码
        if not exists(face_coords_to_file):
            return face_coords, face_mask

        # 生成面坐标到文件的映射
        files = [face_coords_to_file(coords[mask]) for coords, mask in zip(face_coords, face_mask)]
        return files

    # 前向传播函数
    def forward(
        self,
        *,
        vertices:       TensorType['b', 'nv', 3, int],
        faces:          TensorType['b', 'nf', 'nvf', int],
        face_edges:     Optional[TensorType['b', 'e', 2, int]] = None,
        codes:          Optional[Tensor] = None,
        cache:          Optional[LayerIntermediates] = None,
        **kwargs
    # 如果未提供codes,则调用autoencoder的tokenize方法生成codes
    ):
        # 如果codes不存在,则使用autoencoder的tokenize方法生成codes
        if not exists(codes):
            codes = self.autoencoder.tokenize(
                vertices = vertices,
                faces = faces,
                face_edges = face_edges
            )

        # 调用forward_on_codes方法,传入codes和其他参数
        return self.forward_on_codes(codes, cache = cache, **kwargs)

    # 标记为classifier_free_guidance的方法,用于在codes上执行前向传播
    def forward_on_codes(
        self,
        codes = None,
        return_loss = True,
        return_cache = False,
        append_eos = True,
        cache = None,
        texts: Optional[List[str]] = None,
        text_embeds: Optional[Tensor] = None,
        cond_drop_prob = None

.\lucidrains\meshgpt-pytorch\meshgpt_pytorch\trainer.py

# 导入必要的库
from pathlib import Path
from functools import partial
from packaging import version
from contextlib import nullcontext, contextmanager

import torch
from torch import nn, Tensor
from torch.nn import Module
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import _LRScheduler

# 导入自定义的工具函数和类
from pytorch_custom_utils import (
    get_adam_optimizer,
    OptimizerWithWarmupSchedule,
    add_wandb_tracker_contextmanager
)

# 导入加速库
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs

# 导入类型检查相关库
from beartype import beartype
from beartype.door import is_bearable
from beartype.typing import Optional, Tuple, Type, List

# 导入指数移动平均库
from ema_pytorch import EMA

# 导入数据处理相关函数
from meshgpt_pytorch.data import custom_collate

# 导入版本号
from meshgpt_pytorch.version import __version__

# 导入 MeshGPT 相关模型
from meshgpt_pytorch.meshgpt_pytorch import (
    MeshAutoencoder,
    MeshTransformer
)

# 常量定义
DEFAULT_DDP_KWARGS = DistributedDataParallelKwargs(
    find_unused_parameters = True
)

# 辅助函数定义

# 判断变量是否存在
def exists(v):
    return v is not None

# 返回默认值
def default(v, d):
    return v if exists(v) else d

# 判断是否可以整除
def divisible_by(num, den):
    return (num % den) == 0

# 生成数据循环迭代器
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 删除字典中指定的键
def maybe_del(d: dict, *keys):
    for key in keys:
        if key not in d:
            continue

        del d[key]

# 自动编码器训练器类定义

# 添加 WandB 追踪上下文管理器
@add_wandb_tracker_contextmanager()
class MeshAutoencoderTrainer(Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        model: MeshAutoencoder,
        dataset: Dataset,
        num_train_steps: int,
        batch_size: int,
        grad_accum_every: int,
        val_dataset: Optional[Dataset] = None,
        val_every: int = 100,
        val_num_batches: int = 5,
        learning_rate: float = 1e-4,
        weight_decay: float = 0.,
        max_grad_norm: Optional[float] = None,
        ema_kwargs: dict = dict(),
        scheduler: Optional[Type[_LRScheduler]] = None,
        scheduler_kwargs: dict = dict(),
        accelerator_kwargs: dict = dict(),
        optimizer_kwargs: dict = dict(),
        checkpoint_every = 1000,
        checkpoint_folder = './checkpoints',
        data_kwargs: Tuple[str, ...] = ['vertices', 'faces', 'face_edges'],
        warmup_steps = 1000,
        use_wandb_tracking = False
    # 初始化函数,继承父类的初始化方法
    def __init__(
        self,
        model,
        dataset,
        learning_rate,
        batch_size,
        optimizer_kwargs = {},
        scheduler = None,
        scheduler_kwargs = {},
        warmup_steps = 0,
        max_grad_norm = 1.0,
        grad_accum_every = 1,
        num_train_steps = None,
        checkpoint_every = None,
        checkpoint_folder = 'checkpoints',
        ema_kwargs = {},
        val_dataset = None,
        val_every = 1000,
        val_num_batches = 10,
        data_kwargs = {}
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 实验追踪器
        self.use_wandb_tracking = use_wandb_tracking

        # 如果使用 wandb 追踪
        if use_wandb_tracking:
            # 设置加速器参数中的日志记录方式为 'wandb'
            accelerator_kwargs['log_with'] = 'wandb'

        # 如果加速器参数中没有 'kwargs_handlers'
        if 'kwargs_handlers' not in accelerator_kwargs:
            # 设置加速器参数中的 'kwargs_handlers' 为默认的 DDP 参数
            accelerator_kwargs['kwargs_handlers'] = [DEFAULT_DDP_KWARGS]

        # 初始化加速器
        self.accelerator = Accelerator(**accelerator_kwargs)

        # 设置模型
        self.model = model

        # 如果是主进程
        if self.is_main:
            # 初始化 EMA 模型
            self.ema_model = EMA(model, **ema_kwargs)

        # 初始化优化器
        self.optimizer = OptimizerWithWarmupSchedule(
            accelerator = self.accelerator,
            optimizer = get_adam_optimizer(model.parameters(), lr = learning_rate, wd = weight_decay, **optimizer_kwargs),
            scheduler = scheduler,
            scheduler_kwargs = scheduler_kwargs,
            warmup_steps = warmup_steps,
            max_grad_norm = max_grad_norm
        )

        # 初始化数据加载器
        self.dataloader = DataLoader(
            dataset,
            batch_size = batch_size,
            shuffle = True,
            drop_last = True,
            collate_fn = partial(custom_collate, pad_id = model.pad_id)
        )

        # 是否需要验证
        self.should_validate = exists(val_dataset)

        # 如果需要验证
        if self.should_validate:
            # 确保验证数据集不为空
            assert len(val_dataset) > 0, 'your validation dataset is empty'

            # 设置验证频率和验证批次数
            self.val_every = val_every
            self.val_num_batches = val_num_batches

            # 初始化验证数据加载器
            self.val_dataloader = DataLoader(
                val_dataset,
                batch_size = batch_size,
                shuffle = True,
                drop_last = True,
                collate_fn = partial(custom_collate, pad_id = model.pad_id)
            )

        # 如果数据集具有 'data_kwargs' 属性且不为空
        if hasattr(dataset, 'data_kwargs') and exists(dataset.data_kwargs):
            # 确保数据参数是字符串列表
            assert is_bearable(dataset.data_kwargs, List[str])
            self.data_kwargs = dataset.data_kwargs
        else:
            self.data_kwargs = data_kwargs

        # 准备模型和数据加载器
        (
            self.model,
            self.dataloader
        ) = self.accelerator.prepare(
            self.model,
            self.dataloader
        )

        # 设置梯度累积步数和训练步数
        self.grad_accum_every = grad_accum_every
        self.num_train_steps = num_train_steps
        self.register_buffer('step', torch.tensor(0))

        # 设置检查点保存频率和文件夹
        self.checkpoint_every = checkpoint_every
        self.checkpoint_folder = Path(checkpoint_folder)
        self.checkpoint_folder.mkdir(exist_ok = True, parents = True)

    # 获取 EMA tokenizer
    @property
    def ema_tokenizer(self):
        return self.ema_model.ema_model

    # 分词方法
    def tokenize(self, *args, **kwargs):
        return self.ema_tokenizer.tokenize(*args, **kwargs)

    # 日志记录方法
    def log(self, **data_kwargs):
        self.accelerator.log(data_kwargs, step = self.step.item())

    # 获取设备
    @property
    def device(self):
        return self.unwrapped_model.device

    # 是否为主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 获取未包装的模型
    @property
    def unwrapped_model(self):
        return self.accelerator.unwrap_model(self.model)

    # 是否为本地主进程
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    # 等待方法
    def wait(self):
        return self.accelerator.wait_for_everyone()

    # 打印方法
    def print(self, msg):
        return self.accelerator.print(msg)

    # 保存方法
    def save(self, path, overwrite = True):
        path = Path(path)
        # 如果覆盖或路径不存在
        assert overwrite or not path.exists()

        # 保存模型、EMA 模型、优化器等信息到文件
        pkg = dict(
            model = self.unwrapped_model.state_dict(),
            ema_model = self.ema_model.state_dict(),
            optimizer = self.optimizer.state_dict(),
            version = __version__,
            step = self.step.item(),
            config = self.unwrapped_model._config
        )

        torch.save(pkg, str(path))
    # 加载模型参数
    def load(self, path):
        # 将路径转换为 Path 对象
        path = Path(path)
        # 断言路径存在
        assert path.exists()

        # 加载模型参数
        pkg = torch.load(str(path))

        # 检查模型版本是否与当前包版本匹配
        if version.parse(__version__) != version.parse(pkg['version']):
            self.print(f'loading saved mesh autoencoder at version {pkg["version"]}, but current package version is {__version__}')

        # 加载模型参数
        self.model.load_state_dict(pkg['model'])
        self.ema_model.load_state_dict(pkg['ema_model'])
        self.optimizer.load_state_dict(pkg['optimizer'])

        # 加载步数
        self.step.copy_(pkg['step'])

    # 获取下一个要传递给 forward 方法的数据
    def next_data_to_forward_kwargs(self, dl_iter) -> dict:
        # 获取下一个数据
        data = next(dl_iter)

        # 根据数据类型创建传递给 forward 方法的参数字典
        if isinstance(data, tuple):
            forward_kwargs = dict(zip(self.data_kwargs, data))

        elif isinstance(data, dict):
            forward_kwargs = data

        # 删除不需要的键
        maybe_del(forward_kwargs, 'texts', 'text_embeds')
        return forward_kwargs

    # 前向传播方法
    def forward(self):
        # 获取当前步数
        step = self.step.item()
        # 创建数据加载器迭代器
        dl_iter = cycle(self.dataloader)

        # 如果是主进程且需要验证
        if self.is_main and self.should_validate:
            val_dl_iter = cycle(self.val_dataloader)

        # 循环训练步数
        while step < self.num_train_steps:

            # 对于每个梯度累积步数
            for i in range(self.grad_accum_every):
                is_last = i == (self.grad_accum_every - 1)
                maybe_no_sync = partial(self.accelerator.no_sync, self.model) if not is_last else nullcontext

                # 获取下一个要传递给 forward 方法的参数
                forward_kwargs = self.next_data_to_forward_kwargs(dl_iter)

                with self.accelerator.autocast(), maybe_no_sync():

                    # 执行模型前向传播
                    total_loss, (recon_loss, commit_loss) = self.model(
                        **forward_kwargs,
                        return_loss_breakdown = True
                    )

                    # 反向传播
                    self.accelerator.backward(total_loss / self.grad_accum_every)

            # 打印重建损失和压缩损失
            self.print(f'recon loss: {recon_loss.item():.3f} | commit loss: {commit_loss.sum().item():.3f}')

            # 记录损失
            self.log(
                total_loss = total_loss.item(),
                commit_loss = commit_loss.sum().item(),
                recon_loss = recon_loss.item()
            )

            # 更新优化器
            self.optimizer.step()
            self.optimizer.zero_grad()

            # 更新步数
            step += 1
            self.step.add_(1)

            # 等待
            self.wait()

            # 如果是主进程,更新 EMA 模型
            if self.is_main:
                self.ema_model.update()

            # 等待
            self.wait()

            # 如果是主进程且需要验证,并且步数是验证间隔的倍数
            if self.is_main and self.should_validate and divisible_by(step, self.val_every):

                total_val_recon_loss = 0.
                self.ema_model.eval()

                num_val_batches = self.val_num_batches * self.grad_accum_every

                # 验证模型
                for _ in range(num_val_batches):
                    with self.accelerator.autocast(), torch.no_grad():

                        forward_kwargs = self.next_data_to_forward_kwargs(val_dl_iter)

                        val_loss, (val_recon_loss, val_commit_loss) = self.ema_model(
                            **forward_kwargs,
                            return_loss_breakdown = True
                        )

                        total_val_recon_loss += (val_recon_loss / num_val_batches)

                # 打印验证重建损失
                self.print(f'valid recon loss: {total_val_recon_loss:.3f}')

                # 记录验证损失
                self.log(val_loss = total_val_recon_loss)

            # 等待
            self.wait()

            # 如果是主进程且步数是保存检查点间隔的倍数
            if self.is_main and divisible_by(step, self.checkpoint_every):
                checkpoint_num = step // self.checkpoint_every
                self.save(self.checkpoint_folder / f'mesh-autoencoder.ckpt.{checkpoint_num}.pt')

            # 等待
            self.wait()

        # 训练完成
        self.print('training complete')
# mesh transformer trainer

# 添加 WandB跟踪上下文管理器
@add_wandb_tracker_contextmanager()
class MeshTransformerTrainer(Module):
    # 初始化函数
    @beartype
    def __init__(
        self,
        model: MeshTransformer,
        dataset: Dataset,
        num_train_steps: int,
        batch_size: int,
        grad_accum_every: int,
        learning_rate: float = 2e-4,
        weight_decay: float = 0.,
        max_grad_norm: Optional[float] = 0.5,
        val_dataset: Optional[Dataset] = None,
        val_every = 1,
        val_num_batches = 5,
        scheduler: Optional[Type[_LRScheduler]] = None,
        scheduler_kwargs: dict = dict(),
        ema_kwargs: dict = dict(),
        accelerator_kwargs: dict = dict(),
        optimizer_kwargs: dict = dict(),
        checkpoint_every = 1000,
        checkpoint_folder = './checkpoints',
        data_kwargs: Tuple[str, ...] = ['vertices', 'faces', 'face_edges', 'texts'],
        warmup_steps = 1000,
        use_wandb_tracking = False
    ):
        super().__init__()

        # 实验跟踪器

        # 设置是否使用WandB跟踪
        self.use_wandb_tracking = use_wandb_tracking

        # 如果使用WandB跟踪,则设置加速器参数中的日志记录方式为'wandb'
        if use_wandb_tracking:
            accelerator_kwargs['log_with'] = 'wandb'

        # 如果加速器参数中没有'kwargs_handlers',则添加默认的DDP参数处理器
        if 'kwargs_handlers' not in accelerator_kwargs:
            accelerator_kwargs['kwargs_handlers'] = [DEFAULT_DDP_KWARGS]

        # 创建加速器对象
        self.accelerator = Accelerator(**accelerator_kwargs)

        # 设置模型
        self.model = model

        # 获取Adam优化器
        optimizer = get_adam_optimizer(
            model.parameters(),
            lr = learning_rate,
            wd = weight_decay,
            filter_by_requires_grad = True,
            **optimizer_kwargs
        )

        # 设置优化器和学习率调度器
        self.optimizer = OptimizerWithWarmupSchedule(
            accelerator = self.accelerator,
            optimizer = optimizer,
            scheduler = scheduler,
            scheduler_kwargs = scheduler_kwargs,
            warmup_steps = warmup_steps,
            max_grad_norm = max_grad_norm
        )

        # 创建数据加载器
        self.dataloader = DataLoader(
            dataset,
            batch_size = batch_size,
            shuffle = True,
            drop_last = True,
            collate_fn = partial(custom_collate, pad_id = model.pad_id)
        )

        # 是否需要验证
        self.should_validate = exists(val_dataset)

        # 如果需要验证
        if self.should_validate:
            assert len(val_dataset) > 0, 'your validation dataset is empty'

            self.val_every = val_every
            self.val_num_batches = val_num_batches

            # 创建验证数据加载器
            self.val_dataloader = DataLoader(
                val_dataset,
                batch_size = batch_size,
                shuffle = True,
                drop_last = True,
                collate_fn = partial(custom_collate, pad_id = model.pad_id)
            )

        # 如果数据集有'data_kwargs'属性且存在
        if hasattr(dataset, 'data_kwargs') and exists(dataset.data_kwargs):
            assert is_bearable(dataset.data_kwargs, List[str])
            self.data_kwargs = dataset.data_kwargs
        else:
            self.data_kwargs = data_kwargs

        # 准备模型和数据加载器
        (
            self.model,
            self.dataloader
        ) = self.accelerator.prepare(
            self.model,
            self.dataloader
        )

        # 设置梯度累积次数、训练步数、注册缓冲区
        self.grad_accum_every = grad_accum_every
        self.num_train_steps = num_train_steps
        self.register_buffer('step', torch.tensor(0))

        # 设置检查点保存频率和文件夹路径
        self.checkpoint_every = checkpoint_every
        self.checkpoint_folder = Path(checkpoint_folder)
        self.checkpoint_folder.mkdir(exist_ok = True, parents = True)

    # 日志记录函数
    def log(self, **data_kwargs):
        self.accelerator.log(data_kwargs, step = self.step.item())

    # 设备属性
    @property
    def device(self):
        return self.unwrapped_model.device

    # 是否为主进程属性
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 未包装模型属性
    @property
    def unwrapped_model(self):
        return self.accelerator.unwrap_model(self.model)

    # 是否为本地主进程属性
    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    # 等待函数
    def wait(self):
        return self.accelerator.wait_for_everyone()
    # 打印消息,调用加速器的打印方法
    def print(self, msg):
        return self.accelerator.print(msg)

    # 获取下一个要传递给前向传播的数据,并返回关键字参数字典
    def next_data_to_forward_kwargs(self, dl_iter) -> dict:
        # 获取下一个数据
        data = next(dl_iter)

        # 如果数据是元组,则将数据关键字与数据值组成字典
        if isinstance(data, tuple):
            forward_kwargs = dict(zip(self.data_kwargs, data))

        # 如果数据是字典,则直接使用该字典
        elif isinstance(data, dict):
            forward_kwargs = data

        return forward_kwargs

    # 保存模型和优化器状态到指定路径
    def save(self, path, overwrite = True):
        path = Path(path)
        assert overwrite or not path.exists()

        # 构建要保存的数据包
        pkg = dict(
            model = self.unwrapped_model.state_dict(),
            optimizer = self.optimizer.state_dict(),
            step = self.step.item(),
            version = __version__
        )

        # 使用torch保存数据包到指定路径
        torch.save(pkg, str(path))

    # 从指定路径加载模型和优化器状态
    def load(self, path):
        path = Path(path)
        assert path.exists()

        # 加载数据包
        pkg = torch.load(str(path))

        # 检查加载的模型版本与当前包版本是否一致
        if version.parse(__version__) != version.parse(pkg['version']):
            self.print(f'loading saved mesh transformer at version {pkg["version"]}, but current package version is {__version__}')

        # 加载模型和优化器状态
        self.model.load_state_dict(pkg['model'])
        self.optimizer.load_state_dict(pkg['optimizer'])
        self.step.copy_(pkg['step'])

    # 模型的前向传播方法
    def forward(self):
        step = self.step.item()
        dl_iter = cycle(self.dataloader)

        # 如果需要验证,则创建验证数据迭代器
        if self.should_validate:
            val_dl_iter = cycle(self.val_dataloader)

        # 循环训练步数
        while step < self.num_train_steps:

            # 对于每个梯度累积步数
            for i in range(self.grad_accum_every):
                is_last = i == (self.grad_accum_every - 1)
                maybe_no_sync = partial(self.accelerator.no_sync, self.model) if not is_last else nullcontext

                # 获取下一个要传递给前向传播的数据关键字参数
                forward_kwargs = self.next_data_to_forward_kwargs(dl_iter)

                # 使用自动混合精度进行前向传播
                with self.accelerator.autocast(), maybe_no_sync():
                    loss = self.model(**forward_kwargs)

                    # 反向传播
                    self.accelerator.backward(loss / self.grad_accum_every)

            self.print(f'loss: {loss.item():.3f}')

            # 记录损失
            self.log(loss = loss.item())

            # 更新优化器
            self.optimizer.step()
            self.optimizer.zero_grad()

            step += 1
            self.step.add_(1)

            self.wait()

            # 如果是主进程且需要验证,并且当前步数是验证间隔的倍数
            if self.is_main and self.should_validate and divisible_by(step, self.val_every):

                total_val_loss = 0.
                self.unwrapped_model.eval()

                num_val_batches = self.val_num_batches * self.grad_accum_every

                # 验证损失计算
                for _ in range(num_val_batches):
                    with self.accelerator.autocast(), torch.no_grad():

                        forward_kwargs = self.next_data_to_forward_kwargs(val_dl_iter)

                        val_loss = self.unwrapped_model(**forward_kwargs)

                        total_val_loss += (val_loss / num_val_batches)

                self.print(f'valid recon loss: {total_val_loss:.3f}')

                # 记录验证损失
                self.log(val_loss = total_val_loss)

            self.wait()

            # 如果是主进程且当前步数是保存检查点间隔的倍数
            if self.is_main and divisible_by(step, self.checkpoint_every):
                checkpoint_num = step // self.checkpoint_every
                self.save(self.checkpoint_folder / f'mesh-transformer.ckpt.{checkpoint_num}.pt')

            self.wait()

        self.print('training complete')

.\lucidrains\meshgpt-pytorch\meshgpt_pytorch\version.py

# 定义当前代码的版本号为 '1.1.1'
__version__ = '1.1.1'

.\lucidrains\meshgpt-pytorch\meshgpt_pytorch\__init__.py

# 从 meshgpt_pytorch 包中导入 MeshAutoencoder 和 MeshTransformer 类
from meshgpt_pytorch.meshgpt_pytorch import (
    MeshAutoencoder,
    MeshTransformer
)

# 从 meshgpt_pytorch 包中导入 MeshAutoencoderTrainer 和 MeshTransformerTrainer 类
from meshgpt_pytorch.trainer import (
    MeshAutoencoderTrainer,
    MeshTransformerTrainer
)

# 从 meshgpt_pytorch 包中导入 DatasetFromTransforms、cache_text_embeds_for_dataset 和 cache_face_edges_for_dataset 函数
from meshgpt_pytorch.data import (
    DatasetFromTransforms,
    cache_text_embeds_for_dataset,
    cache_face_edges_for_dataset
)

MeshGPT - Pytorch

Implementation of MeshGPT, SOTA Mesh generation using Attention, in Pytorch

Will also add text conditioning, for eventual text-to-3d asset

Please join Join us on Discord if you are interested in collaborating with others to replicate this work

Appreciation

  • StabilityAI, A16Z Open Source AI Grant Program, and 🤗 Huggingface for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research

  • Einops for making my life easy

  • Marcus for the initial code review (pointing out some missing derived features) as well as running the first successful end-to-end experiments

  • Marcus for the first successful training of a collection of shapes conditioned on labels

  • Quexi Ma for finding numerous bugs with automatic eos handling

  • Yingtian for finding a bug with the gaussian blurring of the positions for spatial label smoothing

  • Marcus yet again for running the experiments to validate that it is possible to extend the system from triangles to quads

Install

$ pip install meshgpt-pytorch

Usage

import torch

from meshgpt_pytorch import (
    MeshAutoencoder,
    MeshTransformer
)

# autoencoder

autoencoder = MeshAutoencoder(
    num_discrete_coors = 128
)

# mock inputs

vertices = torch.randn((2, 121, 3))            # (batch, num vertices, coor (3))
faces = torch.randint(0, 121, (2, 64, 3))      # (batch, num faces, vertices (3))

# make sure faces are padded with `-1` for variable lengthed meshes

# forward in the faces

loss = autoencoder(
    vertices = vertices,
    faces = faces
)

loss.backward()

# after much training...
# you can pass in the raw face data above to train a transformer to model this sequence of face vertices

transformer = MeshTransformer(
    autoencoder,
    dim = 512,
    max_seq_len = 768
)

loss = transformer(
    vertices = vertices,
    faces = faces
)

loss.backward()

# after much training of transformer, you can now sample novel 3d assets

faces_coordinates, face_mask = transformer.generate()

# (batch, num faces, vertices (3), coordinates (3)), (batch, num faces)
# now post process for the generated 3d asset

For text-conditioned 3d shape synthesis, simply set condition_on_text = True on your MeshTransformer, and then pass in your list of descriptions as the texts keyword argument

ex.

transformer = MeshTransformer(
    autoencoder,
    dim = 512,
    max_seq_len = 768,
    condition_on_text = True
)


loss = transformer(
    vertices = vertices,
    faces = faces,
    texts = ['a high chair', 'a small teapot'],
)

loss.backward()

# after much training of transformer, you can now sample novel 3d assets conditioned on text

faces_coordinates, face_mask = transformer.generate(texts = ['a long table'])

If you want to tokenize meshes, for use in your multimodal transformer, simply invoke .tokenize on your autoencoder (or same method on autoencoder trainer instance for the exponentially smoothed model)


mesh_token_ids = autoencoder.tokenize(
    vertices = vertices,
    faces = faces
)

# (batch, num face vertices, residual quantized layer)

Todo

Citations

@inproceedings{Siddiqui2023MeshGPTGT,
    title   = {MeshGPT: Generating Triangle Meshes with Decoder-Only Transformers},
    author  = {Yawar Siddiqui and Antonio Alliegro and Alexey Artemov and Tatiana Tommasi and Daniele Sirigatti and Vladislav Rosov and Angela Dai and Matthias Nie{\ss}ner},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:265457242}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@inproceedings{Leviathan2022FastIF,
    title   = {Fast Inference from Transformers via Speculative Decoding},
    author  = {Yaniv Leviathan and Matan Kalman and Y. Matias},
    booktitle = {International Conference on Machine Learning},
    year    = {2022},
    url     = {https://api.semanticscholar.org/CorpusID:254096365}
}
@misc{yu2023language,
    title   = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation}, 
    author  = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
    year    = {2023},
    eprint  = {2310.05737},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@article{Lee2022AutoregressiveIG,
    title   = {Autoregressive Image Generation using Residual Quantization},
    author  = {Doyup Lee and Chiheon Kim and Saehoon Kim and Minsu Cho and Wook-Shin Han},
    journal = {2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    year    = {2022},
    pages   = {11513-11522},
    url     = {https://api.semanticscholar.org/CorpusID:247244535}
}
@inproceedings{Katsch2023GateLoopFD,
    title   = {GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling},
    author  = {Tobias Katsch},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:265018962}
}

.\lucidrains\meshgpt-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 执行版本文件中的代码,将版本信息导入当前环境
exec(open('meshgpt_pytorch/version.py').read())

# 设置包的元信息
setup(
  name = 'meshgpt-pytorch',  # 包名
  packages = find_packages(exclude=[]),  # 查找包
  version = __version__,  # 版本号
  license='MIT',  # 许可证
  description = 'MeshGPT Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/meshgpt-pytorch',  # URL
  keywords = [  # 关键词
    'artificial intelligence',
    'deep learning',
    'attention mechanisms',
    'transformers',
    'mesh generation'
  ],
  install_requires=[  # 安装依赖
    'accelerate>=0.25.0',
    'beartype',
    'classifier-free-guidance-pytorch>=0.5.1',
    'einops>=0.7.0',
    'einx[torch]>=0.1.3',
    'ema-pytorch',
    'local-attention>=1.9.0',
    'gateloop-transformer>=0.2.2',
    'numpy',
    'pytorch-custom-utils>=0.0.9',
    'taylor-series-linear-attention>=0.1.6',
    'torch>=2.1',
    'torch_geometric',
    'torchtyping',
    'tqdm',
    'vector-quantize-pytorch>=1.12.8',
    'x-transformers>=1.26.0',
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

.\lucidrains\metaformer-gpt\metaformer_gpt\autoregressive_wrapper.py

# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 从 torch 库中导入 nn 模块
from torch import nn

# 定义一个辅助函数,用于检查值是否存在
def exists(val):
    return val is not None

# 定义一个装饰器函数,用于在模型评估时切换模型状态
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 定义一个函数,用于对 logits 进行 top k 过滤
def top_k(logits, thres=0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float("-inf"))
    probs.scatter_(1, ind, val)
    return probs

# 定义一个自回归封装器类
class AutoregressiveWrapper(nn.Module):
    def __init__(self, net, max_seq_len=2048, pad_value=0):
        super().__init__()
        self.max_seq_len = max_seq_len
        self.pad_value = pad_value
        self.net = net

    # 生成序列的方法,使用 torch.no_grad() 装饰器和 eval_decorator 装饰器
    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        start_tokens,
        seq_len,
        eos_token=None,
        temperature=1.0,
        filter_thres=0.9,
        **kwargs
    ):
        b, t, device = *start_tokens.shape, start_tokens.device

        out = start_tokens

        for _ in range(seq_len):
            logits = self.net(out, **kwargs)[:, -1, :]

            filtered_logits = top_k(logits, thres=filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim=-1)

            sample = torch.multinomial(probs, 1)

            out = torch.cat((out, sample), dim=-1)

            if exists(eos_token):
                is_eos_token = out == eos_token

                if is_eos_token.any(dim=-1).all():
                    # mask out everything after the eos tokens
                    shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
                    mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1
                    out = out.masked_fill(mask, self.pad_value)
                    break

        out = out[:, t:]
        return out

    # 前向传播方法,计算交叉熵损失
    def forward(self, x, **kwargs):
        x_inp, x_labels = x[:, :-1], x[:, 1:]
        logits = self.net(x_inp, **kwargs)
        return F.cross_entropy(rearrange(logits, "b c n -> b n c"), x_labels)

.\lucidrains\metaformer-gpt\metaformer_gpt\metaformer_gpt.py

import torch
from torch import nn, einsum
from einops import rearrange, repeat

from scipy.fftpack import next_fast_len

# 辅助函数

def cummean(x, *, dim):
    # 计算累积均值
    numer = x.cumsum(dim = dim)
    denom = torch.arange(x.shape[1], device = x.device) + 1
    return numer / rearrange(denom, '... -> ... 1')

def conv1d_fft(x, weights, dim = -2, weight_dim = -1):
    # 使用傅立叶技巧进行 O(N log(N)) 1维卷积

    N = x.shape[dim]
    M = weights.shape[weight_dim]

    fast_len = next_fast_len(N + M - 1)

    # 对输入信号和权重进行傅立叶变换
    f_x = torch.fft.rfft(x, n = fast_len, dim = dim)
    f_weight = torch.fft.rfft(weights, n = fast_len, dim = weight_dim)

    # 计算频域乘积
    f_v_weight = f_x * rearrange(f_weight.conj(), '... -> ... 1')
    out = torch.fft.irfft(f_v_weight, fast_len, dim = dim)
    out = out.roll(-1, dims = (dim,))

    # 选择输出的部分
    indices = torch.arange(start = fast_len - N, end = fast_len, dtype = torch.long, device = x.device)
    out = out.index_select(dim, indices)
    return out

# 类

class MeanCenteringPool(nn.Module):
    def __init__(
        self,
        dim
    ):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.proj = nn.Linear(dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)
        x = cummean(x, dim = 1) - x
        return self.proj(x)

class MultiheadExponentialTimeDecay(nn.Module):
    def __init__(
        self,
        dim,
        *,
        heads = 8,
        dim_head = 64
    ):
        super().__init__()
        self.heads = heads
        inner_dim = heads * dim_head

        self.norm = nn.LayerNorm(dim)
        self.alpha = nn.Parameter(torch.randn(heads))

        self.project_in = nn.Linear(dim, inner_dim, bias = False)
        self.project_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        b, n, d, h, device = *x.shape, self.heads, x.device

        x = self.norm(x)

        # 线性投影

        x = self.project_in(x)

        # 分割头部

        x = rearrange(x, 'b n (h d) -> b h n d', h = h)

        # 准备指数 alpha

        alpha = self.alpha.sigmoid()
        alpha = rearrange(alpha, 'h -> h 1')

        # 计算权重

        arange = torch.arange(n, device = device)
        weights = alpha * (1 - alpha) ** torch.flip(arange, dims = (0,))
        output = conv1d_fft(x, weights)

        # 合并头部

        output = rearrange(output, 'b h n d -> b n (h d)')
        return self.project_out(output)

def FeedForward(dim, mult = 4):
    hidden_dim = int(dim * mult)
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, hidden_dim, bias = False),
        nn.GELU(),
        nn.Linear(hidden_dim, dim, bias = False)
    )

class MetaformerGPT(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        depth,
        heads = 16,
        dim_head = 32,
        max_seq_len = 2048,
        ff_mult = 4
    ):
        super().__init__()
        self.token_emb = nn.Embedding(num_tokens, dim)
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                MultiheadExponentialTimeDecay(dim, heads = heads, dim_head = dim_head),
                MeanCenteringPool(dim),
                FeedForward(dim, mult = ff_mult)
            ]))

        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_tokens, bias = False)
        )

    def forward(self, x):
        n, device = x.shape[1], x.device

        x = self.token_emb(x)
        x = x + self.pos_emb(torch.arange(n, device = device))

        for mh_esa, pool, ff in self.layers:
            x = mh_esa(x) + x
            x = pool(x) + x
            x = ff(x) + x

        return self.to_logits(x)

.\lucidrains\metaformer-gpt\metaformer_gpt\__init__.py

# 从 metaformer_gpt 包中导入 MetaformerGPT 和 MultiheadExponentialTimeDecay 类
from metaformer_gpt.metaformer_gpt import MetaformerGPT, MultiheadExponentialTimeDecay

Metaformer - GPT (wip)

Implementation of Metaformer, but in an autoregressive manner. In particular, they propose simply using mean centering as a way to do token mixing in a parameter-less fashion, alternating with feedforwards.

Install

$ pip install metaformer-gpt

Usage

import torch
from metaformer_gpt import MetaformerGPT

gpt = MetaformerGPT(
    num_tokens = 256,
    dim = 512,
    depth = 8
)

ids = torch.randint(0, 256, (1, 1024))
logits = gpt(ids) # (1, 1024, 256)

Citations

@article{Yu2021MetaFormerIA,
    title   = {MetaFormer is Actually What You Need for Vision},
    author  = {Weihao Yu and Mi Luo and Pan Zhou and Chenyang Si and Yichen Zhou and Xinchao Wang and Jiashi Feng and Shuicheng Yan},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2111.11418}
}
@misc{woo2022etsformer,
    title   = {ETSformer: Exponential Smoothing Transformers for Time-series Forecasting},
    author  = {Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven Hoi},
    year    = {2022},
    eprint  = {2202.01381},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
posted @ 2024-06-28 14:01  绝不原创的飞龙  阅读(14)  评论(0编辑  收藏  举报