Lucidrains-系列项目源码解析-二十七-

Lucidrains 系列项目源码解析(二十七)

.\lucidrains\metaformer-gpt\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的信息
setup(
  # 包的名称
  name = 'metaformer-gpt',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.0.5',
  # 许可证类型
  license='MIT',
  # 描述信息
  description = 'Metaformer - GPT',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/metaformer-gpt',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention-less'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.4',
    'scipy',
    'torch>=1.6',
  ],
  # 分类标签
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\metaformer-gpt\train.py

# 导入所需的库
import gzip
import random
import numpy as np
import torch
import torch.optim as optim
import tqdm
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from metaformer_gpt import MetaformerGPT
from metaformer_gpt.autoregressive_wrapper import AutoregressiveWrapper

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 1024

# 定义辅助函数
def cycle(loader):
    while True:
        for data in loader:
            yield data

def decode_token(token):
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

# 实例化类似 GPT 的解码器模型
model = MetaformerGPT(
    num_tokens = 256,
    dim = 512,
    depth = 8,
    heads = 16,
    dim_head = 32
)

model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
model.cuda()

# 准备 enwik8 数据
with gzip.open("./data/enwik8.gz") as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 定义文本采样数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))

# 定义优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward()

    print(f"training loss: {loss.item()}")
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f"validation loss: {loss.item()}")

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f"%s \n\n %s", (prime, "*" * 100))

        sample = model.generate(inp[None, ...], GENERATE_LENGTH)
        output_str = decode_tokens(sample[0])
        print(output_str)

.\lucidrains\metnet3-pytorch\metnet3_pytorch\metnet3_pytorch.py

# 导入必要的库
from pathlib import Path
from functools import partial
from collections import namedtuple
from contextlib import contextmanager

import torch
from torch import nn, Tensor, einsum
import torch.distributed as dist
from torch.autograd import Function
import torch.nn.functional as F
from torch.nn import Module, ModuleList, Sequential

# 导入 einops 库中的函数和层
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce

# 导入 beartype 库中的类型注解
from beartype import beartype
from beartype.typing import Tuple, Union, List, Optional, Dict, Literal

import pickle

# 定义一些辅助函数

# 判断值是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 将单个元素打包成指定模式的元组
def pack_one(x, pattern):
    return pack([x], pattern)

# 从元组中解包单个元素
def unpack_one(x, ps, pattern):
    return unpack(x, ps, pattern)[0]

# 将值转换为元组
def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# 安全除法,避免分母为零
def safe_div(num, den, eps = 1e-10):
    return num / den.clamp(min = eps)

# 张量归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# 准备在分布式训练中使用的批量归一化

# 根据是否处于分布式环境选择使用 SyncBatchNorm 还是 BatchNorm2d
def MaybeSyncBatchnorm2d(is_distributed = None):
    is_distributed = default(is_distributed, dist.is_initialized() and dist.get_world_size() > 1)
    return nn.SyncBatchNorm if is_distributed else nn.BatchNorm2d

# 冻结批量归一化层
@contextmanager
def freeze_batchnorm(bn):
    assert not exists(next(bn.parameters(), None))

    was_training = bn.training
    was_tracking_stats = bn.track_running_stats
    bn.eval()
    bn.track_running_stats = False

    yield bn

    bn.train(was_training)
    bn.track_running_stats = was_tracking_stats

# 损失缩放

# 自定义损失缩放函数
class LossScaleFunction(Function):
    @staticmethod
    def forward(ctx, x, eps):
        ctx.eps = eps
        assert x.ndim == 4
        return x

    @staticmethod
    def backward(ctx, grads):
        num_channels = grads.shape[1]

        safe_div_ = partial(safe_div, eps = ctx.eps)

        weight = safe_div_(1., grads.norm(p = 2, keepdim = True, dim = (-1, -2)))
        l1_normed_weight = safe_div_(weight, weight.sum(keepdim = True, dim = 1))

        scaled_grads = num_channels * l1_normed_weight * grads

        return scaled_grads, None

# 损失缩放器
class LossScaler(Module):
    def __init__(self, eps = 1e-5):
        super().__init__()
        self.eps = eps

    def forward(self, x):
        return LossScaleFunction.apply(x, self.eps)

# 中心裁剪

# 中心填充模块
class CenterPad(Module):
    def __init__(self, target_dim):
        super().__init__()
        self.target_dim = target_dim

    def forward(self, x):
        target_dim = self.target_dim
        *_, height, width = x.shape
        assert target_dim >= height and target_dim >= width

        height_pad = target_dim - height
        width_pad = target_dim - width
        left_height_pad = height_pad // 2
        left_width_pad = width_pad // 2

        return F.pad(x, (left_height_pad, height_pad - left_height_pad, left_width_pad, width_pad - left_width_pad), value = 0.)

# 中心裁剪模块
class CenterCrop(Module):
    def __init__(self, crop_dim):
        super().__init__()
        self.crop_dim = crop_dim

    def forward(self, x):
        crop_dim = self.crop_dim
        *_, height, width = x.shape
        assert (height >= crop_dim) and (width >= crop_dim)

        cropped_height_start_idx = (height - crop_dim) // 2
        cropped_width_start_idx = (width - crop_dim) // 2

        height_slice = slice(cropped_height_start_idx, cropped_height_start_idx + crop_dim)
        width_slice = slice(cropped_width_start_idx, cropped_width_start_idx + crop_dim)
        return x[..., height_slice, width_slice]

# 下采样和上采样

# 下采样使用最大池化,上采样使用转置卷积
# todo: 弄清楚从 4km 到 1km 的 4 倍上采样

# 2 倍下采样
Downsample2x = partial(nn.MaxPool2d, kernel_size = 2, stride = 2)

# 2 倍上采样
def Upsample2x(dim, dim_out = None):
    # 如果未提供输出维度,则使用输入维度作为输出维度
    dim_out = default(dim_out, dim)
    # 返回一个转置卷积层,输入维度为dim,输出维度为dim_out,卷积核大小为2,步长为2
    return nn.ConvTranspose2d(dim, dim_out, kernel_size = 2, stride = 2)
# 定义一个条件可选的 ResNet 块
class Block(Module):
    def __init__(self, dim, dim_out):
        super().__init__()
        # 使用卷积层进行投影
        self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
        # 使用通道层归一化
        self.norm = ChanLayerNorm(dim_out)
        # 使用 ReLU 激活函数
        self.act = nn.ReLU()

    def forward(self, x, scale_shift = None):
        # 对输入进行投影
        x = self.proj(x)
        # 对投影结果进行归一化
        x = self.norm(x)

        # 如果存在 scale_shift 参数,则进行缩放和平移
        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        # 对结果进行激活
        x = self.act(x)
        return x

# 定义一个 ResNet 块
class ResnetBlock(Module):
    def __init__(
        self,
        dim,
        dim_out = None,
        *,
        cond_dim = None
    ):
        super().__init__()
        dim_out = default(dim_out, dim)
        self.mlp = None

        # 如果存在条件维度,则创建一个 MLP
        if exists(cond_dim):
            self.mlp = Sequential(
                nn.ReLU(),
                nn.Linear(cond_dim, dim_out * 2)
            )

        # 创建两个 Block 实例
        self.block1 = Block(dim, dim_out)
        self.block2 = Block(dim_out, dim_out)
        # 如果输入维度和输出维度不同,则使用卷积层进行投影
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, cond = None):

        scale_shift = None

        # 断言条件:MLP 和条件参数 cond 必须同时存在或同时不存在
        assert not (exists(self.mlp) ^ exists(cond))

        # 如果存在 MLP 和条件参数 cond,则进行处理
        if exists(self.mlp) and exists(cond):
            cond = self.mlp(cond)
            cond = rearrange(cond, 'b c -> b c 1 1')
            scale_shift = cond.chunk(2, dim = 1)

        # 对输入进行第一个 Block 处理
        h = self.block1(x, scale_shift = scale_shift)

        # 对第一个 Block 处理结果进行第二个 Block 处理
        h = self.block2(h)

        # 返回结果加上残差连接
        return h + self.res_conv(x)

# 定义一个包含多个 ResNet 块的模块
class ResnetBlocks(Module):
    def __init__(
        self,
        dim,
        *,
        dim_in = None,
        depth = 1,
        cond_dim = None
    ):
        super().__init__()
        curr_dim = default(dim_in, dim)

        blocks = []
        # 根据深度循环创建多个 ResNet 块
        for _ in range(depth):
            blocks.append(ResnetBlock(dim = curr_dim, dim_out = dim, cond_dim = cond_dim))
            curr_dim = dim

        self.blocks = ModuleList(blocks)

    def forward(self, x, cond = None):

        for block in self.blocks:
            x = block(x, cond = cond)

        return x

# 多头 RMS 归一化,用于查询/键归一化注意力
class RMSNorm(Module):
    def __init__(
        self,
        dim,
        *,
        heads
    ):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(heads, 1, dim))

    def forward(self, x):
        return F.normalize(x, dim = -1) * self.scale * self.gamma

# 在 ResNet 块中使用层归一化的原因
class ChanLayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) * var.clamp(min = self.eps).rsqrt() * self.g + self.b

# MBConv

# 定义一个 Squeeze-and-Excitation 模块
class SqueezeExcitation(Module):
    def __init__(self, dim, shrinkage_rate = 0.25):
        super().__init__()
        hidden_dim = int(dim * shrinkage_rate)

        # 构建门控网络
        self.gate = Sequential(
            Reduce('b c h w -> b c', 'mean'),
            nn.Linear(dim, hidden_dim, bias = False),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim, bias = False),
            nn.Sigmoid(),
            Rearrange('b c -> b c 1 1')
        )

    def forward(self, x):
        return x * self.gate(x)

# 定义一个 MBConv 残差模块
class MBConvResidual(Module):
    def __init__(self, fn, dropout = 0.):
        super().__init__()
        self.fn = fn
        self.dropsample = Dropsample(dropout)

    def forward(self, x):
        out = self.fn(x)
        out = self.dropsample(out)
        return out + x

# 定义一个 Dropout 模块
class Dropsample(Module):
    def __init__(self, prob = 0):
        super().__init__()
        self.prob = prob
    # 定义一个前向传播函数,接受输入张量 x
    def forward(self, x):
        # 获取输入张量 x 的设备信息
        device = x.device

        # 如果概率为 0 或者不处于训练状态,则直接返回输入张量 x
        if self.prob == 0. or (not self.training):
            return x

        # 生成一个与输入张量 x 形状相同的随机掩码,用于随机丢弃部分数据
        keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_() > self.prob
        # 对输入张量 x 进行随机丢弃操作,并进行归一化处理
        return x * keep_mask / (1 - self.prob)
# 定义一个 MBConv 模块,用于 MobileNetV3 的基本块
def MBConv(
    dim_in,
    dim_out,
    *,
    downsample,
    expansion_rate = 4,
    shrinkage_rate = 0.25,
    dropout = 0.
):
    # 计算隐藏层维度
    hidden_dim = int(expansion_rate * dim_out)
    # 如果 downsample 为真,则步长为 2,否则为 1
    stride = 2 if downsample else 1

    # 创建一个 MaybeSyncBatchnorm2d 类的实例
    batchnorm_klass = MaybeSyncBatchnorm2d()

    # 构建网络结构
    net = Sequential(
        nn.Conv2d(dim_in, hidden_dim, 1),
        batchnorm_klass(hidden_dim),
        nn.GELU(),
        nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
        batchnorm_klass(hidden_dim),
        nn.GELU(),
        SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
        nn.Conv2d(hidden_dim, dim_out, 1),
        batchnorm_klass(dim_out)
    )

    # 如果输入维度等于输出维度且不下采样,则添加 MBConvResidual 模块
    if dim_in == dim_out and not downsample:
        net = MBConvResidual(net, dropout = dropout)

    return net

# attention related classes

# 定义一个 XCAttention 类,实现特定的线性注意力机制
class XCAttention(Module):
    """
    this specific linear attention was proposed in https://arxiv.org/abs/2106.09681 (El-Nouby et al.)
    """

    @beartype
    def __init__(
        self,
        *,
        dim,
        cond_dim: Optional[int] = None,
        dim_head = 32,
        heads = 8,
        scale = 8,
        flash = False,
        dropout = 0.
    ):
        super().__init__()
        dim_inner = dim_head * heads

        self.has_cond = exists(cond_dim)

        self.film = None

        # 如果有条件输入,则构建 FILM 网络
        if self.has_cond:
            self.film = Sequential(
                nn.Linear(cond_dim, dim * 2),
                nn.SiLU(),
                nn.Linear(dim * 2, dim * 2),
                Rearrange('b (r d) -> r b 1 d', r = 2)
            )

        # LayerNorm 层
        self.norm = nn.LayerNorm(dim, elementwise_affine = not self.has_cond)

        # QKV 线性映射
        self.to_qkv = Sequential(
            nn.Linear(dim, dim_inner * 3, bias = False),
            Rearrange('b n (qkv h d) -> qkv b h d n', qkv = 3, h = heads)
        )

        self.scale = scale

        self.temperature = nn.Parameter(torch.ones(heads, 1, 1))

        self.attn_dropout = nn.Dropout(dropout)

        # 输出映射
        self.to_out = Sequential(
            Rearrange('b h d n -> b n (h d)'),
            nn.Linear(dim_inner, dim)
        )

    # 前向传播函数
    def forward(
        self,
        x,
        cond: Optional[Tensor] = None
    ):
        x = rearrange(x, 'b c h w -> b h w c')
        x, ps = pack_one(x, 'b * c')

        x = self.norm(x)

        # 条件输入
        if exists(self.film):
            assert exists(cond)

            gamma, beta = self.film(cond)
            x = x * gamma + beta

        # 余弦相似度线性注意力机制
        q, k, v = self.to_qkv(x)

        q, k = map(l2norm, (q, k))
        q = q * self.temperature.exp()

        sim = einsum('b h i n, b h j n -> b h i j', q, k) * self.scale
        attn = sim.softmax(dim = -1)

        out = einsum('b h i j, b h j n -> b h i n', attn, v)

        out = self.to_out(out)

        out = unpack_one(out, ps, 'b * c')
        return rearrange(out, 'b h w c -> b c h w')

# 定义一个 Attention 类,实现注意力机制
class Attention(Module):
    def __init__(
        self,
        dim,
        cond_dim = None,
        heads = 32,
        dim_head = 32,
        dropout = 0.,
        window_size = 8,
        num_registers = 1
    ):
        # 调用父类的构造函数
        super().__init__()
        # 断言寄存器数量大于0
        assert num_registers > 0
        # 断言维度应该可以被每个头的维度整除
        assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'

        # 计算内部维度
        dim_inner = dim_head * heads
        self.heads = heads
        # 缩放因子
        self.scale = dim_head ** -0.5

        # 检查是否有条件
        self.has_cond = exists(cond_dim)

        self.film = None

        # 如果有条件
        if self.has_cond:
            # 创建 FILM 模块
            self.film = Sequential(
                nn.Linear(cond_dim, dim * 2),
                nn.SiLU(),
                nn.Linear(dim * 2, dim * 2),
                Rearrange('b (r d) -> r b 1 d', r = 2)
            )

        # 归一化层
        self.norm = nn.LayerNorm(dim, elementwise_affine = not self.has_cond)

        # 线性变换到查询、键、值
        self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)

        # 查询和键的 RMS 归一化
        self.q_norm = RMSNorm(dim_head, heads = heads)
        self.k_norm = RMSNorm(dim_head, heads = heads)

        # 注意力机制
        self.attend = nn.Sequential(
            nn.Softmax(dim = -1),
            nn.Dropout(dropout)
        )

        # 输出层
        self.to_out = nn.Sequential(
            nn.Linear(dim_inner, dim, bias = False),
            nn.Dropout(dropout)
        )

        # 相对位置偏差

        num_rel_pos_bias = (2 * window_size - 1) ** 2

        # 创建相对位置偏差的 Embedding
        self.rel_pos_bias = nn.Embedding(num_rel_pos_bias + 1, self.heads)

        pos = torch.arange(window_size)
        grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
        grid = rearrange(grid, 'c i j -> (i j) c')
        rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
        rel_pos += window_size - 1
        rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)

        rel_pos_indices = F.pad(rel_pos_indices, (num_registers, 0, num_registers, 0), value = num_rel_pos_bias)
        self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)

    def forward(
        self,
        x: Tensor,
        cond: Optional[Tensor] = None
    ):
        # 获取设备、头数、偏差索引
        device, h, bias_indices = x.device, self.heads, self.rel_pos_indices

        # 归一化输入
        x = self.norm(x)

        # 条件
        if exists(self.film):
            assert exists(cond)

            gamma, beta = self.film(cond)
            x = x * gamma + beta

        # 为查询、键、值进行投影
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)

        # 分割头
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 缩放
        q, k = self.q_norm(q), self.k_norm(k)

        # 相似度
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        # 添加位置偏差
        bias = self.rel_pos_bias(bias_indices)
        sim = sim + rearrange(bias, 'i j h -> h i j')

        # 注意力
        attn = self.attend(sim)

        # 聚合
        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        # 合并头部输出
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
# 定义一个名为 MaxViT 的类,继承自 Module 类
class MaxViT(Module):
    # 初始化函数,接受一系列参数
    def __init__(
        self,
        *,
        dim,  # 特征维度
        depth,  # 模型深度
        cond_dim = 32,   # 用于条件化的前导时间嵌入
        heads = 32,  # 多头注意力机制中的头数
        dim_head = 32,  # 每个头的维度
        window_size = 8,  # 窗口大小
        mbconv_expansion_rate = 4,  # MBConv 层的扩张率
        mbconv_shrinkage_rate = 0.25,  # MBConv 层的收缩率
        dropout = 0.1,  # 丢弃率
        num_register_tokens = 4  # 寄存器令牌数量
    ):
        super().__init__()
        # 如果 depth 是整数,则转换为元组
        depth = (depth,) if isinstance(depth, int) else depth
        # 断言寄存器令牌数量大于0
        assert num_register_tokens > 0

        self.cond_dim = cond_dim

        # 变量

        num_stages = len(depth)

        # 计算每个阶段的维度
        dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
        dim_pairs = tuple(zip(dims[:-1], dims[1:]))

        self.layers = nn.ModuleList([])

        # 窗口大小

        self.window_size = window_size

        self.register_tokens = nn.ParameterList([])

        # 遍历各个阶段

        for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
            for stage_ind in range(layer_depth):
                is_first = stage_ind == 0
                stage_dim_in = layer_dim_in if is_first else layer_dim

                # 创建 MBConv 层
                conv = MBConv(
                    stage_dim_in,
                    layer_dim,
                    downsample = is_first,
                    expansion_rate = mbconv_expansion_rate,
                    shrinkage_rate = mbconv_shrinkage_rate
                )

                # 创建块级别的注意力机制
                block_attn = Attention(dim = layer_dim, cond_dim = cond_dim, heads = heads, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)

                # 创建网格级别的注意力机制
                grid_attn = Attention(dim = layer_dim, cond_dim = cond_dim, heads = heads, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)

                # 创建寄存器令牌
                register_tokens = nn.Parameter(torch.randn(num_register_tokens, layer_dim))

                # 将 MBConv 层、块级别注意力机制、网格级别注意力机制组合成一个模块列表
                self.layers.append(ModuleList([
                    conv,
                    block_attn,
                    grid_attn
                ]))

                # 将寄存器令牌添加到参数列表中
                self.register_tokens.append(register_tokens)

    # 前向传播函数,接受输入张量 x 和条件张量 cond
    def forward(
        self,
        x: Tensor,
        cond: Tensor
    ):
        # 断言条件的形状与输入 x 的形状一致
        assert cond.shape == (x.shape[0], self.cond_dim)

        # 获取输入 x 的批量大小和窗口大小
        b, w = x.shape[0], self.window_size

        # 遍历每个层和对应的注册令牌
        for (conv, block_attn, grid_attn), register_tokens in zip(self.layers, self.register_tokens):
            # 对输入 x 进行卷积操作
            x = conv(x)

            # block-like attention

            # 重新排列输入 x 的维度
            x = rearrange(x, 'b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w)

            # 准备注册令牌
            r = repeat(register_tokens, 'n d -> b x y n d', b = b, x = x.shape[1],y = x.shape[2])
            r, register_batch_ps = pack_one(r, '* n d')

            x, window_ps = pack_one(x, 'b x y * d')
            x, batch_ps  = pack_one(x, '* n d')
            x, register_ps = pack([r, x], 'b * d')

            # 对输入 x 进行块状注意力操作,并与原始输入相加
            x = block_attn(x, cond = cond) + x

            r, x = unpack(x, register_ps, 'b * d')

            x = unpack_one(x, batch_ps, '* n d')
            x = unpack_one(x, window_ps, 'b x y * d')
            x = rearrange(x, 'b x y w1 w2 d -> b d (x w1) (y w2)')

            r = unpack_one(r, register_batch_ps, '* n d')

            # grid-like attention

            # 重新排列输入 x 的维度
            x = rearrange(x, 'b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w)

            # 准备注册令牌
            r = reduce(r, 'b x y n d -> b n d', 'mean')
            r = repeat(r, 'b n d -> b x y n d', x = x.shape[1], y = x.shape[2])
            r, register_batch_ps = pack_one(r, '* n d')

            x, window_ps = pack_one(x, 'b x y * d')
            x, batch_ps  = pack_one(x, '* n d')
            x, register_ps = pack([r, x], 'b * d')

            # 对输入 x 进行网格状注意力操作,并与原始输入相加
            x = grid_attn(x, cond = cond) + x

            r, x = unpack(x, register_ps, 'b * d')

            x = unpack_one(x, batch_ps, '* n d')
            x = unpack_one(x, window_ps, 'b x y * d')
            x = rearrange(x, 'b x y w1 w2 d -> b d (w1 x) (w2 y)')

        # 返回处理后的输入 x
        return x
# 定义一个命名元组 Predictions,包含 surface、hrrr、precipitation 三个字段
Predictions = namedtuple('Predictions', [
    'surface',
    'hrrr',
    'precipitation'
])

# 定义一个命名元组 LossBreakdown,包含 surface、hrrr、precipitation 三个字段
LossBreakdown = namedtuple('LossBreakdown', [
    'surface',
    'hrrr',
    'precipitation'
])

# 定义一个类 MetNet3,继承自 Module
class MetNet3(Module):
    # 初始化方法
    @beartype
    def __init__(
        self,
        *,
        dim = 512,
        num_lead_times = 722,
        lead_time_embed_dim = 32,
        input_spatial_size = 624,
        attn_depth = 12,
        attn_dim_head = 64,
        attn_heads = 32,
        attn_dropout = 0.1,
        vit_window_size = 8,
        vit_mbconv_expansion_rate = 4,
        vit_mbconv_shrinkage_rate = 0.25,
        input_2496_channels = 2 + 14 + 1 + 2 + 20,
        input_4996_channels = 16 + 1,
        surface_and_hrrr_target_spatial_size = 128,
        precipitation_target_bins: Dict[str, int] = dict(
            mrms_rate = 512,
            mrms_accumulation = 512
        ),
        surface_target_bins: Dict[str, int] = dict(
            omo_temperature = 256,
            omo_dew_point = 256,
            omo_wind_speed = 256,
            omo_wind_component_x = 256,
            omo_wind_component_y = 256,
            omo_wind_direction = 180
        ),
        hrrr_norm_strategy: Union[
            Literal['none'],
            Literal['precalculated'],
            Literal['sync_batchnorm']
        ] = 'none',
        hrrr_channels = 617,
        hrrr_norm_statistics: Optional[Tensor] = None,
        hrrr_loss_weight = 10,
        crop_size_post_16km = 48,
        resnet_block_depth = 2,
    
    # 类方法,从路径加载模型
    @classmethod
    def init_and_load_from(cls, path, strict = True):
        # 将路径转换为 Path 对象
        path = Path(path)
        # 断言路径存在
        assert path.exists()
        # 加载模型
        pkg = torch.load(str(path), map_location = 'cpu')

        # 断言模型配置信息在加载的包中
        assert 'config' in pkg, 'model configs were not found in this saved checkpoint'

        # 从包中加载配置信息
        config = pickle.loads(pkg['config'])
        # 创建模型实例
        tokenizer = cls(**config)
        # 加载模型
        tokenizer.load(path, strict = strict)
        return tokenizer

    # 保存模型
    def save(self, path, overwrite = True):
        # 将路径转换为 Path 对象
        path = Path(path)
        # 断言路径不存在或允许覆盖
        assert overwrite or not path.exists(), f'{str(path)} already exists'

        # 构建保存的包
        pkg = dict(
            model_state_dict = self.state_dict(),
            config = self._configs
        )

        # 保存模型
        torch.save(pkg, str(path))

    # 加载模型
    def load(self, path, strict = True):
        # 将路径转换为 Path 对象
        path = Path(path)
        # 断言路径存在
        assert path.exists()

        # 加载模型
        pkg = torch.load(str(path))
        state_dict = pkg.get('model_state_dict')

        # 断言状态字典存在
        assert exists(state_dict)

        # 加载模型状态字典
        self.load_state_dict(state_dict, strict = strict)

    # 前向传播方法
    @beartype
    def forward(
        self,
        *,
        lead_times,
        hrrr_input_2496,
        hrrr_stale_state,
        input_2496,
        input_4996,
        surface_targets: Optional[Dict[str, Tensor]] = None,
        precipitation_targets: Optional[Dict[str, Tensor]] = None,
        hrrr_target: Optional[Tensor] = None,

.\lucidrains\metnet3-pytorch\metnet3_pytorch\__init__.py

# 从 metnet3_pytorch 包中导入 MetNet3 类
from metnet3_pytorch.metnet3_pytorch import (
    MetNet3
)

MetNet-3 - Pytorch

Implementation of MetNet 3, SOTA neural weather model out of Google Deepmind, in Pytorch

The model architecture is pretty unremarkable. It is basically a U-net with a specific well performing vision transformer. The most interesting thing about the paper may end up being the loss scaling in section 4.3.2

Appreciation

Install

$ pip install metnet3-pytorch

Usage

import torch
from metnet3_pytorch import MetNet3

metnet3 = MetNet3(
    dim = 512,
    num_lead_times = 722,
    lead_time_embed_dim = 32,
    input_spatial_size = 624,
    attn_dim_head = 8,
    hrrr_channels = 617,
    input_2496_channels = 2 + 14 + 1 + 2 + 20,
    input_4996_channels = 16 + 1,
    precipitation_target_bins = dict(
        mrms_rate = 512,
        mrms_accumulation = 512,
    ),
    surface_target_bins = dict(
        omo_temperature = 256,
        omo_dew_point = 256,
        omo_wind_speed = 256,
        omo_wind_component_x = 256,
        omo_wind_component_y = 256,
        omo_wind_direction = 180
    ),
    hrrr_loss_weight = 10,
    hrrr_norm_strategy = 'sync_batchnorm',  # this would use a sync batchnorm to normalize the input hrrr and target, without having to precalculate the mean and variance of the hrrr dataset per channel
    hrrr_norm_statistics = None             # you can also also set `hrrr_norm_strategy = "precalculated"` and pass in the mean and variance as shape `(2, 617)` through this keyword argument
)

# inputs

lead_times = torch.randint(0, 722, (2,))
hrrr_input_2496 = torch.randn((2, 617, 624, 624))
hrrr_stale_state = torch.randn((2, 1, 624, 624))
input_2496 = torch.randn((2, 39, 624, 624))
input_4996 = torch.randn((2, 17, 624, 624))

# targets

precipitation_targets = dict(
    mrms_rate = torch.randint(0, 512, (2, 512, 512)),
    mrms_accumulation = torch.randint(0, 512, (2, 512, 512)),
)

surface_targets = dict(
    omo_temperature = torch.randint(0, 256, (2, 128, 128)),
    omo_dew_point = torch.randint(0, 256, (2, 128, 128)),
    omo_wind_speed = torch.randint(0, 256, (2, 128, 128)),
    omo_wind_component_x = torch.randint(0, 256, (2, 128, 128)),
    omo_wind_component_y = torch.randint(0, 256, (2, 128, 128)),
    omo_wind_direction = torch.randint(0, 180, (2, 128, 128))
)

hrrr_target = torch.randn(2, 617, 128, 128)

total_loss, loss_breakdown = metnet3(
    lead_times = lead_times,
    hrrr_input_2496 = hrrr_input_2496,
    hrrr_stale_state = hrrr_stale_state,
    input_2496 = input_2496,
    input_4996 = input_4996,
    precipitation_targets = precipitation_targets,
    surface_targets = surface_targets,
    hrrr_target = hrrr_target
)

total_loss.backward()

# after much training from above, you can predict as follows

metnet3.eval()

surface_preds, hrrr_pred, precipitation_preds = metnet3(
    lead_times = lead_times,
    hrrr_input_2496 = hrrr_input_2496,
    hrrr_stale_state = hrrr_stale_state,
    input_2496 = input_2496,
    input_4996 = input_4996,
)


# Dict[str, Tensor], Tensor, Dict[str, Tensor]

Todo

Citations

@article{Andrychowicz2023DeepLF,
    title   = {Deep Learning for Day Forecasts from Sparse Observations},
    author  = {Marcin Andrychowicz and Lasse Espeholt and Di Li and Samier Merchant and Alexander Merose and Fred Zyda and Shreya Agrawal and Nal Kalchbrenner},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2306.06079},
    url     = {https://api.semanticscholar.org/CorpusID:259129311}
}
@inproceedings{ElNouby2021XCiTCI,
    title   = {XCiT: Cross-Covariance Image Transformers},
    author  = {Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{\'e} J{\'e}gou},
    booktitle = {Neural Information Processing Systems},
    year    = {2021},
    url     = {https://api.semanticscholar.org/CorpusID:235458262}
}

.\lucidrains\metnet3-pytorch\setup.py

# 导入设置安装包和查找包的函数
from setuptools import setup, find_packages

# 设置安装包的信息
setup(
  name = 'metnet3-pytorch',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找并包含所有包
  version = '0.0.12',  # 版本号
  license='MIT',  # 许可证
  description = 'MetNet 3 - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  long_description_content_type = 'text/markdown',  # 长描述内容类型
  url = 'https://github.com/lucidrains/metnet3-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'vision transformers',
    'unet',
    'weather forecasting'
  ],
  install_requires=[  # 安装依赖
    'beartype',
    'einops>=0.7.0',
    'torch>=2.0',
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\mirasol-pytorch\mirasol_pytorch\distributed.py

# 导入必要的库
from functools import cache

import torch
from torch.autograd import Function
import torch.distributed as distributed

from einops import rearrange

# 辅助函数

# 使用缓存装饰器缓存结果,判断当前是否处于分布式环境
@cache
def get_is_distributed():
    return distributed.is_initialized() and distributed.get_world_size() > 1

# 在指定维度上对张量进行填充,使其达到指定长度
def pad_dim_to(t, length, dim = 0):
    pad_length = length - t.shape[dim]
    zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
    return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))

# 分布式辅助函数

# 在所有进程中收集具有可变维度的张量,根据给定的维度和大小
def all_gather_variable_dim(t, dim = 0, sizes = None):
    device, rank, world_size = t.device, distributed.get_rank(), distributed.get_world_size()

    if not exists(sizes):
        size = torch.tensor(t.shape[dim], device = device, dtype = torch.long)
        sizes = [torch.empty_like(size, device = device, dtype = torch.long) for i in range(world_size)]
        distributed.all_gather(sizes, size)
        sizes = torch.stack(sizes)

    max_size = sizes.amax().item()
    padded_t = pad_dim_to(t, max_size, dim = dim)

    gathered_tensors = [torch.empty(padded_t.shape, device = device, dtype = padded_t.dtype) for i in range(world_size)]
    distributed.all_gather(gathered_tensors, padded_t)

    gathered_tensor = torch.cat(gathered_tensors, dim = dim)
    seq = torch.arange(max_size, device = device)

    mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
    mask = rearrange(mask, 'i j -> (i j)')
    seq = torch.arange(mask.shape[-1], device = device)
    indices = seq[mask]

    gathered_tensor = gathered_tensor.index_select(dim, indices)

    return gathered_tensor, sizes

# 自定义的 Function 类,用于实现 all_gather 操作
class AllGather(Function):
    @staticmethod
    def forward(ctx, x, dim, sizes):
        assert get_is_distributed()
        x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
        ctx.batch_sizes = batch_sizes.tolist()
        ctx.dim = dim
        return x, batch_sizes

    @staticmethod
    def backward(ctx, grads, _):
        batch_sizes, rank = ctx.batch_sizes, distributed.get_rank()
        grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
        return grads_by_rank[rank], None, None

# 将自定义的 Function 应用到 all_gather 函数上
all_gather = AllGather.apply

.\lucidrains\mirasol-pytorch\mirasol_pytorch\mirasol_pytorch.py

# 导入所需的模块和函数
import operator
from functools import partial
from collections import namedtuple

import torch
import torch.nn.functional as F
from torch import Tensor, nn, einsum
from torch.nn import Module, ModuleList

# 导入 beartype 模块和相关类型
from beartype import beartype
from beartype.typing import Optional, Union, Tuple, Dict, Any

# 导入 einops 相关函数和层
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange

# 导入 x_transformers 相关模块和类
from x_transformers import (
    Encoder,
    Decoder,
    TransformerWrapper,
    AutoregressiveWrapper
)

# 导入 x_transformers 中的 RotaryEmbedding 类
from x_transformers.x_transformers import RotaryEmbedding

# 导入 mirasol_pytorch 中的分布式函数
from mirasol_pytorch.distributed import all_gather, get_is_distributed

# 辅助函数

# 判断变量是否存在
def exists(v):
    return v is not None

# 返回参数中第一个存在的值
def default(*args):
    for arg in args:
        if exists(arg):
            return arg
    return None

# 判断一个数是否可以被另一个数整除
def divisible_by(num, den):
    return (num % den) == 0

# 判断参数中只有一个为 True
def only_one_true(*bools):
    return sum(*[map(int, bools)]) == 1

# 将张量打包成指定模式
def pack_one(t, pattern):
    return pack([t], pattern)

# 将打包的张量解包成指定模式
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 张量操作函数

# 计算张量的 L2 范数
def l2norm(t):
    return F.normalize(t, dim = -1)

# 计算张量之间的余弦相似度损失
def cosine_sim_loss(x, y):
    x, y = map(l2norm, (x, y))
    return 1. - einsum('b n d, b n d -> b n', x, y).mean()

# 生成位置编码的正弦和余弦值
def posemb_sincos_nd(
    t: Tensor,
    temperature: int = 10000,
    dtype = torch.float32
):
    b, *dims, feat_dim, device = *t.shape, t.device
    seq_len = torch.tensor(dims).cumprod(dim = -1)[-1].item()

    arange = partial(torch.arange, device = device)

    num_dims = len(dims)
    two_times_num_dims = 2 * num_dims # 2 because sin and cos of same position

    rounded_feat_dim = feat_dim // num_dims * num_dims
    feat_dim_remainder = feat_dim % num_dims

    omega = arange(rounded_feat_dim // two_times_num_dims) / (rounded_feat_dim // two_times_num_dims - 1)
    omega = 1.0 / (temperature ** omega)
    meshed = torch.meshgrid(*[*map(arange, dims)], indexing = 'ij')

    pos = torch.cat(tuple(m.flatten()[..., None] for m in meshed), dim = 0)
    pos = pos * omega[None, :]

    pos = torch.cat((pos.sin(), pos.cos()))

    pos = rearrange(pos, '(n f) d -> n (f d)', n = seq_len)
    pos = pos.type(dtype)

    return F.pad(pos, (0, feat_dim_remainder))

# 生成具有一定概率的掩码张量
def mask_with_prob(
    shape: Tuple[int, ...],
    prob: float,
    device = None
) -> Tensor:
    length = shape[-1]
    num_mask = int(prob * length)
    randperm = torch.randn(shape, device = device).argsort(dim = -1)
    return randperm >= num_mask

# 主类

# 定义 Losses 命名元组,包含不同类型的损失
Losses = namedtuple('Losses', [
    'text_autoregressive',
    'av_autoregressive',
    'av_recon',
    'text_av_sim_reg'
])

# Mirasol 类,继承自 Module 类
class Mirasol(Module):

    @beartype
    # 初始化函数,设置模型的各种参数
    def __init__(
        self,
        *,
        dim,
        num_text_tokens,
        video_image_size,
        video_frames_per_timechunk,
        audio_freq_dim,
        audio_time_dim_per_timechunk,
        audio_patch_size: Tuple[int, int],                          # 音频补丁大小 (频率, 时间)
        video_patch_size: Tuple[int, int],                          # 视频补丁大小 (空间, 时间)
        video_recon_patch_size: Optional[Tuple[int, int]] = None,   # 视频重建补丁大小 (空间, 时间) - 用于重建损失的较小视频
        video_recon_interpolate_mode = 'nearest',
        audio_encoder: Union[Module, Dict[str, Any]],
        video_encoder: Union[Module, Dict[str, Any]],
        num_audio_video_register_tokens = 8,                        # 音频视频注册令牌数量 https://arxiv.org/abs/2309.16588
        audio_video_mask_prob = 0.15,                         # 在论文中,他们使用了被屏蔽的令牌,但从伯克利遗忘-因果-掩码论文中,一个简单的键值掩码应该足够
        text_max_seq_len = 2048,
        text_forgetful_causal_mask_prob = 0.1,                      # https://arxiv.org/abs/2210.13432
        encoder_depth = 6,
        decoder_depth = 6,
        combiner_depth = 2,
        combiner_output_num_tokens = 3,
        video_channels = 3,
        attn_dim_head = 64,
        attn_heads = 8,
        flash_attn = True,
        attn_layers_kwargs: dict = dict(),
        combiner: Optional[Module] = None,
        combiner_kwargs: dict = dict(),
        autoregressive_wrapper_kwargs: dict = dict(
            pad_value = 0,
            ignore_index = -100
        ),
        av_autoregressive_loss_weight = 1.,
        av_reconstruction_loss_weight = 1.,
        sim_reg_loss_weight = 0.
    
    # 返回设备信息
    @property
    def device(self):
        return next(self.parameters()).device

    # 生成函数,用于生成序列
    @torch.no_grad()
    def generate(
        self,
        *,
        seq_len: int,
        prompt: Optional[Tensor] = None,
        **kwargs
    ):
        was_training = self.training
        self.eval()

        assert 'generate' not in kwargs
        assert 'generate_seq_len' not in kwargs

        # 调用前向传播函数生成序列
        out = self.forward(
            text = prompt,
            generate = True,
            generate_seq_len = seq_len,
            **kwargs
        )

        self.train(was_training)
        return out

    # 前向传播函数,接收输入并返回输出
    @beartype
    def forward(
        self,
        *,
        audio: Optional[Tensor] = None,
        video: Optional[Tensor] = None,
        encoded_audio: Optional[Tensor] = None,
        encoded_video: Optional[Tensor] = None,
        text: Optional[Tensor] = None,
        text_mask: Optional[Tensor] = None,
        return_loss = True,
        return_loss_breakdown = False,
        generate = False,
        generate_seq_len = None

.\lucidrains\mirasol-pytorch\mirasol_pytorch\__init__.py

# 从 mirasol_pytorch 包中导入 Mirasol 类
from mirasol_pytorch.mirasol_pytorch import Mirasol

🌻 Mirasol - Pytorch

Implementation of Mirasol, SOTA Multimodal Autoregressive model out of Google Deepmind, in Pytorch

Will simply implement the Transformer Combiner and omit the other variants.

Appreciation

Install

$ pip install mirasol-pytorch

Usage

import torch
from mirasol_pytorch import Mirasol

model = Mirasol(
    dim = 512,
    num_text_tokens = 256,
    video_image_size = 128,
    video_frames_per_timechunk = 2,
    audio_freq_dim = 64,
    audio_time_dim_per_timechunk = 32,
    audio_patch_size = (32, 16),
    video_patch_size = (64, 2),
    audio_encoder = dict(
        dim = 512,
        depth = 2
    ),
    video_encoder = dict(
        dim = 512,
        depth = 2
    )
)

audio = torch.randn(1, 64, 1024)
video = torch.randn(1, 3, 12, 128, 128)

text = torch.randint(0, 256, (1, 1024))

loss = model(
    audio = audio,
    video = video,
    text = text
)

loss.backward()

# after much training

sampled_text = model.generate(
    audio = audio,
    video = video,
    seq_len = 512
)

Todo

Citations

@article{Piergiovanni2023Mirasol3BAM,
    title   = {Mirasol3B: A Multimodal Autoregressive model for time-aligned and contextual modalities},
    author  = {A. J. Piergiovanni and Isaac Noble and Dahun Kim and Michael S. Ryoo and Victor Gomes and Anelia Angelova},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2311.05698},
    url     = {https://api.semanticscholar.org/CorpusID:265129010}
}
@inproceedings{Liu2022TowardsBF,
    title   = {Towards Better Few-Shot and Finetuning Performance with Forgetful Causal Language Models},
    author  = {Hao Liu and Xinyang Geng and Lisa Lee and Igor Mordatch and Sergey Levine and Sharan Narang and P. Abbeel},
    year    = {2022},
    url     = {https://api.semanticscholar.org/CorpusID:256416540}
}
@article{Darcet2023VisionTN,
    title   = {Vision Transformers Need Registers},
    author  = {Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2309.16588},
    url     = {https://api.semanticscholar.org/CorpusID:263134283}
}
@article{Bondarenko2023QuantizableTR,
    title   = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
    author  = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2306.12929},
    url     = {https://api.semanticscholar.org/CorpusID:259224568}
}
@misc{shi2023enhance,
    title   = {Enhance audio generation controllability through representation similarity regularization}, 
    author  = {Yangyang Shi and Gael Le Lan and Varun Nagaraja and Zhaoheng Ni and Xinhao Mei and Ernie Chang and Forrest Iandola and Yang Liu and Vikas Chandra},
    year    = {2023},
    eprint  = {2309.08773},
    archivePrefix = {arXiv},
    primaryClass = {cs.SD}
}

.\lucidrains\mirasol-pytorch\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的信息
setup(
  # 包的名称
  name = 'mirasol-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.0.16',
  # 许可证类型
  license='MIT',
  # 描述信息
  description = 'Mirasol - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 项目链接
  url = 'https://github.com/lucidrains/mirasol-pytorch',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'deep learning',
    'multimodality'
  ],
  # 安装依赖项
  install_requires=[
    'beartype',
    'einops>=0.7.0',
    'x-transformers>=1.25.10',
    'torch>=2.0'
  ],
  # 分类标签
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

.\lucidrains\mixture-of-attention\mixture_of_attention\attend.py

# 导入必要的库
from collections import namedtuple
from functools import wraps
from packaging import version

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange

# 定义一个命名元组EfficientAttentionConfig,用于存储配置信息
EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# 定义辅助函数exists,用于检查变量是否存在
def exists(val):
    return val is not None

# 定义装饰器once,确保函数只被调用一次
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 用once装饰print函数,确保只打印一次
print_once = once(print)

# 主要类Attend
class Attend(nn.Module):
    def __init__(
        self,
        dropout = 0.,
        causal = False,
        flash = False
    ):
        super().__init__()
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.causal = causal
        self.flash = flash
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # 确定cuda和cpu的高效注意力配置
        self.cpu_config = EfficientAttentionConfig(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = EfficientAttentionConfig(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = EfficientAttentionConfig(False, True, True)

    # 生成mask
    def get_mask(self, i, j, device):
        return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1)

    # Flash Attention
    def flash_attn(self, q, k, v, mask = None):
        _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda

        if exists(mask) and mask.ndim != 4:
            mask = rearrange(mask, 'b j -> b 1 1 j')
            mask = mask.expand(-1, heads, q_len, -1)

        config = self.cuda_config if is_cuda else self.cpu_config

        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask = mask,
                dropout_p = self.dropout if self.training else 0., 
                is_causal = self.causal
            )

        return out

    # 前向传播函数
    def forward(self, q, k, v, mask = None):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        n, device = q.shape[-2], q.device

        scale = q.shape[-1] ** -0.5

        if self.flash:
            return self.flash_attn(q, k, v, mask = mask)

        # 相似度计算
        sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale

        # key padding mask
        if exists(mask):
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        # causal mask
        if self.causal:
            causal_mask = self.get_mask(n, device)
            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # 注意力计算
        attn = sim.softmax(dim=-1)
        attn = self.attn_dropout(attn)

        # 聚合值
        out = einsum("b h i j, b h j d -> b h i d", attn, v)

        return out

.\lucidrains\mixture-of-attention\mixture_of_attention\autoregressive_wrapper.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F

# 从 einops 库中导入 rearrange 函数
from einops import rearrange

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 评估装饰器函数
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        # 保存模型当前是否为训练状态
        was_training = model.training
        # 将模型设置为评估状态
        model.eval()
        # 调用传入的函数,并传入模型、参数和关键字参数
        out = fn(model, *args, **kwargs)
        # 恢复模型之前的训练状态
        model.train(was_training)
        return out
    return inner

# top k 过滤

# 根据阈值过滤 logits 中的 top k 值
def top_k(logits, thres = 0.9):
    # 计算需要保留的 top k 值的数量
    k = int((1 - thres) * logits.shape[-1])
    # 获取 top k 值和对应的索引
    val, ind = torch.topk(logits, k)
    # 创建一个与 logits 相同形状的张量,填充为负的最大值
    probs = torch.full_like(logits, -torch.finfo(logits.dtype).max)
    # 根据索引将 top k 值填充到 probs 中
    probs.scatter_(1, ind, val)
    return probs

# 自回归包装器类
class AutoregressiveWrapper(nn.Module):
    def __init__(
        self,
        net,        
        pad_value = 0
    ):
        super().__init__()
        # 初始化属性
        self.seq_len = net.seq_len
        self.pad_value = pad_value
        self.net = net

    # 生成函数装饰器,用于生成序列
    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        prompt,
        seq_len,
        temperature=1.0,
        filter_thres=0.9,
        **kwargs
    ):
        # 获取 prompt 的形状和设备信息
        b, t, device = *prompt.shape, prompt.device

        out = prompt

        # 生成序列
        for _ in range(seq_len):
            # 获取最后 self.seq_len 长度的序列,并传入网络获取 logits
            logits = self.net(out[:, -self.seq_len:], **kwargs)[:, -1]

            # 对 logits 进行 top k 过滤
            filtered_logits = top_k(logits, thres = filter_thres)
            # 计算概率分布
            probs = F.softmax(filtered_logits / temperature, dim = -1)

            # 从概率分布中采样一个值
            sample = torch.multinomial(probs, 1)
            # 将采样值拼接到输出序列中
            out = torch.cat((out, sample), dim = -1)

        # 去除前面的 prompt 部分,返回生成的序列
        out = out[:, t:]
        return out

    # 前向传播函数
    def forward(self, x, **kwargs):
        # 获取输入 x 和标签 labels
        x, labels = x[:, :-1], x[:, 1:]
        # 将输入传入网络获取 logits
        logits = self.net(x, **kwargs)
        # 重新排列 logits 的维度
        logits = rearrange(logits, "b c n -> b n c")
        # 计算交叉熵损失
        return F.cross_entropy(logits, labels)

.\lucidrains\mixture-of-attention\mixture_of_attention\mixture_of_attention.py

# 导入数学库
import math

# 导入 PyTorch 库
import torch
import torch.nn.functional as F
from torch import Tensor, nn, einsum

# 导入类型提示
from typing import Tuple, Optional

# 导入 einops 库中的函数
from einops import rearrange, repeat, reduce, pack, unpack

# 导入自定义模块
from mixture_of_attention.attend import Attend
from mixture_of_attention.rotary_emb import apply_rotary_pos_emb

from local_attention import LocalMHA

from colt5_attention import CoordinateDescentRouter

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 将张量打包成指定模式的形状
def pack_one(t, pattern):
    return pack([t], pattern)

# 将打包后的张量解包成原始形状
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 将张量填充到指定的倍数
def pad_to_multiple(tensor, multiple, dim = -1, value = 0):
    seq_len = tensor.shape[dim]
    m = seq_len / multiple
    if m.is_integer():
        return tensor, seq_len

    remainder = math.ceil(m) * multiple - seq_len
    pad_offset = (0,) * (-1 - dim) * 2
    padded_tensor = F.pad(tensor, (*pad_offset, 0, remainder), value = value)
    return padded_tensor, seq_len

# 归一化

# RMS 归一化模块
class RMSNorm(nn.Module):
    def __init__(self, dim, groups = 1):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(groups, dim, 1))

    def forward(self, x):
        normed = F.normalize(x, dim = -2)
        return normed * self.scale * self.gamma

# 注意力机制

# 注意力模块
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        dim_head = 64,
        dim_context = None,
        heads = 8,
        causal = False,
        groups = 1, # 定义专家的数量
        dropout = 0.,
        flash = False,
        prenorm = False
    ):
        super().__init__()
        self.heads = heads
        self.groups = groups

        dim_inner = dim_head * heads
        dim_context = default(dim_context, dim)

        self.norm = RMSNorm(dim, groups = groups) if prenorm else nn.Identity()
        self.context_norm = RMSNorm(dim_context, groups = groups) if prenorm else nn.Identity()

        self.attend = Attend(
            dropout = dropout,
            causal = causal,
            flash = flash
        )

        # 空键/值,用于防止一行全部被掩码掉

        self.null_kv = nn.Parameter(torch.randn(2, groups, heads, 1, dim_head))

        # 利用卷积组并行处理专家

        self.to_q = nn.Conv1d(dim * groups, dim_inner * groups, 1, bias = False, groups = groups)
        self.to_kv = nn.Conv1d(dim_context * groups, dim_inner * 2 * groups, 1, bias = False, groups = groups)
        self.to_out = nn.Conv1d(dim_inner * groups, dim * groups, 1, bias = False, groups = groups)

    def forward(
        self,
        x,
        context = None,
        mask = None,
        queries_scale = None,
        keys_scale = None,
        values_scale = None,
        output_scale = None,
        rotary_emb: Optional[Tuple[Tensor, Tensor]] = None
        ):
            """
            einops
            b - batch
            g - groups
            n - sequence
            d - feature dimension
            """
            # 获取输入张量的形状信息
            b, g, h = x.shape[0], self.groups, self.heads

            # 判断是否只有一个专家
            one_expert = x.ndim == 3

            # 如果只有一个专家,则将其维度扩展为4维
            if one_expert:
                assert g == 1
                x = rearrange(x, 'b n d -> b 1 n d')

            # 断言输入张量为4维
            assert x.ndim == 4
            # 断言输入张量的第二维为groups
            assert x.shape[1] == g

            # 将groups折叠到特征维度中,以便通过分组卷积一次处理
            x = rearrange(x, 'b g n d -> b g d n')

            # 处理交叉注意力的上下文
            if exists(context):
                context_one_expert = context.ndim == 3

                if context_one_expert:
                    assert g == 1
                    context = rearrange(context, 'b n d -> b 1 n d')

                assert context.ndim == 4
                assert context.shape[1] == g

                context = rearrange(context, 'b g n d -> b g d n')

            # 如果没有传入context,则使用输入张量x
            context = default(context, x)

            # 处理mask
            if exists(mask):
                if mask.ndim == 2:
                    mask = repeat(mask, 'b n -> (b g) n', g = g)
                elif mask.ndim == 3:
                    mask = rearrange(mask, 'b g n -> (b g) n')

                mask = F.pad(mask, (1, 0), value = True)

            # 如果适用,进行预归一化
            x = self.norm(x)
            context = self.context_norm(context)

            # 将groups折叠到维度中以进行分组卷积
            x, context = map(lambda t: rearrange(t, 'b g d n -> b (g d) n'), (x, context))

            # 获取查询、键、值
            q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = 1))

            # 拆分头部并将groups合并到批次中
            q, k, v = map(lambda t: rearrange(t, 'b (g h d) n -> b g h n d', h = h, g = g), (q, k, v))

            # 旋转嵌入
            if exists(rotary_emb):
                q_rotary_emb, k_rotary_emb = rotary_emb

                if q_rotary_emb.ndim > 2:
                    q_rotary_emb = rearrange(q_rotary_emb, 'b g n d -> b g 1 n d')

                if k_rotary_emb.ndim > 2:
                    k_rotary_emb = rearrange(k_rotary_emb, 'b g n d -> b g 1 n d')

                q = apply_rotary_pos_emb(q_rotary_emb, q)
                k = apply_rotary_pos_emb(k_rotary_emb, k)

            # 如果传入了queries_scale,则给查询加权
            if exists(queries_scale):
                q = q * queries_scale

            # 如果传入了keys_scale,则给键加权
            if exists(keys_scale):
                k = k * keys_scale

            # 如果传入了values_scale,则给值加权
            if exists(values_scale):
                v = v * values_scale

            # 将groups合并到批次中
            q, k, v = map(lambda t: rearrange(t, 'b g ... -> (b g) ...'), (q, k, v))

            # 连接空键/值,以防止一行中所有元素都被屏蔽并节省大量麻烦
            nk, nv = map(lambda t: repeat(t, 'g h 1 d -> (b g) h 1 d', b = b), self.null_kv)

            k = torch.cat((nk, k), dim = -2)
            v = torch.cat((nv, v), dim = -2)

            # 注意力机制
            out = self.attend(q, k, v, mask = mask)

            # 合并头部输出
            out = rearrange(out, '(b g) h n d -> b (g h d) n', g = g)

            out = self.to_out(out)

            out = rearrange(out, 'b (g d) n -> b g n d', g = g)

            # 如果只有一个专家,则将其维度还原为3维
            if one_expert:
                out = rearrange(out, 'b 1 n d -> b n d')

            # 如果传入了output_scale,则给输出加权
            if exists(output_scale):
                out = out * output_scale

            return out
# 定义混合注意力机制的类
class MixtureOfAttention(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        *,
        num_routed_queries,
        num_routed_key_values,
        dim_context = None,
        local_attn = False,
        local_attn_window_size = None,
        num_experts = 2,
        dim_head = 64,
        heads = 8,
        dropout = 0.,
        use_triton = True,
        flash_attn = True,
        prenorm = True,
        average_routed = False,
        **kwargs
    ):
        super().__init__()
        dim_context = default(dim_context, dim)
        self.num_routed_queries = num_routed_queries
        self.num_routed_key_values = num_routed_key_values

        # 如果不是本地注意力,创建一个参数化的空路由令牌
        self.null_routed_token = nn.Parameter(torch.randn(1, 1, dim)) if not local_attn else None

        self.average_routed = average_routed

        self.local_attn = None

        # 如果使用本地注意力,创建本地多头注意力对象
        if local_attn:
            assert exists(local_attn_window_size)
            self.local_attn = LocalMHA(
                dim = dim,
                dim_head = dim_head,
                heads = heads,
                prenorm = prenorm,
                window_size = local_attn_window_size
            )

        # 创建查询路由器对象
        self.query_router = CoordinateDescentRouter(
            dim,
            num_routing_tokens = num_experts,
            use_triton = use_triton,
            **kwargs
        )

        # 创建键值路由器对象
        self.key_value_router = CoordinateDescentRouter(
            dim_context,
            num_routing_tokens = num_experts,
            use_triton = use_triton,
            **kwargs
        )

        # 创建注意力对象
        self.attn = Attention(
            dim = dim,
            dim_context = dim_context,
            dim_head = dim_head,
            heads = heads,
            groups = num_experts,
            dropout = dropout,
            flash = flash_attn,
            prenorm = prenorm
        )

    # 返回模型参数所在的设备
    @property
    def device(self):
        return next(self.parameters()).device

    # 前向传播函数
    def forward(
        self,
        x,
        context = None,
        mask = None,
        context_mask = None,
        num_routed_queries = None,
        num_routed_key_values = None,
        rotary_emb = None
        ):
            # 设置路由查询数量为默认值或者传入的值
            num_routed_queries = default(num_routed_queries, self.num_routed_queries)
            # 设置路由键值对数量为默认值或者传入的值
            num_routed_key_values = default(num_routed_key_values, self.num_routed_key_values)

            # 判断是否进行跨注意力
            is_cross_attn = exists(context)

            # 断言不能同时存在本地注意力和跨注意力
            assert not (exists(self.local_attn) and is_cross_attn), 'cannot do cross attention with local attention (only for self attention)'

            if not is_cross_attn:
                # 如果不是跨注意力,则使用自注意力
                context = x
                context_mask = mask

            # 获取查询索引、查询分数、查询、查询掩码
            query_indices, query_scores, queries, query_mask = self.query_router(x, mask = mask, num_tokens = num_routed_queries, keep_one_route_dim = True)
            query_scores = rearrange(query_scores, 'b g n -> b g n 1')

            # 获取键值索引、键值分数、键值、键值掩码
            kv_indices, key_value_scores, key_values, key_value_mask = self.key_value_router(context, mask = context_mask, num_tokens = num_routed_key_values, keep_one_route_dim = True)
            key_value_scores = rearrange(key_value_scores, 'b g n -> b g 1 n 1')

            # 旋转嵌入

            if exists(rotary_emb):
                assert not is_cross_attn, 'rotary embedding should not be used for cross attending'
                q_rotary_emb = rotary_emb[query_indices] if exists(query_indices) else rotary_emb
                k_rotary_emb = rotary_emb[kv_indices] if exists(kv_indices) else rotary_emb
                rotary_emb = (q_rotary_emb, k_rotary_emb)

            # 注意力计算

            attn_out = self.attn(
                queries,
                rotary_emb = rotary_emb,
                context = key_values,
                mask = key_value_mask,
                values_scale = key_value_scores,
                output_scale = query_scores
            )

            local_out = None
            if exists(self.local_attn):
                local_out = self.local_attn(x, mask = mask)

            need_route_queries = exists(query_indices)

            if not need_route_queries:
                out = attn_out

                if exists(local_out):
                    local_out = rearrange(local_out, 'b n d -> b 1 n d')
                    out = torch.cat((local_out, out), dim = 1)

                out = reduce(attn_out, 'b e n d -> b n d', 'mean')

                if exists(mask):
                    out = out.masked_fill(~mask[..., None], 0.)

                return out

            out = torch.zeros_like(x)
            counts = torch.zeros(x.shape[:-1], device = x.device)

            query_indices = rearrange(query_indices, 'b g n -> b (g n)')
            attn_out = rearrange(attn_out, 'b g n d -> b (g n) d')

            expanded_query_indices = repeat(query_indices, 'b n -> b n d', d = x.shape[-1])

            attn_out_summed = out.scatter_add(1, expanded_query_indices, attn_out)

            ones = torch.ones(attn_out.shape[:-1], device = self.device)

            if exists(query_mask):
                ones = ones * rearrange(query_mask, 'b g n -> b (g n)')

            counts = counts.scatter_add(1, query_indices, ones)
            counts = rearrange(counts, '... -> ... 1')

            has_unrouted = not exists(local_out)

            if not has_unrouted:
                counts = counts + 1
                attn_out_summed = attn_out_summed + local_out
            else:
                not_routed_mask = counts == 0
                attn_out_summed = attn_out_summed.masked_fill(not_routed_mask, 0.)

            out = attn_out_summed

            # 如果需要,进行平均

            if self.average_routed:
                out = out / counts.clamp(min = 1e-5)

            # 对于未路由的位置,使用学习到的路由令牌而不是仅仅是0

            if has_unrouted:
                out = torch.where(
                    not_routed_mask,
                    self.null_routed_token,
                    out,
                )

            if exists(mask):
                out = out.masked_fill(~mask[..., None], 0.)

            return out
# 定义一个混合自回归注意力模型类
class MixtureOfAutoregressiveAttention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        num_routed_queries,  # 路由查询的数量
        num_routed_key_values,  # 路由键值对的数量
        local_attn_window_size,  # 本地注意力窗口大小
        routed_window_size = None,  # 路由窗口大小,默认为None
        num_experts = 2,  # 专家数量,默认为2
        dim_head = 64,  # 头维度,默认为64
        heads = 8,  # 头数,默认为8
        dropout = 0.,  # 丢弃率,默认为0
        use_triton = False,  # 是否使用 Triton,默认为False
        flash_attn = True,  # 是否使用 Flash 注意力,默认为True
        prenorm = True,  # 是否使用预归一化,默认为True
        average_routed = False,  # 是否平均路由,默认为False
        **kwargs
    ):
        super().__init__()
        self.num_routed_queries = num_routed_queries  # 初始化路由查询数量
        self.num_routed_key_values = num_routed_key_values  # 初始化路由键值对数量

        self.num_experts = num_experts  # 初始化专家数量
        self.null_tokens = nn.Parameter(torch.randn(num_experts, dim))  # 初始化空令牌

        routed_window_size = default(routed_window_size, local_attn_window_size)  # 设置路由窗口大小为默认值或本地注意力窗口大小

        self.routed_window_size = routed_window_size  # 初始化路由窗口大小
        self.average_routed = average_routed  # 初始化是否平均路由

        # 创建本地多头自注意力模块
        self.local_attn = LocalMHA(
            dim = dim,
            dim_head = dim_head,
            heads = heads,
            prenorm = prenorm,
            causal = True,
            window_size = local_attn_window_size
        )

        # 创建查询路由器
        self.query_router = CoordinateDescentRouter(
            dim,
            num_routing_tokens = num_experts,
            use_triton = use_triton,
            **kwargs
        )

        # 创建键值路由器
        self.key_value_router = CoordinateDescentRouter(
            dim,
            num_routing_tokens = num_experts,
            use_triton = use_triton,
            **kwargs
        )

        # 创建注意力模块
        self.attn = Attention(
            dim = dim,
            dim_head = dim_head,
            heads = heads,
            groups = num_experts,
            dropout = dropout,
            flash = flash_attn,
            prenorm = prenorm
        )

    # 定义设备属性
    @property
    def device(self):
        return next(self.parameters()).device

    # 前向传播函数
    def forward(
        self,
        x,
        rotary_emb = None,
        num_routed_queries = None,
        num_routed_key_values = None

.\lucidrains\mixture-of-attention\mixture_of_attention\transformer.py

# 导入所需的库
import torch
import torch.nn.functional as F
from torch import nn, einsum

# 导入重排操作库
from einops import rearrange

# 导入自定义的注意力机制类
from mixture_of_attention.mixture_of_attention import MixtureOfAutoregressiveAttention

# 导入自定义的旋转嵌入类
from mixture_of_attention.rotary_emb import RotaryEmbedding

# 辅助函数

# 判断变量是否存在的辅助函数
def exists(val):
    return val is not None

# 类定义

# RMS 归一化类
class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        normed = F.normalize(x, dim = -1)
        return normed * self.scale * self.gamma

# 前馈神经网络类
def FeedForward(dim, mult = 4):
    return nn.Sequential(
        RMSNorm(dim),
        nn.Linear(dim, dim * mult),
        nn.GELU(),
        nn.Linear(dim * mult, dim)
    )

# 主类定义

# Transformer 模型类
class Transformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        num_tokens,
        depth,
        seq_len,
        local_attn_window_size,
        num_routed_queries,
        num_routed_key_values,
        num_experts,
        cosine_sim_routing = True,
        routed_window_size = None,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        use_triton = True,
        routed_rotary_emb = True
    ):
        super().__init__()
        self.token_emb = nn.Embedding(num_tokens, dim)
        self.pos_emb = nn.Embedding(seq_len, dim)
        self.seq_len = seq_len

        self.rotary_emb = RotaryEmbedding(dim_head) if routed_rotary_emb else None

        self.layers = nn.ModuleList([])

        # 创建多层 Transformer 模型
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                MixtureOfAutoregressiveAttention(
                    dim = dim,
                    local_attn_window_size = local_attn_window_size,
                    routed_window_size = routed_window_size,
                    num_routed_queries = num_routed_queries,
                    num_routed_key_values = num_routed_key_values,
                    cosine_sim_routing = cosine_sim_routing,
                    num_experts = num_experts,
                    dim_head = dim_head,
                    heads = heads,
                    use_triton = use_triton
                ),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

        # 输出层
        self.to_logits = nn.Sequential(
            RMSNorm(dim),
            nn.Linear(dim, num_tokens)
        )

    # ��取设备信息
    @property
    def device(self):
        return next(self.parameters()).device

    # 前向传播函数
    def forward(self, x):
        x = self.token_emb(x)
        x = x + self.pos_emb(torch.arange(x.shape[-2], device = self.device))

        rotary_emb = None
        if exists(self.rotary_emb):
            rotary_emb = self.rotary_emb(x.shape[1])

        # 多层 Transformer 模型的前向传播
        for attn, ff in self.layers:
            x = attn(x, rotary_emb = rotary_emb) + x

            x = ff(x) + x

        return self.to_logits(x)

.\lucidrains\mixture-of-attention\mixture_of_attention\__init__.py

# 从mixture_of_attention包中导入MixtureOfAttention、MixtureOfAutoregressiveAttention和Attention类
from mixture_of_attention.mixture_of_attention import (
    MixtureOfAttention,
    MixtureOfAutoregressiveAttention,
    Attention
)

Mixture-of-Attention

Some personal experiments around routing tokens to different autoregressive attention, akin to mixture-of-experts

Learned from researcher friend that this has been tried in Switch Transformers unsuccessfully, but I'll give it a go, bringing in some learning points from recent papers like CoLT5.

In my opinion, the CoLT5 paper basically demonstrates mixture of attention already for 2 experts. This just has to be generalized to greater than 2 experts, and for autoregressive case. Local attention branch would just be a special case of one expert with fixed routing. If I route only half the tokens, that would lead to a savings of 4x. If I can show even ~4 experts being better than 1 attention, that should be a win.

Appreciation

  • Stability and 🤗 Huggingface for their generous sponsorships to work on and open source cutting edge artificial intelligence research

  • einops for making tensor manipulation fun and easy

Install

$ pip install mixture-of-attention

Usage

import torch
from mixture_of_attention import MixtureOfAttention

mixture_of_attn = MixtureOfAttention(
    dim = 512,
    dim_context = 256,
    num_routed_queries = 16,
    num_routed_key_values = 16,
    num_experts = 2,
    dim_head = 64,
    heads = 8
)

x = torch.randn(1, 1024, 512)
mask = torch.ones((1, 1024)).bool()

context = torch.randn(1, 512, 256)
context_mask = torch.ones((1, 512)).bool()

mixture_of_attn(x, context = context, mask = mask) # (1, 1024, 512)

Autoregressive flavor

import torch
from mixture_of_attention import MixtureOfAutoregressiveAttention

mixture_of_attn = MixtureOfAutoregressiveAttention(
    dim = 512,
    local_attn_window_size = 64,       # local attention window size
    routed_window_size = None,         # will be set to the same as local_attn_window_size if None. ideally less than or equal to local attention window size for full receptive field
    num_routed_queries = 12,
    num_routed_key_values = 12,
    num_experts = 2,
    dim_head = 64,
    heads = 8
)

x = torch.randn(1, 1023, 512)

out = mixture_of_attn(x) # (1, 1023, 512)

Todo

Citations

@inproceedings{Ainslie2023CoLT5FL,
    title   = {CoLT5: Faster Long-Range Transformers with Conditional Computation},
    author  = {Joshua Ainslie and Tao Lei and Michiel de Jong and Santiago Ontan'on and Siddhartha Brahma and Yury Zemlyanskiy and David Uthus and Mandy Guo and James Lee-Thorp and Yi Tay and Yun-Hsuan Sung and Sumit Sanghai},
    year    = {2023}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@article{Wright2015CoordinateDA,
    title   = {Coordinate descent algorithms},
    author  = {Stephen J. Wright},
    journal = {Mathematical Programming},
    year    = {2015},
    volume  = {151},
    pages   = {3-34}
}
@article{Schmitzer2016StabilizedSS,
    title   = {Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems},
    author  = {Bernhard Schmitzer},
    journal = {ArXiv},
    year    = {2016},
    volume  = {abs/1610.06519}
}
@inproceedings{rogozhnikov2022einops,
    title   = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
    author  = {Alex Rogozhnikov},
    booktitle = {International Conference on Learning Representations},
    year    = {2022},
    url     = {https://openreview.net/forum?id=oapKSVM2bcj}
}

.\lucidrains\mixture-of-attention\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'mixture-of-attention', # 包的名称
  packages = find_packages(exclude=[]), # 查找所有包
  version = '0.0.24', # 版本号
  license='MIT', # 许可证
  description = 'Mixture of Attention', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  long_description_content_type = 'text/markdown', # 长描述内容类型
  url = 'https://github.com/lucidrains/mixture-of-attention', # URL
  keywords = [ # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention mechanism',
    'mixture-of-experts',
    'routed attention'
  ],
  install_requires=[ # 安装依赖
    'colt5-attention>=0.10.14',
    'einops>=0.6.1',
    'local-attention>=1.8.6',
    'torch>=1.6',
  ],
  classifiers=[ # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\mixture-of-attention\train.py

# 导入必要的库
import gzip
import random
import tqdm
import numpy as np

import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

from mixture_of_attention.transformer import Transformer
from mixture_of_attention.autoregressive_wrapper import AutoregressiveWrapper

# 定义常量
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
PRIME_LENGTH = 128
GENERATE_EVERY = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512

# 定义辅助函数

# 从 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 从 tokens 解码为字符串
def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))

# 实例化 Transformer 模型
model = Transformer(
    num_tokens = 256,
    dim = 512,
    depth = 8,
    num_experts = 2,
    seq_len = SEQ_LEN,
    local_attn_window_size = 64,
    num_routed_queries = 32,
    num_routed_key_values = 64,
    cosine_sim_routing = True,
    use_triton = True
)

model = AutoregressiveWrapper(model).cuda()

# 准备 enwik8 数据

with gzip.open("./data/enwik8.gz") as file:
    data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
    np_train, np_valid = np.split(data, [int(90e6)])
    data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)

# 定义自定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的 DataLoader
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))

# 定义优化器
optim = Adam(model.parameters(), lr = LEARNING_RATE)

# 训练模型
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
    model.train()

    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward(loss / GRADIENT_ACCUMULATE_EVERY)

    print(f"training loss: {loss.item()}")
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f"validation loss: {loss.item()}")

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:PRIME_LENGTH]
        prime = decode_tokens(inp)
        print(f"%s \n\n %s", (prime, "*" * 100))

        sample = model.generate(inp[None, ...], GENERATE_LENGTH)
        output_str = decode_tokens(sample[0])
        print(output_str, "\n")

.\lucidrains\mixture-of-experts\mixture_of_experts\mixture_of_experts.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch 库中导入 nn.functional 模块,并使用别名 F
import torch.nn.functional as F

# 导入 math 库
import math
# 从 inspect 库中导入 isfunction 函数

# 常量定义
MIN_EXPERT_CAPACITY = 4

# 辅助函数

# 默认值函数,如果 val 为 None,则返回 default_val
def default(val, default_val):
    # 如果 default_val 是函数,则调用该函数,否则直接返回 default_val
    default_val = default_val() if isfunction(default_val) else default_val
    return val if val is not None else default_val

# 将元素 el 转换为元组
def cast_tuple(el):
    return el if isinstance(el, tuple) else (el,)

# 与张量相关的辅助函数

# 获取张量 t 中最大的值和对应的索引
def top1(t):
    values, index = t.topk(k=1, dim=-1)
    values, index = map(lambda x: x.squeeze(dim=-1), (values, index))
    return values, index

# 计算张量 t 在指定维度上的累积和,不包括当前位置的值
def cumsum_exclusive(t, dim=-1):
    num_dims = len(t.shape)
    num_pad_dims = - dim - 1
    pre_padding = (0, 0) * num_pad_dims
    pre_slice   = (slice(None),) * num_pad_dims
    padded_t = F.pad(t, (*pre_padding, 1, 0)).cumsum(dim=dim)
    return padded_t[(..., slice(None, -1), *pre_slice)]

# 安全的 one-hot 编码函数,避免索引超出范围
def safe_one_hot(indexes, max_length):
    max_index = indexes.max() + 1
    return F.one_hot(indexes, max(max_index + 1, max_length))[..., :max_length]

# 初始化张量 t,使用均匀分布
def init_(t):
    dim = t.shape[-1]
    std = 1 / math.sqrt(dim)
    return t.uniform_(-std, std)

# 激活函数

# GELU 激活函数类
class GELU_(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))

# 如果 nn 模块中存在 GELU 函数,则使用该函数,否则使用自定义的 GELU_ 激活函数
GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_

# 专家类

class Experts(nn.Module):
    def __init__(self,
        dim,
        num_experts = 16,
        hidden_dim = None,
        activation = GELU):
        super().__init__()

        hidden_dim = default(hidden_dim, dim * 4)
        num_experts = cast_tuple(num_experts)

        w1 = torch.zeros(*num_experts, dim, hidden_dim)
        w2 = torch.zeros(*num_experts, hidden_dim, dim)

        w1 = init_(w1)
        w2 = init_(w2)

        self.w1 = nn.Parameter(w1)
        self.w2 = nn.Parameter(w2)
        self.act = activation()

    def forward(self, x):
        hidden = torch.einsum('...nd,...dh->...nh', x, self.w1)
        hidden = self.act(hidden)
        out    = torch.einsum('...nh,...hd->...nd', hidden, self.w2)
        return out

# 下面的代码几乎完全从官方的 tensorflow 版本转录而来,相关论文也是基于此版本编写
# https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/research/moe.py

# 门控网络

class Top2Gating(nn.Module):
    def __init__(
        self,
        dim,
        num_gates,
        eps = 1e-9,
        outer_expert_dims = tuple(),
        second_policy_train = 'random',
        second_policy_eval = 'random',
        second_threshold_train = 0.2,
        second_threshold_eval = 0.2,
        capacity_factor_train = 1.25,
        capacity_factor_eval = 2.):
        super().__init__()

        self.eps = eps
        self.num_gates = num_gates
        self.w_gating = nn.Parameter(torch.randn(*outer_expert_dims, dim, num_gates))

        self.second_policy_train = second_policy_train
        self.second_policy_eval = second_policy_eval
        self.second_threshold_train = second_threshold_train
        self.second_threshold_eval = second_threshold_eval
        self.capacity_factor_train = capacity_factor_train
        self.capacity_factor_eval = capacity_factor_eval

# 普通的专家混合模型

class MoE(nn.Module):
    # 初始化函数,设置模型参数和属性
    def __init__(self,
        dim,
        num_experts = 16,
        hidden_dim = None,
        activation = nn.ReLU,
        second_policy_train = 'random',
        second_policy_eval = 'random',
        second_threshold_train = 0.2,
        second_threshold_eval = 0.2,
        capacity_factor_train = 1.25,
        capacity_factor_eval = 2.,
        loss_coef = 1e-2,
        experts = None):
        # 调用父类的初始化函数
        super().__init__()

        # 设置模型的专家数量
        self.num_experts = num_experts

        # 设置门控参数
        gating_kwargs = {'second_policy_train': second_policy_train, 'second_policy_eval': second_policy_eval, 'second_threshold_train': second_threshold_train, 'second_threshold_eval': second_threshold_eval, 'capacity_factor_train': capacity_factor_train, 'capacity_factor_eval': capacity_factor_eval}
        # 创建门控对象
        self.gate = Top2Gating(dim, num_gates = num_experts, **gating_kwargs)
        # 创建专家对象
        self.experts = default(experts, lambda: Experts(dim, num_experts = num_experts, hidden_dim = hidden_dim, activation = activation))
        # 设置损失系数
        self.loss_coef = loss_coef

    # 前向传播函数
    def forward(self, inputs, **kwargs):
        # 获取输入的形状信息
        b, n, d, e = *inputs.shape, self.num_experts
        # 获取门控输出和损失
        dispatch_tensor, combine_tensor, loss = self.gate(inputs)
        # 将输入数据分发给专家
        expert_inputs = torch.einsum('bnd,bnec->ebcd', inputs, dispatch_tensor)

        # 将专家输入数据传递给专家模型
        orig_shape = expert_inputs.shape
        expert_inputs = expert_inputs.reshape(e, -1, d)
        expert_outputs = self.experts(expert_inputs)
        expert_outputs = expert_outputs.reshape(*orig_shape)

        # 将专家输出数据合并
        output = torch.einsum('ebcd,bnec->bnd', expert_outputs, combine_tensor)
        # 返回输出和损失乘以损失系数
        return output, loss * self.loss_coef
# 定义一个名为 HeirarchicalMoE 的类,表示两级层次混合专家模型
class HeirarchicalMoE(nn.Module):
    def __init__(self,
        dim,
        num_experts = (4, 4),  # 设置专家数量,默认为 (4, 4)
        hidden_dim = None,  # 隐藏层维度,默认为 None
        activation = nn.ReLU,  # 激活函数,默认为 ReLU
        second_policy_train = 'random',  # 第二级门控策略(训练阶段),默认为 'random'
        second_policy_eval = 'random',  # 第二级门控策略(评估阶段),默认为 'random'
        second_threshold_train = 0.2,  # 第二级门控阈值(训练阶段),默认为 0.2
        second_threshold_eval = 0.2,  # 第二级门控阈值(评估阶段),默认为 0.2
        capacity_factor_train = 1.25,  # 容量因子(训练阶段),默认为 1.25
        capacity_factor_eval = 2.,  # 容量因子(评估阶段),默认为 2.0
        loss_coef = 1e-2,  # 损失系数,默认为 0.01
        experts = None):  # 专家模型,默认为 None
        super().__init__()

        assert len(num_experts) == 2, 'only 2 levels of heirarchy for experts allowed for now'  # 断言,只允许两级专家层次
        num_experts_outer, num_experts_inner = num_experts
        self.num_experts_outer = num_experts_outer
        self.num_experts_inner = num_experts_inner

        gating_kwargs = {'second_policy_train': second_policy_train, 'second_policy_eval': second_policy_eval, 'second_threshold_train': second_threshold_train, 'second_threshold_eval': second_threshold_eval, 'capacity_factor_train': capacity_factor_train, 'capacity_factor_eval': capacity_factor_eval}

        # 创建外层门控模块和内层门控模块
        self.gate_outer = Top2Gating(dim, num_gates = num_experts_outer, **gating_kwargs)
        self.gate_inner = Top2Gating(dim, num_gates = num_experts_inner, outer_expert_dims = (num_experts_outer,), **gating_kwargs)

        # 创建专家模型
        self.experts = default(experts, lambda: Experts(dim, num_experts = num_experts, hidden_dim = hidden_dim, activation = activation))
        self.loss_coef = loss_coef

    def forward(self, inputs, **kwargs):
        b, n, d, eo, ei = *inputs.shape, self.num_experts_outer, self.num_experts_inner
        dispatch_tensor_outer, combine_tensor_outer, loss_outer = self.gate_outer(inputs)
        expert_inputs_outer = torch.einsum('bnd,bnec->ebcd', inputs, dispatch_tensor_outer)

        # 构建“重要性”张量,用于第二级门控
        importance = combine_tensor_outer.permute(2, 0, 3, 1).sum(dim=-1)
        importance = 0.5 * ((importance > 0.5).float() + (importance > 0.).float())

        dispatch_tensor_inner, combine_tensor_inner, loss_inner = self.gate_inner(expert_inputs_outer, importance = importance)
        expert_inputs = torch.einsum('ebnd,ebnfc->efbcd', expert_inputs_outer, dispatch_tensor_inner)

        # 通过专家模型处理专家输入
        orig_shape = expert_inputs.shape
        expert_inputs = expert_inputs.reshape(eo, ei, -1, d)
        expert_outputs = self.experts(expert_inputs)
        expert_outputs = expert_outputs.reshape(*orig_shape)

        # 合并专家输出
        expert_outputs_outer = torch.einsum('efbcd,ebnfc->ebnd', expert_outputs, combine_tensor_inner)
        output = torch.einsum('ebcd,bnec->bnd', expert_outputs_outer, combine_tensor_outer)
        return output, (loss_outer + loss_inner) * self.loss_coef

.\lucidrains\mixture-of-experts\mixture_of_experts\__init__.py

# 从mixture_of_experts包中导入MoE、HeirarchicalMoE和Experts类
from mixture_of_experts.mixture_of_experts import MoE, HeirarchicalMoE, Experts

Sparsely Gated Mixture of Experts - Pytorch

A Pytorch implementation of Sparsely Gated Mixture of Experts, for massively increasing the capacity (parameter count) of a language model while keeping the computation constant.

It will mostly be a line-by-line transcription of the tensorflow implementation here, with a few enhancements.

Update: You should now use ST Mixture of Experts

PyPI version

Install

$ pip install mixture_of_experts

Usage

import torch
from torch import nn
from mixture_of_experts import MoE

moe = MoE(
    dim = 512,
    num_experts = 16,               # increase the experts (# parameters) of your model without increasing computation
    hidden_dim = 512 * 4,           # size of hidden dimension in each expert, defaults to 4 * dimension
    activation = nn.LeakyReLU,      # use your preferred activation, will default to GELU
    second_policy_train = 'random', # in top_2 gating, policy for whether to use a second-place expert
    second_policy_eval = 'random',  # all (always) | none (never) | threshold (if gate value > the given threshold) | random (if gate value > threshold * random_uniform(0, 1))
    second_threshold_train = 0.2,
    second_threshold_eval = 0.2,
    capacity_factor_train = 1.25,   # experts have fixed capacity per batch. we need some extra capacity in case gating is not perfectly balanced.
    capacity_factor_eval = 2.,      # capacity_factor_* should be set to a value >=1
    loss_coef = 1e-2                # multiplier on the auxiliary expert balancing auxiliary loss
)

inputs = torch.randn(4, 1024, 512)
out, aux_loss = moe(inputs) # (4, 1024, 512), (1,)

The above should suffice for a single machine, but if you want a heirarchical mixture of experts (2 levels), as used in the GShard paper, please follow the instructions below

import torch
from mixture_of_experts import HeirarchicalMoE

moe = HeirarchicalMoE(
    dim = 512,
    num_experts = (4, 4),       # 4 gates on the first layer, then 4 experts on the second, equaling 16 experts
)

inputs = torch.randn(4, 1024, 512)
out, aux_loss = moe(inputs) # (4, 1024, 512), (1,)

1 billion parameters

import torch
from mixture_of_experts import HeirarchicalMoE

moe = HeirarchicalMoE(
    dim = 512,
    num_experts = (22, 22)
).cuda()

inputs = torch.randn(1, 1024, 512).cuda()
out, aux_loss = moe(inputs)

total_params = sum(p.numel() for p in moe.parameters())
print(f'number of parameters - {total_params}')

If you want some more sophisticated network for the experts, you can define your own and pass it into the MoE class as experts

import torch
from torch import nn
from mixture_of_experts import MoE

# a 3 layered MLP as the experts

class Experts(nn.Module):
    def __init__(self, dim, num_experts = 16):
        super().__init__()
        self.w1 = nn.Parameter(torch.randn(num_experts, dim, dim * 4))
        self.w2 = nn.Parameter(torch.randn(num_experts, dim * 4, dim * 4))
        self.w3 = nn.Parameter(torch.randn(num_experts, dim * 4, dim))
        self.act = nn.LeakyReLU(inplace = True)

    def forward(self, x):
        hidden1 = self.act(torch.einsum('end,edh->enh', x, self.w1))
        hidden2 = self.act(torch.einsum('end,edh->enh', hidden1, self.w2))
        out = torch.einsum('end,edh->enh', hidden2, self.w3)
        return out

experts = Experts(512, num_experts = 16)

moe = MoE(
    dim = 512,
    num_experts = 16,
    experts = experts
)

inputs = torch.randn(4, 1024, 512)
out, aux_loss = moe(inputs) # (4, 1024, 512), (1,)

Citation

@misc{shazeer2017outrageously,
    title   = {Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer},
    author  = {Noam Shazeer and Azalia Mirhoseini and Krzysztof Maziarz and Andy Davis and Quoc Le and Geoffrey Hinton and Jeff Dean},
    year    = {2017},
    eprint  = {1701.06538},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{lepikhin2020gshard,
    title   = {GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding},
    author  = {Dmitry Lepikhin and HyoukJoong Lee and Yuanzhong Xu and Dehao Chen and Orhan Firat and Yanping Huang and Maxim Krikun and Noam Shazeer and Zhifeng Chen},
    year    = {2020},
    eprint  = {2006.16668},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

.\lucidrains\mixture-of-experts\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的信息
setup(
  # 包名
  name = 'mixture-of-experts',
  # 查找所有包
  packages = find_packages(),
  # 版本号
  version = '0.2.3',
  # 许可证
  license='MIT',
  # 描述
  description = 'Sparsely-Gated Mixture of Experts for Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/mixture-of-experts',
  # 关键词
  keywords = ['artificial intelligence', 'deep learning', 'transformers', 'mixture of experts'],
  # 安装依赖
  install_requires=[
      'torch'
  ],
  # 分类
  classifiers=[
      'Development Status :: 4 - Beta',
      'Intended Audience :: Developers',
      'Topic :: Scientific/Engineering :: Artificial Intelligence',
      'License :: OSI Approved :: MIT License',
      'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\mlm-pytorch\mlm_pytorch\mlm_pytorch.py

# 导入数学库
import math
# 从 functools 库中导入 reduce 函数
from functools import reduce

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 torch 库中导入 functional 模块
import torch.nn.functional as F

# 辅助函数

# 根据概率生成掩码
def prob_mask_like(t, prob):
    return torch.zeros_like(t).float().uniform_(0, 1) < prob

# 使用特定的标记生成掩码
def mask_with_tokens(t, token_ids):
    init_no_mask = torch.full_like(t, False, dtype=torch.bool)
    mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
    return mask

# 根据概率获取掩码的子集
def get_mask_subset_with_prob(mask, prob):
    batch, seq_len, device = *mask.shape, mask.device
    max_masked = math.ceil(prob * seq_len)

    num_tokens = mask.sum(dim=-1, keepdim=True)
    mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil())
    mask_excess = mask_excess[:, :max_masked]

    rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9)
    _, sampled_indices = rand.topk(max_masked, dim=-1)
    sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0)

    new_mask = torch.zeros((batch, seq_len + 1), device=device)
    new_mask.scatter_(-1, sampled_indices, 1)
    return new_mask[:, 1:].bool()

# 主类

class MLM(nn.Module):
    def __init__(
        self,
        transformer,
        mask_prob = 0.15,
        replace_prob = 0.9,
        num_tokens = None,
        random_token_prob = 0.,
        mask_token_id = 2,
        pad_token_id = 0,
        mask_ignore_token_ids = []):
        super().__init__()

        self.transformer = transformer

        # MLM 相关概率

        self.mask_prob = mask_prob
        self.replace_prob = replace_prob

        self.num_tokens = num_tokens
        self.random_token_prob = random_token_prob

        # 标记 ID

        self.pad_token_id = pad_token_id
        self.mask_token_id = mask_token_id
        self.mask_ignore_token_ids = set([*mask_ignore_token_ids, pad_token_id])

    def forward(self, seq, **kwargs):

        # 不要对 [pad] 标记或任何在被排除的标记中的标记进行掩码,也不要在随机选择的标记中包含这些特殊标记

        no_mask = mask_with_tokens(seq, self.mask_ignore_token_ids)
        mask = get_mask_subset_with_prob(~no_mask, self.mask_prob)

        # 使用概率 `replace_prob` 对输入进行 [mask] 处理(以概率 1 - replace_prob 保持标记不变)

        masked_seq = seq.clone().detach()

        # 推导出要预测的标签

        labels = seq.masked_fill(~mask, self.pad_token_id)

        # 如果 MLM 中随机标记概率 > 0

        if self.random_token_prob > 0:
            assert self.num_tokens is not None, 'num_tokens keyword must be supplied when instantiating MLM if using random token replacement'
            random_token_prob = prob_mask_like(seq, self.random_token_prob)
            random_tokens = torch.randint(0, self.num_tokens, seq.shape, device=seq.device)
            random_no_mask = mask_with_tokens(random_tokens, self.mask_ignore_token_ids)
            random_token_prob &= ~random_no_mask
            masked_seq = torch.where(random_token_prob, random_tokens, masked_seq)

            # 从后续的 [mask] 中移除被随机替换的标记
            mask = mask & ~random_token_prob

        # [mask] 输入

        replace_prob = prob_mask_like(seq, self.replace_prob)
        masked_seq = masked_seq.masked_fill(mask * replace_prob, self.mask_token_id)

        # 获取生成器输出并计算 MLM 损失

        logits = self.transformer(masked_seq, **kwargs)

        mlm_loss = F.cross_entropy(
            logits.transpose(1, 2),
            labels,
            ignore_index = self.pad_token_id
        )

        return mlm_loss

.\lucidrains\mlm-pytorch\mlm_pytorch\__init__.py

# 从 mlm_pytorch 包中导入 MLM 类
from mlm_pytorch.mlm_pytorch import MLM

MLM (Masked Language Modeling) Pytorch

This repository allows you to quickly setup unsupervised training for your transformer off a corpus of sequence data.

Install

$ pip install mlm-pytorch

Usage

First pip install x-transformers, then run the following example to see what one iteration of the unsupervised training is like

import torch
from torch import nn
from torch.optim import Adam
from mlm_pytorch import MLM

# instantiate the language model

from x_transformers import TransformerWrapper, Encoder

transformer = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

# plugin the language model into the MLM trainer

trainer = MLM(
    transformer,
    mask_token_id = 2,          # the token id reserved for masking
    pad_token_id = 0,           # the token id for padding
    mask_prob = 0.15,           # masking probability for masked language modeling
    replace_prob = 0.90,        # ~10% probability that token will not be masked, but included in loss, as detailed in the epaper
    mask_ignore_token_ids = []  # other tokens to exclude from masking, include the [cls] and [sep] here
).cuda()

# optimizer

opt = Adam(trainer.parameters(), lr=3e-4)

# one training step (do this for many steps in a for loop, getting new `data` each time)

data = torch.randint(0, 20000, (8, 1024)).cuda()

loss = trainer(data)
loss.backward()
opt.step()
opt.zero_grad()

# after much training, the model should have improved for downstream tasks

torch.save(transformer, f'./pretrained-model.pt')

Do the above for many steps, and your model should improve.

Citation

@misc{devlin2018bert,
    title   = {BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding},
    author  = {Jacob Devlin and Ming-Wei Chang and Kenton Lee and Kristina Toutanova},
    year    = {2018},
    eprint  = {1810.04805},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

.\lucidrains\mlm-pytorch\setup.py

# 导入设置安装和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'mlm-pytorch',  # 包的名称
  packages = find_packages(),  # 查找所有包
  version = '0.1.0',  # 版本号
  license='MIT',  # 许可证
  description = 'MLM (Masked Language Modeling) - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/mlm-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'transformers',
    'artificial intelligence',
    'pretraining',
    'unsupervised learning'
  ],
  install_requires=[  # 安装依赖
    'torch>=1.1.0'
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

.\lucidrains\mlp-gpt-jax\mlp_gpt_jax\mlp_gpt_jax.py

# 导入必要的库
from functools import partial

import jax
from jax import random
from jax import nn
import jax.numpy as np

import haiku as hk
from haiku import initializers
from einops import rearrange

# 常量定义
EPS = 1e-3
ATTN_MASK_VALUE = -1e10

# 定义 LayerNorm 函数
LayerNorm = partial(hk.LayerNorm, create_scale = True, create_offset = True, axis = -1)

# 定义 exists 函数,用于检查变量是否存在
def exists(val):
    return val is not None

# 定义 Attention 类
class Attention(hk.Module):
    def __init__(
        self,
        *,
        dim_out,
        dim_head
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.to_qkv = hk.Linear(dim_head * 3)
        self.to_out = hk.Linear(dim_out)

    def __call__(self, x):
        n = x.shape[0]

        qkv = self.to_qkv(x)
        q, k, v = np.split(qkv, 3, axis = -1)
        sim = np.einsum('i d, j d -> i j', q, k) * self.scale

        mask = np.triu(np.ones((n, n), dtype = bool), 1)
        sim = np.where(mask, ATTN_MASK_VALUE, sim)

        attn = nn.softmax(sim, axis = -1)
        out = np.einsum('i j, j d -> i d', attn, v)
        return self.to_out(out)

# 定义 SGU 类
class SGU(hk.Module):
    def __init__(
        self,
        *,
        dim,
        dim_out,
        seq_len
    ):
        super().__init__()
        self.seq_len = seq_len
        self.norm = LayerNorm()
        self.proj_out = hk.Linear(dim_out)

    def __call__(self, x, gate_res = None):
        n = self.seq_len
        x, gate = np.split(x, 2, axis = -1)

        gate = self.norm(gate)

        init_scale = EPS / n
        init_eps = initializers.RandomUniform(minval = -init_scale, maxval = init_scale)

        weights = hk.get_parameter('spatial_weights', shape = (n, n), init = init_eps)
        biases = hk.get_parameter('spatial_biases', shape = (n, 1), init = np.ones)

        mask = np.tril(np.ones((n, n)))
        weights = weights * mask

        gate = np.einsum('n d, m n -> m d', gate, weights)
        gate += biases

        if exists(gate_res):
            gate += gate_res

        x = x * gate
        return self.proj_out(x)

# 定义 gMLP 类
class gMLP(hk.Module):
    def __init__(
        self,
        *,
        dim,
        dim_ff,
        seq_len,
        name,
        attn_dim = None
    ):
        super().__init__(name = name)
        self.attn = Attention(dim_head = attn_dim, dim_out = dim_ff // 2) if exists(attn_dim) else None
        self.norm = LayerNorm()
        self.proj_in = hk.Linear(dim_ff)
        self.sgu = SGU(dim = dim_ff, dim_out = dim_ff // 2, seq_len = seq_len)
        self.proj_out = hk.Linear(dim)

    def __call__(self, x):
        x = self.norm(x)
        gate_res = self.attn(x) if exists(self.attn) else None

        x = self.proj_in(x)
        x = nn.gelu(x)
        x = self.sgu(x, gate_res)
        x = self.proj_out(x)
        return x

# 定义 MaybeExecute 类
class MaybeExecute(hk.Module):
    def __init__(
        self,
        *,
        prob_execute,
        fn
    ):
        super().__init__()
        self.fn = fn
        self.prob_execute = prob_execute

    def __call__(self, x):
        key = hk.next_rng_key()
        p = random.bernoulli(key, p = self.prob_execute)
        out = self.fn(x) * p + 0 * (1 - p)
        return out / self.prob_execute

# 定义 MLPGpt 类
class MLPGpt(hk.Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        seq_len,
        depth,
        heads = 1,
        ff_mult = 4,
        attn_dim = None,
        clamp_gate = True,
        layer_survival_prob = 1.
    ):
        super().__init__()
        self.embed = hk.Embed(num_tokens, dim)

        gmlps = [gMLP(dim = dim, dim_ff = dim * ff_mult, seq_len = seq_len, name = f'gmlp{i}', attn_dim = attn_dim) for i in range(depth)]
        self.layers = [MaybeExecute(prob_execute = layer_survival_prob, fn = gmlp) for gmlp in gmlps]

        self.to_logits = hk.Sequential([
            LayerNorm(),
            hk.Linear(num_tokens)
        ])
    # 定义一个类的调用方法,接受输入 x
    def __call__(self, x):
        # 将输入 x 嵌入到模型中
        x = self.embed(x)

        # 遍历模型中的每一层,并对输入 x 进行处理
        for layer in self.layers:
            x += layer(x)

        # 将处理后的结果转换为 logits
        return self.to_logits(x)
# 定义一个装饰器函数,用于将 MLPGpt 模型转换为可训练的函数
def TransformedMLPGpt(**kwargs):
    # 定义一个内部函数,使用 hk.transform 装饰器将其转换为可训练函数
    def inner(seq):
        # 调用 MLPGpt 模型,并传入参数 kwargs,对输入序列进行处理
        return MLPGpt(**kwargs)(seq)
    # 返回内部函数
    return inner

.\lucidrains\mlp-gpt-jax\mlp_gpt_jax\utils.py

# 导入所需的库
from jax import random, nn, value_and_grad, vmap, jit
from jax.lax import top_k
import jax.numpy as np

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 计算对数,加上一个很小的值以避免出现对数零的情况
def log(t, eps = 1e-20):
    return np.log(t + eps)

# 训练函数

# 计算交叉熵损失
def cross_entropy(logits, targets, axis = -1):
    logprobs = nn.log_softmax(logits, axis = axis)
    nll = np.take_along_axis(logprobs, np.expand_dims(targets, axis = axis), axis = axis)
    ce = -np.mean(nll)
    return ce

# 获取训练损失函数
def get_train_loss_fn(model):
    batch_model_apply = jit(vmap(model.apply, in_axes = (None, None, 0), out_axes = 0))

    @value_and_grad
    def loss_fn(params, key, data):
        inp, labels = data[:, :-1], data[:, 1:]
        logits = batch_model_apply(params, key, inp)
        return cross_entropy(logits, labels, axis = -1)

    return loss_fn

# 采样函数

# 选择前 k 个最大值
def select_top_k(tensor, k):
    values, _ = top_k(tensor, k)
    mask = tensor > values.min()
    return mask, np.where(mask, tensor, 0.)

# 生成 Gumbel 噪声
def gumbel_noise(rng, shape):
    noise = random.uniform(rng, shape = shape, minval = 0., maxval = 1.)
    return -log(-log(noise))

# 生成样本序列
def sample(rng, fn, params, prime, length, top_k = None):
    start_pos = prime.shape[-1]
    seq = np.pad(prime, (0, length - prime.shape[-1]))
    one_hots = np.eye(length, dtype = int)

    for curr_pos in range(start_pos, length):
        logits = fn(params, next(rng), seq)
        logits = logits[curr_pos - 1]

        noise = gumbel_noise(next(rng), logits.shape)

        if exists(top_k):
            mask, logits = select_top_k(logits, top_k)
            noise *= mask

        logits += noise
        sampled_ind = np.argmax(logits, axis = -1)

        one_hot = one_hots[curr_pos]
        seq += one_hot * sampled_ind

    return seq

.\lucidrains\mlp-gpt-jax\mlp_gpt_jax\__init__.py

# 从mlp_gpt_jax.mlp_gpt_jax模块中导入MLPGpt和TransformedMLPGpt类
from mlp_gpt_jax.mlp_gpt_jax import MLPGpt, TransformedMLPGpt

MLP GPT - Jax

A GPT, made only of MLPs, in Jax. The specific MLP to be used are gMLPs with the Spatial Gating Units.

Working Pytorch implementation

Install

$ pip install mlp-gpt-jax

Usage

from jax import random
from haiku import PRNGSequence
from mlp_gpt_jax import TransformedMLPGpt

model = TransformedMLPGpt(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 1024
)

rng = PRNGSequence(0)
seq = random.randint(next(rng), (1024,), 0, 20000)

params = model.init(next(rng), seq)
logits = model.apply(params, next(rng), seq) # (1024, 20000)

To use the tiny attention (also made autoregressive with a causal mask), just set the attn_dim to the head dimension you'd like to use. 64 was recommended in the paper

from jax import random
from haiku import PRNGSequence
from mlp_gpt_jax import TransformedMLPGpt

model = TransformedMLPGpt(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 1024,
    attn_dim = 64     # set this to 64
)

rng = PRNGSequence(0)
seq = random.randint(next(rng), (1024,), 0, 20000)

params = model.init(next(rng), seq)
logits = model.apply(params, next(rng), seq) # (1024, 20000)

Citations

@misc{liu2021pay,
    title   = {Pay Attention to MLPs}, 
    author  = {Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
    year    = {2021},
    eprint  = {2105.08050},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

.\lucidrains\mlp-gpt-jax\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
    name="mlp-gpt-jax",  # 包的名称
    packages=find_packages(),  # 查找所有包
    version="0.0.20",  # 版本号
    license="MIT",  # 许可证
    description="MLP GPT - Jax",  # 描述
    author="Phil Wang",  # 作者
    author_email="",  # 作者邮箱
    url="https://github.com/lucidrains/mlp-gpt-jax",  # 项目链接
    keywords=[  # 关键词列表
        "artificial intelligence",
        "deep learning",
        "language model",
        "multilayered-perceptron",
        "jax"
    ],
    install_requires=[  # 安装依赖列表
        "click",
        "click-option-group",
        "einops>=0.3",
        "dm-haiku",
        "jax",
        "jaxlib",
        "optax",
        "torch",
        "tqdm"
    ],
    classifiers=[  # 分类器列表
        "Development Status :: 4 - Beta",
        "Intended Audience :: Developers",
        "Topic :: Scientific/Engineering :: Artificial Intelligence",
        "License :: OSI Approved :: MIT License",
        "Programming Language :: Python :: 3.6",
    ],
)

.\lucidrains\mlp-gpt-jax\train.py

# 从 random 模块中导入 randrange 函数
# 从 tqdm 模块中导入 tqdm 函数
# 从 gzip 模块中导入 gzip 模块
# 从 numpy 模块中导入 np 别名
from random import randrange
import tqdm
import gzip
import numpy as np

# 从 torch.utils.data 模块中导入 DataLoader, Dataset 类
# 从 jax 模块中导入 nn, random, jit 模块
# 从 optax 模块中导入 adam, clip_by_global_norm, chain, apply_updates, apply_every 模块
# 从 haiku 模块中导入 PRNGSequence 类
# 从 mlp_gpt_jax 模块中导入 TransformedMLPGpt 类
# 从 mlp_gpt_jax.utils 模块中导入 sample, get_train_loss_fn 函数
from torch.utils.data import DataLoader, Dataset
import jax
from jax import nn, random, jit
from optax import adam, clip_by_global_norm, chain, apply_updates, apply_every
from haiku import PRNGSequence
from mlp_gpt_jax import TransformedMLPGpt
from mlp_gpt_jax.utils import sample, get_train_loss_fn

# 常量定义
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
MAX_GRAD_NORM = 0.5
VALIDATE_EVERY  = 100
SAMPLE_EVERY  = 500
SEQ_LEN = 768

# 辅助函数定义

# 生成 DataLoader 的循环迭代器
def cycle(loader):
    while True:
        for data in loader:
            yield data

# 解码单个 token
def decode_token(token):
    return str(chr(max(32, token)))

# 解码一组 tokens
def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# 准备 enwik8 数据

# 使用 gzip 模块打开 enwik8.gz 文件
with gzip.open('./data/enwik8.gz') as file:
    # 从文件中读取前 95e6 个字节,转换为 numpy 数组 X
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    # 将数据 X 分割成训练集和验证集
    data_train, data_val = np.split(X, [int(90e6)])

# 定义 TextSamplerDataset 类,继承自 Dataset 类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        # 随机生成起始位置,返回该位置开始的 seq_len + 1 长度的数据
        rand_start = randrange(0, self.data.shape[0] - self.seq_len - 1)
        return self.data[rand_start: rand_start + self.seq_len + 1]

    def __len__(self):
        # 返回数据长度除以 seq_len
        return self.data.shape[0] // self.seq_len

# 创建训练集和验证集的 DataLoader
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 设置模型和参数

model_kwargs = dict(
    num_tokens = 256,
    dim = 512,
    seq_len = SEQ_LEN,
    depth = 8,
    attn_dim = 32,
)

# 初始化训练模型和评估模型
train_model = TransformedMLPGpt(**model_kwargs, layer_survival_prob = 0.95)
eval_model = TransformedMLPGpt(**model_kwargs)

# 创建 PRNGSequence 对象 rng
rng = PRNGSequence(42)
# 初始化模型参数 params
params = train_model.init(next(rng), train_dataset[0][:-1])

# 获取训练损失函数
loss_fn = get_train_loss_fn(train_model)

# 优化器

# 定义优化器链
optim = chain(
    clip_by_global_norm(MAX_GRAD_NORM),
    adam(LEARNING_RATE),
    apply_every(GRADIENT_ACCUMULATE_EVERY)
)

# 初始化优化器状态
optim_state = optim.init(params)

# 训练

# 循环训练 NUM_BATCHES 次
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    # 获取下一个训练数据
    data = next(train_loader).numpy()
    # 计算损失和梯度
    loss, grads = loss_fn(params, next(rng), data)
    # 更新参数
    updates, optim_state = optim.update(grads, optim_state, params)
    params = apply_updates(params, updates)

    # 每隔 GRADIENT_ACCUMULATE_EVERY 次输出损失
    if i % GRADIENT_ACCUMULATE_EVERY == 0:
        print(f'loss: {loss.item()}')

    # 每隔 SAMPLE_EVERY 次生成样本
    if i % SAMPLE_EVERY == 0:
        # 获取下一个验证数据
        valid_data = next(val_loader).numpy()
        prime = valid_data[0][:100]
        prime_str = decode_tokens(prime)
        print(prime_str, "\n", "*" * 40)

        # 生成样本并解码
        sampled = sample(rng, jit(eval_model.apply), params, prime, SEQ_LEN, top_k = 25)
        sampled_str = decode_tokens(sampled[100:])
        print(sampled_str)

.\lucidrains\mlp-mixer-pytorch\mlp_mixer_pytorch\mlp_mixer_pytorch.py

# 导入需要的模块
from torch import nn
from functools import partial
from einops.layers.torch import Rearrange, Reduce

# 定义一个 lambda 函数,用于确保输入是元组类型
pair = lambda x: x if isinstance(x, tuple) else (x, x)

# 定义一个预标准化残差块
class PreNormResidual(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        return self.fn(self.norm(x)) + x

# 定义一个前馈神经网络层
def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear):
    inner_dim = int(dim * expansion_factor)
    return nn.Sequential(
        dense(dim, inner_dim),
        nn.GELU(),
        nn.Dropout(dropout),
        dense(inner_dim, dim),
        nn.Dropout(dropout)
    )

# 定义一个MLP-Mixer模型
def MLPMixer(*, image_size, channels, patch_size, dim, depth, num_classes, expansion_factor = 4, expansion_factor_token = 0.5, dropout = 0.):
    image_h, image_w = pair(image_size)
    assert (image_h % patch_size) == 0 and (image_w % patch_size) == 0, 'image must be divisible by patch size'
    num_patches = (image_h // patch_size) * (image_w // patch_size)
    chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear

    return nn.Sequential(
        # 重排输入数据,将图像分成多个 patch,并将通道维度放在最后
        Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
        # 将每个 patch 的像素值映射到指定维度
        nn.Linear((patch_size ** 2) * channels, dim),
        # 创建多个深度为 depth 的块
        *[nn.Sequential(
            PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
            PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last))
        ) for _ in range(depth)],
        # 对输出进行标准化
        nn.LayerNorm(dim),
        # 对每个 patch 的特征进行平均池化
        Reduce('b n c -> b c', 'mean'),
        # 将特征映射到类别数量的维度
        nn.Linear(dim, num_classes)
    )

.\lucidrains\mlp-mixer-pytorch\mlp_mixer_pytorch\permutator.py

# 导入需要的模块
from torch import nn
from functools import partial
from einops.layers.torch import Rearrange, Reduce

# 定义一个带有 LayerNorm 的残差块
class PreNormResidual(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        return self.fn(self.norm(x)) + x

# 定义一个并行求和的模块
class ParallelSum(nn.Module):
    def __init__(self, *fns):
        super().__init__()
        self.fns = nn.ModuleList(fns)

    def forward(self, x):
        return sum(map(lambda fn: fn(x), self.fns))

# 定义 Permutator 模块,用于生成一个序列模型
def Permutator(*, image_size, patch_size, dim, depth, num_classes, segments, expansion_factor = 4, dropout = 0.):
    # 检查图像大小是否能被分块大小整除
    assert (image_size % patch_size) == 0, 'image must be divisible by patch size'
    # 检查维度是否能被分段数整除
    assert (dim % segments) == 0, 'dimension must be divisible by the number of segments'
    height = width = image_size // patch_size
    s = segments

    return nn.Sequential(
        # 重排输入数据的维度
        Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_size, p2 = patch_size),
        # 线性变换
        nn.Linear((patch_size ** 2) * 3, dim),
        # 创建深度为 depth 的模块序列
        *[nn.Sequential(
            # 带有残差连接的预层归一化
            PreNormResidual(dim, nn.Sequential(
                # 并行求和模块
                ParallelSum(
                    nn.Sequential(
                        # 重排数据维度
                        Rearrange('b h w (c s) -> b w c (h s)', s = s),
                        nn.Linear(height * s, height * s),
                        Rearrange('b w c (h s) -> b h w (c s)', s = s),
                    ),
                    nn.Sequential(
                        # 重排数据维度
                        Rearrange('b h w (c s) -> b h c (w s)', s = s),
                        nn.Linear(width * s, width * s),
                        Rearrange('b h c (w s) -> b h w (c s)', s = s),
                    ),
                    nn.Linear(dim, dim)
                ),
                nn.Linear(dim, dim)
            )),
            # 带有残差连接的预层归一化
            PreNormResidual(dim, nn.Sequential(
                nn.Linear(dim, dim * expansion_factor),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(dim * expansion_factor, dim),
                nn.Dropout(dropout)
            ))
        ) for _ in range(depth)],
        # 层归一化
        nn.LayerNorm(dim),
        # 对数据进行降维
        Reduce('b h w c -> b c', 'mean'),
        # 线性变换
        nn.Linear(dim, num_classes)
    )

.\lucidrains\mlp-mixer-pytorch\mlp_mixer_pytorch\__init__.py

# 从mlp_mixer_pytorch包中导入MLPMixer类
from mlp_mixer_pytorch.mlp_mixer_pytorch import MLPMixer
# 从mlp_mixer_pytorch包中导入Permutator类
from mlp_mixer_pytorch.permutator import Permutator

MLP Mixer - Pytorch

An All-MLP solution for Vision, from Google AI, in Pytorch.

No convolutions nor attention needed!

Yannic Kilcher video

Install

$ pip install mlp-mixer-pytorch

Usage

import torch
from mlp_mixer_pytorch import MLPMixer

model = MLPMixer(
    image_size = 256,
    channels = 3,
    patch_size = 16,
    dim = 512,
    depth = 12,
    num_classes = 1000
)

img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)

Rectangular image

import torch
from mlp_mixer_pytorch import MLPMixer

model = MLPMixer(
    image_size = (256, 128),
    channels = 3,
    patch_size = 16,
    dim = 512,
    depth = 12,
    num_classes = 1000
)

img = torch.randn(1, 3, 256, 128)
pred = model(img) # (1, 1000)

Citations

@misc{tolstikhin2021mlpmixer,
    title   = {MLP-Mixer: An all-MLP Architecture for Vision},
    author  = {Ilya Tolstikhin and Neil Houlsby and Alexander Kolesnikov and Lucas Beyer and Xiaohua Zhai and Thomas Unterthiner and Jessica Yung and Daniel Keysers and Jakob Uszkoreit and Mario Lucic and Alexey Dosovitskiy},
    year    = {2021},
    eprint  = {2105.01601},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{hou2021vision,
    title   = {Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition},
    author  = {Qibin Hou and Zihang Jiang and Li Yuan and Ming-Ming Cheng and Shuicheng Yan and Jiashi Feng},
    year    = {2021},
    eprint  = {2106.12368},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

.\lucidrains\mlp-mixer-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'mlp-mixer-pytorch', # 包的名称
  packages = find_packages(exclude=[]), # 查找所有包
  version = '0.1.1', # 版本号
  license='MIT', # 许可证
  description = 'MLP Mixer - Pytorch', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  url = 'https://github.com/lucidrains/mlp-mixer-pytorch', # 项目链接
  keywords = [ # 关键词
    'artificial intelligence',
    'deep learning',
    'image recognition'
  ],
  install_requires=[ # 安装依赖
    'einops>=0.3',
    'torch>=1.6'
  ],
  classifiers=[ # 分类
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\mogrifier\mogrifier\mogrifier.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn

# 定义一个函数 weight,用于创建线性层
def weight(dim_in, dim_out, factorize_k = None):
    # 如果没有指定 factorize_k,则直接返回一个线性层
    if factorize_k is None:
        return nn.Linear(dim_in, dim_out, bias = False)

    # 断言 factorize_k 必须小于 dim_in 和 dim_out,否则抛出异常
    assert factorize_k < dim_in and factorize_k < dim_out, 'k must be of relative lower rank'

    # 如果指定了 factorize_k,则返回一个包含两个线性层的序列
    return nn.Sequential(
        nn.Linear(dim_in, factorize_k, bias = False),
        nn.Linear(factorize_k, dim_out, bias = False)
    )

# 定义一个 Mogrifier 类,继承自 nn.Module
class Mogrifier(nn.Module):
    # 初始化方法
    def __init__(self, dim, iters = 5, factorize_k = None):
        super().__init__()
        self.dim = dim
        self.iters = iters

        # 创建 Q 线性层
        self.Q = weight(dim, dim, factorize_k)
        # 如果迭代次数大于 1,则创建 R 线性层,否则为 None
        self.R = weight(dim, dim, factorize_k) if iters > 1 else None

    # 前向传播方法
    def forward(self, x, h):
        shape = x.shape
        *_, dim = shape
        # 断言输入张量的最后一个维度必须等于 self.dim
        assert dim == self.dim, f'mogrifier accepts a dimension of {self.dim}'

        # 将输入张量 x 和 h 重塑为二维张量
        x, h = map(lambda t: t.reshape(-1, dim), (x, h))

        # 迭代执行 Mogrifier 算法
        for ind in range(self.iters):
            if (ind % 2) == 0:
                x = 2 * self.Q(h).sigmoid() * x
            else:
                h = 2 * self.R(x).sigmoid() * h

        # 将 x 和 h 重塑为原始形状
        x, h = map(lambda t: t.reshape(*shape), (x, h))
        return x, h

.\lucidrains\mogrifier\mogrifier\__init__.py

# 从 mogrifier.mogrifier 模块中导入 Mogrifier 类
from mogrifier.mogrifier import Mogrifier

PyPI version

Mogrifier

A complete implementation of Mogrifier, a circuit for enhancing LSTMs and potentially other networks. It allows two vectors to modulate each other by having each gate the other in an interleaved, iterative fashion.

Install

$ pip install mogrifier

Usage

import torch
from mogrifier import Mogrifier

m = Mogrifier(
    dim = 512,
    iters = 5,          # number of iterations, defaults to 5 as paper recommended for LSTM
    factorize_k = 16    # factorize weight matrices into (dim x k) and (k x dim), if specified
)

x = torch.randn(1, 16, 512)
h = torch.randn(1, 16, 512)

x_out, h_out = m(x, h) # (1, 16, 512), (1, 16, 512)

Citation

@inproceedings{Melis2020Mogrifier,
    title={Mogrifier LSTM},
    author={Gábor Melis and Tomáš Kočiský and Phil Blunsom},
    booktitle={International Conference on Learning Representations},
    year={2020},
    url={https://openreview.net/forum?id=SJe5P6EYvS}
}

.\lucidrains\mogrifier\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'mogrifier',
  # 查找并包含所有包
  packages = find_packages(),
  # 版本号
  version = '0.0.3',
  # 许可证信息
  license='MIT',
  # 描述信息
  description = 'Implementation of Mogrifier circuit from Deepmind',
  # 作者信息
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/mogrifier',
  # 关键词
  keywords = ['artificial intelligence', 'natural language processing'],
  # 安装依赖
  install_requires=[
      'torch'
  ],
  # 分类信息
  classifiers=[
      'Development Status :: 4 - Beta',
      'Intended Audience :: Developers',
      'Topic :: Scientific/Engineering :: Artificial Intelligence',
      'License :: OSI Approved :: MIT License',
      'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\molecule-attention-transformer\molecule_attention_transformer\molecule_attention_transformer.py

# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 导入 functools 库中的 partial 函数
from functools import partial
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 einops 库中导入 rearrange 函数
from einops import rearrange

# 常量

# 定义不同距离核函数的字典
DIST_KERNELS = {
    'exp': {
        'fn': lambda t: torch.exp(-t),
        'mask_value_fn': lambda t: torch.finfo(t.dtype).max
    },
    'softmax': {
        'fn': lambda t: torch.softmax(t, dim = -1),
        'mask_value_fn': lambda t: -torch.finfo(t.dtype).max
    }
}

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 返回默认值的函数
def default(val, d):
    return d if not exists(val) else val

# 辅助类

# 残差连接类
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return x + self.fn(x, **kwargs)

# 预层归一化类
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

# 前馈神经网络类
class FeedForward(nn.Module):
    def __init__(self, dim, dim_out = None, mult = 4):
        super().__init__()
        dim_out = default(dim_out, dim)
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Linear(dim * mult, dim_out)
        )

    def forward(self, x):
        return self.net(x)

# 注意力机制类
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, Lg = 0.5, Ld = 0.5, La = 1, dist_kernel_fn = 'exp'):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads= heads
        self.scale = dim_head ** -0.5
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

        # 控制加权线性组合的超参数
        self.La = La
        self.Ld = Ld
        self.Lg = Lg

        self.dist_kernel_fn = dist_kernel_fn

    def forward(self, x, mask = None, adjacency_mat = None, distance_mat = None):
        h, La, Ld, Lg, dist_kernel_fn = self.heads, self.La, self.Ld, self.Lg, self.dist_kernel_fn

        qkv = self.to_qkv(x)
        q, k, v = rearrange(qkv, 'b n (h qkv d) -> b h n qkv d', h = h, qkv = 3).unbind(dim = -2)
        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        assert dist_kernel_fn in DIST_KERNELS, f'distance kernel function needs to be one of {DISTANCE_KERNELS.keys()}'
        dist_kernel_config = DIST_KERNELS[dist_kernel_fn]

        if exists(distance_mat):
            distance_mat = rearrange(distance_mat, 'b i j -> b () i j')

        if exists(adjacency_mat):
            adjacency_mat = rearrange(adjacency_mat, 'b i j -> b () i j')

        if exists(mask):
            mask_value = torch.finfo(dots.dtype).max
            mask = mask[:, None, :, None] * mask[:, None, None, :]

            # 屏蔽注意力
            dots.masked_fill_(~mask, -mask_value)

            if exists(distance_mat):
                # 将距离屏蔽为无穷大
                # 待办事项 - 确保对于 softmax 距离核函数,使用 -无穷大
                dist_mask_value = dist_kernel_config['mask_value_fn'](dots)
                distance_mat.masked_fill_(~mask, dist_mask_value)

            if exists(adjacency_mat):
                adjacency_mat.masked_fill_(~mask, 0.)

        attn = dots.softmax(dim = -1)

        # 从邻接矩阵和距离矩阵中汇总贡献
        attn = attn * La

        if exists(adjacency_mat):
            attn = attn + Lg * adjacency_mat

        if exists(distance_mat):
            distance_mat = dist_kernel_config['fn'](distance_mat)
            attn = attn + Ld * distance_mat

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 主类

class MAT(nn.Module):
    # 初始化函数,设置模型的参数和层结构
    def __init__(
        self,
        *,
        dim_in,  # 输入维度
        model_dim,  # 模型维度
        dim_out,  # 输出维度
        depth,  # 模型深度
        heads = 8,  # 多头注意力机制的头数
        Lg = 0.5,  # 注意力机制中的参数
        Ld = 0.5,  # 注意力机制中的参数
        La = 1,  # 注意力机制中的参数
        dist_kernel_fn = 'exp'  # 距离核函数类型
    ):
        # 调用父类的初始化函数
        super().__init__()

        # 输入到模型的线性变换层
        self.embed_to_model = nn.Linear(dim_in, model_dim)
        # 模型的层列表
        self.layers = nn.ModuleList([])

        # 根据深度循环创建模型的每一层
        for _ in range(depth):
            # 每一层包含一个残差连接和一个预层归一化的注意力机制
            # 以及一个残差连接和一个预层归一化的前馈神经网络
            layer = nn.ModuleList([
                Residual(PreNorm(model_dim, Attention(model_dim, heads = heads, Lg = Lg, Ld = Ld, La = La, dist_kernel_fn = dist_kernel_fn))),
                Residual(PreNorm(model_dim, FeedForward(model_dim)))
            ])
            self.layers.append(layer)

        # 输出的归一化层
        self.norm_out = nn.LayerNorm(model_dim)
        # 输出的前馈神经网络
        self.ff_out = FeedForward(model_dim, dim_out)

    # 前向传播函数
    def forward(
        self,
        x,  # 输入数据
        mask = None,  # 掩码
        adjacency_mat = None,  # 邻接矩阵
        distance_mat = None  # 距离矩阵
    ):
        # 将输入数据进行线性变换
        x = self.embed_to_model(x)

        # 遍历模型的每一层,依次进行注意力机制和前馈神经网络操作
        for (attn, ff) in self.layers:
            x = attn(
                x,
                mask = mask,
                adjacency_mat = adjacency_mat,
                distance_mat = distance_mat
            )
            x = ff(x)

        # 对输出进行归一化
        x = self.norm_out(x)
        # 沿着指定维度求均值
        x = x.mean(dim = -2)
        # 输出的前馈神经网络
        x = self.ff_out(x)
        return x

.\lucidrains\molecule-attention-transformer\molecule_attention_transformer\__init__.py

# 从 molecule_attention_transformer 包中导入 MAT 类
from molecule_attention_transformer.molecule_attention_transformer import MAT

Molecule Attention Transformer - Pytorch (wip)

Pytorch reimplementation of Molecule Attention Transformer, which uses a slightly modified transformer to tackle the graph-like structure of molecules. The repository is also meant to be educational, to understand the limitations of transformers for processing graphs (or perhaps lack thereof).

Update: Reread the paper and results do look convincing. However, I do not like how it still takes hyperparameter sweeps of the relative contributions of the distance, adjacency, and self attention matrices to achieve good results. There must be a more hands-off way

Install

$ pip install molecule-attention-transformer

Usage

import torch
from molecule_attention_transformer import MAT

model = MAT(
    dim_in = 26,
    model_dim = 512,
    dim_out = 1,
    depth = 6,
    Lg = 0.5,                   # lambda (g)raph - weight for adjacency matrix
    Ld = 0.5,                   # lambda (d)istance - weight for distance matrix
    La = 1,                     # lambda (a)ttention - weight for usual self-attention
    dist_kernel_fn = 'exp'      # distance kernel fn - either 'exp' or 'softmax'
)

atoms           = torch.randn(2, 100, 26)
mask            = torch.ones(2, 100).bool()
adjacency_mat   = torch.empty(2, 100, 100).random_(2).float()
distance_mat    = torch.randn(2, 100, 100)

out = model(
    atoms,
    mask = mask,
    adjacency_mat = adjacency_mat,
    distance_mat = distance_mat
) # (2, 1)

Citations

@misc{maziarka2020molecule,
    title={Molecule Attention Transformer}, 
    author={Łukasz Maziarka and Tomasz Danel and Sławomir Mucha and Krzysztof Rataj and Jacek Tabor and Stanisław Jastrzębski},
    year={2020},
    eprint={2002.08264},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

.\lucidrains\molecule-attention-transformer\setup.py

# 导入设置安装和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'molecule-attention-transformer',  # 包的名称
  packages = find_packages(),  # 查找所有包
  version = '0.0.4',  # 版本号
  license='MIT',  # 许可证
  description = 'Molecule Attention Transformer - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/molecule-attention-transformer',  # 项目链接
  keywords = [
    'artificial intelligence',  # 关键词:人工智能
    'attention mechanism',  # 关键词:注意力机制
    'molecules'  # 关键词:分子
  ],
  install_requires=[
    'torch>=1.6',  # 安装依赖:torch 版本大于等于 1.6
    'einops>=0.3'  # 安装依赖:einops 版本大于等于 0.3
  ],
  classifiers=[
    'Development Status :: 4 - Beta',  # 分类:开发状态为 Beta
    'Intended Audience :: Developers',  # 分类:面向的受众为开发者
    'Topic :: Scientific/Engineering :: Artificial Intelligence',  # 分类:主题为科学/工程 - 人工智能
    'License :: OSI Approved :: MIT License',  # 分类:许可证为 MIT
    'Programming Language :: Python :: 3.6',  # 分类:编程语言为 Python 3.6
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

.\lucidrains\multistream-transformers\multistream_transformers\autoregressive_wrapper.py

import torch
from torch import nn
import torch.nn.functional as F

# 定义一个装饰器函数,用于在模型评估时切换为eval模式
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 定义一个函数用于对logits进行top k过滤
def top_k(logits, thres = 0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 定义一个包装类,用于自回归模型
class AutoregressiveWrapper(nn.Module):
    def __init__(self, net, ignore_index = -100, pad_value = 0):
        super().__init__()
        self.pad_value = pad_value
        self.ignore_index = ignore_index

        self.net = net
        self.max_seq_len = net.max_seq_len

    # 生成序列的方法,支持自定义起始标记、序列长度、结束标记、温度、logits过滤函数等参数
    @torch.no_grad()
    @eval_decorator
    def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
        device = start_tokens.device
        num_dims = len(start_tokens.shape)

        if num_dims == 1:
            start_tokens = start_tokens[None, :]

        b, t = start_tokens.shape

        out = start_tokens

        for _ in range(seq_len):
            x = out[:, -self.max_seq_len:]

            logits = self.net(x, **kwargs)[:, -1, :]
            filtered_logits = top_k(logits, thres = filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim=-1)

            sample = torch.multinomial(probs, 1)
            out = torch.cat((out, sample), dim=-1)

            if eos_token is not None and (sample == eos_token).all():
                break

        out = out[:, t:]

        if num_dims == 1:
            out = out.squeeze(0)

        return out

    # 前向传播方法,计算损失值
    def forward(self, x, **kwargs):
        xi, xo = x[:, :-1], x[:, 1:]
        out = self.net(xi, **kwargs)
        loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
        return loss

.\lucidrains\multistream-transformers\multistream_transformers\multistream_transformers.py

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat, reduce

from einops.layers.torch import Rearrange

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 返回给定数据类型的最小负值
def max_neg_value(t):
    return -torch.finfo(t.dtype).max

# 对所有张量进行重排列
def rearrange_all(tensors, *args, **kwargs):
    return map(lambda t: rearrange(t, *args, **kwargs), tensors)

# 前馈网络

# 分组层归一化
class GroupLayerNorm(nn.Module):
    def __init__(self, dim, groups = 1, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.groups = groups
        self.g = nn.Parameter(torch.ones(1, groups, dim, 1))
        self.b = nn.Parameter(torch.zeros(1, groups, dim, 1))

    def forward(self, x):
        x = rearrange(x, 'b (g d) n -> b g d n', g = self.groups)
        std = torch.var(x, dim = 2, unbiased = False, keepdim = True).sqrt()
        mean = torch.mean(x, dim = 2, keepdim = True)
        out = (x - mean) / (std + self.eps) * self.g + self.b
        return rearrange(out, 'b g d n -> b (g d) n')

# 预归一化
class PreNorm(nn.Module):
    def __init__(
        self,
        dim,
        fn,
        groups = 1
    ):
        super().__init__()
        self.norm = GroupLayerNorm(dim, groups = groups)
        self.fn = fn

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

# 前馈网络
class FeedForward(nn.Module):
    def __init__(
        self,
        *,
        dim,
        mult = 4,
        groups = 1
    ):
        super().__init__()
        input_dim = dim * groups
        hidden_dim = dim * mult * groups

        self.net = nn.Sequential(
            nn.Conv1d(input_dim, hidden_dim, 1, groups = groups),
            nn.GELU(),
            nn.Conv1d(hidden_dim, input_dim, 1, groups = groups)
        )

    def forward(self, x):
        return self.net(x)

# 注意力机制
class Attention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        causal = False,
        groups = 1
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.groups = groups
        self.heads = heads
        self.causal = causal
        input_dim = dim * groups
        inner_dim = dim_head * heads * groups

        self.to_q = nn.Conv1d(input_dim, inner_dim, 1, bias = False)
        self.to_kv = nn.Conv1d(input_dim, inner_dim * 2, 1, bias = False)
        self.to_out = nn.Conv1d(inner_dim, input_dim, 1)

    def forward(self, x, mask = None, context = None):
        n, device, h, g, causal = x.shape[2], x.device, self.heads, self.groups, self.causal
        context = default(context, x)

        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = 1))
        q, k, v = rearrange_all((q, k, v), 'b (g h d) n -> (b g h) n d', g = g, h = h)

        q = q * self.scale

        sim = einsum('b i d, b j d -> b i j', q, k)

        if exists(mask):
            mask = repeat(mask, 'b n -> (b g h) n', h = h, g = g)
            mask = rearrange(mask, 'b n -> b n ()') * rearrange(mask, 'b n -> b () n')
            mask_value = max_neg_value(sim)
            sim = sim.masked_fill(~mask, mask_value)

        if causal:
            causal_mask = torch.ones((n, n), device = device).triu(1).bool()
            mask_value = max_neg_value(sim)
            sim = sim.masked_fill(causal_mask, mask_value)

        attn = sim.softmax(dim = -1)
        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b g h) n d -> b (g h d) n', h = h, g = g)
        return self.to_out(out)

# Transformer 块
class TransformerBlock(nn.Module):
    def __init__(
        self,
        *,
        dim,
        causal = False,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        groups = 1
    ):  
        # 调用父类的构造函数
        super().__init__()
        # 初始化注意力层,包括预层归一化、注意力机制和前馈神经网络
        self.attn = PreNorm(dim, Attention(dim = dim, dim_head = dim_head, heads = heads, causal = causal, groups = groups), groups = groups)
        # 初始化前馈神经网络层,包括预层归一化和前馈神经网络
        self.ff = PreNorm(dim, FeedForward(dim = dim, mult = ff_mult, groups = groups), groups = groups)

    def forward(self, x, mask = None):
        # 使用注意力层处理输入数据,并将结果与输入数据相加
        x = self.attn(x, mask = mask) + x
        # 使用前馈神经网络层处理上一步的结果,并将结果与上一步的结果相加
        x = self.ff(x) + x
        # 返回处理后的结果
        return x
# 主类定义

class MultistreamTransformer(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        dim,  # 维度
        depth,  # 深度
        num_tokens,  # 令牌数量
        max_seq_len,  # 最大序列长度
        causal = False,  # 是否因果
        dim_head = 64,  # 头维度
        heads = 8,  # 头数
        ff_mult = 4,  # FeedForward倍数
        num_streams = 1  # 流数量
    ):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.num_streams = num_streams
        self.token_emb = nn.Embedding(num_tokens, dim)  # 令牌嵌入层
        self.pos_emb = nn.Embedding(max_seq_len, dim)  # 位置嵌入层

        self.layers = nn.ModuleList([])
        self.pre_transformer_block = TransformerBlock(dim = dim, causal = causal, dim_head = dim_head, heads = heads)  # 前置Transformer块

        for _ in range(depth):
            self.layers.append(TransformerBlock(dim = dim, causal = causal, dim_head = dim_head, heads = heads, groups = num_streams))  # 添加指定数量的Transformer块

        if num_streams > 1:
            self.query = nn.Parameter(torch.randn(dim))  # 查询参数
            self.attn_pool = Attention(dim = dim, dim_head = dim_head, heads = heads)  # 注意力池化层

        self.post_transformer_block = TransformerBlock(dim = dim, causal = causal, dim_head = dim_head, heads = heads,)  # 后置Transformer块

        self.to_logits = nn.Sequential(
            Rearrange('b d n -> b n d'),  # 重排维度
            nn.LayerNorm(dim),  # 层归一化
            nn.Linear(dim, num_tokens)  # 线性层
        )

    # 前向传播函数
    def forward(self, x, mask = None):
        b, n, d, device, is_multistream = *x.shape, self.dim, x.device, (self.num_streams > 1)  # 获取输入张量的形状和设备信息,判断是否为多流模式
        x = self.token_emb(x)  # 令牌嵌入

        pos_emb = self.pos_emb(torch.arange(n, device = device))  # 位置嵌入
        pos_emb = rearrange(pos_emb, 'n d -> () n d')  # 重排维度

        x = x + pos_emb  # 加上位置嵌入
        x = rearrange(x, 'b n d -> b d n')  # 重排维度

        x = self.pre_transformer_block(x, mask = mask)  # 前置Transformer块处理输入
        layers = [x]  # 存储每一层的输出

        if is_multistream:
            x = repeat(x, 'b d n -> b (s d) n', s = self.num_streams)  # 复制张量以支持多流模式

        for block in self.layers:
            x = block(x, mask = mask)  # 处理每个Transformer块
            layers.append(x)  # 存��每一层的输出

        if is_multistream:
            layers = list(map(lambda t: rearrange(t, 'b (s d) n -> (b n) d s', d = d), layers))  # 重排维度以支持多流模式
            layer_tokens = torch.cat(layers, dim = -1)  # 拼接多个流的输出

            query = repeat(self.query, 'd -> b d ()', b = layer_tokens.shape[0])  # 复制查询参数
            x = self.attn_pool(query, context = layer_tokens)  # 使用注意力池化层
            x = rearrange(x, '(b n) d () -> b d n', n = n)  # 重排维度

        x = self.post_transformer_block(x, mask = mask)  # 后置Transformer块处理输出
        return self.to_logits(x)  # 返回预测结果
posted @ 2024-06-28 14:01  绝不原创的飞龙  阅读(9)  评论(0编辑  收藏  举报