Lucidrains-系列项目源码解析-六-

Lucidrains 系列项目源码解析(六)

.\lucidrains\CoCa-pytorch\coca_pytorch\__init__.py

# 从 coca_pytorch 模块中导入 CoCa 类
from coca_pytorch.coca_pytorch import CoCa

CoCa - Pytorch

Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch. They were able to elegantly fit in contrastive learning to a conventional encoder / decoder (image to text) transformer, achieving SOTA 91.0% top-1 accuracy on ImageNet with a finetuned encoder.

This repository also chooses to adopt the specific transformer architecture from PaLM, for both the unimodal and multimodal transformers as well as the cross attention blocks (parallel SwiGLU feedforwards)

Update: CoCa has been trained by the good folks over at OpenClip

Install

$ pip install coca-pytorch

Usage

First install the vit-pytorch for the image encoder, which needs to be pretrained

$ pip install vit-pytorch>=0.40.2

Then

import torch

# import vision transformer

from vit_pytorch.simple_vit_with_patch_dropout import SimpleViT
from vit_pytorch.extractor import Extractor

vit = SimpleViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    patch_dropout = 0.5  # https://arxiv.org/abs/2212.00794
)

vit = Extractor(vit, return_embeddings_only = True, detach = False)

# extractor will enable it so the vision transformer returns its embeddings

# import CoCa and instantiate it

from coca_pytorch.coca_pytorch import CoCa

coca = CoCa(
    dim = 512,                     # model dimension
    img_encoder = vit,             # vision transformer - image encoder, returning image embeddings as (batch, seq, dim)
    image_dim = 1024,              # image embedding dimension, if not the same as model dimensions
    num_tokens = 20000,            # number of text tokens
    unimodal_depth = 6,            # depth of the unimodal transformer
    multimodal_depth = 6,          # depth of the multimodal transformer
    dim_head = 64,                 # dimension per attention head
    heads = 8,                     # number of attention heads
    caption_loss_weight = 1.,      # weight on the autoregressive caption loss
    contrastive_loss_weight = 1.,  # weight on the contrastive loss between image and text CLS embeddings
).cuda()

# mock text and images

text = torch.randint(0, 20000, (4, 512)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# train by giving CoCa your text and images with `return_loss = True`

loss = coca(
    text = text,
    images = images,
    return_loss = True  # set this to True to get the full caption + contrastive loss
)

loss.backward()

# do the above for as much text and images...
# then you can get the caption logits as so

logits = coca(
    text = text,
    images = images
) # (4, 512, 20000)

# and the CLIP-like text and image embeddings as

text_embeds, image_embeds = coca(
    text = text,
    images = images,
    return_embeddings = True
) # (4, 512), (4, 512)

Citations

@inproceedings{Yu2022CoCaCC,
  title   = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
  author  = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
  year    = {2022}
}
@inproceedings{Chowdhery2022PaLMSL,
    title   = {PaLM: Scaling Language Modeling with Pathways},
    author  = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
    year    = {2022}
}

.\lucidrains\CoCa-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'CoCa-pytorch', # 包的名称
  packages = find_packages(exclude=[]), # 查找所有包
  version = '0.1.0', # 版本号
  license='MIT', # 许可证
  description = 'CoCa, Contrastive Captioners are Image-Text Foundation Models - Pytorch', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  long_description_content_type = 'text/markdown', # 长描述内容类型
  url = 'https://github.com/lucidrains/CoCa-pytorch', # 项目链接
  keywords = [ # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'contrastive learning',
    'multimodal'
  ],
  install_requires=[ # 安装依赖
    'einops>=0.4',
    'torch>=1.6',
  ],
  classifiers=[ # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\coco-lm-pytorch\coco_lm_pytorch\coco_lm_pytorch.py

# 导入数学库
import math
# 从 functools 库中导入 reduce 函数
from functools import reduce

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch.nn.functional 库中导入 F 模块
import torch.nn.functional as F

# 辅助函数

# 计算输入张量的对数,加上一个很小的值 eps 防止出现对数值为负数的情况
def log(t, eps=1e-9):
    return torch.log(t + eps)

# 对输入张量进行 L2 归一化
def norm(t):
    return F.normalize(t, p = 2, dim = -1)

# 生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 使用 Gumbel 噪声对输入张量进行采样
def gumbel_sample(t, temperature = 1.):
    return ((t / temperature) + gumbel_noise(t)).argmax(dim=-1)

# 根据概率生成掩码
def prob_mask_like(t, prob):
    return torch.zeros_like(t).float().uniform_(0, 1) < prob

# 根据给定的标记 ID 列表生成掩码
def mask_with_tokens(t, token_ids):
    init_no_mask = torch.full_like(t, False, dtype=torch.bool)
    mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
    return mask

# 根据概率生成子集掩码
def get_mask_subset_with_prob(mask, prob):
    batch, seq_len, device = *mask.shape, mask.device
    max_masked = math.ceil(prob * seq_len)

    num_tokens = mask.sum(dim=-1, keepdim=True)
    mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil())
    mask_excess = mask_excess[:, :max_masked]

    rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9)
    _, sampled_indices = rand.topk(max_masked, dim=-1)
    sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0)

    new_mask = torch.zeros((batch, seq_len + 1), device=device)
    new_mask.scatter_(-1, sampled_indices, 1)
    return new_mask[:, 1:].bool()

# 隐藏层提取器类,用于在语言模型中神奇地添加适配器以进行预训练

class HiddenLayerExtractor(nn.Module):
    def __init__(self, net, layer = -2):
        super().__init__()
        self.net = net
        self.layer = layer

        self.hidden = None
        self.hook_registered = False

    def _find_layer(self):
        if type(self.layer) == str:
            modules = dict([*self.net.named_modules()])
            return modules.get(self.layer, None)
        elif type(self.layer) == int:
            children = [*self.net.children()]
            return children[self.layer]
        return None

    def _hook(self, _, __, output):
        self.hidden = output

    def _register_hook(self):
        layer = self._find_layer()
        assert layer is not None, f'hidden layer ({self.layer}) not found'
        handle = layer.register_forward_hook(self._hook)
        self.hook_registered = True

    def forward(self, x):
        if self.layer == -1:
            return self.net(x)

        if not self.hook_registered:
            self._register_hook()

        _ = self.net(x)
        hidden = self.hidden
        self.hidden = None
        assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
        return hidden

# 主要的 Electra 类

class COCO(nn.Module):
    def __init__(
        self,
        generator,
        discriminator,
        *,
        discr_dim,
        num_tokens = None,
        discr_layer = -1,
        mask_prob = 0.15,
        replace_prob = 0.85,
        random_token_prob = 0.,
        pad_token_id = 0,
        cls_token_id = 1,
        mask_token_id = 2,
        mask_ignore_token_ids = [],
        disc_weight = 50.,
        gen_weight = 1.,
        cl_weight = 1.,
        temperature = 1.,
        crop_percentage = 0.5
        ):
        # 调用父类的构造函数
        super().__init__()

        # 初始化生成器和鉴别器
        self.generator = generator
        self.discriminator = discriminator

        # 提取鉴别器的隐藏层特征
        self.discriminator = HiddenLayerExtractor(discriminator, layer = discr_layer)
        # 将鉴别器的维度映射到1维
        self.to_correction_logits = nn.Linear(discr_dim, 1)

        # MLM相关的概率
        self.mask_prob = mask_prob
        self.replace_prob = replace_prob

        # token的数量
        self.num_tokens = num_tokens
        self.random_token_prob = random_token_prob

        # token的id
        self.cls_token_id = cls_token_id
        self.pad_token_id = pad_token_id
        self.mask_token_id = mask_token_id
        self.mask_ignore_token_ids = set([*mask_ignore_token_ids, pad_token_id, cls_token_id])

        # 采样温度
        self.temperature = temperature

        # 损失权重
        self.disc_weight = disc_weight
        self.gen_weight = gen_weight
        self.cl_weight = cl_weight

        # Contrastive Loss的温度参数
        self.cl_temperature = nn.Parameter(torch.tensor(1.))

        # 裁剪百分比
        self.crop_percentage = crop_percentage

.\lucidrains\coco-lm-pytorch\coco_lm_pytorch\__init__.py

# 从 coco_lm_pytorch.coco_lm_pytorch 模块中导入 COCO 类
from coco_lm_pytorch.coco_lm_pytorch import COCO

COCO LM Pretraining (wip)

Implementation of COCO-LM, Correcting and Contrasting Text Sequences for Language Model Pretraining, in Pytorch. They were able to make contrastive learning work in a self-supervised manner for language model pretraining. Seems like a solid successor to Electra.

Install

$ pip install coco-lm-pytorch

Usage

An example using the x-transformers library

$ pip install x-transformers

Then

import torch
from coco_lm_pytorch import COCO

# (1) instantiate the generator and discriminator, making sure that the generator is roughly a quarter to a half of the size of the discriminator

from x_transformers import TransformerWrapper, Encoder

generator = TransformerWrapper(
    num_tokens = 20000,
    emb_dim = 128,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 256,         # smaller hidden dimension
        heads = 4,         # less heads
        ff_mult = 2,       # smaller feedforward dimension
        depth = 1
    )
)

discriminator = TransformerWrapper(
    num_tokens = 20000,
    emb_dim = 128,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 1024,
        heads = 16,
        ff_mult = 4,
        depth = 12
    )
)

# (2) weight tie the token and positional embeddings of generator and discriminator

generator.token_emb = discriminator.token_emb
generator.pos_emb = discriminator.pos_emb

# weight tie any other embeddings if available, token type embeddings, etc.

# (3) instantiate COCO

trainer = COCO(
    generator,
    discriminator,
    discr_dim = 1024,            # the embedding dimension of the discriminator
    discr_layer = 'norm',        # the layer name in the discriminator, whose output would be used for predicting token is still the same or replaced
    cls_token_id = 1,            # a token id must be reserved for [CLS], which is prepended to the sequence for contrastive learning
    mask_token_id = 2,           # the token id reserved for masking
    pad_token_id = 0,            # the token id for padding
    mask_prob = 0.15,            # masking probability for masked language modeling
    mask_ignore_token_ids = [],  # ids of tokens to ignore for mask modeling ex. (cls, sep)
    cl_weight = 1.,              # weight for the contrastive learning loss
    disc_weight = 1.,            # weight for the corrective learning loss
    gen_weight = 1.              # weight for the MLM loss
)

# (4) train

data = torch.randint(0, 20000, (1, 1024))

loss = trainer(data)
loss.backward()

# after much training, the discriminator should have improved

torch.save(discriminator, f'./pretrained-model.pt')

Citations

@misc{meng2021cocolm,
    title   = {COCO-LM: Correcting and Contrasting Text Sequences for Language Model Pretraining}, 
    author  = {Yu Meng and Chenyan Xiong and Payal Bajaj and Saurabh Tiwary and Paul Bennett and Jiawei Han and Xia Song},
    year    = {2021},
    eprint  = {2102.08473},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

.\lucidrains\coco-lm-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的信息
setup(
  name = 'coco-lm-pytorch',  # 包名
  packages = find_packages(),  # 查找所有包
  version = '0.0.2',  # 版本号
  license='MIT',  # 许可证
  description = 'COCO - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/coco-lm-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'transformers',
    'artificial intelligence',
    'deep learning',
    'pretraining'
  ],
  install_requires=[  # 安装依赖
    'torch>=1.6.0',
    'einops',
    'x-transformers'
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.7',
  ],
)

coffee-neural-network

a simple neural network in coffeescript

running

$ npm install
$ npm install coffee-script -g
$ coffee nn.coffee

.\lucidrains\CoLT5-attention\colt5_attention\attend.py

# 导入必要的库
from collections import namedtuple
from functools import wraps
from packaging import version

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange

# 定义一个命名元组 Config,用于存储 EfficientAttention 的配置信息
Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义一个辅助函数,用于检查值是否存在
def exists(val):
    return val is not None

# 定义一个装饰器函数,确保被装饰的函数只执行一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 定义一个打印函数,确保只打印一次
print_once = once(print)

# 主要的 Attend 类
class Attend(nn.Module):
    def __init__(
        self,
        dropout = 0.,
        causal = False,
        use_flash = False
    ):
        super().__init__()
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.causal = causal
        self.register_buffer("mask", None, persistent=False)

        self.use_flash = use_flash
        assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定在 cuda 和 cpu 上的高效注意力配置

        self.cpu_config = Config(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not use_flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = Config(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = Config(False, True, True)

    # 获取掩码
    def get_mask(self, n, device):
        if exists(self.mask) and self.mask.shape[-1] >= n:
            return self.mask[:n, :n]

        mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
        self.register_buffer("mask", mask, persistent=False)
        return mask

    # Flash Attention 函数
    def flash_attn(self, q, k, v, mask = None):
        _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda

        # 检查掩码是否存在并扩展到兼容的形状
        if exists(mask) and mask.ndim != 4:
            mask = rearrange(mask, 'b j -> b 1 1 j')
            mask = mask.expand(-1, heads, q_len, -1)

        # 检查是否有兼容的设备用于 Flash Attention
        config = self.cuda_config if is_cuda else self.cpu_config

        # 使用 pytorch 2.0 的 flash attention
        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask = mask,
                dropout_p = self.dropout if self.training else 0., 
                is_causal = self.causal
            )

        return out
    # 定义一个前向传播函数,接受查询(q), 键(k), 值(v)以及可选的掩码(mask)
    """
    einstein notation
    b - batch
    h - heads
    n, i, j - sequence length (base sequence length, source, target)
    d - feature dimension
    """

    # 获取查询(q)的序列长度(n)和设备信息(device)
    n, device = q.shape[-2], q.device

    # 计算缩放因子,根据特征维度的平方根
    scale = q.shape[-1] ** -0.5

    # 如果使用闪回注意力机制,则调用flash_attn函数
    if self.use_flash:
        return self.flash_attn(q, k, v, mask = mask)

    # 计算相似度矩阵

    sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale

    # 键的填充掩码

    if exists(mask):
        if mask.ndim != 4:
            mask = rearrange(mask, 'b j -> b 1 1 j')
        sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

    # 因果掩码

    if self.causal:
        causal_mask = self.get_mask(n, device)
        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

    # 注意力权重计算

    attn = sim.softmax(dim=-1)
    attn = self.attn_dropout(attn)

    # 聚合值

    out = einsum("b h i j, b h j d -> b h i d", attn, v)

    return out

.\lucidrains\CoLT5-attention\colt5_attention\coor_descent.py

# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 torch.cuda.amp 模块中导入 autocast 函数
from torch.cuda.amp import autocast
# 从 einops 库中导入 rearrange 函数
from einops import rearrange

# 定义函数,判断变量是否存在
def exists(val):
    return val is not None

# 定义函数,返回 val 或者默认值 d
def default(val, d):
    return val if exists(val) else d

# 定义函数,计算输入张量的对数,避免值过小
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 使用 autocast 装饰器,设置自动混合精度为关闭
@autocast(enabled = False)
# 定义坐标下降函数
def coor_descent(
    s,
    *,
    n_iters,
    k,
    eps = 1e-1,
    eps_init = None,
    eps_decay = 1.,
    mask = None
):
    """
    coordinate descent  - https://arxiv.org/abs/1502.04759, utilized in https://arxiv.org/abs/2303.09752
    ε-scaling           - https://arxiv.org/abs/1610.06519, utilized in https://arxiv.org/abs/2304.04947

    in a follow up paper applying coordinate descent routing to efficient fine tuning
    they were able to cut n_iters from 50 -> 20 by setting eps_init = 4 and eps_decay = 0.7
    eps was dependent on the task, and ranged from 0.02 to 1
    """

    # 断言迭代次数大于 0
    assert n_iters > 0

    # 定义 mask_value 为 s 数据类型的最小值
    mask_value = -torch.finfo(s.dtype).max

    # 如果 k 不是 torch.Tensor 类型,则将其转换为 torch.Tensor 类型
    if not isinstance(k, torch.Tensor):
        k = torch.Tensor([k]).to(s)
    else:
        k = rearrange(k, '... -> ... 1')

    # 计算 k 的对数
    logk = log(k)

    # 如果 mask 存在,则用 mask_value 填充 s
    if exists(mask):
        s = s.masked_fill(~mask, mask_value)

    # 初始化 a 和 b
    a = 0
    b = -s

    # 初始化当前的 epsilon 值
    current_eps = max(default(eps_init, eps), eps)

    # 迭代 n_iters 次
    for _ in range(n_iters):
        # 计算 sb
        sb = ((s + b) / current_eps)

        # 如果 mask 存在,则用 mask_value 填充 sb
        if exists(mask):
            sb = sb.masked_fill(~mask, mask_value)

        # 更新 a 和 b
        a = current_eps * (logk - sb.logsumexp(dim = -1, keepdim = True))
        b = -F.relu(s + a)

        # 更新当前的 epsilon 值
        current_eps = max(current_eps * eps_decay, eps)

    # 计算分数
    scores = ((s + a + b) / current_eps).exp()

    # 如果 mask 存在,则用 0 填充 scores
    if exists(mask):
        scores = scores.masked_fill(~mask, 0.)

    # 返回分数
    return scores

.\lucidrains\CoLT5-attention\colt5_attention\topk.py

import torch
from torch.cuda.amp import autocast

from collections import namedtuple
from colt5_attention.coor_descent import coor_descent

TopkReturn = namedtuple('TopkReturn', ['values', 'indices', 'coor_descent_values', 'gates'])

@autocast(enabled = False)
def topk(
    x,
    k,
    coor_descent_k_ratio = 9 / 8,
    n_iters = 20,
    eps = 1e-1,
    eps_init = None,
    eps_decay = 1.,
    mask = None,
    fused = False,
    non_differentiable = False
):
    """
    differentiable top-k on last dimension
    """

    if non_differentiable:
        # 如果不需要进行微分计算,则直接使用 torch.topk 函数获取前 k 个值和索引
        values, indices = torch.topk(x, k = k, dim = -1)
        return TopkReturn(values, indices, None, None)

    assert coor_descent_k_ratio >= 1.
    assert k > 0

    # whether to used fused kernel or not

    fn = coor_descent

    if fused and x.is_cuda:
        # 如果开启了 fused 选项并且在 GPU 上,则使用 triton_coor_descent 函数
        from colt5_attention.triton_coor_descent import triton_coor_descent
        fn = triton_coor_descent

    # do coordinate descent for gradients

    # 对梯度进行坐标下降优化
    coor_descent_out = fn(
        x,
        k = min(k * coor_descent_k_ratio, x.shape[-1]),   # 获取稍多一点以获得更好的学习效果,如 CoLT5 论文中所述(他们获取了 9/8 倍)
        mask = mask,
        n_iters = n_iters,
        eps = eps,
        eps_init = eps_init,
        eps_decay = eps_decay
    )

    # do straight through

    # 执行直通操作
    gates = coor_descent_out + (1 - coor_descent_out).detach()

    x = x * gates

    # hard topk

    # 使用 torch.topk 函数获取前 k 个值和索引
    values, indices = torch.topk(x, k, dim = -1)

    # return something that looks like a usual topk, but now differentiable

    # 返回类似于常规 topk 的结果,但现在是可微分的
    coor_descent_values = coor_descent_out.gather(-1, indices)
    gates = gates.gather(-1, indices)

    return TopkReturn(values, indices, coor_descent_values, gates)

.\lucidrains\CoLT5-attention\colt5_attention\transformer_block.py

import math
from functools import partial
from collections import namedtuple

import torch
import torch.nn.functional as F
from torch import Tensor, nn, einsum

from typing import Tuple, Optional

from local_attention import LocalMHA
from einops import rearrange, repeat, pack, unpack

from colt5_attention.attend import Attend

# helper functions

# 检查变量是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 检查是否可以被整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# 将张量打包成指定模式
def pack_one(t, pattern):
    return pack([t], pattern)

# 将打包的张量解包成指定模式
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 将张量填充到指定的倍数
def pad_to_multiple(tensor, multiple, dim=-1, value=0):
    seq_len = tensor.shape[dim]
    m = seq_len / multiple
    if m.is_integer():
        return tensor, seq_len

    remainder = math.ceil(m) * multiple - seq_len
    pad_offset = (0,) * (-1 - dim) * 2
    padded_tensor = F.pad(tensor, (*pad_offset, 0, remainder), value=value)
    return padded_tensor, seq_len

# 从张量中按照索引获取数据
def batched_gather(x, indices):
    batch_range = create_batch_range(indices, indices.ndim - 1)
    return x[batch_range, indices]

# 返回输入张量本身
def identity(t):
    return t

# 对张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim=-1)

# tensor helpers

# 创建批次范围
def create_batch_range(t, right_pad_dims=1):
    b, device = t.shape[0], t.device
    batch_range = torch.arange(b, device=device)
    pad_dims = ((1,) * right_pad_dims)
    return batch_range.reshape(-1, *pad_dims)

# rotary positional embeddign
# https://arxiv.org/abs/2104.09864

# 旋转位置嵌入
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, theta=10000):
        super().__init__()
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

    @property
    def device(self):
        return next(self.buffers()).device

    def forward(self, seq_len):
        t = torch.arange(seq_len, device=self.device).type_as(self.inv_freq)
        freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
        freqs = torch.cat((freqs, freqs), dim=-1)
        return freqs

# 旋转张量的一半
def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

# 应用旋转位置嵌入
def apply_rotary_pos_emb(pos, t):
    return t * pos.cos() + rotate_half(t) * pos.sin()

# normalization

# RMS 归一化
class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        normed = F.normalize(x, dim=-1)
        return normed * self.scale * self.gamma

# modules

# 前馈神经网络
def FeedForward(dim, mult=4):
    dim_hidden = int(dim * mult)
    return nn.Sequential(
        RMSNorm(dim),
        nn.Linear(dim, dim_hidden),
        nn.GELU(),
        nn.Linear(dim_hidden, dim)
    )

# 自注意力机制
class SelfAttention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head=64,
        heads=8,
        use_flash=False,
        prenorm=False
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        dim_hidden = dim_head * heads

        self.norm = RMSNorm(dim) if prenorm else nn.Identity()

        self.attend = Attend(use_flash=use_flash)

        self.to_qkv = nn.Linear(dim, dim_hidden * 3, bias=False)
        self.to_out = nn.Linear(dim_hidden, dim, bias=False)

    def forward(self, x):
        h = self.heads

        x = self.norm(x)

        # 获取查询、键、值
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))

        # 注意力
        out = self.attend(q, k, v)

        # 合并头部
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head=64,
        heads=8,
        multiply_keys_by_score=False,
        use_flash=False
        # 调用父类的初始化方法
        super().__init__()
        # 初始化头数和头维度的比例
        self.heads = heads
        self.scale = dim_head ** -0.5
        # 计算隐藏层维度
        dim_hidden = dim_head * heads

        # 设置是否使用乘以键的分数
        self.multiply_keys_by_score = multiply_keys_by_score

        # 初始化 RMS 归一化层
        self.norm = RMSNorm(dim)
        # 初始化空键值对参数
        self.null_kv = nn.Parameter(torch.randn(2, heads, dim_head))

        # 初始化 Attend 层
        self.attend = Attend(use_flash = use_flash)

        # 初始化将输入转换为查询向量的线性层
        self.to_q = nn.Linear(dim, dim_hidden, bias = False)
        # 初始化将输入转换为键值对向量的线性层
        self.to_kv = nn.Linear(dim, dim_hidden * 2, bias = False)
        # 初始化将输出转换为隐藏层向量的线性层
        self.to_out = nn.Linear(dim_hidden, dim, bias = False)

    # 前向传播方法
    def forward(
        self,
        x,
        context = None,
        mask = None,
        normalized_scores_kv = None,
        normalized_scores_q = None,
        rotary_emb: Optional[Tuple[Tensor, Tensor]] = None
        ):
            """
            einops:
            b - batch
            h - heads, or number of heads per route
            r - routing dimension, for routing different sets of key / values - should be more expressive
            n - sequence dimension
            d - head dimension
            i - input model dimension
            """

            # 获取输入张量 x 的 batch 大小和头数
            batch, h = x.shape[0], self.heads

            # 对输入张量 x 进行归一化处理
            x = self.norm(x)

            # 如果存在上下文张量 context,则对其进行归一化处理
            if exists(context):
                context = self.norm(context)

            # 如果不存在上下文张量,则将其设为输入张量 x
            context = default(context, x)

            # 如果上下文张量的维度为 3,则在第二维度上添加一个维度
            if context.ndim == 3:
                context = rearrange(context, 'b n d -> b 1 n d')

            # 如果存在归一化后的得分张量 normalized_scores_kv 且为 torch.Tensor 类型
            if exists(normalized_scores_kv) and isinstance(normalized_scores_kv, torch.Tensor):
                # 如果 normalized_scores_kv 的维度为 2,则在第二维度上添加一个维度
                if normalized_scores_kv.ndim == 2:
                    normalized_scores_kv = rearrange(normalized_scores_kv, 'b n -> b 1 n')

                # 重新排列 normalized_scores_kv 的维度
                normalized_scores_kv = rearrange(normalized_scores_kv, 'b r n -> b r 1 n 1')

            # 获取上下文张量的 key / value 路由数
            num_kv_routes = context.shape[1]

            # 获取查询张量 q
            q = self.to_q(x)
            q = rearrange(q, 'b n (h d) -> b h n d', h = h)

            # 如果存在归一化后的查询得分张量 normalized_scores_q 且为 torch.Tensor 类型
            if exists(normalized_scores_q) and isinstance(normalized_scores_q, torch.Tensor):
                # 将查询张量 q 乘以归一化后的查询得分张量 normalized_scores_q
                q = q * rearrange(normalized_scores_q, 'b n -> b 1 n 1')

            # 处理 key / value,使用路由维度,在路由之间分配头数
            assert divisible_by(h, num_kv_routes), 'number of heads must be divisible by the number of key / value routes'
            heads_per_route = h // num_kv_routes

            # 重新排列 key / value 权重张量的维度
            kv_weight = rearrange(self.to_kv.weight, '(r h d) i -> r h d i', h = heads_per_route, r = num_kv_routes)

            # 计算 key / value
            kv = einsum('r h d i, b r n i -> b r h n d', kv_weight, context)
            k, v = kv.chunk(2, dim = -1)

            # 如果存在归一化后的 key / value 得分张量
            if exists(normalized_scores_kv):
                # 将 value 乘以归一化后的 key / value 得分张量
                v = v * normalized_scores_kv

                # 如果需要将 key 乘以得分
                if self.multiply_keys_by_score:
                    k = k * normalized_scores_kv

            # 如果存在旋转嵌入
            if exists(rotary_emb):
                q_rotary_emb, k_rotary_emb = rotary_emb
                q = apply_rotary_pos_emb(q_rotary_emb, q)

                # 如果 k_rotary_emb 的维度为 4
                if k_rotary_emb.ndim == 4:
                    k_rotary_emb = repeat(k_rotary_emb, 'b 1 n d -> b r 1 n d', r = k.shape[1])

                k = apply_rotary_pos_emb(k_rotary_emb, k)

            # 合并 key / value 的路由维度和头数
            k, v = map(lambda t: rearrange(t, 'b r h n d -> b (r h) n d'), (k, v))

            # 空 key / value
            nk, nv = map(lambda t: repeat(t, 'h d -> b h 1 d', b = batch), self.null_kv)

            # 拼接 key / value
            k = torch.cat((nk, k), dim = -2)
            v = torch.cat((nv, v), dim = -2)

            # 掩码
            if exists(mask):
                if mask.ndim == 3:
                    mask = repeat(mask, 'b r j -> b (r h) 1 j', h = heads_per_route)
                else:
                    mask = rearrange(mask, 'b j -> b 1 1 j')

                mask = F.pad(mask, (1, 0), value = True)

            # 注意力
            out = self.attend(q, k, v, mask = mask)

            # 合并头数
            out = rearrange(out, 'b h n d -> b n (h d)')
            return self.to_out(out)
# 导入所需的模块和函数
from colt5_attention.coor_descent import coor_descent
# 定义一个命名元组,用于存储路由器返回的结果
RouterReturn = namedtuple('RouterReturn', ['indices', 'scores', 'routed_tokens', 'routed_mask'])

# 定义一个路由器类,实现坐标下降算法
class CoordinateDescentRouter(nn.Module):
    """
    from Wright et al. https://arxiv.org/abs/1502.04759
    then adopted by https://arxiv.org/abs/2211.01267 for multi-vector document retrieval by Qian et al
    finally, used successfully by this paper for routing to heavy branch attention / feedforward
    """

    def __init__(
        self,
        dim,
        straight_through = True,
        n_iters = 20,                   # 使用20次迭代,采用ε-scaling
        fetch_k_ratio = 9 / 8,          # 在论文中,稍微增加k(乘以这个比率)以获得更好的学习效果
        eps = 0.03,                     # 坐标下降的ε值。在最近的一篇论文中,文本使用0.03,语音使用1.0
        eps_decay = 0.7,
        eps_init = 4.,
        num_routing_tokens = 1,
        learned_routing_tokens = False,
        use_triton = False,
        cosine_sim_routing = False,
        cosine_sim_scale = 8,
        route_block_size = None,
        triton_checkpoint_segments = None # 是否将坐标下降重新计算为多个段,使用4和50次迭代,向后加速3倍,牺牲前向和一些内存以保存初始a和b
    ):
        super().__init__()
        assert fetch_k_ratio >= 1.

        self.n_iters = n_iters
        self.fetch_k_ratio = fetch_k_ratio

        self.coor_descent = coor_descent

        # 与ε-scaling相关的超参数

        self.eps = eps
        self.eps_decay = eps_decay
        self.eps_init = eps_init

        if use_triton:
            from colt5_attention.triton_coor_descent import triton_coor_descent
            triton_checkpoint_segments = default(triton_checkpoint_segments, n_iters // 5)
            self.coor_descent = partial(triton_coor_descent, checkpoint_segments = triton_checkpoint_segments)

        self.is_one_routing_token = num_routing_tokens == 1
        self.num_routing_tokens = num_routing_tokens

        self.route_block_size = route_block_size

        self.routing_token = nn.Parameter(torch.randn(num_routing_tokens, dim)) if not learned_routing_tokens else None
        self.straight_through = straight_through

        # 是否使用余弦相似度进行路由

        self.cosine_sim_routing = cosine_sim_routing
        self.cosine_sim_scale = cosine_sim_scale

    # 将路由后的结果还原到原始张量中
    def route_back(self, src, routed_tokens, indices):
        batch_range = create_batch_range(routed_tokens)
        src[batch_range, indices] = routed_tokens
        return src

    # 前向传播函数
    def forward(
        self,
        x,
        *,
        num_tokens,
        mask = None,
        random_route = False,
        routing_tokens = None,
        keep_one_route_dim = False  # 如果只有一个路由,是否保持维度
# 主要类

# 有条件的路由前馈网络
class ConditionalRoutedFeedForward(nn.Module):
    def __init__(
        self,
        dim,
        *,
        num_heavy_tokens,
        light_ff_mult = 0.5,
        heavy_ff_mult = 4,
        router_straight_through = True, # 确保所有归一化分数为1,仍可微分
        router_kwargs: dict = {},
        use_triton = False
    ):
        super().__init__()
        self.num_heavy_tokens = num_heavy_tokens

        if use_triton:
            router_kwargs = {**router_kwargs, 'use_triton': True}

        # 初始化路由器
        self.router = CoordinateDescentRouter(
            dim = dim,
            straight_through = router_straight_through,
            **router_kwargs
        )

        # 初始化轻量级前馈网络和重量级前馈网络
        self.light_ff = FeedForward(dim, light_ff_mult)
        self.heavy_ff = FeedForward(dim, heavy_ff_mult)

    # 前向传播函数
    def forward(
        self,
        x,
        mask = None,
        num_heavy_tokens = None
        ):
        # 获取输入张量的设备信息和重要令牌数量
        device, num_heavy_tokens = x.device, default(num_heavy_tokens, self.num_heavy_tokens)

        # 轻量级前馈网络看到所有令牌(隐藏维度仅为模型维度的1/2)
        light_out = self.light_ff(x)

        # 适当路由令牌到重型分支
        indices, normalized_scores, routed_tokens, _ = self.router(x, num_tokens=num_heavy_tokens, mask=mask)

        # 仅使用路由的令牌进行更重的分支
        routed_tokens_out = self.heavy_ff(routed_tokens) * rearrange(normalized_scores, '... -> ... 1')

        # 将重型前馈分支的输出散回
        if exists(indices):
            heavy_out = torch.zeros_like(x)
            heavy_out = self.router.route_back(heavy_out, routed_tokens_out, indices)
        else:
            heavy_out = routed_tokens_out

        # 将轻量级和重型分支相加并返回结果
        return light_out + heavy_out
class ConditionalRoutedAttention(nn.Module):
    # 定义一个条件路由注意力的类,继承自 nn.Module
    def __init__(
        self,
        dim,
        *,
        num_heavy_tokens_q,
        num_heavy_tokens_kv,
        num_routed_kv = 1,
        light_dim_head = 64,
        light_heads = 8,
        light_window_size = 128,        # 每个令牌左右各看 ~ 64 个令牌
        heavy_dim_head = 64,
        heavy_heads = 8,
        router_straight_through = True, # 确保所有归一化分数为 1,仍可微分
        router_kwargs: dict = {},
        multiply_keys_by_score = False,
        multiply_queries_by_score = False,
        use_triton = False,
        use_null_q_tokens = True,
        use_flash_attn = False,
        rotary_emb = False
    ):
        super().__init__()

        if use_triton:
            router_kwargs = {**router_kwargs, 'use_triton': True}

        self.num_heavy_tokens_q = num_heavy_tokens_q
        self.num_heavy_tokens_kv = num_heavy_tokens_kv

        self.multiply_queries_by_score = multiply_queries_by_score

        self.light_attn = LocalMHA(
            dim = dim,
            dim_head = light_dim_head,
            heads = light_heads,
            window_size = light_window_size // 2,
            prenorm = True,
            causal = False,
            use_rotary_pos_emb = False,
            look_backward = 1,
            look_forward = 1
        )

        self.null_q_token = None
        if use_null_q_tokens:
            self.null_q_token = nn.Parameter(torch.randn(dim)) # 为未被路由器选择的查询令牌提供一个学习到的输出嵌入

        self.q_router = CoordinateDescentRouter(
            dim = dim,
            straight_through = router_straight_through,
            **router_kwargs
        )

        self.kv_router = CoordinateDescentRouter(
            dim = dim,
            num_routing_tokens = num_routed_kv,
            straight_through = router_straight_through,
            **router_kwargs
        )

        self.heavy_attn = Attention(
            dim = dim,
            dim_head = heavy_dim_head,
            heads = heavy_heads,
            multiply_keys_by_score = multiply_keys_by_score,
            use_flash = use_flash_attn
        )

        # 旋转嵌入

        self.rotary_emb = RotaryEmbedding(heavy_dim_head) if rotary_emb else None

    def forward(
        self,
        x,
        *,
        num_heavy_tokens_q = None,
        num_heavy_tokens_kv = None,
        mask = None
        ):
        # 解包输入张量的批次大小、序列长度和设备信息
        batch, seq, device = *x.shape[:2], x.device

        # 设置查询和键值中的重要令牌数量,默认为模型中定义的数量
        num_heavy_tokens_q = default(num_heavy_tokens_q, self.num_heavy_tokens_q)
        num_heavy_tokens_kv = default(num_heavy_tokens_kv, self.num_heavy_tokens_kv)

        # 轻量级局部注意力机制查看有限上下文中的所有令牌

        light_out = self.light_attn(x, mask = mask)

        # 适当路由令牌以供重型分支使用

        indices_q, normalized_scores_q, routed_tokens_q, _ = self.q_router(x, num_tokens = num_heavy_tokens_q, mask = mask)
        indices_kv, normalized_scores_kv, routed_tokens_kv, routed_tokens_kv_mask = self.kv_router(x, num_tokens = num_heavy_tokens_kv, mask = mask)

        # 如果指定了旋转嵌入,则获取旋转嵌入

        rotary_emb = None

        if exists(self.rotary_emb):
            seq_rotary_emb = self.rotary_emb(seq)
            q_rotary_emb = rearrange(seq_rotary_emb[indices_q], 'b n d -> b 1 n d') if exists(indices_q) else seq_rotary_emb
            k_rotary_emb = rearrange(seq_rotary_emb[indices_kv], '... n d -> ... 1 n d') if exists(indices_kv) else seq_rotary_emb
            rotary_emb = (q_rotary_emb, k_rotary_emb)

        # 使用仅路由令牌的重型分支

        routed_tokens_out = self.heavy_attn(
            routed_tokens_q,
            mask = routed_tokens_kv_mask,
            context = routed_tokens_kv,
            rotary_emb = rotary_emb,
            normalized_scores_kv = normalized_scores_kv,
            normalized_scores_q = normalized_scores_q if self.multiply_queries_by_score else None
        )

        routed_tokens_out = routed_tokens_out * rearrange(normalized_scores_q, '... -> ... 1')

        # 将重型分支的输出散回

        if exists(indices_q):
            if exists(self.null_q_token):
                heavy_out = rearrange(self.null_q_token, 'd -> 1 1 d')
                heavy_out = heavy_out.expand_as(x).clone()
            else:
                heavy_out = torch.zeros_like(x)

            heavy_out = self.q_router.route_back(heavy_out, routed_tokens_out, indices_q)
        else:
            heavy_out = routed_tokens_out

        # 汇总轻量级和重量级分支的输出

        return light_out + heavy_out
# 定义一个条件路由的图像特征映射注意力模块
class ConditionalRoutedImageAttention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        num_heavy_tokens_q,
        num_heavy_tokens_kv,
        num_routed_kv = 1,
        light_dim_head = 64,
        light_heads = 8,
        light_window_size = 128,        # 每个令牌左右各看大约 64 个令牌
        heavy_dim_head = 64,
        heavy_heads = 8,
        router_straight_through = True, # 确保所有归一化分数为 1,仍然可微分
        router_kwargs: dict = {},
        multiply_keys_by_score = False,
        multiply_queries_by_score = False,
        use_triton = False,
        use_null_q_tokens = True,
        use_flash_attn = False,
        channel_first = False
    ):
        super().__init__()
        self.channel_first = channel_first

        # 如果使用 Triton,设置 router_kwargs 中的 'use_triton' 为 True
        if use_triton:
            router_kwargs = {**router_kwargs, 'use_triton': True}

        self.num_heavy_tokens_q = num_heavy_tokens_q
        self.num_heavy_tokens_kv = num_heavy_tokens_kv

        self.multiply_queries_by_score = multiply_queries_by_score

        self.light_window_size = light_window_size

        # 创建轻量级自注意力模块
        self.light_attn = SelfAttention(
            dim = dim,
            dim_head = light_dim_head,
            heads = light_heads,
            prenorm = True
        )

        self.null_q_token = None
        # 如果使用空查询令牌,为其创建一个学习到的输出嵌入
        if use_null_q_tokens:
            self.null_q_token = nn.Parameter(torch.randn(dim))

        # 创建查询路由器
        self.q_router = CoordinateDescentRouter(
            dim = dim,
            straight_through = router_straight_through,
            **router_kwargs
        )

        # 创建键值路由器
        self.kv_router = CoordinateDescentRouter(
            dim = dim,
            num_routing_tokens = num_routed_kv,
            straight_through = router_straight_through,
            **router_kwargs
        )

        # 创建重量级注意力模块
        self.heavy_attn = Attention(
            dim = dim,
            dim_head = heavy_dim_head,
            heads = heavy_heads,
            multiply_keys_by_score = multiply_keys_by_score,
            use_flash = use_flash_attn
        )

    def forward(
        self,
        x,
        *,
        num_heavy_tokens_q = None,
        num_heavy_tokens_kv = None,
        mask = None
        ):
        # 断言输入张量 x 的维度为 4
        assert x.ndim == 4
        # 获取输入张量 x 的批大小、设备信息、是否通道优先、光窗口大小
        batch, device, channel_first, w = x.shape[0], x.device, self.channel_first, self.light_window_size

        # 如果通道优先,则重新排列张量 x 的维度
        if channel_first:
            x = rearrange(x, 'b d ... -> b ... d')

        # 设置轻量级注意力机制中的重要令牌数量
        num_heavy_tokens_q = default(num_heavy_tokens_q, self.num_heavy_tokens_q)
        num_heavy_tokens_kv = default(num_heavy_tokens_kv, self.num_heavy_tokens_kv)

        # 轻量级局部注意力机制看到有限上下文中的所有令牌

        # 重新排列输入张量 x,以便进行轻量级注意力计算
        light_input = rearrange(x, 'b (h p1) (w p2) d -> b h w (p1 p2) d', p1 = w, p2 = w)
        x, ps = pack_one(light_input, '* n d')

        # 使用轻量级注意力机制计算输出
        light_out = self.light_attn(x)
        light_out = unpack_one(light_out, ps, '* n d')
        light_out = rearrange(light_out, 'b h w (p1 p2) d -> b (h p1) (w p2) d', p1 = w, p2 = w)

        # 为重型分支适当路由令牌

        # 使用查询路由器对输入张量 x 进行路由,获取相关信息
        indices_q, normalized_scores_q, routed_tokens_q, _ = self.q_router(x, num_tokens = num_heavy_tokens_q, mask = mask)
        # 使用键值路由器对输入张量 x 进行路由,获取相关信息
        indices_kv, normalized_scores_kv, routed_tokens_kv, routed_tokens_kv_mask = self.kv_router(x, num_tokens = num_heavy_tokens_kv, mask = mask)

        # 使用仅包含路由令牌的重型注意力机制进行计算

        routed_tokens_out = self.heavy_attn(
            routed_tokens_q,
            mask = routed_tokens_kv_mask,
            context = routed_tokens_kv,
            normalized_scores_kv = normalized_scores_kv,
            normalized_scores_q = normalized_scores_q if self.multiply_queries_by_score else None
        )

        routed_tokens_out = routed_tokens_out * rearrange(normalized_scores_q, '... -> ... 1')

        # 将重型分支的输出散回

        # 如果存在空查询令牌,则使用该令牌进行填充
        if exists(self.null_q_token):
            heavy_out = rearrange(self.null_q_token, 'd -> 1 1 d')
            heavy_out = heavy_out.expand_as(x).clone()
        else:
            heavy_out = torch.zeros_like(x)

        heavy_out = self.q_router.route_back(heavy_out, routed_tokens_out, indices_q)

        heavy_out = unpack_one(heavy_out, ps, '* n d')
        heavy_out = rearrange(heavy_out, 'b h w (p1 p2) d -> b (h p1) (w p2) d', p1 = w, p2 = w)

        # 将轻量级和重型分支的输出相加

        out = light_out + heavy_out

        # 如果通道优先,则重新排列输出张量的维度
        if channel_first:
            out = rearrange(out, 'b ... d -> b d ...')

        # 返回最终输出
        return out
# 定义条件路由的自回归注意力模块
class ConditionalRoutedAutoregressiveAttention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        num_heavy_tokens_q,
        num_heavy_tokens_kv,
        num_routed_kv = 1,
        light_dim_head = 64,
        light_heads = 8,
        light_window_size = 128,        # 每个标记左右各看到 ~ 64 个标记
        heavy_window_size = None,
        heavy_dim_head = 64,
        heavy_heads = 8,
        router_straight_through = True, # 确保所有归一化分数为 1,仍可微分
        router_kwargs: dict = {},
        multiply_keys_by_score = False,
        multiply_queries_by_score = False,
        use_triton = False,
        use_null_q_tokens = True,
        use_flash_attn = False,
        rotary_emb = False
    ):
        super().__init__()

        if use_triton:
            router_kwargs = {**router_kwargs, 'use_triton': True}

        self.num_heavy_tokens_q = num_heavy_tokens_q
        self.num_heavy_tokens_kv = num_heavy_tokens_kv

        self.multiply_queries_by_score = multiply_queries_by_score

        self.heavy_window_size = default(heavy_window_size, light_window_size)

        self.light_attn = LocalMHA(
            dim = dim,
            dim_head = light_dim_head,
            heads = light_heads,
            window_size = light_window_size,
            prenorm = True,
            causal = True,
            exact_windowsize = False,
            use_rotary_pos_emb = False
        )

        self.null_q_token = None
        if use_null_q_tokens:
            self.null_q_token = nn.Parameter(torch.randn(dim)) # 为未被路由器选择的查询标记提供一个学习到的输出嵌入

        self.q_router = CoordinateDescentRouter(
            dim = dim,
            straight_through = router_straight_through,
            **router_kwargs
        )

        self.kv_router = CoordinateDescentRouter(
            dim = dim,
            num_routing_tokens = num_routed_kv,
            straight_through = router_straight_through,
            **router_kwargs
        )

        self.heavy_attn = Attention(
            dim = dim,
            dim_head = heavy_dim_head,
            heads = heavy_heads,
            multiply_keys_by_score = multiply_keys_by_score,
            use_flash = use_flash_attn
        )

        # 旋转嵌入

        self.rotary_emb = RotaryEmbedding(heavy_dim_head) if rotary_emb else None

    def forward(
        self,
        x,
        *,
        num_heavy_tokens_q = None,
        num_heavy_tokens_kv = None,
        random_route = False
# 调整条件路由的自注意力以适应交叉注意力

# 定义条件路由的交叉注意力模块
class ConditionalRoutedCrossAttention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        num_tokens_q,
        num_tokens_kv,
        num_sets_kv = 1,                # 如果设置大于 1,将路由多组键/值,每组大小为 num_tokens_kv,使用这么多路由标记
        dim_head = 64,
        heads = 8,
        router_straight_through = True, # 确保所有归一化分数为 1,仍可微分
        router_kwargs: dict = {},
        kv_routing_tokens = 1,
        multiply_keys_by_score = False,
        use_triton = False,
        use_null_q_tokens = True,
        use_flash_attn = False,
        route_block_size = None
    ):
        super().__init__()

        if use_triton:
            router_kwargs = {**router_kwargs, 'use_triton': True}

        self.num_tokens_q = num_tokens_q
        self.num_tokens_kv = num_tokens_kv

        self.null_q_token = None
        if use_null_q_tokens:
            self.null_q_token = nn.Parameter(torch.randn(dim)) # 为未被路由器选择的查询标记提供一个学习到的输出嵌入

        self.q_router = CoordinateDescentRouter(
            dim = dim,
            straight_through = router_straight_through,
            **router_kwargs
        )

        self.kv_router = CoordinateDescentRouter(
            dim = dim,
            straight_through = router_straight_through,
            num_routing_tokens = kv_routing_tokens,
            route_block_size = route_block_size,
            **router_kwargs
        )

        self.heavy_attn = Attention(
            dim = dim,
            dim_head = dim_head,
            heads = heads,
            multiply_keys_by_score = multiply_keys_by_score,
            use_flash = use_flash_attn
        )

    def forward(
        self,
        x,
        context,
        *,
        num_tokens_q = None,
        num_tokens_kv = None,
        mask = None,
        context_mask = None
    ):
        batch, device = x.shape[0], x.device

        # route the queries

        query_length = x.shape[-2]
        num_tokens_q = default(num_tokens_q, self.num_tokens_q)

        indices_q, normalized_scores_q, routed_tokens_q, _ = self.q_router(x, num_tokens = num_tokens_q, mask = mask)

        # route the long contexts

        key_value_length = context.shape[-2]
        num_tokens_kv = default(num_tokens_kv, self.num_tokens_kv)

        routed_tokens_kv = context
        routed_tokens_kv_mask = context_mask
        normalized_scores_kv = None

        should_route_kv = key_value_length > num_tokens_kv

        if should_route_kv:
            indices_kv, normalized_scores_kv, routed_tokens_kv, routed_tokens_kv_mask = self.kv_router(context, num_tokens = num_tokens_kv, mask = context_mask)

        # do the heavier branch with only routed tokens

        routed_tokens_out = self.heavy_attn(
            routed_tokens_q,
            mask = routed_tokens_kv_mask,
            context = routed_tokens_kv,
            normalized_scores_kv = normalized_scores_kv
        )

        if should_route_queries:
            routed_tokens_out = routed_tokens_out * rearrange(normalized_scores_q, '... -> ... 1')

        # early return if queries did not undergo routing

        if not should_route_queries:
            return routed_tokens_out

        # otherwise, scatter back the query outputs

        if exists(self.null_q_token):
            out = rearrange(self.null_q_token, 'd -> 1 1 d')
            out = out.expand_as(x).clone()
        else:
            out = torch.zeros_like(x)

        if exists(indices_q):
            out = self.q_router.route_back(out, routed_tokens_out, indices_q)

        return out
# 定义一个名为 ConditionalRoutedTransformerBlock 的类,继承自 nn.Module
class ConditionalRoutedTransformerBlock(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        dim,
        *,
        num_heavy_attn_tokens_q,
        num_heavy_attn_tokens_kv,
        num_routed_kv = 1,
        num_heavy_ff_tokens,
        light_dim_head = 64,
        light_heads = 8,
        light_window_size = 128,
        heavy_dim_head = 64,
        heavy_heads = 8,
        light_ff_mult = 0.5,
        heavy_ff_mult = 4,
        router_straight_through = True,
        router_kwargs: dict = {},
        multiply_keys_by_score = False,
        multiply_queries_by_score = False,
        use_triton = False,
        use_null_q_tokens = True,
        use_flash_attn = False
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 创建 ConditionalRoutedFeedForward 对象并赋值给 self.conditional_ff
        self.conditional_ff = ConditionalRoutedFeedForward(
            dim,
            num_heavy_tokens = num_heavy_ff_tokens,
            light_ff_mult = light_ff_mult,
            heavy_ff_mult = heavy_ff_mult,
            router_straight_through = router_straight_through,
            router_kwargs = router_kwargs,
            use_triton = use_triton
        )

        # 创建 ConditionalRoutedAttention 对象并赋值给 self.conditional_attn
        self.conditional_attn = ConditionalRoutedAttention(
            dim,
            light_dim_head = light_dim_head,
            light_heads = light_heads,
            light_window_size = light_window_size,
            heavy_dim_head = heavy_dim_head,
            heavy_heads = heavy_heads,
            num_heavy_tokens_q = num_heavy_attn_tokens_q,
            num_heavy_tokens_kv = num_heavy_attn_tokens_kv,
            num_routed_kv = num_routed_kv,
            router_straight_through = router_straight_through,
            router_kwargs = router_kwargs,
            multiply_keys_by_score = multiply_keys_by_score,
            multiply_queries_by_score = multiply_queries_by_score,
            use_triton = use_triton,
            use_null_q_tokens = use_null_q_tokens,
            use_flash_attn = use_flash_attn
        )

    # 前向传播函数,接受多个参数
    def forward(
        self,
        x,
        mask = None,
        num_heavy_attn_tokens_q = None,
        num_heavy_attn_tokens_kv = None,
        num_heavy_ff_tokens = None
    ):
        # 调用 self.conditional_attn 进行注意力计算,并将结果与输入 x 相加
        x = self.conditional_attn(x, mask = mask, num_heavy_tokens_q = num_heavy_attn_tokens_q, num_heavy_tokens_kv = num_heavy_attn_tokens_kv) + x
        # 调用 self.conditional_ff 进行前馈计算,并将结果与输入 x 相加
        x = self.conditional_ff(x, mask = mask, num_heavy_tokens = num_heavy_ff_tokens) + x
        # 返回计算结果
        return x

.\lucidrains\CoLT5-attention\colt5_attention\triton_coor_descent.py

# 从 math 模块中导入 log 函数
from math import log

# 导入 torch 模块及相关类和函数
import torch
from torch import Tensor
from torch import autograd
import torch.nn.functional as F
from torch.cuda.amp import autocast, custom_fwd, custom_bwd

# 从 colt5_attention 模块中导入 coor_descent 函数
from colt5_attention.coor_descent import coor_descent
# 从 einops 模块中导入 pack、unpack、repeat 函数
from einops import pack, unpack, repeat

# 尝试导入 triton 模块及相关类和函数
try:
    import triton
    import triton.language as tl
except ImportError as e:
    # 如果导入失败,则打印提示信息
    print('triton is not installed, please install by running `pip install triton -U --pre`')
    # 退出程序
    exit()

# 确保使用的是最新版本的 triton

# 导入版本模块,用于比较 triton 版本
from packaging import version
# 断言 triton 版本大于等于 '2.0'
assert version.parse(triton.__version__) >= version.parse('2.0')

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 计算块大小对应的 warp 数量
def calc_num_warps(block_size):
    num_warps = 4
    if block_size >= 2048:
        num_warps = 8
    if block_size >= 4096:
        num_warps = 16
    return num_warps

# 将张量按照指定模式进行打包
def pack_one(t, pattern):
    return pack([t], pattern)

# 将打包后的张量按照指定模式进行解包
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]    

# 将数字分成指定组数
def num_to_groups(num, groups):
    assert 0 < groups <= num
    floor = num // groups
    remainder = num % groups
    out = []
    for ind in range(groups):
        out.append(floor + int(ind < remainder))
    assert sum(out) == num
    return out

# 前向传播

# 定义前向传播的 Triton 内核函数
@triton.jit
def coor_descent_kernel_forward(
    a_ptr,
    b_ptr,
    input_ptr,
    mask_ptr,
    k_ptr,
    a_iter_stride,
    b_row_stride,
    b_iter_stride,
    input_row_stride,
    mask_row_stride,
    n_iters,
    current_eps,
    eps_decay,
    eps,
    n_cols,
    BLOCK_SIZE: tl.constexpr
):
    row_idx = tl.program_id(0)
    col_offsets = tl.arange(0, BLOCK_SIZE)
    col_mask = col_offsets < n_cols

    # 加载 mask 作为整数(因为布尔值会导致 Triton 出错)

    mask_start_ptr = mask_ptr + row_idx * mask_row_stride
    mask_ptrs = mask_start_ptr + col_offsets

    mask_ints = tl.load(mask_ptrs, mask = col_mask, other = 0)
    mask = mask_ints == 1

    # 加载 a 和 b

    a_ptr = a_ptr + row_idx
    a = tl.load(a_ptr)

    b_start_ptr = b_ptr + row_idx * b_row_stride
    b_ptrs = b_start_ptr + col_offsets
    b = tl.load(b_ptrs, mask = col_mask, other = 0)

    # 加载得分 s

    row_start_ptr = input_ptr + row_idx * input_row_stride
    input_ptrs = row_start_ptr + col_offsets
    s = tl.load(input_ptrs, mask = mask, other = -float('inf'))

    # 加载 k - 控制输出的稀疏性

    k_ptr = k_ptr + row_idx
    k = tl.load(k_ptr)

    # 初始化一些常数

    logk = tl.log(k)

    for _ in range(n_iters):        

        a = (s + b) / current_eps
        a = tl.where(mask, a, -float('inf'))

        # 稳定的对数求和指数

        a_max = tl.max(a, axis = 0)
        a_minus_max = tl.where(mask, a - a_max, -float('inf'))
        exp = tl.exp(a_minus_max)
        sum_exp = tl.sum(exp, axis = 0)
        log_sum_exp = tl.log(sum_exp) + a_max

        a = current_eps * (logk - log_sum_exp)

        # 更新 b

        b = s + a
        b = tl.where(b >= 0., -b, 0.)

        # 衰减 epsilon,从 epsilon 缩放

        current_eps *= eps_decay

        if current_eps < eps:
            current_eps = eps

    # 存储 a 和 b 以备下一轮使用

    next_a_ptrs = a_ptr + a_iter_stride
    next_b_ptrs = b_ptrs + b_iter_stride

    tl.store(next_a_ptrs, a)
    tl.store(next_b_ptrs, b, mask = col_mask)

# 反向传播

# 定义反向传播的 Triton 内核函数
@triton.jit
def coor_descent_kernel_backward(
    dk_ptr,
    input_ptr,
    a_ptr,
    b_ptr,
    mask_ptr,
    ds_ptr,
    db_ptr,
    k_ptr,
    last_da_ptr,
    input_row_stride,
    b_row_stride,
    mask_row_stride,
    ds_row_stride,
    db_row_stride,
    n_iters,
    eps_init,
    eps_decay,
    eps,
    n_cols,
    BLOCK_SIZE: tl.constexpr
):
    row_idx = tl.program_id(0)
    col_offsets = tl.arange(0, BLOCK_SIZE)

    # 加载和生成 mask

    col_mask = col_offsets < n_cols

    # 加载 mask 作为整数(因为布尔值会导致 Triton 出错)

    mask_start_ptr = mask_ptr + row_idx * mask_row_stride
    # 计算掩码指针
    mask_ptrs = mask_start_ptr + col_offsets

    # 从指定位置加载整数值
    mask_ints = tl.load(mask_ptrs, mask = col_mask, other = 0)
    # 创建布尔掩码
    mask = mask_ints == 1

     # 加载 a 和 b

    # 更新 a 指针
    a_ptr = a_ptr + row_idx
    # 加载初始值 a
    init_a = tl.load(a_ptr)

    # 更新 b 起始指针
    b_start_ptr = b_ptr + row_idx * b_row_stride
    # 计算 b 指针
    b_ptrs = b_start_ptr + col_offsets
    # 加载初始值 b
    init_b = tl.load(b_ptrs, mask = mask, other = 0)

    # 加载输入

    # 更新行起始指针
    row_start_ptr = input_ptr + row_idx * input_row_stride
    # 计算输入指针
    input_ptrs = row_start_ptr + col_offsets
    # 加载输入值
    s = tl.load(input_ptrs, mask = mask, other = -float('inf'))

    # 加载 k - 控制输出的稀疏性

    # 更新 k 指针
    k_ptr = k_ptr + row_idx
    # 加载 k 值
    k = tl.load(k_ptr)
    # 计算 k 的自然对数
    logk = tl.log(k)

    # 加载上一个 da

    # 更新上一个 da 指针
    last_da_ptr = last_da_ptr + row_idx
    # 加载上一个 da 值
    last_da = tl.load(last_da_ptr)

    # 加载初始 ds

    # 更新 ds 行起始指针
    ds_row_start_ptr = ds_ptr + row_idx * ds_row_stride
    # 计算 ds 指针
    ds_ptrs = ds_row_start_ptr + col_offsets
    # 加载初始 ds 值
    ds = tl.load(ds_ptrs, mask = mask, other = 0.)

    # 加载初始 db

    # 更新 db 行起始指针
    db_row_start_ptr = db_ptr + row_idx * db_row_stride
    # 计算 db 指针
    db_ptrs = db_row_start_ptr + col_offsets
    # 加载初始 db 值
    db = tl.load(db_ptrs, mask = mask, other = 0.)

    # 加载初始 dk

    # 更新 dk 指针
    dk_ptr = dk_ptr + row_idx
    # 加载 dk 值
    dk = tl.load(dk_ptr)

    # 反向传播

    for ind in range(n_iters):
        a = init_a
        b = init_b

        sa = s * 0
        softmax = s * 0

        # 计算 epsilon

        current_eps = eps_init / eps_decay

        # 重新计算

        for _ in range(n_iters - ind):
            # 更新 epsilon

            current_eps *= eps_decay

            if current_eps < eps:
                current_eps = eps

            # 更新 a

            sb = (s + b) / current_eps
            sb = tl.where(mask, sb, -float('inf'))

            # 稳定的对数求和指数

            sb_max = tl.max(sb, axis = 0)
            sb_minus_max = tl.where(mask, sb - sb_max, -float('inf'))
            exp = tl.exp(sb_minus_max)
            sum_exp = tl.sum(exp, axis = 0)
            softmax = exp / sum_exp
            log_sum_exp = tl.log(sum_exp) + sb_max

            a = current_eps * (logk - log_sum_exp)

            # 更新 b

            sa = s + a
            b = tl.where(sa > 0., -sa, 0.)

        # 向后传播

        dsa = db * tl.where(sa > 0, -1., 0.)

        ds += dsa

        da = tl.sum(dsa, axis = 0) + last_da

        dk += da * current_eps

        dsb = da * -softmax

        ds += dsb
        db = dsb

        last_da *= 0.

    # 存储 dk

    tl.store(dk_ptr, dk)

    # 存储 ds

    tl.store(ds_ptrs, ds, mask = col_mask)

    # 存储 db

    tl.store(db_ptrs, db, mask = col_mask)
# 定义一个继承自autograd.Function的类_coor_descent,用于实现坐标下降算法
class _coor_descent(autograd.Function):
    # 前向传播函数
    @staticmethod
    @custom_fwd
    def forward(
        ctx,
        x,
        n_iters,
        k,
        eps,
        eps_init,
        eps_decay,
        mask,
        checkpoint_segments
    ):
        # 断言迭代次数大于0
        assert n_iters > 0
        # 断言输入张量在CUDA上
        assert x.is_cuda, 'triton coordinate descent must be on cuda'

        # 获取输入张量的批大小、是否需要梯度、设备和数据类型
        batch, requires_grad, device, dtype = x.shape[0], x.requires_grad, x.device, x.dtype

        # 如果mask不存在,则创建一个与x相同形状的全1张量
        if not exists(mask):
            mask = torch.ones_like(x, dtype=torch.bool, device=x.device)

        # 将x和mask打包成一维张量
        x, shape = pack_one(x, '* n')
        mask, _ = pack_one(mask, '* n')

        # 将x中mask为False的元素替换为最小值
        x = x.masked_fill(~mask, -torch.finfo(x.dtype).max)
        mask_ints = mask.int()

        epsilons = []
        eps_init = default(eps_init, eps)
        current_eps = float(max(eps_init, eps))

        n_rows, n_cols = x.shape

        # 如果k是整数或浮点数,则创建一个全为k的张量
        if isinstance(k, (int, float)):
            k = torch.full((n_rows,), k)

        # 断言k的元素数量与行数相同
        assert k.numel() == n_rows

        k = k.to(x)

        BLOCK_SIZE = triton.next_power_of_2(n_cols)

        # 断言BLOCK_SIZE小于等于131072
        assert BLOCK_SIZE <= 131072, 'the maximum block size allowed is 131072 for triton cuda kernel - set the `route_block_size` for the CoordinateDescentRouter to be this value or less in order to uniformly route to get around this limitation'

        num_warps = calc_num_warps(BLOCK_SIZE)

        checkpointed_a = torch.empty((checkpoint_segments + 1, n_rows), device=device, dtype=dtype)
        checkpointed_b = torch.empty((checkpoint_segments + 1, n_rows, n_cols), device=device, dtype=dtype)

        checkpointed_a[0] = torch.zeros_like(k)
        checkpointed_b[0] = -x

        for ind, segment_iters in enumerate(num_to_groups(n_iters, checkpoint_segments)):
            is_last = ind == (checkpoint_segments - 1)

            epsilons.append(current_eps)

            # 调用CUDA核函数进行坐标下降计算
            coor_descent_kernel_forward[(n_rows,)](
                checkpointed_a[ind],
                checkpointed_b[ind],
                x,
                mask_ints,
                k,
                checkpointed_a.stride(0),
                n_cols,
                checkpointed_b.stride(0),
                x.stride(0),
                mask_ints.stride(0),
                segment_iters,
                current_eps,
                eps_decay,
                eps,
                n_cols,
                num_warps=num_warps,
                BLOCK_SIZE=BLOCK_SIZE,
            )

            current_eps *= (eps_decay ** segment_iters)
            current_eps = max(current_eps, eps)

        last_a, last_b = map(lambda t: t[-1], (checkpointed_a, checkpointed_b))
        y = torch.exp((last_a[..., None] + last_b + x) / current_eps)

        epsilons.append(current_eps)

        if requires_grad:
            checkpointed_a = checkpointed_a[:-1]
            checkpointed_b = checkpointed_b[:-1]

            ctx.args = (n_iters, checkpoint_segments, epsilons, eps_decay, eps)
            ctx.save_for_backward(x, y, k, mask, checkpointed_a, checkpointed_b)

        y = unpack_one(y, shape, '* n')

        return y

    # 反向传播函数
    @staticmethod
    @custom_bwd
    def backward(
        ctx,
        grad_probs
    ):
        # 断言梯度概率是否在 GPU 上
        assert grad_probs.is_cuda

        # 获取批量大小
        batch = grad_probs.shape[0]

        # 从上下文中获取参数
        n_iters, checkpoint_segments, epsilons, eps_decay, eps = ctx.args
        x, y, k, mask, checkpointed_a, checkpointed_b = ctx.saved_tensors

        # 将梯度概率打包成指定形状
        grad_probs, shape = pack_one(grad_probs, '* n')

        # 如果存在掩码,则将梯度概率中的非掩码部分置零
        if exists(mask):
            grad_probs = grad_probs.masked_fill(~mask, 0.)

        # 获取梯度概率的行数和列数
        n_rows, n_cols = grad_probs.shape

        # 计算块大小
        BLOCK_SIZE = triton.next_power_of_2(n_cols)
        num_warps = calc_num_warps(BLOCK_SIZE)

        # 解包 epsilon 值
        *epsilons, last_eps = epsilons

        # 计算 ds, db, dk, last_da
        ds = grad_probs * y / last_eps
        db = ds.clone()
        dk = torch.zeros_like(k)
        last_da = ds.sum(dim=-1)

        # 将掩码转换为整数类型
        mask_int = mask.int()

        # 使用 zip 函数将多个迭代器的元素打包成元组
        items = zip(
            reversed(checkpointed_a.unbind(dim=0)),
            reversed(checkpointed_b.unbind(dim=0)),
            reversed(num_to_groups(n_iters, checkpoint_segments)),
            reversed(epsilons)
        )

        # 遍历 items 中的元素
        for ind, (init_a, init_b, segment_iters, eps_init) in enumerate(items):
            is_first = ind == 0

            # 调用 coor_descent_kernel_backward 函数
            coor_descent_kernel_backward[(n_rows,)](
                dk,
                x,
                init_a,
                init_b,
                mask_int,
                ds,
                db,
                k,
                last_da if is_first else torch.zeros_like(last_da),
                x.stride(0),
                init_b.stride(0),
                mask_int.stride(0),
                ds.stride(0),
                db.stride(0),
                segment_iters,
                eps_init,
                eps_decay,
                eps,
                n_cols,
                num_warps=num_warps,
                BLOCK_SIZE=BLOCK_SIZE
            )

        # 更新 ds
        ds += -db
        ds = unpack_one(ds, shape, '* n')

        # 如果 k 不需要梯度,则将 dk 置为 None
        if not k.requires_grad:
            dk = None
        else:
            dk /= k

        # 返回结果
        return ds, None, dk, None, None, None, None, None
# 禁用自动类型转换的装饰器
@autocast(enabled = False)
# Triton 坐标下降算法
def triton_coor_descent(
    s,  # 输入张量
    *,
    n_iters,  # 迭代次数
    k,  # 参数 k
    eps = 1e-1,  # 精度参数,默认为 0.1
    eps_init = None,  # 初始精度参数
    eps_decay = 1.,  # 精度参数衰减率
    mask = None,  # 掩码
    checkpoint_segments = 1  # 检查点段数
):
    # 如果输入张量不在 CUDA 上,则使用普通的坐标下降算法
    if not s.is_cuda:
        return coor_descent(s, n_iters = n_iters, k = k, eps = eps, eps_init = eps_init, eps_decay = eps_decay, mask = mask)

    # 在 CUDA 上使用自定义的坐标下降算法
    return _coor_descent.apply(s, n_iters, k, eps, eps_init, eps_decay, mask, checkpoint_segments)

.\lucidrains\CoLT5-attention\colt5_attention\vit.py

import torch
from torch import nn

from einops import rearrange, pack, unpack
from einops.layers.torch import Rearrange, Reduce

from colt5_attention.transformer_block import (
    ConditionalRoutedImageAttention,
    ConditionalRoutedFeedForward
)

# helpers

# 定义一个函数,如果输入参数是元组则返回元组,否则返回元组包含输入参数的元组
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# 定义一个函数,生成二维位置编码的正弦和余弦值
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
    # 获取 patches 的形状信息
    _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype

    # 生成网格坐标
    y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
    # 确保特征维度是4的倍数
    assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
    # 计算 omega 值
    omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
    omega = 1. / (temperature ** omega)

    # 计算位置编码
    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :] 
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
    pe = pe.type(dtype)
    return rearrange(pe, '(h w) d -> h w d', h = h, w = w)

# classes

# 定义一个 Transformer 类
class Transformer(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        attn_num_heavy_tokens_q,
        attn_num_heavy_tokens_kv,
        attn_light_dim_head,
        attn_light_heads,
        attn_light_window_size,
        attn_heavy_dim_head,
        attn_heavy_heads,
        ff_num_heavy_tokens,
        ff_light_mult,
        ff_heavy_mult,
        router_straight_through = True,
        router_kwargs: dict = {},
        router_use_triton = False,
        flash_attn = True,
        attn_num_routed_kv = 1
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):

            # 创建 ConditionalRoutedFeedForward 实例
            ff = ConditionalRoutedFeedForward(
                dim,
                num_heavy_tokens = ff_num_heavy_tokens,
                light_ff_mult = ff_light_mult,
                heavy_ff_mult = ff_heavy_mult,
                router_straight_through = router_straight_through,
                router_kwargs = router_kwargs,
                use_triton = router_use_triton
            )

            # 创建 ConditionalRoutedImageAttention 实例
            attn = ConditionalRoutedImageAttention(
                dim,
                num_heavy_tokens_q = attn_num_heavy_tokens_q,
                num_heavy_tokens_kv = attn_num_heavy_tokens_kv,
                num_routed_kv = attn_num_routed_kv,
                light_dim_head = attn_light_dim_head,
                light_heads = attn_light_heads,
                light_window_size = attn_light_window_size,
                heavy_dim_head = attn_heavy_dim_head,
                heavy_heads = attn_heavy_heads,
                router_straight_through = router_straight_through,
                router_kwargs = router_kwargs,
                use_triton = router_use_triton,
                use_flash_attn = flash_attn,
                channel_first = False,
                use_null_q_tokens = True
            )

            self.layers.append(nn.ModuleList([attn, ff]))

    # 前向传播函数
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x

            x, ps = pack([x], 'b * d')
            x = ff(x) + x            
            x, = unpack(x, ps, 'b * d')

        return x

# 定义一个 ConditionalRoutedViT 类
class ConditionalRoutedViT(nn.Module):
    def __init__(
        self,
        *,
        image_size,
        patch_size,
        num_classes,
        dim,
        depth,
        attn_num_heavy_tokens_q,
        attn_num_heavy_tokens_kv,
        attn_heavy_dim_head,
        attn_heavy_heads,
        attn_light_dim_head,
        attn_light_heads,
        attn_light_window_size,
        ff_num_heavy_tokens,
        ff_heavy_mult,
        ff_light_mult,
        channels = 3,
        router_straight_through = True,
        router_kwargs: dict = {},
        router_use_triton = False,
        flash_attn = True,
        attn_num_routed_kv = 1,
        default_coor_descent_eps = 1.
    # 定义一个继承自 nn.Module 的类,用于实现图像的分块处理和Transformer处理
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 获取图像的高度和宽度
        image_height, image_width = pair(image_size)
        # 获取分块的高度和宽度
        patch_height, patch_width = pair(patch_size)

        # 断言图像的高度和宽度能够被分块的高度和宽度整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # 计算分块的数量
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 计算每个分块的维度
        patch_dim = channels * patch_height * patch_width

        # 定义一个序列模块,用于将图像分块转换为嵌入向量
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        # 设置路由器参数,包括epsilon值
        router_kwargs = {'eps': default_coor_descent_eps, **router_kwargs}

        # 创建Transformer模块
        self.transformer = Transformer(
            dim,
            depth,
            attn_num_heavy_tokens_q,
            attn_num_heavy_tokens_kv,
            attn_light_dim_head,
            attn_light_heads,
            attn_light_window_size,
            attn_heavy_dim_head,
            attn_heavy_heads,
            ff_num_heavy_tokens,
            ff_light_mult,
            ff_heavy_mult,
            router_straight_through,
            router_kwargs,
            router_use_triton,
            flash_attn,
            attn_num_routed_kv
        )

        # 定义一个线性头部模块,用于分类
        self.linear_head = nn.Sequential(
            Reduce('b h w c -> b c', 'mean'),
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    # 前向传播函数
    def forward(self, img):
        # 获取图像的高度、宽度和数据类型
        *_, h, w, dtype = *img.shape, img.dtype

        # 将图像转换为嵌入向量
        x = self.to_patch_embedding(img)
        # 添加位置编码
        x = x + posemb_sincos_2d(x)        

        # 使用Transformer处理嵌入向量
        x = self.transformer(x)

        # 使用线性头部进行分类
        return self.linear_head(x)

.\lucidrains\CoLT5-attention\colt5_attention\__init__.py

# 从 colt5_attention.transformer_block 模块中导入以下类:
# ConditionalRoutedFeedForward:有条件路由的前馈网络
# ConditionalRoutedAttention:有条件路由的注意力机制
# ConditionalRoutedImageAttention:有条件路由的图像注意力机制
# ConditionalRoutedAutoregressiveAttention:有条件路由的自回归注意力机制
# ConditionalRoutedCrossAttention:有条件路由的交叉注意力机制
# ConditionalRoutedTransformerBlock:有条件路由的Transformer块
# CoordinateDescentRouter:坐标下降路由器

from colt5_attention.coor_descent 模块中导入 coor_descent 函数

from colt5_attention.topk 模块中导入 topk 函数

# 从 colt5_attention.vit 模块中导入 ConditionalRoutedViT 类

CoLT5 Attention - Pytorch

Implementation of the conditionally routed efficient attention in the proposed CoLT5 architecture, in Pytorch.

They used coordinate descent from this paper (main algorithm originally from Wright et al) to route a subset of tokens for 'heavier' branches of the feedforward and attention blocks.

Update: unsure of how the routing normalized scores for the key-values are used. Did some improvising there, scaling the projected values, but if you think you know the answer, please open an issue

Update 2: seems to work well with the improvisation above

Appreciation

  • Stability.ai for the generous sponsorship to work on cutting edge artificial intelligence research

  • einops for making my life easy

  • Triton for allowing me to speed up coordinate descent with a fused implementation in just 2 days, sparing me from having to write a thousand lines of CUDA code

Install

$ pip install colt5-attention

Usage

import torch

from colt5_attention import (
    ConditionalRoutedFeedForward,
    ConditionalRoutedAttention,
    ConditionalRoutedTransformerBlock
)

# mock input, say it is 32768 length

tokens = torch.randn(2, 32768, 512)
mask = torch.ones(2, 32768).bool()  # can handle variable lengthed sequences

# feedforward

ff = ConditionalRoutedFeedForward(
    dim = 512,
    light_ff_mult = 0.5,      # hidden dimension ratio of light branch
    heavy_ff_mult = 4,        # hidden dimension ratio of heavy branch
    num_heavy_tokens = 1024   # heavy branch receives only 1024 routed tokens of 32768
)

ff_out = ff(tokens, mask = mask)  # (2, 32768, 512) - light and heavy branch summed

# attention

attn = ConditionalRoutedAttention(
    dim = 512,
    light_dim_head = 64,       # attention head dimension of light branch
    light_heads = 8,           # number of attention heads for light branch
    light_window_size = 128,   # local attention receptive field for light
    heavy_dim_head = 64,       # attention head dimension of heavy branch
    heavy_heads = 8,           # number of attention heads for heavy branch
    num_heavy_tokens_q = 1024, # heavy branch receives only 1024 routed tokens of 32768
    num_heavy_tokens_kv = 1024 # heavy branch receives only 1024 routed tokens of 32768
)

attn_out = attn(tokens, mask = mask) # (2, 32768, 512) - light and heavy branch summed

# both attention and feedforward with residual
# the complete transformer block
# a stack of these would constitute the encoder of CoLT5

block = ConditionalRoutedTransformerBlock(
    dim = 512,
    light_dim_head = 64,
    light_heads = 8,
    light_window_size = 128,
    heavy_dim_head = 64,
    heavy_heads = 8,
    light_ff_mult = 0.5,
    heavy_ff_mult = 4,
    num_heavy_ff_tokens = 1024,
    num_heavy_attn_tokens_q = 1024,
    num_heavy_attn_tokens_kv = 1024
)

block_out = block(tokens, mask = mask) # (2, 32768, 512)

Also included a variation of the conditionally routed attention for cross attention, to be tried with long context memories in a transformer-xl

import torch
from colt5_attention import ConditionalRoutedCrossAttention

# mock input, let us say it is a transformer of 1024 length attending to 1 million context past memories

tokens = torch.randn(1, 1024, 512).cuda()
tokens_mask = torch.ones(1, 1024).bool().cuda()

memories = torch.randn(1, 1_048_576, 512).cuda()
memories_mask = torch.ones(1, 1_048_576).bool().cuda()

# conditionally routed cross attention

cross_attn = ConditionalRoutedCrossAttention(
    dim = 512,
    dim_head = 64,
    heads = 8,
    num_tokens_q = 512,         # only 512 routed from 1024
    num_tokens_kv = 1024,       # only 1024 routed from 1 million
    kv_routing_tokens = 2,      # say you want 2 routing tokens to route different sets of key / values to the queries. 4 attention heads will be allocated to each routed set in this example (8 / 2)
    use_triton = True,          # use cuda kernel
    route_block_size = 131072   # route in blocks of 131072
).cuda()

cross_attn_out = cross_attn(
    tokens,
    context = memories,
    mask = tokens_mask,
    context_mask = memories_mask
)

cross_attn_out.shape # (1, 1024, 512) - same as tokens

This repository also has an improvised version for autoregressive attention. The way this was achieved was by viewing the sequence in windows. Each window can only attend to windows of key / values into the past. The local attention of the light branch covers the intra-window attention.

The coordinate descent is made viable through a CUDA kernel written in Triton. Finally, to get autoregressive generation to work well, I had to make sure for the unrouted tokens (for queries), outputs a learned output embedding rather than just zeros.

Currently I am seeing occasional differences between the gradients (as high as 1e-1 for a very small fraction of elements) once the number of iterations exceed 20. However, enwik8 seems to train well and I can see the effects of the routing. Training is surprisingly stable too

ex.

import torch
from colt5_attention import ConditionalRoutedAutoregressiveAttention

# mock input, say it is 8192 length

tokens = torch.randn(2, 8192, 512).cuda()

# attention

attn = ConditionalRoutedAutoregressiveAttention(
    dim = 512,
    light_dim_head = 64,          # attention head dimension of light branch
    light_heads = 8,              # number of attention heads for light branch
    light_window_size = 128,      # local attention receptive field for light
    heavy_window_size = 128,      # the windowing for the routed heavy attention, by default, will be equal to the light window size. be aware if this is any greater than the light window size, there may be tokens that would be missed by attention
    heavy_dim_head = 64,          # attention head dimension of heavy branch
    heavy_heads = 8,              # number of attention heads for heavy branch
    num_heavy_tokens_q = 32,      # heavy branch receives only 32 out of 128 of the windowed queries (1024 query tokens total)
    num_heavy_tokens_kv = 1024,   # heavy branch receives only 1024 routed tokens for key-values
    num_routed_kv = 2,            # one can split the attention heads so that groups of heads attend to different sets of key - values (2 routing tokens in this case)
    use_triton = True,            # will need to use Triton for this to be viable, otherwise it is too slow and memory efficient with the number of iterations
    use_flash_attn = True         # use flash attention in heavy branch
).cuda()

attn_out = attn(tokens) + tokens # (2, 8192, 512) - output of attention with residual (prenorm is included)

Finally, this repository contains a version for image feature maps. Typically a lot of research papers cannot do attention on image feature maps with dimensions greater than 32 by 32. This routed attention will use a local window patch for the light branch, and routed attention for the heavy

ex.

import torch
from colt5_attention import ConditionalRoutedImageAttention

attn = ConditionalRoutedImageAttention(
    dim = 32,
    light_dim_head = 64,       # attention head dimension of light branch
    light_heads = 8,           # number of attention heads for light branch
    light_window_size = 32,    # height and width of local window attention on the image feature map
    channel_first = True,      # whether to accept images with channel first than last
    heavy_dim_head = 64,       # attention head dimension of heavy branch
    heavy_heads = 8,           # number of attention heads for heavy branch
    num_heavy_tokens_q = 1024, # heavy branch receives only 1024 routed tokens of 65536
    num_heavy_tokens_kv = 1024 # heavy branch receives only 1024 routed tokens of 65536
).cuda()

fmap = torch.randn(1, 32, 256, 256).cuda() # image feature map is too large for attention, given 256 ^ 2  == 65536 tokens

out = attn(fmap)

Simple ViT using coordinate descent routed attention and feedforward

import torch
from colt5_attention.vit import ConditionalRoutedViT

vit = ConditionalRoutedViT(
    image_size = 256,                # image size
    patch_size = 32,                 # patch size
    num_classes = 1000,              # number of output classes
    dim = 1024,                      # feature dimension
    depth = 6,                       # depth
    attn_num_heavy_tokens_q = 16,    # number of routed queries for heavy attention
    attn_num_heavy_tokens_kv = 16,   # number of routed key/values for heavy attention
    attn_heavy_dim_head = 64,        # dimension per attention head for heavy
    attn_heavy_heads = 8,            # number of attention heads for heavy
    attn_light_window_size = 4,      # the local windowed attention for light branch
    attn_light_dim_head = 32,        # dimension per head for local light attention
    attn_light_heads = 4,            # number of attention heads for local windowed attention
    ff_num_heavy_tokens = 16,        # number of tokens routed for heavy feedforward
    ff_heavy_mult = 4,               # the expansion factor of the heavy feedforward branch
    ff_light_mult = 2                # expansion factor of the light feedforward branch
)

images = torch.randn(1, 3, 256, 256)

logits = vit(images) # (1, 1000)

Differentiable Topk

Use a small wrapper around coordinate descent for differentiable topk

import torch
from colt5_attention import topk

x = torch.randn(1024, 512)

values, indices, coor_descent_values, gates = topk(x, k = 10, fused = True)

# you can either use the topk indices + gates, or use the values directly (values have already been multiplied with the gates within the function)

Todo

Citations

@inproceedings{Ainslie2023CoLT5FL,
    title   = {CoLT5: Faster Long-Range Transformers with Conditional Computation},
    author  = {Joshua Ainslie and Tao Lei and Michiel de Jong and Santiago Ontan'on and Siddhartha Brahma and Yury Zemlyanskiy and David Uthus and Mandy Guo and James Lee-Thorp and Yi Tay and Yun-Hsuan Sung and Sumit Sanghai},
    year    = {2023}
}
@article{Tillet2019TritonAI,
    title   = {Triton: an intermediate language and compiler for tiled neural network computations},
    author  = {Philippe Tillet and H. Kung and D. Cox},
    journal = {Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages},
    year    = {2019}
}
@inproceedings{dao2022flashattention,
    title     = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author    = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year      = {2022}
}
@article{Lei2023ConditionalAP,
    title   = {Conditional Adapters: Parameter-efficient Transfer Learning with Fast Inference},
    author  = {Tao Lei and Junwen Bai and Siddhartha Brahma and Joshua Ainslie and Kenton Lee and Yanqi Zhou and Nan Du and Vincent Zhao and Yuexin Wu and Bo Li and Yu Zhang and Ming-Wei Chang},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2304.04947}
}
@article{Beyer2022BetterPV,
    title   = {Better plain ViT baselines for ImageNet-1k},
    author  = {Lucas Beyer and Xiaohua Zhai and Alexander Kolesnikov},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2205.01580}
}

.\lucidrains\CoLT5-attention\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'CoLT5-attention',
  # 查找并包含所有包
  packages = find_packages(),
  # 版本号
  version = '0.10.20',
  # 许可证类型
  license='MIT',
  # 描述
  description = 'Conditionally Routed Attention',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/CoLT5-attention',
  # 关键词
  keywords = [
    'artificial intelligence',
    'attention mechanism',
    'dynamic routing'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.6.1',
    'local-attention>=1.8.6',
    'packaging',
    'torch>=1.10'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\complex-valued-transformer\complex_valued_transformer\attend.py

from functools import partial  # 导入 functools 模块中的 partial 函数

import torch  # 导入 torch 库
from torch import nn, einsum, Tensor  # 从 torch 库中导入 nn、einsum、Tensor
import torch.nn.functional as F  # 从 torch 库中导入 F

from collections import namedtuple  # 导入 collections 模块中的 namedtuple
from functools import wraps  # 导入 functools 模块中的 wraps
from packaging import version  # 导入 packaging 模块中的 version

from einops import rearrange, repeat  # 从 einops 库中导入 rearrange、repeat

# 定义一个命名元组 EfficientAttentionConfig,包含三个属性
EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回该变量,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 仅执行一次的装饰器函数
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 仅打印一次的函数
print_once = once(print)

# tensor 函数

# 创建一个因果掩码
def create_causal_mask(i, j, device):
    return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1)

# 主类

class Attend(nn.Module):
    def __init__(
        self,
        *,
        dropout=0.,
        causal=False,
        heads=None,
        scale=None,
        flash=False,
    ):
        super().__init__()
        self.scale = scale

        self.causal = causal
        self.create_causal_mask = create_causal_mask

        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        # flash attention

        self.flash = flash
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # determine efficient attention configs for cuda and cpu

        self.cpu_config = EfficientAttentionConfig(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        major, minor = device_properties.major, device_properties.minor

        if (major, minor) == (8, 0):
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = EfficientAttentionConfig(True, False, False)
        elif (major, minor) == (9, 0):
            print_once('H100 GPU detected, using flash attention')
            self.cuda_config = EfficientAttentionConfig(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = EfficientAttentionConfig(False, True, True)

    def flash_attn(
        self,
        q, k, v,
        mask=None
    ):
        # 解包 q 的形状,获取 batch, heads, q_len, k_len, is_cuda, device
        batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

        # 检查是否存在 mask 并扩展到兼容的形状
        # mask 是 B L,因此需要扩展为 B H N L

        causal = self.causal

        # 在 kv 缓存中只有一个令牌的情况下(q_len == 1),只需关闭因果掩码
        # 在推测解码中,这可能会增加到 5-6,因此在那里需要右对齐的因果掩码

        if q_len == 1 and causal:
            causal = False

        # 扩展键填充掩码

        if exists(mask):
            assert mask.ndim == 4
            mask = mask.expand(batch, heads, q_len, k_len)

        # 处理 kv 缓存 - 这应该在更新的 flash attention 2 中可以绕过

        if k_len > q_len and causal:
            causal_mask = self.create_causal_mask(q_len, k_len, device=device)
            if not exists(mask):
                mask = ~causal_mask
            else:
                mask = mask & ~causal_mask
            causal = False

        # 手动处理因果掩码,如果给定了另一个掩码

        row_is_entirely_masked = None

        if exists(mask) and causal:
            causal_mask = self.create_causal_mask(q_len, k_len, device=device)
            mask = mask & ~causal_mask

            # 防止整行被掩盖

            row_is_entirely_masked = ~mask.any(dim=-1)
            mask[..., 0] = mask[..., 0] | row_is_entirely_masked

            causal = False

        # 检查是否有兼容的设备用于 flash attention

        config = self.cuda_config if is_cuda else self.cpu_config

        # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
        
        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask=mask,
                dropout_p=self.dropout if self.training else 0., 
                is_causal=causal
            )

        # 对于整行被完全掩盖的情况,应将该行令牌的输出置零

        if exists(row_is_entirely_masked):
            out = out.masked_fill(row_is_entirely_masked[..., None], 0.)

        return out

    def forward(
        self,
        q, k, v,
        mask=None
    ):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device

        scale = default(self.scale, q.shape[-1] ** -0.5)

        if self.flash:
            return self.flash_attn(q, k, v, mask=mask)

        kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'

        sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale

        i, j, dtype = *sim.shape[-2:], sim.dtype

        mask_value = -torch.finfo(sim.dtype).max

        if exists(mask):
            sim = sim.masked_fill(~mask, mask_value)

        if self.causal and n > 1:
            causal_mask = self.create_causal_mask(i, j, device=device)
            sim = sim.masked_fill(causal_mask, mask_value)

        attn = sim.softmax(dim=-1)
        attn = attn.type(dtype)

        attn = self.attn_dropout(attn)

        out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)

        return out

.\lucidrains\complex-valued-transformer\complex_valued_transformer\autoregressive_wrapper.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 函数
from einops import rearrange

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 返回输入的函数
def identity(t):
    return t

# 评估装饰器函数
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        # 保存模型当前是否为训练状态
        was_training = model.training
        # 将模型设置为评估状态
        model.eval()
        # 调用传入的函数
        out = fn(model, *args, **kwargs)
        # 恢复模型之前的训练状态
        model.train(was_training)
        return out
    return inner

# top k 过滤

# 根据阈值过滤 logits 中的 top k 值
def top_k(logits, thres = 0.9):
    # 计算需要保留的 top k 值的数量
    k = int((1 - thres) * logits.shape[-1])
    # 获取 top k 值及其索引
    val, ind = torch.topk(logits, k)
    # 创建与 logits 相同形状的张量,填充为负的最大值
    probs = torch.full_like(logits, -torch.finfo(logits.dtype).max)
    # 根据索引将 top k 值填充到 probs 中
    probs.scatter_(1, ind, val)
    return probs

# 自回归包装器类
class AutoregressiveWrapper(nn.Module):
    def __init__(
        self,
        net,        
        seq_len,
        pad_value = 0,
        logits_fn = identity
    ):
        super().__init__()
        self.seq_len = seq_len
        self.pad_value = pad_value
        self.net = net
        self.logits_fn = logits_fn

    # 生成函数,用于生成序列
    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        prompt,
        seq_len,
        temperature = 1.0,
        filter_thres = 0.9,
        **kwargs
    ):
        # 获取 prompt 的形状、设备信息
        b, t, device = *prompt.shape, prompt.device

        out = prompt

        # 生成序列
        for _ in range(seq_len):
            # 获取最后 seq_len 长度的输出
            logits = self.net(out[:, -self.seq_len:], **kwargs)[:, -1]
            logits = self.logits_fn(logits)

            # 过滤 logits 中的 top k 值
            filtered_logits = top_k(logits, thres = filter_thres)
            # 计算 softmax 温度调节后的概率
            probs = F.softmax(filtered_logits / temperature, dim = -1)

            # 从概率分布中采样一个值
            sample = torch.multinomial(probs, 1)
            # 将采样值拼接到输出序列中
            out = torch.cat((out, sample), dim = -1)

        return out[:, t:]

    # 前向传播函数
    def forward(self, x, **kwargs):
        # 获取输入 x 的特征和标签
        x, labels = x[:, :-1], x[:, 1:]
        # 获取模型输出的 logits
        logits = self.net(x, **kwargs)
        # 重排 logits 的维度
        logits = rearrange(self.logits_fn(logits), "b c n -> b n c")
        # 计算交叉熵损失
        return F.cross_entropy(logits, labels)

.\lucidrains\complex-valued-transformer\complex_valued_transformer\complex_valued_transformer.py

from typing import Optional
from functools import partial

import torch
from torch import cfloat
import torch.nn.functional as F
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList

from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange

from complex_valued_transformer.attend import Attend

# helpers

# 检查变量是否存在
def exists(v):
    return v is not None

# 如果变量存在则返回该变量,否则返回默认值
def default(v, d):
    return v if exists(v) else d

# helper tensor functions

# 使用旋转因子调制输入张量
def modulate_with_rotation(x, m):
    if m.dtype == cfloat:
        m = m.abs()

    rot = m.cos() + 1.j * m.sin()
    return x * rot

# complex attention
# https://arxiv.org/abs/2306.09827

# 实部复杂注意力机制
def complex_attention_real(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    attend: Attend,
    mask: Optional[Tensor] = None
):
    """
    section 4.1 equation 8
    """

    assert all([t.dtype == cfloat for t in (q, k, v)])
    q, k, v = map(torch.view_as_real, (q, k, v))
    q, k, v = map(lambda t: rearrange(t, '... d c -> ... (d c)'), (q, k, v))

    o = attend(q, k, v, mask = mask)

    o = rearrange(o, '... (d c) -> ... d c', c = 2)
    return torch.view_as_complex(o)

# complex attention - Yang et al
# https://arxiv.org/abs/1910.10202

# 完整复杂注意力机制
def complex_attention_complete(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    attend: Attend,
    mask: Optional[Tensor] = None
):
    """
    section 3.2 equation 3
    """
    batch, device = q.shape[0], q.device

    assert all([t.dtype == cfloat for t in (q, k, v)])
    q, k, v = map(torch.view_as_real, (q, k, v))

    # complex attention =    (MH(A, A, A) − MH(A, B, B) − MH(B, A, B) − MH(B, B, A))
    #                     + i(MH(A, A, B) + MH(A, B, A) + MH(B, A, A) − MH(B, B, B))

    q = repeat(q, 'b h n d c -> (c r b) h n d', r = 2)
    k = repeat(k, 'b h n d c -> (r c b) h n d', r = 2)
    v = repeat(v, 'b h n d c -> (r b) h n (d c)', r = 4)

    if exists(mask):
        mask = repeat(mask, 'b ... -> (r b) ...', r = 4)

    o = attend(q, k, v, mask = mask)

    o = rearrange(o, '(r b) ... (d c) -> (r c) b ... d', r = 4, c = 2)

    indices = torch.tensor([0, 3, 5, 6, 1, 2, 4, 7], dtype = torch.long, device = device)

    o = rearrange(o[indices], '(r c) ... -> ... c r', c = 2)

    sign = torch.tensor([
        [1., -1., -1., -1.],   # real component
        [1.,  1.,  1., -1.]    # imag component
    ], dtype = o.dtype, device = device)

    o = (o * sign).sum(dim = -1)

    return torch.view_as_complex(o)

# complex multihead attention

# 复杂多头注意力机制
class ComplexMultiheadAttention(Module):
    def __init__(
        self,
        dim,
        *,
        causal = False,
        dim_head = 32,
        heads = 8,
        complete_complex = False, # whether to use complete complex formulation (Yang et al.) or just the real component, which reduces down to usual dot product on real and imaginary components flattened into the feature dimension
        flash = False
    ):
        super().__init__()
        dim_inner = heads * dim_head

        self.to_q = nn.Linear(dim, dim_inner, bias = False, dtype = cfloat)
        self.to_kv = nn.Linear(dim, dim_inner * 2, bias = False, dtype = cfloat)
        self.to_out = nn.Linear(dim_inner, dim, bias = False, dtype = cfloat)

        maybe_flash_attn = Attend(
            causal = causal,
            heads = heads,
            flash = flash
        )

        complex_attention = complex_attention_complete if complete_complex else complex_attention_real
        self.attend = partial(complex_attention, attend = maybe_flash_attn)

        self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
        self.merge_heads = Rearrange('b h n d -> b n (h d)')

    def forward(
        self,
        x,
        context = None,
        mask = None,
        rotary_emb = None
        ):
        # 检查是否存在上下文变量
        has_context = exists(context)
        # 如果上下文变量不存在,则使用默认值 x
        context = default(context, x)

        # 将输入 x 转换为查询 q,键 k,值 v
        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        # 将查询 q,键 k,值 v 分别拆分为多个头部
        q, k, v = map(self.split_heads, (q, k, v))

        # 如果存在旋转嵌入变量,则将查询 q 和键 k 乘以旋转嵌入
        if exists(rotary_emb):
            q = q * rotary_emb
            k = k * rotary_emb

        # 使用注意力机制计算输出 o
        o = self.attend(q, k, v, mask = mask)

        # 将多个头部的输出 o 合并
        o = self.merge_heads(o)
        # 返回最终输出
        return self.to_out(o)
# 定义一个名为 ComplexRMSNorm 的类,继承自 Module 类
class ComplexRMSNorm(Module):
    # 初始化方法,接受一个参数 dim
    def __init__(self, dim):
        # 调用父类的初始化方法
        super().__init__()
        # 初始化 scale 属性为 dim 的平方根的倒数
        self.scale = dim ** -0.5
        # 初始化 gamma 属性为一个可学习参数,维度为 dim,数据类型为复数
        self.gamma = nn.Parameter(torch.ones(dim, dtype=cfloat))

    # 前向传播方法,接受输入 x
    def forward(self, x):
        # 对输入 x 进行维度为 -1 的标准化,然后乘以 gamma 和 scale
        return F.normalize(x, dim=-1) * self.gamma * self.scale

# 定义一个名为 ModReLU 的类,继承自 Module 类
class ModReLU(Module):
    # 初始化方法,接受一个参数 relu_squared,默认为 False
    def __init__(self, relu_squared=False):
        # 调用父类的初始化方法
        super().__init__()
        # 根据 relu_squared 的值确定 pow 的值为 2 或 1
        self.pow = 2 if relu_squared else 1
        # 初始化 bias 属性为一个可学习参数,值为 0
        self.bias = nn.Parameter(torch.tensor(0.))

    # 前向传播方法,接受输入 x
    def forward(self, x):
        # 计算实部,使用 ReLU 函数对绝对值加上 bias,然后取 pow 次方
        real = F.relu(torch.abs(x) + self.bias) ** self.pow
        # 计算虚部,使用指数函数计算角度
        imag = torch.exp(1.j * torch.angle(x))
        # 返回实部和虚部相加的结果
        return real + imag

# 定义一个名为 ComplexFeedForward 的函数,接受参数 dim、mult 和 relu_squared,默认为 4 和 False
def ComplexFeedForward(dim, mult=4, relu_squared=False):
    # 计算内部维度 dim_inner
    dim_inner = dim * mult
    # 返回一个包含线性层、ModReLU 层和线性层的序列
    return nn.Sequential(
        nn.Linear(dim, dim_inner, dtype=cfloat),
        ModReLU(relu_squared=relu_squared),
        nn.Linear(dim_inner, dim, dtype=cfloat)
    )

# 定义一个名为 RotaryEmbedding 的类,继承自 Module 类
class RotaryEmbedding(Module):
    # 初始化方法,接受参数 dim 和 base,默认为 10000
    def __init__(self, dim, base=10000):
        # 调用父类的初始化方法
        super().__init__()
        # 计算频率的倒数
        inv_freq = 1.0 / (base ** (torch.arange(0, dim).float() / dim))
        # 将频率的倒数作为缓冲区注册为 inv_freq 属性
        self.register_buffer('inv_freq', inv_freq)

    # 定义 device 属性,返回 inv_freq 的设备信息
    @property
    def device(self):
        return self.inv_freq.device

    # 前向传播方法,接受参数 seq_len
    def forward(self, seq_len):
        # 生成序列 t,计算频率,返回余弦和正弦值
        t = torch.arange(seq_len, device=self.device).type_as(self.inv_freq)
        freqs = einsum('i, j -> i j', t, self.inv_freq)
        return torch.cos(freqs) + 1.j * torch.sin(freqs)

# 定义一个名为 ComplexTransformer 的类,继承自 Module 类
class ComplexTransformer(Module):
    # 初始化方法,接受多个参数
    def __init__(
        self,
        dim,
        *,
        depth,
        num_tokens: Optional[int] = None,
        causal=False,
        dim_head=32,
        heads=8,
        ff_mult=4,
        relu_squared=True,
        complete_complex=False,
        rotary_emb=True,
        flash_attn=True
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 判断是否存在 num_tokens
        self.has_embed = exists(num_tokens)

        # 如果存在 num_tokens,则初始化 embed 属性为一个可学习参数
        if exists(num_tokens):
            self.embed = nn.Parameter(torch.randn((num_tokens, dim), dtype=cfloat))

        # 根据 rotary_emb 的值初始化 rotary_emb 属性为 None 或 RotaryEmbedding 对象
        self.rotary_emb = None
        if rotary_emb:
            self.rotary_emb = RotaryEmbedding(dim_head)

        # 初始化 layers 属性为一个模块列表,包含多个复杂层
        self.layers = ModuleList([])
        for _ in range(depth):
            self.layers.append(ModuleList([
                ComplexRMSNorm(dim),
                ComplexMultiheadAttention(dim=dim, dim_head=dim_head, heads=heads, causal=causal, complete_complex=complete_complex, flash=flash_attn),
                ComplexRMSNorm(dim),
                ComplexFeedForward(dim=dim, mult=ff_mult, relu_squared=relu_squared)
            ]))

        # 初始化 norm 属性为 ComplexRMSNorm 对象
        self.norm = ComplexRMSNorm(dim)

        # 初始化 to_logits 属性为一个线性层,用于输出结果
        self.to_logits = nn.Linear(dim, num_tokens, dtype=cfloat)

    # 前向传播方法,接受输入 x、context、mask 和其他参数
    def forward(
        self,
        x,
        context=None,
        mask=None,
        return_abs_logits=False,
        return_real_logits=False
    ):
        # 如果存在 embed 属性,则将 x 替换为 embed[x]
        if self.has_embed:
            x = self.embed[x]

        # 获取序列长度
        seq_len = x.shape[-2]
        rotary_emb = None

        # 如果存在 rotary_emb 属性,则计算 rotary_emb
        if exists(self.rotary_emb):
            rotary_emb = self.rotary_emb(seq_len)

        # 遍历复杂层,进行前向传播
        for attn_norm, attn, ff_norm, ff in self.layers:
            x = attn(attn_norm(x), context=context, mask=mask, rotary_emb=rotary_emb) + x
            x = ff(ff_norm(x)) + x

        # 对结果进行标准化
        x = self.norm(x)

        # 如果不存在 embed 属性,则直接返回结果
        if not self.has_embed:
            return x

        # 计算 logits
        logits = self.to_logits(x)

        # 根据参数选择返回的 logits 类型
        assert (int(return_abs_logits) + int(return_real_logits)) <= 1
        if return_abs_logits:
            logits = logits.abs()
        elif return_real_logits:
            logits = logits.real

        return logits

.\lucidrains\complex-valued-transformer\complex_valued_transformer\__init__.py

# 从 complex_valued_transformer 模块中导入以下函数和类
from complex_valued_transformer.complex_valued_transformer import (
    ComplexMultiheadAttention,  # 导入复数多头注意力机制类
    ComplexRMSNorm,  # 导入复数均方根归一化类
    ComplexFeedForward,  # 导入复数前馈神经网络类
    ComplexTransformer,  # 导入复数变换器类
    complex_attention_real,  # 导入实部注意力函数
    complex_attention_complete,  # 导入完整注意力函数
    modulate_with_rotation  # 导入旋转调制函数
)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

Complex Valued Transformer

Implementation of the transformer proposed in Building Blocks for a Complex-Valued Transformer Architecture, plus a few other proposals from related papers. The full architecture will be evaluated on enwik8 character level language modeling as well as some algorithmic tasks (parity, binary addition).

Will not bother with complex layernorm, as RMS norm is now much more popular.

Update: It trains, seems to tolerate a much higher learning rate. Surprisingly stable, even when using softmax for complete complex formulation from Yang et al. This is likely because both papers are using the original transformer architecture with post-normalization instead of the recent pre-normalization.

Update 2: No difference between Eilers (just real component) vs Yang (real and imaginary) complex attention, at least for enwik8

Update 3: I am not seeing anything remarkable. YMMV

Install

$ pip install complex-valued-transformer

Usage

import torch
from complex_valued_transformer import ComplexTransformer

transformer = ComplexTransformer(
    num_tokens = 256,
    dim = 512,
    depth = 4,
    dim_head = 32,
    heads = 8,
    causal = True,
    complete_complex = True
)

ids = torch.randint(0, 256, (2, 1024))

logits = transformer(ids) # (2, 1024, 256)

Todo

Citations

@article{Eilers2023BuildingBF,
    title   = {Building Blocks for a Complex-Valued Transformer Architecture},
    author  = {Florian Eilers and Xiaoyi Jiang},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2306.09827},
    url     = {https://api.semanticscholar.org/CorpusID:258542729}
}
@article{Yang2019ComplexTA,
    title    = {Complex Transformer: A Framework for Modeling Complex-Valued Sequence},
    author   = {Muqiao Yang and Martin Q. Ma and Dongyu Li and Yao-Hung Hubert Tsai and Ruslan Salakhutdinov},
    journal  = {ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
    year     = {2019},
    pages    = {4232-4236},
    url      = {https://api.semanticscholar.org/CorpusID:204838137}
}
@article{Dong2021SignalTC,
    title   = {Signal Transformer: Complex-valued Attention and Meta-Learning for Signal Recognition},
    author  = {Yihong Dong and Ying Peng and Muqiao Yang and Songtao Lu and Qingjiang Shi},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2106.04392},
    url     = {https://api.semanticscholar.org/CorpusID:235367992}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@article{So2021PrimerSF,
    title   = {Primer: Searching for Efficient Transformers for Language Modeling},
    author  = {David R. So and Wojciech Ma'nke and Hanxiao Liu and Zihang Dai and Noam M. Shazeer and Quoc V. Le},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2109.08668},
    url     = {https://api.semanticscholar.org/CorpusID:237563187}
}

.\lucidrains\complex-valued-transformer\setup.py

# 导入设置安装和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'complex-valued-transformer',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.0.14',  # 版本号
  license='MIT',  # 许可证
  description = 'Complex Valued Transformer / Attention',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/complex-valued-transformer',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'attention mechanisms',
    'transformers',
    'complex domain'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.7.0',
    'torch>=1.12'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\complex-valued-transformer\train.py

# 导入所需的库
import gzip
import random
import tqdm
import numpy as np

import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 导入自定义的模块
from complex_valued_transformer.autoregressive_wrapper import AutoregressiveWrapper
from complex_valued_transformer.complex_valued_transformer import ComplexTransformer

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-1
VALIDATE_EVERY = 100
PRIME_LENGTH = 128
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512

# 定义辅助函数
def cycle(loader):
    # 无限循环生成数据
    while True:
        for data in loader:
            yield data

def decode_token(token):
    # 将 token 解码为字符
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    # 将 tokens 解码为字符串
    return "".join(list(map(decode_token, tokens)))

# 实例化 Transformer 模型
model = ComplexTransformer(
    num_tokens = 256,
    dim = 256,
    dim_head = 32,
    depth = 8,
    causal = True,
    complete_complex = True # 设置为 True 会增加 MHA 的计算量(Yang 等人的论文)
)

model = AutoregressiveWrapper(
    model,
    seq_len = SEQ_LEN,
    logits_fn = lambda logits: logits.real
).cuda()

# 准备 enwik8 数据
with gzip.open("./data/enwik8.gz") as file:
    data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
    np_train, np_valid = np.split(data, [int(90e6)])
    data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))

# 定义优化器
optim = Adam(model.parameters(), lr = LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
    model.train()

    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward(loss / GRADIENT_ACCUMULATE_EVERY)

    print(f"training loss: {loss.item()}")
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f"validation loss: {loss.item()}")

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:PRIME_LENGTH]
        prime = decode_tokens(inp)
        print(f"%s \n\n %s", (prime, "*" * 100))

        sample = model.generate(inp[None, ...], GENERATE_LENGTH)
        output_str = decode_tokens(sample[0])
        print(output_str, "\n")

.\lucidrains\compositional-attention-pytorch\compositional_attention_pytorch\compositional_attention_pytorch.py

import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange
from einops_exts import rearrange_many

# 检查变量是否存在的函数
def exists(val):
    return val is not None

# 计算稳定的 softmax 函数
def stable_softmax(t, dim = -1):
    t = t - t.amax(dim = dim, keepdim = True).detach()
    return t.softmax(dim = dim)

# 组合注意力机制类
class CompositionalAttention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        num_searches = 8,
        num_retrievals = 2,
        dropout = 0.,
        prenorm = False,
        causal = False
    ):
        super().__init__()
        # 根据 prenorm 参数选择是否使用 LayerNorm 或 Identity
        self.norm = nn.LayerNorm(dim) if prenorm else nn.Identity()

        self.scale = dim_head ** -0.5
        inner_search_dim = dim_head * num_searches
        inner_retrieval_dim = dim_head * num_retrievals

        self.num_searches = num_searches
        self.num_retrievals = num_retrievals

        # 线性变换层,将输入映射到搜索查询和键
        self.to_searches_queries = nn.Linear(dim, inner_search_dim, bias = False)
        self.to_searches_keys = nn.Linear(dim, inner_search_dim, bias = False)
        self.to_retrieval_values = nn.Linear(dim, inner_retrieval_dim, bias = False)

        # 线性变换层,将输入映射到检索查询和键
        self.to_retrieval_queries = nn.Linear(dim, inner_search_dim, bias = False)
        self.to_retrieval_keys = nn.Linear(dim_head, dim_head, bias = False)

        # 线性变换层,将检索结果映射回输出维度
        self.to_out = nn.Linear(inner_search_dim, dim, bias = False)

        self.search_dropout = nn.Dropout(dropout)
        self.retrieval_dropout = nn.Dropout(dropout)

        # 是否使用自回归变体进行自我实验
        self.causal = causal

    def forward(self, x, mask = None):
        """
        einstein notation:
        b - batch
        n - sequence dimension
        i - sequence dimension (source)
        j - sequence dimension (target, aggregation dimension)
        s - number of searches
        r - number of retrievals
        d - feature dimension
        """
        x = self.norm(x)

        s = self.num_searches
        r = self.num_retrievals

        # 获取搜索查询和键
        sq, sk = self.to_searches_queries(x), self.to_searches_keys(x)
        sq, sk = rearrange_many((sq, sk), 'b n (s d) -> b s n d', s = s)

        sq = sq * self.scale

        # 计算搜索相似度和注意力
        search_sim = einsum('b s i d, b s j d -> b s i j', sq, sk)

        if exists(mask):
            mask = rearrange(mask, 'b j -> b 1 1 j')
            search_sim = search_sim.masked_fill(~mask, -torch.finfo(search_sim.dtype).max)

        if self.causal:
            i, j = search_sim.shape[-2:]
            causal_mask = torch.ones((i, j), device = x.device, dtype = torch.bool).triu(j - i + 1)
            search_sim = search_sim.masked_fill(causal_mask, -torch.finfo(search_sim.dtype).max)

        search_attn = stable_softmax(search_sim, dim = -1)
        search_attn = self.search_dropout(search_attn)

        # 获取检索值
        rv = self.to_retrieval_values(x)
        rv = rearrange(rv, 'b n (r d) -> b r n d', r = r)

        retrieved = einsum('b s i j, b r j d -> b s r i d', search_attn, rv)

        # 获取检索查询和键
        rq, rk = self.to_retrieval_queries(x), self.to_retrieval_keys(retrieved)
        rq = rearrange(rq, 'b n (s d) -> b s n d', s = s)
        rq = rq * self.scale

        # 获取检索注意力
        retrieval_sim = einsum('b s n d , b s r n d -> b s n r', rq, rk)

        retrieval_attn = stable_softmax(retrieval_sim, dim = -1)
        retrieval_attn = self.retrieval_dropout(retrieval_attn)

        # 聚合检索结果
        out = einsum('b s n r, b s r n d -> b s n d', retrieval_attn, retrieved)

        # 组合搜索结果
        out = rearrange(out, 'b s n d -> b n (s d)')
        return self.to_out(out)

.\lucidrains\compositional-attention-pytorch\compositional_attention_pytorch\__init__.py

# 从compositional_attention_pytorch包中导入CompositionalAttention类
from compositional_attention_pytorch.compositional_attention_pytorch import CompositionalAttention

Compositional Attention - Pytorch

Implementation of Compositional Attention from MILA. They reframe the "heads" of multi-head attention as "searches", and once the multi-headed/searched values are aggregated, there is an extra retrieval step (using attention) off the searched results. They then show this variant of attention yield better OOD results on a toy task. Their ESBN results still leaves a lot to be desired, but I like the general direction of the paper.

Install

$ pip install compositional-attention-pytorch

Usage

import torch
from compositional_attention_pytorch import CompositionalAttention

attn = CompositionalAttention(
    dim = 1024,            # input dimension
    dim_head = 64,         # dimension per attention 'head' - head is now either search or retrieval
    num_searches = 8,      # number of searches
    num_retrievals = 2,    # number of retrievals
    dropout = 0.,          # dropout of attention of search and retrieval
)

tokens = torch.randn(1, 512, 1024)  # tokens
mask = torch.ones((1, 512)).bool()  # mask

out = attn(tokens, mask = mask) # (1, 512, 1024)

Citations

@article{Mittal2021CompositionalAD,
    title   = {Compositional Attention: Disentangling Search and Retrieval},
    author  = {Sarthak Mittal and Sharath Chandra Raparthy and Irina Rish and Yoshua Bengio and Guillaume Lajoie},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2110.09419}
}

.\lucidrains\compositional-attention-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包名
  name = 'compositional-attention-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.0.1',
  # 许可证
  license='MIT',
  # 描述
  description = 'Compositional Attention - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/compositional-attention-pytorch',
  # 关键词
  keywords = [
    'artificial intelligence',
    'deep learning',
    'attention mechanism'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.4',
    'einops-exts',
    'torch>=1.6',
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\compressive-transformer-pytorch\compressive_transformer_pytorch\autoregressive_wrapper.py

# 导入数学库
import math
# 导入partial函数
from functools import partial
# 导入namedtuple
from collections import namedtuple

# 导入torch库
import torch
# 导入torch的神经网络模块
from torch import nn
import torch.nn.functional as F
# 导入pad_sequence函数
from torch.nn.utils.rnn import pad_sequence

# 定义一个命名元组Return,包含loss、aux_loss和is_last_batch三个字段
Return = namedtuple('Return', ['loss', 'aux_loss', 'is_last_batch'])

# 定义辅助函数

# top_p函数,根据概率阈值过滤logits
def top_p(logits, thres = 0.9):
    # 对logits进行降序排序
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    # 计算累积概率
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    # 根据阈值确定需要移除的索引
    sorted_indices_to_remove = cum_probs > (1 - thres)
    sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
    sorted_indices_to_remove[:, 0] = 0

    # 将需要移除的logits设置为负无穷
    sorted_logits[sorted_indices_to_remove] = float('-inf')
    return sorted_logits.scatter(1, sorted_indices, sorted_logits)

# top_k函数,根据概率阈值过滤logits
def top_k(logits, thres = 0.9):
    # 计算需要保留的top k值
    k = int((1 - thres) * logits.shape[-1])
    # 获取top k值及其索引
    val, ind = torch.topk(logits, k)
    # 创建与logits相同形状的tensor,并填充为负无穷
    probs = torch.full_like(logits, float('-inf'))
    # 将top k值填充到对应位置
    probs.scatter_(1, ind, val)
    return probs

# 主类

# AutoregressiveWrapper类,继承自nn.Module
class AutoregressiveWrapper(nn.Module):
    # 初始化函数
    def __init__(self, net, ignore_index = -100, pad_value = 0):
        super().__init__()
        self.pad_value = pad_value
        self.ignore_index = ignore_index

        self.net = net
        self.seq_len = net.seq_len

    # 生成函数,用于生成序列
    @torch.no_grad()
    def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
        # 保存网络是否处于训练状态
        was_training = self.net.training
        num_dims = len(start_tokens.shape)

        if num_dims == 1:
            start_tokens = start_tokens[None, :]

        b, t = start_tokens.shape

        # 将网络设置为评估模式
        self.net.eval()

        out = start_tokens

        # 处理默认的masking

        full_mask_like = lambda x: torch.full_like(x, True, dtype=torch.bool, device=x.device)

        mask = kwargs.pop('mask', None)
        if mask is None:
            mask = full_mask_like(out)

        # 处理任意长度的primed序列

        mem = None
        *primes, out = out.split(self.seq_len, dim=1)
        *prime_masks, mask = mask.split(self.seq_len, dim=1)

        for prime, prime_mask in zip(primes, prime_masks):
            _, mem, _ = self.net(prime, memories = mem, mask = prime_mask, **kwargs)

        # 生成直到达到序列长度

        input_len = out.shape[1]

        for _ in range(seq_len):
            logits, mem, aux_loss = self.net(out[:, -input_len:], memories = mem, mask = mask[:, -input_len:], **kwargs)
            logits = logits[:, -1, :]
            filtered_logits = filter_logits_fn(logits, thres = filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim=-1)
            sample = torch.multinomial(probs, 1)

            # 不同于大多数模型,一旦填满完整序列长度,输入从序列长度为1开始

            out = torch.cat((out, sample), dim=-1)
            mask = F.pad(mask, (0, 1), value=True)

            # 将样本追加到累积输出中

            input_len = input_len % self.seq_len
            input_len += 1

            if eos_token is not None and (sample == eos_token).all():
                break

        out = out[:, t:]

        if num_dims == 1:
            out = out.squeeze(0)

        # 恢复网络训练状态
        self.net.train(was_training)
        return out
    # 定义一个前向传播函数,接受输入 x,最大批处理大小 max_batch_size,默认不返回损失,**kwargs 为其他参数
    def forward(self, x, max_batch_size = None, return_loss = False, **kwargs):
        # 定义一个填充函数,将输入序列填充到相同长度
        pad = partial(pad_sequence, batch_first = True, padding_value = self.pad_value)

        # 如果不需要返回损失
        if not return_loss:
            # 如果输入不是张量,则进行填充
            if not isinstance(x, torch.Tensor):
                x = pad(x)
            # 返回网络输出结果
            return self.net(x, **kwargs)

        # 如果需要返回损失
        if isinstance(x, torch.Tensor):
            # 将输入序列拆分为输入和输出序列
            xi = x[:, :-1]
            xo = x[:, 1:]
        else:
            # 对输入序列进行填充和拆分
            xi = pad(list(map(lambda t: t[:-1], x)))
            xo = pad(list(map(lambda t: t[1:], x)))

        # 处理输入掩码,解决自回归模型中输入掩码与源序列长度不匹配的问题
        mask = kwargs.pop('mask', None)
        if mask is not None and mask.shape[1] == x.shape[1]:
            mask = mask[:, :-1]

        # 定义一个函数,用于将序列分段
        segment_fn = lambda x: x.split(self.seq_len, dim=1)
        # 将输入和输出序列分段
        (xi, xo) = map(segment_fn, (xi, xo))

        # 获取序列段数
        num_segments = len(xi)
        # 如果存在掩码,则对掩码进行分段处理
        mask = segment_fn(mask) if mask is not None else ((None,) * num_segments)

        # 如果最大批处理大小未指定,则设为输入序列的大小
        max_batch_size = x.shape[0] if max_batch_size is None else max_batch_size
        # 定义一个函数,用于将序列按照最大批处理大小分割
        split_batch_fn = lambda x: x.split(max_batch_size, dim=0)

        # 计算梯度累积次数
        grad_accumulate_every = math.ceil(x.shape[0] / max_batch_size)
        # 初始化记忆列表
        mems = [None] * grad_accumulate_every

        # 遍历每个序列段
        for xi_seg, xo_seg, mask_seg in zip(xi, xo, mask):
            # 将输入和输出序列按照最大批处理大小分割
            xi_seg, xo_seg = map(split_batch_fn, (xi_seg, xo_seg))
            mask_seg = split_batch_fn(mask_seg) if mask_seg is not None else ((None,) * grad_accumulate_every)

            new_mems = []
            # 遍历每个分割后的序列段
            for ind, (xi_seg_b, xo_seg_b, mask_seg_b, mem) in enumerate(zip(xi_seg, xo_seg, mask_seg, mems)):
                is_last = ind == (grad_accumulate_every - 1)

                # 获取网络输出结果、新记忆和辅助损失
                logits, new_mem, aux_loss = self.net(xi_seg_b, mask = mask_seg_b, memories = mem, **kwargs)
                new_mems.append(new_mem)

                # 计算交叉熵损失
                loss = F.cross_entropy(logits.transpose(1, 2), xo_seg_b, ignore_index = self.ignore_index)
                # 返回损失、辅助损失和是否为最后一个序列段的标志
                yield Return(loss, aux_loss, is_last)

            mems = new_mems

.\lucidrains\compressive-transformer-pytorch\compressive_transformer_pytorch\compressive_transformer_pytorch.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch.nn 模块中导入 F 函数
import torch.nn.functional as F

# 从 mogrifier 模块中导入 Mogrifier 类
from mogrifier import Mogrifier

# 导入 math 库
import math
# 从 collections 模块中导入 namedtuple 类
from collections import namedtuple
# 从 functools 模块中导入 partial 函数
from functools import partial
# 从 inspect 模块中导入 isfunction 函数

# 定义 Memory 命名元组
Memory = namedtuple('Memory', ['mem', 'compressed_mem'])

# 辅助函数

# 定义 to 函数,返回包含数据类型和设备信息的字典
def to(t):
    return {'dtype': t.dtype, 'device': t.device}

# 定义 cast_tuple 函数,将元素转换为元组
def cast_tuple(el):
    return el if isinstance(el, tuple) else (el,)

# 定义 default 函数,如果 x 不为 None,则返回 x,否则返回 val 或 val() 的结果
def default(x, val):
    if x is not None:
        return x
    return val if not isfunction(val) else val()

# 定义 max_neg_value 函数,返回给定张量的最大负值
def max_neg_value(tensor):
    return -torch.finfo(tensor.dtype).max

# 定义 reshape_dim 函数,根据给定维度和分割维度对张量进行重塑
def reshape_dim(t, dim, split_dims):
    shape = list(t.shape)
    num_dims = len(shape)
    dim = (dim + num_dims) % num_dims
    shape[dim:dim+1] = split_dims
    return t.reshape(shape)

# 定义 split_at_index 函数,根据给定维度和索引将张量分割成两部分
def split_at_index(dim, index, t):
    pre_slices = (slice(None),) * dim
    l = (*pre_slices, slice(None, index))
    r = (*pre_slices, slice(index, None))
    return t[l], t[r]

# 定义 queue_fifo 函数,实现先进先出队列操作
def queue_fifo(*args, length, dim=-2):
    queue = torch.cat(args, dim=dim)
    if length > 0:
        return split_at_index(dim, -length, queue)

    device = queue.device
    shape = list(queue.shape)
    shape[dim] = 0
    return queue, torch.empty(shape, device=device)

# 定义 shift 函数,实现张量的位移操作
def shift(x):
    *_, i, j = x.shape
    zero_pad = torch.zeros((*_, i, i), **to(x))
    x = torch.cat([x, zero_pad], -1)
    l = i + j - 1
    x = x.view(*_, -1)
    zero_pad = torch.zeros(*_, -x.size(-1) % l, **to(x))
    shifted = torch.cat([x, zero_pad], -1).view(*_, -1, l)
    return shifted[..., :i, i - 1:]

# 定义 iterate_tensor 函数,实现对张量的迭代操作
def iterate_tensor(t):
    length = t.shape[0]
    for ind in range(length):
        yield t[ind]

# full attention 用于计算辅助重构损失

# 定义 full_attn 函数,实现全连接注意力机制
def full_attn(q, k, v, dropout_fn=None):
    *_, dim = q.shape
    dots = torch.einsum('bhid,bhjd->bhij', q, k) * (dim ** -0.5)
    attn = dots.softmax(dim=-1)
    if dropout_fn is not None:
        attn = dropout_fn(attn)
    return torch.einsum('bhij,bhjd->bhid', attn, v)

# 辅助类

# 定义 Residual 类,实现残差连接
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        out = self.fn(x, **kwargs)
        out = cast_tuple(out)
        ret = (out[0] + x), *out[1:]
        return ret

# 定义 GRUGating 类,实现 GRU 门控机制
class GRUGating(nn.Module):
    def __init__(self, dim, fn, mogrify=False):
        super().__init__()
        self.dim = dim
        self.fn = fn
        self.gru = nn.GRUCell(dim, dim)
        self.mogrify = Mogrifier(dim, factorize_k=dim // 4) if mogrify else None

    def forward(self, x, **kwargs):
        batch, dim = x.shape[0], self.dim
        out = self.fn(x, **kwargs)
        (y, *rest) = cast_tuple(out)

        if self.mogrify is not None:
            y, x = self.mogrify(y, x)

        gated_output = self.gru(
            y.reshape(-1, dim),
            x.reshape(-1, dim)
        )

        gated_output = gated_output.reshape(batch, -1, dim)
        ret = gated_output, *rest
        return ret

# 定义 PreNorm 类,实现预层归一化
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

# 定义 ConvCompress 类,实现卷积压缩
class ConvCompress(nn.Module):
    def __init__(self, dim, ratio=4):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, ratio, stride=ratio)

    def forward(self, mem):
        mem = mem.transpose(1, 2)
        compressed_mem = self.conv(mem)
        return compressed_mem.transpose(1, 2)

# feedforward

# 定义 GELU_ 类,实现 GELU 激活函数
class GELU_(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))

# 如果 nn 模块中存在 GELU 函数,则使用 nn.GELU,否则使用 GELU_ 类
GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_

# 定义 FeedForward 类
class FeedForward(nn.Module):
    # 初始化神经网络模块,设置输入维度、倍数、dropout率、激活函数和是否使用GLU
    def __init__(self, dim, mult = 4, dropout = 0., activation = None, glu = False):
        # 调用父类的初始化方法
        super().__init__()
        # 设置默认激活函数为GELU
        activation = default(activation, GELU)

        # 是否使用GLU
        self.glu = glu
        # 第一层线性变换,输入维度为dim,输出维度为dim * mult * (2 if glu else 1)
        self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1))
        # 激活函数
        self.act = activation()
        # dropout层
        self.dropout = nn.Dropout(dropout)
        # 第二层线性变换,输入维度为dim * mult,输出维度为dim
        self.w2 = nn.Linear(dim * mult, dim)

    # 前向传播函数
    def forward(self, x, **kwargs):
        # 如果不使用GLU
        if not self.glu:
            # 第一层线性变换
            x = self.w1(x)
            # 激活函数
            x = self.act(x)
        else:
            # 使用GLU
            # 将第一层线性变换的输出分成两部分
            x, v = self.w1(x).chunk(2, dim=-1)
            # 激活函数作用在其中一部分上,另一部分保持不变
            x = self.act(x) * v

        # dropout层
        x = self.dropout(x)
        # 第二层线性变换
        x = self.w2(x)
        # 返回结果
        return x
# 定义 SelfAttention 类,继承自 nn.Module
class SelfAttention(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(self, dim, seq_len, mem_len, cmem_len, cmem_ratio = 4, heads = 8, attn_dropout = 0., dropout = 0., reconstruction_attn_dropout = 0.):
        super().__init__()
        # 断言确保维度能够被头数整除
        assert (dim % heads) == 0, 'dimension must be divisible by the number of heads'

        # 初始化各个参数
        self.heads = heads
        self.dim_head = dim // heads
        self.seq_len = seq_len
        self.mem_len = mem_len
        self.cmem_len = cmem_len
        self.cmem_ratio = cmem_ratio
        self.scale = self.dim_head ** (-0.5)

        # 创建 ConvCompress 对象,用于压缩记忆
        self.compress_mem_fn = ConvCompress(dim, cmem_ratio)

        # 创建线性层,用于计算查询、键和值
        self.to_q = nn.Linear(dim, dim, bias = False)
        self.to_kv = nn.Linear(dim, dim * 2, bias = False)
        self.to_out = nn.Linear(dim, dim)

        # 创建 Dropout 层,用于注意力机制的 dropout 和整体的 dropout
        self.attn_dropout = nn.Dropout(attn_dropout)
        self.dropout = nn.Dropout(dropout)

        # 创建 Dropout 层,用于重构注意力机制的 dropout
        self.reconstruction_attn_dropout = nn.Dropout(reconstruction_attn_dropout)
    # 定义前向传播函数,接受输入 x 和一些可选参数
    def forward(self, x, memories = None, pos_emb = None, input_mask = None, calc_memory = True, **kwargs):
        # 获取输入 x 的形状信息
        b, t, e, h, dim_h = *x.shape, self.heads, self.dim_head

        # 初始化记忆
        memories = default(memories, (None, None))
        mem, cmem = memories

        # 初始化空的记忆
        init_empty_mem = lambda: torch.empty(b, 0, e, **to(x))
        mem = default(mem, init_empty_mem)
        cmem = default(cmem, init_empty_mem)

        # 获取记忆的长度
        mem_len = mem.shape[1]
        cmem_len = cmem.shape[1]

        # 计算查询向量 q
        q = self.to_q(x)

        # 将记忆和输入 x 连接起来,获取键值对 k, v
        kv_input = torch.cat((cmem, mem, x), dim=1)
        kv_len = kv_input.shape[1]
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)

        # 合并多头注意力的维度
        merge_heads = lambda x: reshape_dim(x, -1, (-1, dim_h)).transpose(1, 2)
        q, k, v = map(merge_heads, (q, k, v))

        # 扩展键值对 k, v 的维度
        k, v = map(lambda x: x.expand(-1, h, -1, -1), (k, v))

        # 计算点积注意力
        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        mask_value = max_neg_value(dots)

        # 添加位置编码
        if pos_emb is not None:
            pos_emb = pos_emb[:, -kv_len:].type(q.dtype)
            pos_dots = torch.einsum('bhid,hjd->bhij', q, pos_emb) * self.scale
            pos_dots = shift(pos_dots)
            dots = dots + pos_dots

        # 添加输入掩码
        if input_mask is not None:
            mask = input_mask[:, None, :, None] * input_mask[:, None, None, :]
            mask = F.pad(mask, (mem_len + cmem_len, 0), value = True)
            dots.masked_fill_(~mask, mask_value)

        # 创建掩码矩阵
        total_mem_len = mem_len + cmem_len
        mask = torch.ones(t, t + total_mem_len, **to(x)).triu_(diagonal = 1 + total_mem_len).bool()
        dots.masked_fill_(mask[None, None, ...], mask_value)

        # 计算注意力权重
        attn = dots.softmax(dim=-1)
        attn = self.attn_dropout(attn)

        # 计算输出
        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = out.transpose(1, 2).reshape(b, t, -1)
        logits = self.to_out(out)
        logits = self.dropout(logits)

        # 复制记忆
        new_mem = mem
        new_cmem = cmem
        aux_loss = torch.zeros(1, requires_grad = True, **to(q))

        # 如果序列长度小于设定值或者不需要计算记忆,则直接返回结果
        if self.seq_len > t or not calc_memory:
            return logits, Memory(new_mem, new_cmem), aux_loss

        # 计算记忆和压缩记忆
        old_mem, new_mem = queue_fifo(mem, x, length = self.mem_len, dim = 1)
        old_mem_padding = old_mem.shape[1] % self.cmem_ratio

        # 对旧记忆进行填充
        if old_mem_padding != 0:
            old_mem = F.pad(old_mem, (0, 0, old_mem_padding, 0), value = 0.)

        # 如果旧记忆为空或者压缩记忆长度小于等于0,则直接返回结果
        if old_mem.shape[1] == 0 or self.cmem_len <= 0:
            return logits, Memory(new_mem, new_cmem), aux_loss

        # 压缩记忆
        compressed_mem = self.compress_mem_fn(old_mem.detach())
        old_cmem, new_cmem = split_at_index(1, -self.cmem_len, torch.cat((cmem, compressed_mem), dim=1))

        # 如果不处于训练状态,则直接返回结果
        if not self.training:
            return logits, Memory(new_mem, new_cmem), aux_loss

        # 计算训练时的压缩记忆辅助损失
        self.to_kv.weight.detach_()

        cmem_k, cmem_v = self.to_kv(compressed_mem).chunk(2, dim=-1)
        cmem_k, cmem_v = map(merge_heads, (cmem_k, cmem_v))
        cmem_k, cmem_v = map(lambda x: x.expand(-1, h, -1, -1), (cmem_k, cmem_v))

        old_mem_range = slice(- min(mem_len, self.mem_len) - self.seq_len, -self.seq_len)
        old_mem_k, old_mem_v = map(lambda x: x[:, :, old_mem_range].clone(), (k, v))

        q, old_mem_k, old_mem_v = map(torch.detach, (q, old_mem_k, old_mem_v))

        attn_fn = partial(full_attn, dropout_fn = self.reconstruction_attn_dropout)

        aux_loss = F.mse_loss(
            attn_fn(q, old_mem_k, old_mem_v),
            attn_fn(q, cmem_k, cmem_v)
        )

        return logits, Memory(new_mem, new_cmem), aux_loss
# 定义一个压缩变换器类,继承自 nn.Module
class CompressiveTransformer(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        num_tokens,  # 标记的数量
        dim,  # 维度
        seq_len,  # 序列长度
        depth,  # 深度
        emb_dim = None,  # 嵌入维度,默认为 None
        memory_layers = None,  # 记忆层,默认为 None
        enhanced_recurrence = True,  # 增强循环,默认为 True
        mem_len = None,  # 记忆长度,默认为 None
        cmem_len = None,  # 压缩记忆长度,默认为 None
        cmem_ratio = 4,  # 压缩记忆比率,默认为 4
        heads = 8,  # 头数,默认为 8
        gru_gated_residual = True,  # GRU 门控残差,默认为 True
        mogrify_gru = False,  # Mogrify GRU,默认为 False
        attn_dropout = 0.,  # 注意力丢弃率,默认为 0
        ff_glu = False,  # FeedForward GLU,默认为 False
        ff_dropout = 0.,  # FeedForward 丢弃率,默认为 0
        attn_layer_dropout = 0.,  # 注意力层丢弃率,默认为 0
        reconstruction_attn_dropout = 0.,  # 重构注意力丢弃率,默认为 0
        reconstruction_loss_weight = 1.  # 重构损失权重,默认为 1
    ):
        super().__init__()  # 调用父类的初始化函数
        emb_dim = default(emb_dim, dim)  # 如果嵌入维度为 None,则使用维度
        mem_len = default(mem_len, seq_len)  # 如果记忆长度为 None,则使用序列长度
        cmem_len = default(cmem_len, mem_len // cmem_ratio)  # 如果压缩记忆长度为 None,则使用记忆长度除以压缩比率
        memory_layers = default(memory_layers, list(range(1, depth + 1)))  # 如果记忆层为 None,则使用范围为 1 到深度的列表

        assert mem_len >= seq_len, 'length of memory should be at least the sequence length'  # 断言记忆长度至少应该等于序列长度
        assert cmem_len >= (mem_len // cmem_ratio), f'length of compressed memory should be at least the memory length divided by the compression ratio {int(mem_len // cmem_ratio)}'  # 断言压缩记忆长度至少应该等于记忆长度除以压缩比率
        assert all([layer > 0 and layer <= depth for layer in memory_layers]), 'one of the indicated memory layers is invalid'  # 断言所有指定的记忆层都在有效范围内

        self.seq_len = seq_len  # 保存序列长度

        self.depth = depth  # 保存深度
        self.memory_layers = list(memory_layers)  # 保存记忆层列表
        self.enhanced_recurrence = enhanced_recurrence  # 保存增强循环标志

        self.token_emb = nn.Embedding(num_tokens, emb_dim)  # 创建标记嵌入层
        self.to_model_dim = nn.Identity() if emb_dim == dim else nn.Linear(emb_dim, dim)  # 如果嵌入维度等于维度,则使用恒等映射,否则使用线性映射

        seq_and_mem_len = seq_len + mem_len + cmem_len  # 计算序列和记忆长度之和
        self.pos_emb = nn.Parameter(torch.zeros(heads, seq_and_mem_len, dim // heads))  # 创建位置嵌入参数

        self.to_logits = nn.Sequential(
            nn.Identity() if emb_dim == dim else nn.Linear(dim, emb_dim),  # 如果嵌入维度等于维度,则使用恒等映射,否则使用线性映射
            nn.Linear(emb_dim, num_tokens)  # 线性映射到标记数量
        )

        wrapper = partial(GRUGating, dim, mogrify = mogrify_gru) if gru_gated_residual else Residual  # 根据 GRU 门控残差标志选择包装器

        self.attn_layers = nn.ModuleList([wrapper(PreNorm(dim, SelfAttention(dim, seq_len, mem_len, cmem_len, cmem_ratio, heads, dropout = attn_layer_dropout, attn_dropout = attn_dropout, reconstruction_attn_dropout = reconstruction_attn_dropout))) for _ in range(depth)])  # 创建注意力层列表
        self.ff_layers = nn.ModuleList([wrapper(PreNorm(dim, FeedForward(dim, dropout = ff_dropout, glu = ff_glu))) for _ in range(depth)])  # 创建前馈层列表

        self.reconstruction_loss_weight = reconstruction_loss_weight  # 保存重构损失权重
    # 前向传播函数,接受输入 x,记忆 memories 和掩码 mask
    def forward(self, x, memories = None, mask = None):
        # 对输入进行 token embedding
        x = self.token_emb(x)
        # 调整输入维度到模型维度
        x = self.to_model_dim(x)
        b, t, d = x.shape

        # 断言输入序列长度不超过指定的最大序列长度
        assert t <= self.seq_len, f'input contains a sequence length {t} that is greater than the designated maximum sequence length {self.seq_len}'

        # 初始化记忆
        memories = default(memories, (None, None))
        mem, cmem = memories

        num_memory_layers = len(self.memory_layers)
        # 初始化空记忆
        init_empty_mem = lambda: torch.empty(num_memory_layers, b, 0, d, **to(x))
        mem = default(mem, init_empty_mem)
        cmem = default(cmem, init_empty_mem)

        total_len = mem.shape[2] + cmem.shape[2] + self.seq_len
        # 获取位置编码
        pos_emb = self.pos_emb[:, (self.seq_len - t):total_len]

        next_mem = []
        next_cmem = []
        aux_loss = torch.tensor(0., requires_grad = True, **to(x))

        # 如果启用增强循环
        if self.enhanced_recurrence:
            mem = torch.roll(mem, -1, 0)
            cmem = torch.roll(cmem, -1, 0)

        # 迭代记忆
        mem_iter, cmem_iter = map(iterate_tensor, (mem, cmem))

        # 遍历注意力层和前馈层
        for ind, (attn, ff) in enumerate(zip(self.attn_layers, self.ff_layers)):
            layer_num = ind + 1

            use_memory = layer_num in self.memory_layers
            memories = (next(mem_iter), next(cmem_iter)) if use_memory else None

            # 执行注意力机制和前馈网络
            x, (mem_out, cmem_out), layer_aux_loss = attn(x, memories = memories, calc_memory = use_memory, input_mask = mask, pos_emb = pos_emb)
            x,  = ff(x)

            aux_loss = aux_loss + layer_aux_loss

            # 如果不使用记忆,则跳过
            if not use_memory:
                continue

            next_mem.append(mem_out)
            next_cmem.append(cmem_out)

        # 获取输出结果
        out = self.to_logits(x)

        # 将下一步记忆和压缩记忆堆叠并分离梯度
        next_mem, next_cmem = map(torch.stack, (next_mem, next_cmem))
        next_mem, next_cmem = map(torch.detach, (next_mem, next_cmem))

        # 计算辅助损失
        aux_loss = aux_loss * self.reconstruction_loss_weight / num_memory_layers
        # 返回输出、记忆和辅助损失
        return out, Memory(mem = next_mem, compressed_mem = next_cmem), aux_loss

.\lucidrains\compressive-transformer-pytorch\compressive_transformer_pytorch\__init__.py

# 从 compressive_transformer_pytorch 包中导入 CompressiveTransformer 类
# 从 compressive_transformer_pytorch 包中导入 AutoregressiveWrapper 类
from compressive_transformer_pytorch.compressive_transformer_pytorch import CompressiveTransformer
from compressive_transformer_pytorch.autoregressive_wrapper import AutoregressiveWrapper

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

.\lucidrains\compressive-transformer-pytorch\examples\enwik8_simple\train.py

# 导入所需的库
from compressive_transformer_pytorch import CompressiveTransformer
from compressive_transformer_pytorch.autoregressive_wrapper import AutoregressiveWrapper
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 16
MAX_BATCH_SIZE = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
PRIME_LENGTH    = 512
GENERATE_LENGTH = 1024
SEQ_LEN = 512
NUM_SEGMENTS = 4

# 定义辅助函数
def cycle(loader):
    # 无限循环生成数据
    while True:
        for data in loader:
            yield data

def decode_token(token):
    # 将 token 解码为字符
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    # 将 tokens 解码为字符串
    return ''.join(list(map(decode_token, tokens)))

# 实例化模型
model = CompressiveTransformer(
    num_tokens = 256,
    dim = 512,
    depth = 8,
    seq_len = SEQ_LEN,
    mem_len = SEQ_LEN,
    cmem_len = SEQ_LEN // 4,
    heads = 8,
    memory_layers = [6,7,8]
)
model = AutoregressiveWrapper(model)
model.cuda()

# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len, segments):
        super().__init__()
        self.data = data
        self.seq_len = seq_len
        self.segments = segments
        self.total_len = seq_len * segments

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.total_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.total_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.total_len

# 创建训练集和验证集
train_dataset = TextSamplerDataset(data_train, SEQ_LEN, NUM_SEGMENTS)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN, NUM_SEGMENTS)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()
    grad_accum_every = BATCH_SIZE / MAX_BATCH_SIZE

    for mlm_loss, aux_loss, is_last in model(next(train_loader), max_batch_size = MAX_BATCH_SIZE, return_loss = True):
        loss = mlm_loss + aux_loss
        (loss / grad_accum_every).backward()

        print(f'training loss: {mlm_loss.item():.4f} | aux_loss: {aux_loss.item():.4f}')

        if is_last:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optim.step()
            optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            for loss, aux_loss, _ in model(next(val_loader), return_loss = True):
                print(f'validation loss: {loss.item():.4f}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        inp = inp[:PRIME_LENGTH]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp, GENERATE_LENGTH)
        output_str = decode_tokens(sample)
        print(output_str)

Compressive Transformer in Pytorch

Pytorch implementation of Compressive Transformers, a variant of Transformer-XL with compressed memory for long-range language modelling. I will also combine this with an idea from another paper that adds gating at the residual intersection. The memory and the gating may be synergistic, and lead to further improvements in both language modeling as well as reinforcement learning.

PyPI version

Install

$ pip install compressive_transformer_pytorch

Usage

import torch
from compressive_transformer_pytorch import CompressiveTransformer

model = CompressiveTransformer(
    num_tokens = 20000,
    emb_dim = 128,                 # embedding dimensions, embedding factorization from Albert paper
    dim = 512,
    depth = 12,
    seq_len = 1024,
    mem_len = 1024,                # memory length
    cmem_len = 1024 // 4,          # compressed memory buffer length
    cmem_ratio = 4,                # compressed memory ratio, 4 was recommended in paper
    reconstruction_loss_weight = 1,# weight to place on compressed memory reconstruction loss
    attn_dropout = 0.1,            # dropout post-attention
    ff_dropout = 0.1,              # dropout in feedforward
    attn_layer_dropout = 0.1,      # dropout for attention layer output
    gru_gated_residual = True,     # whether to gate the residual intersection, from 'Stabilizing Transformer for RL' paper
    mogrify_gru = False,           # experimental feature that adds a mogrifier for the update and residual before gating by the GRU
    memory_layers = range(6, 13),  # specify which layers to use long-range memory, from 'Do Transformers Need LR Memory' paper
    ff_glu = True                  # use GLU variant for feedforward
)

inputs = torch.randint(0, 256, (1, 2048))
masks = torch.ones_like(inputs).bool()

segments = inputs.reshape(1, -1, 1024).transpose(0, 1)
masks = masks.reshape(1, -1, 1024).transpose(0, 1)

logits, memories, aux_loss = model(segments[0], mask = masks[0])
logits,        _, aux_loss = model(segments[1], mask = masks[1], memories = memories)

# memories is a named tuple that contains the memory (mem) and the compressed memory (cmem)

When training, you can use the AutoregressiveWrapper to have memory management across segments taken care of for you. As easy as it gets.

import torch
from compressive_transformer_pytorch import CompressiveTransformer
from compressive_transformer_pytorch import AutoregressiveWrapper

model = CompressiveTransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 1024,
    mem_len = 1024,
    cmem_len = 256,
    cmem_ratio = 4,
    memory_layers = [5,6]
).cuda()

model = AutoregressiveWrapper(model)

inputs = torch.randint(0, 20000, (1, 2048 + 1)).cuda()

for loss, aux_loss, _ in model(inputs, return_loss = True):
    (loss + aux_loss).backward()
    # optimizer step and zero grad

# ... after much training ...

# generation is also greatly simplified and automated away
# just pass in the prime, which can be 1 start token or any length
# all is taken care of for you

prime = torch.ones(1, 1).cuda()  # assume 1 is start token
sample = model.generate(prime, 4096)

Citations

@misc{rae2019compressive,
    title   = {Compressive Transformers for Long-Range Sequence Modelling},
    author  = {Jack W. Rae and Anna Potapenko and Siddhant M. Jayakumar and Timothy P. Lillicrap},
    year    = {2019},
    eprint  = {1911.05507},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{parisotto2019stabilizing,
    title   = {Stabilizing Transformers for Reinforcement Learning},
    author  = {Emilio Parisotto and H. Francis Song and Jack W. Rae and Razvan Pascanu and Caglar Gulcehre and Siddhant M. Jayakumar and Max Jaderberg and Raphael Lopez Kaufman and Aidan Clark and Seb Noury and Matthew M. Botvinick and Nicolas Heess and Raia Hadsell},
    year    = {2019},
    eprint  = {1910.06764},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@inproceedings{rae-razavi-2020-transformers,
    title   = "Do Transformers Need Deep Long-Range Memory?",
    author  = "Rae, Jack  and
      Razavi, Ali",
    booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
    month   = jul,
    year    = "2020",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url     = "https://www.aclweb.org/anthology/2020.acl-main.672"
}
@article{Shazeer2019FastTD,
    title   = {Fast Transformer Decoding: One Write-Head is All You Need},
    author  = {Noam Shazeer},
    journal = {ArXiv},
    year    = {2019},
    volume  = {abs/1911.02150}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{lan2019albert,
    title       = {ALBERT: A Lite BERT for Self-supervised Learning of Language Representations},
    author      = {Zhenzhong Lan and Mingda Chen and Sebastian Goodman and Kevin Gimpel and Piyush Sharma and Radu Soricut},
    year        = {2019},
    url         = {https://arxiv.org/abs/1909.11942}
}
@misc{ding2021erniedoc,
    title   = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer},
    author  = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang},
    year    = {2021},
    eprint  = {2012.15688},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
posted @ 2024-06-28 14:02  绝不原创的飞龙  阅读(13)  评论(0编辑  收藏  举报