Lucidrains-系列项目源码解析-三十八-

Lucidrains 系列项目源码解析(三十八)

.\lucidrains\routing-transformer\routing_transformer\routing_transformer.py

# 导入 torch 库
import torch
# 导入 torch 中的神经网络模块
import torch.nn as nn
# 导入 torch 中的函数操作模块
import torch.nn.functional as F
# 导入 math 库
import math
# 从 inspect 模块中导入 isfunction 函数
from inspect import isfunction
# 从 operator 模块中导入 mul 函数
from operator import mul
# 从 functools 模块中导入 partial, reduce, wraps 函数
from functools import partial, reduce, wraps

# 从 einops 库中导入 rearrange, repeat 函数
from einops import rearrange, repeat
# 从 einops.layers.torch 模块中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 从 local_attention 模块中导入 LocalAttention 类
from local_attention import LocalAttention
# 从 product_key_memory 模块中导入 PKM 类
from product_key_memory import PKM
# 从 mixture_of_experts 模块中导入 MoE 类
from mixture_of_experts import MoE
# 从 routing_transformer.reversible 模块中导入 ReversibleSequence, SequentialSequence 类

# 常量定义

# 定义 TOKEN_SELF_ATTN_VALUE 常量为 -5e4
TOKEN_SELF_ATTN_VALUE = -5e4
# 定义 KMEAN_INIT_ITERS 常量为 10
KMEAN_INIT_ITERS = 10

# 辅助函数

# 判断值是否存在的函数
def exists(val):
    return val is not None

# 返回输入值的函数
def identity(x, *args, **kwargs):
    return x

# 如果输入值不存在,则返回默认值的函数
def default(x, d):
    if not exists(x):
        return d if not isfunction(d) else d()
    return x

# 将输入值转换为元组的函数
def cast_tuple(x):
    return x if isinstance(x, tuple) else (x,)

# 缓存函数的装饰器
def cache_fn(f):
    cache = None
    @wraps(f)
    def cached_fn(*args, **kwargs):
        nonlocal cache
        if exists(cache):
            return cache
        cache = f(*args, **kwargs)
        return cache
    return cached_fn

# 组合多个函数的函数
def compose(*fns):
    def inner(x, *args, **kwargs):
        for fn in reversed(fns):
            x = fn(x, *args, **kwargs)
        return x
    return inner

# 返回输入张量的设备和数据类型的字典的函数
def to(t):
    return {'device': t.device, 'dtype': t.dtype}

# 查找神经网络模块中指定类型的模块的函数
def find_modules(nn_module, type):
    return [module for module in nn_module.modules() if isinstance(module, type)]

# 判断张量是否为空的函数
def is_empty(t):
    return t.nelement() == 0

# 返回指定张量数据类型的最大负值的函数
def max_neg_value(tensor):
    return -torch.finfo(tensor.dtype).max

# 在指定维度上对张量进行批量索引选择的函数
def batched_index_select(values, indices):
    last_dim = values.shape[-1]
    return values.gather(2, expand_dim(indices, -1, last_dim))

# 合并张量的维度的函数
def merge_dims(ind_from, ind_to, tensor):
    shape = list(tensor.shape)
    arr_slice = slice(ind_from, ind_to + 1)
    shape[arr_slice] = [reduce(mul, shape[arr_slice])]
    return tensor.reshape(*shape)

# 在指定维度上扩展张量的函数
def expand_dim(t, dim, k):
    t = t.unsqueeze(dim)
    expand_shape = [-1] * len(t.shape)
    expand_shape[dim] = k
    return t.expand(*expand_shape)

# 在指定维度上对张量进行均值散开的函数
def scatter_mean(src, t, index, dim, eps = 1e-5):
    numer = src.scatter_add(dim, index, t)
    denom = src.scatter_add(dim, index, torch.ones_like(t))
    return numer / (denom + eps)

# 在指定维度上将张量拆分为两部分的函数
def split_at_index(dim, index, t):
    pre_slices = (slice(None),) * dim
    l = (*pre_slices, slice(None, index))
    r = (*pre_slices, slice(index, None))
    return t[l], t[r]

# 重塑张量的维度的函数
def reshape_dim(t, dim, split_dims):
    shape = list(t.shape)
    num_dims = len(shape)
    dim = (dim + num_dims) % num_dims
    shape[dim:dim+1] = split_dims
    return t.reshape(shape)

# 指数移动平均的函数
def ema(old, new, decay):
    if not exists(old):
        return new
    return old * decay + new * (1 - decay)

# 就地指数移动平均的函数
def ema_inplace(moving_avg, new, decay):
    if is_empty(moving_avg):
        moving_avg.data.copy_(new)
        return
    moving_avg.data.mul_(decay).add_(new, alpha= (1 - decay))

# 辅助类

# 对第一个元组或元素应用函数的类
class Chunk(nn.Module):
    def __init__(self, chunks, fn, along_dim = -1):
        super().__init__()
        self.dim = along_dim
        self.chunks = chunks
        self.fn = fn

    def forward(self, x, **kwargs):
        if self.chunks <= 1:
            return self.fn(x, **kwargs)
        chunks = x.chunk(self.chunks, dim = self.dim)
        return torch.cat([self.fn(c, **kwargs) for c in chunks], dim = self.dim)

# 具有预处理的模块列表的类
class PreNorm(nn.ModuleList):
    def __init__(self, norm_class, dim, fn):
        super().__init__()
        self.norm = norm_class(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

# ReZero 模块
class ReZero(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.residual_weight = nn.Parameter(torch.zeros(1))
        self.fn = fn

    def forward(self, x, **kwargs):
        x = self.fn(x, **kwargs)
        return map_first_tuple_or_el(x, lambda t: t * self.residual_weight)
# 定义 ScaleNorm 类,用于对输入进行归一化处理
class ScaleNorm(nn.Module):
    # 初始化函数,设置归一化参数和阈值
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.g = nn.Parameter(torch.ones(1))
        self.eps = eps

    # 前向传播函数,对输入进行归一化处理
    def forward(self, x):
        # 定义内部函数 norm,用于计算归一化后的值
        def norm(t):
            # 计算输入张量 t 在指定维度上的 L2 范数,并进行归一化处理
            n = torch.norm(t, dim=-1, keepdim=True).clamp(min=self.eps)
            return t / n * self.g
        # 调用 map_first_tuple_or_el 函数,对输入进行处理
        return map_first_tuple_or_el(x, norm)

# 定义 ProjectInOut 类,用于对输入进行线性投影
class ProjectInOut(nn.Module):
    # 初始化函数,设置投影函数和维度参数
    def __init__(self, fn, dim_in, dim_out, project_out = True):
        super().__init__()
        self.fn = fn
        self.project_in = nn.Linear(dim_in, dim_out)
        self.project_out = nn.Linear(dim_out, dim_in) if project_out else identity

    # 前向传播函数,对输入进行线性投影处理
    def forward(self, x, **kwargs):
        # 对输入进行投影处理
        x = self.project_in(x)
        # 调用 fn 函数处理投影后的结果
        x, loss = self.fn(x, **kwargs)
        # 对输出进行反向投影处理
        x = self.project_out(x)
        return x, loss

# 定义 MatrixMultiply 类,用于矩阵乘法操作
class MatrixMultiply(nn.Module):
    # 初始化函数,设置矩阵和是否转置参数
    def __init__(self, tensor, transpose = False):
        super().__init__()
        self.tensor = tensor
        self.transpose = transpose

    # 前向传播函数,进行矩阵乘法操作
    def forward(self, x):
        tensor = self.tensor
        # 如果需要转置,则对矩阵进行转置操作
        if self.transpose:
            tensor = tensor.t()
        return x @ tensor

# 定义 token shift 函数,用于对输入进行位移操作
def shift(t, amount, mask = None):
    # 如果位移量为 0,则直接返回输入
    if amount == 0:
        return t

    # 如果存在掩码,则根据掩码进行填充操作
    if exists(mask):
        t = t.masked_fill(~mask[..., None], 0.)

    return F.pad(t, (0, 0, amount, -amount), value = 0.)

# 定义 PreShiftTokens 类,用于对输入进行预位移操作
class PreShiftTokens(nn.Module):
    # 初始化函数,设置位移量和处理函数
    def __init__(self, shifts, fn):
        super().__init__()
        self.fn = fn
        self.shifts = tuple(shifts)

    # 前向传播函数,对输入进行预位移处理
    def forward(self, x, **kwargs):
        # 获取掩码信息
        mask = kwargs.get('mask', None)
        shifts = self.shifts
        segments = len(shifts)
        feats_per_shift = x.shape[-1] // segments
        splitted = x.split(feats_per_shift, dim = -1)
        segments_to_shift, rest = splitted[:segments], splitted[segments:]
        segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
        x = torch.cat((*segments_to_shift, *rest), dim = -1)
        return self.fn(x, **kwargs)

# 定义 FixedPositionalEmbedding 类,用于固定位置编码
class FixedPositionalEmbedding(nn.Module):
    # 初始化函数,设置维度和最大序列长度
    def __init__(self, dim, max_seq_len):
        super().__init__()
        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        position = torch.arange(0, max_seq_len, dtype=torch.float)
        sinusoid_inp = torch.einsum("i,j->ij", position, inv_freq)
        emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
        self.register_buffer('emb', emb)

    # 前向传播函数,返回固定位置编码结果
    def forward(self, x):
        return self.emb[None, :x.shape[1], :].to(x)

# 定义 rotate_every_two 函数,用于对输入进行旋转操作
def rotate_every_two(x):
    x = rearrange(x, '... (d j) -> ... d j', j = 2)
    x1, x2 = x.unbind(dim = -1)
    x = torch.stack((-x2, x1), dim = -1)
    return rearrange(x, '... d j -> ... (d j)')

# 定义 apply_rotary_pos_emb 函数,用于应用旋转位置编码
def apply_rotary_pos_emb(q, k, v, sinu_pos):
    sinu_pos = sinu_pos.type(q.dtype)
    sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j = 2)
    sin, cos = sinu_pos.unbind(dim = -2)
    sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j = 2), (sin, cos))
    q, k, v = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k, v))
    return q, k, v

# 定义 update_kmeans_on_backwards 函数,用于在反向传播时更新 kmeans 模块
def update_kmeans_on_backwards(module):
    module.kmean_modules = find_modules(module, Kmeans)
    def hook(_, grad_in, grad_out):
        for m in module.kmean_modules:
            m.update()

    return module.register_backward_hook(hook)

# 定义 similarity 函数,用于计算输入与均值之间的相似度
def similarity(x, means):
    return torch.einsum('bhld,hcd->bhlc', x, means)

# 定义 dists_and_buckets 函数,用于计算距离和分桶
def dists_and_buckets(x, means):
    dists = similarity(x, means)
    _, buckets = torch.max(dists, dim=-1)
    return dists, buckets

# 定义 batched_bincount 函数,用于批量计算索引的频次
def batched_bincount(index, num_classes, dim=-1):
    shape = list(index.shape)
    shape[dim] = num_classes
    out = index.new_zeros(shape)
    out.scatter_add_(dim, index, torch.ones_like(index, dtype=index.dtype))
    return out

# 定义 kmeans_iter 函数,用于执行 kmeans 迭代
def kmeans_iter(x, means, buckets = None):
    b, h, l, d, dtype, num_clusters = *x.shape, x.dtype, means.shape[1]
    # 如果 buckets 不存在,则通过 dists_and_buckets 函数计算出来
    if not exists(buckets):
        _, buckets = dists_and_buckets(x, means)

    # 对 buckets 进行批量计数,然后对结果进行求和
    bins = batched_bincount(buckets, num_clusters).sum(0, keepdim=True)
    # 创建一个与 bins 形状相同的布尔张量,标记 bins 中为 0 的位置
    zero_mask = bins.long() == 0

    # 创建一个与 buckets 相同形状的全零张量 means_
    means_ = buckets.new_zeros(b, h, num_clusters, d, dtype=dtype)
    # 在指定维度上对 means_ 进行 scatter_add_ 操作,将 x 散射到 means_ 上
    means_.scatter_add_(-2, expand_dim(buckets, -1, d), x)
    # 对 means_ 沿着指定维度求和,并进行归一化,然后转换为指定数据类型
    means_ = F.normalize(means_.sum(0, keepdim=True), dim=-1).type(dtype)

    # 使用 torch.where 函数根据 zero_mask 的值选择更新后的 means_ 或保持原来的 means
    means = torch.where(zero_mask.unsqueeze(-1), means, means_)
    # 去除 means 的第一个维度,返回结果
    means = means.squeeze(0)
    # 返回计算得到的 means
    return means
# 根据距离矩阵和窗口大小,获取最大的 k 个索引
_, topk_indices = dists.topk(k=window_size, dim=-2)
# 转置索引矩阵
indices = topk_indices.transpose(-2, -1)
# 重新整形索引矩阵
return indices.reshape(*indices.size()[:2], -1)

# Kmeans 类定义
class Kmeans(nn.Module):
    def __init__(self, num_heads, head_dim, num_clusters, ema_decay = 0.999, commitment = 1e-4):
        super().__init__()
        self.commitment = commitment
        self.ema_decay = ema_decay

        # 注册缓冲区,存储聚类中心和初始化状态
        self.register_buffer('means', torch.randn(num_heads, num_clusters, head_dim))
        self.register_buffer('initted', torch.tensor(False))
        self.num_new_means = 0
        self.new_means = None

    @torch.no_grad()
    def init(self, x):
        if self.initted:
            return
        _, h, _, d, device, dtype = *x.shape, x.device, x.dtype

        num_clusters = self.means.shape[1]

        # 调整输入数据形状
        means = x.transpose(0, 1).contiguous().view(h, -1, d)
        num_samples = means.shape[1]

        # 初始化聚类中心
        if num_samples >= num_clusters:
            indices = torch.randperm(num_samples, device=device)[:num_clusters]
        else:
            indices = torch.randint(0, num_samples, (num_clusters,), device=device)

        means = means[:, indices]

        # 迭代更新聚类中心
        for _ in range(KMEAN_INIT_ITERS):
            means = kmeans_iter(x, means)

        self.num_new_means = 0
        self.means.data.copy_(means)
        self.initted.data.copy_(torch.tensor(True))

    @torch.no_grad()
    def update(self, new_means = None):
        new_means = default(new_means, self.new_means)
        assert exists(new_means), 'new kmeans has not been supplied'
        # 更新聚类中心
        ema_inplace(self.means, new_means, self.ema_decay)

        del self.new_means
        self.new_means = None
        self.num_new_means = 0

    def forward(self, x, update_means = False):
        self.init(x)

        b, dtype = x.shape[0], x.dtype
        means = self.means.type(dtype)
        x = F.normalize(x, 2, dim=-1).type(dtype)

        with torch.no_grad():
            dists, buckets = dists_and_buckets(x, means)

        routed_means = batched_index_select(expand_dim(means, 0, b), buckets)
        loss = F.mse_loss(x, routed_means) * self.commitment

        if update_means:
            with torch.no_grad():
                means = kmeans_iter(x, means, buckets)
            self.new_means = ema(self.new_means, means, self.num_new_means / (self.num_new_means + 1))
            self.num_new_means += 1

        return dists, loss

# KmeansAttention 类定义
class KmeansAttention(nn.Module):
    def __init__(self, num_clusters, window_size, num_heads, head_dim, causal = False, dropout = 0., ema_decay = 0.999, commitment = 1e-4, context_window_size = None, receives_context = False, num_mem_kv = 0, shared_qk = False):
        super().__init__()
        self.num_heads = num_heads
        self.num_clusters = num_clusters
        self.head_dim = head_dim

        self.window_size = window_size
        self.context_window_size = default(context_window_size, window_size)
        self.causal = causal

        self.shared_qk = shared_qk
        self.receives_context = receives_context
        self.kmeans = Kmeans(num_heads, head_dim, num_clusters, ema_decay, commitment)
        self.dropout = nn.Dropout(dropout)

        self.num_mem_kv = max(num_mem_kv, 1 if causal and not shared_qk else 0)
        self.mem_key = nn.Parameter(torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim))
        self.mem_value = nn.Parameter(torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim))
    # 定义前向传播函数,接受查询 q、键 k、值 v,以及可选的查询和键的掩码
    def forward(self, q, k, v, query_mask = None, key_mask = None, **kwargs):
        # 解包变量 b、h、t、d、kv_t、wsz、c_wsz、nc、device、dtype
        b, h, t, d, kv_t, wsz, c_wsz, nc, device, dtype = *q.shape, k.shape[2], self.window_size, self.context_window_size, self.num_clusters, q.device, q.dtype
        # 从 kwargs 中弹出 '_reverse' 键值对,默认为 False
        is_reverse = kwargs.pop('_reverse', False)

        # 创建与 q 相同形状的零张量 out
        out = torch.zeros_like(q, dtype=dtype)

        # 更新 kmeans 模型的标志,训练中且非反向传播时更新
        update_kmeans = self.training and not is_reverse
        
        # 如果不接收上下文信息,则 key_mask 默认为 query_mask
        key_mask = default(key_mask, query_mask) if not self.receives_context else key_mask
        # 如果不接收上下文信息,则 kv_wsz 为 wsz,否则为 c_wsz
        kv_wsz = wsz if not self.receives_context else c_wsz

        # 更新 wsz 和 kv_wsz 为 t 和 kv_t 的最小值
        wsz = min(wsz, t)
        kv_wsz = min(kv_wsz, kv_t)

        # 如果不共享查询和键或者接收上下文信息
        if not self.shared_qk or self.receives_context:
            # 使用 kmeans 模型计算 q 和 k 的聚类中心距离,返回聚类中心距离和辅助损失
            dists, aux_loss = self.kmeans(torch.cat((q, k), dim=2), update_kmeans)
            # 将 dists 按索引 2 分割为 q_dists 和 k_dists
            q_dists, k_dists = split_at_index(2, t, dists)
            # 根据 q_dists 和 wsz 计算索引
            indices = distribution(q_dists, wsz)
            # 根据 k_dists 和 kv_wsz 计算索引
            kv_indices = distribution(k_dists, kv_wsz)
        else:
            # 使用 kmeans 模型计算 q 的聚类中心距离,返回聚类中心距离和辅助损失
            dists, aux_loss = self.kmeans(q, update_kmeans)
            # 对 k 进行归一化,并转换为与 q 相同的类型
            k = F.normalize(k, dim=-1).to(q)
            # 根据 dists 和 wsz 计算索引
            indices = distribution(dists, wsz)
            # kv_indices 与 indices 相同
            kv_indices = indices

        # 根据索引选择 q、k、v 的子集
        q = batched_index_select(q, indices)
        k = batched_index_select(k, kv_indices)
        v = batched_index_select(v, kv_indices)

        # 定义 reshape_with_window 函数,用于将张量重塑为指定形状
        reshape_with_window = lambda x: x.reshape(b, h, nc, -1, d)
        # 将 q、k、v 分别应用 reshape_with_window 函��
        q, k, v = map(reshape_with_window, (q, k, v))

        # 将 self.mem_key 和 self.mem_value 扩展为与 q 相同的形状
        m_k, m_v = map(lambda x: expand_dim(x, 0, b).to(q), (self.mem_key, self.mem_value))
        # 将 k、v 与 m_k、m_v 连接在最后一个维度上
        k, v = map(lambda x: torch.cat(x, dim=3), ((m_k, k), (m_v, v)))

        # 计算点积,乘以缩放因子
        dots = torch.einsum('bhnid,bhnjd->bhnij', q, k) * (d ** -0.5)

        # 计算掩码值
        mask_value = max_neg_value(dots)

        # 如果存在查询或键的掩码
        if exists(query_mask) or exists(key_mask):
            # 默认创建查询掩码为全 1,键掩码为全 1
            query_mask = default(query_mask, lambda: torch.ones((b, t), device=device).bool())
            key_mask = default(key_mask, lambda: torch.ones((b, kv_t), device=device).bool())

            # 根据 indices 和 kv_indices 从掩码中选择子集
            q_mask = expand_dim(query_mask, 1, h).gather(2, indices)
            kv_mask = expand_dim(key_mask, 1, h).gather(2, kv_indices)
            # 将 q_mask、kv_mask 重塑为指定形状
            q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (q_mask, kv_mask))
            # 创建掩码,填充边界
            mask = q_mask[:, :, :, :, None] * kv_mask[:, :, :, None, :]
            mask = F.pad(mask, (self.num_mem_kv, 0), value=True)
            # 将 dots 中不符合掩码条件的位置填充为 mask_value
            dots.masked_fill_(~mask, mask_value)
            del mask

        # 如果是因果注意力机制
        if self.causal:
            # 将 indices、kv_indices 重塑为指定形状
            q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices))
            # 创建因果掩码
            mask = q_mask[:, :, :, :, None] >= kv_mask[:, :, :, None, :]
            mask = F.pad(mask, (self.num_mem_kv, 0), value=True)
            # 将 dots 中不符合掩码条件的位置填充为 mask_value
            dots.masked_fill_(~mask, mask_value)
            del mask            

        # 如果共享查询和键
        if self.shared_qk:
            # 将 indices、kv_indices 重塑为指定形状
            q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices))
            # 创建自注意力掩码
            mask = q_mask[:, :, :, :, None] == kv_mask[:, :, :, None, :]
            mask = F.pad(mask, (self.num_mem_kv, 0), value=False)
            # 将 dots 中符合掩码条件的位置填充为 TOKEN_SELF_ATTN_VALUE
            dots.masked_fill_(mask, TOKEN_SELF_ATTN_VALUE)
            del mask

        # 对 dots 进行 softmax 操作
        dots = dots.softmax(dim=-1)
        # 对 dots 进行 dropout 操作
        dots = self.dropout(dots)

        # 计算输出张量 bo
        bo = torch.einsum('bhcij,bhcjd->bhcid', dots, v)
        # 将 bo 重塑为指定形状
        so = torch.reshape(bo, (b, h, -1, bo.shape[-1])).type(dtype)
        # 对输出张量 out 进行 scatter_mean 操作
        out = scatter_mean(out, so, indices.unsqueeze(-1).expand_as(so), -2)
        # 返回输出张量 out 和辅助损失
        return out, aux_loss
# 定义 GELU 激活函数类
class GELU_(nn.Module):
    # 前向传播函数
    def forward(self, x):
        # GELU 激活函数的计算公式
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))

# 如果 nn 模块中存在 GELU 函数,则使用 nn.GELU,否则使用自定义的 GELU_ 函数
GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_

# 定义前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0., activation = None, glu = False):
        super().__init__()
        # 设置激活函数为 GELU
        activation = default(activation, GELU)

        self.glu = glu
        # 第一个全连接层
        self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1))
        # 激活函数层
        self.act = activation()
        # Dropout 层
        self.dropout = nn.Dropout(dropout)
        # 第二个全连接层
        self.w2 = nn.Linear(dim * mult, dim)

    # 前向传播函数
    def forward(self, x, **kwargs):
        if not self.glu:
            # 非 GLU 模式下的前向传播
            x = self.w1(x)
            x = self.act(x)
        else:
            # GLU 模式下的前向传播
            x, v = self.w1(x).chunk(2, dim=-1)
            x = self.act(x) * v

        x = self.dropout(x)
        x = self.w2(x)
        return x

# 自注意力机制类
class SelfAttention(nn.Module):
    def __init__(self,  dim, depth, max_seq_len, heads, local_attn_heads, window_size, dim_head = None, local_attn_window_size = None, local_attn_radius_blocks = 1, causal = False, attn_dropout = 0., dropout = 0., kmeans_ema_decay = 0.999, commitment_factor = 1e-4, receives_context = False, context_window_size = None, rel_pos_emb = True, num_mem_kv = 0, shared_qk = False, conv_query_kernel = 9):
        super().__init__()
        # 断言确保隐藏维度可以被头数整除
        assert dim_head or (dim % heads) == 0, 'hidden dimension must be divisible by number of heads'
        # 断言确保最大序列长度可以被窗口大小整除
        assert (max_seq_len % window_size) == 0, 'maximum sequence length must be divisible by the target window size'
        # 断言确保本地注意力头数小于总头数
        assert local_attn_heads <= heads, 'number of local attention heads must be less than total heads'
        # 断言确保本地注意力和上下文注意力不能同时使用
        assert not (receives_context and local_attn_heads > 0), 'local attention cannot be used for self attention with context'
        # 断言确保上下文注意力和因果��不能同时使用
        assert not (receives_context and causal), 'contextual attention layer cannot be causal'

        local_attn_window_size = default(local_attn_window_size, window_size)
        context_window_size = default(context_window_size, window_size)

        self.shared_qk = shared_qk
        self.receives_context = receives_context
        self.heads = heads
        self.local_attn_heads = local_attn_heads
        self.global_attn_heads = heads - local_attn_heads

        self.causal = causal
        self.window_size = window_size

        dim_head = default(dim_head, dim // heads)
        dim_heads = dim_head * heads
        self.dim_head = dim_head

        num_clusters = max_seq_len // window_size

        # 本地注意力
        local_dim_heads = dim_head * self.local_attn_heads
        if self.local_attn_heads > 0:
            rel_pos_emb_config = (dim_head, local_attn_heads) if rel_pos_emb else None
            self.local_attn = LocalAttention(dim = dim_head, window_size = local_attn_window_size, causal = causal, dropout = attn_dropout, rel_pos_emb_config = rel_pos_emb_config, look_backward = local_attn_radius_blocks, look_forward = 0 if causal else local_attn_radius_blocks)
            self.local_to_qkv = nn.Linear(dim, 3 * local_dim_heads)

        # 全局注意力
        global_dim_heads = dim_head * self.global_attn_heads
        if self.global_attn_heads > 0:
            self.global_attn = KmeansAttention(num_clusters, window_size, self.global_attn_heads, dim_head, causal = causal, dropout = attn_dropout, ema_decay = kmeans_ema_decay, commitment = commitment_factor, receives_context = receives_context, num_mem_kv = num_mem_kv, shared_qk = shared_qk)

        self.to_q = nn.Linear(dim, global_dim_heads, bias = False)
        self.to_v = nn.Linear(dim, global_dim_heads, bias = False)

        if not self.shared_qk:
            self.to_k = nn.Linear(dim, global_dim_heads, bias = False)

        # 输出
        self.to_out = nn.Linear(dim_heads, dim, bias = False)
        self.dropout = nn.Dropout(dropout)
    # 定义前向传播函数,接受输入 x 和其他参数
    def forward(self, x, context = None, input_mask = None, context_mask = None, pos_emb = None, **kwargs):
        # 断言如果需要上下文信息但未传入,则抛出异常
        assert not (self.receives_context and not exists(context)), 'context must be passed if self attention is set to receive context'
        # 获取输入 x 的形状信息
        b, t, e, h, dh = *x.shape, self.heads, self.dim_head
        # 判断是否存在局部和全局注意力头
        has_local, has_global = map(lambda x: x > 0, (self.local_attn_heads, self.global_attn_heads))

        # 定义函数用于将输入张量按照头数进行分割
        split_heads = lambda v: reshape_dim(v, -1, (-1, dh)).transpose(1, 2).contiguous()

        # 如果存在局部注意力头
        if has_local:
            # 将局部注意力头的查询、键、值分别提取出来并按头数分割
            local_qkv = self.local_to_qkv(x).chunk(3, dim=-1)
            lq, lk, lv = map(split_heads, local_qkv)

        # 如果存在全局注意力头
        if has_global:
            # 根据是否接收上下文信息选择输入作为查询和值
            kv_input = x if not self.receives_context else context

            # 将查询和值分别转换为 Q 和 V,并按头数分割
            q, v = self.to_q(x), self.to_v(kv_input)

            # 如果不共享 Q 和 K,则将键也转换为 K,否则根据是否接收上下文信息选择使用 Q 或者 K
            if not self.shared_qk:
                k = self.to_k(kv_input)
            else:
                k = self.to_q(kv_input) if self.receives_context else q

            q, k, v = map(split_heads, (q, k, v))

        # 初始化输出列表和总损失
        out = []
        total_loss = torch.tensor(0., requires_grad=True, **to(x))

        # 如果存在局部注意力头
        if has_local:
            # 使用局部注意力计算输出
            local_out = self.local_attn(lq, lk, lv, input_mask = input_mask)
            out.append(local_out)

        # 如果存在全局注意力头
        if has_global:
            # 如果不接收上下文信息且存在位置编码,则应用位置编码
            if not self.receives_context and exists(pos_emb):
                q, k, v = apply_rotary_pos_emb(q, k, v, pos_emb)

            # 使用全局注意力计算输出和损失
            global_out, loss = self.global_attn(q, k, v, query_mask = input_mask, key_mask = context_mask)
            total_loss = total_loss + loss

            out.append(global_out)

        # 将所有输出拼接在一起
        out = torch.cat(out, dim=1)
        # 重塑输出张量的形状
        out = out.reshape(b, h, t, -1).transpose(1, 2).reshape(b, t, -1)
        # 将输出传递给输出层,并应用 dropout
        out = self.to_out(out)
        return self.dropout(out), total_loss
class RoutingTransformer(nn.Module):
    # 定义一个路由变换器类,继承自 nn.Module
    def __init__(
        self,
        dim,
        depth,
        max_seq_len,
        heads = 8,
        dim_head = None,
        window_size = 64,
        local_attn_window_size = 256,
        local_attn_radius_blocks = 1,
        causal = False,
        weight_tie = False,
        attn_dropout = 0.,
        ff_dropout = 0.,
        attn_layer_dropout = 0.,
        layer_dropout = 0.,
        n_local_attn_heads = 0,
        ff_glu = False,
        reversible = False,
        ff_chunks = 1,
        kmeans_ema_decay = 0.999,
        commitment_factor = 1e-4,
        receives_context = False,
        context_window_size = None,
        _register_kmeans_update = False,
        rel_pos_emb = True,
        pkm_layers = tuple(),
        pkm_num_keys = 128,
        moe_layers = tuple(),
        moe_num_experts = 4,
        moe_loss_coef = 1e-2,
        num_mem_kv = 0,
        shared_qk = None,
        context_shared_qk = False,
        use_rezero = False,
        use_scale_norm = False,
        ff_activation = None,
        shift_tokens = False
    # 初始化函数,设置路由变换器的各种参数
    def cancel_kmeans_update(self):
        # 取消 K-means 更新
        if not exists(self._handle):
            return
        self._handle.remove()
        self._handle = None

    def register_kmeans_update(self):
        # 注册 K-means 更新
        self._handle = update_kmeans_on_backwards(self)

    def forward(self, x, **kwargs):
        # 前向传播函数
        x, loss = self.layers(x, **kwargs)
        return x, loss

class RoutingTransformerLM(nn.Module):
    # 定义一个路由变换器语言模型类,继承自 nn.Module
    def __init__(
        self,
        num_tokens,
        dim,
        depth,
        max_seq_len,
        heads = 8,
        dim_head = 64,
        window_size = 64,
        local_attn_window_size = None,
        local_attn_radius_blocks = 1,
        causal = False,
        emb_dim = None,
        weight_tie = False,
        attn_dropout = 0.,
        ff_dropout = 0.,
        attn_layer_dropout = 0.,
        layer_dropout = 0.,
        ff_mult = 4,
        ff_activation = None,
        ff_glu = False,
        return_embeddings = False,
        n_local_attn_heads = 0,
        reversible = False,
        ff_chunks = 1,
        kmeans_ema_decay = 0.999,
        commitment_factor = 1e-4,
        receives_context = False,
        context_window_size = None,
        rel_pos_emb = True,
        _register_kmeans_update = True,
        pkm_layers = tuple(),
        pkm_num_keys = 128,
        moe_layers = tuple(),
        moe_num_experts = 4,
        moe_loss_coef = 1e-2,
        num_mem_kv = 0,
        shared_qk = None,
        context_shared_qk = False,
        use_rezero = False,
        use_scale_norm = False,
        tie_embedding = False,
        use_absolute_pos_emb = False,
        shift_tokens = False
    # 初始化函数,设置路由变换器语言模型的各种参数
    ):
        # 调用父类的构造函数
        super().__init__()
        # 断言最大序列长度必须能被窗口大小整除,以计算 kmeans 簇的数量
        assert (max_seq_len % window_size) == 0, 'max sequence length must be divisible by the window size, to calculate number of kmeans cluster'
        # 如果未指定嵌入维度,则使用默认维度
        emb_dim = default(emb_dim, dim)

        # 初始化最大序列长度和正弦位置编码
        self.max_seq_len = max_seq_len
        self.sinu_pos_emb = FixedPositionalEmbedding(dim_head, max_seq_len)

        # 初始化标记嵌入层
        self.token_emb = nn.Embedding(num_tokens, emb_dim)
        # 使用正态分布初始化权重
        nn.init.normal_(self.token_emb.weight, std = 0.02)

        # 初始化路由变换器
        self.routing_transformer = RoutingTransformer(dim, depth, max_seq_len, heads = heads, dim_head = dim_head, window_size = window_size, local_attn_window_size = local_attn_window_size, local_attn_radius_blocks = local_attn_radius_blocks, causal = causal, weight_tie = weight_tie, ff_dropout = ff_dropout, attn_dropout = attn_dropout, attn_layer_dropout = attn_layer_dropout, layer_dropout = layer_dropout, n_local_attn_heads = n_local_attn_heads, ff_glu = ff_glu, reversible = reversible, ff_chunks = ff_chunks, kmeans_ema_decay = kmeans_ema_decay, receives_context = receives_context, context_window_size = context_window_size, rel_pos_emb = rel_pos_emb, pkm_layers = pkm_layers, pkm_num_keys = pkm_num_keys,  moe_layers = moe_layers, moe_num_experts = moe_num_experts, moe_loss_coef = moe_loss_coef, num_mem_kv = num_mem_kv, shared_qk = shared_qk, context_shared_qk = context_shared_qk, _register_kmeans_update = _register_kmeans_update, use_rezero = use_rezero, use_scale_norm = use_scale_norm, ff_activation = ff_activation, shift_tokens = shift_tokens)

        # 如果嵌入维度不等于维度,则使用 ProjectInOut 进行维度转换
        if emb_dim != dim:
            self.routing_transformer = ProjectInOut(self.routing_transformer, emb_dim, dim, project_out = not return_embeddings)

        # 初始化 LayerNorm 层
        self.norm = nn.LayerNorm(emb_dim)

        # 根据返回嵌入标志选择输出层
        if return_embeddings:
            self.out = nn.Identity()
        elif tie_embedding:
            self.out = MatrixMultiply(self.token_emb.weight, transpose = True)
        else:
            self.out = nn.Linear(emb_dim, num_tokens)

    # 取消 kmeans 更新
    def cancel_kmeans_update(self):
        # 找到 RoutingTransformer 模块并取消 kmeans 更新
        transformer = find_modules(self, RoutingTransformer)[0]
        transformer.cancel_kmeans_update()

    # ���新 kmeans
    def update_kmeans(self):
        # 对于所有的 Kmeans 模块,执行更新
        for m in find_modules(self, Kmeans):
            m.update()

    # 前向传播函数
    def forward(self, x, **kwargs):
        # 对输入进行标记嵌入
        x = self.token_emb(x)

        # 计算旋转位置编码
        rotary_pos_emb = self.sinu_pos_emb(x)
        # 使用路由变换器进行前向传播
        x, loss = self.routing_transformer(x, pos_emb = rotary_pos_emb, **kwargs)

        # 对输出进行 LayerNorm
        x = self.norm(x)
        # 返回输出和损失
        return self.out(x), loss

.\lucidrains\routing-transformer\routing_transformer\__init__.py

# 从 routing_transformer 包中导入 RoutingTransformer、RoutingTransformerLM、KmeansAttention、update_kmeans_on_backwards 类
from routing_transformer.routing_transformer import RoutingTransformer, RoutingTransformerLM, KmeansAttention, update_kmeans_on_backwards
# 从 routing_transformer 包中导入 RoutingTransformerEncDec 类
from routing_transformer.encoder_decoder import RoutingTransformerEncDec
# 从 routing_transformer 包中导入 AutoregressiveWrapper 类
from routing_transformer.autoregressive_wrapper import AutoregressiveWrapper
# 从 routing_transformer 包中导入 Autopadder 类
from routing_transformer.autopadder import Autopadder

.\lucidrains\routing-transformer\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'routing_transformer',  # 包的名称
  packages = find_packages(exclude=['examples']),  # 查找并包含除了 examples 之外的所有包
  version = '1.6.1',  # 版本号
  license='MIT',  # 许可证
  description = 'Routing Transformer (Pytorch)',  # 描述
  author = 'Phil Wang, Aran Komatsuzaki',  # 作者
  author_email = 'lucidrains@gmail.com, aran1234321@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/routing-transformer',  # 项目链接
  keywords = ['transformers', 'attention', 'artificial intelligence'],  # 关键词
  install_requires=[
      'einops',  # 安装所需的依赖包
      'local-attention>=1.4.0',
      'mixture-of-experts>=0.2.0',
      'product-key-memory',
      'torch'
  ],
  classifiers=[
      'Development Status :: 4 - Beta',  # 分类器
      'Intended Audience :: Developers',
      'Topic :: Scientific/Engineering :: Artificial Intelligence',
      'License :: OSI Approved :: MIT License',
      'Programming Language :: Python :: 3.6',
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

RQ-Transformer

Implementation of RQ Transformer, which proposes a more efficient way of training multi-dimensional sequences autoregressively. This repository will only contain the transformer for now. You can use this vector quantization library for the residual VQ.

This type of axial autoregressive transformer should be compatible with memcodes, proposed in NWT. It would likely also work well with multi-headed VQ

Install

$ pip install RQ-transformer

Usage

import torch
from rq_transformer import RQTransformer

model = RQTransformer(
    num_tokens = 16000,             # number of tokens, in the paper they had a codebook size of 16k
    dim = 512,                      # transformer model dimension
    max_spatial_seq_len = 1024,     # maximum positions along space
    depth_seq_len = 4,              # number of positions along depth (residual quantizations in paper)
    spatial_layers = 8,             # number of layers for space
    depth_layers = 4,               # number of layers for depth
    dim_head = 64,                  # dimension per head
    heads = 8,                      # number of attention heads
)

x = torch.randint(0, 16000, (1, 1024, 4))

loss = model(x, return_loss = True)
loss.backward()

# then after much training

logits = model(x)

# and sample from the logits accordingly
# or you can use the generate function

sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4)

I also think there is something deeper going on, and have generalized this to any number of dimensions. You can use it by importing the HierarchicalCausalTransformer

import torch
from rq_transformer import HierarchicalCausalTransformer

model = HierarchicalCausalTransformer(
    num_tokens = 16000,                   # number of tokens
    dim = 512,                            # feature dimension
    dim_head = 64,                        # dimension of attention heads
    heads = 8,                            # number of attention heads
    depth = (4, 4, 2),                    # 3 stages (but can be any number) - transformer of depths 4, 4, 2
    max_seq_len = (16, 4, 5)              # the maximum sequence length of first, stage, then the fixed sequence length of all subsequent stages
).cuda()

x = torch.randint(0, 16000, (1, 10, 4, 5)).cuda()

loss = model(x, return_loss = True)
loss.backward()

# after a lot training

sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 16, 4, 5)

Todo

Citations

@unknown{unknown,
    author  = {Lee, Doyup and Kim, Chiheon and Kim, Saehoon and Cho, Minsu and Han, Wook-Shin},
    year    = {2022},
    month   = {03},
    title   = {Autoregressive Image Generation using Residual Quantization}
}
@misc{press2021ALiBi,
    title   = {Train Short, Test Long: Attention with Linear Biases Enable Input Length Extrapolation},
    author  = {Ofir Press and Noah A. Smith and Mike Lewis},
    year    = {2021},
    url     = {https://ofir.io/train_short_test_long.pdf}
}

.\lucidrains\RQ-Transformer\rq_transformer\hierarchical_causal_transformer.py

# 导入数学库
import math
# 导入 functools 库
import functools
# 导入 torch 库
import torch
# 导入 torch.nn.functional 库
import torch.nn.functional as F
# 从 torch 中导入 nn 和 einsum
from torch import nn, einsum
# 从 einops_exts 中导入 rearrange_with_anon_dims
from einops_exts import rearrange_with_anon_dims
# 从 einops 中导入 rearrange, reduce, repeat

# helpers

# 判断值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 计算 num 与 mult 的余数
def remainder_to_mult(num, mult):
    return (mult - num % mult) % mult

# 将输入转换为元组
def cast_tuple(t, length = 1):
    return t if isinstance(t, tuple) else ((t,) * length)

# 对多个数进行乘法运算
def reduce_mult(nums):
    return functools.reduce(lambda x, y: x * y, nums, 1)

# tensor helpers

# 计算张量的对数
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 生成 Gumbel 分布采样
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)

# 获取前 k 个最大值的概率
def top_k(logits, thres = 0.5):
    num_logits = logits.shape[-1]
    k = max(int((1 - thres) * num_logits), 1)
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# positional bias

# 定义 Alibi 类
class Alibi(nn.Module):
    def __init__(self, heads, **kwargs):
        super().__init__()
        self.heads = heads
        slopes = torch.Tensor(self._get_slopes(heads))
        slopes = rearrange(slopes, 'h -> h 1 1')
        self.register_buffer('slopes', slopes, persistent = False)
        self.register_buffer('bias', None, persistent = False)

    @staticmethod
    def _get_slopes(heads):
        def get_slopes_power_of_2(n):
            start = (2**(-2**-(math.log2(n)-3)))
            ratio = start
            return [start*ratio**i for i in range(n)]

        if math.log2(heads).is_integer():
            return get_slopes_power_of_2(heads)

        closest_power_of_2 = 2 ** math.floor(math.log2(heads))
        return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]

    def forward(self, i, j, device):
        if exists(self.bias) and self.bias.shape[-1] >= j:
            return self.bias[..., :j]

        bias = torch.arange(j, device = device)
        bias = rearrange(bias, 'j -> 1 1 j')
        bias = bias * self.slopes

        self.register_buffer('bias', bias, persistent = False)
        return self.bias

# norm

# 定义 RMSNorm 类
class RMSNorm(nn.Module):
    def __init__(self, dim, eps = 1e-8):
        super().__init__()
        self.scale = dim ** -0.5
        self.eps = eps
        self.g = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
        return x / norm.clamp(min = self.eps) * self.g

# helper classes

# 定义 FeedForward 函数
def FeedForward(*, dim, mult = 4, dropout = 0.):
    return nn.Sequential(
        RMSNorm(dim),
        nn.Linear(dim, dim * mult),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(dim * mult, dim)
    )

# 定义 Attention 类
class Attention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        dropout = 0.
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.dropout = nn.Dropout(dropout)
        self.norm = RMSNorm(dim)
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)
    # 实现自注意力机制的前向传播
    def forward(self, x, attn_bias = None):
        # 获取头数和设备信息
        h, device = self.heads, x.device

        # 对输入进行归一化处理
        x = self.norm(x)
        # 将输入转换为查询、键、值
        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
        # 将查询向量重新排列为多头形式
        q = rearrange(q, 'b n (h d) -> b h n d', h = h)

        # 缩放查询向量
        q = q * self.scale
        # 计算注意力分数
        sim = einsum('b h i d, b j d -> b h i j', q, k)

        # 如果存在注意力偏置,则加上
        if exists(attn_bias):
            sim = sim + attn_bias

        # 创建掩码
        i, j = sim.shape[-2:]
        mask_value = -torch.finfo(sim.dtype).max
        mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
        sim = sim.masked_fill(mask, mask_value)

        # 对注意力分数进行归一化处理
        sim = sim - sim.amax(dim = -1, keepdim = True).detach()
        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        # 计算输出
        out = einsum('b h i j, b j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
# 定义一个名为 Transformer 的类,继承自 nn.Module
class Transformer(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        *,
        dim,  # 维度
        layers,  # 层数
        dim_head = 64,  # 头部维度
        heads = 8,  # 头部数量
        attn_dropout = 0.,  # 注意力机制的 dropout
        ff_dropout = 0.,  # 前馈神经网络的 dropout
        ff_mult = 4,  # 前馈神经网络的倍数
        rel_pos_bias = True  # 是否使用相对位置偏置
    ):
        super().__init__()
        # 如果使用相对位置偏置,则创建 Alibi 对象,否则为 None
        self.alibi = Alibi(heads = heads) if rel_pos_bias else None
        # 创建空的 nn.ModuleList 对象
        self.layers = nn.ModuleList([])

        # 循环创建 layers 个层
        for _ in range(layers):
            # 每个层包含一个注意力机制和一个前馈神经网络
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout),
                FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
            ]))

        # 创建 RMSNorm 对象
        self.norm = RMSNorm(dim)

    # 前向传播函数
    def forward(self, x):
        # 获取输入张量 x 的倒数第二个维度的大小
        n = x.shape[-2]
        # 如果存在相对位置偏置,则根据输入张量 x 的设备创建注意力偏置
        attn_bias = self.alibi(n, n, device = x.device) if exists(self.alibi) else None

        # 遍历每个层中的注意力机制和前馈神经网络
        for attn, ff in self.layers:
            # 使用注意力机制处理输入张量 x,并加上原始输入
            x = attn(x, attn_bias = attn_bias) + x
            # 使用前馈神经网络处理输入张量 x,并加上原始输入
            x = ff(x) + x

        # 返回经过归一化处理后的结果
        return self.norm(x)

# 主类
class HierarchicalCausalTransformer(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        *,
        num_tokens,  # 标记数量
        dim,  # 维度
        depth,  # 深度
        max_seq_len,  # 最大序列长度
        dim_head = 64,  # 头部维度
        heads = 8,  # 头部数量
        attn_dropout = 0.,  # 注意力机制的 dropout
        ff_mult = 4,  # 前馈神经网络的倍数
        ff_dropout = 0.,  # 前馈神经网络的 dropout
        pad_id = 0,  # 填充标记的 id
        rel_pos_bias = True  # 是否使用相对位置偏置
    ):
        super().__init__()

        # 简化每个层次的配置
        # depth = (2, 2, 4) ���示第一阶段深度为 2,第二阶段深度为 2,第三阶段深度为 4
        # max_seq_len = (16, 8, 4) 表示第一阶段最大序列长度为 16,第二阶段为 8,第三阶段为 4

        assert isinstance(depth, tuple) and isinstance(max_seq_len, tuple)
        assert len(depth) == len(max_seq_len)

        # 阶段数量为深度元组的长度
        self.stages = len(depth)

        # 创建标记嵌入层
        self.token_emb = nn.Embedding(num_tokens, dim)
        # 创建起始标记参数
        self.start_tokens = nn.Parameter(torch.randn(dim))

        # 最大序列长度和位置嵌入层列表
        self.max_seq_len = max_seq_len
        self.pos_embs = nn.ModuleList([nn.Embedding(seq_len, dim) for seq_len in max_seq_len])

        # 创建 Transformer 模块列表
        self.transformers = nn.ModuleList([])

        # 遍历每个阶段的深度
        for stage_depth in depth:
            # 创建 Transformer 模块并添加到列表中
            self.transformers.append(Transformer(
                dim = dim,
                layers = stage_depth,
                dim_head = dim_head,
                heads = heads,
                attn_dropout = attn_dropout,
                ff_dropout = ff_dropout,
                ff_mult = ff_mult,
                rel_pos_bias = rel_pos_bias
            ))

        # 创建线性层用于输出标记
        self.to_logits = nn.Linear(dim, num_tokens)
        # 填充标记的 id
        self.pad_id = pad_id

    # 生成函数
    def generate(self, prime = None, filter_thres = 0.9, temperature = 1., default_batch_size = 1):
        # 计算总序列长度
        total_seq_len = reduce_mult(self.max_seq_len)
        # 获取设备
        device = next(self.parameters()).device

        # 如果 prime 为空,则创建一个空的张量
        if not exists(prime):
            prime = torch.empty((default_batch_size, 0), dtype = torch.long, device = device)

        # 初始化序列为 prime
        seq = prime

        # 循环生成序列
        for _ in range(total_seq_len - seq.shape[-1]):
            # 获取 logits
            logits = self.forward(seq)[:, -1]
            # 根据 filter_thres 过滤 top-k logits
            logits = top_k(logits, thres = filter_thres)
            # 使用 Gumbel 分布采样
            sampled = gumbel_sample(logits, dim = -1, temperature = temperature)
            # 将采样结果拼接到序列中
            seq = torch.cat((seq, rearrange(sampled, 'b -> b 1')), dim = -1)

        # 重新排列序列并返回
        return rearrange_with_anon_dims(seq, 'b (...d) -> b ...d', d = self.max_seq_len)

    # 空输入前向传播函数
    def forward_empty(self, batch_size):
        # 处理特殊情况,从输入为 0(仅起始标记)的样本中采样

        # 重复起始标记,创建 tokens 张量
        tokens = repeat(self.start_tokens, 'd -> b 1 d', b = batch_size)

        # 遍历每个 Transformer 模块
        for transformer in self.transformers:
            tokens = transformer(tokens)

        # 返回 logits
        return self.to_logits(tokens)
    # 定义前向传播函数,接受输入 ids 和是否返回损失值的标志
    def forward(self, ids, return_loss = False):
        # 断言输入 ids 的维度为 2 或者 self.stages + 1
        assert ids.ndim in {2, self.stages + 1}
        # 检查是否为扁平化维度
        flattened_dims = ids.ndim == 2
        # 保存原始 ids 的维度
        ids_orig_ndim = ids.ndim

        # 如果 ids 为空,则调用 forward_empty 函数
        if ids.numel() == 0:
            return self.forward_empty(ids.shape[0])

        # 如果是扁平化维度,则进行自动填充
        if flattened_dims:
            # 获取序列长度
            seq_len = ids.shape[-1]
            # 计算填充值
            multiple_of = reduce_mult(self.max_seq_len[1:])
            padding = remainder_to_mult(seq_len, multiple_of)
            # 对 ids 进行填充和重新排列
            ids = F.pad(ids, (0, padding), value = self.pad_id)
            ids = rearrange_with_anon_dims(ids, 'b (l ...d) -> b l ...d', d = self.max_seq_len[1:])

        # 获取 ids 的形状和设备信息
        b, *prec_dims, device = *ids.shape, ids.device

        # 检查一些维度

        assert prec_dims[0] <= self.max_seq_len[0], 'the first dimension of your axial autoregressive transformer must be less than the first tuple element of max_seq_len (like any autoregressive transformer)'
        assert tuple(prec_dims[1:]) == tuple(self.max_seq_len[1:]), 'all subsequent dimensions must match exactly'

        # 获取 token embeddings

        tokens = self.token_emb(ids)

        # 获取所有层次阶段的 tokens,减少适当的维度并添加绝对位置嵌入

        tokens_at_stages = []
        reduced_tokens = tokens

        for ind, pos_emb in zip(range(len(prec_dims)), reversed(self.pos_embs)):
            is_first = ind == 0

            if not is_first:
                reduced_tokens = reduce(reduced_tokens, 'b ... r d -> b ... d', 'sum')

            positions = pos_emb(torch.arange(reduced_tokens.shape[-2], device = device))
            tokens_with_position = reduced_tokens + positions
            tokens_at_stages.insert(0, tokens_with_position)

        # 获取起始 tokens 并附加到最粗糙的阶段

        start_tokens = repeat(self.start_tokens, 'f -> b 1 f', b = b)

        # 空间 tokens 是在深度 pos 减少的 tokens + 空间位置

        for ind, (stage_tokens, transformer) in enumerate(zip(tokens_at_stages, self.transformers)):
            is_last = ind == (self.stages - 1)

            stage_tokens = torch.cat((
                start_tokens,
                stage_tokens,
            ), dim = -2)

            *prec_dims, _, _ = stage_tokens.shape

            stage_tokens = rearrange(stage_tokens, '... n d -> (...) n d')
            attended = transformer(stage_tokens)
            attended = rearrange_with_anon_dims(attended, '(...b) n d -> ...b n d', b = prec_dims)

            start_tokens = rearrange(attended[..., :-1, :], '... n d -> ... n 1 d')

        logits = self.to_logits(attended)

        logits = logits[..., 1:, :]

        # 如果不需要返回损失值

        if not return_loss:

            if flattened_dims:
                logits = rearrange(logits, 'b ... n -> b (...) n')
                logits = logits[:, :seq_len]

            return logits

        preds = rearrange(logits, 'b ... c -> b c (...)')
        labels = rearrange(ids, 'b ... -> b (...)')

        # 计算交叉熵损失
        loss = F.cross_entropy(
            preds[..., :-1],
            labels[..., 1:],
            ignore_index = self.pad_id
        )
        return loss

.\lucidrains\RQ-Transformer\rq_transformer\rq_transformer.py

import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops_exts import rearrange_with_anon_dims
from einops import rearrange, reduce, repeat

# helpers

# 检查值是否存在
def exists(val):
    return val is not None

# 返回值或默认值
def default(val, d):
    return val if exists(val) else d

# 计算余数到最接近的倍数
def remainder_to_mult(num, mult):
    return (mult - num % mult) % mult

# 计算对数,避免值过小
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 生成 Gumbel 分布采样
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)

# 保留前 k 个最大值,其余设为负无穷
def top_k(logits, thres = 0.5):
    num_logits = logits.shape[-1]
    k = max(int((1 - thres) * num_logits), 1)
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# helper classes

# 前馈神经网络
def FeedForward(*, dim, mult = 4, dropout = 0.):
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, dim * mult),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(dim * mult, dim)
    )

# 注意力机制
class Attention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        dropout = 0.
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(dim)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        h, device = self.heads, x.device

        x = self.norm(x)
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        q = q * self.scale
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        i, j = sim.shape[-2:]
        mask_value = -torch.finfo(sim.dtype).max
        mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
        sim = sim.masked_fill(mask, mask_value)

        sim = sim - sim.amax(dim = -1, keepdim = True).detach()
        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# Transformer 模块
class Transformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        layers,
        dim_head = 64,
        heads = 8,
        attn_dropout = 0.,
        ff_dropout = 0.,
        ff_mult = 4
    ):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(layers):
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout),
                FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
            ]))

        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

# 主类

class RQTransformer(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        max_spatial_seq_len,
        depth_seq_len,
        spatial_layers,
        depth_layers,
        dim_head = 64,
        heads = 8,
        attn_dropout = 0.,
        ff_mult = 4,
        ff_dropout = 0.,
        pad_id = 0
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化模型的维度
        self.dim = dim
        # 初始化空间序列的最大长度
        self.max_spatial_seq_len = max_spatial_seq_len
        # 初始化深度序列的长度
        self.depth_seq_len = depth_seq_len

        # 创建一个词嵌入层,用于将输入的标记转换为向量表示
        self.token_emb = nn.Embedding(num_tokens, dim)
        # 初始化空间序列的起始标记
        self.spatial_start_token = nn.Parameter(torch.randn(dim))

        # 创建一个空间位置编码层
        self.spatial_pos_emb = nn.Embedding(max_spatial_seq_len + 1, dim) # 考虑到一个边界情况
        # 创建一个深度位置编码层
        self.depth_pos_emb = nn.Embedding(depth_seq_len, dim)

        # 创建一个空间变换器,用于处理空间序列的变换
        self.spatial_transformer = Transformer(
            dim = dim,
            layers = spatial_layers,
            dim_head = dim_head,
            heads = heads,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            ff_mult = ff_mult
        )

        # 创建一个深度变换器,用于处理深度序列的变换
        self.depth_transformer = Transformer(
            dim = dim,
            layers = depth_layers,
            dim_head = dim_head,
            heads = heads,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            ff_mult = ff_mult
        )

        # 创建一个线性层,用于将模型输出转换为标记的概率分布
        self.to_logits = nn.Linear(dim, num_tokens)
        # 初始化填充标记的ID
        self.pad_id = pad_id

    def generate(self, prime = None, filter_thres = 0.9, temperature = 1., default_batch_size = 1):
        # 计算总的序列长度
        total_seq_len = self.depth_seq_len * self.max_spatial_seq_len
        # 获取模型所在的设备
        device = next(self.parameters()).device

        # 如果没有给定初始输入,则创建一个空的张量作为初始输入
        if not exists(prime):
            prime = torch.empty((default_batch_size, 0), dtype = torch.long, device = device)

        seq = prime

        # 生成序列
        for _ in range(total_seq_len - seq.shape[-1]):
            # 获取模型的预测结果
            logits = self.forward(seq)[:, -1]
            # 通过阈值筛选保留概率较高的标记
            logits = top_k(logits, thres = filter_thres)
            # 通过Gumbel采样获取下一个标记
            sampled = gumbel_sample(logits, dim = -1, temperature = temperature)
            # 将新生成的标记添加到序列中
            seq = torch.cat((seq, rearrange(sampled, 'b -> b 1')), dim = -1)

        # 重新排列生成的序列
        return rearrange(seq, 'b (s d) -> b s d', d = self.depth_seq_len)

    def forward_empty(self, batch_size):
        # 处理特殊情况,当从输入中只采样到0(仅起始标记)时

        # 重复空间起始标记,以匹配指定的批量大小
        spatial_tokens = repeat(self.spatial_start_token, 'd -> b 1 d', b = batch_size)
        # 经过空间变换器处理
        depth_tokens = self.spatial_transformer(spatial_tokens)
        # 经过深度变换器处理
        depth_tokens = self.depth_transformer(depth_tokens)
        # 将处理后的深度标记转换为模型输出
        return self.to_logits(depth_tokens)
    # 定义前向传播函数,接受输入 ids 和是否返回损失值的标志
    def forward(self, ids, return_loss = False):
        # 断言输入 ids 的维度为 2 或 3
        assert ids.ndim in {2, 3}
        # 检查是否为扁平化维度
        flattened_dim = ids.ndim == 2
        # 保存原始 ids 的维度
        ids_orig_ndim = ids.ndim

        # 如果 ids 中元素数量为 0,则调用 forward_empty 函数处理
        if ids.numel() == 0:
            return self.forward_empty(ids.shape[0])

        # 如果是扁平化维度
        if flattened_dim:
            # 允许 ids 的形状为 (batch, seq),自动填充到最接近深度序列长度的倍数
            seq_len = ids.shape[-1]
            padding = remainder_to_mult(seq_len, self.depth_seq_len)
            ids = F.pad(ids, (0, padding), value = self.pad_id)
            ids = rearrange(ids, 'b (s d) -> b s d', d = self.depth_seq_len)
        else:
            seq_len = ids.shape[1] * ids.shape[2]

        # 获取 ids 的形状、空间维度、深度维度、设备信息
        b, space, depth, device = *ids.shape, ids.device
        # 断言空间维度小于等于最大空间序列长度加一
        assert space <= (self.max_spatial_seq_len + 1), 'spatial dimension is greater than the max_spatial_seq_len set'
        # 断言深度维度等于深度序列长度
        assert depth == self.depth_seq_len, 'depth dimension must be equal to depth_seq_len'

        # 获取 token embeddings
        tokens = self.token_emb(ids)

        # 获取空间位置编码和深度位置编码
        spatial_pos = self.spatial_pos_emb(torch.arange(space, device = device))
        depth_pos = self.depth_pos_emb(torch.arange(depth, device = device))

        # 将 token embeddings 和深度位置编码相加
        tokens_with_depth_pos = tokens + depth_pos

        # 计算空间 tokens
        spatial_tokens = reduce(tokens_with_depth_pos, 'b s d f -> b s f', 'sum') + spatial_pos

        # 在空间 tokens 前添加起始 token
        spatial_tokens = torch.cat((
            repeat(self.spatial_start_token, 'f -> b 1 f', b = b),
            spatial_tokens
        ), dim = -2)        

        # 使用空间 transformer 处理空间 tokens
        spatial_tokens = self.spatial_transformer(spatial_tokens)

        # 重新排列空间 tokens 的维度
        spatial_tokens = rearrange(spatial_tokens, 'b s f -> b s 1 f')

        # 将空间 tokens 变为深度维度的起始 tokens
        tokens_with_depth_pos = F.pad(tokens_with_depth_pos, (0, 0, 0, 0, 0, 1), value = 0.)

        # 拼��深度 tokens
        depth_tokens = torch.cat((spatial_tokens, tokens_with_depth_pos), dim = -2)

        # 重新排列深度 tokens 的维度
        depth_tokens = rearrange(depth_tokens, '... n d -> (...) n d')

        # 使用深度 transformer 处理深度 tokens
        depth_tokens = self.depth_transformer(depth_tokens)

        # 重新排列深度 tokens 的维度
        depth_tokens = rearrange(depth_tokens, '(b s) d f -> b s d f', b = b)

        # 获取 logits
        logits = self.to_logits(depth_tokens)
        logits = rearrange(logits, 'b ... f -> b (...) f')
        logits = logits[:, :(seq_len + 1)]

        # 如果不需要返回损失值
        if not return_loss:
            logits = logits[:, 1:]

            # 如果是扁平化维度,则返回重新排列后的 logits
            if flattened_dim:
                return rearrange(logits, 'b ... n -> b (...) n')

            return logits

        # 如果需要返回损失值
        logits = logits[:, :-1]
        
        # 重新排列 logits 和 ids 的维度
        preds = rearrange(logits, 'b ... c -> b c (...)')
        labels = rearrange(ids, 'b s d -> b (s d)')

        # 计算交叉熵损失
        loss = F.cross_entropy(preds, labels, ignore_index = self.pad_id)
        return loss

.\lucidrains\RQ-Transformer\rq_transformer\__init__.py

# 从 rq_transformer 模块中导入 RQTransformer 类
from rq_transformer.rq_transformer import RQTransformer
# 从 rq_transformer 模块中导入 HierarchicalCausalTransformer 类
from rq_transformer.hierarchical_causal_transformer import HierarchicalCausalTransformer

.\lucidrains\RQ-Transformer\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'RQ-transformer',  # 包名
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.1.9',  # 版本号
  license='MIT',  # 许可证
  description = 'RQ Transformer - Autoregressive Transformer for Residual Quantized Codes',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/RQ-transformer',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention-mechanism',
    'autoregressive',
  ],
  install_requires=[  # 安装依赖
    'einops>=0.4',
    'einops-exts',
    'torch>=1.6',
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\RQ-Transformer\train.py

# 导入所需的库
from rq_transformer import HierarchicalCausalTransformer

import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 常量定义
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
PRIME_LEN = 100
SEQ_LEN = 1024

# 辅助函数

# 从 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 从 tokens 解码为字符串
def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# 实例化类似 GPT 的解码器模型
model = HierarchicalCausalTransformer(
    num_tokens = 256,
    dim = 512,
    depth = (4, 3, 3, 3),
    max_seq_len = (4, 4, 8, 8)
).cuda()

# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 定义文本采样数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 训练
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader), return_loss = True)
        loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader), return_loss = True)
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime_inp = inp[:PRIME_LEN]
        prime = decode_tokens(prime_inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(prime_inp[None, :])
        sample = sample.flatten(1)

        output_str = decode_tokens(sample[0][PRIME_LEN:])
        print(output_str)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

RVQ-VAE-GPT - Residual Vector Quantize VAE - GPT (wip)

My attempts at applying Soundstream design on learned tokenization of text and then applying a hierarchical transformer to text generation.

The Soundstream will be modified to use all local attention. Experiments will compare VQ, RVQ, and also multi-headed VQ

Was told by a researcher friend this will likely fail 😂😂 but I will try it anyways, yolo. In the case it does not work, maybe it can still be useful for genomics. Come to think of it, why shouldn't it be able to at least learn bigrams (for english) and codons (for genomics)? Why don't we have hierarchical predictive coding? We should

Update: Some live experiments

Todo

Citations

@misc{https://doi.org/10.48550/arxiv.2107.03312,
  title  = {SoundStream: An End-to-End Neural Audio Codec},
  author = {Zeghidour, Neil and Luebs, Alejandro and Omran, Ahmed and Skoglund, Jan and Tagliasacchi, Marco},
  publisher = {arXiv},
  url    = {https://arxiv.org/abs/2107.03312},
  year   = {2021}
}
@unknown{unknown,
    author  = {Lee, Doyup and Kim, Chiheon and Kim, Saehoon and Cho, Minsu and Han, Wook-Shin},
    year    = {2022},
    month   = {03},
    title   = {Autoregressive Image Generation using Residual Quantization}
}
@article{Sunkara2022NoMS,
    title   = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
    author  = {Raja Sunkara and Tie Luo},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2208.03641}
}

.\lucidrains\rvq-vae-gpt\rvq_vae_gpt\rvq_vae_gpt.py

# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块并重命名为 F
import torch.nn.functional as F
# 从 torch 模块中导入 nn、einsum
from torch import nn, einsum

# 从 einops 库中导入 rearrange、repeat、pack、unpack
from einops import rearrange, repeat, pack, unpack
# 从 einops.layers.torch 模块中导入 Rearrange
from einops.layers.torch import Rearrange

# 导入自定义的 local_attention 模块中的 LocalMHA 类
from local_attention import LocalMHA
# 导入自定义的 vector_quantize_pytorch 模块中的 VectorQuantize、ResidualVQ 类
from vector_quantize_pytorch import VectorQuantize, ResidualVQ

# 从 beartype 库中导入 beartype、Tuple、Optional、Union
from beartype import beartype
from beartype.typing import Tuple, Optional, Union

# 从 pathlib 模块中导入 Path 类
from pathlib import Path
# 导入 pickle 库
import pickle

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 获取迭代器的第一个元素
def first(it):
    return it[0]

# 返回第一个存在的值
def default(*vals):
    for val in vals:
        if exists(val):
            return val
    return None

# 判断一个数是否可以被另一个数整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# 将输入转换为元组
def cast_tuple(t, len = 1):
    return ((t,) * len) if not isinstance(t, tuple) else t

# token shift - RWKV 中使用

# 将输入张量按照最后一个维度分割成两部分,并进行位移
def shift_tokens(t):
    t, t_shift = t.chunk(2, dim = -1)
    t_shift = F.pad(t_shift, (0, 0, 1, -1), value = 0.)
    return torch.cat((t, t_shift), dim = -1)

# 前馈网络

# GEGLU 激活函数
class GEGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim = -1)
        return x * F.gelu(gate)

# 创建前馈网络模块
def FeedForward(dim, mult = 4):
    dim_inner = int(dim * mult * 2 / 3)

    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, dim_inner * 2),
        GEGLU(),
        nn.Linear(dim_inner, dim)
    )

# 最佳的上采样和下采样方式

# 上采样模块
class Upsample(nn.Module):
    def __init__(
        self,
        dim,
        dim_out = None,
        factor = 2
    ):
        super().__init__()
        dim_out = default(dim_out, dim)
        linear = nn.Linear(dim, dim_out * factor)

        self.net = nn.Sequential(
            linear,
            nn.SiLU(),
            Rearrange('b n (p d) -> b (n p) d', p = factor)
        )

        self.factor = factor
        self.init_(linear)

    # 初始化线性层的权重和偏置
    def init_(self, linear):
        o, i = linear.weight.shape

        linear_weight = torch.empty(o // self.factor, i)
        nn.init.kaiming_uniform_(linear_weight)

        linear_weight = repeat(linear_weight, 'o ... -> (o r) ...', r = self.factor)

        linear_weight.data.copy_(linear_weight)
        nn.init.zeros_(linear.bias.data)

    def forward(self, x):
        return self.net(x)

# 下采样模块
def Downsample(
    dim,
    dim_out = None,
    factor = 2
):
    dim_out = default(dim_out, dim)
    return nn.Sequential(
        Rearrange('b (n p) d -> b n (p d)', p = factor),
        nn.Linear(dim * factor, dim_out)
    )

# 本地注意力

# 本地 Transformer 模块
class LocalTransformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        heads,
        dim_head,
        window_size
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                LocalMHA(
                    dim = dim,
                    heads = heads,
                    dim_head = dim_head,
                    qk_rmsnorm = True,
                    window_size = window_size,
                    use_rotary_pos_emb = True,
                    use_xpos = True,
                    causal = True
                ),
                FeedForward(dim = dim)
            ]))

    def forward(self, x):

        for attn, ff in self.layers:
            x = attn(shift_tokens(x)) + x
            x = ff(shift_tokens(x)) + x

        return x

# 模块

# 文本 VQ-VAE 模型
@beartype
class TextVQVAE(nn.Module): # 或者基因组,最终,将 num_tokens 设置为 4
    def __init__(
        self,
        *,
        num_tokens,
        dim: Union[int, Tuple[int, ...]],
        depth: Union[int, Tuple[int, ...]],
        strides: Union[int, Tuple[int, ...]],
        codebook_size = 1024,
        local_attn_window_size = 32,
        local_attn_heads = 8,
        local_attn_dim_head = 64,
        num_codebooks = 4,
        vq_decay = 0.9,
        rvq_quantize_dropout = True
    # 初始化函数,继承父类的初始化方法
    def __init__(
        self,
        vq_decay,
        strides,
        dim,
        depth,
        local_attn_window_size,
        num_tokens,
        local_attn_heads,
        local_attn_dim_head,
        num_codebooks,
        codebook_size,
        rvq_quantize_dropout
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 获取当前函数的局部变量
        config = locals()
        # 移除不需要的变量
        config.pop('self')
        config.pop('__class__')
        # 将配置信息保存到实例变量中
        self._config = config

        # 断言 vq_decay 的取值范围
        assert 0 < vq_decay <= 1.

        # 将 strides 转换为元组
        strides = cast_tuple(strides)
        num_layers = len(strides)

        # 将 dim、depth、local_attn_window_size 转换为元组
        dim = cast_tuple(dim, num_layers)
        depth = cast_tuple(depth, num_layers)
        local_attn_window_size = cast_tuple(local_attn_window_size, num_layers)

        # 断言各参数长度一致
        assert num_layers == len(depth) == len(local_attn_window_size) == len(dim)

        # 获取初始维度和 VQ 维度
        init_dim, vq_dim = dim[0], dim[-1]

        # 构建维度列表和维度对
        dims = [first(dim), *dim]
        dim_pairs = tuple(zip(dims[:-1], dims[1:]))

        # 创建 token embedding 层
        self.token_emb = nn.Embedding(num_tokens, init_dim)

        # 计算总步长
        self.total_strides = torch.tensor(list(strides)).cumprod(dim = -1)[-1].item()

        # 初始化 encoder
        self.encoder = nn.ModuleList([])

        # 构建每一层的参数元组
        layer_params = tuple(zip(
            strides,
            depth,
            local_attn_window_size,
            dim_pairs
        ))

        # 初始化初始 transformer
        self.init_transformer = LocalTransformer(
            dim = init_dim,
            depth = first(depth),
            heads = local_attn_heads,
            dim_head = local_attn_dim_head,
            window_size = first(local_attn_window_size)
        )

        # 初始化最终 transformer
        self.final_transformer = LocalTransformer(
            dim = init_dim,
            depth = first(depth),
            heads = local_attn_heads,
            dim_head = local_attn_dim_head,
            window_size = first(local_attn_window_size)
        )

        # 遍历每一层参数,构建 encoder
        for layer_stride, layer_depth, layer_local_attn_window_size, (dim_in, dim_out) in layer_params:
            self.encoder.append(nn.ModuleList([
                Downsample(dim = dim_in, dim_out = dim_out, factor = layer_stride),
                LocalTransformer(
                    dim = dim_out,
                    depth = layer_depth,
                    heads = local_attn_heads,
                    dim_head = local_attn_dim_head,
                    window_size = layer_local_attn_window_size
                )
            ]))

        # 初始化 encoder_norm
        self.encoder_norm = nn.LayerNorm(vq_dim)

        # 初始化 VQ
        self.vq = ResidualVQ(
            dim = vq_dim,
            num_quantizers = num_codebooks,
            codebook_size = codebook_size,
            decay = vq_decay,
            quantize_dropout = num_codebooks > 1 and rvq_quantize_dropout,
            commitment_weight = 0.,   # the weight on the commitment loss
            kmeans_init = True,
            kmeans_iters = 10
        )

        # 初始化 decoder
        self.decoder = nn.ModuleList([])

        # 遍历每一层参数,构建 decoder
        for layer_stride, layer_depth, layer_local_attn_window_size, (dim_in, dim_out) in reversed(layer_params):
            self.decoder.append(nn.ModuleList([
                Upsample(dim = dim_out, dim_out = dim_in, factor = layer_stride),
                LocalTransformer(
                    dim = dim_out,
                    depth = layer_depth,
                    heads = local_attn_heads,
                    dim_head = local_attn_dim_head,
                    window_size = layer_local_attn_window_size
                )
            ]))

        # 初始化 to_logits
        self.to_logits = nn.Sequential(
            nn.LayerNorm(init_dim),
            nn.Linear(init_dim, num_tokens)
        )

    # 保存模型
    def save(self, path):
        path = Path(path)
        pkg = dict(
            model = self.state_dict(),
            config = pickle.dumps(self._config)
        )
        torch.save(pkg, str(path))

    # 加载模型
    def load(self, path):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path))
        self.load_state_dict(pkg['model'])

    # 初始化并加载模型
    @classmethod
    def init_and_load(cls, path):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path))
        model = cls(**pickle.loads(pkg['config']))
        model.load(path)
        return model

    # 获取设备信息
    @property
    def device(self):
        return next(self.parameters()).device
    # 编码器,将输入的ids转换为tokens
    def encode(self, ids):
        # 使用token_emb方法将ids转换为tokens
        tokens = self.token_emb(ids)

        # 使用init_transformer方法对tokens进行初始化转换
        tokens = self.init_transformer(tokens)

        # 遍历编码器中的每个层,进行下采样和局部注意力操作
        for downsample, local_attn in self.encoder:
            tokens = downsample(tokens)
            tokens = local_attn(tokens)

        # 对编码后的tokens进行编码器归一化
        return self.encoder_norm(tokens)

    # 解码器,将codes解码为logits
    def decode(self, codes):
        # 将codes赋值给tokens
        tokens = codes

        # 遍历解码器中的每个层,进行局部注意力和上采样操作
        for upsample, local_attn in self.decoder:
            tokens = local_attn(tokens)
            tokens = upsample(tokens)

        # 对解码后的tokens进行最终转换
        tokens = self.final_transformer(tokens)

        # 将tokens转换为logits
        logits = self.to_logits(tokens)
        return logits

    # 从codebook_ids解码得到logits
    @torch.no_grad()
    def decode_from_codebook_ids(self, codebook_ids):
        # 使用vq对象的get_codes_from_indices方法将codebook_ids转换为codes
        codes = self.vq.get_codes_from_indices(codebook_ids)
        # 调用decode方法解码codes得到logits
        return self.decode(codes)

    # 整体前向传播过程
    def forward(
        self,
        ids,
        return_codebook_indices = False,
        return_reconstruction = False,
        return_loss_breakdown = False
    ):
        # 获取ids的batch和seq长度
        batch, seq = ids.shape
        # 断言seq能够被total_strides整除
        assert divisible_by(seq, self.total_strides)

        # 将ids移动到设备上
        ids = ids.to(self.device)

        # 对ids进行编码得到tokens
        tokens = self.encode(ids)

        # 对tokens进行向量量化操作,返回更新后的tokens、indices和loss
        tokens, indices, _ = self.vq(tokens)

        # 如果需要返回codebook_indices,则直接返回indices
        if return_codebook_indices:
            return indices

        # 对tokens进行解码得到logits
        logits = self.decode(tokens)

        # 将logits重新排列为 'b c n' 的形式
        logits = rearrange(logits, 'b n c -> b c n')

        # 计算交叉熵损失
        loss = F.cross_entropy(
            logits,
            ids
        )

        # 如果需要返���重构结果,则返回loss和logits的argmax值
        if return_reconstruction:
            return loss, logits.argmax(dim = 1)

        # 返回loss
        return loss
# 定义一个名为Transformer的类,表示层次结构的变换器
class Transformer(nn.Module):
    pass

.\lucidrains\rvq-vae-gpt\rvq_vae_gpt\__init__.py

# 从 rvq_vae_gpt.rvq_vae_gpt 模块中导入 TextVQVAE 和 Transformer 类
from rvq_vae_gpt.rvq_vae_gpt import TextVQVAE, Transformer

.\lucidrains\rvq-vae-gpt\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的信息
setup(
  name = 'rvq-vae-gpt',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.0.4',  # 版本号
  license='MIT',  # 许可证
  description = 'Yet another attempt at GPT in quantized latent space',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/rvq-vae-gpt',  # 项目链接
  keywords = [  # 关键词
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism'
  ],
  install_requires=[  # 安装依赖
    'beartype',
    'einops>=0.4',
    'local-attention>=1.0.0',
    'torch>=1.6',
    'vector-quantize-pytorch>=1.1.2'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\rvq-vae-gpt\train.py

# 导入所需的库
import gzip
import random
import tqdm
import numpy as np

import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 导入自定义模块
from rvq_vae_gpt import TextVQVAE

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
SAVE_EVERY = 1000
SEQ_LEN = 2048

# 定义辅助函数
def cycle(loader):
    # 无限循环生成数据
    while True:
        for data in loader:
            yield data

def first(it):
    # 返回迭代器的第一个元素
    return it[0]

def decode_token(token):
    # 将 token 解码为字符
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    # 将 tokens 解码为字符串
    return "".join(list(map(decode_token, tokens)))

# 实例化 TextVQVAE 模型
model = TextVQVAE(
    num_tokens = 256,    
    dim = (128, 256, 512),
    depth = (2, 2, 4),
    local_attn_window_size = 64,
    num_codebooks = 8,
    strides = (2, 2, 2)
).cuda()

# 准备 enwik8 数据
with gzip.open("./data/enwik8.gz") as file:
    data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
    np_train, np_valid = np.split(data, [int(90e6)])
    data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)

# 定义 TextSamplerDataset 类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的 DataLoader
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))

# 定义优化器
optim = Adam(model.parameters(), lr = LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
    model.train()

    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward()

    print(f"training loss: {loss.item():.3f}")
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

    optim.step()
    optim.zero_grad()

    if i == 0:
        continue

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            valid_text = next(val_loader)
            loss, recon = model(valid_text, return_reconstruction = True)

            print(f"validation loss: {loss.item():.3f}")

            print(f"\n\n\n[input text]\n\n {decode_tokens(first(valid_text))}")
            print(f"\n\n[reconstructed text]\n\n {decode_tokens(first(recon))}\n\n")

    if i % SAVE_EVERY == 0:
        model.save('./text-vae.pt')

SAC (Soft Actor Critic) - Pytorch (wip)

Implementation of Soft Actor Critic and some of its improvements in Pytorch. Interest comes from watching this lecture

Temporary Discord

Citations

@article{Haarnoja2018SoftAA,
    title   = {Soft Actor-Critic Algorithms and Applications},
    author  = {Tuomas Haarnoja and Aurick Zhou and Kristian Hartikainen and G. Tucker and Sehoon Ha and Jie Tan and Vikash Kumar and Henry Zhu and Abhishek Gupta and P. Abbeel and Sergey Levine},
    journal = {ArXiv},
    year    = {2018},
    volume  = {abs/1812.05905},
    url     = {https://api.semanticscholar.org/CorpusID:55703664}
}
@article{Hiraoka2021DropoutQF,
    title   = {Dropout Q-Functions for Doubly Efficient Reinforcement Learning},
    author  = {Takuya Hiraoka and Takahisa Imagawa and Taisei Hashimoto and Takashi Onishi and Yoshimasa Tsuruoka},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2110.02034},
    url     = {https://api.semanticscholar.org/CorpusID:238353966}
}
@inproceedings{ObandoCeron2024MixturesOE,
    title   = {Mixtures of Experts Unlock Parameter Scaling for Deep RL},
    author  = {Johan S. Obando-Ceron and Ghada Sokar and Timon Willi and Clare Lyle and Jesse Farebrother and Jakob Foerster and Gintare Karolina Dziugaite and Doina Precup and Pablo Samuel Castro},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:267637059}
}

.\lucidrains\SAC-pytorch\SAC_pytorch\SAC.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum, Tensor
from torch import nn, einsum, Tensor
# 从 torch.nn 库中导入 Module, ModuleList
from torch.nn import Module, ModuleList

# 导入 beartype 库
from beartype import beartype
# 从 beartype.typing 中导入 Tuple, List, Optional, Union
from beartype.typing import Tuple, List, Optional, Union

# 导入 einx 库中的 get_at 函数
from einx import get_at
# 导入 einops 库中的 rearrange, repeat, reduce, pack, unpack 函数
from einops import rearrange, repeat, reduce, pack, unpack

# 导入 ema_pytorch 库中的 EMA 类
from ema_pytorch import EMA

# 定义函数 exists,判断变量是否存在
def exists(v):
    return v is not None

# 定义函数 cast_tuple,将输入转换为元组
def cast_tuple(t, length = 1):
    return t if isinstance(t, tuple) else ((t,) * length)

# 定义 MLP 函数,创建简单的多层感知器网络
@beartype
def MLP(
    dim,
    dim_out,
    dim_hiddens: Union[int, Tuple[int, ...]],
    layernorm = False,
    dropout = 0.,
    activation = nn.ReLU
):
    """
    simple mlp for Q and value networks

    following Figure 1 in https://arxiv.org/pdf/2110.02034.pdf for placement of dropouts and layernorm
    however, be aware that Levine in his lecture has ablations that show layernorm alone (without dropout) is sufficient for regularization
    """

    dim_hiddens = cast_tuple(dim_hiddens)

    layers = []

    curr_dim = dim

    for dim_hidden in dim_hiddens:
        layers.append(nn.Linear(curr_dim, dim_hidden))

        layers.append(nn.Dropout(dropout))

        if layernorm:
            layers.append(nn.LayerNorm(dim_hidden))

        layers.append(activation())

        curr_dim = dim_hidden

    # final layer out

    layers.append(nn.Linear(curr_dim, dim_out))

    return nn.Sequential(*layers)

# 定义 Actor 类,用于创建 Actor 神经网络模型
class Actor(Module):
    def __init__(
        self,
        *,
        dim_state,
        num_cont_actions,
        dim_hiddens: Tuple[int, ...] = tuple(),
        eps = 1e-5
    ):
        super().__init__()
        self.eps = eps

        self.to_cont_actions = MLP(
            dim_state,
            dim_hiddens = dim_hiddens,
            dim_out = num_cont_actions * 2
        )

    def forward(
        self,
        state,
        sample = False
    ):
        """
        einops notation
        n - num actions
        ms - mu sigma
        """

        out = self.to_cont_actions(state)
        mu, sigma = rearrange(out, '... (n ms) -> ms ... n', ms = 2)

        sigma = sigma.sigmoid().clamp(min = self.eps)

        if not sample:
            return mu, sigma

        return mu + sigma * torch.randn_like(sigma)

# 定义 Critic 类,用于创建 Critic 神经网络模型
class Critic(Module):
    @beartype
    def __init__(
        self,
        *,
        dim_state,
        num_continuous_actions,
        dim_hiddens: Tuple[int, ...] = tuple(),
        layernorm = False,
        dropout = 0.
    ):
        super().__init__()

        self.to_q = MLP(
            dim_state + num_continuous_actions,
            dim_out = 1,
            dim_hiddens = dim_hiddens,
            layernorm = layernorm,
            dropout = dropout
        )

    def forward(
        self,
        state,
        actions
    ):
        state_actions, _ = pack([state, actions], 'b *')

        q_values = self.to_q(state_actions)
        q_values = rearrange('b 1 -> b')

        return q_values

# 定义 ValueNetwork 类,用于创建值网络模型
class ValueNetwork(Module):
    @beartype
    def __init__(
        self,
        *,
        dim_state,
        dim_hiddens: Tuple[int, ...] = tuple()
    ):
        super().__init__()

        self.to_values = MLP(
            dim_state,
            dim_out= 1,
            dim_hiddens = dim_hiddens
        )

    def forward(
        self,
        states
    ):
        values = self.to_values(states)
        values = rearrange(values, 'b 1 -> b')
        return values

# 定义 SAC 类,用于创建 SAC 神经网络模型
class SAC(Module):
    def __init__(
        self
    ):
        super().__init__()

    def forward(self, x):
        return x

.\lucidrains\SAC-pytorch\SAC_pytorch\__init__.py

# 从SAC_pytorch包中导入SAC类
from SAC_pytorch.SAC import SAC

.\lucidrains\SAC-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'SAC-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.0.1',
  # 许可证类型
  license='MIT',
  # 描述
  description = 'Soft Actor Critic',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/SAC-pytorch',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'deep learning',
    'reinforcement learning',
    'soft actor critic'
  ],
  # 安装依赖项
  install_requires=[
    'beartype',
    'einops>=0.7.0',
    'einx[torch]>=0.1.3',
    'ema-pytorch',
    'pytorch-custom-utils>=0.0.18',
    'soft-moe-pytorch>=0.1.6',
    'torch>=2.0'
  ],
  # 分类器列表
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Scattering Compositional Learner

Implementation of Scattering Compositional Learner, which reached superhuman levels on Raven's Progressive Matrices, a type of IQ test for analogical reasoning.

This repository is meant to be exploratory, so it may not follow the exact architecture of the paper down to the T. It is meant to find the underlying inductive bias that could be exported for use in attention networks. The paper suggests this to be the 'Scattering Transform', which is basically grouped convolutions but where each group is tranformed by one shared neural network.

If you would like the exact architecture used in the paper, the official repository is here.

Install

$ pip install scattering-transform

Use

Complete Scattering Compositional Learner network

import torch
import torch.nn.functional as F
from scattering_transform import SCL, SCLTrainingWrapper

# data - (batch, number of choices, channel dimension, image height, image width)

questions = torch.randn(1, 8, 1, 160, 160)
answers   = torch.randn(1, 8, 1, 160, 160)
labels    = torch.tensor([2])

# instantiate model

model = SCL(
    image_size = 160,                           # size of image
    set_size = 9,                               # number of questions + 1 answer
    conv_channels = [1, 16, 16, 32, 32, 32],    # convolutional channel progression, 1 for greyscale, 3 for rgb
    conv_output_dim = 80,                       # model dimension, the output dimension of the vision net
    attr_heads = 10,                            # number of attribute heads
    attr_net_hidden_dims = [128],               # attribute scatter transform MLP hidden dimension(s)
    rel_heads = 80,                             # number of relationship heads
    rel_net_hidden_dims = [64, 23, 5]           # MLP for relationship net
)

model = SCLTrainingWrapper(model)
logits = model(questions, answers) # (1, 8) - the logits of each answer being the correct match

# train

loss = F.cross_entropy(logits, labels)
loss.backward()

Scattering Transform, which is basically one MLP that acts over groups of the dimension

import torch
from scattering_transform import ScatteringTransform

# for potential use in a Transformer

mlp = ScatteringTransform(
    dims = [1024, 4096, 1024],    # MLP - dimension in -> hidden sizes -> dimension out
    heads = 16,                   # number of groups (heads)
    activation = nn.LeakyReLU     # activation to use in the MLP
)

x = torch.randn(1, 512, 1024)
mlp(x) # (1, 512, 1024)

Citation

@misc{wu2020scattering,
    title={The Scattering Compositional Learner: Discovering Objects, Attributes, Relationships in Analogical Reasoning},
    author={Yuhuai Wu and Honghua Dong and Roger Grosse and Jimmy Ba},
    year={2020},
    eprint={2007.04212},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

.\lucidrains\scattering-compositional-learner\scattering_transform\scattering_transform.py

# 导入 PyTorch 库
import torch
from torch import nn
import torch.nn.functional as F

# 辅助函数

# 如果 val 不为 None,则返回 val,否则返回 default_val
def default(val, default_val):
    return val if val is not None else default_val

# 在指定维度上扩展张量 t 的大小为 k
def expand_dim(t, dim, k):
    t = t.unsqueeze(dim)
    expand_shape = [-1] * len(t.shape)
    expand_shape[dim] = k
    return t.expand(*expand_shape)

# 简单的具有 ReLU 激活函数的多层感知机

class MLP(nn.Module):
    def __init__(self, *dims, activation = None):
        super().__init__()
        assert len(dims) > 2, 'must have at least 3 dimensions, for dimension in and dimension out'
        activation = default(activation, nn.ReLU)

        layers = []
        pairs = list(zip(dims[:-1], dims[1:]))

        for ind, (dim_in, dim_out) in enumerate(pairs):
            is_last = ind >= (len(pairs) - 1)
            layers.append(nn.Linear(dim_in, dim_out))
            if not is_last:
                layers.append(activation())

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

# 论文中提到的前馈残差块
# 用于在提取视觉特征后以及提取属性信息后使用

class FeedForwardResidual(nn.Module):
    def __init__(self, dim, mult = 4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.LayerNorm(dim * mult),
            nn.ReLU(inplace = True),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        return x + self.net(x)

# 卷积网络
# 待完成,使其可定制化并添加 Evonorm 以进行批次独立归一化

class ConvNet(nn.Module):
    def __init__(self, image_size, chans, output_dim):
        super().__init__()

        num_conv_layers = len(chans) - 1
        conv_output_size = image_size // (2 ** num_conv_layers)

        convolutions = []
        channel_pairs = list(zip(chans[:-1], chans[1:]))

        for ind, (chan_in, chan_out) in enumerate(channel_pairs):
            is_last = ind >= (len(channel_pairs) - 1)
            convolutions.append(nn.Conv2d(chan_in, chan_out, 3, padding=1, stride=2))
            if not is_last:
                convolutions.append(nn.BatchNorm2d(chan_out))

        self.net = nn.Sequential(
            *convolutions,
            nn.Flatten(1),
            nn.Linear(chans[-1] * (conv_output_size ** 2), output_dim),
            nn.ReLU(inplace=True),
            FeedForwardResidual(output_dim)
        )

    def forward(self, x):
        return self.net(x)

# 散射变换

class ScatteringTransform(nn.Module):
    def __init__(self, dims, heads, activation = None):
        super().__init__()
        assert len(dims) > 2, 'must have at least 3 dimensions, for dimension in, the hidden dimension, and dimension out'

        dim_in, *hidden_sizes, dim_out = dims

        dim_in //= heads
        dim_out //= heads

        self.heads = heads
        self.mlp = MLP(dim_in, *hidden_sizes, dim_out, activation = activation)

    def forward(self, x):
        shape, heads = x.shape, self.heads
        dim = shape[-1]

        assert (dim % heads) == 0, f'the dimension {dim} must be divisible by the number of heads {heads}'

        x = x.reshape(-1, heads, dim // heads)
        x = self.mlp(x)

        return x.reshape(shape)

# 主要的散射组合学习器类

class SCL(nn.Module):
    # 初始化函数,设置模型的参数
    def __init__(
        self,
        image_size = 160,  # 图像大小
        set_size = 9,  # 集合大小
        conv_channels = [1, 16, 16, 32, 32, 32],  # 卷积通道数
        conv_output_dim = 80,  # 卷积输出维度
        attr_heads = 10,  # 属性头数
        attr_net_hidden_dims = [128],  # 属性网络隐藏层维度
        rel_heads = 80,  # 关系头数
        rel_net_hidden_dims = [64, 23, 5]):  # 关系网络隐藏层维度

        super().__init__()
        # 创建视觉模型
        self.vision = ConvNet(image_size, conv_channels, conv_output_dim)

        # 设置属性头数和属性网络
        self.attr_heads = attr_heads
        self.attr_net = ScatteringTransform([conv_output_dim, *attr_net_hidden_dims, conv_output_dim], heads = attr_heads)
        self.ff_residual = FeedForwardResidual(conv_output_dim)

        # 设置关系头数和关系网络
        self.rel_heads = rel_heads
        self.rel_net = MLP(set_size * (conv_output_dim // rel_heads), *rel_net_hidden_dims)

        # 线性层,用于输出logits
        self.to_logit = nn.Linear(rel_net_hidden_dims[-1] * rel_heads, 1)

    # 前向传播函数
    def forward(self, sets):
        # 获取输入集合的形状信息
        b, m, n, c, h, w = sets.shape
        # 将集合展平为二维张量
        images = sets.view(-1, c, h, w)
        # 提取图像特征
        features = self.vision(images)

        # 计算属性
        attrs = self.attr_net(features)
        attrs = self.ff_residual(attrs)

        # 重塑属性张量形状
        attrs = attrs.reshape(b, m, n, self.rel_heads, -1).transpose(-2, -3).flatten(3)
        # 计算关系
        rels = self.rel_net(attrs)
        rels = rels.flatten(2)
        
        # 计算logits
        logits = self.to_logit(rels).flatten(1)
        return logits
# 为了更容易进行训练而创建的包装器类
class SCLTrainingWrapper(nn.Module):
    def __init__(self, scl):
        super().__init__()
        self.scl = scl

    # 前向传播函数,接收问题和答案作为输入
    def forward(self, questions, answers):
        # 在答案张量上增加一个维度
        answers = answers.unsqueeze(2)
        # 在问题张量上扩展维度,维度1扩展为8
        questions = expand_dim(questions, dim=1, k=8)

        # 将问题和答案张量连接在一起,沿着第二个维度
        permutations = torch.cat((questions, answers), dim=2)
        # 将连接后的张量传递给self.scl进行处理
        return self.scl(permutations)

.\lucidrains\scattering-compositional-learner\scattering_transform\__init__.py

# 从scattering_transform包中导入SCL, ScatteringTransform, SCLTrainingWrapper类
from scattering_transform.scattering_transform import SCL, ScatteringTransform, SCLTrainingWrapper

.\lucidrains\scattering-compositional-learner\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的元信息
setup(
  name = 'scattering-transform',  # 包名
  packages = find_packages(),  # 查找所有包
  version = '0.0.7',  # 版本号
  license='MIT',  # 许可证
  description = 'Scattering Transform module from the paper Scattering Compositional Learner',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/scattering-compositional-learner',  # 项目链接
  keywords = ['artificial intelligence', 'deep learning', 'reasoning'],  # 关键词
  install_requires=[
      'torch'  # 安装依赖
  ],
  classifiers=[
      'Development Status :: 4 - Beta',  # 开发状态
      'Intended Audience :: Developers',  # 目标受众
      'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 主题
      'License :: OSI Approved :: MIT License',  # 许可证类型
      'Programming Language :: Python :: 3.6',  # 编程语言
  ],
)

.\lucidrains\se3-transformer-pytorch\denoise.py

# 导入 PyTorch 库
import torch
# 导入 PyTorch 中的函数库
import torch.nn.functional as F
# 从 torch.optim 中导入 Adam 优化器
from torch.optim import Adam

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat

# 导入 sidechainnet 库,并从 se3_transformer_pytorch 中导入 SE3Transformer 类
import sidechainnet as scn
from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer

# 设置默认的数据类型为 float64
torch.set_default_dtype(torch.float64)

# 定义批量大小为 1
BATCH_SIZE = 1
# 定义每隔多少次梯度累积
GRADIENT_ACCUMULATE_EVERY = 16

# 定义一个循环函数,用于处理数据加载器
def cycle(loader, len_thres = 500):
    while True:
        for data in loader:
            # 如果数据序列长度大于指定阈值,则继续循环
            if data.seqs.shape[1] > len_thres:
                continue
            yield data

# 创建 SE3Transformer 模型
transformer = SE3Transformer(
    num_tokens = 24,
    dim = 8,
    dim_head = 8,
    heads = 2,
    depth = 2,
    attend_self = True,
    input_degrees = 1,
    output_degrees = 2,
    reduce_dim_out = True,
    differentiable_coors = True,
    num_neighbors = 0,
    attend_sparse_neighbors = True,
    num_adj_degrees = 2,
    adj_dim = 4,
    num_degrees=2,
)

# 加载数据集
data = scn.load(
    casp_version = 12,
    thinning = 30,
    with_pytorch = 'dataloaders',
    batch_size = BATCH_SIZE,
    dynamic_batching = False
)

# 创建数据加载器
dl = cycle(data['train'])
# 使用 Adam 优化器来优化 SE3Transformer 模型的参数
optim = Adam(transformer.parameters(), lr=1e-4)
# 将模型转移到 GPU 上
transformer = transformer.cuda()

# 进行训练循环
for _ in range(10000):
    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        # 获取一个批次的数据
        batch = next(dl)
        seqs, coords, masks = batch.seqs, batch.crds, batch.msks

        # 将序列转移到 GPU 上,并取最大值索引
        seqs = seqs.cuda().argmax(dim = -1)
        # 将坐标转移到 GPU 上,并设置数据类型为 float64
        coords = coords.cuda().type(torch.float64)
        # 将掩码转移到 GPU 上,并设置数据类型为布尔型
        masks = masks.cuda().bool()

        # 获取序列长度
        l = seqs.shape[1]
        # 重新排列坐标数据
        coords = rearrange(coords, 'b (l s) c -> b l s c', s = 14)

        # 保留骨架坐标
        coords = coords[:, :, 0:3, :]
        coords = rearrange(coords, 'b l s c -> b (l s) c')

        # 重复序列和掩码
        seq = repeat(seqs, 'b n -> b (n c)', c = 3)
        masks = repeat(masks, 'b n -> b (n c)', c = 3)

        # 添加高斯噪声到坐标数据
        noised_coords = coords + torch.randn_like(coords).cuda()

        # 创建邻接矩阵
        i = torch.arange(seq.shape[-1], device = seqs.device)
        adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))

        # 使用 SE3Transformer 进行前向传播
        out = transformer(
            seq,
            noised_coords,
            mask = masks,
            adj_mat = adj_mat,
            return_type = 1
        )

        # 对去噪后的坐标数据计算均方误差损失
        denoised_coords = noised_coords + out
        loss = F.mse_loss(denoised_coords[masks], coords[masks]) 
        # 反向传播
        (loss / GRADIENT_ACCUMULATE_EVERY).backward()

    # 输出损失值
    print('loss:', loss.item())
    # 更新优化器
    optim.step()
    # 梯度清零
    optim.zero_grad()

SE3 Transformer - Pytorch

Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. May be needed for replicating Alphafold2 results and other drug discovery applications.

Open In Colab Example of equivariance

If you had been using any version of SE3 Transformers prior to version 0.6.0, please update. A huge bug has been uncovered by @MattMcPartlon, if you were not using the adjacency sparse neighbors settings and relying on nearest neighbors functionality

Update: It is recommended that you use Equiformer instead

Install

$ pip install se3-transformer-pytorch

Usage

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 512,
    heads = 8,
    depth = 6,
    dim_head = 64,
    num_degrees = 4,
    valid_radius = 10
)

feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
mask  = torch.ones(1, 1024).bool()

out = model(feats, coors, mask) # (1, 1024, 512)

Potential example usage in Alphafold2, as outlined here

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 2,
    input_degrees = 1,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True,
    differentiable_coors = True
)

atom_feats = torch.randn(2, 32, 64)
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refined_coors = coors + model(atom_feats, coors, mask, return_type = 1) # (2, 32, 3)

You can also let the base transformer class take care of embedding the type 0 features being passed in. Assuming they are atoms

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 28,       # 28 unique atoms
    dim = 64,
    depth = 2,
    input_degrees = 1,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True
)

atoms = torch.randint(0, 28, (2, 32))
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refined_coors = coors + model(atoms, coors, mask, return_type = 1) # (2, 32, 3)

If you think the net could further benefit from positional encoding, you can featurize your positions in space and pass it in as follows.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 2,
    input_degrees = 2,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True  # reduce out the final dimension
)

atom_feats  = torch.randn(2, 32, 64, 1) # b x n x d x type0
coors_feats = torch.randn(2, 32, 64, 3) # b x n x d x type1

# atom features are type 0, predicted coordinates are type 1
features = {'0': atom_feats, '1': coors_feats}
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refined_coors = coors + model(features, coors, mask, return_type = 1) # (2, 32, 3) - equivariant to input type 1 features and coordinates

Edges

To offer edge information to SE3 Transformers (say bond types between atoms), you just have to pass in two more keyword arguments on initialization.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 28,
    dim = 64,
    num_edge_tokens = 4,       # number of edge type, say 4 bond types
    edge_dim = 16,             # dimension of edge embedding
    depth = 2,
    input_degrees = 1,
    num_degrees = 3,
    output_degrees = 1,
    reduce_dim_out = True
)

atoms = torch.randint(0, 28, (2, 32))
bonds = torch.randint(0, 4, (2, 32, 32))
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

pred = model(atoms, coors, mask, edges = bonds, return_type = 0) # (2, 32, 1)

If you would like to pass in continuous values for your edges, you can choose to not set the num_edge_tokens, encode your discrete bond types, and then concat it to the fourier features of these continuous values

import torch
from se3_transformer_pytorch import SE3Transformer
from se3_transformer_pytorch.utils import fourier_encode

model = SE3Transformer(
    dim = 64,
    depth = 1,
    attend_self = True,
    num_degrees = 2,
    output_degrees = 2,
    edge_dim = 34           # edge dimension must match the final dimension of the edges being passed in
)

feats = torch.randn(1, 32, 64)
coors = torch.randn(1, 32, 3)
mask  = torch.ones(1, 32).bool()

pairwise_continuous_values = torch.randint(0, 4, (1, 32, 32, 2))  # say there are 2

edges = fourier_encode(
    pairwise_continuous_values,
    num_encodings = 8,
    include_self = True
) # (1, 32, 32, 34) - {2 * (2 * 8 + 1)}

out = model(feats, coors, mask, edges = edges, return_type = 1)

Sparse Neighbors

If you know the connectivity of your points (say you are working with molecules), you can pass in an adjacency matrix, in the form of a boolean mask (where True indicates connectivity).

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 32,
    heads = 8,
    depth = 1,
    dim_head = 64,
    num_degrees = 2,
    valid_radius = 10,
    attend_sparse_neighbors = True,  # this must be set to true, in which case it will assert that you pass in the adjacency matrix
    num_neighbors = 0,               # if you set this to 0, it will only consider the connected neighbors as defined by the adjacency matrix. but if you set a value greater than 0, it will continue to fetch the closest points up to this many, excluding the ones already specified by the adjacency matrix
    max_sparse_neighbors = 8         # you can cap the number of neighbors, sampled from within your sparse set of neighbors as defined by the adjacency matrix, if specified
)

feats = torch.randn(1, 128, 32)
coors = torch.randn(1, 128, 3)
mask  = torch.ones(1, 128).bool()

# placeholder adjacency matrix
# naively assuming the sequence is one long chain (128, 128)

i = torch.arange(128)
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))

out = model(feats, coors, mask, adj_mat = adj_mat) # (1, 128, 512)

You can also have the network automatically derive for you the Nth-degree neighbors with one extra keyword num_adj_degrees. If you would like the system to differentiate between the degree of the neighbors as edge information, further pass in a non-zero adj_dim.

import torch
from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 1,
    attend_self = True,
    num_degrees = 2,
    output_degrees = 2,
    num_neighbors = 0,
    attend_sparse_neighbors = True,
    num_adj_degrees = 2,    # automatically derive 2nd degree neighbors
    adj_dim = 4             # embed 1st and 2nd degree neighbors (as well as null neighbors) with edge embeddings of this dimension
)

feats = torch.randn(1, 32, 64)
coors = torch.randn(1, 32, 3)
mask  = torch.ones(1, 32).bool()

# placeholder adjacency matrix
# naively assuming the sequence is one long chain (128, 128)

i = torch.arange(128)
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))

out = model(feats, coors, mask, adj_mat = adj_mat, return_type = 1)

To have fine control over the dimensionality of each type, you can use the hidden_fiber_dict and out_fiber_dict keywords to pass in a dictionary with the degree to dimension values as the key / values.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 28,
    dim = 64,
    num_edge_tokens = 4,
    edge_dim = 16,
    depth = 2,
    input_degrees = 1,
    num_degrees = 3,
    output_degrees = 1,
    hidden_fiber_dict = {0: 16, 1: 8, 2: 4},
    out_fiber_dict = {0: 16, 1: 1},
    reduce_dim_out = False
)

atoms = torch.randint(0, 28, (2, 32))
bonds = torch.randint(0, 4, (2, 32, 32))
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

pred = model(atoms, coors, mask, edges = bonds)

pred['0'] # (2, 32, 16)
pred['1'] # (2, 32, 1, 3)

Neighbors

You can further control which nodes can be considered by passing in a neighbor mask. All False values will be masked out of consideration.

import torch
from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 16,
    dim_head = 16,
    attend_self = True,
    num_degrees = 4,
    output_degrees = 2,
    num_edge_tokens = 4,
    num_neighbors = 8,      # make sure you set this value as the maximum number of neighbors set by your neighbor_mask, or it will throw a warning
    edge_dim = 2,
    depth = 3
)

feats = torch.randn(1, 32, 16)
coors = torch.randn(1, 32, 3)
mask  = torch.ones(1, 32).bool()
bonds = torch.randint(0, 4, (1, 32, 32))

neighbor_mask = torch.ones(1, 32, 32).bool() # set the nodes you wish to be masked out as False

out = model(
    feats,
    coors,
    mask,
    edges = bonds,
    neighbor_mask = neighbor_mask,
    return_type = 1
)

Global Nodes

This feature allows you to pass in vectors that can be viewed as global nodes that are seen by all other nodes. The idea would be to pool your graph into a few feature vectors, which will be projected to key / values across all the attention layers in the network. All nodes will have full access to global node information, regardless of nearest neighbors or adjacency calculation.

import torch
from torch import nn
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 1,
    num_degrees = 2,
    num_neighbors = 4,
    valid_radius = 10,
    global_feats_dim = 32 # this must be set to the dimension of the global features, in this example, 32
)

feats = torch.randn(1, 32, 64)
coors = torch.randn(1, 32, 3)
mask  = torch.ones(1, 32).bool()

# naively derive global features
# by pooling features and projecting
global_feats = nn.Linear(64, 32)(feats.mean(dim = 1, keepdim = True)) # (1, 1, 32)

out = model(feats, coors, mask, return_type = 0, global_feats = global_feats)

Todo:

Autoregressive

You can use SE3 Transformers autoregressively with just one extra flag

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 512,
    heads = 8,
    depth = 6,
    dim_head = 64,
    num_degrees = 4,
    valid_radius = 10,
    causal = True          # set this to True
)

feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
mask  = torch.ones(1, 1024).bool()

out = model(feats, coors, mask) # (1, 1024, 512)

Experimental Features

Non-pairwise convolved keys

I've discovered that using linearly projected keys (rather than the pairwise convolution) seems to do ok in a toy denoising task. This leads to 25% memory savings. You can try this feature by setting linear_proj_keys = True

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 1,
    num_degrees = 4,
    num_neighbors = 8,
    valid_radius = 10,
    splits = 4,
    linear_proj_keys = True # set this to True
).cuda()

feats = torch.randn(1, 32, 64).cuda()
coors = torch.randn(1, 32, 3).cuda()
mask  = torch.ones(1, 32).bool().cuda()

out = model(feats, coors, mask, return_type = 0)

Shared key / values across all heads

There is a relatively unknown technique for transformers where one can share one key / value head across all the heads of the queries. In my experience in NLP, this usually leads to worse performance, but if you are really in need to tradeoff memory for more depth or higher number of degrees, this may be a good option.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 8,
    num_degrees = 4,
    num_neighbors = 8,
    valid_radius = 10,
    splits = 4,
    one_headed_key_values = True  # one head of key / values shared across all heads of the queries
).cuda()

feats = torch.randn(1, 32, 64).cuda()
coors = torch.randn(1, 32, 3).cuda()
mask  = torch.ones(1, 32).bool().cuda()

out = model(feats, coors, mask, return_type = 0)

Tied key / values

You can also tie the key / values (have them be the same), for half memory savings

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 8,
    num_degrees = 4,
    num_neighbors = 8,
    valid_radius = 10,
    splits = 4,
    tie_key_values = True # set this to True
).cuda()

feats = torch.randn(1, 32, 64).cuda()
coors = torch.randn(1, 32, 3).cuda()
mask  = torch.ones(1, 32).bool().cuda()

out = model(feats, coors, mask, return_type = 0)

Using EGNN

This is an experimental version of EGNN that works for higher types, and greater dimensionality than just 1 (for the coordinates). The class name is still SE3Transformer since it reuses some preexisting logic, so just ignore that for now until I clean it up later.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 32,
    num_neighbors = 8,
    num_edge_tokens = 4,
    edge_dim = 4,
    num_degrees = 4,       # number of higher order types - will use basis on a TCN to project to these dimensions
    use_egnn = True,       # set this to true to use EGNN instead of equivariant attention layers
    egnn_hidden_dim = 64,  # egnn hidden dimension
    depth = 4,             # depth of EGNN
    reduce_dim_out = True  # will project the dimension of the higher types to 1
).cuda()

feats = torch.randn(2, 32, 32).cuda()
coors = torch.randn(2, 32, 3).cuda()
bonds = torch.randint(0, 4, (2, 32, 32)).cuda()
mask  = torch.ones(2, 32).bool().cuda()

refinement = model(feats, coors, mask, edges = bonds, return_type = 1) # (2, 32, 3)

coors = coors + refinement  # update coors with refinement

If you would like to specify individual dimensions for each of the higher types, just pass in hidden_fiber_dict where the dictionary is in the format {<degree>:<dim>} instead of num_degrees

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 32,
    num_neighbors = 8,
    hidden_fiber_dict = {0: 32, 1: 16, 2: 8, 3: 4},
    use_egnn = True,
    depth = 4,
    egnn_hidden_dim = 64,
    egnn_weights_clamp_value = 2, 
    reduce_dim_out = True
).cuda()

feats = torch.randn(2, 32, 32).cuda()
coors = torch.randn(2, 32, 3).cuda()
mask  = torch.ones(2, 32).bool().cuda()

refinement = model(feats, coors, mask, return_type = 1) # (2, 32, 3)

coors = coors + refinement  # update coors with refinement

Scaling (wip)

This section will list ongoing efforts to make SE3 Transformer scale a little better.

Firstly, I have added reversible networks. This allows me to add a little more depth before hitting the usual memory roadblocks. Equivariance preservation is demonstrated in the tests.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 20,
    dim = 32,
    dim_head = 32,
    heads = 4,
    depth = 12,             # 12 layers
    input_degrees = 1,
    num_degrees = 3,
    output_degrees = 1,
    reduce_dim_out = True,
    reversible = True       # set reversible to True
).cuda()

atoms = torch.randint(0, 4, (2, 32)).cuda()
coors = torch.randn(2, 32, 3).cuda()
mask  = torch.ones(2, 32).bool().cuda()

pred = model(atoms, coors, mask = mask, return_type = 0)

loss = pred.sum()
loss.backward()

Examples

First install sidechainnet

$ pip install sidechainnet

Then run the protein backbone denoising task

$ python denoise.py

Caching

By default, the basis vectors are cached. However, if there is ever the need to clear the cache, you simply have to set the environmental flag CLEAR_CACHE to some value on initiating the script

$ CLEAR_CACHE=1 python train.py

Or you can try deleting the cache directory, which should exist at

$ rm -rf ~/.cache.equivariant_attention

You can also designate your own directory where you want the caches to be stored, in the case that the default directory may have permission issues

CACHE_PATH=./path/to/my/cache python train.py

Testing

$ python setup.py pytest

Credit

This library is largely a port of Fabian's official repository, but without the DGL library.

Citations

@misc{fuchs2020se3transformers,
    title   = {SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks}, 
    author  = {Fabian B. Fuchs and Daniel E. Worrall and Volker Fischer and Max Welling},
    year    = {2020},
    eprint  = {2006.10503},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{satorras2021en,
    title   = {E(n) Equivariant Graph Neural Networks},
    author  = {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
    year    = {2021},
    eprint  = {2102.09844},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{gomez2017reversible,
    title     = {The Reversible Residual Network: Backpropagation Without Storing Activations},
    author    = {Aidan N. Gomez and Mengye Ren and Raquel Urtasun and Roger B. Grosse},
    year      = {2017},
    eprint    = {1707.04585},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{shazeer2019fast,
    title   = {Fast Transformer Decoding: One Write-Head is All You Need},
    author  = {Noam Shazeer},
    year    = {2019},
    eprint  = {1911.02150},
    archivePrefix = {arXiv},
    primaryClass = {cs.NE}
}

.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\basis.py

# 导入必要的库
import os
from math import pi
import torch
from torch import einsum
from einops import rearrange
from itertools import product
from contextlib import contextmanager

# 导入自定义库
from se3_transformer_pytorch.irr_repr import irr_repr, spherical_harmonics
from se3_transformer_pytorch.utils import torch_default_dtype, cache_dir, exists, default, to_order
from se3_transformer_pytorch.spherical_harmonics import clear_spherical_harmonics_cache

# 常量定义

# 设置缓存路径,默认为用户主目录下的.cache.equivariant_attention文件夹
CACHE_PATH = default(os.getenv('CACHE_PATH'), os.path.expanduser('~/.cache.equivariant_attention'))
# 如果环境变量CLEAR_CACHE存在,则将缓存路径设为None
CACHE_PATH = CACHE_PATH if not exists(os.environ.get('CLEAR_CACHE')) else None

# 随机角度列表
# todo (figure ot why this was hard coded in official repo)
RANDOM_ANGLES = [ 
    [4.41301023, 5.56684102, 4.59384642],
    [4.93325116, 6.12697327, 4.14574096],
    [0.53878964, 4.09050444, 5.36539036],
    [2.16017393, 3.48835314, 5.55174441],
    [2.52385107, 0.2908958, 3.90040975]
]

# 辅助函数

# 空上下文管理器
@contextmanager
def null_context():
    yield

# 函数定义

def get_matrix_kernel(A, eps = 1e-10):
    '''
    计算矩阵A的核的正交基(x_1, x_2, ...)
    A x_i = 0
    scalar_product(x_i, x_j) = delta_ij

    :param A: 矩阵
    :return: 每行是A核的基向量的矩阵
    '''
    _u, s, v = torch.svd(A)
    kernel = v.t()[s < eps]
    return kernel

def get_matrices_kernel(As, eps = 1e-10):
    '''
    计算所有矩阵As的公共核
    '''
    matrix = torch.cat(As, dim=0)
    return get_matrix_kernel(matrix, eps)

def get_spherical_from_cartesian(cartesian, divide_radius_by = 1.0):
    """
    将笛卡尔坐标转换为球坐标

    # ON ANGLE CONVENTION
    #
    # sh has following convention for angles:
    # :param theta: the colatitude / polar angle, ranging from 0(North Pole, (X, Y, Z) = (0, 0, 1)) to pi(South Pole, (X, Y, Z) = (0, 0, -1)).
    # :param phi: the longitude / azimuthal angle, ranging from 0 to 2 pi.
    #
    # the 3D steerable CNN code therefore (probably) has the following convention for alpha and beta:
    # beta = pi - theta; ranging from 0(South Pole, (X, Y, Z) = (0, 0, -1)) to pi(North Pole, (X, Y, Z) = (0, 0, 1).
    # alpha = phi
    #
    """
    # 初始化返回数组
    spherical = torch.zeros_like(cartesian)

    # 索引
    ind_radius, ind_alpha, ind_beta = 0, 1, 2

    cartesian_x, cartesian_y, cartesian_z = 2, 0, 1

    # 获取在xy平面上的投影半径
    r_xy = cartesian[..., cartesian_x] ** 2 + cartesian[..., cartesian_y] ** 2

    # 获取第二个角度
    # 版本 'elevation angle defined from Z-axis down'
    spherical[..., ind_beta] = torch.atan2(torch.sqrt(r_xy), cartesian[..., cartesian_z])

    # 获取xy平面上的角度
    spherical[...,ind_alpha] = torch.atan2(cartesian[...,cartesian_y], cartesian[...,cartesian_x])

    # 获取整体半径
    radius = torch.sqrt(r_xy + cartesian[...,cartesian_z]**2)

    if divide_radius_by != 1.0:
        radius /= divide_radius_by

    spherical[..., ind_radius] = radius
    return spherical

def kron(a, b):
    """
    计算矩阵a和b的Kronecker积
    """
    res = einsum('... i j, ... k l -> ... i k j l', a, b)
    return rearrange(res, '... i j k l -> ... (i j) (k l)')

def get_R_tensor(order_out, order_in, a, b, c):
    return kron(irr_repr(order_out, a, b, c), irr_repr(order_in, a, b, c)

def sylvester_submatrix(order_out, order_in, J, a, b, c):
    ''' 生成用于在子空间J中解Sylvester方程的Kronecker积矩阵 '''
    R_tensor = get_R_tensor(order_out, order_in, a, b, c)  # [m_out * m_in, m_out * m_in]
    R_irrep_J = irr_repr(J, a, b, c)  # [m, m]

    R_tensor_identity = torch.eye(R_tensor.shape[0])
    R_irrep_J_identity = torch.eye(R_irrep_J.shape[0]
    # 计算两个张量的 Kronecker 乘积,并返回结果
    return kron(R_tensor, R_irrep_J_identity) - kron(R_tensor_identity, R_irrep_J.t())  # [(m_out * m_in) * m, (m_out * m_in) * m]
# 使用缓存目录装饰器,指定缓存路径为 CACHE_PATH
# 使用默认的 torch 浮点数类型为 float64 装饰器
# 禁用 torch 的梯度计算功能装饰器
def basis_transformation_Q_J(J, order_in, order_out, random_angles = RANDOM_ANGLES):
    """
    :param J: 球谐函数的阶数
    :param order_in: 输入表示的阶数
    :param order_out: 输出表示的阶数
    :return: 文章中 Q^-1 矩阵的一部分
    """
    # 生成 Sylvester 子矩阵列表
    sylvester_submatrices = [sylvester_submatrix(order_out, order_in, J, a, b, c) for a, b, c in random_angles]
    # 获取 Sylvester 子矩阵的零空间
    null_space = get_matrices_kernel(sylvester_submatrices)
    # 断言零空间的大小为 1,即唯一的子空间解
    assert null_space.size(0) == 1, null_space.size()
    # 获取 Q_J 矩阵
    Q_J = null_space[0]  # [(m_out * m_in) * m]
    # 重塑 Q_J 矩阵的形状
    Q_J = Q_J.view(to_order(order_out) * to_order(order_in), to_order(J))  # [m_out * m_in, m]
    # 转换为 float 类型并返回
    return Q_J.float()  # [m_out * m_in, m]

# 预计算球谐函数直到最大阶数 max_J
def precompute_sh(r_ij, max_J):
    """
    预计算球谐函数直到最大阶数 max_J

    :param r_ij: 相对位置
    :param max_J: 整个网络中使用的最大阶数
    :return: 字典,每个条目的形状为 [B,N,K,2J+1]
    """
    i_alpha, i_beta = 1, 2
    # 生成球谐函数字典
    Y_Js = {J: spherical_harmonics(J, r_ij[...,i_alpha], r_ij[...,i_beta]) for J in range(max_J + 1)}
    # 清除球谐函数缓存
    clear_spherical_harmonics_cache()
    return Y_Js

# 获取等变权重基础(基础)函数
def get_basis(r_ij, max_degree, differentiable = False):
    """Return equivariant weight basis (basis)

    Call this function *once* at the start of each forward pass of the model.
    It computes the equivariant weight basis, W_J^lk(x), and internodal 
    distances, needed to compute varphi_J^lk(x), of eqn 8 of
    https://arxiv.org/pdf/2006.10503.pdf. The return values of this function 
    can be shared as input across all SE(3)-Transformer layers in a model.

    Args:
        r_ij: relative positional vectors
        max_degree: non-negative int for degree of highest feature-type
        differentiable: whether r_ij should receive gradients from basis
    Returns:
        dict of equivariant bases, keys are in form '<d_in><d_out>'
    """

    # 相对位置编码(向量)
    context = null_context if not differentiable else torch.no_grad

    device, dtype = r_ij.device, r_ij.dtype

    with context():
        # 将笛卡尔坐标系转换为球坐标系
        r_ij = get_spherical_from_cartesian(r_ij)

        # 预计算球谐函数
        Y = precompute_sh(r_ij, 2 * max_degree)

        # 等变基础(字典['d_in><d_out>'])
        basis = {}
        for d_in, d_out in product(range(max_degree+1), range(max_degree+1)):
            K_Js = []
            for J in range(abs(d_in - d_out), d_in + d_out + 1):
                # 获取球谐函数变换矩阵 Q_J
                Q_J = basis_transformation_Q_J(J, d_in, d_out)
                Q_J = Q_J.type(dtype).to(device)

                # 从球谐函数创建核
                K_J = torch.matmul(Y[J], Q_J.T)
                K_Js.append(K_J)

            # 重塑以便可以使用点积进行线性组合
            K_Js = torch.stack(K_Js, dim = -1)
            size = (*r_ij.shape[:-1], 1, to_order(d_out), 1, to_order(d_in), to_order(min(d_in,d_out)))
            basis[f'{d_in},{d_out}'] = K_Js.view(*size)

    # 额外的 detach 以确保安全
    if not differentiable:
        for k, v in basis.items():
            basis[k] = v.detach()

    return basis

.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\irr_repr.py

# 导入所需的库
import os
import numpy as np
import torch
from torch import sin, cos, atan2, acos
from math import pi
from pathlib import Path
from functools import wraps

# 导入自定义的函数和类
from se3_transformer_pytorch.utils import exists, default, cast_torch_tensor, to_order
from se3_transformer_pytorch.spherical_harmonics import get_spherical_harmonics, clear_spherical_harmonics_cache

# 设置数据路径
DATA_PATH = path = Path(os.path.dirname(__file__)) / 'data'

# 尝试加载预先计算好的 J_dense 数据
try:
    path = DATA_PATH / 'J_dense.pt'
    Jd = torch.load(str(path))
except:
    # 如果加载失败,则加载 numpy 格式的数据并转换为 torch 格式
    path = DATA_PATH / 'J_dense.npy'
    Jd_np = np.load(str(path), allow_pickle = True)
    Jd = list(map(torch.from_numpy, Jd_np))

# 创建 Wigner D 矩阵
def wigner_d_matrix(degree, alpha, beta, gamma, dtype = None, device = None):
    """Create wigner D matrices for batch of ZYZ Euler anglers for degree l."""
    J = Jd[degree].type(dtype).to(device)
    order = to_order(degree)
    x_a = z_rot_mat(alpha, degree)
    x_b = z_rot_mat(beta, degree)
    x_c = z_rot_mat(gamma, degree)
    res = x_a @ J @ x_b @ J @ x_c
    return res.view(order, order)

# 创建绕 Z 轴旋转的矩阵
def z_rot_mat(angle, l):
    device, dtype = angle.device, angle.dtype
    order = to_order(l)
    m = angle.new_zeros((order, order))
    inds = torch.arange(0, order, 1, dtype=torch.long, device=device)
    reversed_inds = torch.arange(2 * l, -1, -1, dtype=torch.long, device=device)
    frequencies = torch.arange(l, -l - 1, -1, dtype=dtype, device=device)[None]

    m[inds, reversed_inds] = sin(frequencies * angle[None])
    m[inds, inds] = cos(frequencies * angle[None])
    return m

# 创建不可约表示
def irr_repr(order, alpha, beta, gamma, dtype = None):
    """
    irreducible representation of SO3
    - compatible with compose and spherical_harmonics
    """
    cast_ = cast_torch_tensor(lambda t: t)
    dtype = default(dtype, torch.get_default_dtype())
    alpha, beta, gamma = map(cast_, (alpha, beta, gamma))
    return wigner_d_matrix(order, alpha, beta, gamma, dtype = dtype)

# 绕 Z 轴旋转
@cast_torch_tensor
def rot_z(gamma):
    '''
    Rotation around Z axis
    '''
    return torch.tensor([
        [cos(gamma), -sin(gamma), 0],
        [sin(gamma), cos(gamma), 0],
        [0, 0, 1]
    ], dtype=gamma.dtype)

# 绕 Y 轴旋转
@cast_torch_tensor
def rot_y(beta):
    '''
    Rotation around Y axis
    '''
    return torch.tensor([
        [cos(beta), 0, sin(beta)],
        [0, 1, 0],
        [-sin(beta), 0, cos(beta)]
    ], dtype=beta.dtype)

# 将球���上的点转换为 alpha 和 beta
@cast_torch_tensor
def x_to_alpha_beta(x):
    '''
    Convert point (x, y, z) on the sphere into (alpha, beta)
    '''
    x = x / torch.norm(x)
    beta = acos(x[2])
    alpha = atan2(x[1], x[0])
    return (alpha, beta)

# ZYZ 欧拉角旋转
def rot(alpha, beta, gamma):
    '''
    ZYZ Euler angles rotation
    '''
    return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)

# 合成旋转
def compose(a1, b1, c1, a2, b2, c2):
    """
    (a, b, c) = (a1, b1, c1) composed with (a2, b2, c2)
    """
    comp = rot(a1, b1, c1) @ rot(a2, b2, c2)
    xyz = comp @ torch.tensor([0, 0, 1.])
    a, b = x_to_alpha_beta(xyz)
    rotz = rot(0, -b, -a) @ comp
    c = atan2(rotz[1, 0], rotz[0, 0])
    return a, b, c

# 计算球谐函数
def spherical_harmonics(order, alpha, beta, dtype = None):
    return get_spherical_harmonics(order, theta = (pi - beta), phi = alpha)

.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\reversible.py

import torch  # 导入 PyTorch 库
import torch.nn as nn  # 导入 PyTorch 中的神经网络模块
from torch.autograd.function import Function  # 导入 PyTorch 中的自动微分函数
from torch.utils.checkpoint import get_device_states, set_device_states  # 导入 PyTorch 中的检查点函数

# 辅助函数

def map_values(fn, x):  # 定义一个函数,对字典中的值应用给定函数
    out = {}
    for (k, v) in x.items():
        out[k] = fn(v)
    return out

def dict_chunk(x, chunks, dim):  # 定义一个函数,将字典中的值按给定维度和块数进行分块
    out1 = {}
    out2 = {}
    for (k, v) in x.items():
        c1, c2 = v.chunk(chunks, dim=dim)
        out1[k] = c1
        out2[k] = c2
    return out1, out2

def dict_sum(x, y):  # 定义一个函数,对两个字典中的值进行相加
    out = {}
    for k in x.keys():
        out[k] = x[k] + y[k]
    return out

def dict_subtract(x, y):  # 定义一个函数,对两个字典中的值进行相减
    out = {}
    for k in x.keys():
        out[k] = x[k] - y[k]
    return out

def dict_cat(x, y, dim):  # 定义一个函数,对两个字典中的值按给定维度进行拼接
    out = {}
    for k, v1 in x.items():
        v2 = y[k]
        out[k] = torch.cat((v1, v2), dim=dim)
    return out

def dict_set_(x, key, value):  # 定义一个函数,设置字典中所有值的指定属性为给定值
    for k, v in x.items():
        setattr(v, key, value)

def dict_backwards_(outputs, grad_tensors):  # 定义一个函数,对字典中的值进行反向传播
    for k, v in outputs.items():
        torch.autograd.backward(v, grad_tensors[k], retain_graph=True)

def dict_del_(x):  # 定义一个函数,删除字典中的所有值
    for k, v in x.items():
        del v
    del x

def values(d):  # 定义一个函数,返回字典中所有值的列表
    return [v for _, v in d.items()]

# 参考以下示例保存和设置随机数生成器 https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html
class Deterministic(nn.Module):  # 定义一个类,用于确定性计算
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.cpu_state = None
        self.cuda_in_fwd = None
        self.gpu_devices = None
        self.gpu_states = None

    def record_rng(self, *args):  # 记录随机数生成器状态
        self.cpu_state = torch.get_rng_state()
        if torch.cuda._initialized:
            self.cuda_in_fwd = True
            self.gpu_devices, self.gpu_states = get_device_states(*args)

    def forward(self, *args, record_rng=False, set_rng=False, **kwargs):  # 前向传播函数
        if record_rng:
            self.record_rng(*args)

        if not set_rng:
            return self.net(*args, **kwargs)

        rng_devices = []
        if self.cuda_in_fwd:
            rng_devices = self.gpu_devices

        with torch.random.fork_rng(devices=rng_devices, enabled=True):
            torch.set_rng_state(self.cpu_state)
            if self.cuda_in_fwd:
                set_device_states(self.gpu_devices, self.gpu_states)
            return self.net(*args, **kwargs)

# 受 https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 启发
# 一旦多 GPU 工作正常,重构并将 PR 发回源代码
class ReversibleBlock(nn.Module):  # 定义一个可逆块
    def __init__(self, f, g):
        super().__init__()
        self.f = Deterministic(f)
        self.g = Deterministic(g)

    def forward(self, x, **kwargs):  # 前向传播函数
        training = self.training
        x1, x2 = dict_chunk(x, 2, dim=-1)
        y1, y2 = None, None

        with torch.no_grad():
            y1 = dict_sum(x1, self.f(x2, record_rng=training, **kwargs))
            y2 = dict_sum(x2, self.g(y1, record_rng=training))

        return dict_cat(y1, y2, dim=-1)
    # 定义反向传播函数,接收输入 y、梯度 dy 和其他参数
    def backward_pass(self, y, dy, **kwargs):
        # 将 y 按照指定维度分成两部分 y1 和 y2
        y1, y2 = dict_chunk(y, 2, dim = -1)
        # 删除原始 y 字典
        dict_del_(y)

        # 将 dy 按照指定维度分成两部分 dy1 和 dy2
        dy1, dy2 = dict_chunk(dy, 2, dim = -1)
        # 删除原始 dy 字典
        dict_del_(dy)

        # 开启梯度追踪
        with torch.enable_grad():
            # 设置 y1 的 requires_grad 为 True
            dict_set_(y1, 'requires_grad', True)
            # 计算 y1 的梯度 gy1
            gy1 = self.g(y1, set_rng = True)
            # 对 gy1 进行反向传播,传入 dy2
            dict_backwards_(gy1, dy2)

        # 关闭梯度追踪
        with torch.no_grad():
            # 计算 x2,即 y2 减去 gy1
            x2 = dict_subtract(y2, gy1)
            # 删除 y2 和 gy1
            dict_del_(y2)
            dict_del_(gy1)

            # 计算 dx1,即 dy1 加上 y1 中各张量的梯度
            dx1 = dict_sum(dy1, map_values(lambda t: t.grad, y1))
            # 删除 dy1,并将 y1 的梯度设为 None
            dict_del_(dy1)
            dict_set_(y1, 'grad', None)

        # 开启梯度追踪
        with torch.enable_grad():
            # 设置 x2 的 requires_grad 为 True
            dict_set_(x2, 'requires_grad', True)
            # 计算 fx2,即对 x2 进行操作并计算梯度
            fx2 = self.f(x2, set_rng = True, **kwargs)
            # 对 fx2 进行反向传播,传入 dx1
            dict_backwards_(fx2, dx1)

        # 关闭梯度追踪
        with torch.no_grad():
            # 计算 x1,即 y1 减去 fx2
            x1 = dict_subtract(y1, fx2)
            # 删除 y1 和 fx2
            dict_del_(y1)
            dict_del_(fx2)

            # 计算 dx2,即 dy2 加上 x2 中各张量的梯度
            dx2 = dict_sum(dy2, map_values(lambda t: t.grad, x2))
            # 删除 dy2,并将 x2 的梯度设为 None
            dict_del_(dy2)
            dict_set_(x2, 'grad', None)

            # 将 x2 中的张量都 detach,即不再追踪梯度
            x2 = map_values(lambda t: t.detach(), x2)

            # 将 x1 和 x2 按照指定维度拼接成 x
            x = dict_cat(x1, x2, dim = -1)
            # 将 dx1 和 dx2 按照指定维度拼接成 dx
            dx = dict_cat(dx1, dx2, dim = -1)

        # 返回拼接后的 x 和 dx
        return x, dx
class _ReversibleFunction(Function):
    # 定义一个继承自Function的类_ReversibleFunction
    @staticmethod
    def forward(ctx, x, blocks, kwargs):
        # 定义静态方法forward,接受输入x、blocks和kwargs
        input_keys = kwargs.pop('input_keys')
        # 从kwargs中弹出键为'input_keys'的值
        split_dims = kwargs.pop('split_dims')
        # 从kwargs中弹出键为'split_dims'的值
        input_values = x.split(split_dims, dim = -1)
        # 将输入x按照split_dims进行分割,得到输入值列表
        x = dict(zip(input_keys, input_values))
        # 将输入键和值列表组合成字典

        ctx.kwargs = kwargs
        ctx.split_dims = split_dims
        ctx.input_keys = input_keys
        # 将kwargs、split_dims和input_keys保存在上下文对象ctx中

        for block in blocks:
            x = block(x, **kwargs)
        # 遍历blocks中的每个块,对输入x进行处理

        ctx.y = map_values(lambda t: t.detach(), x)
        # 将x中的值进行detach操作,保存在ctx.y中
        ctx.blocks = blocks
        # 将blocks保存在ctx.blocks中

        x = torch.cat(values(x), dim = -1)
        # 将x中的值按照dim = -1进行拼接
        return x
        # 返回处理后的x

    @staticmethod
    def backward(ctx, dy):
        # 定义静态方法backward,接受输入dy
        y = ctx.y
        kwargs = ctx.kwargs
        input_keys = ctx.input_keys
        split_dims = ctx.split_dims
        # 从上下文对象ctx中获取y、kwargs、input_keys和split_dims

        dy = dy.split(split_dims, dim = -1)
        # 将dy按照split_dims进行分割
        dy = dict(zip(input_keys, dy))
        # 将分割后的dy与input_keys组合成字典

        for block in ctx.blocks[::-1]:
            y, dy = block.backward_pass(y, dy, **kwargs)
        # 逆序遍历ctx.blocks中的每个块,对y和dy进行反向传播

        dy = torch.cat(values(dy), dim = -1)
        # 将dy中的值按照dim = -1进行拼接
        return dy, None, None
        # 返回处理后的dy,以及None值



class SequentialSequence(nn.Module):
    # 定义一个继承自nn.Module的类SequentialSequence
    def __init__(self, blocks):
        # 初始化方法,接受blocks作为参数
        super().__init__()
        self.blocks = blocks
        # 调用父类的初始化方法,并将blocks保存在self.blocks中

    def forward(self, x, **kwargs):
        # 前向传播方法,接受输入x和kwargs
        for (attn, ff) in self.blocks:
            x = attn(x, **kwargs)
            x = ff(x)
        # 遍历self.blocks中的每个元素,对输入x进行处理
        return x
        # 返回处理后的x



class ReversibleSequence(nn.Module):
    # 定义一个继承自nn.Module的类ReversibleSequence
    def __init__(self, blocks):
        # 初始化方法,接受blocks作为参数
        super().__init__()
        self.blocks = nn.ModuleList([ReversibleBlock(f, g) for (f, g) in blocks])
        # 调用父类的初始化方法,并将blocks中的每个元素(f, g)构建成ReversibleBlock对象保存在self.blocks中

    def forward(self, x, **kwargs):
        # 前向传播方法,接受输入x和kwargs
        blocks = self.blocks

        x = map_values(lambda t: torch.cat((t, t), dim = -1), x)
        # 对输入x中的值进行操作,将每个值与自身拼接

        input_keys = x.keys()
        split_dims = tuple(map(lambda t: t.shape[-1], x.values()))
        # 获取输入x的键和每个值的最后一个维度大小,保存在split_dims中
        block_kwargs = {'input_keys': input_keys, 'split_dims': split_dims, **kwargs}
        # 构建块的参数字典,包括input_keys、split_dims和kwargs

        x = torch.cat(values(x), dim = -1)
        # 将输入x中的值按照dim = -1进行拼接

        x = _ReversibleFunction.apply(x, blocks, block_kwargs)
        # 调用_ReversibleFunction的apply方法进行处理

        x = dict(zip(input_keys, x.split(split_dims, dim = -1)))
        # 将处理后的x按照split_dims进行分割,组合成字典
        x = map_values(lambda t: torch.stack(t.chunk(2, dim = -1)).mean(dim = 0), x)
        # 对x中的值进行操作,拆分成两部分,取平均值
        return x
        # 返回处理后的x

.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\rotary.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat

# 定义 SinusoidalEmbeddings 类,继承自 nn.Module
class SinusoidalEmbeddings(nn.Module):
    # 初始化函数,接受维度参数 dim
    def __init__(self, dim):
        # 调用父类的初始化函数
        super().__init__()
        # 计算频率的倒数,用于生成正弦位置编码
        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        # 将频率的倒数作为缓冲区注册到模型中
        self.register_buffer('inv_freq', inv_freq)

    # 前向传播函数,接受输入张量 t
    def forward(self, t):
        # 计算频率,用于生成正弦位置编码
        freqs = t[..., None].float() * self.inv_freq[None, :]
        # 将频率重复两次,用于位置编码
        return repeat(freqs, '... d -> ... (d r)', r = 2)

# 定义 rotate_half 函数,用于旋转输入张量的一半
def rotate_half(x):
    # 重新排列输入张量的维度
    x = rearrange(x, '... (d j) m -> ... d j m', j = 2)
    # 将输入张量按照最后一个维度拆分为两部分
    x1, x2 = x.unbind(dim = -2)
    # 将两部分张量进行旋转并拼接在一起
    return torch.cat((-x2, x1), dim = -2)

# 定义 apply_rotary_pos_emb 函数,用于应用旋转位置编码
def apply_rotary_pos_emb(t, freqs):
    # 获取旋转维度的大小
    rot_dim = freqs.shape[-2]
    # 将输入张量 t 拆分为旋转部分和非旋转部分
    t, t_pass = t[..., :rot_dim, :], t[..., rot_dim:, :]
    # 应用旋转位置编码到输入张量 t
    t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
    # 将旋转部分和非旋转部分拼接在一起
    return torch.cat((t, t_pass), dim = -2)

.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\se3_transformer_pytorch.py

# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 从 itertools 模块中导入 product 函数
from itertools import product
# 从 collections 模块中导入 namedtuple 类
from collections import namedtuple

# 导入 torch 库
import torch
# 从 torch.nn.functional 模块中导入 F 别名
import torch.nn.functional as F
# 从 torch 模块中导入 nn 和 einsum 函数
from torch import nn, einsum

# 导入自定义模块
from se3_transformer_pytorch.basis import get_basis
from se3_transformer_pytorch.utils import exists, default, uniq, map_values, batched_index_select, masked_mean, to_order, fourier_encode, cast_tuple, safe_cat, fast_split, rand_uniform, broadcat
from se3_transformer_pytorch.reversible import ReversibleSequence, SequentialSequence
from se3_transformer_pytorch.rotary import SinusoidalEmbeddings, apply_rotary_pos_emb

# 从 einops 模块中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat

# 定义命名元组 FiberEl,包含 degrees 和 dim 两个字段
FiberEl = namedtuple('FiberEl', ['degrees', 'dim'])

# 定义 Fiber 类
class Fiber(nn.Module):
    def __init__(
        self,
        structure
    ):
        super().__init__()
        # 如果 structure 是字典,则转换为列表形式
        if isinstance(structure, dict):
            structure = [FiberEl(degree, dim) for degree, dim in structure.items()]
        self.structure = structure

    # 返回所有维度的列表
    @property
    def dims(self):
        return uniq(map(lambda t: t[1], self.structure))

    # 返回所有度数的生成器
    @property
    def degrees(self):
        return map(lambda t: t[0], self.structure)

    # 创建 Fiber 实例
    @staticmethod
    def create(num_degrees, dim):
        dim_tuple = dim if isinstance(dim, tuple) else ((dim,) * num_degrees)
        return Fiber([FiberEl(degree, dim) for degree, dim in zip(range(num_degrees), dim_tuple)])

    # 获取指定度数的元素
    def __getitem__(self, degree):
        return dict(self.structure)[degree]

    # 迭代器方法
    def __iter__(self):
        return iter(self.structure)

    # 定义乘法操作
    def __mul__(self, fiber):
        return product(self.structure, fiber.structure)

    # 定义与操作
    def __and__(self, fiber):
        out = []
        degrees_out = fiber.degrees
        for degree, dim in self:
            if degree in fiber.degrees:
                dim_out = fiber[degree]
                out.append((degree, dim, dim_out))
        return out

# 获取张量的设备和数据类型
def get_tensor_device_and_dtype(features):
    first_tensor = next(iter(features.items()))[1]
    return first_tensor.device, first_tensor.dtype

# 定义 ResidualSE3 类
class ResidualSE3(nn.Module):
    """ only support instance where both Fibers are identical """
    def forward(self, x, res):
        out = {}
        for degree, tensor in x.items():
            degree = str(degree)
            out[degree] = tensor
            if degree in res:
                out[degree] = out[degree] + res[degree]
        return out

# 定义 LinearSE3 类
class LinearSE3(nn.Module):
    def __init__(
        self,
        fiber_in,
        fiber_out
    ):
        super().__init__()
        self.weights = nn.ParameterDict()

        for (degree, dim_in, dim_out) in (fiber_in & fiber_out):
            key = str(degree)
            self.weights[key]  = nn.Parameter(torch.randn(dim_in, dim_out) / sqrt(dim_in))

    def forward(self, x):
        out = {}
        for degree, weight in self.weights.items():
            out[degree] = einsum('b n d m, d e -> b n e m', x[degree], weight)
        return out

# 定义 NormSE3 类
class NormSE3(nn.Module):
    """Norm-based SE(3)-equivariant nonlinearity.
    
    Nonlinearities are important in SE(3) equivariant GCNs. They are also quite 
    expensive to compute, so it is convenient for them to share resources with
    other layers, such as normalization. The general workflow is as follows:

    > for feature type in features:
    >    norm, phase <- feature
    >    output = fnc(norm) * phase
    
    where fnc: {R+}^m -> R^m is a learnable map from m norms to m scalars.
    """
    def __init__(
        self,
        fiber,
        nonlin = nn.GELU(),
        gated_scale = False,
        eps = 1e-12,
    # 初始化函数,设置初始参数
    def __init__(
        self,
        fiber,
        nonlin = nn.ReLU(),
        eps = 1e-12,
        gated_scale = False
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 将参数赋值给对象属性
        self.fiber = fiber
        self.nonlin = nonlin
        self.eps = eps

        # Norm mappings: 1 per feature type
        # 创建一个 ModuleDict 对象,用于存储每种特征类型的规范化映射
        self.transform = nn.ModuleDict()
        # 遍历 fiber 中的每个元素
        for degree, chan in fiber:
            # 为每种特征类型创建一个参数字典
            self.transform[str(degree)] = nn.ParameterDict({
                'scale': nn.Parameter(torch.ones(1, 1, chan)) if not gated_scale else None,
                'w_gate': nn.Parameter(rand_uniform((chan, chan), -1e-3, 1e-3)) if gated_scale else None
            })

    # 前向传播函数
    def forward(self, features):
        # 初始化输出字典
        output = {}
        # 遍历输入的特征字典
        for degree, t in features.items():
            # 计算规范化和归一化特征
            norm = t.norm(dim = -1, keepdim = True).clamp(min = self.eps)
            phase = t / norm

            # Transform on norms
            # 获取当前特征类型对应的参数
            parameters = self.transform[degree]
            gate_weights, scale = parameters['w_gate'], parameters['scale']

            # 重排特征
            transformed = rearrange(norm, '... () -> ...')

            # 如果缺少 scale 参数,则使用 gate_weights 进行计算
            if not exists(scale):
                scale = einsum('b n d, d e -> b n e', transformed, gate_weights)

            # 对特征进行非线性变换
            transformed = self.nonlin(transformed * scale)
            transformed = rearrange(transformed, '... -> ... ()')

            # 对规范化特征进行非线性变换
            output[degree] = (transformed * phase).view(*t.shape)

        # 返回输出字典
        return output
class ConvSE3(nn.Module):
    """定义一个张量场网络层
    
    ConvSE3代表一个SE(3)-等变卷积层。它相当于MLP中的线性层,CNN中的卷积层,或者GCN中的图卷积层。

    在每个节点上,激活被分成不同的“特征类型”,由SE(3)表示类型索引:非负整数0, 1, 2, ..
    """
    def __init__(
        self,
        fiber_in,
        fiber_out,
        self_interaction = True,
        pool = True,
        edge_dim = 0,
        fourier_encode_dist = False,
        num_fourier_features = 4,
        splits = 4
    ):
        super().__init__()
        self.fiber_in = fiber_in
        self.fiber_out = fiber_out
        self.edge_dim = edge_dim
        self.self_interaction = self_interaction

        self.num_fourier_features = num_fourier_features
        self.fourier_encode_dist = fourier_encode_dist

        # radial function will assume a dimension of at minimum 1, for the relative distance - extra fourier features must be added to the edge dimension
        edge_dim += (0 if not fourier_encode_dist else (num_fourier_features * 2))

        # Neighbor -> center weights
        self.kernel_unary = nn.ModuleDict()

        self.splits = splits # for splitting the computation of kernel and basis, to reduce peak memory usage

        for (di, mi), (do, mo) in (self.fiber_in * self.fiber_out):
            self.kernel_unary[f'({di},{do})'] = PairwiseConv(di, mi, do, mo, edge_dim = edge_dim, splits = splits)

        self.pool = pool

        # Center -> center weights
        if self_interaction:
            assert self.pool, 'must pool edges if followed with self interaction'
            self.self_interact = LinearSE3(fiber_in, fiber_out)
            self.self_interact_sum = ResidualSE3()

    def forward(
        self,
        inp,
        edge_info,
        rel_dist = None,
        basis = None
        ):
            # 获取拆分信息
            splits = self.splits
            neighbor_indices, neighbor_masks, edges = edge_info
            # 重新排列相对距离的维度
            rel_dist = rearrange(rel_dist, 'b m n -> b m n ()')

            kernels = {}
            outputs = {}

            if self.fourier_encode_dist:
                # 对相对距离进行傅立叶编码
                rel_dist = fourier_encode(rel_dist[..., None], num_encodings = self.num_fourier_features)

            # 拆分基础

            basis_keys = basis.keys()
            split_basis_values = list(zip(*list(map(lambda t: fast_split(t, splits, dim = 1), basis.values())))
            split_basis = list(map(lambda v: dict(zip(basis_keys, v)), split_basis_values))

            # 遍历每种输入度类型到输出度类型的排列组合

            for degree_out in self.fiber_out.degrees:
                output = 0
                degree_out_key = str(degree_out)

                for degree_in, m_in in self.fiber_in:
                    etype = f'({degree_in},{degree_out})'

                    x = inp[str(degree_in)]

                    x = batched_index_select(x, neighbor_indices, dim = 1)
                    x = x.view(*x.shape[:3], to_order(degree_in) * m_in, 1)

                    kernel_fn = self.kernel_unary[etype]
                    edge_features = torch.cat((rel_dist, edges), dim = -1) if exists(edges) else rel_dist

                    output_chunk = None
                    split_x = fast_split(x, splits, dim = 1)
                    split_edge_features = fast_split(edge_features, splits, dim = 1)

                    # 沿着序列维度对输入、边缘和基础进行分块处理

                    for x_chunk, edge_features, basis in zip(split_x, split_edge_features, split_basis):
                        kernel = kernel_fn(edge_features, basis = basis)
                        chunk = einsum('... o i, ... i c -> ... o c', kernel, x_chunk)
                        output_chunk = safe_cat(output_chunk, chunk, dim = 1)

                    output = output + output_chunk

                if self.pool:
                    output = masked_mean(output, neighbor_masks, dim = 2) if exists(neighbor_masks) else output.mean(dim = 2)

                leading_shape = x.shape[:2] if self.pool else x.shape[:3]
                output = output.view(*leading_shape, -1, to_order(degree_out))

                outputs[degree_out_key] = output

            if self.self_interaction:
                self_interact_out = self.self_interact(inp)
                outputs = self.self_interact_sum(outputs, self_interact_out)

            return outputs
class RadialFunc(nn.Module):
    """定义一个神经网络参数化的径向函数。"""
    def __init__(
        self,
        num_freq,
        in_dim,
        out_dim,
        edge_dim = None,
        mid_dim = 128
    ):
        super().__init__()
        self.num_freq = num_freq
        self.in_dim = in_dim
        self.mid_dim = mid_dim
        self.out_dim = out_dim
        self.edge_dim = default(edge_dim, 0)

        self.net = nn.Sequential(
            nn.Linear(self.edge_dim + 1, mid_dim),
            nn.LayerNorm(mid_dim),
            nn.GELU(),
            nn.Linear(mid_dim, mid_dim),
            nn.LayerNorm(mid_dim),
            nn.GELU(),
            nn.Linear(mid_dim, num_freq * in_dim * out_dim)
        )

    def forward(self, x):
        y = self.net(x)
        return rearrange(y, '... (o i f) -> ... o () i () f', i = self.in_dim, o = self.out_dim)

class PairwiseConv(nn.Module):
    """两种单一类型特征之间的SE(3)-等变卷积。"""
    def __init__(
        self,
        degree_in,
        nc_in,
        degree_out,
        nc_out,
        edge_dim = 0,
        splits = 4
    ):
        super().__init__()
        self.degree_in = degree_in
        self.degree_out = degree_out
        self.nc_in = nc_in
        self.nc_out = nc_out

        self.num_freq = to_order(min(degree_in, degree_out))
        self.d_out = to_order(degree_out)
        self.edge_dim = edge_dim

        self.rp = RadialFunc(self.num_freq, nc_in, nc_out, edge_dim)

        self.splits = splits

    def forward(self, feat, basis):
        splits = self.splits
        R = self.rp(feat)
        B = basis[f'{self.degree_in},{self.degree_out}']

        out_shape = (*R.shape[:3], self.d_out * self.nc_out, -1)

        # torch.sum(R * B, dim = -1) is too memory intensive
        # needs to be chunked to reduce peak memory usage

        out = 0
        for i in range(R.shape[-1]):
            out += R[..., i] * B[..., i]

        out = rearrange(out, 'b n h s ... -> (b n h s) ...')

        # reshape and out
        return out.view(*out_shape)

# feed forwards

class FeedForwardSE3(nn.Module):
    def __init__(
        self,
        fiber,
        mult = 4
    ):
        super().__init__()
        self.fiber = fiber
        fiber_hidden = Fiber(list(map(lambda t: (t[0], t[1] * mult), fiber)))

        self.project_in  = LinearSE3(fiber, fiber_hidden)
        self.nonlin      = NormSE3(fiber_hidden)
        self.project_out = LinearSE3(fiber_hidden, fiber)

    def forward(self, features):
        outputs = self.project_in(features)
        outputs = self.nonlin(outputs)
        outputs = self.project_out(outputs)
        return outputs

class FeedForwardBlockSE3(nn.Module):
    def __init__(
        self,
        fiber,
        norm_gated_scale = False
    ):
        super().__init__()
        self.fiber = fiber
        self.prenorm = NormSE3(fiber, gated_scale = norm_gated_scale)
        self.feedforward = FeedForwardSE3(fiber)
        self.residual = ResidualSE3()

    def forward(self, features):
        res = features
        out = self.prenorm(features)
        out = self.feedforward(out)
        return self.residual(out, res)

# attention

class AttentionSE3(nn.Module):
    def __init__(
        self,
        fiber,
        dim_head = 64,
        heads = 8,
        attend_self = False,
        edge_dim = None,
        fourier_encode_dist = False,
        rel_dist_num_fourier_features = 4,
        use_null_kv = False,
        splits = 4,
        global_feats_dim = None,
        linear_proj_keys = False,
        tie_key_values = False
        ):
        # 调用父类的构造函数
        super().__init__()
        # 计算隐藏层维度
        hidden_dim = dim_head * heads
        # 创建隐藏层的 Fiber 对象
        hidden_fiber = Fiber(list(map(lambda t: (t[0], hidden_dim), fiber)))
        # 判断是否需要进行输出投影
        project_out = not (heads == 1 and len(fiber.dims) == 1 and dim_head == fiber.dims[0])

        # 设置缩放因子
        self.scale = dim_head ** -0.5
        self.heads = heads

        # 是否对特征进行线性投影以获得 keys
        self.linear_proj_keys = linear_proj_keys
        # 创建 LinearSE3 对象用于处理 queries
        self.to_q = LinearSE3(fiber, hidden_fiber)
        # 创建 ConvSE3 对象用于处理 values
        self.to_v = ConvSE3(fiber, hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)

        # 检查是否同时进行线性投影 keys 和共享 key / values
        assert not (linear_proj_keys and tie_key_values), 'you cannot do linear projection of keys and have shared key / values turned on at the same time'

        # 根据不同情况创建 keys 处理对象
        if linear_proj_keys:
            self.to_k = LinearSE3(fiber, hidden_fiber)
        elif not tie_key_values:
            self.to_k = ConvSE3(fiber, hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)
        else:
            self.to_k = None

        # 创建输出处理对象
        self.to_out = LinearSE3(hidden_fiber, fiber) if project_out else nn.Identity()

        # 是否使用空的 keys 和 values
        self.use_null_kv = use_null_kv
        if use_null_kv:
            self.null_keys = nn.ParameterDict()
            self.null_values = nn.ParameterDict()

            # 初始化空的 keys 和 values
            for degree in fiber.degrees:
                m = to_order(degree)
                degree_key = str(degree)
                self.null_keys[degree_key] = nn.Parameter(torch.zeros(heads, dim_head, m))
                self.null_values[degree_key] = nn.Parameter(torch.zeros(heads, dim_head, m))

        # 是否自我关注
        self.attend_self = attend_self
        if attend_self:
            # 创建自我关注的 keys 处理对象
            self.to_self_k = LinearSE3(fiber, hidden_fiber)
            # 创建自我关注的 values 处理对象
            self.to_self_v = LinearSE3(fiber, hidden_fiber)

        # 是否接受全局特征
        self.accept_global_feats = exists(global_feats_dim)
        if self.accept_global_feats:
            # 创建全局特征的 keys 处理对象
            global_input_fiber = Fiber.create(1, global_feats_dim)
            global_output_fiber = Fiber.create(1, hidden_fiber[0])
            self.to_global_k = LinearSE3(global_input_fiber, global_output_fiber)
            # 创建全局特征的 values 处理对象
            self.to_global_v = LinearSE3(global_input_fiber, global_output_fiber)
    # 定义前向传播函数,接收特征、边信息、相对距离、基础信息、全局特征、位置嵌入和掩码作为输入
    def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None, mask = None):
        # 获取头数和是否自我关注的标志
        h, attend_self = self.heads, self.attend_self
        # 获取设备和数据类型
        device, dtype = get_tensor_device_and_dtype(features)
        # 解包边信息
        neighbor_indices, neighbor_mask, edges = edge_info

        # 如果邻居掩码存在,则重排维度
        if exists(neighbor_mask):
            neighbor_mask = rearrange(neighbor_mask, 'b i j -> b () i j')

        # 将特征转换为查询、值和键
        queries = self.to_q(features)
        values  = self.to_v(features, edge_info, rel_dist, basis)

        # 如果使用线性投影的键,则将键映射到邻居索引
        if self.linear_proj_keys:
            keys = self.to_k(features)
            keys = map_values(lambda val: batched_index_select(val, neighbor_indices, dim = 1), keys)
        # 如果没有定义键转换函数,则将键设置为值
        elif not exists(self.to_k):
            keys = values
        else:
            keys = self.to_k(features, edge_info, rel_dist, basis)

        # 如果允许自我关注,则获取自我键和自我值
        if attend_self:
            self_keys, self_values = self.to_self_k(features), self.to_self_v(features)

        # 如果全局特征存在,则获取全局键和全局值
        if exists(global_feats):
            global_keys, global_values = self.to_global_k(global_feats), self.to_global_v(global_feats)

        # 初始化输出字典
        outputs = {}
        # 遍历特征的度
        for degree in features.keys():
            # 获取当前度的查询、键和值
            q, k, v = map(lambda t: t[degree], (queries, keys, values))

            # 重排查询、键和值的维度
            q = rearrange(q, 'b i (h d) m -> b h i d m', h = h)
            k, v = map(lambda t: rearrange(t, 'b i j (h d) m -> b h i j d m', h = h), (k, v))

            # 如果允许自我关注,则处理自我键和自我值
            if attend_self:
                self_k, self_v = map(lambda t: t[degree], (self_keys, self_values))
                self_k, self_v = map(lambda t: rearrange(t, 'b n (h d) m -> b h n () d m', h = h), (self_k, self_v))
                k = torch.cat((self_k, k), dim = 3)
                v = torch.cat((self_v, v), dim = 3)

            # 如果位置嵌入存在且度为'0',则应用旋转位置嵌入
            if exists(pos_emb) and degree == '0':
                query_pos_emb, key_pos_emb = pos_emb
                query_pos_emb = rearrange(query_pos_emb, 'b i d -> b () i d ()')
                key_pos_emb = rearrange(key_pos_emb, 'b i j d -> b () i j d ()')
                q = apply_rotary_pos_emb(q, query_pos_emb)
                k = apply_rotary_pos_emb(k, key_pos_emb)
                v = apply_rotary_pos_emb(v, key_pos_emb)

            # 如果使用空键值对,则处理空键和空值
            if self.use_null_kv:
                null_k, null_v = map(lambda t: t[degree], (self.null_keys, self.null_values))
                null_k, null_v = map(lambda t: repeat(t, 'h d m -> b h i () d m', b = q.shape[0], i = q.shape[2]), (null_k, null_v))
                k = torch.cat((null_k, k), dim = 3)
                v = torch.cat((null_v, v), dim = 3)

            # 如果全局特征存在且度为'0',则处理全局键和全局值
            if exists(global_feats) and degree == '0':
                global_k, global_v = map(lambda t: t[degree], (global_keys, global_values))
                global_k, global_v = map(lambda t: repeat(t, 'b j (h d) m -> b h i j d m', h = h, i = k.shape[2]), (global_k, global_v))
                k = torch.cat((global_k, k), dim = 3)
                v = torch.cat((global_v, v), dim = 3)

            # 计算注意力权重
            sim = einsum('b h i d m, b h i j d m -> b h i j', q, k) * self.scale

            # 如果邻居掩码存在,则进行掩码处理
            if exists(neighbor_mask):
                num_left_pad = sim.shape[-1] - neighbor_mask.shape[-1]
                mask = F.pad(neighbor_mask, (num_left_pad, 0), value = True)
                sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

            # 计算注意力输出
            attn = sim.softmax(dim = -1)
            out = einsum('b h i j, b h i j d m -> b h i d m', attn, v)
            outputs[degree] = rearrange(out, 'b h n d m -> b n (h d) m')

        # 返回输出结果
        return self.to_out(outputs)
# 定义一个带有一个键/值投影的注意力机制类,该投影在所有查询头之间共享
class OneHeadedKVAttentionSE3(nn.Module):
    def __init__(
        self,
        fiber,
        dim_head = 64,
        heads = 8,
        attend_self = False,
        edge_dim = None,
        fourier_encode_dist = False,
        rel_dist_num_fourier_features = 4,
        use_null_kv = False,
        splits = 4,
        global_feats_dim = None,
        linear_proj_keys = False,
        tie_key_values = False
    ):
        super().__init__()
        hidden_dim = dim_head * heads
        hidden_fiber = Fiber(list(map(lambda t: (t[0], hidden_dim), fiber)))
        kv_hidden_fiber = Fiber(list(map(lambda t: (t[0], dim_head), fiber)))
        project_out = not (heads == 1 and len(fiber.dims) == 1 and dim_head == fiber.dims[0])

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.linear_proj_keys = linear_proj_keys # 是否对键进行线性投影,而不是与基卷积

        # 创建查询线性层
        self.to_q = LinearSE3(fiber, hidden_fiber)
        # 创建值卷积层
        self.to_v = ConvSE3(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)

        assert not (linear_proj_keys and tie_key_values), 'you cannot do linear projection of keys and have shared key / values turned on at the same time'

        if linear_proj_keys:
            # 如果进行线性投影,则创建键的线性层
            self.to_k = LinearSE3(fiber, kv_hidden_fiber)
        elif not tie_key_values:
            # 如果不共享键/值,则创建键的卷积层
            self.to_k = ConvSE3(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)
        else:
            self.to_k = None

        # 创建输出线性层
        self.to_out = LinearSE3(hidden_fiber, fiber) if project_out else nn.Identity()

        self.use_null_kv = use_null_kv
        if use_null_kv:
            # 如果使用空键/值,则创建空键和值的参数字典
            self.null_keys = nn.ParameterDict()
            self.null_values = nn.ParameterDict()

            for degree in fiber.degrees:
                m = to_order(degree)
                degree_key = str(degree)
                self.null_keys[degree_key] = nn.Parameter(torch.zeros(dim_head, m))
                self.null_values[degree_key] = nn.Parameter(torch.zeros(dim_head, m))

        self.attend_self = attend_self
        if attend_self:
            # 如果自我关注,则创建自我键和值的线性层
            self.to_self_k = LinearSE3(fiber, kv_hidden_fiber)
            self.to_self_v = LinearSE3(fiber, kv_hidden_fiber)

        self.accept_global_feats = exists(global_feats_dim)
        if self.accept_global_feats:
            # 如果接受全局特征,则创建全局键和值的线性层
            global_input_fiber = Fiber.create(1, global_feats_dim)
            global_output_fiber = Fiber.create(1, kv_hidden_fiber[0])
            self.to_global_k = LinearSE3(global_input_fiber, global_output_fiber)
            self.to_global_v = LinearSE3(global_input_fiber, global_output_fiber)
    # 定义前向传播函数,接收特征、边信息、相对距离、基础信息、全局特征、位置嵌入和掩码作为输入
    def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None, mask = None):
        # 获取头数和是否自我关注的标志
        h, attend_self = self.heads, self.attend_self
        # 获取设备和数据类型
        device, dtype = get_tensor_device_and_dtype(features)
        # 解包边信息
        neighbor_indices, neighbor_mask, edges = edge_info

        # 如果存在邻居掩码,则重排维度
        if exists(neighbor_mask):
            neighbor_mask = rearrange(neighbor_mask, 'b i j -> b () i j')

        # 将特征转换为查询、值和键
        queries = self.to_q(features)
        values  = self.to_v(features, edge_info, rel_dist, basis)

        # 如果使用线性投影的键,则将键映射到相应的位置
        if self.linear_proj_keys:
            keys = self.to_k(features)
            keys = map_values(lambda val: batched_index_select(val, neighbor_indices, dim = 1), keys)
        # 如果没有定义键转换函数,则将键设置为值
        elif not exists(self.to_k):
            keys = values
        else:
            keys = self.to_k(features, edge_info, rel_dist, basis)

        # 如果允许自我关注,则获取自我关注的键和值
        if attend_self:
            self_keys, self_values = self.to_self_k(features), self.to_self_v(features)

        # 如果存在全局特征,则获取全局键和值
        if exists(global_feats):
            global_keys, global_values = self.to_global_k(global_feats), self.to_global_v(global_feats)

        # 初始化输出字典
        outputs = {}
        # 遍历特征的度
        for degree in features.keys():
            # 获取当前度的查询、键和值
            q, k, v = map(lambda t: t[degree], (queries, keys, values))

            # 重排查询的维度
            q = rearrange(q, 'b i (h d) m -> b h i d m', h = h)

            # 如果允许自我关注,则处理自我关注的键和值
            if attend_self:
                self_k, self_v = map(lambda t: t[degree], (self_keys, self_values))
                self_k, self_v = map(lambda t: rearrange(t, 'b n d m -> b n () d m'), (self_k, self_v))
                k = torch.cat((self_k, k), dim = 2)
                v = torch.cat((self_v, v), dim = 2)

            # 如果存在位置嵌入并且度为 '0',则应用旋转位置嵌入
            if exists(pos_emb) and degree == '0':
                query_pos_emb, key_pos_emb = pos_emb
                query_pos_emb = rearrange(query_pos_emb, 'b i d -> b () i d ()')
                key_pos_emb = rearrange(key_pos_emb, 'b i j d -> b i j d ()')
                q = apply_rotary_pos_emb(q, query_pos_emb)
                k = apply_rotary_pos_emb(k, key_pos_emb)
                v = apply_rotary_pos_emb(v, key_pos_emb)

            # 如果使用空键值对,则将空键值对与当前键值对拼接
            if self.use_null_kv:
                null_k, null_v = map(lambda t: t[degree], (self.null_keys, self.null_values))
                null_k, null_v = map(lambda t: repeat(t, 'd m -> b i () d m', b = q.shape[0], i = q.shape[2]), (null_k, null_v))
                k = torch.cat((null_k, k), dim = 2)
                v = torch.cat((null_v, v), dim = 2)

            # 如果存在全局特征并且度为 '0',则将全局键值对与当前键值对拼接
            if exists(global_feats) and degree == '0':
                global_k, global_v = map(lambda t: t[degree], (global_keys, global_values))
                global_k, global_v = map(lambda t: repeat(t, 'b j d m -> b i j d m', i = k.shape[1]), (global_k, global_v))
                k = torch.cat((global_k, k), dim = 2)
                v = torch.cat((global_v, v), dim = 2)

            # 计算注意力权重
            sim = einsum('b h i d m, b i j d m -> b h i j', q, k) * self.scale

            # 如果存在邻居掩码,则进行掩码操作
            if exists(neighbor_mask):
                num_left_pad = sim.shape[-1] - neighbor_mask.shape[-1]
                mask = F.pad(neighbor_mask, (num_left_pad, 0), value = True)
                sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

            # 计算注意力分布并进行加权求和
            attn = sim.softmax(dim = -1)
            out = einsum('b h i j, b i j d m -> b h i d m', attn, v)
            outputs[degree] = rearrange(out, 'b h n d m -> b n (h d) m')

        # 将输出转换为最终输出
        return self.to_out(outputs)
# 定义一个注意力块类,继承自 nn.Module
class AttentionBlockSE3(nn.Module):
    def __init__(
        self,
        fiber,
        dim_head = 24,
        heads = 8,
        attend_self = False,
        edge_dim = None,
        use_null_kv = False,
        fourier_encode_dist = False,
        rel_dist_num_fourier_features = 4,
        splits = 4,
        global_feats_dim = False,
        linear_proj_keys = False,
        tie_key_values = False,
        attention_klass = AttentionSE3,
        norm_gated_scale = False
    ):
        super().__init__()
        # 初始化注意力机制
        self.attn = attention_klass(fiber, heads = heads, dim_head = dim_head, attend_self = attend_self, edge_dim = edge_dim, use_null_kv = use_null_kv, rel_dist_num_fourier_features = rel_dist_num_fourier_features, fourier_encode_dist =fourier_encode_dist, splits = splits, global_feats_dim = global_feats_dim, linear_proj_keys = linear_proj_keys, tie_key_values = tie_key_values)
        # 初始化预处理层
        self.prenorm = NormSE3(fiber, gated_scale = norm_gated_scale)
        # 初始化残差连接
        self.residual = ResidualSE3()

    def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None, mask = None):
        res = features
        # 对输入特征进行预处理
        outputs = self.prenorm(features)
        # 使用注意力机制处理特征
        outputs = self.attn(outputs, edge_info, rel_dist, basis, global_feats, pos_emb, mask)
        # 返回残差连接结果
        return self.residual(outputs, res)

# 定义 Swish_ 类
class Swish_(nn.Module):
    def forward(self, x):
        return x * x.sigmoid()

# 如果 nn 模块中有 SiLU 函数,则使用 nn.SiLU,否则使用自定义的 Swish_ 类
SiLU = nn.SiLU if hasattr(nn, 'SiLU') else Swish_

# 定义 HtypesNorm 类
class HtypesNorm(nn.Module):
    def __init__(self, dim, eps = 1e-8, scale_init = 1e-2, bias_init = 1e-2):
        super().__init__()
        self.eps = eps
        # 初始化缩放参数和偏置参数
        scale = torch.empty(1, 1, 1, dim, 1).fill_(scale_init)
        bias = torch.empty(1, 1, 1, dim, 1).fill_(bias_init)
        self.scale = nn.Parameter(scale)
        self.bias = nn.Parameter(bias)

    def forward(self, coors):
        # 计算输入张量的范数
        norm = coors.norm(dim = -1, keepdim = True)
        # 对输入张量进行归一化处理
        normed_coors = coors / norm.clamp(min = self.eps)
        return normed_coors * (norm * self.scale + self.bias)

# 定义 EGNN 类
class EGNN(nn.Module):
    def __init__(
        self,
        fiber,
        hidden_dim = 32,
        edge_dim = 0,
        init_eps = 1e-3,
        coor_weights_clamp_value = None
    ):
        super().__init__()
        self.fiber = fiber
        node_dim = fiber[0]

        htypes = list(filter(lambda t: t.degrees != 0, fiber))
        num_htypes = len(htypes)
        htype_dims = sum([fiberel.dim for fiberel in htypes])

        edge_input_dim = node_dim * 2 + htype_dims + edge_dim + 1

        # 初始化节点归一化层
        self.node_norm = nn.LayerNorm(node_dim)

        # 初始化边 MLP
        self.edge_mlp = nn.Sequential(
            nn.Linear(edge_input_dim, edge_input_dim * 2),
            SiLU(),
            nn.Linear(edge_input_dim * 2, hidden_dim),
            SiLU()
        )

        self.htype_norms = nn.ModuleDict({})
        self.htype_gating = nn.ModuleDict({})

        for degree, dim in fiber:
            if degree == 0:
                continue
            # 初始化 HtypesNorm 和线性层
            self.htype_norms[str(degree)] = HtypesNorm(dim)
            self.htype_gating[str(degree)] = nn.Linear(node_dim, dim)

        # 初始化 Htypes MLP
        self.htypes_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            SiLU(),
            nn.Linear(hidden_dim * 4, htype_dims)
        )

        # 初始化节点 MLP
        self.node_mlp = nn.Sequential(
            nn.Linear(node_dim + hidden_dim, node_dim * 2),
            SiLU(),
            nn.Linear(node_dim * 2, node_dim)
        )

        self.coor_weights_clamp_value = coor_weights_clamp_value
        self.init_eps = init_eps
        self.apply(self.init_)

    def init_(self, module):
        if type(module) in {nn.Linear}:
            # 初始化线性层的权重
            nn.init.normal_(module.weight, std = self.init_eps)

    def forward(
        self,
        features,
        edge_info,
        rel_dist,
        mask = None,
        **kwargs
        ):
            # 解包边信息
            neighbor_indices, neighbor_masks, edges = edge_info

            # 使用邻居掩码
            mask = neighbor_masks

            # 类型 0 特征

            # 获取节点特征
            nodes = features['0']
            # 重新排列节点特征
            nodes = rearrange(nodes, '... () -> ...')

            # 更高级别类型(htype)

            # 过滤出非 '0' 类型的特征
            htypes = list(filter(lambda t: t[0] != '0', features.items()))
            # 获取每个类型的度数
            htype_degrees = list(map(lambda t: t[0], htypes))
            # 获取每个类型的维度
            htype_dims = list(map(lambda t: t[1].shape[-2], htypes))

            # 准备更高级别类型

            rel_htypes = []
            rel_htypes_dists = []

            for degree, htype in htypes:
                # 计算相对类型
                rel_htype = rearrange(htype, 'b i d m -> b i () d m') - rearrange(htype, 'b j d m -> b () j d m')
                rel_htype_dist = rel_htype.norm(dim = -1)

                rel_htypes.append(rel_htype)
                rel_htypes_dists.append(rel_htype_dist)

            # 为边 MLP 准备边

            nodes_i = rearrange(nodes, 'b i d -> b i () d')
            nodes_j = batched_index_select(nodes, neighbor_indices, dim = 1)
            neighbor_higher_type_dists = map(lambda t: batched_index_select(t, neighbor_indices, dim = 2), rel_htypes_dists)
            coor_rel_dist = rearrange(rel_dist, 'b i j -> b i j ()')

            edge_mlp_inputs = broadcat((nodes_i, nodes_j, *neighbor_higher_type_dists, coor_rel_dist), dim = -1)

            if exists(edges):
                edge_mlp_inputs = torch.cat((edge_mlp_inputs, edges), dim = -1)

            # 获取中间表示

            m_ij = self.edge_mlp(edge_mlp_inputs)

            # 转换为坐标

            htype_weights = self.htypes_mlp(m_ij)

            if exists(self.coor_weights_clamp_value):
                clamp_value = self.coor_weights_clamp_value
                htype_weights.clamp_(min = -clamp_value, max = clamp_value)

            split_htype_weights = htype_weights.split(htype_dims, dim = -1)

            htype_updates = []

            if exists(mask):
                htype_mask = rearrange(mask, 'b i j -> b i j ()')
                htype_weights = htype_weights.masked_fill(~htype_mask, 0.)

            for degree, rel_htype, htype_weight in zip(htype_degrees, rel_htypes, split_htype_weights):
                normed_rel_htype = self.htype_norms[str(degree)](rel_htype)
                normed_rel_htype = batched_index_select(normed_rel_htype, neighbor_indices, dim = 2)

                htype_update = einsum('b i j d m, b i j d -> b i d m', normed_rel_htype, htype_weight)
                htype_updates.append(htype_update)

            # 转换为节点

            if exists(mask):
                m_ij_mask = rearrange(mask, '... -> ... ()')
                m_ij = m_ij.masked_fill(~m_ij_mask, 0.)

            m_i = m_ij.sum(dim = -2)

            normed_nodes = self.node_norm(nodes)
            node_mlp_input = torch.cat((normed_nodes, m_i), dim = -1)
            node_out = self.node_mlp(node_mlp_input) + nodes

            # 更新节点

            features['0'] = rearrange(node_out, '... -> ... ()')

            # 更新更高级别类型

            update_htype_dicts = dict(zip(htype_degrees, htype_updates))

            for degree, update_htype in update_htype_dicts.items():
                features[degree] = features[degree] + update_htype

            for degree in htype_degrees:
                gating = self.htype_gating[str(degree)](node_out).sigmoid()
                features[degree] = features[degree] * rearrange(gating, '... -> ... ()')

            return features
# 定义一个 EGnnNetwork 类,继承自 nn.Module 类
class EGnnNetwork(nn.Module):
    # 初始化函数,接收多个参数
    def __init__(
        self,
        *,
        fiber,
        depth,
        edge_dim = 0,
        hidden_dim = 32,
        coor_weights_clamp_value = None,
        feedforward = False
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 将参数赋值给对象属性
        self.fiber = fiber
        self.layers = nn.ModuleList([])
        # 循环创建指定数量的 EGNN 和 FeedForwardBlockSE3 对象,并添加到 layers 中
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                EGNN(fiber = fiber, edge_dim = edge_dim, hidden_dim = hidden_dim, coor_weights_clamp_value = coor_weights_clamp_value),
                FeedForwardBlockSE3(fiber) if feedforward else None
            ]))

    # 前向传播函数,接收多个参数
    def forward(
        self,
        features,
        edge_info,
        rel_dist,
        basis,
        global_feats = None,
        pos_emb = None,
        mask = None,
        **kwargs
    ):
        # 解包 edge_info 参数
        neighbor_indices, neighbor_masks, edges = edge_info
        # 获取设备信息
        device = neighbor_indices.device

        # 修改邻居信息以包含自身(因为 SE3 变换器依赖于去除对自身的注意力,但这不适用于 EGNN)

        # 创建包含自身索引的张量
        self_indices = torch.arange(neighbor_indices.shape[1], device = device)
        self_indices = rearrange(self_indices, 'i -> () i ()')
        neighbor_indices = broadcat((self_indices, neighbor_indices), dim = -1)

        # 对邻居掩码进行填充
        neighbor_masks = F.pad(neighbor_masks, (1, 0), value = True)
        rel_dist = F.pad(rel_dist, (1, 0), value = 0.)

        # 如果存在边信息,则对边信息进行填充
        if exists(edges):
            edges = F.pad(edges, (0, 0, 1, 0), value = 0.)  # 暂时将令牌到自身的边设置为 0

        edge_info = (neighbor_indices, neighbor_masks, edges)

        # 遍历每一层
        for egnn, ff in self.layers:
            # 调用 EGNN 对象进行特征变换
            features = egnn(
                features,
                edge_info = edge_info,
                rel_dist = rel_dist,
                basis = basis,
                global_feats = global_feats,
                pos_emb = pos_emb,
                mask = mask,
                **kwargs
            )

            # 如果存在 FeedForwardBlockSE3 对象,则调用进行特征变换
            if exists(ff):
                features = ff(features)

        return features

# 主类
class SE3Transformer(nn.Module):
    # 初始化函数,接收多个参数
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 24,
        depth = 2,
        input_degrees = 1,
        num_degrees = None,
        output_degrees = 1,
        valid_radius = 1e5,
        reduce_dim_out = False,
        num_tokens = None,
        num_positions = None,
        num_edge_tokens = None,
        edge_dim = None,
        reversible = False,
        attend_self = True,
        use_null_kv = False,
        differentiable_coors = False,
        fourier_encode_dist = False,
        rel_dist_num_fourier_features = 4,
        num_neighbors = float('inf'),
        attend_sparse_neighbors = False,
        num_adj_degrees = None,
        adj_dim = 0,
        max_sparse_neighbors = float('inf'),
        dim_in = None,
        dim_out = None,
        norm_out = False,
        num_conv_layers = 0,
        causal = False,
        splits = 4,
        global_feats_dim = None,
        linear_proj_keys = False,
        one_headed_key_values = False,
        tie_key_values = False,
        rotary_position = False,
        rotary_rel_dist = False,
        norm_gated_scale = False,
        use_egnn = False,
        egnn_hidden_dim = 32,
        egnn_weights_clamp_value = None,
        egnn_feedforward = False,
        hidden_fiber_dict = None,
        out_fiber_dict = None
    # 前向传播函数,接收多个参数
    def forward(
        self,
        feats,
        coors,
        mask = None,
        adj_mat = None,
        edges = None,
        return_type = None,
        return_pooled = False,
        neighbor_mask = None,
        global_feats = None
posted @ 2024-06-28 14:02  绝不原创的飞龙  阅读(4)  评论(0编辑  收藏  举报