Lucidrains-系列项目源码解析-三十五-

Lucidrains 系列项目源码解析(三十五)

.\lucidrains\q-transformer\q_transformer\q_robotic_transformer.py

# 从 random 模块导入 random 函数
from random import random
# 从 functools 模块导入 partial, cache 函数
from functools import partial, cache

# 导入 torch 模块
import torch
# 从 torch 模块中导入 F, nn, einsum, Tensor 等
import torch.nn.functional as F
import torch.distributed as dist
from torch.cuda.amp import autocast
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList

# 从 beartype 模块中导入 beartype 函数
from beartype import beartype
# 从 beartype.typing 模块中导入 Union, List, Optional, Callable, Tuple, Dict, Any 等
from beartype.typing import Union, List, Optional, Callable, Tuple, Dict, Any

# 从 einops 模块中导入 pack, unpack, repeat, reduce, rearrange 函数
from einops import pack, unpack, repeat, reduce, rearrange
# 从 einops.layers.torch 模块中导入 Rearrange, Reduce 类
from einops.layers.torch import Rearrange, Reduce

# 从 q_transformer.attend 模块中导入 Attend 类
from q_transformer.attend import Attend

# 从 classifier_free_guidance_pytorch 模块中导入 TextConditioner, AttentionTextConditioner, NullConditioner, classifier_free_guidance 函数
from classifier_free_guidance_pytorch import (
    TextConditioner,
    AttentionTextConditioner,
    NullConditioner,
    classifier_free_guidance
)

# helpers

# 定义函数 exists,判断值是否存在
def exists(val):
    return val is not None

# 定义函数 xnor,实现逻辑运算
def xnor(x, y):
    """ (True, True) or (False, False) -> True """
    return not (x ^ y)

# 定义函数 divisible_by,判断 num 是否能被 den 整除
def divisible_by(num, den):
    return (num % den) == 0

# 定义函数 default,返回 val 或默认值 d
def default(val, d):
    return val if exists(val) else d

# 定义函数 cast_tuple,将 val 转换为元组,长度为 length
def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# tensor helpers

# 定义函数 l2norm,对张量进行 L2 归一化
def l2norm(t, dim = -1):
    return F.normalize(t, dim = dim)

# 定义函数 pack_one,将 x 按照指定模式 pattern 进行打包
def pack_one(x, pattern):
    return pack([x], pattern)

# 定义函数 unpack_one,将 x 按照指定模式 pattern 进行解包
def unpack_one(x, ps, pattern):
    return unpack(x, ps, pattern)[0]

# 2d rotary positional embedding
# https://arxiv.org/abs/2104.09864

# 定义类 RotaryEmbedding,实现 2D 旋转位置嵌入
class RotaryEmbedding(Module):
    def __init__(self, dim, omega = 10000):
        super().__init__()
        inv_freq = 1.0 / (omega ** (torch.arange(0, dim, 4).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

    @autocast(enabled = False)
    def forward(self, height_width):
        device, dtype = self.inv_freq.device, self.inv_freq.dtype

        axial_pos = torch.arange(height_width, device = device).type(dtype)

        freqs = torch.einsum('i, j -> i j', axial_pos, self.inv_freq)
        freqs = repeat(freqs, '... f -> ... (f c)', c = 2)

        freqs = torch.broadcast_tensors(freqs[None, :, :], freqs[:, None, :])
        freqs = torch.cat(freqs, dim = -1)
        return rearrange(freqs, '... f -> (...) f')

# 定义函数 rotate_half,对张量进行旋转
def rotate_half(x):
    x1, x2 = rearrange(x, '... (d c) -> ... d c', c = 2).unbind(dim = -1)
    x = torch.stack((-x2, x1), dim = -1)
    return rearrange(x, '... d c -> ... (d c)')

@autocast(enabled = False)
# 定义函数 apply_rotary_pos_emb,应用旋转位置嵌入
def apply_rotary_pos_emb(pos, t):
    return t * pos.cos() + rotate_half(t) * pos.sin()

# sync batchnorm

# 使用缓存装饰器缓存结果
@cache
def get_is_distributed():
    return dist.is_initialized() and dist.get_world_size() > 1

# 定义函数 MaybeSyncBatchnorm2d,根据是否分布式返回 SyncBatchNorm 或 BatchNorm2d
def MaybeSyncBatchnorm2d(is_distributed = None):
    is_distributed = default(is_distributed, get_is_distributed())
    return nn.SyncBatchNorm if is_distributed else nn.BatchNorm2d

# channel rmsnorm

# 定义类 RMSNorm,实现 RMS 归一化
class RMSNorm(Module):
    def __init__(self, dim, affine = True):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim)) if affine else 1.

    def forward(self, x):
        return l2norm(x) * self.gamma * self.scale

# 定义类 ChanRMSNorm,实现通道 RMS 归一化
class ChanRMSNorm(Module):
    def __init__(self, dim, affine = True):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim, 1, 1)) if affine else 1.

    def forward(self, x):
        return l2norm(x, dim = 1) * self.gamma * self.scale

# sinusoidal positions

# 定义函数 posemb_sincos_1d,生成正弦余弦位置嵌入
def posemb_sincos_1d(seq, dim, temperature = 10000, device = None, dtype = torch.float32):
    n = torch.arange(seq, device = device)
    omega = torch.arange(dim // 2, device = device) / (dim // 2 - 1)
    omega = 1. / (temperature ** omega)

    n = n[:, None] * omega[None, :]
    pos_emb = torch.cat((n.sin(), n.cos()), dim = 1)
    return pos_emb.type(dtype)

# helper classes

# 定义类 Residual,实现残差连接
class Residual(Module):
    @beartype
    def __init__(self, fn: Module):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

# 定义类 FeedForward,实现前馈网络
class FeedForward(Module):
    def __init__(
        self,
        dim,
        mult = 4,
        dropout = 0.,
        adaptive_ln = False
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化自适应层归一化标志
        self.adaptive_ln = adaptive_ln

        # 计算内部维度
        inner_dim = int(dim * mult)
        # 初始化 RMS 归一化层
        self.norm = RMSNorm(dim, affine = not adaptive_ln)

        # 构建神经网络模型
        self.net = nn.Sequential(
            nn.Linear(dim, inner_dim),  # 线性层
            nn.GELU(),  # GELU 激活函数
            nn.Dropout(dropout),  # Dropout 层
            nn.Linear(inner_dim, dim),  # 线性层
            nn.Dropout(dropout)  # Dropout 层
        )

    def forward(
        self,
        x,
        cond_fn: Optional[Callable] = None
    ):
        # 对输入数据进行归一化
        x = self.norm(x)

        # 断言自适应层归一化和条件函数的存在
        assert xnor(self.adaptive_ln, exists(cond_fn))

        if exists(cond_fn):
            # 如果条件函数存在,则应用条件函数
            # 自适应层归一化
            x = cond_fn(x)

        return self.net(x)
# 定义 SqueezeExcitation 类,用于实现通道注意力机制
class SqueezeExcitation(Module):
    def __init__(self, dim, shrinkage_rate = 0.25):
        super().__init__()
        hidden_dim = int(dim * shrinkage_rate)

        # 定义通道注意力机制的结构
        self.gate = nn.Sequential(
            Reduce('b c h w -> b c', 'mean'),  # 对输入进行平均池化
            nn.Linear(dim, hidden_dim, bias = False),  # 线性变换
            nn.SiLU(),  # SiLU 激活函数
            nn.Linear(hidden_dim, dim, bias = False),  # 线性变换
            nn.Sigmoid(),  # Sigmoid 激活函数
            Rearrange('b c -> b c 1 1')  # 重排维度
        )

    def forward(self, x):
        return x * self.gate(x)  # 返回加权后的输出

# 定义 MBConvResidual 类,用于实现残差连接
class MBConvResidual(Module):
    def __init__(self, fn, dropout = 0.):
        super().__init__()
        self.fn = fn
        self.dropsample = Dropsample(dropout)  # 添加随机丢弃采样

    def forward(self, x):
        out = self.fn(x)
        out = self.dropsample(out)
        return out + x  # 返回残差连接后的结果

# 定义 Dropsample 类,用于实现随机丢弃采样
class Dropsample(Module):
    def __init__(self, prob = 0):
        super().__init__()
        self.prob = prob
  
    def forward(self, x):
        batch, device = x.shape[0], x.device

        if self.prob == 0. or (not self.training):
            return x

        keep_mask = torch.FloatTensor((batch, 1, 1, 1), device = device).uniform_() > self.prob
        return x * keep_mask / (1 - self.prob)  # 返回随机丢弃采样后的结果

# 定义 MBConv 函数,用于构建 MBConv 模块
def MBConv(
    dim_in,
    dim_out,
    *,
    downsample,
    expansion_rate = 4,
    shrinkage_rate = 0.25,
    dropout = 0.,
    is_distributed = None,
    use_layernorm = True
):
    hidden_dim = int(expansion_rate * dim_out)
    stride = 2 if downsample else 1

    if use_layernorm:
        norm_klass = ChanRMSNorm
    else:
        norm_klass = MaybeSyncBatchnorm2d(is_distributed)

    # 构建 MBConv 模块的网络结构
    net = nn.Sequential(
        nn.Conv2d(dim_in, hidden_dim, 1),
        norm_klass(hidden_dim),
        nn.GELU(),
        nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
        norm_klass(hidden_dim),
        nn.GELU(),
        SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
        nn.Conv2d(hidden_dim, dim_out, 1),
        norm_klass(dim_out)
    )

    if dim_in == dim_out and not downsample:
        net = MBConvResidual(net, dropout = dropout)  # 添加残差连接

    return net  # 返回构建好的 MBConv 模块

# 定义 Attention 类,用于实现注意力机制
class Attention(Module):
    def __init__(
        self,
        dim,
        heads = 8,
        dim_head = 32,
        dropout = 0.,
        window_size = 7,
        num_mem_kv = 4,
        flash = True
    ):
        super().__init__()
        dim_inner = dim_head * heads

        self.norm = RMSNorm(dim)
        self.heads = heads

        self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)  # 线性变换得到查询、键、值

        self.to_v_gates = nn.Sequential(
            nn.Linear(dim, self.heads),
            nn.Sigmoid(),
            Rearrange('b n h -> b h n 1')
        )

        self.attend = Attend(
            causal = False,
            dropout = dropout,
            flash = flash
        )

        self.to_out = nn.Sequential(
            nn.Linear(dim_inner, dim, bias = False),
            nn.Dropout(dropout)
        )

    def forward(
        self,
        x,
        rotary_emb = None
        # 解包输入张量的形状和设备信息
        batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads

        # 对输入张量进行归一化处理
        x = self.norm(x)

        # 展平输入张量
        x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')

        # 为查询、键、值进行投影
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)

        # 获取值的门控信息
        g = self.to_v_gates(x)

        # 将查询、键、值按头数进行分割
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 如果存在旋转位置编码,则应用到查询和键上
        if exists(rotary_emb):
            q = apply_rotary_pos_emb(rotary_emb, q)
            k = apply_rotary_pos_emb(rotary_emb, k)

        # 注意力机制
        out = self.attend(q, k, v)

        # 每个头部的值乘以门控信息,允许不关注某些值
        out = out * g

        # 合并头部
        out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width)

        # 合并头部输出
        out = self.to_out(out)
        return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width)
# 定义一个名为 MaxViT 的类,继承自 Module 类
class MaxViT(Module):
    # 初始化方法,接收多个参数
    @beartype
    def __init__(
        self,
        *,
        num_classes,  # 类别数量
        dim,  # 维度
        depth: Tuple[int, ...],  # 深度
        heads = 8,  # 头数
        dim_head = 64,  # 头的维度
        dim_conv_stem = None,  # 卷积层的维度
        window_size = 7,  # 窗口大小
        mbconv_expansion_rate = 4,  # 扩张率
        mbconv_shrinkage_rate = 0.25,  # 收缩率
        use_layernorm = True,  # 是否使用层归一化
        dropout = 0.1,  # 丢弃率
        channels = 3,  # 通道数
        flash_attn = True  # 是否使用闪存注意力
    ):
        # 调用父类的初始化方法
        super().__init__()
        
        # 卷积层
        dim_conv_stem = default(dim_conv_stem, dim)
        self.conv_stem = nn.Sequential(
            nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
            nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
        )
        
        # 变量
        num_stages = len(depth)
        dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
        dims = (dim_conv_stem, *dims)
        dim_pairs = tuple(zip(dims[:-1], dims[1:]))
        self.layers = ModuleList([])
        
        # 为了高效的块-网格式注意力,设置窗口大小
        self.window_size = window_size
        w = window_size
        
        # 旋转嵌入
        assert divisible_by(dim_head, 4), f'{dim_head} must be divisible by 4 for axial rotary embedding for maxvit'
        self.axial_rotary_emb = RotaryEmbedding(dim_head)
        self.register_buffer('cached_rotary_emb', self.axial_rotary_emb(window_size), persistent = False)
        
        # 遍历各个阶段
        cond_hidden_dims = []
        
        for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
            for stage_ind in range(layer_depth):
                is_first = stage_ind == 0
                stage_dim_in = layer_dim_in if is_first else layer_dim
                cond_hidden_dims.append(stage_dim_in)
                
                # 定义模块列表
                block = nn.ModuleList([
                    MBConv(
                        stage_dim_in,
                        layer_dim,
                        downsample = is_first,
                        expansion_rate = mbconv_expansion_rate,
                        shrinkage_rate = mbconv_shrinkage_rate,
                        use_layernorm = use_layernorm
                    ),
                    Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w),  # 块状注意力
                    Residual(Attention(dim = layer_dim, heads = heads, dim_head = dim_head, dropout = dropout, window_size = w, flash = flash_attn)),
                    Residual(FeedForward(dim = layer_dim, dropout = dropout)),
                    Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),
                    
                    Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w),  # 网格状注意力
                    Residual(Attention(dim = layer_dim, heads = heads, dim_head = dim_head, dropout = dropout, window_size = w, flash = flash_attn)),
                    Residual(FeedForward(dim = layer_dim, dropout = dropout)),
                    Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
                ])
                
                self.layers.append(block)
        
        embed_dim = dims[-1]
        self.embed_dim = dims[-1]
        self.cond_hidden_dims = cond_hidden_dims
        
        # MLP 头部输出
        self.mlp_head = nn.Sequential(
            Reduce('b d h w -> b d', 'mean'),
            RMSNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )
    
    # 前向传播方法
    @beartype
    def forward(
        self,
        img,  # 图像
        texts: Optional[List[str]] = None,  # 文本
        cond_fns: Optional[Tuple[Callable, ...]] = None,  # 条件函数
        cond_drop_prob = 0.,  # 条件丢弃概率
        return_embeddings = False  # 是否返回嵌入
        # 断言图像的最后两个维度是否都可以被窗口大小整除
        assert all([divisible_by(d, self.window_size) for d in img.shape[-2:])

        # 使用卷积层对输入图像进行处理
        x = self.conv_stem(img)

        # 获取缓存的旋转嵌入
        rotary_emb = self.cached_rotary_emb

        # 初始化条件函数迭代器
        cond_fns = iter(default(cond_fns, []))

        # 遍历模型的每一层
        for (
            mb_conv,
            rearr_windowed_in,
            windowed_attn,
            windowed_ff,
            rearr_windowed_out,
            rearr_grid_in,
            grid_attn,
            grid_ff,
            rearr_grid_out
        ) in self.layers:
            # 获取下一个条件函数
            cond_fn = next(cond_fns, None)

            # 如果存在条件函数,则对输入进行处理
            if exists(cond_fn):
                x = cond_fn(x)

            # 依次经过多个操作:多头卷积、重排窗口输入、窗口注意力、窗口前馈、重排窗口输出、重排网格输入、网格注意力、网格前馈、重排网格输出
            x = mb_conv(x)
            x = rearr_windowed_in(x)
            x = windowed_attn(x, rotary_emb = rotary_emb)
            x = windowed_ff(x)
            x = rearr_windowed_out(x)

            x = rearr_grid_in(x)
            x = grid_attn(x, rotary_emb = rotary_emb)
            x = grid_ff(x)
            x = rearr_grid_out(x)

        # 如果需要返回嵌入向量,则返回最终结果
        if return_embeddings:
            return x

        # 否则返回经过 MLP 头部处理后的结果
        return self.mlp_head(x)
# 定义 TransformerAttention 类,继承自 Module 类
class TransformerAttention(Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        dim_head = 64,
        dim_context = None,
        heads = 8,
        num_mem_kv = 4,
        norm_context = False,
        adaptive_ln = False,
        dropout = 0.1,
        flash = True,
        causal = False
    ):
        super().__init__()
        self.heads = heads
        inner_dim = dim_head * heads

        dim_context = default(dim_context, dim)

        self.adaptive_ln = adaptive_ln
        self.norm = RMSNorm(dim, affine = not adaptive_ln)

        self.context_norm = RMSNorm(dim_context) if norm_context else None

        self.attn_dropout = nn.Dropout(dropout)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)

        self.num_mem_kv = num_mem_kv
        self.mem_kv = None
        if num_mem_kv > 0:
            self.mem_kv = nn.Parameter(torch.randn(2, heads, num_mem_kv, dim_head))

        self.attend = Attend(
            dropout = dropout,
            flash = flash,
            causal = causal
        )

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            nn.Dropout(dropout)
        )

    # 前向传播函数
    def forward(
        self,
        x,
        context = None,
        mask = None,
        attn_mask = None,
        cond_fn: Optional[Callable] = None,
        cache: Optional[Tensor] = None,
        return_cache = False
    ):
        b = x.shape[0]

        assert xnor(exists(context), exists(self.context_norm))

        if exists(context):
            context = self.context_norm(context)

        kv_input = default(context, x)

        x = self.norm(x)

        assert xnor(exists(cond_fn), self.adaptive_ln)

        if exists(cond_fn):
            x = cond_fn(x)

        q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

        if exists(cache):
            ck, cv = cache
            k = torch.cat((ck, k), dim = -2)
            v = torch.cat((cv, v), dim = -2)

        new_kv_cache = torch.stack((k, v))

        if exists(self.mem_kv):
            mk, mv = map(lambda t: repeat(t, '... -> b ...', b = b), self.mem_kv)

            k = torch.cat((mk, k), dim = -2)
            v = torch.cat((mv, v), dim = -2)

            if exists(mask):
                mask = F.pad(mask, (self.num_mem_kv, 0), value = True)

            if exists(attn_mask):
                attn_mask = F.pad(attn_mask, (self.num_mem_kv, 0), value = True)

        out = self.attend(q, k, v, mask = mask, attn_mask = attn_mask)

        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)

        if not return_cache:
            return out

        return out, new_kv_cache

# 定义 Transformer 类,继承自 Module 类
class Transformer(Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        depth = 6,
        attn_dropout = 0.,
        ff_dropout = 0.,
        adaptive_ln = False,
        flash_attn = True,
        cross_attend = False,
        causal = False,
        final_norm = True
    ):
        super().__init__()
        self.layers = ModuleList([])

        attn_kwargs = dict(
            dim = dim,
            heads = heads,
            dim_head = dim_head,
            dropout = attn_dropout,
            flash = flash_attn
        )

        for _ in range(depth):
            self.layers.append(ModuleList([
                TransformerAttention(**attn_kwargs, causal = causal, adaptive_ln = adaptive_ln, norm_context = False),
                TransformerAttention(**attn_kwargs, norm_context = True) if cross_attend else None,
                FeedForward(dim = dim, dropout = ff_dropout, adaptive_ln = adaptive_ln)
            ]))

        self.norm = RMSNorm(dim) if final_norm else nn.Identity()

    @beartype
    # 定义一个前向传播函数,接受输入 x,条件函数列表 cond_fns,注意力掩码 attn_mask,上下文 context,缓存 cache,是否返回缓存 return_cache
    def forward(
        self,
        x,
        cond_fns: Optional[Tuple[Callable, ...]] = None,
        attn_mask = None,
        context: Optional[Tensor] = None,
        cache: Optional[Tensor] = None,
        return_cache = False
    ):
        # 检查是否存在缓存
        has_cache = exists(cache)

        # 如果存在缓存,将输入 x 分为前一部分 x_prev 和最后一部分 x
        if has_cache:
            x_prev, x = x[..., :-1, :], x[..., -1:, :]

        # 将条件函数列表和缓存转换为迭代器
        cond_fns = iter(default(cond_fns, []))
        cache = iter(default(cache, []))

        # 存储新的缓存
        new_caches = []

        # 遍历每个层中的注意力、可能的交叉注意力和前馈网络
        for attn, maybe_cross_attn, ff in self.layers:
            # 使用注意力模型计算输出和新的缓存
            attn_out, new_cache = attn(
                x,
                attn_mask = attn_mask,
                cond_fn = next(cond_fns, None),
                return_cache = True,
                cache = next(cache, None)
            )

            # 将新的缓存添加到列表中
            new_caches.append(new_cache)

            # 更新输入 x
            x = x + attn_out

            # 如果存在交叉注意力,确保上下文不为空,然后更新输入 x
            if exists(maybe_cross_attn):
                assert exists(context)
                x = maybe_cross_attn(x, context = context) + x

            # 使用前馈网络更新输入 x
            x = ff(x, cond_fn = next(cond_fns, None)) + x

        # 将新的缓存堆叠起来
        new_caches = torch.stack(new_caches)

        # 如果存在缓存,将 x_prev 和 x 拼接在一起
        if has_cache:
            x = torch.cat((x_prev, x), dim = -2)

        # 对输出进行归一化
        out = self.norm(x)

        # 如果不需要返回缓存,直接返回输出
        if not return_cache:
            return out

        # 如果需要返回缓存,同时返回输出和新的缓存
        return out, new_caches
# token learner module

class TokenLearner(Module):
    """
    https://arxiv.org/abs/2106.11297
    using the 1.1 version with the MLP (2 dense layers with gelu) for generating attention map
    """

    def __init__(
        self,
        *,
        dim,
        ff_mult = 2,
        num_output_tokens = 8,
        num_layers = 2
    ):
        # 初始化 TokenLearner 类
        super().__init__()
        inner_dim = dim * ff_mult * num_output_tokens

        self.num_output_tokens = num_output_tokens
        # 定义神经网络结构
        self.net = nn.Sequential(
            nn.Conv2d(dim * num_output_tokens, inner_dim, 1, groups = num_output_tokens),
            nn.GELU(),
            nn.Conv2d(inner_dim, num_output_tokens, 1, groups = num_output_tokens),
        )

    def forward(self, x):
        # 对输入数据进行打包
        x, ps = pack_one(x, '* c h w')
        x = repeat(x, 'b c h w -> b (g c) h w', g = self.num_output_tokens)
        # 使用神经网络进行前向传播
        attn = self.net(x)

        attn = rearrange(attn, 'b g h w -> b 1 g h w')
        x = rearrange(x, 'b (g c) h w -> b c g h w', g = self.num_output_tokens)

        # 计算均值
        x = reduce(x * attn, 'b c g h w -> b c g', 'mean')
        # 对数据进行解包
        x = unpack_one(x, ps, '* c n')
        return x

# Dueling heads for Q value

class DuelingHead(Module):
    def __init__(
        self,
        dim,
        expansion_factor = 2,
        action_bins = 256
    ):
        # 初始化 DuelingHead 类
        super().__init__()
        dim_hidden = dim * expansion_factor

        self.stem = nn.Sequential(
            nn.Linear(dim, dim_hidden),
            nn.SiLU()
        )

        self.to_values = nn.Sequential(
            nn.Linear(dim_hidden, 1)
        )

        self.to_advantages = nn.Sequential(
            nn.Linear(dim_hidden, action_bins)
        )

    def forward(self, x):
        x = self.stem(x)

        advantages = self.to_advantages(x)
        advantages = advantages - reduce(advantages, '... a -> ... 1', 'mean')

        values = self.to_values(x)

        q_values = values + advantages
        return q_values.sigmoid()

# Q head modules, for either single or multiple actions

class QHeadSingleAction(Module):
    def __init__(
        self,
        dim,
        *,
        num_learned_tokens = 8,
        action_bins = 256,
        dueling = False
    ):
        # 初始化 QHeadSingleAction 类
        super().__init__()
        self.action_bins = action_bins

        if dueling:
            self.to_q_values = nn.Sequential(
                Reduce('b (f n) d -> b d', 'mean', n = num_learned_tokens),
                DuelingHead(
                    dim,
                    action_bins = action_bins
                )
            )
        else:
            self.to_q_values = nn.Sequential(
                Reduce('b (f n) d -> b d', 'mean', n = num_learned_tokens),
                RMSNorm(dim),
                nn.Linear(dim, action_bins),
                nn.Sigmoid()
            )

    def get_random_actions(self, batch_size):
        return torch.randint(0, self.action_bins, (batch_size,), device = self.device)

    def get_optimal_actions(
        self,
        encoded_state,
        return_q_values = False,
        actions = None,
        **kwargs
    ):
        assert not exists(actions), 'single actions will never receive previous actions'

        q_values = self.forward(encoded_state)

        max_q, action_indices = q_values.max(dim = -1)

        if not return_q_values:
            return action_indices

        return action_indices, max_q

    def forward(self, encoded_state):
        return self.to_q_values(encoded_state)

class QHeadMultipleActions(Module):
    def __init__(
        self,
        dim,
        *,
        num_actions = 8,
        action_bins = 256,
        attn_depth = 2,
        attn_dim_head = 32,
        attn_heads = 8,
        dueling = False,
        weight_tie_action_bin_embed = False
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化动作数量和动作分箱数
        self.num_actions = num_actions
        self.action_bins = action_bins

        # 初始化动作分箱的嵌入参数
        self.action_bin_embeddings = nn.Parameter(torch.zeros(num_actions, action_bins, dim))
        # 使用正态分布初始化动作分箱的嵌入参数
        nn.init.normal_(self.action_bin_embeddings, std = 0.02)

        # 初始化线性层用于将维度转换为动作分箱数
        self.to_q_values = None
        if not weight_tie_action_bin_embed:
            self.to_q_values = nn.Linear(dim, action_bins)

        # 初始化Transformer模型
        self.transformer = Transformer(
            dim = dim,
            depth = attn_depth,
            dim_head = attn_dim_head,
            heads = attn_heads,
            cross_attend = True,
            adaptive_ln = False,
            causal = True,
            final_norm = True
        )

        # 初始化最终的归一化层
        self.final_norm = RMSNorm(dim)

        # 初始化是否使用dueling网络
        self.dueling = dueling
        if dueling:
            self.to_values = nn.Parameter(torch.zeros(num_actions, dim))

    @property
    def device(self):
        # 返回动作分箱嵌入参数所在的设备
        return self.action_bin_embeddings.device

    def maybe_append_actions(self, sos_tokens, actions: Optional[Tensor] = None):
        if not exists(actions):
            return sos_tokens

        batch, num_actions = actions.shape
        # 获取动作的嵌入参数
        action_embeddings = self.action_bin_embeddings[:num_actions]

        action_embeddings = repeat(action_embeddings, 'n a d -> b n a d', b = batch)
        past_action_bins = repeat(actions, 'b n -> b n 1 d', d = action_embeddings.shape[-1])

        bin_embeddings = action_embeddings.gather(-2, past_action_bins)
        bin_embeddings = rearrange(bin_embeddings, 'b n 1 d -> b n d')

        tokens, _ = pack((sos_tokens, bin_embeddings), 'b * d')
        tokens = tokens[:, :self.num_actions] # 最后一个动作分箱不需要用于提议的q-learning
        return tokens

    def get_q_values(self, embed):
        num_actions = embed.shape[-2]

        if exists(self.to_q_values):
            logits = self.to_q_values(embed)
        else:
            # 每个token预测下一个动作分箱
            action_bin_embeddings = self.action_bin_embeddings[:num_actions]
            action_bin_embeddings = torch.roll(action_bin_embeddings, shifts = -1, dims = 1)
            logits = einsum('b n d, n a d -> b n a', embed, action_bin_embeddings)

        if self.dueling:
            advantages = logits
            values = einsum('b n d, n d -> b n', embed, self.to_values[:num_actions])
            values = rearrange(values, 'b n -> b n 1')

            q_values = values + (advantages - reduce(advantages, '... a -> ... 1', 'mean'))
        else:
            q_values = logits

        return q_values.sigmoid()

    def get_random_actions(self, batch_size, num_actions = None):
        num_actions = default(num_actions, self.num_actions)
        return torch.randint(0, self.action_bins, (batch_size, num_actions), device = self.device)

    @torch.no_grad()
    def get_optimal_actions(
        self,
        encoded_state,
        return_q_values = False,
        actions: Optional[Tensor] = None,
        prob_random_action: float = 0.5,
        **kwargs
    ):
        # 断言随机动作概率在 [0, 1] 之间
        assert 0. <= prob_random_action <= 1.
        # 获取批次大小
        batch = encoded_state.shape[0]

        # 如果随机动作概率为1,则返回随机动作
        if prob_random_action == 1:
            return self.get_random_actions(batch)

        # 计算编码状态的均值作为起始符号
        sos_token = reduce(encoded_state, 'b ... d -> b 1 d', 'mean')
        # 可能附加动作到 tokens
        tokens = self.maybe_append_actions(sos_token, actions = actions)

        # 初始化动作 bins 和缓存
        action_bins = []
        cache = None

        # 遍历动作数量
        for action_idx in range(self.num_actions):

            # 使用 transformer 进行转换
            embed, cache = self.transformer(
                tokens,
                context = encoded_state,
                cache = cache,
                return_cache = True
            )

            # 获取最后一个嵌入向量
            last_embed = embed[:, action_idx]
            # 获取动作 bins 的嵌入向量
            bin_embeddings = self.action_bin_embeddings[action_idx]

            # 计算 Q 值
            q_values = einsum('b d, a d -> b a', last_embed, bin_embeddings)

            # 如果随机动作概率大于0
            if prob_random_action > 0.:
                # 创建随机掩码
                random_mask = torch.zeros_like(selected_action_bins).float().uniform_(0., 1.) < prob_random_action
                # 获取随机动作
                random_actions = self.get_random_actions(batch, 1)
                random_actions = rearrange(random_actions, '... 1 -> ...')

                # 根据随机掩码替换选定的动作 bins
                selected_action_bins = torch.where(
                    random_mask,
                    random_actions,
                    selected_action_bins
                )

            # 获取下一个动作的嵌入向量
            next_action_embed = bin_embeddings[selected_action_bins]

            # 更新 tokens
            tokens, _ = pack((tokens, next_action_embed), 'b * d')

            # 添加选定的动作 bins
            action_bins.append(selected_action_bins)

        # 将动作 bins 堆叠在一起
        action_bins = torch.stack(action_bins, dim = -1)

        # 如果不需要返回 Q 值,则返回动作 bins
        if not return_q_values:
            return action_bins

        # 获取所有 Q 值
        all_q_values = self.get_q_values(embed)
        return action_bins, all_q_values

    def forward(
        self,
        encoded_state: Tensor,
        actions: Optional[Tensor] = None
    ):
        """
        einops
        b - batch
        n - number of actions
        a - action bins
        d - dimension
        """

        # 计算编码状态的均值作为起始符号
        sos_token = reduce(encoded_state, 'b ... d -> b 1 d', 'mean')

        # 可能附加动作到 tokens
        tokens = self.maybe_append_actions(sos_token, actions = actions)

        # 使用 transformer 进行转换
        embed = self.transformer(tokens, context = encoded_state)

        # 返回 Q 值
        return self.get_q_values(embed)
# 定义一个名为 QRoboticTransformer 的类,继承自 Module 类
class QRoboticTransformer(Module):

    # 初始化方法,接受多个参数
    @beartype
    def __init__(
        self,
        *,
        vit: Union[Dict[str, Any], MaxViT],  # 接受一个字典或 MaxViT 类型的参数 vit
        num_actions = 8,                     # 默认参数 num_actions 为 8
        action_bins = 256,                   # 默认参数 action_bins 为 256
        depth = 6,                           # 默认参数 depth 为 6
        heads = 8,                           # 默认参数 heads 为 8
        dim_head = 64,                       # 默认参数 dim_head 为 64
        token_learner_ff_mult = 2,           # 默认参数 token_learner_ff_mult 为 2
        token_learner_num_layers = 2,       # 默认参数 token_learner_num_layers 为 2
        token_learner_num_output_tokens = 8, # 默认参数 token_learner_num_output_tokens 为 8
        cond_drop_prob = 0.2,                # 默认参数 cond_drop_prob 为 0.2
        use_attn_conditioner = False,        # 默认参数 use_attn_conditioner 为 False
        conditioner_kwargs: dict = dict(),   # 默认参数 conditioner_kwargs 为一个空字典
        dueling = False,                     # 默认参数 dueling 为 False
        flash_attn = True,                   # 默认参数 flash_attn 为 True
        condition_on_text = True,            # 默认参数 condition_on_text 为 True
        q_head_attn_kwargs: dict = dict(     # 默认参数 q_head_attn_kwargs 为一个字典
            attn_heads = 8,                  # 字典中的键值对
            attn_dim_head = 64,              # 字典中的键值对
            attn_depth = 2                   # 字典中的键值对
        ),
        weight_tie_action_bin_embed = True   # 默认参数 weight_tie_action_bin_embed 为 True
    ):
        super().__init__()  # 调用父类的初始化方法

        # 根据传入的 vit 参数类型进行处理
        if isinstance(vit, dict):
            vit = MaxViT(**vit)

        self.vit = vit  # 将处理后的 vit 赋值给实例变量

        self.num_vit_stages = len(vit.cond_hidden_dims)  # 计算 vit.cond_hidden_dims 的长度并赋值给实例变量

        attend_dim = vit.embed_dim  # 将 vit.embed_dim 赋值给 attend_dim

        # q-transformer 相关的动作嵌入

        assert num_actions >= 1  # 断言 num_actions 大于等于 1

        self.num_actions = num_actions  # 将 num_actions 赋值给实例变量
        self.is_single_action = num_actions == 1  # 判断 num_actions 是否等于 1,并将结果赋值给实例变量
        self.action_bins = action_bins  # 将 action_bins 赋值给实例变量

        # 条件

        self.condition_on_text = condition_on_text  # 将 condition_on_text 赋值给实例变量

        # 根据 condition_on_text 的值选择不同的条件器类
        if condition_on_text:
            conditioner_klass = AttentionTextConditioner if use_attn_conditioner else TextConditioner

            self.conditioner = conditioner_klass(
                hidden_dims = (*tuple(vit.cond_hidden_dims), *((attend_dim,) * depth * 2)),
                hiddens_channel_first = (*((True,) * self.num_vit_stages), *((False,) * depth * 2)),
                cond_drop_prob = cond_drop_prob,
                **conditioner_kwargs
            )
        else:
            self.conditioner = NullConditioner(hidden_dims = tuple())

        self.token_learner = TokenLearner(
            dim = vit.embed_dim,
            ff_mult = token_learner_ff_mult,
            num_output_tokens = token_learner_num_output_tokens,
            num_layers = token_learner_num_layers
        )

        self.num_learned_tokens = token_learner_num_output_tokens  # 将 token_learner_num_output_tokens 赋值给实例变量

        self.transformer_depth = depth  # 将 depth 赋值给实例变量

        self.transformer = Transformer(
            dim = attend_dim,
            dim_head = dim_head,
            heads = heads,
            depth = depth,
            flash_attn = flash_attn,
            adaptive_ln = condition_on_text,
            final_norm = True
        )

        self.cond_drop_prob = cond_drop_prob  # 将 cond_drop_prob 赋值给实例变量

        # Q 头

        # 根据 is_single_action 的值选择不同的 QHead 类
        if self.is_single_action:
            self.q_head = QHeadSingleAction(
                attend_dim,
                num_learned_tokens = self.num_learned_tokens,
                action_bins = action_bins,
                dueling = dueling
            )
        else:
            self.q_head = QHeadMultipleActions(
                attend_dim,
                action_bins = action_bins,
                dueling = dueling,
                weight_tie_action_bin_embed = weight_tie_action_bin_embed,
                **q_head_attn_kwargs
            )

    # 定义一个 device 属性,返回参数的设备信息
    @property
    def device(self):
        return next(self.parameters()).device

    # 获取随机动作的方法
    def get_random_actions(self, batch_size = 1):
        return self.q_head.get_random_actions(batch_size)

    # 嵌入文本的方法
    @beartype
    def embed_texts(self, texts: List[str]):
        return self.conditioner.embed_texts(texts)

    # 获取最优动作的方法
    @torch.no_grad()
    def get_optimal_actions(
        self,
        *args,
        return_q_values = False,
        actions: Optional[Tensor] = None,
        **kwargs
    ):
        encoded_state = self.encode_state(*args, **kwargs)
        return self.q_head.get_optimal_actions(encoded_state, return_q_values = return_q_values, actions = actions)
    # 获取动作函数,根据给定的视频数据和参数返回动作
    def get_actions(
        self,
        video,
        *args,
        prob_random_action = 0.,  # 否则在强化学习中称为 epsilon
        **kwargs,
    ):
        # 获取视频数据的批处理大小
        batch_size = video.shape[0]
        # 确保随机动作概率在 [0, 1] 之间
        assert 0. <= prob_random_action <= 1.

        # 如果随机数小于随机动作概率,则返回随机动作
        if random() < prob_random_action:
            return self.get_random_actions(batch_size = batch_size)

        # 否则返回最优动作
        return self.get_optimal_actions(video, *args, **kwargs)

    # 编码状态函数,根据视频数据、文本、动作等参数编码状态
    def encode_state(
        self,
        video: Tensor,
        texts: Optional[Union[List[str], Tuple[str]]] = None,
        text_embeds: Optional[Tensor] = None,
        actions: Optional[Tensor] = None,
        cond_drop_prob = 0.,
    ):
        """
        einops
        b - batch
        c - channels
        f - frames
        h - height
        w - width
        n - number of learned tokens
        """

        # 如果不是基于文本条件,则不应传入文本或文本嵌入
        if not self.condition_on_text:
            assert (not exists(texts) and not exists(text_embeds)), 'neither texts nor text embeds should be passed in'
        else:
            # 如果基于文本条件,则必须传入文本或文本嵌入
            assert exists(texts) ^ exists(text_embeds), 'either texts or text embeds must be passed in if conditioning on instructions'

        # 如果传入的文本是元组,则转换为列表
        if exists(texts) and isinstance(texts, tuple):
            texts = list(texts)

        # 构建文本条件参数字典
        text_cond_kwargs = dict(texts = texts, text_embeds = text_embeds)

        # 获取变换器深度和条件丢弃概率
        depth = self.transformer_depth
        cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

        # 获取视频帧数和设备信息
        frames, device = video.shape[2], video.device

        # 获取条件函数列表
        cond_fns, _ = self.conditioner(
            **text_cond_kwargs,
            cond_drop_prob = cond_drop_prob,
            repeat_batch = (*((frames,) * self.num_vit_stages), *((1,) * self.transformer_depth * 2))
        )

        # 分离视觉 Transformer 和 Transformer 条件函数
        vit_cond_fns, transformer_cond_fns = cond_fns[:-(depth * 2)], cond_fns[-(depth * 2):]

        # 重排视频数据维度
        video = rearrange(video, 'b c f h w -> b f c h w')
        images, packed_shape = pack_one(video, '* c h w')

        # 使用 ViT 模型获取 tokens
        tokens = self.vit(
            images,
            texts = texts,
            cond_fns = vit_cond_fns,
            cond_drop_prob = cond_drop_prob,
            return_embeddings = True
        )

        tokens = unpack_one(tokens, packed_shape, '* c h w')
        learned_tokens = self.token_learner(tokens)

        tokens_per_frame = learned_tokens.shape[-1]
        learned_tokens = rearrange(learned_tokens, 'b f c n -> b (f n) c')

        # 因果注意力掩码

        attn_mask = ~torch.ones((frames, frames), dtype = torch.bool, device = device).triu(1)
        attn_mask = repeat(attn_mask, 'i j -> (i r1) (j r2)', r1 = self.num_learned_tokens, r2 = self.num_learned_tokens)

        # 正弦位置嵌入

        pos_emb = posemb_sincos_1d(frames, learned_tokens.shape[-1], dtype = learned_tokens.dtype, device = learned_tokens.device)

        learned_tokens = learned_tokens + repeat(pos_emb, 'n d -> (n r) d', r = self.num_learned_tokens)

        # 注意力

        attended_tokens = self.transformer(learned_tokens, cond_fns = transformer_cond_fns, attn_mask = attn_mask)

        return attended_tokens

    # 前向传播函数,根据视频数据、文本、动作等参数执行前向传播
    @classifier_free_guidance
    def forward(
        self,
        video: Tensor,
        texts: Optional[List[str]] = None,
        text_embeds: Optional[Tensor] = None,
        actions: Optional[Tensor] = None,
        cond_drop_prob = 0.,
        # 将输入数据移动到与机器人变换器相同的设备上
        video = video.to(self.device)

        # 如果存在动作数据,则将其移动到与机器人变换器相同的设备上
        if exists(actions):
            actions = actions.to(self.device)

        # 对状态进行编码
        encoded_state = self.encode_state(
            video = video,
            texts = texts,
            text_embeds = text_embeds,
            actions = actions,
            cond_drop_prob = cond_drop_prob
        )

        # 返回 Q 值的头部
        # 支持单个和多个动作
        if self.is_single_action:
            # 对于单个动作的机器人变换器,不应传入动作数据
            assert not exists(actions), 'actions should not be passed in for single action robotic transformer'
            q_values = self.q_head(encoded_state)
        else:
            q_values = self.q_head(encoded_state, actions = actions)

        # 返回 Q 值
        return q_values

.\lucidrains\q-transformer\q_transformer\__init__.py

# 从 q_transformer.q_robotic_transformer 模块中导入 QRoboticTransformer 和 MaxViT 类
from q_transformer.q_robotic_transformer import (
    QRoboticTransformer,
    MaxViT
)

# 从 q_transformer.q_learner 模块中导入 QLearner 类
from q_transformer.q_learner import (
    QLearner
)

# 从 q_transformer.agent 模块中导入 Agent、ReplayMemoryDataset 和 BaseEnvironment 类
from q_transformer.agent import (
    Agent,
    ReplayMemoryDataset,
    BaseEnvironment
)

Q-transformer

Implementation of Q-Transformer, Scalable Offline Reinforcement Learning via Autoregressive Q-Functions, out of Google Deepmind

I will be keeping around the logic for Q-learning on single action just for final comparison with the proposed autoregressive Q-learning on multiple actions. Also to serve as education for myself and the public.

Install

$ pip install q-transformer

Usage

import torch

from q_transformer import (
    QRoboticTransformer,
    QLearner,
    Agent,
    ReplayMemoryDataset
)

# the attention model

model = QRoboticTransformer(
    vit = dict(
        num_classes = 1000,
        dim_conv_stem = 64,
        dim = 64,
        dim_head = 64,
        depth = (2, 2, 5, 2),
        window_size = 7,
        mbconv_expansion_rate = 4,
        mbconv_shrinkage_rate = 0.25,
        dropout = 0.1
    ),
    num_actions = 8,
    action_bins = 256,
    depth = 1,
    heads = 8,
    dim_head = 64,
    cond_drop_prob = 0.2,
    dueling = True
)

# you need to supply your own environment, by overriding BaseEnvironment

from q_transformer.mocks import MockEnvironment

env = MockEnvironment(
    state_shape = (3, 6, 224, 224),
    text_embed_shape = (768,)
)

# env.init()     should return instructions and initial state: Tuple[str, Tensor[*state_shape]]
# env(actions)   should return rewards, next state, and done flag: Tuple[Tensor[()], Tensor[*state_shape], Tensor[()]]

# agent is a class that allows the q-model to interact with the environment to generate a replay memory dataset for learning

agent = Agent(
    model,
    environment = env,
    num_episodes = 1000,
    max_num_steps_per_episode = 100,
)

agent()

# Q learning on the replay memory dataset on the model

q_learner = QLearner(
    model,
    dataset = ReplayMemoryDataset(),
    num_train_steps = 10000,
    learning_rate = 3e-4,
    batch_size = 4,
    grad_accum_every = 16,
)

q_learner()

# after much learning
# your robot should be better at selecting optimal actions

video = torch.randn(2, 3, 6, 224, 224)

instructions = [
    'bring me that apple sitting on the table',
    'please pass the butter'
]

actions = model.get_optimal_actions(video, instructions)

Appreciation

Todo

Citations

@inproceedings{qtransformer,
    title   = {Q-Transformer: Scalable Offline Reinforcement Learning via Autoregressive Q-Functions},
    authors = {Yevgen Chebotar and Quan Vuong and Alex Irpan and Karol Hausman and Fei Xia and Yao Lu and Aviral Kumar and Tianhe Yu and Alexander Herzog and Karl Pertsch and Keerthana Gopalakrishnan and Julian Ibarz and Ofir Nachum and Sumedh Sontakke and Grecia Salazar and Huong T Tran and Jodilyn Peralta and Clayton Tan and Deeksha Manjunath and Jaspiar Singht and Brianna Zitkovich and Tomas Jackson and Kanishka Rao and Chelsea Finn and Sergey Levine},
    booktitle = {7th Annual Conference on Robot Learning},
    year   = {2023}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}

.\lucidrains\q-transformer\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'q-transformer', # 包的名称
  packages = find_packages(exclude=[]), # 查找所有包
  version = '0.1.14', # 版本号
  license='MIT', # 许可证
  description = 'Q-Transformer', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  long_description_content_type = 'text/markdown', # 长描述内容类型
  url = 'https://github.com/lucidrains/q-transformer', # URL
  keywords = [ # 关键词
    'artificial intelligence',
    'deep learning',
    'attention mechanisms',
    'transformers',
    'q-learning'
  ],
  install_requires=[ # 安装依赖
    'accelerate',
    'beartype',
    'classifier-free-guidance-pytorch>=0.4.2',
    'einops>=0.7.0',
    'ema-pytorch>=0.3.1',
    'numpy',
    'torchtyping',
    'torch>=2.0'
  ],
  classifiers=[ # 分类
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\quartic-transformer\quartic_transformer\multi_stream_transformer.py

        """
        实现注意力机制的模块
        参数:
            dim - 输入特征的维度
            num_streams - 流的数量
            dim_head - 每个头的维度
            heads - 头的数量
            dropout - 丢弃率
            causal - 是否使用因果注意力
            pre_talking_heads - 是否使用预对话头
            post_talking_heads - 是否使用后对话头
            non_linear_talking_heads - 是否使用非线性对话头
        """
        super().__init__()
        dim_inner = dim_head * heads
        all_heads = num_streams * heads

        self.num_streams = num_streams

        # 将输入转换为查询、键、值
        self.to_qkv = nn.Sequential(
            nn.Linear(dim, dim_inner * 3, bias = False),
            Rearrange('b n (qkv h d) -> qkv b h n d', h = heads, qkv = 3)
        )

        # 生成门控值
        self.to_gates = nn.Sequential(
            nn.Linear(dim, heads),
            Rearrange('b n h -> b h n 1'),
            nn.Sigmoid()
        )

        # RMSNorm 归一化
        self.rmsnorm = einn.Norm('b... [d]', mean = False, bias = False)

        self.scale = dim_head ** 0.5
        self.causal = causal
        self.dropout = nn.Dropout(dropout)

        self.pre_talking_heads = None
        self.post_talking_heads = None

        # 根据参数选择是否使用非线性对话头
        if non_linear_talking_heads:
            self.pre_talking_heads = TalkingHeadsFeedForward(all_heads) if pre_talking_heads else None
            self.post_talking_heads = TalkingHeadsFeedForward(all_heads) if post_talking_heads else None
        else:
            # 根据参数选择是否使用卷积对话头
            self.pre_talking_heads = nn.Conv2d(all_heads, all_heads, 1, bias = False) if pre_talking_heads else None
            self.post_talking_heads = nn.Conv2d(all_heads, all_heads, 1, bias = False) if post_talking_heads else None

            # 初始化卷积对话头的权重
            nn.init.dirac_(self.pre_talking_heads.weight)
            nn.init.dirac_(self.post_talking_heads.weight)

        # 输出层
        self.to_out = nn.Sequential(
            Rearrange('b h n d -> b n (h d)'),
            nn.Linear(dim_inner, dim, bias = False),
            nn.Dropout(dropout)
        )
        ):
            # 获取输入张量 x 的流数
            s = self.num_streams
            # 对输入张量 x 进行均方根归一化
            x = self.rmsnorm(x)

            # 将输入张量 x 转换为查询、键、值张量
            q, k, v = self.to_qkv(x)

            # 对查询张量 q 进行缩放
            q = q * self.scale
            # 计算注意力矩阵
            sim = einsum('b h i d, b h j d -> b h i j', q, k)

            # 计算掩码值
            mask_value = -torch.finfo(sim.dtype).max

            # 如果存在预处理头部函数
            if exists(self.pre_talking_heads):
                # 重排注意力矩阵的维度
                sim = rearrange(sim, '(b s) h n d -> b (s h) n d', s = s)
                # 对注意力矩阵进行预处理
                sim = self.pre_talking_heads(sim)
                # 恢复注意力矩阵的维度
                sim = rearrange(sim, 'b (s h) n d -> (b s) h n d', s = s)

            # 如果存在掩码
            if exists(mask):
                # 根据掩码值对注意力矩阵进行处理
                sim = einx.where('b j, b ... j, ', mask, sim, mask_value)

            # 如果是因果注意力
            if self.causal:
                i, j = sim.shape[-2:]
                # 创建因果掩码
                causal_mask = torch.ones((i, j), dtype = torch.bool).triu(j - i + 1)
                sim = sim.masked_fill(causal_mask, mask_value)

            # 对注意力矩阵进行 softmax 操作
            attn = einx.softmax('b h i [j]', sim)

            # 保存 softmax 操作后的注意力矩阵
            post_softmax_attn = attn

            # 对注意力矩阵进行 dropout 操作
            attn = self.dropout(attn)

            # 如果存在后处理头部函数
            if exists(self.post_talking_heads):
                # 重排注意力矩阵的维度
                attn = rearrange(attn, '(b s) h n d -> b (s h) n d', s = s)
                # 对注意力矩阵进行后处理
                attn = self.post_talking_heads(attn)
                # 恢复注意力矩阵的维度
                attn = rearrange(attn, 'b (s h) n d -> (b s) h n d', s = s)

            # 计算输出张量
            out = einsum('b h i j, b h j d -> b h i d', attn, v)

            # 对输出张量进行门控操作
            out = out * self.to_gates(x)
            # 对输出张量进行输出转换
            out = self.to_out(out)

            # 返回输出张量和 softmax 操作后的注意力矩阵
            return out, post_softmax_attn
# 定义一个前馈神经网络模块
def FeedForward(dim, mult = 4, dropout = 0.):
    # 计算内部维度
    dim_inner = int(dim * mult)
    # 返回一个包含多个层的神经网络模块
    return nn.Sequential(
        # 归一化层,对输入进行归一化处理
        einn.Norm('b... [d]', mean = False, bias = False),
        # 全连接层,将输入维度转换为内部维度
        nn.Linear(dim, dim_inner, bias = False),
        # GELU激活函数
        nn.GELU(),
        # Dropout层,以一定概率丢弃部分神经元
        nn.Dropout(dropout),
        # 全连接层,将内部维度转换为输出维度
        nn.Linear(dim_inner, dim, bias = False)
    )

# 定义一个TalkingHeads前馈神经网络模块
def TalkingHeadsFeedForward(dim, mult = 2, dropout = 0.):
    # 计算内部维度
    dim_inner = int(dim * mult)
    # 创建一个包含多个层的神经网络模块
    net = nn.Sequential(
        # 归一化层,对输入进行归一化处理
        einn.Norm('b [c] ...', mean = False, bias = False),
        # 二维卷积层,将输入维度转换为内部维度
        nn.Conv2d(dim, dim_inner, 1, bias = False),
        # GELU激活函数
        nn.GELU(),
        # Dropout层,以一定概率丢弃部分神经元
        nn.Dropout(dropout),
        # 二维卷积层,将内部维度转换为输出维度
        nn.Conv2d(dim_inner, dim, 1, bias = False)
    )

    # 初始化最后一层的权重为零
    nn.init.zeros_(net[-1].weight)
    # 返回一个残差连接的神经网络模块
    return Residual(net)

# 定义TokenAndPosEmb类,用于处理共享的Token和位置嵌入
class TokenAndPosEmb(Module):
    def __init__(
        self,
        *,
        dim,
        num_tokens,
        max_seq_len,
        num_streams
    ):
        super().__init__()
        # 创建Token嵌入层
        self.token_emb = nn.Embedding(num_tokens, dim)
        # 创建位置嵌入层
        self.pos_emb = nn.Embedding(max_seq_len, dim)
        # 创建流嵌入参数
        self.stream_emb = nn.Parameter(torch.zeros(num_streams, dim))
        # 初始化流嵌入参数
        nn.init.normal_(self.stream_emb, std = 0.02)

    def forward(self, x):
        # 生成序列长度
        seq_len = torch.arange(x.shape[-1], device = x.device)
        # 获取Token嵌入
        token_emb = self.token_emb(x)
        # 获取位置嵌入
        pos_emb = self.pos_emb(seq_len)
        # 返回Token、位置和流嵌入的加和结果
        return einx.add('b n d, n d, s d -> (b s) n d', token_emb, pos_emb, self.stream_emb)

# 定义SeparateTokenAndPosEmb类,用于处理独立的Token和位置嵌入
class SeparateTokenAndPosEmb(Module):
    def __init__(
        self,
        *,
        dim,
        num_tokens,
        max_seq_len,
        num_streams
    ):
        super().__init__()
        # 创建独立的Token嵌入参数
        self.token_emb = nn.Parameter(torch.zeros(num_streams, num_tokens, dim))
        # 创建独立的位置嵌入参数
        self.pos_emb = nn.Parameter(torch.zeros(num_streams, max_seq_len, dim))
        # 初始化Token嵌入参数和位置嵌入参数
        nn.init.normal_(self.token_emb, std = 0.02)
        nn.init.normal_(self.pos_emb, std = 0.02)

    def forward(self, x):
        # 生成序列长度
        seq_len = torch.arange(x.shape[-1], device = x.device)
        # 获取Token嵌入
        token_emb = get_at('s [e] d, b n -> b s n d', self.token_emb, x)
        # 获取位置嵌入
        pos_emb = get_at('s [e] d, n -> s n d', self.pos_emb, x)
        # 返回Token和位置嵌入的加和结果
        return einx.add('b s n d, s n d -> (b s) n d', token_emb, pos_emb)

# 定义MultiStreamTransformer类,用于多流Transformer模型
class MultiStreamTransformer(Module):
    def __init__(
        self,
        *,
        dim,
        num_tokens,
        depth,
        num_streams = 2,
        dim_head = 64,
        heads = 8,
        max_seq_len = 2048,
        attn_dropout = 0.,
        ff_dropout = 0.,
        ff_mult = 4.,
        ablate_cross_stream_talking_heads = False,
        pre_talking_heads = True,
        post_talking_heads = True,
        separate_stream_emb = True,
        non_linear_talking_heads = False
    ):
        # 调用父类的构造函数
        super().__init__()
        # 根据是否需要分离流嵌入选择不同的嵌入类
        embed_klass = SeparateTokenAndPosEmb if separate_stream_emb else TokenAndPosEmb

        # 初始化嵌入层
        self.emb = embed_klass(
            dim = dim,
            num_tokens = num_tokens,
            num_streams = num_streams,
            max_seq_len = max_seq_len
        )

        # 设置流的数量
        self.num_streams = num_streams
        # 初始化层列表
        self.layers = ModuleList([])

        # 根据是否禁用跨流的交谈头选择不同的流数量
        talking_heads_num_streams = 2 if not ablate_cross_stream_talking_heads else 1

        # 根据深度循环创建多个注意力层和前馈层
        for _ in range(depth):
            self.layers.append(ModuleList([
                Attention(
                    dim = dim,
                    dim_head = dim_head,
                    heads = heads,
                    dropout = attn_dropout,
                    num_streams = talking_heads_num_streams,
                    pre_talking_heads = pre_talking_heads,
                    post_talking_heads = post_talking_heads,
                    non_linear_talking_heads = non_linear_talking_heads
                ),
                FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
            ]))

        # 定义输出层
        self.to_logits = nn.Sequential(
            Reduce('(b s) n d -> b n d', 'sum', s = num_streams),
            einn.Norm('b... [d]', mean = False, bias = False),
            nn.Linear(dim, num_tokens, bias = False)
        )

    def forward(
        self,
        x,
        mask = None,
        stream_attn_diversity_loss = False
    ):
        # 获取输入张量的形状和设备信息
        b, n, s, device = *x.shape, self.num_streams, x.device

        # 如果流的数量大于1,则计算流的注意力多样性损失
        stream_attn_diversity_loss &= s > 1

        # 对输入张量进行嵌入
        x = self.emb(x)

        # 存储每个注意力层的注意力矩阵
        attn_matrices = []

        # 遍历每个注意力层和前馈层
        for attn, ff in self.layers:
            # 计算注意力层的输出和后softmax的注意力矩阵
            attn_out, post_softmax_attn = attn(x, mask = mask)

            # 将后softmax的注意力矩阵添加到列表中
            attn_matrices.append(post_softmax_attn)

            # 更新输入张量
            x = x + attn_out
            x = ff(x) + x

        # 如果需要计算流的注意力多样性损失,则计算辅助损失
        if stream_attn_diversity_loss:
            aux_loss = sum([calc_stream_loss(attn_matrix, s).mean() for attn_matrix in attn_matrices])

        # 计算最终输出
        logits = self.to_logits(x)

        # 如果不需要计算流的注意力多样性损失,则直接返回logits
        if not stream_attn_diversity_loss:
            return logits

        # 如果需要计算流的注意力多样性损失,则返回logits和辅助损失
        return logits, aux_loss

.\lucidrains\quartic-transformer\quartic_transformer\quartic_transformer.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum 模块
from torch import nn, einsum
# 从 torch.nn 模块中导入 Module, ModuleList 类
from torch.nn import Module, ModuleList

# 从 einops 库中导入 rearrange, repeat, pack, unpack 函数
from einops import rearrange, repeat, pack, unpack
# 从 einops.layers.torch 模块中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 导入 einx 库
import einx
# 从 einx.nn.torch 模块中导入 einn 模块

# 导入 colt5_attention 模块中的 topk 函数

# 导入 taylor_series_linear_attention 模块中的 TaylorSeriesLinearAttn 类

# 从 x_transformers.x_transformers 模块中导入 DynamicPositionBias 类

# 定义辅助函数

# 判断变量是否存在的函数
def exists(v):
    return v is not None

# 返回默认值的函数
def default(v, d):
    return v if exists(v) else d

# 将张量打包成指定模式的函数
def pack_one(t, pattern):
    return pack([t], pattern)

# 将打包的张量解包成指定模式的函数
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 定义注意力机制类

class Attention(Module):
    def __init__(
        self,
        dim,
        dim_edges = None,
        dim_head = 64,
        heads = 8,
        dropout = 0.,
        causal = False,
        incorporate_edges = True
    ):
        super().__init__()
        dim_edges = default(dim_edges, dim)
        dim_inner = dim_head * heads

        # 定义 QKV 线性层和重排操作
        self.to_qkv = nn.Sequential(
            nn.Linear(dim, dim_inner * 3, bias = False),
            Rearrange('b n (qkv h d) -> qkv b h n d', h = heads, qkv = 3)
        )

        # 定义门控线性层和 Sigmoid 激活函数
        self.to_gates = nn.Sequential(
            nn.Linear(dim, heads),
            Rearrange('b n h -> b h n 1'),
            nn.Sigmoid()
        )

        # 定义 RMSNorm 层
        self.rmsnorm = einn.Norm('b... [d]', mean = False, bias = False)

        self.scale = dim_head ** 0.5
        self.causal = causal
        self.dropout = nn.Dropout(dropout)

        self.edges_to_attn_bias = None

        if incorporate_edges:
            # 定义边到注意力偏置的线性层和重排操作
            self.edges_to_attn_bias = nn.Sequential(
                einn.Norm('b... [d]', mean = False, bias = False),
                nn.Linear(dim_edges, heads),
                Rearrange('b i j h -> b h i j')
            )

        # 定义预处理头部的卷积层
        self.pre_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)

        self.to_edges_out = None

        if incorporate_edges:
            # 定义输出到边的线��层和重排操作
            self.to_edges_out = nn.Sequential(
                nn.Conv2d(heads, dim_edges, 1, bias = False),
                Rearrange('b d i j -> b i j d')
            )

        # 定义输出层
        self.to_out = nn.Sequential(
            Rearrange('b h n d -> b n (h d)'),
            nn.Linear(dim_inner, dim, bias = False),
            nn.Dropout(dropout)
        )

    def forward(
        self,
        x,
        mask = None,
        edges = None
    ):
        x = self.rmsnorm(x)

        q, k, v = self.to_qkv(x)

        q = q * self.scale
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        mask_value = -torch.finfo(sim.dtype).max

        if exists(edges) and exists(self.edges_to_attn_bias):
            attn_bias = self.edges_to_attn_bias(edges)
            sim = sim + attn_bias

        sim = self.pre_talking_heads(sim)

        if exists(mask):
            sim = einx.where('b j, b ... j, ', mask, sim, mask_value)

        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones((i, j), dtype = torch.bool).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, mask_value)

        attn = einx.softmax('b h i [j]', sim)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        out = out * self.to_gates(x)
        out = self.to_out(out)

        edges_out = None
        if exists(self.to_edges_out):
            edges_out = self.to_edges_out(attn)

        if not exists(edges_out):
            return out

        return out, edges_out

# 定义前馈神经网络类

def FeedForward(dim, mult = 4, dropout = 0.):
    dim_inner = int(dim * mult)
    return nn.Sequential(
        einn.Norm('b... [d]', mean = False, bias = False),
        nn.Linear(dim, dim_inner, bias = False),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(dim_inner, dim, bias = False)
    )

# 定义边嵌入类

class EdgeEmbed(Module):
    # 初始化函数,接受维度参数和可选的边缘维度参数
    def __init__(self, dim, dim_edges = None):
        # 调用父类的初始化函数
        super().__init__()
        # 如果没有提供边缘维度参数,则使用默认值为维度参数
        dim_edges = default(dim_edges, dim)
        # 创建一个线性层,将输入维度映射到边缘维度,不使用偏置
        self.to_rows = nn.Linear(dim, dim_edges, bias = False)
        # 创建另一个线性层,将输入维度映射到边缘维度,不使用偏置
        self.to_cols = nn.Linear(dim, dim_edges, bias = False)

        # 创建一个序列模块,包含一个线性层和一个 LayerNorm 层,用于处理边缘维度数据
        self.to_edges = nn.Sequential(
            nn.Linear(dim_edges, dim_edges, bias = False),
            nn.LayerNorm(dim_edges)
        )

    # 前向传播函数,接受输入张量 x
    def forward(self, x):
        # 将输入张量 x 映射到行维度
        rows = self.to_rows(x)
        # 将输入张量 x 映射到列维度
        cols = self.to_cols(x)
        # 对行和列的外积求和,得到四维张量
        outer_sum = einx.add('b i d, b j d -> b i j d', rows, cols)
        # 将外积求和结果传入边缘处理模块,返回处理后的结果
        return self.to_edges(outer_sum)
# 定义 AxialLinearAttention 类,用于实现轴向线性注意力机制
class AxialLinearAttention(Module):
    def __init__(
        self,
        dim,
        diagonal_attn = True,
        **attn_kwargs
    ):
        super().__init__()

        # 初始化行注意力机制
        self.row_attn = TaylorSeriesLinearAttn(dim = dim, gate_value_heads = True, prenorm = True, **attn_kwargs)
        # 初始化列注意力机制
        self.col_attn = TaylorSeriesLinearAttn(dim = dim, gate_value_heads = True, prenorm = True, **attn_kwargs)

        # 如果设置了对角线注意力机制,则初始化对角线注意力机制
        self.diagonal_attn = Attention(dim = dim, incorporate_edges = False, **attn_kwargs) if diagonal_attn else None

    # 前向传播函数
    def forward(
        self,
        x,
        mask = None
    ):
        # 获取输入张量 x 的形状信息
        b, n, device = *x.shape[:2], x.device

        # 重排输入张量 x 的维度
        x = rearrange(x, 'b i j d -> (b i) j d')

        # 对行进行注意力计算并更新 x
        x = self.row_attn(x, mask = mask) + x

        # 重排 x 的维度
        x = rearrange(x, '(b i) j d -> (b j) i d', b = b)

        # 对列进行注意力计算并更新 x
        x = self.col_attn(x, mask = mask) + x

        # 重排 x 的维度
        x = rearrange(x, '(b j) i d -> b i j d', b = b)

        # 如果没有对角线注意力机制,则直接返回 x
        if not exists(self.diagonal_attn):
            return x

        # 创建对角线掩码
        diagonal_mask = torch.eye(n, dtype = torch.bool, device = device)
        diagonal_mask = rearrange(diagonal_mask, 'i j -> 1 i j')

        # 从 x 中提取对角线元素
        x = rearrange(x[diagonal_mask], '(b n) d -> b n d', b = b)

        # 对对角线元素进行注意力计算并更新 x
        x = self.diagonal_attn(x) + x

        # 重新排列对角线掩码的维度
        diagonal_mask = rearrange(diagonal_mask, '... -> ... 1')
        # 使用对角线掩码更新 x
        x = x.masked_scatter(diagonal_mask, x)
        return x

# 定义 QuarticTransformer 类,用于实现四次方变换器
class QuarticTransformer(Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        depth,
        dim_edges = None,
        dim_head = 64,
        heads = 8,
        causal = False,
        linear_dim_head = 16,
        linear_heads = 16,
        ff_mult = 4,
        dropout = 0.,
        max_seq_len = 2048,
        ablate_edges = False,
        edges_diagonal_attn = True
    ):
        super().__init__()
        dim_edges = default(dim_edges, dim)

        # 初始化类的属性
        self.ablate_edges = ablate_edges
        self.max_seq_len = max_seq_len

        # 初始化 token embedding 和 position embedding
        self.token_emb = nn.Embedding(num_tokens, dim)
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        # 初始化动态相对位置偏置
        self.dynamic_rel_pos_bias = DynamicPositionBias(dim, depth = 2, heads = dim_edges)

        # 初始化边缘嵌入
        self.to_edge_emb = EdgeEmbed(dim, dim_edges)

        # 初始化层列表
        self.layers = ModuleList([])
        for _ in range(depth):
            self.layers.append(ModuleList([
                ModuleList([
                    Attention(dim = dim, dim_edges = dim_edges, dim_head = dim_head, heads = heads, dropout = dropout, causal = causal),
                    FeedForward(dim = dim, mult = ff_mult, dropout = dropout)
                ]),
                ModuleList([
                    AxialLinearAttention(dim = dim_edges, dim_head = linear_dim_head, heads = linear_heads, causal = causal, diagonal_attn = edges_diagonal_attn),
                    FeedForward(dim = dim_edges, mult = ff_mult)
                ])
            ]))

        # 初始化输出层
        self.to_logits = nn.Sequential(
            einn.Norm('b... [d]', mean = False, bias = False),
            nn.Linear(dim, num_tokens, bias = False)
        )

    # 前向传播函数
    def forward(
        self,
        x,
        mask = None
        ):
        # 获取输入张量的序列长度和设备信息
        seq_len, device = x.shape[-1], x.device
        # 断言序列长度不超过最大序列长度
        assert seq_len <= self.max_seq_len

        # 对输入张量进行 token embedding
        x = self.token_emb(x)

        # 添加位置编码
        x = x + self.pos_emb(torch.arange(seq_len, device=device))
        # 获取边的嵌入表示
        edges = self.to_edge_emb(x)

        # 计算动态相对位置偏置
        edges_rel_pos = self.dynamic_rel_pos_bias(seq_len, seq_len)
        # 将边的嵌入表示与动态相对位置偏置相加
        edges = einx.add('b i j d, d i j -> b i j d', edges, edges_rel_pos)

        # 初始化边的掩码
        edges_mask = None
        # 如果掩码存在,则更新边的掩码
        if exists(mask):
            edges_mask = einx.logical_and('b i, b j -> b (i j)', mask, mask)

        # 遍历每个层
        for (attn, ff), (edges_linear_attn, edges_ff,) in self.layers:

            # 使用注意力机制和前馈网络处理节点和边
            nodes_out, edges_out = attn(x, mask=mask, edges=edges if not self.ablate_edges else None)

            # 更新节点表示
            x = x + nodes_out
            x = ff(x) + x

            # 如果需要剔除边信息,则跳过
            if self.ablate_edges:
                continue

            # 更新边的表示
            edges = edges + edges_out

            # 线性变换边信息
            edges = edges_linear_attn(edges, mask=mask) + edges

            # 使用前馈网络处理边信息
            edges = edges_ff(edges) + edges

        # 返回最终的输出结果
        return self.to_logits(x)

.\lucidrains\quartic-transformer\quartic_transformer\__init__.py

# 从 quartic_transformer 包中导入 QuarticTransformer 类
from quartic_transformer.quartic_transformer import QuarticTransformer

# 从 quartic_transformer 包中导入 MultiStreamTransformer 类
from quartic_transformer.multi_stream_transformer import MultiStreamTransformer

Quartic Transformer (wip)

Exploring an idea where one forgets about efficiency and carries out attention on each edge of the nodes (tokens). You can think of it as doing attention on the attention matrix, taking the perspective of the attention matrix as all the directed edges of a fully connected graph.

The hypothesis is that there is a task out there that the (sub)quartic transformer can do that quadratic transformers cannot.

Will also contain a modified implementation of multistream transformer (which is not quartic, but number of streams times the quadratic).

Appreciation

Install

$ pip install quartic-transformer

Usage

import torch
from quartic_transformer import QuarticTransformer

model = QuarticTransformer(
    num_tokens = 256,
    depth = 2,
    dim = 512,
    dim_edges = 32
)

tokens = torch.randint(0, 256, (1, 128))

logits = model(tokens) # (1, 128, 256)

Todo

Citation

@inproceedings{Keles2022OnTC,
    title   = {On The Computational Complexity of Self-Attention},
    author  = {Feyza Duman Keles and Pruthuvi Maheshakya Wijewardena and Chinmay Hegde},
    booktitle = {International Conference on Algorithmic Learning Theory},
    year    = {2022},
    url     = {https://api.semanticscholar.org/CorpusID:252198880}
}
@article{Burtsev2021MultiStreamT,
    title   = {Multi-Stream Transformers},
    author  = {Mikhail S. Burtsev and Anna Rumshisky},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2107.10342},
    url     = {https://api.semanticscholar.org/CorpusID:236171087}
}
@misc{Sutton,
    title  = {The Bitter Lesson},
    url    = {http://www.incompleteideas.net/IncIdeas/BitterLesson.html},
    author = {Sutton, Rich}
}
@article{Shazeer2020TalkingHeadsA,
    title   = {Talking-Heads Attention},
    author  = {Noam M. Shazeer and Zhenzhong Lan and Youlong Cheng and Nan Ding and Le Hou},
    journal = {ArXiv},
    year    = {2020},
    volume  = {abs/2003.02436},
    url     = {https://api.semanticscholar.org/CorpusID:212414717}
}

.\lucidrains\quartic-transformer\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'quartic-transformer', # 包的名称
  packages = find_packages(exclude=[]), # 查找所有包
  version = '0.0.12', # 版本号
  license='MIT', # 许可证
  description = 'Quartic Transformer', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  long_description_content_type = 'text/markdown', # 长描述内容类型
  url = 'https://github.com/lucidrains/quartic-transformer', # URL
  keywords = [ # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformer',
    'attention'
  ],
  install_requires=[ # 安装依赖
    'colt5-attention',
    'einops>=0.7.0',
    'einx[torch]>=0.1.3',
    'taylor-series-linear-attention',
    'torch>=2.0',
    'x-transformers'
  ],
  classifiers=[ # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Recurrent Interface Network (RIN) - Pytorch

Implementation of Recurrent Interface Network (RIN), for highly efficient generation of images and video without cascading networks, in Pytorch. The author unawaredly reinvented the induced set-attention block from the set transformers paper. They also combine this with the self-conditioning technique from the Bit Diffusion paper, specifically for the latents. The last ingredient seems to be a new noise function based around the sigmoid, which the author claims is better than cosine scheduler for larger images.

The big surprise is that the generations can reach this level of fidelity. Will need to verify this on my own machine

Additionally, we will try adding an extra linear attention on the main branch as well as self conditioning in the pixel-space.

The insight of being able to self-condition on any hidden state of the network as well as the newly proposed sigmoid noise schedule are the two main findings.

This repository also contains the ability to noise higher resolution images more, using the scale keyword argument on the GaussianDiffusion class. It also contains the simple linear gamma schedule proposed in that paper.

Appreciation

  • Stability.ai for the generous sponsorship to work on cutting edge artificial intelligence research

Install

$ pip install rin-pytorch

Usage

from rin_pytorch import GaussianDiffusion, RIN, Trainer

model = RIN(
    dim = 256,                  # model dimensions
    image_size = 128,           # image size
    patch_size = 8,             # patch size
    depth = 6,                  # depth
    num_latents = 128,          # number of latents. they used 256 in the paper
    dim_latent = 512,           # can be greater than the image dimension (dim) for greater capacity
    latent_self_attn_depth = 4, # number of latent self attention blocks per recurrent step, K in the paper
).cuda()

diffusion = GaussianDiffusion(
    model,
    timesteps = 400,
    train_prob_self_cond = 0.9,  # how often to self condition on latents
    scale = 1.                   # this will be set to < 1. for more noising and leads to better convergence when training on higher resolution images (512, 1024) - input noised images will be auto variance normalized
).cuda()

trainer = Trainer(
    diffusion,
    '/path/to/your/images',
    num_samples = 16,
    train_batch_size = 4,
    gradient_accumulate_every = 4,
    train_lr = 1e-4,
    save_and_sample_every = 1000,
    train_num_steps = 700000,         # total training steps
    ema_decay = 0.995,                # exponential moving average decay
)

trainer.train()

Results will be saved periodically to the ./results folder

If you would like to experiment with the RIN and GaussianDiffusion class outside the Trainer

import torch
from rin_pytorch import RIN, GaussianDiffusion

model = RIN(
    dim = 256,                  # model dimensions
    image_size = 128,           # image size
    patch_size = 8,             # patch size
    depth = 6,                  # depth
    num_latents = 128,          # number of latents. they used 256 in the paper
    latent_self_attn_depth = 4, # number of latent self attention blocks per recurrent step, K in the paper
).cuda()

diffusion = GaussianDiffusion(
    model,
    timesteps = 1000,
    train_prob_self_cond = 0.9,
    scale = 1.
)

training_images = torch.randn(8, 3, 128, 128).cuda() # images are normalized from 0 to 1
loss = diffusion(training_images)
loss.backward()
# after a lot of training

sampled_images = diffusion.sample(batch_size = 4)
sampled_images.shape # (4, 3, 128, 128)

Todo

Citations

@misc{jabri2022scalable,
    title   = {Scalable Adaptive Computation for Iterative Generation}, 
    author  = {Allan Jabri and David Fleet and Ting Chen},
    year    = {2022},
    eprint  = {2212.11972},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@inproceedings{Chen2023OnTI,
    title   = {On the Importance of Noise Scheduling for Diffusion Models},
    author  = {Ting Chen},
    year    = {2023}
}
@article{Salimans2022ProgressiveDF,
    title   = {Progressive Distillation for Fast Sampling of Diffusion Models},
    author  = {Tim Salimans and Jonathan Ho},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2202.00512}
}
@misc{https://doi.org/10.48550/arxiv.2302.01327,
    doi     = {10.48550/ARXIV.2302.01327},
    url     = {https://arxiv.org/abs/2302.01327},
    author  = {Kumar, Manoj and Dehghani, Mostafa and Houlsby, Neil},
    title   = {Dual PatchNorm},
    publisher = {arXiv},
    year    = {2023},
    copyright = {Creative Commons Attribution 4.0 International}
}
@inproceedings{Hang2023EfficientDT,
    title   = {Efficient Diffusion Training via Min-SNR Weighting Strategy},
    author  = {Tiankai Hang and Shuyang Gu and Chen Li and Jianmin Bao and Dong Chen and Han Hu and Xin Geng and Baining Guo},
    year    = {2023}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@inproceedings{Hoogeboom2023simpleDE,
    title   = {simple diffusion: End-to-end diffusion for high resolution images},
    author  = {Emiel Hoogeboom and Jonathan Heek and Tim Salimans},
    year    = {2023}
}

.\lucidrains\recurrent-interface-network-pytorch\rin_pytorch\attend.py

# 导入所需的模块和类
from functools import wraps
from packaging import version
from collections import namedtuple

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, reduce

# 定义一个命名元组 FlashAttentionConfig,用于存储 FlashAttention 的配置信息
FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义一个辅助函数,用于检查值是否存在
def exists(val):
    return val is not None

# 定义一个装饰器函数,确保被装饰的函数只执行一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 定义一个打印函数,确保只打印一次
print_once = once(print)

# 主要类 Attend
class Attend(nn.Module):
    def __init__(
        self,
        dropout = 0.,
        flash = False
    ):
        super().__init__()
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.flash = flash
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定 CUDA 和 CPU 的高效注意力配置
        self.cpu_config = FlashAttentionConfig(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = FlashAttentionConfig(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = FlashAttentionConfig(False, True, True)

    # Flash Attention 方法
    def flash_attn(self, q, k, v, mask = None):
        _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

        # 检查是否存在 mask,并将其扩展到兼容的形状
        if exists(mask):
            mask = mask.expand(-1, heads, q_len, -1)

        # 检查是否有兼容的设备用于 Flash Attention
        config = self.cuda_config if is_cuda else self.cpu_config

        # 使用 torch.backends.cuda.sdp_kernel() 来执行 Flash Attention
        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask = mask,
                dropout_p = self.dropout if self.training else 0.
            )

        return out

    # 前向传播方法
    def forward(self, q, k, v, mask = None):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        q_len, k_len, device = q.shape[-2], k.shape[-2], q.device

        scale = q.shape[-1] ** -0.5

        if exists(mask) and mask.ndim != 4:
            mask = rearrange(mask, 'b j -> b 1 1 j')

        if self.flash:
            return self.flash_attn(q, k, v, mask = mask)

        # 相似度计算
        sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale

        # key padding mask
        if exists(mask):
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        # 注意力计算
        attn = sim.softmax(dim=-1)
        attn = self.attn_dropout(attn)

        # 聚合值
        out = einsum(f"b h i j, b h j d -> b h i d", attn, v)

        return out

.\lucidrains\recurrent-interface-network-pytorch\rin_pytorch\rin_pytorch.py

import math
from pathlib import Path
from random import random
from functools import partial
from multiprocessing import cpu_count

import torch
from torch import nn, einsum
from torch.special import expm1
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torch.optim import Adam
from torchvision import transforms as T, utils

from beartype import beartype

from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange

from rin_pytorch.attend import Attend

from PIL import Image
from tqdm.auto import tqdm
from ema_pytorch import EMA

from accelerate import Accelerator, DistributedDataParallelKwargs

# helpers functions

# 检查变量是否存在
def exists(x):
    return x is not None

# 返回输入值
def identity(x):
    return x

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 检查一个数是否可以被另一个数整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# 安全地进行除法运算
def safe_div(numer, denom, eps = 1e-10):
    return numer / denom.clamp(min = eps)

# 生成数据集的循环迭代器
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 检查一个数是否有整数平方根
def has_int_squareroot(num):
    num_sqrt = math.sqrt(num)
    return int(num_sqrt) == num_sqrt

# 将一个数分成若干组
def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

# 将图像转换为指定类型
def convert_image_to(img_type, image):
    if image.mode != img_type:
        return image.convert(img_type)
    return image

# 创建序列模块
def Sequential(*mods):
    return nn.Sequential(*filter(exists, mods))

# use layernorm without bias, more stable

# 自定义 LayerNorm 类
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# 自定义 MultiHeadedRMSNorm 类
class MultiHeadedRMSNorm(nn.Module):
    def __init__(self, dim, heads = 1):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(heads, 1, dim))

    def forward(self, x):
        return F.normalize(x, dim = -1) * self.scale * self.gamma

# positional embeds

# 自定义 LearnedSinusoidalPosEmb 类
class LearnedSinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        assert (dim % 2) == 0
        half_dim = dim // 2
        self.weights = nn.Parameter(torch.randn(half_dim))

    def forward(self, x):
        x = rearrange(x, 'b -> b 1')
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
        fouriered = torch.cat((x, fouriered), dim = -1)
        return fouriered

# 自定义 LinearAttention 类
class LinearAttention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 32,
        norm = False,
        qk_norm = False,
        time_cond_dim = None
    ):
        super().__init__()
        hidden_dim = dim_head * heads
        self.scale = dim_head ** -0.5
        self.heads = heads

        self.time_cond = None

        if exists(time_cond_dim):
            self.time_cond = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_cond_dim, dim * 2),
                Rearrange('b d -> b 1 d')
            )

            nn.init.zeros_(self.time_cond[-2].weight)
            nn.init.zeros_(self.time_cond[-2].bias)

        self.norm = LayerNorm(dim) if norm else nn.Identity()

        self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias = False)

        self.qk_norm = qk_norm
        if qk_norm:
            self.q_norm = MultiHeadedRMSNorm(dim_head, heads)
            self.k_norm = MultiHeadedRMSNorm(dim_head, heads)

        self.to_out = nn.Sequential(
            nn.Linear(hidden_dim, dim, bias = False),
            LayerNorm(dim)
        )

    def forward(
        self,
        x,
        time = None
        ):
        # 获取 self.heads 的值,表示注意力头的数量
        h = self.heads
        # 对输入 x 进行归一化处理
        x = self.norm(x)

        # 如果存在时间条件
        if exists(self.time_cond):
            # 确保时间存在
            assert exists(time)
            # 将时间条件应用到输入 x 上,得到缩放和偏移量
            scale, shift = self.time_cond(time).chunk(2, dim = -1)
            x = (x * (scale + 1)) + shift

        # 将输入 x 转换为查询、键、值,并分成三部分
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        # 如果需要对查询和键进行归一化
        if self.qk_norm:
            q = self.q_norm(q)
            k = self.k_norm(k)

        # 对查询和键进行 softmax 操作
        q = q.softmax(dim = -1)
        k = k.softmax(dim = -2)

        # 对查询结果乘以缩放因子
        q = q * self.scale

        # 计算上下文信息
        context = torch.einsum('b h n d, b h n e -> b h d e', k, v)

        # 计算输出
        out = torch.einsum('b h d e, b h n d -> b h n e', context, q)
        # 重新排列输出的维度
        out = rearrange(out, 'b h n d -> b n (h d)')
        # 将输出传递给输出层并返回结果
        return self.to_out(out)
# 定义注意力机制模块
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_context = None,
        heads = 4,
        dim_head = 32,
        norm = False,
        norm_context = False,
        time_cond_dim = None,
        flash = False,
        qk_norm = False
    ):
        super().__init__()
        hidden_dim = dim_head * heads
        dim_context = default(dim_context, dim)

        self.time_cond = None

        # 如果存在时间条件维度,创建时间条件模块
        if exists(time_cond_dim):
            self.time_cond = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_cond_dim, dim * 2),
                Rearrange('b d -> b 1 d')
            )

            nn.init.zeros_(self.time_cond[-2].weight)
            nn.init.zeros_(self.time_cond[-2].bias)

        self.scale = dim_head ** -0.5
        self.heads = heads

        # 根据是否需要归一化创建 LayerNorm 或者 nn.Identity
        self.norm = LayerNorm(dim) if norm else nn.Identity()
        self.norm_context = LayerNorm(dim_context) if norm_context else nn.Identity()

        # 创建线性变换层
        self.to_q = nn.Linear(dim, hidden_dim, bias = False)
        self.to_kv = nn.Linear(dim_context, hidden_dim * 2, bias = False)
        self.to_out = nn.Linear(hidden_dim, dim, bias = False)

        self.qk_norm = qk_norm
        # 如果需要对 Q 和 K 进行归一化,创建 MultiHeadedRMSNorm 对象
        if qk_norm:
            self.q_norm = MultiHeadedRMSNorm(dim_head, heads)
            self.k_norm = MultiHeadedRMSNorm(dim_head, heads)

        # 创建 Attend 对象
        self.attend = Attend(flash = flash)

    def forward(
        self,
        x,
        context = None,
        time = None
    ):
        h = self.heads

        # 如果存在上下文,对上下文进行归一化
        if exists(context):
            context = self.norm_context(context)

        x = self.norm(x)

        context = default(context, x)

        # 如果存在时间条件,对输入进行时间条件处理
        if exists(self.time_cond):
            assert exists(time)
            scale, shift = self.time_cond(time).chunk(2, dim = -1)
            x = (x * (scale + 1)) + shift

        qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        if self.qk_norm:
            q = self.q_norm(q)
            k = self.k_norm(k)

        out = self.attend(q, k, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义位置编码器模块
class PEG(nn.Module):
    def __init__(
        self,
        dim
    ):
        super().__init__()
        # 创建深度可分离卷积层
        self.ds_conv = nn.Conv2d(dim, dim, 3, padding = 1, groups = dim)

    def forward(self, x):
        b, n, d = x.shape
        hw = int(math.sqrt(n))
        x = rearrange(x, 'b (h w) d -> b d h w', h = hw)
        x = self.ds_conv(x)
        x = rearrange(x, 'b d h w -> b (h w) d')
        return x

# 定义前馈神经网络模块
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, time_cond_dim = None):
        super().__init__()
        self.norm = LayerNorm(dim)

        self.time_cond = None

        # 如果存在时间条件维度,创建时间条件模块
        if exists(time_cond_dim):
            self.time_cond = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_cond_dim, dim * 2),
                Rearrange('b d -> b 1 d')
            )

            nn.init.zeros_(self.time_cond[-2].weight)
            nn.init.zeros_(self.time_cond[-2].bias)

        inner_dim = int(dim * mult)
        # 创建前馈神经网络结构
        self.net = nn.Sequential(
            nn.Linear(dim, inner_dim),
            nn.GELU(),
            nn.Linear(inner_dim, dim)
        )

    def forward(self, x, time = None):
        x = self.norm(x)

        if exists(self.time_cond):
            assert exists(time)
            scale, shift = self.time_cond(time).chunk(2, dim = -1)
            x = (x * (scale + 1)) + shift

        return self.net(x)

# 定义 RINBlock 模块
class RINBlock(nn.Module):
    def __init__(
        self,
        dim,
        latent_self_attn_depth,
        dim_latent = None,
        final_norm = True,
        patches_self_attn = True,
        **attn_kwargs
    # 初始化函数,设置模型的各个组件
    def __init__(
        self,
        dim,
        dim_latent,
        latent_self_attn_depth,
        final_norm = False,
        patches_self_attn = False,
        **attn_kwargs
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 如果未指定隐藏层维度,则使用输入维度
        dim_latent = default(dim_latent, dim)

        # 将潜在特征向量关注到补丁上的注意力机制
        self.latents_attend_to_patches = Attention(dim_latent, dim_context = dim, norm = True, norm_context = True, **attn_kwargs)
        # 潜在特征向量的交叉注意力机制和前馈网络
        self.latents_cross_attn_ff = FeedForward(dim_latent)

        # 潜在特征向量的自注意力机制列表
        self.latent_self_attns = nn.ModuleList([])
        for _ in range(latent_self_attn_depth):
            self.latent_self_attns.append(nn.ModuleList([
                Attention(dim_latent, norm = True, **attn_kwargs),
                FeedForward(dim_latent)
            ]))

        # 最终潜在特征向量的归一化层
        self.latent_final_norm = LayerNorm(dim_latent) if final_norm else nn.Identity()

        # 补丁的位置编码
        self.patches_peg = PEG(dim)
        self.patches_self_attn = patches_self_attn

        # 如果开启了补丁的自注意力机制
        if patches_self_attn:
            # 补丁的自注意力机制和前馈网络
            self.patches_self_attn = LinearAttention(dim, norm = True, **attn_kwargs)
            self.patches_self_attn_ff = FeedForward(dim)

        # 补丁关注到潜在特征向量的注意力机制和前馈网络
        self.patches_attend_to_latents = Attention(dim, dim_context = dim_latent, norm = True, norm_context = True, **attn_kwargs)
        self.patches_cross_attn_ff = FeedForward(dim)

    # 前向传播函数
    def forward(self, patches, latents, t):
        # 对补丁进行位置编码
        patches = self.patches_peg(patches) + patches

        # 潜在特征向量从补丁中提取或聚类信息
        latents = self.latents_attend_to_patches(latents, patches, time = t) + latents

        # 潜在特征向量的交叉注意力机制和前馈网络
        latents = self.latents_cross_attn_ff(latents, time = t) + latents

        # 潜在特征向量的自注意力机制
        for attn, ff in self.latent_self_attns:
            latents = attn(latents, time = t) + latents
            latents = ff(latents, time = t) + latents

        # 如果开启了补丁的自注意力机制
        if self.patches_self_attn:
            # 补丁的额外自注意力机制
            patches = self.patches_self_attn(patches, time = t) + patches
            patches = self.patches_self_attn_ff(patches) + patches

        # 补丁关注到潜在特征向量的注意力机制
        patches = self.patches_attend_to_latents(patches, latents, time = t) + patches

        # 补丁的交叉注意力机制和前馈网络
        patches = self.patches_cross_attn_ff(patches, time = t) + patches

        # 最终潜在特征向量的归一化
        latents = self.latent_final_norm(latents)
        return patches, latents
# 定义 RIN(Recursive Image Network)类,继承自 nn.Module
class RIN(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        image_size,
        patch_size = 16,
        channels = 3,
        depth = 6,                      # RIN 块的数量
        latent_self_attn_depth = 2,     # 每轮从像素空间到潜在空间交叉注意力的自注意力数量
        dim_latent = None,              # 潜在空间的维度,默认为图像维度(dim)
        num_latents = 256,              # 为了获得良好结果,仍然需要使用相当数量的潜在空间(256),与 Deepmind 的 Perceiver 系列论文保持一致
        learned_sinusoidal_dim = 16,
        latent_token_time_cond = False, # 是否使用一个潜在令牌作为时间条件,或者采用自适应层归一化的方式(如其他论文“Paella” - Dominic Rampas 等所示)
        dual_patchnorm = True,
        patches_self_attn = True,       # 该存储库中的自注意力并不严格遵循论文中提出的设计。提供一种方法来移除它,以防它是不稳定的根源
        **attn_kwargs
        ):
        # 调用父类的构造函数
        super().__init__()
        # 断言图像大小能够被补丁大小整除
        assert divisible_by(image_size, patch_size)
        # 如果未指定 latent 维度,则使用默认的维度
        dim_latent = default(dim_latent, dim)

        # 设置图像大小和通道数(由于自条件,通道数乘以2)
        self.image_size = image_size
        self.channels = channels

        # 计算图像中的补丁数量和每个像素补丁的维度
        patch_height_width = image_size // patch_size
        num_patches = patch_height_width ** 2
        pixel_patch_dim = channels * (patch_size ** 2)

        # 时间条件

        # 学习的正弦位置嵌入
        sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim)
        time_dim = dim * 4
        fourier_dim = learned_sinusoidal_dim + 1

        self.latent_token_time_cond = latent_token_time_cond
        time_output_dim = dim_latent if latent_token_time_cond else time_dim

        # 时间 MLP
        self.time_mlp = nn.Sequential(
            sinu_pos_emb,
            nn.Linear(fourier_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_output_dim)
        )

        # 像素到补丁和反向

        self.to_patches = Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size),
            nn.LayerNorm(pixel_patch_dim * 2) if dual_patchnorm else None,
            nn.Linear(pixel_patch_dim * 2, dim),
            nn.LayerNorm(dim) if dual_patchnorm else None,
        )

        # 轴向位置嵌入,由 MLP 参数化

        pos_emb_dim = dim // 2

        self.axial_pos_emb_height_mlp = nn.Sequential(
            Rearrange('... -> ... 1'),
            nn.Linear(1, pos_emb_dim),
            nn.SiLU(),
            nn.Linear(pos_emb_dim, pos_emb_dim),
            nn.SiLU(),
            nn.Linear(pos_emb_dim, dim)
        )

        self.axial_pos_emb_width_mlp = nn.Sequential(
            Rearrange('... -> ... 1'),
            nn.Linear(1, pos_emb_dim),
            nn.SiLU(),
            nn.Linear(pos_emb_dim, pos_emb_dim),
            nn.SiLU(),
            nn.Linear(pos_emb_dim, dim)
        )

        # nn.Parameter(torch.randn(2, patch_height_width, dim) * 0.02)

        self.to_pixels = nn.Sequential(
            LayerNorm(dim),
            nn.Linear(dim, pixel_patch_dim),
            Rearrange('b (h w) (c p1 p2) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size, h = patch_height_width)
        )

        # 初始化 latent
        self.latents = nn.Parameter(torch.randn(num_latents, dim_latent))
        nn.init.normal_(self.latents, std = 0.02)

        self.init_self_cond_latents = nn.Sequential(
            FeedForward(dim_latent),
            LayerNorm(dim_latent)
        )

        nn.init.zeros_(self.init_self_cond_latents[-1].gamma)

        # 主要的 RIN 主体参数 - 另一个注意力即可时刻

        if not latent_token_time_cond:
            attn_kwargs = {**attn_kwargs, 'time_cond_dim': time_dim}

        # 创建 RINBlock 模块列表
        self.blocks = nn.ModuleList([RINBlock(dim, dim_latent = dim_latent, latent_self_attn_depth = latent_self_attn_depth, patches_self_attn = patches_self_attn, **attn_kwargs) for _ in range(depth)])

    @property
    def device(self):
        # 返回模型参数所在的设备
        return next(self.parameters()).device

    def forward(
        self,
        x,
        time,
        x_self_cond = None,
        latent_self_cond = None,
        return_latents = False
        ):
        # 获取输入张量的批量大小
        batch = x.shape[0]

        # 如果没有给定 latents 的条件,则使用全零张量作为 latents 的条件
        x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))

        # 在第二维度上连接 x_self_cond 和 x,得到新的输入张量 x
        x = torch.cat((x_self_cond, x), dim = 1)

        # 准备时间条件
        t = self.time_mlp(time)

        # 准备 latents
        latents = repeat(self.latents, 'n d -> b n d', b = batch)

        # 根据论文中的方法对 latents 进行初始化
        if exists(latent_self_cond):
            latents = latents + self.init_self_cond_latents(latent_self_cond)

        # 如果将时间条件视为一个 latents token 或用于自适应层归一化的尺度和偏移
        if self.latent_token_time_cond:
            t = rearrange(t, 'b d -> b 1 d')
            latents = torch.cat((latents, t), dim = -2)

        # 将输入 x 转换为 patches
        patches = self.to_patches(x)

        # 生成高度和宽度范围
        height_range = width_range = torch.linspace(0., 1., steps = int(math.sqrt(patches.shape[-2])), device = self.device)
        pos_emb_h, pos_emb_w = self.axial_pos_emb_height_mlp(height_range), self.axial_pos_emb_width_mlp(width_range)

        # 生成位置编码
        pos_emb = rearrange(pos_emb_h, 'i d -> i 1 d') + rearrange(pos_emb_w, 'j d -> 1 j d')
        patches = patches + rearrange(pos_emb, 'i j d -> (i j) d')

        # 循环执行递归接口网络的每个块
        for block in self.blocks:
            patches, latents = block(patches, latents, t)

        # 将 patches 转换为像素
        pixels = self.to_pixels(patches)

        # 如果不需要返回 latents,则直接返回像素
        if not return_latents:
            return pixels

        # 如果设置了 latent_token_time_cond,则移除时间条件 token
        if self.latent_token_time_cond:
            latents = latents[:, :-1]

        # 返回像素和 latents
        return pixels, latents
# 定义函数,将图像归一化到[-1, 1]范围
def normalize_img(x):
    return x * 2 - 1

# 定义函数,将图像反归一化
def unnormalize_img(x):
    return (x + 1) * 0.5

# 定义函数,将带噪声图像的方差归一化,如果比例不为1
def normalize_img_variance(x, eps = 1e-5):
    std = reduce(x, 'b c h w -> b 1 1 1', partial(torch.std, unbiased = False))
    return x / std.clamp(min = eps)

# 定义函数,计算输入张量的自然对数
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 定义函数,将输入张量的维度右侧填充到与另一个张量相同的维度
def right_pad_dims_to(x, t):
    padding_dims = x.ndim - t.ndim
    if padding_dims <= 0:
        return t
    return t.view(*t.shape, *((1,) * padding_dims))

# 定义简单线性调度函数
def simple_linear_schedule(t, clip_min = 1e-9):
    return (1 - t).clamp(min = clip_min)

# 定义余弦调度函数
def cosine_schedule(t, start = 0, end = 1, tau = 1, clip_min = 1e-9):
    power = 2 * tau
    v_start = math.cos(start * math.pi / 2) ** power
    v_end = math.cos(end * math.pi / 2) ** power
    output = math.cos((t * (end - start) + start) * math.pi / 2) ** power
    output = (v_end - output) / (v_end - v_start)
    return output.clamp(min = clip_min)

# 定义Sigmoid调度函数
def sigmoid_schedule(t, start = -3, end = 3, tau = 1, clamp_min = 1e-9):
    v_start = torch.tensor(start / tau).sigmoid()
    v_end = torch.tensor(end / tau).sigmoid()
    gamma = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
    return gamma.clamp_(min = clamp_min, max = 1.)

# 将gamma转换为alpha和sigma
def gamma_to_alpha_sigma(gamma, scale = 1):
    return torch.sqrt(gamma) * scale, torch.sqrt(1 - gamma)

# 将gamma转换为对数信噪比
def gamma_to_log_snr(gamma, scale = 1, eps = 1e-5):
    return log(gamma * (scale ** 2) / (1 - gamma), eps = eps)

# 定义高斯扩散类
@beartype
class GaussianDiffusion(nn.Module):
    def __init__(
        self,
        model: RIN,
        *,
        timesteps = 1000,
        use_ddim = True,
        noise_schedule = 'sigmoid',
        objective = 'v',
        schedule_kwargs: dict = dict(),
        time_difference = 0.,
        min_snr_loss_weight = True,
        min_snr_gamma = 5,
        train_prob_self_cond = 0.9,
        scale = 1.                      # this will be set to < 1. for better convergence when training on higher resolution images
    ):
        super().__init__()
        self.model = model
        self.channels = self.model.channels

        assert objective in {'x0', 'eps', 'v'}, 'objective must be either predict x0 or noise'
        self.objective = objective

        self.image_size = model.image_size

        if noise_schedule == "linear":
            self.gamma_schedule = simple_linear_schedule
        elif noise_schedule == "cosine":
            self.gamma_schedule = cosine_schedule
        elif noise_schedule == "sigmoid":
            self.gamma_schedule = sigmoid_schedule
        else:
            raise ValueError(f'invalid noise schedule {noise_schedule}')

        assert scale <= 1, 'scale must be less than or equal to 1'
        self.scale = scale
        self.maybe_normalize_img_variance = normalize_img_variance if scale < 1 else identity

        self.gamma_schedule = partial(self.gamma_schedule, **schedule_kwargs)

        self.timesteps = timesteps
        self.use_ddim = use_ddim

        self.time_difference = time_difference

        self.train_prob_self_cond = train_prob_self_cond

        self.min_snr_loss_weight = min_snr_loss_weight
        self.min_snr_gamma = min_snr_gamma

    @property
    def device(self):
        return next(self.model.parameters()).device
    # 获取采样时间步长
    def get_sampling_timesteps(self, batch, *, device):
        # 在设备上创建一个从1到0的等差数列,共self.timesteps+1个点
        times = torch.linspace(1., 0., self.timesteps + 1, device=device)
        # 将时间序列重复batch次
        times = repeat(times, 't -> b t', b=batch)
        # 将时间序列拆分成相邻时间对
        times = torch.stack((times[:, :-1], times[:, 1:]), dim=0)
        times = times.unbind(dim=-1)
        return times

    # 无需梯度计算
    @torch.no_grad()
    def ddpm_sample(self, shape, time_difference=None):
        batch, device = shape[0], self.device

        # 设置时间差值
        time_difference = default(time_difference, self.time_difference)

        # 获取采样时间对
        time_pairs = self.get_sampling_timesteps(batch, device=device)

        # 生成随机噪声图像
        img = torch.randn(shape, device=device)

        x_start = None
        last_latents = None

        # 遍历时间对
        for time, time_next in tqdm(time_pairs, desc='sampling loop time step', total=self.timesteps):

            # 添加时间延迟
            time_next = (time_next - self.time_difference).clamp(min=0.)

            noise_cond = time

            # 获取预测的 x0
            maybe_normalized_img = self.maybe_normalize_img_variance(img)
            model_output, last_latents = self.model(maybe_normalized_img, noise_cond, x_start, last_latents, return_latents=True)

            # 获取 log(snr)
            gamma = self.gamma_schedule(time)
            gamma_next = self.gamma_schedule(time_next)
            gamma, gamma_next = map(partial(right_pad_dims_to, img), (gamma, gamma_next))

            # 获取 alpha 和 sigma
            alpha, sigma = gamma_to_alpha_sigma(gamma)
            alpha_next, sigma_next = gamma_to_alpha_sigma(gamma_next)

            # 计算 x0 和噪声
            if self.objective == 'x0':
                x_start = model_output
            elif self.objective == 'eps':
                x_start = safe_div(img - sigma * model_output, alpha)
            elif self.objective == 'v':
                x_start = alpha * img - sigma * model_output

            # 限制 x0 的取值范围
            x_start.clamp_(-1., 1.)

            # 推导后验均值和方差
            log_snr, log_snr_next = map(gamma_to_log_snr, (gamma, gamma_next))
            c = -expm1(log_snr - log_snr_next)
            mean = alpha_next * (img * (1 - c) / alpha + c * x_start)
            variance = (sigma_next ** 2) * c
            log_variance = log(variance)

            # 获取噪声
            noise = torch.where(
                rearrange(time_next > 0, 'b -> b 1 1 1'),
                torch.randn_like(img),
                torch.zeros_like(img)
            )

            img = mean + (0.5 * log_variance).exp() * noise

        return unnormalize_img(img)

    # 无需梯度计算
    @torch.no_grad()
    # 从给定形状中获取批次和设备信息
    def ddim_sample(self, shape, time_difference = None):
        batch, device = shape[0], self.device

        # 设置时间差值为默认值或者给定值
        time_difference = default(time_difference, self.time_difference)

        # 获取采样时间步骤
        time_pairs = self.get_sampling_timesteps(batch, device = device)

        # 生成符合正态分布的随机张量
        img = torch.randn(shape, device = device)

        x_start = None
        last_latents = None

        # 遍历时间对
        for times, times_next in tqdm(time_pairs, desc = 'sampling loop time step'):

            # 获取时间和噪声水平
            gamma = self.gamma_schedule(times)
            gamma_next = self.gamma_schedule(times_next)

            # 将噪声水平填充到与图像相同的维度
            padded_gamma, padded_gamma_next = map(partial(right_pad_dims_to, img), (gamma, gamma_next))

            # 将噪声水平转换为 alpha 和 sigma
            alpha, sigma = gamma_to_alpha_sigma(padded_gamma)
            alpha_next, sigma_next = gamma_to_alpha_sigma(padded_gamma_next)

            # 添加时间延迟
            times_next = (times_next - time_difference).clamp(min = 0.)

            # 预测 x0
            maybe_normalized_img = self.maybe_normalize_img_variance(img)
            model_output, last_latents = self.model(maybe_normalized_img, times, x_start, last_latents, return_latents = True)

            # 计算 x0 和噪声
            if self.objective == 'x0':
                x_start = model_output
            elif self.objective == 'eps':
                x_start = safe_div(img - sigma * model_output, alpha)
            elif self.objective == 'v':
                x_start = alpha * img - sigma * model_output

            # 限制 x0 的取值范围
            x_start.clamp_(-1., 1.)

            # 获取预测的噪声
            pred_noise = safe_div(img - alpha * x_start, sigma)

            # 计算下一个图像
            img = x_start * alpha_next + pred_noise * sigma_next

        # 返回未归一化的图像
        return unnormalize_img(img)

    # 无需梯度计算的函数装饰器
    @torch.no_grad()
    # 生成样本
    def sample(self, batch_size = 16):
        image_size, channels = self.image_size, self.channels
        # 根据是否使用 DDIM 选择采样函数
        sample_fn = self.ddpm_sample if not self.use_ddim else self.ddim_sample
        return sample_fn((batch_size, channels, image_size, image_size))
    # 定义一个前向传播函数,接受图像和其他参数
    def forward(self, img, *args, **kwargs):
        # 解包图像的形状和设备信息
        batch, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
        # 断言图像的高度和宽度必须为指定的图像大小
        assert h == img_size and w == img_size, f'height and width of image must be {img_size}'

        # 生成随机时间采样
        times = torch.zeros((batch,), device=device).float().uniform_(0, 1.)

        # 将图像转换为比特表示
        img = normalize_img(img)

        # 生成噪声样本
        noise = torch.randn_like(img)

        # 计算 gamma 值
        gamma = self.gamma_schedule(times)
        padded_gamma = right_pad_dims_to(img, gamma)
        alpha, sigma = gamma_to_alpha_sigma(padded_gamma, self.scale)

        # 添加噪声到图像
        noised_img = alpha * img + sigma * noise

        # 可能对图像进行归一化处理
        noised_img = self.maybe_normalize_img_variance(noised_img)

        # 在论文中,他们必须使用非常高的概率进行潜在的自我条件,高达 90% 的时间
        # 稍微有点缺点
        self_cond = self_latents = None

        if random() < self.train_prob_self_cond:
            with torch.no_grad():
                model_output, self_latents = self.model(noised_img, times, return_latents=True)
                self_latents = self_latents.detach()

                if self.objective == 'x0':
                    self_cond = model_output

                elif self.objective == 'eps':
                    self_cond = safe_div(noised_img - sigma * model_output, alpha)

                elif self.objective == 'v':
                    self_cond = alpha * noised_img - sigma * model_output

                self_cond.clamp_(-1., 1.)
                self_cond = self_cond.detach()

        # 预测并进行梯度下降步骤
        pred = self.model(noised_img, times, self_cond, self_latents)

        if self.objective == 'eps':
            target = noise

        elif self.objective == 'x0':
            target = img

        elif self.objective == 'v':
            target = alpha * noise - sigma * img

        # 计算损失
        loss = F.mse_loss(pred, target, reduction='none')
        loss = reduce(loss, 'b ... -> b', 'mean')

        # 最小信噪比损失权重
        snr = (alpha * alpha) / (sigma * sigma)
        maybe_clipped_snr = snr.clone()

        if self.min_snr_loss_weight:
            maybe_clipped_snr.clamp_(max=self.min_snr_gamma)

        if self.objective == 'eps':
            loss_weight = maybe_clipped_snr / snr

        elif self.objective == 'x0':
            loss_weight = maybe_clipped_snr

        elif self.objective == 'v':
            loss_weight = maybe_clipped_snr / (snr + 1)

        return (loss * loss_weight).mean()
# dataset classes

# 定义 Dataset 类,继承自 torch.utils.data.Dataset
class Dataset(Dataset):
    # 初始化函数
    def __init__(
        self,
        folder,  # 数据集文件夹路径
        image_size,  # 图像大小
        exts = ['jpg', 'jpeg', 'png', 'tiff'],  # 图像文件扩展名列表
        augment_horizontal_flip = False,  # 是否进行水平翻转增强
        convert_image_to = None  # 图像转换函数
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        # 获取文件夹中指定扩展名的所有文件路径
        self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

        # 部分应用转换函数
        maybe_convert_fn = partial(convert_image_to, convert_image_to) if exists(convert_image_to) else nn.Identity()

        # 图像转换操作序列
        self.transform = T.Compose([
            T.Lambda(maybe_convert_fn),
            T.Resize(image_size),
            T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

    # 返回数据集长度
    def __len__(self):
        return len(self.paths)

    # 获取指定索引处的数据
    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)

# trainer class

# 定义 Trainer 类
@beartype
class Trainer(object):
    # 初始化函数
    def __init__(
        self,
        diffusion_model: GaussianDiffusion,  # 扩散模型
        folder,  # 数据集文件夹路径
        *,
        train_batch_size = 16,  # 训练批量大小
        gradient_accumulate_every = 1,  # 梯度累积步数
        augment_horizontal_flip = True,  # 是否进行水平翻转增强
        train_lr = 1e-4,  # 训练学习率
        train_num_steps = 100000,  # 训练步数
        max_grad_norm = 1.,  # 梯度裁剪阈值
        ema_update_every = 10,  # EMA 更新频率
        ema_decay = 0.995,  # EMA 衰减率
        betas = (0.9, 0.99),  # Adam 优化器的 beta 参数
        save_and_sample_every = 1000,  # 保存和采样频率
        num_samples = 25,  # 采样数量
        results_folder = './results',  # 结果保存文件夹路径
        amp = False,  # 是否使用混合精度训练
        mixed_precision_type = 'fp16',  # 混合精度类型
        split_batches = True,  # 是否拆分批次
        convert_image_to = None  # 图像转换函数
    ):
        super().__init__()

        # 初始化加速器
        self.accelerator = Accelerator(
            split_batches = split_batches,
            mixed_precision = mixed_precision_type if amp else 'no',
            kwargs_handlers = [DistributedDataParallelKwargs(find_unused_parameters=True)]
        )

        # 设置扩散模型
        self.model = diffusion_model

        # 检查采样数量是否有整数平方根
        assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'
        self.num_samples = num_samples
        self.save_and_sample_every = save_and_sample_every

        self.batch_size = train_batch_size
        self.gradient_accumulate_every = gradient_accumulate_every
        self.max_grad_norm = max_grad_norm

        self.train_num_steps = train_num_steps
        self.image_size = diffusion_model.image_size

        # 数据集和数据加载器

        # 创建数据集对象
        self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip, convert_image_to = convert_image_to)
        # 创建数据加载器
        dl = DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())

        # 准备数据加载器
        dl = self.accelerator.prepare(dl)
        self.dl = cycle(dl)

        # 优化器

        # 创建 Adam 优化器
        self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = betas)

        # 定期记录结果到文件夹

        self.results_folder = Path(results_folder)

        if self.accelerator.is_local_main_process:
            self.results_folder.mkdir(exist_ok = True)

        if self.accelerator.is_main_process:
            self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every)

        # 步数计数器状态

        self.step = 0

        # 准备模型、数据加载器、优化器与加速器

        self.model, self.opt = self.accelerator.prepare(self.model, self.opt)

    # 保存模型
    def save(self, milestone):
        if not self.accelerator.is_local_main_process:
            return

        data = {
            'step': self.step + 1,
            'model': self.accelerator.get_state_dict(self.model),
            'opt': self.opt.state_dict(),
            'ema': self.ema.state_dict(),
            'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None
        }

        torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))
    # 加载指定里程碑的模型数据
    def load(self, milestone):
        # 从文件中加载模型数据
        data = torch.load(str(self.results_folder / f'model-{milestone}.pt'))

        # 获取未加速的模型对象
        model = self.accelerator.unwrap_model(self.model)
        # 加载模型的状态字典
        model.load_state_dict(data['model'])

        # 设置当前训练步数
        self.step = data['step']
        # 加载优化器的状态字典
        self.opt.load_state_dict(data['opt'])

        # 如果是主进程,则加载指数移动平均模型的状态字典
        if self.accelerator.is_main_process:
            self.ema.load_state_dict(data['ema'])

        # 如果加速器和数据中都存在缩放器状态字典,则加载缩放器的状态字典
        if exists(self.accelerator.scaler) and exists(data['scaler']):
            self.accelerator.scaler.load_state_dict(data['scaler'])

    # 训练模型
    def train(self):
        # 获取加速器和设备
        accelerator = self.accelerator
        device = accelerator.device

        # 使用 tqdm 显示训练进度
        with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar:

            # 在未达到训练步数上限前循环训练
            while self.step < self.train_num_steps:

                total_loss = 0.

                # 根据梯度累积次数循环执行训练步骤
                for _ in range(self.gradient_accumulate_every):
                    # 获取下一个数据批次并发送到设备
                    data = next(self.dl).to(device)

                    # 使用自动混合精度计算模型损失
                    with accelerator.autocast():
                        loss = self.model(data)
                        loss = loss / self.gradient_accumulate_every
                        total_loss += loss.item()

                    # 反向传播计算梯度
                    accelerator.backward(loss)

                # 更新进度条显示当前损失值
                pbar.set_description(f'loss: {total_loss:.4f}')

                # 等待所有进程完成当前步骤
                accelerator.wait_for_everyone()
                # 对模型参数进行梯度裁剪
                accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

                # 执行优化器的一步更新
                self.opt.step()
                # 清空梯度
                self.opt.zero_grad()

                # 等待所有进程完成当前步骤
                accelerator.wait_for_everyone()

                # 在每个本地主进程上保存里程碑,仅在全局主进程上采样
                if accelerator.is_local_main_process:
                    milestone = self.step // self.save_and_sample_every
                    save_and_sample = self.step != 0 and self.step % self.save_and_sample_every == 0
                    
                    if accelerator.is_main_process:
                        # 将指数移动平均模型发送到设备
                        self.ema.to(device)
                        # 更新指数移动平均模型
                        self.ema.update()

                        if save_and_sample:
                            # 将指数移动平均模型设置为评估模式
                            self.ema.ema_model.eval()

                            with torch.no_grad():
                                # 将样本数量分组并生成样本图像
                                batches = num_to_groups(self.num_samples, self.batch_size)
                                all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches))

                            all_images = torch.cat(all_images_list, dim = 0)
                            # 保存生成的样本图像
                            utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = int(math.sqrt(self.num_samples)))

                    if save_and_sample:
                        # 保存当前里程碑的模型数据
                        self.save(milestone)

                # 更新训练步数并更新进度条
                self.step += 1
                pbar.update(1)

        # 打印训练完成信息
        accelerator.print('training complete')

.\lucidrains\recurrent-interface-network-pytorch\rin_pytorch\__init__.py

# 从rin_pytorch.rin_pytorch模块中导入GaussianDiffusion、RIN和Trainer类
from rin_pytorch.rin_pytorch import GaussianDiffusion, RIN, Trainer

.\lucidrains\recurrent-interface-network-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'RIN-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.7.10',  # 版本号
  license='MIT',  # 许可证
  description = 'RIN - Recurrent Interface Network - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/RIN-pytorch',  # URL
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'attention mechanism',
    'denoising diffusion',
    'image and video generation'
  ],
  install_requires=[  # 安装依赖
    'accelerate',
    'beartype',
    'ema-pytorch',
    'einops>=0.6',
    'pillow',
    'torch>=1.12.0',
    'torchvision',
    'tqdm'
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

Recurrent Memory Transformer - Pytorch

Implementation of Recurrent Memory Transformer (openreview) in Pytorch. They had a short follow up paper recently that demonstrated it was able to copy information across 1 million tokens at the very least.

There is no doubt in my mind that RMT would make a stronger RL agent than AdA, which is just a Transformer-XL - Update: Recurrent Memory Decision Transformer

Yannic Kilcher paper review

Appreciation

  • Stability and 🤗 Huggingface for their generous sponsorships to work on and open source cutting edge artificial intelligence research

Install

$ pip install recurrent-memory-transformer-pytorch

Usage

import torch
from recurrent_memory_transformer_pytorch import RecurrentMemoryTransformer

model = RecurrentMemoryTransformer(
    num_tokens = 20000,               # number of tokens
    num_memory_tokens = 128,          # number of memory tokens, this will determine the bottleneck for information being passed to the future
    dim = 512,                        # model dimensions
    depth = 6,                        # transformer depth
    causal = True,                    # autoregressive or not
    dim_head = 64,                    # dimension per head
    heads = 8,                        # heads
    seq_len = 1024,                   # sequence length of a segment
    use_flash_attn = True             # whether to use flash attention
)

x = torch.randint(0, 256, (1, 1024))

logits1, mem1, _ = model(x)        # (1, 1024, 20000), (1, 128, 512), None
logits2, mem2, _ = model(x, mem1)  # (1, 1024, 20000), (1, 128, 512), None
logits3, mem3, _ = model(x, mem2)  # (1, 1024, 20000), (1, 128, 512), None

# and so on ...

With XL memories

import torch
from recurrent_memory_transformer_pytorch import RecurrentMemoryTransformer

model = RecurrentMemoryTransformer(
    num_tokens = 20000,
    num_memory_tokens = 128,
    dim = 512,
    depth = 6,
    causal = True,
    dim_head = 64,
    heads = 8,
    seq_len = 1024,
    use_flash_attn = True,
    use_xl_memories = True,    # set this to True
    xl_mem_len = 512           # can be shorter than the seq len - i think just having a bit of the past will prevent much of the RMT memories  memorizing the immediate preceding text
)

x = torch.randint(0, 256, (1, 1024))

logits1, mem1, xl_mem1 = model(x)                               # (1, 1024, 20000), (1, 128, 512), [(2, 1, 512, 512)]
logits2, mem2, xl_mem2 = model(x, mem1, xl_memories = xl_mem1)  # (1, 1024, 20000), (1, 128, 512), [(2, 1, 512, 512)]
logits3, mem3, xl_mem3 = model(x, mem2, xl_memories = xl_mem2)  # (1, 1024, 20000), (1, 128, 512), [(2, 1, 512, 512)]

# and so on ...

Train on an absurdly long sequence

import torch
from recurrent_memory_transformer_pytorch import (
    RecurrentMemoryTransformer,
    RecurrentMemoryTransformerWrapper
)

model = RecurrentMemoryTransformer(
    num_tokens = 256,
    num_memory_tokens = 128,
    dim = 512,
    depth = 6,
    seq_len = 1024,
    use_flash_attn = True,
    causal = True
)

model = RecurrentMemoryTransformerWrapper(model).cuda()

seq = torch.randint(0, 256, (4, 65536)).cuda()   # absurdly long sequence, in reality, they curriculum learned this starting with 1 segment to about 7-8 segments

loss = model(seq, memory_replay_backprop = True) # memory efficient training from memformer paper

Todo

Alternatives

Citations

@inproceedings{bulatov2022recurrent,
  title     = {Recurrent Memory Transformer},
  author    = {Aydar Bulatov and Yuri Kuratov and Mikhail Burtsev},
  booktitle = {Advances in Neural Information Processing Systems},
  editor    = {Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
  year      = {2022},
  url       = {https://openreview.net/forum?id=Uynr3iPhksa}
}
@misc{bulatov2023scaling,
  title     = {Scaling Transformer to 1M tokens and beyond with RMT},
  author    = {Aydar Bulatov and Yuri Kuratov and Mikhail S. Burtsev},
  year      = {2023},
  eprint    = {2304.11062},
  archivePrefix = {arXiv},
  primaryClass = {cs.CL}
}
@inproceedings{dao2022flashattention,
  title     = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
  author    = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
  booktitle = {Advances in Neural Information Processing Systems},
  year      = {2022}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@inproceedings{Wu2020MemformerAM,
    title   = {Memformer: A Memory-Augmented Transformer for Sequence Modeling},
    author  = {Qingyang Wu and Zhenzhong Lan and Kun Qian and Jing Gu and Alborz Geramifard and Zhou Yu},
    booktitle = {AACL/IJCNLP},
    year    = {2020}
}
@software{peng_bo_2021_5196578,
    author    = {PENG Bo},
    title     = {BlinkDL/RWKV-LM: 0.01},
    month     = {aug},
    year      = {2021},
    publisher = {Zenodo},
    version   = {0.01},
    doi       = {10.5281/zenodo.5196578},
    url       = {https://doi.org/10.5281/zenodo.5196578}
}
@misc{ding2021cogview,
    title   = {CogView: Mastering Text-to-Image Generation via Transformers},
    author  = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
    year    = {2021},
    eprint  = {2105.13290},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@software{Dayma_DALLE_Mini_2021,
    author  = {Dayma, Boris and Patil, Suraj and Cuenca, Pedro and Saifullah, Khalid and Abraham, Tanishq and Lê Khắc, Phúc and Melas, Luke and Ghosh, Ritobrata},
    doi     = {10.5281/zenodo.5146400},
    license = {Apache-2.0},
    month   = {jul},
    title   = {{DALL·E Mini}},
    url     = {https://github.com/borisdayma/dalle-mini},
    version = {v0.1-alpha},
    year    = {2021}}
@inproceedings{anonymous2022normformer,
    title   = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
    author  = {Anonymous},
    booktitle = {Submitted to The Tenth International Conference on Learning Representations },
    year    = {2022},
    url     = {https://openreview.net/forum?id=GMYWzWztDx5},
    note    = {under review}
}
@misc{ding2021erniedoc,
    title   = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer},
    author  = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang},
    year    = {2021},
    eprint  = {2012.15688},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@article{Xie2023ResiDualTW,
  title     = {ResiDual: Transformer with Dual Residual Connections},
  author    = {Shufang Xie and Huishuai Zhang and Junliang Guo and Xu Tan and Jiang Bian and Hany Hassan Awadalla and Arul Menezes and Tao Qin and Rui Yan},
  journal   = {ArXiv},
  year      = {2023},
  volume    = {abs/2304.14802}
}

.\lucidrains\recurrent-memory-transformer-pytorch\recurrent_memory_transformer_pytorch\attend.py

# 导入必要的库
from collections import namedtuple
from functools import wraps
from packaging import version

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange

# 定义一个命名元组 Config,包含三个布尔类型的参数
Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义一个辅助函数,用于检查值是否存在
def exists(val):
    return val is not None

# 定义一个装饰器函数,确保被装饰的函数只执行一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 用 once 装饰 print 函数,确保只打印一次
print_once = once(print)

# 主要类 Attend
class Attend(nn.Module):
    def __init__(
        self,
        dropout = 0.,
        causal = False,
        use_flash = False
    ):
        super().__init__()
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.causal = causal
        self.register_buffer("mask", None, persistent=False)

        self.use_flash = use_flash
        assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定用于 cuda 和 cpu 的高效注意力配置

        self.cpu_config = Config(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not use_flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = Config(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = Config(False, True, True)

    # 获取掩码
    def get_mask(self, n, device):
        if exists(self.mask) and self.mask.shape[-1] >= n:
            return self.mask[:n, :n]

        mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
        self.register_buffer("mask", mask, persistent=False)
        return mask

    # Flash Attention 函数
    def flash_attn(self, q, k, v, mask = None):
        _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda

        # 检查掩码是否存在并扩展到兼容的形状
        if exists(mask):
            if mask.ndim != 4:
                mask = rearrange(mask, 'b j -> b 1 1 j')

            mask = mask.expand(-1, heads, q_len, -1)

        # 检查是否有兼容的设备用于 Flash Attention
        config = self.cuda_config if is_cuda else self.cpu_config

        # 使用 torch.backends.cuda.sdp_kernel 函数进行 Flash Attention 计算
        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask = mask,
                dropout_p = self.dropout if self.training else 0., 
                is_causal = self.causal
            )

        return out
    # 定义一个前向传播函数,接受查询(q), 键(k), 值(v)以及可选的掩码(mask)
    """
    einstein notation
    b - batch
    h - heads
    n, i, j - sequence length (base sequence length, source, target)
    d - feature dimension
    """

    # 获取查询(q)的序列长度(n)和设备信息(device)
    n, device = q.shape[-2], q.device

    # 计算缩放因子,根据特征维度的平方根
    scale = q.shape[-1] ** -0.5

    # 如果使用闪回注意力机制,则调用flash_attn函数
    if self.use_flash:
        return self.flash_attn(q, k, v, mask = mask)

    # 计算相似度矩阵

    sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale

    # 键的填充掩码

    if exists(mask):
        if mask.ndim != 4:
            mask = rearrange(mask, 'b j -> b 1 1 j')
        sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

    # 因果掩码

    if self.causal:
        causal_mask = self.get_mask(n, device)
        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

    # 注意力权重计算

    attn = sim.softmax(dim=-1)
    attn = self.attn_dropout(attn)

    # 聚合值

    out = einsum("b h i j, b h j d -> b h i d", attn, v)

    return out

.\lucidrains\recurrent-memory-transformer-pytorch\recurrent_memory_transformer_pytorch\recurrent_memory_transformer.py

# 导入数学库
import math
# 导入partial函数
from functools import partial
# 导入zip_longest函数
from itertools import zip_longest
# 导入nullcontext函数
from contextlib import nullcontext

# 导入类型提示相关库
from typing import Optional, List, Tuple

# 导入torch库
import torch
# 导入torch.nn.functional库
import torch.nn.functional as F
# 导入torch.nn、einsum、Tensor
from torch import nn, einsum, Tensor

# 导入rearrange、repeat、pack、unpack函数
from einops import rearrange, repeat, pack, unpack

# 导入Attend类
from recurrent_memory_transformer_pytorch.attend import Attend

# 定义常量Linear为nn.Linear函数的偏函数,不包含偏置
Linear = partial(nn.Linear, bias = False)

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 返回输入的第一个参数
def identity(t, *args, **kwargs):
    return t

# 返回输入参数中第一个不为None的值
def default(*vals):
    for val in vals:
        if exists(val):
            return val
    return None

# 评估装饰器,用于在评估模式下运行函数
def eval_decorator(fn):
    def inner(self, *args, **kwargs):
        was_training = self.training
        self.eval()
        out = fn(self, *args, **kwargs)
        self.train(was_training)
        return out
    return inner

# 判断一个数是否能被另一个数整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# 采样辅助函数

# 计算输入张量的对数
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 生成Gumbel噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 生成Gumbel采样
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)

# 生成top-k采样
def top_k(logits, thres = 0.9):
    k = math.ceil((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 令牌移位函数
def token_shift_fn(t, ps):
    read_mem, t, write_mem = unpack(t, ps, 'b * d')
    t, t_shift = t.chunk(2, dim = -1)
    t_shift = F.pad(t_shift, (0, 0, 1, -1), value = 0.)
    t = torch.cat((t, t_shift), dim = -1)
    return torch.cat((read_mem, t, write_mem), dim = -2)

# 分数梯度函数
def frac_gradient(t, frac = 1.):
    if frac == 1.:
        return t

    return t * frac + t.detach() * (1. - frac)

# 旋转嵌入

# 旋转嵌入类
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, theta = 32768):
        super().__init__()
        inv_freq = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, positions):
        freqs = torch.einsum('i , j -> i j', positions, self.inv_freq)
        freqs = torch.cat((freqs, freqs), dim = -1)
        return freqs

# 旋转半个周期
def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

# 应用旋转位置嵌入
def apply_rotary_pos_emb(pos, t):
    return (t * pos.cos()) + (rotate_half(t) * pos.sin())

# 规范化

# 均方根规范化类
class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return F.normalize(x, dim = -1) * self.scale * self.gamma

# 前馈网络

# GEGLU激活函数
class GEGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim = -1)
        return x * F.gelu(gate)

# 前馈网络函数
def FeedForward(dim, mult = 4, dropout = 0.):
    dim_inner = int(dim * mult * 2 / 3)
    return nn.Sequential(
        Linear(dim, dim_inner * 2, bias = False),
        GEGLU(),
        RMSNorm(dim_inner),
        nn.Dropout(dropout),
        Linear(dim_inner, dim, bias = False)
    )

# 注意力机制

# 注意力类
class Attention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        causal = False,
        dim_head = 64,
        heads = 8,
        dropout = 0.,
        use_flash_attn = False,
        use_custom_causal_attn_mask = False
    ):
        super().__init__()
        dim_inner = dim_head * heads
        self.heads = heads

        self.attend = Attend(
            causal = causal and not use_custom_causal_attn_mask,
            dropout = dropout,
            use_flash = use_flash_attn
        )

        self.null_kv = nn.Parameter(torch.randn(2, heads, dim_head))

        self.to_q = Linear(dim, dim_inner)
        self.to_kv = Linear(dim, dim_inner * 2)
        self.to_out = Linear(dim_inner, dim)
    # 定义一个前向传播函数,接受输入 x,旋转嵌入 rotary_emb(可选),掩码 mask,XL 内存 xl_memories(可选)
    def forward(
        self,
        x,
        rotary_emb: Optional[Tuple[Tensor, Tensor]] = None,
        mask = None,
        xl_memories = None
    ):
        # 获取头数 h
        h = self.heads

        # 将输入 x 转换为查询 q
        q = self.to_q(x)
        # 将输入 x 转换为键 k 和值 v
        k, v = self.to_kv(x).chunk(2, dim = -1)

        # 对查询 q、键 k、值 v 进行重排列,以适应多头注意力的计算
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 添加一个空键/值对,以防止整个序列被完全掩码,同时使注意力能够关注空值
        nk, nv = map(lambda t: repeat(t, 'h d -> b h 1 d', b = x.shape[0]), self.null_kv)

        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)

        # 如果存在掩码,则在掩码前面填充一个位置
        if exists(mask):
            mask = F.pad(mask, (1, 0), value = True)

        # 管理记忆
        next_xl_memories = torch.stack((k, v))

        # 如果存在 XL 记忆,则将 XL 记忆与当前键 k 和值 v 连接起来
        if exists(xl_memories):
            kx, vx = xl_memories
            k = torch.cat((kx, k), dim = -2)
            v = torch.cat((vx, v), dim = -2)

            # 如果存在掩码,则在掩码前面填充 XL 记忆的长度个位置
            if exists(mask):
                mask = F.pad(mask, (xl_memories.shape[-2], 0), value = True)

        # 如果存在旋转嵌入,则将查询 q 和键 k 应用旋转位置嵌入
        if exists(rotary_emb):
            q_rotary_emb, k_rotary_emb = rotary_emb

            q = apply_rotary_pos_emb(q_rotary_emb, q)
            k = apply_rotary_pos_emb(k_rotary_emb, k)

        # 使用注意力机制计算输出
        out = self.attend(q, k, v, mask = mask)

        # 将输出重排列为原始形状
        out = rearrange(out, 'b h n d -> b n (h d)')

        # 将输出传递给输出层,并返回下一个 XL 记忆
        return self.to_out(out), next_xl_memories
# 定义一个名为 RecurrentMemoryTransformer 的类,继承自 nn.Module
class RecurrentMemoryTransformer(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        dim,
        *,
        num_tokens,
        depth,
        num_memory_tokens,
        seq_len,
        causal = True,        
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        attn_dropout = 0.,
        ff_dropout = 0.,
        use_flash_attn = False,
        ignore_index = -1,
        abs_pos_emb = True,
        rotary_pos_emb = False,
        token_shift = True,
        use_xl_memories = True,
        xl_mem_len = None,
        enhanced_xl_recurrence = False,      # 是否使用增强的 XL 记忆方法,来自 ernie-doc 论文
        emb_gradient_frac = 0.1,             # 来自 cogview 论文的技巧,导致更稳定一些
        memory_not_causal = True,            # 如果闪光注意力在没有显式传递因果掩码的情况下表现更佳,那么有必要将其打开
        add_write_to_next_write_mem = False, # 将上一步的写记忆添加到下一步的写步骤中 - 感谢 @IcarusWizard 指出这个不一致之处
        next_write_mem_stop_grad = True,     # 是否停止前一个读记忆的梯度 -> 下一个写记忆
        always_have_read_memories = True,    # 是否始终具有读记忆,即使在第一步也是如此,以使模型能够导出为 ONNX
        resi_dual_scale = 1.,                # 在 prenorm 分支中发生 fp16 溢出的情况下,将其设置为小于 1 的值
        ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化模型参数
        self.causal = causal
        self.seq_len = seq_len

        self.emb_gradient_frac = emb_gradient_frac

        # 断言保证 resi_dual_scale 在 0 和 1 之间
        assert 0 < resi_dual_scale <= 1., 'resiDual scale must be between 0 and 1'
        self.resi_dual_scale = resi_dual_scale

        assert num_memory_tokens > 0

        # 初始化 token embedding 层
        self.token_emb = nn.Embedding(num_tokens, dim)

        # 初始化位置编码
        assert any([abs_pos_emb, rotary_pos_emb, token_shift])

        self.pos_emb = nn.Embedding(seq_len, dim) if abs_pos_emb else None

        self.rotary_pos_emb = RotaryEmbedding(dim_head) if rotary_pos_emb else None

        self.maybe_token_shift = token_shift_fn if token_shift else identity

        # 初始化与记忆相关的参数
        self.num_memory_tokens = num_memory_tokens

        self.read_memory_emb = nn.Parameter(torch.zeros(num_memory_tokens, dim))
        nn.init.normal_(self.read_memory_emb, std = 0.02)

        self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
        nn.init.normal_(self.memory_tokens, std = 0.02)

        # 初始化 xl memories
        xl_mem_len = default(xl_mem_len, seq_len)
        assert xl_mem_len <= seq_len
        self.xl_mem_len = xl_mem_len

        self.use_xl_memories = use_xl_memories
        self.enhanced_xl_recurrence = enhanced_xl_recurrence

        # 初始化层
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(
                    dim = dim,
                    dim_head = dim_head,
                    causal = causal,
                    heads = heads,
                    use_flash_attn = use_flash_attn,
                    use_custom_causal_attn_mask = memory_not_causal,
                    dropout = attn_dropout
                ),
                RMSNorm(dim),
                FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout),
                RMSNorm(dim)
            ]))

        self.norm = RMSNorm(dim)
        self.to_logits = nn.Linear(dim, num_tokens)

        self.ignore_index = ignore_index

        # 是否使用自定义注意力掩码,如果是因果性的且记忆不应该是因果性的
        self.use_custom_causal_attn_mask = causal and memory_not_causal

        # 在论文中,他们实际上还使用前一个写入记忆来生成下一个写入记忆
        self.add_write_to_next_write_mem = add_write_to_next_write_mem
        self.next_write_mem_stop_grad = next_write_mem_stop_grad

        # 允许在第一步时关注原始读取记忆的位置编码
        # 为了使其能够在 onnx 中运行,并且不会有影响
        self.always_have_read_memories = always_have_read_memories

    # 初始化记忆
    def init_memory(self, batch):
        return repeat(self.memory_tokens, 'm d -> b m d', b = batch)

    # 前向传播函数
    def forward(
        self,
        x,
        read_memories = None,
        *,
        mask = None,
        labels = None,
        xl_memories: Optional[List[Tensor]] = None,
        mask_out_read_memories = False   # 在传入读取记忆为 0 时,用于 onnx 模型
# 管理多个段的包装器

class RecurrentMemoryTransformerWrapper(nn.Module):
    def __init__(
        self,
        transformer: RecurrentMemoryTransformer,
        truncate_at_step = None  # 在分离记忆之前截断步骤的数量(截断 bptt)。通过记忆重播检查点,不应该有记忆问题,但如果出现不稳定性,如初始论文中报告的那样
    ):
        super().__init__()
        self.transformer = transformer
        self.seq_len = transformer.seq_len
        self.truncate_at_step = truncate_at_step

    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        prime,
        *,
        length,
        memories = None,
        xl_memories: Optional[List[Tensor]] = None,
        temperature = 1.,
        filter_thres = 0.9
    ):
        assert self.transformer.causal, 'only autoregressive transformers can generate'

        start_len, seq_len = prime.shape[-1], self.seq_len

        assert length >= start_len

        *past_segments, curr_segment = prime.split(seq_len, dim = -1)

        # catch memories up to the current segment

        for past_segment in past_segments:
            _, memories, xl_memories = self.transformer(past_segment, memories, xl_memories = xl_memories)

        # sample for the remaining length

        for ind in range(length - start_len):
            logits, next_memories, next_xl_memories = self.transformer(curr_segment, memories, xl_memories = xl_memories)

            logits = logits[:, -1]

            filtered_logits = top_k(logits, thres = filter_thres)
            sampled = gumbel_sample(filtered_logits, temperature = temperature)
            sampled = rearrange(sampled, 'b -> b 1')

            curr_segment = torch.cat((curr_segment, sampled), dim = -1)

            if divisible_by(curr_segment.shape[-1] - 1, seq_len):
                memories = next_memories
                xl_memories = next_xl_memories

                past_segment, curr_segment = curr_segment[..., :seq_len], curr_segment[..., -1:]
                past_segments.append(past_segment)

        # add current segment to all segments

        past_segments.append(curr_segment)

        # reconcat all segments

        output = torch.cat(past_segments, dim = -1)

        output = output[:, start_len:]
        return output

    def forward(
        self,
        x,
        memories = None,
        *,
        mask = None,
        xl_memories: Optional[List[Tensor]] = None,
        return_loss = False,
        labels = None,
        truncate_at_step = None,         # 如果设置,这将覆盖初始化时的 truncate_at_step
        memory_replay_backprop = False,  # 是否让类进行内存高效的反向传播
        mrbp_loss_weight = 1.            # 如果使用内存重播反向传播与梯度累积,通过此因子缩放损失,例如(1. / <num grad accum steps>)
        ):
            # 设置序列长度和截断步数
            seq_len, truncate_at_step = self.seq_len, default(truncate_at_step, self.truncate_at_step)

            labels = None
            # 如果需要返回损失或进行记忆重播反向传播,并且标签不存在,则从输入中获取标签
            if (return_loss or memory_replay_backprop) and not exists(labels):
                x, labels = x[:, :-1], x[:, 1:]

            # 分割输入
            segments = x.split(seq_len, dim = -1)
            total_length = x.shape[-1]
            num_segments = len(segments)
            segment_length_frac = tuple(map(lambda t: t.shape[-1] / total_length, segments))

            # 默认值
            label_segments = mask_segments = (None,)

            # 处理标签
            if exists(labels):
                label_segments = labels.split(seq_len, dim = -1)

            # 处理掩码
            if exists(mask):
                mask_segments = mask.split(seq_len, dim = -1)

            # 保留重播缓冲区
            replay_buffer = [memories]

            # 用于xl记忆的重播缓冲区
            xl_segments = [xl_memories]

            # 根据是否进行记忆重播反向传播决定前向上下文
            forward_context = nullcontext if not memory_replay_backprop else torch.no_grad

            # 前向传播并获取所有输出(可以是损失或逻辑值)
            logits = []
            losses = []

            for step, (segment, mask_segment, label_segment, loss_weight) in enumerate(zip_longest(segments, mask_segments, label_segments, segment_length_frac):

                with forward_context():
                    output, memories, xl_memories = self.transformer(segment, memories, mask = mask_segment, labels = label_segment)

                if exists(truncate_at_step) and divisible_by(step + 1, truncate_at_step):
                    memories = memories.detach()

                replay_buffer.append(memories)

                xl_segments.append(xl_memories)

                if return_loss:
                    losses.append(output * loss_weight)
                else:
                    logits.append(output)

            # 是否进行记忆重播反向传播
            # https://arxiv.org/abs/2010.06891
            # 算法1
            if memory_replay_backprop:
                memories_grad = torch.zeros_like(replay_buffer[-1])

                reversed_inputs = zip_longest(*map(reversed, [
                    range(num_segments),
                    segments,
                    replay_buffer[:-1],
                    xl_segments[:-1],
                    mask_segments,
                    label_segments,
                    segment_length_frac,
                ]))

                total_loss = 0.

                for step, segment, segment_memories, segment_xl_memories, mask_segment, label_segment, loss_weight in reversed_inputs:
                    is_first = step == 0

                    if exists(segment_memories):
                        segment_memories.requires_grad_()

                    loss, next_segment_memories, _ = self.transformer(segment, segment_memories, mask = mask_segment, xl_memories = segment_xl_memories, labels = label_segment)

                    weighted_loss = loss * loss_weight * mrbp_loss_weight

                    weighted_loss.backward(retain_graph = True)

                    next_segment_memories.backward(memories_grad)

                    total_loss += weighted_loss

                    if is_first:
                        continue

                    if exists(truncate_at_step) and divisible_by(step, truncate_at_step):
                        memories_grad.zero_()
                    else:
                        memories_grad.copy_(segment_memories.grad.data)

                return total_loss

            # 如果不需要返回损失,则返回逻辑值
            if not return_loss:
                logits = torch.cat(logits, dim = -2)
                return logits, memories

            # 否则返回损失
            return sum(losses), memories

.\lucidrains\recurrent-memory-transformer-pytorch\recurrent_memory_transformer_pytorch\__init__.py

# 从 recurrent_memory_transformer_pytorch.recurrent_memory_transformer 模块中导入 RecurrentMemoryTransformer 和 RecurrentMemoryTransformerWrapper 类
from recurrent_memory_transformer_pytorch.recurrent_memory_transformer import RecurrentMemoryTransformer, RecurrentMemoryTransformerWrapper

.\lucidrains\recurrent-memory-transformer-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'recurrent-memory-transformer-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.5.6',
  # 许可证类型
  license='MIT',
  # 描述信息
  description = 'Recurrent Memory Transformer - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/recurrent-memory-transformer-pytorch',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'recurrence',
    'memory',
    'long-context'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.6.1',
    'torch>=1.6',
  ],
  # 分类标签
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\recurrent-memory-transformer-pytorch\train.py

# 导入所需的库
import gzip
import random
import tqdm
import numpy as np

import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 导入自定义的模型和包装器
from recurrent_memory_transformer_pytorch import RecurrentMemoryTransformer, RecurrentMemoryTransformerWrapper

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
PRIME_LENGTH = 128
GENERATE_EVERY = 250
GENERATE_LENGTH = 2048
SEQ_LEN = 2048

# 定义辅助函数

# 从 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 从 tokens 解码为字符串
def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))


# 实例化 RecurrentMemoryTransformer 模型
model = RecurrentMemoryTransformer(
    num_tokens = 256,
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8,
    seq_len = 512,
    use_flash_attn = True,
    num_memory_tokens = 128,
    use_xl_memories = True,
    xl_mem_len = 256
)

# 使用包装器对模型进行包装
model = RecurrentMemoryTransformerWrapper(model)

# 将模型移至 GPU
model.cuda()

# 准备 enwik8 数据

# 从压缩文件中读取数据
with gzip.open("./data/enwik8.gz") as file:
    data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
    np_train, np_valid = np.split(data, [int(90e6)])
    data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 定义优化器
optim = Adam(model.parameters(), lr = LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
    model.train()

    total_loss = 0.
    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(
            next(train_loader),
            memory_replay_backprop = True,
            mrbp_loss_weight = 1. / GRADIENT_ACCUMULATE_EVERY
        )

        total_loss += loss

    print(f"training loss: {total_loss.item()}")
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss, _ = model(next(val_loader), return_loss = True)
            print(f"validation loss: {loss.item()}")

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:PRIME_LENGTH]
        prime = decode_tokens(inp)
        print(f"%s \n\n %s", (prime, "*" * 100))

        sample = model.generate(inp[None, :], length = GENERATE_LENGTH)
        output_str = decode_tokens(sample[0])
        print(output_str, "\n")
posted @ 2024-06-28 14:03  绝不原创的飞龙  阅读(9)  评论(0编辑  收藏  举报