Lucidrains-系列项目源码解析-三十二-

Lucidrains 系列项目源码解析(三十二)

Parti - Pytorch

Implementation of Parti, Google's pure attention-based text-to-image neural network, in Pytorch. Project Page

This repository also contains working training code for ViT VQGan VAE. It also contains some additional modifications for faster training from vision transformers literature.

Yannic Kilcher

Please join Join us on Discord if you are interested in helping out with the replication with the LAION community

Install

$ pip install parti-pytorch

Usage

First you will need to train your Transformer VQ-GAN VAE

from parti_pytorch import VitVQGanVAE, VQGanVAETrainer

vit_vae = VitVQGanVAE(
    dim = 256,               # dimensions
    image_size = 256,        # target image size
    patch_size = 16,         # size of the patches in the image attending to each other
    num_layers = 3           # number of layers
).cuda()

trainer = VQGanVAETrainer(
    vit_vae,
    folder = '/path/to/your/images',
    num_train_steps = 100000,
    lr = 3e-4,
    batch_size = 4,
    grad_accum_every = 8,
    amp = True
)

trainer.train()

Then

import torch
from parti_pytorch import Parti, VitVQGanVAE

# first instantiate your ViT VQGan VAE
# a VQGan VAE made of transformers

vit_vae = VitVQGanVAE(
    dim = 256,               # dimensions
    image_size = 256,        # target image size
    patch_size = 16,         # size of the patches in the image attending to each other
    num_layers = 3           # number of layers
).cuda()

vit_vae.load_state_dict(torch.load(f'/path/to/vae.pt')) # you will want to load the exponentially moving averaged VAE

# then you plugin the ViT VqGan VAE into your Parti as so

parti = Parti(
    vae = vit_vae,            # vit vqgan vae
    dim = 512,                # model dimension
    depth = 8,                # depth
    dim_head = 64,            # attention head dimension
    heads = 8,                # attention heads
    dropout = 0.,             # dropout
    cond_drop_prob = 0.25,    # conditional dropout, for classifier free guidance
    ff_mult = 4,              # feedforward expansion factor
    t5_name = 't5-large',     # name of your T5
)

# ready your training text and images

texts = [
    'a child screaming at finding a worm within a half-eaten apple',
    'lizard running across the desert on two feet',
    'waking up to a psychedelic landscape',
    'seashells sparkling in the shallow waters'
]

images = torch.randn(4, 3, 256, 256).cuda()

# feed it into your parti instance, with return_loss set to True

loss = parti(
    texts = texts,
    images = images,
    return_loss = True
)

loss.backward()

# do this for a long time on much data
# then...

images = parti.generate(texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles'
], cond_scale = 3., return_pil_images = True) # conditioning scale for classifier free guidance

# List[PILImages] (256 x 256 RGB)

Realistically, when scaling up, you'll want to pre-encode your text into tokens and their respective mask

from parti_pytorch.t5 import t5_encode_text

images = torch.randn(4, 3, 256, 256).cuda()

text_token_embeds, text_mask = t5_encode_text([
    'a child screaming at finding a worm within a half-eaten apple',
    'lizard running across the desert on two feet',
    'waking up to a psychedelic landscape',
    'seashells sparkling in the shallow waters'
], name = 't5-large', output_device = images.device)

# store somewhere, then load with the dataloader

loss = parti(
    text_token_embeds = text_token_embeds,
    text_mask = text_mask,
    images = images,
    return_loss = True
)

loss.backward()

Appreciation

  • StabilityAI for the sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.

  • 🤗 Huggingface for the transformers library and the ease for encoding text with T5 language model

Todo

Citations

@inproceedings{Yu2022Pathways
    title   = {Pathways Autoregressive Text-to-Image Model},
    author  = {Jiahui Yu*, Yuanzhong Xu†, Jing Yu Koh†, Thang Luong†, Gunjan Baid†, Zirui Wang†, Vijay Vasudevan†, Alexander Ku†, Yinfei Yang, Burcu Karagol Ayan, Ben Hutchinson, Wei Han, Zarana Parekh, Xin Li, Han Zhang, Jason Baldridge†, Yonghui Wu*},
    year    = {2022}
}
@article{Shleifer2021NormFormerIT,
    title   = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
    author  = {Sam Shleifer and Jason Weston and Myle Ott},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2110.09456}
}
@article{Sankararaman2022BayesFormerTW,
    title   = {BayesFormer: Transformer with Uncertainty Estimation},
    author  = {Karthik Abinav Sankararaman and Sinong Wang and Han Fang},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2206.00826}
}
@article{Lee2021VisionTF,
    title   = {Vision Transformer for Small-Size Datasets},
    author  = {Seung Hoon Lee and Seunghyun Lee and Byung Cheol Song},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2112.13492}
}
@article{Chu2021DoWR,
    title   = {Do We Really Need Explicit Position Encodings for Vision Transformers?},
    author  = {Xiangxiang Chu and Bo Zhang and Zhi Tian and Xiaolin Wei and Huaxia Xia},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2102.10882}
}
@article{So2021PrimerSF,
    title   = {Primer: Searching for Efficient Transformers for Language Modeling},
    author  = {David R. So and Wojciech Ma'nke and Hanxiao Liu and Zihang Dai and Noam M. Shazeer and Quoc V. Le},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2109.08668}
}
@inproceedings{Wang2021CrossFormerAV,
    title   = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
    author  = {Wenxiao Wang and Lulian Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
    year    = {2021}
}
@misc{mentzer2023finite,
    title   = {Finite Scalar Quantization: VQ-VAE Made Simple},
    author  = {Fabian Mentzer and David Minnen and Eirikur Agustsson and Michael Tschannen},
    year    = {2023},
    eprint  = {2309.15505},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{yu2023language,
    title   = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation},
    author  = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
    year    = {2023},
    eprint  = {2310.05737},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

.\lucidrains\parti-pytorch\setup.py

# 导入设置工具和查找包
from setuptools import setup, find_packages
# 执行版本文件中的代码,将版本信息导入当前环境
exec(open('parti_pytorch/version.py').read())

# 设置包的元信息
setup(
  name = 'parti-pytorch',  # 包名
  packages = find_packages(exclude=[]),  # 查找包
  version = __version__,  # 版本号
  license='MIT',  # 许可证
  description = 'Parti - Pathways Autoregressive Text-to-Image Model - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/parti-pytorch',  # URL
  keywords = [  # 关键词
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'text-to-image'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.7',
    'einops-exts',
    'ema-pytorch',
    'torch>=1.6',
    'torchvision',
    'transformers',
    'vector-quantize-pytorch>=1.11.8'
  ],
  classifiers=[  # 分类
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\pause-transformer\pause_transformer\pause_transformer.py

import torch
import torch.nn.functional as F
from torch import nn, Tensor, einsum
from torch.nn import Module, ModuleList, Sequential

from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange

# functions

# 检查变量是否存在的函数
def exists(v):
    return v is not None

# tensor functions

# 计算张量的对数,避免出现负无穷
def log(t, eps = 1e-20):
    return t.clamp(min = eps).log()

# 计算张量的熵
def entropy(t, dim = -1):
    prob = t.softmax(dim = dim)
    return (prob * log(prob)).sum(dim = dim)

# norm

# RMS 归一化
class RMSNorm(Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return F.normalize(x, dim = -1) * self.scale * self.gamma

# cheap relative positions
# from Peng Bo's RWKV

# 移动 token 的模块
class ShiftTokens(Module):
    def forward(self, x):
        x, x_shift = x.chunk(2, dim = -1)
        x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.)
        return torch.cat((x, x_shift), dim = -1)

# feedforward

# 前馈神经网络
def FeedForward(dim, mult = 4):
    dim_inner = int(dim * mult)
    return Sequential(
        ShiftTokens(),
        RMSNorm(dim),
        nn.Linear(dim, dim_inner),
        nn.GELU(),
        nn.Linear(dim_inner, dim)
    )

# CausalAttention

# 因果注意力机制
class CausalAttention(Module):
    def __init__(
        self,
        dim,
        *,
        dim_head = 64,
        heads = 8
    ):
        super().__init__()
        self.scale = dim ** -0.5
        dim_inner = dim_head * heads

        self.norm = RMSNorm(dim)

        self.to_qkv = Sequential(
            nn.Linear(dim, dim_inner * 3, bias = False),
            Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads)
        )

        self.to_out = Sequential(
            Rearrange('b h n d -> b n (h d)'),
            nn.Linear(dim_inner, dim, bias = False)
        )

    def forward(self, x):
        x = self.norm(x)

        q, k, v = self.to_qkv(x)

        q = q * self.scale
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        i, j = sim.shape[-2:]
        causal_mask = torch.ones((i, j), device = x.device).triu(j - i + 1)

        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        attn = sim.softmax(dim = -1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        return self.to_out(out), torch.stack((k, v))

# integrate previous pause / thinking information

# 整合之前的暂停/思考信息
class IntegratePreviousThought(Module):
    def __init__(self, dim):
        super().__init__()
        self.net =  Sequential(
            RMSNorm(dim),
            Rearrange('b n p d -> b n (p d)'),
            nn.Linear(dim * 2, dim)
        )

    def forward(
        self,
        x,
        pause_tokens,
        pause_lengths = None
    ):
        if not exists(pause_lengths):
            p = pause_tokens[:, :, -1]
        else:
            batch, seq_len = x.shape[:2]
            batch_arange = torch.arange(batch, device = x.device)[:, None, None]
            seq_arange = torch.arange(seq_len, device = x.device)[:, None]
            pause_lengths = pause_lengths[:, :, None]

            p = pause_tokens[batch_arange, seq_arange, pause_lengths]
            p = rearrange(p, '... 1 d -> ... d')

        p = F.pad(p, (0, 0, 1, -1), value = 0.)

        x = torch.stack((x, p), dim = -2)
        out = self.net(x)
        return out

# class

# 暂停 Transformer
class PauseTransformer(Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        depth,
        max_pause_length = 2,
        dim_head = 64,
        heads = 8,
        ff_mult = 4
    ):
        # 调用父类的构造函数
        super().__init__()

        # 创建一个嵌入层,用于将输入的 token 映射为指定维度的向量
        self.token_emb = nn.Embedding(num_tokens, dim)

        # 设置最大暂停长度
        self.max_pause_length = max_pause_length

        # 创建一个可学习的参数,表示暂停的 token
        self.pause_tokens = nn.Parameter(torch.randn(max_pause_length, dim))

        # 创建一个用于整合前一个暂停的模块
        self.integrate_prev_pause = IntegratePreviousThought(dim)

        # 创建一个空的模块列表,用于存储多个层
        self.layers = ModuleList([])

        # 根据指定的深度循环创建多个层
        for _ in range(depth):
            # 每个层包含一个自注意力机制和一个前馈神经网络
            self.layers.append(ModuleList([
                CausalAttention(dim = dim, dim_head = dim_head, heads = heads),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

        # 创建一个用于输出 logits 的序列模块
        self.to_logits = Sequential(
            RMSNorm(dim),
            nn.Linear(dim, num_tokens, bias = False)
        )

    def forward(
        self,
        x,
        return_loss = False,
        return_logit_entropy = False,
        arrest_pausing = False,
        no_prev_pause_integration = False,
        pause_lengths = None,
        rand_uniform_pausing = False        # this would do random pausing uniform from [0, max_pause_length]
    ):
        """
        einstein notation:
        b - batch
        n - main sequence length
        p - thinking sequence length (pause)
        d - feature dimension
        """

        # 如果需要返回损失,则提取输入序列和标签序列
        if return_loss:
            x, labels = x[:, :-1], x[:, 1:]

        # 如果不需要阻止暂停
        if not arrest_pausing:
            # 如果需要随机暂停且暂停长度未指定,则随机生成暂停长度
            if rand_uniform_pausing and not exists(pause_lengths):
                pause_lengths = torch.randint(0, self.max_pause_length, x.shape)

        # 获取输入张量的批量大小和序列长度
        batch, seq_len = x.shape

        # 将输入 token 映射为向量
        x = self.token_emb(x)

        # 重复暂停 token,以便与输入张量形状匹配
        p = repeat(self.pause_tokens, 'p d -> b n p d', b = batch, n = seq_len)

        # 如果暂停长度已指定
        if exists(pause_lengths):
            max_pause = int(pause_lengths.amax().item())
            p = p[:, :, :(max_pause + 1)]

            # 如果最大暂停长度为 0,则阻止暂停
            arrest_pausing = max_pause == 0

        # 遍历每个层的自注意力机制和前馈神经网络
        for attn, ff in self.layers:
            attn_out, cached_kvs = attn(x)
            x = x + attn_out
            x = ff(x) + x

            # 如果阻止暂停,则跳过暂停处理
            if arrest_pausing:
                continue

            # 处理思考 token

            x, ps = pack([x, p], 'b n * d')
            x = rearrange(x, '... p d -> (...) p d')

            attn_out, _ = attn(x)

            x = x + attn_out
            x = ff(x) + x

            x = rearrange(x, '(b n) p d -> b n p d', b = batch)
            x, p = unpack(x, ps, 'b n * d')

            # 在训练过程中,允许每个 token 独立思考,不受前一个 token 思考的影响
            if no_prev_pause_integration:
                continue

            # 整合前一个暂停的最后一个 token
            x = x + self.integrate_prev_pause(x, p, pause_lengths)

        # 如果不阻止暂停,则重新打包输入张量和暂停张量
        if not arrest_pausing:
            x, _ = pack([x, p], 'b n * d')

        # 计算 logits
        logits = self.to_logits(x)

        # 如果需要返回 logits 的熵
        if return_logit_entropy:
            return entropy(logits)

        # 如果不需要返回损失,则返回 logits
        if not return_loss:
            return logits

        # 如果阻止暂停,则重新排列 logits 的形状
        if arrest_pausing:
            logits = rearrange(logits, 'b n d -> b d n')
        else:
            labels = repeat(labels, 'b n -> (b p) n', p = self.max_pause_length + 1)
            logits = rearrange(logits, 'b n p d -> (b p) d n')

        # 计算交叉熵损失
        loss = F.cross_entropy(logits, labels)
        return loss

.\lucidrains\pause-transformer\pause_transformer\__init__.py

# 从 pause_transformer.pause_transformer 模块中导入 PauseTransformer 类
from pause_transformer.pause_transformer import PauseTransformer

Pause Transformer (wip)

Yet another random morning idea to be quickly tried and architecture shared if it works; to allow the transformer to pause for any amount of time on any token.

Again, the idea relies on axial attention; one axis attends along the sequence length as in the usual transformer, the other along a thinking or pause dimension.

Todo

Citations

@inproceedings{Goyal2023ThinkBY,
    title   = {Think before you speak: Training Language Models With Pause Tokens},
    author  = {Sachin Goyal and Ziwei Ji and Ankit Singh Rawat and Aditya Krishna Menon and Sanjiv Kumar and Vaishnavh Nagarajan},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:263608983}
}

.\lucidrains\pause-transformer\setup.py

# 导入设置安装和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'pause-transformer',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.0.7',
  # 许可证类型
  license='MIT',
  # 描述信息
  description = 'Pause Transformer',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/pause-transformer',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'deep learning',
    'adaptive computation'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.7.0',
    'torch>=2.0'
  ],
  # 分类标签
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

.\lucidrains\perceiver-ar-pytorch\perceiver_ar_pytorch\autoregressive_wrapper.py

# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 从 torch 库中导入 nn 模块
from torch import nn

# 定义一个辅助函数,用于检查值是否存在
def exists(val):
    return val is not None

# 定义一个装饰器函数,用于在模型评估时切换模型状态
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 定义一个函数用于对 logits 进行 top k 过滤
def top_k(logits, thres=0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float("-inf"))
    probs.scatter_(1, ind, val)
    return probs

# 定义一个自回归封装器类
class AutoregressiveWrapper(nn.Module):
    def __init__(self, net, pad_value=0):
        super().__init__()
        self.max_seq_len = net.max_seq_len
        self.pad_value = pad_value
        self.net = net

    # 生成函数,用于生成序列
    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        start_tokens,
        seq_len,
        eos_token=None,
        temperature=1.0,
        filter_thres=0.9,
        **kwargs
    ):
        b, n, device = *start_tokens.shape, start_tokens.device

        out = start_tokens

        for _ in range(seq_len):
            logits = self.net(
                out[:, -self.max_seq_len:],
                **kwargs
            )[:, -1]

            filtered_logits = top_k(logits, thres=filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim=-1)

            sample = torch.multinomial(probs, 1)
            out = torch.cat((out, sample), dim=-1)

            if exists(eos_token):
                is_eos_token = out == eos_token

                if is_eos_token.any(dim=-1).all():
                    # mask out everything after the eos tokens
                    shifted_is_eos_tokens = F.pad(is_eos_token, (1, -1))
                    mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1
                    out = out.masked_fill(mask, self.pad_value)
                    break

        out = out[:, n:]
        return out

    # 前向传播函数,用于模型训练
    def forward(self, x, **kwargs):
        x_inp, x_labels = x[:, :-1], x[:, 1:]
        return self.net(x_inp, labels=x_labels, **kwargs)

.\lucidrains\perceiver-ar-pytorch\perceiver_ar_pytorch\perceiver_ar_pytorch.py

import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange, repeat

# helper functions

# 检查变量是否存在的辅助函数
def exists(val):
    return val is not None

# feedforward

# 定义前馈神经网络层
def FeedForward(dim, mult = 4, dropout = 0.):
    hidden_dim = int(dim * mult)
    return nn.Sequential(
        nn.LayerNorm(dim),  # 对输入进行 Layer Normalization
        nn.Linear(dim, hidden_dim, bias = False),  # 线性变换
        nn.GELU(),  # GELU 激活函数
        nn.Dropout(dropout),  # Dropout 正则化
        nn.Linear(hidden_dim, dim, bias = False)  # 线性变换
    )

# rotary positional embedding
# https://arxiv.org/abs/2104.09864

# 旋转位置嵌入类
class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, max_seq_len, *, device):
        seq = torch.arange(max_seq_len, device = device, dtype = self.inv_freq.dtype)
        freqs = einsum("i , j -> i j", seq, self.inv_freq)
        return torch.cat((freqs, freqs), dim = -1)


# 旋转半个张量
def rotate_half(x):
    x = rearrange(x, "... (j d) -> ... j d", j = 2)
    x1, x2 = x.unbind(dim = -2)
    return torch.cat((-x2, x1), dim = -1)


# 应用旋转位置嵌入
def apply_rotary_pos_emb(pos, t):
    seq_len, rotate_dim = t.shape[-2], pos.shape[-1]
    pos = pos[..., -seq_len:, :]
    t, t_pass = t[..., :rotate_dim], t[..., rotate_dim:]
    t = (t * pos.cos()) + (rotate_half(t) * pos.sin())
    return torch.cat((t, t_pass), dim = -1)

# attention

# 因果注意力机制类
class CausalAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        dropout = 0.
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = heads * dim_head

        self.norm = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(dropout)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x, rotary_pos_emb = None):
        x = self.norm(x)

        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

        q = q * self.scale

        if exists(rotary_pos_emb):
            q = apply_rotary_pos_emb(rotary_pos_emb, q)
            k = apply_rotary_pos_emb(rotary_pos_emb, k)

        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        i, j = sim.shape[-2:]
        causal_mask = torch.ones((i, j), device = x.device, dtype = torch.bool).triu(j - i + 1)
        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 因果前缀注意力机制类
class CausalPrefixAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        max_heads_process = 2,
        dropout = 0.,
        cross_attn_dropout = 0.
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        self.max_heads_process = max_heads_process

        inner_dim = heads * dim_head

        self.norm = nn.LayerNorm(dim)
        self.context_norm = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(dropout)

        self.cross_attn_dropout = cross_attn_dropout # they drop out a percentage of the prefix during training, shown to help prevent overfitting

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)
    # 定义前向传播函数,接受输入 x、上下文 context、上下文掩码 context_mask 和旋转位置嵌入 rotary_pos_emb
    def forward(self, x, context, context_mask = None, rotary_pos_emb = None):
        # 获取输入 x 的批量大小、上下文长度和设备信息
        batch, context_len, device = x.shape[0], context.shape[-2], x.device

        # 复制旋转位置嵌入作为查询和键的旋转位置嵌入
        q_rotary_pos_emb = rotary_pos_emb
        k_rotary_pos_emb = rotary_pos_emb

        # 处理交叉注意力的 dropout

        if self.training and self.cross_attn_dropout > 0.:
            # 生成随机数用于 dropout
            rand = torch.zeros((batch, context_len), device = device).uniform_()
            keep_context_len = context_len - int(context_len * self.cross_attn_dropout)
            keep_indices = rand.topk(keep_context_len, dim = -1).indices
            keep_mask = torch.zeros_like(rand).scatter_(1, keep_indices, 1).bool()

            # 根据掩码保留一部分上下文信息
            context = rearrange(context[keep_mask], '(b n) d -> b n d', b = batch)

            if exists(context_mask):
                context_mask = rearrange(context_mask[keep_mask], '(b n) -> b n', b = batch)

            # 对键的旋转位置嵌入进行操作
            k_rotary_pos_emb = repeat(k_rotary_pos_emb, '... -> b ...', b = batch)
            k_rotary_pos_emb_context, k_rotary_pos_emb_seq = k_rotary_pos_emb[:, :context_len], k_rotary_pos_emb[:, context_len:]
            k_rotary_pos_emb_context = rearrange(k_rotary_pos_emb_context[keep_mask], '(b n) d -> b n d', b = batch)

            k_rotary_pos_emb = torch.cat((k_rotary_pos_emb_context, k_rotary_pos_emb_seq), dim = 1)
            k_rotary_pos_emb = rearrange(k_rotary_pos_emb, 'b n d -> b 1 n d')

        # 归一化处理
        x = self.norm(x)
        context = self.context_norm(context)

        # 获取查询、键、值
        q = self.to_q(x)

        k_input, v_input = self.to_kv(x).chunk(2, dim = -1)
        k_context, v_context = self.to_kv(context).chunk(2, dim = -1)

        k = torch.cat((k_context, k_input), dim = 1)
        v = torch.cat((v_context, v_input), dim = 1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

        q = q * self.scale

        # 使用旋转位置嵌入旋转查询和键
        if exists(rotary_pos_emb):
            q = apply_rotary_pos_emb(q_rotary_pos_emb, q)
            k = apply_rotary_pos_emb(k_rotary_pos_emb, k)

        # 处理掩码
        i, j = q.shape[-2], k.shape[-2]
        mask_value = -torch.finfo(q.dtype).max

        if exists(context_mask):
            mask_len = context_mask.shape[-1]
            context_mask = F.pad(context_mask, (0, max(j - mask_len, 0)), value = True)
            context_mask = rearrange(context_mask, 'b j -> b 1 1 j')

        causal_mask = torch.ones((i, j), device = x.device, dtype = torch.bool).triu(j - i + 1)

        # 按头部分块处理
        out = []

        max_heads = self.max_heads_process

        for q_chunk, k_chunk, v_chunk in zip(q.split(max_heads, dim = 1), k.split(max_heads, dim = 1), v.split(max_heads, dim = 1):
            sim = einsum('b h i d, b h j d -> b h i j', q_chunk, k_chunk)

            if exists(context_mask):
                sim = sim.masked_fill(~context_mask, mask_value)

            sim = sim.masked_fill(causal_mask, mask_value)

            attn = sim.softmax(dim = -1)
            attn = self.dropout(attn)

            out_chunk = einsum('b h i j, b h j d -> b h i d', attn, v_chunk)
            out.append(out_chunk)

        # 拼接所有头部
        out = torch.cat(out, dim = 1)

        # 合并头部并与线性层结合
        out = rearrange(out, 'b h n d -> b n (h d)')

        return self.to_out(out)
class PerceiverAR(nn.Module):
    # 定义 PerceiverAR 类,继承自 nn.Module
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        depth,
        max_seq_len,
        cross_attn_seq_len,
        dim_head = 64,
        heads = 8,
        dropout = 0.,
        cross_attn_dropout = 0.,
        ff_mult = 4,
        perceive_depth = 1,
        perceive_max_heads_process = 2 # processes the heads in the perceiver layer in chunks to lower peak memory, in the case the prefix is really long
    ):
        # 初始化函数,接受多个参数
        super().__init__()
        # 调用父类的初始化函数
        assert max_seq_len > cross_attn_seq_len, 'max_seq_len must be greater than cross_attn_seq_len, the length of the sequence for which to cross attend to "perceiver" style'
        # 断言,确保 max_seq_len 大于 cross_attn_seq_len
        self.max_seq_len = max_seq_len
        self.cross_attn_seq_len = cross_attn_seq_len

        self.token_emb = nn.Embedding(num_tokens, dim)
        # 创建 token embedding 层
        self.pos_emb = nn.Embedding(max_seq_len, dim)
        # 创建位置 embedding 层

        self.rotary_pos_emb = RotaryEmbedding(dim = max(32, dim_head // 2))
        # 创建旋转位置 embedding 层

        self.perceive_layers  = nn.ModuleList([])
        # 创建感知层的 ModuleList

        for _ in range(perceive_depth):
            # 循环感知深度次数
            self.perceive_layers.append(nn.ModuleList([
                CausalPrefixAttention(dim = dim, dim_head = dim_head, heads = heads, max_heads_process = perceive_max_heads_process, dropout = dropout, cross_attn_dropout = cross_attn_dropout),
                FeedForward(dim, mult = ff_mult, dropout = dropout)
            ]))
            # 将 CausalPrefixAttention 和 FeedForward 添加到感知层中

        self.layers = nn.ModuleList([])
        # 创建层的 ModuleList
        for _ in range(depth):
            # 循环深度次数
            self.layers.append(nn.ModuleList([
                CausalAttention(dim = dim, dim_head = dim_head, heads = heads),
                FeedForward(dim, mult = ff_mult, dropout = dropout),
            ]))
            # 将 CausalAttention 和 FeedForward 添加到层中

        self.to_logits = nn.Linear(dim, num_tokens, bias = False)
        # 创建线性层,用于输出 logits

    def forward(
        self,
        x,
        prefix_mask = None,
        labels = None
    ):
        # 前向传播函数,接受输入 x,前缀掩码和标签
        seq_len, device = x.shape[1], x.device
        # 获取序列长度和设备信息
        assert self.cross_attn_seq_len < seq_len <= self.max_seq_len
        # 断言,确保交叉注意力序列长度小于序列长度且小于等于最大序列长度

        x = self.token_emb(x)
        # 对输入进行 token embedding
        x = x + self.pos_emb(torch.arange(seq_len, device = device))
        # 添加位置 embedding

        # rotary positional embedding

        rotary_pos_emb = self.rotary_pos_emb(seq_len, device = device)
        # 获取旋转位置 embedding

        # divide into prefix to cross attend to and sequence to self attend to

        prefix, x = x[:, :self.cross_attn_seq_len], x[:, self.cross_attn_seq_len:]
        # 将输入分为前缀和序列部分

        # initial perceiver attention and feedforward (one cross attention)

        for cross_attn, ff in self.perceive_layers:
            # 遍历感知层
            x = cross_attn(x, prefix, context_mask = prefix_mask, rotary_pos_emb = rotary_pos_emb) + x
            # 进行交叉注意力操作
            x = ff(x) + x
            # 进行前馈操作

        # layers

        for attn, ff in self.layers:
            # 遍历层
            x = attn(x, rotary_pos_emb = rotary_pos_emb) + x
            # 进行自注意力操作
            x = ff(x) + x
            # 进行前馈操作

        # to logits

        logits = self.to_logits(x)
        # 计算 logits

        # take care of cross entropy loss if labels are provided

        if not exists(labels):
            return logits
        # 如果提供了标签,则处理交叉熵损失

        labels = labels[:, self.cross_attn_seq_len:]
        # 获取标签的序列部分
        return F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels, ignore_index = 0)
        # 计算交叉熵损失

.\lucidrains\perceiver-ar-pytorch\perceiver_ar_pytorch\__init__.py

# 从 perceiver_ar_pytorch.perceiver_ar_pytorch 模块中导入 PerceiverAR 类
from perceiver_ar_pytorch.perceiver_ar_pytorch import PerceiverAR

Perceiver AR - Pytorch

Implementation of Perceiver AR, Deepmind's new long-context attention network based on Perceiver architecture, in Pytorch.

Generated piano samples

I am building this out of popular demand, not because I believe in the architecture. As someone else puts it succinctly, this is equivalent to an encoder / decoder transformer architecture where the encoder has 0 layers (and the decoder cross attention is restricted to 1 layer)

However, the experimental results they provided are still worthwhile and I'll build it out so students and researchers alike can explore along this avenue.

Official Jax repository

Update: seems to be performing decently well on enwik8 with 4096 context length. maybe I was wrong to be pessimistic

Install

$ pip install perceiver-ar-pytorch

Usage

import torch
from perceiver_ar_pytorch import PerceiverAR

model = PerceiverAR(
    num_tokens = 20000,             # number of tokens
    dim = 512,                      # model dimensions
    depth = 8,                      # model depth
    dim_head = 64,                  # attention head dimension
    heads = 8,                      # attention heads
    max_seq_len = 4096,             # total max sequence length
    cross_attn_seq_len = 3072,      # the sequence length in which to attend to, but does not undergo self attention (must be less than max_seq_len)
    cross_attn_dropout = 0.5,       # what percentage of the prefix to dropout during training, in paper they had extensive experimentation to show up to 50% dropout helped prevent overfitting
)

x = torch.randint(0, 20000, (1, 4096))

logits = model(x) # (1, 1024, 20000) - (4096 [seq len] - 3072 [perceived prefix] == 1024)

Test

Enwik8 at 4096

$ python train.py

Citations

@article{Hawthorne2022GeneralpurposeLA,
    title   = {General-purpose, long-context autoregressive modeling with Perceiver AR},
    author  = {Curtis Hawthorne and Andrew Jaegle and Cătălina Cangea and Sebastian Borgeaud and Charlie Nash and Mateusz Malinowski and Sander Dieleman and Oriol Vinyals and Matthew M. Botvinick and Ian Simon and Hannah R. Sheahan and Neil Zeghidour and Jean-Baptiste Alayrac and Jo{\~a}o Carreira and Jesse Engel},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2202.07765}
}

.\lucidrains\perceiver-ar-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'perceiver-ar-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.0.10',
  # 许可证类型
  license='MIT',
  # 描述
  description = 'Perceiver AR',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/perceiver-ar-pytorch',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'deep learning',
    'transformers',
    'long context',
    'attention'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.4',
    'torch>=1.6',
  ],
  # 分类标签
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\perceiver-ar-pytorch\train.py

# 导入所需的库
import gzip
import random

import numpy as np
import torch
import torch.optim as optim
import tqdm
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 导入自定义的模型和包装器
from perceiver_ar_pytorch import PerceiverAR
from perceiver_ar_pytorch.autoregressive_wrapper import AutoregressiveWrapper

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 4096
PREFIX_SEQ_LEN = 3584

# 定义循环函数
def cycle(loader):
    while True:
        for data in loader:
            yield data

# 解码单个 token
def decode_token(token):
    return str(chr(max(32, token)))

# 解码一组 tokens
def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

# 创建 PerceiverAR 模型
model = PerceiverAR(
    num_tokens = 256,
    dim = 512,
    depth = 8,
    heads = 8,
    dim_head = 64,
    cross_attn_dropout = 0.5,
    max_seq_len = SEQ_LEN,
    cross_attn_seq_len = PREFIX_SEQ_LEN
)

# 使用 AutoregressiveWrapper 包装模型
model = AutoregressiveWrapper(model)
# 将模型移动到 GPU
model.cuda()

# 准备 enwik8 数据
with gzip.open("./data/enwik8.gz") as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 定义自定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))

# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward()

    print(f"training loss: {loss.item()}")
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f"validation loss: {loss.item()}")

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f"%s \n\n %s", (prime, "*" * 100))

        sample = model.generate(inp[None, ...], GENERATE_LENGTH)
        output_str = decode_tokens(sample[0])
        print(output_str)

.\lucidrains\perceiver-pytorch\perceiver_pytorch\experimental.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch.nn.functional 模块中导入 F 别名
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat

# 从 perceiver_pytorch.perceiver_pytorch 模块中导入 exists, default, cache_fn, fourier_encode, PreNorm, FeedForward, Attention 类

# 定义线性注意力类 LinearAttention
class LinearAttention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        heads = 4,
        dim_head = 64,
        dropout = 0.
    ):
        super().__init__()
        inner_dim = heads * dim_head
        self.heads = heads
        self.scale = dim_head ** -0.5

        # 定义线性变换层,将输入维度转换为内部维度的三倍
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        # 定义输出层,包含线性变换和 dropout 操作
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    # 前向传播函数
    def forward(self, x, mask = None):
        h = self.heads
        # 将输入 x 经过线性变换层得到查询、键、值
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        # 重排查询、键、值的维度
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))

        # 缩放查询
        q = q * self.scale
        # 对查询和键进行 softmax 操作
        q, k = q.softmax(dim = -1), k.softmax(dim = -2)

        # 如果存在 mask,则对键进行填充
        if exists(mask):
            k.masked_fill_(mask, 0.)

        # 计算上下文信息
        context = einsum('b n d, b n e -> b d e', q, k)
        # 计算输出
        out = einsum('b d e, b n d -> b n e', context, v)
        # 重排输出的维度
        out = rearrange(out, ' (b h) n d -> b n (h d)', h = h)
        return self.to_out(out)

# 主类 Perceiver
class Perceiver(nn.Module):
    def __init__(
        self,
        *,
        num_freq_bands,
        depth,
        max_freq,
        input_channels = 3,
        input_axis = 2,
        num_latents = 512,
        latent_dim = 512,
        cross_heads = 1,
        latent_heads = 8,
        cross_dim_head = 64,
        latent_dim_head = 64,
        num_classes = 1000,
        attn_dropout = 0.,
        ff_dropout = 0.,
        weight_tie_layers = False,
        fourier_encode_data = True
        ):
        # 调用父类的构造函数
        super().__init__()
        # 设置输入数据的轴数
        self.input_axis = input_axis
        # 设置最大频率
        self.max_freq = max_freq
        # 设置频率带数量
        self.num_freq_bands = num_freq_bands
        # 是否对数据进行傅立叶编码
        self.fourier_encode_data = fourier_encode_data

        # 计算输入维度
        input_dim = input_channels

        # 如果需要对数据进行傅立叶编码
        if fourier_encode_data:
            # 更新输入维度
            input_dim += input_axis * ((num_freq_bands * 2) + 1) + input_channels

        # 初始化潜在变量
        self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))

        # 数据投影层
        self.data_proj = nn.Linear(input_dim, input_dim)

        # 定义获取交叉注意力的函数
        get_cross_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, input_dim, heads = cross_heads, dim_head = cross_dim_head, dropout = attn_dropout), context_dim = input_dim)
        # 定义获取交叉前馈网络的函数
        get_cross_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))

        # 定义获取输入注意力的函数
        get_input_attn = lambda: PreNorm(input_dim, LinearAttention(input_dim, dropout = attn_dropout))
        # 定义获取反向交叉注意力的函数
        get_rev_cross_attn = lambda: PreNorm(input_dim, Attention(input_dim, latent_dim, heads = cross_heads, dim_head = cross_dim_head, dropout = attn_dropout), context_dim = latent_dim)
        # 定义获取反向交叉前馈网络的函数
        get_rev_cross_ff = lambda: PreNorm(input_dim, FeedForward(input_dim, dropout = ff_dropout))

        # 定义获取潜在注意力的函数
        get_latent_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head, dropout = attn_dropout))
        # 定义获取潜在前馈网络的函数
        get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))

        # 使用缓存函数对获取函数进行缓存
        get_cross_attn, get_cross_ff, get_rev_cross_attn, get_rev_cross_ff, get_input_attn, get_latent_attn, get_latent_ff = map(cache_fn, (get_cross_attn, get_cross_ff, get_rev_cross_attn, get_rev_cross_ff, get_input_attn, get_latent_attn, get_latent_ff))

        # 初始化网络层
        self.layers = nn.ModuleList([])
        for i in range(depth):
            should_cache = i > 0 and weight_tie_layers
            cache_args = {'_cache': should_cache}

            self.layers.append(nn.ModuleList([
                get_cross_attn(**cache_args),
                get_cross_ff(**cache_args),
                get_rev_cross_attn(**cache_args),
                get_rev_cross_ff(**cache_args),
                get_input_attn(**cache_args),
                get_latent_attn(**cache_args),
                get_latent_ff(**cache_args)
            ]))

        # 输出层
        self.to_logits = nn.Sequential(
            nn.LayerNorm(latent_dim),
            nn.Linear(latent_dim, num_classes)
        )

    def forward(self, data, mask = None):
        # 获取数据的维度信息
        b, *axis, _, device = *data.shape, data.device
        # 断言数据维度与输入轴数相符
        assert len(axis) == self.input_axis, 'input data must have the right number of axis'

        # 如果需要对数据进行傅立叶编码
        if self.fourier_encode_data:
            # 计算在[-1, 1]范围内的傅立叶编码位置,对所有轴
            axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis))
            pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1)
            enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
            enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
            enc_pos = repeat(enc_pos, '... -> b ...', b = b)

            # 将编码位置与数据的通道连接并展平轴
            data = torch.cat((data, enc_pos), dim = -1)

        data = rearrange(data, 'b ... d -> b (...) d')

        # 数据投影
        data = self.data_proj(data)

        # 重复潜在变量
        x = repeat(self.latents, 'n d -> b n d', b = b)

        # 遍历网络层
        for i, (cross_attn, cross_ff, rev_cross_attn, rev_cross_ff, input_attn, latent_attn, latent_ff) in enumerate(self.layers):
            is_last = i == (len(self.layers) - 1)

            x = cross_attn(x, context = data, mask = mask) + x
            x = cross_ff(x) + x

            if not is_last:
                data = input_attn(data, mask = mask) + data
                data = rev_cross_attn(data, context = x) + data
                data = rev_cross_ff(data) + data

            x = latent_attn(x) + x
            x = latent_ff(x) + x

        # 对最后的输出进行平均处理
        x = x.mean(dim = -2)
        return self.to_logits(x)

.\lucidrains\perceiver-pytorch\perceiver_pytorch\gated.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块、einsum 函数
from torch import nn, einsum
# 从 torch.nn.functional 中导入 F 模块
import torch.nn.functional as F

# 从 einops 库中导入 rearrange、repeat 函数
from einops import rearrange, repeat

# 从 perceiver_pytorch.perceiver_pytorch 中导入 exists、default、cache_fn、fourier_encode、PreNorm、FeedForward、Attention

# helpers

# 定义 Residual 类,继承 nn.Module 类
class Residual(nn.Module):
    # 初始化函数
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    # 前向传播函数
    def forward(self, x, **kwargs):
        return x + self.fn(x, **kwargs)

# 定义 GRUGating 类,继承 nn.Module 类
class GRUGating(nn.Module):
    # 初始化函数
    def __init__(self, dim, fn):
        super().__init__()
        self.dim = dim
        self.fn = fn
        self.gru = nn.GRUCell(dim, dim)

    # 前向传播函数
    def forward(self, x, **kwargs):
        b, dim = x.shape[0], self.dim
        y = self.fn(x, **kwargs)

        gated_output = self.gru(
            rearrange(y, '... d -> (...) d'),
            rearrange(x, '... d -> (...) d')
        )

        gated_output = rearrange(gated_output, '(b n) d -> b n d', b = b)
        return gated_output

# main class

# 定义 Perceiver 类,继承 nn.Module 类
class Perceiver(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        num_freq_bands,
        depth,
        max_freq,
        input_channels = 3,
        input_axis = 2,
        num_latents = 512,
        latent_dim = 512,
        cross_heads = 1,
        latent_heads = 8,
        cross_dim_head = 64,
        latent_dim_head = 64,
        num_classes = 1000,
        attn_dropout = 0.,
        ff_dropout = 0.,
        weight_tie_layers = False
    ):
        super().__init__()
        self.input_axis = input_axis
        self.max_freq = max_freq
        self.num_freq_bands = num_freq_bands

        input_dim = input_axis * ((num_freq_bands * 2) + 1) + input_channels

        self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))

        get_cross_attn  = lambda: GRUGating(latent_dim, PreNorm(latent_dim, Attention(latent_dim, input_dim, heads = cross_heads, dim_head = cross_dim_head, dropout = attn_dropout), context_dim = input_dim))
        get_latent_attn = lambda: GRUGating(latent_dim, PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head, dropout = attn_dropout))
        get_cross_ff    = lambda: Residual(PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout)))
        get_latent_ff   = lambda: Residual(PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout)))

        get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(cache_fn, (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff))

        self.layers = nn.ModuleList([])
        for i in range(depth):
            should_cache = i > 0 and weight_tie_layers
            cache_args = {'_cache': should_cache}

            self.layers.append(nn.ModuleList([
                get_cross_attn(**cache_args),
                get_cross_ff(**cache_args),
                get_latent_attn(**cache_args),
                get_latent_ff(**cache_args)
            ]))

        self.to_logits = nn.Sequential(
            nn.LayerNorm(latent_dim),
            nn.Linear(latent_dim, num_classes)
        )
    # 前向传播函数,接受数据和掩码作为输入
    def forward(self, data, mask = None):
        # 获取数据的形状和设备信息
        b, *axis, _, device = *data.shape, data.device
        # 断言数据的轴数与输入轴数相同
        assert len(axis) == self.input_axis, 'input data must have the right number of axis'

        # 计算傅立叶编码的位置,范围为[-1, 1],对所有轴

        # 生成每个轴上的位置信息
        axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis))
        # 生成位置的网格
        pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1)
        # 对位置信息进行傅立叶编码
        enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
        # 重新排列编码后的位置信息
        enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
        # 复制编码后的位置信息,使其与数据维度相匹配
        enc_pos = repeat(enc_pos, '... -> b ...', b = b)

        # 将编码后的位置信息连接到数据的通道上,并展平轴

        data = torch.cat((data, enc_pos), dim = -1)
        data = rearrange(data, 'b ... d -> b (...) d')

        # 复制潜在变量,使其与数据维度相匹配
        x = repeat(self.latents, 'n d -> b n d', b = b)

        # 遍历每个层,进行交叉注意力、交叉前馈、潜在注意力和潜在前馈操作
        for cross_attn, cross_ff, latent_attn, latent_ff in self.layers:
            x = cross_attn(x, context = data, mask = mask)
            x = cross_ff(x)
            x = latent_attn(x)
            x = latent_ff(x)

        # 对最终结果进行平均处理,并返回logits
        x = x.mean(dim = -2)
        return self.to_logits(x)

.\lucidrains\perceiver-pytorch\perceiver_pytorch\mixed_latents.py

# 导入所需的库
import torch
from torch import nn, einsum
import torch.nn.functional as F

# 导入额外的库
from einops import rearrange, repeat

# 导入自定义的模块
from perceiver_pytorch.perceiver_pytorch import exists, default, cache_fn, fourier_encode, PreNorm, FeedForward, Attention

# 定义 latent mixer 函数
def Mixer(seq_len, mult = 4, dropout = 0.):
    return nn.Sequential(
        nn.Conv1d(seq_len, seq_len * mult, 1),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Conv1d(seq_len * mult, seq_len, 1)
    )

# 定义主要的 Perceiver 类
class Perceiver(nn.Module):
    def __init__(
        self,
        *,
        num_freq_bands,
        depth,
        max_freq,
        input_channels = 3,
        input_axis = 2,
        num_latents = 512,
        latent_dim = 512,
        cross_heads = 1,
        latent_heads = 8,
        cross_dim_head = 64,
        latent_dim_head = 64,
        num_classes = 1000,
        attn_dropout = 0.,
        ff_dropout = 0.,
        weight_tie_layers = False,
        **kwargs
    ):
        super().__init__()
        self.input_axis = input_axis
        self.max_freq = max_freq
        self.num_freq_bands = num_freq_bands

        # 计算输入维度
        input_dim = input_axis * ((num_freq_bands * 2) + 1) + input_channels

        # 初始化可学习参数
        self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))

        # 定义获取不同类型注意力和前馈网络的函数
        get_cross_attn  = lambda: PreNorm(latent_dim, Attention(latent_dim, input_dim, heads = cross_heads, dim_head = cross_dim_head, dropout = attn_dropout), context_dim = input_dim)
        get_latent_attn = lambda: PreNorm(latent_dim, Mixer(num_latents, dropout = ff_dropout))
        get_cross_ff    = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))
        get_latent_ff   = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))

        # 缓存函数的结果
        get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(cache_fn, (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff))

        # 初始化层列表
        self.layers = nn.ModuleList([])
        for i in range(depth):
            should_cache = i > 0 and weight_tie_layers
            cache_args = {'_cache': should_cache}

            self.layers.append(nn.ModuleList([
                get_cross_attn(**cache_args),
                get_cross_ff(**cache_args),
                get_latent_attn(**cache_args),
                get_latent_ff(**cache_args)
            ]))

        # 定义输出层
        self.to_logits = nn.Sequential(
            nn.LayerNorm(latent_dim),
            nn.Linear(latent_dim, num_classes)
        )

    def forward(self, data, mask = None):
        # 获取数据的形状和设备信息
        b, *axis, _, device = *data.shape, data.device
        assert len(axis) == self.input_axis, 'input data must have the right number of axis'

        # 计算傅立叶编码的位置信息
        axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis))
        pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1)
        enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
        enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
        enc_pos = repeat(enc_pos, '... -> b ...', b = b)

        # 将位置信息拼接到数据中并展平轴
        data = torch.cat((data, enc_pos), dim = -1)
        data = rearrange(data, 'b ... d -> b (...) d')

        # 复制 latent 参数到每个样本
        x = repeat(self.latents, 'n d -> b n d', b = b)

        # 循环处理每一层
        for cross_attn, cross_ff, latent_attn, latent_ff in self.layers:
            x = cross_attn(x, context = data, mask = mask) + x
            x = cross_ff(x) + x
            x = latent_attn(x) + x
            x = latent_ff(x) + x

        # 对最后的输出进行平均处理并返回
        x = x.mean(dim = -2)
        return self.to_logits(x)

.\lucidrains\perceiver-pytorch\perceiver_pytorch\perceiver_io.py

# 从 math 模块中导入 pi 和 log 函数
# 从 functools 模块中导入 wraps 函数
# 导入 torch 模块及其子模块 nn, einsum, functional
# 从 einops 模块中导入 rearrange, repeat 函数
from math import pi, log
from functools import wraps

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat

# 定义辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 缓存函数的结果
def cache_fn(f):
    cache = None
    @wraps(f)
    def cached_fn(*args, _cache = True, **kwargs):
        if not _cache:
            return f(*args, **kwargs)
        nonlocal cache
        if cache is not None:
            return cache
        cache = f(*args, **kwargs)
        return cache
    return cached_fn

# 结构化的 dropout,比传统的注意力 dropout 更有效

# 对序列进行 dropout
def dropout_seq(seq, mask, dropout):
    b, n, *_, device = *seq.shape, seq.device
    logits = torch.randn(b, n, device = device)

    if exists(mask):
        logits = logits.masked_fill(~mask, -torch.finfo(logits.dtype).max)

    keep_prob = 1. - dropout
    num_keep = max(1,  int(keep_prob * n))
    keep_indices = logits.topk(num_keep, dim = 1).indices

    batch_indices = torch.arange(b, device = device)
    batch_indices = rearrange(batch_indices, 'b -> b 1')

    seq = seq[batch_indices, keep_indices]

    if exists(mask):
        seq_counts = mask.sum(dim = -1)
        seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
        keep_mask = torch.arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')

        mask = mask[batch_indices, keep_indices] & keep_mask

    return seq, mask

# 辅助类

# 预层归一化
class PreNorm(nn.Module):
    def __init__(self, dim, fn, context_dim = None):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)
        self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None

    def forward(self, x, **kwargs):
        x = self.norm(x)

        if exists(self.norm_context):
            context = kwargs['context']
            normed_context = self.norm_context(context)
            kwargs.update(context = normed_context)

        return self.fn(x, **kwargs)

# GEGLU 激活函数
class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

# 前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        return self.net(x)

# 注意力机制
class Attention(nn.Module):
    def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)
        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, query_dim)

    def forward(self, x, context = None, mask = None):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k, v = self.to_kv(context).chunk(2, dim = -1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))

        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h = h)
            sim.masked_fill_(~mask, max_neg_value)

        # 注意力机制,我们无法获得足够的
        attn = sim.softmax(dim = -1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        return self.to_out(out)

# 主类

class PerceiverIO(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(
        self,
        *,
        depth,
        dim,
        queries_dim,
        logits_dim = None,
        num_latents = 512,
        latent_dim = 512,
        cross_heads = 1,
        latent_heads = 8,
        cross_dim_head = 64,
        latent_dim_head = 64,
        weight_tie_layers = False,
        decoder_ff = False,
        seq_dropout_prob = 0.
    ):
        # 调用父类初始化函数
        super().__init__()
        # 设置序列的dropout概率
        self.seq_dropout_prob = seq_dropout_prob

        # 初始化模型中的可学习参数
        self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))

        # 创建交叉注意力块和前馈网络块
        self.cross_attend_blocks = nn.ModuleList([
            PreNorm(latent_dim, Attention(latent_dim, dim, heads = cross_heads, dim_head = cross_dim_head), context_dim = dim),
            PreNorm(latent_dim, FeedForward(latent_dim))
        ])

        # 定义获取潜在注意力和前馈网络的函数
        get_latent_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head))
        get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim))
        # 使用缓存函数对获取潜在注意力和前馈网络的函数进行缓存
        get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))

        # 初始化模型的层
        self.layers = nn.ModuleList([])
        cache_args = {'_cache': weight_tie_layers}

        # 循环创建多个层
        for i in range(depth):
            self.layers.append(nn.ModuleList([
                get_latent_attn(**cache_args),
                get_latent_ff(**cache_args)
            ]))

        # 创建解码器的交叉注意力块和前馈网络块
        self.decoder_cross_attn = PreNorm(queries_dim, Attention(queries_dim, latent_dim, heads = cross_heads, dim_head = cross_dim_head), context_dim = latent_dim)
        self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None

        # 创建输出层
        self.to_logits = nn.Linear(queries_dim, logits_dim) if exists(logits_dim) else nn.Identity()

    # 前向传播函数
    def forward(
        self,
        data,
        mask = None,
        queries = None
    ):
        # 获取数据的维度和设备信息
        b, *_, device = *data.shape, data.device

        # 将潜在向量重复扩展到与数据相同的维度
        x = repeat(self.latents, 'n d -> b n d', b = b)

        # 获取交��注意力块和前馈网络块
        cross_attn, cross_ff = self.cross_attend_blocks

        # 结构化的dropout操作
        if self.training and self.seq_dropout_prob > 0.:
            data, mask = dropout_seq(data, mask, self.seq_dropout_prob)

        # 执行交叉注意力操作
        x = cross_attn(x, context = data, mask = mask) + x
        x = cross_ff(x) + x

        # 多层自注意力和前馈网络操作
        for self_attn, self_ff in self.layers:
            x = self_attn(x) + x
            x = self_ff(x) + x

        # 如果没有查询数据,则直接返回结果
        if not exists(queries):
            return x

        # 确保查询数据包含批处理维度
        if queries.ndim == 2:
            queries = repeat(queries, 'n d -> b n d', b = b)

        # 从解码器查询到潜在向量的交叉注意力操作
        latents = self.decoder_cross_attn(queries, context = x)

        # 可选的解码器前馈网络操作
        if exists(self.decoder_ff):
            latents = latents + self.decoder_ff(latents)

        # 最终的线性输出
        return self.to_logits(latents)
# Perceiver LM 示例

class PerceiverLM(nn.Module):
    def __init__(
        self,
        *,
        dim,  # 定义维度
        num_tokens,  # 定义标记数量
        max_seq_len,  # 定义最大序列长度
        **kwargs  # 其他参数
    ):
        super().__init__()
        self.token_emb = nn.Embedding(num_tokens, dim)  # 创建标记嵌入层
        self.pos_emb = nn.Embedding(max_seq_len, dim)  # 创建位置嵌入层

        self.perceiver_io = PerceiverIO(  # 创建 PerceiverIO 模块
            dim = dim,
            queries_dim = dim,
            logits_dim = num_tokens,
            **kwargs
        )

    def forward(
        self,
        x,  # 输入张量
        mask = None  # 掩码,默认为空
    ):
        n, device = x.shape[1], x.device  # 获取输入张量的维度和设备信息
        x = self.token_emb(x)  # 对输入张量进行标记嵌入

        pos_emb = self.pos_emb(torch.arange(n, device = device))  # 根据序列长度创建位置嵌入
        pos_emb = rearrange(pos_emb, 'n d -> () n d')  # 重新排列位置嵌入的维度
        x = x + pos_emb  # 将标记嵌入和位置嵌入相加

        logits = self.perceiver_io(x, mask = mask, queries = x)  # 使用 PerceiverIO 模块进行前向传播
        return logits  # 返回输出结果

.\lucidrains\perceiver-pytorch\perceiver_pytorch\perceiver_pytorch.py

# 从 math 模块中导入 pi 和 log 函数
# 从 functools 模块中导入 wraps 装饰器
# 导入 torch 库及其相关模块
# 从 torch.nn 模块中导入 nn 和 einsum
# 从 torch.nn.functional 模块中导入 F
# 导入 einops 库中的 rearrange 和 repeat 函数
# 从 einops.layers.torch 模块中导入 Reduce 类
from math import pi, log
from functools import wraps

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Reduce

# 定义一些辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 如果变量存在则返回该变量,否则返回默认值的函数
def default(val, d):
    return val if exists(val) else d

# 缓存函数结果的装饰器
def cache_fn(f):
    cache = dict()
    @wraps(f)
    def cached_fn(*args, _cache = True, key = None, **kwargs):
        if not _cache:
            return f(*args, **kwargs)
        nonlocal cache
        if key in cache:
            return cache[key]
        result = f(*args, **kwargs)
        cache[key] = result
        return result
    return cached_fn

# 对输入进行傅立叶编码的函数
def fourier_encode(x, max_freq, num_bands = 4):
    x = x.unsqueeze(-1)
    device, dtype, orig_x = x.device, x.dtype, x

    scales = torch.linspace(1., max_freq / 2, num_bands, device = device, dtype = dtype)
    scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis]

    x = x * scales * pi
    x = torch.cat([x.sin(), x.cos()], dim = -1)
    x = torch.cat((x, orig_x), dim = -1)
    return x

# 定义一些辅助类

# 实现预层归一化的类
class PreNorm(nn.Module):
    def __init__(self, dim, fn, context_dim = None):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)
        self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None

    def forward(self, x, **kwargs):
        x = self.norm(x)

        if exists(self.norm_context):
            context = kwargs['context']
            normed_context = self.norm_context(context)
            kwargs.update(context = normed_context)

        return self.fn(x, **kwargs)

# 实现GEGLU激活函数的类
class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

# 实现前馈神经网络的类
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Linear(dim * mult, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

# 实现注意力机制的类
class Attention(nn.Module):
    def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)

        self.dropout = nn.Dropout(dropout)
        self.to_out = nn.Linear(inner_dim, query_dim)

    def forward(self, x, context = None, mask = None):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k, v = self.to_kv(context).chunk(2, dim = -1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))

        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h = h)
            sim.masked_fill_(~mask, max_neg_value)

        # 注意力机制,获取重要信息
        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        return self.to_out(out)

# 主类

class Perceiver(nn.Module):
    # 初始化函数,设置Transformer模型的参数
    def __init__(
        self,
        *,
        num_freq_bands,  # 频率带数量
        depth,  # Transformer的深度
        max_freq,  # 最大频率
        input_channels = 3,  # 输入通道数,默认为3
        input_axis = 2,  # 输入轴,默认为2
        num_latents = 512,  # 潜在变量数量,默认为512
        latent_dim = 512,  # 潜在维度,默认为512
        cross_heads = 1,  # 交叉头数,默认为1
        latent_heads = 8,  # 潜在头数,默认为8
        cross_dim_head = 64,  # 交叉维度头数,默认为64
        latent_dim_head = 64,  # 潜在维度头数,默认为64
        num_classes = 1000,  # 类别数量,默认为1000
        attn_dropout = 0.,  # 注意力机制的dropout,默认为0
        ff_dropout = 0.,  # 前馈网络的dropout,默认为0
        weight_tie_layers = False,  # 是否权重绑定层,默认为False
        fourier_encode_data = True,  # 是否对数据进行傅立叶编码,默认为True
        self_per_cross_attn = 1,  # 自注意力与交叉注意力的比例,默认为1
        final_classifier_head = True  # 是否使用最终分类头,默认为True
        """The shape of the final attention mechanism will be:
        depth * (cross attention -> self_per_cross_attn * self attention)

        Args:
          num_freq_bands: Number of freq bands, with original value (2 * K + 1)
          depth: Depth of net.
          max_freq: Maximum frequency, hyperparameter depending on how
              fine the data is.
          freq_base: Base for the frequency
          input_channels: Number of channels for each token of the input.
          input_axis: Number of axes for input data (2 for images, 3 for video)
          num_latents: Number of latents, or induced set points, or centroids.
              Different papers giving it different names.
          latent_dim: Latent dimension.
          cross_heads: Number of heads for cross attention. Paper said 1.
          latent_heads: Number of heads for latent self attention, 8.
          cross_dim_head: Number of dimensions per cross attention head.
          latent_dim_head: Number of dimensions per latent self attention head.
          num_classes: Output number of classes.
          attn_dropout: Attention dropout
          ff_dropout: Feedforward dropout
          weight_tie_layers: Whether to weight tie layers (optional).
          fourier_encode_data: Whether to auto-fourier encode the data, using
              the input_axis given. defaults to True, but can be turned off
              if you are fourier encoding the data yourself.
          self_per_cross_attn: Number of self attention blocks per cross attn.
          final_classifier_head: mean pool and project embeddings to number of classes (num_classes) at the end
        """
        super().__init__()
        self.input_axis = input_axis
        self.max_freq = max_freq
        self.num_freq_bands = num_freq_bands

        self.fourier_encode_data = fourier_encode_data
        fourier_channels = (input_axis * ((num_freq_bands * 2) + 1)) if fourier_encode_data else 0
        input_dim = fourier_channels + input_channels

        self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))

        get_cross_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, input_dim, heads = cross_heads, dim_head = cross_dim_head, dropout = attn_dropout), context_dim = input_dim)
        get_cross_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))
        get_latent_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head, dropout = attn_dropout))
        get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))

        get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(cache_fn, (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff))

        self.layers = nn.ModuleList([])
        for i in range(depth):
            should_cache = i > 0 and weight_tie_layers
            cache_args = {'_cache': should_cache}

            self_attns = nn.ModuleList([])

            for block_ind in range(self_per_cross_attn):
                self_attns.append(nn.ModuleList([
                    get_latent_attn(**cache_args, key = block_ind),
                    get_latent_ff(**cache_args, key = block_ind)
                ]))

            self.layers.append(nn.ModuleList([
                get_cross_attn(**cache_args),
                get_cross_ff(**cache_args),
                self_attns
            ]))

        self.to_logits = nn.Sequential(
            Reduce('b n d -> b d', 'mean'),
            nn.LayerNorm(latent_dim),
            nn.Linear(latent_dim, num_classes)
        ) if final_classifier_head else nn.Identity()

    def forward(
        self,
        data,
        mask = None,
        return_embeddings = False
        ):
        # 解构 data 的 shape,获取除了最后两个元素外的所有元素,分别赋值给 b 和 axis
        b, *axis, _, device, dtype = *data.shape, data.device, data.dtype
        # 断言 axis 的长度等于 self.input_axis,确保输入数据具有正确数量的轴
        assert len(axis) == self.input_axis, 'input data must have the right number of axis'

        if self.fourier_encode_data:
            # 如果需要对数据进行傅立叶编码
            # 计算每个轴上范围为[-1, 1]的傅立叶编码位置

            # 为每个轴生成均匀分布的位置
            axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps=size, device=device, dtype=dtype), axis))
            # 将每个轴的位置组合成多维网格
            pos = torch.stack(torch.meshgrid(*axis_pos, indexing='ij'), dim=-1)
            # 对位置进行傅立叶编码
            enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
            # 重新排列编码后的位置
            enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
            # 将编码后的位置重复 b 次
            enc_pos = repeat(enc_pos, '... -> b ...', b=b)

            # 将编码后的位置拼接到数据的通道中
            data = torch.cat((data, enc_pos), dim=-1)

        # 将数据拼接到通道并展平轴
        data = rearrange(data, 'b ... d -> b (...) d')

        # 将 latents 重复 b 次
        x = repeat(self.latents, 'n d -> b n d', b=b)

        # 循环处理每一层
        for cross_attn, cross_ff, self_attns in self.layers:
            # 跨通道注意力和前馈网络
            x = cross_attn(x, context=data, mask=mask) + x
            x = cross_ff(x) + x

            # 处理每个自注意力和前馈网络
            for self_attn, self_ff in self_attns:
                x = self_attn(x) + x
                x = self_ff(x) + x

        # 如果需要返回嵌入向量
        if return_embeddings:
            return x

        # 转换为 logits
        return self.to_logits(x)

.\lucidrains\perceiver-pytorch\perceiver_pytorch\__init__.py

# 从 perceiver_pytorch.perceiver_pytorch 模块中导入 Perceiver 类
from perceiver_pytorch.perceiver_pytorch import Perceiver
# 从 perceiver_pytorch.perceiver_io 模块中导入 PerceiverIO 和 PerceiverLM 类
from perceiver_pytorch.perceiver_io import PerceiverIO, PerceiverLM

Perceiver - Pytorch

Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch

Yannic Kilcher explanation!

Install

$ pip install perceiver-pytorch

Usage

import torch
from perceiver_pytorch import Perceiver

model = Perceiver(
    input_channels = 3,          # number of channels for each token of the input
    input_axis = 2,              # number of axis for input data (2 for images, 3 for video)
    num_freq_bands = 6,          # number of freq bands, with original value (2 * K + 1)
    max_freq = 10.,              # maximum frequency, hyperparameter depending on how fine the data is
    depth = 6,                   # depth of net. The shape of the final attention mechanism will be:
                                 #   depth * (cross attention -> self_per_cross_attn * self attention)
    num_latents = 256,           # number of latents, or induced set points, or centroids. different papers giving it different names
    latent_dim = 512,            # latent dimension
    cross_heads = 1,             # number of heads for cross attention. paper said 1
    latent_heads = 8,            # number of heads for latent self attention, 8
    cross_dim_head = 64,         # number of dimensions per cross attention head
    latent_dim_head = 64,        # number of dimensions per latent self attention head
    num_classes = 1000,          # output number of classes
    attn_dropout = 0.,
    ff_dropout = 0.,
    weight_tie_layers = False,   # whether to weight tie layers (optional, as indicated in the diagram)
    fourier_encode_data = True,  # whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself
    self_per_cross_attn = 2      # number of self attention blocks per cross attention
)

img = torch.randn(1, 224, 224, 3) # 1 imagenet image, pixelized

model(img) # (1, 1000)

For the backbone of Perceiver IO, the follow up paper that allows for flexible number of output sequence length, just import PerceiverIO instead

import torch
from perceiver_pytorch import PerceiverIO

model = PerceiverIO(
    dim = 32,                    # dimension of sequence to be encoded
    queries_dim = 32,            # dimension of decoder queries
    logits_dim = 100,            # dimension of final logits
    depth = 6,                   # depth of net
    num_latents = 256,           # number of latents, or induced set points, or centroids. different papers giving it different names
    latent_dim = 512,            # latent dimension
    cross_heads = 1,             # number of heads for cross attention. paper said 1
    latent_heads = 8,            # number of heads for latent self attention, 8
    cross_dim_head = 64,         # number of dimensions per cross attention head
    latent_dim_head = 64,        # number of dimensions per latent self attention head
    weight_tie_layers = False,   # whether to weight tie layers (optional, as indicated in the diagram)
    seq_dropout_prob = 0.2       # fraction of the tokens from the input sequence to dropout (structured dropout, for saving compute and regularizing effects)
)

seq = torch.randn(1, 512, 32)
queries = torch.randn(128, 32)

logits = model(seq, queries = queries) # (1, 128, 100) - (batch, decoder seq, logits dim)

As an example, using PerceiverIO as a language model

import torch
from perceiver_pytorch import PerceiverLM

model = PerceiverLM(
    num_tokens = 20000,          # number of tokens
    dim = 32,                    # dimension of sequence to be encoded
    depth = 6,                   # depth of net
    max_seq_len = 2048,          # maximum sequence length
    num_latents = 256,           # number of latents, or induced set points, or centroids. different papers giving it different names
    latent_dim = 512,            # latent dimension
    cross_heads = 1,             # number of heads for cross attention. paper said 1
    latent_heads = 8,            # number of heads for latent self attention, 8
    cross_dim_head = 64,         # number of dimensions per cross attention head
    latent_dim_head = 64,        # number of dimensions per latent self attention head
    weight_tie_layers = False    # whether to weight tie layers (optional, as indicated in the diagram)
)

seq = torch.randint(0, 20000, (1, 512))
mask = torch.ones(1, 512).bool()

logits = model(seq, mask = mask) # (1, 512, 20000)

Experimental

I have also included a version of Perceiver that includes bottom-up (in addition to top-down) attention, using the same scheme as presented in the original Set Transformers paper as the Induced Set Attention Block.

You simply have to change the above import to

from perceiver_pytorch.experimental import Perceiver

Citations

@misc{jaegle2021perceiver,
    title   = {Perceiver: General Perception with Iterative Attention},
    author  = {Andrew Jaegle and Felix Gimeno and Andrew Brock and Andrew Zisserman and Oriol Vinyals and Joao Carreira},
    year    = {2021},
    eprint  = {2103.03206},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{jaegle2021perceiver,
    title   = {Perceiver IO: A General Architecture for Structured Inputs & Outputs},
    author  = {Andrew Jaegle and Sebastian Borgeaud and Jean-Baptiste Alayrac and Carl Doersch and Catalin Ionescu and David Ding and Skanda Koppula and Andrew Brock and Evan Shelhamer and Olivier Hénaff and Matthew M. Botvinick and Andrew Zisserman and Oriol Vinyals and João Carreira},
    year    = {2021},
    eprint  = {2107.14795},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

.\lucidrains\perceiver-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'perceiver-pytorch',  # 包名
  packages = find_packages(),  # 查找所有包
  version = '0.8.8',  # 版本号
  license='MIT',  # 许可证
  description = 'Perceiver - Pytorch',  # 描述
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/perceiver-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformer',
    'attention mechanism'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.3',
    'torch>=1.6'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

Performer with Deepspeed for Enwik8

Deepspeed is the framework Microsoft used to train the world's largest Attention model (17GB) to date. They have open sourced it, and it works with Reformer Pytorch!

  1. First install Deepspeed following instructions from their official repository https://github.com/microsoft/DeepSpeed

  2. Run the following command in this folder

$ deepspeed train.py --deepspeed --deepspeed_config ds_config.json

.\lucidrains\performer-pytorch\examples\enwik8_deepspeed\train.py

import deepspeed  # 导入deepspeed库

from performer_pytorch import PerformerLM  # 从performer_pytorch库中导入PerformerLM类
from performer_pytorch.autoregressive_wrapper import AutoregressiveWrapper  # 从performer_pytorch.autoregressive_wrapper库中导入AutoregressiveWrapper类

import argparse  # 导入argparse库,用于解析命令行参数
import random  # 导入random库,用于生成随机数
import tqdm  # 导入tqdm库,用于显示进度条
import gzip  # 导入gzip库,用于处理gzip压缩文件
import numpy as np  # 导入numpy库,用于处理数组
import torch  # 导入torch库,用于构建神经网络
import torch.optim as optim  # 从torch库中导入optim模块,用于定义优化器
from torch.nn import functional as F  # 从torch库中导入functional模块,用于定义神经网络的函数
from torch.utils.data import DataLoader, Dataset  # 从torch.utils.data库中导入DataLoader和Dataset类,用于处理数据集

def add_argument():  # 定义函数add_argument,用于添加命令行参数
    parser=argparse.ArgumentParser(description='enwik8')  # 创建一个ArgumentParser对象,设置描述信息为'enwik8'

    parser.add_argument('--with_cuda', default=False, action='store_true',  # 添加一个名为'--with_cuda'的命令行参数,默认值为False,如果存在则设置为True
                        help='use CPU in case there\'s no GPU support')  # 设置参数的帮助信息
    parser.add_argument('--use_ema', default=False, action='store_true',  # 添加一个名为'--use_ema'的命令行参数,默认值为False,如果存在则设置为True
                        help='whether use exponential moving average')  # 设置参数的帮助信息
    parser.add_argument('-b', '--batch_size', default=32, type=int,  # 添加一个名为'-b'或'--batch_size'的命令行参数,默认值为32,类型为整数
                        help='mini-batch size (default: 32)')  # 设置参数的帮助信息
    parser.add_argument('-e', '--epochs', default=30, type=int,  # 添加一个名为'-e'或'--epochs'的命令行参数,默认值为30,类型为整数
                        help='number of total epochs (default: 30)')  # 设置参数的帮助信息
    parser.add_argument('--local_rank', type=int, default=-1,  # 添加一个名为'--local_rank'的命令行参数,类型为整数,默认值为-1
                       help='local rank passed from distributed launcher')  # 设置参数的帮助信息

    parser = deepspeed.add_config_arguments(parser)  # 调用deepspeed库中的add_config_arguments函数,添加配置参数
    args=parser.parse_args()  # 解析命令行参数并返回结果
    return args  # 返回解析后的参数对象

# constants

EPOCHS = 20  # 定义常量EPOCHS为20,表示训练的总轮数
VALIDATE_EVERY  = 100  # 定义常量VALIDATE_EVERY为100,表示每隔100步进行一次验证
GENERATE_EVERY  = 500  # 定义常量GENERATE_EVERY为500,表示每隔500步生成一次数据
GENERATE_LENGTH = 512  # 定义常量GENERATE_LENGTH为512,表示生成数据的长度
SEQ_LEN = 1024  # 定义常量SEQ_LEN为1024,表示序列的长度

# helpers

def decode_token(token):  # 定义函数decode_token,用于将token解码为字符
    return str(chr(max(32, token)))  # 返回ASCII码对应的字符,如果小于32则返回空格

def decode_tokens(tokens):  # 定义函数decode_tokens,用于将tokens解码为字符串
    return ''.join(list(map(decode_token, tokens)))  # 将tokens中的每个token解码为字符并拼接成字符串

# instantiate model

model = PerformerLM(  # 创建PerformerLM模型对象
    num_tokens = 256,  # 设置模型的token数量为256
    dim = 512,  # 设置模型的维度为512
    depth = 6,  # 设置模型的深度为6
    max_seq_len = SEQ_LEN,  # 设置模型的最大序列长度为SEQ_LEN
    heads = 8,  # 设置模型的头数为8
    causal = True,  # 设置模型为因果模型
    reversible = True,  # 设置模型为可逆模型
    nb_features = 256,  # 设置模型的特征数量为256
    use_scalenorm = True,  # 设置模型使用scalenorm
)

model = AutoregressiveWrapper(model)  # 使用AutoregressiveWrapper对模型进行包装
model.cuda()  # 将模型移动到GPU上

# prepare enwik8 data

with gzip.open('./data/enwik8.gz') as file:  # 打开enwik8.gz文件
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)  # 从文件中读取数据并转换为numpy数组
    trX, vaX = np.split(X, [int(90e6)])  # 将数据分割为训练集和验证集
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)  # 将数据转换为PyTorch张量

class TextSamplerDataset(Dataset):  # 定义TextSamplerDataset类,继承自Dataset类
    def __init__(self, data, seq_len):  # 定义初始化方法,接受数据和序列长度作为参数
        super().__init__()  # 调用父类的初始化方法
        self.data = data  # 设置数据属性
        self.seq_len = seq_len  # 设置序列长度属性

    def __getitem__(self, index):  # 定义获取数据项的方法
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))  # 随机生成起始位置
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()  # 获取完整序列
        return full_seq  # 返回完整序列

    def __len__(self):  # 定义获取数据集长度的方法
        return self.data.size(0) // self.seq_len  # 返回数据集长度除以序列长度的整数部分

train_dataset = TextSamplerDataset(data_train, SEQ_LEN)  # 创建训练数据集对象
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)  # 创建验证数据集对象

# setup deepspeed

cmd_args = add_argument()  # 调用add_argument函数,获取命令行参数
model_engine, optimizer, trainloader, _ = deepspeed.initialize(args=cmd_args, model=model, model_parameters=model.parameters(),  training_data=train_dataset)  # 使用deepspeed初始化模型引擎、优化器、数据加载器

# training

for _ in range(EPOCHS):  # 循环训练EPOCHS轮
    for i, data in enumerate(trainloader):  # 遍历训练数据加载器
        model_engine.train()  # 设置模型为训练模式
        data = data.to(model_engine.local_rank)  # 将数据移动到指定设备
        loss = model_engine(data, return_loss = True)  # 计算损失
        model_engine.backward(loss)  # 反向传播计算梯度
        model_engine.step()  # 更新模型参数
        print(loss.item() * GRADIENT_ACCUMULATE_EVERY)  # 打印损失值乘以梯度累积步数

        if model_engine.local_rank != 0:  # 如果不是主进程
            continue  # 继续下一次循环

        if i % VALIDATE_EVERY == 0:  # 每隔VALIDATE_EVERY步进行一次验证
            model.eval()  # 设置模型为评估模式
            with torch.no_grad():  # 禁用梯度计算
                inp = random.choice(val_dataset)[:-1]  # 从验证集中随机选择一个输入序列
                loss = model(inp[None, :].cuda(), return_loss = True)  # 计算验证集上的损失
                print(f'validation loss: {loss.item()}')  # 打印验证损失值

        if i % GENERATE_EVERY == 0:  # 每隔GENERATE_EVERY步生成一次数据
            model.eval()  # 设置模型为评估模式
            inp = random.choice(val_dataset)[:-1]  # 从验证集中随机选择一个输入序列
            prime = decode_tokens(inp)  # 解码输入序列
            print(f'%s \n\n %s', (prime, '*' * 100))  # 打印输入序列和分隔符

            sample = model.generate(inp.cuda(), GENERATE_LENGTH)  # 生成数据
            output_str = decode_tokens(sample)  # 解码生成的数据
            print(output_str)  # 打印生成的数据

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

.\lucidrains\performer-pytorch\examples\enwik8_simple\train.py

# 导入所需的库
from performer_pytorch import PerformerLM
from performer_pytorch.autoregressive_wrapper import AutoregressiveWrapper
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 2048
SEQ_LEN = 4096

# 定义辅助函数
def cycle(loader):
    # 无限循环生成数据
    while True:
        for data in loader:
            yield data

def decode_token(token):
    # 将 token 解码为字符
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    # 将 tokens 解码为字符串
    return ''.join(list(map(decode_token, tokens)))

# 实例化模型
model = PerformerLM(
    num_tokens = 256,
    dim = 512,
    depth = 6,
    max_seq_len = SEQ_LEN,
    heads = 8,
    causal = True,
    reversible = True,
    nb_features = 256,
    use_scalenorm = True,
    shift_tokens = True,
    local_attn_heads = (8, 8, 8, 6, 4, 2)
)

model = AutoregressiveWrapper(model)
model.cuda()

# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scaler = GradScaler()

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        with autocast():
            loss = model(next(train_loader), return_loss = True)
        scaler.scale(loss).backward()

    print(f'training loss: {loss.item()}')

    scaler.unscale_(optim)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    scaler.step(optim)
    scaler.update()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader), return_loss = True)
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0 and i != 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp, GENERATE_LENGTH)
        output_str = decode_tokens(sample)
        print(output_str)

.\lucidrains\performer-pytorch\examples\toy_tasks\enc_dec_copy.py

# 导入必要的库
import tqdm
import torch
import torch.optim as optim
from performer_pytorch import PerformerEncDec
from torch.cuda.amp import autocast, GradScaler

# 定义常量
NUM_BATCHES = int(1e5)  # 总批次数
BATCH_SIZE = 32  # 每批次大小
LEARNING_RATE = 1e-4  # 学习率
GENERATE_EVERY  = 100  # 每隔多少批次生成输出
NUM_TOKENS = 16 + 2  # 标记的数量
ENC_SEQ_LEN = 32  # 编码器序列长度
DEC_SEQ_LEN = 64 + 1  # 解码器序列长度

# 定义生成数据的辅助函数
def cycle():
    while True:
        prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
        src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
        tgt = torch.cat((prefix, src, src), 1)
        src_mask = torch.ones(BATCH_SIZE, ENC_SEQ_LEN).bool().cuda()
        tgt_mask = torch.ones(BATCH_SIZE, tgt.shape[1]).bool().cuda()
        yield (src, tgt, src_mask, tgt_mask)

# 实例化模型
model = PerformerEncDec(
    dim=512,
    enc_num_tokens=NUM_TOKENS,
    enc_depth=1,
    enc_heads=8,
    enc_max_seq_len=ENC_SEQ_LEN,
    enc_reversible=True,
    enc_feature_redraw_interval=1000,
    enc_nb_features = 64,
    dec_num_tokens=NUM_TOKENS,
    dec_depth=3,
    dec_heads=8,
    dec_max_seq_len=DEC_SEQ_LEN,
    dec_reversible=True,
    dec_feature_redraw_interval=1000,
    dec_nb_features=64
).cuda()

# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scaler = GradScaler()

# 训练过程
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    src, tgt, src_mask, tgt_mask = next(cycle())

    with autocast():
        loss = model(src, tgt, enc_mask=src_mask, dec_mask=tgt_mask)

    scaler.scale(loss).backward()
    print(f'{i}: {loss.item()}')

    scaler.step(optim)
    scaler.update()
    optim.zero_grad()

    if i != 0 and i % GENERATE_EVERY == 0:
        model.eval()
        src, _, src_mask, _ = next(cycle())
        src, src_mask = src[:1], src_mask[:1]
        start_tokens = (torch.ones((1, 1)) * 1).long().cuda()

        sample = model.generate(src, start_tokens, ENC_SEQ_LEN, enc_mask=src_mask)
        incorrects = (src != sample).abs().sum()

        print(f"input:  ", src)
        print(f"predicted output:  ", sample)
        print(f"incorrects: {incorrects}")

.\lucidrains\performer-pytorch\examples\toy_tasks\enc_dec_copy_apex.py

# 导入必要的库
import tqdm
import torch
import torch.optim as optim
from performer_pytorch import PerformerEncDec
from apex import amp

# 定义常量
NUM_BATCHES = int(1e5)  # 总批次数
BATCH_SIZE = 32  # 每批次的样本数量
LEARNING_RATE = 1e-4  # 学习率
GENERATE_EVERY  = 100  # 每隔多少批次生成一次输出
NUM_TOKENS = 16 + 2  # 标记的数量
ENC_SEQ_LEN = 32  # 编码器序列长度
DEC_SEQ_LEN = 64 + 1  # 解码器序列长度

# 定义生成数据的辅助函数
def cycle():
    while True:
        prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
        src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
        tgt = torch.cat((prefix, src, src), 1)
        src_mask = torch.ones(BATCH_SIZE, ENC_SEQ_LEN).bool().cuda()
        tgt_mask = torch.ones(BATCH_SIZE, tgt.shape[1]).bool().cuda()
        yield (src, tgt, src_mask, tgt_mask)

# 实例化模型
model = PerformerEncDec(
    dim=512,
    enc_num_tokens=NUM_TOKENS,
    enc_depth=1,
    enc_heads=8,
    enc_max_seq_len=ENC_SEQ_LEN,
    enc_reversible=True,
    enc_feature_redraw_interval=1000,
    enc_nb_features = 64,
    dec_num_tokens=NUM_TOKENS,
    dec_depth=3,
    dec_heads=8,
    dec_max_seq_len=DEC_SEQ_LEN,
    dec_reversible=True,
    dec_feature_redraw_interval=1000,
    dec_nb_features=64
).cuda()

# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 混合精度训练
model, optim = amp.initialize(model, optim, opt_level = 'O1')

# 训练过程
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    src, tgt, src_mask, tgt_mask = next(cycle())
    loss = model(src, tgt, enc_mask=src_mask, dec_mask=tgt_mask)

    with amp.scale_loss(loss, optim) as scaled_loss:
        scaled_loss.backward()

    print(f'{i}: {loss.item()}')
    optim.step()
    optim.zero_grad()

    if i != 0 and i % GENERATE_EVERY == 0:
        model.eval()
        src, _, src_mask, _ = next(cycle())
        src, src_mask = src[:1], src_mask[:1]
        start_tokens = (torch.ones((1, 1)) * 1).long().cuda()

        sample = model.generate(src, start_tokens, ENC_SEQ_LEN, enc_mask=src_mask)
        incorrects = (src != sample).abs().sum()

        print(f"input:  ", src)
        print(f"predicted output:  ", sample)
        print(f"incorrects: {incorrects}")

.\lucidrains\performer-pytorch\performer_pytorch\autoregressive_wrapper.py

from functools import partial
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

# 定义一个函数,用于检查值是否存在
def exists(val):
    return val is not None

# 从logits中选择概率最高的部分,保留概率大于阈值的部分
def top_p(logits, thres = 0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    sorted_indices_to_remove = cum_probs > (1 - thres)
    sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
    sorted_indices_to_remove[:, 0] = 0

    sorted_logits[sorted_indices_to_remove] = float('-inf')
    return sorted_logits.scatter(1, sorted_indices, sorted_logits)

# 从logits中选择概率最高的k个部分,保留概率大于阈值的部分
def top_k(logits, thres = 0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 根据重复惩罚因子对logits进行调整
def repetition_penalty_fn(logits, ctx, theta=1.2):
    w = torch.ones(logits.shape[-1], dtype=torch.float, device=logits.device)
    for i in torch.unique(ctx):
        w[i] = theta
    return logits/w

# 自回归模型的包装器
class AutoregressiveWrapper(nn.Module):
    def __init__(self, net, ignore_index = 0, pad_value = 0):
        super().__init__()
        self.pad_value = pad_value
        self.ignore_index = ignore_index

        self.net = net
        self.max_seq_len = net.max_seq_len

    @torch.no_grad()
    def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, repetition_penalty=1.0, repetition_penalty_ctx=32, **kwargs):
        was_training = self.net.training
        num_dims = len(start_tokens.shape)

        if num_dims == 1:
            start_tokens = start_tokens[None, :]

        b, t = start_tokens.shape

        self.net.eval()
        out = start_tokens
        input_mask = kwargs.pop('mask', None)

        if input_mask is None:
            input_mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)
        
        # 在条件生成的情况下,如果未提供enc_mask,则使用正确的context_mask
        context_mask = kwargs.pop('context_mask', None)

        if 'context' in kwargs and not exists(context_mask):
            context = kwargs['context']
            context_mask = torch.full(context.shape[:2], True, dtype=torch.bool, device=out.device)

        kwargs.update(context_mask = context_mask)

        for _ in range(seq_len):
            x = out[:, -self.max_seq_len:]
            input_mask = input_mask[:, -self.max_seq_len:]
            logits = self.net(x, mask=input_mask, **kwargs)[:, -1, :]
            if repetition_penalty > 1.0:
                logits = repetition_penalty_fn(logits, out[-repetition_penalty_ctx:], theta=repetition_penalty)
            filtered_logits = filter_logits_fn(logits, thres = filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim=-1)
            sample = torch.multinomial(probs, 1)

            out = torch.cat((out, sample), dim=-1)
            input_mask = F.pad(input_mask, (0, 1), value=True)

            if eos_token is not None and (sample == eos_token).all():
                break

        out = out[:, t:]

        if num_dims == 1:
            out = out.squeeze(0)

        self.net.train(was_training)
        return out

    def forward(self, x, **kwargs):
        xi = x[:, :-1]
        xo = x[:, 1:]

        # 帮助解决自回归中输入掩码的一个困惑区域
        # 如果用户提供的掩码与源序列相差一个位置,为其解决
        mask = kwargs.pop('mask', None)
        if mask is not None and mask.shape[1] == x.shape[1]:
            mask = mask[:, :-1]
        kwargs.update(mask = mask)

        out = self.net(xi, **kwargs)

        loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
        return loss

.\lucidrains\performer-pytorch\performer_pytorch\performer_enc_dec.py

# 导入所需的库
import re
import torch
from torch import nn
from performer_pytorch.performer_pytorch import PerformerLM
from performer_pytorch.autoregressive_wrapper import AutoregressiveWrapper

# 定义编码器和解码器的前缀
ENC_PREFIX = 'enc_'
DEC_PREFIX = 'dec_'

# 根据条件将字典分组
def group_dict_by_key(cond, d):
    return_val = [dict(),dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)

# 检查字符串是否以指定前缀开头
def string_begins_with(prefix, str):
    return bool(re.match(f'^{prefix}', str))

# 根据前缀将字典分组
def group_by_key_prefix(prefix, d):
    return group_dict_by_key(lambda x: string_begins_with(prefix, x), d)

# 根据前缀分组并移除前缀
def group_by_key_prefix_and_remove_prefix(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: string_begins_with(prefix, x), d)
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
    return kwargs_without_prefix, kwargs

# 提取编码器和解码器的关键字参数
def extract_enc_dec_kwargs(kwargs):
    enc_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(ENC_PREFIX, kwargs)
    dec_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(DEC_PREFIX, kwargs)
    return enc_kwargs, dec_kwargs, kwargs

# 提取并设置编码器和解码器的关键字参数
def extract_and_set_enc_dec_kwargs(kwargs):
    enc_kwargs, dec_kwargs, kwargs = extract_enc_dec_kwargs(kwargs)
    if 'mask' in enc_kwargs:
        dec_kwargs.setdefault('context_mask', enc_kwargs['mask'])
    return enc_kwargs, dec_kwargs, kwargs

# 定义 PerformerEncDec 类
class PerformerEncDec(nn.Module):
    def __init__(
        self,
        dim,
        ignore_index = 0,
        pad_value = 0,
        tie_token_embeds = False,
        no_projection = False,
        **kwargs
    ):
        super().__init__()
        enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs)
        
        assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder'

        enc_kwargs['dim'] = dec_kwargs['dim'] = dim
        enc_kwargs['no_projection'] = dec_kwargs['no_projection'] = no_projection

        dec_kwargs['causal'] = True
        dec_kwargs['cross_attend'] = True

        enc = PerformerLM(**enc_kwargs)
        dec = PerformerLM(**dec_kwargs)

        if tie_token_embeds:
            enc.token_emb = dec.token_emb

        self.enc = enc
        self.dec = AutoregressiveWrapper(dec, ignore_index = ignore_index, pad_value = pad_value)

    @torch.no_grad()
    def generate(self, seq_in, seq_out_start, seq_len, **kwargs):
        enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs)
        encodings = self.enc(seq_in, return_encodings = True, **enc_kwargs)
        return self.dec.generate(seq_out_start, seq_len, context = encodings, **{**dec_kwargs, **kwargs})

    def forward(self, seq_in, seq_out, enc_mask = None, **kwargs):
        enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs)
        encodings = self.enc(seq_in, mask = enc_mask, return_encodings = True, **enc_kwargs)
        return self.dec(seq_out, context = encodings, context_mask = enc_mask, **dec_kwargs)

.\lucidrains\performer-pytorch\performer_pytorch\performer_pytorch.py

# 导入数学库
import math
# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 torch 中导入 nn 模块
from torch import nn
# 从 torch.cuda.amp 中导入 autocast 函数
from torch.cuda.amp import autocast
# 从 einops 中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat

# 从 functools 中导入 partial 函数
from functools import partial
# 从 contextlib 中导入 contextmanager 函数
from contextlib import contextmanager

# 导入自定义的 local_attention 模块
from local_attention import LocalAttention
# 导入自定义的 axial_positional_embedding 模块
from axial_positional_embedding import AxialPositionalEmbedding
# 导入 performer_pytorch 中的 reversible 模块
from performer_pytorch.reversible import ReversibleSequence, SequentialSequence

# 从 distutils.version 中导入 LooseVersion 类
from distutils.version import LooseVersion

# 检查 torch 版本是否大于等于 1.8.0
TORCH_GE_1_8_0 = LooseVersion(torch.__version__) >= LooseVersion('1.8.0')

try:
    # 尝试导入 apex 库中的 amp 模块
    from apex import amp
    APEX_AVAILABLE = True
except:
    # 如果导入失败,则将 APEX_AVAILABLE 设为 False
    APEX_AVAILABLE = False

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 判断张量是否为空
def empty(tensor):
    return tensor.numel() == 0

# 返回 val 或默认值 d
def default(val, d):
    return val if exists(val) else d

# 空上下文管理器
@contextmanager
def null_context():
    yield

# 将 val 转换为元组
def cast_tuple(val):
    return (val,) if not isinstance(val, tuple) else val

# 获取模块的设备
def get_module_device(module):
    return next(module.parameters()).device

# 查找 nn_module 中的指定类型的模块
def find_modules(nn_module, type):
    return [module for module in nn_module.modules() if isinstance(module, type)]

# 始终返回指定值的模块
class Always(nn.Module):
    def __init__(self, val):
        super().__init__()
        self.val = val

    def forward(self, *args, **kwargs):
        return self.val

# token 移动的辅助函数和类

# 将张量 t 沿指定方向移动指定量 amount
def shift(t, amount, mask = None):
    if amount == 0:
        return t

    if exists(mask):
        t = t.masked_fill(~mask[..., None], 0.)

    return F.pad(t, (0, 0, amount, -amount), value = 0.)

# 预先移动 token 的类
class PreShiftTokens(nn.Module):
    def __init__(self, shifts, fn):
        super().__init__()
        self.fn = fn
        self.shifts = tuple(shifts)

    def forward(self, x, **kwargs):
        mask = kwargs.get('mask', None)
        shifts = self.shifts
        segments = len(shifts)
        feats_per_shift = x.shape[-1] // segments
        splitted = x.split(feats_per_shift, dim = -1)
        segments_to_shift, rest = splitted[:segments], splitted[segments:]
        segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
        x = torch.cat((*segments_to_shift, *rest), dim = -1)
        return self.fn(x, **kwargs)

# 核函数

# 从 jax 转录到 pytorch 的 softmax 核函数
def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None):
    b, h, *_ = data.shape

    data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.

    ratio = (projection_matrix.shape[0] ** -0.5)

    projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h)
    projection = projection.type_as(data)

    data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)

    diag_data = data ** 2
    diag_data = torch.sum(diag_data, dim=-1)
    diag_data = (diag_data / 2.0) * (data_normalizer ** 2)
    diag_data = diag_data.unsqueeze(dim=-1)

    if is_query:
        data_dash = ratio * (
            torch.exp(data_dash - diag_data -
                    torch.amax(data_dash, dim=-1, keepdim=True).detach()) + eps)
    else:
        data_dash = ratio * (
            torch.exp(data_dash - diag_data - torch.amax(data_dash, dim=(-1, -2), keepdim=True).detach()) + eps)

    return data_dash.type_as(data)

# 通用核函数
def generalized_kernel(data, *, projection_matrix, kernel_fn = nn.ReLU(), kernel_epsilon = 0.001, normalize_data = True, device = None):
    b, h, *_ = data.shape

    data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.

    if projection_matrix is None:
        return kernel_fn(data_normalizer * data) + kernel_epsilon

    projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h)
    projection = projection.type_as(data)

    data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)

    data_prime = kernel_fn(data_dash) + kernel_epsilon
    # 将data的数据类型转换为与data_prime相同的数据类型,并返回结果
    return data_prime.type_as(data)
# 生成一个正交矩阵块
def orthogonal_matrix_chunk(cols, device = None):
    # 生成一个随机的矩阵
    unstructured_block = torch.randn((cols, cols), device = device)
    # 使用 QR 分解得到正交矩阵 q
    if TORCH_GE_1_8_0:
        q, r = torch.linalg.qr(unstructured_block.cpu(), mode = 'reduced')
    else:
        q, r = torch.qr(unstructured_block.cpu(), some = True)
    # 将 q 和 r 移动到指定设备上
    q, r = map(lambda t: t.to(device), (q, r))
    # 返回 q 的转置
    return q.t()

# 生成一个高斯正交随机矩阵
def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, device = None):
    # 计算完整块的数量
    nb_full_blocks = int(nb_rows / nb_columns)

    block_list = []

    # 生成完整块
    for _ in range(nb_full_blocks):
        q = orthogonal_matrix_chunk(nb_columns, device = device)
        block_list.append(q)

    # 处理剩余的行
    remaining_rows = nb_rows - nb_full_blocks * nb_columns
    if remaining_rows > 0:
        q = orthogonal_matrix_chunk(nb_columns, device = device)
        block_list.append(q[:remaining_rows])

    # 拼接所有块
    final_matrix = torch.cat(block_list)

    # 根据 scaling 参数生成 multiplier
    if scaling == 0:
        multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1)
    elif scaling == 1:
        multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device)
    else:
        raise ValueError(f'Invalid scaling {scaling}')

    # 返回乘积结果
    return torch.diag(multiplier) @ final_matrix

# 线性注意力类,使用 softmax 核

# 非因果线性注意力
def linear_attention(q, k, v):
    # 计算 k 的累加和
    k_cumsum = k.sum(dim = -2)
    # 计算 D_inv
    D_inv = 1. / torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q))
    # 计算上下文
    context = torch.einsum('...nd,...ne->...de', k, v)
    # 计算输出
    out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv)
    return out

# 高效因果线性注意力,由 EPFL 创建
# TODO: 重写 EPFL 的 CUDA 核以进行混合精度,并删除半精度到单精度的转换
def causal_linear_attention(q, k, v, eps = 1e-6):
    from fast_transformers.causal_product import CausalDotProduct
    autocast_enabled = torch.is_autocast_enabled()
    is_half = isinstance(q, torch.cuda.HalfTensor)
    assert not is_half or APEX_AVAILABLE, 'half tensors can only be used if nvidia apex is available'
    cuda_context = null_context if not autocast_enabled else partial(autocast, enabled = False)

    causal_dot_product_fn = amp.float_function(CausalDotProduct.apply) if is_half else CausalDotProduct.apply

    k_cumsum = k.cumsum(dim=-2) + eps
    D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k_cumsum.type_as(q))

    with cuda_context():
        if autocast_enabled:
            q, k, v = map(lambda t: t.float(), (q, k, v))

        out = causal_dot_product_fn(q, k, v)

    out = torch.einsum('...nd,...n->...nd', out, D_inv)
    return out

# 低效因果线性注意力,不包含 CUDA 代码,供读者参考
# 未被使用
def causal_linear_attention_noncuda(q, k, v, chunk_size = 128, eps = 1e-6):
    last_k_cumsum = 0
    last_context_cumsum = 0
    outs = []

    for q, k, v in zip(*map(lambda t: t.chunk(chunk_size, dim = -2), (q, k, v))):
        k_cumsum = last_k_cumsum + k.cumsum(dim=-2)

        D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k_cumsum.type_as(q) + eps)
        context = torch.einsum('...nd,...ne->...nde', k, v)
        context_cumsum = last_context_cumsum + context.cumsum(dim=-3)
        out = torch.einsum('...nde,...nd,...n->...ne', context_cumsum, q, D_inv)

        last_k_cumsum = k_cumsum[:, :, -1:]
        last_context_cumsum = context_cumsum[:, :, -1:]
        outs.append(out)

    return torch.cat(outs, dim = -2)

class FastAttention(nn.Module):
    # 初始化函数,设置注意力头的维度、特征数量、正交缩放、是否因果、是否使用广义注意力、核函数、是否不使用投影
    def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, causal = False, generalized_attention = False, kernel_fn = nn.ReLU(), no_projection = False):
        # 调用父类的初始化函数
        super().__init__()
        # 如果未指定特征数量,则默认为注意力头维度乘以注意力头维度的对数
        nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))

        # 设置注意力头的维度、特征数量、正交缩放
        self.dim_heads = dim_heads
        self.nb_features = nb_features
        self.ortho_scaling = ortho_scaling

        # 创建投影矩阵的函数,使用高斯正交随机矩阵
        self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling)
        # 生成投影矩阵并注册为缓冲区
        projection_matrix = self.create_projection()
        self.register_buffer('projection_matrix', projection_matrix)

        # 设置是否使用广义注意力、核函数
        self.generalized_attention = generalized_attention
        self.kernel_fn = kernel_fn

        # 如果设置为不使用投影,则不进行投影,直接对查询和键进行 softmax 处理
        if this is turned on, no projection will be used
        queries and keys will be softmax-ed as in the original efficient attention paper
        self.no_projection = no_projection

        # 设置是否因果,如果是因果的则使用因果线性注意力函数
        self.causal = causal
        if causal:
            try:
                import fast_transformers.causal_product.causal_product_cuda
                self.causal_linear_fn = partial(causal_linear_attention)
            except ImportError:
                print('unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version')
                self.causal_linear_fn = causal_linear_attention_noncuda

    # 重新生成投影矩阵的函数,用于在训练过程中更新投影矩阵
    @torch.no_grad()
    def redraw_projection_matrix(self, device):
        # 生成新的投影矩阵并复制到原有的投影矩阵中
        projections = self.create_projection(device = device)
        self.projection_matrix.copy_(projections)
        del projections

    # 前向传播函数,接收查询、键、值作为输入,返回注意力计算结果
    def forward(self, q, k, v):
        device = q.device

        # 如果设置为不使用投影,则直接对查询和键进行 softmax 处理
        if self.no_projection:
            q = q.softmax(dim = -1)
            k = torch.exp(k) if self.causal else k.softmax(dim = -2)

        # 如果设置为使用广义注意力,则使用广义核函数进行计算
        elif self.generalized_attention:
            create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device)
            q, k = map(create_kernel, (q, k))

        # 否则使用 softmax 核函数进行计算
        else:
            create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device)
            q = create_kernel(q, is_query = True)
            k = create_kernel(k, is_query = False)

        # 根据是否因果选择不同的注意力函数进行计算
        attn_fn = linear_attention if not self.causal else self.causal_linear_fn
        out = attn_fn(q, k, v)
        return out
# 用于跟踪何时更新投影的模块

class ProjectionUpdater(nn.Module):
    def __init__(self, instance, feature_redraw_interval):
        super().__init__()
        self.instance = instance
        self.feature_redraw_interval = feature_redraw_interval
        self.register_buffer('calls_since_last_redraw', torch.tensor(0))

    def fix_projections_(self):
        # 修正投影
        self.feature_redraw_interval = None

    def redraw_projections(self):
        model = self.instance

        if not self.training:
            return

        # 如果存在特征重绘间隔并且自上次重绘以来的调用次数大于等于特征重绘间隔
        if exists(self.feature_redraw_interval) and self.calls_since_last_redraw >= self.feature_redraw_interval:
            device = get_module_device(model)

            # 查找模型中的 FastAttention 模块
            fast_attentions = find_modules(model, FastAttention)
            for fast_attention in fast_attentions:
                fast_attention.redraw_projection_matrix(device)

            self.calls_since_last_redraw.zero_()
            return

        self.calls_since_last_redraw += 1

    def forward(self, x):
        raise NotImplemented

# 类

class ReZero(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.g = nn.Parameter(torch.tensor(1e-3))
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) * self.g

class PreScaleNorm(nn.Module):
    def __init__(self, dim, fn, eps=1e-5):
        super().__init__()
        self.fn = fn
        self.g = nn.Parameter(torch.ones(1))
        self.eps = eps

    def forward(self, x, **kwargs):
        n = torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps)
        x = x / n * self.g
        return self.fn(x, **kwargs)

class PreLayerNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class Chunk(nn.Module):
    def __init__(self, chunks, fn, along_dim = -1):
        super().__init__()
        self.dim = along_dim
        self.chunks = chunks
        self.fn = fn

    def forward(self, x, **kwargs):
        if self.chunks == 1:
            return self.fn(x, **kwargs)
        chunks = x.chunk(self.chunks, dim = self.dim)
        return torch.cat([self.fn(c, **kwargs) for c in chunks], dim = self.dim)

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0., activation = None, glu = False):
        super().__init__()
        activation = default(activation, nn.GELU)

        self.glu = glu
        self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1))
        self.act = activation()
        self.dropout = nn.Dropout(dropout)
        self.w2 = nn.Linear(dim * mult, dim)

    def forward(self, x, **kwargs):
        if not self.glu:
            x = self.w1(x)
            x = self.act(x)
        else:
            x, v = self.w1(x).chunk(2, dim=-1)
            x = self.act(x) * v

        x = self.dropout(x)
        x = self.w2(x)
        return x

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        causal = False,
        heads = 8,
        dim_head = 64,
        local_heads = 0,
        local_window_size = 256,
        nb_features = None,
        feature_redraw_interval = 1000,
        generalized_attention = False,
        kernel_fn = nn.ReLU(),
        dropout = 0.,
        no_projection = False,
        qkv_bias = False,
        attn_out_bias = True
    # 初始化函数,继承父类的初始化方法
    def __init__(
        super().__init__()
        # 断言维度必须能够被头数整除
        assert dim % heads == 0, 'dimension must be divisible by number of heads'
        # 计算每个头的维度
        dim_head = default(dim_head, dim // heads)
        # 计算内部维度
        inner_dim = dim_head * heads
        # 创建快速注意力对象
        self.fast_attention = FastAttention(dim_head, nb_features, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, no_projection = no_projection)

        # 设置头数和全局头数
        self.heads = heads
        self.global_heads = heads - local_heads
        # 如果有局部头数,则创建局部注意力对象
        self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None

        # 创建线性层,用于将输入转换为查询、键、值
        self.to_q = nn.Linear(dim, inner_dim, bias = qkv_bias)
        self.to_k = nn.Linear(dim, inner_dim, bias = qkv_bias)
        self.to_v = nn.Linear(dim, inner_dim, bias = qkv_bias)
        self.to_out = nn.Linear(inner_dim, dim, bias = attn_out_bias)
        self.dropout = nn.Dropout(dropout)

    # 前向传播函数
    def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = None, **kwargs):
        # 获取输入张量的形状信息
        b, n, _, h, gh = *x.shape, self.heads, self.global_heads

        # 判断是否存在上下文信息
        cross_attend = exists(context)

        # 设置默认上下文和上下文掩码
        context = default(context, x)
        context_mask = default(context_mask, mask) if not cross_attend else context_mask

        # 将输入张量转换为查询、键、值
        q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)

        # 重排查询、键、值张量的维度
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
        (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v))

        # 存储注意力输出
        attn_outs = []

        # 如果查询不为空
        if not empty(q):
            # 如果存在上下文掩码,则对值进行掩码
            if exists(context_mask):
                global_mask = context_mask[:, None, :, None]
                v.masked_fill_(~global_mask, 0.)

            # 如果存在位置编码且不是交叉注意力,则应用旋转位置编码
            if exists(pos_emb) and not cross_attend:
                q, k = apply_rotary_pos_emb(q, k, pos_emb)

            # 使用快速注意力计算输出
            out = self.fast_attention(q, k, v)
            attn_outs.append(out)

        # 如果局部查询不为空
        if not empty(lq):
            # 断言不支持交叉注意力和局部注意力同时存在
            assert not cross_attend, 'local attention is not compatible with cross attention'
            # 使用局部注意力计算输出
            out = self.local_attn(lq, lk, lv, input_mask = mask)
            attn_outs.append(out)

        # 拼接所有注意力输出
        out = torch.cat(attn_outs, dim = 1)
        # 重排输出张量的维度
        out = rearrange(out, 'b h n d -> b n (h d)')
        # 经过输出线性层
        out =  self.to_out(out)
        # 使用丢弃层
        return self.dropout(out)
# 定义 SelfAttention 类,继承自 Attention 类
class SelfAttention(Attention):
    # 重写 forward 方法,接收任意参数和关键字参数 context,默认为 None
    def forward(self, *args, context = None, **kwargs):
        # 断言 context 不存在,即 self attention 不应该接收 context
        assert not exists(context), 'self attention should not receive context'
        # 调用父类的 forward 方法,传入参数和关键字参数
        return super().forward(*args, **kwargs)

# 定义 CrossAttention 类,继承自 Attention 类
class CrossAttention(Attention):
    # 重写 forward 方法,接收任意参数和关键字参数 context,默认为 None
    def forward(self, *args, context = None, **kwargs):
        # 断言 context 存在,即 cross attention 应该接收 context
        assert exists(context), 'cross attention should receive context'
        # 调用父类的 forward 方法,传入参数、context 和关键字参数
        return super().forward(*args, context = context, **kwargs)

# positional embeddings

# 定义 AbsolutePositionalEmbedding 类,继承自 nn.Module 类
class AbsolutePositionalEmbedding(nn.Module):
    # 初始化方法,接收维度 dim 和最大序列长度 max_seq_len
    def __init__(self, dim, max_seq_len):
        super().__init__()
        # 创建一个 Embedding 层,将最大序列长度和维度作为参数
        self.emb = nn.Embedding(max_seq_len, dim)

    # 前向传播方法,接收输入 x
    def forward(self, x):
        # 生成一个序列长度的张量 t,设备为 x 的设备
        t = torch.arange(x.shape[1], device=x.device)
        # 返回 Embedding 层对 t 的嵌入结果
        return self.emb(t)

# rotary positional embedding helpers

# 定义 rotate_every_two 函数,接收输入 x
def rotate_every_two(x):
    # 重新排列 x 的维度,将最后一维拆分为两个维度
    x = rearrange(x, '... (d j) -> ... d j', j = 2)
    # 将 x 拆分为两部分 x1 和 x2
    x1, x2 = x.unbind(dim = -1)
    # 将 x1 和 x2 交换位置并合并成新的张量 x
    x = torch.stack((-x2, x1), dim = -1)
    # 重新排列 x 的维度,将最后两维合并为一维
    return rearrange(x, '... d j -> ... (d j)')

# 定义 apply_rotary_pos_emb 函数,接收查询向量 q、键向量 k 和正弦位置编码 sinu_pos
def apply_rotary_pos_emb(q, k, sinu_pos):
    # 重新排列 sinu_pos 的维度,将第二维拆分为两个维度
    sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j = 2)
    # 拆分 sinu_pos 为 sin 和 cos
    sin, cos = sinu_pos.unbind(dim = -2)
    # 将 sin 和 cos 扩展为与 q、k 相同的维度
    sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j = 2), (sin, cos))
    # 对 q、k 应用正弦和余弦位置编码
    q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
    # 返回处理后的 q 和 k
    return q, k

# sinusoidal positional embeddings

# 定义 FixedPositionalEmbedding 类,继承自 nn.Module 类
class FixedPositionalEmbedding(nn.Module):
    # 初始化方法,接收维度 dim 和最大序列长度 max_seq_len
    def __init__(self, dim, max_seq_len):
        super().__init__()
        # 计算频率的倒数
        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        # 生成位置张量和频率张量的乘积
        position = torch.arange(0, max_seq_len, dtype=torch.float)
        sinusoid_inp = torch.einsum("i,j->ij", position, inv_freq)
        # 拼接正弦和余弦结果作为位置编码
        emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
        # 将位置编码作为缓冲区注册到模型中
        self.register_buffer('emb', emb)

    # 前向传播方法,接收输入 x
    def forward(self, x):
        # 返回位置编码的子集,维度与输入 x 相匹配
        return self.emb[None, :x.shape[1], :].to(x)

# performer

# 定义 Performer 类,继承自 nn.Module 类
class Performer(nn.Module):
    # 初始化方法,接收多个参数设置
    def __init__(
        self,
        dim,
        depth,
        heads,
        dim_head,
        local_attn_heads = 0,
        local_window_size = 256,
        causal = False,
        ff_mult = 4,
        nb_features = None,
        feature_redraw_interval = 1000,
        reversible = False,
        ff_chunks = 1,
        generalized_attention = False,
        kernel_fn = nn.ReLU(),
        use_scalenorm = False,
        use_rezero = False,
        ff_glu = False,
        ff_dropout = 0.,
        attn_dropout = 0.,
        cross_attend = False,
        no_projection = False,
        auto_check_redraw = True,
        qkv_bias = True,
        attn_out_bias = True,
        shift_tokens = False
    # 初始化函数,继承父类的初始化方法
    ):
        # 初始化一个空的模块列表
        super().__init__()
        layers = nn.ModuleList([])
        # 将本地注意力头数转换为元组
        local_attn_heads = cast_tuple(local_attn_heads)
        # 如果只有一个本地注意力头数,则复制到每一层
        local_attn_heads = local_attn_heads * depth if len(local_attn_heads) == 1 else local_attn_heads
        # 确保本地注意力头数的长度等于深度
        assert len(local_attn_heads) == depth, 'tuple specifying number of local attention heads per depth must be equal to the total depth'
        # 确保本地注意力头数的值小于总头数
        assert all(map(lambda n: n >= 0 and n <= heads, local_attn_heads)), 'local attention head value must be less than the total number of heads'

        # 根据使用的归一化方法选择包装函数
        if use_scalenorm:
            wrapper_fn = partial(PreScaleNorm, dim)
        elif use_rezero:
            wrapper_fn = ReZero
        else:
            wrapper_fn = partial(PreLayerNorm, dim)

        # 遍历每一层
        for _, local_heads in zip(range(depth), local_attn_heads):

            # 创建自注意力层和前馈层
            attn = SelfAttention(dim, causal = causal, heads = heads, dim_head = dim_head, local_heads = local_heads, local_window_size = local_window_size, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, dropout = attn_dropout, no_projection = no_projection, qkv_bias = qkv_bias, attn_out_bias = attn_out_bias)
            ff = Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1)

            # 如果需要移动标记,则对自注意力层和前馈层进行移动
            if shift_tokens:
                shift = (0, 1) if causal else (-1, 0, 1)
                attn, ff = map(lambda t: PreShiftTokens(shift, t), (attn, ff))

            # 对自注意力层和前馈层应用包装函数
            attn, ff = map(wrapper_fn, (attn, ff))
            # 将自注意力层和前馈层添加到模块列表中
            layers.append(nn.ModuleList([attn, ff]))

            # 如果不需要跨层注意力,则继续下一层
            if not cross_attend:
                continue

            # 添加跨层注意力和前馈层到模块列表中
            layers.append(nn.ModuleList([
                wrapper_fn(CrossAttention(dim, heads = heads, dim_head = dim_head, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, dropout = attn_dropout, no_projection = no_projection, qkv_bias = qkv_bias, attn_out_bias = attn_out_bias)),
                wrapper_fn(Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1))
            ]))

        # 根据是否可逆选择执行类型
        execute_type = ReversibleSequence if reversible else SequentialSequence

        # 设置自注意力和上下文的路由映射
        route_attn = ((True, False),) * depth * (2 if cross_attend else 1)
        route_context = ((False, False), (True, False)) * depth
        attn_route_map = {'mask': route_attn, 'pos_emb': route_attn}
        context_route_map = {'context': route_context, 'context_mask': route_context} if cross_attend else {}
        # 创建网络结构
        self.net = execute_type(layers, args_route = {**attn_route_map, **context_route_map})

        # 记录何时重新绘制所有注意力层的投影矩阵
        self.auto_check_redraw = auto_check_redraw
        self.proj_updater = ProjectionUpdater(self.net, feature_redraw_interval)

    # 修正投影矩阵
    def fix_projection_matrices_(self):
        self.proj_updater.feature_redraw_interval = None

    # 前向传播函数
    def forward(self, x, **kwargs):
        # 如果需要自动检查重新绘制,则重新绘制投影矩阵
        if self.auto_check_redraw:
            self.proj_updater.redraw_projections()
        return self.net(x, **kwargs)
class PerformerLM(nn.Module):
    # 定义 PerformerLM 类,继承自 nn.Module
    def __init__(
        self,
        *,
        num_tokens,
        max_seq_len,
        dim,
        depth,
        heads,
        dim_head = 64,
        local_attn_heads = 0,
        local_window_size = 256,
        causal = False,
        ff_mult = 4,
        nb_features = None,
        feature_redraw_interval = 1000,
        reversible = False,
        ff_chunks = 1,
        ff_glu = False,
        emb_dropout = 0.,
        ff_dropout = 0.,
        attn_dropout = 0.,
        generalized_attention = False,
        kernel_fn = nn.ReLU(),
        use_scalenorm = False,
        use_rezero = False,
        cross_attend = False,
        no_projection = False,
        tie_embed = False,
        rotary_position_emb = True,
        axial_position_emb = False,
        axial_position_shape = None,
        auto_check_redraw = True,
        qkv_bias = False,
        attn_out_bias = False,
        shift_tokens = False
    ):
        # 初始化函数,接收多个参数
        super().__init__()
        local_attn_heads = cast_tuple(local_attn_heads)

        self.max_seq_len = max_seq_len
        self.token_emb = nn.Embedding(num_tokens, dim)
        # 创建 token embedding 层

        if rotary_position_emb:
            self.pos_emb = FixedPositionalEmbedding(dim, max_seq_len)
            self.layer_pos_emb = FixedPositionalEmbedding(dim_head, max_seq_len)
        elif axial_position_emb:
            axial_position_shape = default(axial_position_shape, (math.ceil(max_seq_len / 64), 64))
            self.pos_emb = AxialPositionalEmbedding(dim, axial_position_shape)
            self.layer_pos_emb = Always(None)
        else:
            self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)
            self.layer_pos_emb = Always(None)
        # 根据不同的位置编码方式创建位置编码层

        self.dropout = nn.Dropout(emb_dropout)
        # 创建 dropout 层

        self.performer = Performer(dim, depth, heads, dim_head, local_attn_heads, local_window_size, causal, ff_mult, nb_features, feature_redraw_interval, reversible, ff_chunks, generalized_attention, kernel_fn, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, no_projection, auto_check_redraw, qkv_bias, attn_out_bias, shift_tokens)
        # 创建 Performer 模型

        self.norm = nn.LayerNorm(dim)
        # 创建 LayerNorm 层

        self.to_out = nn.Linear(dim, num_tokens) if not tie_embed else None
        # 创建线性层,如果 tie_embed 为 False,则创建线性层,否则为 None

    def check_redraw_projections(self):
        # 检查是否需要重新绘制投影矩阵
        self.performer.check_redraw_projections()

    def fix_projection_matrices_(self):
        # 修正投影矩阵
        self.performer.fix_projection_matrices_()

    def forward(self, x, return_encodings = False, **kwargs):
        # 前向传播函数,接收输入 x 和是否返回编码的标志
        b, n, device = *x.shape, x.device
        # 获取输入 x 的形状和设备信息
        assert n <= self.max_seq_len, f'sequence length {n} must be less than the max sequence length {self.max_seq_len}'
        # 断言序列长度小于等于最大序列长度

        # token and positional embeddings
        x = self.token_emb(x)
        # 获取 token embedding
        x += self.pos_emb(x)
        # 添加位置编码

        x = self.dropout(x)
        # 应用 dropout

        # performer layers

        layer_pos_emb = self.layer_pos_emb(x)
        # 获取层级位置编码
        x = self.performer(x, pos_emb = layer_pos_emb, **kwargs)
        # 使用 Performer 模型进行计算

        # norm and to logits
        x = self.norm(x)
        # 应用 LayerNorm

        if return_encodings:
            return x
        # 如果需要返回编码,则直接返回编码

        if exists(self.to_out):
            return self.to_out(x)
        # 如果存在输出层,则返回输出

        return x @ self.token_emb.weight.t()
        # 返回结果

.\lucidrains\performer-pytorch\performer_pytorch\reversible.py

# 导入 torch 库
import torch
# 导入 torch 中的神经网络模块
import torch.nn as nn
# 从 operator 模块中导入 itemgetter 函数
from operator import itemgetter
# 从 torch.autograd.function 模块中导入 Function 类
from torch.autograd.function import Function
# 从 torch.utils.checkpoint 模块中导入 get_device_states 和 set_device_states 函数

# 用于将参数路由到可逆层函数中的函数
def route_args(router, args, depth):
    # 初始化路由后的参数列表
    routed_args = [(dict(), dict()) for _ in range(depth)]
    # 获取参数中与路由器匹配的键
    matched_keys = [key for key in args.keys() if key in router]

    # 遍历匹配的键
    for key in matched_keys:
        val = args[key]
        # 遍历路由后的参数列表和路由器中的路由
        for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
            # 根据路由将参数添加到对应的函数参数中
            new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
            routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
    return routed_args

# 参考示例 https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 中的保存和设置随机数生成器
class Deterministic(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.cpu_state = None
        self.cuda_in_fwd = None
        self.gpu_devices = None
        self.gpu_states = None

    def record_rng(self, *args):
        self.cpu_state = torch.get_rng_state()
        if torch.cuda._initialized:
            self.cuda_in_fwd = True
            self.gpu_devices, self.gpu_states = get_device_states(*args)

    def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
        if record_rng:
            self.record_rng(*args)

        if not set_rng:
            return self.net(*args, **kwargs)

        rng_devices = []
        if self.cuda_in_fwd:
            rng_devices = self.gpu_devices

        with torch.random.fork_rng(devices=rng_devices, enabled=True):
            torch.set_rng_state(self.cpu_state)
            if self.cuda_in_fwd:
                set_device_states(self.gpu_devices, self.gpu_states)
            return self.net(*args, **kwargs)

# 受 https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 启发
# 一旦多 GPU 确认工作正常,重构并将 PR 发回源代码
class ReversibleBlock(nn.Module):
    def __init__(self, f, g):
        super().__init__()
        self.f = Deterministic(f)
        self.g = Deterministic(g)

    def forward(self, x, f_args = {}, g_args = {}):
        x1, x2 = torch.chunk(x, 2, dim=2)
        y1, y2 = None, None

        with torch.no_grad():
            y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
            y2 = x2 + self.g(y1, record_rng=self.training, **g_args)

        return torch.cat([y1, y2], dim=2)

    def backward_pass(self, y, dy, f_args = {}, g_args = {}):
        y1, y2 = torch.chunk(y, 2, dim=2)
        del y

        dy1, dy2 = torch.chunk(dy, 2, dim=2)
        del dy

        with torch.enable_grad():
            y1.requires_grad = True
            gy1 = self.g(y1, set_rng=True, **g_args)
            torch.autograd.backward(gy1, dy2)

        with torch.no_grad():
            x2 = y2 - gy1
            del y2, gy1

            dx1 = dy1 + y1.grad
            del dy1
            y1.grad = None

        with torch.enable_grad():
            x2.requires_grad = True
            fx2 = self.f(x2, set_rng=True, **f_args)
            torch.autograd.backward(fx2, dx1, retain_graph=True)

        with torch.no_grad():
            x1 = y1 - fx2
            del y1, fx2

            dx2 = dy2 + x2.grad
            del dy2
            x2.grad = None

            x = torch.cat([x1, x2.detach()], dim=2)
            dx = torch.cat([dx1, dx2], dim=2)

        return x, dx

class _ReversibleFunction(Function):
    @staticmethod
    def forward(ctx, x, blocks, args):
        ctx.args = args
        for block, kwarg in zip(blocks, args):
            x = block(x, **kwarg)
        ctx.y = x.detach()
        ctx.blocks = blocks
        return x

    @staticmethod
    # 定义反向传播函数,接收上下文和梯度作为参数
    def backward(ctx, dy):
        # 获取上下文中的 y 和 args
        y = ctx.y
        args = ctx.args
        # 反向遍历上下文中的 blocks 和 args
        for block, kwargs in zip(ctx.blocks[::-1], args[::-1]):
            # 调用每个 block 的反向传播函数,更新 y 和 dy
            y, dy = block.backward_pass(y, dy, **kwargs)
        # 返回更新后的梯度
        return dy, None, None
class SequentialSequence(nn.Module):
    # 定义一个顺序执行的神经网络模块
    def __init__(self, layers, args_route = {}):
        super().__init__()
        # 断言每个参数路由映射的深度与顺序层的数量相同
        assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
        # 初始化顺序层和参数路由
        self.layers = layers
        self.args_route = args_route

    def forward(self, x, **kwargs):
        # 根据参数路由获取参数
        args = route_args(self.args_route, kwargs, len(self.layers))
        # 将顺序层和参数组成元组列表
        layers_and_args = list(zip(self.layers, args))

        # 遍历每个顺序层和参数
        for (f, g), (f_args, g_args) in layers_and_args:
            # 执行顺序层 f,并将结果加到输入 x 上
            x = x + f(x, **f_args)
            # 执行顺序层 g,并将结果加到输入 x 上
            x = x + g(x, **g_args)
        # 返回最终结果
        return x

class ReversibleSequence(nn.Module):
    # 定义一个可逆的序列神经网络模块
    def __init__(self, blocks, args_route = {}):
        super().__init__()
        # 初始化参数路由和模块列表
        self.args_route = args_route
        self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks])

    def forward(self, x, **kwargs):
        # 在最后一个维度上拼接输入 x
        x = torch.cat([x, x], dim=-1)

        # 获取模块列表和参数
        blocks = self.blocks
        args = route_args(self.args_route, kwargs, len(blocks))
        # 将参数转换为字典列表
        args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))

        # 调用自定义的可逆函数 _ReversibleFunction,并传入参数
        out =  _ReversibleFunction.apply(x, blocks, args)
        # 在最后一个维度上分割结果并求和
        return torch.stack(out.chunk(2, dim=-1)).sum(dim=0)

.\lucidrains\performer-pytorch\performer_pytorch\__init__.py

# 从 performer_pytorch 模块中导入 PerformerLM, Performer, FastAttention, SelfAttention, CrossAttention, ProjectionUpdater 类
# 以及 AutoregressiveWrapper, PerformerEncDec 类
from performer_pytorch.performer_pytorch import PerformerLM, Performer, FastAttention, SelfAttention, CrossAttention, ProjectionUpdater
from performer_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from performer_pytorch.performer_enc_dec import PerformerEncDec

Performer - Pytorch

PyPI version

An implementation of Performer, a linear attention-based transformer variant with a Fast Attention Via positive Orthogonal Random features approach (FAVOR+).

Install

$ pip install performer-pytorch

Then you must run the following, if you plan on training an autoregressive model

$ pip install -r requirements.txt

Usage

Performer Language Model

import torch
from performer_pytorch import PerformerLM

model = PerformerLM(
    num_tokens = 20000,
    max_seq_len = 2048,             # max sequence length
    dim = 512,                      # dimension
    depth = 12,                     # layers
    heads = 8,                      # heads
    causal = False,                 # auto-regressive or not
    nb_features = 256,              # number of random features, if not set, will default to (d * log(d)), where d is the dimension of each head
    feature_redraw_interval = 1000, # how frequently to redraw the projection matrix, the more frequent, the slower the training
    generalized_attention = False,  # defaults to softmax approximation, but can be set to True for generalized attention
    kernel_fn = torch.nn.ReLU(),    # the kernel function to be used, if generalized attention is turned on, defaults to Relu
    reversible = True,              # reversible layers, from Reformer paper
    ff_chunks = 10,                 # chunk feedforward layer, from Reformer paper
    use_scalenorm = False,          # use scale norm, from 'Transformers without Tears' paper
    use_rezero = False,             # use rezero, from 'Rezero is all you need' paper
    ff_glu = True,                  # use GLU variant for feedforward
    emb_dropout = 0.1,              # embedding dropout
    ff_dropout = 0.1,               # feedforward dropout
    attn_dropout = 0.1,             # post-attn dropout
    local_attn_heads = 4,           # 4 heads are local attention, 4 others are global performers
    local_window_size = 256,        # window size of local attention
    rotary_position_emb = True,     # use rotary positional embedding, which endows linear attention with relative positional encoding with no learned parameters. should always be turned on unless if you want to go back to old absolute positional encoding
    shift_tokens = True             # shift tokens by 1 along sequence dimension before each block, for better convergence
)

x = torch.randint(0, 20000, (1, 2048))
mask = torch.ones_like(x).bool()

model(x, mask = mask) # (1, 2048, 20000)

Plain Performer, if you are working with say images or other modalities

import torch
from performer_pytorch import Performer

model = Performer(
    dim = 512,
    depth = 1,
    heads = 8,
    causal = True
)

x = torch.randn(1, 2048, 512)
model(x) # (1, 2048, 512)

Encoder / Decoder - Made possible by Thomas Melistas

import torch
from performer_pytorch import PerformerEncDec

SRC_SEQ_LEN = 4096
TGT_SEQ_LEN = 4096
GENERATE_LEN = 512

enc_dec = PerformerEncDec(
    dim = 512,
    tie_token_embed = True,
    enc_num_tokens = 20000,
    enc_depth = 6,
    enc_heads = 8,
    enc_max_seq_len = SRC_SEQ_LEN,
    dec_num_tokens = 20000,
    dec_depth = 6,
    dec_heads = 8,
    dec_max_seq_len = TGT_SEQ_LEN,
)

src = torch.randint(0, 20000, (1, SRC_SEQ_LEN))
tgt = torch.randint(0, 20000, (1, TGT_SEQ_LEN))
src_mask = torch.ones_like(src).bool()
tgt_mask = torch.ones_like(src).bool()

# train
enc_dec.train()
loss = enc_dec(src, tgt, enc_mask = src_mask, dec_mask = tgt_mask)
loss.backward()

# generate
generate_in = torch.randint(0, 20000, (1, SRC_SEQ_LEN)).long()
generate_out_prime = torch.tensor([[0.]]).long() # prime with <bos> token
samples = enc_dec.generate(generate_in, generate_out_prime, seq_len = GENERATE_LEN, eos_token = 1) # assume 1 is id of stop token
print(samples.shape) # (1, <= GENERATE_LEN) decode the tokens

Standalone self-attention layer with linear complexity in respect to sequence length, for replacing trained full-attention transformer self-attention layers.

import torch
from performer_pytorch import SelfAttention

attn = SelfAttention(
    dim = 512,
    heads = 8,
    causal = False,
).cuda()

x = torch.randn(1, 1024, 512).cuda()
attn(x) # (1, 1024, 512)

Cross attention is similarly

import torch
from performer_pytorch import CrossAttention

attn = CrossAttention(
    dim = 512,
    heads = 8
).cuda()

x = torch.randn(1, 1024, 512).cuda()
context = torch.randn(1, 512, 512).cuda()

attn(x, context = context) # (1, 1024, 512)

To minimize model surgery, you could also simply rewrite the code, so that the attention step is done by the FastAttention module, as follows.

import torch
from performer_pytorch import FastAttention

# queries / keys / values with heads already split and transposed to first dimension
# 8 heads, dimension of head is 64, sequence length of 512
q = torch.randn(1, 8, 512, 64)
k = torch.randn(1, 8, 512, 64)
v = torch.randn(1, 8, 512, 64)

attn_fn = FastAttention(
    dim_heads = 64,
    nb_features = 256,
    causal = False
)

out = attn_fn(q, k, v) # (1, 8, 512, 64)
# now merge heads and combine outputs with Wo

Advanced

At the end of training, if you wish to fix the projection matrices to get the model to output deterministically, you can invoke the following

model.fix_projection_matrices_()

Now your model will have fixed projection matrices across all layers

Citations

@misc{choromanski2020rethinking,
    title   = {Rethinking Attention with Performers},
    author  = {Krzysztof Choromanski and Valerii Likhosherstov and David Dohan and Xingyou Song and Andreea Gane and Tamas Sarlos and Peter Hawkins and Jared Davis and Afroz Mohiuddin and Lukasz Kaiser and David Belanger and Lucy Colwell and Adrian Weller},
    year    = {2020},
    eprint  = {2009.14794},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@inproceedings{kitaev2020reformer,
    title       = {Reformer: The Efficient Transformer},
    author      = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
    booktitle   = {International Conference on Learning Representations},
    year        = {2020},
    url         = {https://openreview.net/forum?id=rkgNKkHtvB}
}
@inproceedings{katharopoulos_et_al_2020,
    author  = {Katharopoulos, A. and Vyas, A. and Pappas, N. and Fleuret, F.},
    title   = {Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention},
    booktitle = {Proceedings of the International Conference on Machine Learning (ICML)},
    year    = {2020}
}
@misc{bachlechner2020rezero,
    title   = {ReZero is All You Need: Fast Convergence at Large Depth},
    author  = {Thomas Bachlechner and Bodhisattwa Prasad Majumder and Huanru Henry Mao and Garrison W. Cottrell and Julian McAuley},
    year    = {2020},
    url     = {https://arxiv.org/abs/2003.04887}
}
@article{1910.05895,
    author  = {Toan Q. Nguyen and Julian Salazar},
    title   = {Transformers without Tears: Improving the Normalization of Self-Attention},
    year    = {2019},
    eprint  = {arXiv:1910.05895},
    doi     = {10.5281/zenodo.3525484},
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{roy*2020efficient,
    title   = {Efficient Content-Based Sparse Attention with Routing Transformers},
    author  = {Aurko Roy* and Mohammad Taghi Saffar* and David Grangier and Ashish Vaswani},
    year    = {2020},
    url     = {https://arxiv.org/pdf/2003.05997.pdf}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL},
    url     = {https://arxiv.org/abs/2104.09864}
}

.\lucidrains\performer-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'performer-pytorch',
  # 查找并包含除了'examples'之外的所有包
  packages = find_packages(exclude=['examples']),
  # 版本号
  version = '1.1.4',
  # 许可证
  license='MIT',
  # 描述
  description = 'Performer - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/performer-pytorch',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'attention mechanism',
    'efficient attention',
    'transformers'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.3',
    'local-attention>=1.1.1',
    'torch>=1.6',
    'axial-positional-embedding>=0.1.0'
  ],
  # 分类标签
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\perfusion-pytorch\perfusion_pytorch\embedding.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn, Tensor
from torch import nn, Tensor
# 从 torch.nn 库中导入 Module
from torch.nn import Module

# 从 collections 库中导入 namedtuple
from collections import namedtuple

# 从 beartype 库中导入 beartype
from beartype import beartype
# 从 beartype.door 库中导入 is_bearable
from beartype.door import is_bearable
# 从 beartype.typing 库中导入 Optional, Tuple, Union, Callable, List
from beartype.typing import Optional, Tuple, Union, Callable, List

# 从 einops 库中导入 rearrange
from einops import rearrange

# 从 open_clip 库中导入 tokenizer
from open_clip import tokenizer

# 定义常量 EmbeddingReturn 为一个命名元组,包含 'embed_with_concept', 'embed_with_superclass', 'embed_mask', 'concept_indices' 四个字段
EmbeddingReturn = namedtuple('EmbeddingReturn', [
    'embed_with_concept',
    'embed_with_superclass',
    'embed_mask',
    'concept_indices'
])

# 定义辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 判断列表中元素是否全部唯一
def is_all_unique(arr):
    return len(set(arr)) == len(arr)

# 根据给定的索引过滤元组中的元素
def filter_tuple_indices(tup, indices):
    return tuple(tup[i] for i in indices)

# 根据给定的 ids 创建一个 mask
@beartype
def get_mask(
    x: Tensor,
    ids: Tuple[int, ...]
):
    masks = tuple(x == i for i in ids)
    mask, *rest_masks = masks

    for rest_mask in rest_masks:
        mask = mask | rest_mask

    return mask

# 嵌入包装类

class EmbeddingWrapper(Module):

    # 初始化函数
    @beartype
    def __init__(
        self,
        embed: nn.Embedding,
        num_concepts = 1,
        superclass_embed_id: Optional[Union[int, Tuple[int, ...]]] = None,
        superclass_string: Optional[str] = None,
        tokenize: Callable[[List[str]], Tensor] = tokenizer.tokenize,
        tokenizer_pad_id: int = 0,
        tokenizer_sos_eos_id: Tuple[int, int] = (49406, 49407)
    ):
        super().__init__()
        self.embed = embed
        num_embeds, dim = embed.weight.shape

        self.num_embeds = num_embeds
        self.num_concepts = num_concepts
        self.concepts = nn.Parameter(torch.zeros(num_concepts, dim))

        assert not (exists(superclass_embed_id) and exists(superclass_string)), 'either superclass embed id is given, or the superclass string'

        self.pad_id = tokenizer_pad_id
        self.tokenize = None

        if exists(superclass_string):
            self.tokenize = tokenize

            ids = tokenize([superclass_string])[0]

            mask_for_ids = get_mask(ids, (tokenizer_pad_id, *tokenizer_sos_eos_id))
            ids = ids[~mask_for_ids]

            assert ids.shape[-1] == 1, f'your superclass concept string must map exactly one token id'
            superclass_embed_id = ids[0].item()

            print(f'super class embed for "{superclass_string}"" set as {superclass_embed_id}')
            print(f'you can now pass in a list of strings containing superclass concept, and this wrapper will return the embedding w/ concept and superclass required for finetuning')

        self.superclass_embed_id = superclass_embed_id

        assert not (exists(superclass_embed_id) and num_concepts > 1), 'cannot do multi concept with superclass embed id given'

        if exists(superclass_embed_id):
            # 作者发现将概念嵌入初始化为超类嵌入会获得更好的结果,允许这种选项

            if not isinstance(superclass_embed_id, tuple):
                superclass_embed_id = (superclass_embed_id,)

            superclass_embed_indices = torch.tensor(list(superclass_embed_id))
            superclass_embeds = embed(superclass_embed_indices)
            self.concepts.data.copy_(superclass_embeds)
        else:
            # 否则初始化为通常用于嵌入的小初始化值

            nn.init.normal_(self.concepts, std = 0.02)

        self.concept_embed_ids = tuple(range(num_embeds, num_embeds + num_concepts))

    # 返回参数
    def parameters(self):
        return [self.concepts]

    # 返回设备
    @property
    def device(self):
        return self.concepts.device

    # 前向传播函数
    @beartype
    def forward(
        self,
        x: Union[Tensor, List[str]],
        concept_id: Optional[Union[int, Tuple[int, ...]]] = None,
        return_embed_with_superclass = True,
        clip_transformer_fn: Optional[Callable[[Tensor], Tensor]] = None
# 一个用于 CLIP 的包装器
# 自动将令牌嵌入与新概念包装在一起
# 定义一个类 OpenClipEmbedWrapper,用于包装 CLIP 模型的嵌入层,并在前向传播中通过文本转换器和最终层归一化层传递概念嵌入和超类概念嵌入
# 同时,将 ids 和 superclass_ids 通过修改后的文本编码器传递两次(将尝试用 nn.Identity 替换 nn.Embedding)

class OpenClipEmbedWrapper(Module):
    @beartype
    def __init__(
        self,
        clip: Module,
        text_transformer_path = 'transformer',
        ln_final_path = 'ln_final',  # 在 CLIP 中,最终的层归一化层与转换器分开
        **embedding_wrapper_kwargs
    ):
        super().__init__()
        # 创建一个嵌入层包装器,用于包装 CLIP 模型的 token 嵌入
        self.wrapped_embed = EmbeddingWrapper(clip.token_embedding, **embedding_wrapper_kwargs)

        # 获取 CLIP 模型中各模块的路径和模块对象的字典
        path_to_modules = dict([(path, mod) for path, mod in clip.named_modules()])

        # 确保文本转换器路径在路径字典中
        assert text_transformer_path in path_to_modules

        # 获取文本转换器和最终层归一化层(如果存在)
        text_transformer = path_to_modules[text_transformer_path]
        ln_final = path_to_modules.get(ln_final_path, nn.Identity())

        # 将文本转换器和最终层归一化层组合成一个序列
        self.text_transformer = nn.Sequential(
            text_transformer,
            ln_final
        )

    # 前向传播函数,接收输入 x 和其他关键字参数,返回嵌入层包装器
    def forward(
        self,
        x,
        **kwargs
    ) -> EmbeddingWrapper:
        # 通过嵌入层包装器获取文本嵌入、超类文本嵌入、文本掩码和概念索引
        text_embeds, superclass_text_embeds, text_mask, concept_indices = self.wrapped_embed(x, **kwargs)

        # 将文本嵌入传递给文本转换器
        text_enc = self.text_transformer(text_embeds)

        superclass_text_enc = None

        # 如果超类文本嵌入存在,则将其传递给文本转换器
        if exists(superclass_text_embeds):
            superclass_text_enc = self.text_transformer(superclass_text_embeds)

        # 返回嵌入返回对象,包括文本嵌入、超类文本嵌入、文本掩码和概念索引
        return EmbeddingReturn(text_enc, superclass_text_enc, text_mask, concept_indices)

# 将多个嵌入层包装器(每个具有一个概念)合并为一个具有多个概念的合并嵌入层包装器

@beartype
def merge_embedding_wrappers(
    *embeds: EmbeddingWrapper
) -> EmbeddingWrapper:

    # 计算总概念数
    total_concepts = sum([embed.num_concepts for embed in embeds])

    # 确保所有嵌入层的权重形状相同
    assert len(set([tuple(embed.embed.weight.shape) for embed in embeds])) == 1

    # 获取第一个嵌入层的嵌入
    embed = embeds[0].embed

    # 创建一个合并的嵌入层包装器,包括总概念数
    merged_concepts = EmbeddingWrapper(
        embed = embed,
        num_concepts = total_concepts
    )

    # 将合并的嵌入层包装器设置为评估模式
    merged_concepts.eval()

    # 将所有嵌入层的概念连接起来
    concepts = torch.cat(tuple(embed.concepts.data for embed in embeds), dim = 0)

    # 将连接后的概念设置为合并的嵌入层包装器的概念
    merged_concepts.concepts = nn.Parameter(concepts)

    # 返回合并的嵌入层包装器
    return merged_concepts

.\lucidrains\perfusion-pytorch\perfusion_pytorch\open_clip.py

# 导入必要的库
from beartype import beartype
from beartype.typing import List, Optional

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange

import open_clip

# 定义一个函数,用于检查值是否存在
def exists(val):
    return val is not None

# 定义一个函数,用于对张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# 定义一个类,用于适配 OpenCLIP 模型
class OpenClipAdapter(nn.Module):
    @beartype
    def __init__(
        self,
        name = 'ViT-B/32',
        pretrained = 'laion400m_e32',
        tokenizer_name = 'ViT-B-32-quickgelu',
        eos_id = 49407
    ):
        super().__init__()

        # 创建 OpenCLIP 模型、预处理函数和 tokenizer
        clip, _, preprocess = open_clip.create_model_and_transforms(name, pretrained = pretrained)
        tokenizer = open_clip.get_tokenizer(tokenizer_name)

        self.clip = clip
        self.tokenizer = tokenizer
        self.eos_id = eos_id

        # 用于获取最终文本表示的钩子

        text_attention_final = self.find_layer('ln_final')
        self._dim_latent = text_attention_final.weight.shape[0]
        self.text_handle = text_attention_final.register_forward_hook(self._text_hook)

        # 标准化函数

        self.clip_normalize = preprocess.transforms[-1]
        self.cleared = False

    @property
    def device(self):
        return next(self.parameters()).device

    # 查找指定层
    def find_layer(self,  layer):
        modules = dict([*self.clip.named_modules()])
        return modules.get(layer, None)

    # 清除钩子
    def clear(self):
        if self.cleared:
            return

        self.text_handle()

    # 文本钩子函数
    def _text_hook(self, _, inputs, outputs):
        self.text_encodings = outputs

    @property
    def dim_latent(self):
        return self._dim_latent

    @property
    def max_text_len(self):
        return self.clip.positional_embedding.shape[0]

    @beartype
    def embed_texts(
        self,
        texts: List[str]
    ):
        # 对文本进行编码
        ids = self.tokenizer(texts)
        ids = ids.to(self.device)
        ids = ids[..., :self.max_text_len]

        is_eos_id = (ids == self.eos_id)
        text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
        text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
        text_mask = text_mask & (ids != 0)
        assert not self.cleared

        # 编码文本并进行掩码
        text_embed = self.clip.encode_text(ids)
        text_encodings = self.text_encodings
        text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
        return text_encodings.float(), text_mask

.\lucidrains\perfusion-pytorch\perfusion_pytorch\optimizer.py

# 从 torch.nn 模块中导入 Module 类
# 从 torch.optim 模块中导入 AdamW、Adam、Optimizer 类
from torch.nn import Module
from torch.optim import AdamW, Adam, Optimizer

# 从 beartype 模块中导入 beartype 装饰器
from beartype import beartype

# 从 perfusion_pytorch.embedding 模块中导入 EmbeddingWrapper 类
# 从 perfusion_pytorch.perfusion 模块中导入 Rank1EditModule 类

from perfusion_pytorch.embedding import EmbeddingWrapper
from perfusion_pytorch.perfusion import Rank1EditModule

# 定义一个函数,用于自动查找微调所需的所有参数
@beartype
def get_finetune_parameters(text_image_model: Module):
    # 初始化参数列表
    params = []
    # 遍历 text_image_model 模块中的所有子模块
    for module in text_image_model.modules():
        # 如果子模块是 EmbeddingWrapper 或 Rank1EditModule 类型
        if isinstance(module, (EmbeddingWrapper, Rank1EditModule)):
            # 将子模块的参数添加到参数列表中
            params.extend(module.parameters())

    # 返回参数列表
    return params

# 定义一个函数,用于获取微调优化器
@beartype
def get_finetune_optimizer(
    text_image_model: Module,
    lr = 1e-4,
    wd = 1e-2,
    betas = (0.9, 0.99),
    eps = 1e-8,
    **kwargs
) -> Optimizer:
    # 获取微调所需的参数
    params = get_finetune_parameters(text_image_model)

    # 断言参数列表长度大于0,否则抛出异常
    assert len(params) > 0, 'no finetuneable parameters found'
    # 计算总参数数量
    total_params = sum([p.numel() for p in params])
    # 打印优化的参数数量
    print(f'optimizing {total_params} parameters')

    # 判断是否有权重衰减
    has_weight_decay = wd > 0
    # 根据是否有权重衰减选择 AdamW 或 Adam 类
    adam_klass = AdamW if has_weight_decay else Adam
    # 初始化 Adam 的参数
    adam_kwargs = dict(lr = lr, betas = betas, eps = eps)

    # 如果有权重衰减,则更新参数字典
    if has_weight_decay:
        adam_kwargs.update(weight_decay = wd)

    # 返回根据参数和参数字典初始化的优化器
    return adam_klass(params, **adam_kwargs, **kwargs)

.\lucidrains\perfusion-pytorch\perfusion_pytorch\perfusion.py

# 从 math 模块中导入 ceil 函数
# 从 copy 模块中导入 deepcopy 函数
# 从 pathlib 模块中导入 Path 类
# 从 beartype 模块中导入 beartype 装饰器
# 从 beartype.typing 模块中导入 Union, List, Optional, Tuple 类型
# 从 torch 模块中导入 nn, einsum, Tensor 类
# 从 torch.nn 模块中导入 Module 类
# 从 torch.nn.functional 模块中导入 F 函数
# 从 einops 模块中导入 rearrange, reduce 函数
# 从 opt_einsum 模块中导入 contract 函数
# 从 perfusion_pytorch.open_clip 模块中导入 OpenClipAdapter 类

from math import ceil
from copy import deepcopy
from pathlib import Path

from beartype import beartype
from beartype.typing import Union, List, Optional, Tuple

import torch
from torch import nn, einsum, Tensor
from torch.nn import Module
import torch.nn.functional as F

from einops import rearrange, reduce

from opt_einsum import contract as opt_einsum

from perfusion_pytorch.open_clip import OpenClipAdapter

# 预先计算的协方差路径
# 如果论文验证通过,将为更多模型添加

CURRENT_DIR = Path(__file__).parents[0]
DATA_DIR = CURRENT_DIR / 'data'

assert DATA_DIR.is_dir()

COVARIANCE_FILENAME_BY_TEXT_IMAGE_MODEL = dict(
    SD15 = DATA_DIR / 'covariance_CLIP_ViT-L-14.pt'
)

assert all([filepath.exists() for filepath in COVARIANCE_FILENAME_BY_TEXT_IMAGE_MODEL.values()])

# 辅助函数

def exists(val):
    return val is not None

def is_all_unique(arr):
    return len(set(arr)) == len(arr)

# 用于计算 C - 输入协方差的函数

@beartype
@torch.no_grad()
def calculate_input_covariance(
    clip: OpenClipAdapter,
    texts: List[str],
    batch_size = 32,
    **cov_kwargs
):
    num_batches = ceil(len(texts) / batch_size)

    all_embeds = []

    length = len(texts)

    for batch_ind in range(num_batches):
        start_index = batch_ind * batch_size
        batch_texts = texts[start_index:(start_index + batch_size)]

        embeds, mask = clip.embed_texts(batch_texts)
        all_embeds.append(embeds[mask])

    all_embeds = torch.cat(all_embeds, dim = 0)

    return einsum('n d, n e -> d e', all_embeds, all_embeds) / length

# 由掩码加权的损失函数

@beartype
def loss_fn_weighted_by_mask(
    pred: Tensor,
    target: Tensor,
    mask: Tensor,
    normalized_mask_min_value = 0.
):
    assert mask.shape[-2:] == pred.shape[-2:] == target.shape[-2:]
    assert mask.shape[0] == pred.shape[0] == target.shape[0]

    assert (mask.amin() >= 0.).all(), 'mask should not have values below 0'

    if mask.ndim == 4:
        assert mask.shape[1] == 1
        mask = rearrange(mask, 'b 1 h w -> b h w')

    loss = F.mse_loss(pred, target, reduction = 'none')
    loss = reduce(loss, 'b c h w -> b h w')

    # 通过最大值对掩码进行归一化

    normalized_mask = mask / mask.amax(dim = -1, keepdim = True).clamp(min = 1e-5)
    normalized_mask = normalized_mask.clamp(min = normalized_mask_min_value)

    loss = loss * normalized_mask

    return loss.mean()

# 一个模块,包装了交叉注意力的键和值投影到文本编码

class Rank1EditModule(Module):

    @beartype
    def __init__(
        self,
        key_or_values_proj: nn.Linear,
        *,
        num_concepts: int = 1,
        C: Optional[Tensor] = None,          # 输入的协方差,从 100K laion 文本中预先计算
        default_model = 'SD15',
        text_seq_len: int = 77,
        is_key_proj: bool = False,
        input_decay = 0.99,
        train_beta = 0.75,
        train_temperature = 0.1,
        eval_beta = 0.70,                    # 在论文中,指定了本地键锁定的范围 (0.6 - 0.75),全局键锁定的范围 (0.4 -0.6)
        eval_temperature = 0.15,
        frac_gradient_concept_embed = 0.1,   # 他们使用一个较慢的学习率来嵌入 - 这可以通过一个技巧来减少反向传播的梯度
        multi_concepts_use_cholesky = False  # 对于多个概念,使用一种不需要 Cholesky 根的近似技术
        ):
        # 调用父类的构造函数
        super().__init__()
        # 断言在注意力中的键值投影不应该有偏置
        assert not exists(key_or_values_proj.bias), 'key value projection in attention should not have bias'

        # 初始化注意力模块的参数
        self.num_concepts = num_concepts
        self.multi_concepts_use_cholesky = multi_concepts_use_cholesky

        # 获取键值投影的权重
        self.weight = key_or_values_proj.weight
        dim_output, dim_input = self.weight.shape

        # 设置训练和评估时的温度和 beta 参数
        self.train_beta = train_beta
        self.train_temperature = train_temperature
        self.eval_beta = eval_beta
        self.eval_temperature = eval_temperature

        # 输入的衰减参数
        self.input_decay = input_decay

        # 文本序列的长度
        self.text_seq_len = text_seq_len

        # 降低概念嵌入学习率的参数
        assert 0 < frac_gradient_concept_embed <= 1.
        self.frac_gradient_concept_embed = frac_gradient_concept_embed

        # 初始化概念文本嵌入的指数平滑参数
        self.register_buffer('initted', torch.zeros(num_concepts, 1).bool())
        self.register_buffer('ema_concept_text_encs', torch.zeros(num_concepts, dim_input))

        # 概念输出 - 仅优化值,而不是键
        self.is_key_proj = is_key_proj # 锁定输出到超类,并关闭梯度

        self.concept_outputs = nn.Parameter(torch.zeros(num_concepts, dim_output), requires_grad = not is_key_proj)

        # 输入协方差 C 的逆矩阵,如果未传入协方差,则使用默认值
        if not exists(C):
            covariance_filepath = COVARIANCE_FILENAME_BY_TEXT_IMAGE_MODEL.get(default_model, None)

            assert exists(covariance_filepath), f'{default_model} not found in the list of precomputed covariances {tuple(COVARIANCE_FILENAME_BY_TEXT_IMAGE_MODEL.keys())}'

            C = torch.load(str(covariance_filepath))
            print(f'precomputed covariance loaded from {str(covariance_filepath)}')

        # 计算 C_inv
        C_inv = torch.inverse(C)
        self.register_buffer('C_inv', C_inv)

    @property
    def num_concepts(self):
        return self._num_concepts

    @num_concepts.setter
    def num_concepts(self, value):
        self._num_concepts = value

        if value == 1 or not self.multi_concepts_use_cholesky:
            return

        # 对于多个概念,需要 cholesky 分解 L_t_inv
        try:
            L = torch.linalg.cholesky(self.C_inv)
        except:
            print('unable to perform cholesky. please make sure input covariance matrix is properly calculated')
            exit()

        L_T = L.T
        L_T_inv = torch.inverse(L_T)

        self.register_buffer('L_T', L_T, persistent = False)
        self.register_buffer('L_T_inv', L_T_inv, persistent = False)

    @property
    def device(self):
        return next(self.buffers()).device

    # 返回参数
    def parameters(self):
        if not self.is_key_proj:
            return []

        return [self.concept_outputs]

    @beartype
    def forward(
        self,
        text_enc: Tensor,
        *,
        concept_indices: Optional[Tensor] = None,
        text_enc_with_superclass: Optional[Tensor] = None,
        concept_id: Union[int, Tuple[int, ...]] = 0
# 合并已训练的 Rank1EditModule(s) 的函数

@beartype
def merge_rank1_edit_modules(
    *modules: Rank1EditModule,  # 接受多个 Rank1EditModule 参数
    use_cholesky = False  # 是否使用 Cholesky 分解,默认为 False
) -> Rank1EditModule:  # 返回合并后的 Rank1EditModule 对象

    # 断言所有模块都已初始化并最好已训练
    assert all([m.initted.all() for m in modules]), 'all modules must be initialized and ideally trained'
    # 断言概念输出维度必须相同
    assert len(set([m.concept_outputs.shape[-1] for m in modules])) == 1, 'concept output dimension must be the same'
    # 断言所有模块必须为键或值。不能将键和值的 Rank1EditModule 合并在一起
    assert len(set([m.is_key_proj for m in modules])) == 1, 'all modules must be either for keys, or values. you cannot merge rank 1 edit modules of keys and values together'

    # 获取第一个模块
    first_module = modules[0]
    # 深拷贝第一个模块
    merged_module = deepcopy(first_module)
    # 设置是否使用 Cholesky 分解
    merged_module.multi_concepts_use_cholesky = use_cholesky

    # 计算总概念数
    total_concepts = sum([m.num_concepts for m in modules])
    merged_module.num_concepts = total_concepts

    # 拼接所有模块的概念输出
    concept_outputs = torch.cat(tuple(m.concept_outputs.data for m in modules), dim = 0)
    merged_module.concept_outputs = nn.Parameter(concept_outputs, requires_grad = not first_module.is_key_proj)

    # 拼接所有模块的 EMA 概念文本编码
    ema_concept_text_encs = torch.cat(tuple(m.ema_concept_text_encs.data for m in modules), dim = 0)
    merged_module.register_buffer('ema_concept_text_encs',  ema_concept_text_encs)

    # 注册初始化状态
    merged_module.register_buffer('initted', torch.ones(total_concepts, 1).bool())

    # 返回合并后的模块
    return merged_module

# 用于连接交叉注意力的函数

@beartype
def make_key_value_proj_rank1_edit_modules_(
    cross_attention: nn.Module,  # 交叉注意力模块
    *,
    input_covariance: Tensor,  # 输入协方差
    key_proj_name: str,  # 键投影名称
    value_proj_name: str,  # 值投影名称
    **rank1_edit_module_kwargs  # Rank1EditModule 的其他参数
):
    # 获取键投影和值投影线性层
    linear_key = getattr(cross_attention, key_proj_name, None)
    linear_values = getattr(cross_attention, value_proj_name, None)

    # 断言键投影和值投影必须是 nn.Linear 类型
    assert isinstance(linear_key, nn.Linear), f'{key_proj_name} must point to where the keys projection is (ex. self.to_keys = nn.Linear(in, out, bias = False) -> key_proj_name = "to_keys")'
    assert isinstance(linear_values, nn.Linear), f'{value_proj_name} must point to where the values projection is (ex. self.to_values = nn.Linear(in, out, bias = False) -> value_proj_name = "to_values")'

    # 创建键和值的 Rank1EditModule
    rank1_edit_module_keys = Rank1EditModule(linear_key, input_covariance = input_covariance, is_key_proj = True, **rank1_edit_module_kwargs)
    rank1_edit_module_values = Rank1EditModule(linear_values, input_covariance = input_covariance, is_key_proj = False, **rank1_edit_module_kwargs)

    # 将 Rank1EditModule 设置为键投影和值投影
    setattr(cross_attention, key_proj_name, rank1_edit_module_keys)
    setattr(cross_attention, value_proj_name, rank1_edit_module_values)
posted @ 2024-06-28 14:02  绝不原创的飞龙  阅读(5)  评论(0编辑  收藏  举报