Lucidrains-系列项目源码解析-四十五-
Lucidrains 系列项目源码解析(四十五)
.\lucidrains\triton-transformer\triton_transformer\cross_entropy.py
# 导入 torch 库
import torch
# 导入 torch 中的函数库
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 导入 triton 库
import triton
# 从 triton 库中导入 language 模块并重命名为 tl
import triton.language as tl
# 定义交叉熵损失函数,接受 logits(预测值)、labels(真实标签)、ignore_index(忽略的索引,默认为0)、use_triton(是否使用 triton 加速,默认为 False)
def cross_entropy_fn(logits, labels, ignore_index = 0., use_triton = False):
# 重新排列 logits 张量的维度,将 'b n c' 转换为 '(b n) c'
logits = rearrange(logits, 'b n c -> (b n) c')
# 重新排列 labels 张量的维度,将 'b n' 转换为 '(b n)'
labels = rearrange(labels, 'b n -> (b n)')
# 如果 use_triton 为 True,则使用 triton 库中的 cross_entropy 函数计算损失
if use_triton:
loss = triton.ops.cross_entropy(logits, labels)
# 否则使用 torch.nn.functional 库中的 cross_entropy 函数计算损失
else:
loss = F.cross_entropy(logits, labels, reduction = 'none')
# 创建一个掩码,标记 labels 中不等于 ignore_index 的位置
mask = (labels != ignore_index)
# 返回经过掩码处理后的损失的均值
return loss[mask].mean()
.\lucidrains\triton-transformer\triton_transformer\dropout.py
# 导入所需的库
import torch
from torch import autograd
import torch.nn.functional as F
import triton
import triton.language as tl
from random import randrange
# 定义常量 BLOCK_SIZE
BLOCK_SIZE = 1024
# Triton JIT 编译的函数,实现带有随机种子的 dropout 操作
@triton.jit
def _seeded_dropout(x_ptr, output_ptr, n_elements, p, seed, **meta):
BLOCK_SIZE = meta['BLOCK_SIZE']
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE * 4
off0 = block_start + BLOCK_SIZE * 0 + tl.arange(0, BLOCK_SIZE)
off1 = block_start + BLOCK_SIZE * 1 + tl.arange(0, BLOCK_SIZE)
off2 = block_start + BLOCK_SIZE * 2 + tl.arange(0, BLOCK_SIZE)
off3 = block_start + BLOCK_SIZE * 3 + tl.arange(0, BLOCK_SIZE)
mask0 = off0 < n_elements
mask1 = off1 < n_elements
mask2 = off2 < n_elements
mask3 = off3 < n_elements
x0 = tl.load(x_ptr + off0, mask = mask0)
x1 = tl.load(x_ptr + off1, mask = mask1)
x2 = tl.load(x_ptr + off2, mask = mask2)
x3 = tl.load(x_ptr + off3, mask = mask3)
r0, r1, r2, r3 = tl.random.rand4x(seed, off0)
keep0, keep1, keep2, keep3 = r0 > p, r1 > p, r2 > p, r3 > p
o0 = tl.where(keep0, x0 / (1 - p), 0.0)
o1 = tl.where(keep1, x1 / (1 - p), 0.0)
o2 = tl.where(keep2, x2 / (1 - p), 0.0)
o3 = tl.where(keep3, x3 / (1 - p), 0.0)
tl.store(output_ptr + off0, o0, mask = mask0)
tl.store(output_ptr + off1, o1, mask = mask1)
tl.store(output_ptr + off2, o2, mask = mask2)
tl.store(output_ptr + off3, o3, mask = mask3)
# 带有随机种子的 dropout 操作的包装函数
def seeded_dropout(x, p, seed):
output = torch.empty_like(x)
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE'] * 4),)
_seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE = BLOCK_SIZE)
return output
# 自定义 autograd.Function 类,实现 dropout 操作
class dropout_(autograd.Function):
@classmethod
def forward(cls, ctx, x, p):
seed = randrange(int(1e6))
ctx.p = p
ctx.seed = seed
return seeded_dropout(x, p, seed)
@classmethod
def backward(cls, ctx, dy):
p = ctx.p
seed = ctx.seed
return seeded_dropout(dy, p, seed), None
# dropout 操作的函数,根据 use_triton 参数选择使用 Triton 实现的 dropout 还是 PyTorch 自带的 dropout
def dropout_fn(x, p, use_triton = False):
if p == 0. or not x.requires_grad:
return x
if not use_triton:
return F.dropout(x, p, training = True)
return dropout_.apply(x, p)
.\lucidrains\triton-transformer\triton_transformer\layernorm.py
# 导入 torch 库
import torch
# 从 torch 库中导入 autograd 模块
from torch import autograd
# 从 torch 库中导入 functional 模块
import torch.nn.functional as F
# 导入 triton 库
import triton
# 从 triton 库中导入 language 模块并重命名为 tl
import triton.language as tl
# 从 triton_transformer.utils 模块中导入 calc_num_warps 和 exists 函数
from triton_transformer.utils import calc_num_warps, exists
# 定义 GAMMA_BLOCK_SIZE 常量为 64
GAMMA_BLOCK_SIZE = 64
# 定义 GAMMA_ROW_BLOCK_SIZE 常量为 64
GAMMA_ROW_BLOCK_SIZE = 64
# 定义 layernorm_kernel_forward_training 函数
@triton.jit
def layernorm_kernel_forward_training(
output_ptr,
mean_centered_ptr,
normed_ptr,
input_ptr,
gamma_ptr,
input_row_stride,
gamma_row_stride,
output_row_stride,
mean_centered_row_stride,
normed_row_stride,
n_cols,
stable,
eps,
**meta
):
# 获取当前程序的 ID
row_idx = tl.program_id(0)
# 从 meta 中获取 BLOCK_SIZE 常量
BLOCK_SIZE = meta['BLOCK_SIZE']
# 计算当前行的起始指针
row_start_ptr = input_ptr + row_idx * input_row_stride
# 计算当前行 gamma 的起始指针
gamma_row_start_ptr = gamma_ptr + row_idx * gamma_row_stride
# 生成列偏移量
col_offsets = tl.arange(0, BLOCK_SIZE)
# 计算当前行的输入指针
input_ptrs = row_start_ptr + col_offsets
# 计算当前行的 gamma 指针
gamma_ptrs = gamma_row_start_ptr + col_offsets
# 创建一个掩码,用于处理列偏移量小于 n_cols 的情况
mask = col_offsets < n_cols
# 从输入指针处加载数据到 row,如果掩码为 False,则加载 0.0
row = tl.load(input_ptrs, mask=mask, other=0.)
# 从 gamma 指针处加载数据到 gammas,如果掩码为 False,则加载 0.0
gammas = tl.load(gamma_ptrs, mask=mask, other=0.)
# 如果 stable 为 True
if stable:
# 计算当前行的最大值
row_max = tl.max(tl.where(mask, row, float('-inf')), axis=0)
# 对当前行进行归一化
row /= row_max
# 计算当前行的均值
row_mean = tl.sum(row, axis=0) / n_cols
# 计算当前行的中心化值
row_mean_centered = tl.where(mask, row - row_mean, 0.)
# 计算当前行的方差
row_var = tl.sum(row_mean_centered * row_mean_centered, axis=0) / n_cols
# 计算当前行的标准差的倒数
inv_var = 1. / tl.sqrt(row_var + eps)
# 计算当前行的归一化值
normed = row_mean_centered * inv_var
# 计算输出值
output = normed * gammas
# 计算输出行的起始指针
output_row_start_ptr = output_ptr + row_idx * output_row_stride
# 计算输出指针
output_ptrs = output_row_start_ptr + col_offsets
# 将输出值存储到输出指针处
tl.store(output_ptrs, output, mask=mask)
# 计算中心化行的起始指针
mean_centered_row_start_ptr = mean_centered_ptr + row_idx * mean_centered_row_stride
# 计算中心化指针
mean_centered_ptrs = mean_centered_row_start_ptr + col_offsets
# 将中心化值存储到中心化指针处
tl.store(mean_centered_ptrs, row_mean_centered, mask=mask)
# 计算归一化行的起始指针
normed_row_start_ptr = normed_ptr + row_idx * normed_row_stride
# 计算归一化指针
normed_ptrs = normed_row_start_ptr + col_offsets
# 将归一化值存储到归一化指针处
tl.store(normed_ptrs, normed, mask=mask)
# 定义 layernorm_kernel_forward_inference 函数
@triton.jit
def layernorm_kernel_forward_inference(
output_ptr,
input_ptr,
gamma_ptr,
input_row_stride,
gamma_row_stride,
output_row_stride,
n_cols,
stable,
eps,
**meta
):
# 获取当前程序的 ID
row_idx = tl.program_id(0)
# 从 meta 中获取 BLOCK_SIZE 常量
BLOCK_SIZE = meta['BLOCK_SIZE']
# 计算当前行的起始指针
row_start_ptr = input_ptr + row_idx * input_row_stride
# 计算当前行 gamma 的起始指针
gamma_row_start_ptr = gamma_ptr + row_idx * gamma_row_stride
# 生成列偏移量
col_offsets = tl.arange(0, BLOCK_SIZE)
# 计算当前行的输入指针
input_ptrs = row_start_ptr + col_offsets
# 计算当前行的 gamma 指针
gamma_ptrs = gamma_row_start_ptr + col_offsets
# 创建一个掩码,用于处理列偏移量小于 n_cols 的情况
mask = col_offsets < n_cols
# 从输入指针处加载数据到 row,如果掩码为 False,则加载 0.0
row = tl.load(input_ptrs, mask=mask, other=0.)
# 从 gamma 指针处加载数据到 gammas,如果掩码为 False,则加载 0.0
gammas = tl.load(gamma_ptrs, mask=mask, other=0.)
# 如果 stable 为 True
if stable:
# 计算当前行的最大值
row_max = tl.max(tl.where(mask, row, float('-inf')), axis=0)
# 对当前行进行归一化
row /= row_max
# 计算当前行的均值
row_mean = tl.sum(row, axis=0) / n_cols
# 计算当前行的中心化值
row_mean_centered = tl.where(mask, row - row_mean, 0.)
# 计算当前行的方差
row_var = tl.sum(row_mean_centered * row_mean_centered, axis=0) / n_cols
# 计算当前行的标准差的倒数
inv_var = 1. / tl.sqrt(row_var + eps)
# 计算当前行的归一化值
normed = row_mean_centered * inv_var
# 计算输出值
output = normed * gammas
# 计算输出行的起始指针
output_row_start_ptr = output_ptr + row_idx * output_row_stride
# 计算输出指针
output_ptrs = output_row_start_ptr + col_offsets
# 将输出值存储到输出指针处
tl.store(output_ptrs, output, mask=mask)
# 定义 layernorm_kernel_backward 函数
@triton.jit
def layernorm_kernel_backward(
output_ptr,
dy_ptr,
mean_centered_ptr,
output_row_stride,
dy_row_stride,
mean_centered_row_stride,
n_cols,
eps,
**meta
):
# 获取当前程序的 ID
row_idx = tl.program_id(0)
# 从 meta 中获取 BLOCK_SIZE 常量
BLOCK_SIZE = meta['BLOCK_SIZE']
# 计算当前行的 dy 起始指针
dy_row_start_ptr = dy_ptr + row_idx * dy_row_stride
# 计算当前行的中心化值起始指针
mean_centered_row_start_ptr = mean_centered_ptr + row_idx * mean_centered_row_stride
# 生成列偏移量
col_offsets = tl.arange(0, BLOCK_SIZE)
# 计算当前行的 dy 指针
dy_ptrs = dy_row_start_ptr + col_offsets
# 计算当前行的中心化值指针
mean_centered_ptrs = mean_centered_row_start_ptr + col_offsets
# 创建一个掩码,用于处理列偏移量小于 n_cols 的情况
mask = col_offsets < n_cols
# 从 dy 指针处加载数据到 dy,如果掩码为 False,则加载 0.0
dy = tl.load(dy_ptrs, mask=mask, other=0.)
# 从中心化值指针处加载数据到 mean_centered,如果掩码为 False,则加载 0.0
mean_centered = tl.load(mean_centered_ptrs, mask=mask, other=0.)
# 计算每行的方差
row_var = tl.sum(mean_centered * mean_centered, axis=0) / n_cols
# 计算每行的标准差的倒数
inv_var = 1. / tl.sqrt(row_var + eps)
# 对数据进行标准化处理
normed = mean_centered * inv_var
# 计算输出值
output = 1. / n_cols * inv_var * (n_cols * dy - tl.sum(dy, axis=0) - normed * tl.sum(dy * normed, axis=0))
# 计算输出行的起始指针
output_row_start_ptr = output_ptr + row_idx * output_row_stride
# 计算输出指针数组
output_ptrs = output_row_start_ptr + col_offsets
# 存储输出数据到指定的指针位置,使用掩码进行过滤
tl.store(output_ptrs, output, mask=mask)
# 定义一个使用 Triton JIT 编译的函数,用于计算 LayerNorm 操作的 gamma 反向传播
def layernorm_gamma_kernel_backward(
dgamma_ptr, # 存储计算得到的 dgamma 结果的指针
norm_ptr, # 存储 norm 数据的指针
dy_ptr, # 存储 dy 数据的指针
norm_stride, # norm 数据的步长
dy_stride, # dy 数据的步长
dgamma_row_stride, # dgamma 行步长
n_rows, # 数据行数
n_cols, # 数据列数
**meta # 其他元数据
):
# 获取当前程序的列索引和行索引
col_idx = tl.program_id(0)
row_idx = tl.program_id(1)
# 从元数据中获取 BLOCK_SIZE 和 ROW_BLOCK_SIZE
BLOCK_SIZE = meta['BLOCK_SIZE']
ROW_BLOCK_SIZE = meta['BLOCK_SIZE_ROW']
# 创建列偏移量和行偏移量
col_offsets = tl.arange(0, BLOCK_SIZE)
row_offsets = tl.arange(0, ROW_BLOCK_SIZE)
# 计算列范围和行范围
col_range = col_idx * BLOCK_SIZE + col_offsets
row_range = row_idx * ROW_BLOCK_SIZE + row_offsets
# 创建列掩码
col_mask = col_range < n_cols
# 创建掩码,用于过滤超出数据范围的行列
mask = (row_range < n_rows)[:, None] & col_mask[None, :]
# 更新 dy_ptr 和 norm_ptr 指针位置
dy_ptr += row_range[:, None] * dy_stride + col_range[None, :]
norm_ptr += row_range[:, None] * norm_stride + col_range[None, :]
# 从指定位置加载 dy 和 norm 数据
dy = tl.load(dy_ptr, mask=mask, other=0.)
norm = tl.load(norm_ptr, mask=mask, other=0.)
# 计算 dgamma
dgamma = tl.sum(dy * norm, axis=0)
# 更新 dgamma_ptr 指针位置
dgamma_ptr += row_idx * dgamma_row_stride + col_range
# 存储计算得到的 dgamma 结果
tl.store(dgamma_ptr, dgamma, mask=col_mask)
# 定义一个 autograd 函数 _layernorm
class _layernorm(autograd.Function):
@classmethod
def forward(cls, ctx, x, gamma, eps, stable):
# 获取输入 x 的形状和维度
shape = x.shape
dim = shape[-1]
x = x.view(-1, dim)
n_rows, n_cols = x.shape
# 扩展 gamma 到与 x 相同的形状
expanded_gamma = gamma[None, :].expand(n_rows, -1)
# 计算 BLOCK_SIZE 和 num_warps
BLOCK_SIZE = triton.next_power_of_2(n_cols)
num_warps = calc_num_warps(BLOCK_SIZE)
# 创建一个与 x 相同形状的输出张量
out = torch.empty_like(x)
# 保存 eps 到上下文中
ctx.eps = eps
if x.requires_grad:
# 创建 scaled_x 和 normed 张量
scaled_x = torch.empty_like(x)
normed = torch.empty_like(x)
# 调用 layernorm_kernel_forward_training 函数进行前向传播计算
layernorm_kernel_forward_training[(n_rows,)](
out,
scaled_x,
normed,
x,
expanded_gamma,
x.stride(0),
expanded_gamma.stride(0),
out.stride(0),
scaled_x.stride(0),
normed.stride(0),
n_cols,
stable,
eps,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
# 保存 scaled_x, gamma, out 到上下文中
ctx.save_for_backward(scaled_x, gamma, out)
else:
# 调用 layernorm_kernel_forward_inference 函数进行前向传播计算(无梯度)
layernorm_kernel_forward_inference[(n_rows,)](
out,
x,
expanded_gamma,
x.stride(0),
expanded_gamma.stride(0),
out.stride(0),
n_cols,
stable,
eps,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
# 返回输出张量,并恢复原始形状
return out.view(*shape)
@classmethod
def backward(cls, ctx, dy):
# 获取 dy 的形状和设备信息
shape, device = dy.shape, dy.device
dim = shape[-1]
dy = dy.view(-1, dim)
# 从上下文中获取保存的 scaled_x, gamma, normed 张量
scaled_x, gamma, normed = ctx.saved_tensors
n_rows, n_cols = dy.shape
# 计算 num_col_programs 和 num_row_programs
num_col_programs = triton.cdiv(n_cols, GAMMA_BLOCK_SIZE)
num_row_programs = triton.cdiv(n_rows, GAMMA_ROW_BLOCK_SIZE)
# 创建一个用于存储 dgamma 的张量
dgamma = torch.empty((num_row_programs, n_cols), device=device)
# 调用 layernorm_gamma_kernel_backward 函数进行 gamma 反向传播计算
layernorm_gamma_kernel_backward[(num_col_programs, num_row_programs)](
dgamma,
normed,
dy,
normed.stride(0),
dy.stride(0),
dgamma.stride(0),
n_rows,
n_cols,
num_warps=4,
BLOCK_SIZE=GAMMA_BLOCK_SIZE,
BLOCK_SIZE_ROW=GAMMA_ROW_BLOCK_SIZE
)
# 对 dgamma 沿指定维度求和
dgamma = dgamma.sum(dim=0)
# 计算 dxhat 和 dx
dxhat = dy * gamma
dx = torch.empty_like(dy)
# 计算 BLOCK_SIZE 和 num_warps
BLOCK_SIZE = triton.next_power_of_2(n_cols)
num_warps = calc_num_warps(BLOCK_SIZE)
# 调用 layernorm_kernel_backward 函数进行反向传播计算
layernorm_kernel_backward[(n_rows,)](
dx,
dxhat,
scaled_x,
dx.stride(0),
dxhat.stride(0),
scaled_x.stride(0),
n_cols,
ctx.eps,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
# 恢复原始形状并返回 dx, dgamma
dx = dx.view(*shape)
return dx, dgamma, None, None
# 对输入数据进行 Layer Normalization 处理
def layernorm(x, gamma, eps = 1e-5, use_triton = False, stable = False):
# 如果使用 Triton 加速库
if use_triton:
# 调用 Triton 提供的 Layer Normalization 函数
out = _layernorm.apply(x, gamma, eps, stable)
else:
# 如果不使用 Triton 加速库
if stable:
# 对输入数据进行稳定处理,将每个元素除以最大值
x = x / torch.amax(x, dim = -1, keepdim = True)
# 使用 PyTorch 提供的 Layer Normalization 函数
out = F.layer_norm(x, (x.shape[-1],), gamma, torch.zeros_like(gamma), eps = eps)
# 返回处理后的数据
return out
.\lucidrains\triton-transformer\triton_transformer\softmax.py
# 导入 torch 库
import torch
# 从 torch 库中导入 autograd 模块
from torch import autograd
# 从 torch.nn.functional 模块中导入 F 函数
import torch.nn.functional as F
# 导入 triton 库
import triton
# 从 triton.language 模块中导入 tl
import triton.language as tl
# 从 triton_transformer.utils 模块中导入 calc_num_warps 函数
from triton_transformer.utils import calc_num_warps
# 定义 softmax_kernel_forward 函数,使用 triton.jit 装饰器
@triton.jit
def softmax_kernel_forward(
output_ptr,
input_ptr,
input_row_stride,
output_row_stride,
n_cols,
causal,
**meta
):
# 获取当前程序的行索引
row_idx = tl.program_id(0)
# 获取 meta 字典中的 BLOCK_SIZE 值
BLOCK_SIZE = meta['BLOCK_SIZE']
# 计算当前行的起始指针
row_start_ptr = input_ptr + row_idx * input_row_stride
# 生成列偏移量
col_offsets = tl.arange(0, BLOCK_SIZE)
# 计算输入指针
input_ptrs = row_start_ptr + col_offsets
# 创建掩码,用于处理超出列数的情况
mask = col_offsets < n_cols
# 从输入指针加载数据到 row 变量,处理超出列数的情况
row = tl.load(input_ptrs, mask = mask, other = -float('inf'))
# 如果是因果的情况,进行处理
if causal:
causal_mask = col_offsets > (row_idx % n_cols)
row = row + tl.where(causal_mask, -float('inf'), 0.)
# 计算 row 减去最大值
row_minus_max = row - tl.max(row, axis=0)
# 计算 softmax 的分子
numerator = tl.exp(row_minus_max)
# 计算 softmax 的分母
denominator = tl.sum(numerator, axis=0)
# 计算 softmax 输出
softmax_output = numerator / denominator
# 计算输出行的起始指针
output_row_start_ptr = output_ptr + row_idx * output_row_stride
# 计算输出指针
output_ptrs = output_row_start_ptr + col_offsets
# 存储 softmax 输出到输出指针,处理超出列数的情况
tl.store(output_ptrs, softmax_output, mask = mask)
# 定义 softmax_kernel_backward 函数,使用 triton.jit 装饰器
@triton.jit
def softmax_kernel_backward(
output_ptr,
input_ptr,
grad_ptr,
grad_row_stride,
input_row_stride,
output_row_stride,
n_cols,
**meta
):
# 获取当前程序的行索引
row_idx = tl.program_id(0)
# 获取 meta 字典中的 BLOCK_SIZE 值
BLOCK_SIZE = meta['BLOCK_SIZE']
# 计算当前行的起始指针
row_start_ptr = input_ptr + row_idx * input_row_stride
grad_row_start_ptr = grad_ptr + row_idx * grad_row_stride
# 生成列偏移量
col_offsets = tl.arange(0, BLOCK_SIZE)
# 计算输入指针和梯度指���
input_ptrs = row_start_ptr + col_offsets
grad_ptrs = grad_row_start_ptr + col_offsets
# 创建掩码,用于处理超出列数的情况
mask = col_offsets < n_cols
# 从输入指针加载数据到 probs_row 变量,处理超出列数的情况
probs_row = tl.load(input_ptrs, mask = mask, other = 0.)
# 从梯度指针加载数据到 grad_row 变量,处理超出列数的情况
grad_row = tl.load(grad_ptrs, mask = mask, other = 0.)
# 计算 dxhat
dxhat = probs_row * grad_row
# 计算 softmax 梯度输出
softmax_grad_output = dxhat - probs_row * tl.sum(dxhat, axis = 0)
# 计算输出行的起始指针
output_row_start_ptr = output_ptr + row_idx * output_row_stride
# 计算输出指针
output_ptrs = output_row_start_ptr + col_offsets
# 存储 softmax 梯度输出到输出指针,处理超出列数的情况
tl.store(output_ptrs, softmax_grad_output, mask = mask)
# 定义 _softmax 类,继承自 autograd.Function
class _softmax(autograd.Function):
# 前向传播函数
@classmethod
def forward(self, ctx, x, causal):
# 获取输入张量的形状
shape = x.shape
# 将输入张量展平为二维张量
x = x.view(-1, shape[-1])
n_rows, n_cols = x.shape
# 计算 BLOCK_SIZE 和 num_warps
BLOCK_SIZE = triton.next_power_of_2(n_cols)
num_warps = calc_num_warps(BLOCK_SIZE)
# 创建与输入张量相同形状的空张量 y
y = torch.empty_like(x)
# 调用 softmax_kernel_forward 函数进行前向传播计算
softmax_kernel_forward[(n_rows,)](
y,
x,
x.stride(0),
y.stride(0),
n_cols,
causal,
num_warps = num_warps,
BLOCK_SIZE = BLOCK_SIZE,
)
# 如果输入张量需要梯度,保存 y 用于反向传播
if x.requires_grad:
ctx.save_for_backward(y)
return y.view(*shape)
# 反向传播函数
@classmethod
def backward(self, ctx, grad_probs):
# 获取梯度张量的形状
shape = grad_probs.shape
# 从上下文中获取保存的张量 probs
probs, = ctx.saved_tensors
# 将梯度张量展平为二维张量
grad_probs = grad_probs.view(-1, grad_probs.shape[-1])
n_rows, n_cols = grad_probs.shape
# 计算 BLOCK_SIZE 和 num_warps
BLOCK_SIZE = triton.next_power_of_2(n_cols)
num_warps = calc_num_warps(BLOCK_SIZE)
# 创建与 probs 张量相同形状的空张量 dx
dx = torch.empty_like(probs)
# 调用 softmax_kernel_backward 函数进行反向传播计算
softmax_kernel_backward[(n_rows,)](
dx,
probs,
grad_probs,
grad_probs.stride(0),
probs.stride(0),
dx.stride(0),
n_cols,
num_warps = num_warps,
BLOCK_SIZE = BLOCK_SIZE
)
# 返回 dx 和 None,None 表示不需要额外的梯度信息
return dx.view(*shape), None
# 定义 triton_softmax 函数,调用 _softmax 类的 apply 方法
triton_softmax = _softmax.apply
# 定义 softmax 函数,实现 softmax 操作
def softmax(x, causal = False, use_triton = False):
# 如果使用 triton 进行计算
if use_triton:
# 调用 triton_softmax 函数
return triton_softmax(x, causal)
else:
# 使用 PyTorch 的 F.softmax 函数
return F.softmax(x, dim = -1)
.\lucidrains\triton-transformer\triton_transformer\transformer.py
# 导入必要的库
from functools import partial
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
# 导入自定义的模块
from triton_transformer.layernorm import layernorm
from triton_transformer.softmax import softmax
from triton_transformer.cross_entropy import cross_entropy_fn
from triton_transformer.bmm import fused_relu_squared
from triton_transformer.dropout import dropout_fn
from triton_transformer.utils import exists, default
# 定义类
class PreNormResidual(nn.Module):
def __init__(self, dim, fn, use_triton = False):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
self.use_triton = use_triton
def forward(self, x, **kwargs):
use_triton = kwargs.get('use_triton', self.use_triton)
normed = layernorm(x, self.norm.weight, use_triton = use_triton)
return self.fn(normed, **kwargs) + x
# 辅助类
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head = 64,
heads = 8,
causal = False,
dropout = 0.,
use_triton = False
):
super().__init__()
self.use_triton = use_triton
self.heads = heads
self.scale = dim_head ** -0.5
self.causal = causal
inner_dim = dim_head * heads
self.dropout = dropout
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x, mask = None, use_triton = None):
use_triton = default(use_triton, self.use_triton)
h = self.heads
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
q = q * self.scale
sim = einsum('b i d, b j d -> b i j', q, k)
if exists(mask):
mask_value = -torch.finfo(sim.dtype).max
sim = sim.masked_fill(mask, mask_value)
attn = softmax(sim, causal = self.causal, use_triton = use_triton)
attn = dropout_fn(attn, self.dropout, use_triton = use_triton)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
out = self.to_out(out)
return dropout_fn(out, self.dropout, use_triton = use_triton)
class FeedForward(nn.Module):
def __init__(
self,
dim,
mult = 4,
dropout = 0.,
use_triton = False
):
super().__init__()
self.use_triton = use_triton
inner_dim = dim * mult
self.dropout = dropout
self.proj_in_weight = nn.Parameter(torch.randn(dim, inner_dim))
self.proj_out = nn.Linear(inner_dim, dim)
def forward(self, x, use_triton = None):
use_triton = default(use_triton, self.use_triton)
x = fused_relu_squared(x, self.proj_in_weight, use_triton = use_triton)
x = dropout_fn(x, self.dropout, use_triton = use_triton)
x = self.proj_out(x)
return x
# 主类
class Transformer(nn.Module):
def __init__(
self,
*,
dim,
num_tokens,
max_seq_len,
depth,
causal = False,
heads = 8,
dim_head = 64,
ff_dropout = 0.,
ff_mult = 4,
attn_dropout = 0.,
use_triton = False
):
# 调用父类的构造函数
super().__init__()
# 初始化最大序列长度
self.max_seq_len = max_seq_len
# 创建 token embedding 层
self.token_emb = nn.Embedding(num_tokens, dim)
# 创建位置 embedding 层
self.pos_emb = nn.Embedding(max_seq_len, dim)
# 初始化层列表
self.layers = nn.ModuleList([])
# 创建部分预归一化残差块
wrapper = partial(PreNormResidual, dim)
# 循环创建指定深度的注意力和前馈网络层
for _ in range(depth):
self.layers.append(nn.ModuleList([
wrapper(Attention(dim, heads = heads, dim_head = dim_head, causal = causal, dropout = attn_dropout, use_triton = use_triton)),
wrapper(FeedForward(dim, dropout = ff_dropout, mult = ff_mult, use_triton = use_triton))
]))
# 创建层归一化层
self.norm = nn.LayerNorm(dim)
# 创建输出层
self.to_logits = nn.Linear(dim, num_tokens)
# 创建掩码
self.use_triton = use_triton
self.causal = causal
# 根据是否自回归创建掩码
mask = torch.ones(max_seq_len, max_seq_len, dtype = torch.bool).triu(1) if causal else None
self.register_buffer('mask', mask, persistent = False)
def forward(
self,
x,
mask = None,
*,
labels = None,
use_triton = None
):
# 设置使用 Triton 加速的标志
use_triton = default(use_triton, self.use_triton)
# 获取序列长度和设备信息
n, device = x.shape[1], x.device
# 嵌入 token 并添加位置嵌入
x = self.token_emb(x)
pos_emb = self.pos_emb(torch.arange(n, device = device))
x = x + rearrange(pos_emb, 'n d -> () n d')
# 生成掩码,取决于是否自回归
assert not (self.causal and exists(mask)), 'mask is not needed during autoregressive mode'
if self.causal and not use_triton:
mask = self.mask[:n, :n]
mask = rearrange(mask, 'i j -> () i j')
elif not self.causal and exists(mask):
mask = rearrange(mask, 'b i -> b i ()') * rearrange(mask, 'b j -> b () j')
mask = ~mask
# 通过层
for attn, ff in self.layers:
x = attn(x, mask = mask, use_triton = use_triton)
x = ff(x, use_triton = use_triton)
# 进行层归一化
x = layernorm(x, self.norm.weight, use_triton = use_triton, stable = True)
# 计算 logits
logits = self.to_logits(x)
if not exists(labels):
return logits
# 计算损失
loss = cross_entropy_fn(logits, labels, ignore_index = 0, use_triton = use_triton)
return loss
.\lucidrains\triton-transformer\triton_transformer\utils.py
# 检查值是否不为 None
def exists(val):
return val is not None
# 如果值存在,则返回该值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 根据块大小计算 warp 数量
def calc_num_warps(block_size):
# 默认 warp 数量为 4
num_warps = 4
# 如果块大小大于等于 2048,则 warp 数量为 8
if block_size >= 2048:
num_warps = 8
# 如果块大小大于等于 4096,则 warp 数量为 16
if block_size >= 4096:
num_warps = 16
# 返回 warp 数量
return num_warps
.\lucidrains\triton-transformer\triton_transformer\__init__.py
# 从 triton_transformer.transformer 模块中导入 Transformer 类
from triton_transformer.transformer import Transformer
Uformer - Pytorch
Implementation of Uformer, Attention-based Unet, in Pytorch. It will only offer the concat-cross-skip connection.
This repository will be geared towards use in a project for learning protein structures. Specifically, it will include the ability to condition on time steps (needed for DDPM), as well as 2d relative positional encoding using rotary embeddings (instead of the bias on the attention matrix in the paper).
Install
$ pip install uformer-pytorch
Usage
import torch
from uformer_pytorch import Uformer
model = Uformer(
dim = 64, # initial dimensions after input projection, which increases by 2x each stage
stages = 4, # number of stages
num_blocks = 2, # number of transformer blocks per stage
window_size = 16, # set window size (along one side) for which to do the attention within
dim_head = 64,
heads = 8,
ff_mult = 4
)
x = torch.randn(1, 3, 256, 256)
pred = model(x) # (1, 3, 256, 256)
To condition on time for DDPM training
import torch
from uformer_pytorch import Uformer
model = Uformer(
dim = 64,
stages = 4,
num_blocks = 2,
window_size = 16,
dim_head = 64,
heads = 8,
ff_mult = 4,
time_emb = True # set this to true
)
x = torch.randn(1, 3, 256, 256)
time = torch.arange(1)
pred = model(x, time = time) # (1, 3, 256, 256)
Citations
@misc{wang2021uformer,
title = {Uformer: A General U-Shaped Transformer for Image Restoration},
author = {Zhendong Wang and Xiaodong Cun and Jianmin Bao and Jianzhuang Liu},
year = {2021},
eprint = {2106.03106},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
.\lucidrains\uformer-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'uformer-pytorch', # 包的名称
packages = find_packages(), # 查找并包含所有包
version = '0.0.8', # 版本号
license='MIT', # 许可证信息
description = 'Uformer - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/uformer-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'transformers',
'image segmentation',
'unet'
],
install_requires=[ # 安装依赖列表
'einops>=0.3',
'torch>=1.6'
],
classifiers=[ # 分类器列表
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\uformer-pytorch\uformer_pytorch\uformer_pytorch.py
# 导入 math 模块
import math
# 从 math 模块导入 log, pi, sqrt 函数
from math import log, pi, sqrt
# 从 functools 模块导入 partial 函数
from functools import partial
# 导入 torch 模块
import torch
# 从 torch 模块导入 nn, einsum 函数
from torch import nn, einsum
# 从 torch.nn 模块导入 functional 模块
import torch.nn.functional as F
# 导入 einops 模块中的 rearrange, repeat 函数
from einops import rearrange, repeat
# 定义常量 List 为 nn.ModuleList 类
List = nn.ModuleList
# 辅助函数
# 判断变量是否存在的函数
def exists(val):
return val is not None
# 返回默认值的函数
def default(val, d):
return val if exists(val) else d
# 将变量转换为元组的函数
def cast_tuple(val, depth = 1):
return val if isinstance(val, tuple) else (val,) * depth
# 位置嵌入
# 应用旋转位置嵌入的函数
def apply_rotary_emb(q, k, pos_emb):
sin, cos = pos_emb
dim_rotary = sin.shape[-1]
(q, q_pass), (k, k_pass) = map(lambda t: (t[..., :dim_rotary], t[..., dim_rotary:]), (q, k))
q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
q, k = map(lambda t: torch.cat(t, dim = -1), ((q, q_pass), (k, k_pass)))
return q, k
# 每两个元素旋转的函数
def rotate_every_two(x):
x = rearrange(x, '... (d j) -> ... d j', j = 2)
x1, x2 = x.unbind(dim = -1)
x = torch.stack((-x2, x1), dim = -1)
return rearrange(x, '... d j -> ... (d j)')
# 轴向旋转嵌入类
class AxialRotaryEmbedding(nn.Module):
def __init__(self, dim, max_freq = 10):
super().__init__()
self.dim = dim
scales = torch.linspace(1., max_freq / 2, self.dim // 4)
self.register_buffer('scales', scales)
def forward(self, x):
device, dtype, h, w = x.device, x.dtype, *x.shape[-2:]
seq_x = torch.linspace(-1., 1., steps = h, device = device)
seq_x = seq_x.unsqueeze(-1)
seq_y = torch.linspace(-1., 1., steps = w, device = device)
seq_y = seq_y.unsqueeze(-1)
scales = self.scales[(*((None,) * (len(seq_x.shape) - 1)), Ellipsis)]
scales = scales.to(x)
scales = self.scales[(*((None,) * (len(seq_y.shape) - 1)), Ellipsis)]
scales = scales.to(x)
seq_x = seq_x * scales * pi
seq_y = seq_y * scales * pi
x_sinu = repeat(seq_x, 'i d -> i j d', j = w)
y_sinu = repeat(seq_y, 'j d -> i j d', i = h)
sin = torch.cat((x_sinu.sin(), y_sinu.sin()), dim = -1)
cos = torch.cat((x_sinu.cos(), y_sinu.cos()), dim = -1)
sin, cos = map(lambda t: rearrange(t, 'i j d -> i j d'), (sin, cos))
sin, cos = map(lambda t: repeat(t, 'i j d -> () i j (d r)', r = 2), (sin, cos))
return sin, cos
# 时间正弦位置嵌入类
class TimeSinuPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = device) * -emb)
emb = einsum('i, j -> i j', x, emb)
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
return emb
# 辅助类
# 层归一化类
class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
# 预归一化��
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = LayerNorm(dim)
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
# 注意力类
class Attention(nn.Module):
def __init__(self, dim, dim_head = 64, heads = 8, window_size = 16):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
self.window_size = window_size
inner_dim = dim_head * heads
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1)
# 定义前向传播函数,接受输入 x,跳跃连接 skip,默认时间嵌入 time_emb 和位置嵌入 pos_emb
def forward(self, x, skip = None, time_emb = None, pos_emb = None):
# 获取头数 h,窗口大小 w,输入张量的批量大小 b
h, w, b = self.heads, self.window_size, x.shape[0]
# 如果时间嵌入存在,则将其重排维度并与输入相加
if exists(time_emb):
time_emb = rearrange(time_emb, 'b c -> b c () ()')
x = x + time_emb
# 将输入 x 转换为查询向量 q
q = self.to_q(x)
# 将键值对输入设置为 x
kv_input = x
# 如果跳跃连接存在,则将其与键值对输入连接在一起
if exists(skip):
kv_input = torch.cat((kv_input, skip), dim = 0)
# 将键值对输入转换为键 k 和值 v,并按维度进行分块
k, v = self.to_kv(kv_input).chunk(2, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) x y c', h = h), (q, k, v))
# 如果位置嵌入存在,则应用旋转位置嵌入到查询 q 和键 k 上
if exists(pos_emb):
q, k = apply_rotary_emb(q, k, pos_emb)
# 重排查询 q、键 k 和值 v 的维度
q, k, v = map(lambda t: rearrange(t, 'b (x w1) (y w2) c -> (b x y) (w1 w2) c', w1 = w, w2 = w), (q, k, v))
# 如果跳跃连接存在,则对键 k 和值 v 进行维度重排
if exists(skip):
k, v = map(lambda t: rearrange(t, '(r b) n d -> b (r n) d', r = 2), (k, v))
# 计算注意力相似度矩阵
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
# 对相似度矩阵进行 softmax 操作得到注意力权重
attn = sim.softmax(dim = -1)
# 根据注意力权重计算输出
out = einsum('b i j, b j d -> b i d', attn, v)
# 重排输出的维度
out = rearrange(out, '(b h x y) (w1 w2) c -> b (h c) (x w1) (y w2)', b = b, h = h, y = x.shape[-1] // w, w1 = w, w2 = w)
# 将输出传递给输出层并返回结果
return self.to_out(out)
# 定义一个前馈神经网络模块
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4):
super().__init__()
hidden_dim = dim * mult
# 输入投影层,将输入维度转换为隐藏维度
self.project_in = nn.Conv2d(dim, hidden_dim, 1)
# 输出投影层,包含卷积、GELU激活函数和再次卷积
self.project_out = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 3, padding = 1),
nn.GELU(),
nn.Conv2d(hidden_dim, dim, 1)
)
def forward(self, x, time_emb = None):
# 对输入进行投影
x = self.project_in(x)
# 如果存在时间嵌入,则将其重排并加到输入上
if exists(time_emb):
time_emb = rearrange(time_emb, 'b c -> b c () ()')
x = x + time_emb
# 返回经过输出投影层的结果
return self.project_out(x)
# 定义一个块模块
class Block(nn.Module):
def __init__(
self,
dim,
depth,
dim_head = 64,
heads = 8,
ff_mult = 4,
window_size = 16,
time_emb_dim = None,
rotary_emb = True
):
super().__init__()
self.attn_time_emb = None
self.ff_time_emb = None
# 如果存在时间嵌入维度,则创建注意力和前馈的时间嵌入
if exists(time_emb_dim):
self.attn_time_emb = nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
self.ff_time_emb = nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim * ff_mult))
# 如果使用轴向旋转嵌入,则创建位置嵌入
self.pos_emb = AxialRotaryEmbedding(dim_head) if rotary_emb else None
# 创建多个块层
self.layers = List([])
for _ in range(depth):
self.layers.append(List([
PreNorm(dim, Attention(dim, dim_head = dim_head, heads = heads, window_size = window_size)),
PreNorm(dim, FeedForward(dim, mult = ff_mult))
]))
def forward(self, x, skip = None, time = None):
attn_time_emb = None
ff_time_emb = None
# 如果存在时间信息,则计算注意力和前馈的时间嵌入
if exists(time):
assert exists(self.attn_time_emb) and exists(self.ff_time_emb), 'time_emb_dim must be given on init if you are conditioning based on time'
attn_time_emb = self.attn_time_emb(time)
ff_time_emb = self.ff_time_emb(time)
pos_emb = None
# 如果存在位置嵌入,则计算位置嵌入
if exists(self.pos_emb):
pos_emb = self.pos_emb(x)
# 遍历每个块层,进行注意力和前馈操作
for attn, ff in self.layers:
x = attn(x, skip = skip, time_emb = attn_time_emb, pos_emb = pos_emb) + x
x = ff(x, time_emb = ff_time_emb) + x
# 返回处理后的结果
return x
# 定义一个 Uformer 模块
class Uformer(nn.Module):
def __init__(
self,
dim = 64,
channels = 3,
stages = 4,
num_blocks = 2,
dim_head = 64,
window_size = 16,
heads = 8,
ff_mult = 4,
time_emb = False,
input_channels = None,
output_channels = None
):
# 调用父类的构造函数
super().__init__()
# 设置输入通道数为默认值或者与输出通道数相同
input_channels = default(input_channels, channels)
output_channels = default(output_channels, channels)
self.to_time_emb = None
time_emb_dim = None
# 如果需要时间嵌入
if time_emb:
time_emb_dim = dim
# 创建时间嵌入层
self.to_time_emb = nn.Sequential(
TimeSinuPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
# 输入通道到维度转换
self.project_in = nn.Sequential(
nn.Conv2d(input_channels, dim, 3, padding = 1),
nn.GELU()
)
# 维度到输出通道转换
self.project_out = nn.Sequential(
nn.Conv2d(dim, output_channels, 3, padding = 1),
)
# 下采样和上采样列表
self.downs = List([])
self.ups = List([])
# 将参数转换为指定深度的元组
heads, window_size, dim_head, num_blocks = map(partial(cast_tuple, depth = stages), (heads, window_size, dim_head, num_blocks))
# 遍历各个阶段
for ind, heads, window_size, dim_head, num_blocks in zip(range(stages), heads, window_size, dim_head, num_blocks):
is_last = ind == (stages - 1)
# 添加下采样模块
self.downs.append(List([
Block(dim, depth = num_blocks, dim_head = dim_head, heads = heads, ff_mult = ff_mult, window_size = window_size, time_emb_dim = time_emb_dim),
nn.Conv2d(dim, dim * 2, 4, stride = 2, padding = 1)
]))
# 添加上采样模块
self.ups.append(List([
nn.ConvTranspose2d(dim * 2, dim, 2, stride = 2),
Block(dim, depth = num_blocks, dim_head = dim_head, heads = heads, ff_mult = ff_mult, window_size = window_size, time_emb_dim = time_emb_dim)
]))
dim *= 2
# 如果是最后一个阶段,设置中间模块
if is_last:
self.mid = Block(dim = dim, depth = num_blocks, dim_head = dim_head, heads = heads, ff_mult = ff_mult, window_size = window_size, time_emb_dim = time_emb_dim)
# 前向传播函数
def forward(
self,
x,
time = None
):
# 如果存在时间信息
if exists(time):
assert exists(self.to_time_emb), 'time_emb must be set to true to condition on time'
time = time.to(x)
time = self.to_time_emb(time)
# 输入数据通过输入通道转换
x = self.project_in(x)
skips = []
# 对下采样模块进行迭代
for block, downsample in self.downs:
x = block(x, time = time)
skips.append(x)
x = downsample(x)
# 中间模块
x = self.mid(x, time = time)
# 对上采样模块进行迭代
for (upsample, block), skip in zip(reversed(self.ups), reversed(skips)):
x = upsample(x)
x = block(x, skip = skip, time = time)
# 输出数据通过输出通道转换
x = self.project_out(x)
return x
.\lucidrains\uformer-pytorch\uformer_pytorch\__init__.py
# 从uformer_pytorch.uformer_pytorch模块中导入Uformer类
from uformer_pytorch.uformer_pytorch import Uformer
UNet Stylegan2
An implementation of Stylegan2 with UNet Discriminator. This repository works largely the same way as Stylegan2 Pytorch. Simply replace all the stylegan2_pytorch
command with unet_stylegan2
instead.
Update: Results have been very good. Will need to investigate combining this with a few other techniques, and then I will write up full instructions for use.
Install
$ pip install unet-stylegan2
Usage
$ unet_stylegan2 --data ./path/to/data
Citations
@misc{karras2019analyzing,
title={Analyzing and Improving the Image Quality of StyleGAN},
author={Tero Karras and Samuli Laine and Miika Aittala and Janne Hellsten and Jaakko Lehtinen and Timo Aila},
year={2019},
eprint={1912.04958},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
@misc{schnfeld2020unet,
title={A U-Net Based Discriminator for Generative Adversarial Networks},
author={Edgar Schönfeld and Bernt Schiele and Anna Khoreva},
year={2020},
eprint={2002.12655},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
.\lucidrains\unet-stylegan2\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'unet_stylegan2', # 包的名称
packages = find_packages(), # 查找并包含所有包
scripts=['bin/unet_stylegan2'], # 包含可执行脚本
version = '0.5.1', # 版本号
license='GPLv3+', # 许可证
description = 'StyleGan2 with UNet Discriminator, in Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/unet-stylegan2', # 项目链接
keywords = ['generative adversarial networks', 'artificial intelligence'], # 关键词
install_requires=[ # 安装依赖
'fire',
'numpy',
'retry',
'tqdm',
'torch',
'torchvision',
'pillow',
'linear_attention_transformer>=0.12.1'
],
classifiers=[ # 分类
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\unet-stylegan2\unet_stylegan2\diff_augment.py
# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块
import torch.nn.functional as F
# 定义函数 DiffAugment,对输入进行不同类型的数据增强
def DiffAugment(x, types=[]):
# 遍历传入的增强类型列表
for p in types:
# 遍历对应增强类型的函数列表
for f in AUGMENT_FNS[p]:
# 对输入数据应用增强函数
x = f(x)
# 返回增强后的数据,保证内存格式为 torch.contiguous_format
return x.contiguous(memory_format=torch.contiguous_format)
# 定义函数 rand_brightness,对输入数据进行随机亮度增强
def rand_brightness(x):
# 对输入数据添加随机亮度
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
return x
# 定义函数 rand_saturation,对输入数据进行随机饱和度增强
def rand_saturation(x):
# 计算输入数据的均值
x_mean = x.mean(dim=1, keepdim=True)
# 对输入数据添加随机饱和度
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
return x
# 定义函数 rand_contrast,对输入数据进行随机对比度增强
def rand_contrast(x):
# 计算输入数据的均值
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
# 对输入数据添加随机对比度
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
return x
# 定义函数 rand_translation,对输入数据进行随机平移增强
def rand_translation(x, ratio=0.125):
# 计算平移的像素数
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
# 生成随机平移量
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
# 生成平移后的坐标网格
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(x.size(0), dtype=torch.long, device=x.device),
torch.arange(x.size(2), dtype=torch.long, device=x.device),
torch.arange(x.size(3), dtype=torch.long, device=x.device),
)
# 对坐标进行平移
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
# 对输入数据进行平移操作
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous(memory_format=torch.contiguous_format)
return x
# 定义函数 rand_cutout,对输入数据进行随机遮挡增强
def rand_cutout(x, ratio=0.5):
# 计算遮挡区域的大小
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
# 生成随机遮挡区域的偏移量
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
# 生成遮挡区域的坐标网格
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(x.size(0), dtype=torch.long, device=x.device),
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
)
# 对遮挡区域进行偏移
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
# 生成遮挡掩码
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
mask[grid_batch, grid_x, grid_y] = 0
# 对输入数据应用遮挡
x = x * mask.unsqueeze(1)
return x
# 定义增强函数字典,包含不同类型的增强函数列表
AUGMENT_FNS = {
'color': [rand_brightness, rand_saturation, rand_contrast],
'translation': [rand_translation],
'cutout': [rand_cutout],
}
.\lucidrains\unet-stylegan2\unet_stylegan2\unet_stylegan2.py
# 导入必要的库
import os
import sys
import math
import fire
import json
from tqdm import tqdm
from math import floor, log2
from random import random
from shutil import rmtree
from functools import partial
import multiprocessing
import numpy as np
import torch
from torch import nn
from torch.utils import data
import torch.nn.functional as F
from torch.optim import Adam
from torch.autograd import grad as torch_grad
import torchvision
from torchvision import transforms
from linear_attention_transformer import ImageLinearAttention
from PIL import Image
from pathlib import Path
# 尝试导入 apex 库,设置 APEX_AVAILABLE 变量
try:
from apex import amp
APEX_AVAILABLE = True
except:
APEX_AVAILABLE = False
# 检查是否有可用的 CUDA 设备
assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.'
# 获取 CPU 核心数量
num_cores = multiprocessing.cpu_count()
# 常量定义
# 支持的图片文件格式
EXTS = ['jpg', 'jpeg', 'png', 'webp']
# 微小的常数,用于避免除零错误
EPS = 1e-8
# 辅助类定义
# 自定义异常类,用于处理 NaN 异常
class NanException(Exception):
pass
# 指数移动平均类
class EMA():
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
# 随机应用类,根据概率应用不同的函数
class RandomApply(nn.Module):
def __init__(self, prob, fn, fn_else = lambda x: x):
super().__init__()
self.fn = fn
self.fn_else = fn_else
self.prob = prob
def forward(self, x):
fn = self.fn if random() < self.prob else self.fn_else
return fn(x)
# 残差连接类
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
# 展平类
class Flatten(nn.Module):
def __init__(self, index):
super().__init__()
self.index = index
def forward(self, x):
return x.flatten(self.index)
# Rezero 类
class Rezero(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
self.g = nn.Parameter(torch.zeros(1))
def forward(self, x):
return self.fn(x) * self.g
# 图像的自注意力和前馈网络层
attn_and_ff = lambda chan: nn.Sequential(*[
Residual(Rezero(ImageLinearAttention(chan, norm_queries = True))),
Residual(Rezero(nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(), nn.Conv2d(chan * 2, chan, 1))))
])
# 辅助函数定义
# 返回默认值
def default(value, d):
return d if value is None else value
# 无限循环迭代器
def cycle(iterable):
while True:
for i in iterable:
yield i
# 将元素转换为列表
def cast_list(el):
return el if isinstance(el, list) else [el]
# 检查张量是否为空
def is_empty(t):
if isinstance(t, torch.Tensor):
return t.nelement() == 0
return t is None
# 如果张量包含 NaN,则抛出异常
def raise_if_nan(t):
if torch.isnan(t):
raise NanException
# 反向传播函数,支持混合精度训练
def loss_backwards(fp16, loss, optimizer, **kwargs):
if fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward(**kwargs)
else:
loss.backward(**kwargs)
# 计算梯度惩罚项
def gradient_penalty(images, outputs, weight = 10):
batch_size = images.shape[0]
gradients = torch_grad(outputs=outputs, inputs=images,
grad_outputs=list(map(lambda t: torch.ones(t.size()).cuda(), outputs)),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = gradients.reshape(batch_size, -1)
return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
# 计算潜在空间长度
def calc_pl_lengths(styles, images):
num_pixels = images.shape[2] * images.shape[3]
pl_noise = torch.randn(images.shape).cuda() / math.sqrt(num_pixels)
outputs = (images * pl_noise).sum()
pl_grads = torch_grad(outputs=outputs, inputs=styles,
grad_outputs=torch.ones(outputs.shape).cuda(),
create_graph=True, retain_graph=True, only_inputs=True)[0]
return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt()
# 生成随机噪声
def noise(n, latent_dim):
return torch.randn(n, latent_dim).cuda()
# 生成多层随机噪声列表
def noise_list(n, layers, latent_dim):
# 返回一个包含噪声和层信息的元组列表
return [(noise(n, latent_dim), layers)]
# 生成一个混合的噪声列表,包含两个噪声列表的和
def mixed_list(n, layers, latent_dim):
# 随机选择一个整数作为分割点
tt = int(torch.rand(()).numpy() * layers)
# 返回两个噪声列表的和
return noise_list(n, tt, latent_dim) + noise_list(n, layers - tt, latent_dim)
# 将潜在向量描述转换为样式向量和层数的元组列表
def latent_to_w(style_vectorizer, latent_descr):
return [(style_vectorizer(z), num_layers) for z, num_layers in latent_descr]
# 生成一个指定大小的图像噪声
def image_noise(n, im_size):
return torch.FloatTensor(n, im_size, im_size, 1).uniform_(0., 1.).cuda()
# 返回一个带有泄漏整流的激活函数
def leaky_relu(p=0.2):
return nn.LeakyReLU(p)
# 将输入参数按照最大批量大小分块,对模型进行评估
def evaluate_in_chunks(max_batch_size, model, *args):
split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args))))
chunked_outputs = [model(*i) for i in split_args]
if len(chunked_outputs) == 1:
return chunked_outputs[0]
return torch.cat(chunked_outputs, dim=0)
# 将样式定义转换为张量
def styles_def_to_tensor(styles_def):
return torch.cat([t[:, None, :].expand(-1, n, -1) for t, n in styles_def], dim=1)
# 设置模型参数是否需要梯度
def set_requires_grad(model, bool):
for p in model.parameters():
p.requires_grad = bool
# Slerp 插值函数
def slerp(val, low, high):
low_norm = low / torch.norm(low, dim=1, keepdim=True)
high_norm = high / torch.norm(high, dim=1, keepdim=True)
omega = torch.acos((low_norm * high_norm).sum(1))
so = torch.sin(omega)
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res
# 热身函数,用于在一定步数内线性增加数值
def warmup(start, end, max_steps, current_step):
if current_step > max_steps:
return end
return (end - start) * (current_step / max_steps) + start
# 对张量进行对数运算
def log(t, eps = 1e-6):
return torch.log(t + eps)
# 生成 CutMix 的坐标
def cutmix_coordinates(height, width, alpha = 1.):
lam = np.random.beta(alpha, alpha)
cx = np.random.uniform(0, width)
cy = np.random.uniform(0, height)
w = width * np.sqrt(1 - lam)
h = height * np.sqrt(1 - lam)
x0 = int(np.round(max(cx - w / 2, 0)))
x1 = int(np.round(min(cx + w / 2, width)))
y0 = int(np.round(max(cy - h / 2, 0)))
y1 = int(np.round(min(cy + h / 2, height)))
return ((y0, y1), (x0, x1)), lam
# 执行 CutMix 操作
def cutmix(source, target, coors, alpha = 1.):
source, target = map(torch.clone, (source, target))
((y0, y1), (x0, x1)), _ = coors
source[:, :, y0:y1, x0:x1] = target[:, :, y0:y1, x0:x1]
return source
# 对源和目标进行遮罩操作
def mask_src_tgt(source, target, mask):
return source * mask + (1 - mask) * target
# 数据集
# 将 RGB 图像转换为带透明通道的图像
def convert_rgb_to_transparent(image):
if image.mode == 'RGB':
return image.convert('RGBA')
return image
# 将带透明通道的图像转换为 RGB 图像
def convert_transparent_to_rgb(image):
if image.mode == 'RGBA':
return image.convert('RGB')
return image
# 扩展灰度图像通道数
class expand_greyscale(object):
def __init__(self, num_channels):
self.num_channels = num_channels
def __call__(self, tensor):
return tensor.expand(self.num_channels, -1, -1)
# 调整图像大小至最小尺寸
def resize_to_minimum_size(min_size, image):
if max(*image.size) < min_size:
return torchvision.transforms.functional.resize(image, min_size)
return image
# 数据集类
class Dataset(data.Dataset):
def __init__(self, folder, image_size, transparent = False, aug_prob = 0.):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in EXTS for p in Path(f'{folder}').glob(f'**/*.{ext}')]
convert_image_fn = convert_transparent_to_rgb if not transparent else convert_rgb_to_transparent
num_channels = 3 if not transparent else 4
self.transform = transforms.Compose([
transforms.Lambda(convert_image_fn),
transforms.Lambda(partial(resize_to_minimum_size, image_size)),
transforms.Resize(image_size),
RandomApply(aug_prob, transforms.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.98, 1.02)), transforms.CenterCrop(image_size)),
transforms.ToTensor(),
transforms.Lambda(expand_greyscale(num_channels))
])
def __len__(self):
return len(self.paths)
# 定义一个特殊方法,用于获取对象中指定索引位置的元素
def __getitem__(self, index):
# 获取指定索引位置的路径
path = self.paths[index]
# 打开指定路径的图像文件
img = Image.open(path)
# 对图像进行变换处理并返回
return self.transform(img)
# 定义一个生成器块类
class GeneratorBlock(nn.Module):
# 初始化函数
def __init__(self, latent_dim, input_channel, upsample, rgba=False):
super().__init__()
self.input_channel = input_channel
# 将输入的潜在向量映射到输入通道数
self.to_style = nn.Linear(latent_dim, input_channel)
# 如果是 RGBA 模式,则输出通道数为 4,否则为 3
out_filters = 3 if not rgba else 4
# 定义卷积层,不进行调制
self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False)
# 如果需要上采样,则定义上采样层
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
# 前向传播函数
def forward(self, x, prev_rgb, istyle):
b, c, h, w = x.shape
# 将潜在向量映射到输入通道数
style = self.to_style(istyle)
# 使用卷积层进行特征提取
x = self.conv(x, style)
# 如果有上一个 RGB 图像,则进行残差连接
if prev_rgb is not None:
x = x + prev_rgb
# 如果需要上采样,则进行上采样操作
if self.upsample is not None:
x = self.upsample(x)
return x
# 初始化函数,定义生成器的结构
def __init__(self, latent_dim, input_channels, filters, upsample = True, upsample_rgb = True, rgba = False):
# 调用父类的初始化函数
super().__init__()
# 如果需要上采样,则创建上采样层
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None
# 创建将潜在向量映射到输入通道的全连接层
self.to_style1 = nn.Linear(latent_dim, input_channels)
# 创建将噪声映射到滤波器数量的全连接层
self.to_noise1 = nn.Linear(1, filters)
# 创建卷积层,使用自定义的Conv2DMod类
self.conv1 = Conv2DMod(input_channels, filters, 3)
# 创建将潜在向量映射到滤波器数量的全连接层
self.to_style2 = nn.Linear(latent_dim, filters)
# 创建将噪声映射到滤波器数量的全连接层
self.to_noise2 = nn.Linear(1, filters)
# 创建卷积层,使用自定义的Conv2DMod类
self.conv2 = Conv2DMod(filters, filters, 3)
# 定义激活函数为LeakyReLU
self.activation = leaky_relu()
# 创建RGBBlock实例,用于生成RGB输出
self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba)
# 前向传播函数,定义生成器的前向传播过程
def forward(self, x, prev_rgb, istyle, inoise):
# 如果需要上采样,则对输入进行上采样
if self.upsample is not None:
x = self.upsample(x)
# 裁剪噪声张量,使其与输入张量的尺寸相匹配
inoise = inoise[:, :x.shape[2], :x.shape[3], :]
# 将噪声映射到滤波器数量,并进行维度变换
noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1))
noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1))
# 将潜在向量映射到输入通道,并进行卷积操作
style1 = self.to_style1(istyle)
x = self.conv1(x, style1)
x = self.activation(x + noise1)
# 将潜在向量映射到滤波器数量,并进行卷积操作
style2 = self.to_style2(istyle)
x = self.conv2(x, style2)
x = self.activation(x + noise2)
# 生成RGB输出
rgb = self.to_rgb(x, prev_rgb, istyle)
return x, rgb
# 定义一个包含两个卷积层和激活函数的序列模块
def double_conv(chan_in, chan_out):
return nn.Sequential(
nn.Conv2d(chan_in, chan_out, 3, padding=1), # 3x3卷积层,输入通道数为chan_in,输出通道数为chan_out,填充为1
leaky_relu(), # 使用LeakyReLU激活函数
nn.Conv2d(chan_out, chan_out, 3, padding=1), # 3x3卷积层,输入通道数为chan_out,输出通道数为chan_out,填充为1
leaky_relu() # 使用LeakyReLU激活函数
)
# 定义一个下采样块模块
class DownBlock(nn.Module):
def __init__(self, input_channels, filters, downsample=True):
super().__init__()
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1)) # 1x1卷积层,输入通道数为input_channels,输出通道数为filters,步长为2或1
self.net = double_conv(input_channels, filters) # 使用double_conv函数创建卷积层序列
self.down = nn.Conv2d(filters, filters, 3, padding=1, stride=2) if downsample else None # 下采样卷积层,输入通道数为filters,输出通道数为filters,填充为1,步长为2或None
def forward(self, x):
res = self.conv_res(x) # 对输入x进行1x1卷积
x = self.net(x) # 使用卷积层序列处理输入x
unet_res = x
if self.down is not None:
x = self.down(x) # 如果存在下采样卷积层,则对x进行下采样
x = x + res # 将1x1卷积结果与处理后的x相加
return x, unet_res
# 定义一个上采样块模块
class UpBlock(nn.Module):
def __init__(self, input_channels, filters):
super().__init__()
self.conv_res = nn.ConvTranspose2d(input_channels // 2, filters, 1, stride=2) # 转置卷积层,输入通道数为input_channels的一半,输出通道数为filters,步长为2
self.net = double_conv(input_channels, filters) # 使用double_conv函数创建卷积层序列
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) # 上采样层,尺度因��为2,插值模式为双���性插值,不对齐角点
def forward(self, x, res):
*_, h, w = x.shape
conv_res = self.conv_res(x, output_size=(h * 2, w * 2)) # 对输入x进行转置卷积
x = self.up(x) # 对输入x进行上采样
x = torch.cat((x, res), dim=1) # 在通道维度上拼接x和res
x = self.net(x) # 使用卷积层序列处理拼接后的x
x = x + conv_res # 将转置卷积结果与处理后的x相加
return x
# 定义一个生成器模块
class Generator(nn.Module):
def __init__(self, image_size, latent_dim, network_capacity=16, transparent=False, no_const=False, fmap_max=512):
super().__init__()
self.image_size = image_size
self.latent_dim = latent_dim
self.num_layers = int(log2(image_size) - 1)
filters = [network_capacity * (2 ** (i + 1)) for i in range(self.num_layers)][::-1]
set_fmap_max = partial(min, fmap_max)
filters = list(map(set_fmap_max, filters))
init_channels = filters[0]
filters = [init_channels, *filters]
in_out_pairs = zip(filters[:-1], filters[1:])
self.no_const = no_const
if no_const:
self.to_initial_block = nn.ConvTranspose2d(latent_dim, init_channels, 4, 1, 0, bias=False) # 转置卷积层,输入通道数为latent_dim,输出通道数为init_channels,核大小为4,步长为1,填充为0,无偏置
else:
self.initial_block = nn.Parameter(torch.randn((1, init_channels, 4, 4))) # 初始化块参数为随机张量
self.initial_conv = nn.Conv2d(filters[0], filters[0], 3, padding=1) # 3x3卷积层,输入通道数为filters[0],输出通道数为filters[0],填充为1
self.blocks = nn.ModuleList([]) # 创建模块列表
self.attns = nn.ModuleList([]) # 创建模块列表
for ind, (in_chan, out_chan) in enumerate(in_out_pairs):
not_first = ind != 0
not_last = ind != (self.num_layers - 1)
num_layer = self.num_layers - ind
attn_fn = attn_and_ff(in_chan) # 获取注意力函数
self.attns.append(attn_fn) # 添加到注意力模块列表
block = GeneratorBlock(
latent_dim,
in_chan,
out_chan,
upsample=not_first,
upsample_rgb=not_last,
rgba=transparent
)
self.blocks.append(block) # 添加生成器块模块到模块列表
def forward(self, styles, input_noise):
batch_size = styles.shape[0]
image_size = self.image_size
if self.no_const:
avg_style = styles.mean(dim=1)[:, :, None, None]
x = self.to_initial_block(avg_style) # 使用平均风格向量生成初始块
else:
x = self.initial_block.expand(batch_size, -1, -1, -1) # 扩展初始块参数
x = self.initial_conv(x) # 对初始块进行卷积
styles = styles.transpose(0, 1) # 转置风格张量
rgb = None
for style, block, attn in zip(styles, self.blocks, self.attns):
if attn is not None:
x = attn(x) # 如果存在注意力模块,则应用注意力
x, rgb = block(x, rgb, style, input_noise) # 使用生成器块模块处理x和rgb
return rgb # 返回rgb
class Discriminator(nn.Module):
# 初始化函数,设置神经网络的参数
def __init__(self, image_size, network_capacity = 16, transparent = False, fmap_max = 512):
# 调用父类的初始化函数
super().__init__()
# 计算网络层数
num_layers = int(log2(image_size) - 3)
# 初始化滤波器数量
num_init_filters = 3 if not transparent else 4
blocks = []
# 计算每一层的滤波器数量
filters = [num_init_filters] + [(network_capacity) * (2 ** i) for i in range(num_layers + 1)]
# 设置最大滤波器数量
set_fmap_max = partial(min, fmap_max)
filters = list(map(set_fmap_max, filters))
filters[-1] = filters[-2]
# 组合输入输出通道数
chan_in_out = list(zip(filters[:-1], filters[1:]))
chan_in_out = list(map(list, chan_in_out))
down_blocks = []
attn_blocks = []
# 遍历每一层,创建下采样块和注意力块
for ind, (in_chan, out_chan) in enumerate(chan_in_out):
num_layer = ind + 1
is_not_last = ind != (len(chan_in_out) - 1)
block = DownBlock(in_chan, out_chan, downsample = is_not_last)
down_blocks.append(block)
attn_fn = attn_and_ff(out_chan)
attn_blocks.append(attn_fn)
# 将下采样块和注意力块转换为 ModuleList
self.down_blocks = nn.ModuleList(down_blocks)
self.attn_blocks = nn.ModuleList(attn_blocks)
last_chan = filters[-1]
# 定义输出层
self.to_logit = nn.Sequential(
leaky_relu(),
nn.AvgPool2d(image_size // (2 ** num_layers)),
Flatten(1),
nn.Linear(last_chan, 1)
)
self.conv = double_conv(last_chan, last_chan)
# 反向遍历通道输入输出,创建上采样块
dec_chan_in_out = chan_in_out[:-1][::-1]
self.up_blocks = nn.ModuleList(list(map(lambda c: UpBlock(c[1] * 2, c[0]), dec_chan_in_out)))
self.conv_out = nn.Conv2d(3, 1, 1)
# 前向传播函数
def forward(self, x):
b, *_ = x.shape
residuals = []
# 遍历下采样块和注意力块
for (down_block, attn_block) in zip(self.down_blocks, self.attn_blocks):
x, unet_res = down_block(x)
residuals.append(unet_res)
if attn_block is not None:
x = attn_block(x)
x = self.conv(x) + x
enc_out = self.to_logit(x)
# 反向遍历上采样块,生成解码输出
for (up_block, res) in zip(self.up_blocks, residuals[:-1][::-1]):
x = up_block(x, res)
dec_out = self.conv_out(x)
return enc_out.squeeze(), dec_out
class StyleGAN2(nn.Module):
# 定义 StyleGAN2 类,继承自 nn.Module
def __init__(self, image_size, latent_dim = 512, fmap_max = 512, style_depth = 8, network_capacity = 16, transparent = False, fp16 = False, steps = 1, lr = 1e-4, ttur_mult = 2, no_const = False, lr_mul = 0.1, aug_types = ['translation', 'cutout']):
# 初始化函数,接受多个参数
super().__init__()
# 调用父类的初始化函数
self.lr = lr
self.steps = steps
self.ema_updater = EMA(0.995)
# 设置学习率、步数和指数移动平均更新器
self.S = StyleVectorizer(latent_dim, style_depth, lr_mul = lr_mul)
self.G = Generator(image_size, latent_dim, network_capacity, transparent = transparent, no_const = no_const, fmap_max = fmap_max)
self.D = Discriminator(image_size, network_capacity, transparent = transparent, fmap_max = fmap_max)
# 创建 StyleVectorizer、Generator 和 Discriminator 实例
self.SE = StyleVectorizer(latent_dim, style_depth, lr_mul = lr_mul)
self.GE = Generator(image_size, latent_dim, network_capacity, transparent = transparent, no_const = no_const)
# 创建额外的 StyleVectorizer 和 Generator 实例
self.D_aug = AugWrapper(self.D, image_size, aug_types)
# 创建用于增强所有输入到鉴别器的包装器
set_requires_grad(self.SE, False)
set_requires_grad(self.GE, False)
# 设置 SE 和 GE 的梯度计算为 False
generator_params = list(self.G.parameters()) + list(self.S.parameters())
self.G_opt = Adam(generator_params, lr = self.lr, betas=(0.5, 0.9))
self.D_opt = Adam(self.D.parameters(), lr = self.lr * ttur_mult, betas=(0.5, 0.9))
# 设置生成器和鉴别器的优化器
self._init_weights()
self.reset_parameter_averaging()
# 初始化权重和参数平均化
self.cuda()
# 将模型移至 GPU
self.fp16 = fp16
if fp16:
(self.S, self.G, self.D, self.SE, self.GE), (self.G_opt, self.D_opt) = amp.initialize([self.S, self.G, self.D, self.SE, self.GE], [self.G_opt, self.D_opt], opt_level='O1')
# 如果启用混合精度训练,则初始化混合精度训练
def _init_weights(self):
# 初始化权重函数
for m in self.modules():
if type(m) in {nn.Conv2d, nn.Linear}:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
# 对卷积层和全连接层进行权重初始化
for block in self.G.blocks:
nn.init.zeros_(block.to_noise1.weight)
nn.init.zeros_(block.to_noise2.weight)
nn.init.zeros_(block.to_noise1.bias)
nn.init.zeros_(block.to_noise2.bias)
# 初始化生成器中的噪声层参数
def EMA(self):
# 指数移动平均函数
def update_moving_average(ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.ema_updater.update_average(old_weight, up_weight)
# 更新移动平均参数
update_moving_average(self.SE, self.S)
update_moving_average(self.GE, self.G)
# 更新 SE 和 GE 的移动平均参数
def reset_parameter_averaging(self):
# 重置参数平均化函数
self.SE.load_state_dict(self.S.state_dict())
self.GE.load_state_dict(self.G.state_dict())
# 将 SE 和 GE 的状态字典加载到 S 和 G 中
def forward(self, x):
# 前向传播函数
return x
# 返回输入 x
class Trainer():
# 定义 Trainer 类
# 初始化函数,设置模型参数和训练参数
def __init__(self, name, results_dir, models_dir, image_size, network_capacity, transparent = False, batch_size = 4, mixed_prob = 0.9, gradient_accumulate_every=1, lr = 2e-4, ttur_mult = 2, num_workers = None, save_every = 1000, trunc_psi = 0.6, fp16 = False, no_const = False, aug_prob = 0., dataset_aug_prob = 0., cr_weight = 0.2, apply_pl_reg = False, lr_mul = 0.1, *args, **kwargs):
# 存储 GAN 参数
self.GAN_params = [args, kwargs]
self.GAN = None
# 设置模型名称、结果目录、模型目录、配置文件路径
self.name = name
self.results_dir = Path(results_dir)
self.models_dir = Path(models_dir)
self.config_path = self.models_dir / name / '.config.json'
# 检查图像大小是否为2的幂次方
assert log2(image_size).is_integer(), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
self.image_size = image_size
self.network_capacity = network_capacity
self.transparent = transparent
self.no_const = no_const
self.aug_prob = aug_prob
# 设置学习率、TTUR倍数、学习率倍数、批量大小、工作进程数、混合概率
self.lr = lr
self.ttur_mult = ttur_mult
self.lr_mul = lr_mul
self.batch_size = batch_size
self.num_workers = num_workers
self.mixed_prob = mixed_prob
self.save_every = save_every
self.steps = 0
self.av = None
self.trunc_psi = trunc_psi
self.apply_pl_reg = apply_pl_reg
self.pl_mean = None
self.gradient_accumulate_every = gradient_accumulate_every
# 检查是否支持混合精度训练
assert not fp16 or fp16 and APEX_AVAILABLE, 'Apex is not available for you to use mixed precision training'
self.fp16 = fp16
self.d_loss = 0
self.g_loss = 0
self.last_gp_loss = 0
self.last_cr_loss = 0
# 初始化指数移动平均
self.pl_length_ma = EMA(0.99)
self.init_folders()
self.loader = None
self.dataset_aug_prob = dataset_aug_prob
self.cr_weight = cr_weight
# 初始化 GAN 模型
def init_GAN(self):
args, kwargs = self.GAN_params
self.GAN = StyleGAN2(lr = self.lr, ttur_mult = self.ttur_mult, lr_mul = self.lr_mul, image_size = self.image_size, network_capacity = self.network_capacity, transparent = self.transparent, fp16 = self.fp16, no_const = self.no_const, *args, **kwargs)
# 写入配置文件
def write_config(self):
self.config_path.write_text(json.dumps(self.config()))
# 加载配置文件
def load_config(self):
config = self.config() if not self.config_path.exists() else json.loads(self.config_path.read_text())
self.image_size = config['image_size']
self.network_capacity = config['network_capacity']
self.transparent = config['transparent']
self.no_const = config.pop('no_const', False)
del self.GAN
self.init_GAN()
# 返回配置信息
def config(self):
return {'image_size': self.image_size, 'network_capacity': self.network_capacity, 'transparent': self.transparent, 'no_const': self.no_const}
# 设置数据源
def set_data_src(self, folder):
self.dataset = Dataset(folder, self.image_size, transparent = self.transparent, aug_prob = self.dataset_aug_prob)
self.loader = cycle(data.DataLoader(self.dataset, num_workers = default(self.num_workers, num_cores), batch_size = self.batch_size, drop_last = True, shuffle=True, pin_memory=True))
# 禁用梯度计算
@torch.no_grad()
# 定义评估函数,用于生成图像
def evaluate(self, num = 0, num_image_tiles = 8, trunc = 1.0):
# 将 GAN 设置为评估模式
self.GAN.eval()
# 根据是否透明设置文件扩展名
ext = 'jpg' if not self.transparent else 'png'
num_rows = num_image_tiles
latent_dim = self.GAN.G.latent_dim
image_size = self.GAN.G.image_size
num_layers = self.GAN.G.num_layers
# latents and noise
# 生成潜在向量和噪声
latents = noise_list(num_rows ** 2, num_layers, latent_dim)
n = image_noise(num_rows ** 2, image_size)
# regular
# 生成正常图像
generated_images = self.generate_truncated(self.GAN.S, self.GAN.G, latents, n, trunc_psi = self.trunc_psi)
torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows)
# moving averages
# 生成移动平均图像
generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, latents, n, trunc_psi = self.trunc_psi)
torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-ema.{ext}'), nrow=num_rows)
# mixing regularities
# 定义瓷砖函数
def tile(a, dim, n_tile):
init_dim = a.size(dim)
repeat_idx = [1] * a.dim()
repeat_idx[dim] = n_tile
a = a.repeat(*(repeat_idx))
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).cuda()
return torch.index_select(a, dim, order_index)
nn = noise(num_rows, latent_dim)
tmp1 = tile(nn, 0, num_rows)
tmp2 = nn.repeat(num_rows, 1)
tt = int(num_layers / 2)
mixed_latents = [(tmp1, tt), (tmp2, num_layers - tt)]
# 生成混合图像
generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, mixed_latents, n, trunc_psi = self.trunc_psi)
torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-mr.{ext}'), nrow=num_rows)
@torch.no_grad()
# 生成截断图像
def generate_truncated(self, S, G, style, noi, trunc_psi = 0.75, num_image_tiles = 8):
latent_dim = G.latent_dim
if self.av is None:
z = noise(2000, latent_dim)
samples = evaluate_in_chunks(self.batch_size, S, z).cpu().numpy()
self.av = np.mean(samples, axis = 0)
self.av = np.expand_dims(self.av, axis = 0)
w_space = []
for tensor, num_layers in style:
tmp = S(tensor)
av_torch = torch.from_numpy(self.av).cuda()
tmp = trunc_psi * (tmp - av_torch) + av_torch
w_space.append((tmp, num_layers))
w_styles = styles_def_to_tensor(w_space)
generated_images = evaluate_in_chunks(self.batch_size, G, w_styles, noi)
return generated_images.clamp_(0., 1.)
@torch.no_grad()
# 生成插值图像序列
def generate_interpolation(self, num = 0, num_image_tiles = 8, trunc = 1.0, save_frames = False):
# 将 GAN 设置为评估模式
self.GAN.eval()
# 确定文件扩展名
ext = 'jpg' if not self.transparent else 'png'
# 设置图像行数
num_rows = num_image_tiles
# 获取潜在空间维度、图像尺寸和层数
latent_dim = self.GAN.G.latent_dim
image_size = self.GAN.G.image_size
num_layers = self.GAN.G.num_layers
# 生成潜在向量和噪声
latents_low = noise(num_rows ** 2, latent_dim)
latents_high = noise(num_rows ** 2, latent_dim)
n = image_noise(num_rows ** 2, image_size)
# 创建插值比例
ratios = torch.linspace(0., 8., 100)
frames = []
# 遍历插值比例
for ratio in tqdm(ratios):
# 线性插值生成插值潜在向量
interp_latents = slerp(ratio, latents_low, latents_high)
latents = [(interp_latents, num_layers)]
# 生成经过截断的图像
generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, latents, n, trunc_psi = self.trunc_psi)
# 将生成的图像拼接成网格
images_grid = torchvision.utils.make_grid(generated_images, nrow = num_rows)
# 转换为 PIL 图像
pil_image = transforms.ToPILImage()(images_grid.cpu())
frames.append(pil_image)
# 保存为 GIF 动画
frames[0].save(str(self.results_dir / self.name / f'{str(num)}.gif'), save_all=True, append_images=frames[1:], duration=80, loop=0, optimize=True)
# 如果需要保存每一帧图像
if save_frames:
folder_path = (self.results_dir / self.name / f'{str(num)}')
folder_path.mkdir(parents=True, exist_ok=True)
for ind, frame in enumerate(frames):
frame.save(str(folder_path / f'{str(ind)}.{ext}')
# 打印日志信息
def print_log(self):
pl_mean = default(self.pl_mean, 0)
print(f'G: {self.g_loss:.2f} | D: {self.d_loss:.2f} | GP: {self.last_gp_loss:.2f} | PL: {pl_mean:.2f} | CR: {self.last_cr_loss:.2f}')
# 返回模型文件名
def model_name(self, num):
return str(self.models_dir / self.name / f'model_{num}.pt')
# 初始化结果和模型文件夹
def init_folders(self):
(self.results_dir / self.name).mkdir(parents=True, exist_ok=True)
(self.models_dir / self.name).mkdir(parents=True, exist_ok=True)
# 清空结果和模型文件夹
def clear(self):
rmtree(f'./models/{self.name}', True)
rmtree(f'./results/{self.name}', True)
rmtree(str(self.config_path), True)
self.init_folders()
# 保存模型
def save(self, num):
save_data = {'GAN': self.GAN.state_dict()}
if self.GAN.fp16:
save_data['amp'] = amp.state_dict()
torch.save(save_data, self.model_name(num))
self.write_config()
# 加载模型
def load(self, num = -1):
self.load_config()
name = num
if num == -1:
file_paths = [p for p in Path(self.models_dir / self.name).glob('model_*.pt')]
saved_nums = sorted(map(lambda x: int(x.stem.split('_')[1]), file_paths))
if len(saved_nums) == 0:
return
name = saved_nums[-1]
print(f'continuing from previous epoch - {name}')
self.steps = name * self.save_every
load_data = torch.load(self.model_name(name))
self.GAN.load_state_dict(load_data['GAN'])
if self.GAN.fp16 and 'amp' in load_data:
amp.load_state_dict(load_data['amp'])
.\lucidrains\unet-stylegan2\unet_stylegan2\__init__.py
# 从 unet_stylegan2 模块中导入 Trainer, StyleGAN2 和 NanException 类
from unet_stylegan2.unet_stylegan2 import Trainer, StyleGAN2, NanException
Uniformer - Pytorch
Implementation of Uniformer, a simple attention and 3d convolutional net that achieved SOTA in a number of video classification tasks
Install
$ pip install uniformer-pytorch
Usage
Uniformer-S
import torch
from uniformer_pytorch import Uniformer
model = Uniformer(
num_classes = 1000, # number of output classes
dims = (64, 128, 256, 512), # feature dimensions per stage (4 stages)
depths = (3, 4, 8, 3), # depth at each stage
mhsa_types = ('l', 'l', 'g', 'g') # aggregation type at each stage, 'l' stands for local, 'g' stands for global
)
video = torch.randn(1, 3, 8, 224, 224) # (batch, channels, time, height, width)
logits = model(video) # (1, 1000)
Uniformer-B
import torch
from uniformer_pytorch import Uniformer
model = Uniformer(
num_classes = 1000
depths = (5, 8, 20, 7)
)
Citations
@inproceedings{anonymous2022uniformer,
title = {UniFormer: Unified Transformer for Efficient Spatial-Temporal Representation Learning},
author = {Anonymous},
booktitle = {Submitted to The Tenth International Conference on Learning Representations },
year = {2022},
url = {https://openreview.net/forum?id=nBU_u6DLvoK},
note = {under review}
}
.\lucidrains\uniformer-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'uniformer-pytorch', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.0.4', # 版本号
license='MIT', # 许可证
description = 'Uniformer - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/uniformer-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'attention mechanism',
'video classification'
],
install_requires=[ # 安装依赖
'einops>=0.3',
'torch>=1.6'
],
classifiers=[ # 分类器列表
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\uniformer-pytorch\uniformer_pytorch\uniformer_pytorch.py
import torch
from torch import nn, einsum
from einops import rearrange
from einops.layers.torch import Reduce
# helpers
# 检查值是否存在的辅助函数
def exists(val):
return val is not None
# classes
# LayerNorm 类
class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1, 1))
def forward(self, x):
# 计算标准差
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
# 计算均值
mean = torch.mean(x, dim = 1, keepdim = True)
# LayerNorm 操作
return (x - mean) / (std + self.eps) * self.g + self.b
# FeedForward 函数
def FeedForward(dim, mult = 4, dropout = 0.):
return nn.Sequential(
LayerNorm(dim),
nn.Conv3d(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv3d(dim * mult, dim, 1)
)
# MHRAs (multi-head relation aggregators)
# LocalMHRA 类
class LocalMHRA(nn.Module):
def __init__(
self,
dim,
heads,
dim_head = 64,
local_aggr_kernel = 5
):
super().__init__()
self.heads = heads
inner_dim = dim_head * heads
# 使用 BatchNorm3d 代替 LayerNorm
self.norm = nn.BatchNorm3d(dim)
# 仅使用值,因为注意力矩阵由卷积处理
self.to_v = nn.Conv3d(dim, inner_dim, 1, bias = False)
# 通过相对位置聚合
self.rel_pos = nn.Conv3d(heads, heads, local_aggr_kernel, padding = local_aggr_kernel // 2, groups = heads)
# 合并所有头部的输出
self.to_out = nn.Conv3d(inner_dim, dim, 1)
def forward(self, x):
x = self.norm(x)
b, c, *_, h = *x.shape, self.heads
# 转换为值
v = self.to_v(x)
# 分割头部
v = rearrange(v, 'b (c h) ... -> (b c) h ...', h = h)
# 通过相对位置聚合
out = self.rel_pos(v)
# 合并头部
out = rearrange(out, '(b c) h ... -> b (c h) ...', b = b)
return self.to_out(out)
# GlobalMHRA 类
class GlobalMHRA(nn.Module):
def __init__(
self,
dim,
heads,
dim_head = 64,
dropout = 0.
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.norm = LayerNorm(dim)
self.to_qkv = nn.Conv1d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv1d(inner_dim, dim, 1)
def forward(self, x):
x = self.norm(x)
shape, h = x.shape, self.heads
x = rearrange(x, 'b c ... -> b c (...)')
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h d) n -> b h n d', h = h), (q, k, v))
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
# 注意力
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b (h d) n', h = h)
out = self.to_out(out)
return out.view(*shape)
# Transformer 类
class Transformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
heads,
mhsa_type = 'g',
local_aggr_kernel = 5,
dim_head = 64,
ff_mult = 4,
ff_dropout = 0.,
attn_dropout = 0.
# 调用父类的构造函数初始化对象
):
super().__init__()
# 初始化一个空的神经网络模块列表
self.layers = nn.ModuleList([])
# 循环创建指定数量的层
for _ in range(depth):
# 根据不同的注意力类型创建不同的注意力模块
if mhsa_type == 'l':
attn = LocalMHRA(dim, heads = heads, dim_head = dim_head, local_aggr_kernel = local_aggr_kernel)
elif mhsa_type == 'g':
attn = GlobalMHRA(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout)
else:
raise ValueError('unknown mhsa_type')
# 将卷积层、注意力层和前馈网络层组成一个模块列表,并添加到神经网络模块列表中
self.layers.append(nn.ModuleList([
nn.Conv3d(dim, dim, 3, padding = 1),
attn,
FeedForward(dim, mult = ff_mult, dropout = ff_dropout),
]))
# 前向传播函数
def forward(self, x):
# 遍历每个层,依次进行前向传播
for dpe, attn, ff in self.layers:
# 执行卷积层、注意力层和前馈网络层的操作,并将结果与输入相加
x = dpe(x) + x
x = attn(x) + x
x = ff(x) + x
# 返回最终的输出结果
return x
# 主类定义
class Uniformer(nn.Module):
# 初始化函数
def __init__(
self,
*,
num_classes, # 类别数量
dims = (64, 128, 256, 512), # 不同层的维度
depths = (3, 4, 8, 3), # 不同层的深度
mhsa_types = ('l', 'l', 'g', 'g'), # 多头自注意力类型
local_aggr_kernel = 5, # 局部聚合核大小
channels = 3, # 输入通道数
ff_mult = 4, # FeedForward 层的倍数
dim_head = 64, # 头部维度
ff_dropout = 0., # FeedForward 层的 dropout
attn_dropout = 0. # 注意力层的 dropout
):
super().__init__()
init_dim, *_, last_dim = dims
# 将输入视频转换为 tokens
self.to_tokens = nn.Conv3d(channels, init_dim, (3, 4, 4), stride = (2, 4, 4), padding = (1, 0, 0))
dim_in_out = tuple(zip(dims[:-1], dims[1:]))
mhsa_types = tuple(map(lambda t: t.lower(), mhsa_types))
self.stages = nn.ModuleList([])
# 遍历不同层的深度和多头自注意力类型
for ind, (depth, mhsa_type) in enumerate(zip(depths, mhsa_types)):
is_last = ind == len(depths) - 1
stage_dim = dims[ind]
heads = stage_dim // dim_head
# 添加 Transformer 层和下采样层到 stages
self.stages.append(nn.ModuleList([
Transformer(
dim = stage_dim,
depth = depth,
heads = heads,
mhsa_type = mhsa_type,
ff_mult = ff_mult,
ff_dropout = ff_dropout,
attn_dropout = attn_dropout
),
nn.Sequential(
nn.Conv3d(stage_dim, dims[ind + 1], (1, 2, 2), stride = (1, 2, 2)),
LayerNorm(dims[ind + 1]),
) if not is_last else None
]))
# 输出层
self.to_logits = nn.Sequential(
Reduce('b c t h w -> b c', 'mean'),
nn.LayerNorm(last_dim),
nn.Linear(last_dim, num_classes)
)
# 前向传播函数
def forward(self, video):
x = self.to_tokens(video)
# 遍历不同层的 Transformer 和下采样层
for transformer, conv in self.stages:
x = transformer(x)
if exists(conv):
x = conv(x)
return self.to_logits(x)
.\lucidrains\uniformer-pytorch\uniformer_pytorch\__init__.py
# 从 uniformer_pytorch 包中导入 Uniformer 类
from uniformer_pytorch.uniformer_pytorch import Uniformer
.\lucidrains\vector-quantize-pytorch\examples\autoencoder.py
# 导入所需的库
from tqdm.auto import trange
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from vector_quantize_pytorch import VectorQuantize
# 设置超参数
lr = 3e-4
train_iter = 1000
num_codes = 256
seed = 1234
device = "cuda" if torch.cuda.is_available() else "cpu"
# 定义简单的 VQ 自编码器模型
class SimpleVQAutoEncoder(nn.Module):
def __init__(self, **vq_kwargs):
super().__init__()
# 定义模型的层
self.layers = nn.ModuleList(
[
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.GELU(),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
VectorQuantize(dim=32, accept_image_fmap=True, **vq_kwargs),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
nn.GELU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
]
)
return
# 前向传播函数
def forward(self, x):
for layer in self.layers:
if isinstance(layer, VectorQuantize):
x, indices, commit_loss = layer(x)
else:
x = layer(x)
return x.clamp(-1, 1), indices, commit_loss
# 训练函数
def train(model, train_loader, train_iterations=1000, alpha=10):
def iterate_dataset(data_loader):
data_iter = iter(data_loader)
while True:
try:
x, y = next(data_iter)
except StopIteration:
data_iter = iter(data_loader)
x, y = next(data_iter)
yield x.to(device), y.to(device)
# 迭代训练数据集
for _ in (pbar := trange(train_iterations)):
opt.zero_grad()
x, _ = next(iterate_dataset(train_loader))
out, indices, cmt_loss = model(x)
rec_loss = (out - x).abs().mean()
(rec_loss + alpha * cmt_loss).backward()
opt.step()
# 更新进度条显示
pbar.set_description(
f"rec loss: {rec_loss.item():.3f} | "
+ f"cmt loss: {cmt_loss.item():.3f} | "
+ f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
)
return
# 数据预处理
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
train_dataset = DataLoader(
datasets.FashionMNIST(
root="~/data/fashion_mnist", train=True, download=True, transform=transform
),
batch_size=256,
shuffle=True,
)
# 打印信息并开始训练
print("baseline")
torch.random.manual_seed(seed)
model = SimpleVQAutoEncoder(codebook_size=num_codes).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=lr)
train(model, train_dataset, train_iterations=train_iter)
.\lucidrains\vector-quantize-pytorch\examples\autoencoder_fsq.py
# 导入所需的库
from tqdm.auto import trange
import math
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from vector_quantize_pytorch import FSQ
# 设置超参数
lr = 3e-4
train_iter = 1000
levels = [8, 6, 5] # 目标大小为 2^8,实际大小为 240
num_codes = math.prod(levels) # 计算编码数量
seed = 1234
device = "cuda" if torch.cuda.is_available() else "cpu"
# 定义简单的自动编码器类
class SimpleFSQAutoEncoder(nn.Module):
def __init__(self, levels: list[int]):
super().__init__()
# 定义网络层
self.layers = nn.ModuleList(
[
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.GELU(),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, len(levels), kernel_size=1),
FSQ(levels), # 使用自定义的 FSQ 模块
nn.Conv2d(len(levels), 32, kernel_size=3, stride=1, padding=1),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
nn.GELU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
]
)
return
def forward(self, x):
for layer in self.layers:
if isinstance(layer, FSQ):
x, indices = layer(x) # 使用 FSQ 模块
else:
x = layer(x)
return x.clamp(-1, 1), indices
# 训练函数
def train(model, train_loader, train_iterations=1000):
def iterate_dataset(data_loader):
data_iter = iter(data_loader)
while True:
try:
x, y = next(data_iter)
except StopIteration:
data_iter = iter(data_loader)
x, y = next(data_iter)
yield x.to(device), y.to(device)
for _ in (pbar := trange(train_iterations)):
opt.zero_grad()
x, _ = next(iterate_dataset(train_loader))
out, indices = model(x)
rec_loss = (out - x).abs().mean()
rec_loss.backward()
opt.step()
pbar.set_description(
f"rec loss: {rec_loss.item():.3f} | "
+ f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
)
return
# 数据预处理
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
train_dataset = DataLoader(
datasets.FashionMNIST(
root="~/data/fashion_mnist", train=True, download=True, transform=transform
),
batch_size=256,
shuffle=True,
)
# 打印信息并开始训练
print("baseline")
torch.random.manual_seed(seed)
model = SimpleFSQAutoEncoder(levels).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=lr)
train(model, train_dataset, train_iterations=train_iter)
.\lucidrains\vector-quantize-pytorch\examples\autoencoder_lfq.py
# 导入所需的库
from tqdm.auto import trange
from math import log2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 导入自定义的 LFQ 模块
from vector_quantize_pytorch import LFQ
# 设置训练参数
lr = 3e-4
train_iter = 1000
seed = 1234
codebook_size = 2 ** 8
entropy_loss_weight = 0.02
diversity_gamma = 1.
device = "cuda" if torch.cuda.is_available() else "cpu"
# 定义 LFQAutoEncoder 类,继承自 nn.Module
class LFQAutoEncoder(nn.Module):
def __init__(
self,
codebook_size,
**vq_kwargs
):
super().__init__()
assert log2(codebook_size).is_integer()
quantize_dim = int(log2(codebook_size))
# 编码器部分
self.encode = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.GELU(),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.GroupNorm(4, 32, affine=False), # 添加规范化层
nn.Conv2d(32, quantize_dim, kernel_size=1),
)
# LFQ 模块
self.quantize = LFQ(dim=quantize_dim, **vq_kwargs)
# 解码器部分
self.decode = nn.Sequential(
nn.Conv2d(quantize_dim, 32, kernel_size=3, stride=1, padding=1),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
nn.GELU(),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
)
return
# 前向传播函数
def forward(self, x):
x = self.encode(x)
x, indices, entropy_aux_loss = self.quantize(x)
x = self.decode(x)
return x.clamp(-1, 1), indices, entropy_aux_loss
# 训练函数
def train(model, train_loader, train_iterations=1000):
def iterate_dataset(data_loader):
data_iter = iter(data_loader)
while True:
try:
x, y = next(data_iter)
except StopIteration:
data_iter = iter(data_loader)
x, y = next(data_iter)
yield x.to(device), y.to(device)
# 迭代训练数据集
for _ in (pbar := trange(train_iterations)):
opt.zero_grad()
x, _ = next(iterate_dataset(train_loader))
out, indices, entropy_aux_loss = model(x)
rec_loss = F.l1_loss(out, x)
(rec_loss + entropy_aux_loss).backward()
opt.step()
pbar.set_description(
f"rec loss: {rec_loss.item():.3f} | "
+ f"entropy aux loss: {entropy_aux_loss.item():.3f} | "
+ f"active %: {indices.unique().numel() / codebook_size * 100:.3f}"
)
return
# 数据预处理
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
# 加载 FashionMNIST 数据集
train_dataset = DataLoader(
datasets.FashionMNIST(
root="~/data/fashion_mnist", train=True, download=True, transform=transform
),
batch_size=256,
shuffle=True,
)
# 打印提示信息
print("baseline")
# 设置随机种子
torch.random.manual_seed(seed)
# 创建 LFQAutoEncoder 模型实例
model = LFQAutoEncoder(
codebook_size = codebook_size,
entropy_loss_weight = entropy_loss_weight,
diversity_gamma = diversity_gamma
).to(device)
# 定义优化器
opt = torch.optim.AdamW(model.parameters(), lr=lr)
# 训练模型
train(model, train_dataset, train_iterations=train_iter)
Vector Quantization - Pytorch
A vector quantization library originally transcribed from Deepmind's tensorflow implementation, made conveniently into a package. It uses exponential moving averages to update the dictionary.
VQ has been successfully used by Deepmind and OpenAI for high quality generation of images (VQ-VAE-2) and music (Jukebox).
Install
$ pip install vector-quantize-pytorch
Usage
import torch
from vector_quantize_pytorch import VectorQuantize
vq = VectorQuantize(
dim = 256,
codebook_size = 512, # codebook size
decay = 0.8, # the exponential moving average decay, lower means the dictionary will change faster
commitment_weight = 1. # the weight on the commitment loss
)
x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x) # (1, 1024, 256), (1, 1024), (1)
Residual VQ
This paper proposes to use multiple vector quantizers to recursively quantize the residuals of the waveform. You can use this with the ResidualVQ
class and one extra initialization parameter.
import torch
from vector_quantize_pytorch import ResidualVQ
residual_vq = ResidualVQ(
dim = 256,
num_quantizers = 8, # specify number of quantizers
codebook_size = 1024, # codebook size
)
x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = residual_vq(x)
# (1, 1024, 256), (1, 1024, 8), (1, 8)
# (batch, seq, dim), (batch, seq, quantizer), (batch, quantizer)
# if you need all the codes across the quantization layers, just pass return_all_codes = True
quantized, indices, commit_loss, all_codes = residual_vq(x, return_all_codes = True)
# *_, (8, 1, 1024, 256)
# all_codes - (quantizer, batch, seq, dim)
Furthermore, this paper uses Residual-VQ to construct the RQ-VAE, for generating high resolution images with more compressed codes.
They make two modifications. The first is to share the codebook across all quantizers. The second is to stochastically sample the codes rather than always taking the closest match. You can use both of these features with two extra keyword arguments.
import torch
from vector_quantize_pytorch import ResidualVQ
residual_vq = ResidualVQ(
dim = 256,
num_quantizers = 8,
codebook_size = 1024,
stochastic_sample_codes = True,
sample_codebook_temp = 0.1, # temperature for stochastically sampling codes, 0 would be equivalent to non-stochastic
shared_codebook = True # whether to share the codebooks for all quantizers or not
)
x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = residual_vq(x)
# (1, 1024, 256), (8, 1, 1024), (8, 1)
# (batch, seq, dim), (quantizer, batch, seq), (quantizer, batch)
A recent paper further proposes to do residual VQ on groups of the feature dimension, showing equivalent results to Encodec while using far fewer codebooks. You can use it by importing GroupedResidualVQ
import torch
from vector_quantize_pytorch import GroupedResidualVQ
residual_vq = GroupedResidualVQ(
dim = 256,
num_quantizers = 8, # specify number of quantizers
groups = 2,
codebook_size = 1024, # codebook size
)
x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = residual_vq(x)
# (1, 1024, 256), (2, 1, 1024, 8), (2, 1, 8)
# (batch, seq, dim), (groups, batch, seq, quantizer), (groups, batch, quantizer)
Initialization
The SoundStream paper proposes that the codebook should be initialized by the kmeans centroids of the first batch. You can easily turn on this feature with one flag kmeans_init = True
, for either VectorQuantize
or ResidualVQ
class
import torch
from vector_quantize_pytorch import ResidualVQ
residual_vq = ResidualVQ(
dim = 256,
codebook_size = 256,
num_quantizers = 4,
kmeans_init = True, # set to True
kmeans_iters = 10 # number of kmeans iterations to calculate the centroids for the codebook on init
)
x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = residual_vq(x)
Increasing codebook usage
This repository will contain a few techniques from various papers to combat "dead" codebook entries, which is a common problem when using vector quantizers.
Lower codebook dimension
The Improved VQGAN paper proposes to have the codebook kept in a lower dimension. The encoder values are projected down before being projected back to high dimensional after quantization. You can set this with the codebook_dim
hyperparameter.
import torch
from vector_quantize_pytorch import VectorQuantize
vq = VectorQuantize(
dim = 256,
codebook_size = 256,
codebook_dim = 16 # paper proposes setting this to 32 or as low as 8 to increase codebook usage
)
x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)
Cosine similarity
The Improved VQGAN paper also proposes to l2 normalize the codes and the encoded vectors, which boils down to using cosine similarity for the distance. They claim enforcing the vectors on a sphere leads to improvements in code usage and downstream reconstruction. You can turn this on by setting use_cosine_sim = True
import torch
from vector_quantize_pytorch import VectorQuantize
vq = VectorQuantize(
dim = 256,
codebook_size = 256,
use_cosine_sim = True # set this to True
)
x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)
Expiring stale codes
Finally, the SoundStream paper has a scheme where they replace codes that have hits below a certain threshold with randomly selected vector from the current batch. You can set this threshold with threshold_ema_dead_code
keyword.
import torch
from vector_quantize_pytorch import VectorQuantize
vq = VectorQuantize(
dim = 256,
codebook_size = 512,
threshold_ema_dead_code = 2 # should actively replace any codes that have an exponential moving average cluster size less than 2
)
x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)
Orthogonal regularization loss
VQ-VAE / VQ-GAN is quickly gaining popularity. A recent paper proposes that when using vector quantization on images, enforcing the codebook to be orthogonal leads to translation equivariance of the discretized codes, leading to large improvements in downstream text to image generation tasks.
You can use this feature by simply setting the orthogonal_reg_weight
to be greater than 0
, in which case the orthogonal regularization will be added to the auxiliary loss outputted by the module.
import torch
from vector_quantize_pytorch import VectorQuantize
vq = VectorQuantize(
dim = 256,
codebook_size = 256,
accept_image_fmap = True, # set this true to be able to pass in an image feature map
orthogonal_reg_weight = 10, # in paper, they recommended a value of 10
orthogonal_reg_max_codes = 128, # this would randomly sample from the codebook for the orthogonal regularization loss, for limiting memory usage
orthogonal_reg_active_codes_only = False # set this to True if you have a very large codebook, and would only like to enforce the loss on the activated codes per batch
)
img_fmap = torch.randn(1, 256, 32, 32)
quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32), (1,)
# loss now contains the orthogonal regularization loss with the weight as assigned
Multi-headed VQ
There has been a number of papers that proposes variants of discrete latent representations with a multi-headed approach (multiple codes per feature). I have decided to offer one variant where the same codebook is used to vector quantize across the input dimension head
times.
You can also use a more proven approach (memcodes) from NWT paper
import torch
from vector_quantize_pytorch import VectorQuantize
vq = VectorQuantize(
dim = 256,
codebook_dim = 32, # a number of papers have shown smaller codebook dimension to be acceptable
heads = 8, # number of heads to vector quantize, codebook shared across all heads
separate_codebook_per_head = True, # whether to have a separate codebook per head. False would mean 1 shared codebook
codebook_size = 8196,
accept_image_fmap = True
)
img_fmap = torch.randn(1, 256, 32, 32)
quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32, 8), (1,)
# indices shape - (batch, height, width, heads)
Random Projection Quantizer
This paper first proposed to use a random projection quantizer for masked speech modeling, where signals are projected with a randomly initialized matrix and then matched with a random initialized codebook. One therefore does not need to learn the quantizer. This technique was used by Google's Universal Speech Model to achieve SOTA for speech-to-text modeling.
USM further proposes to use multiple codebook, and the masked speech modeling with a multi-softmax objective. You can do this easily by setting num_codebooks
to be greater than 1
import torch
from vector_quantize_pytorch import RandomProjectionQuantizer
quantizer = RandomProjectionQuantizer(
dim = 512, # input dimensions
num_codebooks = 16, # in USM, they used up to 16 for 5% gain
codebook_dim = 256, # codebook dimension
codebook_size = 1024 # codebook size
)
x = torch.randn(1, 1024, 512)
indices = quantizer(x) # (1, 1024, 16) - (batch, seq, num_codebooks)
This repository should also automatically synchronizing the codebooks in a multi-process setting. If somehow it isn't, please open an issue. You can override whether to synchronize codebooks or not by setting sync_codebook = True | False
Finite Scalar Quantization
VQ | FSQ | |
---|---|---|
Quantization | argmin_c || z-c || | round(f(z)) |
Gradients | Straight Through Estimation (STE) | STE |
Auxiliary Losses | Commitment, codebook, entropy loss, ... | N/A |
Tricks | EMA on codebook, codebook splitting, projections, ... | N/A |
Parameters | Codebook | N/A |
This work out of Google Deepmind aims to vastly simplify the way vector quantization is done for generative modeling, removing the need for commitment losses, EMA updating of the codebook, as well as tackle the issues with codebook collapse or insufficient utilization. They simply round each scalar into discrete levels with straight through gradients; the codes become uniform points in a hypercube.
Thanks goes out to @sekstini for porting over this implementation in record time!
import torch
from vector_quantize_pytorch import FSQ
levels = [8,5,5,5] # see 4.1 and A.4.1 in the paper
quantizer = FSQ(levels)
x = torch.randn(1, 1024, 4) # 4 since there are 4 levels
xhat, indices = quantizer(x)
print(xhat.shape) # (1, 1024, 4) - (batch, seq, dim)
print(indices.shape) # (1, 1024) - (batch, seq)
assert xhat.shape == x.shape
assert torch.all(xhat == quantizer.indices_to_codes(indices))
An improvised Residual FSQ, for an attempt to improve audio encoding.
Credit goes to @sekstini for originally incepting the idea here
import torch
from vector_quantize_pytorch import ResidualFSQ
residual_fsq = ResidualFSQ(
dim = 256,
levels = [8, 5, 5, 3],
num_quantizers = 8
)
x = torch.randn(1, 1024, 256)
residual_fsq.eval()
quantized, indices = residual_fsq(x)
# (1, 1024, 256), (1, 1024, 8), (8)
# (batch, seq, dim), (batch, seq, quantizers), (quantizers)
quantized_out = residual_fsq.get_output_from_indices(indices)
# (8, 1, 1024, 8)
# (residual layers, batch, seq, quantizers)
assert torch.all(quantized == quantized_out)
Lookup Free Quantization
The research team behind MagViT has released new SOTA results for generative video modeling. A core change between v1 and v2 include a new type of quantization, look-up free quantization (LFQ), which eliminates the codebook and embedding lookup entirely.
This paper presents a simple LFQ quantizer of using independent binary latents. Other implementations of LFQ exist. However, the team shows that MAGVIT-v2 with LFQ significantly improves on the ImageNet benchmark. The differences between LFQ and 2-level FSQ includes entropy regularizations as well as maintained commitment loss.
Developing a more advanced method of LFQ quantization without codebook-lookup could revolutionize generative modeling.
You can use it simply as follows. Will be dogfooded at MagViT2 pytorch port
import torch
from vector_quantize_pytorch import LFQ
# you can specify either dim or codebook_size
# if both specified, will be validated against each other
quantizer = LFQ(
codebook_size = 65536, # codebook size, must be a power of 2
dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
entropy_loss_weight = 0.1, # how much weight to place on entropy loss
diversity_gamma = 1. # within entropy loss, how much weight to give to diversity of codes, taken from https://arxiv.org/abs/1911.05894
)
image_feats = torch.randn(1, 16, 32, 32)
quantized, indices, entropy_aux_loss = quantizer(image_feats, inv_temperature=100.) # you may want to experiment with temperature
# (1, 16, 32, 32), (1, 32, 32), (1,)
assert image_feats.shape == quantized.shape
assert (quantized == quantizer.indices_to_codes(indices)).all()
You can also pass in video features as (batch, feat, time, height, width)
or sequences as (batch, seq, feat)
seq = torch.randn(1, 32, 16)
quantized, *_ = quantizer(seq)
assert seq.shape == quantized.shape
video_feats = torch.randn(1, 16, 10, 32, 32)
quantized, *_ = quantizer(video_feats)
assert video_feats.shape == quantized.shape
Or support multiple codebooks
import torch
from vector_quantize_pytorch import LFQ
quantizer = LFQ(
codebook_size = 4096,
dim = 16,
num_codebooks = 4 # 4 codebooks, total codebook dimension is log2(4096) * 4
)
image_feats = torch.randn(1, 16, 32, 32)
quantized, indices, entropy_aux_loss = quantizer(image_feats)
# (1, 16, 32, 32), (1, 32, 32, 4), (1,)
assert image_feats.shape == quantized.shape
assert (quantized == quantizer.indices_to_codes(indices)).all()
An improvised Residual LFQ, to see if it can lead to an improvement for audio compression.
import torch
from vector_quantize_pytorch import ResidualLFQ
residual_lfq = ResidualLFQ(
dim = 256,
codebook_size = 256,
num_quantizers = 8
)
x = torch.randn(1, 1024, 256)
residual_lfq.eval()
quantized, indices, commit_loss = residual_lfq(x)
# (1, 1024, 256), (1, 1024, 8), (8)
# (batch, seq, dim), (batch, seq, quantizers), (quantizers)
quantized_out = residual_lfq.get_output_from_indices(indices)
# (8, 1, 1024, 8)
# (residual layers, batch, seq, quantizers)
assert torch.all(quantized == quantized_out)
Latent Quantization
Disentanglement is essential for representation learning as it promotes interpretability, generalization, improved learning, and robustness. It aligns with the goal of capturing meaningful and independent features of the data, facilitating more effective use of learned representations across various applications. For better disentanglement, the challenge is to disentangle underlying variations in a dataset without explicit ground truth information. This work introduces a key inductive bias aimed at encoding and decoding within an organized latent space. The strategy incorporated encompasses discretizing the latent space by assigning discrete code vectors through the utilization of an individual learnable scalar codebook for each dimension. This methodology enables their models to surpass robust prior methods effectively.
Be aware they had to use a very high weight decay for the results in this paper.
import torch
from vector_quantize_pytorch import LatentQuantize
# you can specify either dim or codebook_size
# if both specified, will be validated against each other
quantizer = LatentQuantize(
levels = [5, 5, 8], # number of levels per codebook dimension
dim = 16, # input dim
commitment_loss_weight=0.1,
quantization_loss_weight=0.1,
)
image_feats = torch.randn(1, 16, 32, 32)
quantized, indices, loss = quantizer(image_feats)
# (1, 16, 32, 32), (1, 32, 32), (1,)
assert image_feats.shape == quantized.shape
assert (quantized == quantizer.indices_to_codes(indices)).all()
You can also pass in video features as (batch, feat, time, height, width)
or sequences as (batch, seq, feat)
seq = torch.randn(1, 32, 16)
quantized, *_ = quantizer(seq)
assert seq.shape == quantized.shape
video_feats = torch.randn(1, 16, 10, 32, 32)
quantized, *_ = quantizer(video_feats)
assert video_feats.shape == quantized.shape
Or support multiple codebooks
import torch
from vector_quantize_pytorch import LatentQuantize
levels = [4, 8, 16]
dim = 9
num_codebooks = 3
model = LatentQuantize(levels, dim, num_codebooks=num_codebooks)
input_tensor = torch.randn(2, 3, dim)
output_tensor, indices, loss = model(input_tensor)
assert output_tensor.shape == input_tensor.shape
assert indices.shape == (2, 3, num_codebooks)
assert loss.item() >= 0
Citations
@misc{oord2018neural,
title = {Neural Discrete Representation Learning},
author = {Aaron van den Oord and Oriol Vinyals and Koray Kavukcuoglu},
year = {2018},
eprint = {1711.00937},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@misc{zeghidour2021soundstream,
title = {SoundStream: An End-to-End Neural Audio Codec},
author = {Neil Zeghidour and Alejandro Luebs and Ahmed Omran and Jan Skoglund and Marco Tagliasacchi},
year = {2021},
eprint = {2107.03312},
archivePrefix = {arXiv},
primaryClass = {cs.SD}
}
@inproceedings{anonymous2022vectorquantized,
title = {Vector-quantized Image Modeling with Improved {VQGAN}},
author = {Anonymous},
booktitle = {Submitted to The Tenth International Conference on Learning Representations },
year = {2022},
url = {https://openreview.net/forum?id=pfNyExj7z2},
note = {under review}
}
@unknown{unknown,
author = {Lee, Doyup and Kim, Chiheon and Kim, Saehoon and Cho, Minsu and Han, Wook-Shin},
year = {2022},
month = {03},
title = {Autoregressive Image Generation using Residual Quantization}
}
@article{Defossez2022HighFN,
title = {High Fidelity Neural Audio Compression},
author = {Alexandre D'efossez and Jade Copet and Gabriel Synnaeve and Yossi Adi},
journal = {ArXiv},
year = {2022},
volume = {abs/2210.13438}
}
@inproceedings{Chiu2022SelfsupervisedLW,
title = {Self-supervised Learning with Random-projection Quantizer for Speech Recognition},
author = {Chung-Cheng Chiu and James Qin and Yu Zhang and Jiahui Yu and Yonghui Wu},
booktitle = {International Conference on Machine Learning},
year = {2022}
}
@inproceedings{Zhang2023GoogleUS,
title = {Google USM: Scaling Automatic Speech Recognition Beyond 100 Languages},
author = {Yu Zhang and Wei Han and James Qin and Yongqiang Wang and Ankur Bapna and Zhehuai Chen and Nanxin Chen and Bo Li and Vera Axelrod and Gary Wang and Zhong Meng and Ke Hu and Andrew Rosenberg and Rohit Prabhavalkar and Daniel S. Park and Parisa Haghani and Jason Riesa and Ginger Perng and Hagen Soltau and Trevor Strohman and Bhuvana Ramabhadran and Tara N. Sainath and Pedro J. Moreno and Chung-Cheng Chiu and Johan Schalkwyk and Franccoise Beaufays and Yonghui Wu},
year = {2023}
}
@inproceedings{Shen2023NaturalSpeech2L,
title = {NaturalSpeech 2: Latent Diffusion Models are Natural and Zero-Shot Speech and Singing Synthesizers},
author = {Kai Shen and Zeqian Ju and Xu Tan and Yanqing Liu and Yichong Leng and Lei He and Tao Qin and Sheng Zhao and Jiang Bian},
year = {2023}
}
@inproceedings{Yang2023HiFiCodecGV,
title = {HiFi-Codec: Group-residual Vector quantization for High Fidelity Audio Codec},
author = {Dongchao Yang and Songxiang Liu and Rongjie Huang and Jinchuan Tian and Chao Weng and Yuexian Zou},
year = {2023}
}
@article{Liu2023BridgingDA,
title = {Bridging Discrete and Backpropagation: Straight-Through and Beyond},
author = {Liyuan Liu and Chengyu Dong and Xiaodong Liu and Bin Yu and Jianfeng Gao},
journal = {ArXiv},
year = {2023},
volume = {abs/2304.08612}
}
@inproceedings{huh2023improvedvqste,
title = {Straightening Out the Straight-Through Estimator: Overcoming Optimization Challenges in Vector Quantized Networks},
author = {Huh, Minyoung and Cheung, Brian and Agrawal, Pulkit and Isola, Phillip},
booktitle = {International Conference on Machine Learning},
year = {2023},
organization = {PMLR}
}
@inproceedings{rogozhnikov2022einops,
title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
author = {Alex Rogozhnikov},
booktitle = {International Conference on Learning Representations},
year = {2022},
url = {https://openreview.net/forum?id=oapKSVM2bcj}
}
@misc{shin2021translationequivariant,
title = {Translation-equivariant Image Quantizer for Bi-directional Image-Text Generation},
author = {Woncheol Shin and Gyubok Lee and Jiyoung Lee and Joonseok Lee and Edward Choi},
year = {2021},
eprint = {2112.00384},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@misc{mentzer2023finite,
title = {Finite Scalar Quantization: VQ-VAE Made Simple},
author = {Fabian Mentzer and David Minnen and Eirikur Agustsson and Michael Tschannen},
year = {2023},
eprint = {2309.15505},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@misc{yu2023language,
title = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation},
author = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
year = {2023},
eprint = {2310.05737},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@misc{hsu2023disentanglement,
title = {Disentanglement via Latent Quantization},
author = {Kyle Hsu and Will Dorrell and James C. R. Whittington and Jiajun Wu and Chelsea Finn},
year = {2023},
eprint = {2305.18378},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
.\lucidrains\vector-quantize-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
# 包的名称
name = 'vector_quantize_pytorch',
# 查找并包含所有包
packages = find_packages(),
# 版本号
version = '1.14.5',
# 许可证
license='MIT',
# 描述
description = 'Vector Quantization - Pytorch',
# 长描述内容类型
long_description_content_type = 'text/markdown',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 项目链接
url = 'https://github.com/lucidrains/vector-quantizer-pytorch',
# 关键词
keywords = [
'artificial intelligence',
'deep learning',
'pytorch',
'quantization'
],
# 安装依赖
install_requires=[
'einops>=0.7.0',
'einx[torch]>=0.1.3',
'torch'
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\finite_scalar_quantization.py
"""
Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
Code adapted from Jax version in Appendix A.1
"""
from typing import List, Tuple, Optional
import torch
import torch.nn as nn
from torch.nn import Module
from torch import Tensor, int32
from torch.cuda.amp import autocast
from einops import rearrange, pack, unpack
# helper functions
# 检查变量是否存在
def exists(v):
return v is not None
# 返回第一个存在的参数
def default(*args):
for arg in args:
if exists(arg):
return arg
return None
# 将单个张量按照指定模式打包
def pack_one(t, pattern):
return pack([t], pattern)
# 将单个张量按照指定模式解包
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# tensor helpers
# 使用直通梯度进行四舍五入
def round_ste(z: Tensor) -> Tensor:
"""Round with straight through gradients."""
zhat = z.round()
return z + (zhat - z).detach()
# main class
class FSQ(Module):
def __init__(
self,
levels: List[int],
dim: Optional[int] = None,
num_codebooks = 1,
keep_num_codebooks_dim: Optional[bool] = None,
scale: Optional[float] = None,
allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64)
):
super().__init__()
_levels = torch.tensor(levels, dtype=int32)
self.register_buffer("_levels", _levels, persistent = False)
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
self.register_buffer("_basis", _basis, persistent = False)
self.scale = scale
codebook_dim = len(levels)
self.codebook_dim = codebook_dim
effective_codebook_dim = codebook_dim * num_codebooks
self.num_codebooks = num_codebooks
self.effective_codebook_dim = effective_codebook_dim
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
self.keep_num_codebooks_dim = keep_num_codebooks_dim
self.dim = default(dim, len(_levels) * num_codebooks)
has_projections = self.dim != effective_codebook_dim
self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
self.has_projections = has_projections
self.codebook_size = self._levels.prod().item()
implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out = False)
self.register_buffer("implicit_codebook", implicit_codebook, persistent = False)
self.allowed_dtypes = allowed_dtypes
def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
"""Bound `z`, an array of shape (..., d)."""
half_l = (self._levels - 1) * (1 + eps) / 2
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
shift = (offset / half_l).atanh()
return (z + shift).tanh() * half_l - offset
def quantize(self, z: Tensor) -> Tensor:
"""Quantizes z, returns quantized zhat, same shape as z."""
quantized = round_ste(self.bound(z))
half_width = self._levels // 2 # Renormalize to [-1, 1].
return quantized / half_width
def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
half_width = self._levels // 2
return (zhat_normalized * half_width) + half_width
def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
half_width = self._levels // 2
return (zhat - half_width) / half_width
def codes_to_indices(self, zhat: Tensor) -> Tensor:
"""Converts a `code` to an index in the codebook."""
assert zhat.shape[-1] == self.codebook_dim
zhat = self._scale_and_shift(zhat)
return (zhat * self._basis).sum(dim=-1).to(int32)
def indices_to_codes(
self,
indices: Tensor,
project_out = True
def codes_to_indices(self, indices: Tensor) -> Tensor:
"""Inverse of `codes_to_indices`."""
# 检查输入张量的维度是否大于等于3(图片或视频)
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
# 将输入张量的维度调整为 '... -> ... 1'
indices = rearrange(indices, '... -> ... 1')
# 计算非中心化的编码
codes_non_centered = (indices // self._basis) % self._levels
# 对编码进行缩放和偏移
codes = self._scale_and_shift_inverse(codes_non_centered)
# 如果需要保留编码簇维度
if self.keep_num_codebooks_dim:
codes = rearrange(codes, '... c d -> ... (c d)')
# 如果需要进行投影
if project_out:
codes = self.project_out(codes)
# 如果是图片或视频
if is_img_or_video:
codes = rearrange(codes, 'b ... d -> b d ...')
# 返回编码
return codes
@autocast(enabled = False)
def forward(self, z: Tensor) -> Tensor:
"""
einstein notation
b - batch
n - sequence (or flattened spatial dimensions)
d - feature dimension
c - number of codebook dim
"""
# 保存原始数据类型
orig_dtype = z.dtype
# 检查输入张量的维度是否大于等于4(图片或视频)
is_img_or_video = z.ndim >= 4
# 确保输入张量的数据类型在允许的范围内
if z.dtype not in self.allowed_dtypes:
z = z.float()
# 标准化图片或视频数据为 (batch, seq, dimension) 的形式
if is_img_or_video:
z = rearrange(z, 'b d ... -> b ... d')
z, ps = pack_one(z, 'b * d')
# 断言输入张量的最后一个维度是否与指定的维度相匹配
assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
# 对输入张量进行投影
z = self.project_in(z)
# 调整输入张量的维度为 'b n (c d)'
z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)
# 对输入张量进行量化
codes = self.quantize(z)
# 将编码转换为索引
indices = self.codes_to_indices(codes)
# 调整编码的维度为 'b n (c d)'
codes = rearrange(codes, 'b n c d -> b n (c d)')
# 对输出进行投影
out = self.project_out(codes)
# 恢复图片或视频的维度
if is_img_or_video:
out = unpack_one(out, ps, 'b * d')
out = rearrange(out, 'b ... d -> b d ...')
indices = unpack_one(indices, ps, 'b * c')
# 如果不需要保留编码簇维度
if not self.keep_num_codebooks_dim:
indices = rearrange(indices, '... 1 -> ...')
# 将输出转换回原始数据类型
if out.dtype != orig_dtype:
out = out.type(orig_dtype)
# 返回量化输出和索引
return out, indices
.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\latent_quantization.py
"""
Disentanglement via Latent Quantization
- https://arxiv.org/abs/2305.18378
Code adapted from Jax version in https://github.com/kylehkhsu/latent_quantization
"""
# 导入所需的库
from typing import List, Optional, Union, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module
from torch import Tensor, int32
from torch.optim import Optimizer
from einops import rearrange, pack, unpack
# 辅助函数
# 检查变量是否存在
def exists(v):
return v is not None
# 返回第一个非空参数
def default(*args):
for arg in args:
if exists(arg):
return arg
return None
# 将单个张量按指定模式打包
def pack_one(t, pattern):
return pack([t], pattern)
# 将单个张量按指定模式解包
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# 主类
class LatentQuantize(Module):
# 计算量化损失
def quantization_loss(self, z: Tensor, zhat: Tensor, reduce="mean") -> Tensor:
"""Computes the quantization loss."""
return F.mse_loss(zhat.detach(), z, reduction=reduce)
# 计算约束损失
def commitment_loss(self, z: Tensor, zhat: Tensor, reduce="mean") -> Tensor:
"""Computes the commitment loss."""
return F.mse_loss(z.detach(), zhat, reduction=reduce)
# 对 z 进行量化
def quantize(self, z: Tensor) -> Tensor:
"""Quantizes z, returns quantized zhat, same shape as z.
The quantization is done by measuring the distance between the input and the codebook values per latent dimension
and returning the index of the closest codebook value.
"""
def distance(x, y):
return torch.abs(x - y)
if self._equal_levels:
index = torch.argmin(distance(z[..., None], self.values_per_latent), dim=-1)
quantize = self.values_per_latent[torch.arange(self.dim), index]
else:
index = torch.stack([torch.argmin(distance(z[..., i, None], self.values_per_latent[i]), dim=-1) for i in range(self.codebook_dim)], dim=-1)
quantize = torch.stack([self.values_per_latent[i][index[..., i]] for i in range(self.codebook_dim)], dim=-1)
quantize = z + (quantize - z).detach()
#half_width = self._levels // 2 / 2 # Renormalize to [-0.5, 0.5].
return quantize #/ half_width
# 缩放和移位 zhat 从 [-0.5, 0.5] 到 [0, level_per_dim]
def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
""" scale and shift zhat from [-0.5, 0.5] to [0, level_per_dim]"""
half_width = self._levels // 2
return (zhat_normalized * 2 * half_width) + half_width
# 将 zhat 反向缩放和移位为 [-0.5, 0.5]
def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
"""normalize zhat to [-0.5, 0.5]"""
half_width = self._levels // 2
return (zhat - half_width) / half_width / 2
# 将编码转换为索引
def codes_to_indices(self, zhat: Tensor) -> Tensor:
"""Converts a `code` which contains the number per latent to an index in the codebook."""
assert zhat.shape[-1] == self.codebook_dim
zhat = self._scale_and_shift(zhat)
return (zhat * self._basis).sum(dim=-1).to(int32)
# 将索引转换为编码
def indices_to_codes(
self,
indices: Tensor,
project_out = True
) -> Tensor:
"""Inverse of `codes_to_indices`."""
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
indices = rearrange(indices, '... -> ... 1')
codes_non_centered = (indices // self._basis) % self._levels
codes = self._scale_and_shift_inverse(codes_non_centered)
if self.keep_num_codebooks_dim:
codes = rearrange(codes, '... c d -> ... (c d)')
if project_out:
codes = self.project_out(codes)
if is_img_or_video:
codes = rearrange(codes, 'b ... d -> b d ...')
return codes
# 对输入张量进行量化和投影操作
def quantize_and_project(self, z: Tensor, is_img_or_video, ps) -> Tensor:
# 对输入张量进行量化操作
codes = self.quantize(z)
# 将量化后的结果转换为索引
indices = self.codes_to_indices(codes)
# 重排列张量维度
codes = rearrange(codes, 'b n c d -> b n (c d)')
# 对量化后的结果进行投影操作
out = self.project_out(codes)
# 重新构建图像或视频的维度
if is_img_or_video:
# 解包张量
out = unpack_one(out, ps, 'b * d')
out = rearrange(out, 'b ... d -> b d ...')
indices = unpack_one(indices, ps, 'b * c')
if not self.keep_num_codebooks_dim:
indices = rearrange(indices, '... 1 -> ...')
return codes, out, indices
# 前向传播函数
def forward(self,
z: Tensor) -> Tensor:
"""
einstein notation
b - batch
n - sequence (or flattened spatial dimensions)
d - feature dimension
c - number of codebook dim
"""
# 判断输入张量是否为图像或视频
is_img_or_video = z.ndim >= 4
original_input = z
# 标准化图像或视频为 (batch, seq, dimension) 格式
should_inplace_optimize = exists(self.in_place_codebook_optimizer)
if is_img_or_video:
z = rearrange(z, 'b d ... -> b ... d')
z, ps = pack_one(z, 'b * d')
assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
# 投影输入张量
z = self.project_in(z)
z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)
# 对输入张量进行量化操作
codes = self.quantize(z)
# 将量化后的结果转换为索引
indices = self.codes_to_indices(codes)
# 重排列张量维度
codes = rearrange(codes, 'b n c d -> b n (c d)')
# 对量化后的结果进行投影操作
out = self.project_out(codes)
# 重新构建图像或视频的维度
if is_img_or_video:
out = unpack_one(out, ps, 'b * d')
out = rearrange(out, 'b ... d -> b d ...')
indices = unpack_one(indices, ps, 'b * c')
if not self.keep_num_codebooks_dim:
indices = rearrange(indices, '... 1 -> ...')
if should_inplace_optimize and self.training and not self.optimize_values:
# 更新码���
loss = self.commitment_loss(z, out) if self.commitment_loss_weight!=0 else torch.tensor(0.)
loss+= self.quantization_loss(z, out) if self.quantization_loss_weight!=0 else torch.tensor(0.)
loss.backward()
self.in_place_codebook_optimizer.step()
self.in_place_codebook_optimizer.zero_grad()
# 再次对输入张量进行量化
codes = self.quantize(z)
indices = self.codes_to_indices(codes)
codes = rearrange(codes, 'b n c d -> b n (c d)')
out = self.project_out(codes)
if is_img_or_video:
out = unpack_one(out, ps, 'b * d')
out = rearrange(out, 'b ... d -> b d ...')
indices = unpack_one(indices, ps, 'b * c')
if not self.keep_num_codebooks_dim:
indices = rearrange(indices, '... 1 -> ...')
# 计算损失
commitment_loss = self.commitment_loss(original_input, out) if self.training and self.commitment_loss_weight!=0 else torch.tensor(0.)
quantization_loss = self.quantization_loss(original_input, out) if self.training and self.quantization_loss_weight!=0 else torch.tensor(0.)
loss = self.commitment_loss_weight * commitment_loss + self.quantization_loss_weight * quantization_loss
return out, indices, loss
.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\lookup_free_quantization.py
"""
Lookup Free Quantization
Proposed in https://arxiv.org/abs/2310.05737
In the simplest setup, each dimension is quantized into {-1, 1}.
An entropy penalty is used to encourage utilization.
"""
from math import log2, ceil
from collections import namedtuple
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.nn import Module
from torch.cuda.amp import autocast
from einops import rearrange, reduce, pack, unpack
# constants
Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])
LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
# helper functions
def exists(v):
return v is not None
def default(*args):
for arg in args:
if exists(arg):
return arg() if callable(arg) else arg
return None
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# entropy
def log(t, eps = 1e-5):
return t.clamp(min = eps).log()
def entropy(prob):
return (-prob * log(prob)).sum(dim=-1)
# class
class LFQ(Module):
def __init__(
self,
*,
dim = None,
codebook_size = None,
entropy_loss_weight = 0.1,
commitment_loss_weight = 0.25,
diversity_gamma = 1.,
straight_through_activation = nn.Identity(),
num_codebooks = 1,
keep_num_codebooks_dim = None,
codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer
frac_per_sample_entropy = 1. # make less than 1. to only use a random fraction of the probs for per sample entropy
):
super().__init__()
# some assert validations
assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
codebook_size = default(codebook_size, lambda: 2 ** dim)
codebook_dim = int(log2(codebook_size))
codebook_dims = codebook_dim * num_codebooks
dim = default(dim, codebook_dims)
has_projections = dim != codebook_dims
self.project_in = nn.Linear(dim, codebook_dims) if has_projections else nn.Identity()
self.project_out = nn.Linear(codebook_dims, dim) if has_projections else nn.Identity()
self.has_projections = has_projections
self.dim = dim
self.codebook_dim = codebook_dim
self.num_codebooks = num_codebooks
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
self.keep_num_codebooks_dim = keep_num_codebooks_dim
# straight through activation
self.activation = straight_through_activation
# entropy aux loss related weights
assert 0 < frac_per_sample_entropy <= 1.
self.frac_per_sample_entropy = frac_per_sample_entropy
self.diversity_gamma = diversity_gamma
self.entropy_loss_weight = entropy_loss_weight
# codebook scale
self.codebook_scale = codebook_scale
# commitment loss
self.commitment_loss_weight = commitment_loss_weight
# for no auxiliary loss, during inference
self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
self.register_buffer('zero', torch.tensor(0.), persistent = False)
# codes
all_codes = torch.arange(codebook_size)
bits = ((all_codes[..., None].int() & self.mask) != 0).float()
codebook = self.bits_to_codes(bits)
self.register_buffer('codebook', codebook, persistent = False)
def bits_to_codes(self, bits):
return bits * self.codebook_scale * 2 - self.codebook_scale
@property
# 返回当前对象的数据类型
def dtype(self):
return self.codebook.dtype
# 将索引转换为代码
def indices_to_codes(
self,
indices,
project_out = True
):
# 判断输入的索引是否为图像或视频数据
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
# 如果不保留代码簿的数量维度,则重新排列索引
if not self.keep_num_codebooks_dim:
indices = rearrange(indices, '... -> ... 1')
# 将索引转换为代码,代码为-1或1的位
bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
# 将位转换为代码
codes = self.bits_to_codes(bits)
# 重新排列代码的维度
codes = rearrange(codes, '... c d -> ... (c d)')
# 是否将代码投影到原始维度
# 如果输入特征维度不是log2(代码簿大小)
if project_out:
codes = self.project_out(codes)
# 将代码重新排列回原始形状
if is_img_or_video:
codes = rearrange(codes, 'b ... d -> b d ...')
return codes
# 前向传播函数
@autocast(enabled = False)
def forward(
self,
x,
inv_temperature = 100.,
return_loss_breakdown = False,
mask = None,
.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\random_projection_quantizer.py
import torch
from torch import nn, einsum
import torch.nn.functional as F
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
from einops import rearrange, repeat, pack, unpack
def exists(val):
return val is not None
class RandomProjectionQuantizer(nn.Module):
""" https://arxiv.org/abs/2202.01855 """
def __init__(
self,
*,
dim,
codebook_size,
codebook_dim,
num_codebooks = 1,
norm = True,
**kwargs
):
super().__init__()
self.num_codebooks = num_codebooks
# 初始化随机投影矩阵,形状为(num_codebooks, dim, codebook_dim)
rand_projs = torch.empty(num_codebooks, dim, codebook_dim)
nn.init.xavier_normal_(rand_projs)
# 将随机投影矩阵注册为模型的缓冲区
self.register_buffer('rand_projs', rand_projs)
# 根据输入参数决定是否进行归一化
self.norm = nn.LayerNorm(dim, elementwise_affine = False) if norm else nn.Identity()
# 创建向量量化层
self.vq = VectorQuantize(
dim = codebook_dim * num_codebooks,
heads = num_codebooks,
codebook_size = codebook_size,
use_cosine_sim = True,
separate_codebook_per_head = True,
**kwargs
)
def forward(
self,
x,
indices = None
):
return_loss = exists(indices)
# 对输入数据进行归一化
x = self.norm(x)
# 进行随机投影
x = einsum('b n d, h d e -> b n h e', x, self.rand_projs)
x, ps = pack([x], 'b n *')
# 将向量量化层设置为评估模式
self.vq.eval()
# 使用向量量化层处理输入数据
out = self.vq(x, indices = indices)
if return_loss:
_, ce_loss = out
return ce_loss
_, indices, _ = out
return indices
.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\residual_fsq.py
import random
from math import log2
from functools import partial
from typing import List
import torch
from torch import nn
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from torch.cuda.amp import autocast
from vector_quantize_pytorch.finite_scalar_quantization import FSQ
from einops import rearrange, repeat, reduce, pack, unpack
from einx import get_at
# helper functions
# 检查值是否存在
def exists(val):
return val is not None
# 返回列表的第一个元素
def first(l):
return l[0]
# 如果值存在则返回值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 将数字向上取整到最接近的倍数
def round_up_multiple(num, mult):
return ceil(num / mult) * mult
# main class
class ResidualFSQ(Module):
""" Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
def __init__(
self,
*,
dim,
levels: List[int],
num_quantizers,
quantize_dropout = False,
quantize_dropout_cutoff_index = 0,
quantize_dropout_multiple_of = 1,
**kwargs
):
super().__init__()
codebook_dim = len(levels)
requires_projection = codebook_dim != dim
# 如果需要投影,则创建输入和输出的线性层
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
self.has_projections = requires_projection
self.num_quantizers = num_quantizers
self.levels = levels
self.layers = nn.ModuleList([])
levels_tensor = torch.Tensor(levels)
scales = []
for ind in range(num_quantizers):
scales.append((levels_tensor - 1) ** -ind)
fsq = FSQ(
levels = levels,
dim = codebook_dim,
**kwargs
)
self.layers.append(fsq)
assert all([not fsq.has_projections for fsq in self.layers])
self.codebook_size = self.layers[0].codebook_size
# 将尺度存储为缓冲区
self.register_buffer('scales', torch.stack(scales), persistent = False)
self.quantize_dropout = quantize_dropout and num_quantizers > 1
assert quantize_dropout_cutoff_index >= 0
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
@property
def codebooks(self):
# 获取所有量化器的隐式码书
codebooks = [layer.implicit_codebook for layer in self.layers]
codebooks = torch.stack(codebooks, dim = 0)
return codebooks
def get_codes_from_indices(self, indices):
batch, quantize_dim = indices.shape[0], indices.shape[-1]
# 可能会接收到形状为 'b h w q' 的索引(accept_image_fmap)
indices, ps = pack([indices], 'b * q')
# 由于量化丢失,可能会传入粗糙的索引,网络应该能够重建
if quantize_dim < self.num_quantizers:
assert self.quantize_dropout > 0., 'quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations'
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1)
# 处理量化器丢失
mask = indices == -1
indices = indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later
all_codes = get_at('q [c] d, b n q -> q b n d', self.codebooks, indices)
# 屏蔽任何被丢弃的代码
all_codes = all_codes.masked_fill(rearrange(mask, 'b n q -> q b n 1'), 0.)
# 缩放代码
scales = rearrange(self.scales, 'q d -> q 1 1 d')
all_codes = all_codes * scales
# 如果(accept_image_fmap = True),则返回形状(量化,批量,高度,宽度,维度)
all_codes, = unpack(all_codes, ps, 'q b * d')
return all_codes
# 从给定的索引中获取输出
def get_output_from_indices(self, indices):
# 从索引中获取编码
codes = self.get_codes_from_indices(indices)
# 对编码进行求和
codes_summed = reduce(codes, 'q ... -> ...', 'sum')
# 对求和后的编码进行投影
return self.project_out(codes_summed)
# 前向传播函数
def forward(
self,
x,
return_all_codes = False,
rand_quantize_dropout_fixed_seed = None
):
# 获取量化器数量、量化丢弃倍数、设备信息
num_quant, quant_dropout_multiple_of, device = self.num_quantizers, self.quantize_dropout_multiple_of, x.device
# 对输入进行投影
x = self.project_in(x)
quantized_out = 0.
residual = first(self.layers).bound(x)
all_indices = []
should_quantize_dropout = self.training and self.quantize_dropout
# 从中随机选择一个层索引,用于进一步丢弃残差量化
# 同时准备空索引
if should_quantize_dropout:
rand = random.Random(rand_quantize_dropout_fixed_seed) if exists(rand_quantize_dropout_fixed_seed) else random
rand_quantize_dropout_index = rand.randrange(self.quantize_dropout_cutoff_index, num_quant)
if quant_dropout_multiple_of != 1:
rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1
null_indices = torch.full(x.shape[:2], -1., device = device, dtype = torch.long)
# 遍历所有层
with autocast(enabled = False):
for quantizer_index, (layer, scale) in enumerate(zip(self.layers, self.scales)):
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
all_indices.append(null_indices)
continue
quantized, indices = layer(residual / scale)
quantized = quantized * scale
residual = residual - quantized.detach()
quantized_out = quantized_out + quantized
all_indices.append(indices)
# 如果需要,进行投影
quantized_out = self.project_out(quantized_out)
# 将所有索引堆叠在一起
all_indices = torch.stack(all_indices, dim = -1)
ret = (quantized_out, all_indices)
if not return_all_codes:
return ret
# 是否返回所有层中所有码书的所有编码
all_codes = self.get_codes_from_indices(all_indices)
# 返回所有编码的形状为 (量化器,批次,序列长度,码书维度)
return (*ret, all_codes)
# 定义一个名为 GroupedResidualFSQ 的类,继承自 Module 类
class GroupedResidualFSQ(Module):
# 初始化函数,接收参数 dim、groups、accept_image_fmap 和 kwargs
def __init__(
self,
*,
dim,
groups = 1,
accept_image_fmap = False,
**kwargs
):
# 调用父类的初始化函数
super().__init__()
# 初始化类的属性 dim 和 groups
self.dim = dim
self.groups = groups
# 断言 dim 能够被 groups 整除
assert (dim % groups) == 0
# 计算每个组的维度
dim_per_group = dim // groups
# 初始化类的属性 accept_image_fmap
self.accept_image_fmap = accept_image_fmap
# 初始化一个空的 ModuleList 对象 rvqs
self.rvqs = nn.ModuleList([])
# 循环创建 groups 个 ResidualFSQ 对象并添加到 rvqs 中
for _ in range(groups):
self.rvqs.append(ResidualFSQ(
dim = dim_per_group,
**kwargs
))
# 获取第一个 ResidualFSQ 对象的 codebook_size 属性作为类的 codebook_size 属性
self.codebook_size = self.rvqs[0].codebook_size
# 定义 codebooks 属性,返回所有 rvqs 中的 codebooks 组成的张量
@property
def codebooks(self):
return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
# 定义 split_dim 属性,根据 accept_image_fmap 的值返回不同的维度
@property
def split_dim(self):
return 1 if self.accept_image_fmap else -1
# 定义 get_codes_from_indices 方法,根据 indices 获取对应的 codes
def get_codes_from_indices(self, indices):
codes = tuple(rvq.get_codes_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
return torch.stack(codes)
# 定义 get_output_from_indices 方法,根据 indices 获取对应的 outputs
def get_output_from_indices(self, indices):
outputs = tuple(rvq.get_output_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
return torch.cat(outputs, dim = self.split_dim)
# 定义前向传播函数 forward,接收参数 x 和 return_all_codes
def forward(
self,
x,
return_all_codes = False
):
# 获取输入 x 的形状和 split_dim
shape, split_dim = x.shape, self.split_dim
# 断言输入 x 在 split_dim 维度上的大小等于 dim
assert shape[split_dim] == self.dim
# 将特征维度分成 groups 组
x = x.chunk(self.groups, dim = split_dim)
forward_kwargs = dict(
return_all_codes = return_all_codes,
rand_quantize_dropout_fixed_seed = random.randint(0, 1e7)
)
# 对每个组分别调用对应的 ResidualFSQ 对象进行前向传播
out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
out = tuple(zip(*out))
# 否则,获取所有的 zipped 输出并将它们组合起来
quantized, all_indices, *maybe_all_codes = out
quantized = torch.cat(quantized, dim = split_dim)
all_indices = torch.stack(all_indices)
ret = (quantized, all_indices, *maybe_all_codes)
return ret
.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\residual_lfq.py
# 导入所需的库
import random
from math import log2
from functools import partial
import torch
from torch import nn
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from torch.cuda.amp import autocast
# 导入自定义的 LFQ 模块
from vector_quantize_pytorch.lookup_free_quantization import LFQ
# 导入 einops 库中的函数
from einops import rearrange, repeat, reduce, pack, unpack
# 导入自定义的 get_at 函数
from einx import get_at
# 辅助函数
# 检查变量是否存在
def exists(val):
return val is not None
# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 将数字向上取整到最接近的倍数
def round_up_multiple(num, mult):
return ceil(num / mult) * mult
# 主类
class ResidualLFQ(Module):
""" Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
def __init__(
self,
*,
dim,
num_quantizers,
codebook_size,
quantize_dropout = False,
quantize_dropout_cutoff_index = 0,
quantize_dropout_multiple_of = 1,
**kwargs
):
super().__init__()
codebook_dim = int(log2(codebook_size))
requires_projection = codebook_dim != dim
# 如果 codebook_dim 不等于 dim,则需要进行投影
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
self.has_projections = requires_projection
self.num_quantizers = num_quantizers
self.layers = nn.ModuleList([])
# 创建 num_quantizers 个 LFQ 层
for ind in range(num_quantizers):
codebook_scale = 2 ** -ind
lfq = LFQ(
dim = codebook_dim,
codebook_scale = codebook_scale,
**kwargs
)
self.layers.append(lfq)
# 断言所有 LFQ 层都没有投影
assert all([not lfq.has_projections for lfq in self.layers])
self.quantize_dropout = quantize_dropout and num_quantizers > 1
# 断言 quantize_dropout_cutoff_index 大于等于 0
assert quantize_dropout_cutoff_index >= 0
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # 编码论文提出结构化的 dropout,这里设置为 4
@property
def codebooks(self):
# 获取所有 LFQ 层的 codebook,并按维度 0 进行堆叠
codebooks = [layer.codebook for layer in self.layers]
codebooks = torch.stack(codebooks, dim = 0)
return codebooks
def get_codes_from_indices(self, indices):
batch, quantize_dim = indices.shape[0], indices.shape[-1]
# 可能接收到 'b h w q' 形状的 indices(accept_image_fmap)
indices, ps = pack([indices], 'b * q')
# 由于 quantize dropout,可能传入粗糙的 indices,网络应该能够重构
if quantize_dim < self.num_quantizers:
assert self.quantize_dropout > 0., '如果希望从较少的精细量化信号重构,则 quantize dropout 必须大于 0'
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1)
# 处理量化器 dropout
mask = indices == -1.
indices = indices.masked_fill(mask, 0) # 有一个虚拟代码被掩盖
all_codes = get_at('q [c] d, b n q -> q b n d', self.codebooks, indices)
# 掩盖任何被 dropout 的代码
all_codes = all_codes.masked_fill(rearrange(mask, 'b n q -> q b n 1'), 0.)
# 如果(accept_image_fmap = True),则返回形状为(quantize,batch,height,width,dimension)
all_codes, = unpack(all_codes, ps, 'q b * d')
return all_codes
def get_output_from_indices(self, indices):
codes = self.get_codes_from_indices(indices)
codes_summed = reduce(codes, 'q ... -> ...', 'sum')
return self.project_out(codes_summed)
def forward(
self,
x,
mask = None,
return_all_codes = False,
rand_quantize_dropout_fixed_seed = None
):
# 获取量化器数量、量化丢弃的倍数、设备信息
num_quant, quant_dropout_multiple_of, device = self.num_quantizers, self.quantize_dropout_multiple_of, x.device
# 对输入进行投影
x = self.project_in(x)
# 初始化量化输出和残差
quantized_out = 0.
residual = x
# 初始化损失列表和索引列表
all_losses = []
all_indices = []
# 是否需要进行量化丢弃
should_quantize_dropout = self.training and self.quantize_dropout
# 随机选择一个层索引,用于进一步丢弃残差量化
# 同时准备空索引和损失
if should_quantize_dropout:
rand = random.Random(rand_quantize_dropout_fixed_seed) if exists(rand_quantize_dropout_fixed_seed) else random
rand_quantize_dropout_index = rand.randrange(self.quantize_dropout_cutoff_index, num_quant)
if quant_dropout_multiple_of != 1:
rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1
null_indices = torch.full(x.shape[:2], -1., device=device, dtype=torch.long)
null_loss = torch.tensor(0., device=device, dtype=x.dtype)
# 遍历所有层
with autocast(enabled=False):
for quantizer_index, layer in enumerate(self.layers):
# 如果需要进行量化丢弃且当前层索引大于随机选择的丢弃索引
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
all_indices.append(null_indices)
all_losses.append(null_loss)
continue
# 进行量化操作,获取量化结果、索引和损失
quantized, indices, loss = layer(residual, mask=mask)
# 更新残差和量化输出
residual = residual - quantized.detach()
quantized_out = quantized_out + quantized
# 添加索引和损失到列表中
all_indices.append(indices)
all_losses.append(loss)
# 对输出进行投影
quantized_out = self.project_out(quantized_out)
# 合并所有损失和索引
all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices))
# 返回结果
ret = (quantized_out, all_indices, all_losses)
# 如果不需要返回所有编码,则直接返回结果
if not return_all_codes:
return ret
# 是否返回所有层中所有码书的所有编码
all_codes = self.get_codes_from_indices(all_indices)
# 返回所有编码的形状为(量化器,批次,序列长度,码书维度)
return (*ret, all_codes)
# 定义一个名为 GroupedResidualLFQ 的类,继承自 Module 类
class GroupedResidualLFQ(Module):
# 初始化函数,接受一些参数
def __init__(
self,
*,
dim,
groups = 1,
accept_image_fmap = False,
**kwargs
):
# 调用父类的初始化函数
super().__init__()
# 初始化类的属性
self.dim = dim
self.groups = groups
# 确保 dim 能够被 groups 整除
assert (dim % groups) == 0
dim_per_group = dim // groups
self.accept_image_fmap = accept_image_fmap
# 创建一个空的 ModuleList 对象
self.rvqs = nn.ModuleList([])
# 根据 groups 的数量循环创建 ResidualLFQ 对象并添加到 rvqs 中
for _ in range(groups):
self.rvqs.append(ResidualLFQ(
dim = dim_per_group,
**kwargs
))
# 定义 codebooks 属性,返回所有 rvq 对象的 codebooks 组成的张量
@property
def codebooks(self):
return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
# 定义 split_dim 属性,根据 accept_image_fmap 的值返回不同的维度
@property
def split_dim(self):
return 1 if self.accept_image_fmap else -1
# 根据 indices 获取每个 rvq 对象的 codes,并返回组合后的张量
def get_codes_from_indices(self, indices):
codes = tuple(rvq.get_codes_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
return torch.stack(codes)
# 根据 indices 获取每个 rvq 对象的 output,并返回组合后的张量
def get_output_from_indices(self, indices):
outputs = tuple(rvq.get_output_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
return torch.cat(outputs, dim = self.split_dim)
# 前向传播函数,接受输入 x 和一些参数
def forward(
self,
x,
mask = None,
return_all_codes = False
):
shape, split_dim = x.shape, self.split_dim
assert shape[split_dim] == self.dim
# 将特征维度按 split_dim 分成 groups 组
x = x.chunk(self.groups, dim = split_dim)
forward_kwargs = dict(
mask = mask,
return_all_codes = return_all_codes,
rand_quantize_dropout_fixed_seed = random.randint(0, 1e7)
)
# 对每个 group 调用 residual vq
out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
out = tuple(zip(*out))
# 否则,获取所有的 zipped 输出并组合它们
quantized, all_indices, commit_losses, *maybe_all_codes = out
quantized = torch.cat(quantized, dim = split_dim)
all_indices = torch.stack(all_indices)
commit_losses = torch.stack(commit_losses)
ret = (quantized, all_indices, commit_losses, *maybe_all_codes)
return ret