Lucidrains-系列项目源码解析-二十一-

Lucidrains 系列项目源码解析(二十一)

.\lucidrains\imagen-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 执行版本文件中的代码,将版本信息导入当前环境
exec(open('imagen_pytorch/version.py').read())

# 设置包的信息
setup(
  # 包名
  name = 'imagen-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 包含所有数据文件
  include_package_data = True,
  # 设置入口点,定义命令行脚本
  entry_points={
    'console_scripts': [
      'imagen_pytorch = imagen_pytorch.cli:main',
      'imagen = imagen_pytorch.cli:imagen'
    ],
  },
  # 版本号
  version = __version__,
  # 许可证
  license='MIT',
  # 描述
  description = 'Imagen - unprecedented photorealism × deep level of language understanding',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/imagen-pytorch',
  # 关键词
  keywords = [
    'artificial intelligence',
    'deep learning',
    'transformers',
    'text-to-image',
    'denoising-diffusion'
  ],
  # 安装依赖
  install_requires=[
    'accelerate>=0.23.0',
    'beartype',
    'click',
    'datasets',
    'einops>=0.7.0',
    'ema-pytorch>=0.0.3',
    'fsspec',
    'kornia',
    'numpy',
    'packaging',
    'pillow',
    'pydantic>=2',
    'pytorch-warmup',
    'sentencepiece',
    'torch>=1.6',
    'torchvision',
    'transformers',
    'tqdm'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Insertion Deletion Denoising Diffusion Probabilistic Models (wip)

Implementation of Insertion Deletion Denoising Diffusion Probabilistic Models. This scheme basically allows for DDPM to work beyond just in-place corruption along the sequence. They try to apply this to text generation with lukewarm results. I think it holds promise for protein design, as it would be able to infill certain regions without being constrained to a fixed number of amino acids.

Citations

@article{Johnson2021BeyondIC,
    title   = {Beyond In-Place Corruption: Insertion and Deletion In Denoising Probabilistic Models},
    author  = {Daniel D. Johnson and Jacob Austin and Rianne van den Berg and Daniel Tarlow},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2107.07675}
}

.\lucidrains\invariant-point-attention\denoise.py

# 导入所需的库
import torch
import torch.nn.functional as F
from torch import nn, einsum
from torch.optim import Adam

# 导入 einops 库中的函数
from einops import rearrange, repeat
# 导入 sidechainnet 库
import sidechainnet as scn
# 导入自定义的模块 invariant_point_attention 中的 IPATransformer 类
from invariant_point_attention import IPATransformer

# 定义批处理大小和梯度累积次数
BATCH_SIZE = 1
GRADIENT_ACCUMULATE_EVERY = 16

# 定义一个循环生成器函数,用于处理数据加载器中的数据
def cycle(loader, len_thres = 200):
    while True:
        for data in loader:
            # 如果序列长度超过阈值,则跳过
            if data.seqs.shape[1] > len_thres:
                continue
            yield data

# 创建 IPATransformer 模型实例
net = IPATransformer(
    dim = 16,
    num_tokens = 21,
    depth = 5,
    require_pairwise_repr = False,
    predict_points = True
).cuda()

# 加载数据集
data = scn.load(
    casp_version = 12,
    thinning = 30,
    with_pytorch = 'dataloaders',
    batch_size = BATCH_SIZE,
    dynamic_batching = False
)

# 创建数据加载器
dl = cycle(data['train'])
# 初始化 Adam 优化器
optim = Adam(net.parameters(), lr=1e-3)

# 迭代训练模型
for _ in range(10000):
    # 梯度累积
    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        # 获取一个批次的数据
        batch = next(dl)
        seqs, coords, masks = batch.seqs, batch.crds, batch.msks

        # 将序列转移到 GPU 并获取最大值索引
        seqs = seqs.cuda().argmax(dim = -1)
        coords = coords.cuda()
        masks = masks.cuda().bool()

        # 获取序列长度并重新排列坐标
        l = seqs.shape[1]
        coords = rearrange(coords, 'b (l s) c -> b l s c', s = 14)

        # 仅保留 Ca 原子坐标
        coords = coords[:, :, 1, :]
        # 添加随机噪声
        noised_coords = coords + torch.randn_like(coords)

        # 输入模型进行去噪处理
        denoised_coords = net(
            seqs,
            translations = noised_coords,
            mask = masks
        )

        # 计算损失
        loss = F.mse_loss(denoised_coords[masks], coords[masks])
        # 反向传播
        (loss / GRADIENT_ACCUMULATE_EVERY).backward()

    # 输出损失值
    print('loss:', loss.item())
    # 更新优化器
    optim.step()
    # 梯度清零
    optim.zero_grad()

.\lucidrains\invariant-point-attention\invariant_point_attention\invariant_point_attention.py

import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
from contextlib import contextmanager
from torch import nn, einsum

from einops.layers.torch import Rearrange
from einops import rearrange, repeat

# helpers

# 检查值是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 返回给定类型的最大负值
def max_neg_value(t):
    return -torch.finfo(t.dtype).max

@contextmanager
def disable_tf32():
    orig_value = torch.backends.cuda.matmul.allow_tf32
    torch.backends.cuda.matmul.allow_tf32 = False
    yield
    torch.backends.cuda.matmul.allow_tf32 = orig_value

# classes

class InvariantPointAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        scalar_key_dim = 16,
        scalar_value_dim = 16,
        point_key_dim = 4,
        point_value_dim = 4,
        pairwise_repr_dim = None,
        require_pairwise_repr = True,
        eps = 1e-8
    ):
        super().__init__()
        self.eps = eps
        self.heads = heads
        self.require_pairwise_repr = require_pairwise_repr

        # num attention contributions

        num_attn_logits = 3 if require_pairwise_repr else 2

        # qkv projection for scalar attention (normal)

        self.scalar_attn_logits_scale = (num_attn_logits * scalar_key_dim) ** -0.5

        self.to_scalar_q = nn.Linear(dim, scalar_key_dim * heads, bias = False)
        self.to_scalar_k = nn.Linear(dim, scalar_key_dim * heads, bias = False)
        self.to_scalar_v = nn.Linear(dim, scalar_value_dim * heads, bias = False)

        # qkv projection for point attention (coordinate and orientation aware)

        point_weight_init_value = torch.log(torch.exp(torch.full((heads,), 1.)) - 1.)
        self.point_weights = nn.Parameter(point_weight_init_value)

        self.point_attn_logits_scale = ((num_attn_logits * point_key_dim) * (9 / 2)) ** -0.5

        self.to_point_q = nn.Linear(dim, point_key_dim * heads * 3, bias = False)
        self.to_point_k = nn.Linear(dim, point_key_dim * heads * 3, bias = False)
        self.to_point_v = nn.Linear(dim, point_value_dim * heads * 3, bias = False)

        # pairwise representation projection to attention bias

        pairwise_repr_dim = default(pairwise_repr_dim, dim) if require_pairwise_repr else 0

        if require_pairwise_repr:
            self.pairwise_attn_logits_scale = num_attn_logits ** -0.5

            self.to_pairwise_attn_bias = nn.Sequential(
                nn.Linear(pairwise_repr_dim, heads),
                Rearrange('b ... h -> (b h) ...')
            )

        # combine out - scalar dim + pairwise dim + point dim * (3 for coordinates in R3 and then 1 for norm)

        self.to_out = nn.Linear(heads * (scalar_value_dim + pairwise_repr_dim + point_value_dim * (3 + 1)), dim)

    def forward(
        self,
        single_repr,
        pairwise_repr = None,
        *,
        rotations,
        translations,
        mask = None
    ):
        pass

# one transformer block based on IPA

def FeedForward(dim, mult = 1., num_layers = 2, act = nn.ReLU):
    layers = []
    dim_hidden = dim * mult

    for ind in range(num_layers):
        is_first = ind == 0
        is_last  = ind == (num_layers - 1)
        dim_in   = dim if is_first else dim_hidden
        dim_out  = dim if is_last else dim_hidden

        layers.append(nn.Linear(dim_in, dim_out))

        if is_last:
            continue

        layers.append(act())

    return nn.Sequential(*layers)

class IPABlock(nn.Module):
    def __init__(
        self,
        *,
        dim,
        ff_mult = 1,
        ff_num_layers = 3,          # in the paper, they used 3 layer transition (feedforward) block
        post_norm = True,           # in the paper, they used post-layernorm - offering pre-norm as well
        post_attn_dropout = 0.,
        post_ff_dropout = 0.,
        **kwargs
    ):
        pass
    # 初始化函数,继承父类的初始化方法
    def __init__(
        self,
        post_norm: bool
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 设置是否在后处理时进行归一化
        self.post_norm = post_norm

        # 初始化注意力层的归一化层
        self.attn_norm = nn.LayerNorm(dim)
        # 创建不变点注意力层对象
        self.attn = InvariantPointAttention(dim = dim, **kwargs)
        # 初始化注意力层后的丢弃层
        self.post_attn_dropout = nn.Dropout(post_attn_dropout)

        # 初始化前馈神经网络的归一化层
        self.ff_norm = nn.LayerNorm(dim)
        # 创建前馈神经网络对象
        self.ff = FeedForward(dim, mult = ff_mult, num_layers = ff_num_layers)
        # 初始化前馈神经网络后的丢弃层
        self.post_ff_dropout = nn.Dropout(post_ff_dropout)

    # 前向传播函数
    def forward(self, x, **kwargs):
        # 获取是否在后处理时进行归一化的标志
        post_norm = self.post_norm

        # 如果不进行后处理归一化,则直接使用输入作为注意力层的输入,否则对输入进行归一化
        attn_input = x if post_norm else self.attn_norm(x)
        # 经过注意力层的计算,并加上残差连接
        x = self.attn(attn_input, **kwargs) + x
        # 经过注意力层后的丢弃操作
        x = self.post_attn_dropout(x)
        # 如果不进行后处理归一化,则对输出进行归一化,否则直接输出
        x = self.attn_norm(x) if post_norm else x

        # 如果不进行后处理归一化,则直接使用输入作为前馈神经网络的输入,否则对输入进行归一化
        ff_input = x if post_norm else self.ff_norm(x)
        # 经过前馈神经网络的计算,并加上残差连接
        x = self.ff(ff_input) + x
        # 经过前馈神经网络后的丢弃操作
        x = self.post_ff_dropout(x)
        # 如果不进行后处理归一化,则对输出进行归一化,否则直接输出
        x = self.ff_norm(x) if post_norm else x
        # 返回最终输出
        return x
# 添加一个 IPA Transformer - 迭代更新旋转和平移

# 这部分与 AF2 不太准确,因为 AF2 在每一层都应用了一个 FAPE 辅助损失,以及在旋转上应用了一个停止梯度
# 这只是一个尝试,看看是否可以演变成更普遍可用的东西

class IPATransformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        num_tokens = None,
        predict_points = False,
        detach_rotations = True,
        **kwargs
    ):
        super().__init__()

        # 使用来自 pytorch3d 的四元数函数

        try:
            from pytorch3d.transforms import quaternion_multiply, quaternion_to_matrix
            self.quaternion_to_matrix = quaternion_to_matrix
            self.quaternion_multiply = quaternion_multiply
        except (ImportError, ModuleNotFoundError) as err:
            print('unable to import pytorch3d - please install with `conda install pytorch3d -c pytorch3d`')
            raise err

        # 嵌入

        self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None

        # 层

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                IPABlock(dim = dim, **kwargs),
                nn.Linear(dim, 6)
            ]))

        # 是否分离旋转以保持训练稳定性

        self.detach_rotations = detach_rotations

        # 输出

        self.predict_points = predict_points

        if predict_points:
            self.to_points = nn.Linear(dim, 3)

    def forward(
        self,
        single_repr,
        *,
        translations = None,
        quaternions = None,
        pairwise_repr = None,
        mask = None
    ):
        x, device, quaternion_multiply, quaternion_to_matrix = single_repr, single_repr.device, self.quaternion_multiply, self.quaternion_to_matrix
        b, n, *_ = x.shape

        if exists(self.token_emb):
            x = self.token_emb(x)

        # 如果没有传入初始四元数,从单位矩阵开始

        if not exists(quaternions):
            quaternions = torch.tensor([1., 0., 0., 0.], device = device) # 初始旋转
            quaternions = repeat(quaternions, 'd -> b n d', b = b, n = n)

        # 如果没有传入平移,从零开始

        if not exists(translations):
            translations = torch.zeros((b, n, 3), device = device)

        # 遍历层并应用不变点注意力和前馈

        for block, to_update in self.layers:
            rotations = quaternion_to_matrix(quaternions)

            if self.detach_rotations:
                rotations = rotations.detach()

            x = block(
                x,
                pairwise_repr = pairwise_repr,
                rotations = rotations,
                translations = translations
            )

            # 更新四元数和平移

            quaternion_update, translation_update = to_update(x).chunk(2, dim = -1)
            quaternion_update = F.pad(quaternion_update, (1, 0), value = 1.)
            quaternion_update = quaternion_update / torch.linalg.norm(quaternion_update, dim=-1, keepdim=True)
            quaternions = quaternion_multiply(quaternions, quaternion_update)
            translations = translations + einsum('b n c, b n c r -> b n r', translation_update, rotations)

        if not self.predict_points:
            return x, translations, quaternions

        points_local = self.to_points(x)
        rotations = quaternion_to_matrix(quaternions)
        points_global = einsum('b n c, b n c d -> b n d', points_local, rotations) + translations
        return points_global

.\lucidrains\invariant-point-attention\invariant_point_attention\utils.py

# 导入 torch 库
import torch
# 从 torch 库中导入 sin, cos, atan2, acos 函数
from torch import sin, cos, atan2, acos
# 从 functools 库中导入 wraps 装饰器
from functools import wraps

# 定义一个装饰器函数,将输入转换为 torch 张量
def cast_torch_tensor(fn):
    # 定义内部函数,用于实际执行函数并进行类型转换
    @wraps(fn)
    def inner(t):
        # 如果输入不是 torch 张量,则将其转换为 torch 张量
        if not torch.is_tensor(t):
            t = torch.tensor(t, dtype=torch.get_default_dtype())
        # 调用原始函数并返回结果
        return fn(t)
    # 返回内部函数
    return inner

# 使用装饰器将 rot_z 函数转换为接受 torch 张量作为输入的函数
@cast_torch_tensor
def rot_z(gamma):
    # 返回绕 z 轴旋转角度 gamma 的旋转矩阵
    return torch.tensor([
        [cos(gamma), -sin(gamma), 0],
        [sin(gamma), cos(gamma), 0],
        [0, 0, 1]
    ], dtype=gamma.dtype)

# 使用装饰器将 rot_y 函数转换为接受 torch 张量作为输入的函数
@cast_torch_tensor
def rot_y(beta):
    # 返回绕 y 轴旋转角度 beta 的旋转矩阵
    return torch.tensor([
        [cos(beta), 0, sin(beta)],
        [0, 1, 0],
        [-sin(beta), 0, cos(beta)]
    ], dtype=beta.dtype)

# 定义一个函数,通过组合旋转矩阵实现绕不同轴的旋转
def rot(alpha, beta, gamma):
    # 返回绕 z 轴旋转角度 alpha、绕 y 轴旋转角度 beta、绕 z 轴旋转角度 gamma 的组合旋转矩阵
    return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)

.\lucidrains\invariant-point-attention\invariant_point_attention\__init__.py

# 从 invariant_point_attention 模块中导入 InvariantPointAttention, IPABlock, IPATransformer 类
from invariant_point_attention.invariant_point_attention import InvariantPointAttention, IPABlock, IPATransformer

Invariant Point Attention - Pytorch

Implementation of Invariant Point Attention as a standalone module, which was used in the structure module of Alphafold2 for coordinate refinement.

Install

$ pip install invariant-point-attention

Usage

import torch
from einops import repeat
from invariant_point_attention import InvariantPointAttention

attn = InvariantPointAttention(
    dim = 64,                  # single (and pairwise) representation dimension
    heads = 8,                 # number of attention heads
    scalar_key_dim = 16,       # scalar query-key dimension
    scalar_value_dim = 16,     # scalar value dimension
    point_key_dim = 4,         # point query-key dimension
    point_value_dim = 4        # point value dimension
)

single_repr   = torch.randn(1, 256, 64)      # (batch x seq x dim)
pairwise_repr = torch.randn(1, 256, 256, 64) # (batch x seq x seq x dim)
mask          = torch.ones(1, 256).bool()    # (batch x seq)

rotations     = repeat(torch.eye(3), '... -> b n ...', b = 1, n = 256)  # (batch x seq x rot1 x rot2) - example is identity
translations  = torch.zeros(1, 256, 3) # translation, also identity for example

attn_out = attn(
    single_repr,
    pairwise_repr,
    rotations = rotations,
    translations = translations,
    mask = mask
)

attn_out.shape # (1, 256, 64)

You can also use this module without the pairwise representations, which is very specific to the Alphafold2 architecture.

import torch
from einops import repeat
from invariant_point_attention import InvariantPointAttention

attn = InvariantPointAttention(
    dim = 64,
    heads = 8,
    require_pairwise_repr = False   # set this to False to use the module without pairwise representations
)

seq           = torch.randn(1, 256, 64)
mask          = torch.ones(1, 256).bool()

rotations     = repeat(torch.eye(3), '... -> b n ...', b = 1, n = 256)
translations  = torch.randn(1, 256, 3)

attn_out = attn(
    seq,
    rotations = rotations,
    translations = translations,
    mask = mask
)

attn_out.shape # (1, 256, 64)

You can also use one IPA-based transformer block, which is an IPA followed by a feedforward. By default it will use post-layernorm as done in the official code, but you can also try pre-layernorm by setting post_norm = False

import torch
from torch import nn
from einops import repeat
from invariant_point_attention import IPABlock

block = IPABlock(
    dim = 64,
    heads = 8,
    scalar_key_dim = 16,
    scalar_value_dim = 16,
    point_key_dim = 4,
    point_value_dim = 4
)

seq           = torch.randn(1, 256, 64)
pairwise_repr = torch.randn(1, 256, 256, 64)
mask          = torch.ones(1, 256).bool()

rotations     = repeat(torch.eye(3), 'r1 r2 -> b n r1 r2', b = 1, n = 256)
translations  = torch.randn(1, 256, 3)

block_out = block(
    seq,
    pairwise_repr = pairwise_repr,
    rotations = rotations,
    translations = translations,
    mask = mask
)

updates = nn.Linear(64, 6)(block_out)
quaternion_update, translation_update = updates.chunk(2, dim = -1) # (1, 256, 3), (1, 256, 3)

# apply updates to rotations and translations for the next iteration

Toy Example

To run IPA on a toy task for denoising protein backbone coordinates, first install pytorch3d by running

$ conda install pytorch3d -c pytorch3d

Then you need to install sidechainnet with

$ pip install sidechainnet

Finally

$ python denoise.py

Citations

@Article{AlphaFold2021,
    author  = {Jumper, John and Evans, Richard and Pritzel, Alexander and Green, Tim and Figurnov, Michael and Ronneberger, Olaf and Tunyasuvunakool, Kathryn and Bates, Russ and {\v{Z}}{\'\i}dek, Augustin and Potapenko, Anna and Bridgland, Alex and Meyer, Clemens and Kohl, Simon A A and Ballard, Andrew J and Cowie, Andrew and Romera-Paredes, Bernardino and Nikolov, Stanislav and Jain, Rishub and Adler, Jonas and Back, Trevor and Petersen, Stig and Reiman, David and Clancy, Ellen and Zielinski, Michal and Steinegger, Martin and Pacholska, Michalina and Berghammer, Tamas and Bodenstein, Sebastian and Silver, David and Vinyals, Oriol and Senior, Andrew W and Kavukcuoglu, Koray and Kohli, Pushmeet and Hassabis, Demis},
    journal = {Nature},
    title   = {Highly accurate protein structure prediction with {AlphaFold}},
    year    = {2021},
    doi     = {10.1038/s41586-021-03819-2},
    note    = {(Accelerated article preview)},
}

.\lucidrains\invariant-point-attention\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'invariant-point-attention',  # 包的名称
  packages = find_packages(),  # 查找所有包
  version = '0.2.2',  # 版本号
  license='MIT',  # 许可证
  description = 'Invariant Point Attention',  # 描述
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/invariant-point-attention',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'protein folding'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.3',
    'torch>=1.7'
  ],
  setup_requires=[  # 设置依赖
    'pytest-runner',
  ],
  tests_require=[  # 测试依赖
    'pytest'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\invariant-point-attention\tests\invariance.py

# 导入所需的库
import torch
from torch import nn
from einops import repeat
from invariant_point_attention import InvariantPointAttention, IPABlock
from invariant_point_attention.utils import rot

# 测试不变性点注意力机制的函数
def test_ipa_invariance():
    # 创建不变性点注意力机制对象
    attn = InvariantPointAttention(
        dim = 64,
        heads = 8,
        scalar_key_dim = 16,
        scalar_value_dim = 16,
        point_key_dim = 4,
        point_value_dim = 4
    )

    # 创建随机输入序列、成对表示、掩码
    seq           = torch.randn(1, 256, 64)
    pairwise_repr = torch.randn(1, 256, 256, 64)
    mask          = torch.ones(1, 256).bool()

    # 创建随机旋转和平移
    rotations     = repeat(rot(*torch.randn(3)), 'r1 r2 -> b n r1 r2', b = 1, n = 256)
    translations  = torch.randn(1, 256, 3)

    # 随机旋转,用于测试不变性
    random_rotation = rot(*torch.randn(3))

    # 获取不变性点注意力机制的输出
    attn_out = attn(
        seq,
        pairwise_repr = pairwise_repr,
        rotations = rotations,
        translations = translations,
        mask = mask
    )

    # 获取旋转后的不变性点注意力机制的输出
    rotated_attn_out = attn(
        seq,
        pairwise_repr = pairwise_repr,
        rotations = rotations @ random_rotation,
        translations = translations @ random_rotation,
        mask = mask
    )

    # 输出必须是不变的
    diff = (attn_out - rotated_attn_out).max()
    assert diff <= 1e-6, 'must be invariant to global rotation'

# 测试不变性点注意力机制块的函数
def test_ipa_block_invariance():
    # 创建不变性点注意力机制块对象
    attn = IPABlock(
        dim = 64,
        heads = 8,
        scalar_key_dim = 16,
        scalar_value_dim = 16,
        point_key_dim = 4,
        point_value_dim = 4
    )

    # 创建随机输入序列、成对表示、掩码
    seq           = torch.randn(1, 256, 64)
    pairwise_repr = torch.randn(1, 256, 256, 64)
    mask          = torch.ones(1, 256).bool()

    # 创建随机旋转和平移
    rotations     = repeat(rot(*torch.randn(3)), 'r1 r2 -> b n r1 r2', b = 1, n = 256)
    translations  = torch.randn(1, 256, 3)

    # 随机旋转,用于测试不变性
    random_rotation = rot(*torch.randn(3))

    # 获取不变性点注意力机制块的输出
    attn_out = attn(
        seq,
        pairwise_repr = pairwise_repr,
        rotations = rotations,
        translations = translations,
        mask = mask
    )

    # 获取旋转后的不变性点注意力机制块的输出
    rotated_attn_out = attn(
        seq,
        pairwise_repr = pairwise_repr,
        rotations = rotations @ random_rotation,
        translations = translations @ random_rotation,
        mask = mask
    )

    # 输出必须是不变的
    diff = (attn_out - rotated_attn_out).max()
    assert diff <= 1e-6, 'must be invariant to global rotation'

.\lucidrains\isab-pytorch\isab_pytorch\isab_pytorch.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch.nn.functional 中导入 F 模块
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat

# 定义一个辅助函数,用于检查值是否存在
def exists(val):
    return val is not None

# 定义一个注意力机制类
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8,
        dim_head = 64,
        and_self_attend = False
    ):
        super().__init__()
        inner_dim = heads * dim_head
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.and_self_attend = and_self_attend

        # 定义将输入转换为查询向量的线性层
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        # 定义将输入转换为键值对的线性层
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        # 定义将输出转换为最终输出的线性层
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(
        self,
        x,
        context,
        mask = None
    ):
        h, scale = self.heads, self.scale

        if self.and_self_attend:
            # 如果需要自注意力机制,则将上下文信息与输入拼接在一起
            context = torch.cat((x, context), dim = -2)

            if exists(mask):
                # 对 mask 进行填充,使其与输入的维度相匹配
                mask = F.pad(mask, (x.shape[-2], 0), value = True)

        # 将输入 x 转换为查询向量 q,将上下文信息转换为键值对 k 和 v
        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))

        # 将查询向量 q、键 k、值 v 重排维度,以适应注意力计算
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
        # 计算点积注意力得分
        dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale

        if exists(mask):
            # 对注意力得分进行 mask 处理
            mask_value = -torch.finfo(dots.dtype).max
            mask = rearrange(mask, 'b n -> b 1 1 n')
            dots.masked_fill_(~mask, mask_value)

        # 对注意力得分进行 softmax 操作,得到注意力权重
        attn = dots.softmax(dim = -1)
        # 根据注意力权重计算输出
        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        # 重排输出维度,返回最终输出
        out = rearrange(out, 'b h n d -> b n (h d)', h = h)
        return self.to_out(out)

# 定义一个独立的多头自注意力块类
class ISAB(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        num_latents = None,
        latent_self_attend = False
    ):
        super().__init__()
        # 如果存在 latents 数量,则初始化为随机张量,否则为 None
        self.latents = nn.Parameter(torch.randn(num_latents, dim)) if exists(num_latents) else None
        # 定义第一个注意力机制,用于处理 latents 和输入 x
        self.attn1 = Attention(dim, heads, and_self_attend = latent_self_attend)
        # 定义第二个注意力机制,用于处理输入 x 和 latents
        self.attn2 = Attention(dim, heads)

    def forward(self, x, latents = None, mask = None):
        b, *_ = x.shape

        # 确保 latents 参数存在性与 latents 属性的一致性
        assert exists(latents) ^ exists(self.latents), 'you can only either learn the latents within the module, or pass it in externally'
        latents = latents if exists(latents) else self.latents

        if latents.ndim == 2:
            # 如果 latents 是二维张量,则重复扩展为与输入 x 相同的 batch 维度
            latents = repeat(latents, 'n d -> b n d', b = b)

        # 使用第一个注意力机制处理 latents 和输入 x,得到 latents
        latents = self.attn1(latents, x, mask = mask)
        # 使用第二个注意力机制处理输入 x 和 latents,得到输出
        out     = self.attn2(x, latents)

        return out, latents

.\lucidrains\isab-pytorch\isab_pytorch\__init__.py

# 从 isab_pytorch 模块中导入 ISAB 类
from isab_pytorch.isab_pytorch import ISAB

Induced Set Attention Block (ISAB) - Pytorch

A concise implementation of (Induced) Set Attention Block, from the Set Transformers paper. It proposes to reduce attention from O(n²) to O(mn), where m is the number of inducing points (learned latents).

Update: Interesting enough, a new paper has used the ISAB block successfully, in the domain of denoising diffusion for efficient generation of images and video.

Install

$ pip install isab-pytorch

Usage

You can either set the number of latents, in which the parameters will be instantiated and returned on completion of cross attention.

import torch
from isab_pytorch import ISAB

attn = ISAB(
    dim = 512,
    heads = 8,
    num_latents = 128,
    latent_self_attend = True
)

seq = torch.randn(1, 16384, 512) # (batch, seq, dim)
mask = torch.ones((1, 16384)).bool()

out, latents = attn(seq, mask = mask) # (1, 16384, 512), (1, 128, 512)

Or you can choose not to set the number of latents, and pass in the latents yourself (some persistent latent that propagates down the transformer, as an example)

import torch
from isab_pytorch import ISAB

attn = ISAB(
    dim = 512,
    heads = 8
)

seq = torch.randn(1, 16384, 512) # (batch, seq, dim)
latents = torch.nn.Parameter(torch.randn(128, 512)) # some memory, passed through multiple ISABs

out, new_latents = attn(seq, latents) # (1, 16384, 512), (1, 128, 512)

Citations

@misc{lee2019set,
    title   = {Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks},
    author  = {Juho Lee and Yoonho Lee and Jungtaek Kim and Adam R. Kosiorek and Seungjin Choi and Yee Whye Teh},
    year    = {2019},
    eprint  = {1810.00825},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@article{Alayrac2022Flamingo,
    title   = {Flamingo: a Visual Language Model for Few-Shot Learning},
    author  = {Jean-Baptiste Alayrac et al},
    year    = {2022}
}

.\lucidrains\isab-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'isab-pytorch',  # 包的名称
  packages = find_packages(),  # 查找并包含所有包
  version = '0.2.3',  # 版本号
  license='MIT',  # 许可证信息
  description = 'Induced Set Attention Block - Pytorch',  # 描述
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/isab-pytorch',  # 项目链接
  keywords = [
    'artificial intelligence',  # 关键词:人工智能
    'attention mechanism'  # 关键词:注意力机制
  ],
  install_requires=[
    'torch',  # 安装所需的依赖包:torch
    'einops>=0.3'  # 安装所需的依赖包:einops,版本需大于等于0.3
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类器:开发状态为Beta
    'Intended Audience :: Developers',  # 分类器:面向的受众为开发者
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类器:主题为科学/工程 - 人工智能
    'License :: OSI Approved :: MIT License',  # 分类器:许可证为MIT
    'Programming Language :: Python :: 3.6',  # 分类器:编程语言为Python 3.6
  ],
)

.\lucidrains\iTransformer\iTransformer\attend.py

# 导入所需的库
from functools import partial

import torch
from torch import nn, einsum, Tensor
import torch.nn.functional as F

from collections import namedtuple
from functools import wraps
from packaging import version

from einops import rearrange, repeat

# 定义一个命名元组EfficientAttentionConfig,包含三个布尔类型的参数
EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回该变量,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 保证函数只被调用一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 打印函数,只打印一次
print_once = once(print)

# 主类

class Attend(nn.Module):
    def __init__(
        self,
        *,
        dropout = 0.,
        heads = None,
        scale = None,
        flash = False,
        causal = False
    ):
        super().__init__()
        self.scale = scale

        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.causal = causal

        # flash attention

        self.flash = flash
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # determine efficient attention configs for cuda and cpu

        self.cpu_config = EfficientAttentionConfig(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        major, minor = device_properties.major, device_properties.minor

        if (major, minor) == (8, 0):
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = EfficientAttentionConfig(True, False, False)
        elif (major, minor) == (9, 0):
            print_once('H100 GPU detected, using flash attention')
            self.cuda_config = EfficientAttentionConfig(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = EfficientAttentionConfig(False, True, True)

    # 实现flash attention
    def flash_attn(
        self,
        q, k, v
    ):
        batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

        # 检查是否有兼容的设备用于flash attention

        config = self.cuda_config if is_cuda else self.cpu_config

        # 使用torch.backends.cuda.sdp_kernel(**config._asdict())来调用pytorch 2.0的flash attention
        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                is_causal = self.causal,
                dropout_p = self.dropout if self.training else 0.
            )
        
        return out

    # 前向传播函数
    def forward(
        self,
        q, k, v
    ):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        n, heads, kv_heads, device, dtype = q.shape[-2], q.shape[1], k.shape[1], q.device, q.dtype

        scale = default(self.scale, q.shape[-1] ** -0.5)

        if self.flash:
            return self.flash_attn(q, k, v)

        sim = einsum(f'b h i d, b h j d -> b h i j', q, k) * scale

        if self.causal:
            i, j, dtype = *sim.shape[-2:], sim.dtype
            mask_value = -torch.finfo(sim.dtype).max
            causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, mask_value)

        attn = sim.softmax(dim = -1)
        attn = attn.type(dtype)

        attn = self.attn_dropout(attn)

        out = einsum(f'b h i j, b h j d -> b h i d', attn, v)

        return out

.\lucidrains\iTransformer\iTransformer\iTransformer.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum, Tensor
from torch import nn, einsum, Tensor
# 从 torch.nn 库中导入 Module, ModuleList
from torch.nn import Module, ModuleList
# 从 torch.nn.functional 库中导入 F
import torch.nn.functional as F

# 从 beartype 库中导入 beartype
from beartype import beartype
# 从 beartype.typing 库中导入 Optional, Union, Tuple
from beartype.typing import Optional, Union, Tuple

# 从 einops 库中导入 rearrange, reduce, repeat, pack, unpack
from einops import rearrange, reduce, repeat, pack, unpack
# 从 einops.layers.torch 库中导入 Rearrange
from einops.layers.torch import Rearrange

# 从 iTransformer.attend 模块中导入 Attend 类
from iTransformer.attend import Attend
# 从 iTransformer.revin 模块中导入 RevIN 类

# 辅助函数

# 判断变量是否存在
def exists(v):
    return v is not None

# 如果变量存在则返回变量,否则返回默认值
def default(v, d):
    return v if exists(v) else d

# 返回输入本身
def identity(t, *args, **kwargs):
    return t

# 将输入转换为元组
def cast_tuple(t):
    return (t,) if not isinstance(t, tuple) else t

# 注意力机制

class Attention(Module):
    def __init__(
        self,
        dim,
        dim_head = 32,
        heads = 4,
        dropout = 0.,
        flash = True
    ):
        super().__init__()
        # 缩放因子
        self.scale = dim_head ** -0.5
        dim_inner = dim_head * heads

        # 将输入转换为查询、键、值
        self.to_qkv = nn.Sequential(
            nn.Linear(dim, dim_inner * 3, bias = False),
            Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads)
        )

        # 将输入转换为值门控制
        self.to_v_gates = nn.Sequential(
            nn.Linear(dim, dim_inner, bias = False),
            nn.SiLU(),
            Rearrange('b n (h d) -> b h n d', h = heads)
        )

        # 注意力机制
        self.attend = Attend(flash = flash, dropout = dropout)

        # 输出层
        self.to_out = nn.Sequential(
            Rearrange('b h n d -> b n (h d)'),
            nn.Linear(dim_inner, dim, bias = False),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        q, k, v = self.to_qkv(x)

        out = self.attend(q, k, v)

        out = out * self.to_v_gates(x)
        return self.to_out(out)

# 前馈神经网络

class GEGLU(Module):
    def forward(self, x):
        x, gate = rearrange(x, '... (r d) -> r ... d', r = 2)
        return x * F.gelu(gate)

# 创建前馈神经网络
def FeedForward(dim, mult = 4, dropout = 0.):
    dim_inner = int(dim * mult * 2 / 3)
    return nn.Sequential(
        nn.Linear(dim, dim_inner * 2),
        GEGLU(),
        nn.Dropout(dropout),
        nn.Linear(dim_inner, dim)
    )

# 主类

class iTransformer(Module):
    @beartype
    def __init__(
        self,
        *,
        num_variates: int,
        lookback_len: int,
        depth: int,
        dim: int,
        num_tokens_per_variate = 1,
        pred_length: Union[int, Tuple[int, ...]],
        dim_head = 32,
        heads = 4,
        attn_dropout = 0.,
        ff_mult = 4,
        ff_dropout = 0.,
        num_mem_tokens = 4,
        use_reversible_instance_norm = False,
        reversible_instance_norm_affine = False,
        flash_attn = True
    ):
        # 初始化函数,设置模型的变量数和回溯长度
        super().__init__()
        self.num_variates = num_variates
        self.lookback_len = lookback_len

        # 初始化内存令牌参数
        self.mem_tokens = nn.Parameter(torch.randn(num_mem_tokens, dim)) if num_mem_tokens > 0 else None

        # 处理预测长度
        pred_length = cast_tuple(pred_length)
        self.pred_length = pred_length

        # 初始化可逆实例归一化层
        self.reversible_instance_norm = RevIN(num_variates, affine = reversible_instance_norm_affine) if use_reversible_instance_norm else None
        self.num_tokens_per_variate = num_tokens_per_variate

        # 初始化模型的层
        self.layers = ModuleList([])
        for _ in range(depth):
            self.layers.append(ModuleList([
                Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = flash_attn),
                nn.LayerNorm(dim),
                FeedForward(dim, mult = ff_mult, dropout = ff_dropout),
                nn.LayerNorm(dim)
            ]))

        # 初始化 MLP 输入层
        self.mlp_in = nn.Sequential(
            nn.Linear(lookback_len, dim * num_tokens_per_variate),
            Rearrange('b v (n d) -> b (v n) d', n = num_tokens_per_variate),
            nn.LayerNorm(dim)
        )

        # 初始化预测头部
        self.pred_heads = ModuleList([])

        for one_pred_length in pred_length:
            head = nn.Sequential(
                Rearrange('b (v n) d -> b v (n d)', n = num_tokens_per_variate),
                nn.Linear(dim * num_tokens_per_variate, one_pred_length),
                Rearrange('b v n -> b n v')
            )

            self.pred_heads.append(head)

    @beartype
    def forward(
        self,
        x: Tensor,
        targets: Optional[Union[Tensor, Tuple[Tensor, ...]]] = None
    ):
        """
        einstein notation

        b - batch
        n - time
        v - variate
        t - num tokens per variate
        """
        t = self.num_tokens_per_variate

        has_mem = exists(self.mem_tokens)
        assert x.shape[1:] == (self.lookback_len, self.num_variates)

        # 将输入数据重新排列
        x = rearrange(x, 'b n v -> b v n')

        if exists(self.reversible_instance_norm):
            x, reverse_fn = self.reversible_instance_norm(x)

        x = self.mlp_in(x)

        # 内存令牌

        if has_mem:
            m = repeat(self.mem_tokens, 'm d -> b m d', b = x.shape[0])
            x, mem_ps = pack([m, x], 'b * d')

        # 注意力和前馈层

        for attn, attn_post_norm, ff, ff_post_norm in self.layers:
            x = attn(x) + x
            x = attn_post_norm(x)
            x = ff(x) + x
            x = ff_post_norm(x)

        # 剪切出内存令牌

        if has_mem:
            _, x = unpack(x, mem_ps, 'b * d')

        # 如果需要可逆实例归一化

        if exists(self.reversible_instance_norm):
            x = rearrange(x, 'b (n t) d -> t b n d', t = t)
            x = reverse_fn(x)
            x = rearrange(x, 't b n d -> b (n t) d', t = t)

        # 多次预测

        pred_list = [fn(x) for fn in self.pred_heads]

        # 如果传入了目标值,则计算损失

        if exists(targets):
            targets = cast_tuple(targets)
            assert len(targets) == len(pred_list)

            assert self.training
            mse_loss = 0.
            for target, pred in zip(targets, pred_list):
                assert target.shape == pred.shape

                mse_loss = mse_loss + F.mse_loss(target, pred)

            return mse_loss

        if len(pred_list) == 0:
            return pred_list[0]

        pred_dict = dict(zip(self.pred_length, pred_list))
        return pred_dict

.\lucidrains\iTransformer\iTransformer\iTransformer2D.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum, Tensor
from torch import nn, einsum, Tensor
# 从 torch.nn 库中导入 Module, ModuleList
from torch.nn import Module, ModuleList
# 从 torch.nn.functional 库中导入 F
import torch.nn.functional as F

# 从 beartype 库中导入 beartype
from beartype import beartype
# 从 beartype.typing 库中导入 Optional, Union, Tuple
from beartype.typing import Optional, Union, Tuple

# 导入 einops 库中的 rearrange, reduce, repeat, pack, unpack
from einops import rearrange, reduce, repeat, pack, unpack
# 从 einops.layers.torch 库中导入 Rearrange
from einops.layers.torch import Rearrange

# 从 iTransformer.attend 模块中导入 Attend 类
from iTransformer.attend import Attend
# 从 iTransformer.revin 模块中导入 RevIN 类
from iTransformer.revin import RevIN

# 从 gateloop_transformer 模块中导入 SimpleGateLoopLayer 类
from gateloop_transformer import SimpleGateLoopLayer
# 从 rotary_embedding_torch 模块中导入 RotaryEmbedding 类

# 定义 helper functions

# 判断变量是否存在
def exists(v):
    return v is not None

# 如果变量存在则返回变量,否则返回默认值
def default(v, d):
    return v if exists(v) else d

# 将输入的张量 t 按照指定的 pattern 进行打包
def pack_one(t, pattern):
    return pack([t], pattern)

# 将输入的张量 t 按照指定的 pattern 进行解包
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 返回输入的张量 t
def identity(t, *args, **kwargs):
    return t

# 判断 num 是否能被 den 整除
def divisible_by(num, den):
    return (num % den) == 0

# 将输入的变量 t 转换为元组形式
def cast_tuple(t):
    return (t,) if not isinstance(t, tuple) else t

# 定义 attention 类

class Attention(Module):
    def __init__(
        self,
        dim,
        dim_head = 32,
        heads = 4,
        dropout = 0.,
        causal = False,
        flash = True,
        rotary_emb: Optional[RotaryEmbedding] = None,
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        dim_inner = dim_head * heads

        self.rotary_emb = rotary_emb

        self.to_qkv = nn.Sequential(
            nn.Linear(dim, dim_inner * 3, bias = False),
            Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads)
        )

        self.to_v_gates = nn.Sequential(
            nn.Linear(dim, dim_inner, bias = False),
            nn.SiLU(),
            Rearrange('b n (h d) -> b h n d', h = heads)
        )

        self.attend = Attend(flash = flash, dropout = dropout, causal = causal)

        self.to_out = nn.Sequential(
            Rearrange('b h n d -> b n (h d)'),
            nn.Linear(dim_inner, dim, bias = False),
            nn.Dropout(dropout)
        )

    @beartype
    def forward(self, x):
        q, k, v = self.to_qkv(x)

        if exists(self.rotary_emb):
            q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))

        out = self.attend(q, k, v)

        out = out * self.to_v_gates(x)
        return self.to_out(out)

# 定义 GEGLU 类

class GEGLU(Module):
    def forward(self, x):
        x, gate = rearrange(x, '... (r d) -> r ... d', r = 2)
        return x * F.gelu(gate)

# 定义 FeedForward 函数

def FeedForward(dim, mult = 4, dropout = 0.):
    dim_inner = int(dim * mult * 2 / 3)
    return nn.Sequential(
        nn.Linear(dim, dim_inner * 2),
        GEGLU(),
        nn.Dropout(dropout),
        nn.Linear(dim_inner, dim)
    )

# 定义 transformer block 类

class TransformerBlock(Module):
    def __init__(
        self,
        *,
        dim,
        causal = False,
        dim_head = 32,
        heads = 8,
        ff_mult = 4,
        flash_attn = True,
        attn_dropout = 0.,
        ff_dropout = 0.,
        rotary_emb: Optional[RotaryEmbedding] = None,
    ):
        super().__init__()
        self.rotary_emb = rotary_emb

        self.attn = Attention(flash = flash_attn, rotary_emb = rotary_emb, causal = causal, dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)
        self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
        self.attn_norm = nn.LayerNorm(dim)
        self.ff_norm = nn.LayerNorm(dim)

    def forward(self, x, rotary_emb: Optional[RotaryEmbedding] = None):

        x = self.attn(x) + x
        x = self.attn_norm(x)

        x = self.ff(x) + x
        x = self.ff_norm(x)

        return x

# 定义主类 iTransformer2D

class iTransformer2D(Module):
    @beartype
    # 初始化函数,设置模型的各种参数
    def __init__(
        self,
        *,
        num_variates: int,  # 变量数量
        lookback_len: int,  # 回溯长度
        num_time_tokens: int,  # 时间标记数量
        depth: int,  # 模型深度
        dim: int,  # 维度
        pred_length: Union[int, Tuple[int, ...]],  # 预测长度
        dim_head = 32,  # 头部维度
        heads = 4,  # 头部数量
        attn_dropout = 0.,  # 注意力机制的dropout
        ff_mult = 4,  # FeedForward 层的倍数
        ff_dropout = 0.,  # FeedForward 层的dropout
        num_mem_tokens = 4,  # 记忆标记数量
        use_reversible_instance_norm = False,  # 是否使用可逆实例归一化
        reversible_instance_norm_affine = True,  # 可逆实例归一化的可学习参数
        flash_attn = True  # 是否使用 Flash Attention
    ):
        super().__init__()
        assert divisible_by(lookback_len, num_time_tokens)  # 断言回溯长度可以被时间标记数量整除
        assert num_time_tokens >= 2  # 断言时间标记数量至少为2

        self.num_variates = num_variates  # 设置变量数量
        self.lookback_len = lookback_len  # 设置回溯长度
        self.num_time_tokens = num_time_tokens  # 设置时间标记数量

        self.mem_tokens = nn.Parameter(torch.randn(num_mem_tokens, dim)) if num_mem_tokens > 0 else None  # 设置记忆标记

        pred_length = cast_tuple(pred_length)  # 将预测长度转换为元组
        self.pred_length = pred_length  # 设置预测长度

        self.reversible_instance_norm = RevIN(num_variates, affine = reversible_instance_norm_affine) if use_reversible_instance_norm else None  # 设置可逆实例归一化

        rotary_emb = RotaryEmbedding(dim_head)  # 创建旋转嵌入对象

        self.layers = ModuleList([])  # 创建模型层列表

        block_kwargs = dict(
            dim = dim,
            dim_head = dim_head,
            heads = heads,
            ff_mult = ff_mult,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            flash_attn = flash_attn
        )

        # 循环创建模型层
        for _ in range(depth):
            self.layers.append(ModuleList([
                SimpleGateLoopLayer(dim = dim),
                TransformerBlock(causal = True, rotary_emb = rotary_emb, **block_kwargs),
                TransformerBlock(causal = False, **block_kwargs)
            ]))

        # 创建变量标记转换层
        self.to_variate_token = nn.Sequential(
            nn.Linear(lookback_len, dim),
            nn.LayerNorm(dim)
        )

        time_kernel_size = lookback_len // num_time_tokens  # 计算时间卷积核大小

        # 创建时间标记转换层
        self.to_time_tokens = nn.Sequential(
            Rearrange('b v n -> (b v) 1 n'),
            nn.ConstantPad1d((time_kernel_size, 0), value = 0.),
            nn.Conv1d(1, dim, time_kernel_size * 2),
            Rearrange('(b v) d t -> b v t d', v = num_variates),
            nn.LayerNorm(dim)
        )

        self.pred_heads = ModuleList([])  # 创建预测头列表

        # 循环创建预测头
        for one_pred_length in pred_length:
            head = nn.Sequential(
                nn.Linear(dim, one_pred_length),
                Rearrange('b v n -> b n v')
            )

            self.pred_heads.append(head)

    @beartype
    # 前向传播函数
    def forward(
        self,
        x: Tensor,  # 输入张量
        targets: Optional[Union[Tensor, Tuple[Tensor, ...]]] = None  # 目标张量
    ):
        """
        einstein notation

        b - batch
        n - time
        v - variate
        t - number of time tokens
        """

        # 检查是否存在记忆令牌
        has_mem = exists(self.mem_tokens)
        # 断言输入张量的形状符合预期
        assert x.shape[1:] == (self.lookback_len, self.num_variates)

        # 将输入张量重新排列,将时间维度放在最后
        x = rearrange(x, 'b n v -> b v n')

        # 如果存在可逆实例归一化,则对输入张量进行处理
        if exists(self.reversible_instance_norm):
            x, reverse_fn = self.reversible_instance_norm(x)

        # 推导每个变量的时间令牌 't'

        t = self.to_time_tokens(x)

        # 'v' 将是变量池令牌,与 iTransformer 中的每个变量令牌相同

        v = self.to_variate_token(x)

        # 将时间和变量令牌组合成二维特征图,包含变量和时间

        x, variate_pool_token_ps = pack((t, v), 'b v * d')

        # 记忆令牌

        if has_mem:
            m = repeat(self.mem_tokens, 'm d -> b m t d', b = x.shape[0], t = x.shape[-2])
            x, mem_ps = pack([m, x], 'b * t d')

        # 注意力和前馈层

        for gateloop_block, time_attn_block, variate_attn_block in self.layers:
            x, ps = pack_one(x, '* t d')

            # gateloop block
            x = gateloop_block(x) + x

            # 每个变量的时间上的因果关注
            x = time_attn_block(x)

            x = unpack_one(x, ps, '* t d')

            x = rearrange(x, 'b v t d -> b t v d')
            x, ps = pack_one(x, '* v d')

            # 全局变量关注(如反向 Transformer 论文中)
            x = variate_attn_block(x)

            x = unpack_one(x, ps, '* v d')
            x = rearrange(x, 'b t v d -> b v t d')

        # 剥离记忆令牌

        if has_mem:
            _, x = unpack(x, mem_ps, 'b * t d')

        # 获取原始的变量池令牌

        _, v = unpack(x, variate_pool_token_ps, 'b v * d')

        # 如果需要,进行可逆实例归一化

        if exists(self.reversible_instance_norm):
            v = reverse_fn(v)

        # 预测多个时间步

        pred_list = [fn(v) for fn in self.pred_heads]

        # 如果传入了目标值,则计算损失

        if exists(targets):
            targets = cast_tuple(targets)
            assert len(targets) == len(pred_list)

            assert self.training
            mse_loss = 0.
            for target, pred in zip(targets, pred_list):
                assert target.shape == pred.shape

                mse_loss = mse_loss + F.mse_loss(target, pred)

            return mse_loss

        if len(pred_list) == 0:
            return pred_list[0]

        pred_dict = dict(zip(self.pred_length, pred_list))
        return pred_dict

.\lucidrains\iTransformer\iTransformer\iTransformerFFT.py

# 导入 torch 库
import torch
# 从 torch.fft 模块中导入 fft 函数
from torch.fft import fft
# 从 torch 模块中导入 nn、einsum、Tensor
from torch import nn, einsum, Tensor
# 从 torch.nn 模块中导入 Module、ModuleList
from torch.nn import Module, ModuleList
# 从 torch.nn.functional 模块中导入 F
import torch.nn.functional as F

# 从 beartype 库中导入 beartype 函数
from beartype import beartype
# 从 beartype.typing 模块中导入 Optional、Union、Tuple
from beartype.typing import Optional, Union, Tuple

# 从 einops 库中导入 rearrange、reduce、repeat、pack、unpack
from einops import rearrange, reduce, repeat, pack, unpack
# 从 einops.layers.torch 模块中导入 Rearrange
from einops.layers.torch import Rearrange

# 从 iTransformer.attend 模块中导入 Attend 类
from iTransformer.attend import Attend
# 从 iTransformer.revin 模块中导入 RevIN 类

# 定义 helper functions

# 判断变量是否存在
def exists(v):
    return v is not None

# 如果变量存在则返回变量,否则返回默认值
def default(v, d):
    return v if exists(v) else d

# 返回输入的值
def identity(t, *args, **kwargs):
    return t

# 如果输入不是元组,则转换为元组
def cast_tuple(t):
    return (t,) if not isinstance(t, tuple) else t

# 定义 attention 类

class Attention(Module):
    def __init__(
        self,
        dim,
        dim_head = 32,
        heads = 4,
        dropout = 0.,
        flash = True
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        dim_inner = dim_head * heads

        self.to_qkv = nn.Sequential(
            nn.Linear(dim, dim_inner * 3, bias = False),
            Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads)
        )

        self.to_v_gates = nn.Sequential(
            nn.Linear(dim, dim_inner, bias = False),
            nn.SiLU(),
            Rearrange('b n (h d) -> b h n d', h = heads)
        )

        self.attend = Attend(flash = flash, dropout = dropout)

        self.to_out = nn.Sequential(
            Rearrange('b h n d -> b n (h d)'),
            nn.Linear(dim_inner, dim, bias = False),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        q, k, v = self.to_qkv(x)

        out = self.attend(q, k, v)

        out = out * self.to_v_gates(x)
        return self.to_out(out)

# 定义 feedforward 类

class GEGLU(Module):
    def forward(self, x):
        x, gate = rearrange(x, '... (r d) -> r ... d', r = 2)
        return x * F.gelu(gate)

# 定义 FeedForward 函数
def FeedForward(dim, mult = 4, dropout = 0.):
    dim_inner = int(dim * mult * 2 / 3)
    return nn.Sequential(
        nn.Linear(dim, dim_inner * 2),
        GEGLU(),
        nn.Dropout(dropout),
        nn.Linear(dim_inner, dim)
    )

# 定义主类 iTransformerFFT

class iTransformerFFT(Module):
    @beartype
    def __init__(
        self,
        *,
        num_variates: int,
        lookback_len: int,
        depth: int,
        dim: int,
        num_tokens_per_variate = 1,
        pred_length: Union[int, Tuple[int, ...]],
        dim_head = 32,
        heads = 4,
        attn_dropout = 0.,
        ff_mult = 4,
        ff_dropout = 0.,
        num_mem_tokens = 4,
        use_reversible_instance_norm = False,
        reversible_instance_norm_affine = False,
        flash_attn = True
    # 定义模型的初始化方法
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 设置模型的变量数和回溯长度
        self.num_variates = num_variates
        self.lookback_len = lookback_len

        # 如果存在记忆令牌数量大于0,则使用随机初始化的参数
        self.mem_tokens = nn.Parameter(torch.randn(num_mem_tokens, dim)) if num_mem_tokens > 0 else None

        # 将预测长度转换为元组形式
        pred_length = cast_tuple(pred_length)
        self.pred_length = pred_length

        # 如果使用可逆实例归一化,则初始化RevIN对象
        self.reversible_instance_norm = RevIN(num_variates, affine = reversible_instance_norm_affine) if use_reversible_instance_norm else None

        # 初始化模型的层列表
        self.layers = ModuleList([])
        # 根据深度循环添加多个层
        for _ in range(depth):
            self.layers.append(ModuleList([
                Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = flash_attn),
                nn.LayerNorm(dim),
                FeedForward(dim, mult = ff_mult, dropout = ff_dropout),
                nn.LayerNorm(dim)
            ]))

        # 定义MLP输入层
        self.mlp_in = nn.Sequential(
            nn.Linear(lookback_len, dim * num_tokens_per_variate),
            Rearrange('b v (n d) -> b (v n) d', n = num_tokens_per_variate),
            nn.LayerNorm(dim)
        )

        # 定义FFT-MLP输入层
        self.fft_mlp_in = nn.Sequential(
            Rearrange('b v n c -> b v (n c)'),
            nn.Linear(lookback_len * 2, dim * num_tokens_per_variate),
            Rearrange('b v (n d) -> b (v n) d', n = num_tokens_per_variate),
            nn.LayerNorm(dim)   
        )

        # 初始化预测头列表
        self.pred_heads = ModuleList([])

        # 针对每个预测长度,添加一个预测头
        for one_pred_length in pred_length:
            head = nn.Sequential(
                Rearrange('b (v n) d -> b v (n d)', n = num_tokens_per_variate),
                nn.Linear(dim * num_tokens_per_variate, one_pred_length),
                Rearrange('b v n -> b n v')
            )

            self.pred_heads.append(head)

    # 定义模型的前向传播方法
    @beartype
    def forward(
        self,
        x: Tensor,
        targets: Optional[Union[Tensor, Tuple[Tensor, ...]]] = None
        ):
        """
        einstein notation

        b - batch
        n - time
        v - variate
        """
        # 检查是否存在记忆令牌
        has_mem = exists(self.mem_tokens)
        # 断言输入张量的形状符合预期
        assert x.shape[1:] == (self.lookback_len, self.num_variates)

        # 论文的关键在于将变量视为注意力中的空间维度
        # 如果成功复制论文,有很多改进的机会

        # 重新排列输入张量的维度,将时间维度放在最后
        x = rearrange(x, 'b n v -> b v n')

        # 对输入张量进行傅立叶变换
        x_fft = fft(x)
        # 将傅立叶变换后的结果转换为实部和虚部
        x_fft = torch.view_as_real(x_fft)

        # 如果存在可逆实例归一化,则对输入张量进行归一化
        if exists(self.reversible_instance_norm):
            x, reverse_fn = self.reversible_instance_norm(x)

        # 将输入张量投影到变量令牌中,对时间和傅立叶域都进行投影
        x = self.mlp_in(x)
        x_fft = self.fft_mlp_in(x_fft)

        # 将傅立叶变换后的结果放在左侧,以便稍后拼接
        x, fft_ps = pack([x_fft, x], 'b * d')

        # 记忆令牌
        if has_mem:
            # 重复记忆令牌以匹配输入张量的批次维度
            m = repeat(self.mem_tokens, 'm d -> b m d', b = x.shape[0])
            x, mem_ps = pack([m, x], 'b * d')

        # 注意力和前馈层
        for attn, attn_post_norm, ff, ff_post_norm in self.layers:
            x = attn(x) + x
            x = attn_post_norm(x)
            x = ff(x) + x
            x = ff_post_norm(x)

        # 拼接出记忆令牌
        if has_mem:
            _, x = unpack(x, mem_ps, 'b * d')

        # 拼接出傅立叶令牌
        x_fft, x = unpack(x, fft_ps, 'b * d')

        # 如果需要,进行可逆实例归一化
        if exists(self.reversible_instance_norm):
            x = reverse_fn(x)

        # 预测多次
        pred_list = [fn(x) for fn in self.pred_heads]

        # 如果传入了目标值,则计算损失
        if exists(targets):
            targets = cast_tuple(targets)
            assert len(targets) == len(pred_list)

            assert self.training
            mse_loss = 0.
            for target, pred in zip(targets, pred_list):
                assert target.shape == pred.shape

                mse_loss = mse_loss + F.mse_loss(target, pred)

            return mse_loss

        # 如果预测列表为空,则返回第一个预测值
        if len(pred_list) == 0:
            return pred_list[0]

        # 将预测结果与预测长度组成字典返回
        pred_dict = dict(zip(self.pred_length, pred_list))
        return pred_dict

.\lucidrains\iTransformer\iTransformer\revin.py

# 导入必要的库
from collections import namedtuple
import torch
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList
import torch.nn.functional as F

# 定义一个命名元组,用于存储统计信息
Statistics = namedtuple('Statistics', [
    'mean',
    'variance',
    'gamma',
    'beta'
])

# 可逆实例归一化
# 提议的实例归一化方法,参考 https://openreview.net/forum?id=cGDAkQo1C0p

class RevIN(Module):
    def __init__(
        self,
        num_variates,
        affine = True,
        eps = 1e-5
    ):
        super().__init__()
        self.eps = eps
        self.num_variates = num_variates
        # 初始化可学习参数 gamma 和 beta
        self.gamma = nn.Parameter(torch.ones(num_variates, 1), requires_grad = affine)
        self.beta = nn.Parameter(torch.zeros(num_variates, 1), requires_grad = affine)

    def forward(self, x, return_statistics = False):
        assert x.shape[1] == self.num_variates

        # 计算均值和方差
        var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = -1, keepdim = True)
        var_rsqrt = var.clamp(min = self.eps).rsqrt()
        # 实例归一化
        instance_normalized = (x - mean) * var_rsqrt
        # 重新缩放
        rescaled = instance_normalized * self.gamma + self.beta

        # 定义反向函数
        def reverse_fn(scaled_output):
            clamped_gamma = torch.sign(self.gamma) * self.gamma.abs().clamp(min = self.eps)
            unscaled_output = (scaled_output - self.beta) / clamped_gamma
            return unscaled_output * var.sqrt() + mean

        if not return_statistics:
            return rescaled, reverse_fn

        # 返回统计信息
        statistics = Statistics(mean, var, self.gamma, self.beta)

        return rescaled, reverse_fn, statistics

# 主函数,用于进行简单的测试
if __name__ == '__main__':

    # 创建 RevIN 实例
    rev_in = RevIN(512)

    # 生成随机输入
    x = torch.randn(2, 512, 1024)

    # 进行实例归一化并返回统计信息
    normalized, reverse_fn, statistics = rev_in(x, return_statistics = True)

    # 反向操作
    out = reverse_fn(normalized)

    # 断言输入和输出是否一致
    assert torch.allclose(x, out)

.\lucidrains\iTransformer\iTransformer\__init__.py

# 从 iTransformer 模块中导入 iTransformer 和 RevIN 类
# 从 iTransformer 模块中导入 iTransformer2D 类
# 从 iTransformer 模块中导入 iTransformerFFT 类
from iTransformer.iTransformer import (
    iTransformer,
    RevIN
)

from iTransformer.iTransformer2D import iTransformer2D

from iTransformer.iTransformerFFT import iTransformerFFT

iTransformer

Implementation of iTransformer - SOTA Time Series Forecasting using Attention networks, out of Tsinghua / Ant group

All that remains is tabular data (xgboost still champion here) before one can truly declare "Attention is all you need"

In before Apple gets the authors to change the name.

The official implementation has been released here!

Appreciation

  • StabilityAI and 🤗 Huggingface for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source current artificial intelligence techniques.

  • Greg DeVos for sharing experiments he ran on iTransformer and some of the improvised variants

Install

$ pip install iTransformer

Usage

import torch
from iTransformer import iTransformer

# using solar energy settings

model = iTransformer(
    num_variates = 137,
    lookback_len = 96,                  # or the lookback length in the paper
    dim = 256,                          # model dimensions
    depth = 6,                          # depth
    heads = 8,                          # attention heads
    dim_head = 64,                      # head dimension
    pred_length = (12, 24, 36, 48),     # can be one prediction, or many
    num_tokens_per_variate = 1,         # experimental setting that projects each variate to more than one token. the idea is that the network can learn to divide up into time tokens for more granular attention across time. thanks to flash attention, you should be able to accommodate long sequence lengths just fine
    use_reversible_instance_norm = True # use reversible instance normalization, proposed here https://openreview.net/forum?id=cGDAkQo1C0p . may be redundant given the layernorms within iTransformer (and whatever else attention learns emergently on the first layer, prior to the first layernorm). if i come across some time, i'll gather up all the statistics across variates, project them, and condition the transformer a bit further. that makes more sense
)

time_series = torch.randn(2, 96, 137)  # (batch, lookback len, variates)

preds = model(time_series)

# preds -> Dict[int, Tensor[batch, pred_length, variate]]
#       -> (12: (2, 12, 137), 24: (2, 24, 137), 36: (2, 36, 137), 48: (2, 48, 137))

For an improvised version that does granular attention across time tokens (as well as the original per-variate tokens), just import iTransformer2D and set the additional num_time_tokens

Update: It works! Thanks goes out to Greg DeVos for running the experiment here!

Update 2: Got an email. Yes you are free to write a paper on this, if the architecture holds up for your problem. I have no skin in the game

import torch
from iTransformer import iTransformer2D

# using solar energy settings

model = iTransformer2D(
    num_variates = 137,
    num_time_tokens = 16,               # number of time tokens (patch size will be (look back length // num_time_tokens))
    lookback_len = 96,                  # the lookback length in the paper
    dim = 256,                          # model dimensions
    depth = 6,                          # depth
    heads = 8,                          # attention heads
    dim_head = 64,                      # head dimension
    pred_length = (12, 24, 36, 48),     # can be one prediction, or many
    use_reversible_instance_norm = True # use reversible instance normalization
)

time_series = torch.randn(2, 96, 137)  # (batch, lookback len, variates)

preds = model(time_series)

# preds -> Dict[int, Tensor[batch, pred_length, variate]]
#       -> (12: (2, 12, 137), 24: (2, 24, 137), 36: (2, 36, 137), 48: (2, 48, 137))

Experimental

iTransformer with fourier tokens

A iTransformer but also with fourier tokens (FFT of time series is projected into tokens of their own and attended along side with the variate tokens, spliced out at the end)

import torch
from iTransformer import iTransformerFFT

# using solar energy settings

model = iTransformerFFT(
    num_variates = 137,
    lookback_len = 96,                  # or the lookback length in the paper
    dim = 256,                          # model dimensions
    depth = 6,                          # depth
    heads = 8,                          # attention heads
    dim_head = 64,                      # head dimension
    pred_length = (12, 24, 36, 48),     # can be one prediction, or many
    num_tokens_per_variate = 1,         # experimental setting that projects each variate to more than one token. the idea is that the network can learn to divide up into time tokens for more granular attention across time. thanks to flash attention, you should be able to accommodate long sequence lengths just fine
    use_reversible_instance_norm = True # use reversible instance normalization, proposed here https://openreview.net/forum?id=cGDAkQo1C0p . may be redundant given the layernorms within iTransformer (and whatever else attention learns emergently on the first layer, prior to the first layernorm). if i come across some time, i'll gather up all the statistics across variates, project them, and condition the transformer a bit further. that makes more sense
)

time_series = torch.randn(2, 96, 137)  # (batch, lookback len, variates)

preds = model(time_series)

# preds -> Dict[int, Tensor[batch, pred_length, variate]]
#       -> (12: (2, 12, 137), 24: (2, 24, 137), 36: (2, 36, 137), 48: (2, 48, 137))

Todo

Citation

@misc{liu2023itransformer,
  title   = {iTransformer: Inverted Transformers Are Effective for Time Series Forecasting}, 
  author  = {Yong Liu and Tengge Hu and Haoran Zhang and Haixu Wu and Shiyu Wang and Lintao Ma and Mingsheng Long},
  year    = {2023},
  eprint  = {2310.06625},
  archivePrefix = {arXiv},
  primaryClass = {cs.LG}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{burtsev2020memory,
    title   = {Memory Transformer},
    author  = {Mikhail S. Burtsev and Grigory V. Sapunov},
    year    = {2020},
    eprint  = {2006.11527},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@inproceedings{Darcet2023VisionTN,
    title   = {Vision Transformers Need Registers},
    author  = {Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:263134283}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@Article{AlphaFold2021,
    author  = {Jumper, John and Evans, Richard and Pritzel, Alexander and Green, Tim and Figurnov, Michael and Ronneberger, Olaf and Tunyasuvunakool, Kathryn and Bates, Russ and {\v{Z}}{\'\i}dek, Augustin and Potapenko, Anna and Bridgland, Alex and Meyer, Clemens and Kohl, Simon A A and Ballard, Andrew J and Cowie, Andrew and Romera-Paredes, Bernardino and Nikolov, Stanislav and Jain, Rishub and Adler, Jonas and Back, Trevor and Petersen, Stig and Reiman, David and Clancy, Ellen and Zielinski, Michal and Steinegger, Martin and Pacholska, Michalina and Berghammer, Tamas and Bodenstein, Sebastian and Silver, David and Vinyals, Oriol and Senior, Andrew W and Kavukcuoglu, Koray and Kohli, Pushmeet and Hassabis, Demis},
    journal = {Nature},
    title   = {Highly accurate protein structure prediction with {AlphaFold}},
    year    = {2021},
    doi     = {10.1038/s41586-021-03819-2},
    note    = {(Accelerated article preview)},
}
@inproceedings{kim2022reversible,
    title   = {Reversible Instance Normalization for Accurate Time-Series Forecasting against Distribution Shift},
    author  = {Taesung Kim and Jinhee Kim and Yunwon Tae and Cheonbok Park and Jang-Ho Choi and Jaegul Choo},
    booktitle = {International Conference on Learning Representations},
    year    = {2022},
    url     = {https://openreview.net/forum?id=cGDAkQo1C0p}
}
@inproceedings{Katsch2023GateLoopFD,
    title   = {GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling},
    author  = {Tobias Katsch},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:265018962}
}

.\lucidrains\iTransformer\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的信息
setup(
  name = 'iTransformer', # 包名
  packages = find_packages(exclude=[]), # 查找包
  version = '0.5.5', # 版本号
  license='MIT', # 许可证
  description = 'iTransformer - Inverted Transformer Are Effective for Time Series Forecasting', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  long_description_content_type = 'text/markdown', # 长描述内容类型
  url = 'https://github.com/lucidrains/iTransformer', # URL
  keywords = [ # 关键词
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'time series forecasting'
  ],
  install_requires=[ # 安装依赖
    'beartype',
    'einops>=0.7.0',
    'gateloop-transformer>=0.2.3',
    'rotary-embedding-torch',
    'torch>=2.1',
  ],
  classifiers=[ # 分类
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\ITTR-pytorch\ITTR_pytorch\ITTR_pytorch.py

import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange, reduce, repeat

# 定义辅助函数

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def l2norm(t):
    return F.normalize(t, dim = -1)

# 定义辅助类

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class ChanLayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

# 定义类

class HPB(nn.Module):
    """ Hybrid Perception Block """

    def __init__(
        self,
        dim,
        dim_head = 32,
        heads = 8,
        ff_mult = 4,
        attn_height_top_k = 16,
        attn_width_top_k = 16,
        attn_dropout = 0.,
        ff_dropout = 0.
    ):
        super().__init__()

        self.attn = DPSA(
            dim = dim,
            heads = heads,
            dim_head = dim_head,
            height_top_k = attn_height_top_k,
            width_top_k = attn_width_top_k,
            dropout = attn_dropout
        )

        self.dwconv = nn.Conv2d(dim, dim, 3, padding = 1, groups = dim)
        self.attn_parallel_combine_out = nn.Conv2d(dim * 2, dim, 1)

        ff_inner_dim = dim * ff_mult

        self.ff = nn.Sequential(
            nn.Conv2d(dim, ff_inner_dim, 1),
            nn.InstanceNorm2d(ff_inner_dim),
            nn.GELU(),
            nn.Dropout(ff_dropout),
            Residual(nn.Sequential(
                nn.Conv2d(ff_inner_dim, ff_inner_dim, 3, padding = 1, groups = ff_inner_dim),
                nn.InstanceNorm2d(ff_inner_dim),
                nn.GELU(),
                nn.Dropout(ff_dropout)
            )),
            nn.Conv2d(ff_inner_dim, dim, 1),
            nn.InstanceNorm2d(ff_inner_dim)
        )

    def forward(self, x):
        attn_branch_out = self.attn(x)
        conv_branch_out = self.dwconv(x)

        concatted_branches = torch.cat((attn_branch_out, conv_branch_out), dim = 1)
        attn_out = self.attn_parallel_combine_out(concatted_branches) + x

        return self.ff(attn_out)

class DPSA(nn.Module):
    """ Dual-pruned Self-attention Block """

    def __init__(
        self,
        dim,
        height_top_k = 16,
        width_top_k = 16,
        dim_head = 32,
        heads = 8,
        dropout = 0.
    ):
        super().__init__()
        self.heads = heads
        self.dim_head = dim_head
        self.scale = dim_head ** -0.5
        inner_dim = heads * dim_head

        self.norm = ChanLayerNorm(dim)
        self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)

        self.height_top_k = height_top_k
        self.width_top_k = width_top_k

        self.dropout = nn.Dropout(dropout)
        self.to_out = nn.Conv2d(inner_dim, dim, 1)
    # 定义前向传播函数,接受输入张量 x
    def forward(self, x):
        # 获取输入张量 x 的形状信息
        b, c, h, w = x.shape

        # 对输入张量 x 进行归一化处理
        x = self.norm(x)

        # 将输入张量 x 转换为查询、键、值张量
        q, k, v = self.to_qkv(x).chunk(3, dim = 1)

        # 将查询、键、值张量按照头数进行折叠
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) c x y', h = self.heads), (q, k, v))

        # 对查询、键进行 L2 归一化处理
        q, k = map(l2norm, (q, k))

        # 计算是否需要在高度和宽度上进行选择和排名
        need_height_select_and_rank = self.height_top_k < h
        need_width_select_and_rank = self.width_top_k < w

        # 选择和排名键/值,使用查询进行探测,查询在高度和宽度上进行降维,键在行和列上进行降维
        if need_width_select_and_rank or need_height_select_and_rank:
            q_probe = reduce(q, 'b h w d -> b d', 'sum')

        # 沿着高度和宽度进行聚合
        if need_height_select_and_rank:
            k_height = reduce(k, 'b h w d -> b h d', 'sum')

            top_h_indices = einsum('b d, b h d -> b h', q_probe, k_height).topk(k = self.height_top_k, dim = -1).indices

            top_h_indices = repeat(top_h_indices, 'b h -> b h w d', d = self.dim_head, w = k.shape[-2])

            k, v = map(lambda t: t.gather(1, top_h_indices), (k, v)) # 首先沿着高度进行聚合

        if need_width_select_and_rank:
            k_width = reduce(k, 'b h w d -> b w d', 'sum')

            top_w_indices = einsum('b d, b w d -> b w', q_probe, k_width).topk(k = self.width_top_k, dim = -1).indices

            top_w_indices = repeat(top_w_indices, 'b w -> b h w d', d = self.dim_head, h = k.shape[1])

            k, v = map(lambda t: t.gather(2, top_w_indices), (k, v)) # 然后沿着宽度进行聚合

        # 选择适当的键和值
        q, k, v = map(lambda t: rearrange(t, 'b ... d -> b (...) d'), (q, k, v))

        # 计算余弦相似度
        sim = einsum('b i d, b j d -> b i j', q, k)

        # 注意力计算
        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        # 聚合输出
        out = einsum('b i j, b j d -> b i d', attn, v)

        # 合并头部并组合输出
        out = rearrange(out, '(b h) (x y) d -> b (h d) x y', x = h, y = w, h = self.heads)
        return self.to_out(out)

.\lucidrains\ITTR-pytorch\ITTR_pytorch\__init__.py

# 从 ITTR_pytorch.ITTR_pytorch 模块中导入 HPB 和 DPSA 类
from ITTR_pytorch.ITTR_pytorch import HPB, DPSA

ITTR - Pytorch

Implementation of the Hybrid Perception Block (HPB) and Dual-Pruned Self-Attention (DPSA) block from the ITTR paper for Image to Image Translation using Transformers.

Install

$ pip install ITTR-pytorch

Usage

They had 9 blocks of Hybrid Perception Block (HPB) in the paper

import torch
from ITTR_pytorch import HPB

block = HPB(
    dim = 512,              # dimension
    dim_head = 32,          # dimension per attention head
    heads = 8,              # number of attention heads
    attn_height_top_k = 16, # number of top indices to select along height, for the attention pruning
    attn_width_top_k = 16,  # number of top indices to select along width, for the attention pruning
    attn_dropout = 0.,      # attn dropout
    ff_mult = 4,            # expansion factor of feedforward
    ff_dropout = 0.         # feedforward dropout
)

fmap = torch.randn(1, 512, 32, 32)

out = block(fmap) # (1, 512, 32, 32)

You can also use the dual-pruned self-attention as so

import torch
from ITTR_pytorch import DPSA

attn = DPSA(
    dim = 512,         # dimension
    dim_head = 32,     # dimension per attention head
    heads = 8,         # number of attention heads
    height_top_k = 48, # number of top indices to select along height, for the attention pruning
    width_top_k = 48,  # number of top indices to select along width, for the attention pruning
    dropout = 0.       # attn dropout
)

fmap = torch.randn(1, 512, 32, 32)

out = attn(fmap) # (1, 512, 32, 32)

Citations

@inproceedings{Zheng2022ITTRUI,
  title   = {ITTR: Unpaired Image-to-Image Translation with Transformers},
  author  = {Wanfeng Zheng and Qiang Li and Guoxin Zhang and Pengfei Wan and Zhongyuan Wang},
  year    = {2022}
}

.\lucidrains\ITTR-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的元信息
setup(
  name = 'ITTR-pytorch', # 包名
  packages = find_packages(exclude=[]), # 查找所有包
  version = '0.0.4', # 版本号
  license='MIT', # 许可证
  description = 'ITTR - Implementation of the Hybrid Perception Block and Dual-Pruned Self-Attention block', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  url = 'https://github.com/lucidrains/ITTR-pytorch', # 项目链接
  keywords = [ # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism'
  ],
  install_requires=[ # 安装依赖
    'einops>=0.4',
    'torch>=1.6',
  ],
  classifiers=[ # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\jax2torch\jax2torch\jax2torch.py

# 导入需要的库
import torch
from torch.utils import dlpack as torch_dlpack

import jax
from jax import dlpack as jax_dlpack
import jax.numpy as jnp
from jax.tree_util import tree_map

from inspect import signature
from functools import wraps

# 定义将 JAX 数组转换为 PyTorch 张量的函数
def j2t(x_jax):
    x_torch = torch_dlpack.from_dlpack(jax_dlpack.to_dlpack(x_jax))
    return x_torch

# 定义将 PyTorch 张量转换为 JAX 数组的函数
def t2j(x_torch):
    x_torch = x_torch.contiguous() # 保证张量是连续的
    x_jax = jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(x_torch))
    return x_jax

# 定义将树状结构中的 PyTorch 张量转换为 JAX 数组的函数
def tree_t2j(x_torch):
    return tree_map(lambda t: t2j(t) if isinstance(t, torch.Tensor) else t, x_torch)

# 定义将树状结构中的 JAX 数组转换为 PyTorch 张量的函数
def tree_j2t(x_jax):
    return tree_map(lambda t: j2t(t) if isinstance(t, jnp.ndarray) else t, x_jax)

# 定义装饰器,将 JAX 函数转换为 PyTorch 函数
def jax2torch(fn):
    @wraps(fn)
    def inner(*args, **kwargs):
        # 定义一个继承自 torch.autograd.Function 的类
        class JaxFun(torch.autograd.Function):
            @staticmethod
            def forward(ctx, *args):
                # 将输入参数转换为 JAX 数组
                args = tree_t2j(args)
                # 调用 JAX 的 vjp 函数计算函数值和梯度
                y_, ctx.fun_vjp = jax.vjp(fn, *args)
                # 将结果转换为 PyTorch 张量
                return tree_j2t(y_)

            @staticmethod
            def backward(ctx, *grad_args):
                # 将梯度参数转换为 JAX 数组
                grad_args = tree_t2j(grad_args) if len(grad_args) > 1 else t2j(grad_args[0])
                # 计算梯度
                grads = ctx.fun_vjp(grad_args)
                # 将梯度转换为 PyTorch 张量
                grads = tuple(map(lambda t: t if isinstance(t, jnp.ndarray) else None, grads))
                return tree_j2t(grads)

        # 获取函数的参数签名
        sig = signature(fn)
        bound = sig.bind(*args, **kwargs)
        bound.apply_defaults()
        # 调用 JaxFun 类的 apply 方法
        return JaxFun.apply(*bound.arguments.values())
    return inner

.\lucidrains\jax2torch\jax2torch\__init__.py

# 从 jax2torch.jax2torch 模块中导入 jax2torch 函数
from jax2torch.jax2torch import jax2torch

jax2torch

Use Jax functions in Pytorch with DLPack, as outlined in a gist by @mattjj. The repository was made for the purposes of making this differentiable alignment work interoperable with Pytorch projects.

Install

$ pip install jax2torch

Memory management

By default, Jax pre-allocates 90% of VRAM, which leaves Pytorch with very little left over. To prevent this behavior, set the XLA_PYTHON_CLIENT_PREALLOCATE environmental variable to false before running any Jax code:

import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

Usage

Open In Colab Quick test

import jax
import torch
from jax2torch import jax2torch
import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# Jax function

@jax.jit
def jax_pow(x, y = 2):
  return x ** y

# convert to Torch function

torch_pow = jax2torch(jax_pow)

# run it on Torch data!

x = torch.tensor([1., 2., 3.])
y = torch_pow(x, y = 3)
print(y)  # tensor([1., 8., 27.])

# And differentiate!

x = torch.tensor([2., 3.], requires_grad = True)
y = torch.sum(torch_pow(x, y = 3))
y.backward()
print(x.grad) # tensor([12., 27.])

.\lucidrains\jax2torch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'jax2torch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.0.7',
  # 许可证类型
  license='MIT',
  # 描述信息
  description = 'Jax 2 Torch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/jax2torch',
  # 关键词列表
  keywords = [
    'jax',
    'pytorch'
  ],
  # 安装依赖项
  install_requires=[
    'torch>=1.6',
    'jax>=0.2.20'
  ],
  # 分类器列表
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Kalman Filtering Attention (wip)

Implementation of the Kalman Filtering Attention proposed in Kalman Filtering Attention for User Behavior Modeling in CTR Prediction

Will use this repository as guidance. Looks like the core of Kalman filtering is just 5 lines of code.

Citations

@inproceedings{NEURIPS2020_68ce199e,
    author = {Liu, Hu and LU, Jing and Zhao, Xiwei and Xu, Sulong and Peng, Hao and Liu, Yutong and Zhang, Zehua and Li, Jian and Jin, Junsheng and Bao, Yongjun and Yan, Weipeng},
    booktitle = {Advances in Neural Information Processing Systems},
    editor = {H. Larochelle and M. Ranzato and R. Hadsell and M.F. Balcan and H. Lin},
    pages = {9228--9238},
    publisher = {Curran Associates, Inc.},
    title = {Kalman Filtering Attention for User Behavior Modeling in CTR Prediction},
    url = {https://proceedings.neurips.cc/paper_files/paper/2020/file/68ce199ec2c5517597ce0a4d89620f55-Paper.pdf},
    volume = {33},
    year = {2020}
}

.\lucidrains\kronecker-attention-pytorch\kronecker_attention_pytorch\kronecker_attention_pytorch.py

import torch
from torch import nn, einsum
from einops import rearrange, repeat
import torch.nn.functional as F

class KroneckerSelfAttention(nn.Module):
    def __init__(self, dim, heads, dim_heads = 32):
        super().__init__()
        hidden_dim = heads * dim_heads

        self.heads = heads
        # 定义将输入转换为查询、键、值的卷积层
        self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias = False)
        # 定义将输出转换为原始维度的卷积层
        self.to_out = nn.Conv1d(hidden_dim, dim, 1)

    def forward(self, x):
        h = x.shape[-2]

        # 沿着最后两个维度对输入进行平均并拼接
        x = torch.cat((x.mean(dim=-1), x.mean(dim=-2)), dim=-1)

        # 将输入通过查询、键、值的卷积层
        qkv = self.to_qkv(x)
        # 重新排列查询、键、值的维度
        q, k, v = rearrange(qkv, 'b (qkv h d) n -> qkv b h d n', h=self.heads, qkv=3)
        
        # 计算点积注意力
        dots = einsum('bhdi,bhdj->bhij', q, k)
        # 对注意力进行 softmax 操作
        attn = dots.softmax(dim=-1)
        # 计算输出
        out = einsum('bhij,bhdj->bhdi', attn, v)
        
        # 重新排列输出的维度
        out = rearrange(out, 'b h d n -> b (h d) n')
        # 将输出通过输出转换卷积层
        out = self.to_out(out)

        # 对输出进行外部求和操作
        out = rearrange(out[..., :h], 'b c (n 1) -> b c n 1') + rearrange(out[..., h:], 'b c (1 n) -> b c 1 n')
        return out

.\lucidrains\kronecker-attention-pytorch\kronecker_attention_pytorch\__init__.py

# 从 kronecker_attention_pytorch 模块中导入 KroneckerSelfAttention 类
from kronecker_attention_pytorch.kronecker_attention_pytorch import KroneckerSelfAttention

Kronecker Attention Pytorch

Implementation of Kronecker Attention in Pytorch. Results look less than stellar, but if someone found some context where this architecture works well, please post in the issues and let everyone know.

Install

$ pip install kronecker_attention_pytorch

Usage

import torch
from kronecker_attention_pytorch import KroneckerSelfAttention

attn = KroneckerSelfAttention(
    chan = 32,
    heads = 8,
    dim_heads = 64
)

x = torch.randn(1, 32, 256, 512)
attn(x) # (1, 32, 256, 512)

Citations

@article{Gao_2020,
   title={Kronecker Attention Networks},
   url={http://dx.doi.org/10.1145/3394486.3403065},
   journal={Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining},
   publisher={ACM},
   author={Gao, Hongyang and Wang, Zhengyang and Ji, Shuiwang},
   year={2020},
   month={Aug}
}

.\lucidrains\kronecker-attention-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'kronecker-attention-pytorch',  # 包的名称
  packages = find_packages(),  # 查找所有包
  version = '0.0.6',  # 版本号
  license='MIT',  # 许可证
  description = 'Kronecker Attention - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/kronecker-attention-pytorch',  # 项目链接
  keywords = [
    'artificial intelligence',  # 关键词:人工智能
    'attention mechanism'  # 关键词:注意力机制
  ],
  install_requires=[
    'torch',  # 安装依赖:torch
    'einops>=0.3'  # 安装依赖:einops 版本大于等于0.3
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类:开发状态为Beta
    'Intended Audience :: Developers',  # 分类:面向的受众为开发者
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类:主题为科学/工程 - 人工智能
    'License :: OSI Approved :: MIT License',  # 分类:许可证为MIT
    'Programming Language :: Python :: 3.6',  # 分类:编程语言为Python 3.6
  ],
)

.\lucidrains\lambda-networks\lambda_networks\lambda_networks.py

import torch
from torch import nn, einsum
from einops import rearrange

# helpers functions

# 检查值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 计算相对位置编码
def calc_rel_pos(n):
    # 生成网格坐标
    pos = torch.meshgrid(torch.arange(n), torch.arange(n))
    # 重新排列坐标
    pos = rearrange(torch.stack(pos), 'n i j -> (i j) n')  # [n*n, 2] pos[n] = (i, j)
    # 计算相对位置
    rel_pos = pos[None, :] - pos[:, None]                  # [n*n, n*n, 2] rel_pos[n, m] = (rel_i, rel_j)
    rel_pos += n - 1                                       # 将值范围从[-n+1, n-1]转换为[0, 2n-2]
    return rel_pos

# lambda layer

class LambdaLayer(nn.Module):
    def __init__(
        self,
        dim,
        *,
        dim_k,
        n = None,
        r = None,
        heads = 4,
        dim_out = None,
        dim_u = 1):
        super().__init__()
        dim_out = default(dim_out, dim)
        self.u = dim_u # intra-depth dimension
        self.heads = heads

        assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
        dim_v = dim_out // heads

        # 定义卷积层
        self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False)
        self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False)
        self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False)

        # 定义归一化层
        self.norm_q = nn.BatchNorm2d(dim_k * heads)
        self.norm_v = nn.BatchNorm2d(dim_v * dim_u)

        # 检查是否存在局部上下文
        self.local_contexts = exists(r)
        if exists(r):
            assert (r % 2) == 1, 'Receptive kernel size should be odd'
            self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2))
        else:
            assert exists(n), 'You must specify the window size (n=h=w)'
            rel_lengths = 2 * n - 1
            self.rel_pos_emb = nn.Parameter(torch.randn(rel_lengths, rel_lengths, dim_k, dim_u))
            self.rel_pos = calc_rel_pos(n)

    def forward(self, x):
        b, c, hh, ww, u, h = *x.shape, self.u, self.heads

        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        q = self.norm_q(q)
        v = self.norm_v(v)

        q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h = h)
        k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u)
        v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u = u)

        k = k.softmax(dim=-1)

        λc = einsum('b u k m, b u v m -> b k v', k, v)
        Yc = einsum('b h k n, b k v -> b h v n', q, λc)

        if self.local_contexts:
            v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
            λp = self.pos_conv(v)
            Yp = einsum('b h k n, b k v n -> b h v n', q, λp.flatten(3))
        else:
            n, m = self.rel_pos.unbind(dim = -1)
            rel_pos_emb = self.rel_pos_emb[n, m]
            λp = einsum('n m k u, b u v m -> b n k v', rel_pos_emb, v)
            Yp = einsum('b h k n, b n k v -> b h v n', q, λp)

        Y = Yc + Yp
        out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww)
        return out

.\lucidrains\lambda-networks\lambda_networks\tfkeras.py

import tensorflow as tf
from einops.layers.tensorflow import Rearrange
from tensorflow.keras.layers import Conv2D, BatchNormalization, Conv3D, ZeroPadding3D, Softmax, Lambda, Add, Layer
from tensorflow.keras import initializers
from tensorflow import einsum, nn, meshgrid

# 导入所需的库

# helpers functions

def exists(val):
    return val is not None

# 检查值是否存在

def default(val, d):
    return val if exists(val) else d

# 如果值存在则返回该值,否则返回默认值

def calc_rel_pos(n):
    pos = tf.stack(meshgrid(tf.range(n), tf.range(n), indexing = 'ij'))
    pos = Rearrange('n i j -> (i j) n')(pos)             # 重新排列位置信息
    rel_pos = pos[None, :] - pos[:, None]                # 计算相对位置
    rel_pos += n - 1                                     # 调整值范围
    return rel_pos

# 计算相对位置信息

# lambda layer

class LambdaLayer(Layer):
    def __init__(
        self,
        *,
        dim_k,
        n = None,
        r = None,
        heads = 4,
        dim_out = None,
        dim_u = 1):
        super(LambdaLayer, self).__init__()

        self.out_dim = dim_out
        self.u = dim_u  # intra-depth dimension
        self.heads = heads

        assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
        self.dim_v = dim_out // heads
        self.dim_k = dim_k
        self.heads = heads

        self.to_q = Conv2D(self.dim_k * heads, 1, use_bias=False)
        self.to_k = Conv2D(self.dim_k * dim_u, 1, use_bias=False)
        self.to_v = Conv2D(self.dim_v * dim_u, 1, use_bias=False)

        self.norm_q = BatchNormalization()
        self.norm_v = BatchNormalization()

        self.local_contexts = exists(r)
        if exists(r):
            assert (r % 2) == 1, 'Receptive kernel size should be odd'
            self.pos_conv = Conv3D(dim_k, (1, r, r), padding='same')
        else:
            assert exists(n), 'You must specify the window length (n = h = w)'
            rel_length = 2 * n - 1
            self.rel_pos_emb = self.add_weight(name='pos_emb',
                                               shape=(rel_length, rel_length, dim_k, dim_u),
                                               initializer=initializers.random_normal,
                                               trainable=True)
            self.rel_pos = calc_rel_pos(n)

    # 初始化 LambdaLayer 类

    def call(self, x, **kwargs):
        b, hh, ww, c, u, h = *x.get_shape().as_list(), self.u, self.heads

        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        q = self.norm_q(q)
        v = self.norm_v(v)

        q = Rearrange('b hh ww (h k) -> b h k (hh ww)', h=h)(q)
        k = Rearrange('b hh ww (u k) -> b u k (hh ww)', u=u)(k)
        v = Rearrange('b hh ww (u v) -> b u v (hh ww)', u=u)(v)

        k = nn.softmax(k)

        Lc = einsum('b u k m, b u v m -> b k v', k, v)
        Yc = einsum('b h k n, b k v -> b n h v', q, Lc)

        if self.local_contexts:
            v = Rearrange('b u v (hh ww) -> b v hh ww u', hh=hh, ww=ww)(v)
            Lp = self.pos_conv(v)
            Lp = Rearrange('b v h w k -> b v k (h w)')(Lp)
            Yp = einsum('b h k n, b v k n -> b n h v', q, Lp)
        else:
            rel_pos_emb = tf.gather_nd(self.rel_pos_emb, self.rel_pos)
            Lp = einsum('n m k u, b u v m -> b n k v', rel_pos_emb, v)
            Yp = einsum('b h k n, b n k v -> b n h v', q, Lp)

        Y = Yc + Yp
        out = Rearrange('b (hh ww) h v -> b hh ww (h v)', hh = hh, ww = ww)(Y)
        return out

    # 调用 LambdaLayer 类

    def compute_output_shape(self, input_shape):
        return (*input_shape[:2], self.out_dim)

    # ���算输出形状

    def get_config(self):
        config = {'output_dim': (*self.input_shape[:2], self.out_dim)}
        base_config = super(LambdaLayer, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    # 获取配置信息

.\lucidrains\lambda-networks\lambda_networks\__init__.py

# 从 lambda_networks 模块中导入 LambdaLayer 类
from lambda_networks.lambda_networks import LambdaLayer
# 将 LambdaLayer 类赋值给 λLayer 变量
λLayer = LambdaLayer

Lambda Networks - Pytorch

Implementation of λ Networks, a new approach to image recognition that reaches SOTA on ImageNet. The new method utilizes λ layer, which captures interactions by transforming contexts into linear functions, termed lambdas, and applying these linear functions to each input separately.

Yannic Kilcher's paper review

Install

$ pip install lambda-networks

Usage

Global context

import torch
from lambda_networks import LambdaLayer

layer = LambdaLayer(
    dim = 32,       # channels going in
    dim_out = 32,   # channels out
    n = 64,         # size of the receptive window - max(height, width)
    dim_k = 16,     # key dimension
    heads = 4,      # number of heads, for multi-query
    dim_u = 1       # 'intra-depth' dimension
)

x = torch.randn(1, 32, 64, 64)
layer(x) # (1, 32, 64, 64)

Localized context

import torch
from lambda_networks import LambdaLayer

layer = LambdaLayer(
    dim = 32,
    dim_out = 32,
    r = 23,         # the receptive field for relative positional encoding (23 x 23)
    dim_k = 16,
    heads = 4,
    dim_u = 4
)

x = torch.randn(1, 32, 64, 64)
layer(x) # (1, 32, 64, 64)

For fun, you can also import this as follows

from lambda_networks import λLayer

Tensorflow / Keras version

Shinel94 has added a Keras implementation! It won't be officially supported in this repository, so either copy / paste the code under ./lambda_networks/tfkeras.py or make sure to install tensorflow and keras before running the following.

import tensorflow as tf
from lambda_networks.tfkeras import LambdaLayer

layer = LambdaLayer(
    dim_out = 32,
    r = 23,
    dim_k = 16,
    heads = 4,
    dim_u = 1
)

x = tf.random.normal((1, 64, 64, 16)) # channel last format
layer(x) # (1, 64, 64, 32)

Citations

@inproceedings{
    anonymous2021lambdanetworks,
    title={LambdaNetworks: Modeling long-range Interactions without Attention},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=xTJEN-ggl1b},
    note={under review}
}

.\lucidrains\lambda-networks\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的信息
setup(
  name = 'lambda-networks', # 包的名称
  packages = find_packages(), # 查找所有包
  version = '0.4.0', # 版本号
  license='MIT', # 许可证
  description = 'Lambda Networks - Pytorch', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  url = 'https://github.com/lucidrains/lambda-networks', # 项目链接
  keywords = [
    'artificial intelligence', # 关键词:人工智能
    'attention mechanism', # 关键词:注意力机制
    'image recognition' # 关键词:图像识别
  ],
  install_requires=[
    'torch>=1.6', # 安装所需的 torch 版本
    'einops>=0.3' # 安装所需的 einops 版本
  ],
  classifiers=[
    'Development Status :: 4 - Beta', # 分类:开发状态为 Beta
    'Intended Audience :: Developers', # 分类:面向的受众为开发者
    'Topic :: Scientific/Engineering :: Artificial Intelligence', # 分类:主题为科学/工程 - 人工智能
    'License :: OSI Approved :: MIT License', # 分类:许可证为 MIT
    'Programming Language :: Python :: 3.6', # 分类:编程语言为 Python 3.6
  ],
)

.\lucidrains\learning-to-expire-pytorch\learning_to_expire_pytorch\learning_to_expire_pytorch.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch.nn.functional 中导入 F 模块
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 collections 模块中导入 namedtuple 类
from collections import namedtuple

# 定义一个命名元组 Memory,包含 mems 和 elapsed_times 两个字段
Memory = namedtuple('Memory', ['mems', 'elapsed_times'])

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 安全地拼接张量
def safe_cat(tensors, dim = -1):
    tensors = list(filter(exists, tensors))
    if len(tensors) == 1:
        return tensors[0]
    return torch.cat(tensors, dim = dim)

# 安全地对张量进行加法操作
def safe_add(tensor, n):
    if not exists(tensor):
        return None
    return tensor + n

# 位置嵌入

# 相对位移函数
def rel_shift(t):
    b, h, i, j, device, dtype = *t.shape, t.device, t.dtype
    zero_pad = torch.zeros((b, h, i, 1), device = device, dtype = dtype)
    concatted = torch.cat([zero_pad, t], dim = -1)
    shifted = concatted.view(b, h, j + 1, i)[:, :, 1:]
    return shifted.view_as(t)

# 正弦嵌入类
class SinusoidalEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, x):
        n, device = x.shape[1], x.device
        t = torch.arange(n - 1, -1, -1, device = device).type_as(self.inv_freq)
        sinusoid_inp = einsum('i , j -> i j', t, self.inv_freq)
        emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim = -1)
        return emb

# 过期时间跨度逻辑

# 过期时间跨度类
class ExpireSpan(nn.Module):
    def __init__(self, dim, max_mem_len, ramp_length):
        super().__init__()
        self.max_mem_len = max_mem_len
        self.ramp_length = ramp_length
        self.to_expiration = nn.Linear(dim, 1)
        nn.init.constant_(self.to_expiration.bias.data, val = -self.max_mem_len)

    def forward(self, mem, time, seq_len):
        exps = self.to_expiration(mem).squeeze(-1).sigmoid() * self.max_mem_len
        exps = rearrange(exps, 'b j -> b () () j')
        t = rearrange(time, 'b j -> b () () j')
        r = F.pad(exps - t, (0, seq_len), value = 1.)
        mask = torch.clamp((r / self.ramp_length) + 1, min = 0., max = 1.)
        return exps, mask

# 类

# 预层归一化类
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

# 前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        return self.net(x)

# 因果注意力类
class CausalAttention(nn.Module):
    def __init__(self, dim, heads = 8):
        super().__init__()
        dim_head = dim // heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_pos = nn.Linear(dim, dim_head)
        self.to_q = nn.Linear(dim, dim)
        self.to_kv = nn.Linear(dim, dim * 2)
        self.to_out = nn.Linear(dim, dim)
    # 定义一个前向传播函数,接受输入 x,位置编码 pos_emb,记忆 mem,默认为 None,过期掩码 expire_mask,默认为 None
    def forward(self, x, pos_emb, mem = None, expire_mask = None):
        # 获取输入 x 的维度信息:n 为序列长度,h 为头数,scale 为缩放因子,device 为设备信息
        n, h, scale, device = x.shape[1], self.heads, self.scale, x.device

        # 将输入 x 转换为查询向量 q
        q = self.to_q(x)

        # 如果存在记忆 mem,则获取其长度,否则记忆长度为 0
        mem_len = mem.shape[1] if exists(mem) else 0
        # 将记忆 mem 和输入 x 拼接在一起,形成上下文 context
        context = safe_cat((mem, x), dim = 1)

        # 将上下文 context 转换为键值对 kv,并按键值对拆分为 k 和 v
        kv = self.to_kv(context).chunk(2, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, *kv))

        # 计算点积注意力得分 dots
        dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale

        # 计算相对位置贡献
        pos = self.to_pos(pos_emb)
        pos_dots = einsum('b h i d, j d -> b h i j', q, pos) * scale
        pos_dots = rel_shift(pos_dots)
        pos_dots = F.pad(pos_dots, (mem_len, 0), value = 0)
        dots += pos_dots

        # 生成因果掩码
        mask = torch.ones(dots.shape[-2:], device = device).triu_(mem_len + 1).bool()
        mask = rearrange(mask, 'i j -> () () i j')
        dots.masked_fill_(mask, float('-inf'))
        del mask

        # 计算注意力权重
        attn = dots.softmax(dim = -1)

        # 如果存在过期掩码,则将注意力权重乘以过期掩码
        if exists(expire_mask):
            attn  = attn * expire_mask

        # 计算输出
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
# 定义一个名为 ExpireSpanTransformerXL 的类,继承自 nn.Module
class ExpireSpanTransformerXL(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        *,
        num_tokens,  # 标记的数量
        dim,  # 向量的维度
        depth,  # 模型的深度
        seq_len,  # 序列的长度
        heads = 8,  # 多头注意力机制的头数,默认为 8
        num_memory_blocks = 10,  # 记忆块的数量,默认为 10
        expire_loss_coef = 1e-6,  # 过期损失系数,默认为 1e-6
        ramp_length = 128):  # 渐变长度,默认为 128
        super().__init__()  # 调用父类的初始化函数
        # 创建一个标记嵌入层,将标记映射到指定维度的向量
        self.token_emb = nn.Embedding(num_tokens, dim)
        # 创建一个正弦嵌入层,用于添加正弦位置编码

        self.sinusoidal_emb = SinusoidalEmbedding(dim)

        self.dim = dim  # 将维度赋值给类属性
        self.depth = depth  # 将深度赋值给类属性
        self.seq_len = seq_len  # 将序列长度赋值给类属性
        self.max_mem_len = num_memory_blocks * seq_len  # 计算最大记忆长度

        self.expire_loss_coef = expire_loss_coef  # 将过期损失系数赋值给类属性

        self.layers = nn.ModuleList([])  # 创建一个空的模块列表
        # 循环创建深度次数的层,并添加到模块列表中
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                ExpireSpan(dim, self.max_mem_len, ramp_length),  # 添加过期跨度模块
                PreNorm(dim, CausalAttention(dim, heads = heads)),  # 添加预归一化的因果注意力模块
                PreNorm(dim, FeedForward(dim)),  # 添加预归一化的前馈神经网络模块
            ]))

        self.to_logits = nn.Linear(dim, num_tokens)  # 创建一个线性层,将输出维度映射到标记数量
    # 定义前向传播函数,接受输入 x 和记忆 memory,默认为 None
    def forward(self, x, memory = None):
        # 获取输入 x 的形状信息,包括 batch 大小 b,序列长度 n,维度 d,设备信息 device
        b, n, d, device = *x.shape, self.dim, x.device
        # 对输入 x 进行 token embedding
        x = self.token_emb(x)
        # 生成位置编码
        pos_emb = self.sinusoidal_emb(x)

        hidden_states = []
        expire_masks_layers = []
        # 如果存在记忆,则获取记忆中的 mems 和 elapsed_times,否则初始化为 None
        mems_layers = memory.mems if exists(memory) else ((None,) * self.depth)
        times_layers = memory.elapsed_times if exists(memory) else ((None,) * self.depth)
        # 初始化辅助损失为 0
        aux_loss = torch.tensor(0., requires_grad = True)

        # 遍历每个层的记忆和时间信息,以及每个层的注意力和前馈网络
        for (mem, time, (expire_span, attn, ff)) in zip(mems_layers, times_layers, self.layers):
            hidden_states.append(x)

            # 计算过期时间和过期掩码
            exps, expire_mask = expire_span(mem, time, seq_len = n) if exists(mem) else (None, None)
            expire_masks_layers.append(expire_mask)

            # 训练模式下,根据时间信息生成遗忘掩码
            if self.training and exists(time):
                forget_time_thres = torch.randint(0, self.max_mem_len, (b, 1), device = device)
                forget_dropout_mask = (time < forget_time_thres).float()
                forget_dropout_mask = rearrange(forget_dropout_mask, 'b n -> b () () n')
                forget_dropout_mask = F.pad(forget_dropout_mask, (0, n), value = 1.)
                expire_mask *= forget_dropout_mask

            # 执行注意力和前馈网络操作
            x = attn(x, pos_emb = pos_emb, mem = mem, expire_mask = expire_mask) + x
            x = ff(x) + x

            if exists(exps):
                # 计算辅助损失,仅对产生软掩码值的过期进行 L1 辅助损失
                expiring_exps_mask = (expire_mask > 0) & (expire_mask < 1.)
                expiring_exps = exps.masked_select(expiring_exps_mask[..., :-n])
                aux_loss = aux_loss + (expiring_exps / self.seq_len).sum() * self.expire_loss_coef

        # 生成最终的 logits
        logits = self.to_logits(x)

        # 如果序列长度等于 n
        if self.seq_len == n:
            if exists(expire_mask):
                mems_layers_new = []
                times_layers_new = []

                # 遍���每个层的记忆、时间和过期掩码信息
                for mems, times, expire_mask in zip(mems_layers, times_layers, expire_masks_layers):
                    expire_mask = rearrange(expire_mask, 'b () () i -> b i')
                    # 丢弃已过期的记忆
                    expired_exps_mask = (expire_mask <= 0)[..., :-n]
                    num_to_expire = min(expired_exps_mask.sum(dim = -1)
                    _, indices = expired_exps_mask.float().topk(k = num_to_expire, dim = -1)
                    even_expired_exps_mask = torch.zeros_like(expired_exps_mask, device = device).scatter(-1, indices, 1.).bool()

                    mems = mems.masked_select(~even_expired_exps_mask.unsqueeze(-1))
                    mems = mems.reshape(b, -1, d)
                    mems_layers_new.append(mems)

                    times = times.masked_select(~even_expired_exps_mask)
                    times = times.reshape(b, -1)
                    times_layers_new.append(times)

                mems_layers = mems_layers_new
                times_layers = times_layers_new

            # 更新记忆和时间信息
            new_memories = map(lambda t: safe_cat(t, dim = 1), list(zip(mems_layers, hidden_states)))
            new_memories = map(lambda t: t[:, -self.max_mem_len:].detach(), new_memories)

            new_times = torch.arange(n - 1, -1, -1, device = device)
            new_times = repeat(new_times, 'n -> b n', b = b)
            new_elapsed_times = map(lambda t: safe_cat((safe_add(t, n), new_times), dim = 1), times_layers)
            new_elapsed_times = map(lambda t: t[-self.max_mem_len:], new_elapsed_times)

            memory = Memory(list(new_memories), list(new_elapsed_times))

        # 返回 logits、memory 和辅助损失
        return logits, memory, aux_loss

.\lucidrains\learning-to-expire-pytorch\learning_to_expire_pytorch\__init__.py

# 从 learning_to_expire_pytorch.learning_to_expire_pytorch 模块中导入 ExpireSpanTransformerXL 类
from learning_to_expire_pytorch.learning_to_expire_pytorch import ExpireSpanTransformerXL

Learning to Expire - Pytorch (wip)

An implementation of Transformer with Expire-Span, a proposed technique for learning which memories to retain for long-range learning in attention-based networks.

Citations

@inproceedings{
    anonymous2021not,
    title={Not All Memories are Created Equal: Learning to Expire},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=ZVBtN6B_6i7},
    note={under review}
}

.\lucidrains\learning-to-expire-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的信息
setup(
  # 包名
  name = 'learning-to-expire-pytorch',
  # 查找包,排除 examples 文件夹
  packages = find_packages(exclude=['examples']),
  # 版本号
  version = '0.0.2',
  # 许可证
  license='MIT',
  # 描述
  description = 'Learning to Expire - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/learning-to-expire-pytorch',
  # 关键词
  keywords = [
    'artificial intelligence',
    'attention mechanism',
    'transformers',
    'memory'
  ],
  # 安装依赖
  install_requires=[
    'torch>=1.6',
    'einops>=0.3'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\lie-transformer-pytorch\lie_transformer_pytorch\lie_transformer_pytorch.py

# 导入数学库
import math
# 从 functools 库中导入 partial 函数
from functools import partial
# 导入 PyTorch 库
import torch
import torch.nn.functional as F
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 lie_transformer_pytorch.se3 模块中导入 SE3 类
from lie_transformer_pytorch.se3 import SE3
# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 lie_transformer_pytorch.reversible 模块中导入 SequentialSequence 和 ReversibleSequence 类

# helpers

# 定义函数,判断变量是否存在
def exists(val):
    return val is not None

# 定义函数,将变量转换为元组
def cast_tuple(val, depth):
    return val if isinstance(val, tuple) else ((val,) * depth)

# 定义函数,返回默认值
def default(val, d):
    return val if exists(val) else d

# 定义函数,对张量进行批量索引选择
def batched_index_select(values, indices, dim = 1):
    value_dims = values.shape[(dim + 1):]
    values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
    indices = indices[(..., *((None,) * len(value_dims))]
    indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
    value_expand_len = len(indices_shape) - (dim + 1)
    values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]

    value_expand_shape = [-1] * len(values.shape)
    expand_slice = slice(dim, (dim + value_expand_len))
    value_expand_shape[expand_slice] = indices.shape[expand_slice]
    values = values.expand(*value_expand_shape)

    dim += value_expand_len
    return values.gather(dim, indices)

# helper classes

# 定义 Pass 类,用于对输入进行处理
class Pass(nn.Module):
    def __init__(self, fn, dim = 1):
        super().__init__()
        self.fn = fn
        self.dim = dim

    def forward(self,x):
        dim = self.dim
        xs = list(x)
        xs[dim] = self.fn(xs[dim])
        return xs

# 定义 Lambda 类,用于对输入进行处理
class Lambda(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x)

# 定义 GlobalPool 类,用于计算在掩码中所有空间位置(和群元素)上减少的值
class GlobalPool(nn.Module):
    def __init__(self, mean = False):
        super().__init__()
        self.mean = mean

    def forward(self, x):
        coords, vals, mask = x

        if not exists(mask):
            return val.mean(dim = 1)

        masked_vals = vals.masked_fill_(~mask[..., None], 0.)
        summed = masked_vals.sum(dim = 1)

        if not self.mean:
            return summed

        count = mask.sum(-1).unsqueeze(-1)
        return summed / count

# subsampling code

# 定义 FPSindices 函数,用于根据距离矩阵和掩码进行下采样
def FPSindices(dists, frac, mask):
    """ inputs: pairwise distances DISTS (bs,n,n), downsample_frac (float), valid atom mask (bs,n)
        outputs: chosen_indices (bs,m) """
    m = int(round(frac * dists.shape[1]))
    bs, n, device = *dists.shape[:2], dists.device
    dd_kwargs = {'device': device, 'dtype': torch.long}
    B = torch.arange(bs, **dd_kwargs)

    chosen_indices = torch.zeros(bs, m, **dd_kwargs)
    distances = torch.ones(bs, n, device=device) * 1e8
    a = torch.randint(0, n, (bs,), **dd_kwargs)            # choose random start
    idx = a % mask.sum(-1) + torch.cat([torch.zeros(1, **dd_kwargs), torch.cumsum(mask.sum(-1), dim=0)[:-1]], dim=0)
    farthest = torch.where(mask)[1][idx]

    for i in range(m):
        chosen_indices[:, i] = farthest                    # add point that is farthest to chosen
        dist = dists[B, farthest].masked_fill(~mask, -100) # (bs,n) compute distance from new point to all others
        closer = dist < distances                          # if dist from new point is smaller than chosen points so far
        distances[closer] = dist[closer]                   # update the chosen set's distance to all other points
        farthest = torch.max(distances, -1)[1]             # select the point that is farthest from the set

    return chosen_indices

# 定义 FPSsubsample 类,用于进行 FPS 下采样
class FPSsubsample(nn.Module):
    def __init__(self, ds_frac, cache = False, group = None):
        super().__init__()
        self.ds_frac = ds_frac
        self.cache = cache
        self.cached_indices = None
        self.group = default(group, SE3())
    # 获取查询索引,根据是否启用缓存和缓存文件是否存在来决定是否重新计算
    def get_query_indices(self, abq_pairs, mask):
        # 如果启用缓存并且缓存文件存在,则直接返回缓存的查询索引
        if self.cache and exists(self.cached_indices):
            return self.cached_indices

        # 定义距离函数,如果存在分组则使用分组的距离函数,否则使用默认的 L2 范数
        dist = self.group.distance if self.group else lambda ab: ab.norm(dim=-1)
        # 计算 FPS 索引,根据数据集的分数和掩码值,返回索引值,并且将其从计算图中分离
        value = FPSindices(dist(abq_pairs), self.ds_frac, mask).detach()

        # 如果启用缓存,则将计算得到的索引值缓存起来
        if self.cache:
            self.cached_indices = value

        # 返回计算得到的索引值
        return value

    # 前向传播函数,根据输入数据进行处理并返回结果
    def forward(self, inp, withquery=False):
        # 解包输入数据
        abq_pairs, vals, mask, edges = inp
        # 获取设备信息
        device = vals.device

        # 如果数据子采样比例不为1
        if self.ds_frac != 1:
            # 获取查询索引
            query_idx = self.get_query_indices(abq_pairs, mask)

            # 创建索引 B,用于索引操作
            B = torch.arange(query_idx.shape[0], device=device).long()[:, None]
            # 根据查询索引对 abq_pairs 进行子采样
            subsampled_abq_pairs = abq_pairs[B, query_idx][B, :, query_idx]
            # 根据查询索引对 vals 进行子采样
            subsampled_values = batched_index_select(vals, query_idx, dim=1)
            # 根据查询索引对 mask 进行子采样
            subsampled_mask = batched_index_select(mask, query_idx, dim=1)
            # 如果存在边信息,则根据查询索引对 edges 进行子采样
            subsampled_edges = edges[B, query_idx][B, :, query_idx] if exists(edges) else None
        else:
            # 如果数据子采样比例为1,则不进行子采样操作
            subsampled_abq_pairs = abq_pairs
            subsampled_values = vals
            subsampled_mask = mask
            subsampled_edges = edges
            query_idx = None

        # 将子采样后的数据组合成元组
        ret = (
            subsampled_abq_pairs,
            subsampled_values,
            subsampled_mask,
            subsampled_edges
        )

        # 如果需要查询索引信息,则将查询索引信息添加到返回结果中
        if withquery:
            ret = (*ret, query_idx)

        # 返回处理后的结果
        return ret
# 定义一个自注意力机制的类 LieSelfAttention
class LieSelfAttention(nn.Module):
    def __init__(
        self,
        dim,
        edge_dim = None,
        group = None,
        mc_samples = 32,
        ds_frac = 1,
        fill = 1 / 3,
        dim_head = 64,
        heads = 8,
        cache = False
    ):
        super().__init__()
        self.dim = dim

        # 设置用于估计卷积的样本数量
        self.mc_samples = mc_samples
        # 设置 LieConv 的等变性群
        self.group = default(group, SE3())
        # 注册缓冲区变量 r,用于本地邻域半径,由 fill 设置
        self.register_buffer('r',torch.tensor(2.))
        # 设置平均输入进入本地邻域的分数,决定 r
        self.fill_frac = min(fill, 1.)
        
        # 创建 FPSsubsample 对象,用于下采样
        self.subsample = FPSsubsample(ds_frac, cache = cache, group = self.group)
        # 内部系数,用于更新 r
        self.coeff = .5
        # 用于记录平均填充分数,仅用于日志记录
        self.fill_frac_ema = fill

        # 注意力相关参数
        inner_dim = dim_head * heads
        self.heads = heads

        # 线性变换,用于计算查询、键、值和输出
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_k = nn.Linear(dim, inner_dim, bias = False)
        self.to_v = nn.Linear(dim, inner_dim, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

        edge_dim = default(edge_dim, 0)
        edge_dim_in = self.group.lie_dim + edge_dim

        # 局部注意力 MLP
        self.loc_attn_mlp = nn.Sequential(
            nn.Linear(edge_dim_in, edge_dim_in * 4),
            nn.ReLU(),
            nn.Linear(edge_dim_in * 4, 1),
        )

    # 提取邻域信息
    def extract_neighborhood(self, inp, query_indices):
        """ inputs: [pairs_abq (bs,n,n,d), inp_vals (bs,n,c), mask (bs,n), query_indices (bs,m)]
            outputs: [neighbor_abq (bs,m,mc_samples,d), neighbor_vals (bs,m,mc_samples,c)]"""

        # 从输入中获取数据
        pairs_abq, inp_vals, mask, edges = inp
        device = inp_vals.device

        # 根据查询索引对 pairs_abq、inp_vals、mask 进行下采样
        if exists(query_indices):
            abq_at_query = batched_index_select(pairs_abq, query_indices, dim = 1)
            mask_at_query = batched_index_select(mask, query_indices, dim = 1)
            edges_at_query = batched_index_select(edges, query_indices, dim = 1) if exists(edges) else None
        else:
            abq_at_query = pairs_abq
            mask_at_query = mask
            edges_at_query = edges

        mask_at_query = mask_at_query[..., None]

        vals_at_query = inp_vals
        dists = self.group.distance(abq_at_query)
        mask_value = torch.finfo(dists.dtype).max
        dists = dists.masked_fill(mask[:,None,:], mask_value)

        k = min(self.mc_samples, inp_vals.shape[1])

        # 从距离球中采样
        bs, m, n = dists.shape
        within_ball = (dists < self.r) & mask[:,None,:] & mask_at_query
        noise = torch.zeros((bs, m, n), device = device).uniform_(0, 1)
        valid_within_ball, nbhd_idx = torch.topk(within_ball + noise, k, dim=-1, sorted=False)
        valid_within_ball = (valid_within_ball > 1)

        # 获取邻域位置的 abq_pairs、values 和 mask
        nbhd_abq = batched_index_select(abq_at_query, nbhd_idx, dim = 2)
        nbhd_vals = batched_index_select(vals_at_query, nbhd_idx, dim = 1)
        nbhd_mask = batched_index_select(mask, nbhd_idx, dim = 1)
        nbhd_edges = batched_index_select(edges_at_query, nbhd_idx, dim = 2) if exists(edges) else None

        # 如果处于训练阶段,���新球半径以匹配 fill_frac
        if self.training:
            navg = (within_ball.float()).sum(-1).sum() / mask_at_query.sum()
            avg_fill = (navg / mask.sum(-1).float().mean()).cpu().item()
            self.r +=  self.coeff * (self.fill_frac - avg_fill)
            self.fill_frac_ema += .1 * (avg_fill-self.fill_frac_ema)

        nbhd_mask &= valid_within_ball.bool()

        return nbhd_abq, nbhd_vals, nbhd_mask, nbhd_edges, nbhd_idx
    # 定义前向传播函数,接收输入数据
    def forward(self, inp):
        """inputs: [pairs_abq (bs,n,n,d)], [inp_vals (bs,n,ci)]), [query_indices (bs,m)]
           outputs [subsampled_abq (bs,m,m,d)], [convolved_vals (bs,m,co)]"""
        # 从输入数据中抽取子样本,包括子样本的abq、值、掩码、边缘和查询索引
        sub_abq, sub_vals, sub_mask, sub_edges, query_indices = self.subsample(inp, withquery = True)
        # 从输入数据中提取邻域,包括邻域的abq、值、掩码、边缘和邻域索引
        nbhd_abq, nbhd_vals, nbhd_mask, nbhd_edges, nbhd_indices = self.extract_neighborhood(inp, query_indices)

        # 获取头数、批次大小、节点数、特征维度和设备信息
        h, b, n, d, device = self.heads, *sub_vals.shape, sub_vals.device

        # 将子样本的值转换为查询、键和值
        q, k, v = self.to_q(sub_vals), self.to_k(nbhd_vals), self.to_v(nbhd_vals)

        # 重排查询、键和值的维度
        q = rearrange(q, 'b n (h d) -> b h n d', h = h)
        k, v = map(lambda t: rearrange(t, 'b n m (h d) -> b h n m d', h = h), (k, v))

        # 计算注意力相似度
        sim = einsum('b h i d, b h i j d -> b h i j', q, k) * (q.shape[-1] ** -0.5)

        # 更新边缘信息
        edges = nbhd_abq
        if exists(nbhd_edges):
            edges = torch.cat((nbhd_abq, nbhd_edges), dim = -1)

        # 通过位置注意力MLP更新位置注意力
        loc_attn = self.loc_attn_mlp(edges)
        loc_attn = rearrange(loc_attn, 'b i j () -> b () i j')
        sim = sim + loc_attn

        # 创建掩码值
        mask_value = -torch.finfo(sim.dtype).max

        # 使用掩码值对相似度矩阵进行掩码
        sim.masked_fill_(~rearrange(nbhd_mask, 'b n m -> b () n m'), mask_value)

        # 计算注意力权重
        attn = sim.softmax(dim = -1)
        # 计算输出值
        out = einsum('b h i j, b h i j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)', h = h)
        # 将输出值转换为输出维度
        combined = self.to_out(out)

        # 返回子样本的abq、组合值、子样本掩码和子样本边缘
        return sub_abq, combined, sub_mask, sub_edges
class LieSelfAttentionWrapper(nn.Module):
    # 自注意力机制的包装器类
    def __init__(self, dim, attn):
        super().__init__()
        self.dim = dim
        self.attn = attn

        self.net = nn.Sequential(
            Pass(nn.LayerNorm(dim)),  # 添加层归一化
            self.attn
        )

    def forward(self, inp):
        sub_coords, sub_values, mask, edges = self.attn.subsample(inp)
        new_coords, new_values, mask, edges = self.net(inp)
        new_values[..., :self.dim] += sub_values
        return new_coords, new_values, mask, edges

class FeedForward(nn.Module):
    # 前馈神经网络类
    def __init__(self, dim, mult = 4):
        super().__init__()
        self.dim = dim

        self.net = nn.Sequential(
            Pass(nn.LayerNorm(dim)),  # 添加层归一化
            Pass(nn.Linear(dim, mult * dim)),  # 线性变换
            Pass(nn.GELU()),  # GELU激活函数
            Pass(nn.Linear(mult * dim, dim)),  # 线性变换
        )

    def forward(self,inp):
        sub_coords, sub_values, mask, edges = inp
        new_coords, new_values, mask, edges = self.net(inp)
        new_values = new_values + sub_values
        return new_coords, new_values, mask, edges

# transformer class

class LieTransformer(nn.Module):
    """
    [Fill] specifies the fraction of the input which is included in local neighborhood.
            (can be array to specify a different value for each layer)
    [nbhd] number of samples to use for Monte Carlo estimation (p)
    [dim] number of input channels: 1 for MNIST, 3 for RGB images, other for non images
    [ds_frac] total downsampling to perform throughout the layers of the net. In (0,1)
    [num_layers] number of BottleNeck Block layers in the network
    [k] channel width for the network. Can be int (same for all) or array to specify individually.
    [liftsamples] number of samples to use in lifting. 1 for all groups with trivial stabilizer. Otherwise 2+
    [Group] Chosen group to be equivariant to.
    """
    def __init__(
        self,
        dim,
        num_tokens = None,
        num_edge_types = None,
        edge_dim = None,
        heads = 8,
        dim_head = 64,
        depth = 2,
        ds_frac = 1.,
        dim_out = None,
        k = 1536,
        nbhd = 128,
        mean = True,
        per_point = True,
        liftsamples = 4,
        fill = 1 / 4,
        cache = False,
        reversible = False,
        **kwargs
    ):
        super().__init__()
        assert not (ds_frac < 1 and reversible), 'must not downsample if network is reversible'

        dim_out = default(dim_out, dim)
        self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None
        self.edge_emb = nn.Embedding(num_edge_types, edge_dim) if exists(num_edge_types) else None

        group = SE3()
        self.group = group
        self.liftsamples = liftsamples

        layers_fill = cast_tuple(fill, depth)
        layers = nn.ModuleList([])

        for _, layer_fill in zip(range(depth), layers_fill):
            layers.append(nn.ModuleList([
                LieSelfAttentionWrapper(dim, LieSelfAttention(dim, heads = heads, dim_head = dim_head, edge_dim = edge_dim, mc_samples = nbhd, ds_frac = ds_frac, group = group, fill = fill, cache = cache,**kwargs)),
                FeedForward(dim)
            ]))

        execute_class = ReversibleSequence if reversible else SequentialSequence
        self.net = execute_class(layers)

        self.to_logits = nn.Sequential(
            Pass(nn.LayerNorm(dim)),  # 添加层归一化
            Pass(nn.Linear(dim, dim_out))  # 线性变换
        )

        self.pool = GlobalPool(mean = mean)  # 全局池化
    # 定义一个前向传播函数,接受特征、坐标、边缘、掩码等参数,并返回池化结果
    def forward(self, feats, coors, edges = None, mask = None, return_pooled = False):
        # 获取批次大小、节点数等信息
        b, n, *_ = feats.shape

        # 如果存在 token_emb 属性,则对特征进行处理
        if exists(self.token_emb):
            feats = self.token_emb(feats)

        # 如果存在 edge_emb 属性,则对边缘进行处理
        if exists(self.edge_emb):
            # 确保 edges 参数存在
            assert exists(edges), 'edges must be passed in on forward'
            # 确保 edges 的形状符合要求
            assert edges.shape[1] == edges.shape[2] and edges.shape[1] == n, f'edges must be of the shape ({b}, {n}, {n})'
            edges = self.edge_emb(edges)

        # 将坐标、特征、掩码、边缘等参数组合成元组
        inps = (coors, feats, mask, edges)

        # 使用 group 属性对输入进行变换
        lifted_x = self.group.lift(inps, self.liftsamples)
        # 将变换后的输入传入网络进行计算
        out = self.net(lifted_x)

        # 将输出结果转换为 logits
        out = self.to_logits(out)

        # 如果不需要返回池化结果,则直接返回特征
        if not return_pooled:
            features = out[1]
            return features

        # 返回池化结果
        return self.pool(out)

.\lucidrains\lie-transformer-pytorch\lie_transformer_pytorch\reversible.py

# 导入 torch 库
import torch
# 导入 torch 中的神经网络模块
import torch.nn as nn
# 从 torch.autograd.function 中导入 Function 类
from torch.autograd.function import Function
# 从 torch.utils.checkpoint 中导入 get_device_states 和 set_device_states 函数

# 辅助函数

# 对元组中指定维度的元素求和
def sum_tuple(x, y, dim = 1):
    x = list(x)
    x[dim] += y[dim]
    return tuple(x)

# 对元组中指定维度的元素求差
def subtract_tuple(x, y, dim = 1):
    x = list(x)
    x[dim] -= y[dim]
    return tuple(x)

# 设置元组中指定维度的值
def set_tuple(x, dim, value):
    x = list(x).copy()
    x[dim] = value
    return tuple(x)

# 对元组中指定维度的元素应用函数
def map_tuple(fn, x, dim = 1):
    x = list(x)
    x[dim] = fn(x[dim])
    return tuple(x)

# 对元组中指定维度的元素进行分块
def chunk_tuple(fn, x, dim = 1):
    x = list(x)
    value = x[dim]
    chunks = fn(value)
    return tuple(map(lambda t: set_tuple(x, 1, t), chunks))

# 将两个元组在指定维度进行拼接
def cat_tuple(x, y, dim = 1, cat_dim = -1):
    x = list(x)
    y = list(y)
    x[dim] = torch.cat((x[dim], y[dim]), dim = cat_dim)
    return tuple(x)

# 删除元组中的元素
def del_tuple(x):
    for el in x:
        if el is not None:
            del el

# 根据 https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 中的示例,实现保存和设置随机数生成器状态的类
class Deterministic(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.cpu_state = None
        self.cuda_in_fwd = None
        self.gpu_devices = None
        self.gpu_states = None

    def record_rng(self, *args):
        self.cpu_state = torch.get_rng_state()
        if torch.cuda._initialized:
            self.cuda_in_fwd = True
            self.gpu_devices, self.gpu_states = get_device_states(*args)

    def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
        if record_rng:
            self.record_rng(*args)

        if not set_rng:
            return self.net(*args, **kwargs)

        rng_devices = []
        if self.cuda_in_fwd:
            rng_devices = self.gpu_devices

        with torch.random.fork_rng(devices=rng_devices, enabled=True):
            torch.set_rng_state(self.cpu_state)
            if self.cuda_in_fwd:
                set_device_states(self.gpu_devices, self.gpu_states)
            return self.net(*args, **kwargs)

# 受 https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 启发,实现可逆块类
# 一旦多 GPU 确认工作正常,重构并将 PR 发回源代码
class ReversibleBlock(nn.Module):
    def __init__(self, f, g):
        super().__init__()
        self.f = Deterministic(f)
        self.g = Deterministic(g)

    def forward(self, x, f_args = {}, g_args = {}):
        training = self.training
        x1, x2 = chunk_tuple(lambda t: torch.chunk(t, 2, dim=2), x)
        y1, y2 = None, None

        with torch.no_grad():
            y1 = sum_tuple(self.f(x2, record_rng = training, **f_args), x1)
            y2 = sum_tuple(self.g(y1, record_rng = training, **g_args), x2)

        return cat_tuple(y1, y2, cat_dim = 2)

    def backward_pass(self, y, dy, f_args = {}, g_args = {}):
        y1, y2 = chunk_tuple(lambda t: torch.chunk(t, 2, dim=2), y)
        del_tuple(y)

        dy1, dy2 = torch.chunk(dy, 2, dim=2)
        del dy

        with torch.enable_grad():
            y1[1].requires_grad = True
            gy1 = self.g(y1, set_rng=True, **g_args)
            torch.autograd.backward(gy1[1], dy2)

        with torch.no_grad():
            x2 = subtract_tuple(y2, gy1)
            del_tuple(y2)
            del gy1

            dx1 = dy1 + y1[1].grad
            del dy1
            y1[1].grad = None

        with torch.enable_grad():
            x2[1].requires_grad = True
            fx2 = self.f(x2, set_rng = True, **f_args)
            torch.autograd.backward(fx2[1], dx1)

        with torch.no_grad():
            x1 = subtract_tuple(y1, fx2)
            del fx2
            del_tuple(y1)

            dx2 = dy2 + x2[1].grad
            del dy2
            x2[1].grad = None

            x2 = map_tuple(lambda t: t.detach(), x2)
            x = cat_tuple(x1, x2, cat_dim = -1)
            dx = torch.cat((dx1, dx2), dim=2)

        return x, dx

class _ReversibleFunction(Function):
    # 定义一个静态方法,用于前向传播
    @staticmethod
    def forward(ctx, x, blocks, kwargs):
        # 将传入的参数保存在上下文中
        ctx.kwargs = kwargs
        # 将传入的参数重新组合
        x = (kwargs.pop('coords'), x, kwargs.pop('mask'), kwargs.pop('edges'))
    
        # 遍历每个块并进行前向传播
        for block in blocks:
            x = block(x, **kwargs)
    
        # 将计算结果保存在上下文中,并将梯度分离
        ctx.y = map_tuple(lambda t: t.detach(), x, dim=1)
        ctx.blocks = blocks
        # 返回计算结果的第二个元素
        return x[1]
    
    # 定义一个静态方法,用于反向传播
    @staticmethod
    def backward(ctx, dy):
        # 从上下文中获取保存的数据
        y = ctx.y
        kwargs = ctx.kwargs
    
        # 反向遍历每个块并进行反向传播
        for block in ctx.blocks[::-1]:
            y, dy = block.backward_pass(y, dy, **kwargs)
        # 返回计算结果的梯度
        return dy, None, None
class SequentialSequence(nn.Module):
    # 定义一个顺序执行的序列模块
    def __init__(self, blocks):
        # 初始化函数,接受一个包含多个块的列表作为参数
        super().__init__()
        # 调用父类的初始化函数
        self.blocks = blocks
        # 将传入的块列表保存在当前对象的属性中

    def forward(self, x):
        # 前向传播函数,接受输入参数 x
        for (f, g) in self.blocks:
            # 遍历块列表中的每个块,每个块包含两个函数 f 和 g
            x = sum_tuple(f(x), x, dim = 1)
            # 将 f 函数作用在输入 x 上,然后与 x 求和,指定维度为 1
            x = sum_tuple(g(x), x, dim = 1)
            # 将 g 函数作用在上一步的结果 x 上,然后与 x 求和,指定维度为 1
        return x
        # 返回最终结果 x

class ReversibleSequence(nn.Module):
    # 定义一个可逆执行的序列模块
    def __init__(self, blocks):
        # 初始化函数,接受一个包含多个块的列表作为参数
        super().__init__()
        # 调用父类的初始化函数
        self.blocks = nn.ModuleList([ReversibleBlock(f, g) for (f, g) in blocks])
        # 将传入的块列表中的每个块转换为 ReversibleBlock 对象,并保存在当前对象的属性中

    def forward(self, x, **kwargs):
        # 前向传播函数,接受输入参数 x 和关键字参数 kwargs
        x = map_tuple(lambda t: torch.cat((t, t), dim = -1), x)
        # 对输入 x 中的每个元素应用 lambda 函数,将其在最后一个维度上进行拼接

        blocks = self.blocks
        # 将当前对象的块列表保存在变量 blocks 中

        coords, values, mask, edges = x
        # 将输入 x 拆分为 coords、values、mask 和 edges 四部分
        kwargs = {'coords': coords, 'mask': mask, 'edges': edges, **kwargs}
        # 将 coords、mask、edges 和 kwargs 合并为一个字典
        x = _ReversibleFunction.apply(values, blocks, kwargs)
        # 调用自定义的 _ReversibleFunction 类的 apply 方法,传入 values、blocks 和 kwargs,得到结果 x

        x = (coords, x, mask, edges)
        # 将 x 重新组合为一个元组
        return map_tuple(lambda t: sum(t.chunk(2, dim = -1)) * 0.5, x)
        # 对 x 中的每个元素应用 lambda 函数,将其在最后一个维度上进行拆分并求和,然后乘以 0.5

.\lucidrains\lie-transformer-pytorch\lie_transformer_pytorch\se3.py

# 从 math 模块中导入 pi 常数
from math import pi
# 导入 torch 模块
import torch
# 从 functools 模块中导入 wraps 装饰器
from functools import wraps
# 从 torch 模块中导入 acos, atan2, cos, sin 函数
from torch import acos, atan2, cos, sin
# 从 einops 模块中导入 rearrange, repeat 函数

# 常量
THRES = 7e-2

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 返回张量的设备和数据类型
def to(t):
    return {'device': t.device, 'dtype': t.dtype}

# Taylor 展开函数
def taylor(thres):
    def outer(fn):
        @wraps(fn)
        def inner(x):
            usetaylor = x.abs() < THRES
            taylor_expanded, full = fn(x, x * x)
            return torch.where(usetaylor, taylor_expanded, full)
        return inner
    return outer

# 用于解析指数映射的辅助函数。在 x=0 附近使用 Taylor 展开
# 参考 http://ethaneade.com/lie_groups.pdf 进行推导

# sinc 函数的 Taylor 展开
@taylor(THRES)
def sinc(x, x2):
    """ sin(x)/x """
    texpand = 1-x2/6*(1-x2/20*(1-x2/42))
    full = sin(x) / x
    return texpand, full

# sincc 函数的 Taylor 展开
@taylor(THRES)
def sincc(x, x2):
    """ (1-sinc(x))/x^2"""
    texpand = 1/6*(1-x2/20*(1-x2/42*(1-x2/72)))
    full = (x-sin(x)) / x**3
    return texpand, full

# cosc 函数的 Taylor 展开
@taylor(THRES)
def cosc(x, x2):
    """ (1-cos(x))/x^2"""
    texpand = 1/2*(1-x2/12*(1-x2/30*(1-x2/56)))
    full = (1-cos(x)) / x2
    return texpand, full

# coscc 函数的 Taylor 展开
@taylor(THRES)
def coscc(x, x2):
    texpand = 1/12*(1+x2/60*(1+x2/42*(1+x2/40)))
    costerm = (2*(1-cos(x))).clamp(min=1e-6)
    full = (1-x*sin(x)/costerm) / x2
    return texpand, full

# sinc_inv 函数的 Taylor 展开
@taylor(THRES)
def sinc_inv(x, _):
    texpand = 1+(1/6)*x**2 +(7/360)*x**4
    full = x / sin(x)
    assert not torch.any(torch.isinf(texpand)|torch.isnan(texpand)),'sincinv texpand inf'+torch.any(torch.isinf(texpand))
    return texpand, full

# Lie 群作用于 R3

# R3 上的 Hodge 星算子
def cross_matrix(k):
    """Application of hodge star on R3, mapping Λ^1 R3 -> Λ^2 R3"""
    K = torch.zeros(*k.shape[:-1], 3, 3, **to(k))
    K[...,0,1] = -k[...,2]
    K[...,0,2] = k[...,1]
    K[...,1,0] = k[...,2]
    K[...,1,2] = -k[...,0]
    K[...,2,0] = -k[...,1]
    K[...,2,1] = k[...,0]
    return K

# 逆 Hodge 星算子
def uncross_matrix(K):
    """Application of hodge star on R3, mapping Λ^2 R3 -> Λ^1 R3"""
    k = torch.zeros(*K.shape[:-1], **to(K))
    k[...,0] = (K[...,2,1] - K[...,1,2])/2
    k[...,1] = (K[...,0,2] - K[...,2,0])/2
    k[...,2] = (K[...,1,0] - K[...,0,1])/2
    return k

# SO3 类
class SO3:
    lie_dim = 3
    rep_dim = 3
    q_dim = 1

    def __init__(self, alpha = .2):
        super().__init__()
        self.alpha = alpha
    
    # 计算指数映射
    def exp(self,w):
        """ Computes (matrix) exponential Lie algebra elements (in a given basis).
            ie out = exp(\sum_i a_i A_i) where A_i are the exponential generators of G.
            Input: [a (*,lie_dim)] where * is arbitrarily shaped
            Output: [exp(a) (*,rep_dim,rep_dim)] returns the matrix for each."""

        """ Rodriguez's formula, assuming shape (*,3)
            where components 1,2,3 are the generators for xrot,yrot,zrot"""
        theta = w.norm(dim=-1)[..., None, None]
        K = cross_matrix(w)
        I = torch.eye(3, **to(K))
        Rs = I + K * sinc(theta) + (K @ K) * cosc(theta)
        return Rs
    
    # 计算对数映射
    def log(self,R):
        """ Computes components in terms of generators rx,ry,rz. Shape (*,3,3)"""

        """ Computes (matrix) logarithm for collection of matrices and converts to Lie algebra basis.
            Input [u (*,rep_dim,rep_dim)]
            Output [coeffs of log(u) in basis (*,d)] """
        trR = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
        costheta = ((trR-1) / 2).clamp(max=1, min=-1).unsqueeze(-1)
        theta = acos(costheta)
        logR = uncross_matrix(R) * sinc_inv(theta)
        return logR

    # 计算逆元素
    def inv(self,g):
        """ We can compute the inverse of elements g (*,rep_dim,rep_dim) as exp(-log(g))"""
        return self.exp(-self.log(g))
    def elems2pairs(self,a):
        """ 计算输入中沿着 n 维度的所有 a b 对的 log(e^-b e^a)。
            输入: [a (bs,n,d)] 输出: [pairs_ab (bs,n,n,d)] """
        # 计算 e^-a 的逆
        vinv = self.exp(-a.unsqueeze(-3))
        # 计算 e^a
        u = self.exp(a.unsqueeze(-2))
        # 计算 log(e^-b e^a)
        return self.log(vinv@u)    # ((bs,1,n,d) -> (bs,1,n,r,r))@((bs,n,1,d) -> (bs,n,1,r,r))

    def lift(self, x, nsamples, **kwargs):
        """ 假设 p 的形状为 (*,n,2),vals 的形状为 (*,n,c),mask 的形状为 (*,n)
            返回形状为 [(*,n*nsamples,lie_dim),(*,n*nsamples,c)] 的 (a,v) """
        p, v, m, e = x
        # 将 p 展开为 (bs,n*ns,d) 和 (bs,n*ns,qd)
        expanded_a = self.lifted_elems(p,nsamples,**kwargs)
        nsamples = expanded_a.shape[-2]//m.shape[-1]
        # 将 v 和 mask 像 q 一样展开
        expanded_v = repeat(v, 'b n c -> b (n m) c', m = nsamples) # (bs,n,c) -> (bs,n,1,c) -> (bs,n,ns,c) -> (bs,n*ns,c)
        expanded_mask = repeat(m, 'b n -> b (n m)', m = nsamples) # (bs,n) -> (bs,n,ns) -> (bs,n*ns)
        expanded_e = repeat(e, 'b n1 n2 c -> b (n1 m1) (n2 m2) c', m1 = nsamples, m2 = nsamples) if exists(e) else None

        # 从 elems 转换为 pairs
        paired_a = self.elems2pairs(expanded_a) #(bs,n*ns,d) -> (bs,n*ns,n*ns,d)
        embedded_locations = paired_a
        return (embedded_locations,expanded_v,expanded_mask, expanded_e)
class SE3(SO3):
    # 定义 SE3 类,继承自 SO3 类
    lie_dim = 6
    # 定义李代数维度为 6
    rep_dim = 4
    # 定义表示维度为 4
    q_dim = 0
    # 定义 q 维度为 0

    def __init__(self, alpha=.2, per_point=True):
        # 初始化函数,接受 alpha 和 per_point 两个参数
        super().__init__()
        # 调用父类的初始化函数
        self.alpha = alpha
        # 设置对象的 alpha 属性为传入的 alpha 值
        self.per_point = per_point
        # 设置对象的 per_point 属性为传入的 per_point 值

    def exp(self,w):
        # 定义 exp 函数,接受参数 w
        dd_kwargs = to(w)
        # 将 w 转换为 dd_kwargs
        theta = w[...,:3].norm(dim=-1)[...,None,None]
        # 计算 w 的前三个元素的范数,并扩展维度
        K = cross_matrix(w[...,:3])
        # 计算 w 的前三个元素的叉乘矩阵
        R = super().exp(w[...,:3])
        # 调用父类的 exp 函数,计算 w 的前三个元素的指数映射
        I = torch.eye(3, **dd_kwargs)
        # 创建 3x3 的单位矩阵
        V = I + cosc(theta)*K + sincc(theta)*(K@K)
        # 计算 V 矩阵
        U = torch.zeros(*w.shape[:-1],4,4, **dd_kwargs)
        # 创建全零的 4x4 矩阵
        U[...,:3,:3] = R
        # 将 R 赋值给 U 的前三行前三列
        U[...,:3,3] = (V@w[...,3:].unsqueeze(-1)).squeeze(-1)
        # 计算并赋值 U 的前三行第四列
        U[...,3,3] = 1
        # 设置 U 的第四行第四列为 1
        return U
        # 返回 U 矩阵
    
    def log(self,U):
        # 定义 log 函数,接受参数 U
        w = super().log(U[..., :3, :3])
        # 调用父类的 log 函数,计算 U 的前三行前三列的对数映射
        I = torch.eye(3, **to(w))
        # 创建 3x3 的单位矩阵
        K = cross_matrix(w[..., :3])
        # 计算 w 的前三个元素的叉乘矩阵
        theta = w.norm(dim=-1)[..., None, None]#%(2*pi)
        # 计算 w 的范数,并扩展维度
        cosccc = coscc(theta)
        # 计算 coscc(theta)
        Vinv = I - K/2 + cosccc*(K@K)
        # 计算 Vinv 矩阵
        u = (Vinv @ U[..., :3, 3].unsqueeze(-1)).squeeze(-1)
        # 计算 u 向量
        return torch.cat([w, u], dim=-1)
        # 返回拼接后的 w 和 u 向量

    def lifted_elems(self,pt,nsamples):
        """ pt (bs,n,D) mask (bs,n), per_point specifies whether to
            use a different group element per atom in the molecule"""
        # 返回 farthest_lift 函数的结果
        # same lifts for each point right now
        bs,n = pt.shape[:2]
        # 获取 pt 的形状
        dd_kwargs = to(pt)
        # 将 pt 转换为 dd_kwargs

        q = torch.randn(bs, (n if self.per_point else 1), nsamples, 4, **dd_kwargs)
        # 生成服从标准正态分布的随机数
        q /= q.norm(dim=-1).unsqueeze(-1)
        # 对 q 进行归一化

        theta_2 = atan2(q[..., 1:].norm(dim=-1),q[..., 0])[..., None]
        # 计算角度 theta_2
        so3_elem = 2 * sinc_inv(theta_2) * q[...,1:]
        # 计算 so3_elem
        se3_elem = torch.cat([so3_elem, torch.zeros_like(so3_elem)], dim=-1)
        # 拼接得到 se3_elem
        R = self.exp(se3_elem)
        # 计算 se3_elem 的指数映射

        T = torch.zeros(bs, n, nsamples, 4, 4, **dd_kwargs)
        # 创建全零的 4x4 矩阵
        T[..., :, :] = torch.eye(4, **dd_kwargs)
        # 将单位矩阵赋值给 T
        T[..., :3, 3] = pt[..., None, :]
        # 将 pt 赋值给 T 的前三行第四列

        a = self.log(T @ R)
        # 计算 T @ R 的对数映射
        return a.reshape(bs, n * nsamples, 6)
        # 返回重塑后的结果

    def distance(self,abq_pairs):
        # 定义 distance 函数,接受参数 abq_pairs
        dist_rot = abq_pairs[...,:3].norm(dim=-1)
        # 计算旋转部分的距离
        dist_trans = abq_pairs[...,3:].norm(dim=-1)
        # 计算平移部分的距离
        return dist_rot * self.alpha + (1-self.alpha) * dist_trans
        # 返回旋转部分距禂乘以 alpha 加上平移部分距离乘以 (1-alpha) 的结果

.\lucidrains\lie-transformer-pytorch\lie_transformer_pytorch\__init__.py

# 从lie_transformer_pytorch模块中导入LieTransformer类
from lie_transformer_pytorch.lie_transformer_pytorch import LieTransformer

Lie Transformer - Pytorch

Implementation of Lie Transformer, Equivariant Self-Attention, in Pytorch. Only the SE3 version will be present in this repository, as it may be needed for Alphafold2 replication.

Install

$ pip install lie-transformer-pytorch

Usage

import torch
from lie_transformer_pytorch import LieTransformer

model = LieTransformer(
    dim = 512,
    depth = 2,
    heads = 8,
    dim_head = 64,
    liftsamples = 4
)

coors = torch.randn(1, 64, 3)
features = torch.randn(1, 64, 512)
mask = torch.ones(1, 64).bool()

out = model(features, coors, mask = mask) # (1, 256, 512) <- 256 = (seq len * liftsamples)

Allowing Lie Transformer take care of embedding the features, just specify the number of unique tokens (node types).

import torch
from lie_transformer_pytorch import LieTransformer

model = LieTransformer(
    num_tokens = 28,           # say 28 different types of atoms
    dim = 512,
    depth = 2,
    heads = 8,
    dim_head = 64,
    liftsamples = 4
)

atoms = torch.randint(0, 28, (1, 64))
coors = torch.randn(1, 64, 3)
mask = torch.ones(1, 64).bool()

out = model(atoms, coors, mask = mask) # (1, 256, 512) <- 256 = (seq len * liftsamples)

Although it was not in the paper, I decided to allow for passing in edge information as well (bond types). The edge information will be embedded by the dimension specified, concatted with the location, and passed through the MLP before summed with the attention matrix.

Simply set two more keyword arguments on initialization of the transformer, and then pass in the specific bond types as shape b x seq x seq.

import torch
from lie_transformer_pytorch import LieTransformer

model = LieTransformer(
    num_tokens = 28,           # say 28 different types of atoms
    num_edge_types = 4,        # number of different edge types
    edge_dim = 16,             # dimension of edges
    dim = 512,
    depth = 2,
    heads = 8,
    dim_head = 64,
    liftsamples = 4
)

atoms = torch.randint(0, 28, (1, 64))
bonds = torch.randint(0, 4, (1, 64, 64))
coors = torch.randn(1, 64, 3)
mask = torch.ones(1, 64).bool()

out = model(atoms, coors, edges = bonds, mask = mask) # (1, 256, 512) <- 256 = (seq len * liftsamples)

Credit

This repository is largely adapted from LieConv, cited below

Citations

@misc{hutchinson2020lietransformer,
    title       = {LieTransformer: Equivariant self-attention for Lie Groups}, 
    author      = {Michael Hutchinson and Charline Le Lan and Sheheryar Zaidi and Emilien Dupont and Yee Whye Teh and Hyunjik Kim},
    year        = {2020},
    eprint      = {2012.10885},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{finzi2020generalizing,
    title   = {Generalizing Convolutional Neural Networks for Equivariance to Lie Groups on Arbitrary Continuous Data}, 
    author  = {Marc Finzi and Samuel Stanton and Pavel Izmailov and Andrew Gordon Wilson},
    year    = {2020},
    eprint  = {2002.12880},
    archivePrefix = {arXiv},
    primaryClass = {stat.ML}
}

.\lucidrains\lie-transformer-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'lie-transformer-pytorch', # 包的名称
  packages = find_packages(), # 查找所有包
  version = '0.0.17', # 版本号
  license='MIT', # 许可证
  description = 'Lie Transformer - Pytorch', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  url = 'https://github.com/lucidrains/lie-transformer-pytorch', # 项目链接
  keywords = [ # 关键词列表
    'artificial intelligence',
    'attention mechanism',
    'transformers',
    'equivariance',
    'lifting',
    'lie groups'
  ],
  install_requires=[ # 安装依赖
    'torch>=1.6',
    'einops>=0.3'
  ],
  setup_requires=[ # 设置需要的依赖
    'pytest-runner',
  ],
  tests_require=[ # 测试需要的依赖
    'pytest'
  ],
  classifiers=[ # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\lie-transformer-pytorch\tests.py

# 导入 torch 库
import torch
# 从 lie_transformer_pytorch 库中导入 LieTransformer 类
from lie_transformer_pytorch import LieTransformer

# 定义测试 LieTransformer 类的函数
def test_transformer():
    # 创建 LieTransformer 模型对象,设置维度为 512,深度为 1
    model = LieTransformer(
        dim = 512,
        depth = 1
    )

    # 生成一个形状为 (1, 64, 512) 的随机张量 feats
    feats = torch.randn(1, 64, 512)
    # 生成一个形状为 (1, 64, 3) 的随机张量 coors
    coors = torch.randn(1, 64, 3)
    # 生成一个形状为 (1, 64) 的全为 True 的布尔张量 mask
    mask = torch.ones(1, 64).bool()

    # 使用 LieTransformer 模型处理 feats, coors 和 mask,得到输出 out
    out = model(feats, coors, mask = mask)
    # 断言输出 out 的形状为 (1, 256, 512),如果不是则输出 'transformer runs'
    assert out.shape == (1, 256, 512), 'transformer runs'

.\lucidrains\lightweight-gan\lightweight_gan\cli.py

# 导入所需的库
import os
import fire
import random
from retry.api import retry_call
from tqdm import tqdm
from datetime import datetime
from functools import wraps
from lightweight_gan import Trainer, NanException
from lightweight_gan.diff_augment_test import DiffAugmentTest

import torch
import torch.multiprocessing as mp
import torch.distributed as dist

import numpy as np

# 检查值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 将元素转换为列表
def cast_list(el):
    return el if isinstance(el, list) else [el]

# 生成带时间戳的文件名
def timestamped_filename(prefix = 'generated-'):
    now = datetime.now()
    timestamp = now.strftime("%m-%d-%Y_%H-%M-%S")
    return f'{prefix}{timestamp}'

# 设置随机种子
def set_seed(seed):
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

# 运行训练过程
def run_training(rank, world_size, model_args, data, load_from, new, num_train_steps, name, seed, use_aim, aim_repo, aim_run_hash):
    is_main = rank == 0
    is_ddp = world_size > 1

    if is_ddp:
        set_seed(seed)
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        dist.init_process_group('nccl', rank=rank, world_size=world_size)

        print(f"{rank + 1}/{world_size} process initialized.")

    model_args.update(
        is_ddp = is_ddp,
        rank = rank,
        world_size = world_size
    )

    model = Trainer(**model_args, hparams=model_args, use_aim=use_aim, aim_repo=aim_repo, aim_run_hash=aim_run_hash)

    if not new:
        model.load(load_from)
    else:
        model.clear()

    model.set_data_src(data)

    progress_bar = tqdm(initial = model.steps, total = num_train_steps, mininterval=10., desc=f'{name}<{data}>')
    while model.steps < num_train_steps:
        retry_call(model.train, tries=3, exceptions=NanException)
        progress_bar.n = model.steps
        progress_bar.refresh()
        if is_main and model.steps % 50 == 0:
            model.print_log()

    model.save(model.checkpoint_num)

    if is_ddp:
        dist.destroy_process_group()

# 从文件夹中训练模型
def train_from_folder(
    data = './data',
    results_dir = './results',
    models_dir = './models',
    name = 'default',
    new = False,
    load_from = -1,
    image_size = 256,
    optimizer = 'adam',
    fmap_max = 512,
    transparent = False,
    greyscale = False,
    batch_size = 10,
    gradient_accumulate_every = 4,
    num_train_steps = 150000,
    learning_rate = 2e-4,
    save_every = 1000,
    evaluate_every = 1000,
    generate = False,
    generate_types = ['default', 'ema'],
    generate_interpolation = False,
    aug_test = False,
    aug_prob=None,
    aug_types=['cutout', 'translation'],
    dataset_aug_prob=0.,
    attn_res_layers = [32],
    freq_chan_attn = False,
    disc_output_size = 1,
    dual_contrast_loss = False,
    antialias = False,
    interpolation_num_steps = 100,
    save_frames = False,
    num_image_tiles = None,
    num_workers = None,
    multi_gpus = False,
    calculate_fid_every = None,
    calculate_fid_num_images = 12800,
    clear_fid_cache = False,
    seed = 42,
    amp = False,
    show_progress = False,
    use_aim = False,
    aim_repo = None,
    aim_run_hash = None,
    load_strict = True
):
    num_image_tiles = default(num_image_tiles, 4 if image_size > 512 else 8)
    # 定义模型参数字典
    model_args = dict(
        name = name,
        results_dir = results_dir,
        models_dir = models_dir,
        batch_size = batch_size,
        gradient_accumulate_every = gradient_accumulate_every,
        attn_res_layers = cast_list(attn_res_layers),
        freq_chan_attn = freq_chan_attn,
        disc_output_size = disc_output_size,
        dual_contrast_loss = dual_contrast_loss,
        antialias = antialias,
        image_size = image_size,
        num_image_tiles = num_image_tiles,
        optimizer = optimizer,
        num_workers = num_workers,
        fmap_max = fmap_max,
        transparent = transparent,
        greyscale = greyscale,
        lr = learning_rate,
        save_every = save_every,
        evaluate_every = evaluate_every,
        aug_prob = aug_prob,
        aug_types = cast_list(aug_types),
        dataset_aug_prob = dataset_aug_prob,
        calculate_fid_every = calculate_fid_every,
        calculate_fid_num_images = calculate_fid_num_images,
        clear_fid_cache = clear_fid_cache,
        amp = amp,
        load_strict = load_strict
    )

    # 如果需要生成图片
    if generate:
        # 创建训练器对象,传入模型参数和是否使用 AIM
        model = Trainer(**model_args, use_aim = use_aim)
        # 加载模型
        model.load(load_from)
        # 生成样本名称
        samples_name = timestamped_filename()
        # 获取当前训练步数
        checkpoint = model.checkpoint_num
        # 生成图片
        dir_result = model.generate(samples_name, num_image_tiles, checkpoint, generate_types)
        # 打印生成的样本图片路径
        print(f'sample images generated at {dir_result}')
        return

    # 如果需要生成插值图片
    if generate_interpolation:
        # 创建训练器对象,传入模型参数和是否使用 AIM
        model = Trainer(**model_args, use_aim = use_aim)
        # 加载模型
        model.load(load_from)
        # 生成样本名称
        samples_name = timestamped_filename()
        # 生成插值图片
        model.generate_interpolation(samples_name, num_image_tiles, num_steps = interpolation_num_steps, save_frames = save_frames)
        # 打印生成的插值图片路径
        print(f'interpolation generated at {results_dir}/{name}/{samples_name}')
        return

    # 如果需要展示训练进度
    if show_progress:
        # 创建训练器对象,传入模型参数和是否使用 AIM
        model = Trainer(**model_args, use_aim = use_aim)
        # 展示训练进度
        model.show_progress(num_images=num_image_tiles, types=generate_types)
        return

    # 如果需要进行数据增强测试
    if aug_test:
        # 进行数据增强测试
        DiffAugmentTest(data=data, image_size=image_size, batch_size=batch_size, types=aug_types, nrow=num_image_tiles)
        return

    # 获取当前可用的 GPU 数量
    world_size = torch.cuda.device_count()

    # 如果只有一个 GPU 或者不使用多 GPU 训练
    if world_size == 1 or not multi_gpus:
        # 单 GPU 训练
        run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed, use_aim, aim_repo, aim_run_hash)
        return

    # 使用多 GPU 训练
    mp.spawn(run_training,
        args=(world_size, model_args, data, load_from, new, num_train_steps, name, seed, use_aim, aim_repo, aim_run_hash,),
        nprocs=world_size,
        join=True)
# 定义主函数
def main():
    # 使用 Fire 库将 train_from_folder 函数转换为命令行接口
    fire.Fire(train_from_folder)
posted @ 2024-06-28 14:02  绝不原创的飞龙  阅读(6)  评论(0编辑  收藏  举报