Lucidrains-系列项目源码解析-二十五-

Lucidrains 系列项目源码解析(二十五)

.\lucidrains\med-seg-diff-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的信息
setup(
  # 包名
  name = 'med-seg-diff-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.3.3',
  # 许可证
  license='MIT',
  # 描述
  description = 'MedSegDiff - SOTA medical image segmentation - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/med-seg-diff-pytorch',
  # 关键词
  keywords = [
    'artificial intelligence',
    'deep learning',
    'denoising diffusion',
    'medical segmentation'
  ],
  # 安装依赖
  install_requires = [
    'beartype',
    'einops',
    'lion-pytorch',
    'torch',
    'torchvision',
    'tqdm',
    'accelerate>=0.25.0',
    'wandb'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Medical AI Experiments (wip)

A repository to house some personal attempts to beat some state-of-the-art for medical datasets. Will start with basic arrhythmia detection and work my way up to EEG seizure classification / detection.

I will apply everything that I know from the attention field.

.\lucidrains\medical-chatgpt\medical_chatgpt\medical_chatgpt.py

# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 torch 中导入 nn 和 einsum 模块
from torch import nn, einsum
# 从 einops 中导入 rearrange 函数

from einops import rearrange

# 定义函数,判断变量是否存在
def exists(val):
    return val is not None

# 定义函数,返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

# 定义注意力机制类
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        causal = False,
        dim_head = 64,
        dim_context = None,
        heads = 8,
        norm_context = False,
        num_null_kv = 0,
        dropout = 0.1
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.causal = causal
        inner_dim = dim_head * heads

        dim_context = default(dim_context, dim)

        self.norm = nn.LayerNorm(dim)
        self.context_norm = nn.LayerNorm(dim_context) if norm_context else nn.Identity()

        self.attn_dropout = nn.Dropout(dropout)

        self.num_null_kv = num_null_kv
        self.null_kv = nn.Parameter(torch.randn(2, num_null_kv, dim_head))

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim_context, dim_head * 2, bias = False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            nn.Dropout(dropout)
        )

    def forward(
        self,
        x,
        context = None,
        mask = None,
        attn_bias = None
    ):
        b = x.shape[0]

        if exists(context):
            context = self.context_norm(context)

        kv_input = default(context, x)

        x = self.norm(x)

        q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)

        if self.num_null_kv > 0:
            null_k, null_v = repeat(self.null_kv, 'kv n d -> kv b n d', b = b).unbind(dim = 0)
            k = torch.cat((null_k, k), dim = -2)
            v = torch.cat((null_v, v), dim = -2)

        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)

        q = q * self.scale

        sim = einsum('b h i d, b j d -> b h i j', q, k)

        if exists(attn_bias):
            attn_bias = F.pad(attn_bias, (self.num_null_kv, 0), value = 0.)
            sim = sim + attn_bias

        if exists(mask):
            mask = F.pad(mask, (self.num_null_kv, 0), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        attn = sim.softmax(dim = -1)
        attn = self.attn_dropout(attn)

        out = einsum('b h i j, b j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

.\lucidrains\medical-chatgpt\medical_chatgpt\__init__.py

# 定义一个名为calculate_area的函数,用于计算矩形的面积
def calculate_area(length, width):
    # 计算矩形的面积
    area = length * width
    # 返回计算得到的面积
    return area

Explorations into training a ChatGPT, but tailored towards primary care medicine, with the reward being able to collect patient histories in a thorough and efficient manner and come up with a differential diagnosis. May also explore to see if it can be further fine-tuned on pirated copies of Up-To-Date for specialist knowledge

Sadly, I no longer think this is possible in its current state. It will probably see some utility with scribing of basic bread and butter cases; however assess and plan it cannot.

Citations

@inproceedings{Singhal2022LargeLM,
    title   = {Large Language Models Encode Clinical Knowledge},
    author  = {Karan Singhal and Shekoofeh Azizi and Tao Tu and Said Mahdavi and Jason Lee Kai Wei and Hyung Won Chung and Nathan Scales and Ajay Kumar Tanwani and Heather J. Cole-Lewis and Stephen J. Pfohl and P A Payne and Martin G. Seneviratne and Paul Gamble and Chris Kelly and Nathaneal Scharli and Aakanksha Chowdhery and P. D. Mansfield and Blaise Ag{\"u}era y Arcas and Dale R. Webster and Greg S. Corrado and Y. Matias and Katherine Hui-Ling Chou and Juraj Gottweis and Nenad Toma{\vs}ev and Yun Liu and Alvin Rajkomar and Jo{\"e}lle K. Barral and Christopher Semturs and Alan Karthikesalingam and Vivek Natarajan},
    year    = {2022}
}
@article {Kung2022.12.19.22283643,
    author  = {Kung, Tiffany H. and Cheatham, Morgan and , and Medenilla, Arielle and Sillos, Czarina and De Leon, Lorie and Elepa{\~n}o, Camille and Madriaga, Maria and Aggabao, Rimel and Diaz-Candido, Giezel and Maningo, James and Tseng, Victor},
    title   = {Performance of ChatGPT on USMLE: Potential for AI-Assisted Medical Education Using Large Language Models},
    elocation-id = {2022.12.19.22283643},
    year    = {2022},
    doi     = {10.1101/2022.12.19.22283643},
    publisher = {Cold Spring Harbor Laboratory Press},
    URL     = {https://www.medrxiv.org/content/early/2022/12/21/2022.12.19.22283643},
    eprint  = {https://www.medrxiv.org/content/early/2022/12/21/2022.12.19.22283643.full.pdf},
    journal = {medRxiv}
}
@misc{https://doi.org/10.48550/arxiv.2301.10035,
    doi     = {10.48550/ARXIV.2301.10035},
    url     = {https://arxiv.org/abs/2301.10035},
    author  = {Nov, Oded and Singh, Nina and Mann, Devin},
    keywords = {Human-Computer Interaction (cs.HC), FOS: Computer and information sciences, FOS: Computer and information sciences},
    title   = {Putting ChatGPT's Medical Advice to the (Turing) Test},
    publisher = {arXiv},
    year    = {2023},  
    copyright = {Creative Commons Attribution Share Alike 4.0 International}
}
@inproceedings{Schick2023ToolformerLM,
    title   = {Toolformer: Language Models Can Teach Themselves to Use Tools},
    author  = {Timo Schick and Jane Dwivedi-Yu and Roberto Dessi and Roberta Raileanu and Maria Lomeli and Luke Zettlemoyer and Nicola Cancedda and Thomas Scialom},
    year    = {2023}
}
@inproceedings{Peng2023CheckYF,
    title     = {Check Your Facts and Try Again: Improving Large Language Models with External Knowledge and Automated Feedback},
    author    = {Baolin Peng and Michel Galley and Pengcheng He and Hao Cheng and Yujia Xie and Yu Hu and Qiuyuan Huang and Lars Lid{\'e}n and Zhou Yu and Weizhu Chen and Jianfeng Gao},
    year      = {2023}
}
@inproceedings{Nori2023CapabilitiesOG,
    title   = {Capabilities of GPT-4 on Medical Challenge Problems},
    author  = {Harsha Nori and Nicholas King and Scott Mayer McKinney and Dean Carignan and Eric Horvitz},
    year    = {2023}
}

.\lucidrains\medical-chatgpt\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'medical-chatgpt',  # 包名
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.0.1',  # 版本号
  license='MIT',  # 许可证
  description = 'Medical ChatGPT',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/medical-chatgpt',  # URL
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'reinforcement learning with human feedback'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.6',
    'django-ninja',
    'torch>=1.6',
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

.\lucidrains\Mega-pytorch\mega_pytorch\autoregressive_wrapper.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 函数
from einops import rearrange

# 定义一个辅助函数 exists,用于检查值是否存在
def exists(val):
    return val is not None

# 定义一个装饰器 eval_decorator,用于在模型评估时切换模型状态
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 定义一个函数 top_k,用于对 logits 进行 top-k 过滤
def top_k(logits, thres = 0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 定义一个类 AutoregressiveWrapper,用于包装模型
class AutoregressiveWrapper(nn.Module):
    def __init__(self, net, pad_value = 0):
        super().__init__()
        self.pad_value = pad_value
        self.net = net

    # 生成函数,用于生成序列
    @torch.no_grad()
    @eval_decorator
    def generate(self, start_tokens, seq_len, temperature = 1., filter_thres = 0.9, **kwargs):
        b, t, device = *start_tokens.shape, start_tokens.device

        out = start_tokens

        for _ in range(seq_len):
            logits = self.net(out, **kwargs)[:, -1, :]

            filtered_logits = top_k(logits, thres = filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim=-1)

            sample = torch.multinomial(probs, 1)

            out = torch.cat((out, sample), dim=-1)

        out = out[:, t:]
        return out

    # 前向传播函数,用于计算损失
    def forward(self, x, **kwargs):
        x_inp, x_labels = x[:, :-1], x[:, 1:]
        logits = self.net(x_inp, **kwargs)
        return F.cross_entropy(rearrange(logits, 'b c n -> b n c'), x_labels)

.\lucidrains\Mega-pytorch\mega_pytorch\mega_pytorch.py

# 导入数学库
import math
# 从 functools 库中导入 partial 函数
from functools import partial

# 导入 torch 库
import torch
# 从 torch.nn.functional 模块中导入 F
import torch.nn.functional as F
# 从 torch 模块中导入 nn 和 einsum
from torch import nn, einsum
# 从 torch.fft 模块中导入 rfft 和 irfft
from torch.fft import rfft, irfft

# 从 einops 库中导入 rearrange 和 Rearrange
from einops import rearrange
from einops.layers.torch import Rearrange

# 从 scipy.fftpack 模块中导入 next_fast_len 函数

# functions

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 返回输入的函数
def identity(t, *args, **kwargs):
    return t

# 如果输入值存在则返回输入值,否则返回默认值的函数
def default(val, d):
    return val if exists(val) else d

# 在输入张量的末尾添加指定数量的维度的函数
def append_dims(x, num_dims):
    if num_dims <= 0:
        return x
    return x.view(*x.shape, *((1,) * num_dims))

# 使用傅立叶技巧进行 O(N log(N)) 的 1D 卷积的函数
def conv1d_fft(x, weights, dim = -2, weight_dim = -1):
    # O(N log(N)) 1d convolution using some fourier trick

    assert weight_dim >= dim

    N = x.shape[dim]
    M = weights.shape[weight_dim]

    fast_len = next_fast_len(N + M - 1)

    f_x = rfft(x, n = fast_len, dim = dim)
    f_weight = rfft(weights, n = fast_len, dim = weight_dim)

    f_v_weight = f_x * append_dims(f_weight.conj(), weight_dim - dim)
    out = irfft(f_v_weight, fast_len, dim = dim)
    out = out.roll(-1, dims = (dim,))

    indices = torch.arange(start = fast_len - N, end = fast_len, dtype = torch.long, device = x.device)
    out = out.index_select(dim, indices)
    return out

# 用于单头注意力的位置偏置类
class T5RelativePositionBias(nn.Module):
    def __init__(
        self,
        scale,
        causal = False,
        num_buckets = 32,
        max_distance = 128
    ):
        super().__init__()
        self.scale = scale
        self.causal = causal
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.relative_attention_bias = nn.Embedding(num_buckets, 1)

    @staticmethod
    def _relative_position_bucket(
        relative_position,
        causal = True,
        num_buckets = 32,
        max_distance = 128
    ):
        ret = 0
        n = -relative_position
        if not causal:
            num_buckets //= 2
            ret += (n < 0).long() * num_buckets
            n = torch.abs(n)
        else:
            n = torch.max(n, torch.zeros_like(n))

        max_exact = num_buckets // 2
        is_small = n < max_exact

        val_if_large = max_exact + (
            torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
        ).long()
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))

        ret += torch.where(is_small, n, val_if_large)
        return ret

    def forward(self, x):
        i, j, device = *x.shape[-2:], x.device
        q_pos = torch.arange(i, dtype = torch.long, device = device)
        k_pos = torch.arange(j, dtype = torch.long, device = device)
        rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
        rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
        values = self.relative_attention_bias(rp_bucket)
        bias = rearrange(values, 'i j 1 -> i j')
        return bias * self.scale

# classes

# 拉普拉斯注意力函数类
class LaplacianAttnFn(nn.Module):
    def forward(self, x):
        mu = math.sqrt(0.5)
        std = math.sqrt((4 * math.pi) ** -1)
        return (1 + torch.special.erf((x - mu) / (std * math.sqrt(2)))) * 0.5

# 偏移和缩放类
class OffsetScale(nn.Module):
    def __init__(self, dim, heads = 1):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(heads, dim))
        self.beta = nn.Parameter(torch.zeros(heads, dim))
        nn.init.normal_(self.gamma, std = 0.02)

    def forward(self, x):
        out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
        return out.unbind(dim = -2)

# 单头注意力类
class SingleHeadedAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_qk,
        dim_value,
        causal = False,
        laplacian_attn_fn = False
    # 初始化 Transformer 层
    def __init__(
        self,
        causal: bool = False,
        laplacian_attn_fn: bool = False
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 设置是否使用因果关系和 Laplacian 注意力函数
        self.causal = causal
        self.laplacian_attn_fn = laplacian_attn_fn

        # 根据是否使用 Laplacian 注意力函数选择不同的注意力函数
        self.attn_fn = partial(F.softmax, dim = -1) if not laplacian_attn_fn else LaplacianAttnFn()

        # 初始化相对位置偏置
        self.rel_pos_bias = T5RelativePositionBias(causal = causal, scale = dim_qk ** 0.5)

        # 将输入转换为查询和键值对
        self.to_qk = nn.Sequential(
            nn.Linear(dim, dim_qk),
            nn.SiLU()
        )

        # 初始化偏移和缩放层
        self.offsetscale = OffsetScale(dim_qk, heads = 2)

        # 将输入转换为值
        self.to_v = nn.Sequential(
            nn.Linear(dim, dim_value),
            nn.SiLU()
        )

    # 前向传播函数
    def forward(self, x, v_input = None):
        # 获取序列长度、维度、设备和数据类型
        seq_len, dim, device, dtype = *x.shape[-2:], x.device, x.dtype

        # 如果未提供值输入,则使用 x 作为值输入
        v_input = default(v_input, x)

        # 将输入转换为查询、键和值
        qk, v = self.to_qk(x), self.to_v(v_input)
        q, k = self.offsetscale(qk)

        # 计算缩放因子
        scale = (seq_len ** -1) if self.laplacian_attn_fn else (dim ** -0.5)

        # 计算注意力矩阵
        sim = einsum('b i d, b j d -> b i j', q, k) * scale

        # 添加相对位置偏置
        sim = sim + self.rel_pos_bias(sim)

        # 如果使用因果关系,则创建因果 mask
        if self.causal:
            causal_mask = torch.ones((seq_len, seq_len), device = device, dtype = torch.bool).triu(1)

        # 如果使用因果关系且不使用 Laplacian 注意力函数,则对注意力矩阵进行 mask 处理
        if self.causal and not self.laplacian_attn_fn:
            # 如果是 softmax 注意力并且使用大的负值作为 softmax 前的值
            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # 计算注意力权重
        attn = self.attn_fn(sim)

        # 如果使用因果关系且使用 Laplacian 注意力函数,则将上三角部分置为 0
        if self.causal and self.laplacian_attn_fn:
            # 如果使用 Laplacian 注意力函数,则将上三角部分置为 0
            attn = attn.masked_fill(causal_mask, 0.)

        # 计算输出值
        return einsum('b i j, b j d -> b i d', attn, v)
class MultiHeadedEMA(nn.Module):
    # 定义多头EMA模块
    def __init__(
        self,
        *,
        dim,
        heads,
        bidirectional = False,
        norm_mhesa_heads = False
    ):
        # 初始化函数
        super().__init__()
        self.bidirectional = bidirectional

        # 初始化参数
        self.expansion = nn.Parameter(torch.randn(heads * (2 if bidirectional else 1), dim))
        self.reduction = nn.Parameter(torch.randn(heads * (2 if bidirectional else 1), dim))

        # 学习的alpha和阻尼因子

        self.alphas = nn.Parameter(torch.randn(heads))
        self.dampen_factors = nn.Parameter(torch.randn(heads))

        if bidirectional:
            self.reverse_alphas = nn.Parameter(torch.randn(heads))
            self.reverse_dampen_factors = nn.Parameter(torch.randn(heads))

        self.heads = heads

        self.norm_heads = nn.Identity()

        if norm_mhesa_heads:
            # 使用子层归一化作为组归一化
            self.norm_heads = nn.Sequential(
                Rearrange('b n h d -> b (h d) n'),
                nn.GroupNorm(heads, dim * heads),
                Rearrange('b (h d) n -> b n h d', h = heads)
            )

    def forward(self, x):
        # 前向传播函数
        device, seq_len = x.device, x.shape[1]

        # 投影并分割头部
        x = einsum('... d, h d -> ... h d', x, self.expansion)

        if self.bidirectional:
            x, x_reversed = x.chunk(2, dim = -2)
            x_reversed = torch.flip(x_reversed, dims = (1,))

        # 从alphas派生的权重(学习的指数平滑衰减率)
        def apply_learned_ema_with_damping(x, alphas, dampen_factors):
            alphas = alphas.sigmoid()
            dampen_factors = dampen_factors.sigmoid()

            reversed_powers = torch.arange(seq_len - 1, -1, -1, device = device)
            K = alphas * (((1 - alphas) * dampen_factors) ** rearrange(reversed_powers, '... l -> ... l 1'))

            # 使用conv1d fft计算
            return conv1d_fft(x, K, dim = -3, weight_dim = -2)

        x = apply_learned_ema_with_damping(x, self.alphas, self.dampen_factors)

        if self.bidirectional:
            x_reversed = apply_learned_ema_with_damping(x_reversed, self.reverse_alphas, self.reverse_dampen_factors)
            x_reversed = torch.flip(x_reversed, dims = (1,))
            x = torch.cat((x, x_reversed), dim = -2)

        # 可能归一化头部
        x = self.norm_heads(x)

        # 合并头部和输出
        return einsum('... h d, h d -> ... d', x, self.reduction)

# Mega Layer
# 单头注意力 + 多头EMA,然后是类似GRU的门控

class MegaLayer(nn.Module):
    # 定义MegaLayer模块
    def __init__(
        self,
        *,
        dim = 128,
        ema_heads = 16,
        attn_dim_qk = 64,
        attn_dim_value = 256,
        laplacian_attn_fn = False,
        causal = True,
        norm_mhesa_heads = False
    ):
        # 初始化函数
        super().__init__()

        # 单头注意力
        self.single_headed_attn = SingleHeadedAttention(
            dim = dim,
            dim_qk = attn_dim_qk,
            dim_value = attn_dim_value,
            causal = causal,
            laplacian_attn_fn = laplacian_attn_fn
        )

        # 多头EMA
        self.multi_headed_ema = MultiHeadedEMA(
            dim = dim,
            heads = ema_heads,
            bidirectional = not causal,
            norm_mhesa_heads = norm_mhesa_heads
        )

        # 重置门
        self.to_reset_gate = nn.Sequential(
            nn.Linear(dim, attn_dim_value),
            nn.SiLU()
        )

        # 更新门
        self.to_update_gate = nn.Sequential(
            nn.Linear(dim, dim),
            nn.Sigmoid()
        )

        # 计算H的方程式14
        self.Wh = nn.Parameter(torch.randn(dim, dim))
        self.Uh = nn.Parameter(torch.randn(attn_dim_value, dim))
        self.bh = nn.Parameter(torch.randn(dim))
    # 定义前向传播函数,接受输入 x 和残差 residual,默认为 None
    def forward(self, x, residual = None):
        # 如果没有传入残差,则使用 x 作为默认值
        residual = default(residual, x)

        # 使用多头 EMA 模型处理输入 x
        ema_output = self.multi_headed_ema(x)
        # 使用单头注意力模型处理 EMA 输出和输入 x
        attn_output = self.single_headed_attn(ema_output, x)

        # 计算重置门和更新门
        reset_gate = self.to_reset_gate(ema_output)
        update_gate = self.to_update_gate(ema_output)

        # 使用重置门对注意力输出进行门控
        gated_attn_output = attn_output * reset_gate

        # 计算 H,根据方程式 14
        H = F.silu(ema_output @ self.Wh + gated_attn_output @ self.Uh + self.bh)

        # 更新门
        return update_gate * H + (1 - update_gate) * residual
# 定义一个前馈神经网络层,包括线性层、GELU激活函数和另一个线性层
def FeedForward(dim, ff_mult):
    # 计算隐藏层维度
    dim_hidden = int(dim * ff_mult)
    return nn.Sequential(
        nn.Linear(dim, dim_hidden),  # 输入维度为dim,输出维度为dim_hidden的线性层
        nn.GELU(),  # GELU激活函数
        nn.Linear(dim_hidden, dim)  # 输入维度为dim_hidden,输出维度为dim的线性层
    )

# 定义一个Mega类,继承自nn.Module
class Mega(nn.Module):
    def __init__(
        self,
        *,
        dim,
        num_tokens,
        depth,
        ff_mult = 2,
        pre_norm = False,
        **kwargs
    ):
        super().__init__()
        self.token_emb = nn.Embedding(num_tokens, dim)  # 创建一个嵌入层,用于将token映射为dim维向量
        self.pre_norm = pre_norm  # 是否使用预层归一化

        self.layers = nn.ModuleList([])  # 创建一个空的ModuleList,用于存储多个MegaLayer

        # 循环depth次,创建多个MegaLayer及其相关层,并添加到layers中
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                MegaLayer(dim = dim, **kwargs),  # MegaLayer层
                nn.LayerNorm(dim),  # LayerNorm层
                FeedForward(dim = dim, ff_mult = ff_mult),  # FeedForward层
                nn.LayerNorm(dim)  # LayerNorm层
            ]))

        # 创建一个Sequential模块,用于将模型输出映射为num_tokens维度
        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim) if pre_norm else nn.Identity(),  # 如果使用预层归一化,则使用LayerNorm,否则使用Identity
            nn.Linear(dim, num_tokens)  # 线性层,将dim维度映射为num_tokens维度
        )

    # 前向传播函数
    def forward(self, x):
        pre_norm = self.pre_norm
        post_norm = not self.pre_norm

        x = self.token_emb(x)  # 将输入的token映射为dim维度的向量

        # 遍历layers中的每个MegaLayer及其相关层
        for mega_layer, mega_norm, ff, ff_norm in self.layers:
            mega_maybe_prenorm = mega_norm if pre_norm else identity
            ff_maybe_prenorm = ff_norm if pre_norm else identity

            mega_maybe_postnorm = mega_norm if post_norm else identity
            ff_maybe_postnorm = ff_norm if post_norm else identity

            x = mega_layer(mega_maybe_prenorm(x), x)  # MegaLayer的前向传播

            x = mega_maybe_postnorm(x)  # 可能的后层归一化

            x = ff(ff_maybe_prenorm(x)) + x  # FeedForward层的前向传播

            x = ff_maybe_postnorm(x)  # 可能的后层归一化

        return self.to_logits(x)  # 将输出映射为num_tokens维度

.\lucidrains\Mega-pytorch\mega_pytorch\__init__.py

# 从 mega_pytorch.mega_pytorch 模块中导入 MegaLayer, Mega, MultiHeadedEMA 类
from mega_pytorch.mega_pytorch import MegaLayer, Mega, MultiHeadedEMA

Mega - Moving Average Equipped Gated Attention - Pytorch

Implementation of the Mega layer, the Single-head Attention with Multi-headed EMA layer that exists in the architecture that currently holds SOTA on Long Range Arena, beating S4 on Pathfinder-X and all the other tasks save for audio.

Install

$ pip install mega-pytorch

Usage

The Mega Layer with combination of attention and learned EMA

import torch
from mega_pytorch import MegaLayer

layer = MegaLayer(
    dim = 128,                   # model dimensions
    ema_heads = 16,              # number of EMA heads
    attn_dim_qk = 64,            # dimension of queries / keys in attention
    attn_dim_value = 256,        # dimension of values in attention
    laplacian_attn_fn = False,   # whether to use softmax (false) or laplacian attention activation fn (true)
)

x = torch.randn(1, 1024, 128)     # (batch, seq, dim)

out = layer(x) # (1, 1024, 128)

Full Mega (with layernorm for now)

import torch
from mega_pytorch import Mega

mega = Mega(
    num_tokens = 256,            # number of tokens
    dim = 128,                   # model dimensions
    depth = 6,                   # depth
    ema_heads = 16,              # number of EMA heads
    attn_dim_qk = 64,            # dimension of queries / keys in attention
    attn_dim_value = 256,        # dimensino of values in attention
    laplacian_attn_fn = True,    # whether to use softmax (false) or laplacian attention activation fn (true)
)

x = torch.randint(0, 256, (1, 1024))

logits = mega(x) # (1, 1024, 256)

Todo

Citations

@inproceedings{Ma2022MegaMA,
    title   = {Mega: Moving Average Equipped Gated Attention},
    author  = {Xuezhe Ma and Chunting Zhou and Xiang Kong and Junxian He and Liangke Gui and Graham Neubig and Jonathan May and Luke Zettlemoyer},
    year    = {2022}
}

.\lucidrains\Mega-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的信息
setup(
  # 包的名称
  name = 'Mega-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.1.0',
  # 许可证
  license='MIT',
  # 描述
  description = 'Mega - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/Mega-pytorch',
  # 关键词
  keywords = [
    'artificial intelligence',
    'deep learning',
    'attention mechanism',
    'exponential moving average',
    'long range arena'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.4',
    'scipy',
    'torch>=1.6',
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\Mega-pytorch\train.py

# 导入所需的库
from mega_pytorch.mega_pytorch import Mega
from mega_pytorch.autoregressive_wrapper import AutoregressiveWrapper

import argparse
import random
import tqdm
import gzip
import numpy as np

import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512

# 定义辅助函数

# 从 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 从 tokens 解码为字符串
def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# 实例化类似 GPT 的解码器模型
model = Mega(
    num_tokens = 256,
    dim = 512,
    depth = 8
)

model = AutoregressiveWrapper(model)

model.cuda()

# 准备 enwik8 数据

with gzip.open('./data/enwik8.gz') as file:
    x = np.array(np.frombuffer(file.read(int(95e6)), dtype = np.uint8))
    train_x, valid_x = np.split(x, [int(90e6)])
    data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)

# 定义文本采样数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f"\n\n {prime} \n\n {'-' * 80} \n")

        sample = model.generate(inp[None, ...], GENERATE_LENGTH)
        output_str = decode_tokens(sample[0])
        print(output_str + "\n\n")

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

.\lucidrains\MEGABYTE-pytorch\MEGABYTE_pytorch\attend.py

# 导入必要的库
from collections import namedtuple
from functools import wraps
from packaging import version

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange

# 定义一个命名元组EfficientAttentionConfig,包含三个布尔类型的参数
EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义一个辅助函数,用于检查值是否存在
def exists(val):
    return val is not None

# 定义一个装饰器函数,确保被装饰的函数只执行一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 用once装饰print函数,确保只打印一次
print_once = once(print)

# 主要类Attend
class Attend(nn.Module):
    def __init__(
        self,
        causal = False,
        dropout = 0.,
        flash = False
    ):
        super().__init__()
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.causal = causal
        self.flash = flash
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定用于cuda和cpu的高效注意力配置

        self.cpu_config = EfficientAttentionConfig(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = EfficientAttentionConfig(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = EfficientAttentionConfig(False, True, True)

    # 生成掩码
    def get_mask(self, i, j, device):
        return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1)

    # Flash Attention函数
    def flash_attn(self, q, k, v, mask = None, attn_bias = None):
        _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

        # 单头键/值
        if k.ndim == 3:
            k = rearrange(k, 'b n d -> b 1 n d')

        if v.ndim == 3:
            v = rearrange(v, 'b n d -> b 1 n d')

        # 检查掩码是否存在并扩展到兼容的形状
        if exists(mask) and mask.ndim != 4:
            mask = rearrange(mask, 'b j -> b 1 1 j')
            mask = mask.expand(-1, heads, q_len, -1)

        # 检查是否有兼容的设备用于Flash Attention
        config = self.cuda_config if is_cuda else self.cpu_config

        # 使用torch.backends.cuda.sdp_kernel(**config._asdict())来执行pytorch 2.0的flash attention
        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask = mask,
                dropout_p = self.dropout if self.training else 0., 
                is_causal = self.causal
            )

        return out
    # 定义一个前向传播函数,用于计算注意力机制中的查询、键、值以及掩码
    def forward(self, q, k, v, mask = None):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        # 获取查询和键的序列长度,以及设备信息
        q_len, k_len, device = q.shape[-2], k.shape[-2], q.device

        # 计算缩放因子
        scale = q.shape[-1] ** -0.5

        # 根据键的维度确定 einsum 的等式
        kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'

        # 如果启用了 flash 注意力机制,则调用相应函数
        if self.flash:
            return self.flash_attn(q, k, v, mask = mask)

        # 计算相似度
        sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale

        # 如果启用了因果掩码
        if self.causal:
            # 获取因果掩码
            causal_mask = self.get_mask(q_len, k_len, device)
            # 将掩码应用到相似度矩阵中
            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # 计算注意力权重
        attn = sim.softmax(dim=-1)
        attn = self.attn_dropout(attn)

        # 聚合值
        out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)

        return out

.\lucidrains\MEGABYTE-pytorch\MEGABYTE_pytorch\megabyte.py

# 导入数学库
import math
# 导入 functools 库
import functools
# 从 itertools 库中导入 zip_longest 函数
from itertools import zip_longest

# 导入 torch 库
import torch
# 从 torch.nn.functional 中导入 F
import torch.nn.functional as F
# 从 torch 中导入 nn, einsum
from torch import nn, einsum

# 从 einops 库中导入 rearrange, reduce, repeat, pack, unpack
from einops import rearrange, reduce, repeat, pack, unpack
# 从 einops.layers.torch 中导入 Rearrange
from einops.layers.torch import Rearrange

# 从 beartype 库中导入 beartype
from beartype import beartype
# 从 beartype.typing 中导入 Tuple, Union
from beartype.typing import Tuple, Union

# 从 MEGABYTE_pytorch.attend 中导入 Attend
from MEGABYTE_pytorch.attend import Attend

# 从 tqdm 中导入 tqdm
from tqdm import tqdm

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 将单个张量按照指定模式打包
def pack_one(t, pattern):
    return pack([t], pattern)

# 将单个张量按照指定模式解包
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 计算使 num 变为 mult 的倍数的余数
def remainder_to_mult(num, mult):
    return (mult - num % mult) % mult

# 将输入转换为元组,如果输入不是元组则重复 length 次
def cast_tuple(t, length = 1):
    return t if isinstance(t, tuple) else ((t,) * length)

# 计算多个数的乘积
def reduce_mult(nums):
    return functools.reduce(lambda x, y: x * y, nums, 1)

# 张量辅助函数

# 计算张量的自然对数,避免小于 eps 的值
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps)

# 生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# 生成 Gumbel 分布采样
def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)

# 保留前 k 个最大值,其余设为负无穷
def top_k(logits, thres = 0.5):
    num_logits = logits.shape[-1]
    k = max(int((1 - thres) * num_logits), 1)
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# Token Shift,从 Peng et al of RWKV 中借鉴
def token_shift(t):
    t, t_shift = t.chunk(2, dim = -1)
    t_shift = F.pad(t_shift, (0, 0, 1, -1))
    return torch.cat((t, t_shift), dim = -1)

# 旋转位置嵌入
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, theta = 10000):
        super().__init__()
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    @property
    def device(self):
        return next(self.buffers()).device

    def forward(self, seq_len):
        t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq)
        freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
        freqs = torch.cat((freqs, freqs), dim = -1)
        return freqs

# 旋转半个张量
def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

# 应用旋转位置嵌入
def apply_rotary_pos_emb(pos, t):
    return t * pos.cos() + rotate_half(t) * pos.sin()

# 归一化
class RMSNorm(nn.Module):
    def __init__(self, dim, eps = 1e-8):
        super().__init__()
        self.scale = dim ** -0.5
        self.eps = eps
        self.g = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
        return x / norm.clamp(min = self.eps) * self.g

# 辅助类

# 创建 FeedForward 网络
def FeedForward(*, dim, mult = 4, dropout = 0.):
    return nn.Sequential(
        RMSNorm(dim),
        nn.Linear(dim, dim * mult),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(dim * mult, dim)
    )

# 注意力机制
class Attention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        dropout = 0.,
        flash = False
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads

        self.attend = Attend(
            causal = True,
            flash = flash,
            dropout = dropout
        )

        self.dropout = nn.Dropout(dropout)
        self.norm = RMSNorm(dim)
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)
    # 实现 Transformer 模型的前向传播过程
    def forward(self, x, rotary_emb = None):
        # 获取头数和设备信息
        h, device = self.heads, x.device

        # 对输入进行归一化处理
        x = self.norm(x)
        # 将输入 x 分别转换为查询 q、键 k、值 v
        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
        # 将查询 q 重新排列为形状为 'b h n d' 的张量
        q = rearrange(q, 'b n (h d) -> b h n d', h = h)

        # 如果存在旋转位置编码,则对查询 q 和键 k 应用旋转位置编码
        if exists(rotary_emb):
            q, k = map(lambda t: apply_rotary_pos_emb(rotary_emb, t), (q, k))

        # 使用注意力机制进行注意力计算
        out = self.attend(q, k, v)

        # 将输出重新排列为形状为 'b n (h d)' 的张量
        out = rearrange(out, 'b h n d -> b n (h d)')
        # 将输出转换为最终输出
        return self.to_out(out)
# 定义一个名为 Transformer 的类,继承自 nn.Module
class Transformer(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        *,
        dim,  # 维度
        layers,  # 层数
        dim_head = 64,  # 头部维度
        heads = 8,  # 头部数量
        attn_dropout = 0.,  # 注意力机制的 dropout
        ff_dropout = 0.,  # 前馈神经网络的 dropout
        ff_mult = 4,  # 前馈神经网络的倍数
        rel_pos = True,  # 是否使用相对位置编码
        flash_attn = False  # 是否使用 Flash 注意力机制
    ):
        super().__init__()  # 调用父类的初始化函数
        self.rotary_emb = RotaryEmbedding(dim_head) if rel_pos else None  # 如果使用相对位置编码,则创建旋转嵌入对象,否则为 None
        self.layers = nn.ModuleList([])  # 创建一个空的 nn.ModuleList 对象

        # 循环创建指定层数的注意力机制和前馈神经网络
        for _ in range(layers):
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = flash_attn),  # 创建注意力机制对象
                FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)  # 创建前馈神经网络对象
            ]))

        self.norm = RMSNorm(dim)  # 创建 RMS 归一化对象

    # 前向传播函数,接受输入 x
    def forward(self, x):
        n = x.shape[-2]  # 获取输入 x 的倒数第二维度大小
        rotary_emb = self.rotary_emb(n) if exists(self.rotary_emb) else None  # 如果存在旋转嵌入对象,则根据 n 创建旋转嵌入,否则为 None

        # 遍历每一层的注意力机制和前馈神经网络
        for attn, ff in self.layers:
            x = attn(token_shift(x), rotary_emb = rotary_emb) + x  # 执行注意力机制和残差连接
            x = ff(token_shift(x)) + x  # 执行前馈神经网络和残差连接

        return self.norm(x)  # 返回经过归一化的结果

# 主类 MEGABYTE
class MEGABYTE(nn.Module):

    @beartype
    # 初始化函数,接受多个参数
    def __init__(
        self,
        *,
        num_tokens,  # 标记数量
        dim: Union[Tuple, int],  # 维度
        depth: Tuple,  # 深度
        max_seq_len: Tuple,  # 最大序列长度
        dim_head = 64,  # 头部维度
        heads = 8,  # 头部数量
        attn_dropout = 0.,  # 注意力机制的 dropout
        ff_mult = 4,  # 前馈神经网络的倍数
        ff_dropout = 0.,  # 前馈神经网络的 dropout
        pad_id = 0,  # 填充��记的 id
        rel_pos = False,  # 是否使用相对位置编码
        pos_emb = False,  # 是否使用位置嵌入
        flash_attn = False  # 是否使用 Flash 注意力机制
    ):
        # 调用父类的构造函数
        super().__init__()

        # 简化每个层次的配置
        # depth = (2, 2, 4) 表示第一阶段深度为2,第二阶段深度为2,第三阶段深度为4
        # max_seq_len = (16, 8, 4) 表示第一阶段最大序列长度为16,第二阶段为8,最后一阶段为4
        assert isinstance(depth, tuple) and isinstance(max_seq_len, tuple)
        assert len(depth) == len(max_seq_len)

        self.stages = len(depth)
        dim = cast_tuple(dim, self.stages)

        assert len(dim) == self.stages

        coarsest_dim, *_, fine_dim = dim

        self.max_seq_len = max_seq_len

        # 初始化起始 token
        self.start_tokens = nn.ParameterList([nn.Parameter(torch.randn(h_dim)) for h_dim, seq_len in zip(dim, max_seq_len)])
        # 初始化位置嵌入
        self.pos_embs = nn.ModuleList([nn.Embedding(seq_len, h_dim) for h_dim, seq_len in zip(dim, max_seq_len)]) if pos_emb else None

        self.token_embs = nn.ModuleList([])

        patch_size = 1
        # 添加 token 嵌入
        self.token_embs.append(nn.Embedding(num_tokens, fine_dim))

        for dim_out, seq_len in zip(reversed(dim[:-1]), reversed(max_seq_len[1:])):
            patch_size *= seq_len

            # 构建 token 嵌入的序列
            self.token_embs.append(nn.Sequential(
                nn.Embedding(num_tokens, fine_dim),
                Rearrange('... r d -> ... (r d)'),
                nn.LayerNorm(patch_size * fine_dim),
                nn.Linear(patch_size * fine_dim, dim_out),
                nn.LayerNorm(dim_out)
            ))

        self.transformers = nn.ModuleList([])
        self.to_next_transformer_projections = nn.ModuleList([])

        for h_dim, next_h_dim, stage_depth, next_seq_len in zip_longest(dim, dim[1:], depth, max_seq_len[1:]):
            # 添加 Transformer 模块
            self.transformers.append(Transformer(
                dim = h_dim,
                layers = stage_depth,
                dim_head = dim_head,
                heads = heads,
                attn_dropout = attn_dropout,
                ff_dropout = ff_dropout,
                ff_mult = ff_mult,
                rel_pos = rel_pos,
                flash_attn = flash_attn
            ))

            proj = nn.Identity()

            if exists(next_h_dim) and next_h_dim != dim:
                proj = nn.Sequential(
                    Rearrange('b ... d -> b (...) d'),
                    nn.Linear(h_dim, next_h_dim * next_seq_len),
                    Rearrange('b m (n d) -> (b m) n d', n = next_seq_len)
                )

            self.to_next_transformer_projections.append(proj)

        # 线性层,用于输出 logits
        self.to_logits = nn.Linear(fine_dim, num_tokens)
        self.pad_id = pad_id

    # 生成文本
    def generate(self, prime = None, filter_thres = 0.9, temperature = 1., default_batch_size = 1):
        total_seq_len = reduce_mult(self.max_seq_len)
        device = next(self.parameters()).device

        if not exists(prime):
            prime = torch.empty((default_batch_size, 0), dtype = torch.long, device = device)

        seq = prime
        batch = seq.shape[0]

        # 生成文本序列
        for _ in tqdm(range(total_seq_len - seq.shape[-1])):
            logits = self.forward(seq)[:, -1]
            logits = top_k(logits, thres = filter_thres)
            sampled = gumbel_sample(logits, dim = -1, temperature = temperature)
            seq = torch.cat((seq, rearrange(sampled, 'b -> b 1')), dim = -1)

        return seq.reshape(batch, *self.max_seq_len)
    # 定义一个方法,用于处理特殊情况,即从输入为0(仅起始标记)中进行采样
    def forward_empty(self, batch_size):
        # 初始化前一个阶段的标记表示为空
        prev_stage_tokens_repr = None

        # 遍历起始标记、变换器和投影器,分别对应每个阶段
        for stage_start_tokens, transformer, proj in zip(self.start_tokens, self.transformers, self.to_next_transformer_projections):
            # 将起始标记重复扩展到指定批次大小
            tokens = repeat(stage_start_tokens, 'd -> b 1 d', b = batch_size)

            # 如果前一个阶段的标记表示存在,则将其与当前阶段的标记相加
            if exists(prev_stage_tokens_repr):
                tokens = tokens + prev_stage_tokens_repr[..., :tokens.shape[-2], :]

            # 经过变换器处理标记
            tokens = transformer(tokens)
            # 通过投影器得到当前阶段的标记表示
            prev_stage_tokens_repr = proj(tokens)

        # 返回标记转换为对数概率的结果
        return self.to_logits(tokens)
    # 定义前向传播函数,接受输入 ids 和是否返回损失值的标志
    def forward(self, ids, return_loss = False):
        # 获取批量大小
        batch = ids.shape[0]

        # 断言输入 ids 的维度为 2 或者 self.stages + 1
        assert ids.ndim in {2, self.stages + 1}
        # 检查是否为扁平化维度
        flattened_dims = ids.ndim == 2
        ids_orig_ndim = ids.ndim

        # 如果 ids 为空,则调用 forward_empty 函数
        if ids.numel() == 0:
            return self.forward_empty(ids.shape[0])

        # 如果为扁平化维度,则自动填充到最接近深度序列长度的倍数
        if flattened_dims:
            # 获取序列长度
            seq_len = ids.shape[-1]
            # 计算填充值
            multiple_of = reduce_mult(self.max_seq_len[1:])
            padding = remainder_to_mult(seq_len, multiple_of)
            # 对 ids 进行填充
            ids = F.pad(ids, (0, padding), value = self.pad_id)
            ids = ids.reshape(batch, -1, *self.max_seq_len[1:])

        # 获取 ids 的形状和设备信息
        b, *prec_dims, device = *ids.shape, ids.device

        # 检查一些维度

        assert prec_dims[0] <= self.max_seq_len[0], 'the first dimension of your axial autoregressive transformer must be less than the first tuple element of max_seq_len (like any autoregressive transformer)'
        assert tuple(prec_dims[1:]) == tuple(self.max_seq_len[1:]), 'all subsequent dimensions must match exactly'

        # 获取所有层次阶段的 tokens,减少适当的维度并添加绝对位置嵌入

        tokens_at_stages = []
        pos_embs = default(self.pos_embs, (None,))

        for ind, pos_emb, token_emb in zip_longest(range(len(prec_dims)), pos_embs, self.token_embs):
            is_first = ind == 0

            tokens = token_emb(ids)

            if exists(pos_emb):
                positions = pos_emb(torch.arange(tokens.shape[-2], device = device))
                tokens = tokens + positions

            tokens_at_stages.insert(0, tokens)

            if is_first:
                continue

            ids = rearrange(ids, '... m n -> ... (m n)')

        # 上一个层次结构的未像素化表示,从 None 开始

        prev_stage_tokens_repr = None

        # 空间 tokens 是在深度 pos 减少的 tokens + 空间位置

        for stage_start_tokens, stage_tokens, transformer, proj in zip(self.start_tokens, tokens_at_stages, self.transformers, self.to_next_transformer_projections):
            stage_tokens, ps = pack_one(stage_tokens, '* n d')
            stage_start_tokens = repeat(stage_start_tokens, 'f -> b 1 f', b = stage_tokens.shape[0])

            # 连接起始 token

            stage_tokens = torch.cat((
                stage_start_tokens,
                stage_tokens,
            ), dim = -2)

            # 对上一个层次结构的表示求和

            if exists(prev_stage_tokens_repr):
                prev_stage_tokens_repr = F.pad(prev_stage_tokens_repr, (0, 0, 1, 0), value = 0.)
                stage_tokens = stage_tokens + prev_stage_tokens_repr

            attended = transformer(stage_tokens)

            attended = unpack_one(attended, ps, '* n d')

            # 为下一个层次结构投影

            prev_stage_tokens_repr = proj(attended[..., :-1, :])

        # 投影到 logits

        logits = self.to_logits(attended)

        start_tokens = logits[(slice(None), *((0,) * (logits.ndim - 2)), slice(None)]
        start_tokens = rearrange(start_tokens, 'b d -> b 1 d')

        logits = logits[..., 1:, :]

        if not return_loss:

            if flattened_dims:
                logits = rearrange(logits, 'b ... c -> b (...) c')
                logits = logits[:, :seq_len]

            return logits

        logits = rearrange(logits, 'b ... c -> b (...) c')
        logits = torch.cat((start_tokens, logits), dim = -2)

        preds = rearrange(logits, 'b n c -> b c n')
        labels = rearrange(ids, 'b ... -> b (...)')

        loss = F.cross_entropy(
            preds[..., :-1],
            labels,
            ignore_index = self.pad_id
        )

        return loss

.\lucidrains\MEGABYTE-pytorch\MEGABYTE_pytorch\__init__.py

# 从MEGABYTE_pytorch包中导入MEGABYTE类
from MEGABYTE_pytorch.megabyte import MEGABYTE

MEGABYTE - Pytorch

Implementation of MEGABYTE, Predicting Million-byte Sequences with Multiscale Transformers, in Pytorch. Took the liberty to generalize it even further so one can have multiple local models.

Similar independent research that is a further generalization

Appreciation

  • Stability and 🤗 Huggingface for the generous sponsorship to work on and open source cutting edge artificial intelligence research

Install

$ pip install MEGABYTE-pytorch

Usage

import torch
from MEGABYTE_pytorch import MEGABYTE

model = MEGABYTE(
    num_tokens = 16000,             # number of tokens
    dim = (512, 256),               # transformer model dimension (512 for coarsest, 256 for fine in this example)
    max_seq_len = (1024, 4),        # sequence length for global and then local. this can be more than 2
    depth = (6, 4),                 # number of layers for global and then local. this can be more than 2, but length must match the max_seq_len's
    dim_head = 64,                  # dimension per head
    heads = 8,                      # number of attention heads
    flash_attn = True               # use flash attention
)

x = torch.randint(0, 16000, (1, 1024, 4))

loss = model(x, return_loss = True)
loss.backward()

# then after much training

logits = model(x)

# and sample from the logits accordingly
# or you can use the generate function

sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4)

Test

Train on character-level enwik8 with patches of size 4 - length 8192

$ python train.py

Citations

@misc{yu2023megabyte,
    title   = {MEGABYTE: Predicting Million-byte Sequences with Multiscale Transformers}, 
    author  = {Lili Yu and Dániel Simig and Colin Flaherty and Armen Aghajanyan and Luke Zettlemoyer and Mike Lewis},
    year    = {2023},
    eprint  = {2305.07185},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{https://doi.org/10.48550/arxiv.2302.01327,
    doi     = {10.48550/ARXIV.2302.01327},
    url     = {https://arxiv.org/abs/2302.01327},
    author  = {Kumar, Manoj and Dehghani, Mostafa and Houlsby, Neil},
    title   = {Dual PatchNorm},
    publisher = {arXiv},
    year    = {2023},
    copyright = {Creative Commons Attribution 4.0 International}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@software{peng_bo_2021_5196578,
    author    = {PENG Bo},
    title     = {BlinkDL/RWKV-LM: 0.01},
    month     = {aug},
    year      = {2021},
    publisher = {Zenodo},
    version   = {0.01},
    doi       = {10.5281/zenodo.5196578},
    url       = {https://doi.org/10.5281/zenodo.5196578}
}
@article{Kazemnejad2023TheIO,
    title   = {The Impact of Positional Encoding on Length Generalization in Transformers},
    author  = {Amirhossein Kazemnejad and Inkit Padhi and Karthikeyan Natesan Ramamurthy and Payel Das and Siva Reddy},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2305.19466}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

.\lucidrains\MEGABYTE-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'MEGABYTE-pytorch',  # 包的名称
  packages = find_packages(),  # 查找所有包
  version = '0.2.1',  # 版本号
  license='MIT',  # 许可证
  description = 'MEGABYTE - Pytorch',  # 描述
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/MEGABYTE-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'attention mechanism',
    'transformers'
  ],
  install_requires=[  # 安装依赖
    'beartype',
    'einops>=0.6.1',
    'torch>=1.10',
    'tqdm'
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\MEGABYTE-pytorch\train.py

# 导入所需的库
from MEGABYTE_pytorch import MEGABYTE
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
PRIME_LEN = 100
SEQ_LEN = 8192

# 定义辅助函数
def cycle(loader):
    # 无限循环生成数据
    while True:
        for data in loader:
            yield data

def decode_token(token):
    # 将 token 解码为字符
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    # 将 tokens 解码为字符串
    return ''.join(list(map(decode_token, tokens)))

# 实例化类似 GPT 的解码器模型
model = MEGABYTE(
    num_tokens = 256,
    dim = (768, 512, 256),
    depth = (6, 4, 2),
    max_seq_len = (512, 4, 4),
    flash_attn = True
).cuda()

# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
    x = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
    train_x, valid_x = np.split(x, [int(90e6)])
    data_train, data_val = map(torch.from_numpy, (train_x, valid_x))

# 定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader), return_loss = True)
        loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader), return_loss = True)
            print(f'validation loss: {loss.item()}')

    if i != 0 and i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime_inp = inp[:PRIME_LEN]
        prime = decode_tokens(prime_inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(prime_inp[None, :])
        sample = sample.flatten(1)

        output_str = decode_tokens(sample[0][PRIME_LEN:])
        print(output_str)

.\lucidrains\memformer\memformer\autoregressive_wrapper.py

# 导入必要的库
from functools import partial
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

# 定义函数,根据给定的概率阈值选择最高概率的元素
def top_p(logits, thres = 0.9):
    # 对logits进行降序排序
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    # 计算累积概率
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    # 根据阈值确定要移除的元素
    sorted_indices_to_remove = cum_probs > (1 - thres)
    sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
    sorted_indices_to_remove[:, 0] = 0

    # 将超过阈值的元素设置为负无穷
    sorted_logits[sorted_indices_to_remove] = float('-inf')
    return sorted_logits.scatter(1, sorted_indices, sorted_logits)

# 定义函数,根据给定的概率阈值选择最高的k个元素
def top_k(logits, thres = 0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 定义一个自回归的包装器类
class AutoregressiveWrapper(nn.Module):
    def __init__(self, net, ignore_index = -100, pad_value = 0):
        super().__init__()
        self.pad_value = pad_value
        self.ignore_index = ignore_index

        self.net = net
        self.max_seq_len = net.max_seq_len

    # 生成序列
    @torch.no_grad()
    def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
        was_training = self.net.training
        num_dims = len(start_tokens.shape)

        if num_dims == 1:
            start_tokens = start_tokens[None, :]

        b, t = start_tokens.shape

        self.net.eval()
        out = start_tokens
        input_mask = kwargs.pop('input_mask', None)

        if input_mask is None:
            input_mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)

        for _ in range(seq_len):
            x = out[:, -self.max_seq_len:]
            input_mask = input_mask[:, -self.max_seq_len:]

            logits = self.net(x, input_mask=input_mask, **kwargs)[:, -1, :]
            filtered_logits = filter_logits_fn(logits, thres = filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim=-1)
            sample = torch.multinomial(probs, 1)

            out = torch.cat((out, sample), dim=-1)
            input_mask = F.pad(input_mask, (0, 1), value=True)

            if eos_token is not None and (sample == eos_token).all():
                break

        out = out[:, t:]

        if num_dims == 1:
            out = out.squeeze(0)

        self.net.train(was_training)
        return out

    # 前向传播
    def forward(self, x, return_loss = False, **kwargs):
        pad = partial(pad_sequence, batch_first = True, padding_value = self.pad_value)

        if not return_loss:
            if not isinstance(x, torch.Tensor):
                x = pad(x)
            return self.net(x, **kwargs)

        if isinstance(x, torch.Tensor):
            xi = x[:, :-1]
            xo = x[:, 1:]

            mask = kwargs.pop('src_mask', None)
            if mask is not None and mask.shape[1] == x.shape[1]:
                mask = mask[:, :-1]
                kwargs.update(src_mask = mask)
        else:
            xi = pad(list(map(lambda t: t[:-1], x)))
            xo = pad(list(map(lambda t: t[1:], x)))

        out = self.net(xi, **kwargs)

        loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
        return loss

.\lucidrains\memformer\memformer\memformer.py

# 导入数学库和 PyTorch 库
import math
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 functools 库中导入 partial 函数
from functools import partial
# 从 torch.nn.functional 库中导入 F
import torch.nn.functional as F
# 从 inspect 库中导入 isfunction 函数
from inspect import isfunction
# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 collections 库中导入 namedtuple 类
from collections import namedtuple
# 从 memformer.autoregressive_wrapper 模块中导入 AutoregressiveWrapper 类

# 常量

# 创建一个名为 Results 的命名元组,包含 enc_out、mem 和 dec_out 三个字段
Results = namedtuple('Results', ['enc_out', 'mem', 'dec_out'])
# 创建一个名为 EncOnlyResults 的命名元组,包含 enc_out 和 mem 两个字段
EncOnlyResults = namedtuple('EncOnlyResults', ['enc_out', 'mem'])

# 辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 返回值或默认值
def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

# 返回张量的最大负值
def max_neg_value(tensor):
    return -torch.finfo(tensor.dtype).max

# 关键字参数辅助函数

# 从字典中选择指定键的值并弹出这些键
def pick_and_pop(keys, d):
    values = list(map(lambda key: d.pop(key, None), keys))
    return dict(zip(keys, values))

# 根据条件将字典分组
def group_dict_by_key(cond, d):
    return_val = [dict(),dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)

# 判断字符串是否以指定前缀开头
def string_begins_with(prefix, str):
    return str.startswith(prefix)

# 根据前缀将字典分组
def group_by_key_prefix(prefix, d):
    return group_dict_by_key(partial(string_begins_with, prefix), d)

# 根据前缀将字典分组并去除前缀
def group_by_key_prefix_and_trim(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
    return kwargs_without_prefix, kwargs

# 辅助类

# 带残差连接的模块
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

# 带预层归一化的模块
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

# 位置嵌入

# 相对位置偏置模块
class RelativePositionBias(nn.Module):
    def __init__(self, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
        super().__init__()
        self.causal = causal
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.relative_attention_bias = nn.Embedding(num_buckets, heads)

    @staticmethod
    def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
        ret = 0
        n = -relative_position
        if causal:
            num_buckets //= 2
            ret += (n < 0).long() * num_buckets
            n = torch.abs(n)
        else:
            n = torch.max(n, torch.zeros_like(n))

        max_exact = num_buckets // 2
        is_small = n < max_exact

        val_if_large = max_exact + (
            torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
        ).long()
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))

        ret += torch.where(is_small, n, val_if_large)
        return ret

    def forward(self, qlen, klen):
        device = self.relative_attention_bias.weight.device
        q_pos = torch.arange(qlen, dtype = torch.long, device = device)
        k_pos = torch.arange(klen, dtype = torch.long, device = device)
        rel_pos = k_pos[None, :] - q_pos[:, None]
        rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets)
        values = self.relative_attention_bias(rp_bucket)
        return rearrange(values, 'i j h -> () h i j')

# 主要类

# 前馈神经网络模块
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        return self.net(x)

# 注意力模块
class Attention(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(self, dim, heads = 8, causal = False, rel_pos_emb = False):
        # 调用父类的初始化函数
        super().__init__()
        # 确保维度可以被头数整除
        assert (dim % heads) == 0, 'dimension must be divisible by number of heads'
        # 计算每个头的维度
        dim_head = dim // heads
        # 缩放因子
        self.scale = dim_head ** -0.5
        # 头数
        self.heads = heads
        # 是否使用自回归
        self.causal = causal

        # 线性变换,将输入转换为查询向量
        self.to_q = nn.Linear(dim, dim)
        # 线性变换,将输入转换为键值对
        self.to_kv = nn.Linear(dim, dim * 2)
        # 线性变换,将输出转换为最终结果
        self.to_out = nn.Linear(dim, dim)

    # 前向传播函数
    def forward(self, x, context = None, pos_emb = None, mask = None, query_mask = None, kv_mask = None, attend_self = False):
        # 获取输入张量的形状和设备信息
        b, n, _, h, scale, device = *x.shape, self.heads, self.scale, x.device

        # 如果需要自注意力机制
        if attend_self:
            # 将输入和上下文拼接在一起
            kv_input = torch.cat((x, context), dim = 1)
        else:
            # 否则使用默认的上下文
            kv_input = default(context, x)

        # 计算查询向量
        q = self.to_q(x)
        # 计算键值对
        kv = self.to_kv(kv_input).chunk(2, dim = -1)

        # 重排查询、键、值张量的形状
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, *kv))
        # 计算点积注意力
        dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale

        # 如果存在位置编码
        if exists(pos_emb):
            # 添加位置编码偏置
            pos_emb_bias = pos_emb(*dots.shape[-2:])
            dots += pos_emb_bias

        # 设置掩码值为最大负值
        mask_value = max_neg_value(dots)

        # 如果是自回归模型
        if self.causal:
            # 创建自回归掩码
            causal_mask = torch.ones((n, n), device = device).triu_(1).bool()
            dots.masked_fill_(causal_mask, mask_value)
            del causal_mask

        # 如果存在查询掩码或键值掩码
        if any(map(exists, (query_mask, kv_mask))):
            # 默认查询掩码为全 1
            query_mask = default(query_mask, lambda: torch.ones((b, n), device = device).bool())

            # 如果存在上下文
            if exists(context):
                # 默认键值掩码为全 1
                kv_mask = default(kv_mask, lambda: torch.ones((b, context.shape[1]), device = device).bool())
            else:
                kv_mask = default(kv_mask, query_mask)

            # 重排查询掩码和键值掩码的形状
            query_mask = rearrange(query_mask, 'b i -> b () i ()')
            kv_mask = rearrange(kv_mask, 'b j -> b () () j')
            seq_mask = query_mask * kv_mask
            dots.masked_fill_(~seq_mask, mask_value)
            del seq_mask

        # 如果存在额外掩码
        if exists(mask):
            mask = rearrange(mask, 'b i j -> b () i j')
            dots.masked_fill_(~mask, mask_value)
            del mask

        # 计算注意力权重
        attn = dots.softmax(dim = -1)
        # 计算输出
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        # 重排输出形状
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
class Encoder(nn.Module):
    # 编码器类,包含初始化函数
    def __init__(self, dim, depth, heads = 8):
        super().__init__()
        # 初始化相对位置偏置
        self.rel_pos_emb = RelativePositionBias(heads = heads)
        # 初始化层列表
        self.layers = nn.ModuleList([])
        # 循环创建指定数量的层
        for _ in range(depth):
            # 向层列表中添加编码器层
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads = heads, rel_pos_emb = True))),
                Residual(PreNorm(dim, Attention(dim, heads = heads))),
                Residual(PreNorm(dim, FeedForward(dim)))
            ]))
    # 前向传播函数
    def forward(self, x, context = None, src_mask = None):
        # 遍历编码器层
        for (self_attn, cross_attn, ff) in self.layers:
            # 自注意力机制
            x = self_attn(x, pos_emb = self.rel_pos_emb, query_mask = src_mask)
            # 交叉注意力机制
            x = cross_attn(x, context = context)
            # 前馈神经网络
            x = ff(x)
        return x

class Decoder(nn.Module):
    # 解码器类,包含初始化函数
    def __init__(self, dim, depth, heads = 8):
        super().__init__()
        # 初始化相对位置偏置
        self.rel_pos_emb = RelativePositionBias(heads = heads, causal = True)
        # 初始化层列表
        self.layers = nn.ModuleList([])
        # 循环创建指定数量的层
        for _ in range(depth):
            # 向层列表中添加解码器层
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads = heads, causal = True, rel_pos_emb = True))),
                Residual(PreNorm(dim, Attention(dim, heads = heads))),
                Residual(PreNorm(dim, FeedForward(dim))),
            ]))
    # 前向传播函数
    def forward(self, x, context = None, src_mask = None, tgt_mask = None):
        # 遍历解码器层
        for (self_attn, cross_attn, ff) in self.layers:
            # 自注意力机制
            x = self_attn(x, pos_emb = self.rel_pos_emb, query_mask = src_mask)
            # 交叉注意力机制
            x = cross_attn(x, context = context, query_mask = src_mask, kv_mask = tgt_mask)
            # 前馈神经网络
            x = ff(x)
        return x

class TransformerWrapper(nn.Module):
    # 转换器包装器类,包含初始化函数
    def __init__(self, *, num_tokens, max_seq_len, dim, layer_blocks, heads = 8, return_logits = True):
        super().__init__()
        # 初始化标记嵌入
        self.token_emb = nn.Embedding(num_tokens, dim)
        self.max_seq_len = max_seq_len
        self.layer_blocks = layer_blocks
        self.norm = nn.LayerNorm(dim)
        self.to_logits = nn.Linear(dim, num_tokens) if return_logits else nn.Identity()

    # 前向传播函数
    def forward(self, x, **kwargs):
        _, n, device = *x.shape, x.device
        # 标记嵌入
        x = self.token_emb(x)
        # 层块
        x = self.layer_blocks(x, **kwargs)
        x = self.norm(x)
        return self.to_logits(x)

class Memformer(nn.Module):
    # 记忆形式类,包含初始化函数
    def __init__(
        self,
        *,
        dim,
        num_memory_slots,
        num_mem_updates = 1,
        encoder_only = False,
        mem_update_attn_heads = 8,
        **kwargs):
        super().__init__()
        # 分组关键字参数
        enc_kwargs, kwargs = group_by_key_prefix_and_trim('enc_', kwargs)
        dec_kwargs, kwargs = group_by_key_prefix_and_trim('dec_', kwargs)
        assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
        enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs)
        dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)

        # 初始化编码器
        self.encoder = TransformerWrapper(
            dim = dim,
            layer_blocks = Encoder(dim = dim, **enc_kwargs),
            return_logits = False,
            **enc_transformer_kwargs
        )

        # 初始化解码器
        self.decoder = TransformerWrapper(
            dim = dim,
            layer_blocks = Decoder(dim = dim, **dec_kwargs),
            return_logits = True,
            **dec_transformer_kwargs
        ) if not encoder_only else None

        if exists(self.decoder):
            self.decoder = AutoregressiveWrapper(self.decoder)

        self.num_mem = num_memory_slots
        self.memory_slots = nn.Parameter(torch.randn(num_memory_slots, dim))

        self.num_mem_updates = num_mem_updates
        self.mem_updater = Attention(dim, heads = mem_update_attn_heads)
        self.gru = nn.GRUCell(dim, dim)
        self.mem_ff = Residual(PreNorm(dim, FeedForward(dim)))
    # 获取初始记忆,将记忆槽复制多份以适应批处理大小
    def get_initial_mem(self, batch_size):
        return repeat(self.memory_slots, 'n d -> b n d', b = batch_size)

    # 前向传播函数,接收源数据、目标数据、记忆、源数据掩码、目标数据掩码等参数
    def forward(self, src, tgt = None, mems = None, src_mask = None, tgt_mask = None):
        # 获取源数据的形状信息
        b, n, num_mem, device = *src.shape, self.num_mem, src.device
        # 如果没有传入记忆,则使用默认的初始记忆
        mems = default(mems, lambda: self.get_initial_mem(b))

        # 编码器处理源数据和记忆,生成编码结果
        enc = self.encoder(src, context = mems, src_mask = src_mask)

        # 如果存在解码器和目标数据,则进行解码操作
        if exists(self.decoder) and exists(tgt):
            dec_out = self.decoder(tgt, context = enc, src_mask = tgt_mask, tgt_mask = src_mask, return_loss = True)
        else:
            # 否则创建一个梯度可求的张量作为占位符
            dec_out = torch.tensor(0., requires_grad = True, device = device)

        # 更新记忆,使用注意力机制
        mem_mask = torch.eye(num_mem, num_mem, device = device).bool()
        mem_mask = repeat(mem_mask, 'i j -> b i j', b = b)
        mem_mask = F.pad(mem_mask, (0, n), value = True)

        # 如果存在源数据掩码,则将其与记忆掩码相结合
        if exists(src_mask):
            src_mask = rearrange(src_mask, 'b j -> b () j')
            mem_enc_mask = F.pad(src_mask, (num_mem, 0), value = True)
            mem_mask &= mem_enc_mask

        # 多次更新记忆
        for _ in range(self.num_mem_updates):
            prev_mems = mems
            updated_mems = self.mem_updater(mems, enc, mask = mem_mask, attend_self = True)

            next_mems = self.gru(
                rearrange(updated_mems, 'b n d -> (b n) d'),
                rearrange(prev_mems, 'b n d -> (b n) d')
            )

            mems = rearrange(next_mems, '(b n) d -> b n d', b = b)
            mems = self.mem_ff(mems)

        # 如果没有解码器,则返回编码结果和记忆
        if not exists(self.decoder):
            return EncOnlyResults(enc, mems)

        # 否则返回编码结果、记忆和解码结果
        return Results(enc, mems, dec_out)

.\lucidrains\memformer\memformer\mrbp.py

# 导入 torch 库
import torch
# 从 operator 库中导入 itemgetter 函数

# 定义内存回放反向传播函数,接受模型、源数据、目标数据、源数据掩码和目标数据掩码作为参数
def memory_replay_backprop(
    model,
    src,
    tgt,
    src_mask = None,
    tgt_mask = None
):
    # 获取源数据的 batch 大小
    b, *_ = src.shape

    # 从编码器获取初始内存和最大序列长度
    mem_init = model.get_initial_mem(b)
    max_seq_len = model.encoder.max_seq_len

    # 实例化内存回放缓冲区
    replay_buffer = [mem_init]

    # 拆分序列和掩码
    src_segs = src.split(max_seq_len, dim = 1)
    num_segs = len(src_segs)
    src_mask_segs = src_mask.split(max_seq_len, dim = 1) if src_mask is not None else ((None,) * num_segs)

    # 目前假设目标序列和掩码在最后一个段中传递
    # 待办事项 - 允许在任何段中连接目标序列
    # 并将自定义损失附加到编码器输出
    tgt_segs = ((None,) * (num_segs - 1)) + (tgt,)
    tgt_mask_segs = ((None,) * (num_segs - 1)) + (tgt_mask,)

    # 运行前向传播并收集所有内存
    prev_mem = mem_init
    with torch.no_grad():
        for i in range(num_segs - 1):
            src, src_mask = map(itemgetter(i), (src_segs, src_mask_segs))
            _, mem, _ = model(src, src_mask = src_mask, mems = prev_mem)
            replay_buffer.append(mem)
            prev_mem = mem

    # 逐个段进行反向传播,从最后一步到第一步
    mem_grad = torch.zeros_like(prev_mem)
    for i in reversed(range(num_segs)):
        src, src_mask, tgt, tgt_mask, mems = map(itemgetter(i), (src_segs, src_mask_segs, tgt_segs, tgt_mask_segs, replay_buffer))
        mems = mems.requires_grad_()

        _, mems_next, tgt_loss = model(src = src, tgt = tgt, src_mask = src_mask, tgt_mask = tgt_mask, mems = mems)
        tgt_loss.backward(retain_graph = True)
        mems_next.backward(mem_grad, retain_graph = True)

        # 如果不是最后一步,则将下一个内存的梯度传递回一步
        if i != 0:
            mem_grad.copy_(mems.grad.data)

.\lucidrains\memformer\memformer\__init__.py

# 从 memformer 包中导入 Memformer 类
from memformer.memformer import Memformer
# 从 memformer 包中导入 memory_replay_backprop 函数
from memformer.mrbp import memory_replay_backprop

Memformer - Pytorch

Implementation of Memformer, a Memory-augmented Transformer, in Pytorch. It includes memory slots, which are updated with attention, learned efficiently through Memory-Replay BackPropagation (MRBP) through time.

Install

$ pip install memformer

Usage

Full encoder / decoder, as in the paper

import torch
from memformer import Memformer

model = Memformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 2,
    enc_heads = 8,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 2,
    dec_heads = 8,
    dec_max_seq_len = 1024,
    num_memory_slots = 128
)

src_seg_1 = torch.randint(0, 256, (1, 1024))
src_seg_2 = torch.randint(0, 256, (1, 1024))
src_seg_3 = torch.randint(0, 256, (1, 1024))

tgt = torch.randint(0, 256, (1, 1024))

enc_out1, mems1,    _ = model(src_seg_1) # (1, 1024, 512), (1, 128, 512), _
enc_out2, mems2,    _ = model(src_seg_2, mems = mems1)
enc_out3, mems3, loss = model(src_seg_3, tgt, mems = mems2)

loss.backward()

Encoder only

import torch
from memformer import Memformer

model = Memformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_heads = 8,
    enc_depth = 2,
    enc_max_seq_len = 1024,
    num_memory_slots = 128,
    num_mem_updates = 2,
    encoder_only = True       # only use encoder, in which output is encoded output
)

src1 = torch.randint(0, 256, (1, 1024))
src2 = torch.randint(0, 256, (1, 1024))

enc1, mems1 = model(src1) # (1, 1024, 512), (1, 128, 512)
enc2, mems2 = model(src2, mems = mems1)

Memory Replay Back-Propagation

import torch
from memformer import Memformer, memory_replay_backprop

model = Memformer(
    dim = 512,
    num_memory_slots = 128,
    enc_num_tokens = 256,
    enc_depth = 2,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 2,
    dec_max_seq_len = 1024
).cuda()

seq = torch.randint(0, 256, (1, 8192)).cuda()
seq_mask = torch.ones_like(seq).bool().cuda()

tgt = torch.randint(0, 256, (1, 512)).cuda()
tgt_mask = torch.ones_like(tgt).bool().cuda()

# will automatically split the source sequence to 8 segments
memory_replay_backprop(
    model,
    src = seq,
    tgt = tgt,
    src_mask = seq_mask,
    tgt_mask = tgt_mask
)

Citations

@inproceedings{
    anonymous2021memformer,
    title={Memformer: The Memory-Augmented Transformer},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=_adSMszz_g9},
    note={under review}
}

.\lucidrains\memformer\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'memformer',  # 包的名称
  packages = find_packages(exclude=['examples']),  # 查找并包含除了 examples 之外的所有包
  version = '0.3.1',  # 版本号
  license='MIT',  # 许可证信息
  description = 'Memformer - Pytorch',  # 描述信息
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/memformer',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'attention mechanism',
    'transformers',
    'memory'
  ],
  install_requires=[  # 安装依赖列表
    'torch>=1.6',
    'einops>=0.3'
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

.\lucidrains\memorizing-transformers-pytorch\memorizing_transformers_pytorch\knn_memory.py

# 导入必要的库
import os
import math
import torch
import faiss
import numpy as np
from pathlib import Path
from functools import wraps

# 导入上下文管理器相关的库
from contextlib import ExitStack, contextmanager

# 导入 einops 库
from einops import rearrange, pack, unpack

# 导入 multiprocessing 相关库
from joblib import Parallel, delayed, cpu_count

# 定义常量
FAISS_INDEX_GPU_ID = int(os.getenv('FAISS_INDEX_GPU_ID', 0))
DEFAULT_KNN_MEMORY_MEMMAP_DIRECTORY = './.tmp/knn.memories'

# 定义一些辅助函数

# 检查变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回该变量,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 将变量转换为列表
def cast_list(val):
    return val if isinstance(val, list) else [val]

# 检查数组中的元素是否全部唯一
def all_el_unique(arr):
    return len(set(arr)) == len(arr)

# 定义一个多上下文管理器
@contextmanager
def multi_context(*cms):
    with ExitStack() as stack:
        yield [stack.enter_context(cls) for cls in cms]

# 计算两个数组的交集
def count_intersect(x, y):
    return np.sum(rearrange(x, 'i -> i 1') == rearrange(y, 'j -> 1 j'), axis = -1)

# 检查张量的形状是否符合指定的模式
def check_shape(tensor, pattern, **kwargs):
    return rearrange(tensor, f"{pattern} -> {pattern}", **kwargs)

# 定义一个 KNN 类,封装了 faiss IndexIVFFlat,并自动处理旧键的过期
class KNN():
    def __init__(
        self,
        dim,
        max_num_entries,
        cap_num_entries = False,
        M = 15,
        keep_stats = False
    ):
        index = faiss.IndexHNSWFlat(dim, M, faiss.METRIC_INNER_PRODUCT)
        self.index = index
        self.max_num_entries = max_num_entries
        self.cap_num_entries = cap_num_entries
        self.is_trained = False
        self.keep_stats = keep_stats

        self.reset()

    def __del__(self):
        if hasattr(self, 'index'):
            del self.index

    def reset(self):
        self.ids = np.empty((0,), dtype = np.int32)

        if self.keep_stats:
            self.hits = np.empty((0,), dtype = np.int32)
            self.age_num_iterations = np.empty((0,), dtype = np.int32)
            self.ages_since_last_hit = np.empty((0,), dtype = np.int32)

        self.index.reset()
        self.is_trained = False

    def train(self, x):
        self.index.train(x)
        self.is_trained = True

    def add(self, x, ids):
        if not self.is_trained:
            self.train(x)

        self.ids = np.concatenate((ids, self.ids))

        if self.keep_stats:
            self.hits = np.concatenate((np.zeros_like(ids), self.hits))
            self.age_num_iterations = np.concatenate((np.zeros_like(ids), self.age_num_iterations))
            self.ages_since_last_hit = np.concatenate((np.zeros_like(ids), self.ages_since_last_hit))

        if self.cap_num_entries and len(self.ids) > self.max_num_entries:
            self.reset()

        return self.index.add(x)

    def search(
        self,
        x,
        topk,
        nprobe = 8,
        return_distances = False,
        increment_hits = False,
        increment_age = True
    ):
        if not self.is_trained:
            return np.full((x.shape[0], topk), -1)

        distances, indices = self.index.search(x, k = topk)

        if increment_hits and self.keep_stats:
            hits = count_intersect(self.ids, rearrange(indices, '... -> (...)'))
            self.hits += hits

            self.ages_since_last_hit += 1
            self.ages_since_last_hit *= (hits == 0)

        if increment_age and self.keep_stats:
            self.age_num_iterations += 1

        if return_distances:
            return indices, distances

        return indices

# 定义一个 KNNMemory 类,用于存储键/值记忆,可以自动处理一组 faiss 索引(跨批次维度)
class KNNMemory():
    def __init__(
        self,
        dim,
        max_memories = 16000,
        num_indices = 1,
        memmap_filename = './knn.memory.memmap',
        multiprocessing = True
    # 初始化方法,设置对象的维度、索引数量、索引范围、最大内存、形状、数据库偏移量等属性
    ):
        self.dim = dim
        self.num_indices = num_indices
        self.scoped_indices = list(range(num_indices))

        self.max_memories = max_memories
        self.shape = (num_indices, max_memories, 2, dim)
        self.db_offsets = np.zeros(num_indices, dtype = np.int32)

        # 创建一个内存映射对象,用于存储数据
        self.db = np.memmap(memmap_filename, mode = 'w+', dtype = np.float32, shape = self.shape)
        # 创建一个 KNN 对象列表
        self.knns = [KNN(dim = dim, max_num_entries = max_memories, cap_num_entries = True) for _ in range(num_indices)]
    
        # 根据是否使用多进程设置并行任务数
        self.n_jobs = cpu_count() if multiprocessing else 1

    # 设置作用域索引
    def set_scoped_indices(self, indices):
        indices = list(indices)
        # 检查索引是否唯一
        assert all_el_unique(indices), f'all scoped batch indices must be unique, received: {indices}'
        # 检查索引范围是否在有效范围内
        assert all([0 <= i < self.num_indices for i in indices]), f'each batch index must be between 0 and less than {self.num_indices}: received {indices}'
        self.scoped_indices = indices

    # 上下文管理器,用于设置作用域索引
    @contextmanager
    def at_batch_indices(self, indices):
        prev_indices = self.scoped_indices
        self.set_scoped_indices(indices)
        yield self
        self.set_scoped_indices(prev_indices)

    # 清空指定批次的数据
    def clear(self, batch_indices = None):
        if not exists(batch_indices):
            batch_indices = list(range(self.num_indices))

        batch_indices = cast_list(batch_indices)

        # 重置指定批次的 KNN 对象
        for index in batch_indices:
            knn = self.knns[index]
            knn.reset()

        self.db_offsets[batch_indices] = 0

    # 添加新的记忆数据
    def add(self, memories):
        # 检查记忆数据的形状
        check_shape(memories, 'b n kv d', d = self.dim, kv = 2, b = len(self.scoped_indices))

        memories = memories.detach().cpu().numpy()
        memories = memories[:, -self.max_memories:]
        num_memories = memories.shape[1]

        knn_insert_ids = np.arange(num_memories)

        keys = np.ascontiguousarray(memories[..., 0, :])
        knns = [self.knns[i] for i in self.scoped_indices]
        db_offsets = [self.db_offsets[i] for i in self.scoped_indices]

        # 使用 joblib 将新的键/值记忆插入到 faiss 索引中

        @delayed
        def knn_add(knn, key, db_offset):
            knn.add(key, ids = knn_insert_ids + db_offset)
            return knn

        updated_knns = Parallel(n_jobs = self.n_jobs)(knn_add(*args) for args in zip(knns, keys, db_offsets))
        for knn_idx, scoped_idx in enumerate(self.scoped_indices):
            self.knns[scoped_idx] = updated_knns[knn_idx]

        # 将新的记忆数据添加到内存映射的数据库中

        add_indices = (rearrange(np.arange(num_memories), 'j -> 1 j') + rearrange(self.db_offsets[list(self.scoped_indices)], 'i -> i 1')) % self.max_memories
        self.db[rearrange(np.array(self.scoped_indices), 'i -> i 1'), add_indices] = memories
        self.db.flush()

        self.db_offsets += num_memories

    # 搜索方法,用于查询最近邻
    def search(
        self,
        queries,
        topk,
        nprobe = 8,
        increment_hits = True,
        increment_age = True
        ):
        # 检查查询数据的形状是否符合要求
        check_shape(queries, 'b ... d', d = self.dim, b = len(self.scoped_indices))
        # 将查询数据打包成指定格式
        queries, ps = pack([queries], 'b * d')

        # 获取查询数据的设备信息
        device = queries.device
        # 将查询数据转换为 numpy 数组
        queries = queries.detach().cpu().numpy()

        # 初始化空列表用于存储掩码和键值对
        all_masks = []
        all_key_values = []

        # 获取指定索引处的 knn 对象
        knns = [self.knns[i] for i in self.scoped_indices]

        # 并行化 faiss 搜索

        @delayed
        def knn_search(knn, query):
            return knn.search(query, topk, nprobe, increment_hits = increment_hits, increment_age = increment_age)

        # 并行执行 knn_search 函数,获取搜索结果
        fetched_indices = Parallel(n_jobs = self.n_jobs)(knn_search(*args) for args in zip(knns, queries))

        # 从内存映射 'database' 中获取所有的键/值对
        # 待办事项 - 移除下面的 for 循环

        for batch_index, indices in zip(self.scoped_indices, fetched_indices):
            # 创建掩码,将无效索引替换为 0
            mask = indices !=  -1
            db_indices = np.where(mask, indices, 0)

            # 将掩码转换为 PyTorch 张量并添加到列表中
            all_masks.append(torch.from_numpy(mask))

            # 获取键值对并添加到列表中
            key_values = self.db[batch_index, db_indices % self.max_memories]
            all_key_values.append(torch.from_numpy(key_values))

        # 将所有掩码和键值对堆叠成张量
        all_masks = torch.stack(all_masks)
        all_key_values = torch.stack(all_key_values)
        # 使用掩码填充键值对中的无效值为 0
        all_key_values = all_key_values.masked_fill(~rearrange(all_masks, '... -> ... 1 1'), 0.)

        # 拆分键值对张量
        all_key_values, = unpack(all_key_values, ps, 'b * n kv d')
        all_masks, = unpack(all_masks, ps, 'b * n')

        # 返回结果并将其发送到指定设备
        return all_key_values.to(device), all_masks.to(device)

    def __del__(self):
        # 在对象销毁时,删除 knns 和 db 属性
        if hasattr(self, 'knns'):
            for knn in self.knns:
                del knn
        del self.db
# 为 KNN 记忆集合扩展了一些额外的方法

class KNNMemoryList(list):
    # 清理方法,用于清理所有记忆
    def cleanup(self):
        for memory in self:
            del memory

    # 创建记忆方法,用于创建多个记忆对象
    @classmethod
    def create_memories(
        self,
        *,
        batch_size,
        num_memory_layers,
        memories_directory = DEFAULT_KNN_MEMORY_MEMMAP_DIRECTORY
    ):
        # 设置记忆路径
        memories_path = Path(memories_directory)
        memories_path.mkdir(exist_ok = True, parents = True)

        # 内部方法,用于创建多个记忆对象
        def inner(*args, **kwargs):
            return self([KNNMemory(*args, num_indices = batch_size, memmap_filename = str(memories_path / f'knn.memory.layer.{ind + 1}.memmap'), **kwargs) for ind in range(num_memory_layers)])
        return inner

    # 批量索引上下文管理器,用于在多个记忆对象上进行批量索引
    @contextmanager
    def at_batch_indices(
        self,
        indices
    ):
        knn_batch_indices_contexts = [memory.at_batch_indices(indices) for memory in self]
        with multi_context(*knn_batch_indices_contexts):
            yield

    # 清除记忆方法,用于清除指定的记忆对象
    def clear_memory(
        self,
        batch_indices = None,
        memory_indices = None
    ):
        # 默认情况下清除所有记忆对象
        memory_indices = default(memory_indices, tuple(range(len(self)))

        # 遍历指定的记忆对象,清除指定的批次索引
        for memory_index in memory_indices:
            memory = self[memory_index]
            memory.clear(batch_indices)

.\lucidrains\memorizing-transformers-pytorch\memorizing_transformers_pytorch\memorizing_transformers_pytorch.py

# 导入数学库
import math
# 从 functools 模块导入 partial 函数
from functools import partial
# 从 contextlib 模块导入 contextmanager 上下文管理器
from contextlib import contextmanager
# 从 pathlib 模块导入 Path 类
from pathlib import Path
# 从 filelock 模块导入 FileLock 类
from filelock import FileLock

# 导入 torch 库
import torch
# 从 torch 中导入 nn 模块和 F 模块
import torch.nn.functional as F
# 从 torch 中导入 nn 模块和 einsum 函数
from torch import nn, einsum

# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 einops_exts 库中导入 repeat_many 函数
from einops_exts import repeat_many
# 从 einops.layers.torch 中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 从 memorizing_transformers_pytorch.knn_memory 模块中导入 KNNMemoryList 类和 DEFAULT_KNN_MEMORY_MEMMAP_DIRECTORY 常量

# 辅助函数

# 定义一个返回输入的函数
def identity(t):
    return t

# 判断输入是否存在的函数
def exists(val):
    return val is not None

# 返回输入列表中唯一元素的函数
def unique(arr):
    return list({el: True for el in arr}.keys())

# 返回输入值或默认值的函数
def default(val, d):
    return val if exists(val) else d

# 将输入值转换为元组的函数
def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# 对输入张量进行 L2 归一化的函数
def l2norm(t):
    return F.normalize(t, dim = -1)

# 辅助类

# 实现预层归一化残差连接的类
class PreNormResidual(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, **kwargs):
        out = self.fn(self.norm(x), **kwargs)

        if not isinstance(out, tuple):
            return out + x

        head, *tail = out
        return (head + x, *tail)

# T5 相对位置偏置类

class T5RelativePositionBias(nn.Module):
    def __init__(
        self,
        scale,
        num_buckets = 32,
        max_distance = 128,
        heads = 8
    ):
        super().__init__()
        self.scale = scale
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.relative_attention_bias = nn.Embedding(num_buckets, heads)

    @staticmethod
    def _relative_position_bucket(
        relative_position,
        num_buckets = 32,
        max_distance = 128
    ):
        n = -relative_position
        n = torch.max(n, torch.zeros_like(n))

        max_exact = num_buckets // 2
        is_small = n < max_exact

        val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
        return torch.where(is_small, n, val_if_large)

    def forward(self, i, j, *, device):
        q_pos = torch.arange(i, dtype = torch.long, device = device)
        k_pos = torch.arange(j, dtype = torch.long, device = device)
        rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
        rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
        values = self.relative_attention_bias(rp_bucket)
        bias = rearrange(values, 'i j h -> () h i j')
        return bias * self.scale

# 前馈网络类

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        return self.net(x)

# 注意力机制类

class Attention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 64,
        dropout = 0.,
        xl_max_memories = 0.,
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = heads * dim_head
        self.xl_max_memories = xl_max_memories

        self.dropout = nn.Dropout(dropout)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)
    # 定义一个前向传播函数,接受输入 x,可选的 xl_memory 和 rel_pos_bias 参数
    def forward(self, x, *, xl_memory = None, rel_pos_bias = None):
        # 获取头数和设备信息
        h, device = self.heads, x.device
        # 将输入 x 分别转换为查询 q,键 k,值 v
        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))

        # 重新排列查询 q 的维度
        q = rearrange(q, 'b n (h d) -> b h n d', h = h)

        # 对查询 q 进行缩放
        q = q * self.scale

        # 如果存在 xl_memory,则将其拆分为键值对,并与当前的 k 和 v 连接起来
        if exists(xl_memory):
            k_xl_mem, v_xl_mem = xl_memory.unbind(dim = -2)
            k = torch.cat((k_xl_mem, k), dim = -2)
            v = torch.cat((v_xl_mem, v), dim = -2)

        # 计算查询和键之间的相似度
        sim = einsum('b h i d, b j d -> b h i j', q, k)
        i, j = sim.shape[-2:]

        # 如果存在相对位置偏置,则加到相似度上
        if exists(rel_pos_bias):
            sim = rel_pos_bias[..., -i:, -j:] + sim

        # 创建一个因果掩码,用于屏蔽未来信息
        causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # 对相似度进行 softmax 操作
        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        # 根据注意力权重计算输出
        out = einsum('b h i j, b j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')

        # 创建新的 xl 记忆
        new_kv_memories = torch.stack((k, v), dim = -2).detach()

        # 如果设置了最大 xl 记忆数,则保留最新的 xl 记忆
        if self.xl_max_memories > 0:
            new_xl_kv_memories = new_kv_memories[:, -self.xl_max_memories:]
        else:
            new_xl_kv_memories = None

        # 返回输出和新的 xl 记忆
        return self.to_out(out), new_xl_kv_memories
# 定义一个近似最近邻注意力机制的类 KNNAttention
class KNNAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,  # 输入特征的维度
        heads = 8,  # 多头注意力的头数
        dim_head = 64,  # 每个头的维度
        dropout = 0.,  # dropout 概率
        num_retrieved_memories = 32,  # 检索的记忆数量
        xl_max_memories = 0.,  # 最大记忆数量
        attn_scale_init = 20,  # 注意力缩放初始化值
        gate_output = False  # 是否使用输出门
    ):
        super().__init__()
        self.heads = heads  # 头数
        self.scale = nn.Parameter(torch.ones(heads, 1, 1) * math.log(attn_scale_init))  # 缩放参数

        inner_dim = heads * dim_head  # 内部维度
        self.xl_max_memories = xl_max_memories  # 最大记忆数量

        self.num_retrieved_memories = num_retrieved_memories  # 检索的记忆数量

        self.dropout = nn.Dropout(dropout)  # dropout 操作
        self.knn_mem_dropout = nn.Dropout(dropout)  # knn 记忆的 dropout 操作

        self.to_q = nn.Linear(dim, inner_dim, bias = False)  # 查询映射
        self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)  # 键值映射
        self.to_out = nn.Linear(inner_dim, dim, bias = False)  # 输出映射

        self.output_gate = nn.Parameter(torch.zeros(1)) if gate_output else None  # 输出门参数

    def forward(
        self,
        x,  # 输入张量
        *,
        knn_memory,  # KNN 记忆
        xl_memory = None,  # XL 记忆
        add_knn_memory = True,  # 是否添加 KNN 记忆
        rel_pos_bias = None  # 相对位置偏置
        ):
            # 解包 x 的形状,获取 batch size, 序列长度, 头数, 设备信息
            b, n, h, device = *x.shape[:2], self.heads
            # 将输入 x 分别转换为查询 q, 键 k, 值 v
            q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))

            # 重排查询 q 的形状,以适应多头注意力计算
            q = rearrange(q, 'b n (h d) -> b h n d', h = h)

            # 根据论文,对键进行归一化以提高训练稳定性
            # 这里采用完全余弦相似度注意力 https://arxiv.org/abs/2010.04245
            q, k = map(l2norm, (q, k))

            # 处理 XL 内存
            if exists(xl_memory):
                k_xl_mem, v_xl_mem = xl_memory.unbind(dim = -2)
                k = torch.cat((k_xl_mem, k), dim = -2)
                v = torch.cat((v_xl_mem, v), dim = -2)

            # 计算局部注意力
            scale = self.scale.exp()

            sim = einsum('b h i d, b j d -> b h i j', q, k) * scale
            i, j = sim.shape[-2:]

            # 如果存在相对位置偏置,则加入到注意力矩阵中
            if exists(rel_pos_bias):
                sim = rel_pos_bias[..., -i:, -j:] + sim

            mask_value = -torch.finfo(sim.dtype).max

            causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, mask_value)

            # 如果传入索引,则计算记忆中的 knn 注意力
            mem_kv, mem_mask = knn_memory.search(q, self.num_retrieved_memories)
            mem_k, mem_v = mem_kv.unbind(dim = -2)

            sim_mem = einsum('b h i d, b h i j d -> b h i j', q, mem_k) * scale
            sim_mem = sim_mem.masked_fill(~mem_mask, mask_value)

            # 计算新的 XL 记忆,以及要丢弃的记忆
            new_kv_memories = torch.stack((k, v), dim = -2).detach()

            if self.xl_max_memories > 0:
                new_kv_memories_discarded, new_xl_kv_memories = new_kv_memories[:, :-self.xl_max_memories], new_kv_memories[:, -self.xl_max_memories:]
            else:
                new_kv_memories_discarded, new_xl_kv_memories = new_kv_memories, None

            # 将要丢弃的记忆添加到 KNN 记忆中
            if add_knn_memory and new_kv_memories_discarded.numel() > 0:
                knn_memory.add(new_kv_memories_discarded)

            # 组合局部和远程注意力
            sim = torch.cat((sim_mem, sim), dim = -1)
            attn = sim.softmax(dim = -1)
            attn = self.dropout(attn)

            local_attn, mem_attn = attn[..., self.num_retrieved_memories:], attn[..., :self.num_retrieved_memories]
            local_out = einsum('b h i j, b j d -> b h i d', local_attn, v)
            mem_out = einsum('b h i j, b h i j d -> b h i d', mem_attn, mem_v)

            out = local_out + mem_out

            # 合并头部并进行投影
            out = rearrange(out, 'b h n d -> b n (h d)')
            out = self.to_out(out)

            # 使用 flamingo 风格的输出门控制输出,以便将记忆化 transformer 门控到现有的 LLM 中
            if exists(self.output_gate):
                out = out * self.output_gate.tanh()

            return out, new_xl_kv_memories
# 主类
class MemorizingTransformer(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        num_tokens,  # 标记数量
        dim,  # 维度
        depth,  # 深度
        dim_head = 64,  # 头维度
        heads = 8,  # 头数
        knn_attn_heads = None,  # KNN注意力头数
        attn_dropout = 0.,  # 注意力丢弃率
        ff_mult = 4,  # 前馈倍数
        ff_dropout = 0.,  # 前馈丢弃率
        memorizing_layers = None,  # 记忆层
        max_knn_memories = 250000,  # 最大KNN记忆
        num_retrieved_memories = 32,  # 检索的记忆数
        clear_memories_on_sos_token_id = None,  # SOS标记时清除记忆
        clear_memories_on_eos_token_id = None,  # EOS标记时清除记忆
        knn_memories_directory = DEFAULT_KNN_MEMORY_MEMMAP_DIRECTORY,  # KNN记忆目录
        shift_knn_memories_down = 0.,  # KNN记忆下移
        pad_id = 0,  # 填充标记
        xl_max_memories = 0,  # XL最大记忆
        xl_memory_layers = None,  # XL记忆层
        shift_xl_memories_down = 0.,  # XL记忆下移
        knn_memory_multiprocessing = False  # KNN记忆多进程
    ):
        super().__init__()
        self.token_emb = nn.Embedding(num_tokens, dim)  # 标记嵌入
        self.pad_id = pad_id  # 填充标记

        block_wrapper = partial(PreNormResidual, dim)  # 块包装器
        valid_layers = set(range(1, depth + 1))  # 有效层范围

        memorizing_layers = default(memorizing_layers, (depth // 2,))  # 默认KNN注意力层为变压器中点
        memorizing_layers = cast_tuple(memorizing_layers)  # 转换为元组
        memorizing_layers = tuple(filter(lambda i: i in valid_layers, memorizing_layers))  # 过滤有效层

        self.dim_head = dim_head  # 头维度

        knn_attn_heads = default(knn_attn_heads, heads)  # 默认KNN注意力头数

        # XL记忆超参数
        if xl_max_memories > 0:
            xl_memory_layers = default(xl_memory_layers, tuple(range(1, depth + 1)))  # 默认XL记忆层为所有层
            xl_memory_layers = unique(xl_memory_layers)  # 唯一值
            self.xl_memory_layers = tuple(filter(lambda i: i in valid_layers, xl_memory_layers))  # 过滤有效层
            self.num_xl_memory_layers = len(self.xl_memory_layers)  # XL记忆层数
        else:
            self.xl_memory_layers = tuple()
            self.num_xl_memory_layers = 0

        # KNN记忆超参数
        self.max_knn_memories = max_knn_memories  # 最大KNN记忆
        self.knn_memories_directory = knn_memories_directory  # KNN记忆目录
        self.memorizing_layers = unique(memorizing_layers)  # 唯一值
        self.num_memory_layers = len(memorizing_layers)  # 记���层数

        self.clear_memories_on_sos_token_id = clear_memories_on_sos_token_id  # SOS标记时清除记忆
        self.clear_memories_on_eos_token_id = clear_memories_on_eos_token_id  # EOS标记时清除记忆

        # 相对位置偏置
        self.rel_pos_bias = T5RelativePositionBias(scale = dim_head ** 0.5, heads = heads)  # 相对位置偏置
        self.knn_rel_pos_bias = T5RelativePositionBias(scale = dim_head ** 0.5, heads = heads)  # KNN相对位置偏置

        # 层
        self.layers = nn.ModuleList([])
        for idx in range(depth):
            layer_num = idx + 1

            use_xl_memories = layer_num in self.xl_memory_layers  # 使用XL记忆
            use_knn_attention = layer_num in memorizing_layers  # 使用KNN注意力
            xl_max_memories_layer = 0 if not use_xl_memories else xl_max_memories  # XL最大记忆层

            if use_knn_attention:
                attn = KNNAttention(dim = dim, dim_head = dim_head, heads = knn_attn_heads, dropout = attn_dropout, num_retrieved_memories = num_retrieved_memories, xl_max_memories = xl_max_memories_layer)  # KNN注意力
            else:
                attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, xl_max_memories = xl_max_memories_layer)  # 注意力

            self.layers.append(nn.ModuleList([
                block_wrapper(attn),
                block_wrapper(FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)),
            ]))

        # 记忆层移动
        # 来自一篇鲜为人知的论文 https://arxiv.org/abs/2012.15688

        self.shift_knn_memories_down = shift_knn_memories_down  # KNN记忆下移
        self.shift_xl_memories_down = shift_xl_memories_down  # XL记忆下移

        # 转换为logits
        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_tokens)
        )

        # KNN记忆初始化
        self.knn_mem_kwargs = dict(
            dim = self.dim_head,
            max_memories = self.max_knn_memories,
            multiprocessing = knn_memory_multiprocessing
        )
    # 创建 KNN 记忆体列表
    def create_knn_memories(
        self,
        *,
        batch_size
    ):
        # 调用 KNNMemoryList 类的 create_memories 方法创建记忆体
        return KNNMemoryList.create_memories(
            batch_size = batch_size,
            num_memory_layers = self.num_memory_layers,
            memories_directory = self.knn_memories_directory,
        )(**self.knn_mem_kwargs)

    # 上下文管理器,用于处理 KNN 记忆体
    @contextmanager
    def knn_memories_context(
        self,
        **kwargs
    ):
        # 获取 KNN 记忆体目录路径
        knn_dir = Path(self.knn_memories_directory)
        # 如果目录不存在则创建
        knn_dir.mkdir(exist_ok = True, parents = True)
        # 创建文件锁
        lock = FileLock(str(knn_dir / 'mutex'))

        # 使用文件锁
        with lock:
            # 创建 KNN 记忆体
            knn_memories = self.create_knn_memories(**kwargs)
            # 通过 yield 将 KNN 记忆体传递给调用者
            yield knn_memories
            # 清理 KNN 记忆体
            knn_memories.cleanup()

    # 清除记忆体中包含指定 token id 的批次行
    def clear_memory(self, x, token_id):
        """ clears the KNN memories based on if the batch row contains the specified token id """
        """ for auto-clearing KNN memories based on start and end of strings """

        # 判断是否需要清除记忆体
        clear_memory = (x == token_id).any(dim = -1)
        # 获取需要清除的批次索引
        batch_indices, _ = clear_memory.nonzero(as_tuple = True)
        batch_indices_to_clear = batch_indices.tolist()

        # 如果没有需要清除的批次索引,则直接返回
        if len(batch_indices_to_clear) == 0:
            return

        # 清除指定批次索引的记忆体
        knn_memories.clear_memory(batch_indices_to_clear)

    # 前向传播函数
    def forward(
        self,
        x,
        knn_memories,
        xl_memories = None,
        labels = None,
        add_knn_memory = True
        ):
            # 解构输入张量 x 的形状,获取批量大小、序列长度和设备信息
            batch_size, seq_len, *_, device = *x.shape, x.device
            # 使用 token_emb 对象对输入张量 x 进行 token 嵌入
            x = self.token_emb(x)

            # 验证 KNN memories 是否有足够的索引来匹配批量大小

            assert all([memory.num_indices == batch_size for memory in knn_memories]), f'you passed in an input with batch size {batch_size} but your memories were not instantiated with that number of KNN indices'

            # 如果传入了 KNN memories,并且研究人员希望在检测到 <sos> 标记时自动清除 memories
            # 执行适当的逻辑

            if exists(self.clear_memories_on_sos_token_id):
                self.clear_memory(x, self.clear_memories_on_sos_token_id)

            # 处理 XL memories

            xl_memories = default(xl_memories, (None,) * self.num_xl_memory_layers)
            assert len(xl_memories) == self.num_xl_memory_layers
            has_xl_memories = len(xl_memories) > 0

            # 将 memories 向下移动若干层,这是 Ernie-Doc 论文中展示的增强 memories 的鲜为人知的技术

            if len(knn_memories) > 0 and self.shift_knn_memories_down > 0:
                knn_memories = [*knn_memories[self.shift_knn_memories_down:], *knn_memories[:self.shift_knn_memories_down]]

            if len(xl_memories) > 0 and self.shift_xl_memories_down > 0:
                xl_memories = [*xl_memories[self.shift_xl_memories_down:], *xl_memories[:self.shift_xl_memories_down]]

            # 按照包含 KNNAttention 的升序层次顺序遍历 memories

            xl_memories_iter = iter(xl_memories)
            knn_memories_iter = iter(knn_memories)

            # 位置偏置

            max_context_len = max([seq_len, *map(lambda t: (t.shape[-3] if exists(t) else 0) + seq_len, xl_memories)])

            rel_pos_bias = self.rel_pos_bias(seq_len, max_context_len, device = device)
            knn_rel_pos_bias = self.knn_rel_pos_bias(seq_len, max_context_len, device = device)

            # 跟踪新的 XL memories

            new_xl_memories = [] if has_xl_memories else None

            # 遍历所有层

            for ind, (attn, ff) in enumerate(self.layers):
                layer_num = ind + 1

                is_memorizing_layer = layer_num in self.memorizing_layers
                is_xl_memory_layer = layer_num in self.xl_memory_layers

                attn_kwargs = dict(rel_pos_bias = rel_pos_bias if not is_memorizing_layer else knn_rel_pos_bias)

                if is_memorizing_layer:
                    attn_kwargs = {**attn_kwargs, 'knn_memory': next(knn_memories_iter), 'add_knn_memory': add_knn_memory}

                if is_xl_memory_layer:
                    attn_kwargs = {**attn_kwargs, 'xl_memory': next(xl_memories_iter)}

                # 注意力机制

                x, xl_mem = attn(x, **attn_kwargs)

                # 如果需要,添加新的 XL memories

                if exists(xl_mem):
                    new_xl_memories.append(xl_mem)

                # 前馈网络

                x = ff(x)

            # 转换为 logits

            logits = self.to_logits(x)

            # 在字符串结束标记时自动清除 KNN memories

            if exists(self.clear_memories_on_eos_token_id):
                self.clear_memory(x, self.clear_memories_on_eos_token_id)

            # 对于训练

            if not exists(labels):
                if exists(new_xl_memories):
                    return logits, new_xl_memories

                return logits

            loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels, ignore_index = self.pad_id)

            if exists(new_xl_memories):
                return loss, new_xl_memories

            return loss

.\lucidrains\memorizing-transformers-pytorch\memorizing_transformers_pytorch\__init__.py

# 从 memorizing_transformers_pytorch 包中导入 MemorizingTransformer 和 KNNAttention 类
# 以及从 knn_memory 模块中导入 KNNMemory 类
from memorizing_transformers_pytorch.memorizing_transformers_pytorch import MemorizingTransformer, KNNAttention
from memorizing_transformers_pytorch.knn_memory import KNNMemory

Memorizing Transformers - Pytorch

Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch

This repository deviates from the paper slightly, using a hybrid attention across attention logits local and distant (rather than the sigmoid gate setup). It also uses cosine similarity attention (with learned temperature) for the KNN attention layer.

Install

$ pip install memorizing-transformers-pytorch

Usage

import torch
from memorizing_transformers_pytorch import MemorizingTransformer

model = MemorizingTransformer(
    num_tokens = 20000,                 # number of tokens
    dim = 512,                          # dimension
    dim_head = 64,                      # dimension per attention head
    depth = 8,                          # number of layers
    memorizing_layers = (4, 5),         # which layers to have ANN memories
    max_knn_memories = 64000,           # maximum ANN memories to keep (once it hits this capacity, it will be reset for now, due to limitations in faiss' ability to remove entries)
    num_retrieved_memories = 32,        # number of ANN memories to retrieve
    clear_memories_on_sos_token_id = 1, # clear passed in ANN memories automatically for batch indices which contain this specified SOS token id - otherwise, you can also manually iterate through the ANN memories and clear the indices before the next iteration
)

data = torch.randint(0, 20000, (2, 1024)) # mock data

knn_memories = model.create_knn_memories(batch_size = 2) # create collection of KNN memories with the correct batch size (2 in example)

logits = model(data, knn_memories = knn_memories) # (1, 1024, 20000)

You can make the KNN memories read-only by setting add_knn_memory on forward to False

ex.

logits = model(data, knn_memories = knn_memories, add_knn_memory = False) # knn memories will not be updated

With Transformer-XL memories (only the memories that will be discarded will be added to the KNN memory)

import torch
from memorizing_transformers_pytorch import MemorizingTransformer

model = MemorizingTransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 8,
    memorizing_layers = (4, 5),
    max_knn_memories = 64000,
    num_retrieved_memories = 32,
    clear_memories_on_sos_token_id = 1,
    xl_memory_layers = (2, 3, 4, 5),      # xl memory layers - (https://arxiv.org/abs/2007.03356 shows you do not need XL memory on all layers, just the latter ones) - if a KNNAttention layer ends up using XL memories, only the XL memories that will be discarded will be added to long term memory
    xl_max_memories = 512,                # number of xl memories to keep
    shift_knn_memories_down = 1,          # let a layer look at the KNN memories this number of layers above
    shift_xl_memories_down = 1,           # let a layer look at the XL memories this number of layers above, shown to enhance receptive field in ernie-doc paper
)

data = torch.randint(0, 20000, (2, 1024)) # mock data

xl_memories = None

with model.knn_memories_context(batch_size = 2) as knn_memories:
    logits1, xl_memories = model(data, knn_memories = knn_memories, xl_memories = xl_memories)
    logits2, xl_memories = model(data, knn_memories = knn_memories, xl_memories = xl_memories)
    logits3, xl_memories = model(data, knn_memories = knn_memories, xl_memories = xl_memories)

    # ... and so on

KNN Memory

This repository contains a wrapper around Faiss that can automatically store and retrieve key / values

import torch
from memorizing_transformers_pytorch import KNNMemory

memory = KNNMemory(
    dim = 64,                   # dimension of key / values
    max_memories = 64000,       # maximum number of memories to keep (will throw out the oldest memories for now if it overfills)
    num_indices = 2             # this should be equivalent to batch dimension, as each batch keeps track of its own memories, expiring when it sees a new document
)

memory.add(torch.randn(2, 512, 2, 64))  # (batch, seq, key | value, feature dim)
memory.add(torch.randn(2, 512, 2, 64))

memory.clear([0]) # clear batch 0, if it saw an <sos>

memory.add(torch.randn(2, 512, 2, 64))
memory.add(torch.randn(2, 512, 2, 64))

key_values, mask = memory.search(torch.randn(2, 512, 64), topk = 32)

Training

Enwik8 training

$ python train.py

Todo

Citations

@article{wu2022memorizing,
  title   = {Memorizing transformers},
  author  = {Wu, Yuhuai and Rabe, Markus N and Hutchins, DeLesley and Szegedy, Christian},
  journal = {arXiv preprint arXiv:2203.08913},
  year    = {2022}
}
@article{Shazeer2019FastTD,
  title   = {Fast Transformer Decoding: One Write-Head is All You Need},
  author  = {Noam M. Shazeer},
  journal = {ArXiv},
  year    = {2019},
  volume  = {abs/1911.02150}
}
@Article{AlphaFold2021,
  author  = {Jumper, John and Evans, Richard and Pritzel, Alexander and Green, Tim and Figurnov, Michael and Ronneberger, Olaf and Tunyasuvunakool, Kathryn and Bates, Russ and {\v{Z}}{\'\i}dek, Augustin and Potapenko, Anna and Bridgland, Alex and Meyer, Clemens and Kohl, Simon A A and Ballard, Andrew J and Cowie, Andrew and Romera-Paredes, Bernardino and Nikolov, Stanislav and Jain, Rishub and Adler, Jonas and Back, Trevor and Petersen, Stig and Reiman, David and Clancy, Ellen and Zielinski, Michal and Steinegger, Martin and Pacholska, Michalina and Berghammer, Tamas and Bodenstein, Sebastian and Silver, David and Vinyals, Oriol and Senior, Andrew W and Kavukcuoglu, Koray and Kohli, Pushmeet and Hassabis, Demis},
  journal = {Nature},
  title   = {Highly accurate protein structure prediction with {AlphaFold}},
  year    = {2021},
  doi     = {10.1038/s41586-021-03819-2},
  note    = {(Accelerated article preview)},
}
@inproceedings{Rae2020DoTN,
  title   = {Do Transformers Need Deep Long-Range Memory?},
  author  = {Jack W. Rae and Ali Razavi},
  booktitle = {ACL},
  year    = {2020}
}
@misc{ding2021erniedoc,
  title   = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer},
  author  = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang},
  year    = {2021},
  eprint  = {2012.15688},
  archivePrefix = {arXiv},
  primaryClass = {cs.CL}
}
@misc{henry2020querykey,
    title   = {Query-Key Normalization for Transformers},
    author  = {Alex Henry and Prudhvi Raj Dachapally and Shubham Pawar and Yuxuan Chen},
    year    = {2020},
    eprint  = {2010.04245},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

Memory is Attention through Time - Alex Graves

.\lucidrains\memorizing-transformers-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'memorizing-transformers-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.4.1',  # 版本号
  license='MIT',  # 许可证
  description = 'Memorizing Transformer - Pytorch',  # 描述
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/memorizing-transformers-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'memory',
    'retrieval'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.6',
    'filelock',
    'joblib',
    'faiss-gpu',
    'numpy',
    'torch>=1.6',
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\memorizing-transformers-pytorch\train.py

# 导入所需的库
from memorizing_transformers_pytorch import MemorizingTransformer

import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 常量定义
NUM_BATCHES = int(1e5)
BATCH_SIZE = 16
SEQ_LEN = 512
SEGMENTS = 5

LEARNING_RATE = 2e-4
MAX_GRAD_CLIP_NORM = 0.5

VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512

# 辅助函数

# 从 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 从 tokens 解码为字符串
def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# 实例化类似 GPT 的解码器模型
model = MemorizingTransformer(
    num_tokens = 256,
    dim = 512,
    depth = 8,
    memorizing_layers = 4,
    max_knn_memories = 512 * 15,
    num_retrieved_memories = 32,
    xl_memory_layers = (7, 8),
    xl_max_memories = 512,
).cuda()

# 准备 enwik8 数据
with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 定义文本采样数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 数据集和数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN * SEGMENTS)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
valid_dataset = TextSamplerDataset(data_val, SEQ_LEN * SEGMENTS)
valid_loader = cycle(DataLoader(valid_dataset, batch_size = BATCH_SIZE, drop_last = True))

# 优化器
optim = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)

# 训练
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
    model.train()

    data = next(train_loader)

    train_loss = 0.
    with model.knn_memories_context(batch_size = BATCH_SIZE) as knn_memories:
        xl_memories = None    
        seq, labels = data[:, :-1], data[:, 1:]

        for seq_segment, labels_segment in zip(seq.chunk(SEGMENTS, dim = -1), labels.chunk(SEGMENTS, dim = -1)):
            loss, xl_memories = model(
                seq_segment,
                labels = labels_segment,
                knn_memories = knn_memories,
                xl_memories = xl_memories
            )

            train_loss += loss.item() / SEGMENTS
            (loss / SEGMENTS).backward()    

    print(f'training loss: {train_loss}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_CLIP_NORM)
    optim.step()
    optim.zero_grad()

    if not (i % VALIDATE_EVERY):
        model.eval()

        valid_data = next(valid_loader)
        valid_loss = 0.

        with torch.no_grad(), model.knn_memories_context(batch_size = BATCH_SIZE) as knn_memories:
            xl_memories = None    
            seq, labels = data[:, :-1], data[:, 1:]

            for seq_segment, labels_segment in zip(seq.chunk(SEGMENTS, dim = -1), labels.chunk(SEGMENTS, dim = -1)):
                loss, xl_memories = model(
                    seq_segment,
                    labels = labels_segment,
                    knn_memories = knn_memories,
                    xl_memories = xl_memories
                )

                valid_loss += loss.item() / SEGMENTS

        print(f'valid loss: {valid_loss}')

.\lucidrains\memory-compressed-attention\memory_compressed_attention\memory_compressed_attention.py

import torch
from torch import nn
import torch.nn.functional as F

# 定义卷积压缩类
class ConvCompress(nn.Module):
    def __init__(self, dim, ratio = 3, groups = 1):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, ratio, stride = ratio, groups = groups)

    def forward(self, mem):
        mem = mem.transpose(1, 2)
        compressed_mem = self.conv(mem)
        return compressed_mem.transpose(1, 2)

# 主类
class MemoryCompressedAttention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8,
        causal = False,
        compression_factor = 3,
        dropout = 0.):
        super().__init__()
        assert (dim % heads) == 0, 'dimension must be divisible by number of heads'

        self.heads = heads
        self.causal = causal

        self.compression_factor = compression_factor
        self.compress_fn = ConvCompress(dim, compression_factor, groups = heads)

        self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
        self.to_out = nn.Linear(dim, dim)

        self.dropout = nn.Dropout(dropout)

        self.null_k = nn.Parameter(torch.zeros(1, 1, dim))
        self.null_v = nn.Parameter(torch.zeros(1, 1, dim))

    def forward(self, x, input_mask = None):
        b, t, d, h, cf, device = *x.shape, self.heads, self.compression_factor, x.device
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)

        # 确保键和值的序列长度可以被压缩因子整除
        padding = cf - (t % cf)
        if padding < cf:
            k, v = map(lambda t: F.pad(t, (0, 0, padding, 0)), (k, v))

        # 压缩键和值
        k, v = map(self.compress_fn, (k, v))

        # 在第一个查询没有键需要关注的情况下,附加一个空键和值
        nk, nv = map(lambda t: t.expand(b, -1, -1), (self.null_k, self.null_v))
        k = torch.cat((nk, k), dim=1)
        v = torch.cat((nv, v), dim=1)

        # 合并头部
        q, k, v = map(lambda t: t.reshape(*t.shape[:2], h, -1).transpose(1, 2), (q, k, v))

        # 注意力计算
        dots = torch.einsum('bhid,bhjd->bhij', q, k) * d ** -0.5

        mask_value = -torch.finfo(dots.dtype).max

        # 如果需要,进行因果遮罩
        if self.causal:
            mask_q = mask_k = torch.arange(t, device=device)

            if padding < cf:
                mask_k = F.pad(mask_k, (padding, 0))

            mask_k, _ = mask_k.reshape(-1, cf).max(dim=-1)
            mask = mask_q[:, None] < mask_k[None, :]
            mask = F.pad(mask, (1, 0), value=False)

            dots.masked_fill_(mask[None, None, ...], mask_value)
            del mask

        # 输入遮罩
        if input_mask is not None:
            mask_q = mask_k = input_mask
            if padding < cf:
                mask_k = F.pad(mask_k, (padding, 0), value=True)
            mask_k = mask_k.reshape(b, -1, cf).sum(dim=-1) > 0
            mask = mask_q[:, None, :, None] < mask_k[:, None, None, :]
            mask = F.pad(mask, (1, 0), value=True)

            dots.masked_fill_(~mask, mask_value)
            del mask

        # 注意力权重
        attn = dots.softmax(dim=-1)

        # dropout
        attn = self.dropout(attn)
        out = torch.einsum('bhij,bhjd->bhid', attn, v)

        # 分割头部并合并
        out = out.transpose(1, 2).reshape(b, t, d)
        return self.to_out(out)

.\lucidrains\memory-compressed-attention\memory_compressed_attention\__init__.py

# 从 memory_compressed_attention.memory_compressed_attention 模块中导入 MemoryCompressedAttention 类
from memory_compressed_attention.memory_compressed_attention import MemoryCompressedAttention

Memory Compressed Attention

Implementation of the Self-Attention layer of the proposed Memory-Compressed Attention, in Pytorch. This repository offers both the causal and non-causal variant, and will take care of the padding if the sequence length is not divisible by the compression ratio.

The code also resolves an edge-case where the very first query have no keys to attend to in the auto-regressive scenario. The solution is to use null key/values, appended to the final compressed set, so that there is always at least 1 key for all queries to attend to.

Install

$ pip install memory_compressed_attention

Usage

import torch
from memory_compressed_attention import MemoryCompressedAttention

attn = MemoryCompressedAttention(
    dim = 512,
    heads = 8,                 # number of heads
    causal = False,            # auto-regressive or not
    compression_factor = 3,    # compression ratio
    dropout = 0.1              # dropout post-attention
)

x = torch.randn(1, 1024, 512)
mask = torch.ones(1, 1024).bool()

attn(x, input_mask = mask) # (1, 1024, 512)

Citations

@misc{liu2018generating,
    title={Generating Wikipedia by Summarizing Long Sequences},
    author={Peter J. Liu and Mohammad Saleh and Etienne Pot and Ben Goodrich and Ryan Sepassi and Lukasz Kaiser and Noam Shazeer},
    year={2018},
    eprint={1801.10198},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}

.\lucidrains\memory-compressed-attention\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'memory_compressed_attention',  # 包的名称
  packages = find_packages(),  # 查找所有包
  version = '0.0.7',  # 版本号
  license='MIT',  # 许可证
  description = 'Memory-Compressed Self Attention',  # 描述
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/memory-compressed-attention',  # 项目链接
  keywords = ['transformers', 'artificial intelligence', 'attention mechanism'],  # 关键词
  install_requires=[
    'torch'  # 安装所需的依赖
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 开发状态
    'Intended Audience :: Developers',  # 预期受众
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 主题
    'License :: OSI Approved :: MIT License',  # 许可证
    'Programming Language :: Python :: 3.6',  # 编程语言
  ],
)

My explorations into editing the knowledge and memories of an attention network.

Citations

@article{meng2022memit,
  title   = {Mass Editing Memory in a Transformer},
  author  = {Kevin Meng and Sen Sharma, Arnab and Alex Andonian and Yonatan Belinkov and David Bau},
  journal = {arXiv preprint arXiv:2210.07229},
  year    = {2022}
}
@inproceedings{Burns2022DiscoveringLK,
  title  = {Discovering Latent Knowledge in Language Models Without Supervision},
  author = {Collin Burns and Hao-Tong Ye and Dan Klein and Jacob Steinhardt},
  year   = {2022}
}

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\autoregressive_wrapper.py

import torch
from torch import nn
import torch.nn.functional as F

# helper function

# 检查值是否存在的辅助函数
def exists(val):
    return val is not None

# 评估装饰器函数
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        # 保存模型原始训练状态
        was_training = model.training
        # 将模型设置为评估模式
        model.eval()
        # 调用传入的函数
        out = fn(model, *args, **kwargs)
        # 恢复模型原始训练状态
        model.train(was_training)
        return out
    return inner

# top k filtering

# 根据阈值过滤 logits 中的 top k 值
def top_k(logits, thres = 0.9):
    # 计算 top k 的数量
    k = int((1 - thres) * logits.shape[-1])
    # 获取 top k 的值和索引
    val, ind = torch.topk(logits, k)
    # 创建与 logits 相同形状的全为负无穷的张量
    probs = torch.full_like(logits, float('-inf'))
    # 根据索引将 top k 的值填充到 probs 中
    probs.scatter_(1, ind, val)
    return probs

# 自回归包装器类
class AutoregressiveWrapper(nn.Module):
    def __init__(self, net, pad_value = 0):
        super().__init__()
        self.pad_value = pad_value
        self.net = net
        self.max_seq_len = net.max_seq_len

    # 生成序列的方法
    @torch.no_grad()
    @eval_decorator
    def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_thres = 0.9, **kwargs):
        # 获取起始 tokens 的形状和设备信息
        b, t, device = *start_tokens.shape, start_tokens.device

        out = start_tokens

        for _ in range(seq_len):
            # 获取最后 self.max_seq_len 个 token
            x = out[:, -self.max_seq_len:]

            # 获取模型预测的 logits
            logits = self.net(x, **kwargs)[:, -1, :]

            # 过滤 logits 中的 top k 值
            filtered_logits = top_k(logits, thres = filter_thres)
            # 计算 softmax 温度调节后的概率
            probs = F.softmax(filtered_logits / temperature, dim=-1)

            # 从概率分布中采样一个 token
            sample = torch.multinomial(probs, 1)

            # 将采样的 token 添加到输出序列中
            out = torch.cat((out, sample), dim=-1)

            if exists(eos_token):
                # 检查是否存在 eos_token
                is_eos_token = (out == eos_token)

                if is_eos_token.any(dim = -1).all():
                    # 如果所有序列中都存在 eos_token,则停止生成
                    # 创建一个向右移动一位�� eos_token mask
                    shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
                    # 创建一个 mask,标记 eos_token 后的所有位置
                    mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
                    # 将 mask 标记的位置填充为 pad_value
                    out = out.masked_fill(mask, self.pad_value)
                    break

        # 去除起始 tokens,返回生成的序列
        out = out[:, t:]
        return out

    # 前向传播方法
    def forward(self, x, **kwargs):
        # 将输入拆分为输入和标签
        x_inp, x_labels = x[:, :-1], x[:, 1:]
        return self.net(x_inp, labels = x_labels, **kwargs)

.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\cosine_sim_flash_attention.py

# 导入所需的库
import math
import torch
from functools import partial
from torch import nn, einsum
import torch.nn.functional as F
from torch.autograd.function import Function

from einops import rearrange

# 定义常量
EPSILON = 1e-6

# 辅助函数

# 检查变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 对输入张量进行 L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# FlashAttentionFunction 类,实现了自定义的 PyTorch 函数
class FlashAttentionFunction(Function):
    # 前向传播函数
    @staticmethod
    @torch.no_grad()
    def forward(ctx, q, k, v, mask, scale, causal, q_bucket_size, k_bucket_size):
        device = q.device
        max_neg_value = -torch.finfo(q.dtype).max
        qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)

        k_len = k.shape[-2] # 在余弦相似度注意力中,行和受到键/值序列长度的限制

        o = torch.zeros_like(q)
        all_row_sums = torch.zeros((*q.shape[:-1], 1), device = device)

        # 处理输入的 mask
        if not exists(mask):
            mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
        else:
            mask = rearrange(mask, 'b n -> b 1 1 n')
            mask = mask.split(q_bucket_size, dim = -1)

        row_splits = zip(
            q.split(q_bucket_size, dim = -2),
            o.split(q_bucket_size, dim = -2),
            mask,
            all_row_sums.split(q_bucket_size, dim = -2),
        )

        # 遍历每个分块的行
        for ind, (qc, oc, row_mask, row_sums) in enumerate(row_splits):
            q_start_index = ind * q_bucket_size - qk_len_diff

            col_splits = zip(
                k.split(k_bucket_size, dim = -2),
                v.split(k_bucket_size, dim = -2),
            )

            # 遍历每个分块的列
            for k_ind, (kc, vc) in enumerate(col_splits):
                k_start_index = k_ind * k_bucket_size

                # 计算注意力权重
                attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale

                # 如果存在行 mask,则进行填充
                if exists(row_mask):
                    attn_weights.masked_fill_(~row_mask, max_neg_value)

                # 如果启用因果注意力,并且当前位置不应该看到后续位置的信息,则进行填充
                if causal and q_start_index < (k_start_index + k_bucket_size - 1):
                    causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
                    attn_weights.masked_fill_(causal_mask, max_neg_value)

                attn_weights -= scale
                exp_weights = torch.exp(attn_weights)

                # 如果存在行 mask,则进行填充
                if exists(row_mask):
                    exp_weights.masked_fill_(~row_mask, 0.)

                block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON)

                exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)

                oc.add_(exp_values / k_len)
                row_sums.add_(block_row_sums)

        # 保存参数和中间结果,用于反向传播
        ctx.args = (scale, causal, mask, q_bucket_size, k_bucket_size)
        ctx.save_for_backward(q, k, v, o, all_row_sums)

        # 对输出进行缩放
        o.mul_(k_len / all_row_sums)

        return o

    @staticmethod
    @torch.no_grad()
    # 定义一个反向传播函数,接收上下文和梯度作为参数
    def backward(ctx, do):
        # 解包上下文参数
        scale, causal, mask, q_bucket_size, k_bucket_size = ctx.args
        q, k, v, o, l = ctx.saved_tensors

        # 获取设备信息
        device = q.device

        # 计算最大负值
        max_neg_value = -torch.finfo(q.dtype).max
        # 计算 q 和 k 的长度差
        qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)

        # 初始化梯度变量
        dq = torch.zeros_like(q)
        dk = torch.zeros_like(k)
        dv = torch.zeros_like(v)

        # 按照 q_bucket_size 分割张量
        row_splits = zip(
            q.split(q_bucket_size, dim = -2),
            o.split(q_bucket_size, dim = -2),
            do.split(q_bucket_size, dim = -2),
            mask,
            l.split(q_bucket_size, dim = -2),
            dq.split(q_bucket_size, dim = -2)
        )

        # 遍历分割后的张量
        for ind, (qc, oc, doc, row_mask, lc, dqc) in enumerate(row_splits):
            # 计算 q 的起始索引
            q_start_index = ind * q_bucket_size - qk_len_diff

            # 按照 k_bucket_size 分割张量
            col_splits = zip(
                k.split(k_bucket_size, dim = -2),
                v.split(k_bucket_size, dim = -2),
                dk.split(k_bucket_size, dim = -2),
                dv.split(k_bucket_size, dim = -2),
            )

            # 遍历分割后的张量
            for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
                # 计算 k 的起始索引
                k_start_index = k_ind * k_bucket_size

                # 计算注意力权重
                attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale

                # 如果是因果注意力机制,进行掩码处理
                if causal and q_start_index < (k_start_index + k_bucket_size - 1):
                    causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
                    attn_weights.masked_fill_(causal_mask, max_neg_value)

                # 计算指数化的注意力权重
                exp_attn_weights = torch.exp(attn_weights - scale)

                # 如果存在行掩码,进行填充
                if exists(row_mask):
                    exp_attn_weights.masked_fill_(~row_mask, 0.)

                # 计算概率
                p = exp_attn_weights / lc

                # 计算 dv_chunk
                dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
                # 计算 dp
                dp = einsum('... i d, ... j d -> ... i j', doc, vc)

                # 计算 D
                D = (doc * oc).sum(dim = -1, keepdims = True)
                # 计算 ds
                ds = p * scale * (dp - D)

                # 计算 dq_chunk
                dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
                # 计算 dk_chunk
                dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)

                # 累加梯度
                dqc.add_(dq_chunk)
                dkc.add_(dk_chunk)
                dvc.add_(dv_chunk)

        # 返回梯度
        return dq, dk, dv, None, None, None, None, None
# 主类
# 闪光注意力机制用于余弦相似度注意力
# 相对较简单,不再需要担心 softmax 数值稳定性问题,行和受到限制

class FlashAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        scale = 16,
        heads = 8,
        dim_head = 64,
        causal = False,
        q_bucket_size = 512,
        k_bucket_size = 1024
    ):
        super().__init__()
        self.heads = heads

        self.scale = scale
        self.causal = causal

        inner_dim = heads * dim_head

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

        # 内存高效的注意力相关参数
        # 可以在前向传播中被覆盖
        self.q_bucket_size = q_bucket_size
        self.k_bucket_size = k_bucket_size

    def forward(
        self,
        x,
        context = None,
        mask = None,
        q_bucket_size = None,
        k_bucket_size = None,
    ):
        q_bucket_size = default(q_bucket_size, self.q_bucket_size)
        k_bucket_size = default(k_bucket_size, self.k_bucket_size)

        h = self.heads
        context = default(context, x)

        q = self.to_q(x)
        k, v = self.to_kv(context).chunk(2, dim = -1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        q, k = map(l2norm, (q, k))

        out = FlashAttentionFunction.apply(q, k, v, mask, self.scale, self.causal, q_bucket_size, k_bucket_size)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\flash_attention.py

# 导入数学库和 PyTorch 库
import math
import torch
# 导入 partial 函数
from functools import partial
# 从 torch 模块中导入 nn 和 einsum 函数
from torch import nn, einsum
# 从 torch.autograd.function 模块中导入 Function 类
from torch.autograd.function import Function
# 从 einops 库中导入 rearrange 函数

from einops import rearrange

# 定义常量 EPSILON
EPSILON = 1e-10

# 定义辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值的函数
def default(val, d):
    return val if exists(val) else d

# flash attention 前向和后向

# flash attention v1 - https://arxiv.org/abs/2205.14135
# flash attention v2 - https://tridao.me/publications/flash2/flash2.pdf

# 定义 FlashAttentionFunction 类,继承自 Function 类
class FlashAttentionFunction(Function):
    # 静态方法,用 @torch.no_grad() 装饰
    @staticmethod
    @torch.no_grad()
    # 前向传播函数,接收参数 q, k, v, mask, causal, q_bucket_size, k_bucket_size
    def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
        """ Algorithm 1 in the v2 paper """

        # 获取设备信息
        device = q.device
        # 获取最大负值
        max_neg_value = -torch.finfo(q.dtype).max
        # 计算 q 和 k 的长度差
        qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)

        # 初始化输出 o,所有行的和和最大值
        o = torch.zeros_like(q)
        all_row_sums = torch.zeros((*q.shape[:-1], 1), device=device)
        all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, device=device)

        # 缩放因子
        scale = (q.shape[-1] ** -0.5)

        # 计算行和列的分块数量
        num_row_tiles = math.ceil(q.shape[-2] / q_bucket_size)
        num_col_tiles = math.ceil(k.shape[-2] / k_bucket_size)

        # 处理 mask
        if exists(mask) and mask.ndim == 2:
            mask = rearrange(mask, 'b n -> b 1 1 n')

        if not exists(mask):
            col_masks = (None,) * num_col_tiles
            mask = (col_masks,) * num_row_tiles 
        else:
            mask = ((mask,) * num_row_tiles) if mask.shape[-2] == 1 else mask.split(q_bucket_size, dim=-2)
            mask = tuple(((row_mask,) * num_col_tiles) if row_mask.shape[-1] == 1 else row_mask.split(k_bucket_size, dim=-1) for row_mask in mask)

        # 按行分块
        row_splits = zip(
            q.split(q_bucket_size, dim=-2),
            o.split(q_bucket_size, dim=-2),
            mask,
            all_row_sums.split(q_bucket_size, dim=-2),
            all_row_maxes.split(q_bucket_size, dim=-2),
        )

        for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
            q_start_index = ind * q_bucket_size - qk_len_diff

            # 按列分块
            col_splits = zip(
                k.split(k_bucket_size, dim=-2),
                v.split(k_bucket_size, dim=-2),
                row_mask
            )

            for k_ind, (kc, vc, col_mask) in enumerate(col_splits):
                k_start_index = k_ind * k_bucket_size

                # 计算注意力权重
                attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale

                if exists(col_mask):
                    attn_weights.masked_fill_(~col_mask, max_neg_value)

                if causal and q_start_index < (k_start_index + k_bucket_size - 1):
                    causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(q_start_index - k_start_index + 1)
                    attn_weights.masked_fill_(causal_mask, max_neg_value)

                block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
                new_row_maxes = torch.maximum(block_row_maxes, row_maxes)

                exp_weights = torch.exp(attn_weights - new_row_maxes)

                if exists(col_mask):
                    exp_weights.masked_fill_(~col_mask, 0.)

                block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)

                exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)

                exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)

                new_row_sums = exp_row_max_diff * row_sums + block_row_sums

                oc.mul_(exp_row_max_diff).add_(exp_values)

                row_maxes.copy_(new_row_maxes)
                row_sums.copy_(new_row_sums)

            oc.div_(row_sums)

        lse = all_row_sums.log() + all_row_maxes

        # 保存参数并返回输出 o
        ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
        ctx.save_for_backward(q, k, v, o, lse)

        return o

    # 静态方法,用 @torch.no_grad() 装饰
    @staticmethod
    @torch.no_grad()
    # 定义一个向后传播函数,实现 v2 论文中的算法 2
    def backward(ctx, do):
        """ Algorithm 2 in the v2 paper """

        # 从上下文中获取参数
        causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
        q, k, v, o, lse = ctx.saved_tensors

        # 获取计算设备
        device = q.device

        # 获取最大负值
        max_neg_value = -torch.finfo(q.dtype).max
        # 计算 q 和 k 的长度差
        qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)

        # 初始化 dq, dk, dv
        dq = torch.zeros_like(q)
        dk = torch.zeros_like(k)
        dv = torch.zeros_like(v)

        # 按照 q_bucket_size 分割 q, o, do, mask, lse, dq
        row_splits = zip(
            q.split(q_bucket_size, dim = -2),
            o.split(q_bucket_size, dim = -2),
            do.split(q_bucket_size, dim = -2),
            mask,
            lse.split(q_bucket_size, dim = -2),
            dq.split(q_bucket_size, dim = -2)
        )

        # 遍历每个分割后的行
        for ind, (qc, oc, doc, row_mask, lsec, dqc) in enumerate(row_splits):
            q_start_index = ind * q_bucket_size - qk_len_diff

            # 按照 k_bucket_size 分割 k, v, dk, dv, row_mask
            col_splits = zip(
                k.split(k_bucket_size, dim = -2),
                v.split(k_bucket_size, dim = -2),
                dk.split(k_bucket_size, dim = -2),
                dv.split(k_bucket_size, dim = -2),
                row_mask
            )

            # 遍历每个分割后的列
            for k_ind, (kc, vc, dkc, dvc, col_mask) in enumerate(col_splits):
                k_start_index = k_ind * k_bucket_size

                # 计算注意力权重
                attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale

                # 如果是因果注意力机制,并且 q_start_index 小于 (k_start_index + k_bucket_size - 1)
                if causal and q_start_index < (k_start_index + k_bucket_size - 1):
                    causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
                    attn_weights.masked_fill_(causal_mask, max_neg_value)

                # 计算概率
                p = torch.exp(attn_weights - lsec)

                # 如果存在列掩码,则将概率中对应位置置零
                if exists(col_mask):
                    p.masked_fill_(~col_mask, 0.)

                # 计算 dv_chunk
                dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
                dp = einsum('... i d, ... j d -> ... i j', doc, vc)

                # 计算 D 和 ds
                D = (doc * oc).sum(dim = -1, keepdims = True)
                ds = p * scale * (dp - D)

                # 计算 dq_chunk, dk_chunk
                dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
                dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)

                # 累加到梯度中
                dqc.add_(dq_chunk)
                dkc.add_(dk_chunk)
                dvc.add_(dv_chunk)

        # 返回梯度 dq, dk, dv
        return dq, dk, dv, None, None, None, None
# 主类 FlashAttention,用于实现注意力机制
# 在纯 PyTorch 中实现会比在 CUDA 中实现慢很多
# 用于调试和教育目的

class FlashAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,  # 输入维度
        heads = 8,  # 头数
        dim_head = 64,  # 每个头的维度
        causal = False,  # 是否使用因果注意力
        q_bucket_size = 512,  # 查询桶大小
        k_bucket_size = 1024  # 键值桶大小
    ):
        super().__init__()
        self.heads = heads
        self.causal = causal

        inner_dim = heads * dim_head

        self.to_q = nn.Linear(dim, inner_dim, bias = False)  # 查询线性层
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)  # 键值线性层
        self.to_out = nn.Linear(inner_dim, dim, bias = False)  # 输出线性层

        # 内存高效的注意力相关参数
        # 可以在前向传播中被覆盖
        self.q_bucket_size = q_bucket_size
        self.k_bucket_size = k_bucket_size

    def forward(
        self,
        x,  # 输入张量
        context = None,  # 上下文张量
        mask = None,  # 掩码张量
        q_bucket_size = None,  # 查询桶大小
        k_bucket_size = None,  # 键值桶大小
    ):
        q_bucket_size = default(q_bucket_size, self.q_bucket_size)  # 设置查询桶大小
        k_bucket_size = default(k_bucket_size, self.k_bucket_size)  # 设置键值桶大小

        h = self.heads
        context = default(context, x)  # 如果上下文为空,则使用输入张量作为上下文

        q = self.to_q(x)  # 计算查询张量
        k, v = self.to_kv(context).chunk(2, dim = -1)  # 计算键值张量并分割为键和值

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))  # 重排张量形状

        out = FlashAttentionFunction.apply(q, k, v, mask, self.causal, q_bucket_size, k_bucket_size)  # 调用自定义的注意力函数

        out = rearrange(out, 'b h n d -> b n (h d)')  # 重排输出张量形状
        return self.to_out(out)  # 返回输出结果

.\lucidrains\memory-efficient-attention-pytorch\memory_efficient_attention_pytorch\memory_efficient_attention.py

import torch
from functools import partial
from torch import nn, einsum
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F

from einops import rearrange

# 导入所需的库

def exists(val):
    return val is not None

# 检查值是否存在的辅助函数

def default(val, d):
    return val if exists(val) else d

# 如果值存在则返回该值,否则返回默认值的辅助函数

# regular attention

def attention(
    q, k, v,
    mask = None,
    causal = False,
    attn_bias = None,
    **kwargs
):
    scale = q.shape[-1] ** -0.5
    q = q * scale

    # 缩放查询向量

    sim = einsum('b h i d, b h j d -> b h i j', q, k)

    # 计算注意力分数

    if exists(attn_bias):
        sim = sim + attn_bias

    # 添加注意力偏置

    mask_value = -torch.finfo(sim.dtype).max

    # 计算掩码值

    if exists(mask):
        if mask.ndim == 2:
            mask = rearrange(mask, 'b j -> b 1 1 j')
        sim = sim.masked_fill(~mask, mask_value)

    # 应用掩码

    if causal:
        i, j = sim.shape[-2:]
        mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1)
        sim = sim.masked_fill(mask, mask_value)

    # 应用因果掩码

    sim = sim - sim.amax(dim = -1, keepdim = True).detach()
    attn = sim.softmax(dim = -1)

    # 计算注意力权重

    out = einsum('b h i j, b h j d -> b h i d', attn, v)
    return out

    # 计算输出

# memory efficient attention

def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices, dropout):
    q_start_index, k_start_index, q_chunk_size, k_chunk_size, device = *qk_start_indices, q.shape[-2], k.shape[-2], q.device

    weight = einsum('b h i d, b h j d -> b h i j', q, k)

    # 计算权重

    if exists(attn_bias_chunk):
        weight = weight + attn_bias_chunk

    # 添加注意力偏置

    mask_value = -torch.finfo(weight.dtype).max

    # 计算掩码值

    if exists(mask):
        mask = rearrange(mask, 'b j -> b 1 1 j')
        weight = weight.masked_fill(~mask, mask_value)

    # 应用掩码

    if causal and q_start_index < (k_start_index + k_chunk_size - 1):
        causal_mask = torch.ones((q_chunk_size, k_chunk_size), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
        weight = weight.masked_fill(causal_mask, mask_value)

    # 应用因果掩码

    weight_max = weight.amax(dim = -1, keepdim = True).detach()
    weight = weight - weight_max

    exp_weight = weight.exp()

    exp_weight = F.dropout(exp_weight, p = dropout)

    weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v)

    return exp_weight.sum(dim = -1), weighted_value, rearrange(weight_max, '... 1 -> ...')

checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk)

# 创建检查点函数

def memory_efficient_attention(
    q, k, v,
    mask = None,
    causal = False,
    attn_bias = None,
    q_bucket_size = 512,
    k_bucket_size = 1024,
    eps = 1e-8,
    dropout = 0.,
    training = False
):
    scale = q.shape[-1] ** -0.5
    q = q * scale

    # 缩放查询向量

    needs_backwards = q.requires_grad or k.requires_grad or v.requires_grad
    summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk

    # 根据是否需要反向传播选择函数

    q_chunks = q.split(q_bucket_size, dim = -2)
    k_chunks = k.split(k_bucket_size, dim = -2)
    v_chunks = v.split(k_bucket_size, dim = -2)
    mask_chunks = mask.split(k_bucket_size, dim = -1) if exists(mask) else ((None,) * len(k_chunks))

    if exists(attn_bias):
        i, j = attn_bias.shape[-2:]
        attn_bias_chunks = attn_bias.split(q_bucket_size, dim = -2)
        attn_bias_chunks = list(map(lambda t: t.split(k_bucket_size, dim = -1), attn_bias_chunks))

    # 将输入分块

    out = []

    # 初始化输出列表
    # 遍历查询块列表,获取索引和查询块
    for q_index, q_chunk in enumerate(q_chunks):
        # 初始化空列表,用于存储期望权重、加权值和权重最大值
        exp_weights = []
        weighted_values = []
        weight_maxes = []

        # 遍历键值块、值块和掩码块的元组列表
        for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)):
            # 计算查询块和键块的起始索引
            q_start_index = q_index * q_bucket_size
            k_start_index = k_index * k_bucket_size

            # 如果是因果的且键块的起始索引大于查询块的结束索引,则跳过当前循环
            if causal and k_start_index > (q_start_index + q_chunk.shape[-2] - 1):
                continue

            # 如果存在注意力偏置,则获取当前注意力偏置块
            attn_bias_chunk = attn_bias_chunks[q_index][k_index] if exists(attn_bias) else None

            # 调用 summarize_qkv_fn 函数,计算期望权重、加权值和权重最大值
            exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn(
                q_chunk,
                k_chunk,
                v_chunk,
                mask_chunk,
                attn_bias_chunk,
                causal,
                (q_start_index, k_start_index),
                dropout if training else 0.
            )

            # 将计算得到的结果添加到对应的列表中
            exp_weights.append(exp_weight_chunk)
            weighted_values.append(weighted_value_chunk)
            weight_maxes.append(weight_max_chunk)

        # 将权重最大值堆叠在一起
        weight_maxes = torch.stack(weight_maxes, dim=-1)

        # 将加权值堆叠在一起
        weighted_values = torch.stack(weighted_values, dim=-1)
        # 将期望权重堆叠在一起
        exp_weights = torch.stack(exp_weights, dim=-1)

        # 计算全局最大值
        global_max = weight_maxes.amax(dim=-1, keepdim=True)
        # 计算重新归一化因子
        renorm_factor = (weight_maxes - global_max).exp().detach()

        # 期望权重乘以重新归一化因子
        exp_weights = exp_weights * renorm_factor
        # 加权值乘以重新排列的重新归一化因子
        weighted_values = weighted_values * rearrange(renorm_factor, '... c -> ... 1 c')

        # 对所有加权值进行求和
        all_values = weighted_values.sum(dim=-1)
        # 对所有期望权重进行求和
        all_weights = exp_weights.sum(dim=-1)

        # 对归一化���的值进行计算
        normalized_values = all_values / (rearrange(all_weights, '... -> ... 1') + eps)
        # 将归一化后的值添加到输出列表中
        out.append(normalized_values)

    # 沿着指定维度连接输出列表中的张量
    return torch.cat(out, dim=-2)
# 主要的注意力机制类

class Attention(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        dim,  # 输入维度
        heads = 8,  # 头数,默认为8
        dim_head = 64,  # 每个头的维度,默认为64
        dropout = 0.,  # 丢弃概率,默认为0
        causal = False,  # 是否使用因果注意力,默认为False
        memory_efficient = False,  # 是否使用内存高效的注意力,默认为False
        q_bucket_size = 512,  # 查询桶大小,默认为512
        k_bucket_size = 1024  # 键值桶大小,默认为1024
    ):
        super().__init__()
        self.heads = heads  # 头数
        self.causal = causal  # 是否因果
        self.dropout = dropout  # 丢弃概率
        inner_dim = heads * dim_head  # 内部维度为头数乘以每个头的维度

        self.to_q = nn.Linear(dim, inner_dim, bias = False)  # 输入到查询的线性层
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)  # 输入到键值的线性层
        self.to_out = nn.Linear(inner_dim, dim, bias = False)  # 输出的线性层

        # 内存高效注意力相关参数
        # 可在前向传播中覆盖
        self.memory_efficient = memory_efficient  # 是否内存高效
        self.q_bucket_size = q_bucket_size  # 查询桶大小
        self.k_bucket_size = k_bucket_size  # 键值桶大小

    # 前向传播函数
    def forward(
        self,
        x,  # 输入张量
        context = None,  # 上下文,默认为None
        mask = None,  # 掩码,默认为None
        attn_bias = None,  # 注意力偏置,默认为None
        memory_efficient = None,  # 是否内存高效,默认为None
        q_bucket_size = None,  # 查询桶大小,默认为None
        k_bucket_size = None,  # 键值桶大小,默认为None
    ):
        memory_efficient = default(memory_efficient, self.memory_efficient)  # 使用默认值或者自定义值
        q_bucket_size = default(q_bucket_size, self.q_bucket_size)  # 使用默认值或者自定义值
        k_bucket_size = default(k_bucket_size, self.k_bucket_size)  # 使用默认值或者自定义值

        h = self.heads  # 头数
        context = default(context, x)  # 上下文,默认为输入张量

        q = self.to_q(x)  # 查询张量
        k, v = self.to_kv(context).chunk(2, dim = -1)  # 键值张量拆分为k和v

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))  # 重排张量形状

        attn_fn = attention if not memory_efficient else memory_efficient_attention  # 根据内存高效性选择不同的注意力函数

        out = attn_fn(q, k, v, mask = mask, attn_bias = attn_bias, causal = self.causal, q_bucket_size = q_bucket_size, 
                    k_bucket_size = k_bucket_size, dropout = self.dropout, training = self.training)  # 注意力计算

        out = rearrange(out, 'b h n d -> b n (h d)')  # 重排输出形状
        return self.to_out(out)  # 输出结果
posted @ 2024-06-28 14:01  绝不原创的飞龙  阅读(16)  评论(0编辑  收藏  举报