Lucidrains-系列项目源码解析-四十五-

Lucidrains 系列项目源码解析(四十五)

.\lucidrains\triton-transformer\triton_transformer\cross_entropy.py

# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 函数
from einops import rearrange

# 导入 triton 库
import triton
# 从 triton 库中导入 language 模块并重命名为 tl
import triton.language as tl

# 定义交叉熵损失函数,接受 logits(预测值)、labels(真实标签)、ignore_index(忽略的索引,默认为0)、use_triton(是否使用 triton 加速,默认为 False)
def cross_entropy_fn(logits, labels, ignore_index = 0., use_triton = False):
    # 重新排列 logits 张量的维度,将 'b n c' 转换为 '(b n) c'
    logits = rearrange(logits, 'b n c -> (b n) c')
    # 重新排列 labels 张量的维度,将 'b n' 转换为 '(b n)'
    labels = rearrange(labels, 'b n -> (b n)')

    # 如果 use_triton 为 True,则使用 triton 库中的 cross_entropy 函数计算损失
    if use_triton:
        loss = triton.ops.cross_entropy(logits, labels)        
    # 否则使用 torch.nn.functional 库中的 cross_entropy 函数计算损失
    else:
        loss = F.cross_entropy(logits, labels, reduction = 'none')

    # 创建一个掩码,标记 labels 中不等于 ignore_index 的位置
    mask = (labels != ignore_index)
    # 返回经过掩码处理后的损失的均值
    return loss[mask].mean()

.\lucidrains\triton-transformer\triton_transformer\dropout.py

# 导入所需的库
import torch
from torch import autograd
import torch.nn.functional as F
import triton
import triton.language as tl
from random import randrange

# 定义常量 BLOCK_SIZE
BLOCK_SIZE = 1024

# Triton JIT 编译的函数,实现带有随机种子的 dropout 操作
@triton.jit
def _seeded_dropout(x_ptr, output_ptr, n_elements, p, seed, **meta):
    BLOCK_SIZE = meta['BLOCK_SIZE']
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE * 4

    off0 = block_start + BLOCK_SIZE * 0 + tl.arange(0, BLOCK_SIZE)
    off1 = block_start + BLOCK_SIZE * 1 + tl.arange(0, BLOCK_SIZE)
    off2 = block_start + BLOCK_SIZE * 2 + tl.arange(0, BLOCK_SIZE)
    off3 = block_start + BLOCK_SIZE * 3 + tl.arange(0, BLOCK_SIZE)

    mask0 = off0 < n_elements
    mask1 = off1 < n_elements
    mask2 = off2 < n_elements
    mask3 = off3 < n_elements

    x0 = tl.load(x_ptr + off0, mask = mask0)
    x1 = tl.load(x_ptr + off1, mask = mask1)
    x2 = tl.load(x_ptr + off2, mask = mask2)
    x3 = tl.load(x_ptr + off3, mask = mask3)

    r0, r1, r2, r3 = tl.random.rand4x(seed, off0)
    keep0, keep1, keep2, keep3 = r0 > p, r1 > p, r2 > p, r3 > p

    o0 = tl.where(keep0, x0 / (1 - p), 0.0)
    o1 = tl.where(keep1, x1 / (1 - p), 0.0)
    o2 = tl.where(keep2, x2 / (1 - p), 0.0)
    o3 = tl.where(keep3, x3 / (1 - p), 0.0)

    tl.store(output_ptr + off0, o0, mask = mask0)
    tl.store(output_ptr + off1, o1, mask = mask1)
    tl.store(output_ptr + off2, o2, mask = mask2)
    tl.store(output_ptr + off3, o3, mask = mask3)

# 带有随机种子的 dropout 操作的包装函数
def seeded_dropout(x, p, seed):
    output = torch.empty_like(x)
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE'] * 4),)
    _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE = BLOCK_SIZE)
    return output

# 自定义 autograd.Function 类,实现 dropout 操作
class dropout_(autograd.Function):
    @classmethod
    def forward(cls, ctx, x, p):
        seed = randrange(int(1e6))
        ctx.p = p
        ctx.seed = seed
        return seeded_dropout(x, p, seed)

    @classmethod
    def backward(cls, ctx, dy):
        p = ctx.p
        seed = ctx.seed
        return seeded_dropout(dy, p, seed), None

# dropout 操作的函数,根据 use_triton 参数选择使用 Triton 实现的 dropout 还是 PyTorch 自带的 dropout
def dropout_fn(x, p, use_triton = False):
    if p == 0. or not x.requires_grad:
        return x

    if not use_triton:
        return F.dropout(x, p, training = True)

    return dropout_.apply(x, p)

.\lucidrains\triton-transformer\triton_transformer\layernorm.py

# 导入 torch 库
import torch
# 从 torch 库中导入 autograd 模块
from torch import autograd
# 从 torch 库中导入 functional 模块
import torch.nn.functional as F

# 导入 triton 库
import triton
# 从 triton 库中导入 language 模块并重命名为 tl
import triton.language as tl

# 从 triton_transformer.utils 模块中导入 calc_num_warps 和 exists 函数
from triton_transformer.utils import calc_num_warps, exists

# 定义 GAMMA_BLOCK_SIZE 常量为 64
GAMMA_BLOCK_SIZE = 64
# 定义 GAMMA_ROW_BLOCK_SIZE 常量为 64
GAMMA_ROW_BLOCK_SIZE = 64

# 定义 layernorm_kernel_forward_training 函数
@triton.jit
def layernorm_kernel_forward_training(
    output_ptr,
    mean_centered_ptr,
    normed_ptr,
    input_ptr,
    gamma_ptr,
    input_row_stride,
    gamma_row_stride,
    output_row_stride,
    mean_centered_row_stride,
    normed_row_stride,
    n_cols,
    stable,
    eps,
    **meta
):
    # 获取当前程序的 ID
    row_idx = tl.program_id(0)
    # 从 meta 中获取 BLOCK_SIZE 常量
    BLOCK_SIZE = meta['BLOCK_SIZE']

    # 计算当前行的起始指针
    row_start_ptr = input_ptr + row_idx * input_row_stride
    # 计算当前行 gamma 的起始指针
    gamma_row_start_ptr = gamma_ptr + row_idx * gamma_row_stride

    # 生成列偏移量
    col_offsets = tl.arange(0, BLOCK_SIZE)
    # 计算当前行的输入指针
    input_ptrs = row_start_ptr + col_offsets
    # 计算当前行的 gamma 指针
    gamma_ptrs = gamma_row_start_ptr + col_offsets

    # 创建一个掩码,用于处理列偏移量小于 n_cols 的情况
    mask = col_offsets < n_cols
    # 从输入指针处加载数据到 row,如果掩码为 False,则加载 0.0
    row = tl.load(input_ptrs, mask=mask, other=0.)
    # 从 gamma 指针处加载数据到 gammas,如果掩码为 False,则加载 0.0
    gammas = tl.load(gamma_ptrs, mask=mask, other=0.)

    # 如果 stable 为 True
    if stable:
        # 计算当前行的最大值
        row_max = tl.max(tl.where(mask, row, float('-inf')), axis=0)
        # 对当前行进行归一化
        row /= row_max

    # 计算当前行的均值
    row_mean = tl.sum(row, axis=0) / n_cols
    # 计算当前行的中心化值
    row_mean_centered = tl.where(mask, row - row_mean, 0.)
    # 计算当前行的方差
    row_var = tl.sum(row_mean_centered * row_mean_centered, axis=0) / n_cols
    # 计算当前行的标准差的倒数
    inv_var = 1. / tl.sqrt(row_var + eps)
    # 计算当前行的归一化值
    normed = row_mean_centered * inv_var

    # 计算输出值
    output = normed * gammas

    # 计算输出行的起始指针
    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    # 计算输出指针
    output_ptrs = output_row_start_ptr + col_offsets
    # 将输出值存储到输出指针处
    tl.store(output_ptrs, output, mask=mask)

    # 计算中心化行的起始指针
    mean_centered_row_start_ptr = mean_centered_ptr + row_idx * mean_centered_row_stride
    # 计算中心化指针
    mean_centered_ptrs = mean_centered_row_start_ptr + col_offsets
    # 将中心化值存储到中心化指针处
    tl.store(mean_centered_ptrs, row_mean_centered, mask=mask)

    # 计算归一化行的起始指针
    normed_row_start_ptr = normed_ptr + row_idx * normed_row_stride
    # 计算归一化指针
    normed_ptrs = normed_row_start_ptr + col_offsets
    # 将归一化值存储到归一化指针处
    tl.store(normed_ptrs, normed, mask=mask)

# 定义 layernorm_kernel_forward_inference 函数
@triton.jit
def layernorm_kernel_forward_inference(
    output_ptr,
    input_ptr,
    gamma_ptr,
    input_row_stride,
    gamma_row_stride,
    output_row_stride,
    n_cols,
    stable,
    eps,
    **meta
):
    # 获取当前程序的 ID
    row_idx = tl.program_id(0)
    # 从 meta 中获取 BLOCK_SIZE 常量
    BLOCK_SIZE = meta['BLOCK_SIZE']

    # 计算当前行的起始指针
    row_start_ptr = input_ptr + row_idx * input_row_stride
    # 计算当前行 gamma 的起始指针
    gamma_row_start_ptr = gamma_ptr + row_idx * gamma_row_stride

    # 生成列偏移量
    col_offsets = tl.arange(0, BLOCK_SIZE)
    # 计算当前行的输入指针
    input_ptrs = row_start_ptr + col_offsets
    # 计算当前行的 gamma 指针
    gamma_ptrs = gamma_row_start_ptr + col_offsets

    # 创建一个掩码,用于处理列偏移量小于 n_cols 的情况
    mask = col_offsets < n_cols
    # 从输入指针处加载数据到 row,如果掩码为 False,则加载 0.0
    row = tl.load(input_ptrs, mask=mask, other=0.)
    # 从 gamma 指针处加载数据到 gammas,如果掩码为 False,则加载 0.0
    gammas = tl.load(gamma_ptrs, mask=mask, other=0.)

    # 如果 stable 为 True
    if stable:
        # 计算当前行的最大值
        row_max = tl.max(tl.where(mask, row, float('-inf')), axis=0)
        # 对当前行进行归一化
        row /= row_max

    # 计算当前行的均值
    row_mean = tl.sum(row, axis=0) / n_cols
    # 计算当前行的中心化值
    row_mean_centered = tl.where(mask, row - row_mean, 0.)
    # 计算当前行的方差
    row_var = tl.sum(row_mean_centered * row_mean_centered, axis=0) / n_cols
    # 计算当前行的标准差的倒数
    inv_var = 1. / tl.sqrt(row_var + eps)
    # 计算当前行的归一化值
    normed = row_mean_centered * inv_var

    # 计算输出值
    output = normed * gammas

    # 计算输出行的起始指针
    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    # 计算输出指针
    output_ptrs = output_row_start_ptr + col_offsets
    # 将输出值存储到输出指针处
    tl.store(output_ptrs, output, mask=mask)

# 定义 layernorm_kernel_backward 函数
@triton.jit
def layernorm_kernel_backward(
    output_ptr,
    dy_ptr,
    mean_centered_ptr,
    output_row_stride,
    dy_row_stride,
    mean_centered_row_stride,
    n_cols,
    eps,
    **meta
):
    # 获取当前程序的 ID
    row_idx = tl.program_id(0)
    # 从 meta 中获取 BLOCK_SIZE 常量
    BLOCK_SIZE = meta['BLOCK_SIZE']

    # 计算当前行的 dy 起始指针
    dy_row_start_ptr = dy_ptr + row_idx * dy_row_stride
    # 计算当前行的中心化值起始指针
    mean_centered_row_start_ptr = mean_centered_ptr + row_idx * mean_centered_row_stride

    # 生成列偏移量
    col_offsets = tl.arange(0, BLOCK_SIZE)
    # 计算当前行的 dy 指针
    dy_ptrs = dy_row_start_ptr + col_offsets
    # 计算当前行的中心化值指针
    mean_centered_ptrs = mean_centered_row_start_ptr + col_offsets

    # 创建一个掩码,用于处理列偏移量小于 n_cols 的情况
    mask = col_offsets < n_cols

    # 从 dy 指针处加载数据到 dy,如果掩码为 False,则加载 0.0
    dy = tl.load(dy_ptrs, mask=mask, other=0.)
    # 从中心化值指针处加载数据到 mean_centered,如果掩码为 False,则加载 0.0
    mean_centered = tl.load(mean_centered_ptrs, mask=mask, other=0.)
    # 计算每行的方差
    row_var = tl.sum(mean_centered * mean_centered, axis=0) / n_cols
    # 计算每行的标准差的倒数
    inv_var = 1. / tl.sqrt(row_var + eps)
    # 对数据进行标准化处理
    normed = mean_centered * inv_var

    # 计算输出值
    output = 1. / n_cols * inv_var * (n_cols * dy - tl.sum(dy, axis=0) - normed * tl.sum(dy * normed, axis=0))

    # 计算输出行的起始指针
    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    # 计算输出指针数组
    output_ptrs = output_row_start_ptr + col_offsets
    # 存储输出数据到指定的指针位置,使用掩码进行过滤
    tl.store(output_ptrs, output, mask=mask)
# 定义一个使用 Triton JIT 编译的函数,用于计算 LayerNorm 操作的 gamma 反向传播
def layernorm_gamma_kernel_backward(
    dgamma_ptr,  # 存储计算得到的 dgamma 结果的指针
    norm_ptr,  # 存储 norm 数据的指针
    dy_ptr,  # 存储 dy 数据的指针
    norm_stride,  # norm 数据的步长
    dy_stride,  # dy 数据的步长
    dgamma_row_stride,  # dgamma 行步长
    n_rows,  # 数据行数
    n_cols,  # 数据列数
    **meta  # 其他元数据
):
    # 获取当前程序的列索引和行索引
    col_idx = tl.program_id(0)
    row_idx = tl.program_id(1)
    # 从元数据中获取 BLOCK_SIZE 和 ROW_BLOCK_SIZE
    BLOCK_SIZE = meta['BLOCK_SIZE']
    ROW_BLOCK_SIZE = meta['BLOCK_SIZE_ROW']

    # 创建列偏移量和行偏移量
    col_offsets = tl.arange(0, BLOCK_SIZE)
    row_offsets = tl.arange(0, ROW_BLOCK_SIZE)

    # 计算列范围和行范围
    col_range = col_idx * BLOCK_SIZE + col_offsets
    row_range = row_idx * ROW_BLOCK_SIZE + row_offsets

    # 创建列掩码
    col_mask = col_range < n_cols
    # 创建掩码,用于过滤超出数据范围的行列
    mask = (row_range < n_rows)[:, None] & col_mask[None, :]

    # 更新 dy_ptr 和 norm_ptr 指针位置
    dy_ptr += row_range[:, None] * dy_stride + col_range[None, :]
    norm_ptr += row_range[:, None] * norm_stride + col_range[None, :]

    # 从指定位置加载 dy 和 norm 数据
    dy = tl.load(dy_ptr, mask=mask, other=0.)
    norm = tl.load(norm_ptr, mask=mask, other=0.)

    # 计算 dgamma
    dgamma = tl.sum(dy * norm, axis=0)

    # 更新 dgamma_ptr 指针位置
    dgamma_ptr += row_idx * dgamma_row_stride + col_range

    # 存储计算得到的 dgamma 结果
    tl.store(dgamma_ptr, dgamma, mask=col_mask)

# 定义一个 autograd 函数 _layernorm
class _layernorm(autograd.Function):
    @classmethod
    def forward(cls, ctx, x, gamma, eps, stable):
        # 获取输入 x 的形状和维度
        shape = x.shape
        dim = shape[-1]
        x = x.view(-1, dim)
        n_rows, n_cols = x.shape

        # 扩展 gamma 到与 x 相同的形状
        expanded_gamma = gamma[None, :].expand(n_rows, -1)

        # 计算 BLOCK_SIZE 和 num_warps
        BLOCK_SIZE = triton.next_power_of_2(n_cols)
        num_warps = calc_num_warps(BLOCK_SIZE)

        # 创建一个与 x 相同形状的输出张量
        out = torch.empty_like(x)

        # 保存 eps 到上下文中
        ctx.eps = eps

        if x.requires_grad:
            # 创建 scaled_x 和 normed 张量
            scaled_x = torch.empty_like(x)
            normed = torch.empty_like(x)

            # 调用 layernorm_kernel_forward_training 函数进行前向传播计算
            layernorm_kernel_forward_training[(n_rows,)](
                out,
                scaled_x,
                normed,
                x,
                expanded_gamma,
                x.stride(0),
                expanded_gamma.stride(0),
                out.stride(0),
                scaled_x.stride(0),
                normed.stride(0),
                n_cols,
                stable,
                eps,
                num_warps=num_warps,
                BLOCK_SIZE=BLOCK_SIZE,
            )
            # 保存 scaled_x, gamma, out 到上下文中
            ctx.save_for_backward(scaled_x, gamma, out)
        else:
            # 调用 layernorm_kernel_forward_inference 函数进行前向传播计算(无梯度)
            layernorm_kernel_forward_inference[(n_rows,)](
                out,
                x,
                expanded_gamma,
                x.stride(0),
                expanded_gamma.stride(0),
                out.stride(0),
                n_cols,
                stable,
                eps,
                num_warps=num_warps,
                BLOCK_SIZE=BLOCK_SIZE,
            )

        # 返回输出张量,并恢复原始形状
        return out.view(*shape)

    @classmethod
    def backward(cls, ctx, dy):
        # 获取 dy 的形状和设备信息
        shape, device = dy.shape, dy.device
        dim = shape[-1]
        dy = dy.view(-1, dim)

        # 从上下文中获取保存的 scaled_x, gamma, normed 张量
        scaled_x, gamma, normed = ctx.saved_tensors

        n_rows, n_cols = dy.shape

        # 计算 num_col_programs 和 num_row_programs
        num_col_programs = triton.cdiv(n_cols, GAMMA_BLOCK_SIZE)
        num_row_programs = triton.cdiv(n_rows, GAMMA_ROW_BLOCK_SIZE)

        # 创建一个用于存储 dgamma 的张量
        dgamma = torch.empty((num_row_programs, n_cols), device=device)

        # 调用 layernorm_gamma_kernel_backward 函数进行 gamma 反向传播计算
        layernorm_gamma_kernel_backward[(num_col_programs, num_row_programs)](
            dgamma,
            normed,
            dy,
            normed.stride(0),
            dy.stride(0),
            dgamma.stride(0),
            n_rows,
            n_cols,
            num_warps=4,
            BLOCK_SIZE=GAMMA_BLOCK_SIZE,
            BLOCK_SIZE_ROW=GAMMA_ROW_BLOCK_SIZE
        )

        # 对 dgamma 沿指定维度求和
        dgamma = dgamma.sum(dim=0)

        # 计算 dxhat 和 dx
        dxhat = dy * gamma
        dx = torch.empty_like(dy)

        # 计算 BLOCK_SIZE 和 num_warps
        BLOCK_SIZE = triton.next_power_of_2(n_cols)
        num_warps = calc_num_warps(BLOCK_SIZE)

        # 调用 layernorm_kernel_backward 函数进行反向传播计算
        layernorm_kernel_backward[(n_rows,)](
            dx,
            dxhat,
            scaled_x,
            dx.stride(0),
            dxhat.stride(0),
            scaled_x.stride(0),
            n_cols,
            ctx.eps,
            num_warps=num_warps,
            BLOCK_SIZE=BLOCK_SIZE,
        )

        # 恢复原始形状并返回 dx, dgamma
        dx = dx.view(*shape)
        return dx, dgamma, None, None
# 对输入数据进行 Layer Normalization 处理
def layernorm(x, gamma, eps = 1e-5, use_triton = False, stable = False):
    # 如果使用 Triton 加速库
    if use_triton:
        # 调用 Triton 提供的 Layer Normalization 函数
        out = _layernorm.apply(x, gamma, eps, stable)
    else:
        # 如果不使用 Triton 加速库
        if stable:
            # 对输入数据进行稳定处理,将每个元素除以最大值
            x = x / torch.amax(x, dim = -1, keepdim = True)
        # 使用 PyTorch 提供的 Layer Normalization 函数
        out = F.layer_norm(x, (x.shape[-1],), gamma, torch.zeros_like(gamma), eps = eps)
    # 返回处理后的数据
    return out

.\lucidrains\triton-transformer\triton_transformer\softmax.py

# 导入 torch 库
import torch
# 从 torch 库中导入 autograd 模块
from torch import autograd
# 从 torch.nn.functional 模块中导入 F 函数
import torch.nn.functional as F

# 导入 triton 库
import triton
# 从 triton.language 模块中导入 tl
import triton.language as tl
# 从 triton_transformer.utils 模块中导入 calc_num_warps 函数
from triton_transformer.utils import calc_num_warps

# 定义 softmax_kernel_forward 函数,使用 triton.jit 装饰器
@triton.jit
def softmax_kernel_forward(
    output_ptr,
    input_ptr,
    input_row_stride,
    output_row_stride,
    n_cols,
    causal,
    **meta
):
    # 获取当前程序的行索引
    row_idx = tl.program_id(0)
    # 获取 meta 字典中的 BLOCK_SIZE 值
    BLOCK_SIZE = meta['BLOCK_SIZE']

    # 计算当前行的起始指针
    row_start_ptr = input_ptr + row_idx * input_row_stride

    # 生成列偏移量
    col_offsets = tl.arange(0, BLOCK_SIZE)
    # 计算输入指针
    input_ptrs = row_start_ptr + col_offsets

    # 创建掩码,用于处理超出列数的情况
    mask = col_offsets < n_cols

    # 从输入指针加载数据到 row 变量,处理超出列数的情况
    row = tl.load(input_ptrs, mask = mask, other = -float('inf'))

    # 如果是因果的情况,进行处理
    if causal:
        causal_mask = col_offsets > (row_idx % n_cols)
        row = row + tl.where(causal_mask, -float('inf'), 0.)

    # 计算 row 减去最大值
    row_minus_max = row - tl.max(row, axis=0)

    # 计算 softmax 的分子
    numerator = tl.exp(row_minus_max)
    # 计算 softmax 的分母
    denominator = tl.sum(numerator, axis=0)
    # 计算 softmax 输出
    softmax_output = numerator / denominator

    # 计算输出行的起始指针
    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    # 计算输出指针
    output_ptrs = output_row_start_ptr + col_offsets
    # 存储 softmax 输出到输出指针,处理超出列数的情况
    tl.store(output_ptrs, softmax_output, mask = mask)

# 定义 softmax_kernel_backward 函数,使用 triton.jit 装饰器
@triton.jit
def softmax_kernel_backward(
    output_ptr,
    input_ptr,
    grad_ptr,
    grad_row_stride,
    input_row_stride,
    output_row_stride,
    n_cols,
    **meta
):
    # 获取当前程序的行索引
    row_idx = tl.program_id(0)
    # 获取 meta 字典中的 BLOCK_SIZE 值
    BLOCK_SIZE = meta['BLOCK_SIZE']

    # 计算当前行的起始指针
    row_start_ptr = input_ptr + row_idx * input_row_stride
    grad_row_start_ptr = grad_ptr + row_idx * grad_row_stride

    # 生成列偏移量
    col_offsets = tl.arange(0, BLOCK_SIZE)
    # 计算输入指针和梯度指���
    input_ptrs = row_start_ptr + col_offsets
    grad_ptrs = grad_row_start_ptr + col_offsets

    # 创建掩码,用于处理超出列数的情况
    mask = col_offsets < n_cols

    # 从输入指针加载数据到 probs_row 变量,处理超出列数的情况
    probs_row = tl.load(input_ptrs, mask = mask, other = 0.)
    # 从梯度指针加载数据到 grad_row 变量,处理超出列数的情况
    grad_row = tl.load(grad_ptrs, mask = mask, other = 0.)

    # 计算 dxhat
    dxhat = probs_row * grad_row
    # 计算 softmax 梯度输出
    softmax_grad_output = dxhat - probs_row * tl.sum(dxhat, axis = 0)

    # 计算输出行的起始指针
    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    # 计算输出指针
    output_ptrs = output_row_start_ptr + col_offsets
    # 存储 softmax 梯度输出到输出指针,处理超出列数的情况
    tl.store(output_ptrs, softmax_grad_output, mask = mask)

# 定义 _softmax 类,继承自 autograd.Function
class _softmax(autograd.Function):
    # 前向传播函数
    @classmethod
    def forward(self, ctx, x, causal):
        # 获取输入张量的形状
        shape = x.shape
        # 将输入张量展平为二维张量
        x = x.view(-1, shape[-1])
        n_rows, n_cols = x.shape

        # 计算 BLOCK_SIZE 和 num_warps
        BLOCK_SIZE = triton.next_power_of_2(n_cols)
        num_warps = calc_num_warps(BLOCK_SIZE)

        # 创建与输入张量相同形状的空张量 y
        y = torch.empty_like(x)

        # 调用 softmax_kernel_forward 函数进行前向传播计算
        softmax_kernel_forward[(n_rows,)](
            y,
            x,
            x.stride(0),
            y.stride(0),
            n_cols,
            causal,
            num_warps = num_warps,
            BLOCK_SIZE = BLOCK_SIZE,
        )

        # 如果输入张量需要梯度,保存 y 用于反向传播
        if x.requires_grad:
            ctx.save_for_backward(y)
        return y.view(*shape)

    # 反向传播函数
    @classmethod
    def backward(self, ctx, grad_probs):
        # 获取梯度张量的形状
        shape = grad_probs.shape
        # 从上下文中获取保存的张量 probs
        probs, = ctx.saved_tensors

        # 将梯度张量展平为二维张量
        grad_probs = grad_probs.view(-1, grad_probs.shape[-1])
        n_rows, n_cols = grad_probs.shape

        # 计算 BLOCK_SIZE 和 num_warps
        BLOCK_SIZE = triton.next_power_of_2(n_cols)
        num_warps = calc_num_warps(BLOCK_SIZE)

        # 创建与 probs 张量相同形状的空张量 dx
        dx = torch.empty_like(probs)

        # 调用 softmax_kernel_backward 函数进行反向传播计算
        softmax_kernel_backward[(n_rows,)](
            dx,
            probs,
            grad_probs,
            grad_probs.stride(0),
            probs.stride(0),
            dx.stride(0),
            n_cols,
            num_warps = num_warps,
            BLOCK_SIZE = BLOCK_SIZE
        )

        # 返回 dx 和 None,None 表示不需要额外的梯度信息
        return dx.view(*shape), None

# 定义 triton_softmax 函数,调用 _softmax 类的 apply 方法
triton_softmax = _softmax.apply

# 定义 softmax 函数,实现 softmax 操作
def softmax(x, causal = False, use_triton = False):
    # 如果使用 triton 进行计算
    if use_triton:
        # 调用 triton_softmax 函数
        return triton_softmax(x, causal)
    else:
        # 使用 PyTorch 的 F.softmax 函数
        return F.softmax(x, dim = -1)

.\lucidrains\triton-transformer\triton_transformer\transformer.py

# 导入必要的库
from functools import partial
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange

# 导入自定义的模块
from triton_transformer.layernorm import layernorm
from triton_transformer.softmax import softmax
from triton_transformer.cross_entropy import cross_entropy_fn
from triton_transformer.bmm import fused_relu_squared
from triton_transformer.dropout import dropout_fn
from triton_transformer.utils import exists, default

# 定义类

class PreNormResidual(nn.Module):
    def __init__(self, dim, fn, use_triton = False):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
        self.use_triton = use_triton

    def forward(self, x, **kwargs):
        use_triton = kwargs.get('use_triton', self.use_triton)
        normed = layernorm(x, self.norm.weight, use_triton = use_triton)
        return self.fn(normed, **kwargs) + x

# 辅助类

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        causal = False,
        dropout = 0.,
        use_triton = False
    ):
        super().__init__()
        self.use_triton = use_triton
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.causal = causal
        inner_dim = dim_head * heads
        self.dropout = dropout

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x, mask = None, use_triton = None):
        use_triton = default(use_triton, self.use_triton)
        h = self.heads

        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))

        q = q * self.scale
        sim = einsum('b i d, b j d -> b i j', q, k)

        if exists(mask):
            mask_value = -torch.finfo(sim.dtype).max
            sim = sim.masked_fill(mask, mask_value)

        attn = softmax(sim, causal = self.causal, use_triton = use_triton)
        attn = dropout_fn(attn, self.dropout, use_triton = use_triton)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        out = self.to_out(out)
        return dropout_fn(out, self.dropout, use_triton = use_triton)

class FeedForward(nn.Module):
    def __init__(
        self,
        dim,
        mult = 4,
        dropout = 0.,
        use_triton = False
    ):
        super().__init__()
        self.use_triton = use_triton
        inner_dim = dim * mult
        self.dropout = dropout
        self.proj_in_weight = nn.Parameter(torch.randn(dim, inner_dim))
        self.proj_out = nn.Linear(inner_dim, dim)

    def forward(self, x, use_triton = None):
        use_triton = default(use_triton, self.use_triton)

        x = fused_relu_squared(x, self.proj_in_weight, use_triton = use_triton)
        x = dropout_fn(x, self.dropout, use_triton = use_triton)

        x = self.proj_out(x)
        return x

# 主类

class Transformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        num_tokens,
        max_seq_len,
        depth,
        causal = False,
        heads = 8,
        dim_head = 64,
        ff_dropout = 0.,
        ff_mult = 4,
        attn_dropout = 0.,
        use_triton = False
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化最大序列长度
        self.max_seq_len = max_seq_len
        # 创建 token embedding 层
        self.token_emb = nn.Embedding(num_tokens, dim)
        # 创建位置 embedding 层
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        # 初始化层列表
        self.layers = nn.ModuleList([])
        # 创建部分预归一化残差块
        wrapper = partial(PreNormResidual, dim)

        # 循环创建指定深度的注意力和前馈网络层
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                wrapper(Attention(dim, heads = heads, dim_head = dim_head, causal = causal, dropout = attn_dropout, use_triton = use_triton)),
                wrapper(FeedForward(dim, dropout = ff_dropout, mult = ff_mult, use_triton = use_triton))
            ]))

        # 创建层归一化层
        self.norm = nn.LayerNorm(dim)
        # 创建输出层
        self.to_logits = nn.Linear(dim, num_tokens)

        # 创建掩码

        self.use_triton = use_triton
        self.causal = causal
        # 根据是否自回归创建掩码
        mask = torch.ones(max_seq_len, max_seq_len, dtype = torch.bool).triu(1) if causal else None
        self.register_buffer('mask', mask, persistent = False)

    def forward(
        self,
        x,
        mask = None,
        *,
        labels = None,
        use_triton = None
    ):
        # 设置使用 Triton 加速的标志
        use_triton = default(use_triton, self.use_triton)
        # 获取序列长度和设备信息
        n, device = x.shape[1], x.device

        # 嵌入 token 并添加位置嵌入

        x = self.token_emb(x)
        pos_emb = self.pos_emb(torch.arange(n, device = device))
        x = x + rearrange(pos_emb, 'n d -> () n d')

        # 生成掩码,取决于是否自回归

        assert not (self.causal and exists(mask)), 'mask is not needed during autoregressive mode'

        if self.causal and not use_triton:
            mask = self.mask[:n, :n]
            mask = rearrange(mask, 'i j -> () i j')
        elif not self.causal and exists(mask):
            mask = rearrange(mask, 'b i -> b i ()') * rearrange(mask, 'b j -> b () j')
            mask = ~mask

        # 通过层

        for attn, ff in self.layers:
            x = attn(x, mask = mask, use_triton = use_triton)
            x = ff(x, use_triton = use_triton)

        # 进行层归一化
        x = layernorm(x, self.norm.weight, use_triton = use_triton, stable = True)
        # 计算 logits
        logits = self.to_logits(x)

        if not exists(labels):
            return logits

        # 计算损失
        loss = cross_entropy_fn(logits, labels, ignore_index = 0, use_triton = use_triton)
        return loss

.\lucidrains\triton-transformer\triton_transformer\utils.py

# 检查值是否不为 None
def exists(val):
    return val is not None

# 如果值存在,则返回该值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 根据块大小计算 warp 数量
def calc_num_warps(block_size):
    # 默认 warp 数量为 4
    num_warps = 4
    # 如果块大小大于等于 2048,则 warp 数量为 8
    if block_size >= 2048:
        num_warps = 8
    # 如果块大小大于等于 4096,则 warp 数量为 16
    if block_size >= 4096:
        num_warps = 16
    # 返回 warp 数量
    return num_warps

.\lucidrains\triton-transformer\triton_transformer\__init__.py

# 从 triton_transformer.transformer 模块中导入 Transformer 类
from triton_transformer.transformer import Transformer

Uformer - Pytorch

Implementation of Uformer, Attention-based Unet, in Pytorch. It will only offer the concat-cross-skip connection.

This repository will be geared towards use in a project for learning protein structures. Specifically, it will include the ability to condition on time steps (needed for DDPM), as well as 2d relative positional encoding using rotary embeddings (instead of the bias on the attention matrix in the paper).

Install

$ pip install uformer-pytorch

Usage

import torch
from uformer_pytorch import Uformer

model = Uformer(
    dim = 64,           # initial dimensions after input projection, which increases by 2x each stage
    stages = 4,         # number of stages
    num_blocks = 2,     # number of transformer blocks per stage
    window_size = 16,   # set window size (along one side) for which to do the attention within
    dim_head = 64,
    heads = 8,
    ff_mult = 4
)

x = torch.randn(1, 3, 256, 256)
pred = model(x) # (1, 3, 256, 256)

To condition on time for DDPM training

import torch
from uformer_pytorch import Uformer

model = Uformer(
    dim = 64,
    stages = 4,
    num_blocks = 2,
    window_size = 16,
    dim_head = 64,
    heads = 8,
    ff_mult = 4,
    time_emb = True    # set this to true
)

x = torch.randn(1, 3, 256, 256)
time = torch.arange(1)
pred = model(x, time = time) # (1, 3, 256, 256)

Citations

@misc{wang2021uformer,
    title   = {Uformer: A General U-Shaped Transformer for Image Restoration}, 
    author  = {Zhendong Wang and Xiaodong Cun and Jianmin Bao and Jianzhuang Liu},
    year    = {2021},
    eprint  = {2106.03106},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

.\lucidrains\uformer-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'uformer-pytorch',  # 包的名称
  packages = find_packages(),  # 查找并包含所有包
  version = '0.0.8',  # 版本号
  license='MIT',  # 许可证信息
  description = 'Uformer - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/uformer-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'image segmentation',
    'unet'
  ],
  install_requires=[  # 安装依赖列表
    'einops>=0.3',
    'torch>=1.6'
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\uformer-pytorch\uformer_pytorch\uformer_pytorch.py

# 导入 math 模块
import math
# 从 math 模块导入 log, pi, sqrt 函数
from math import log, pi, sqrt
# 从 functools 模块导入 partial 函数
from functools import partial

# 导入 torch 模块
import torch
# 从 torch 模块导入 nn, einsum 函数
from torch import nn, einsum
# 从 torch.nn 模块导入 functional 模块
import torch.nn.functional as F

# 导入 einops 模块中的 rearrange, repeat 函数
from einops import rearrange, repeat

# 定义常量 List 为 nn.ModuleList 类
List = nn.ModuleList

# 辅助函数

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 返回默认值的函数
def default(val, d):
    return val if exists(val) else d

# 将变量转换为元组的函数
def cast_tuple(val, depth = 1):
    return val if isinstance(val, tuple) else (val,) * depth

# 位置嵌入

# 应用旋转位置嵌入的函数
def apply_rotary_emb(q, k, pos_emb):
    sin, cos = pos_emb
    dim_rotary = sin.shape[-1]
    (q, q_pass), (k, k_pass) = map(lambda t: (t[..., :dim_rotary], t[..., dim_rotary:]), (q, k))
    q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
    q, k = map(lambda t: torch.cat(t, dim = -1), ((q, q_pass), (k, k_pass)))
    return q, k

# 每两个元素旋转的函数
def rotate_every_two(x):
    x = rearrange(x, '... (d j) -> ... d j', j = 2)
    x1, x2 = x.unbind(dim = -1)
    x = torch.stack((-x2, x1), dim = -1)
    return rearrange(x, '... d j -> ... (d j)')

# 轴向旋转嵌入类
class AxialRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_freq = 10):
        super().__init__()
        self.dim = dim
        scales = torch.linspace(1., max_freq / 2, self.dim // 4)
        self.register_buffer('scales', scales)

    def forward(self, x):
        device, dtype, h, w = x.device, x.dtype, *x.shape[-2:]

        seq_x = torch.linspace(-1., 1., steps = h, device = device)
        seq_x = seq_x.unsqueeze(-1)

        seq_y = torch.linspace(-1., 1., steps = w, device = device)
        seq_y = seq_y.unsqueeze(-1)

        scales = self.scales[(*((None,) * (len(seq_x.shape) - 1)), Ellipsis)]
        scales = scales.to(x)

        scales = self.scales[(*((None,) * (len(seq_y.shape) - 1)), Ellipsis)]
        scales = scales.to(x)

        seq_x = seq_x * scales * pi
        seq_y = seq_y * scales * pi

        x_sinu = repeat(seq_x, 'i d -> i j d', j = w)
        y_sinu = repeat(seq_y, 'j d -> i j d', i = h)

        sin = torch.cat((x_sinu.sin(), y_sinu.sin()), dim = -1)
        cos = torch.cat((x_sinu.cos(), y_sinu.cos()), dim = -1)

        sin, cos = map(lambda t: rearrange(t, 'i j d -> i j d'), (sin, cos))
        sin, cos = map(lambda t: repeat(t, 'i j d -> () i j (d r)', r = 2), (sin, cos))
        return sin, cos

# 时间正弦位置嵌入类
class TimeSinuPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device = device) * -emb)
        emb = einsum('i, j -> i  j', x, emb)
        emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
        return emb

# 辅助类

# 层归一化类
class LayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)

    def forward(self, x):
        std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (std + self.eps) * self.g + self.b

# 预归一化��
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = LayerNorm(dim)

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

# 注意力类
class Attention(nn.Module):
    def __init__(self, dim, dim_head = 64, heads = 8, window_size = 16):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        self.window_size = window_size
        inner_dim = dim_head * heads

        self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
        self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias = False)
        self.to_out = nn.Conv2d(inner_dim, dim, 1)
    # 定义前向传播函数,接受输入 x,跳跃连接 skip,默认时间嵌入 time_emb 和位置嵌入 pos_emb
    def forward(self, x, skip = None, time_emb = None, pos_emb = None):
        # 获取头数 h,窗口大小 w,输入张量的批量大小 b
        h, w, b = self.heads, self.window_size, x.shape[0]

        # 如果时间嵌入存在,则将其重排维度并与输入相加
        if exists(time_emb):
            time_emb = rearrange(time_emb, 'b c -> b c () ()')
            x = x + time_emb

        # 将输入 x 转换为查询向量 q
        q = self.to_q(x)

        # 将键值对输入设置为 x
        kv_input = x

        # 如果跳跃连接存在,则将其与键值对输入连接在一起
        if exists(skip):
            kv_input = torch.cat((kv_input, skip), dim = 0)

        # 将键值对输入转换为键 k 和值 v,并按维度进行分块
        k, v = self.to_kv(kv_input).chunk(2, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) x y c', h = h), (q, k, v))

        # 如果位置嵌入存在,则应用旋转位置嵌入到查询 q 和键 k 上
        if exists(pos_emb):
            q, k = apply_rotary_emb(q, k, pos_emb)

        # 重排查询 q、键 k 和值 v 的维度
        q, k, v = map(lambda t: rearrange(t, 'b (x w1) (y w2) c -> (b x y) (w1 w2) c', w1 = w, w2 = w), (q, k, v))

        # 如果跳跃连接存在,则对键 k 和值 v 进行维度重排
        if exists(skip):
            k, v = map(lambda t: rearrange(t, '(r b) n d -> b (r n) d', r = 2), (k, v))

        # 计算注意力相似度矩阵
        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
        # 对相似度矩阵进行 softmax 操作得到注意力权重
        attn = sim.softmax(dim = -1)
        # 根据注意力权重计算输出
        out = einsum('b i j, b j d -> b i d', attn, v)

        # 重排输出的维度
        out = rearrange(out, '(b h x y) (w1 w2) c -> b (h c) (x w1) (y w2)', b = b, h = h, y = x.shape[-1] // w, w1 = w, w2 = w)
        # 将输出传递给输出层并返回结果
        return self.to_out(out)
# 定义一个前馈神经网络模块
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4):
        super().__init__()
        hidden_dim = dim * mult
        # 输入投影层,将输入维度转换为隐藏维度
        self.project_in = nn.Conv2d(dim, hidden_dim, 1)
        # 输出投影层,包含卷积、GELU激活函数和再次卷积
        self.project_out = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding = 1),
            nn.GELU(),
            nn.Conv2d(hidden_dim, dim, 1)
        )

    def forward(self, x, time_emb = None):
        # 对输入进行投影
        x = self.project_in(x)
        # 如果存在时间嵌入,则将其重排并加到输入上
        if exists(time_emb):
            time_emb = rearrange(time_emb, 'b c -> b c () ()')
            x = x + time_emb
        # 返回经过输出投影层的结果
        return self.project_out(x)

# 定义一个块模块
class Block(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        window_size = 16,
        time_emb_dim = None,
        rotary_emb = True
    ):
        super().__init__()
        self.attn_time_emb = None
        self.ff_time_emb = None
        # 如果存在时间嵌入维度,则创建注意力和前馈的时间嵌入
        if exists(time_emb_dim):
            self.attn_time_emb = nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
            self.ff_time_emb = nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim * ff_mult))

        # 如果使用轴向旋转嵌入,则创建位置嵌入
        self.pos_emb = AxialRotaryEmbedding(dim_head) if rotary_emb else None

        # 创建多个块层
        self.layers = List([])
        for _ in range(depth):
            self.layers.append(List([
                PreNorm(dim, Attention(dim, dim_head = dim_head, heads = heads, window_size = window_size)),
                PreNorm(dim, FeedForward(dim, mult = ff_mult))
            ]))

    def forward(self, x, skip = None, time = None):
        attn_time_emb = None
        ff_time_emb = None
        # 如果存在时间信息,则计算注意力和前馈的时间嵌入
        if exists(time):
            assert exists(self.attn_time_emb) and exists(self.ff_time_emb), 'time_emb_dim must be given on init if you are conditioning based on time'
            attn_time_emb = self.attn_time_emb(time)
            ff_time_emb = self.ff_time_emb(time)

        pos_emb = None
        # 如果存在位置嵌入,则计算位置嵌入
        if exists(self.pos_emb):
            pos_emb = self.pos_emb(x)

        # 遍历每个块层,进行注意力和前馈操作
        for attn, ff in self.layers:
            x = attn(x, skip = skip, time_emb = attn_time_emb, pos_emb = pos_emb) + x
            x = ff(x, time_emb = ff_time_emb) + x
        # 返回处理后的结果
        return x

# 定义一个 Uformer 模块
class Uformer(nn.Module):
    def __init__(
        self,
        dim = 64,
        channels = 3,
        stages = 4,
        num_blocks = 2,
        dim_head = 64,
        window_size = 16,
        heads = 8,
        ff_mult = 4,
        time_emb = False,
        input_channels = None,
        output_channels = None
    ):
        # 调用父类的构造函数
        super().__init__()
        # 设置输入通道数为默认值或者与输出通道数相同
        input_channels = default(input_channels, channels)
        output_channels = default(output_channels, channels)

        self.to_time_emb = None
        time_emb_dim = None

        # 如果需要时间嵌入
        if time_emb:
            time_emb_dim = dim
            # 创建时间嵌入层
            self.to_time_emb = nn.Sequential(
                TimeSinuPosEmb(dim),
                nn.Linear(dim, dim * 4),
                nn.GELU(),
                nn.Linear(dim * 4, dim)
            )

        # 输入通道到维度转换
        self.project_in = nn.Sequential(
            nn.Conv2d(input_channels, dim, 3, padding = 1),
            nn.GELU()
        )

        # 维度到输出通道转换
        self.project_out = nn.Sequential(
            nn.Conv2d(dim, output_channels, 3, padding = 1),
        )

        # 下采样和上采样列表
        self.downs = List([])
        self.ups = List([])

        # 将参数转换为指定深度的元组
        heads, window_size, dim_head, num_blocks = map(partial(cast_tuple, depth = stages), (heads, window_size, dim_head, num_blocks))

        # 遍历各个阶段
        for ind, heads, window_size, dim_head, num_blocks in zip(range(stages), heads, window_size, dim_head, num_blocks):
            is_last = ind == (stages - 1)

            # 添加下采样模块
            self.downs.append(List([
                Block(dim, depth = num_blocks, dim_head = dim_head, heads = heads, ff_mult = ff_mult, window_size = window_size, time_emb_dim = time_emb_dim),
                nn.Conv2d(dim, dim * 2, 4, stride = 2, padding = 1)
            ]))

            # 添加上采样模块
            self.ups.append(List([
                nn.ConvTranspose2d(dim * 2, dim, 2, stride = 2),
                Block(dim, depth = num_blocks, dim_head = dim_head, heads = heads, ff_mult = ff_mult, window_size = window_size, time_emb_dim = time_emb_dim)
            ]))

            dim *= 2

            # 如果是最后一个阶段,设置中间模块
            if is_last:
                self.mid = Block(dim = dim, depth = num_blocks, dim_head = dim_head, heads = heads, ff_mult = ff_mult, window_size = window_size, time_emb_dim = time_emb_dim)

    # 前向传播函数
    def forward(
        self,
        x,
        time = None
    ):
        # 如果存在时间信息
        if exists(time):
            assert exists(self.to_time_emb), 'time_emb must be set to true to condition on time'
            time = time.to(x)
            time = self.to_time_emb(time)

        # 输入数据通过输入通道转换
        x = self.project_in(x)

        skips = []
        # 对下采样模块进行迭代
        for block, downsample in self.downs:
            x = block(x, time = time)
            skips.append(x)
            x = downsample(x)

        # 中间模块
        x = self.mid(x, time = time)

        # 对上采样模块进行迭代
        for (upsample, block), skip in zip(reversed(self.ups), reversed(skips)):
            x = upsample(x)
            x = block(x, skip = skip, time = time)

        # 输出数据通过输出通道转换
        x = self.project_out(x)
        return x

.\lucidrains\uformer-pytorch\uformer_pytorch\__init__.py

# 从uformer_pytorch.uformer_pytorch模块中导入Uformer类
from uformer_pytorch.uformer_pytorch import Uformer

UNet Stylegan2

An implementation of Stylegan2 with UNet Discriminator. This repository works largely the same way as Stylegan2 Pytorch. Simply replace all the stylegan2_pytorch command with unet_stylegan2 instead.

Update: Results have been very good. Will need to investigate combining this with a few other techniques, and then I will write up full instructions for use.

Install

$ pip install unet-stylegan2

Usage

$ unet_stylegan2 --data ./path/to/data

Citations

@misc{karras2019analyzing,
    title={Analyzing and Improving the Image Quality of StyleGAN},
    author={Tero Karras and Samuli Laine and Miika Aittala and Janne Hellsten and Jaakko Lehtinen and Timo Aila},
    year={2019},
    eprint={1912.04958},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}
@misc{schnfeld2020unet,
    title={A U-Net Based Discriminator for Generative Adversarial Networks},
    author={Edgar Schönfeld and Bernt Schiele and Anna Khoreva},
    year={2020},
    eprint={2002.12655},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

.\lucidrains\unet-stylegan2\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'unet_stylegan2',  # 包的名称
  packages = find_packages(),  # 查找并包含所有包
  scripts=['bin/unet_stylegan2'],  # 包含可执行脚本
  version = '0.5.1',  # 版本号
  license='GPLv3+',  # 许可证
  description = 'StyleGan2 with UNet Discriminator, in Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/unet-stylegan2',  # 项目链接
  keywords = ['generative adversarial networks', 'artificial intelligence'],  # 关键词
  install_requires=[  # 安装依赖
      'fire',
      'numpy',
      'retry',
      'tqdm',
      'torch',
      'torchvision',
      'pillow',
      'linear_attention_transformer>=0.12.1'
  ],
  classifiers=[  # 分类
      'Development Status :: 4 - Beta',
      'Intended Audience :: Developers',
      'Topic :: Scientific/Engineering :: Artificial Intelligence',
      'License :: OSI Approved :: MIT License',
      'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\unet-stylegan2\unet_stylegan2\diff_augment.py

# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块
import torch.nn.functional as F

# 定义函数 DiffAugment,对输入进行不同类型的数据增强
def DiffAugment(x, types=[]):
    # 遍历传入的增强类型列表
    for p in types:
        # 遍历对应增强类型的函数列表
        for f in AUGMENT_FNS[p]:
            # 对输入数据应用增强函数
            x = f(x)
    # 返回增强后的数据,保证内存格式为 torch.contiguous_format
    return x.contiguous(memory_format=torch.contiguous_format)

# 定义函数 rand_brightness,对输入数据进行随机亮度增强
def rand_brightness(x):
    # 对输入数据添加随机亮度
    x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
    return x

# 定义函数 rand_saturation,对输入数据进行随机饱和度增强
def rand_saturation(x):
    # 计算输入数据的均值
    x_mean = x.mean(dim=1, keepdim=True)
    # 对输入数据添加随机饱和度
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
    return x

# 定义函数 rand_contrast,对输入数据进行随机对比度增强
def rand_contrast(x):
    # 计算输入数据的均值
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    # 对输入数据添加随机对比度
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
    return x

# 定义函数 rand_translation,对输入数据进行随机平移增强
def rand_translation(x, ratio=0.125):
    # 计算平移的像素数
    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    # 生成随机平移量
    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
    # 生成平移后的坐标网格
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )
    # 对坐标进行平移
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    # 对输入数据进行平移操作
    x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous(memory_format=torch.contiguous_format)
    return x

# 定义函数 rand_cutout,对输入数据进行随机遮挡增强
def rand_cutout(x, ratio=0.5):
    # 计算遮挡区域的大小
    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    # 生成随机遮挡区域的偏移量
    offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
    offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
    # 生成遮挡区域的坐标网格
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
    )
    # 对遮挡区域进行偏移
    grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
    grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
    # 生成遮挡掩码
    mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
    mask[grid_batch, grid_x, grid_y] = 0
    # 对输入数据应用遮挡
    x = x * mask.unsqueeze(1)
    return x

# 定义增强函数字典,包含不同类型的增强函数列表
AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'translation': [rand_translation],
    'cutout': [rand_cutout],
}

.\lucidrains\unet-stylegan2\unet_stylegan2\unet_stylegan2.py

# 导入必要的库
import os
import sys
import math
import fire
import json
from tqdm import tqdm
from math import floor, log2
from random import random
from shutil import rmtree
from functools import partial
import multiprocessing

import numpy as np
import torch
from torch import nn
from torch.utils import data
import torch.nn.functional as F

from torch.optim import Adam
from torch.autograd import grad as torch_grad

import torchvision
from torchvision import transforms

from linear_attention_transformer import ImageLinearAttention

from PIL import Image
from pathlib import Path

# 尝试导入 apex 库,设置 APEX_AVAILABLE 变量
try:
    from apex import amp
    APEX_AVAILABLE = True
except:
    APEX_AVAILABLE = False

# 检查是否有可用的 CUDA 设备
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'

# 获取 CPU 核心数量
num_cores = multiprocessing.cpu_count()

# 常量定义

# 支持的图片文件格式
EXTS = ['jpg', 'jpeg', 'png', 'webp']
# 微小的常数,用于避免除零错误
EPS = 1e-8

# 辅助类定义

# 自定义异常类,用于处理 NaN 异常
class NanException(Exception):
    pass

# 指数移动平均类
class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta
    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

# 随机应用类,根据概率应用不同的函数
class RandomApply(nn.Module):
    def __init__(self, prob, fn, fn_else = lambda x: x):
        super().__init__()
        self.fn = fn
        self.fn_else = fn_else
        self.prob = prob
    def forward(self, x):
        fn = self.fn if random() < self.prob else self.fn_else
        return fn(x)

# 残差连接类
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x):
        return self.fn(x) + x

# 展平类
class Flatten(nn.Module):
    def __init__(self, index):
        super().__init__()
        self.index = index
    def forward(self, x):
        return x.flatten(self.index)

# Rezero 类
class Rezero(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        self.g = nn.Parameter(torch.zeros(1))
    def forward(self, x):
        return self.fn(x) * self.g

# 图像的自注意力和前馈网络层
attn_and_ff = lambda chan: nn.Sequential(*[
    Residual(Rezero(ImageLinearAttention(chan, norm_queries = True))),
    Residual(Rezero(nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(), nn.Conv2d(chan * 2, chan, 1))))
])

# 辅助函数定义

# 返回默认值
def default(value, d):
    return d if value is None else value

# 无限循环迭代器
def cycle(iterable):
    while True:
        for i in iterable:
            yield i

# 将元素转换为列表
def cast_list(el):
    return el if isinstance(el, list) else [el]

# 检查张量是否为空
def is_empty(t):
    if isinstance(t, torch.Tensor):
        return t.nelement() == 0
    return t is None

# 如果张量包含 NaN,则抛出异常
def raise_if_nan(t):
    if torch.isnan(t):
        raise NanException

# 反向传播函数,支持混合精度训练
def loss_backwards(fp16, loss, optimizer, **kwargs):
    if fp16:
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward(**kwargs)
    else:
        loss.backward(**kwargs)

# 计算梯度惩罚项
def gradient_penalty(images, outputs, weight = 10):
    batch_size = images.shape[0]
    gradients = torch_grad(outputs=outputs, inputs=images,
                           grad_outputs=list(map(lambda t: torch.ones(t.size()).cuda(), outputs)),
                           create_graph=True, retain_graph=True, only_inputs=True)[0]

    gradients = gradients.reshape(batch_size, -1)
    return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()

# 计算潜在空间长度
def calc_pl_lengths(styles, images):
    num_pixels = images.shape[2] * images.shape[3]
    pl_noise = torch.randn(images.shape).cuda() / math.sqrt(num_pixels)
    outputs = (images * pl_noise).sum()

    pl_grads = torch_grad(outputs=outputs, inputs=styles,
                          grad_outputs=torch.ones(outputs.shape).cuda(),
                          create_graph=True, retain_graph=True, only_inputs=True)[0]

    return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt()

# 生成随机噪声
def noise(n, latent_dim):
    return torch.randn(n, latent_dim).cuda()

# 生成多层随机噪声列表
def noise_list(n, layers, latent_dim):
    # 返回一个包含噪声和层信息的元组列表
    return [(noise(n, latent_dim), layers)]
# 生成一个混合的噪声列表,包含两个噪声列表的和
def mixed_list(n, layers, latent_dim):
    # 随机选择一个整数作为分割点
    tt = int(torch.rand(()).numpy() * layers)
    # 返回两个噪声列表的和
    return noise_list(n, tt, latent_dim) + noise_list(n, layers - tt, latent_dim)

# 将潜在向量描述转换为样式向量和层数的元组列表
def latent_to_w(style_vectorizer, latent_descr):
    return [(style_vectorizer(z), num_layers) for z, num_layers in latent_descr]

# 生成一个指定大小的图像噪声
def image_noise(n, im_size):
    return torch.FloatTensor(n, im_size, im_size, 1).uniform_(0., 1.).cuda()

# 返回一个带有泄漏整流的激活函数
def leaky_relu(p=0.2):
    return nn.LeakyReLU(p)

# 将输入参数按照最大批量大小分块,对模型进行评估
def evaluate_in_chunks(max_batch_size, model, *args):
    split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args))))
    chunked_outputs = [model(*i) for i in split_args]
    if len(chunked_outputs) == 1:
        return chunked_outputs[0]
    return torch.cat(chunked_outputs, dim=0)

# 将样式定义转换为张量
def styles_def_to_tensor(styles_def):
    return torch.cat([t[:, None, :].expand(-1, n, -1) for t, n in styles_def], dim=1)

# 设置模型参数是否需要梯度
def set_requires_grad(model, bool):
    for p in model.parameters():
        p.requires_grad = bool

# Slerp 插值函数
def slerp(val, low, high):
    low_norm = low / torch.norm(low, dim=1, keepdim=True)
    high_norm = high / torch.norm(high, dim=1, keepdim=True)
    omega = torch.acos((low_norm * high_norm).sum(1))
    so = torch.sin(omega)
    res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
    return res

# 热身函数,用于在一定步数内线性增加数值
def warmup(start, end, max_steps, current_step):
    if current_step > max_steps:
        return end
    return (end - start) * (current_step / max_steps) + start

# 对张量进行对数运算
def log(t, eps = 1e-6):
    return torch.log(t + eps)

# 生成 CutMix 的坐标
def cutmix_coordinates(height, width, alpha = 1.):
    lam = np.random.beta(alpha, alpha)

    cx = np.random.uniform(0, width)
    cy = np.random.uniform(0, height)
    w = width * np.sqrt(1 - lam)
    h = height * np.sqrt(1 - lam)
    x0 = int(np.round(max(cx - w / 2, 0)))
    x1 = int(np.round(min(cx + w / 2, width)))
    y0 = int(np.round(max(cy - h / 2, 0)))
    y1 = int(np.round(min(cy + h / 2, height)))

    return ((y0, y1), (x0, x1)), lam

# 执行 CutMix 操作
def cutmix(source, target, coors, alpha = 1.):
    source, target = map(torch.clone, (source, target))
    ((y0, y1), (x0, x1)), _ = coors
    source[:, :, y0:y1, x0:x1] = target[:, :, y0:y1, x0:x1]
    return source

# 对源和目标进行遮罩操作
def mask_src_tgt(source, target, mask):
    return source * mask + (1 - mask) * target

# 数据集

# 将 RGB 图像转换为带透明通道的图像
def convert_rgb_to_transparent(image):
    if image.mode == 'RGB':
        return image.convert('RGBA')
    return image

# 将带透明通道的图像转换为 RGB 图像
def convert_transparent_to_rgb(image):
    if image.mode == 'RGBA':
        return image.convert('RGB')
    return image

# 扩展灰度图像通道数
class expand_greyscale(object):
    def __init__(self, num_channels):
        self.num_channels = num_channels
    def __call__(self, tensor):
        return tensor.expand(self.num_channels, -1, -1)

# 调整图像大小至最小尺寸
def resize_to_minimum_size(min_size, image):
    if max(*image.size) < min_size:
        return torchvision.transforms.functional.resize(image, min_size)
    return image

# 数据集类
class Dataset(data.Dataset):
    def __init__(self, folder, image_size, transparent = False, aug_prob = 0.):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        self.paths = [p for ext in EXTS for p in Path(f'{folder}').glob(f'**/*.{ext}')]

        convert_image_fn = convert_transparent_to_rgb if not transparent else convert_rgb_to_transparent
        num_channels = 3 if not transparent else 4

        self.transform = transforms.Compose([
            transforms.Lambda(convert_image_fn),
            transforms.Lambda(partial(resize_to_minimum_size, image_size)),
            transforms.Resize(image_size),
            RandomApply(aug_prob, transforms.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.98, 1.02)), transforms.CenterCrop(image_size)),
            transforms.ToTensor(),
            transforms.Lambda(expand_greyscale(num_channels))
        ])

    def __len__(self):
        return len(self.paths)
    # 定义一个特殊方法,用于获取对象中指定索引位置的元素
    def __getitem__(self, index):
        # 获取指定索引位置的路径
        path = self.paths[index]
        # 打开指定路径的图像文件
        img = Image.open(path)
        # 对图像进行变换处理并返回
        return self.transform(img)
# 定义一个生成器块类
class GeneratorBlock(nn.Module):
    # 初始化函数
    def __init__(self, latent_dim, input_channel, upsample, rgba=False):
        super().__init__()
        self.input_channel = input_channel
        # 将输入的潜在向量映射到输入通道数
        self.to_style = nn.Linear(latent_dim, input_channel)
        
        # 如果是 RGBA 模式,则输出通道数为 4,否则为 3
        out_filters = 3 if not rgba else 4
        # 定义卷积层,不进行调制
        self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False)
        
        # 如果需要上采样,则定义上采样层
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None

    # 前向传播函数
    def forward(self, x, prev_rgb, istyle):
        b, c, h, w = x.shape
        # 将潜在向量映射到输入通道数
        style = self.to_style(istyle)
        # 使用卷积层进行特征提取
        x = self.conv(x, style)

        # 如果有上一个 RGB 图像,则进行残差连接
        if prev_rgb is not None:
            x = x + prev_rgb

        # 如果需要上采样,则进行上采样操作
        if self.upsample is not None:
            x = self.upsample(x)

        return x
    # 初始化函数,定义生成器的结构
    def __init__(self, latent_dim, input_channels, filters, upsample = True, upsample_rgb = True, rgba = False):
        # 调用父类的初始化函数
        super().__init__()
        # 如果需要上采样,则创建上采样层
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None

        # 创建将潜在向量映射到输入通道的全连接层
        self.to_style1 = nn.Linear(latent_dim, input_channels)
        # 创建将噪声映射到滤波器数量的全连接层
        self.to_noise1 = nn.Linear(1, filters)
        # 创建卷积层,使用自定义的Conv2DMod类
        self.conv1 = Conv2DMod(input_channels, filters, 3)
        
        # 创建将潜在向量映射到滤波器数量的全连接层
        self.to_style2 = nn.Linear(latent_dim, filters)
        # 创建将噪声映射到滤波器数量的全连接层
        self.to_noise2 = nn.Linear(1, filters)
        # 创建卷积层,使用自定义的Conv2DMod类
        self.conv2 = Conv2DMod(filters, filters, 3)

        # 定义激活函数为LeakyReLU
        self.activation = leaky_relu()
        # 创建RGBBlock实例,用于生成RGB输出
        self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba)

    # 前向传播函数,定义生成器的前向传播过程
    def forward(self, x, prev_rgb, istyle, inoise):
        # 如果需要上采样,则对输入进行上采样
        if self.upsample is not None:
            x = self.upsample(x)

        # 裁剪噪声张量,使其与输入张量的尺寸相匹配
        inoise = inoise[:, :x.shape[2], :x.shape[3], :]
        # 将噪声映射到滤波器数量,并进行维度变换
        noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1))
        noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1))

        # 将潜在向量映射到输入通道,并进行卷积操作
        style1 = self.to_style1(istyle)
        x = self.conv1(x, style1)
        x = self.activation(x + noise1)

        # 将潜在向量映射到滤波器数量,并进行卷积操作
        style2 = self.to_style2(istyle)
        x = self.conv2(x, style2)
        x = self.activation(x + noise2)

        # 生成RGB输出
        rgb = self.to_rgb(x, prev_rgb, istyle)
        return x, rgb
# 定义一个包含两个卷积层和激活函数的序列模块
def double_conv(chan_in, chan_out):
    return nn.Sequential(
        nn.Conv2d(chan_in, chan_out, 3, padding=1),  # 3x3卷积层,输入通道数为chan_in,输出通道数为chan_out,填充为1
        leaky_relu(),  # 使用LeakyReLU激活函数
        nn.Conv2d(chan_out, chan_out, 3, padding=1),  # 3x3卷积层,输入通道数为chan_out,输出通道数为chan_out,填充为1
        leaky_relu()  # 使用LeakyReLU激活函数
    )

# 定义一个下采样块模块
class DownBlock(nn.Module):
    def __init__(self, input_channels, filters, downsample=True):
        super().__init__()
        self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))  # 1x1卷积层,输入通道数为input_channels,输出通道数为filters,步长为2或1

        self.net = double_conv(input_channels, filters)  # 使用double_conv函数创建卷积层序列
        self.down = nn.Conv2d(filters, filters, 3, padding=1, stride=2) if downsample else None  # 下采样卷积层,输入通道数为filters,输出通道数为filters,填充为1,步长为2或None

    def forward(self, x):
        res = self.conv_res(x)  # 对输入x进行1x1卷积
        x = self.net(x)  # 使用卷积层序列处理输入x
        unet_res = x

        if self.down is not None:
            x = self.down(x)  # 如果存在下采样卷积层,则对x进行下采样

        x = x + res  # 将1x1卷积结果与处理后的x相加
        return x, unet_res

# 定义一个上采样块模块
class UpBlock(nn.Module):
    def __init__(self, input_channels, filters):
        super().__init__()
        self.conv_res = nn.ConvTranspose2d(input_channels // 2, filters, 1, stride=2)  # 转置卷积层,输入通道数为input_channels的一半,输出通道数为filters,步长为2
        self.net = double_conv(input_channels, filters)  # 使用double_conv函数创建卷积层序列
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)  # 上采样层,尺度因��为2,插值模式为双���性插值,不对齐角点

    def forward(self, x, res):
        *_, h, w = x.shape
        conv_res = self.conv_res(x, output_size=(h * 2, w * 2))  # 对输入x进行转置卷积
        x = self.up(x)  # 对输入x进行上采样
        x = torch.cat((x, res), dim=1)  # 在通道维度上拼接x和res
        x = self.net(x)  # 使用卷积层序列处理拼接后的x
        x = x + conv_res  # 将转置卷积结果与处理后的x相加
        return x

# 定义一个生成器模块
class Generator(nn.Module):
    def __init__(self, image_size, latent_dim, network_capacity=16, transparent=False, no_const=False, fmap_max=512):
        super().__init__()
        self.image_size = image_size
        self.latent_dim = latent_dim
        self.num_layers = int(log2(image_size) - 1)

        filters = [network_capacity * (2 ** (i + 1)) for i in range(self.num_layers)][::-1]

        set_fmap_max = partial(min, fmap_max)
        filters = list(map(set_fmap_max, filters))
        init_channels = filters[0]
        filters = [init_channels, *filters]

        in_out_pairs = zip(filters[:-1], filters[1:])
        self.no_const = no_const

        if no_const:
            self.to_initial_block = nn.ConvTranspose2d(latent_dim, init_channels, 4, 1, 0, bias=False)  # 转置卷积层,输入通道数为latent_dim,输出通道数为init_channels,核大小为4,步长为1,填充为0,无偏置
        else:
            self.initial_block = nn.Parameter(torch.randn((1, init_channels, 4, 4)))  # 初始化块参数为随机张量

        self.initial_conv = nn.Conv2d(filters[0], filters[0], 3, padding=1)  # 3x3卷积层,输入通道数为filters[0],输出通道数为filters[0],填充为1

        self.blocks = nn.ModuleList([])  # 创建模块列表
        self.attns = nn.ModuleList([])  # 创建模块列表

        for ind, (in_chan, out_chan) in enumerate(in_out_pairs):
            not_first = ind != 0
            not_last = ind != (self.num_layers - 1)
            num_layer = self.num_layers - ind

            attn_fn = attn_and_ff(in_chan)  # 获取注意力函数
            self.attns.append(attn_fn)  # 添加到注意力模块列表

            block = GeneratorBlock(
                latent_dim,
                in_chan,
                out_chan,
                upsample=not_first,
                upsample_rgb=not_last,
                rgba=transparent
            )
            self.blocks.append(block)  # 添加生成器块模块到模块列表

    def forward(self, styles, input_noise):
        batch_size = styles.shape[0]
        image_size = self.image_size

        if self.no_const:
            avg_style = styles.mean(dim=1)[:, :, None, None]
            x = self.to_initial_block(avg_style)  # 使用平均风格向量生成初始块
        else:
            x = self.initial_block.expand(batch_size, -1, -1, -1)  # 扩展初始块参数

        x = self.initial_conv(x)  # 对初始块进行卷积
        styles = styles.transpose(0, 1)  # 转置风格张量

        rgb = None
        for style, block, attn in zip(styles, self.blocks, self.attns):
            if attn is not None:
                x = attn(x)  # 如果存在注意力模块,则应用注意力
            x, rgb = block(x, rgb, style, input_noise)  # 使用生成器块模块处理x和rgb

        return rgb  # 返回rgb

class Discriminator(nn.Module):
    # 初始化函数,设置神经网络的参数
    def __init__(self, image_size, network_capacity = 16, transparent = False, fmap_max = 512):
        # 调用父类的初始化函数
        super().__init__()
        # 计算网络层数
        num_layers = int(log2(image_size) - 3)
        # 初始化滤波器数量
        num_init_filters = 3 if not transparent else 4

        blocks = []
        # 计算每一层的滤波器数量
        filters = [num_init_filters] + [(network_capacity) * (2 ** i) for i in range(num_layers + 1)]

        # 设置最大滤波器数量
        set_fmap_max = partial(min, fmap_max)
        filters = list(map(set_fmap_max, filters))
        filters[-1] = filters[-2]

        # 组合输入输出通道数
        chan_in_out = list(zip(filters[:-1], filters[1:]))
        chan_in_out = list(map(list, chan_in_out))

        down_blocks = []
        attn_blocks = []

        # 遍历每一层,创建下采样块和注意力块
        for ind, (in_chan, out_chan) in enumerate(chan_in_out):
            num_layer = ind + 1
            is_not_last = ind != (len(chan_in_out) - 1)

            block = DownBlock(in_chan, out_chan, downsample = is_not_last)
            down_blocks.append(block)

            attn_fn = attn_and_ff(out_chan)
            attn_blocks.append(attn_fn)

        # 将下采样块和注意力块转换为 ModuleList
        self.down_blocks = nn.ModuleList(down_blocks)
        self.attn_blocks = nn.ModuleList(attn_blocks)

        last_chan = filters[-1]

        # 定义输出层
        self.to_logit = nn.Sequential(
            leaky_relu(),
            nn.AvgPool2d(image_size // (2 ** num_layers)),
            Flatten(1),
            nn.Linear(last_chan, 1)
        )

        self.conv = double_conv(last_chan, last_chan)

        # 反向遍历通道输入输出,创建上采样块
        dec_chan_in_out = chan_in_out[:-1][::-1]
        self.up_blocks = nn.ModuleList(list(map(lambda c: UpBlock(c[1] * 2, c[0]), dec_chan_in_out)))
        self.conv_out = nn.Conv2d(3, 1, 1)

    # 前向传播函数
    def forward(self, x):
        b, *_ = x.shape

        residuals = []

        # 遍历下采样块和注意力块
        for (down_block, attn_block) in zip(self.down_blocks, self.attn_blocks):
            x, unet_res = down_block(x)
            residuals.append(unet_res)

            if attn_block is not None:
                x = attn_block(x)

        x = self.conv(x) + x
        enc_out = self.to_logit(x)

        # 反向遍历上采样块,生成解码输出
        for (up_block, res) in zip(self.up_blocks, residuals[:-1][::-1]):
            x = up_block(x, res)

        dec_out = self.conv_out(x)
        return enc_out.squeeze(), dec_out
class StyleGAN2(nn.Module):
    # 定义 StyleGAN2 类,继承自 nn.Module
    def __init__(self, image_size, latent_dim = 512, fmap_max = 512, style_depth = 8, network_capacity = 16, transparent = False, fp16 = False, steps = 1, lr = 1e-4, ttur_mult = 2, no_const = False, lr_mul = 0.1, aug_types = ['translation', 'cutout']):
        # 初始化函数,接受多个参数
        super().__init__()
        # 调用父类的初始化函数

        self.lr = lr
        self.steps = steps
        self.ema_updater = EMA(0.995)
        # 设置学习率、步数和指数移动平均更新器

        self.S = StyleVectorizer(latent_dim, style_depth, lr_mul = lr_mul)
        self.G = Generator(image_size, latent_dim, network_capacity, transparent = transparent, no_const = no_const, fmap_max = fmap_max)
        self.D = Discriminator(image_size, network_capacity, transparent = transparent, fmap_max = fmap_max)
        # 创建 StyleVectorizer、Generator 和 Discriminator 实例

        self.SE = StyleVectorizer(latent_dim, style_depth, lr_mul = lr_mul)
        self.GE = Generator(image_size, latent_dim, network_capacity, transparent = transparent, no_const = no_const)
        # 创建额外的 StyleVectorizer 和 Generator 实例

        self.D_aug = AugWrapper(self.D, image_size, aug_types)
        # 创建用于增强所有输入到鉴别器的包装器

        set_requires_grad(self.SE, False)
        set_requires_grad(self.GE, False)
        # 设置 SE 和 GE 的梯度计算为 False

        generator_params = list(self.G.parameters()) + list(self.S.parameters())
        self.G_opt = Adam(generator_params, lr = self.lr, betas=(0.5, 0.9))
        self.D_opt = Adam(self.D.parameters(), lr = self.lr * ttur_mult, betas=(0.5, 0.9))
        # 设置生成器和鉴别器的优化器

        self._init_weights()
        self.reset_parameter_averaging()
        # 初始化权重和参数平均化

        self.cuda()
        # 将模型移至 GPU

        self.fp16 = fp16
        if fp16:
            (self.S, self.G, self.D, self.SE, self.GE), (self.G_opt, self.D_opt) = amp.initialize([self.S, self.G, self.D, self.SE, self.GE], [self.G_opt, self.D_opt], opt_level='O1')
        # 如果启用混合精度训练,则初始化混合精度训练

    def _init_weights(self):
        # 初始化权重函数
        for m in self.modules():
            if type(m) in {nn.Conv2d, nn.Linear}:
                nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
        # 对卷积层和全连接层进行权重初始化

        for block in self.G.blocks:
            nn.init.zeros_(block.to_noise1.weight)
            nn.init.zeros_(block.to_noise2.weight)
            nn.init.zeros_(block.to_noise1.bias)
            nn.init.zeros_(block.to_noise2.bias)
        # 初始化生成器中的噪声层参数

    def EMA(self):
        # 指数移动平均函数
        def update_moving_average(ma_model, current_model):
            for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
                old_weight, up_weight = ma_params.data, current_params.data
                ma_params.data = self.ema_updater.update_average(old_weight, up_weight)
        # 更新移动平均参数

        update_moving_average(self.SE, self.S)
        update_moving_average(self.GE, self.G)
        # 更新 SE 和 GE 的移动平均参数

    def reset_parameter_averaging(self):
        # 重置参数平均化函数
        self.SE.load_state_dict(self.S.state_dict())
        self.GE.load_state_dict(self.G.state_dict())
        # 将 SE 和 GE 的状态字典加载到 S 和 G 中

    def forward(self, x):
        # 前向传播函数
        return x
        # 返回输入 x

class Trainer():
    # 定义 Trainer 类
    # 初始化函数,设置模型参数和训练参数
    def __init__(self, name, results_dir, models_dir, image_size, network_capacity, transparent = False, batch_size = 4, mixed_prob = 0.9, gradient_accumulate_every=1, lr = 2e-4, ttur_mult = 2, num_workers = None, save_every = 1000, trunc_psi = 0.6, fp16 = False, no_const = False, aug_prob = 0., dataset_aug_prob = 0., cr_weight = 0.2, apply_pl_reg = False, lr_mul = 0.1, *args, **kwargs):
        # 存储 GAN 参数
        self.GAN_params = [args, kwargs]
        self.GAN = None

        # 设置模型名称、结果目录、模型目录、配置文件路径
        self.name = name
        self.results_dir = Path(results_dir)
        self.models_dir = Path(models_dir)
        self.config_path = self.models_dir / name / '.config.json'

        # 检查图像大小是否为2的幂次方
        assert log2(image_size).is_integer(), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
        self.image_size = image_size
        self.network_capacity = network_capacity
        self.transparent = transparent

        self.no_const = no_const
        self.aug_prob = aug_prob

        # 设置学习率、TTUR倍数、学习率倍数、批量大小、工作进程数、混合概率
        self.lr = lr
        self.ttur_mult = ttur_mult
        self.lr_mul = lr_mul
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.mixed_prob = mixed_prob

        self.save_every = save_every
        self.steps = 0

        self.av = None
        self.trunc_psi = trunc_psi

        self.apply_pl_reg = apply_pl_reg
        self.pl_mean = None

        self.gradient_accumulate_every = gradient_accumulate_every

        # 检查是否支持混合精度训练
        assert not fp16 or fp16 and APEX_AVAILABLE, 'Apex is not available for you to use mixed precision training'
        self.fp16 = fp16

        self.d_loss = 0
        self.g_loss = 0
        self.last_gp_loss = 0
        self.last_cr_loss = 0

        # 初始化指数移动平均
        self.pl_length_ma = EMA(0.99)
        self.init_folders()

        self.loader = None
        self.dataset_aug_prob = dataset_aug_prob

        self.cr_weight = cr_weight

    # 初始化 GAN 模型
    def init_GAN(self):
        args, kwargs = self.GAN_params
        self.GAN = StyleGAN2(lr = self.lr, ttur_mult = self.ttur_mult, lr_mul = self.lr_mul, image_size = self.image_size, network_capacity = self.network_capacity, transparent = self.transparent, fp16 = self.fp16, no_const = self.no_const, *args, **kwargs)

    # 写入配置文件
    def write_config(self):
        self.config_path.write_text(json.dumps(self.config()))

    # 加载配置文件
    def load_config(self):
        config = self.config() if not self.config_path.exists() else json.loads(self.config_path.read_text())
        self.image_size = config['image_size']
        self.network_capacity = config['network_capacity']
        self.transparent = config['transparent']
        self.no_const = config.pop('no_const', False)
        del self.GAN
        self.init_GAN()

    # 返回配置信息
    def config(self):
        return {'image_size': self.image_size, 'network_capacity': self.network_capacity, 'transparent': self.transparent, 'no_const': self.no_const}

    # 设置数据源
    def set_data_src(self, folder):
        self.dataset = Dataset(folder, self.image_size, transparent = self.transparent, aug_prob = self.dataset_aug_prob)
        self.loader = cycle(data.DataLoader(self.dataset, num_workers = default(self.num_workers, num_cores), batch_size = self.batch_size, drop_last = True, shuffle=True, pin_memory=True))

    # 禁用梯度计算
    @torch.no_grad()
    # 定义评估函数,用于生成图像
    def evaluate(self, num = 0, num_image_tiles = 8, trunc = 1.0):
        # 将 GAN 设置为评估模式
        self.GAN.eval()
        # 根据是否透明设置文件扩展名
        ext = 'jpg' if not self.transparent else 'png'
        num_rows = num_image_tiles

        latent_dim = self.GAN.G.latent_dim
        image_size = self.GAN.G.image_size
        num_layers = self.GAN.G.num_layers

        # latents and noise

        # 生成潜在向量和噪声
        latents = noise_list(num_rows ** 2, num_layers, latent_dim)
        n = image_noise(num_rows ** 2, image_size)

        # regular

        # 生成正常图像
        generated_images = self.generate_truncated(self.GAN.S, self.GAN.G, latents, n, trunc_psi = self.trunc_psi)
        torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows)
        
        # moving averages

        # 生成移动平均图像
        generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, latents, n, trunc_psi = self.trunc_psi)
        torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-ema.{ext}'), nrow=num_rows)

        # mixing regularities

        # 定义瓷砖函数
        def tile(a, dim, n_tile):
            init_dim = a.size(dim)
            repeat_idx = [1] * a.dim()
            repeat_idx[dim] = n_tile
            a = a.repeat(*(repeat_idx))
            order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).cuda()
            return torch.index_select(a, dim, order_index)

        nn = noise(num_rows, latent_dim)
        tmp1 = tile(nn, 0, num_rows)
        tmp2 = nn.repeat(num_rows, 1)

        tt = int(num_layers / 2)
        mixed_latents = [(tmp1, tt), (tmp2, num_layers - tt)]

        # 生成混合图像
        generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, mixed_latents, n, trunc_psi = self.trunc_psi)
        torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-mr.{ext}'), nrow=num_rows)

    @torch.no_grad()
    # 生成截断图像
    def generate_truncated(self, S, G, style, noi, trunc_psi = 0.75, num_image_tiles = 8):
        latent_dim = G.latent_dim

        if self.av is None:
            z = noise(2000, latent_dim)
            samples = evaluate_in_chunks(self.batch_size, S, z).cpu().numpy()
            self.av = np.mean(samples, axis = 0)
            self.av = np.expand_dims(self.av, axis = 0)
            
        w_space = []
        for tensor, num_layers in style:
            tmp = S(tensor)
            av_torch = torch.from_numpy(self.av).cuda()
            tmp = trunc_psi * (tmp - av_torch) + av_torch
            w_space.append((tmp, num_layers))

        w_styles = styles_def_to_tensor(w_space)
        generated_images = evaluate_in_chunks(self.batch_size, G, w_styles, noi)
        return generated_images.clamp_(0., 1.)

    @torch.no_grad()
    # 生成插值图像序列
    def generate_interpolation(self, num = 0, num_image_tiles = 8, trunc = 1.0, save_frames = False):
        # 将 GAN 设置为评估模式
        self.GAN.eval()
        # 确定文件扩展名
        ext = 'jpg' if not self.transparent else 'png'
        # 设置图像行数
        num_rows = num_image_tiles

        # 获取潜在空间维度、图像尺寸和层数
        latent_dim = self.GAN.G.latent_dim
        image_size = self.GAN.G.image_size
        num_layers = self.GAN.G.num_layers

        # 生成潜在向量和噪声
        latents_low = noise(num_rows ** 2, latent_dim)
        latents_high = noise(num_rows ** 2, latent_dim)
        n = image_noise(num_rows ** 2, image_size)

        # 创建插值比例
        ratios = torch.linspace(0., 8., 100)

        frames = []
        # 遍历插值比例
        for ratio in tqdm(ratios):
            # 线性插值生成插值潜在向量
            interp_latents = slerp(ratio, latents_low, latents_high)
            latents = [(interp_latents, num_layers)]
            # 生成经过截断的图像
            generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, latents, n, trunc_psi = self.trunc_psi)
            # 将生成的图像拼接成网格
            images_grid = torchvision.utils.make_grid(generated_images, nrow = num_rows)
            # 转换为 PIL 图像
            pil_image = transforms.ToPILImage()(images_grid.cpu())
            frames.append(pil_image)

        # 保存为 GIF 动画
        frames[0].save(str(self.results_dir / self.name / f'{str(num)}.gif'), save_all=True, append_images=frames[1:], duration=80, loop=0, optimize=True)

        # 如果需要保存每一帧图像
        if save_frames:
            folder_path = (self.results_dir / self.name / f'{str(num)}')
            folder_path.mkdir(parents=True, exist_ok=True)
            for ind, frame in enumerate(frames):
                frame.save(str(folder_path / f'{str(ind)}.{ext}')

    # 打印日志信息
    def print_log(self):
        pl_mean = default(self.pl_mean, 0)
        print(f'G: {self.g_loss:.2f} | D: {self.d_loss:.2f} | GP: {self.last_gp_loss:.2f} | PL: {pl_mean:.2f} | CR: {self.last_cr_loss:.2f}')

    # 返回模型文件名
    def model_name(self, num):
        return str(self.models_dir / self.name / f'model_{num}.pt')

    # 初始化结果和模型文件夹
    def init_folders(self):
        (self.results_dir / self.name).mkdir(parents=True, exist_ok=True)
        (self.models_dir / self.name).mkdir(parents=True, exist_ok=True)

    # 清空结果和模型文件夹
    def clear(self):
        rmtree(f'./models/{self.name}', True)
        rmtree(f'./results/{self.name}', True)
        rmtree(str(self.config_path), True)
        self.init_folders()

    # 保存模型
    def save(self, num):
        save_data = {'GAN': self.GAN.state_dict()}

        if self.GAN.fp16:
            save_data['amp'] = amp.state_dict()

        torch.save(save_data, self.model_name(num))
        self.write_config()

    # 加载模型
    def load(self, num = -1):
        self.load_config()

        name = num
        if num == -1:
            file_paths = [p for p in Path(self.models_dir / self.name).glob('model_*.pt')]
            saved_nums = sorted(map(lambda x: int(x.stem.split('_')[1]), file_paths))
            if len(saved_nums) == 0:
                return
            name = saved_nums[-1]
            print(f'continuing from previous epoch - {name}')

        self.steps = name * self.save_every

        load_data = torch.load(self.model_name(name))

        self.GAN.load_state_dict(load_data['GAN'])

        if self.GAN.fp16 and 'amp' in load_data:
            amp.load_state_dict(load_data['amp'])

.\lucidrains\unet-stylegan2\unet_stylegan2\__init__.py

# 从 unet_stylegan2 模块中导入 Trainer, StyleGAN2 和 NanException 类
from unet_stylegan2.unet_stylegan2 import Trainer, StyleGAN2, NanException

Uniformer - Pytorch

Implementation of Uniformer, a simple attention and 3d convolutional net that achieved SOTA in a number of video classification tasks

Install

$ pip install uniformer-pytorch

Usage

Uniformer-S

import torch
from uniformer_pytorch import Uniformer

model = Uniformer(
    num_classes = 1000,                 # number of output classes
    dims = (64, 128, 256, 512),         # feature dimensions per stage (4 stages)
    depths = (3, 4, 8, 3),              # depth at each stage
    mhsa_types = ('l', 'l', 'g', 'g')   # aggregation type at each stage, 'l' stands for local, 'g' stands for global
)

video = torch.randn(1, 3, 8, 224, 224)  # (batch, channels, time, height, width)

logits = model(video) # (1, 1000)

Uniformer-B

import torch
from uniformer_pytorch import Uniformer

model = Uniformer(
    num_classes = 1000
    depths = (5, 8, 20, 7)
)

Citations

@inproceedings{anonymous2022uniformer,
    title   = {UniFormer: Unified Transformer for Efficient Spatial-Temporal Representation Learning},
    author  = {Anonymous},
    booktitle = {Submitted to The Tenth International Conference on Learning Representations },
    year    = {2022},
    url     = {https://openreview.net/forum?id=nBU_u6DLvoK},
    note    = {under review}
}

.\lucidrains\uniformer-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'uniformer-pytorch', # 包的名称
  packages = find_packages(), # 查找所有包
  version = '0.0.4', # 版本号
  license='MIT', # 许可证
  description = 'Uniformer - Pytorch', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  url = 'https://github.com/lucidrains/uniformer-pytorch', # 项目链接
  keywords = [ # 关键词列表
    'artificial intelligence',
    'attention mechanism',
    'video classification'
  ],
  install_requires=[ # 安装依赖
    'einops>=0.3',
    'torch>=1.6'
  ],
  classifiers=[ # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\uniformer-pytorch\uniformer_pytorch\uniformer_pytorch.py

import torch
from torch import nn, einsum
from einops import rearrange
from einops.layers.torch import Reduce

# helpers

# 检查值是否存在的辅助函数
def exists(val):
    return val is not None

# classes

# LayerNorm 类
class LayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1, 1))

    def forward(self, x):
        # 计算标准差
        std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
        # 计算均值
        mean = torch.mean(x, dim = 1, keepdim = True)
        # LayerNorm 操作
        return (x - mean) / (std + self.eps) * self.g + self.b

# FeedForward 函数
def FeedForward(dim, mult = 4, dropout = 0.):
    return nn.Sequential(
        LayerNorm(dim),
        nn.Conv3d(dim, dim * mult, 1),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Conv3d(dim * mult, dim, 1)
    )

# MHRAs (multi-head relation aggregators)

# LocalMHRA 类
class LocalMHRA(nn.Module):
    def __init__(
        self,
        dim,
        heads,
        dim_head = 64,
        local_aggr_kernel = 5
    ):
        super().__init__()
        self.heads = heads
        inner_dim = dim_head * heads

        # 使用 BatchNorm3d 代替 LayerNorm
        self.norm = nn.BatchNorm3d(dim)

        # 仅使用值,因为注意力矩阵由卷积处理
        self.to_v = nn.Conv3d(dim, inner_dim, 1, bias = False)

        # 通过相对位置聚合
        self.rel_pos = nn.Conv3d(heads, heads, local_aggr_kernel, padding = local_aggr_kernel // 2, groups = heads)

        # 合并所有头部的输出
        self.to_out = nn.Conv3d(inner_dim, dim, 1)

    def forward(self, x):
        x = self.norm(x)

        b, c, *_, h = *x.shape, self.heads

        # 转换为值
        v = self.to_v(x)

        # 分割头部
        v = rearrange(v, 'b (c h) ... -> (b c) h ...', h = h)

        # 通过相对位置聚合
        out = self.rel_pos(v)

        # 合并头部
        out = rearrange(out, '(b c) h ... -> b (c h) ...', b = b)
        return self.to_out(out)

# GlobalMHRA 类
class GlobalMHRA(nn.Module):
    def __init__(
        self,
        dim,
        heads,
        dim_head = 64,
        dropout = 0.
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        self.norm = LayerNorm(dim)
        self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False)
        self.to_out = nn.Conv1d(inner_dim, dim, 1)

    def forward(self, x):
        x = self.norm(x)

        shape, h = x.shape, self.heads

        x = rearrange(x, 'b c ... -> b c (...)')

        q, k, v = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h d) n -> b h n d', h = h), (q, k, v))

        q = q * self.scale
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        # 注意力
        attn = sim.softmax(dim = -1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b (h d) n', h = h)

        out = self.to_out(out)
        return out.view(*shape)

# Transformer 类
class Transformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        heads,
        mhsa_type = 'g',
        local_aggr_kernel = 5,
        dim_head = 64,
        ff_mult = 4,
        ff_dropout = 0.,
        attn_dropout = 0.
    # 调用父类的构造函数初始化对象
    ):
        super().__init__()

        # 初始化一个空的神经网络模块列表
        self.layers = nn.ModuleList([])

        # 循环创建指定数量的层
        for _ in range(depth):
            # 根据不同的注意力类型创建不同的注意力模块
            if mhsa_type == 'l':
                attn = LocalMHRA(dim, heads = heads, dim_head = dim_head, local_aggr_kernel = local_aggr_kernel)
            elif mhsa_type == 'g':
                attn = GlobalMHRA(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout)
            else:
                raise ValueError('unknown mhsa_type')

            # 将卷积层、注意力层和前馈网络层组成一个模块列表,并添加到神经网络模块列表中
            self.layers.append(nn.ModuleList([
                nn.Conv3d(dim, dim, 3, padding = 1),
                attn,
                FeedForward(dim, mult = ff_mult, dropout = ff_dropout),
            ]))

    # 前向传播函数
    def forward(self, x):
        # 遍历每个层,依次进行前向传播
        for dpe, attn, ff in self.layers:
            # 执行卷积层、注意力层和前馈网络层的操作,并将结果与输入相加
            x = dpe(x) + x
            x = attn(x) + x
            x = ff(x) + x

        # 返回最终的输出结果
        return x
# 主类定义
class Uniformer(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        num_classes,  # 类别数量
        dims = (64, 128, 256, 512),  # 不同层的维度
        depths = (3, 4, 8, 3),  # 不同层的深度
        mhsa_types = ('l', 'l', 'g', 'g'),  # 多头自注意力类型
        local_aggr_kernel = 5,  # 局部聚合核大小
        channels = 3,  # 输入通道数
        ff_mult = 4,  # FeedForward 层的倍数
        dim_head = 64,  # 头部维度
        ff_dropout = 0.,  # FeedForward 层的 dropout
        attn_dropout = 0.  # 注意力层的 dropout
    ):
        super().__init__()
        init_dim, *_, last_dim = dims
        # 将输入视频转换为 tokens
        self.to_tokens = nn.Conv3d(channels, init_dim, (3, 4, 4), stride = (2, 4, 4), padding = (1, 0, 0))

        dim_in_out = tuple(zip(dims[:-1], dims[1:]))
        mhsa_types = tuple(map(lambda t: t.lower(), mhsa_types))

        self.stages = nn.ModuleList([])

        # 遍历不同层的深度和多头自注意力类型
        for ind, (depth, mhsa_type) in enumerate(zip(depths, mhsa_types)):
            is_last = ind == len(depths) - 1
            stage_dim = dims[ind]
            heads = stage_dim // dim_head

            # 添加 Transformer 层和下采样层到 stages
            self.stages.append(nn.ModuleList([
                Transformer(
                    dim = stage_dim,
                    depth = depth,
                    heads = heads,
                    mhsa_type = mhsa_type,
                    ff_mult = ff_mult,
                    ff_dropout = ff_dropout,
                    attn_dropout = attn_dropout
                ),
                nn.Sequential(
                    nn.Conv3d(stage_dim, dims[ind + 1], (1, 2, 2), stride = (1, 2, 2)),
                    LayerNorm(dims[ind + 1]),
                ) if not is_last else None
            ]))

        # 输出层
        self.to_logits = nn.Sequential(
            Reduce('b c t h w -> b c', 'mean'),
            nn.LayerNorm(last_dim),
            nn.Linear(last_dim, num_classes)
        )

    # 前向传播函数
    def forward(self, video):
        x = self.to_tokens(video)

        # 遍历不同层的 Transformer 和下采样层
        for transformer, conv in self.stages:
            x = transformer(x)

            if exists(conv):
                x = conv(x)

        return self.to_logits(x)

.\lucidrains\uniformer-pytorch\uniformer_pytorch\__init__.py

# 从 uniformer_pytorch 包中导入 Uniformer 类
from uniformer_pytorch.uniformer_pytorch import Uniformer

.\lucidrains\vector-quantize-pytorch\examples\autoencoder.py

# 导入所需的库
from tqdm.auto import trange
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from vector_quantize_pytorch import VectorQuantize

# 设置超参数
lr = 3e-4
train_iter = 1000
num_codes = 256
seed = 1234
device = "cuda" if torch.cuda.is_available() else "cpu"

# 定义简单的 VQ 自编码器模型
class SimpleVQAutoEncoder(nn.Module):
    def __init__(self, **vq_kwargs):
        super().__init__()
        # 定义模型的层
        self.layers = nn.ModuleList(
            [
                nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.GELU(),
                nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
                nn.MaxPool2d(kernel_size=2, stride=2),
                VectorQuantize(dim=32, accept_image_fmap=True, **vq_kwargs),
                nn.Upsample(scale_factor=2, mode="nearest"),
                nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
                nn.GELU(),
                nn.Upsample(scale_factor=2, mode="nearest"),
                nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
            ]
        )
        return

    # 前向传播函数
    def forward(self, x):
        for layer in self.layers:
            if isinstance(layer, VectorQuantize):
                x, indices, commit_loss = layer(x)
            else:
                x = layer(x)

        return x.clamp(-1, 1), indices, commit_loss

# 训练函数
def train(model, train_loader, train_iterations=1000, alpha=10):
    def iterate_dataset(data_loader):
        data_iter = iter(data_loader)
        while True:
            try:
                x, y = next(data_iter)
            except StopIteration:
                data_iter = iter(data_loader)
                x, y = next(data_iter)
            yield x.to(device), y.to(device)

    # 迭代训练数据集
    for _ in (pbar := trange(train_iterations)):
        opt.zero_grad()
        x, _ = next(iterate_dataset(train_loader))
        out, indices, cmt_loss = model(x)
        rec_loss = (out - x).abs().mean()
        (rec_loss + alpha * cmt_loss).backward()

        opt.step()
        # 更新进度条显示
        pbar.set_description(
            f"rec loss: {rec_loss.item():.3f} | "
            + f"cmt loss: {cmt_loss.item():.3f} | "
            + f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
        )
    return

# 数据预处理
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
train_dataset = DataLoader(
    datasets.FashionMNIST(
        root="~/data/fashion_mnist", train=True, download=True, transform=transform
    ),
    batch_size=256,
    shuffle=True,
)

# 打印信息并开始训练
print("baseline")
torch.random.manual_seed(seed)
model = SimpleVQAutoEncoder(codebook_size=num_codes).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=lr)
train(model, train_dataset, train_iterations=train_iter)

.\lucidrains\vector-quantize-pytorch\examples\autoencoder_fsq.py

# 导入所需的库
from tqdm.auto import trange
import math
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from vector_quantize_pytorch import FSQ

# 设置超参数
lr = 3e-4
train_iter = 1000
levels = [8, 6, 5] # 目标大小为 2^8,实际大小为 240
num_codes = math.prod(levels) # 计算编码数量
seed = 1234
device = "cuda" if torch.cuda.is_available() else "cpu"

# 定义简单的自动编码器类
class SimpleFSQAutoEncoder(nn.Module):
    def __init__(self, levels: list[int]):
        super().__init__()
        # 定义网络层
        self.layers = nn.ModuleList(
            [
                nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.GELU(),
                nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(32, len(levels), kernel_size=1),
                FSQ(levels), # 使用自定义的 FSQ 模块
                nn.Conv2d(len(levels), 32, kernel_size=3, stride=1, padding=1),
                nn.Upsample(scale_factor=2, mode="nearest"),
                nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
                nn.GELU(),
                nn.Upsample(scale_factor=2, mode="nearest"),
                nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
            ]
        )
        return

    def forward(self, x):
        for layer in self.layers:
            if isinstance(layer, FSQ):
                x, indices = layer(x) # 使用 FSQ 模块
            else:
                x = layer(x)

        return x.clamp(-1, 1), indices

# 训练函数
def train(model, train_loader, train_iterations=1000):
    def iterate_dataset(data_loader):
        data_iter = iter(data_loader)
        while True:
            try:
                x, y = next(data_iter)
            except StopIteration:
                data_iter = iter(data_loader)
                x, y = next(data_iter)
            yield x.to(device), y.to(device)

    for _ in (pbar := trange(train_iterations)):
        opt.zero_grad()
        x, _ = next(iterate_dataset(train_loader))
        out, indices = model(x)
        rec_loss = (out - x).abs().mean()
        rec_loss.backward()

        opt.step()
        pbar.set_description(
            f"rec loss: {rec_loss.item():.3f} | "
            + f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
        )
    return

# 数据预处理
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
train_dataset = DataLoader(
    datasets.FashionMNIST(
        root="~/data/fashion_mnist", train=True, download=True, transform=transform
    ),
    batch_size=256,
    shuffle=True,
)

# 打印信息并开始训练
print("baseline")
torch.random.manual_seed(seed)
model = SimpleFSQAutoEncoder(levels).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=lr)
train(model, train_dataset, train_iterations=train_iter)

.\lucidrains\vector-quantize-pytorch\examples\autoencoder_lfq.py

# 导入所需的库
from tqdm.auto import trange
from math import log2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 导入自定义的 LFQ 模块
from vector_quantize_pytorch import LFQ

# 设置训练参数
lr = 3e-4
train_iter = 1000
seed = 1234
codebook_size = 2 ** 8
entropy_loss_weight = 0.02
diversity_gamma = 1.
device = "cuda" if torch.cuda.is_available() else "cpu"

# 定义 LFQAutoEncoder 类,继承自 nn.Module
class LFQAutoEncoder(nn.Module):
    def __init__(
        self,
        codebook_size,
        **vq_kwargs
    ):
        super().__init__()
        assert log2(codebook_size).is_integer()
        quantize_dim = int(log2(codebook_size))

        # 编码器部分
        self.encode = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.GELU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.GroupNorm(4, 32, affine=False),  # 添加规范化层
            nn.Conv2d(32, quantize_dim, kernel_size=1),
        )

        # LFQ 模块
        self.quantize = LFQ(dim=quantize_dim, **vq_kwargs)

        # 解码器部分
        self.decode = nn.Sequential(
            nn.Conv2d(quantize_dim, 32, kernel_size=3, stride=1, padding=1),
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
            nn.GELU(),
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
        )
        return

    # 前向传播函数
    def forward(self, x):
        x = self.encode(x)
        x, indices, entropy_aux_loss = self.quantize(x)
        x = self.decode(x)
        return x.clamp(-1, 1), indices, entropy_aux_loss

# 训练函数
def train(model, train_loader, train_iterations=1000):
    def iterate_dataset(data_loader):
        data_iter = iter(data_loader)
        while True:
            try:
                x, y = next(data_iter)
            except StopIteration:
                data_iter = iter(data_loader)
                x, y = next(data_iter)
            yield x.to(device), y.to(device)

    # 迭代训练数据集
    for _ in (pbar := trange(train_iterations)):
        opt.zero_grad()
        x, _ = next(iterate_dataset(train_loader))
        out, indices, entropy_aux_loss = model(x)

        rec_loss = F.l1_loss(out, x)
        (rec_loss + entropy_aux_loss).backward()

        opt.step()
        pbar.set_description(
              f"rec loss: {rec_loss.item():.3f} | "
            + f"entropy aux loss: {entropy_aux_loss.item():.3f} | "
            + f"active %: {indices.unique().numel() / codebook_size * 100:.3f}"
        )
    return

# 数据预处理
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

# 加载 FashionMNIST 数据集
train_dataset = DataLoader(
    datasets.FashionMNIST(
        root="~/data/fashion_mnist", train=True, download=True, transform=transform
    ),
    batch_size=256,
    shuffle=True,
)

# 打印提示信息
print("baseline")

# 设置随机种子
torch.random.manual_seed(seed)

# 创建 LFQAutoEncoder 模型实例
model = LFQAutoEncoder(
    codebook_size = codebook_size,
    entropy_loss_weight = entropy_loss_weight,
    diversity_gamma = diversity_gamma
).to(device)

# 定义优化器
opt = torch.optim.AdamW(model.parameters(), lr=lr)

# 训练模型
train(model, train_dataset, train_iterations=train_iter)

Vector Quantization - Pytorch

A vector quantization library originally transcribed from Deepmind's tensorflow implementation, made conveniently into a package. It uses exponential moving averages to update the dictionary.

VQ has been successfully used by Deepmind and OpenAI for high quality generation of images (VQ-VAE-2) and music (Jukebox).

Install

$ pip install vector-quantize-pytorch

Usage

import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
    dim = 256,
    codebook_size = 512,     # codebook size
    decay = 0.8,             # the exponential moving average decay, lower means the dictionary will change faster
    commitment_weight = 1.   # the weight on the commitment loss
)

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x) # (1, 1024, 256), (1, 1024), (1)

Residual VQ

This paper proposes to use multiple vector quantizers to recursively quantize the residuals of the waveform. You can use this with the ResidualVQ class and one extra initialization parameter.

import torch
from vector_quantize_pytorch import ResidualVQ

residual_vq = ResidualVQ(
    dim = 256,
    num_quantizers = 8,      # specify number of quantizers
    codebook_size = 1024,    # codebook size
)

x = torch.randn(1, 1024, 256)

quantized, indices, commit_loss = residual_vq(x)

# (1, 1024, 256), (1, 1024, 8), (1, 8)
# (batch, seq, dim), (batch, seq, quantizer), (batch, quantizer)

# if you need all the codes across the quantization layers, just pass return_all_codes = True

quantized, indices, commit_loss, all_codes = residual_vq(x, return_all_codes = True)

# *_, (8, 1, 1024, 256)
# all_codes - (quantizer, batch, seq, dim)

Furthermore, this paper uses Residual-VQ to construct the RQ-VAE, for generating high resolution images with more compressed codes.

They make two modifications. The first is to share the codebook across all quantizers. The second is to stochastically sample the codes rather than always taking the closest match. You can use both of these features with two extra keyword arguments.

import torch
from vector_quantize_pytorch import ResidualVQ

residual_vq = ResidualVQ(
    dim = 256,
    num_quantizers = 8,
    codebook_size = 1024,
    stochastic_sample_codes = True,
    sample_codebook_temp = 0.1,         # temperature for stochastically sampling codes, 0 would be equivalent to non-stochastic
    shared_codebook = True              # whether to share the codebooks for all quantizers or not
)

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = residual_vq(x)

# (1, 1024, 256), (8, 1, 1024), (8, 1)
# (batch, seq, dim), (quantizer, batch, seq), (quantizer, batch)

A recent paper further proposes to do residual VQ on groups of the feature dimension, showing equivalent results to Encodec while using far fewer codebooks. You can use it by importing GroupedResidualVQ

import torch
from vector_quantize_pytorch import GroupedResidualVQ

residual_vq = GroupedResidualVQ(
    dim = 256,
    num_quantizers = 8,      # specify number of quantizers
    groups = 2,
    codebook_size = 1024,    # codebook size
)

x = torch.randn(1, 1024, 256)

quantized, indices, commit_loss = residual_vq(x)

# (1, 1024, 256), (2, 1, 1024, 8), (2, 1, 8)
# (batch, seq, dim), (groups, batch, seq, quantizer), (groups, batch, quantizer)

Initialization

The SoundStream paper proposes that the codebook should be initialized by the kmeans centroids of the first batch. You can easily turn on this feature with one flag kmeans_init = True, for either VectorQuantize or ResidualVQ class

import torch
from vector_quantize_pytorch import ResidualVQ

residual_vq = ResidualVQ(
    dim = 256,
    codebook_size = 256,
    num_quantizers = 4,
    kmeans_init = True,   # set to True
    kmeans_iters = 10     # number of kmeans iterations to calculate the centroids for the codebook on init
)

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = residual_vq(x)

Increasing codebook usage

This repository will contain a few techniques from various papers to combat "dead" codebook entries, which is a common problem when using vector quantizers.

Lower codebook dimension

The Improved VQGAN paper proposes to have the codebook kept in a lower dimension. The encoder values are projected down before being projected back to high dimensional after quantization. You can set this with the codebook_dim hyperparameter.

import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
    dim = 256,
    codebook_size = 256,
    codebook_dim = 16      # paper proposes setting this to 32 or as low as 8 to increase codebook usage
)

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)

Cosine similarity

The Improved VQGAN paper also proposes to l2 normalize the codes and the encoded vectors, which boils down to using cosine similarity for the distance. They claim enforcing the vectors on a sphere leads to improvements in code usage and downstream reconstruction. You can turn this on by setting use_cosine_sim = True

import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
    dim = 256,
    codebook_size = 256,
    use_cosine_sim = True   # set this to True
)

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)

Expiring stale codes

Finally, the SoundStream paper has a scheme where they replace codes that have hits below a certain threshold with randomly selected vector from the current batch. You can set this threshold with threshold_ema_dead_code keyword.

import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
    dim = 256,
    codebook_size = 512,
    threshold_ema_dead_code = 2  # should actively replace any codes that have an exponential moving average cluster size less than 2
)

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)

Orthogonal regularization loss

VQ-VAE / VQ-GAN is quickly gaining popularity. A recent paper proposes that when using vector quantization on images, enforcing the codebook to be orthogonal leads to translation equivariance of the discretized codes, leading to large improvements in downstream text to image generation tasks.

You can use this feature by simply setting the orthogonal_reg_weight to be greater than 0, in which case the orthogonal regularization will be added to the auxiliary loss outputted by the module.

import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
    dim = 256,
    codebook_size = 256,
    accept_image_fmap = True,                   # set this true to be able to pass in an image feature map
    orthogonal_reg_weight = 10,                 # in paper, they recommended a value of 10
    orthogonal_reg_max_codes = 128,             # this would randomly sample from the codebook for the orthogonal regularization loss, for limiting memory usage
    orthogonal_reg_active_codes_only = False    # set this to True if you have a very large codebook, and would only like to enforce the loss on the activated codes per batch
)

img_fmap = torch.randn(1, 256, 32, 32)
quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32), (1,)
# loss now contains the orthogonal regularization loss with the weight as assigned

Multi-headed VQ

There has been a number of papers that proposes variants of discrete latent representations with a multi-headed approach (multiple codes per feature). I have decided to offer one variant where the same codebook is used to vector quantize across the input dimension head times.

You can also use a more proven approach (memcodes) from NWT paper

import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
    dim = 256,
    codebook_dim = 32,                  # a number of papers have shown smaller codebook dimension to be acceptable
    heads = 8,                          # number of heads to vector quantize, codebook shared across all heads
    separate_codebook_per_head = True,  # whether to have a separate codebook per head. False would mean 1 shared codebook
    codebook_size = 8196,
    accept_image_fmap = True
)

img_fmap = torch.randn(1, 256, 32, 32)
quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32, 8), (1,)

# indices shape - (batch, height, width, heads)

Random Projection Quantizer

This paper first proposed to use a random projection quantizer for masked speech modeling, where signals are projected with a randomly initialized matrix and then matched with a random initialized codebook. One therefore does not need to learn the quantizer. This technique was used by Google's Universal Speech Model to achieve SOTA for speech-to-text modeling.

USM further proposes to use multiple codebook, and the masked speech modeling with a multi-softmax objective. You can do this easily by setting num_codebooks to be greater than 1

import torch
from vector_quantize_pytorch import RandomProjectionQuantizer

quantizer = RandomProjectionQuantizer(
    dim = 512,               # input dimensions
    num_codebooks = 16,      # in USM, they used up to 16 for 5% gain
    codebook_dim = 256,      # codebook dimension
    codebook_size = 1024     # codebook size
)

x = torch.randn(1, 1024, 512)
indices = quantizer(x) # (1, 1024, 16) - (batch, seq, num_codebooks)

This repository should also automatically synchronizing the codebooks in a multi-process setting. If somehow it isn't, please open an issue. You can override whether to synchronize codebooks or not by setting sync_codebook = True | False

Finite Scalar Quantization

VQ FSQ
Quantization argmin_c || z-c || round(f(z))
Gradients Straight Through Estimation (STE) STE
Auxiliary Losses Commitment, codebook, entropy loss, ... N/A
Tricks EMA on codebook, codebook splitting, projections, ... N/A
Parameters Codebook N/A

This work out of Google Deepmind aims to vastly simplify the way vector quantization is done for generative modeling, removing the need for commitment losses, EMA updating of the codebook, as well as tackle the issues with codebook collapse or insufficient utilization. They simply round each scalar into discrete levels with straight through gradients; the codes become uniform points in a hypercube.

Thanks goes out to @sekstini for porting over this implementation in record time!

import torch
from vector_quantize_pytorch import FSQ

levels = [8,5,5,5] # see 4.1 and A.4.1 in the paper
quantizer = FSQ(levels)

x = torch.randn(1, 1024, 4) # 4 since there are 4 levels
xhat, indices = quantizer(x)

print(xhat.shape)    # (1, 1024, 4) - (batch, seq, dim)
print(indices.shape) # (1, 1024)    - (batch, seq)

assert xhat.shape == x.shape
assert torch.all(xhat == quantizer.indices_to_codes(indices))

An improvised Residual FSQ, for an attempt to improve audio encoding.

Credit goes to @sekstini for originally incepting the idea here

import torch
from vector_quantize_pytorch import ResidualFSQ

residual_fsq = ResidualFSQ(
    dim = 256,
    levels = [8, 5, 5, 3],
    num_quantizers = 8
)

x = torch.randn(1, 1024, 256)

residual_fsq.eval()

quantized, indices = residual_fsq(x)

# (1, 1024, 256), (1, 1024, 8), (8)
# (batch, seq, dim), (batch, seq, quantizers), (quantizers)

quantized_out = residual_fsq.get_output_from_indices(indices)

# (8, 1, 1024, 8)
# (residual layers, batch, seq, quantizers)

assert torch.all(quantized == quantized_out)

Lookup Free Quantization

The research team behind MagViT has released new SOTA results for generative video modeling. A core change between v1 and v2 include a new type of quantization, look-up free quantization (LFQ), which eliminates the codebook and embedding lookup entirely.

This paper presents a simple LFQ quantizer of using independent binary latents. Other implementations of LFQ exist. However, the team shows that MAGVIT-v2 with LFQ significantly improves on the ImageNet benchmark. The differences between LFQ and 2-level FSQ includes entropy regularizations as well as maintained commitment loss.

Developing a more advanced method of LFQ quantization without codebook-lookup could revolutionize generative modeling.

You can use it simply as follows. Will be dogfooded at MagViT2 pytorch port

import torch
from vector_quantize_pytorch import LFQ

# you can specify either dim or codebook_size
# if both specified, will be validated against each other

quantizer = LFQ(
    codebook_size = 65536,      # codebook size, must be a power of 2
    dim = 16,                   # this is the input feature dimension, defaults to log2(codebook_size) if not defined
    entropy_loss_weight = 0.1,  # how much weight to place on entropy loss
    diversity_gamma = 1.        # within entropy loss, how much weight to give to diversity of codes, taken from https://arxiv.org/abs/1911.05894
)

image_feats = torch.randn(1, 16, 32, 32)

quantized, indices, entropy_aux_loss = quantizer(image_feats, inv_temperature=100.)  # you may want to experiment with temperature

# (1, 16, 32, 32), (1, 32, 32), (1,)

assert image_feats.shape == quantized.shape
assert (quantized == quantizer.indices_to_codes(indices)).all()

You can also pass in video features as (batch, feat, time, height, width) or sequences as (batch, seq, feat)


seq = torch.randn(1, 32, 16)
quantized, *_ = quantizer(seq)

assert seq.shape == quantized.shape

video_feats = torch.randn(1, 16, 10, 32, 32)
quantized, *_ = quantizer(video_feats)

assert video_feats.shape == quantized.shape

Or support multiple codebooks

import torch
from vector_quantize_pytorch import LFQ

quantizer = LFQ(
    codebook_size = 4096,
    dim = 16,
    num_codebooks = 4  # 4 codebooks, total codebook dimension is log2(4096) * 4
)

image_feats = torch.randn(1, 16, 32, 32)

quantized, indices, entropy_aux_loss = quantizer(image_feats)

# (1, 16, 32, 32), (1, 32, 32, 4), (1,)

assert image_feats.shape == quantized.shape
assert (quantized == quantizer.indices_to_codes(indices)).all()

An improvised Residual LFQ, to see if it can lead to an improvement for audio compression.

import torch
from vector_quantize_pytorch import ResidualLFQ

residual_lfq = ResidualLFQ(
    dim = 256,
    codebook_size = 256,
    num_quantizers = 8
)

x = torch.randn(1, 1024, 256)

residual_lfq.eval()

quantized, indices, commit_loss = residual_lfq(x)

# (1, 1024, 256), (1, 1024, 8), (8)
# (batch, seq, dim), (batch, seq, quantizers), (quantizers)

quantized_out = residual_lfq.get_output_from_indices(indices)

# (8, 1, 1024, 8)
# (residual layers, batch, seq, quantizers)

assert torch.all(quantized == quantized_out)

Latent Quantization

Disentanglement is essential for representation learning as it promotes interpretability, generalization, improved learning, and robustness. It aligns with the goal of capturing meaningful and independent features of the data, facilitating more effective use of learned representations across various applications. For better disentanglement, the challenge is to disentangle underlying variations in a dataset without explicit ground truth information. This work introduces a key inductive bias aimed at encoding and decoding within an organized latent space. The strategy incorporated encompasses discretizing the latent space by assigning discrete code vectors through the utilization of an individual learnable scalar codebook for each dimension. This methodology enables their models to surpass robust prior methods effectively.

Be aware they had to use a very high weight decay for the results in this paper.

import torch
from vector_quantize_pytorch import LatentQuantize

# you can specify either dim or codebook_size
# if both specified, will be validated against each other

quantizer = LatentQuantize(
    levels = [5, 5, 8],      # number of levels per codebook dimension
    dim = 16,                   # input dim
    commitment_loss_weight=0.1,  
    quantization_loss_weight=0.1,
)

image_feats = torch.randn(1, 16, 32, 32)

quantized, indices, loss = quantizer(image_feats)

# (1, 16, 32, 32), (1, 32, 32), (1,)

assert image_feats.shape == quantized.shape
assert (quantized == quantizer.indices_to_codes(indices)).all()

You can also pass in video features as (batch, feat, time, height, width) or sequences as (batch, seq, feat)


seq = torch.randn(1, 32, 16)
quantized, *_ = quantizer(seq)

assert seq.shape == quantized.shape

video_feats = torch.randn(1, 16, 10, 32, 32)
quantized, *_ = quantizer(video_feats)

assert video_feats.shape == quantized.shape

Or support multiple codebooks

import torch
from vector_quantize_pytorch import LatentQuantize

levels = [4, 8, 16]
dim = 9
num_codebooks = 3

model = LatentQuantize(levels, dim, num_codebooks=num_codebooks)

input_tensor = torch.randn(2, 3, dim)
output_tensor, indices, loss = model(input_tensor)

assert output_tensor.shape == input_tensor.shape
assert indices.shape == (2, 3, num_codebooks)
assert loss.item() >= 0

Citations

@misc{oord2018neural,
    title   = {Neural Discrete Representation Learning},
    author  = {Aaron van den Oord and Oriol Vinyals and Koray Kavukcuoglu},
    year    = {2018},
    eprint  = {1711.00937},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{zeghidour2021soundstream,
    title   = {SoundStream: An End-to-End Neural Audio Codec},
    author  = {Neil Zeghidour and Alejandro Luebs and Ahmed Omran and Jan Skoglund and Marco Tagliasacchi},
    year    = {2021},
    eprint  = {2107.03312},
    archivePrefix = {arXiv},
    primaryClass = {cs.SD}
}
@inproceedings{anonymous2022vectorquantized,
    title   = {Vector-quantized Image Modeling with Improved {VQGAN}},
    author  = {Anonymous},
    booktitle = {Submitted to The Tenth International Conference on Learning Representations },
    year    = {2022},
    url     = {https://openreview.net/forum?id=pfNyExj7z2},
    note    = {under review}
}
@unknown{unknown,
    author  = {Lee, Doyup and Kim, Chiheon and Kim, Saehoon and Cho, Minsu and Han, Wook-Shin},
    year    = {2022},
    month   = {03},
    title   = {Autoregressive Image Generation using Residual Quantization}
}
@article{Defossez2022HighFN,
    title   = {High Fidelity Neural Audio Compression},
    author  = {Alexandre D'efossez and Jade Copet and Gabriel Synnaeve and Yossi Adi},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2210.13438}
}
@inproceedings{Chiu2022SelfsupervisedLW,
    title   = {Self-supervised Learning with Random-projection Quantizer for Speech Recognition},
    author  = {Chung-Cheng Chiu and James Qin and Yu Zhang and Jiahui Yu and Yonghui Wu},
    booktitle = {International Conference on Machine Learning},
    year    = {2022}
}
@inproceedings{Zhang2023GoogleUS,
    title   = {Google USM: Scaling Automatic Speech Recognition Beyond 100 Languages},
    author  = {Yu Zhang and Wei Han and James Qin and Yongqiang Wang and Ankur Bapna and Zhehuai Chen and Nanxin Chen and Bo Li and Vera Axelrod and Gary Wang and Zhong Meng and Ke Hu and Andrew Rosenberg and Rohit Prabhavalkar and Daniel S. Park and Parisa Haghani and Jason Riesa and Ginger Perng and Hagen Soltau and Trevor Strohman and Bhuvana Ramabhadran and Tara N. Sainath and Pedro J. Moreno and Chung-Cheng Chiu and Johan Schalkwyk and Franccoise Beaufays and Yonghui Wu},
    year    = {2023}
}
@inproceedings{Shen2023NaturalSpeech2L,
    title   = {NaturalSpeech 2: Latent Diffusion Models are Natural and Zero-Shot Speech and Singing Synthesizers},
    author  = {Kai Shen and Zeqian Ju and Xu Tan and Yanqing Liu and Yichong Leng and Lei He and Tao Qin and Sheng Zhao and Jiang Bian},
    year    = {2023}
}
@inproceedings{Yang2023HiFiCodecGV,
    title   = {HiFi-Codec: Group-residual Vector quantization for High Fidelity Audio Codec},
    author  = {Dongchao Yang and Songxiang Liu and Rongjie Huang and Jinchuan Tian and Chao Weng and Yuexian Zou},
    year    = {2023}
}
@article{Liu2023BridgingDA,
    title   = {Bridging Discrete and Backpropagation: Straight-Through and Beyond},
    author  = {Liyuan Liu and Chengyu Dong and Xiaodong Liu and Bin Yu and Jianfeng Gao},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2304.08612}
}
@inproceedings{huh2023improvedvqste,
    title   = {Straightening Out the Straight-Through Estimator: Overcoming Optimization Challenges in Vector Quantized Networks},
    author  = {Huh, Minyoung and Cheung, Brian and Agrawal, Pulkit and Isola, Phillip},
    booktitle = {International Conference on Machine Learning},
    year    = {2023},
    organization = {PMLR}
}
@inproceedings{rogozhnikov2022einops,
    title   = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
    author  = {Alex Rogozhnikov},
    booktitle = {International Conference on Learning Representations},
    year    = {2022},
    url     = {https://openreview.net/forum?id=oapKSVM2bcj}
}
@misc{shin2021translationequivariant,
    title   = {Translation-equivariant Image Quantizer for Bi-directional Image-Text Generation},
    author  = {Woncheol Shin and Gyubok Lee and Jiyoung Lee and Joonseok Lee and Edward Choi},
    year    = {2021},
    eprint  = {2112.00384},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{mentzer2023finite,
    title   = {Finite Scalar Quantization: VQ-VAE Made Simple},
    author  = {Fabian Mentzer and David Minnen and Eirikur Agustsson and Michael Tschannen},
    year    = {2023},
    eprint  = {2309.15505},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{yu2023language,
    title   = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation},
    author  = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
    year    = {2023},
    eprint  = {2310.05737},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{hsu2023disentanglement,
    title   = {Disentanglement via Latent Quantization}, 
    author  = {Kyle Hsu and Will Dorrell and James C. R. Whittington and Jiajun Wu and Chelsea Finn},
    year    = {2023},
    eprint  = {2305.18378},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

.\lucidrains\vector-quantize-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  # 包的名称
  name = 'vector_quantize_pytorch',
  # 查找并包含所有包
  packages = find_packages(),
  # 版本号
  version = '1.14.5',
  # 许可证
  license='MIT',
  # 描述
  description = 'Vector Quantization - Pytorch',
  # 长描述内容类型
  long_description_content_type = 'text/markdown',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/vector-quantizer-pytorch',
  # 关键词
  keywords = [
    'artificial intelligence',
    'deep learning',
    'pytorch',
    'quantization'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.7.0',
    'einx[torch]>=0.1.3',
    'torch'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\finite_scalar_quantization.py

"""
Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
Code adapted from Jax version in Appendix A.1
"""

from typing import List, Tuple, Optional

import torch
import torch.nn as nn
from torch.nn import Module
from torch import Tensor, int32
from torch.cuda.amp import autocast

from einops import rearrange, pack, unpack

# helper functions

# 检查变量是否存在
def exists(v):
    return v is not None

# 返回第一个存在的参数
def default(*args):
    for arg in args:
        if exists(arg):
            return arg
    return None

# 将单个张量按照指定模式打包
def pack_one(t, pattern):
    return pack([t], pattern)

# 将单个张量按照指定模式解包
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# tensor helpers

# 使用直通梯度进行四舍五入
def round_ste(z: Tensor) -> Tensor:
    """Round with straight through gradients."""
    zhat = z.round()
    return z + (zhat - z).detach()

# main class

class FSQ(Module):
    def __init__(
        self,
        levels: List[int],
        dim: Optional[int] = None,
        num_codebooks = 1,
        keep_num_codebooks_dim: Optional[bool] = None,
        scale: Optional[float] = None,
        allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64)
    ):
        super().__init__()
        _levels = torch.tensor(levels, dtype=int32)
        self.register_buffer("_levels", _levels, persistent = False)

        _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
        self.register_buffer("_basis", _basis, persistent = False)

        self.scale = scale

        codebook_dim = len(levels)
        self.codebook_dim = codebook_dim

        effective_codebook_dim = codebook_dim * num_codebooks
        self.num_codebooks = num_codebooks
        self.effective_codebook_dim = effective_codebook_dim

        keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
        assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
        self.keep_num_codebooks_dim = keep_num_codebooks_dim

        self.dim = default(dim, len(_levels) * num_codebooks)

        has_projections = self.dim != effective_codebook_dim
        self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
        self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
        self.has_projections = has_projections

        self.codebook_size = self._levels.prod().item()

        implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out = False)
        self.register_buffer("implicit_codebook", implicit_codebook, persistent = False)

        self.allowed_dtypes = allowed_dtypes

    def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
        """Bound `z`, an array of shape (..., d)."""
        half_l = (self._levels - 1) * (1 + eps) / 2
        offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
        shift = (offset / half_l).atanh()
        return (z + shift).tanh() * half_l - offset

    def quantize(self, z: Tensor) -> Tensor:
        """Quantizes z, returns quantized zhat, same shape as z."""
        quantized = round_ste(self.bound(z))
        half_width = self._levels // 2 # Renormalize to [-1, 1].
        return quantized / half_width
    
    def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
        half_width = self._levels // 2
        return (zhat_normalized * half_width) + half_width
    
    def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
        half_width = self._levels // 2
        return (zhat - half_width) / half_width
    
    def codes_to_indices(self, zhat: Tensor) -> Tensor:
        """Converts a `code` to an index in the codebook."""
        assert zhat.shape[-1] == self.codebook_dim
        zhat = self._scale_and_shift(zhat)
        return (zhat * self._basis).sum(dim=-1).to(int32)
    
    def indices_to_codes(
        self,
        indices: Tensor,
        project_out = True
    def codes_to_indices(self, indices: Tensor) -> Tensor:
        """Inverse of `codes_to_indices`."""
        
        # 检查输入张量的维度是否大于等于3(图片或视频)
        is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))

        # 将输入张量的维度调整为 '... -> ... 1'
        indices = rearrange(indices, '... -> ... 1')
        
        # 计算非中心化的编码
        codes_non_centered = (indices // self._basis) % self._levels
        # 对编码进行缩放和偏移
        codes = self._scale_and_shift_inverse(codes_non_centered)

        # 如果需要保留编码簇维度
        if self.keep_num_codebooks_dim:
            codes = rearrange(codes, '... c d -> ... (c d)')

        # 如果需要进行投影
        if project_out:
            codes = self.project_out(codes)

        # 如果是图片或视频
        if is_img_or_video:
            codes = rearrange(codes, 'b ... d -> b d ...')

        # 返回编码
        return codes

    @autocast(enabled = False)
    def forward(self, z: Tensor) -> Tensor:
        """
        einstein notation
        b - batch
        n - sequence (or flattened spatial dimensions)
        d - feature dimension
        c - number of codebook dim
        """

        # 保存原始数据类型
        orig_dtype = z.dtype
        # 检查输入张量的维度是否大于等于4(图片或视频)
        is_img_or_video = z.ndim >= 4

        # 确保输入张量的数据类型在允许的范围内
        if z.dtype not in self.allowed_dtypes:
            z = z.float()

        # 标准化图片或视频数据为 (batch, seq, dimension) 的形式
        if is_img_or_video:
            z = rearrange(z, 'b d ... -> b ... d')
            z, ps = pack_one(z, 'b * d')

        # 断言输入张量的最后一个维度是否与指定的维度相匹配
        assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'

        # 对输入张量进行投影
        z = self.project_in(z)

        # 调整输入张量的维度为 'b n (c d)'
        z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)

        # 对输入张量进行量化
        codes = self.quantize(z)
        # 将编码转换为索引
        indices = self.codes_to_indices(codes)

        # 调整编码的维度为 'b n (c d)'
        codes = rearrange(codes, 'b n c d -> b n (c d)')

        # 对输出进行投影
        out = self.project_out(codes)

        # 恢复图片或视频的维度
        if is_img_or_video:
            out = unpack_one(out, ps, 'b * d')
            out = rearrange(out, 'b ... d -> b d ...')

            indices = unpack_one(indices, ps, 'b * c')

        # 如果不需要保留编码簇维度
        if not self.keep_num_codebooks_dim:
            indices = rearrange(indices, '... 1 -> ...')

        # 将输出转换回原始数据类型
        if out.dtype != orig_dtype:
            out = out.type(orig_dtype)

        # 返回量化输出和索引
        return out, indices

.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\latent_quantization.py

"""
Disentanglement via Latent Quantization
 - https://arxiv.org/abs/2305.18378
Code adapted from Jax version in https://github.com/kylehkhsu/latent_quantization
"""

# 导入所需的库
from typing import List, Optional, Union, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module
from torch import Tensor, int32
from torch.optim import Optimizer
from einops import rearrange, pack, unpack

# 辅助函数

# 检查变量是否存在
def exists(v):
    return v is not None

# 返回第一个非空参数
def default(*args):
    for arg in args:
        if exists(arg):
            return arg
    return None

# 将单个张量按指定模式打包
def pack_one(t, pattern):
    return pack([t], pattern)

# 将单个张量按指定模式解包
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 主类

class LatentQuantize(Module):
    # 计算量化损失
    def quantization_loss(self, z: Tensor, zhat: Tensor, reduce="mean") -> Tensor:
        """Computes the quantization loss."""
        return F.mse_loss(zhat.detach(), z, reduction=reduce)

    # 计算约束损失
    def commitment_loss(self, z: Tensor, zhat: Tensor, reduce="mean") -> Tensor:
        """Computes the commitment loss."""
        return F.mse_loss(z.detach(), zhat, reduction=reduce)    

    # 对 z 进行量化
    def quantize(self, z: Tensor) -> Tensor:
        """Quantizes z, returns quantized zhat, same shape as z.
        The quantization is done by measuring the distance between the input and the codebook values per latent dimension
        and returning the index of the closest codebook value.
        """
        def distance(x, y):
            return torch.abs(x - y)
        
        if self._equal_levels:
            index = torch.argmin(distance(z[..., None], self.values_per_latent), dim=-1)
            quantize = self.values_per_latent[torch.arange(self.dim), index]
        else:
            index = torch.stack([torch.argmin(distance(z[..., i, None], self.values_per_latent[i]), dim=-1) for i in range(self.codebook_dim)], dim=-1)
            quantize = torch.stack([self.values_per_latent[i][index[..., i]] for i in range(self.codebook_dim)], dim=-1)

        quantize = z + (quantize - z).detach()
        #half_width = self._levels // 2 / 2  # Renormalize to [-0.5, 0.5].
        return quantize #/ half_width
    
    # 缩放和移位 zhat 从 [-0.5, 0.5] 到 [0, level_per_dim]
    def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
        """ scale and shift zhat from [-0.5, 0.5] to [0, level_per_dim]"""
        half_width = self._levels // 2
        return (zhat_normalized * 2 * half_width) + half_width
    
    # 将 zhat 反向缩放和移位为 [-0.5, 0.5]
    def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
        """normalize zhat to [-0.5, 0.5]"""
        half_width = self._levels // 2
        return (zhat - half_width) / half_width / 2
    
    # 将编码转换为索引
    def codes_to_indices(self, zhat: Tensor) -> Tensor:
        """Converts a `code` which contains the number per latent to an index in the codebook."""
        assert zhat.shape[-1] == self.codebook_dim
        zhat = self._scale_and_shift(zhat)
        return (zhat * self._basis).sum(dim=-1).to(int32)
    
    # 将索引转换为编码
    def indices_to_codes(
        self,
        indices: Tensor,
        project_out = True
    ) -> Tensor:
        """Inverse of `codes_to_indices`."""

        is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))

        indices = rearrange(indices, '... -> ... 1')
        codes_non_centered = (indices // self._basis) % self._levels
        codes = self._scale_and_shift_inverse(codes_non_centered)

        if self.keep_num_codebooks_dim:
            codes = rearrange(codes, '... c d -> ... (c d)')

        if project_out:
            codes = self.project_out(codes)

        if is_img_or_video:
            codes = rearrange(codes, 'b ... d -> b d ...')

        return codes
    # 对输入张量进行量化和投影操作
    def quantize_and_project(self, z: Tensor, is_img_or_video, ps) -> Tensor:
        # 对输入张量进行量化操作
        codes = self.quantize(z)
        # 将量化后的结果转换为索引
        indices = self.codes_to_indices(codes)

        # 重排列张量维度
        codes = rearrange(codes, 'b n c d -> b n (c d)')

        # 对量化后的结果进行投影操作
        out = self.project_out(codes)

        # 重新构建图像或视频的维度

        if is_img_or_video:
            # 解包张量
            out = unpack_one(out, ps, 'b * d')
            out = rearrange(out, 'b ... d -> b d ...')

            indices = unpack_one(indices, ps, 'b * c')

        if not self.keep_num_codebooks_dim:
            indices = rearrange(indices, '... 1 -> ...')
        return codes, out, indices

    # 前向传播函数
    def forward(self,
                 z: Tensor) -> Tensor:
        """
        einstein notation
        b - batch
        n - sequence (or flattened spatial dimensions)
        d - feature dimension 
        c - number of codebook dim
        """

        # 判断输入张量是否为图像或视频
        is_img_or_video = z.ndim >= 4
        original_input = z
        # 标准化图像或视频为 (batch, seq, dimension) 格式
        should_inplace_optimize = exists(self.in_place_codebook_optimizer)

        if is_img_or_video:
            z = rearrange(z, 'b d ... -> b ... d')
            z, ps = pack_one(z, 'b * d')

        assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'

        # 投影输入张量
        z = self.project_in(z)
        z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)

        # 对输入张量进行量化操作
        codes = self.quantize(z)
        # 将量化后的结果转换为索引
        indices = self.codes_to_indices(codes)

        # 重排列张量维度
        codes = rearrange(codes, 'b n c d -> b n (c d)')

        # 对量化后的结果进行投影操作
        out = self.project_out(codes)

        # 重新构建图像或视频的维度
        if is_img_or_video:
            out = unpack_one(out, ps, 'b * d')
            out = rearrange(out, 'b ... d -> b d ...')

            indices = unpack_one(indices, ps, 'b * c')

        if not self.keep_num_codebooks_dim:
            indices = rearrange(indices, '... 1 -> ...')
            
        if should_inplace_optimize and self.training and not self.optimize_values:
            # 更新码���
            loss = self.commitment_loss(z, out) if self.commitment_loss_weight!=0  else torch.tensor(0.)
            loss+= self.quantization_loss(z, out) if self.quantization_loss_weight!=0 else torch.tensor(0.)
            loss.backward()
            self.in_place_codebook_optimizer.step()
            self.in_place_codebook_optimizer.zero_grad()
            # 再次对输入张量进行量化
            codes = self.quantize(z)
            indices = self.codes_to_indices(codes)
            codes = rearrange(codes, 'b n c d -> b n (c d)')
            out = self.project_out(codes)
            
            if is_img_or_video:
                out = unpack_one(out, ps, 'b * d')
                out = rearrange(out, 'b ... d -> b d ...')

                indices = unpack_one(indices, ps, 'b * c')

            if not self.keep_num_codebooks_dim:
                indices = rearrange(indices, '... 1 -> ...')


        # 计算损失
        commitment_loss = self.commitment_loss(original_input, out) if self.training and self.commitment_loss_weight!=0  else torch.tensor(0.)
        quantization_loss = self.quantization_loss(original_input, out) if self.training and self.quantization_loss_weight!=0 else torch.tensor(0.)


        loss = self.commitment_loss_weight * commitment_loss + self.quantization_loss_weight * quantization_loss 

        return out, indices, loss

.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\lookup_free_quantization.py

"""
Lookup Free Quantization
Proposed in https://arxiv.org/abs/2310.05737

In the simplest setup, each dimension is quantized into {-1, 1}.
An entropy penalty is used to encourage utilization.
"""

from math import log2, ceil
from collections import namedtuple

import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.nn import Module
from torch.cuda.amp import autocast

from einops import rearrange, reduce, pack, unpack

# constants

Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])

LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])

# helper functions

def exists(v):
    return v is not None

def default(*args):
    for arg in args:
        if exists(arg):
            return arg() if callable(arg) else arg
    return None

def pack_one(t, pattern):
    return pack([t], pattern)

def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# entropy

def log(t, eps = 1e-5):
    return t.clamp(min = eps).log()

def entropy(prob):
    return (-prob * log(prob)).sum(dim=-1)

# class

class LFQ(Module):
    def __init__(
        self,
        *,
        dim = None,
        codebook_size = None,
        entropy_loss_weight = 0.1,
        commitment_loss_weight = 0.25,
        diversity_gamma = 1.,
        straight_through_activation = nn.Identity(),
        num_codebooks = 1,
        keep_num_codebooks_dim = None,
        codebook_scale = 1.,            # for residual LFQ, codebook scaled down by 2x at each layer
        frac_per_sample_entropy = 1.    # make less than 1. to only use a random fraction of the probs for per sample entropy
    ):
        super().__init__()

        # some assert validations

        assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
        assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'

        codebook_size = default(codebook_size, lambda: 2 ** dim)
        codebook_dim = int(log2(codebook_size))

        codebook_dims = codebook_dim * num_codebooks
        dim = default(dim, codebook_dims)

        has_projections = dim != codebook_dims
        self.project_in = nn.Linear(dim, codebook_dims) if has_projections else nn.Identity()
        self.project_out = nn.Linear(codebook_dims, dim) if has_projections else nn.Identity()
        self.has_projections = has_projections

        self.dim = dim
        self.codebook_dim = codebook_dim
        self.num_codebooks = num_codebooks

        keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
        assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
        self.keep_num_codebooks_dim = keep_num_codebooks_dim

        # straight through activation

        self.activation = straight_through_activation

        # entropy aux loss related weights

        assert 0 < frac_per_sample_entropy <= 1.
        self.frac_per_sample_entropy = frac_per_sample_entropy

        self.diversity_gamma = diversity_gamma
        self.entropy_loss_weight = entropy_loss_weight

        # codebook scale

        self.codebook_scale = codebook_scale

        # commitment loss

        self.commitment_loss_weight = commitment_loss_weight

        # for no auxiliary loss, during inference

        self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
        self.register_buffer('zero', torch.tensor(0.), persistent = False)

        # codes

        all_codes = torch.arange(codebook_size)
        bits = ((all_codes[..., None].int() & self.mask) != 0).float()
        codebook = self.bits_to_codes(bits)

        self.register_buffer('codebook', codebook, persistent = False)

    def bits_to_codes(self, bits):
        return bits * self.codebook_scale * 2 - self.codebook_scale

    @property
    # 返回当前对象的数据类型
    def dtype(self):
        return self.codebook.dtype

    # 将索引转换为代码
    def indices_to_codes(
        self,
        indices,
        project_out = True
    ):
        # 判断输入的索引是否为图像或视频数据
        is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))

        # 如果不保留代码簿的数量维度,则重新排列索引
        if not self.keep_num_codebooks_dim:
            indices = rearrange(indices, '... -> ... 1')

        # 将索引转换为代码,代码为-1或1的位
        bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)

        # 将位转换为代码
        codes = self.bits_to_codes(bits)

        # 重新排列代码的维度
        codes = rearrange(codes, '... c d -> ... (c d)')

        # 是否将代码投影到原始维度
        # 如果输入特征维度不是log2(代码簿大小)
        if project_out:
            codes = self.project_out(codes)

        # 将代码重新排列回原始形状
        if is_img_or_video:
            codes = rearrange(codes, 'b ... d -> b d ...')

        return codes

    # 前向传播函数
    @autocast(enabled = False)
    def forward(
        self,
        x,
        inv_temperature = 100.,
        return_loss_breakdown = False,
        mask = None,

.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\random_projection_quantizer.py

import torch
from torch import nn, einsum
import torch.nn.functional as F
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize

from einops import rearrange, repeat, pack, unpack

def exists(val):
    return val is not None

class RandomProjectionQuantizer(nn.Module):
    """ https://arxiv.org/abs/2202.01855 """

    def __init__(
        self,
        *,
        dim,
        codebook_size,
        codebook_dim,
        num_codebooks = 1,
        norm = True,
        **kwargs
    ):
        super().__init__()
        self.num_codebooks = num_codebooks

        # 初始化随机投影矩阵,形状为(num_codebooks, dim, codebook_dim)
        rand_projs = torch.empty(num_codebooks, dim, codebook_dim)
        nn.init.xavier_normal_(rand_projs)

        # 将随机投影矩阵注册为模型的缓冲区
        self.register_buffer('rand_projs', rand_projs)

        # 根据输入参数决定是否进行归一化
        self.norm = nn.LayerNorm(dim, elementwise_affine = False) if norm else nn.Identity()

        # 创建向量量化层
        self.vq = VectorQuantize(
            dim = codebook_dim * num_codebooks,
            heads = num_codebooks,
            codebook_size = codebook_size,
            use_cosine_sim = True,
            separate_codebook_per_head = True,
            **kwargs
        )

    def forward(
        self,
        x,
        indices = None
    ):
        return_loss = exists(indices)

        # 对输入数据进行归一化
        x = self.norm(x)

        # 进行随机投影
        x = einsum('b n d, h d e -> b n h e', x, self.rand_projs)
        x, ps = pack([x], 'b n *')

        # 将向量量化层设置为评估模式
        self.vq.eval()
        # 使用向量量化层处理输入数据
        out = self.vq(x, indices = indices)

        if return_loss:
            _, ce_loss = out
            return ce_loss

        _, indices, _ = out
        return indices

.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\residual_fsq.py

import random
from math import log2
from functools import partial

from typing import List

import torch
from torch import nn
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from torch.cuda.amp import autocast

from vector_quantize_pytorch.finite_scalar_quantization import FSQ

from einops import rearrange, repeat, reduce, pack, unpack

from einx import get_at

# helper functions

# 检查值是否存在
def exists(val):
    return val is not None

# 返回列表的第一个元素
def first(l):
    return l[0]

# 如果值存在则返回值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 将数字向上取整到最接近的倍数
def round_up_multiple(num, mult):
    return ceil(num / mult) * mult

# main class

class ResidualFSQ(Module):
    """ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """

    def __init__(
        self,
        *,
        dim,
        levels: List[int],
        num_quantizers,
        quantize_dropout = False,
        quantize_dropout_cutoff_index = 0,
        quantize_dropout_multiple_of = 1,
        **kwargs
    ):
        super().__init__()
        codebook_dim = len(levels)

        requires_projection = codebook_dim != dim
        # 如果需要投影,则创建输入和输出的线性层
        self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
        self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
        self.has_projections = requires_projection

        self.num_quantizers = num_quantizers

        self.levels = levels
        self.layers = nn.ModuleList([])

        levels_tensor = torch.Tensor(levels)

        scales = []

        for ind in range(num_quantizers):
            scales.append((levels_tensor - 1) ** -ind)

            fsq = FSQ(
                levels = levels,
                dim = codebook_dim,
                **kwargs
            )

            self.layers.append(fsq)

        assert all([not fsq.has_projections for fsq in self.layers])

        self.codebook_size = self.layers[0].codebook_size

        # 将尺度存储为缓冲区
        self.register_buffer('scales', torch.stack(scales), persistent = False)

        self.quantize_dropout = quantize_dropout and num_quantizers > 1

        assert quantize_dropout_cutoff_index >= 0

        self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
        self.quantize_dropout_multiple_of = quantize_dropout_multiple_of  # encodec paper proposes structured dropout, believe this was set to 4

    @property
    def codebooks(self):
        # 获取所有量化器的隐式码书
        codebooks = [layer.implicit_codebook for layer in self.layers]
        codebooks = torch.stack(codebooks, dim = 0)
        return codebooks

    def get_codes_from_indices(self, indices):

        batch, quantize_dim = indices.shape[0], indices.shape[-1]

        # 可能会接收到形状为 'b h w q' 的索引(accept_image_fmap)

        indices, ps = pack([indices], 'b * q')

        # 由于量化丢失,可能会传入粗糙的索引,网络应该能够重建

        if quantize_dim < self.num_quantizers:
            assert self.quantize_dropout > 0., 'quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations'
            indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1)

        # 处理量化器丢失

        mask = indices == -1
        indices = indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later

        all_codes = get_at('q [c] d, b n q -> q b n d', self.codebooks, indices)

        # 屏蔽任何被丢弃的代码

        all_codes = all_codes.masked_fill(rearrange(mask, 'b n q -> q b n 1'), 0.)

        # 缩放代码

        scales = rearrange(self.scales, 'q d -> q 1 1 d')
        all_codes = all_codes * scales

        # 如果(accept_image_fmap = True),则返回形状(量化,批量,高度,宽度,维度)

        all_codes, = unpack(all_codes, ps, 'q b * d')

        return all_codes
    # 从给定的索引中获取输出
    def get_output_from_indices(self, indices):
        # 从索引中获取编码
        codes = self.get_codes_from_indices(indices)
        # 对编码进行求和
        codes_summed = reduce(codes, 'q ... -> ...', 'sum')
        # 对求和后的编码进行投影
        return self.project_out(codes_summed)

    # 前向传播函数
    def forward(
        self,
        x,
        return_all_codes = False,
        rand_quantize_dropout_fixed_seed = None
    ):
        # 获取量化器数量、量化丢弃倍数、设备信息
        num_quant, quant_dropout_multiple_of, device = self.num_quantizers, self.quantize_dropout_multiple_of, x.device

        # 对输入进行投影
        x = self.project_in(x)

        quantized_out = 0.
        residual = first(self.layers).bound(x)

        all_indices = []

        should_quantize_dropout = self.training and self.quantize_dropout

        # 从中随机选择一个层索引,用于进一步丢弃残差量化
        # 同时准备空索引
        if should_quantize_dropout:
            rand = random.Random(rand_quantize_dropout_fixed_seed) if exists(rand_quantize_dropout_fixed_seed) else random

            rand_quantize_dropout_index = rand.randrange(self.quantize_dropout_cutoff_index, num_quant)

            if quant_dropout_multiple_of != 1:
                rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1

            null_indices = torch.full(x.shape[:2], -1., device = device, dtype = torch.long)

        # 遍历所有层
        with autocast(enabled = False):
            for quantizer_index, (layer, scale) in enumerate(zip(self.layers, self.scales)):

                if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
                    all_indices.append(null_indices)
                    continue

                quantized, indices = layer(residual / scale)
                quantized = quantized * scale

                residual = residual - quantized.detach()
                quantized_out = quantized_out + quantized

                all_indices.append(indices)

        # 如果需要,进行投影
        quantized_out = self.project_out(quantized_out)

        # 将所有索引堆叠在一起
        all_indices = torch.stack(all_indices, dim = -1)

        ret = (quantized_out, all_indices)

        if not return_all_codes:
            return ret

        # 是否返回所有层中所有码书的所有编码
        all_codes = self.get_codes_from_indices(all_indices)

        # 返回所有编码的形状为 (量化器,批次,序列长度,码书维度)
        return (*ret, all_codes)
# 定义一个名为 GroupedResidualFSQ 的类,继承自 Module 类
class GroupedResidualFSQ(Module):
    # 初始化函数,接收参数 dim、groups、accept_image_fmap 和 kwargs
    def __init__(
        self,
        *,
        dim,
        groups = 1,
        accept_image_fmap = False,
        **kwargs
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 初始化类的属性 dim 和 groups
        self.dim = dim
        self.groups = groups
        # 断言 dim 能够被 groups 整除
        assert (dim % groups) == 0
        # 计算每个组的维度
        dim_per_group = dim // groups

        # 初始化类的属性 accept_image_fmap
        self.accept_image_fmap = accept_image_fmap

        # 初始化一个空的 ModuleList 对象 rvqs
        self.rvqs = nn.ModuleList([])

        # 循环创建 groups 个 ResidualFSQ 对象并添加到 rvqs 中
        for _ in range(groups):
            self.rvqs.append(ResidualFSQ(
                dim = dim_per_group,
                **kwargs
            ))

        # 获取第一个 ResidualFSQ 对象的 codebook_size 属性作为类的 codebook_size 属性
        self.codebook_size = self.rvqs[0].codebook_size

    # 定义 codebooks 属性,返回所有 rvqs 中的 codebooks 组成的张量
    @property
    def codebooks(self):
        return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))

    # 定义 split_dim 属性,根据 accept_image_fmap 的值返回不同的维度
    @property
    def split_dim(self):
        return 1 if self.accept_image_fmap else -1

    # 定义 get_codes_from_indices 方法,根据 indices 获取对应的 codes
    def get_codes_from_indices(self, indices):
        codes = tuple(rvq.get_codes_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
        return torch.stack(codes)

    # 定义 get_output_from_indices 方法,根据 indices 获取对应的 outputs
    def get_output_from_indices(self, indices):
        outputs = tuple(rvq.get_output_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
        return torch.cat(outputs, dim = self.split_dim)

    # 定义前向传播函数 forward,接收参数 x 和 return_all_codes
    def forward(
        self,
        x,
        return_all_codes = False
    ):
        # 获取输入 x 的形状和 split_dim
        shape, split_dim = x.shape, self.split_dim
        # 断言输入 x 在 split_dim 维度上的大小等于 dim

        assert shape[split_dim] == self.dim

        # 将特征维度分成 groups 组

        x = x.chunk(self.groups, dim = split_dim)

        forward_kwargs = dict(
            return_all_codes = return_all_codes,
            rand_quantize_dropout_fixed_seed = random.randint(0, 1e7)
        )

        # 对每个组分别调用对应的 ResidualFSQ 对象进行前向传播

        out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
        out = tuple(zip(*out))

        # 否则,获取所有的 zipped 输出并将它们组合起来

        quantized, all_indices, *maybe_all_codes = out

        quantized = torch.cat(quantized, dim = split_dim)
        all_indices = torch.stack(all_indices)

        ret = (quantized, all_indices, *maybe_all_codes)
        return ret

.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\residual_lfq.py

# 导入所需的库
import random
from math import log2
from functools import partial

import torch
from torch import nn
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from torch.cuda.amp import autocast

# 导入自定义的 LFQ 模块
from vector_quantize_pytorch.lookup_free_quantization import LFQ

# 导入 einops 库中的函数
from einops import rearrange, repeat, reduce, pack, unpack

# 导入自定义的 get_at 函数
from einx import get_at

# 辅助函数

# 检查变量是否存在
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 将数字向上取整到最接近的倍数
def round_up_multiple(num, mult):
    return ceil(num / mult) * mult

# 主类

class ResidualLFQ(Module):
    """ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """

    def __init__(
        self,
        *,
        dim,
        num_quantizers,
        codebook_size,
        quantize_dropout = False,
        quantize_dropout_cutoff_index = 0,
        quantize_dropout_multiple_of = 1,
        **kwargs
    ):
        super().__init__()
        codebook_dim = int(log2(codebook_size))

        requires_projection = codebook_dim != dim
        # 如果 codebook_dim 不等于 dim,则需要进行投影
        self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
        self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
        self.has_projections = requires_projection

        self.num_quantizers = num_quantizers

        self.layers = nn.ModuleList([])

        # 创建 num_quantizers 个 LFQ 层
        for ind in range(num_quantizers):
            codebook_scale = 2 ** -ind

            lfq = LFQ(
                dim = codebook_dim,
                codebook_scale = codebook_scale,
                **kwargs
            )

            self.layers.append(lfq)

        # 断言所有 LFQ 层都没有投影
        assert all([not lfq.has_projections for lfq in self.layers])

        self.quantize_dropout = quantize_dropout and num_quantizers > 1

        # 断言 quantize_dropout_cutoff_index 大于等于 0
        assert quantize_dropout_cutoff_index >= 0

        self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
        self.quantize_dropout_multiple_of = quantize_dropout_multiple_of  # 编码论文提出结构化的 dropout,这里设置为 4

    @property
    def codebooks(self):
        # 获取所有 LFQ 层的 codebook,并按维度 0 进行堆叠
        codebooks = [layer.codebook for layer in self.layers]
        codebooks = torch.stack(codebooks, dim = 0)
        return codebooks

    def get_codes_from_indices(self, indices):

        batch, quantize_dim = indices.shape[0], indices.shape[-1]

        # 可能接收到 'b h w q' 形状的 indices(accept_image_fmap)

        indices, ps = pack([indices], 'b * q')

        # 由于 quantize dropout,可能传入粗糙的 indices,网络应该能够重构

        if quantize_dim < self.num_quantizers:
            assert self.quantize_dropout > 0., '如果希望从较少的精细量化信号重构,则 quantize dropout 必须大于 0'
            indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1)

        # 处理量化器 dropout

        mask = indices == -1.
        indices = indices.masked_fill(mask, 0)  # 有一个虚拟代码被掩盖

        all_codes = get_at('q [c] d, b n q -> q b n d', self.codebooks, indices)

        # 掩盖任何被 dropout 的代码

        all_codes = all_codes.masked_fill(rearrange(mask, 'b n q -> q b n 1'), 0.)

        # 如果(accept_image_fmap = True),则返回形状为(quantize,batch,height,width,dimension)

        all_codes, = unpack(all_codes, ps, 'q b * d')

        return all_codes

    def get_output_from_indices(self, indices):
        codes = self.get_codes_from_indices(indices)
        codes_summed = reduce(codes, 'q ... -> ...', 'sum')
        return self.project_out(codes_summed)

    def forward(
        self,
        x,
        mask = None,
        return_all_codes = False,
        rand_quantize_dropout_fixed_seed = None
        ):
            # 获取量化器数量、量化丢弃的倍数、设备信息
            num_quant, quant_dropout_multiple_of, device = self.num_quantizers, self.quantize_dropout_multiple_of, x.device

            # 对输入进行投影
            x = self.project_in(x)

            # 初始化量化输出和残差
            quantized_out = 0.
            residual = x

            # 初始化损失列表和索引列表
            all_losses = []
            all_indices = []

            # 是否需要进行量化丢弃
            should_quantize_dropout = self.training and self.quantize_dropout

            # 随机选择一个层索引,用于进一步丢弃残差量化
            # 同时准备空索引和损失
            if should_quantize_dropout:
                rand = random.Random(rand_quantize_dropout_fixed_seed) if exists(rand_quantize_dropout_fixed_seed) else random

                rand_quantize_dropout_index = rand.randrange(self.quantize_dropout_cutoff_index, num_quant)

                if quant_dropout_multiple_of != 1:
                    rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1

                null_indices = torch.full(x.shape[:2], -1., device=device, dtype=torch.long)
                null_loss = torch.tensor(0., device=device, dtype=x.dtype)

            # 遍历所有层
            with autocast(enabled=False):
                for quantizer_index, layer in enumerate(self.layers):

                    # 如果需要进行量化丢弃且当前层索引大于随机选择的丢弃索引
                    if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
                        all_indices.append(null_indices)
                        all_losses.append(null_loss)
                        continue

                    # 进行量化操作,获取量化结果、索引和损失
                    quantized, indices, loss = layer(residual, mask=mask)

                    # 更新残差和量化输出
                    residual = residual - quantized.detach()
                    quantized_out = quantized_out + quantized

                    # 添加索引和损失到列表中
                    all_indices.append(indices)
                    all_losses.append(loss)

            # 对输出进行投影
            quantized_out = self.project_out(quantized_out)

            # 合并所有损失和索引
            all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices))

            # 返回结果
            ret = (quantized_out, all_indices, all_losses)

            # 如果不需要返回所有编码,则直接返回结果
            if not return_all_codes:
                return ret

            # 是否返回所有层中所有码书的所有编码
            all_codes = self.get_codes_from_indices(all_indices)

            # 返回所有编码的形状为(量化器,批次,序列长度,码书维度)
            return (*ret, all_codes)
# 定义一个名为 GroupedResidualLFQ 的类,继承自 Module 类
class GroupedResidualLFQ(Module):
    # 初始化函数,接受一些参数
    def __init__(
        self,
        *,
        dim,
        groups = 1,
        accept_image_fmap = False,
        **kwargs
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 初始化类的属性
        self.dim = dim
        self.groups = groups
        # 确保 dim 能够被 groups 整除
        assert (dim % groups) == 0
        dim_per_group = dim // groups

        self.accept_image_fmap = accept_image_fmap

        # 创建一个空的 ModuleList 对象
        self.rvqs = nn.ModuleList([])

        # 根据 groups 的数量循环创建 ResidualLFQ 对象并添加到 rvqs 中
        for _ in range(groups):
            self.rvqs.append(ResidualLFQ(
                dim = dim_per_group,
                **kwargs
            ))

    # 定义 codebooks 属性,返回所有 rvq 对象的 codebooks 组成的张量
    @property
    def codebooks(self):
        return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))

    # 定义 split_dim 属性,根据 accept_image_fmap 的值返回不同的维度
    @property
    def split_dim(self):
        return 1 if self.accept_image_fmap else -1

    # 根据 indices 获取每个 rvq 对象的 codes,并返回组合后的张量
    def get_codes_from_indices(self, indices):
        codes = tuple(rvq.get_codes_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
        return torch.stack(codes)

    # 根据 indices 获取每个 rvq 对象的 output,并返回组合后的张量
    def get_output_from_indices(self, indices):
        outputs = tuple(rvq.get_output_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
        return torch.cat(outputs, dim = self.split_dim)

    # 前向传播函数,接受输入 x 和一些参数
    def forward(
        self,
        x,
        mask = None,
        return_all_codes = False
    ):
        shape, split_dim = x.shape, self.split_dim
        assert shape[split_dim] == self.dim

        # 将特征维度按 split_dim 分成 groups 组

        x = x.chunk(self.groups, dim = split_dim)

        forward_kwargs = dict(
            mask = mask,
            return_all_codes = return_all_codes,
            rand_quantize_dropout_fixed_seed = random.randint(0, 1e7)
        )

        # 对每个 group 调用 residual vq

        out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
        out = tuple(zip(*out))

        # 否则,获取所有的 zipped 输出并组合它们

        quantized, all_indices, commit_losses, *maybe_all_codes = out

        quantized = torch.cat(quantized, dim = split_dim)
        all_indices = torch.stack(all_indices)
        commit_losses = torch.stack(commit_losses)

        ret = (quantized, all_indices, commit_losses, *maybe_all_codes)
        return ret
posted @ 2024-06-28 14:13  绝不原创的飞龙  阅读(18)  评论(0编辑  收藏  举报