Lucidrains-系列项目源码解析-三十八-
Lucidrains 系列项目源码解析(三十八)
.\lucidrains\routing-transformer\routing_transformer\routing_transformer.py
# 导入 torch 库
import torch
# 导入 torch 中的神经网络模块
import torch.nn as nn
# 导入 torch 中的函数操作模块
import torch.nn.functional as F
# 导入 math 库
import math
# 从 inspect 模块中导入 isfunction 函数
from inspect import isfunction
# 从 operator 模块中导入 mul 函数
from operator import mul
# 从 functools 模块中导入 partial, reduce, wraps 函数
from functools import partial, reduce, wraps
# 从 einops 库中导入 rearrange, repeat 函数
from einops import rearrange, repeat
# 从 einops.layers.torch 模块中导入 Rearrange 类
from einops.layers.torch import Rearrange
# 从 local_attention 模块中导入 LocalAttention 类
from local_attention import LocalAttention
# 从 product_key_memory 模块中导入 PKM 类
from product_key_memory import PKM
# 从 mixture_of_experts 模块中导入 MoE 类
from mixture_of_experts import MoE
# 从 routing_transformer.reversible 模块中导入 ReversibleSequence, SequentialSequence 类
# 常量定义
# 定义 TOKEN_SELF_ATTN_VALUE 常量为 -5e4
TOKEN_SELF_ATTN_VALUE = -5e4
# 定义 KMEAN_INIT_ITERS 常量为 10
KMEAN_INIT_ITERS = 10
# 辅助函数
# 判断值是否存在的函数
def exists(val):
return val is not None
# 返回输入值的函数
def identity(x, *args, **kwargs):
return x
# 如果输入值不存在,则返回默认值的函数
def default(x, d):
if not exists(x):
return d if not isfunction(d) else d()
return x
# 将输入值转换为元组的函数
def cast_tuple(x):
return x if isinstance(x, tuple) else (x,)
# 缓存函数的装饰器
def cache_fn(f):
cache = None
@wraps(f)
def cached_fn(*args, **kwargs):
nonlocal cache
if exists(cache):
return cache
cache = f(*args, **kwargs)
return cache
return cached_fn
# 组合多个函数的函数
def compose(*fns):
def inner(x, *args, **kwargs):
for fn in reversed(fns):
x = fn(x, *args, **kwargs)
return x
return inner
# 返回输入张量的设备和数据类型的字典的函数
def to(t):
return {'device': t.device, 'dtype': t.dtype}
# 查找神经网络模块中指定类型的模块的函数
def find_modules(nn_module, type):
return [module for module in nn_module.modules() if isinstance(module, type)]
# 判断张量是否为空的函数
def is_empty(t):
return t.nelement() == 0
# 返回指定张量数据类型的最大负值的函数
def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max
# 在指定维度上对张量进行批量索引选择的函数
def batched_index_select(values, indices):
last_dim = values.shape[-1]
return values.gather(2, expand_dim(indices, -1, last_dim))
# 合并张量的维度的函数
def merge_dims(ind_from, ind_to, tensor):
shape = list(tensor.shape)
arr_slice = slice(ind_from, ind_to + 1)
shape[arr_slice] = [reduce(mul, shape[arr_slice])]
return tensor.reshape(*shape)
# 在指定维度上扩展张量的函数
def expand_dim(t, dim, k):
t = t.unsqueeze(dim)
expand_shape = [-1] * len(t.shape)
expand_shape[dim] = k
return t.expand(*expand_shape)
# 在指定维度上对张量进行均值散开的函数
def scatter_mean(src, t, index, dim, eps = 1e-5):
numer = src.scatter_add(dim, index, t)
denom = src.scatter_add(dim, index, torch.ones_like(t))
return numer / (denom + eps)
# 在指定维度上将张量拆分为两部分的函数
def split_at_index(dim, index, t):
pre_slices = (slice(None),) * dim
l = (*pre_slices, slice(None, index))
r = (*pre_slices, slice(index, None))
return t[l], t[r]
# 重塑张量的维度的函数
def reshape_dim(t, dim, split_dims):
shape = list(t.shape)
num_dims = len(shape)
dim = (dim + num_dims) % num_dims
shape[dim:dim+1] = split_dims
return t.reshape(shape)
# 指数移动平均的函数
def ema(old, new, decay):
if not exists(old):
return new
return old * decay + new * (1 - decay)
# 就地指数移动平均的函数
def ema_inplace(moving_avg, new, decay):
if is_empty(moving_avg):
moving_avg.data.copy_(new)
return
moving_avg.data.mul_(decay).add_(new, alpha= (1 - decay))
# 辅助类
# 对第一个元组或元素应用函数的类
class Chunk(nn.Module):
def __init__(self, chunks, fn, along_dim = -1):
super().__init__()
self.dim = along_dim
self.chunks = chunks
self.fn = fn
def forward(self, x, **kwargs):
if self.chunks <= 1:
return self.fn(x, **kwargs)
chunks = x.chunk(self.chunks, dim = self.dim)
return torch.cat([self.fn(c, **kwargs) for c in chunks], dim = self.dim)
# 具有预处理的模块列表的类
class PreNorm(nn.ModuleList):
def __init__(self, norm_class, dim, fn):
super().__init__()
self.norm = norm_class(dim)
self.fn = fn
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
# ReZero 模块
class ReZero(nn.Module):
def __init__(self, fn):
super().__init__()
self.residual_weight = nn.Parameter(torch.zeros(1))
self.fn = fn
def forward(self, x, **kwargs):
x = self.fn(x, **kwargs)
return map_first_tuple_or_el(x, lambda t: t * self.residual_weight)
# 定义 ScaleNorm 类,用于对输入进行归一化处理
class ScaleNorm(nn.Module):
# 初始化函数,设置归一化参数和阈值
def __init__(self, dim, eps=1e-5):
super().__init__()
self.g = nn.Parameter(torch.ones(1))
self.eps = eps
# 前向传播函数,对输入进行归一化处理
def forward(self, x):
# 定义内部函数 norm,用于计算归一化后的值
def norm(t):
# 计算输入张量 t 在指定维度上的 L2 范数,并进行归一化处理
n = torch.norm(t, dim=-1, keepdim=True).clamp(min=self.eps)
return t / n * self.g
# 调用 map_first_tuple_or_el 函数,对输入进行处理
return map_first_tuple_or_el(x, norm)
# 定义 ProjectInOut 类,用于对输入进行线性投影
class ProjectInOut(nn.Module):
# 初始化函数,设置投影函数和维度参数
def __init__(self, fn, dim_in, dim_out, project_out = True):
super().__init__()
self.fn = fn
self.project_in = nn.Linear(dim_in, dim_out)
self.project_out = nn.Linear(dim_out, dim_in) if project_out else identity
# 前向传播函数,对输入进行线性投影处理
def forward(self, x, **kwargs):
# 对输入进行投影处理
x = self.project_in(x)
# 调用 fn 函数处理投影后的结果
x, loss = self.fn(x, **kwargs)
# 对输出进行反向投影处理
x = self.project_out(x)
return x, loss
# 定义 MatrixMultiply 类,用于矩阵乘法操作
class MatrixMultiply(nn.Module):
# 初始化函数,设置矩阵和是否转置参数
def __init__(self, tensor, transpose = False):
super().__init__()
self.tensor = tensor
self.transpose = transpose
# 前向传播函数,进行矩阵乘法操作
def forward(self, x):
tensor = self.tensor
# 如果需要转置,则对矩阵进行转置操作
if self.transpose:
tensor = tensor.t()
return x @ tensor
# 定义 token shift 函数,用于对输入进行位移操作
def shift(t, amount, mask = None):
# 如果位移量为 0,则直接返回输入
if amount == 0:
return t
# 如果存在掩码,则根据掩码进行填充操作
if exists(mask):
t = t.masked_fill(~mask[..., None], 0.)
return F.pad(t, (0, 0, amount, -amount), value = 0.)
# 定义 PreShiftTokens 类,用于对输入进行预位移操作
class PreShiftTokens(nn.Module):
# 初始化函数,设置位移量和处理函数
def __init__(self, shifts, fn):
super().__init__()
self.fn = fn
self.shifts = tuple(shifts)
# 前向传播函数,对输入进行预位移处理
def forward(self, x, **kwargs):
# 获取掩码信息
mask = kwargs.get('mask', None)
shifts = self.shifts
segments = len(shifts)
feats_per_shift = x.shape[-1] // segments
splitted = x.split(feats_per_shift, dim = -1)
segments_to_shift, rest = splitted[:segments], splitted[segments:]
segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
x = torch.cat((*segments_to_shift, *rest), dim = -1)
return self.fn(x, **kwargs)
# 定义 FixedPositionalEmbedding 类,用于固定位置编码
class FixedPositionalEmbedding(nn.Module):
# 初始化函数,设置维度和最大序列长度
def __init__(self, dim, max_seq_len):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
position = torch.arange(0, max_seq_len, dtype=torch.float)
sinusoid_inp = torch.einsum("i,j->ij", position, inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
self.register_buffer('emb', emb)
# 前向传播函数,返回固定位置编码结果
def forward(self, x):
return self.emb[None, :x.shape[1], :].to(x)
# 定义 rotate_every_two 函数,用于对输入进行旋转操作
def rotate_every_two(x):
x = rearrange(x, '... (d j) -> ... d j', j = 2)
x1, x2 = x.unbind(dim = -1)
x = torch.stack((-x2, x1), dim = -1)
return rearrange(x, '... d j -> ... (d j)')
# 定义 apply_rotary_pos_emb 函数,用于应用旋转位置编码
def apply_rotary_pos_emb(q, k, v, sinu_pos):
sinu_pos = sinu_pos.type(q.dtype)
sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j = 2)
sin, cos = sinu_pos.unbind(dim = -2)
sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j = 2), (sin, cos))
q, k, v = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k, v))
return q, k, v
# 定义 update_kmeans_on_backwards 函数,用于在反向传播时更新 kmeans 模块
def update_kmeans_on_backwards(module):
module.kmean_modules = find_modules(module, Kmeans)
def hook(_, grad_in, grad_out):
for m in module.kmean_modules:
m.update()
return module.register_backward_hook(hook)
# 定义 similarity 函数,用于计算输入与均值之间的相似度
def similarity(x, means):
return torch.einsum('bhld,hcd->bhlc', x, means)
# 定义 dists_and_buckets 函数,用于计算距离和分桶
def dists_and_buckets(x, means):
dists = similarity(x, means)
_, buckets = torch.max(dists, dim=-1)
return dists, buckets
# 定义 batched_bincount 函数,用于批量计算索引的频次
def batched_bincount(index, num_classes, dim=-1):
shape = list(index.shape)
shape[dim] = num_classes
out = index.new_zeros(shape)
out.scatter_add_(dim, index, torch.ones_like(index, dtype=index.dtype))
return out
# 定义 kmeans_iter 函数,用于执行 kmeans 迭代
def kmeans_iter(x, means, buckets = None):
b, h, l, d, dtype, num_clusters = *x.shape, x.dtype, means.shape[1]
# 如果 buckets 不存在,则通过 dists_and_buckets 函数计算出来
if not exists(buckets):
_, buckets = dists_and_buckets(x, means)
# 对 buckets 进行批量计数,然后对结果进行求和
bins = batched_bincount(buckets, num_clusters).sum(0, keepdim=True)
# 创建一个与 bins 形状相同的布尔张量,标记 bins 中为 0 的位置
zero_mask = bins.long() == 0
# 创建一个与 buckets 相同形状的全零张量 means_
means_ = buckets.new_zeros(b, h, num_clusters, d, dtype=dtype)
# 在指定维度上对 means_ 进行 scatter_add_ 操作,将 x 散射到 means_ 上
means_.scatter_add_(-2, expand_dim(buckets, -1, d), x)
# 对 means_ 沿着指定维度求和,并进行归一化,然后转换为指定数据类型
means_ = F.normalize(means_.sum(0, keepdim=True), dim=-1).type(dtype)
# 使用 torch.where 函数根据 zero_mask 的值选择更新后的 means_ 或保持原来的 means
means = torch.where(zero_mask.unsqueeze(-1), means, means_)
# 去除 means 的第一个维度,返回结果
means = means.squeeze(0)
# 返回计算得到的 means
return means
# 根据距离矩阵和窗口大小,获取最大的 k 个索引
_, topk_indices = dists.topk(k=window_size, dim=-2)
# 转置索引矩阵
indices = topk_indices.transpose(-2, -1)
# 重新整形索引矩阵
return indices.reshape(*indices.size()[:2], -1)
# Kmeans 类定义
class Kmeans(nn.Module):
def __init__(self, num_heads, head_dim, num_clusters, ema_decay = 0.999, commitment = 1e-4):
super().__init__()
self.commitment = commitment
self.ema_decay = ema_decay
# 注册缓冲区,存储聚类中心和初始化状态
self.register_buffer('means', torch.randn(num_heads, num_clusters, head_dim))
self.register_buffer('initted', torch.tensor(False))
self.num_new_means = 0
self.new_means = None
@torch.no_grad()
def init(self, x):
if self.initted:
return
_, h, _, d, device, dtype = *x.shape, x.device, x.dtype
num_clusters = self.means.shape[1]
# 调整输入数据形状
means = x.transpose(0, 1).contiguous().view(h, -1, d)
num_samples = means.shape[1]
# 初始化聚类中心
if num_samples >= num_clusters:
indices = torch.randperm(num_samples, device=device)[:num_clusters]
else:
indices = torch.randint(0, num_samples, (num_clusters,), device=device)
means = means[:, indices]
# 迭代更新聚类中心
for _ in range(KMEAN_INIT_ITERS):
means = kmeans_iter(x, means)
self.num_new_means = 0
self.means.data.copy_(means)
self.initted.data.copy_(torch.tensor(True))
@torch.no_grad()
def update(self, new_means = None):
new_means = default(new_means, self.new_means)
assert exists(new_means), 'new kmeans has not been supplied'
# 更新聚类中心
ema_inplace(self.means, new_means, self.ema_decay)
del self.new_means
self.new_means = None
self.num_new_means = 0
def forward(self, x, update_means = False):
self.init(x)
b, dtype = x.shape[0], x.dtype
means = self.means.type(dtype)
x = F.normalize(x, 2, dim=-1).type(dtype)
with torch.no_grad():
dists, buckets = dists_and_buckets(x, means)
routed_means = batched_index_select(expand_dim(means, 0, b), buckets)
loss = F.mse_loss(x, routed_means) * self.commitment
if update_means:
with torch.no_grad():
means = kmeans_iter(x, means, buckets)
self.new_means = ema(self.new_means, means, self.num_new_means / (self.num_new_means + 1))
self.num_new_means += 1
return dists, loss
# KmeansAttention 类定义
class KmeansAttention(nn.Module):
def __init__(self, num_clusters, window_size, num_heads, head_dim, causal = False, dropout = 0., ema_decay = 0.999, commitment = 1e-4, context_window_size = None, receives_context = False, num_mem_kv = 0, shared_qk = False):
super().__init__()
self.num_heads = num_heads
self.num_clusters = num_clusters
self.head_dim = head_dim
self.window_size = window_size
self.context_window_size = default(context_window_size, window_size)
self.causal = causal
self.shared_qk = shared_qk
self.receives_context = receives_context
self.kmeans = Kmeans(num_heads, head_dim, num_clusters, ema_decay, commitment)
self.dropout = nn.Dropout(dropout)
self.num_mem_kv = max(num_mem_kv, 1 if causal and not shared_qk else 0)
self.mem_key = nn.Parameter(torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim))
self.mem_value = nn.Parameter(torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim))
# 定义前向传播函数,接受查询 q、键 k、值 v,以及可选的查询和键的掩码
def forward(self, q, k, v, query_mask = None, key_mask = None, **kwargs):
# 解包变量 b、h、t、d、kv_t、wsz、c_wsz、nc、device、dtype
b, h, t, d, kv_t, wsz, c_wsz, nc, device, dtype = *q.shape, k.shape[2], self.window_size, self.context_window_size, self.num_clusters, q.device, q.dtype
# 从 kwargs 中弹出 '_reverse' 键值对,默认为 False
is_reverse = kwargs.pop('_reverse', False)
# 创建与 q 相同形状的零张量 out
out = torch.zeros_like(q, dtype=dtype)
# 更新 kmeans 模型的标志,训练中且非反向传播时更新
update_kmeans = self.training and not is_reverse
# 如果不接收上下文信息,则 key_mask 默认为 query_mask
key_mask = default(key_mask, query_mask) if not self.receives_context else key_mask
# 如果不接收上下文信息,则 kv_wsz 为 wsz,否则为 c_wsz
kv_wsz = wsz if not self.receives_context else c_wsz
# 更新 wsz 和 kv_wsz 为 t 和 kv_t 的最小值
wsz = min(wsz, t)
kv_wsz = min(kv_wsz, kv_t)
# 如果不共享查询和键或者接收上下文信息
if not self.shared_qk or self.receives_context:
# 使用 kmeans 模型计算 q 和 k 的聚类中心距离,返回聚类中心距离和辅助损失
dists, aux_loss = self.kmeans(torch.cat((q, k), dim=2), update_kmeans)
# 将 dists 按索引 2 分割为 q_dists 和 k_dists
q_dists, k_dists = split_at_index(2, t, dists)
# 根据 q_dists 和 wsz 计算索引
indices = distribution(q_dists, wsz)
# 根据 k_dists 和 kv_wsz 计算索引
kv_indices = distribution(k_dists, kv_wsz)
else:
# 使用 kmeans 模型计算 q 的聚类中心距离,返回聚类中心距离和辅助损失
dists, aux_loss = self.kmeans(q, update_kmeans)
# 对 k 进行归一化,并转换为与 q 相同的类型
k = F.normalize(k, dim=-1).to(q)
# 根据 dists 和 wsz 计算索引
indices = distribution(dists, wsz)
# kv_indices 与 indices 相同
kv_indices = indices
# 根据索引选择 q、k、v 的子集
q = batched_index_select(q, indices)
k = batched_index_select(k, kv_indices)
v = batched_index_select(v, kv_indices)
# 定义 reshape_with_window 函数,用于将张量重塑为指定形状
reshape_with_window = lambda x: x.reshape(b, h, nc, -1, d)
# 将 q、k、v 分别应用 reshape_with_window 函��
q, k, v = map(reshape_with_window, (q, k, v))
# 将 self.mem_key 和 self.mem_value 扩展为与 q 相同的形状
m_k, m_v = map(lambda x: expand_dim(x, 0, b).to(q), (self.mem_key, self.mem_value))
# 将 k、v 与 m_k、m_v 连接在最后一个维度上
k, v = map(lambda x: torch.cat(x, dim=3), ((m_k, k), (m_v, v)))
# 计算点积,乘以缩放因子
dots = torch.einsum('bhnid,bhnjd->bhnij', q, k) * (d ** -0.5)
# 计算掩码值
mask_value = max_neg_value(dots)
# 如果存在查询或键的掩码
if exists(query_mask) or exists(key_mask):
# 默认创建查询掩码为全 1,键掩码为全 1
query_mask = default(query_mask, lambda: torch.ones((b, t), device=device).bool())
key_mask = default(key_mask, lambda: torch.ones((b, kv_t), device=device).bool())
# 根据 indices 和 kv_indices 从掩码中选择子集
q_mask = expand_dim(query_mask, 1, h).gather(2, indices)
kv_mask = expand_dim(key_mask, 1, h).gather(2, kv_indices)
# 将 q_mask、kv_mask 重塑为指定形状
q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (q_mask, kv_mask))
# 创建掩码,填充边界
mask = q_mask[:, :, :, :, None] * kv_mask[:, :, :, None, :]
mask = F.pad(mask, (self.num_mem_kv, 0), value=True)
# 将 dots 中不符合掩码条件的位置填充为 mask_value
dots.masked_fill_(~mask, mask_value)
del mask
# 如果是因果注意力机制
if self.causal:
# 将 indices、kv_indices 重塑为指定形状
q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices))
# 创建因果掩码
mask = q_mask[:, :, :, :, None] >= kv_mask[:, :, :, None, :]
mask = F.pad(mask, (self.num_mem_kv, 0), value=True)
# 将 dots 中不符合掩码条件的位置填充为 mask_value
dots.masked_fill_(~mask, mask_value)
del mask
# 如果共享查询和键
if self.shared_qk:
# 将 indices、kv_indices 重塑为指定形状
q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices))
# 创建自注意力掩码
mask = q_mask[:, :, :, :, None] == kv_mask[:, :, :, None, :]
mask = F.pad(mask, (self.num_mem_kv, 0), value=False)
# 将 dots 中符合掩码条件的位置填充为 TOKEN_SELF_ATTN_VALUE
dots.masked_fill_(mask, TOKEN_SELF_ATTN_VALUE)
del mask
# 对 dots 进行 softmax 操作
dots = dots.softmax(dim=-1)
# 对 dots 进行 dropout 操作
dots = self.dropout(dots)
# 计算输出张量 bo
bo = torch.einsum('bhcij,bhcjd->bhcid', dots, v)
# 将 bo 重塑为指定形状
so = torch.reshape(bo, (b, h, -1, bo.shape[-1])).type(dtype)
# 对输出张量 out 进行 scatter_mean 操作
out = scatter_mean(out, so, indices.unsqueeze(-1).expand_as(so), -2)
# 返回输出张量 out 和辅助损失
return out, aux_loss
# 定义 GELU 激活函数类
class GELU_(nn.Module):
# 前向传播函数
def forward(self, x):
# GELU 激活函数的计算公式
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))
# 如果 nn 模块中存在 GELU 函数,则使用 nn.GELU,否则使用自定义的 GELU_ 函数
GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_
# 定义前馈神经网络类
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0., activation = None, glu = False):
super().__init__()
# 设置激活函数为 GELU
activation = default(activation, GELU)
self.glu = glu
# 第一个全连接层
self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1))
# 激活函数层
self.act = activation()
# Dropout 层
self.dropout = nn.Dropout(dropout)
# 第二个全连接层
self.w2 = nn.Linear(dim * mult, dim)
# 前向传播函数
def forward(self, x, **kwargs):
if not self.glu:
# 非 GLU 模式下的前向传播
x = self.w1(x)
x = self.act(x)
else:
# GLU 模式下的前向传播
x, v = self.w1(x).chunk(2, dim=-1)
x = self.act(x) * v
x = self.dropout(x)
x = self.w2(x)
return x
# 自注意力机制类
class SelfAttention(nn.Module):
def __init__(self, dim, depth, max_seq_len, heads, local_attn_heads, window_size, dim_head = None, local_attn_window_size = None, local_attn_radius_blocks = 1, causal = False, attn_dropout = 0., dropout = 0., kmeans_ema_decay = 0.999, commitment_factor = 1e-4, receives_context = False, context_window_size = None, rel_pos_emb = True, num_mem_kv = 0, shared_qk = False, conv_query_kernel = 9):
super().__init__()
# 断言确保隐藏维度可以被头数整除
assert dim_head or (dim % heads) == 0, 'hidden dimension must be divisible by number of heads'
# 断言确保最大序列长度可以被窗口大小整除
assert (max_seq_len % window_size) == 0, 'maximum sequence length must be divisible by the target window size'
# 断言确保本地注意力头数小于总头数
assert local_attn_heads <= heads, 'number of local attention heads must be less than total heads'
# 断言确保本地注意力和上下文注意力不能同时使用
assert not (receives_context and local_attn_heads > 0), 'local attention cannot be used for self attention with context'
# 断言确保上下文注意力和因果��不能同时使用
assert not (receives_context and causal), 'contextual attention layer cannot be causal'
local_attn_window_size = default(local_attn_window_size, window_size)
context_window_size = default(context_window_size, window_size)
self.shared_qk = shared_qk
self.receives_context = receives_context
self.heads = heads
self.local_attn_heads = local_attn_heads
self.global_attn_heads = heads - local_attn_heads
self.causal = causal
self.window_size = window_size
dim_head = default(dim_head, dim // heads)
dim_heads = dim_head * heads
self.dim_head = dim_head
num_clusters = max_seq_len // window_size
# 本地注意力
local_dim_heads = dim_head * self.local_attn_heads
if self.local_attn_heads > 0:
rel_pos_emb_config = (dim_head, local_attn_heads) if rel_pos_emb else None
self.local_attn = LocalAttention(dim = dim_head, window_size = local_attn_window_size, causal = causal, dropout = attn_dropout, rel_pos_emb_config = rel_pos_emb_config, look_backward = local_attn_radius_blocks, look_forward = 0 if causal else local_attn_radius_blocks)
self.local_to_qkv = nn.Linear(dim, 3 * local_dim_heads)
# 全局注意力
global_dim_heads = dim_head * self.global_attn_heads
if self.global_attn_heads > 0:
self.global_attn = KmeansAttention(num_clusters, window_size, self.global_attn_heads, dim_head, causal = causal, dropout = attn_dropout, ema_decay = kmeans_ema_decay, commitment = commitment_factor, receives_context = receives_context, num_mem_kv = num_mem_kv, shared_qk = shared_qk)
self.to_q = nn.Linear(dim, global_dim_heads, bias = False)
self.to_v = nn.Linear(dim, global_dim_heads, bias = False)
if not self.shared_qk:
self.to_k = nn.Linear(dim, global_dim_heads, bias = False)
# 输出
self.to_out = nn.Linear(dim_heads, dim, bias = False)
self.dropout = nn.Dropout(dropout)
# 定义前向传播函数,接受输入 x 和其他参数
def forward(self, x, context = None, input_mask = None, context_mask = None, pos_emb = None, **kwargs):
# 断言如果需要上下文信息但未传入,则抛出异常
assert not (self.receives_context and not exists(context)), 'context must be passed if self attention is set to receive context'
# 获取输入 x 的形状信息
b, t, e, h, dh = *x.shape, self.heads, self.dim_head
# 判断是否存在局部和全局注意力头
has_local, has_global = map(lambda x: x > 0, (self.local_attn_heads, self.global_attn_heads))
# 定义函数用于将输入张量按照头数进行分割
split_heads = lambda v: reshape_dim(v, -1, (-1, dh)).transpose(1, 2).contiguous()
# 如果存在局部注意力头
if has_local:
# 将局部注意力头的查询、键、值分别提取出来并按头数分割
local_qkv = self.local_to_qkv(x).chunk(3, dim=-1)
lq, lk, lv = map(split_heads, local_qkv)
# 如果存在全局注意力头
if has_global:
# 根据是否接收上下文信息选择输入作为查询和值
kv_input = x if not self.receives_context else context
# 将查询和值分别转换为 Q 和 V,并按头数分割
q, v = self.to_q(x), self.to_v(kv_input)
# 如果不共享 Q 和 K,则将键也转换为 K,否则根据是否接收上下文信息选择使用 Q 或者 K
if not self.shared_qk:
k = self.to_k(kv_input)
else:
k = self.to_q(kv_input) if self.receives_context else q
q, k, v = map(split_heads, (q, k, v))
# 初始化输出列表和总损失
out = []
total_loss = torch.tensor(0., requires_grad=True, **to(x))
# 如果存在局部注意力头
if has_local:
# 使用局部注意力计算输出
local_out = self.local_attn(lq, lk, lv, input_mask = input_mask)
out.append(local_out)
# 如果存在全局注意力头
if has_global:
# 如果不接收上下文信息且存在位置编码,则应用位置编码
if not self.receives_context and exists(pos_emb):
q, k, v = apply_rotary_pos_emb(q, k, v, pos_emb)
# 使用全局注意力计算输出和损失
global_out, loss = self.global_attn(q, k, v, query_mask = input_mask, key_mask = context_mask)
total_loss = total_loss + loss
out.append(global_out)
# 将所有输出拼接在一起
out = torch.cat(out, dim=1)
# 重塑输出张量的形状
out = out.reshape(b, h, t, -1).transpose(1, 2).reshape(b, t, -1)
# 将输出传递给输出层,并应用 dropout
out = self.to_out(out)
return self.dropout(out), total_loss
class RoutingTransformer(nn.Module):
# 定义一个路由变换器类,继承自 nn.Module
def __init__(
self,
dim,
depth,
max_seq_len,
heads = 8,
dim_head = None,
window_size = 64,
local_attn_window_size = 256,
local_attn_radius_blocks = 1,
causal = False,
weight_tie = False,
attn_dropout = 0.,
ff_dropout = 0.,
attn_layer_dropout = 0.,
layer_dropout = 0.,
n_local_attn_heads = 0,
ff_glu = False,
reversible = False,
ff_chunks = 1,
kmeans_ema_decay = 0.999,
commitment_factor = 1e-4,
receives_context = False,
context_window_size = None,
_register_kmeans_update = False,
rel_pos_emb = True,
pkm_layers = tuple(),
pkm_num_keys = 128,
moe_layers = tuple(),
moe_num_experts = 4,
moe_loss_coef = 1e-2,
num_mem_kv = 0,
shared_qk = None,
context_shared_qk = False,
use_rezero = False,
use_scale_norm = False,
ff_activation = None,
shift_tokens = False
# 初始化函数,设置路由变换器的各种参数
def cancel_kmeans_update(self):
# 取消 K-means 更新
if not exists(self._handle):
return
self._handle.remove()
self._handle = None
def register_kmeans_update(self):
# 注册 K-means 更新
self._handle = update_kmeans_on_backwards(self)
def forward(self, x, **kwargs):
# 前向传播函数
x, loss = self.layers(x, **kwargs)
return x, loss
class RoutingTransformerLM(nn.Module):
# 定义一个路由变换器语言模型类,继承自 nn.Module
def __init__(
self,
num_tokens,
dim,
depth,
max_seq_len,
heads = 8,
dim_head = 64,
window_size = 64,
local_attn_window_size = None,
local_attn_radius_blocks = 1,
causal = False,
emb_dim = None,
weight_tie = False,
attn_dropout = 0.,
ff_dropout = 0.,
attn_layer_dropout = 0.,
layer_dropout = 0.,
ff_mult = 4,
ff_activation = None,
ff_glu = False,
return_embeddings = False,
n_local_attn_heads = 0,
reversible = False,
ff_chunks = 1,
kmeans_ema_decay = 0.999,
commitment_factor = 1e-4,
receives_context = False,
context_window_size = None,
rel_pos_emb = True,
_register_kmeans_update = True,
pkm_layers = tuple(),
pkm_num_keys = 128,
moe_layers = tuple(),
moe_num_experts = 4,
moe_loss_coef = 1e-2,
num_mem_kv = 0,
shared_qk = None,
context_shared_qk = False,
use_rezero = False,
use_scale_norm = False,
tie_embedding = False,
use_absolute_pos_emb = False,
shift_tokens = False
# 初始化函数,设置路由变换器语言模型的各种参数
):
# 调用父类的构造函数
super().__init__()
# 断言最大序列长度必须能被窗口大小整除,以计算 kmeans 簇的数量
assert (max_seq_len % window_size) == 0, 'max sequence length must be divisible by the window size, to calculate number of kmeans cluster'
# 如果未指定嵌入维度,则使用默认维度
emb_dim = default(emb_dim, dim)
# 初始化最大序列长度和正弦位置编码
self.max_seq_len = max_seq_len
self.sinu_pos_emb = FixedPositionalEmbedding(dim_head, max_seq_len)
# 初始化标记嵌入层
self.token_emb = nn.Embedding(num_tokens, emb_dim)
# 使用正态分布初始化权重
nn.init.normal_(self.token_emb.weight, std = 0.02)
# 初始化路由变换器
self.routing_transformer = RoutingTransformer(dim, depth, max_seq_len, heads = heads, dim_head = dim_head, window_size = window_size, local_attn_window_size = local_attn_window_size, local_attn_radius_blocks = local_attn_radius_blocks, causal = causal, weight_tie = weight_tie, ff_dropout = ff_dropout, attn_dropout = attn_dropout, attn_layer_dropout = attn_layer_dropout, layer_dropout = layer_dropout, n_local_attn_heads = n_local_attn_heads, ff_glu = ff_glu, reversible = reversible, ff_chunks = ff_chunks, kmeans_ema_decay = kmeans_ema_decay, receives_context = receives_context, context_window_size = context_window_size, rel_pos_emb = rel_pos_emb, pkm_layers = pkm_layers, pkm_num_keys = pkm_num_keys, moe_layers = moe_layers, moe_num_experts = moe_num_experts, moe_loss_coef = moe_loss_coef, num_mem_kv = num_mem_kv, shared_qk = shared_qk, context_shared_qk = context_shared_qk, _register_kmeans_update = _register_kmeans_update, use_rezero = use_rezero, use_scale_norm = use_scale_norm, ff_activation = ff_activation, shift_tokens = shift_tokens)
# 如果嵌入维度不等于维度,则使用 ProjectInOut 进行维度转换
if emb_dim != dim:
self.routing_transformer = ProjectInOut(self.routing_transformer, emb_dim, dim, project_out = not return_embeddings)
# 初始化 LayerNorm 层
self.norm = nn.LayerNorm(emb_dim)
# 根据返回嵌入标志选择输出层
if return_embeddings:
self.out = nn.Identity()
elif tie_embedding:
self.out = MatrixMultiply(self.token_emb.weight, transpose = True)
else:
self.out = nn.Linear(emb_dim, num_tokens)
# 取消 kmeans 更新
def cancel_kmeans_update(self):
# 找到 RoutingTransformer 模块并取消 kmeans 更新
transformer = find_modules(self, RoutingTransformer)[0]
transformer.cancel_kmeans_update()
# ���新 kmeans
def update_kmeans(self):
# 对于所有的 Kmeans 模块,执行更新
for m in find_modules(self, Kmeans):
m.update()
# 前向传播函数
def forward(self, x, **kwargs):
# 对输入进行标记嵌入
x = self.token_emb(x)
# 计算旋转位置编码
rotary_pos_emb = self.sinu_pos_emb(x)
# 使用路由变换器进行前向传播
x, loss = self.routing_transformer(x, pos_emb = rotary_pos_emb, **kwargs)
# 对输出进行 LayerNorm
x = self.norm(x)
# 返回输出和损失
return self.out(x), loss
.\lucidrains\routing-transformer\routing_transformer\__init__.py
# 从 routing_transformer 包中导入 RoutingTransformer、RoutingTransformerLM、KmeansAttention、update_kmeans_on_backwards 类
from routing_transformer.routing_transformer import RoutingTransformer, RoutingTransformerLM, KmeansAttention, update_kmeans_on_backwards
# 从 routing_transformer 包中导入 RoutingTransformerEncDec 类
from routing_transformer.encoder_decoder import RoutingTransformerEncDec
# 从 routing_transformer 包中导入 AutoregressiveWrapper 类
from routing_transformer.autoregressive_wrapper import AutoregressiveWrapper
# 从 routing_transformer 包中导入 Autopadder 类
from routing_transformer.autopadder import Autopadder
.\lucidrains\routing-transformer\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'routing_transformer', # 包的名称
packages = find_packages(exclude=['examples']), # 查找并包含除了 examples 之外的所有包
version = '1.6.1', # 版本号
license='MIT', # 许可证
description = 'Routing Transformer (Pytorch)', # 描述
author = 'Phil Wang, Aran Komatsuzaki', # 作者
author_email = 'lucidrains@gmail.com, aran1234321@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/routing-transformer', # 项目链接
keywords = ['transformers', 'attention', 'artificial intelligence'], # 关键词
install_requires=[
'einops', # 安装所需的依赖包
'local-attention>=1.4.0',
'mixture-of-experts>=0.2.0',
'product-key-memory',
'torch'
],
classifiers=[
'Development Status :: 4 - Beta', # 分类器
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
Data source
The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
RQ-Transformer
Implementation of RQ Transformer, which proposes a more efficient way of training multi-dimensional sequences autoregressively. This repository will only contain the transformer for now. You can use this vector quantization library for the residual VQ.
This type of axial autoregressive transformer should be compatible with memcodes, proposed in NWT. It would likely also work well with multi-headed VQ
Install
$ pip install RQ-transformer
Usage
import torch
from rq_transformer import RQTransformer
model = RQTransformer(
num_tokens = 16000, # number of tokens, in the paper they had a codebook size of 16k
dim = 512, # transformer model dimension
max_spatial_seq_len = 1024, # maximum positions along space
depth_seq_len = 4, # number of positions along depth (residual quantizations in paper)
spatial_layers = 8, # number of layers for space
depth_layers = 4, # number of layers for depth
dim_head = 64, # dimension per head
heads = 8, # number of attention heads
)
x = torch.randint(0, 16000, (1, 1024, 4))
loss = model(x, return_loss = True)
loss.backward()
# then after much training
logits = model(x)
# and sample from the logits accordingly
# or you can use the generate function
sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4)
I also think there is something deeper going on, and have generalized this to any number of dimensions. You can use it by importing the HierarchicalCausalTransformer
import torch
from rq_transformer import HierarchicalCausalTransformer
model = HierarchicalCausalTransformer(
num_tokens = 16000, # number of tokens
dim = 512, # feature dimension
dim_head = 64, # dimension of attention heads
heads = 8, # number of attention heads
depth = (4, 4, 2), # 3 stages (but can be any number) - transformer of depths 4, 4, 2
max_seq_len = (16, 4, 5) # the maximum sequence length of first, stage, then the fixed sequence length of all subsequent stages
).cuda()
x = torch.randint(0, 16000, (1, 10, 4, 5)).cuda()
loss = model(x, return_loss = True)
loss.backward()
# after a lot training
sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 16, 4, 5)
Todo
Citations
@unknown{unknown,
author = {Lee, Doyup and Kim, Chiheon and Kim, Saehoon and Cho, Minsu and Han, Wook-Shin},
year = {2022},
month = {03},
title = {Autoregressive Image Generation using Residual Quantization}
}
@misc{press2021ALiBi,
title = {Train Short, Test Long: Attention with Linear Biases Enable Input Length Extrapolation},
author = {Ofir Press and Noah A. Smith and Mike Lewis},
year = {2021},
url = {https://ofir.io/train_short_test_long.pdf}
}
.\lucidrains\RQ-Transformer\rq_transformer\hierarchical_causal_transformer.py
# 导入数学库
import math
# 导入 functools 库
import functools
# 导入 torch 库
import torch
# 导入 torch.nn.functional 库
import torch.nn.functional as F
# 从 torch 中导入 nn 和 einsum
from torch import nn, einsum
# 从 einops_exts 中导入 rearrange_with_anon_dims
from einops_exts import rearrange_with_anon_dims
# 从 einops 中导入 rearrange, reduce, repeat
# helpers
# 判断值是否存在
def exists(val):
return val is not None
# 如果值存在则返回该值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 计算 num 与 mult 的余数
def remainder_to_mult(num, mult):
return (mult - num % mult) % mult
# 将输入转换为元组
def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else ((t,) * length)
# 对多个数进行乘法运算
def reduce_mult(nums):
return functools.reduce(lambda x, y: x * y, nums, 1)
# tensor helpers
# 计算张量的对数
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
# 生成 Gumbel 噪声
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
# 生成 Gumbel 分布采样
def gumbel_sample(t, temperature = 1., dim = -1):
return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)
# 获取前 k 个最大值的概率
def top_k(logits, thres = 0.5):
num_logits = logits.shape[-1]
k = max(int((1 - thres) * num_logits), 1)
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# positional bias
# 定义 Alibi 类
class Alibi(nn.Module):
def __init__(self, heads, **kwargs):
super().__init__()
self.heads = heads
slopes = torch.Tensor(self._get_slopes(heads))
slopes = rearrange(slopes, 'h -> h 1 1')
self.register_buffer('slopes', slopes, persistent = False)
self.register_buffer('bias', None, persistent = False)
@staticmethod
def _get_slopes(heads):
def get_slopes_power_of_2(n):
start = (2**(-2**-(math.log2(n)-3)))
ratio = start
return [start*ratio**i for i in range(n)]
if math.log2(heads).is_integer():
return get_slopes_power_of_2(heads)
closest_power_of_2 = 2 ** math.floor(math.log2(heads))
return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
def forward(self, i, j, device):
if exists(self.bias) and self.bias.shape[-1] >= j:
return self.bias[..., :j]
bias = torch.arange(j, device = device)
bias = rearrange(bias, 'j -> 1 1 j')
bias = bias * self.slopes
self.register_buffer('bias', bias, persistent = False)
return self.bias
# norm
# 定义 RMSNorm 类
class RMSNorm(nn.Module):
def __init__(self, dim, eps = 1e-8):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
return x / norm.clamp(min = self.eps) * self.g
# helper classes
# 定义 FeedForward 函数
def FeedForward(*, dim, mult = 4, dropout = 0.):
return nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, dim * mult),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
# 定义 Attention 类
class Attention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
dropout = 0.
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.dropout = nn.Dropout(dropout)
self.norm = RMSNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
# 实现自注意力机制的前向传播
def forward(self, x, attn_bias = None):
# 获取头数和设备信息
h, device = self.heads, x.device
# 对输入进行归一化处理
x = self.norm(x)
# 将输入转换为查询、键、值
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
# 将查询向量重新排列为多头形式
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
# 缩放查询向量
q = q * self.scale
# 计算注意力分数
sim = einsum('b h i d, b j d -> b h i j', q, k)
# 如果存在注意力偏置,则加上
if exists(attn_bias):
sim = sim + attn_bias
# 创建掩码
i, j = sim.shape[-2:]
mask_value = -torch.finfo(sim.dtype).max
mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
sim = sim.masked_fill(mask, mask_value)
# 对注意力分数进行归一化处理
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# 计算输出
out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# 定义一个名为 Transformer 的类,继承自 nn.Module
class Transformer(nn.Module):
# 初始化函数,接受多个参数
def __init__(
self,
*,
dim, # 维度
layers, # 层数
dim_head = 64, # 头部维度
heads = 8, # 头部数量
attn_dropout = 0., # 注意力机制的 dropout
ff_dropout = 0., # 前馈神经网络的 dropout
ff_mult = 4, # 前馈神经网络的倍数
rel_pos_bias = True # 是否使用相对位置偏置
):
super().__init__()
# 如果使用相对位置偏置,则创建 Alibi 对象,否则为 None
self.alibi = Alibi(heads = heads) if rel_pos_bias else None
# 创建空的 nn.ModuleList 对象
self.layers = nn.ModuleList([])
# 循环创建 layers 个层
for _ in range(layers):
# 每个层包含一个注意力机制和一个前馈神经网络
self.layers.append(nn.ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
]))
# 创建 RMSNorm 对象
self.norm = RMSNorm(dim)
# 前向传播函数
def forward(self, x):
# 获取输入张量 x 的倒数第二个维度的大小
n = x.shape[-2]
# 如果存在相对位置偏置,则根据输入张量 x 的设备创建注意力偏置
attn_bias = self.alibi(n, n, device = x.device) if exists(self.alibi) else None
# 遍历每个层中的注意力机制和前馈神经网络
for attn, ff in self.layers:
# 使用注意力机制处理输入张量 x,并加上原始输入
x = attn(x, attn_bias = attn_bias) + x
# 使用前馈神经网络处理输入张量 x,并加上原始输入
x = ff(x) + x
# 返回经过归一化处理后的结果
return self.norm(x)
# 主类
class HierarchicalCausalTransformer(nn.Module):
# 初始化函数,接受多个参数
def __init__(
self,
*,
num_tokens, # 标记数量
dim, # 维度
depth, # 深度
max_seq_len, # 最大序列长度
dim_head = 64, # 头部维度
heads = 8, # 头部数量
attn_dropout = 0., # 注意力机制的 dropout
ff_mult = 4, # 前馈神经网络的倍数
ff_dropout = 0., # 前馈神经网络的 dropout
pad_id = 0, # 填充标记的 id
rel_pos_bias = True # 是否使用相对位置偏置
):
super().__init__()
# 简化每个层次的配置
# depth = (2, 2, 4) ���示第一阶段深度为 2,第二阶段深度为 2,第三阶段深度为 4
# max_seq_len = (16, 8, 4) 表示第一阶段最大序列长度为 16,第二阶段为 8,第三阶段为 4
assert isinstance(depth, tuple) and isinstance(max_seq_len, tuple)
assert len(depth) == len(max_seq_len)
# 阶段数量为深度元组的长度
self.stages = len(depth)
# 创建标记嵌入层
self.token_emb = nn.Embedding(num_tokens, dim)
# 创建起始标记参数
self.start_tokens = nn.Parameter(torch.randn(dim))
# 最大序列长度和位置嵌入层列表
self.max_seq_len = max_seq_len
self.pos_embs = nn.ModuleList([nn.Embedding(seq_len, dim) for seq_len in max_seq_len])
# 创建 Transformer 模块列表
self.transformers = nn.ModuleList([])
# 遍历每个阶段的深度
for stage_depth in depth:
# 创建 Transformer 模块并添加到列表中
self.transformers.append(Transformer(
dim = dim,
layers = stage_depth,
dim_head = dim_head,
heads = heads,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout,
ff_mult = ff_mult,
rel_pos_bias = rel_pos_bias
))
# 创建线性层用于输出标记
self.to_logits = nn.Linear(dim, num_tokens)
# 填充标记的 id
self.pad_id = pad_id
# 生成函数
def generate(self, prime = None, filter_thres = 0.9, temperature = 1., default_batch_size = 1):
# 计算总序列长度
total_seq_len = reduce_mult(self.max_seq_len)
# 获取设备
device = next(self.parameters()).device
# 如果 prime 为空,则创建一个空的张量
if not exists(prime):
prime = torch.empty((default_batch_size, 0), dtype = torch.long, device = device)
# 初始化序列为 prime
seq = prime
# 循环生成序列
for _ in range(total_seq_len - seq.shape[-1]):
# 获取 logits
logits = self.forward(seq)[:, -1]
# 根据 filter_thres 过滤 top-k logits
logits = top_k(logits, thres = filter_thres)
# 使用 Gumbel 分布采样
sampled = gumbel_sample(logits, dim = -1, temperature = temperature)
# 将采样结果拼接到序列中
seq = torch.cat((seq, rearrange(sampled, 'b -> b 1')), dim = -1)
# 重新排列序列并返回
return rearrange_with_anon_dims(seq, 'b (...d) -> b ...d', d = self.max_seq_len)
# 空输入前向传播函数
def forward_empty(self, batch_size):
# 处理特殊情况,从输入为 0(仅起始标记)的样本中采样
# 重复起始标记,创建 tokens 张量
tokens = repeat(self.start_tokens, 'd -> b 1 d', b = batch_size)
# 遍历每个 Transformer 模块
for transformer in self.transformers:
tokens = transformer(tokens)
# 返回 logits
return self.to_logits(tokens)
# 定义前向传播函数,接受输入 ids 和是否返回损失值的标志
def forward(self, ids, return_loss = False):
# 断言输入 ids 的维度为 2 或者 self.stages + 1
assert ids.ndim in {2, self.stages + 1}
# 检查是否为扁平化维度
flattened_dims = ids.ndim == 2
# 保存原始 ids 的维度
ids_orig_ndim = ids.ndim
# 如果 ids 为空,则调用 forward_empty 函数
if ids.numel() == 0:
return self.forward_empty(ids.shape[0])
# 如果是扁平化维度,则进行自动填充
if flattened_dims:
# 获取序列长度
seq_len = ids.shape[-1]
# 计算填充值
multiple_of = reduce_mult(self.max_seq_len[1:])
padding = remainder_to_mult(seq_len, multiple_of)
# 对 ids 进行填充和重新排列
ids = F.pad(ids, (0, padding), value = self.pad_id)
ids = rearrange_with_anon_dims(ids, 'b (l ...d) -> b l ...d', d = self.max_seq_len[1:])
# 获取 ids 的形状和设备信息
b, *prec_dims, device = *ids.shape, ids.device
# 检查一些维度
assert prec_dims[0] <= self.max_seq_len[0], 'the first dimension of your axial autoregressive transformer must be less than the first tuple element of max_seq_len (like any autoregressive transformer)'
assert tuple(prec_dims[1:]) == tuple(self.max_seq_len[1:]), 'all subsequent dimensions must match exactly'
# 获取 token embeddings
tokens = self.token_emb(ids)
# 获取所有层次阶段的 tokens,减少适当的维度并添加绝对位置嵌入
tokens_at_stages = []
reduced_tokens = tokens
for ind, pos_emb in zip(range(len(prec_dims)), reversed(self.pos_embs)):
is_first = ind == 0
if not is_first:
reduced_tokens = reduce(reduced_tokens, 'b ... r d -> b ... d', 'sum')
positions = pos_emb(torch.arange(reduced_tokens.shape[-2], device = device))
tokens_with_position = reduced_tokens + positions
tokens_at_stages.insert(0, tokens_with_position)
# 获取起始 tokens 并附加到最粗糙的阶段
start_tokens = repeat(self.start_tokens, 'f -> b 1 f', b = b)
# 空间 tokens 是在深度 pos 减少的 tokens + 空间位置
for ind, (stage_tokens, transformer) in enumerate(zip(tokens_at_stages, self.transformers)):
is_last = ind == (self.stages - 1)
stage_tokens = torch.cat((
start_tokens,
stage_tokens,
), dim = -2)
*prec_dims, _, _ = stage_tokens.shape
stage_tokens = rearrange(stage_tokens, '... n d -> (...) n d')
attended = transformer(stage_tokens)
attended = rearrange_with_anon_dims(attended, '(...b) n d -> ...b n d', b = prec_dims)
start_tokens = rearrange(attended[..., :-1, :], '... n d -> ... n 1 d')
logits = self.to_logits(attended)
logits = logits[..., 1:, :]
# 如果不需要返回损失值
if not return_loss:
if flattened_dims:
logits = rearrange(logits, 'b ... n -> b (...) n')
logits = logits[:, :seq_len]
return logits
preds = rearrange(logits, 'b ... c -> b c (...)')
labels = rearrange(ids, 'b ... -> b (...)')
# 计算交叉熵损失
loss = F.cross_entropy(
preds[..., :-1],
labels[..., 1:],
ignore_index = self.pad_id
)
return loss
.\lucidrains\RQ-Transformer\rq_transformer\rq_transformer.py
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops_exts import rearrange_with_anon_dims
from einops import rearrange, reduce, repeat
# helpers
# 检查值是否存在
def exists(val):
return val is not None
# 返回值或默认值
def default(val, d):
return val if exists(val) else d
# 计算余数到最接近的倍数
def remainder_to_mult(num, mult):
return (mult - num % mult) % mult
# 计算对数,避免值过小
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
# 生成 Gumbel 噪声
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
# 生成 Gumbel 分布采样
def gumbel_sample(t, temperature = 1., dim = -1):
return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)
# 保留前 k 个最大值,其余设为负无穷
def top_k(logits, thres = 0.5):
num_logits = logits.shape[-1]
k = max(int((1 - thres) * num_logits), 1)
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# helper classes
# 前馈神经网络
def FeedForward(*, dim, mult = 4, dropout = 0.):
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * mult),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
# 注意力机制
class Attention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
dropout = 0.
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.dropout = nn.Dropout(dropout)
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
h, device = self.heads, x.device
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
i, j = sim.shape[-2:]
mask_value = -torch.finfo(sim.dtype).max
mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
sim = sim.masked_fill(mask, mask_value)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# Transformer 模块
class Transformer(nn.Module):
def __init__(
self,
*,
dim,
layers,
dim_head = 64,
heads = 8,
attn_dropout = 0.,
ff_dropout = 0.,
ff_mult = 4
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(layers):
self.layers.append(nn.ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
]))
self.norm = nn.LayerNorm(dim)
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
# 主类
class RQTransformer(nn.Module):
def __init__(
self,
*,
num_tokens,
dim,
max_spatial_seq_len,
depth_seq_len,
spatial_layers,
depth_layers,
dim_head = 64,
heads = 8,
attn_dropout = 0.,
ff_mult = 4,
ff_dropout = 0.,
pad_id = 0
):
# 调用父类的构造函数
super().__init__()
# 初始化模型的维度
self.dim = dim
# 初始化空间序列的最大长度
self.max_spatial_seq_len = max_spatial_seq_len
# 初始化深度序列的长度
self.depth_seq_len = depth_seq_len
# 创建一个词嵌入层,用于将输入的标记转换为向量表示
self.token_emb = nn.Embedding(num_tokens, dim)
# 初始化空间序列的起始标记
self.spatial_start_token = nn.Parameter(torch.randn(dim))
# 创建一个空间位置编码层
self.spatial_pos_emb = nn.Embedding(max_spatial_seq_len + 1, dim) # 考虑到一个边界情况
# 创建一个深度位置编码层
self.depth_pos_emb = nn.Embedding(depth_seq_len, dim)
# 创建一个空间变换器,用于处理空间序列的变换
self.spatial_transformer = Transformer(
dim = dim,
layers = spatial_layers,
dim_head = dim_head,
heads = heads,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout,
ff_mult = ff_mult
)
# 创建一个深度变换器,用于处理深度序列的变换
self.depth_transformer = Transformer(
dim = dim,
layers = depth_layers,
dim_head = dim_head,
heads = heads,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout,
ff_mult = ff_mult
)
# 创建一个线性层,用于将模型输出转换为标记的概率分布
self.to_logits = nn.Linear(dim, num_tokens)
# 初始化填充标记的ID
self.pad_id = pad_id
def generate(self, prime = None, filter_thres = 0.9, temperature = 1., default_batch_size = 1):
# 计算总的序列长度
total_seq_len = self.depth_seq_len * self.max_spatial_seq_len
# 获取模型所在的设备
device = next(self.parameters()).device
# 如果没有给定初始输入,则创建一个空的张量作为初始输入
if not exists(prime):
prime = torch.empty((default_batch_size, 0), dtype = torch.long, device = device)
seq = prime
# 生成序列
for _ in range(total_seq_len - seq.shape[-1]):
# 获取模型的预测结果
logits = self.forward(seq)[:, -1]
# 通过阈值筛选保留概率较高的标记
logits = top_k(logits, thres = filter_thres)
# 通过Gumbel采样获取下一个标记
sampled = gumbel_sample(logits, dim = -1, temperature = temperature)
# 将新生成的标记添加到序列中
seq = torch.cat((seq, rearrange(sampled, 'b -> b 1')), dim = -1)
# 重新排列生成的序列
return rearrange(seq, 'b (s d) -> b s d', d = self.depth_seq_len)
def forward_empty(self, batch_size):
# 处理特殊情况,当从输入中只采样到0(仅起始标记)时
# 重复空间起始标记,以匹配指定的批量大小
spatial_tokens = repeat(self.spatial_start_token, 'd -> b 1 d', b = batch_size)
# 经过空间变换器处理
depth_tokens = self.spatial_transformer(spatial_tokens)
# 经过深度变换器处理
depth_tokens = self.depth_transformer(depth_tokens)
# 将处理后的深度标记转换为模型输出
return self.to_logits(depth_tokens)
# 定义前向传播函数,接受输入 ids 和是否返回损失值的标志
def forward(self, ids, return_loss = False):
# 断言输入 ids 的维度为 2 或 3
assert ids.ndim in {2, 3}
# 检查是否为扁平化维度
flattened_dim = ids.ndim == 2
# 保存原始 ids 的维度
ids_orig_ndim = ids.ndim
# 如果 ids 中元素数量为 0,则调用 forward_empty 函数处理
if ids.numel() == 0:
return self.forward_empty(ids.shape[0])
# 如果是扁平化维度
if flattened_dim:
# 允许 ids 的形状为 (batch, seq),自动填充到最接近深度序列长度的倍数
seq_len = ids.shape[-1]
padding = remainder_to_mult(seq_len, self.depth_seq_len)
ids = F.pad(ids, (0, padding), value = self.pad_id)
ids = rearrange(ids, 'b (s d) -> b s d', d = self.depth_seq_len)
else:
seq_len = ids.shape[1] * ids.shape[2]
# 获取 ids 的形状、空间维度、深度维度、设备信息
b, space, depth, device = *ids.shape, ids.device
# 断言空间维度小于等于最大空间序列长度加一
assert space <= (self.max_spatial_seq_len + 1), 'spatial dimension is greater than the max_spatial_seq_len set'
# 断言深度维度等于深度序列长度
assert depth == self.depth_seq_len, 'depth dimension must be equal to depth_seq_len'
# 获取 token embeddings
tokens = self.token_emb(ids)
# 获取空间位置编码和深度位置编码
spatial_pos = self.spatial_pos_emb(torch.arange(space, device = device))
depth_pos = self.depth_pos_emb(torch.arange(depth, device = device))
# 将 token embeddings 和深度位置编码相加
tokens_with_depth_pos = tokens + depth_pos
# 计算空间 tokens
spatial_tokens = reduce(tokens_with_depth_pos, 'b s d f -> b s f', 'sum') + spatial_pos
# 在空间 tokens 前添加起始 token
spatial_tokens = torch.cat((
repeat(self.spatial_start_token, 'f -> b 1 f', b = b),
spatial_tokens
), dim = -2)
# 使用空间 transformer 处理空间 tokens
spatial_tokens = self.spatial_transformer(spatial_tokens)
# 重新排列空间 tokens 的维度
spatial_tokens = rearrange(spatial_tokens, 'b s f -> b s 1 f')
# 将空间 tokens 变为深度维度的起始 tokens
tokens_with_depth_pos = F.pad(tokens_with_depth_pos, (0, 0, 0, 0, 0, 1), value = 0.)
# 拼��深度 tokens
depth_tokens = torch.cat((spatial_tokens, tokens_with_depth_pos), dim = -2)
# 重新排列深度 tokens 的维度
depth_tokens = rearrange(depth_tokens, '... n d -> (...) n d')
# 使用深度 transformer 处理深度 tokens
depth_tokens = self.depth_transformer(depth_tokens)
# 重新排列深度 tokens 的维度
depth_tokens = rearrange(depth_tokens, '(b s) d f -> b s d f', b = b)
# 获取 logits
logits = self.to_logits(depth_tokens)
logits = rearrange(logits, 'b ... f -> b (...) f')
logits = logits[:, :(seq_len + 1)]
# 如果不需要返回损失值
if not return_loss:
logits = logits[:, 1:]
# 如果是扁平化维度,则返回重新排列后的 logits
if flattened_dim:
return rearrange(logits, 'b ... n -> b (...) n')
return logits
# 如果需要返回损失值
logits = logits[:, :-1]
# 重新排列 logits 和 ids 的维度
preds = rearrange(logits, 'b ... c -> b c (...)')
labels = rearrange(ids, 'b s d -> b (s d)')
# 计算交叉熵损失
loss = F.cross_entropy(preds, labels, ignore_index = self.pad_id)
return loss
.\lucidrains\RQ-Transformer\rq_transformer\__init__.py
# 从 rq_transformer 模块中导入 RQTransformer 类
from rq_transformer.rq_transformer import RQTransformer
# 从 rq_transformer 模块中导入 HierarchicalCausalTransformer 类
from rq_transformer.hierarchical_causal_transformer import HierarchicalCausalTransformer
.\lucidrains\RQ-Transformer\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'RQ-transformer', # 包名
packages = find_packages(exclude=[]), # 查找所有包
version = '0.1.9', # 版本号
license='MIT', # 许可证
description = 'RQ Transformer - Autoregressive Transformer for Residual Quantized Codes', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/RQ-transformer', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'transformers',
'attention-mechanism',
'autoregressive',
],
install_requires=[ # 安装依赖
'einops>=0.4',
'einops-exts',
'torch>=1.6',
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\RQ-Transformer\train.py
# 导入所需的库
from rq_transformer import HierarchicalCausalTransformer
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 常量定义
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
PRIME_LEN = 100
SEQ_LEN = 1024
# 辅助函数
# 从 token 解码为字符
def decode_token(token):
return str(chr(max(32, token)))
# 从 tokens 解码为字符串
def decode_tokens(tokens):
return ''.join(list(map(decode_token, tokens)))
# 实例化类似 GPT 的解码器模型
model = HierarchicalCausalTransformer(
num_tokens = 256,
dim = 512,
depth = (4, 3, 3, 3),
max_seq_len = (4, 4, 8, 8)
).cuda()
# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
trX, vaX = np.split(X, [int(90e6)])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
# 定义文本采样数据集类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start: rand_start + self.seq_len].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
# 优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader), return_loss = True)
loss.backward()
print(f'training loss: {loss.item()}')
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader), return_loss = True)
print(f'validation loss: {loss.item()}')
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:-1]
prime_inp = inp[:PRIME_LEN]
prime = decode_tokens(prime_inp)
print(f'%s \n\n %s', (prime, '*' * 100))
sample = model.generate(prime_inp[None, :])
sample = sample.flatten(1)
output_str = decode_tokens(sample[0][PRIME_LEN:])
print(output_str)
Data source
The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
RVQ-VAE-GPT - Residual Vector Quantize VAE - GPT (wip)
My attempts at applying Soundstream design on learned tokenization of text and then applying a hierarchical transformer to text generation.
The Soundstream will be modified to use all local attention. Experiments will compare VQ, RVQ, and also multi-headed VQ
Was told by a researcher friend this will likely fail 😂😂 but I will try it anyways, yolo. In the case it does not work, maybe it can still be useful for genomics. Come to think of it, why shouldn't it be able to at least learn bigrams (for english) and codons (for genomics)? Why don't we have hierarchical predictive coding? We should
Update: Some live experiments
Todo
Citations
@misc{https://doi.org/10.48550/arxiv.2107.03312,
title = {SoundStream: An End-to-End Neural Audio Codec},
author = {Zeghidour, Neil and Luebs, Alejandro and Omran, Ahmed and Skoglund, Jan and Tagliasacchi, Marco},
publisher = {arXiv},
url = {https://arxiv.org/abs/2107.03312},
year = {2021}
}
@unknown{unknown,
author = {Lee, Doyup and Kim, Chiheon and Kim, Saehoon and Cho, Minsu and Han, Wook-Shin},
year = {2022},
month = {03},
title = {Autoregressive Image Generation using Residual Quantization}
}
@article{Sunkara2022NoMS,
title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
author = {Raja Sunkara and Tie Luo},
journal = {ArXiv},
year = {2022},
volume = {abs/2208.03641}
}
.\lucidrains\rvq-vae-gpt\rvq_vae_gpt\rvq_vae_gpt.py
# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块并重命名为 F
import torch.nn.functional as F
# 从 torch 模块中导入 nn、einsum
from torch import nn, einsum
# 从 einops 库中导入 rearrange、repeat、pack、unpack
from einops import rearrange, repeat, pack, unpack
# 从 einops.layers.torch 模块中导入 Rearrange
from einops.layers.torch import Rearrange
# 导入自定义的 local_attention 模块中的 LocalMHA 类
from local_attention import LocalMHA
# 导入自定义的 vector_quantize_pytorch 模块中的 VectorQuantize、ResidualVQ 类
from vector_quantize_pytorch import VectorQuantize, ResidualVQ
# 从 beartype 库中导入 beartype、Tuple、Optional、Union
from beartype import beartype
from beartype.typing import Tuple, Optional, Union
# 从 pathlib 模块中导入 Path 类
from pathlib import Path
# 导入 pickle 库
import pickle
# 辅助函数
# 判断值是否存在
def exists(val):
return val is not None
# 获取迭代器的第一个元素
def first(it):
return it[0]
# 返回第一个存在的值
def default(*vals):
for val in vals:
if exists(val):
return val
return None
# 判断一个数是否可以被另一个数整除
def divisible_by(numer, denom):
return (numer % denom) == 0
# 将输入转换为元组
def cast_tuple(t, len = 1):
return ((t,) * len) if not isinstance(t, tuple) else t
# token shift - RWKV 中使用
# 将输入张量按照最后一个维度分割成两部分,并进行位移
def shift_tokens(t):
t, t_shift = t.chunk(2, dim = -1)
t_shift = F.pad(t_shift, (0, 0, 1, -1), value = 0.)
return torch.cat((t, t_shift), dim = -1)
# 前馈网络
# GEGLU 激活函数
class GEGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim = -1)
return x * F.gelu(gate)
# 创建前馈网络模块
def FeedForward(dim, mult = 4):
dim_inner = int(dim * mult * 2 / 3)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim_inner * 2),
GEGLU(),
nn.Linear(dim_inner, dim)
)
# 最佳的上采样和下采样方式
# 上采样模块
class Upsample(nn.Module):
def __init__(
self,
dim,
dim_out = None,
factor = 2
):
super().__init__()
dim_out = default(dim_out, dim)
linear = nn.Linear(dim, dim_out * factor)
self.net = nn.Sequential(
linear,
nn.SiLU(),
Rearrange('b n (p d) -> b (n p) d', p = factor)
)
self.factor = factor
self.init_(linear)
# 初始化线性层的权重和偏置
def init_(self, linear):
o, i = linear.weight.shape
linear_weight = torch.empty(o // self.factor, i)
nn.init.kaiming_uniform_(linear_weight)
linear_weight = repeat(linear_weight, 'o ... -> (o r) ...', r = self.factor)
linear_weight.data.copy_(linear_weight)
nn.init.zeros_(linear.bias.data)
def forward(self, x):
return self.net(x)
# 下采样模块
def Downsample(
dim,
dim_out = None,
factor = 2
):
dim_out = default(dim_out, dim)
return nn.Sequential(
Rearrange('b (n p) d -> b n (p d)', p = factor),
nn.Linear(dim * factor, dim_out)
)
# 本地注意力
# 本地 Transformer 模块
class LocalTransformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
heads,
dim_head,
window_size
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
LocalMHA(
dim = dim,
heads = heads,
dim_head = dim_head,
qk_rmsnorm = True,
window_size = window_size,
use_rotary_pos_emb = True,
use_xpos = True,
causal = True
),
FeedForward(dim = dim)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(shift_tokens(x)) + x
x = ff(shift_tokens(x)) + x
return x
# 模块
# 文本 VQ-VAE 模型
@beartype
class TextVQVAE(nn.Module): # 或者基因组,最终,将 num_tokens 设置为 4
def __init__(
self,
*,
num_tokens,
dim: Union[int, Tuple[int, ...]],
depth: Union[int, Tuple[int, ...]],
strides: Union[int, Tuple[int, ...]],
codebook_size = 1024,
local_attn_window_size = 32,
local_attn_heads = 8,
local_attn_dim_head = 64,
num_codebooks = 4,
vq_decay = 0.9,
rvq_quantize_dropout = True
# 初始化函数,继承父类的初始化方法
def __init__(
self,
vq_decay,
strides,
dim,
depth,
local_attn_window_size,
num_tokens,
local_attn_heads,
local_attn_dim_head,
num_codebooks,
codebook_size,
rvq_quantize_dropout
):
# 调用父类的初始化方法
super().__init__()
# 获取当前函数的局部变量
config = locals()
# 移除不需要的变量
config.pop('self')
config.pop('__class__')
# 将配置信息保存到实例变量中
self._config = config
# 断言 vq_decay 的取值范围
assert 0 < vq_decay <= 1.
# 将 strides 转换为元组
strides = cast_tuple(strides)
num_layers = len(strides)
# 将 dim、depth、local_attn_window_size 转换为元组
dim = cast_tuple(dim, num_layers)
depth = cast_tuple(depth, num_layers)
local_attn_window_size = cast_tuple(local_attn_window_size, num_layers)
# 断言各参数长度一致
assert num_layers == len(depth) == len(local_attn_window_size) == len(dim)
# 获取初始维度和 VQ 维度
init_dim, vq_dim = dim[0], dim[-1]
# 构建维度列表和维度对
dims = [first(dim), *dim]
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
# 创建 token embedding 层
self.token_emb = nn.Embedding(num_tokens, init_dim)
# 计算总步长
self.total_strides = torch.tensor(list(strides)).cumprod(dim = -1)[-1].item()
# 初始化 encoder
self.encoder = nn.ModuleList([])
# 构建每一层的参数元组
layer_params = tuple(zip(
strides,
depth,
local_attn_window_size,
dim_pairs
))
# 初始化初始 transformer
self.init_transformer = LocalTransformer(
dim = init_dim,
depth = first(depth),
heads = local_attn_heads,
dim_head = local_attn_dim_head,
window_size = first(local_attn_window_size)
)
# 初始化最终 transformer
self.final_transformer = LocalTransformer(
dim = init_dim,
depth = first(depth),
heads = local_attn_heads,
dim_head = local_attn_dim_head,
window_size = first(local_attn_window_size)
)
# 遍历每一层参数,构建 encoder
for layer_stride, layer_depth, layer_local_attn_window_size, (dim_in, dim_out) in layer_params:
self.encoder.append(nn.ModuleList([
Downsample(dim = dim_in, dim_out = dim_out, factor = layer_stride),
LocalTransformer(
dim = dim_out,
depth = layer_depth,
heads = local_attn_heads,
dim_head = local_attn_dim_head,
window_size = layer_local_attn_window_size
)
]))
# 初始化 encoder_norm
self.encoder_norm = nn.LayerNorm(vq_dim)
# 初始化 VQ
self.vq = ResidualVQ(
dim = vq_dim,
num_quantizers = num_codebooks,
codebook_size = codebook_size,
decay = vq_decay,
quantize_dropout = num_codebooks > 1 and rvq_quantize_dropout,
commitment_weight = 0., # the weight on the commitment loss
kmeans_init = True,
kmeans_iters = 10
)
# 初始化 decoder
self.decoder = nn.ModuleList([])
# 遍历每一层参数,构建 decoder
for layer_stride, layer_depth, layer_local_attn_window_size, (dim_in, dim_out) in reversed(layer_params):
self.decoder.append(nn.ModuleList([
Upsample(dim = dim_out, dim_out = dim_in, factor = layer_stride),
LocalTransformer(
dim = dim_out,
depth = layer_depth,
heads = local_attn_heads,
dim_head = local_attn_dim_head,
window_size = layer_local_attn_window_size
)
]))
# 初始化 to_logits
self.to_logits = nn.Sequential(
nn.LayerNorm(init_dim),
nn.Linear(init_dim, num_tokens)
)
# 保存模型
def save(self, path):
path = Path(path)
pkg = dict(
model = self.state_dict(),
config = pickle.dumps(self._config)
)
torch.save(pkg, str(path))
# 加载模型
def load(self, path):
path = Path(path)
assert path.exists()
pkg = torch.load(str(path))
self.load_state_dict(pkg['model'])
# 初始化并加载模型
@classmethod
def init_and_load(cls, path):
path = Path(path)
assert path.exists()
pkg = torch.load(str(path))
model = cls(**pickle.loads(pkg['config']))
model.load(path)
return model
# 获取设备信息
@property
def device(self):
return next(self.parameters()).device
# 编码器,将输入的ids转换为tokens
def encode(self, ids):
# 使用token_emb方法将ids转换为tokens
tokens = self.token_emb(ids)
# 使用init_transformer方法对tokens进行初始化转换
tokens = self.init_transformer(tokens)
# 遍历编码器中的每个层,进行下采样和局部注意力操作
for downsample, local_attn in self.encoder:
tokens = downsample(tokens)
tokens = local_attn(tokens)
# 对编码后的tokens进行编码器归一化
return self.encoder_norm(tokens)
# 解码器,将codes解码为logits
def decode(self, codes):
# 将codes赋值给tokens
tokens = codes
# 遍历解码器中的每个层,进行局部注意力和上采样操作
for upsample, local_attn in self.decoder:
tokens = local_attn(tokens)
tokens = upsample(tokens)
# 对解码后的tokens进行最终转换
tokens = self.final_transformer(tokens)
# 将tokens转换为logits
logits = self.to_logits(tokens)
return logits
# 从codebook_ids解码得到logits
@torch.no_grad()
def decode_from_codebook_ids(self, codebook_ids):
# 使用vq对象的get_codes_from_indices方法将codebook_ids转换为codes
codes = self.vq.get_codes_from_indices(codebook_ids)
# 调用decode方法解码codes得到logits
return self.decode(codes)
# 整体前向传播过程
def forward(
self,
ids,
return_codebook_indices = False,
return_reconstruction = False,
return_loss_breakdown = False
):
# 获取ids的batch和seq长度
batch, seq = ids.shape
# 断言seq能够被total_strides整除
assert divisible_by(seq, self.total_strides)
# 将ids移动到设备上
ids = ids.to(self.device)
# 对ids进行编码得到tokens
tokens = self.encode(ids)
# 对tokens进行向量量化操作,返回更新后的tokens、indices和loss
tokens, indices, _ = self.vq(tokens)
# 如果需要返回codebook_indices,则直接返回indices
if return_codebook_indices:
return indices
# 对tokens进行解码得到logits
logits = self.decode(tokens)
# 将logits重新排列为 'b c n' 的形式
logits = rearrange(logits, 'b n c -> b c n')
# 计算交叉熵损失
loss = F.cross_entropy(
logits,
ids
)
# 如果需要返���重构结果,则返回loss和logits的argmax值
if return_reconstruction:
return loss, logits.argmax(dim = 1)
# 返回loss
return loss
# 定义一个名为Transformer的类,表示层次结构的变换器
class Transformer(nn.Module):
pass
.\lucidrains\rvq-vae-gpt\rvq_vae_gpt\__init__.py
# 从 rvq_vae_gpt.rvq_vae_gpt 模块中导入 TextVQVAE 和 Transformer 类
from rvq_vae_gpt.rvq_vae_gpt import TextVQVAE, Transformer
.\lucidrains\rvq-vae-gpt\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 设置包的信息
setup(
name = 'rvq-vae-gpt', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.0.4', # 版本号
license='MIT', # 许可证
description = 'Yet another attempt at GPT in quantized latent space', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/rvq-vae-gpt', # 项目链接
keywords = [ # 关键词
'artificial intelligence',
'deep learning',
'transformers',
'attention mechanism'
],
install_requires=[ # 安装依赖
'beartype',
'einops>=0.4',
'local-attention>=1.0.0',
'torch>=1.6',
'vector-quantize-pytorch>=1.1.2'
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\rvq-vae-gpt\train.py
# 导入所需的库
import gzip
import random
import tqdm
import numpy as np
import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 导入自定义模块
from rvq_vae_gpt import TextVQVAE
# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
SAVE_EVERY = 1000
SEQ_LEN = 2048
# 定义辅助函数
def cycle(loader):
# 无限循环生成数据
while True:
for data in loader:
yield data
def first(it):
# 返回迭代器的第一个元素
return it[0]
def decode_token(token):
# 将 token 解码为字符
return str(chr(max(32, token)))
def decode_tokens(tokens):
# 将 tokens 解码为字符串
return "".join(list(map(decode_token, tokens)))
# 实例化 TextVQVAE 模型
model = TextVQVAE(
num_tokens = 256,
dim = (128, 256, 512),
depth = (2, 2, 4),
local_attn_window_size = 64,
num_codebooks = 8,
strides = (2, 2, 2)
).cuda()
# 准备 enwik8 数据
with gzip.open("./data/enwik8.gz") as file:
data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
np_train, np_valid = np.split(data, [int(90e6)])
data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)
# 定义 TextSamplerDataset 类
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
full_seq = self.data[rand_start : rand_start + self.seq_len].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
# 创建训练集和验证集的 DataLoader
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))
# 定义优化器
optim = Adam(model.parameters(), lr = LEARNING_RATE)
# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
model.train()
for _ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
loss.backward()
print(f"training loss: {loss.item():.3f}")
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i == 0:
continue
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
valid_text = next(val_loader)
loss, recon = model(valid_text, return_reconstruction = True)
print(f"validation loss: {loss.item():.3f}")
print(f"\n\n\n[input text]\n\n {decode_tokens(first(valid_text))}")
print(f"\n\n[reconstructed text]\n\n {decode_tokens(first(recon))}\n\n")
if i % SAVE_EVERY == 0:
model.save('./text-vae.pt')
SAC (Soft Actor Critic) - Pytorch (wip)
Implementation of Soft Actor Critic and some of its improvements in Pytorch. Interest comes from watching this lecture
Citations
@article{Haarnoja2018SoftAA,
title = {Soft Actor-Critic Algorithms and Applications},
author = {Tuomas Haarnoja and Aurick Zhou and Kristian Hartikainen and G. Tucker and Sehoon Ha and Jie Tan and Vikash Kumar and Henry Zhu and Abhishek Gupta and P. Abbeel and Sergey Levine},
journal = {ArXiv},
year = {2018},
volume = {abs/1812.05905},
url = {https://api.semanticscholar.org/CorpusID:55703664}
}
@article{Hiraoka2021DropoutQF,
title = {Dropout Q-Functions for Doubly Efficient Reinforcement Learning},
author = {Takuya Hiraoka and Takahisa Imagawa and Taisei Hashimoto and Takashi Onishi and Yoshimasa Tsuruoka},
journal = {ArXiv},
year = {2021},
volume = {abs/2110.02034},
url = {https://api.semanticscholar.org/CorpusID:238353966}
}
@inproceedings{ObandoCeron2024MixturesOE,
title = {Mixtures of Experts Unlock Parameter Scaling for Deep RL},
author = {Johan S. Obando-Ceron and Ghada Sokar and Timon Willi and Clare Lyle and Jesse Farebrother and Jakob Foerster and Gintare Karolina Dziugaite and Doina Precup and Pablo Samuel Castro},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:267637059}
}
.\lucidrains\SAC-pytorch\SAC_pytorch\SAC.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum, Tensor
from torch import nn, einsum, Tensor
# 从 torch.nn 库中导入 Module, ModuleList
from torch.nn import Module, ModuleList
# 导入 beartype 库
from beartype import beartype
# 从 beartype.typing 中导入 Tuple, List, Optional, Union
from beartype.typing import Tuple, List, Optional, Union
# 导入 einx 库中的 get_at 函数
from einx import get_at
# 导入 einops 库中的 rearrange, repeat, reduce, pack, unpack 函数
from einops import rearrange, repeat, reduce, pack, unpack
# 导入 ema_pytorch 库中的 EMA 类
from ema_pytorch import EMA
# 定义函数 exists,判断变量是否存在
def exists(v):
return v is not None
# 定义函数 cast_tuple,将输入转换为元组
def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else ((t,) * length)
# 定义 MLP 函数,创建简单的多层感知器网络
@beartype
def MLP(
dim,
dim_out,
dim_hiddens: Union[int, Tuple[int, ...]],
layernorm = False,
dropout = 0.,
activation = nn.ReLU
):
"""
simple mlp for Q and value networks
following Figure 1 in https://arxiv.org/pdf/2110.02034.pdf for placement of dropouts and layernorm
however, be aware that Levine in his lecture has ablations that show layernorm alone (without dropout) is sufficient for regularization
"""
dim_hiddens = cast_tuple(dim_hiddens)
layers = []
curr_dim = dim
for dim_hidden in dim_hiddens:
layers.append(nn.Linear(curr_dim, dim_hidden))
layers.append(nn.Dropout(dropout))
if layernorm:
layers.append(nn.LayerNorm(dim_hidden))
layers.append(activation())
curr_dim = dim_hidden
# final layer out
layers.append(nn.Linear(curr_dim, dim_out))
return nn.Sequential(*layers)
# 定义 Actor 类,用于创建 Actor 神经网络模型
class Actor(Module):
def __init__(
self,
*,
dim_state,
num_cont_actions,
dim_hiddens: Tuple[int, ...] = tuple(),
eps = 1e-5
):
super().__init__()
self.eps = eps
self.to_cont_actions = MLP(
dim_state,
dim_hiddens = dim_hiddens,
dim_out = num_cont_actions * 2
)
def forward(
self,
state,
sample = False
):
"""
einops notation
n - num actions
ms - mu sigma
"""
out = self.to_cont_actions(state)
mu, sigma = rearrange(out, '... (n ms) -> ms ... n', ms = 2)
sigma = sigma.sigmoid().clamp(min = self.eps)
if not sample:
return mu, sigma
return mu + sigma * torch.randn_like(sigma)
# 定义 Critic 类,用于创建 Critic 神经网络模型
class Critic(Module):
@beartype
def __init__(
self,
*,
dim_state,
num_continuous_actions,
dim_hiddens: Tuple[int, ...] = tuple(),
layernorm = False,
dropout = 0.
):
super().__init__()
self.to_q = MLP(
dim_state + num_continuous_actions,
dim_out = 1,
dim_hiddens = dim_hiddens,
layernorm = layernorm,
dropout = dropout
)
def forward(
self,
state,
actions
):
state_actions, _ = pack([state, actions], 'b *')
q_values = self.to_q(state_actions)
q_values = rearrange('b 1 -> b')
return q_values
# 定义 ValueNetwork 类,用于创建值网络模型
class ValueNetwork(Module):
@beartype
def __init__(
self,
*,
dim_state,
dim_hiddens: Tuple[int, ...] = tuple()
):
super().__init__()
self.to_values = MLP(
dim_state,
dim_out= 1,
dim_hiddens = dim_hiddens
)
def forward(
self,
states
):
values = self.to_values(states)
values = rearrange(values, 'b 1 -> b')
return values
# 定义 SAC 类,用于创建 SAC 神经网络模型
class SAC(Module):
def __init__(
self
):
super().__init__()
def forward(self, x):
return x
.\lucidrains\SAC-pytorch\SAC_pytorch\__init__.py
# 从SAC_pytorch包中导入SAC类
from SAC_pytorch.SAC import SAC
.\lucidrains\SAC-pytorch\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
# 包的名称
name = 'SAC-pytorch',
# 查找所有包,不排除任何包
packages = find_packages(exclude=[]),
# 版本号
version = '0.0.1',
# 许可证类型
license='MIT',
# 描述
description = 'Soft Actor Critic',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 项目链接
url = 'https://github.com/lucidrains/SAC-pytorch',
# 关键词列表
keywords = [
'artificial intelligence',
'deep learning',
'reinforcement learning',
'soft actor critic'
],
# 安装依赖项
install_requires=[
'beartype',
'einops>=0.7.0',
'einx[torch]>=0.1.3',
'ema-pytorch',
'pytorch-custom-utils>=0.0.18',
'soft-moe-pytorch>=0.1.6',
'torch>=2.0'
],
# 分类器列表
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
Scattering Compositional Learner
Implementation of Scattering Compositional Learner, which reached superhuman levels on Raven's Progressive Matrices, a type of IQ test for analogical reasoning.
This repository is meant to be exploratory, so it may not follow the exact architecture of the paper down to the T. It is meant to find the underlying inductive bias that could be exported for use in attention networks. The paper suggests this to be the 'Scattering Transform', which is basically grouped convolutions but where each group is tranformed by one shared neural network.
If you would like the exact architecture used in the paper, the official repository is here.
Install
$ pip install scattering-transform
Use
Complete Scattering Compositional Learner network
import torch
import torch.nn.functional as F
from scattering_transform import SCL, SCLTrainingWrapper
# data - (batch, number of choices, channel dimension, image height, image width)
questions = torch.randn(1, 8, 1, 160, 160)
answers = torch.randn(1, 8, 1, 160, 160)
labels = torch.tensor([2])
# instantiate model
model = SCL(
image_size = 160, # size of image
set_size = 9, # number of questions + 1 answer
conv_channels = [1, 16, 16, 32, 32, 32], # convolutional channel progression, 1 for greyscale, 3 for rgb
conv_output_dim = 80, # model dimension, the output dimension of the vision net
attr_heads = 10, # number of attribute heads
attr_net_hidden_dims = [128], # attribute scatter transform MLP hidden dimension(s)
rel_heads = 80, # number of relationship heads
rel_net_hidden_dims = [64, 23, 5] # MLP for relationship net
)
model = SCLTrainingWrapper(model)
logits = model(questions, answers) # (1, 8) - the logits of each answer being the correct match
# train
loss = F.cross_entropy(logits, labels)
loss.backward()
Scattering Transform, which is basically one MLP that acts over groups of the dimension
import torch
from scattering_transform import ScatteringTransform
# for potential use in a Transformer
mlp = ScatteringTransform(
dims = [1024, 4096, 1024], # MLP - dimension in -> hidden sizes -> dimension out
heads = 16, # number of groups (heads)
activation = nn.LeakyReLU # activation to use in the MLP
)
x = torch.randn(1, 512, 1024)
mlp(x) # (1, 512, 1024)
Citation
@misc{wu2020scattering,
title={The Scattering Compositional Learner: Discovering Objects, Attributes, Relationships in Analogical Reasoning},
author={Yuhuai Wu and Honghua Dong and Roger Grosse and Jimmy Ba},
year={2020},
eprint={2007.04212},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
.\lucidrains\scattering-compositional-learner\scattering_transform\scattering_transform.py
# 导入 PyTorch 库
import torch
from torch import nn
import torch.nn.functional as F
# 辅助函数
# 如果 val 不为 None,则返回 val,否则返回 default_val
def default(val, default_val):
return val if val is not None else default_val
# 在指定维度上扩展张量 t 的大小为 k
def expand_dim(t, dim, k):
t = t.unsqueeze(dim)
expand_shape = [-1] * len(t.shape)
expand_shape[dim] = k
return t.expand(*expand_shape)
# 简单的具有 ReLU 激活函数的多层感知机
class MLP(nn.Module):
def __init__(self, *dims, activation = None):
super().__init__()
assert len(dims) > 2, 'must have at least 3 dimensions, for dimension in and dimension out'
activation = default(activation, nn.ReLU)
layers = []
pairs = list(zip(dims[:-1], dims[1:]))
for ind, (dim_in, dim_out) in enumerate(pairs):
is_last = ind >= (len(pairs) - 1)
layers.append(nn.Linear(dim_in, dim_out))
if not is_last:
layers.append(activation())
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
# 论文中提到的前馈残差块
# 用于在提取视觉特征后以及提取属性信息后使用
class FeedForwardResidual(nn.Module):
def __init__(self, dim, mult = 4):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult),
nn.LayerNorm(dim * mult),
nn.ReLU(inplace = True),
nn.Linear(dim * mult, dim)
)
def forward(self, x):
return x + self.net(x)
# 卷积网络
# 待完成,使其可定制化并添加 Evonorm 以进行批次独立归一化
class ConvNet(nn.Module):
def __init__(self, image_size, chans, output_dim):
super().__init__()
num_conv_layers = len(chans) - 1
conv_output_size = image_size // (2 ** num_conv_layers)
convolutions = []
channel_pairs = list(zip(chans[:-1], chans[1:]))
for ind, (chan_in, chan_out) in enumerate(channel_pairs):
is_last = ind >= (len(channel_pairs) - 1)
convolutions.append(nn.Conv2d(chan_in, chan_out, 3, padding=1, stride=2))
if not is_last:
convolutions.append(nn.BatchNorm2d(chan_out))
self.net = nn.Sequential(
*convolutions,
nn.Flatten(1),
nn.Linear(chans[-1] * (conv_output_size ** 2), output_dim),
nn.ReLU(inplace=True),
FeedForwardResidual(output_dim)
)
def forward(self, x):
return self.net(x)
# 散射变换
class ScatteringTransform(nn.Module):
def __init__(self, dims, heads, activation = None):
super().__init__()
assert len(dims) > 2, 'must have at least 3 dimensions, for dimension in, the hidden dimension, and dimension out'
dim_in, *hidden_sizes, dim_out = dims
dim_in //= heads
dim_out //= heads
self.heads = heads
self.mlp = MLP(dim_in, *hidden_sizes, dim_out, activation = activation)
def forward(self, x):
shape, heads = x.shape, self.heads
dim = shape[-1]
assert (dim % heads) == 0, f'the dimension {dim} must be divisible by the number of heads {heads}'
x = x.reshape(-1, heads, dim // heads)
x = self.mlp(x)
return x.reshape(shape)
# 主要的散射组合学习器类
class SCL(nn.Module):
# 初始化函数,设置模型的参数
def __init__(
self,
image_size = 160, # 图像大小
set_size = 9, # 集合大小
conv_channels = [1, 16, 16, 32, 32, 32], # 卷积通道数
conv_output_dim = 80, # 卷积输出维度
attr_heads = 10, # 属性头数
attr_net_hidden_dims = [128], # 属性网络隐藏层维度
rel_heads = 80, # 关系头数
rel_net_hidden_dims = [64, 23, 5]): # 关系网络隐藏层维度
super().__init__()
# 创建视觉模型
self.vision = ConvNet(image_size, conv_channels, conv_output_dim)
# 设置属性头数和属性网络
self.attr_heads = attr_heads
self.attr_net = ScatteringTransform([conv_output_dim, *attr_net_hidden_dims, conv_output_dim], heads = attr_heads)
self.ff_residual = FeedForwardResidual(conv_output_dim)
# 设置关系头数和关系网络
self.rel_heads = rel_heads
self.rel_net = MLP(set_size * (conv_output_dim // rel_heads), *rel_net_hidden_dims)
# 线性层,用于输出logits
self.to_logit = nn.Linear(rel_net_hidden_dims[-1] * rel_heads, 1)
# 前向传播函数
def forward(self, sets):
# 获取输入集合的形状信息
b, m, n, c, h, w = sets.shape
# 将集合展平为二维张量
images = sets.view(-1, c, h, w)
# 提取图像特征
features = self.vision(images)
# 计算属性
attrs = self.attr_net(features)
attrs = self.ff_residual(attrs)
# 重塑属性张量形状
attrs = attrs.reshape(b, m, n, self.rel_heads, -1).transpose(-2, -3).flatten(3)
# 计算关系
rels = self.rel_net(attrs)
rels = rels.flatten(2)
# 计算logits
logits = self.to_logit(rels).flatten(1)
return logits
# 为了更容易进行训练而创建的包装器类
class SCLTrainingWrapper(nn.Module):
def __init__(self, scl):
super().__init__()
self.scl = scl
# 前向传播函数,接收问题和答案作为输入
def forward(self, questions, answers):
# 在答案张量上增加一个维度
answers = answers.unsqueeze(2)
# 在问题张量上扩展维度,维度1扩展为8
questions = expand_dim(questions, dim=1, k=8)
# 将问题和答案张量连接在一起,沿着第二个维度
permutations = torch.cat((questions, answers), dim=2)
# 将连接后的张量传递给self.scl进行处理
return self.scl(permutations)
.\lucidrains\scattering-compositional-learner\scattering_transform\__init__.py
# 从scattering_transform包中导入SCL, ScatteringTransform, SCLTrainingWrapper类
from scattering_transform.scattering_transform import SCL, ScatteringTransform, SCLTrainingWrapper
.\lucidrains\scattering-compositional-learner\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 设置包的元信息
setup(
name = 'scattering-transform', # 包名
packages = find_packages(), # 查找所有包
version = '0.0.7', # 版本号
license='MIT', # 许可证
description = 'Scattering Transform module from the paper Scattering Compositional Learner', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/scattering-compositional-learner', # 项目链接
keywords = ['artificial intelligence', 'deep learning', 'reasoning'], # 关键词
install_requires=[
'torch' # 安装依赖
],
classifiers=[
'Development Status :: 4 - Beta', # 开发状态
'Intended Audience :: Developers', # 目标受众
'Topic :: Scientific/Engineering :: Artificial Intelligence', # 主题
'License :: OSI Approved :: MIT License', # 许可证类型
'Programming Language :: Python :: 3.6', # 编程语言
],
)
.\lucidrains\se3-transformer-pytorch\denoise.py
# 导入 PyTorch 库
import torch
# 导入 PyTorch 中的函数库
import torch.nn.functional as F
# 从 torch.optim 中导入 Adam 优化器
from torch.optim import Adam
# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 导入 sidechainnet 库,并从 se3_transformer_pytorch 中导入 SE3Transformer 类
import sidechainnet as scn
from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer
# 设置默认的数据类型为 float64
torch.set_default_dtype(torch.float64)
# 定义批量大小为 1
BATCH_SIZE = 1
# 定义每隔多少次梯度累积
GRADIENT_ACCUMULATE_EVERY = 16
# 定义一个循环函数,用于处理数据加载器
def cycle(loader, len_thres = 500):
while True:
for data in loader:
# 如果数据序列长度大于指定阈值,则继续循环
if data.seqs.shape[1] > len_thres:
continue
yield data
# 创建 SE3Transformer 模型
transformer = SE3Transformer(
num_tokens = 24,
dim = 8,
dim_head = 8,
heads = 2,
depth = 2,
attend_self = True,
input_degrees = 1,
output_degrees = 2,
reduce_dim_out = True,
differentiable_coors = True,
num_neighbors = 0,
attend_sparse_neighbors = True,
num_adj_degrees = 2,
adj_dim = 4,
num_degrees=2,
)
# 加载数据集
data = scn.load(
casp_version = 12,
thinning = 30,
with_pytorch = 'dataloaders',
batch_size = BATCH_SIZE,
dynamic_batching = False
)
# 创建数据加载器
dl = cycle(data['train'])
# 使用 Adam 优化器来优化 SE3Transformer 模型的参数
optim = Adam(transformer.parameters(), lr=1e-4)
# 将模型转移到 GPU 上
transformer = transformer.cuda()
# 进行训练循环
for _ in range(10000):
for _ in range(GRADIENT_ACCUMULATE_EVERY):
# 获取一个批次的数据
batch = next(dl)
seqs, coords, masks = batch.seqs, batch.crds, batch.msks
# 将序列转移到 GPU 上,并取最大值索引
seqs = seqs.cuda().argmax(dim = -1)
# 将坐标转移到 GPU 上,并设置数据类型为 float64
coords = coords.cuda().type(torch.float64)
# 将掩码转移到 GPU 上,并设置数据类型为布尔型
masks = masks.cuda().bool()
# 获取序列长度
l = seqs.shape[1]
# 重新排列坐标数据
coords = rearrange(coords, 'b (l s) c -> b l s c', s = 14)
# 保留骨架坐标
coords = coords[:, :, 0:3, :]
coords = rearrange(coords, 'b l s c -> b (l s) c')
# 重复序列和掩码
seq = repeat(seqs, 'b n -> b (n c)', c = 3)
masks = repeat(masks, 'b n -> b (n c)', c = 3)
# 添加高斯噪声到坐标数据
noised_coords = coords + torch.randn_like(coords).cuda()
# 创建邻接矩阵
i = torch.arange(seq.shape[-1], device = seqs.device)
adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))
# 使用 SE3Transformer 进行前向传播
out = transformer(
seq,
noised_coords,
mask = masks,
adj_mat = adj_mat,
return_type = 1
)
# 对去噪后的坐标数据计算均方误差损失
denoised_coords = noised_coords + out
loss = F.mse_loss(denoised_coords[masks], coords[masks])
# 反向传播
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
# 输出损失值
print('loss:', loss.item())
# 更新优化器
optim.step()
# 梯度清零
optim.zero_grad()
SE3 Transformer - Pytorch
Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. May be needed for replicating Alphafold2 results and other drug discovery applications.
If you had been using any version of SE3 Transformers prior to version 0.6.0, please update. A huge bug has been uncovered by @MattMcPartlon, if you were not using the adjacency sparse neighbors settings and relying on nearest neighbors functionality
Update: It is recommended that you use Equiformer instead
Install
$ pip install se3-transformer-pytorch
Usage
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer(
dim = 512,
heads = 8,
depth = 6,
dim_head = 64,
num_degrees = 4,
valid_radius = 10
)
feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
mask = torch.ones(1, 1024).bool()
out = model(feats, coors, mask) # (1, 1024, 512)
Potential example usage in Alphafold2, as outlined here
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer(
dim = 64,
depth = 2,
input_degrees = 1,
num_degrees = 2,
output_degrees = 2,
reduce_dim_out = True,
differentiable_coors = True
)
atom_feats = torch.randn(2, 32, 64)
coors = torch.randn(2, 32, 3)
mask = torch.ones(2, 32).bool()
refined_coors = coors + model(atom_feats, coors, mask, return_type = 1) # (2, 32, 3)
You can also let the base transformer class take care of embedding the type 0 features being passed in. Assuming they are atoms
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer(
num_tokens = 28, # 28 unique atoms
dim = 64,
depth = 2,
input_degrees = 1,
num_degrees = 2,
output_degrees = 2,
reduce_dim_out = True
)
atoms = torch.randint(0, 28, (2, 32))
coors = torch.randn(2, 32, 3)
mask = torch.ones(2, 32).bool()
refined_coors = coors + model(atoms, coors, mask, return_type = 1) # (2, 32, 3)
If you think the net could further benefit from positional encoding, you can featurize your positions in space and pass it in as follows.
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer(
dim = 64,
depth = 2,
input_degrees = 2,
num_degrees = 2,
output_degrees = 2,
reduce_dim_out = True # reduce out the final dimension
)
atom_feats = torch.randn(2, 32, 64, 1) # b x n x d x type0
coors_feats = torch.randn(2, 32, 64, 3) # b x n x d x type1
# atom features are type 0, predicted coordinates are type 1
features = {'0': atom_feats, '1': coors_feats}
coors = torch.randn(2, 32, 3)
mask = torch.ones(2, 32).bool()
refined_coors = coors + model(features, coors, mask, return_type = 1) # (2, 32, 3) - equivariant to input type 1 features and coordinates
Edges
To offer edge information to SE3 Transformers (say bond types between atoms), you just have to pass in two more keyword arguments on initialization.
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer(
num_tokens = 28,
dim = 64,
num_edge_tokens = 4, # number of edge type, say 4 bond types
edge_dim = 16, # dimension of edge embedding
depth = 2,
input_degrees = 1,
num_degrees = 3,
output_degrees = 1,
reduce_dim_out = True
)
atoms = torch.randint(0, 28, (2, 32))
bonds = torch.randint(0, 4, (2, 32, 32))
coors = torch.randn(2, 32, 3)
mask = torch.ones(2, 32).bool()
pred = model(atoms, coors, mask, edges = bonds, return_type = 0) # (2, 32, 1)
If you would like to pass in continuous values for your edges, you can choose to not set the num_edge_tokens
, encode your discrete bond types, and then concat it to the fourier features of these continuous values
import torch
from se3_transformer_pytorch import SE3Transformer
from se3_transformer_pytorch.utils import fourier_encode
model = SE3Transformer(
dim = 64,
depth = 1,
attend_self = True,
num_degrees = 2,
output_degrees = 2,
edge_dim = 34 # edge dimension must match the final dimension of the edges being passed in
)
feats = torch.randn(1, 32, 64)
coors = torch.randn(1, 32, 3)
mask = torch.ones(1, 32).bool()
pairwise_continuous_values = torch.randint(0, 4, (1, 32, 32, 2)) # say there are 2
edges = fourier_encode(
pairwise_continuous_values,
num_encodings = 8,
include_self = True
) # (1, 32, 32, 34) - {2 * (2 * 8 + 1)}
out = model(feats, coors, mask, edges = edges, return_type = 1)
Sparse Neighbors
If you know the connectivity of your points (say you are working with molecules), you can pass in an adjacency matrix, in the form of a boolean mask (where True
indicates connectivity).
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer(
dim = 32,
heads = 8,
depth = 1,
dim_head = 64,
num_degrees = 2,
valid_radius = 10,
attend_sparse_neighbors = True, # this must be set to true, in which case it will assert that you pass in the adjacency matrix
num_neighbors = 0, # if you set this to 0, it will only consider the connected neighbors as defined by the adjacency matrix. but if you set a value greater than 0, it will continue to fetch the closest points up to this many, excluding the ones already specified by the adjacency matrix
max_sparse_neighbors = 8 # you can cap the number of neighbors, sampled from within your sparse set of neighbors as defined by the adjacency matrix, if specified
)
feats = torch.randn(1, 128, 32)
coors = torch.randn(1, 128, 3)
mask = torch.ones(1, 128).bool()
# placeholder adjacency matrix
# naively assuming the sequence is one long chain (128, 128)
i = torch.arange(128)
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))
out = model(feats, coors, mask, adj_mat = adj_mat) # (1, 128, 512)
You can also have the network automatically derive for you the Nth-degree neighbors with one extra keyword num_adj_degrees
. If you would like the system to differentiate between the degree of the neighbors as edge information, further pass in a non-zero adj_dim
.
import torch
from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer
model = SE3Transformer(
dim = 64,
depth = 1,
attend_self = True,
num_degrees = 2,
output_degrees = 2,
num_neighbors = 0,
attend_sparse_neighbors = True,
num_adj_degrees = 2, # automatically derive 2nd degree neighbors
adj_dim = 4 # embed 1st and 2nd degree neighbors (as well as null neighbors) with edge embeddings of this dimension
)
feats = torch.randn(1, 32, 64)
coors = torch.randn(1, 32, 3)
mask = torch.ones(1, 32).bool()
# placeholder adjacency matrix
# naively assuming the sequence is one long chain (128, 128)
i = torch.arange(128)
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))
out = model(feats, coors, mask, adj_mat = adj_mat, return_type = 1)
To have fine control over the dimensionality of each type, you can use the hidden_fiber_dict
and out_fiber_dict
keywords to pass in a dictionary with the degree to dimension values as the key / values.
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer(
num_tokens = 28,
dim = 64,
num_edge_tokens = 4,
edge_dim = 16,
depth = 2,
input_degrees = 1,
num_degrees = 3,
output_degrees = 1,
hidden_fiber_dict = {0: 16, 1: 8, 2: 4},
out_fiber_dict = {0: 16, 1: 1},
reduce_dim_out = False
)
atoms = torch.randint(0, 28, (2, 32))
bonds = torch.randint(0, 4, (2, 32, 32))
coors = torch.randn(2, 32, 3)
mask = torch.ones(2, 32).bool()
pred = model(atoms, coors, mask, edges = bonds)
pred['0'] # (2, 32, 16)
pred['1'] # (2, 32, 1, 3)
Neighbors
You can further control which nodes can be considered by passing in a neighbor mask. All False
values will be masked out of consideration.
import torch
from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer
model = SE3Transformer(
dim = 16,
dim_head = 16,
attend_self = True,
num_degrees = 4,
output_degrees = 2,
num_edge_tokens = 4,
num_neighbors = 8, # make sure you set this value as the maximum number of neighbors set by your neighbor_mask, or it will throw a warning
edge_dim = 2,
depth = 3
)
feats = torch.randn(1, 32, 16)
coors = torch.randn(1, 32, 3)
mask = torch.ones(1, 32).bool()
bonds = torch.randint(0, 4, (1, 32, 32))
neighbor_mask = torch.ones(1, 32, 32).bool() # set the nodes you wish to be masked out as False
out = model(
feats,
coors,
mask,
edges = bonds,
neighbor_mask = neighbor_mask,
return_type = 1
)
Global Nodes
This feature allows you to pass in vectors that can be viewed as global nodes that are seen by all other nodes. The idea would be to pool your graph into a few feature vectors, which will be projected to key / values across all the attention layers in the network. All nodes will have full access to global node information, regardless of nearest neighbors or adjacency calculation.
import torch
from torch import nn
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer(
dim = 64,
depth = 1,
num_degrees = 2,
num_neighbors = 4,
valid_radius = 10,
global_feats_dim = 32 # this must be set to the dimension of the global features, in this example, 32
)
feats = torch.randn(1, 32, 64)
coors = torch.randn(1, 32, 3)
mask = torch.ones(1, 32).bool()
# naively derive global features
# by pooling features and projecting
global_feats = nn.Linear(64, 32)(feats.mean(dim = 1, keepdim = True)) # (1, 1, 32)
out = model(feats, coors, mask, return_type = 0, global_feats = global_feats)
Todo:
Autoregressive
You can use SE3 Transformers autoregressively with just one extra flag
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer(
dim = 512,
heads = 8,
depth = 6,
dim_head = 64,
num_degrees = 4,
valid_radius = 10,
causal = True # set this to True
)
feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
mask = torch.ones(1, 1024).bool()
out = model(feats, coors, mask) # (1, 1024, 512)
Experimental Features
Non-pairwise convolved keys
I've discovered that using linearly projected keys (rather than the pairwise convolution) seems to do ok in a toy denoising task. This leads to 25% memory savings. You can try this feature by setting linear_proj_keys = True
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer(
dim = 64,
depth = 1,
num_degrees = 4,
num_neighbors = 8,
valid_radius = 10,
splits = 4,
linear_proj_keys = True # set this to True
).cuda()
feats = torch.randn(1, 32, 64).cuda()
coors = torch.randn(1, 32, 3).cuda()
mask = torch.ones(1, 32).bool().cuda()
out = model(feats, coors, mask, return_type = 0)
Shared key / values across all heads
There is a relatively unknown technique for transformers where one can share one key / value head across all the heads of the queries. In my experience in NLP, this usually leads to worse performance, but if you are really in need to tradeoff memory for more depth or higher number of degrees, this may be a good option.
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer(
dim = 64,
depth = 8,
num_degrees = 4,
num_neighbors = 8,
valid_radius = 10,
splits = 4,
one_headed_key_values = True # one head of key / values shared across all heads of the queries
).cuda()
feats = torch.randn(1, 32, 64).cuda()
coors = torch.randn(1, 32, 3).cuda()
mask = torch.ones(1, 32).bool().cuda()
out = model(feats, coors, mask, return_type = 0)
Tied key / values
You can also tie the key / values (have them be the same), for half memory savings
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer(
dim = 64,
depth = 8,
num_degrees = 4,
num_neighbors = 8,
valid_radius = 10,
splits = 4,
tie_key_values = True # set this to True
).cuda()
feats = torch.randn(1, 32, 64).cuda()
coors = torch.randn(1, 32, 3).cuda()
mask = torch.ones(1, 32).bool().cuda()
out = model(feats, coors, mask, return_type = 0)
Using EGNN
This is an experimental version of EGNN that works for higher types, and greater dimensionality than just 1 (for the coordinates). The class name is still SE3Transformer
since it reuses some preexisting logic, so just ignore that for now until I clean it up later.
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer(
dim = 32,
num_neighbors = 8,
num_edge_tokens = 4,
edge_dim = 4,
num_degrees = 4, # number of higher order types - will use basis on a TCN to project to these dimensions
use_egnn = True, # set this to true to use EGNN instead of equivariant attention layers
egnn_hidden_dim = 64, # egnn hidden dimension
depth = 4, # depth of EGNN
reduce_dim_out = True # will project the dimension of the higher types to 1
).cuda()
feats = torch.randn(2, 32, 32).cuda()
coors = torch.randn(2, 32, 3).cuda()
bonds = torch.randint(0, 4, (2, 32, 32)).cuda()
mask = torch.ones(2, 32).bool().cuda()
refinement = model(feats, coors, mask, edges = bonds, return_type = 1) # (2, 32, 3)
coors = coors + refinement # update coors with refinement
If you would like to specify individual dimensions for each of the higher types, just pass in hidden_fiber_dict
where the dictionary is in the format {<degree>:<dim>} instead of num_degrees
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer(
dim = 32,
num_neighbors = 8,
hidden_fiber_dict = {0: 32, 1: 16, 2: 8, 3: 4},
use_egnn = True,
depth = 4,
egnn_hidden_dim = 64,
egnn_weights_clamp_value = 2,
reduce_dim_out = True
).cuda()
feats = torch.randn(2, 32, 32).cuda()
coors = torch.randn(2, 32, 3).cuda()
mask = torch.ones(2, 32).bool().cuda()
refinement = model(feats, coors, mask, return_type = 1) # (2, 32, 3)
coors = coors + refinement # update coors with refinement
Scaling (wip)
This section will list ongoing efforts to make SE3 Transformer scale a little better.
Firstly, I have added reversible networks. This allows me to add a little more depth before hitting the usual memory roadblocks. Equivariance preservation is demonstrated in the tests.
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer(
num_tokens = 20,
dim = 32,
dim_head = 32,
heads = 4,
depth = 12, # 12 layers
input_degrees = 1,
num_degrees = 3,
output_degrees = 1,
reduce_dim_out = True,
reversible = True # set reversible to True
).cuda()
atoms = torch.randint(0, 4, (2, 32)).cuda()
coors = torch.randn(2, 32, 3).cuda()
mask = torch.ones(2, 32).bool().cuda()
pred = model(atoms, coors, mask = mask, return_type = 0)
loss = pred.sum()
loss.backward()
Examples
First install sidechainnet
$ pip install sidechainnet
Then run the protein backbone denoising task
$ python denoise.py
Caching
By default, the basis vectors are cached. However, if there is ever the need to clear the cache, you simply have to set the environmental flag CLEAR_CACHE
to some value on initiating the script
$ CLEAR_CACHE=1 python train.py
Or you can try deleting the cache directory, which should exist at
$ rm -rf ~/.cache.equivariant_attention
You can also designate your own directory where you want the caches to be stored, in the case that the default directory may have permission issues
CACHE_PATH=./path/to/my/cache python train.py
Testing
$ python setup.py pytest
Credit
This library is largely a port of Fabian's official repository, but without the DGL library.
Citations
@misc{fuchs2020se3transformers,
title = {SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks},
author = {Fabian B. Fuchs and Daniel E. Worrall and Volker Fischer and Max Welling},
year = {2020},
eprint = {2006.10503},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@misc{satorras2021en,
title = {E(n) Equivariant Graph Neural Networks},
author = {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
year = {2021},
eprint = {2102.09844},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@misc{gomez2017reversible,
title = {The Reversible Residual Network: Backpropagation Without Storing Activations},
author = {Aidan N. Gomez and Mengye Ren and Raquel Urtasun and Roger B. Grosse},
year = {2017},
eprint = {1707.04585},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@misc{shazeer2019fast,
title = {Fast Transformer Decoding: One Write-Head is All You Need},
author = {Noam Shazeer},
year = {2019},
eprint = {1911.02150},
archivePrefix = {arXiv},
primaryClass = {cs.NE}
}
.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\basis.py
# 导入必要的库
import os
from math import pi
import torch
from torch import einsum
from einops import rearrange
from itertools import product
from contextlib import contextmanager
# 导入自定义库
from se3_transformer_pytorch.irr_repr import irr_repr, spherical_harmonics
from se3_transformer_pytorch.utils import torch_default_dtype, cache_dir, exists, default, to_order
from se3_transformer_pytorch.spherical_harmonics import clear_spherical_harmonics_cache
# 常量定义
# 设置缓存路径,默认为用户主目录下的.cache.equivariant_attention文件夹
CACHE_PATH = default(os.getenv('CACHE_PATH'), os.path.expanduser('~/.cache.equivariant_attention'))
# 如果环境变量CLEAR_CACHE存在,则将缓存路径设为None
CACHE_PATH = CACHE_PATH if not exists(os.environ.get('CLEAR_CACHE')) else None
# 随机角度列表
# todo (figure ot why this was hard coded in official repo)
RANDOM_ANGLES = [
[4.41301023, 5.56684102, 4.59384642],
[4.93325116, 6.12697327, 4.14574096],
[0.53878964, 4.09050444, 5.36539036],
[2.16017393, 3.48835314, 5.55174441],
[2.52385107, 0.2908958, 3.90040975]
]
# 辅助函数
# 空上下文管理器
@contextmanager
def null_context():
yield
# 函数定义
def get_matrix_kernel(A, eps = 1e-10):
'''
计算矩阵A的核的正交基(x_1, x_2, ...)
A x_i = 0
scalar_product(x_i, x_j) = delta_ij
:param A: 矩阵
:return: 每行是A核的基向量的矩阵
'''
_u, s, v = torch.svd(A)
kernel = v.t()[s < eps]
return kernel
def get_matrices_kernel(As, eps = 1e-10):
'''
计算所有矩阵As的公共核
'''
matrix = torch.cat(As, dim=0)
return get_matrix_kernel(matrix, eps)
def get_spherical_from_cartesian(cartesian, divide_radius_by = 1.0):
"""
将笛卡尔坐标转换为球坐标
# ON ANGLE CONVENTION
#
# sh has following convention for angles:
# :param theta: the colatitude / polar angle, ranging from 0(North Pole, (X, Y, Z) = (0, 0, 1)) to pi(South Pole, (X, Y, Z) = (0, 0, -1)).
# :param phi: the longitude / azimuthal angle, ranging from 0 to 2 pi.
#
# the 3D steerable CNN code therefore (probably) has the following convention for alpha and beta:
# beta = pi - theta; ranging from 0(South Pole, (X, Y, Z) = (0, 0, -1)) to pi(North Pole, (X, Y, Z) = (0, 0, 1).
# alpha = phi
#
"""
# 初始化返回数组
spherical = torch.zeros_like(cartesian)
# 索引
ind_radius, ind_alpha, ind_beta = 0, 1, 2
cartesian_x, cartesian_y, cartesian_z = 2, 0, 1
# 获取在xy平面上的投影半径
r_xy = cartesian[..., cartesian_x] ** 2 + cartesian[..., cartesian_y] ** 2
# 获取第二个角度
# 版本 'elevation angle defined from Z-axis down'
spherical[..., ind_beta] = torch.atan2(torch.sqrt(r_xy), cartesian[..., cartesian_z])
# 获取xy平面上的角度
spherical[...,ind_alpha] = torch.atan2(cartesian[...,cartesian_y], cartesian[...,cartesian_x])
# 获取整体半径
radius = torch.sqrt(r_xy + cartesian[...,cartesian_z]**2)
if divide_radius_by != 1.0:
radius /= divide_radius_by
spherical[..., ind_radius] = radius
return spherical
def kron(a, b):
"""
计算矩阵a和b的Kronecker积
"""
res = einsum('... i j, ... k l -> ... i k j l', a, b)
return rearrange(res, '... i j k l -> ... (i j) (k l)')
def get_R_tensor(order_out, order_in, a, b, c):
return kron(irr_repr(order_out, a, b, c), irr_repr(order_in, a, b, c)
def sylvester_submatrix(order_out, order_in, J, a, b, c):
''' 生成用于在子空间J中解Sylvester方程的Kronecker积矩阵 '''
R_tensor = get_R_tensor(order_out, order_in, a, b, c) # [m_out * m_in, m_out * m_in]
R_irrep_J = irr_repr(J, a, b, c) # [m, m]
R_tensor_identity = torch.eye(R_tensor.shape[0])
R_irrep_J_identity = torch.eye(R_irrep_J.shape[0]
# 计算两个张量的 Kronecker 乘积,并返回结果
return kron(R_tensor, R_irrep_J_identity) - kron(R_tensor_identity, R_irrep_J.t()) # [(m_out * m_in) * m, (m_out * m_in) * m]
# 使用缓存目录装饰器,指定缓存路径为 CACHE_PATH
# 使用默认的 torch 浮点数类型为 float64 装饰器
# 禁用 torch 的梯度计算功能装饰器
def basis_transformation_Q_J(J, order_in, order_out, random_angles = RANDOM_ANGLES):
"""
:param J: 球谐函数的阶数
:param order_in: 输入表示的阶数
:param order_out: 输出表示的阶数
:return: 文章中 Q^-1 矩阵的一部分
"""
# 生成 Sylvester 子矩阵列表
sylvester_submatrices = [sylvester_submatrix(order_out, order_in, J, a, b, c) for a, b, c in random_angles]
# 获取 Sylvester 子矩阵的零空间
null_space = get_matrices_kernel(sylvester_submatrices)
# 断言零空间的大小为 1,即唯一的子空间解
assert null_space.size(0) == 1, null_space.size()
# 获取 Q_J 矩阵
Q_J = null_space[0] # [(m_out * m_in) * m]
# 重塑 Q_J 矩阵的形状
Q_J = Q_J.view(to_order(order_out) * to_order(order_in), to_order(J)) # [m_out * m_in, m]
# 转换为 float 类型并返回
return Q_J.float() # [m_out * m_in, m]
# 预计算球谐函数直到最大阶数 max_J
def precompute_sh(r_ij, max_J):
"""
预计算球谐函数直到最大阶数 max_J
:param r_ij: 相对位置
:param max_J: 整个网络中使用的最大阶数
:return: 字典,每个条目的形状为 [B,N,K,2J+1]
"""
i_alpha, i_beta = 1, 2
# 生成球谐函数字典
Y_Js = {J: spherical_harmonics(J, r_ij[...,i_alpha], r_ij[...,i_beta]) for J in range(max_J + 1)}
# 清除球谐函数缓存
clear_spherical_harmonics_cache()
return Y_Js
# 获取等变权重基础(基础)函数
def get_basis(r_ij, max_degree, differentiable = False):
"""Return equivariant weight basis (basis)
Call this function *once* at the start of each forward pass of the model.
It computes the equivariant weight basis, W_J^lk(x), and internodal
distances, needed to compute varphi_J^lk(x), of eqn 8 of
https://arxiv.org/pdf/2006.10503.pdf. The return values of this function
can be shared as input across all SE(3)-Transformer layers in a model.
Args:
r_ij: relative positional vectors
max_degree: non-negative int for degree of highest feature-type
differentiable: whether r_ij should receive gradients from basis
Returns:
dict of equivariant bases, keys are in form '<d_in><d_out>'
"""
# 相对位置编码(向量)
context = null_context if not differentiable else torch.no_grad
device, dtype = r_ij.device, r_ij.dtype
with context():
# 将笛卡尔坐标系转换为球坐标系
r_ij = get_spherical_from_cartesian(r_ij)
# 预计算球谐函数
Y = precompute_sh(r_ij, 2 * max_degree)
# 等变基础(字典['d_in><d_out>'])
basis = {}
for d_in, d_out in product(range(max_degree+1), range(max_degree+1)):
K_Js = []
for J in range(abs(d_in - d_out), d_in + d_out + 1):
# 获取球谐函数变换矩阵 Q_J
Q_J = basis_transformation_Q_J(J, d_in, d_out)
Q_J = Q_J.type(dtype).to(device)
# 从球谐函数创建核
K_J = torch.matmul(Y[J], Q_J.T)
K_Js.append(K_J)
# 重塑以便可以使用点积进行线性组合
K_Js = torch.stack(K_Js, dim = -1)
size = (*r_ij.shape[:-1], 1, to_order(d_out), 1, to_order(d_in), to_order(min(d_in,d_out)))
basis[f'{d_in},{d_out}'] = K_Js.view(*size)
# 额外的 detach 以确保安全
if not differentiable:
for k, v in basis.items():
basis[k] = v.detach()
return basis
.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\irr_repr.py
# 导入所需的库
import os
import numpy as np
import torch
from torch import sin, cos, atan2, acos
from math import pi
from pathlib import Path
from functools import wraps
# 导入自定义的函数和类
from se3_transformer_pytorch.utils import exists, default, cast_torch_tensor, to_order
from se3_transformer_pytorch.spherical_harmonics import get_spherical_harmonics, clear_spherical_harmonics_cache
# 设置数据路径
DATA_PATH = path = Path(os.path.dirname(__file__)) / 'data'
# 尝试加载预先计算好的 J_dense 数据
try:
path = DATA_PATH / 'J_dense.pt'
Jd = torch.load(str(path))
except:
# 如果加载失败,则加载 numpy 格式的数据并转换为 torch 格式
path = DATA_PATH / 'J_dense.npy'
Jd_np = np.load(str(path), allow_pickle = True)
Jd = list(map(torch.from_numpy, Jd_np))
# 创建 Wigner D 矩阵
def wigner_d_matrix(degree, alpha, beta, gamma, dtype = None, device = None):
"""Create wigner D matrices for batch of ZYZ Euler anglers for degree l."""
J = Jd[degree].type(dtype).to(device)
order = to_order(degree)
x_a = z_rot_mat(alpha, degree)
x_b = z_rot_mat(beta, degree)
x_c = z_rot_mat(gamma, degree)
res = x_a @ J @ x_b @ J @ x_c
return res.view(order, order)
# 创建绕 Z 轴旋转的矩阵
def z_rot_mat(angle, l):
device, dtype = angle.device, angle.dtype
order = to_order(l)
m = angle.new_zeros((order, order))
inds = torch.arange(0, order, 1, dtype=torch.long, device=device)
reversed_inds = torch.arange(2 * l, -1, -1, dtype=torch.long, device=device)
frequencies = torch.arange(l, -l - 1, -1, dtype=dtype, device=device)[None]
m[inds, reversed_inds] = sin(frequencies * angle[None])
m[inds, inds] = cos(frequencies * angle[None])
return m
# 创建不可约表示
def irr_repr(order, alpha, beta, gamma, dtype = None):
"""
irreducible representation of SO3
- compatible with compose and spherical_harmonics
"""
cast_ = cast_torch_tensor(lambda t: t)
dtype = default(dtype, torch.get_default_dtype())
alpha, beta, gamma = map(cast_, (alpha, beta, gamma))
return wigner_d_matrix(order, alpha, beta, gamma, dtype = dtype)
# 绕 Z 轴旋转
@cast_torch_tensor
def rot_z(gamma):
'''
Rotation around Z axis
'''
return torch.tensor([
[cos(gamma), -sin(gamma), 0],
[sin(gamma), cos(gamma), 0],
[0, 0, 1]
], dtype=gamma.dtype)
# 绕 Y 轴旋转
@cast_torch_tensor
def rot_y(beta):
'''
Rotation around Y axis
'''
return torch.tensor([
[cos(beta), 0, sin(beta)],
[0, 1, 0],
[-sin(beta), 0, cos(beta)]
], dtype=beta.dtype)
# 将球���上的点转换为 alpha 和 beta
@cast_torch_tensor
def x_to_alpha_beta(x):
'''
Convert point (x, y, z) on the sphere into (alpha, beta)
'''
x = x / torch.norm(x)
beta = acos(x[2])
alpha = atan2(x[1], x[0])
return (alpha, beta)
# ZYZ 欧拉角旋转
def rot(alpha, beta, gamma):
'''
ZYZ Euler angles rotation
'''
return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)
# 合成旋转
def compose(a1, b1, c1, a2, b2, c2):
"""
(a, b, c) = (a1, b1, c1) composed with (a2, b2, c2)
"""
comp = rot(a1, b1, c1) @ rot(a2, b2, c2)
xyz = comp @ torch.tensor([0, 0, 1.])
a, b = x_to_alpha_beta(xyz)
rotz = rot(0, -b, -a) @ comp
c = atan2(rotz[1, 0], rotz[0, 0])
return a, b, c
# 计算球谐函数
def spherical_harmonics(order, alpha, beta, dtype = None):
return get_spherical_harmonics(order, theta = (pi - beta), phi = alpha)
.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\reversible.py
import torch # 导入 PyTorch 库
import torch.nn as nn # 导入 PyTorch 中的神经网络模块
from torch.autograd.function import Function # 导入 PyTorch 中的自动微分函数
from torch.utils.checkpoint import get_device_states, set_device_states # 导入 PyTorch 中的检查点函数
# 辅助函数
def map_values(fn, x): # 定义一个函数,对字典中的值应用给定函数
out = {}
for (k, v) in x.items():
out[k] = fn(v)
return out
def dict_chunk(x, chunks, dim): # 定义一个函数,将字典中的值按给定维度和块数进行分块
out1 = {}
out2 = {}
for (k, v) in x.items():
c1, c2 = v.chunk(chunks, dim=dim)
out1[k] = c1
out2[k] = c2
return out1, out2
def dict_sum(x, y): # 定义一个函数,对两个字典中的值进行相加
out = {}
for k in x.keys():
out[k] = x[k] + y[k]
return out
def dict_subtract(x, y): # 定义一个函数,对两个字典中的值进行相减
out = {}
for k in x.keys():
out[k] = x[k] - y[k]
return out
def dict_cat(x, y, dim): # 定义一个函数,对两个字典中的值按给定维度进行拼接
out = {}
for k, v1 in x.items():
v2 = y[k]
out[k] = torch.cat((v1, v2), dim=dim)
return out
def dict_set_(x, key, value): # 定义一个函数,设置字典中所有值的指定属性为给定值
for k, v in x.items():
setattr(v, key, value)
def dict_backwards_(outputs, grad_tensors): # 定义一个函数,对字典中的值进行反向传播
for k, v in outputs.items():
torch.autograd.backward(v, grad_tensors[k], retain_graph=True)
def dict_del_(x): # 定义一个函数,删除字典中的所有值
for k, v in x.items():
del v
del x
def values(d): # 定义一个函数,返回字典中所有值的列表
return [v for _, v in d.items()]
# 参考以下示例保存和设置随机数生成器 https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html
class Deterministic(nn.Module): # 定义一个类,用于确定性计算
def __init__(self, net):
super().__init__()
self.net = net
self.cpu_state = None
self.cuda_in_fwd = None
self.gpu_devices = None
self.gpu_states = None
def record_rng(self, *args): # 记录随机数生成器状态
self.cpu_state = torch.get_rng_state()
if torch.cuda._initialized:
self.cuda_in_fwd = True
self.gpu_devices, self.gpu_states = get_device_states(*args)
def forward(self, *args, record_rng=False, set_rng=False, **kwargs): # 前向传播函数
if record_rng:
self.record_rng(*args)
if not set_rng:
return self.net(*args, **kwargs)
rng_devices = []
if self.cuda_in_fwd:
rng_devices = self.gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=True):
torch.set_rng_state(self.cpu_state)
if self.cuda_in_fwd:
set_device_states(self.gpu_devices, self.gpu_states)
return self.net(*args, **kwargs)
# 受 https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 启发
# 一旦多 GPU 工作正常,重构并将 PR 发回源代码
class ReversibleBlock(nn.Module): # 定义一个可逆块
def __init__(self, f, g):
super().__init__()
self.f = Deterministic(f)
self.g = Deterministic(g)
def forward(self, x, **kwargs): # 前向传播函数
training = self.training
x1, x2 = dict_chunk(x, 2, dim=-1)
y1, y2 = None, None
with torch.no_grad():
y1 = dict_sum(x1, self.f(x2, record_rng=training, **kwargs))
y2 = dict_sum(x2, self.g(y1, record_rng=training))
return dict_cat(y1, y2, dim=-1)
# 定义反向传播函数,接收输入 y、梯度 dy 和其他参数
def backward_pass(self, y, dy, **kwargs):
# 将 y 按照指定维度分成两部分 y1 和 y2
y1, y2 = dict_chunk(y, 2, dim = -1)
# 删除原始 y 字典
dict_del_(y)
# 将 dy 按照指定维度分成两部分 dy1 和 dy2
dy1, dy2 = dict_chunk(dy, 2, dim = -1)
# 删除原始 dy 字典
dict_del_(dy)
# 开启梯度追踪
with torch.enable_grad():
# 设置 y1 的 requires_grad 为 True
dict_set_(y1, 'requires_grad', True)
# 计算 y1 的梯度 gy1
gy1 = self.g(y1, set_rng = True)
# 对 gy1 进行反向传播,传入 dy2
dict_backwards_(gy1, dy2)
# 关闭梯度追踪
with torch.no_grad():
# 计算 x2,即 y2 减去 gy1
x2 = dict_subtract(y2, gy1)
# 删除 y2 和 gy1
dict_del_(y2)
dict_del_(gy1)
# 计算 dx1,即 dy1 加上 y1 中各张量的梯度
dx1 = dict_sum(dy1, map_values(lambda t: t.grad, y1))
# 删除 dy1,并将 y1 的梯度设为 None
dict_del_(dy1)
dict_set_(y1, 'grad', None)
# 开启梯度追踪
with torch.enable_grad():
# 设置 x2 的 requires_grad 为 True
dict_set_(x2, 'requires_grad', True)
# 计算 fx2,即对 x2 进行操作并计算梯度
fx2 = self.f(x2, set_rng = True, **kwargs)
# 对 fx2 进行反向传播,传入 dx1
dict_backwards_(fx2, dx1)
# 关闭梯度追踪
with torch.no_grad():
# 计算 x1,即 y1 减去 fx2
x1 = dict_subtract(y1, fx2)
# 删除 y1 和 fx2
dict_del_(y1)
dict_del_(fx2)
# 计算 dx2,即 dy2 加上 x2 中各张量的梯度
dx2 = dict_sum(dy2, map_values(lambda t: t.grad, x2))
# 删除 dy2,并将 x2 的梯度设为 None
dict_del_(dy2)
dict_set_(x2, 'grad', None)
# 将 x2 中的张量都 detach,即不再追踪梯度
x2 = map_values(lambda t: t.detach(), x2)
# 将 x1 和 x2 按照指定维度拼接成 x
x = dict_cat(x1, x2, dim = -1)
# 将 dx1 和 dx2 按照指定维度拼接成 dx
dx = dict_cat(dx1, dx2, dim = -1)
# 返回拼接后的 x 和 dx
return x, dx
class _ReversibleFunction(Function):
# 定义一个继承自Function的类_ReversibleFunction
@staticmethod
def forward(ctx, x, blocks, kwargs):
# 定义静态方法forward,接受输入x、blocks和kwargs
input_keys = kwargs.pop('input_keys')
# 从kwargs中弹出键为'input_keys'的值
split_dims = kwargs.pop('split_dims')
# 从kwargs中弹出键为'split_dims'的值
input_values = x.split(split_dims, dim = -1)
# 将输入x按照split_dims进行分割,得到输入值列表
x = dict(zip(input_keys, input_values))
# 将输入键和值列表组合成字典
ctx.kwargs = kwargs
ctx.split_dims = split_dims
ctx.input_keys = input_keys
# 将kwargs、split_dims和input_keys保存在上下文对象ctx中
for block in blocks:
x = block(x, **kwargs)
# 遍历blocks中的每个块,对输入x进行处理
ctx.y = map_values(lambda t: t.detach(), x)
# 将x中的值进行detach操作,保存在ctx.y中
ctx.blocks = blocks
# 将blocks保存在ctx.blocks中
x = torch.cat(values(x), dim = -1)
# 将x中的值按照dim = -1进行拼接
return x
# 返回处理后的x
@staticmethod
def backward(ctx, dy):
# 定义静态方法backward,接受输入dy
y = ctx.y
kwargs = ctx.kwargs
input_keys = ctx.input_keys
split_dims = ctx.split_dims
# 从上下文对象ctx中获取y、kwargs、input_keys和split_dims
dy = dy.split(split_dims, dim = -1)
# 将dy按照split_dims进行分割
dy = dict(zip(input_keys, dy))
# 将分割后的dy与input_keys组合成字典
for block in ctx.blocks[::-1]:
y, dy = block.backward_pass(y, dy, **kwargs)
# 逆序遍历ctx.blocks中的每个块,对y和dy进行反向传播
dy = torch.cat(values(dy), dim = -1)
# 将dy中的值按照dim = -1进行拼接
return dy, None, None
# 返回处理后的dy,以及None值
class SequentialSequence(nn.Module):
# 定义一个继承自nn.Module的类SequentialSequence
def __init__(self, blocks):
# 初始化方法,接受blocks作为参数
super().__init__()
self.blocks = blocks
# 调用父类的初始化方法,并将blocks保存在self.blocks中
def forward(self, x, **kwargs):
# 前向传播方法,接受输入x和kwargs
for (attn, ff) in self.blocks:
x = attn(x, **kwargs)
x = ff(x)
# 遍历self.blocks中的每个元素,对输入x进行处理
return x
# 返回处理后的x
class ReversibleSequence(nn.Module):
# 定义一个继承自nn.Module的类ReversibleSequence
def __init__(self, blocks):
# 初始化方法,接受blocks作为参数
super().__init__()
self.blocks = nn.ModuleList([ReversibleBlock(f, g) for (f, g) in blocks])
# 调用父类的初始化方法,并将blocks中的每个元素(f, g)构建成ReversibleBlock对象保存在self.blocks中
def forward(self, x, **kwargs):
# 前向传播方法,接受输入x和kwargs
blocks = self.blocks
x = map_values(lambda t: torch.cat((t, t), dim = -1), x)
# 对输入x中的值进行操作,将每个值与自身拼接
input_keys = x.keys()
split_dims = tuple(map(lambda t: t.shape[-1], x.values()))
# 获取输入x的键和每个值的最后一个维度大小,保存在split_dims中
block_kwargs = {'input_keys': input_keys, 'split_dims': split_dims, **kwargs}
# 构建块的参数字典,包括input_keys、split_dims和kwargs
x = torch.cat(values(x), dim = -1)
# 将输入x中的值按照dim = -1进行拼接
x = _ReversibleFunction.apply(x, blocks, block_kwargs)
# 调用_ReversibleFunction的apply方法进行处理
x = dict(zip(input_keys, x.split(split_dims, dim = -1)))
# 将处理后的x按照split_dims进行分割,组合成字典
x = map_values(lambda t: torch.stack(t.chunk(2, dim = -1)).mean(dim = 0), x)
# 对x中的值进行操作,拆分成两部分,取平均值
return x
# 返回处理后的x
.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\rotary.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 定义 SinusoidalEmbeddings 类,继承自 nn.Module
class SinusoidalEmbeddings(nn.Module):
# 初始化函数,接受维度参数 dim
def __init__(self, dim):
# 调用父类的初始化函数
super().__init__()
# 计算频率的倒数,用于生成正弦位置编码
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
# 将频率的倒数作为缓冲区注册到模型中
self.register_buffer('inv_freq', inv_freq)
# 前向传播函数,接受输入张量 t
def forward(self, t):
# 计算频率,用于生成正弦位置编码
freqs = t[..., None].float() * self.inv_freq[None, :]
# 将频率重复两次,用于位置编码
return repeat(freqs, '... d -> ... (d r)', r = 2)
# 定义 rotate_half 函数,用于旋转输入张量的一半
def rotate_half(x):
# 重新排列输入张量的维度
x = rearrange(x, '... (d j) m -> ... d j m', j = 2)
# 将输入张量按照最后一个维度拆分为两部分
x1, x2 = x.unbind(dim = -2)
# 将两部分张量进行旋转并拼接在一起
return torch.cat((-x2, x1), dim = -2)
# 定义 apply_rotary_pos_emb 函数,用于应用旋转位置编码
def apply_rotary_pos_emb(t, freqs):
# 获取旋转维度的大小
rot_dim = freqs.shape[-2]
# 将输入张量 t 拆分为旋转部分和非旋转部分
t, t_pass = t[..., :rot_dim, :], t[..., rot_dim:, :]
# 应用旋转位置编码到输入张量 t
t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
# 将旋转部分和非旋转部分拼接在一起
return torch.cat((t, t_pass), dim = -2)
.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\se3_transformer_pytorch.py
# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 从 itertools 模块中导入 product 函数
from itertools import product
# 从 collections 模块中导入 namedtuple 类
from collections import namedtuple
# 导入 torch 库
import torch
# 从 torch.nn.functional 模块中导入 F 别名
import torch.nn.functional as F
# 从 torch 模块中导入 nn 和 einsum 函数
from torch import nn, einsum
# 导入自定义模块
from se3_transformer_pytorch.basis import get_basis
from se3_transformer_pytorch.utils import exists, default, uniq, map_values, batched_index_select, masked_mean, to_order, fourier_encode, cast_tuple, safe_cat, fast_split, rand_uniform, broadcat
from se3_transformer_pytorch.reversible import ReversibleSequence, SequentialSequence
from se3_transformer_pytorch.rotary import SinusoidalEmbeddings, apply_rotary_pos_emb
# 从 einops 模块中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 定义命名元组 FiberEl,包含 degrees 和 dim 两个字段
FiberEl = namedtuple('FiberEl', ['degrees', 'dim'])
# 定义 Fiber 类
class Fiber(nn.Module):
def __init__(
self,
structure
):
super().__init__()
# 如果 structure 是字典,则转换为列表形式
if isinstance(structure, dict):
structure = [FiberEl(degree, dim) for degree, dim in structure.items()]
self.structure = structure
# 返回所有维度的列表
@property
def dims(self):
return uniq(map(lambda t: t[1], self.structure))
# 返回所有度数的生成器
@property
def degrees(self):
return map(lambda t: t[0], self.structure)
# 创建 Fiber 实例
@staticmethod
def create(num_degrees, dim):
dim_tuple = dim if isinstance(dim, tuple) else ((dim,) * num_degrees)
return Fiber([FiberEl(degree, dim) for degree, dim in zip(range(num_degrees), dim_tuple)])
# 获取指定度数的元素
def __getitem__(self, degree):
return dict(self.structure)[degree]
# 迭代器方法
def __iter__(self):
return iter(self.structure)
# 定义乘法操作
def __mul__(self, fiber):
return product(self.structure, fiber.structure)
# 定义与操作
def __and__(self, fiber):
out = []
degrees_out = fiber.degrees
for degree, dim in self:
if degree in fiber.degrees:
dim_out = fiber[degree]
out.append((degree, dim, dim_out))
return out
# 获取张量的设备和数据类型
def get_tensor_device_and_dtype(features):
first_tensor = next(iter(features.items()))[1]
return first_tensor.device, first_tensor.dtype
# 定义 ResidualSE3 类
class ResidualSE3(nn.Module):
""" only support instance where both Fibers are identical """
def forward(self, x, res):
out = {}
for degree, tensor in x.items():
degree = str(degree)
out[degree] = tensor
if degree in res:
out[degree] = out[degree] + res[degree]
return out
# 定义 LinearSE3 类
class LinearSE3(nn.Module):
def __init__(
self,
fiber_in,
fiber_out
):
super().__init__()
self.weights = nn.ParameterDict()
for (degree, dim_in, dim_out) in (fiber_in & fiber_out):
key = str(degree)
self.weights[key] = nn.Parameter(torch.randn(dim_in, dim_out) / sqrt(dim_in))
def forward(self, x):
out = {}
for degree, weight in self.weights.items():
out[degree] = einsum('b n d m, d e -> b n e m', x[degree], weight)
return out
# 定义 NormSE3 类
class NormSE3(nn.Module):
"""Norm-based SE(3)-equivariant nonlinearity.
Nonlinearities are important in SE(3) equivariant GCNs. They are also quite
expensive to compute, so it is convenient for them to share resources with
other layers, such as normalization. The general workflow is as follows:
> for feature type in features:
> norm, phase <- feature
> output = fnc(norm) * phase
where fnc: {R+}^m -> R^m is a learnable map from m norms to m scalars.
"""
def __init__(
self,
fiber,
nonlin = nn.GELU(),
gated_scale = False,
eps = 1e-12,
# 初始化函数,设置初始参数
def __init__(
self,
fiber,
nonlin = nn.ReLU(),
eps = 1e-12,
gated_scale = False
):
# 调用父类的初始化函数
super().__init__()
# 将参数赋值给对象属性
self.fiber = fiber
self.nonlin = nonlin
self.eps = eps
# Norm mappings: 1 per feature type
# 创建一个 ModuleDict 对象,用于存储每种特征类型的规范化映射
self.transform = nn.ModuleDict()
# 遍历 fiber 中的每个元素
for degree, chan in fiber:
# 为每种特征类型创建一个参数字典
self.transform[str(degree)] = nn.ParameterDict({
'scale': nn.Parameter(torch.ones(1, 1, chan)) if not gated_scale else None,
'w_gate': nn.Parameter(rand_uniform((chan, chan), -1e-3, 1e-3)) if gated_scale else None
})
# 前向传播函数
def forward(self, features):
# 初始化输出字典
output = {}
# 遍历输入的特征字典
for degree, t in features.items():
# 计算规范化和归一化特征
norm = t.norm(dim = -1, keepdim = True).clamp(min = self.eps)
phase = t / norm
# Transform on norms
# 获取当前特征类型对应的参数
parameters = self.transform[degree]
gate_weights, scale = parameters['w_gate'], parameters['scale']
# 重排特征
transformed = rearrange(norm, '... () -> ...')
# 如果缺少 scale 参数,则使用 gate_weights 进行计算
if not exists(scale):
scale = einsum('b n d, d e -> b n e', transformed, gate_weights)
# 对特征进行非线性变换
transformed = self.nonlin(transformed * scale)
transformed = rearrange(transformed, '... -> ... ()')
# 对规范化特征进行非线性变换
output[degree] = (transformed * phase).view(*t.shape)
# 返回输出字典
return output
class ConvSE3(nn.Module):
"""定义一个张量场网络层
ConvSE3代表一个SE(3)-等变卷积层。它相当于MLP中的线性层,CNN中的卷积层,或者GCN中的图卷积层。
在每个节点上,激活被分成不同的“特征类型”,由SE(3)表示类型索引:非负整数0, 1, 2, ..
"""
def __init__(
self,
fiber_in,
fiber_out,
self_interaction = True,
pool = True,
edge_dim = 0,
fourier_encode_dist = False,
num_fourier_features = 4,
splits = 4
):
super().__init__()
self.fiber_in = fiber_in
self.fiber_out = fiber_out
self.edge_dim = edge_dim
self.self_interaction = self_interaction
self.num_fourier_features = num_fourier_features
self.fourier_encode_dist = fourier_encode_dist
# radial function will assume a dimension of at minimum 1, for the relative distance - extra fourier features must be added to the edge dimension
edge_dim += (0 if not fourier_encode_dist else (num_fourier_features * 2))
# Neighbor -> center weights
self.kernel_unary = nn.ModuleDict()
self.splits = splits # for splitting the computation of kernel and basis, to reduce peak memory usage
for (di, mi), (do, mo) in (self.fiber_in * self.fiber_out):
self.kernel_unary[f'({di},{do})'] = PairwiseConv(di, mi, do, mo, edge_dim = edge_dim, splits = splits)
self.pool = pool
# Center -> center weights
if self_interaction:
assert self.pool, 'must pool edges if followed with self interaction'
self.self_interact = LinearSE3(fiber_in, fiber_out)
self.self_interact_sum = ResidualSE3()
def forward(
self,
inp,
edge_info,
rel_dist = None,
basis = None
):
# 获取拆分信息
splits = self.splits
neighbor_indices, neighbor_masks, edges = edge_info
# 重新排列相对距离的维度
rel_dist = rearrange(rel_dist, 'b m n -> b m n ()')
kernels = {}
outputs = {}
if self.fourier_encode_dist:
# 对相对距离进行傅立叶编码
rel_dist = fourier_encode(rel_dist[..., None], num_encodings = self.num_fourier_features)
# 拆分基础
basis_keys = basis.keys()
split_basis_values = list(zip(*list(map(lambda t: fast_split(t, splits, dim = 1), basis.values())))
split_basis = list(map(lambda v: dict(zip(basis_keys, v)), split_basis_values))
# 遍历每种输入度类型到输出度类型的排列组合
for degree_out in self.fiber_out.degrees:
output = 0
degree_out_key = str(degree_out)
for degree_in, m_in in self.fiber_in:
etype = f'({degree_in},{degree_out})'
x = inp[str(degree_in)]
x = batched_index_select(x, neighbor_indices, dim = 1)
x = x.view(*x.shape[:3], to_order(degree_in) * m_in, 1)
kernel_fn = self.kernel_unary[etype]
edge_features = torch.cat((rel_dist, edges), dim = -1) if exists(edges) else rel_dist
output_chunk = None
split_x = fast_split(x, splits, dim = 1)
split_edge_features = fast_split(edge_features, splits, dim = 1)
# 沿着序列维度对输入、边缘和基础进行分块处理
for x_chunk, edge_features, basis in zip(split_x, split_edge_features, split_basis):
kernel = kernel_fn(edge_features, basis = basis)
chunk = einsum('... o i, ... i c -> ... o c', kernel, x_chunk)
output_chunk = safe_cat(output_chunk, chunk, dim = 1)
output = output + output_chunk
if self.pool:
output = masked_mean(output, neighbor_masks, dim = 2) if exists(neighbor_masks) else output.mean(dim = 2)
leading_shape = x.shape[:2] if self.pool else x.shape[:3]
output = output.view(*leading_shape, -1, to_order(degree_out))
outputs[degree_out_key] = output
if self.self_interaction:
self_interact_out = self.self_interact(inp)
outputs = self.self_interact_sum(outputs, self_interact_out)
return outputs
class RadialFunc(nn.Module):
"""定义一个神经网络参数化的径向函数。"""
def __init__(
self,
num_freq,
in_dim,
out_dim,
edge_dim = None,
mid_dim = 128
):
super().__init__()
self.num_freq = num_freq
self.in_dim = in_dim
self.mid_dim = mid_dim
self.out_dim = out_dim
self.edge_dim = default(edge_dim, 0)
self.net = nn.Sequential(
nn.Linear(self.edge_dim + 1, mid_dim),
nn.LayerNorm(mid_dim),
nn.GELU(),
nn.Linear(mid_dim, mid_dim),
nn.LayerNorm(mid_dim),
nn.GELU(),
nn.Linear(mid_dim, num_freq * in_dim * out_dim)
)
def forward(self, x):
y = self.net(x)
return rearrange(y, '... (o i f) -> ... o () i () f', i = self.in_dim, o = self.out_dim)
class PairwiseConv(nn.Module):
"""两种单一类型特征之间的SE(3)-等变卷积。"""
def __init__(
self,
degree_in,
nc_in,
degree_out,
nc_out,
edge_dim = 0,
splits = 4
):
super().__init__()
self.degree_in = degree_in
self.degree_out = degree_out
self.nc_in = nc_in
self.nc_out = nc_out
self.num_freq = to_order(min(degree_in, degree_out))
self.d_out = to_order(degree_out)
self.edge_dim = edge_dim
self.rp = RadialFunc(self.num_freq, nc_in, nc_out, edge_dim)
self.splits = splits
def forward(self, feat, basis):
splits = self.splits
R = self.rp(feat)
B = basis[f'{self.degree_in},{self.degree_out}']
out_shape = (*R.shape[:3], self.d_out * self.nc_out, -1)
# torch.sum(R * B, dim = -1) is too memory intensive
# needs to be chunked to reduce peak memory usage
out = 0
for i in range(R.shape[-1]):
out += R[..., i] * B[..., i]
out = rearrange(out, 'b n h s ... -> (b n h s) ...')
# reshape and out
return out.view(*out_shape)
# feed forwards
class FeedForwardSE3(nn.Module):
def __init__(
self,
fiber,
mult = 4
):
super().__init__()
self.fiber = fiber
fiber_hidden = Fiber(list(map(lambda t: (t[0], t[1] * mult), fiber)))
self.project_in = LinearSE3(fiber, fiber_hidden)
self.nonlin = NormSE3(fiber_hidden)
self.project_out = LinearSE3(fiber_hidden, fiber)
def forward(self, features):
outputs = self.project_in(features)
outputs = self.nonlin(outputs)
outputs = self.project_out(outputs)
return outputs
class FeedForwardBlockSE3(nn.Module):
def __init__(
self,
fiber,
norm_gated_scale = False
):
super().__init__()
self.fiber = fiber
self.prenorm = NormSE3(fiber, gated_scale = norm_gated_scale)
self.feedforward = FeedForwardSE3(fiber)
self.residual = ResidualSE3()
def forward(self, features):
res = features
out = self.prenorm(features)
out = self.feedforward(out)
return self.residual(out, res)
# attention
class AttentionSE3(nn.Module):
def __init__(
self,
fiber,
dim_head = 64,
heads = 8,
attend_self = False,
edge_dim = None,
fourier_encode_dist = False,
rel_dist_num_fourier_features = 4,
use_null_kv = False,
splits = 4,
global_feats_dim = None,
linear_proj_keys = False,
tie_key_values = False
):
# 调用父类的构造函数
super().__init__()
# 计算隐藏层维度
hidden_dim = dim_head * heads
# 创建隐藏层的 Fiber 对象
hidden_fiber = Fiber(list(map(lambda t: (t[0], hidden_dim), fiber)))
# 判断是否需要进行输出投影
project_out = not (heads == 1 and len(fiber.dims) == 1 and dim_head == fiber.dims[0])
# 设置缩放因子
self.scale = dim_head ** -0.5
self.heads = heads
# 是否对特征进行线性投影以获得 keys
self.linear_proj_keys = linear_proj_keys
# 创建 LinearSE3 对象用于处理 queries
self.to_q = LinearSE3(fiber, hidden_fiber)
# 创建 ConvSE3 对象用于处理 values
self.to_v = ConvSE3(fiber, hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)
# 检查是否同时进行线性投影 keys 和共享 key / values
assert not (linear_proj_keys and tie_key_values), 'you cannot do linear projection of keys and have shared key / values turned on at the same time'
# 根据不同情况创建 keys 处理对象
if linear_proj_keys:
self.to_k = LinearSE3(fiber, hidden_fiber)
elif not tie_key_values:
self.to_k = ConvSE3(fiber, hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)
else:
self.to_k = None
# 创建输出处理对象
self.to_out = LinearSE3(hidden_fiber, fiber) if project_out else nn.Identity()
# 是否使用空的 keys 和 values
self.use_null_kv = use_null_kv
if use_null_kv:
self.null_keys = nn.ParameterDict()
self.null_values = nn.ParameterDict()
# 初始化空的 keys 和 values
for degree in fiber.degrees:
m = to_order(degree)
degree_key = str(degree)
self.null_keys[degree_key] = nn.Parameter(torch.zeros(heads, dim_head, m))
self.null_values[degree_key] = nn.Parameter(torch.zeros(heads, dim_head, m))
# 是否自我关注
self.attend_self = attend_self
if attend_self:
# 创建自我关注的 keys 处理对象
self.to_self_k = LinearSE3(fiber, hidden_fiber)
# 创建自我关注的 values 处理对象
self.to_self_v = LinearSE3(fiber, hidden_fiber)
# 是否接受全局特征
self.accept_global_feats = exists(global_feats_dim)
if self.accept_global_feats:
# 创建全局特征的 keys 处理对象
global_input_fiber = Fiber.create(1, global_feats_dim)
global_output_fiber = Fiber.create(1, hidden_fiber[0])
self.to_global_k = LinearSE3(global_input_fiber, global_output_fiber)
# 创建全局特征的 values 处理对象
self.to_global_v = LinearSE3(global_input_fiber, global_output_fiber)
# 定义前向传播函数,接收特征、边信息、相对距离、基础信息、全局特征、位置嵌入和掩码作为输入
def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None, mask = None):
# 获取头数和是否自我关注的标志
h, attend_self = self.heads, self.attend_self
# 获取设备和数据类型
device, dtype = get_tensor_device_and_dtype(features)
# 解包边信息
neighbor_indices, neighbor_mask, edges = edge_info
# 如果邻居掩码存在,则重排维度
if exists(neighbor_mask):
neighbor_mask = rearrange(neighbor_mask, 'b i j -> b () i j')
# 将特征转换为查询、值和键
queries = self.to_q(features)
values = self.to_v(features, edge_info, rel_dist, basis)
# 如果使用线性投影的键,则将键映射到邻居索引
if self.linear_proj_keys:
keys = self.to_k(features)
keys = map_values(lambda val: batched_index_select(val, neighbor_indices, dim = 1), keys)
# 如果没有定义键转换函数,则将键设置为值
elif not exists(self.to_k):
keys = values
else:
keys = self.to_k(features, edge_info, rel_dist, basis)
# 如果允许自我关注,则获取自我键和自我值
if attend_self:
self_keys, self_values = self.to_self_k(features), self.to_self_v(features)
# 如果全局特征存在,则获取全局键和全局值
if exists(global_feats):
global_keys, global_values = self.to_global_k(global_feats), self.to_global_v(global_feats)
# 初始化输出字典
outputs = {}
# 遍历特征的度
for degree in features.keys():
# 获取当前度的查询、键和值
q, k, v = map(lambda t: t[degree], (queries, keys, values))
# 重排查询、键和值的维度
q = rearrange(q, 'b i (h d) m -> b h i d m', h = h)
k, v = map(lambda t: rearrange(t, 'b i j (h d) m -> b h i j d m', h = h), (k, v))
# 如果允许自我关注,则处理自我键和自我值
if attend_self:
self_k, self_v = map(lambda t: t[degree], (self_keys, self_values))
self_k, self_v = map(lambda t: rearrange(t, 'b n (h d) m -> b h n () d m', h = h), (self_k, self_v))
k = torch.cat((self_k, k), dim = 3)
v = torch.cat((self_v, v), dim = 3)
# 如果位置嵌入存在且度为'0',则应用旋转位置嵌入
if exists(pos_emb) and degree == '0':
query_pos_emb, key_pos_emb = pos_emb
query_pos_emb = rearrange(query_pos_emb, 'b i d -> b () i d ()')
key_pos_emb = rearrange(key_pos_emb, 'b i j d -> b () i j d ()')
q = apply_rotary_pos_emb(q, query_pos_emb)
k = apply_rotary_pos_emb(k, key_pos_emb)
v = apply_rotary_pos_emb(v, key_pos_emb)
# 如果使用空键值对,则处理空键和空值
if self.use_null_kv:
null_k, null_v = map(lambda t: t[degree], (self.null_keys, self.null_values))
null_k, null_v = map(lambda t: repeat(t, 'h d m -> b h i () d m', b = q.shape[0], i = q.shape[2]), (null_k, null_v))
k = torch.cat((null_k, k), dim = 3)
v = torch.cat((null_v, v), dim = 3)
# 如果全局特征存在且度为'0',则处理全局键和全局值
if exists(global_feats) and degree == '0':
global_k, global_v = map(lambda t: t[degree], (global_keys, global_values))
global_k, global_v = map(lambda t: repeat(t, 'b j (h d) m -> b h i j d m', h = h, i = k.shape[2]), (global_k, global_v))
k = torch.cat((global_k, k), dim = 3)
v = torch.cat((global_v, v), dim = 3)
# 计算注意力权重
sim = einsum('b h i d m, b h i j d m -> b h i j', q, k) * self.scale
# 如果邻居掩码存在,则进行掩码处理
if exists(neighbor_mask):
num_left_pad = sim.shape[-1] - neighbor_mask.shape[-1]
mask = F.pad(neighbor_mask, (num_left_pad, 0), value = True)
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
# 计算注意力输出
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h i j d m -> b h i d m', attn, v)
outputs[degree] = rearrange(out, 'b h n d m -> b n (h d) m')
# 返回输出结果
return self.to_out(outputs)
# 定义一个带有一个键/值投影的注意力机制类,该投影在所有查询头之间共享
class OneHeadedKVAttentionSE3(nn.Module):
def __init__(
self,
fiber,
dim_head = 64,
heads = 8,
attend_self = False,
edge_dim = None,
fourier_encode_dist = False,
rel_dist_num_fourier_features = 4,
use_null_kv = False,
splits = 4,
global_feats_dim = None,
linear_proj_keys = False,
tie_key_values = False
):
super().__init__()
hidden_dim = dim_head * heads
hidden_fiber = Fiber(list(map(lambda t: (t[0], hidden_dim), fiber)))
kv_hidden_fiber = Fiber(list(map(lambda t: (t[0], dim_head), fiber)))
project_out = not (heads == 1 and len(fiber.dims) == 1 and dim_head == fiber.dims[0])
self.scale = dim_head ** -0.5
self.heads = heads
self.linear_proj_keys = linear_proj_keys # 是否对键进行线性投影,而不是与基卷积
# 创建查询线性层
self.to_q = LinearSE3(fiber, hidden_fiber)
# 创建值卷积层
self.to_v = ConvSE3(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)
assert not (linear_proj_keys and tie_key_values), 'you cannot do linear projection of keys and have shared key / values turned on at the same time'
if linear_proj_keys:
# 如果进行线性投影,则创建键的线性层
self.to_k = LinearSE3(fiber, kv_hidden_fiber)
elif not tie_key_values:
# 如果不共享键/值,则创建键的卷积层
self.to_k = ConvSE3(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, fourier_encode_dist = fourier_encode_dist, num_fourier_features = rel_dist_num_fourier_features, splits = splits)
else:
self.to_k = None
# 创建输出线性层
self.to_out = LinearSE3(hidden_fiber, fiber) if project_out else nn.Identity()
self.use_null_kv = use_null_kv
if use_null_kv:
# 如果使用空键/值,则创建空键和值的参数字典
self.null_keys = nn.ParameterDict()
self.null_values = nn.ParameterDict()
for degree in fiber.degrees:
m = to_order(degree)
degree_key = str(degree)
self.null_keys[degree_key] = nn.Parameter(torch.zeros(dim_head, m))
self.null_values[degree_key] = nn.Parameter(torch.zeros(dim_head, m))
self.attend_self = attend_self
if attend_self:
# 如果自我关注,则创建自我键和值的线性层
self.to_self_k = LinearSE3(fiber, kv_hidden_fiber)
self.to_self_v = LinearSE3(fiber, kv_hidden_fiber)
self.accept_global_feats = exists(global_feats_dim)
if self.accept_global_feats:
# 如果接受全局特征,则创建全局键和值的线性层
global_input_fiber = Fiber.create(1, global_feats_dim)
global_output_fiber = Fiber.create(1, kv_hidden_fiber[0])
self.to_global_k = LinearSE3(global_input_fiber, global_output_fiber)
self.to_global_v = LinearSE3(global_input_fiber, global_output_fiber)
# 定义前向传播函数,接收特征、边信息、相对距离、基础信息、全局特征、位置嵌入和掩码作为输入
def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None, mask = None):
# 获取头数和是否自我关注的标志
h, attend_self = self.heads, self.attend_self
# 获取设备和数据类型
device, dtype = get_tensor_device_and_dtype(features)
# 解包边信息
neighbor_indices, neighbor_mask, edges = edge_info
# 如果存在邻居掩码,则重排维度
if exists(neighbor_mask):
neighbor_mask = rearrange(neighbor_mask, 'b i j -> b () i j')
# 将特征转换为查询、值和键
queries = self.to_q(features)
values = self.to_v(features, edge_info, rel_dist, basis)
# 如果使用线性投影的键,则将键映射到相应的位置
if self.linear_proj_keys:
keys = self.to_k(features)
keys = map_values(lambda val: batched_index_select(val, neighbor_indices, dim = 1), keys)
# 如果没有定义键转换函数,则将键设置为值
elif not exists(self.to_k):
keys = values
else:
keys = self.to_k(features, edge_info, rel_dist, basis)
# 如果允许自我关注,则获取自我关注的键和值
if attend_self:
self_keys, self_values = self.to_self_k(features), self.to_self_v(features)
# 如果存在全局特征,则获取全局键和值
if exists(global_feats):
global_keys, global_values = self.to_global_k(global_feats), self.to_global_v(global_feats)
# 初始化输出字典
outputs = {}
# 遍历特征的度
for degree in features.keys():
# 获取当前度的查询、键和值
q, k, v = map(lambda t: t[degree], (queries, keys, values))
# 重排查询的维度
q = rearrange(q, 'b i (h d) m -> b h i d m', h = h)
# 如果允许自我关注,则处理自我关注的键和值
if attend_self:
self_k, self_v = map(lambda t: t[degree], (self_keys, self_values))
self_k, self_v = map(lambda t: rearrange(t, 'b n d m -> b n () d m'), (self_k, self_v))
k = torch.cat((self_k, k), dim = 2)
v = torch.cat((self_v, v), dim = 2)
# 如果存在位置嵌入并且度为 '0',则应用旋转位置嵌入
if exists(pos_emb) and degree == '0':
query_pos_emb, key_pos_emb = pos_emb
query_pos_emb = rearrange(query_pos_emb, 'b i d -> b () i d ()')
key_pos_emb = rearrange(key_pos_emb, 'b i j d -> b i j d ()')
q = apply_rotary_pos_emb(q, query_pos_emb)
k = apply_rotary_pos_emb(k, key_pos_emb)
v = apply_rotary_pos_emb(v, key_pos_emb)
# 如果使用空键值对,则将空键值对与当前键值对拼接
if self.use_null_kv:
null_k, null_v = map(lambda t: t[degree], (self.null_keys, self.null_values))
null_k, null_v = map(lambda t: repeat(t, 'd m -> b i () d m', b = q.shape[0], i = q.shape[2]), (null_k, null_v))
k = torch.cat((null_k, k), dim = 2)
v = torch.cat((null_v, v), dim = 2)
# 如果存在全局特征并且度为 '0',则将全局键值对与当前键值对拼接
if exists(global_feats) and degree == '0':
global_k, global_v = map(lambda t: t[degree], (global_keys, global_values))
global_k, global_v = map(lambda t: repeat(t, 'b j d m -> b i j d m', i = k.shape[1]), (global_k, global_v))
k = torch.cat((global_k, k), dim = 2)
v = torch.cat((global_v, v), dim = 2)
# 计算注意力权重
sim = einsum('b h i d m, b i j d m -> b h i j', q, k) * self.scale
# 如果存在邻居掩码,则进行掩码操作
if exists(neighbor_mask):
num_left_pad = sim.shape[-1] - neighbor_mask.shape[-1]
mask = F.pad(neighbor_mask, (num_left_pad, 0), value = True)
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
# 计算注意力分布并进行加权求和
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b i j d m -> b h i d m', attn, v)
outputs[degree] = rearrange(out, 'b h n d m -> b n (h d) m')
# 将输出转换为最终输出
return self.to_out(outputs)
# 定义一个注意力块类,继承自 nn.Module
class AttentionBlockSE3(nn.Module):
def __init__(
self,
fiber,
dim_head = 24,
heads = 8,
attend_self = False,
edge_dim = None,
use_null_kv = False,
fourier_encode_dist = False,
rel_dist_num_fourier_features = 4,
splits = 4,
global_feats_dim = False,
linear_proj_keys = False,
tie_key_values = False,
attention_klass = AttentionSE3,
norm_gated_scale = False
):
super().__init__()
# 初始化注意力机制
self.attn = attention_klass(fiber, heads = heads, dim_head = dim_head, attend_self = attend_self, edge_dim = edge_dim, use_null_kv = use_null_kv, rel_dist_num_fourier_features = rel_dist_num_fourier_features, fourier_encode_dist =fourier_encode_dist, splits = splits, global_feats_dim = global_feats_dim, linear_proj_keys = linear_proj_keys, tie_key_values = tie_key_values)
# 初始化预处理层
self.prenorm = NormSE3(fiber, gated_scale = norm_gated_scale)
# 初始化残差连接
self.residual = ResidualSE3()
def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None, mask = None):
res = features
# 对输入特征进行预处理
outputs = self.prenorm(features)
# 使用注意力机制处理特征
outputs = self.attn(outputs, edge_info, rel_dist, basis, global_feats, pos_emb, mask)
# 返回残差连接结果
return self.residual(outputs, res)
# 定义 Swish_ 类
class Swish_(nn.Module):
def forward(self, x):
return x * x.sigmoid()
# 如果 nn 模块中有 SiLU 函数,则使用 nn.SiLU,否则使用自定义的 Swish_ 类
SiLU = nn.SiLU if hasattr(nn, 'SiLU') else Swish_
# 定义 HtypesNorm 类
class HtypesNorm(nn.Module):
def __init__(self, dim, eps = 1e-8, scale_init = 1e-2, bias_init = 1e-2):
super().__init__()
self.eps = eps
# 初始化缩放参数和偏置参数
scale = torch.empty(1, 1, 1, dim, 1).fill_(scale_init)
bias = torch.empty(1, 1, 1, dim, 1).fill_(bias_init)
self.scale = nn.Parameter(scale)
self.bias = nn.Parameter(bias)
def forward(self, coors):
# 计算输入张量的范数
norm = coors.norm(dim = -1, keepdim = True)
# 对输入张量进行归一化处理
normed_coors = coors / norm.clamp(min = self.eps)
return normed_coors * (norm * self.scale + self.bias)
# 定义 EGNN 类
class EGNN(nn.Module):
def __init__(
self,
fiber,
hidden_dim = 32,
edge_dim = 0,
init_eps = 1e-3,
coor_weights_clamp_value = None
):
super().__init__()
self.fiber = fiber
node_dim = fiber[0]
htypes = list(filter(lambda t: t.degrees != 0, fiber))
num_htypes = len(htypes)
htype_dims = sum([fiberel.dim for fiberel in htypes])
edge_input_dim = node_dim * 2 + htype_dims + edge_dim + 1
# 初始化节点归一化层
self.node_norm = nn.LayerNorm(node_dim)
# 初始化边 MLP
self.edge_mlp = nn.Sequential(
nn.Linear(edge_input_dim, edge_input_dim * 2),
SiLU(),
nn.Linear(edge_input_dim * 2, hidden_dim),
SiLU()
)
self.htype_norms = nn.ModuleDict({})
self.htype_gating = nn.ModuleDict({})
for degree, dim in fiber:
if degree == 0:
continue
# 初始化 HtypesNorm 和线性层
self.htype_norms[str(degree)] = HtypesNorm(dim)
self.htype_gating[str(degree)] = nn.Linear(node_dim, dim)
# 初始化 Htypes MLP
self.htypes_mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
SiLU(),
nn.Linear(hidden_dim * 4, htype_dims)
)
# 初始化节点 MLP
self.node_mlp = nn.Sequential(
nn.Linear(node_dim + hidden_dim, node_dim * 2),
SiLU(),
nn.Linear(node_dim * 2, node_dim)
)
self.coor_weights_clamp_value = coor_weights_clamp_value
self.init_eps = init_eps
self.apply(self.init_)
def init_(self, module):
if type(module) in {nn.Linear}:
# 初始化线性层的权重
nn.init.normal_(module.weight, std = self.init_eps)
def forward(
self,
features,
edge_info,
rel_dist,
mask = None,
**kwargs
):
# 解包边信息
neighbor_indices, neighbor_masks, edges = edge_info
# 使用邻居掩码
mask = neighbor_masks
# 类型 0 特征
# 获取节点特征
nodes = features['0']
# 重新排列节点特征
nodes = rearrange(nodes, '... () -> ...')
# 更高级别类型(htype)
# 过滤出非 '0' 类型的特征
htypes = list(filter(lambda t: t[0] != '0', features.items()))
# 获取每个类型的度数
htype_degrees = list(map(lambda t: t[0], htypes))
# 获取每个类型的维度
htype_dims = list(map(lambda t: t[1].shape[-2], htypes))
# 准备更高级别类型
rel_htypes = []
rel_htypes_dists = []
for degree, htype in htypes:
# 计算相对类型
rel_htype = rearrange(htype, 'b i d m -> b i () d m') - rearrange(htype, 'b j d m -> b () j d m')
rel_htype_dist = rel_htype.norm(dim = -1)
rel_htypes.append(rel_htype)
rel_htypes_dists.append(rel_htype_dist)
# 为边 MLP 准备边
nodes_i = rearrange(nodes, 'b i d -> b i () d')
nodes_j = batched_index_select(nodes, neighbor_indices, dim = 1)
neighbor_higher_type_dists = map(lambda t: batched_index_select(t, neighbor_indices, dim = 2), rel_htypes_dists)
coor_rel_dist = rearrange(rel_dist, 'b i j -> b i j ()')
edge_mlp_inputs = broadcat((nodes_i, nodes_j, *neighbor_higher_type_dists, coor_rel_dist), dim = -1)
if exists(edges):
edge_mlp_inputs = torch.cat((edge_mlp_inputs, edges), dim = -1)
# 获取中间表示
m_ij = self.edge_mlp(edge_mlp_inputs)
# 转换为坐标
htype_weights = self.htypes_mlp(m_ij)
if exists(self.coor_weights_clamp_value):
clamp_value = self.coor_weights_clamp_value
htype_weights.clamp_(min = -clamp_value, max = clamp_value)
split_htype_weights = htype_weights.split(htype_dims, dim = -1)
htype_updates = []
if exists(mask):
htype_mask = rearrange(mask, 'b i j -> b i j ()')
htype_weights = htype_weights.masked_fill(~htype_mask, 0.)
for degree, rel_htype, htype_weight in zip(htype_degrees, rel_htypes, split_htype_weights):
normed_rel_htype = self.htype_norms[str(degree)](rel_htype)
normed_rel_htype = batched_index_select(normed_rel_htype, neighbor_indices, dim = 2)
htype_update = einsum('b i j d m, b i j d -> b i d m', normed_rel_htype, htype_weight)
htype_updates.append(htype_update)
# 转换为节点
if exists(mask):
m_ij_mask = rearrange(mask, '... -> ... ()')
m_ij = m_ij.masked_fill(~m_ij_mask, 0.)
m_i = m_ij.sum(dim = -2)
normed_nodes = self.node_norm(nodes)
node_mlp_input = torch.cat((normed_nodes, m_i), dim = -1)
node_out = self.node_mlp(node_mlp_input) + nodes
# 更新节点
features['0'] = rearrange(node_out, '... -> ... ()')
# 更新更高级别类型
update_htype_dicts = dict(zip(htype_degrees, htype_updates))
for degree, update_htype in update_htype_dicts.items():
features[degree] = features[degree] + update_htype
for degree in htype_degrees:
gating = self.htype_gating[str(degree)](node_out).sigmoid()
features[degree] = features[degree] * rearrange(gating, '... -> ... ()')
return features
# 定义一个 EGnnNetwork 类,继承自 nn.Module 类
class EGnnNetwork(nn.Module):
# 初始化函数,接收多个参数
def __init__(
self,
*,
fiber,
depth,
edge_dim = 0,
hidden_dim = 32,
coor_weights_clamp_value = None,
feedforward = False
):
# 调用父类的初始化函数
super().__init__()
# 将参数赋值给对象属性
self.fiber = fiber
self.layers = nn.ModuleList([])
# 循环创建指定数量的 EGNN 和 FeedForwardBlockSE3 对象,并添加到 layers 中
for _ in range(depth):
self.layers.append(nn.ModuleList([
EGNN(fiber = fiber, edge_dim = edge_dim, hidden_dim = hidden_dim, coor_weights_clamp_value = coor_weights_clamp_value),
FeedForwardBlockSE3(fiber) if feedforward else None
]))
# 前向传播函数,接收多个参数
def forward(
self,
features,
edge_info,
rel_dist,
basis,
global_feats = None,
pos_emb = None,
mask = None,
**kwargs
):
# 解包 edge_info 参数
neighbor_indices, neighbor_masks, edges = edge_info
# 获取设备信息
device = neighbor_indices.device
# 修改邻居信息以包含自身(因为 SE3 变换器依赖于去除对自身的注意力,但这不适用于 EGNN)
# 创建包含自身索引的张量
self_indices = torch.arange(neighbor_indices.shape[1], device = device)
self_indices = rearrange(self_indices, 'i -> () i ()')
neighbor_indices = broadcat((self_indices, neighbor_indices), dim = -1)
# 对邻居掩码进行填充
neighbor_masks = F.pad(neighbor_masks, (1, 0), value = True)
rel_dist = F.pad(rel_dist, (1, 0), value = 0.)
# 如果存在边信息,则对边信息进行填充
if exists(edges):
edges = F.pad(edges, (0, 0, 1, 0), value = 0.) # 暂时将令牌到自身的边设置为 0
edge_info = (neighbor_indices, neighbor_masks, edges)
# 遍历每一层
for egnn, ff in self.layers:
# 调用 EGNN 对象进行特征变换
features = egnn(
features,
edge_info = edge_info,
rel_dist = rel_dist,
basis = basis,
global_feats = global_feats,
pos_emb = pos_emb,
mask = mask,
**kwargs
)
# 如果存在 FeedForwardBlockSE3 对象,则调用进行特征变换
if exists(ff):
features = ff(features)
return features
# 主类
class SE3Transformer(nn.Module):
# 初始化函数,接收多个参数
def __init__(
self,
*,
dim,
heads = 8,
dim_head = 24,
depth = 2,
input_degrees = 1,
num_degrees = None,
output_degrees = 1,
valid_radius = 1e5,
reduce_dim_out = False,
num_tokens = None,
num_positions = None,
num_edge_tokens = None,
edge_dim = None,
reversible = False,
attend_self = True,
use_null_kv = False,
differentiable_coors = False,
fourier_encode_dist = False,
rel_dist_num_fourier_features = 4,
num_neighbors = float('inf'),
attend_sparse_neighbors = False,
num_adj_degrees = None,
adj_dim = 0,
max_sparse_neighbors = float('inf'),
dim_in = None,
dim_out = None,
norm_out = False,
num_conv_layers = 0,
causal = False,
splits = 4,
global_feats_dim = None,
linear_proj_keys = False,
one_headed_key_values = False,
tie_key_values = False,
rotary_position = False,
rotary_rel_dist = False,
norm_gated_scale = False,
use_egnn = False,
egnn_hidden_dim = 32,
egnn_weights_clamp_value = None,
egnn_feedforward = False,
hidden_fiber_dict = None,
out_fiber_dict = None
# 前向传播函数,接收多个参数
def forward(
self,
feats,
coors,
mask = None,
adj_mat = None,
edges = None,
return_type = None,
return_pooled = False,
neighbor_mask = None,
global_feats = None