Lucidrains-系列项目源码解析-四十六-

Lucidrains 系列项目源码解析(四十六)

.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\residual_vq.py

# 导入必要的库
import random
from math import ceil
from functools import partial
from itertools import zip_longest

import torch
from torch import nn
import torch.nn.functional as F
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize

from einops import rearrange, repeat, reduce, pack, unpack

from einx import get_at

# 辅助函数

# 检查变量是否存在
def exists(val):
    return val is not None

# 返回默认值
def default(val, d):
    return val if exists(val) else d

# 向上取整到最接近的倍数
def round_up_multiple(num, mult):
    return ceil(num / mult) * mult

# 主类

class ResidualVQ(nn.Module):
    """ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
    def __init__(
        self,
        *,
        dim,
        num_quantizers,
        codebook_dim = None,
        shared_codebook = False,
        heads = 1,
        quantize_dropout = False,
        quantize_dropout_cutoff_index = 0,
        quantize_dropout_multiple_of = 1,
        accept_image_fmap = False,
        **kwargs
    ):
        super().__init__()
        assert heads == 1, 'residual vq is not compatible with multi-headed codes'
        codebook_dim = default(codebook_dim, dim)
        codebook_input_dim = codebook_dim * heads

        requires_projection = codebook_input_dim != dim
        # 如果需要投影,则创建投影层
        self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
        self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
        self.has_projections = requires_projection

        self.num_quantizers = num_quantizers

        self.accept_image_fmap = accept_image_fmap
        # 创建多个 VectorQuantize 层
        self.layers = nn.ModuleList([VectorQuantize(dim = codebook_dim, codebook_dim = codebook_dim, accept_image_fmap = accept_image_fmap, **kwargs) for _ in range(num_quantizers)])

        assert all([not vq.has_projections for vq in self.layers])

        self.quantize_dropout = quantize_dropout and num_quantizers > 1

        assert quantize_dropout_cutoff_index >= 0

        self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
        self.quantize_dropout_multiple_of = quantize_dropout_multiple_of  # encodec paper proposes structured dropout, believe this was set to 4

        if not shared_codebook:
            return

        # 如果共享码书,则将所有 VectorQuantize 层的码书设置为第一个 VectorQuantize 层的码书
        first_vq, *rest_vq = self.layers
        codebook = first_vq._codebook

        for vq in rest_vq:
            vq._codebook = codebook

    @property
    def codebooks(self):
        # 获取所有码书的嵌入向量
        codebooks = [layer._codebook.embed for layer in self.layers]
        codebooks = torch.stack(codebooks, dim = 0)
        codebooks = rearrange(codebooks, 'q 1 c d -> q c d')
        return codebooks

    def get_codes_from_indices(self, indices):

        batch, quantize_dim = indices.shape[0], indices.shape[-1]

        # 可能接收到形状为 'b h w q' 的索引(如果 accept_image_fmap 为 True)

        indices, ps = pack([indices], 'b * q')

        # 由于量化丢失,可能传入粗糙的索引,网络应该能够重建

        if quantize_dim < self.num_quantizers:
            assert self.quantize_dropout > 0., 'quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations'
            indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1)

        # 处理量化器丢失

        mask = indices == -1.
        indices = indices.masked_fill(mask, 0) # 用一个虚拟码填充,稍后会被屏蔽

        all_codes = get_at('q [c] d, b n q -> q b n d', self.codebooks, indices)

        # 屏蔽任何被丢弃的码

        all_codes = all_codes.masked_fill(rearrange(mask, 'b n q -> q b n 1'), 0.)

        # 如果(accept_image_fmap = True),则返回形状为(量化器,批次,高度,宽度,维度)

        all_codes, = unpack(all_codes, ps, 'q b * d')

        return all_codes
    # 根据给定的索引获取输出
    def get_output_from_indices(self, indices):
        # 从索引获取编码
        codes = self.get_codes_from_indices(indices)
        # 对编码进行求和
        codes_summed = reduce(codes, 'q ... -> ...', 'sum')
        # 投影编码
        return self.project_out(codes_summed)

    # 前向传播函数
    def forward(
        self,
        x,
        mask = None,
        indices = None,
        return_all_codes = False,
        sample_codebook_temp = None,
        freeze_codebook = False,
        rand_quantize_dropout_fixed_seed = None
    ):
        # 获取一些参数
        num_quant, quant_dropout_multiple_of, return_loss, device = self.num_quantizers, self.quantize_dropout_multiple_of, exists(indices), x.device

        # 对输入进行投影
        x = self.project_in(x)

        # 断言不接受图像特征图和存在索引
        assert not (self.accept_image_fmap and exists(indices))

        quantized_out = 0.
        residual = x

        all_losses = []
        all_indices = []

        if return_loss:
            assert not torch.any(indices == -1), 'some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss'
            ce_losses = []

        should_quantize_dropout = self.training and self.quantize_dropout and not return_loss

        # 采样一个层索引来进一步丢弃残差量化
        # 同时准备空索引和损失
        if should_quantize_dropout:
            rand = random.Random(rand_quantize_dropout_fixed_seed) if exists(rand_quantize_dropout_fixed_seed) else random

            rand_quantize_dropout_index = rand.randrange(self.quantize_dropout_cutoff_index, num_quant)

            if quant_dropout_multiple_of != 1:
                rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1

            null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
            null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long)
            null_loss = torch.full((1,), 0., device = device, dtype = x.dtype)

        # 遍历所有层
        for quantizer_index, layer in enumerate(self.layers):

            if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
                all_indices.append(null_indices)
                all_losses.append(null_loss)
                continue

            layer_indices = None
            if return_loss:
                layer_indices = indices[..., quantizer_index]

            quantized, *rest = layer(
                residual,
                mask = mask,
                indices = layer_indices,
                sample_codebook_temp = sample_codebook_temp,
                freeze_codebook = freeze_codebook
            )

            residual = residual - quantized.detach()
            quantized_out = quantized_out + quantized

            if return_loss:
                ce_loss = rest[0]
                ce_losses.append(ce_loss)
                continue

            embed_indices, loss = rest

            all_indices.append(embed_indices)
            all_losses.append(loss)

        # 投影输出,如果需要的话
        quantized_out = self.project_out(quantized_out)

        # 是否提前返回交叉熵损失
        if return_loss:
            return quantized_out, sum(ce_losses)

        # 堆叠所有损失和索引
        all_losses, all_indices = map(partial(torch.stack, dim = -1), (all_losses, all_indices))

        ret = (quantized_out, all_indices, all_losses)

        if return_all_codes:
            # 是否返回所有层中所有码书的所有编码
            all_codes = self.get_codes_from_indices(all_indices)

            # 将返回所有编码的形状设置为(量化器,批次,序列长度,码书维度)
            ret = (*ret, all_codes)

        return ret
# 定义一个名为 GroupedResidualVQ 的类,继承自 nn.Module
class GroupedResidualVQ(nn.Module):
    # 初始化函数,接受参数 dim、groups、accept_image_fmap 和 kwargs
    def __init__(
        self,
        *,
        dim,
        groups = 1,
        accept_image_fmap = False,
        **kwargs
    ):
        super().__init__()
        # 初始化类的属性
        self.dim = dim
        self.groups = groups
        assert (dim % groups) == 0
        dim_per_group = dim // groups

        self.accept_image_fmap = accept_image_fmap

        self.rvqs = nn.ModuleList([])

        # 根据 groups 的数量创建 ResidualVQ 对象列表
        for _ in range(groups):
            self.rvqs.append(ResidualVQ(
                dim = dim_per_group,
                accept_image_fmap = accept_image_fmap,
                **kwargs
            ))

    # 返回所有 rvq 对象的 codebooks
    @property
    def codebooks(self):
        return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))

    # 返回分割维度的值
    @property
    def split_dim(self):
        return 1 if self.accept_image_fmap else -1

    # 根据索引获取代码
    def get_codes_from_indices(self, indices):
        codes = tuple(rvq.get_codes_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
        return torch.stack(codes)

    # 根据索引获取输出
    def get_output_from_indices(self, indices):
        outputs = tuple(rvq.get_output_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
        return torch.cat(outputs, dim = self.split_dim)

    # 前向传播函数
    def forward(
        self,
        x,
        indices = None,
        return_all_codes = False,
        sample_codebook_temp = None,
        freeze_codebook = False,
        mask = None,
    ):
        shape, split_dim = x.shape, self.split_dim
        assert shape[split_dim] == self.dim

        # 将特征维度分成多个组

        x = x.chunk(self.groups, dim = split_dim)

        indices = default(indices, tuple())
        return_ce_loss = len(indices) > 0
        assert len(indices) == 0 or len(indices) == self.groups

        forward_kwargs = dict(
            return_all_codes = return_all_codes,
            sample_codebook_temp = sample_codebook_temp,
            mask = mask,
            freeze_codebook = freeze_codebook,
            rand_quantize_dropout_fixed_seed = random.randint(0, 1e7)
        )

        # 对每个组调用 ResidualVQ

        out = tuple(rvq(chunk, indices = chunk_indices, **forward_kwargs) for rvq, chunk, chunk_indices in zip_longest(self.rvqs, x, indices))
        out = tuple(zip(*out))

        # 如果���回交叉熵损失到 rvq codebooks

        if return_ce_loss:
            quantized, ce_losses = out
            return torch.cat(quantized, dim = split_dim), sum(ce_losses)

        # 否则,获取所有输出并组合它们

        quantized, all_indices, commit_losses, *maybe_all_codes = out

        quantized = torch.cat(quantized, dim = split_dim)
        all_indices = torch.stack(all_indices)
        commit_losses = torch.stack(commit_losses)

        ret = (quantized, all_indices, commit_losses, *maybe_all_codes)
        return ret

.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\vector_quantize_pytorch.py

# 导入必要的库
from functools import partial

import torch
from torch import nn, einsum
import torch.nn.functional as F
import torch.distributed as distributed
from torch.optim import Optimizer
from torch.cuda.amp import autocast

# 导入 einops 库中的函数
from einops import rearrange, repeat, reduce, pack, unpack

# 导入 Callable 类型
from typing import Callable

# 检查变量是否存在的函数
def exists(val):
    return val is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    return val if exists(val) else d

# 空函数
def noop(*args, **kwargs):
    pass

# 返回输入的函数
def identity(t):
    return t

# 对输入进行 L2 归一化
def l2norm(t):
    return F.normalize(t, p = 2, dim = -1)

# 计算输入张量 x 和 y 之间的欧氏距离
def cdist(x, y):
    x2 = reduce(x ** 2, 'b n d -> b n', 'sum')
    y2 = reduce(y ** 2, 'b n d -> b n', 'sum')
    xy = einsum('b i d, b j d -> b i j', x, y) * -2
    return (rearrange(x2, 'b i -> b i 1') + rearrange(y2, 'b j -> b 1 j') + xy).clamp(min = 0).sqrt()

# 计算输入张量的自然对数
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 在原地更新指数移动平均值
def ema_inplace(old, new, decay):
    is_mps = str(old.device).startswith('mps:')

    if not is_mps:
        old.lerp_(new, 1 - decay)
    else:
        old.mul_(decay).add_(new * (1 - decay))

# 将输入张量按照指定模式打包
def pack_one(t, pattern):
    return pack([t], pattern)

# 将输入张量按照指定模式解包
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 使用均匀分布初始化输入形状的张量
def uniform_init(*shape):
    t = torch.empty(shape)
    nn.init.kaiming_uniform_(t)
    return t

# 生成 Gumbel 噪声
def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

# Gumbel 分布采样
def gumbel_sample(
    logits,
    temperature = 1.,
    stochastic = False,
    straight_through = False,
    reinmax = False,
    dim = -1,
    training = True
):
    dtype, size = logits.dtype, logits.shape[dim]

    if training and stochastic and temperature > 0:
        sampling_logits = (logits / temperature) + gumbel_noise(logits)
    else:
        sampling_logits = logits

    ind = sampling_logits.argmax(dim = dim)
    one_hot = F.one_hot(ind, size).type(dtype)

    assert not (reinmax and not straight_through), 'reinmax can only be turned on if using straight through gumbel softmax'

    if not straight_through or temperature <= 0. or not training:
        return ind, one_hot

    # 使用 ReinMax 提高二阶精度
    if reinmax:
        π0 = logits.softmax(dim = dim)
        π1 = (one_hot + (logits / temperature).softmax(dim = dim)) / 2
        π1 = ((log(π1) - logits).detach() + logits).softmax(dim = 1)
        π2 = 2 * π1 - 0.5 * π0
        one_hot = π2 - π2.detach() + one_hot
    else:
        π1 = (logits / temperature).softmax(dim = dim)
        one_hot = one_hot + π1 - π1.detach()

    return ind, one_hot

# Laplace 平滑
def laplace_smoothing(x, n_categories, eps = 1e-5, dim = -1):
    denom = x.sum(dim = dim, keepdim = True)
    return (x + eps) / (denom + n_categories * eps)

# 从样本中随机抽取指定数量的向量
def sample_vectors(samples, num):
    num_samples, device = samples.shape[0], samples.device
    if num_samples >= num:
        indices = torch.randperm(num_samples, device = device)[:num]
    else:
        indices = torch.randint(0, num_samples, (num,), device = device)

    return samples[indices]

# 批量从样本中随机抽取指定数量的向量
def batched_sample_vectors(samples, num):
    return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim = 0)], dim = 0)

# 在指定维度上填充形状
def pad_shape(shape, size, dim = 0):
    return [size if i == dim else s for i, s in enumerate(shape)]

# 多项式分布采样
def sample_multinomial(total_count, probs):
    device = probs.device
    probs = probs.cpu()

    total_count = probs.new_full((), total_count)
    remainder = probs.new_ones(())
    sample = torch.empty_like(probs, dtype = torch.long)

    for i, p in enumerate(probs):
        s = torch.binomial(total_count, p / remainder)
        sample[i] = s
        total_count -= s
        remainder -= p

    return sample.to(device)

# 收集所有进程的指定维度大小
def all_gather_sizes(x, dim):
    size = torch.tensor(x.shape[dim], dtype = torch.long, device = x.device)
    all_sizes = [torch.empty_like(size) for _ in range(distributed.get_world_size())]
    distributed.all_gather(all_sizes, size)
    # 使用torch.stack将列表中的张量按照第0维度进行堆叠
    return torch.stack(all_sizes)
def all_gather_variably_sized(x, sizes, dim = 0):
    # 获取当前进程的排名
    rank = distributed.get_rank()
    # 初始化一个空列表用于存储所有进程的数据
    all_x = []

    # 遍历每个进程的数据大小
    for i, size in enumerate(sizes):
        # 如果当前进程是当前循环的进程,则直接使用原始数据x,否则创建一个新的空tensor
        t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim))
        # 使用广播将数据传输到所有进程
        distributed.broadcast(t, src = i, async_op = True)
        # 将数据添加到列表中
        all_x.append(t)

    # 等待所有进程完成数据传输
    distributed.barrier()
    return all_x

def sample_vectors_distributed(local_samples, num):
    # 重新排列本地样本数据的维度
    local_samples = rearrange(local_samples, '1 ... -> ...')

    # 获取当前进程的排名
    rank = distributed.get_rank()
    # 获取所有进程的样本数量
    all_num_samples = all_gather_sizes(local_samples, dim = 0)

    # 如果当前进程是主进程
    if rank == 0:
        # 对所有进程的样本数量进行多项式采样
        samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum())
    else:
        # 创建一个与所有进程样本数量相同的空tensor
        samples_per_rank = torch.empty_like(all_num_samples)

    # 使用广播将采样结果传输到所有进程
    distributed.broadcast(samples_per_rank, src = 0)
    # 将tensor转换为列表
    samples_per_rank = samples_per_rank.tolist()

    # 对本地样本进行采样
    local_samples = sample_vectors(local_samples, samples_per_rank[rank])
    # 将所有进程的样本数据按照不同大小进行收集
    all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim = 0)
    # 拼接所有进程的样本数据
    out = torch.cat(all_samples, dim = 0)

    return rearrange(out, '... -> 1 ...')

def batched_bincount(x, *, minlength):
    # 获取batch大小、数据类型和设备信息
    batch, dtype, device = x.shape[0], x.dtype, x.device
    # 初始化一个全零tensor用于存储结果
    target = torch.zeros(batch, minlength, dtype = dtype, device = device)
    # 初始化一个全一tensor
    values = torch.ones_like(x)
    # 对目标tensor进行scatter_add操作
    target.scatter_add_(-1, x, values)
    return target

def kmeans(
    samples,
    num_clusters,
    num_iters = 10,
    use_cosine_sim = False,
    sample_fn = batched_sample_vectors,
    all_reduce_fn = noop
):
    # 获取��本数据的维度、数据类型和设备信息
    num_codebooks, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device

    # 使用指定函数对样本数据进行采样得到初始均值
    means = sample_fn(samples, num_clusters)

    # 迭代更新均值
    for _ in range(num_iters):
        # 计算样本数据与均值之间的距离
        if use_cosine_sim:
            dists = samples @ rearrange(means, 'h n d -> h d n')
        else:
            dists = -cdist(samples, means)

        # 将样本分配到最近的均值点
        buckets = torch.argmax(dists, dim = -1)
        # 对分配结果进行统计
        bins = batched_bincount(buckets, minlength = num_clusters)
        # 对统计结果进行全局归约
        all_reduce_fn(bins)

        # 处理空簇
        zero_mask = bins == 0
        bins_min_clamped = bins.masked_fill(zero_mask, 1)

        # 计算新的均值
        new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype = dtype)
        new_means.scatter_add_(1, repeat(buckets, 'h n -> h n d', d = dim), samples)
        new_means = new_means / rearrange(bins_min_clamped, '... -> ... 1')
        all_reduce_fn(new_means)

        # 对新的均值进行归一化
        if use_cosine_sim:
            new_means = l2norm(new_means)

        # 更新均值
        means = torch.where(
            rearrange(zero_mask, '... -> ... 1'),
            means,
            new_means
        )

    return means, bins

def batched_embedding(indices, embeds):
    # 获取batch大小和嵌入维度
    batch, dim = indices.shape[1], embeds.shape[-1]
    # 将索引数据扩展到与嵌入数据相同的维度
    indices = repeat(indices, 'h b n -> h b n d', d = dim)
    # 将嵌入数据扩展到与索引数据相同的维度
    embeds = repeat(embeds, 'h c d -> h b c d', b = batch)
    # 根据索引获取对应的嵌入数据
    return embeds.gather(2, indices)

# regularization losses

def orthogonal_loss_fn(t):
    # 计算正交损失
    # 参考论文中的公式(2)
    h, n = t.shape[:2]
    normed_codes = l2norm(t)
    cosine_sim = einsum('h i d, h j d -> h i j', normed_codes, normed_codes)
    return (cosine_sim ** 2).sum() / (h * n ** 2) - (1 / n)

# distance types

class EuclideanCodebook(nn.Module):
    def __init__(
        self,
        dim,
        codebook_size,
        num_codebooks = 1,
        kmeans_init = False,
        kmeans_iters = 10,
        sync_kmeans = True,
        decay = 0.8,
        eps = 1e-5,
        threshold_ema_dead_code = 2,
        reset_cluster_size = None,
        use_ddp = False,
        learnable_codebook = False,
        gumbel_sample = gumbel_sample,
        sample_codebook_temp = 1.,
        ema_update = True,
        affine_param = False,
        sync_affine_param = False,
        affine_param_batch_decay = 0.99,
        affine_param_codebook_decay = 0.9
    ):
        # 调用父类的构造函数
        super().__init__()
        # 设置输入变换函数为恒等映射
        self.transform_input = identity

        # 设置衰减率和指数移动平均更新标志
        self.decay = decay
        self.ema_update = ema_update

        # 根据是否使用 kmeans 初始化选择初始化函数
        init_fn = uniform_init if not kmeans_init else torch.zeros
        # 初始化嵌入矩阵
        embed = init_fn(num_codebooks, codebook_size, dim)

        # 设置码书大小和码书数量
        self.codebook_size = codebook_size
        self.num_codebooks = num_codebooks

        # 设置 kmeans 迭代次数和阈值
        self.kmeans_iters = kmeans_iters
        self.eps = eps
        self.threshold_ema_dead_code = threshold_ema_dead_code
        self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code)

        # 确保 gumbel_sample 是可调用的
        assert callable(gumbel_sample)
        self.gumbel_sample = gumbel_sample
        self.sample_codebook_temp = sample_codebook_temp

        # 检查是否在分布式环境中使用 kmeans 初始化
        assert not (use_ddp and num_codebooks > 1 and kmeans_init), 'kmeans init is not compatible with multiple codebooks in distributed environment for now'

        # 根据是否使用分布式和同步 kmeans 选择采样函数
        self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
        self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
        self.all_reduce_fn = distributed.all_reduce if use_ddp else noop

        # 注册缓冲区
        self.register_buffer('initted', torch.Tensor([not kmeans_init]))
        self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size))
        self.register_buffer('embed_avg', embed.clone())

        # 设置是否可学习码书
        self.learnable_codebook = learnable_codebook
        if learnable_codebook:
            self.embed = nn.Parameter(embed)
        else:
            self.register_buffer('embed', embed)

        # 仿射相关参数

        self.affine_param = affine_param
        self.sync_affine_param = sync_affine_param

        if not affine_param:
            return

        # 设置仿射参数的衰减率
        self.affine_param_batch_decay = affine_param_batch_decay
        self.affine_param_codebook_decay = affine_param_codebook_decay

        # 注册缓冲区
        self.register_buffer('batch_mean', None)
        self.register_buffer('batch_variance', None)

        self.register_buffer('codebook_mean_needs_init', torch.Tensor([True]))
        self.register_buffer('codebook_mean', torch.empty(num_codebooks, 1, dim))
        self.register_buffer('codebook_variance_needs_init', torch.Tensor([True]))
        self.register_buffer('codebook_variance', torch.empty(num_codebooks, 1, dim))

    @torch.jit.ignore
    def init_embed_(self, data, mask = None):
        # 如果已经初始化,则直接返回
        if self.initted:
            return

        # 如果存在掩码,则重新排列数据
        if exists(mask):
            c = data.shape[0]
            data = rearrange(data[mask], '(c n) d -> c n d', c = c)

        # 使用 kmeans 初始化 embed 和 cluster_size
        embed, cluster_size = kmeans(
            data,
            self.codebook_size,
            self.kmeans_iters,
            sample_fn = self.sample_fn,
            all_reduce_fn = self.kmeans_all_reduce_fn
        )

        embed_sum = embed * rearrange(cluster_size, '... -> ... 1')

        # 更新 embed 和 cluster_size
        self.embed.data.copy_(embed)
        self.embed_avg.data.copy_(embed_sum)
        self.cluster_size.data.copy_(cluster_size)
        self.initted.data.copy_(torch.Tensor([True]))

    @torch.jit.ignore
    def update_with_decay(self, buffer_name, new_value, decay):
        # 获取旧值
        old_value = getattr(self, buffer_name)

        # 获取是否需要初始化的标志
        needs_init = getattr(self, buffer_name + "_needs_init", False)

        # 如果需要初始化,则更新标志
        if needs_init:
            self.register_buffer(buffer_name + "_needs_init", torch.Tensor([False]))

        # 如果旧值不存在或需要初始化,则注册新值
        if not exists(old_value) or needs_init:
            self.register_buffer(buffer_name, new_value.detach())

            return

        # 更新值
        value = old_value * decay + new_value.detach() * (1 - decay)
        self.register_buffer(buffer_name, value)

    @torch.jit.ignore
    # 更新仿射变换参数,根据输入数据和嵌入向量,可选地使用掩码
    def update_affine(self, data, embed, mask = None):
        # 断言仿射参数已存在
        assert self.affine_param

        # 创建一个偏函数,用于计算方差
        var_fn = partial(torch.var, unbiased = False)

        # 计算码书均值和方差
        embed = rearrange(embed, 'h ... d -> h (...) d')

        # 如果处于训练模式
        if self.training:
            # 使用指数衰减更新码书均值
            self.update_with_decay('codebook_mean', reduce(embed, 'h n d -> h 1 d', 'mean'), self.affine_param_codebook_decay)
            # 使用指数衰减更新码书方差
            self.update_with_decay('codebook_variance', reduce(embed, 'h n d -> h 1 d', var_fn), self.affine_param_codebook_decay)

        # 准备批量数据,取决于是否有掩码
        data = rearrange(data, 'h ... d -> h (...) d')

        # 如果存在掩码
        if exists(mask):
            c = data.shape[0]
            data = rearrange(data[mask], '(c n) d -> c n d', c = c)

        # 计算批量均值和方差
        if not self.sync_affine_param:
            # 如果不同步仿射参数,使用指数衰减更新批量均值和方差
            self.update_with_decay('batch_mean', reduce(data, 'h n d -> h 1 d', 'mean'), self.affine_param_batch_decay)
            self.update_with_decay('batch_variance', reduce(data, 'h n d -> h 1 d', var_fn), self.affine_param_batch_decay)
            return

        # 计算分布式均值和方差
        num_vectors, device, dtype = data.shape[-2], data.device, data.dtype

        # 计算向量数量,用作分母
        num_vectors = torch.tensor([num_vectors], device = device, dtype = dtype)
        distributed.all_reduce(num_vectors)

        # 计算分布式均值
        batch_sum = reduce(data, 'h n d -> h 1 d', 'sum')
        distributed.all_reduce(batch_sum)
        batch_mean = batch_sum / num_vectors

        self.update_with_decay('batch_mean', batch_mean, self.affine_param_batch_decay)

        # 计算分布式方差
        variance_numer = reduce((data - batch_mean) ** 2, 'h n d -> h 1 d', 'sum')
        distributed.all_reduce(variance_numer)
        batch_variance = variance_numer / num_vectors

        self.update_with_decay('batch_variance', batch_variance, self.affine_param_batch_decay)

    # 替换过期的码字
    def replace(self, batch_samples, batch_mask):
        for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0)):
            if not torch.any(mask):
                continue

            # 从样本中采样新的码字
            sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
            sampled = rearrange(sampled, '1 ... -> ...')
            
            # 替换过期的码字
            self.embed.data[ind][mask] = sampled

            self.cluster_size.data[ind][mask] = self.reset_cluster_size
            self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size

    # 过期码字
    def expire_codes_(self, batch_samples):
        if self.threshold_ema_dead_code == 0:
            return

        # 检查哪些码字过期
        expired_codes = self.cluster_size < self.threshold_ema_dead_code

        if not torch.any(expired_codes):
            return

        batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d')
        self.replace(batch_samples, batch_mask = expired_codes)

    # 前向传播函数
    @autocast(enabled = False)
    def forward(
        self,
        x,
        sample_codebook_temp = None,
        mask = None,
        freeze_codebook = False
        ):
            # 检查输入张量的维度是否小于4
            needs_codebook_dim = x.ndim < 4
            # 如果sample_codebook_temp未指定,则使用默认值
            sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp)

            # 将输入张量转换为浮点型
            x = x.float()

            # 如果需要增加codebook的维度
            if needs_codebook_dim:
                x = rearrange(x, '... -> 1 ...')

            # 获取输入张量的数据类型
            dtype = x.dtype
            # 将输入张量打包成一维数组,并返回打包后的数组和打包参数ps
            flatten, ps = pack_one(x, 'h * d')

            # 如果存在mask,则重复mask以匹配flatten的形状
            if exists(mask):
                mask = repeat(mask, 'b n -> c (b h n)', c = flatten.shape[0], h = flatten.shape[-2] // (mask.shape[0] * mask.shape[1]))

            # 初始化嵌入层
            self.init_embed_(flatten, mask = mask)

            # 如果使用仿射参数
            if self.affine_param:
                # 更新仿射参数
                self.update_affine(flatten, self.embed, mask = mask)

            # 获取嵌入层,如果不可学习则使用detach
            embed = self.embed if self.learnable_codebook else self.embed.detach()

            # 如果使用仿射参数
            if self.affine_param:
                # 计算codebook的标准差和批次的标准差
                codebook_std = self.codebook_variance.clamp(min = 1e-5).sqrt()
                batch_std = self.batch_variance.clamp(min = 1e-5).sqrt()
                # 对嵌入层进行仿射变换
                embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean

            # 计算输入张量和嵌入层之间的距离
            dist = -cdist(flatten, embed)

            # 使用Gumbel采样获取嵌入层索引和独热编码
            embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1, temperature = sample_codebook_temp, training = self.training)

            # 解包嵌入层索引
            embed_ind = unpack_one(embed_ind, ps, 'h *')

            # 如果处于训练状态
            if self.training:
                # 解包独热编码
                unpacked_onehot = unpack_one(embed_onehot, ps, 'h * c')
                # 量化操作
                quantize = einsum('h b n c, h c d -> h b n d', unpacked_onehot, embed)
            else:
                # 批量嵌入操作
                quantize = batched_embedding(embed_ind, embed)

            # 如果处于训练状态且需要EMA更新且未冻结codebook
            if self.training and self.ema_update and not freeze_codebook:

                # 如果使用仿射参数
                if self.affine_param:
                    # 对输入张量进行仿射变换
                    flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean

                # 如果存在mask,则将未被mask覆盖的部分置零
                if exists(mask):
                    embed_onehot[~mask] = 0.

                # 计算聚类大小
                cluster_size = embed_onehot.sum(dim = 1)

                # 全局归约操作
                self.all_reduce_fn(cluster_size)
                # EMA更新聚类大小
                ema_inplace(self.cluster_size.data, cluster_size, self.decay)

                # 计算嵌入层总和
                embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot)
                embed_sum = embed_sum.contiguous()
                # 全局归约操作
                self.all_reduce_fn(embed_sum)

                # EMA更新嵌入层平均值
                ema_inplace(self.embed_avg.data, embed_sum, self.decay)

                # 对聚类大小进行拉普拉斯平滑
                cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True)

                # 归一化嵌入层
                embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
                self.embed.data.copy_(embed_normalized)
                # 清除过时的code
                self.expire_codes_(x)

            # 如果需要增加codebook的维度
            if needs_codebook_dim:
                quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))

            # 解包距离
            dist = unpack_one(dist, ps, 'h * d')

            # 返回量化结果、嵌入层索引和距离
            return quantize, embed_ind, dist
class CosineSimCodebook(nn.Module):
    # 定义一个继承自 nn.Module 的类 CosineSimCodebook
    def __init__(
        self,
        dim,
        codebook_size,
        num_codebooks = 1,
        kmeans_init = False,
        kmeans_iters = 10,
        sync_kmeans = True,
        decay = 0.8,
        eps = 1e-5,
        threshold_ema_dead_code = 2,
        reset_cluster_size = None,
        use_ddp = False,
        learnable_codebook = False,
        gumbel_sample = gumbel_sample,
        sample_codebook_temp = 1.,
        ema_update = True
    ):
        # 初始化函数,接受多个参数
        super().__init__()
        # 调用父类的初始化函数

        self.transform_input = l2norm
        # 设置 transform_input 为 l2norm 函数

        self.ema_update = ema_update
        self.decay = decay
        # 设置 ema_update 和 decay 的值

        if not kmeans_init:
            embed = l2norm(uniform_init(num_codebooks, codebook_size, dim))
        else:
            embed = torch.zeros(num_codebooks, codebook_size, dim)
        # 根据 kmeans_init 的值初始化 embed

        self.codebook_size = codebook_size
        self.num_codebooks = num_codebooks
        # 设置 codebook_size 和 num_codebooks 的值

        self.kmeans_iters = kmeans_iters
        self.eps = eps
        self.threshold_ema_dead_code = threshold_ema_dead_code
        self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code)
        # 设置 kmeans_iters、eps、threshold_ema_dead_code 和 reset_cluster_size 的值

        assert callable(gumbel_sample)
        self.gumbel_sample = gumbel_sample
        self.sample_codebook_temp = sample_codebook_temp
        # 断言 gumbel_sample 是可调用的,设置 gumbel_sample 和 sample_codebook_temp 的值

        self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
        self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
        self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
        # 根据 use_ddp 和 sync_kmeans 的值选择 sample_fn、kmeans_all_reduce_fn 和 all_reduce_fn 的函数

        self.register_buffer('initted', torch.Tensor([not kmeans_init]))
        self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size))
        self.register_buffer('embed_avg', embed.clone())
        # 注册缓冲区 initted、cluster_size 和 embed_avg

        self.learnable_codebook = learnable_codebook
        if learnable_codebook:
            self.embed = nn.Parameter(embed)
        else:
            self.register_buffer('embed', embed)
        # 设置 learnable_codebook 和 embed

    @torch.jit.ignore
    def init_embed_(self, data, mask = None):
        # 定义一个忽略 Torch JIT 的函数 init_embed_
        if self.initted:
            return
        # 如果已经初始化过,则直接返回

        if exists(mask):
            c = data.shape[0]
            data = rearrange(data[mask], '(c n) d -> c n d', c = c)
        # 如果 mask 存在,则重新排列数据

        embed, cluster_size = kmeans(
            data,
            self.codebook_size,
            self.kmeans_iters,
            use_cosine_sim = True,
            sample_fn = self.sample_fn,
            all_reduce_fn = self.kmeans_all_reduce_fn
        )
        # 使用 kmeans 函数初始化 embed 和 cluster_size

        embed_sum = embed * rearrange(cluster_size, '... -> ... 1')
        # 计算 embed_sum

        self.embed.data.copy_(embed)
        self.embed_avg.data.copy_(embed_sum)
        self.cluster_size.data.copy_(cluster_size)
        self.initted.data.copy_(torch.Tensor([True]))
        # 复制数据到相应的缓���区

    def replace(self, batch_samples, batch_mask):
        # 定义一个替换函数 replace
        batch_samples = l2norm(batch_samples)
        # 对 batch_samples 进行 l2norm 处理

        for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0)):
            # 遍历 batch_samples 和 batch_mask
            if not torch.any(mask):
                continue
            # 如果 mask 中没有任何元素,则继续下一次循环

            sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
            sampled = rearrange(sampled, '1 ... -> ...')
            # 对样本进行采样和重新排列

            self.embed.data[ind][mask] = sampled
            self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size
            self.cluster_size.data[ind][mask] = self.reset_cluster_size
            # 更新 embed、embed_avg 和 cluster_size

    def expire_codes_(self, batch_samples):
        # 定义一个过期代码的函数 expire_codes_
        if self.threshold_ema_dead_code == 0:
            return
        # 如果阈值为 0,则直接返回

        expired_codes = self.cluster_size < self.threshold_ema_dead_code
        # 计算过期代码

        if not torch.any(expired_codes):
            return
        # 如果没有过期代码,则直接返回

        batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d')
        self.replace(batch_samples, batch_mask = expired_codes)
        # 重新排列 batch_samples 并调用 replace 函数

    @autocast(enabled = False)
    def forward(
        self,
        x,
        sample_codebook_temp = None,
        mask = None,
        freeze_codebook = False
        # 定义前向传播函数 forward,接受多个参数
        ):
        # 检查输入张量的维度是否小于4,如果是则需要添加一个维度
        needs_codebook_dim = x.ndim < 4
        # 如果未指定sample_codebook_temp,则使用默认值
        sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp)

        # 将输入张量转换为浮点型
        x = x.float()

        # 如果需要添加一个维度,则重新排列张量
        if needs_codebook_dim:
            x = rearrange(x, '... -> 1 ...')

        # 获取输入张量的数据类型
        dtype = x.dtype

        # 将输入张量打包成一维数组,并返回打包后的数组和打包方案
        flatten, ps = pack_one(x, 'h * d')

        # 如果存在掩码,则重复掩码以匹配打包后的数组形状
        if exists(mask):
            mask = repeat(mask, 'b n -> c (b h n)', c = flatten.shape[0], h = flatten.shape[-2] // (mask.shape[0] * mask.shape[1]))

        # 初始化嵌入层,传入打包后的数组和掩码
        self.init_embed_(flatten, mask = mask)

        # 如果学习可学习码书,则使用可学习码书,否则使用固定码书
        embed = self.embed if self.learnable_codebook else self.embed.detach()

        # 计算嵌入距离
        dist = einsum('h n d, h c d -> h n c', flatten, embed)

        # 使用Gumbel采样获取嵌入索引和独热编码
        embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1, temperature = sample_codebook_temp, training = self.training)
        # 解包嵌入索引
        embed_ind = unpack_one(embed_ind, ps, 'h *')

        # 如果处于训练状态
        if self.training:
            # 解包独热编码
            unpacked_onehot = unpack_one(embed_onehot, ps, 'h * c')
            # 量化操作
            quantize = einsum('h b n c, h c d -> h b n d', unpacked_onehot, embed)
        else:
            # 使用批量嵌入获取量化结果
            quantize = batched_embedding(embed_ind, embed)

        # 如果处于训练状态且需要EMA更新且未冻结码书
        if self.training and self.ema_update and not freeze_codebook:
            # 如果存在掩码,则将掩码外的元素置零
            if exists(mask):
                embed_onehot[~mask] = 0.

            # 计算码书中每个码字的数量
            bins = embed_onehot.sum(dim = 1)
            self.all_reduce_fn(bins)

            # 更新EMA
            ema_inplace(self.cluster_size.data, bins, self.decay)

            # 计算码书的均值
            embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot)
            embed_sum = embed_sum.contiguous()
            self.all_reduce_fn(embed_sum)

            # 更新EMA
            ema_inplace(self.embed_avg.data, embed_sum, self.decay)

            # 对码书大小进行Laplace平滑
            cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True)

            # 归一化嵌入向量
            embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
            embed_normalized = l2norm(embed_normalized)

            # 更新嵌入层参数
            self.embed.data.copy_(l2norm(embed_normalized))
            # 清除过时码字
            self.expire_codes_(x)

        # 如果需要添加一个维度,则重新排列量化结果和嵌入索引
        if needs_codebook_dim:
            quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))

        # 解包嵌入距离
        dist = unpack_one(dist, ps, 'h * d')
        # 返回量化结果、嵌入索引和嵌入距离
        return quantize, embed_ind, dist
# 主类

class VectorQuantize(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,  # 输入向量的维度
        codebook_size,  # 量化码书的大小
        codebook_dim = None,  # 量化码书的维度,默认为None
        heads = 1,  # 多头注意力机制中的头数,默认为1
        separate_codebook_per_head = False,  # 每个头是否有独立的码书,默认为False
        decay = 0.8,  # 指数移动平均的衰减率,默认为0.8
        eps = 1e-5,  # 用于数值稳定性的小值,默认为1e-5
        freeze_codebook = False,  # 是否冻结码书,默认为False
        kmeans_init = False,  # 是否使用K均值初始化码书,默认为False
        kmeans_iters = 10,  # K均值初始化码书的迭代次数,默认为10
        sync_kmeans = True,  # 是否同步K均值初始化码书,默认为True
        use_cosine_sim = False,  # 是否使用余弦相似度,默认为False
        threshold_ema_dead_code = 0,  # EMA更新码书时的阈值,默认为0
        channel_last = True,  # 是否使用通道最后的数据格式,默认为True
        accept_image_fmap = False,  # 是否接受图像特征图,默认为False
        commitment_weight = 1.,  # 量化损失的权重,默认为1.0
        commitment_use_cross_entropy_loss = False,  # 是否使用交叉熵损失,默认为False
        orthogonal_reg_weight = 0.,  # 正交正则化的权重,默认为0.0
        orthogonal_reg_active_codes_only = False,  # 是否只对激活码进行正交正则化,默认为False
        orthogonal_reg_max_codes = None,  # 正交正则化的最大码书数量,默认为None
        stochastic_sample_codes = False,  # 是否随机采样码书,默认为False
        sample_codebook_temp = 1.,  # 采样码书时的温度参数,默认为1.0
        straight_through = False,  # 是否使用直通梯度传播,默认为False
        reinmax = False,  # 是否使用reinmax来改进直通梯度传播,默认为False
        sync_codebook = None,  # 同步更新码书的规则,默认为None
        sync_affine_param = False,  # 是否同步更新仿射参数,默认为False
        ema_update = True,  # 是否使用EMA更新码书,默认为True
        learnable_codebook = False,  # 是否可学习码书,默认为False
        in_place_codebook_optimizer: Callable[..., Optimizer] = None,  # 用于更新可学习码书的优化器,默认为None
        affine_param = False,  # 是否使用仿射参数,默认为False
        affine_param_batch_decay = 0.99,  # 仿射参数的批次衰减率,默认为0.99
        affine_param_codebook_decay = 0.9,  # 仿射参数的码书衰减率,默认为0.9
        sync_update_v = 0.  # ���制同步更新规则中乐观与悲观更新的参数v,默认为0.0
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化模型的维度
        self.dim = dim
        # 初始化头数
        self.heads = heads
        # 是否为每个头单独使用码书
        self.separate_codebook_per_head = separate_codebook_per_head

        # 设置码书维度,默认为模型维度
        codebook_dim = default(codebook_dim, dim)
        # 计算码书输入维度
        codebook_input_dim = codebook_dim * heads

        # 判断是否需要投影
        requires_projection = codebook_input_dim != dim
        # 如果需要投影,则使用线性层进行投影,否则使用恒等映射
        self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
        self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()

        # 记录是否有投影
        self.has_projections = requires_projection

        # 设置 epsilon
        self.eps = eps
        # 设置码书权重
        self.commitment_weight = commitment_weight
        # 是否使用交叉熵损失作为码书的约束损失
        self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss

        # 是否可学习的码书
        self.learnable_codebook = learnable_codebook

        # 是否有码书正交损失
        has_codebook_orthogonal_loss = orthogonal_reg_weight > 0
        self.has_codebook_orthogonal_loss = has_codebook_orthogonal_loss
        self.orthogonal_reg_weight = orthogonal_reg_weight
        self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
        self.orthogonal_reg_max_codes = orthogonal_reg_max_codes

        # 检查是否EMA更新和可学习码书不兼容
        assert not (ema_update and learnable_codebook), 'learnable codebook not compatible with EMA update'

        # 检查同步更新参数的范围
        assert 0 <= sync_update_v <= 1.
        assert not (sync_update_v > 0. and not learnable_codebook), 'learnable codebook must be turned on'

        # 设置同步更新参数
        self.sync_update_v = sync_update_v

        # 根据是否使用余弦相似度选择码书类
        codebook_class = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook

        # 部分应用函数,用于生成 Gumbel 样本
        gumbel_sample_fn = partial(
            gumbel_sample,
            stochastic = stochastic_sample_codes,
            reinmax = reinmax,
            straight_through = straight_through
        )

        # 如果未提供同步码书,则根据分布式环境设置同步码书
        if not exists(sync_codebook):
            sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1

        # 设置码书参数
        codebook_kwargs = dict(
            dim = codebook_dim,
            num_codebooks = heads if separate_codebook_per_head else 1,
            codebook_size = codebook_size,
            kmeans_init = kmeans_init,
            kmeans_iters = kmeans_iters,
            sync_kmeans = sync_kmeans,
            decay = decay,
            eps = eps,
            threshold_ema_dead_code = threshold_ema_dead_code,
            use_ddp = sync_codebook,
            learnable_codebook = has_codebook_orthogonal_loss or learnable_codebook,
            sample_codebook_temp = sample_codebook_temp,
            gumbel_sample = gumbel_sample_fn,
            ema_update = ema_update
        )

        # 如果使用仿射参数,则更新码书参数
        if affine_param:
            assert not use_cosine_sim, 'affine param is only compatible with euclidean codebook'
            codebook_kwargs = dict(
                **codebook_kwargs,
                affine_param = True,
                sync_affine_param = sync_affine_param,
                affine_param_batch_decay = affine_param_batch_decay,
                affine_param_codebook_decay = affine_param_codebook_decay,
            )

        # 初始化码书对象
        self._codebook = codebook_class(**codebook_kwargs)

        # 如果存在码书优化器,则初始化
        self.in_place_codebook_optimizer = in_place_codebook_optimizer(self._codebook.parameters()) if exists(in_place_codebook_optimizer) else None

        # 设置码书大小
        self.codebook_size = codebook_size

        # 是否接受图像特征图
        self.accept_image_fmap = accept_image_fmap
        # 是否通道在最后
        self.channel_last = channel_last

    @property
    def codebook(self):
        # 获取码书
        codebook = self._codebook.embed

        # 如果每个头单独使用码书,则直接返回码书
        if self.separate_codebook_per_head:
            return codebook

        # 否则重新排列码书维度
        return rearrange(codebook, '1 ... -> ...')

    @codebook.setter
    def codebook(self, codes):
        # 如果不是每个头单独使用码书,则重新排列码书维度
        if not self.separate_codebook_per_head:
            codes = rearrange(codes, '... -> 1 ...')

        # 将码书赋值给内部码书对象
        self._codebook.embed.copy_(codes)
    # 从给定的索引中获取对应的编码
    def get_codes_from_indices(self, indices):
        # 获取编码簿
        codebook = self.codebook
        # 判断是否为多头编码
        is_multiheaded = codebook.ndim > 2

        # 如果不是多头编码
        if not is_multiheaded:
            # 从编码簿中获取对应索引的编码
            codes = codebook[indices]
        else:
            # 打包索引
            indices, ps = pack_one(indices, 'b * h')
            # 重新排列索引
            indices = rearrange(indices, 'b n h -> b h n')

            # 重复索引
            indices = repeat(indices, 'b h n -> b h n d', d = codebook.shape[-1])
            # 重复编码簿
            codebook = repeat(codebook, 'h n d -> b h n d', b = indices.shape[0])

            # 从编码簿中收集编码
            codes = codebook.gather(2, indices)
            # 重新排列编码
            codes = rearrange(codes, 'b h n d -> b n (h d)')
            # 解包编码
            codes = unpack_one(codes, ps, 'b * d')

        # 如果不是通道在最后
        if not self.channel_last:
            # 重新排列编码
            codes = rearrange(codes, 'b ... d -> b d ...')

        # 返回编码
        return codes

    # 从给定的索引中获取输出
    def get_output_from_indices(self, indices):
        # 获取编码
        codes = self.get_codes_from_indices(indices)
        # 对编码进行投影
        return self.project_out(codes)

    # 前向传播函数
    def forward(
        self,
        x,
        indices = None,
        mask = None,
        sample_codebook_temp = None,
        freeze_codebook = False

.\lucidrains\vector-quantize-pytorch\vector_quantize_pytorch\__init__.py

# 导入自定义模块中的 VectorQuantize 类
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
# 导入自定义模块中的 ResidualVQ 和 GroupedResidualVQ 类
from vector_quantize_pytorch.residual_vq import ResidualVQ, GroupedResidualVQ
# 导入自定义模块中的 RandomProjectionQuantizer 类
from vector_quantize_pytorch.random_projection_quantizer import RandomProjectionQuantizer
# 导入自定义模块中的 FSQ 类
from vector_quantize_pytorch.finite_scalar_quantization import FSQ
# 导入自定义模块中的 LFQ 类
from vector_quantize_pytorch.lookup_free_quantization import LFQ
# 导入自定义模块中的 ResidualLFQ 和 GroupedResidualLFQ 类
from vector_quantize_pytorch.residual_lfq import ResidualLFQ, GroupedResidualLFQ
# 导入自定义模块中的 ResidualFSQ 和 GroupedResidualFSQ 类
from vector_quantize_pytorch.residual_fsq import ResidualFSQ, GroupedResidualFSQ

machine imagined fireworks

these fireworks do not exist

Video Diffusion - Pytorch

Text to video, it is happening! Official Project Page

Implementation of Video Diffusion Models, Jonathan Ho's new paper extending DDPMs to Video Generation - in Pytorch. It uses a special space-time factored U-net, extending generation from 2d images to 3d videos

Status

14k for difficult moving mnist (converging much faster and better than NUWA) - wip

The above experiments are possible only due to resources provided by Stability.ai

Any new developments for text-to-video synthesis will be centralized at Imagen-pytorch

Install

$ pip install video-diffusion-pytorch

Usage

import torch
from video_diffusion_pytorch import Unet3D, GaussianDiffusion

model = Unet3D(
    dim = 64,
    dim_mults = (1, 2, 4, 8)
)

diffusion = GaussianDiffusion(
    model,
    image_size = 32,
    num_frames = 5,
    timesteps = 1000,   # number of steps
    loss_type = 'l1'    # L1 or L2
)

videos = torch.randn(1, 3, 5, 32, 32) # video (batch, channels, frames, height, width) - normalized from -1 to +1
loss = diffusion(videos)
loss.backward()
# after a lot of training

sampled_videos = diffusion.sample(batch_size = 4)
sampled_videos.shape # (4, 3, 5, 32, 32)

For conditioning on text, they derived text embeddings by first passing the tokenized text through BERT-large. Then you just have to train it like so

import torch
from video_diffusion_pytorch import Unet3D, GaussianDiffusion

model = Unet3D(
    dim = 64,
    cond_dim = 64,
    dim_mults = (1, 2, 4, 8)
)

diffusion = GaussianDiffusion(
    model,
    image_size = 32,
    num_frames = 5,
    timesteps = 1000,   # number of steps
    loss_type = 'l1'    # L1 or L2
)

videos = torch.randn(2, 3, 5, 32, 32) # video (batch, channels, frames, height, width)
text = torch.randn(2, 64)             # assume output of BERT-large has dimension of 64

loss = diffusion(videos, cond = text)
loss.backward()
# after a lot of training

sampled_videos = diffusion.sample(cond = text)
sampled_videos.shape # (2, 3, 5, 32, 32)

You can also directly pass in the descriptions of the video as strings, if you plan on using BERT-base for text conditioning

import torch
from video_diffusion_pytorch import Unet3D, GaussianDiffusion

model = Unet3D(
    dim = 64,
    use_bert_text_cond = True,  # this must be set to True to auto-use the bert model dimensions
    dim_mults = (1, 2, 4, 8),
)

diffusion = GaussianDiffusion(
    model,
    image_size = 32,    # height and width of frames
    num_frames = 5,     # number of video frames
    timesteps = 1000,   # number of steps
    loss_type = 'l1'    # L1 or L2
)

videos = torch.randn(3, 3, 5, 32, 32) # video (batch, channels, frames, height, width)

text = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles'
]

loss = diffusion(videos, cond = text)
loss.backward()
# after a lot of training

sampled_videos = diffusion.sample(cond = text, cond_scale = 2)
sampled_videos.shape # (3, 3, 5, 32, 32)

Training

This repository also contains a handy Trainer class for training on a folder of gifs. Each gif must be of the correct dimensions image_size and num_frames.

import torch
from video_diffusion_pytorch import Unet3D, GaussianDiffusion, Trainer

model = Unet3D(
    dim = 64,
    dim_mults = (1, 2, 4, 8),
)

diffusion = GaussianDiffusion(
    model,
    image_size = 64,
    num_frames = 10,
    timesteps = 1000,   # number of steps
    loss_type = 'l1'    # L1 or L2
).cuda()

trainer = Trainer(
    diffusion,
    './data',                         # this folder path needs to contain all your training data, as .gif files, of correct image size and number of frames
    train_batch_size = 32,
    train_lr = 1e-4,
    save_and_sample_every = 1000,
    train_num_steps = 700000,         # total training steps
    gradient_accumulate_every = 2,    # gradient accumulation steps
    ema_decay = 0.995,                # exponential moving average decay
    amp = True                        # turn on mixed precision
)

trainer.train()

Sample videos (as gif files) will be saved to ./results periodically, as are the diffusion model parameters.

Co-training Images and Video

One of the claims in the paper is that by doing factored space-time attention, one can force the network to attend on the present for training images and video in conjunction, leading to better results.

It was not clear how they achieved this, but I furthered a guess.

To arrest attention to the present moment for a certain percentage of batch videos samples, simply pass prob_focus_present = <prob> on the diffusion forward method

loss = diffusion(videos, cond = text, prob_focus_present = 0.5) # for 50% of videos, focus on the present during training
loss.backward()

If you have a better idea how this is done, just open a github issue.

Todo

Citations

@misc{ho2022video,
  title   = {Video Diffusion Models}, 
  author  = {Jonathan Ho and Tim Salimans and Alexey Gritsenko and William Chan and Mohammad Norouzi and David J. Fleet},
  year    = {2022},
  eprint  = {2204.03458},
  archivePrefix = {arXiv},
  primaryClass = {cs.CV}
}
@misc{Saharia2022,
    title   = {Imagen: unprecedented photorealism × deep level of language understanding},
    author  = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
    year    = {2022}
}

.\lucidrains\video-diffusion-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
    name='video-diffusion-pytorch',  # 包名
    packages=find_packages(exclude=[]),  # 查找所有包
    version='0.6.3',  # 版本号
    license='MIT',  # 许可证
    description='Video Diffusion - Pytorch',  # 描述
    long_description_content_type='text/markdown',  # 长描述内容类型
    author='Phil Wang',  # 作者
    author_email='lucidrains@gmail.com',  # 作者邮箱
    url='https://github.com/lucidrains/video-diffusion-pytorch',  # 项目链接
    keywords=[  # 关键词列表
        'artificial intelligence',
        'deep learning',
        'denoising diffusion probabilistic models',
        'video generation'
    ],
    install_requires=[  # 依赖的包列表
        'einops>=0.4',
        'einops-exts',
        'rotary-embedding-torch',
        'sacremoses',
        'sentencepiece',
        'torch>=1.10',
        'torchvision',
        'transformers[torch]',
        'tqdm'
    ],
    classifiers=[  # 分类器列表
        'Development Status :: 4 - Beta',
        'Intended Audience :: Developers',
        'Topic :: Scientific/Engineering :: Artificial Intelligence',
        'License :: OSI Approved :: MIT License',
        'Programming Language :: Python :: 3.6',
    ],
)

.\lucidrains\video-diffusion-pytorch\video_diffusion_pytorch\text.py

# 导入 torch 库
import torch
# 从 einops 库中导入 rearrange 函数
from einops import rearrange

# 检查变量是否存在的函数
def exists(val):
    return val is not None

# 全局单例变量

# 模型和分词器的初始化为 None
MODEL = None
TOKENIZER = None
# BERT 模型的维度为 768
BERT_MODEL_DIM = 768

# 获取分词器函数
def get_tokenizer():
    global TOKENIZER
    # 如果 TOKENIZER 不存在,则加载 'bert-base-cased' 模型的分词器
    if not exists(TOKENIZER):
        TOKENIZER = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased')
    return TOKENIZER

# 获取 BERT 模型函数
def get_bert():
    global MODEL
    # 如果 MODEL 不存在,则加载 'bert-base-cased' 模型
    if not exists(MODEL):
        MODEL = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-cased')
        # 如果有 GPU 可用,则将模型移动到 GPU
        if torch.cuda.is_available():
            MODEL = MODEL.cuda()

    return MODEL

# 分词函数

def tokenize(texts, add_special_tokens = True):
    # 如果 texts 不是列表或元组,则转换为列表
    if not isinstance(texts, (list, tuple)):
        texts = [texts]

    # 获取分词器
    tokenizer = get_tokenizer()

    # 对文本进行编码
    encoding = tokenizer.batch_encode_plus(
        texts,
        add_special_tokens = add_special_tokens,
        padding = True,
        return_tensors = 'pt'
    )

    # 获取 token_ids
    token_ids = encoding.input_ids
    return token_ids

# 嵌入函数

@torch.no_grad()
def bert_embed(
    token_ids,
    return_cls_repr = False,
    eps = 1e-8,
    pad_id = 0.
):
    # 获取 BERT 模型
    model = get_bert()
    # 创建 mask,标记不为 pad_id 的位置
    mask = token_ids != pad_id

    # 如果有 GPU 可用,则将 token_ids 和 mask 移动到 GPU
    if torch.cuda.is_available():
        token_ids = token_ids.cuda()
        mask = mask.cuda()

    # 使用 BERT 模型进行前向传播
    outputs = model(
        input_ids = token_ids,
        attention_mask = mask,
        output_hidden_states = True
    )

    # 获取最后一层的隐藏状态
    hidden_state = outputs.hidden_states[-1]

    # 如果需要返回 [cls] 的表示
    if return_cls_repr:
        return hidden_state[:, 0]               # 返回 [cls] 作为表示

    # 如果 mask 不存在,则返回所有 token 的平均值
    if not exists(mask):
        return hidden_state.mean(dim = 1)

    # 重新定义 mask,排除 [cls],考虑长度
    mask = mask[:, 1:]                          # 平均所有 token,排除 [cls]
    mask = rearrange(mask, 'b n -> b n 1')

    # 计算加权平均值
    numer = (hidden_state[:, 1:] * mask).sum(dim = 1)
    denom = mask.sum(dim = 1)
    masked_mean =  numer / (denom + eps)
    return masked_mean

.\lucidrains\video-diffusion-pytorch\video_diffusion_pytorch\video_diffusion_pytorch.py

# 导入数学库
import math
# 导入拷贝库
import copy
# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum 模块
from torch import nn, einsum
# 从 torch 库中导入 F 模块
import torch.nn.functional as F
# 从 functools 库中导入 partial 函数
from functools import partial
# 从 torch.utils 库中导入 data 模块
from torch.utils import data
# 从 pathlib 库中导入 Path 类
from pathlib import Path
# 从 torch.optim 库中导入 Adam 优化器
from torch.optim import Adam
# 从 torchvision 库中导入 transforms, utils 模块
from torchvision import transforms as T, utils
# 从 torch.cuda.amp 库中导入 autocast, GradScaler 模块
from torch.cuda.amp import autocast, GradScaler
# 从 PIL 库中导入 Image 类
from PIL import Image
# 从 tqdm 库中导入 tqdm 函数
from tqdm import tqdm
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 从 einops_exts 库中导入 check_shape, rearrange_many 函数
from einops_exts import check_shape, rearrange_many
# 从 rotary_embedding_torch 库中导入 RotaryEmbedding 类
from rotary_embedding_torch import RotaryEmbedding
# 从 video_diffusion_pytorch.text 模块中导入 tokenize, bert_embed, BERT_MODEL_DIM 常量

# 辅助函数

# 判断变量是否存在
def exists(x):
    return x is not None

# 空操作函数
def noop(*args, **kwargs):
    pass

# 判断一个数是否为奇数
def is_odd(n):
    return (n % 2) == 1

# 返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 无限循环生成器
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 将一个数分成若干组
def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

# 生成概率掩码
def prob_mask_like(shape, prob, device):
    if prob == 1:
        return torch.ones(shape, device=device, dtype=torch.bool)
    elif prob == 0:
        return torch.zeros(shape, device=device, dtype=torch.bool)
    else:
        return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob

# 判断列表中是否全为字符串
def is_list_str(x):
    if not isinstance(x, (list, tuple)):
        return False
    return all([type(el) == str for el in x])

# 相对位置偏置

class RelativePositionBias(nn.Module):
    def __init__(
        self,
        heads=8,
        num_buckets=32,
        max_distance=128
    ):
        super().__init__()
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.relative_attention_bias = nn.Embedding(num_buckets, heads)

    @staticmethod
    def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
        ret = 0
        n = -relative_position

        num_buckets //= 2
        ret += (n < 0).long() * num_buckets
        n = torch.abs(n)

        max_exact = num_buckets // 2
        is_small = n < max_exact

        val_if_large = max_exact + (
            torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
        ).long()
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))

        ret += torch.where(is_small, n, val_if_large)
        return ret

    def forward(self, n, device):
        q_pos = torch.arange(n, dtype=torch.long, device=device)
        k_pos = torch.arange(n, dtype=torch.long, device=device)
        rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
        rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance)
        values = self.relative_attention_bias(rp_bucket)
        return rearrange(values, 'i j h -> h i j')

# 小助手模块

class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    # 定义一个前向传播函数,接受输入张量 x
    def forward(self, x):
        # 获取输入张量 x 的设备信息
        device = x.device
        # 计算嵌入维度的一半
        half_dim = self.dim // 2
        # 计算嵌入的指数
        emb = math.log(10000) / (half_dim - 1)
        # 计算嵌入的指数值
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        # 计算嵌入矩阵
        emb = x[:, None] * emb[None, :]
        # 将正弦和余弦值拼接在一起,形成最终的嵌入矩阵
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        # 返回嵌入矩阵
        return emb
# 定义一个上采样函数,使用 ConvTranspose3d 实现
def Upsample(dim):
    return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))

# 定义一个下采样函数,使用 Conv3d 实现
def Downsample(dim):
    return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))

# 定义 LayerNorm 类,用于实现层归一化
class LayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1))

    def forward(self, x):
        # 计算输入张量 x 的方差
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        # 计算输入张量 x 的均值
        mean = torch.mean(x, dim = 1, keepdim = True)
        # 返回经过层归一化处理后的结果
        return (x - mean) / (var + self.eps).sqrt() * self.gamma

# 定义 PreNorm 类,结合层归一化和函数 fn 的处理
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = LayerNorm(dim)

    def forward(self, x, **kwargs):
        # 对输入张量 x 进行层归一化处理
        x = self.norm(x)
        # 返回经过函数 fn 处理后的结果
        return self.fn(x, **kwargs)

# 构建基础模块

# 定义 Block 类,包含投影、归一化和激活函数
class Block(nn.Module):
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding = (0, 1, 1))
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift = None):
        # 投影操作
        x = self.proj(x)
        # 归一化操作
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            # 应用缩放和平移
            x = x * (scale + 1) + shift

        return self.act(x)

# 定义 ResnetBlock 类,包含 MLP、Block 和残差连接
class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, dim_out * 2)
        ) if exists(time_emb_dim) else None

        self.block1 = Block(dim, dim_out, groups = groups)
        self.block2 = Block(dim_out, dim_out, groups = groups)
        self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb = None):

        scale_shift = None
        if exists(self.mlp):
            assert exists(time_emb), 'time emb must be passed in'
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, 'b c -> b c 1 1 1')
            scale_shift = time_emb.chunk(2, dim = 1)

        h = self.block1(x, scale_shift = scale_shift)

        h = self.block2(h)
        return h + self.res_conv(x)

# 定义 SpatialLinearAttention 类,包含注意力计算和输出转换
class SpatialLinearAttention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, f, h, w = x.shape
        x = rearrange(x, 'b c f h w -> (b f) c h w')

        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = rearrange_many(qkv, 'b (h c) x y -> b h c (x y)', h = self.heads)

        q = q.softmax(dim = -2)
        k = k.softmax(dim = -1)

        q = q * self.scale
        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)

        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
        out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
        out = self.to_out(out)
        return rearrange(out, '(b f) c h w -> b c f h w', b = b)

# 定义 EinopsToAndFrom 类,用于实现输入输出形状的转换
class EinopsToAndFrom(nn.Module):
    def __init__(self, from_einops, to_einops, fn):
        super().__init__()
        self.from_einops = from_einops
        self.to_einops = to_einops
        self.fn = fn

    def forward(self, x, **kwargs):
        shape = x.shape
        reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape)))
        x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')
        x = self.fn(x, **kwargs)
        x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)
        return x

# 定义 Attention 类
class Attention(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 32,
        rotary_emb = None
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 计算缩放因子
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads

        # 初始化旋转嵌入和线性变换层
        self.rotary_emb = rotary_emb
        self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias = False)
        self.to_out = nn.Linear(hidden_dim, dim, bias = False)

    # 前向传播函数
    def forward(
        self,
        x,
        pos_bias = None,
        focus_present_mask = None
    ):
        # 获取输入张量的维度和设备信息
        n, device = x.shape[-2], x.device

        # 将输入张量通过线性变换层得到查询、键、值
        qkv = self.to_qkv(x).chunk(3, dim = -1)

        # 如果存在焦点存在掩码并且所有焦点都存在,则直接输出值
        if exists(focus_present_mask) and focus_present_mask.all():
            values = qkv[-1]
            return self.to_out(values)

        # 将查询、键、值按头数拆分
        q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h = self.heads)

        # 缩放查询
        q = q * self.scale

        # 将位置旋转到查询和键中以进行时间注意力
        if exists(self.rotary_emb):
            q = self.rotary_emb.rotate_queries_or_keys(q)
            k = self.rotary_emb.rotate_queries_or_keys(k)

        # 计算相似度
        sim = einsum('... h i d, ... h j d -> ... h i j', q, k)

        # 添加相对位置偏置
        if exists(pos_bias):
            sim = sim + pos_bias

        # 如果存在焦点存在掩码并且不是所有焦点都存在,则进行掩码处理
        if exists(focus_present_mask) and not (~focus_present_mask).all():
            attend_all_mask = torch.ones((n, n), device = device, dtype = torch.bool)
            attend_self_mask = torch.eye(n, device = device, dtype = torch.bool)

            mask = torch.where(
                rearrange(focus_present_mask, 'b -> b 1 1 1 1'),
                rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),
                rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),
            )

            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        # 数值稳定性处理
        sim = sim - sim.amax(dim = -1, keepdim = True).detach()
        attn = sim.softmax(dim = -1)

        # 聚合值
        out = einsum('... h i j, ... h j d -> ... h i d', attn, v)
        out = rearrange(out, '... h n d -> ... n (h d)')
        return self.to_out(out)
# 定义一个名为Unet3D的类,继承自nn.Module
class Unet3D(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        dim,  # 输入数据的维度
        cond_dim = None,  # 条件数据的维度,默认为None
        out_dim = None,  # 输出数据的维度,默认为None
        dim_mults=(1, 2, 4, 8),  # 每个层级的维度倍增系数
        channels = 3,  # 输入数据的通道数,默认为3
        attn_heads = 8,  # 注意力头的数量,默认为8
        attn_dim_head = 32,  # 每个注意力头的维度,默认为32
        use_bert_text_cond = False,  # 是否使用BERT文本条件,默认为False
        init_dim = None,  # 初始化维度,默认为None
        init_kernel_size = 7,  # 初始化卷积核大小,默认为7
        use_sparse_linear_attn = True,  # 是否使用稀疏线性注意力,默认为True
        block_type = 'resnet',  # 块类型,默认为'resnet'
        resnet_groups = 8  # ResNet块的数量,默认为8
    ):
        # 调用父类的构造函数
        super().__init__()
        # 设置通道数
        self.channels = channels

        # 时间注意力和其相对位置编码

        # 创建旋转嵌入对象
        rotary_emb = RotaryEmbedding(min(32, attn_dim_head))

        # 定义时间注意力函数
        temporal_attn = lambda dim: EinopsToAndFrom('b c f h w', 'b (h w) f c', Attention(dim, heads = attn_heads, dim_head = attn_dim_head, rotary_emb = rotary_emb))

        # 创建相对位置偏置对象
        self.time_rel_pos_bias = RelativePositionBias(heads = attn_heads, max_distance = 32) # 现实中不太可能生成那么多帧的视频...但是

        # 初始卷积

        # 初始化维度
        init_dim = default(init_dim, dim)
        assert is_odd(init_kernel_size)

        init_padding = init_kernel_size // 2
        # 创建初始卷积层
        self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size, init_kernel_size), padding = (0, init_padding, init_padding))

        # 创建初始时间注意力层
        self.init_temporal_attn = Residual(PreNorm(init_dim, temporal_attn(init_dim)))

        # 维度

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        # 时间条件

        time_dim = dim * 4
        # 创建时间 MLP
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(dim),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )

        # 文本条件

        self.has_cond = exists(cond_dim) or use_bert_text_cond
        cond_dim = BERT_MODEL_DIM if use_bert_text_cond else cond_dim

        self.null_cond_emb = nn.Parameter(torch.randn(1, cond_dim)) if self.has_cond else None

        cond_dim = time_dim + int(cond_dim or 0)

        # 层

        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])

        num_resolutions = len(in_out)

        # 块类型

        block_klass = partial(ResnetBlock, groups = resnet_groups)
        block_klass_cond = partial(block_klass, time_emb_dim = cond_dim)

        # 所有层的模块

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(nn.ModuleList([
                block_klass_cond(dim_in, dim_out),
                block_klass_cond(dim_out, dim_out),
                Residual(PreNorm(dim_out, SpatialLinearAttention(dim_out, heads = attn_heads))) if use_sparse_linear_attn else nn.Identity(),
                Residual(PreNorm(dim_out, temporal_attn(dim_out))),
                Downsample(dim_out) if not is_last else nn.Identity()
            ]))

        mid_dim = dims[-1]
        # 创建中间块1
        self.mid_block1 = block_klass_cond(mid_dim, mid_dim)

        spatial_attn = EinopsToAndFrom('b c f h w', 'b f (h w) c', Attention(mid_dim, heads = attn_heads))

        # 创建中间空间注意力层
        self.mid_spatial_attn = Residual(PreNorm(mid_dim, spatial_attn))
        # 创建中间时间注意力层
        self.mid_temporal_attn = Residual(PreNorm(mid_dim, temporal_attn(mid_dim)))

        # 创建中间块2
        self.mid_block2 = block_klass_cond(mid_dim, mid_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind >= (num_resolutions - 1)

            self.ups.append(nn.ModuleList([
                block_klass_cond(dim_out * 2, dim_in),
                block_klass_cond(dim_in, dim_in),
                Residual(PreNorm(dim_in, SpatialLinearAttention(dim_in, heads = attn_heads))) if use_sparse_linear_attn else nn.Identity(),
                Residual(PreNorm(dim_in, temporal_attn(dim_in))),
                Upsample(dim_in) if not is_last else nn.Identity()
            ]))

        out_dim = default(out_dim, channels)
        # 创建最终卷积层
        self.final_conv = nn.Sequential(
            block_klass(dim * 2, dim),
            nn.Conv3d(dim, out_dim, 1)
        )

    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 2.,
        **kwargs
    # 根据给定参数计算模型的输出 logits
    def forward(
        self,
        x,
        time,
        cond = None,
        null_cond_prob = 0.,
        focus_present_mask = None,
        prob_focus_present = 0.  # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time)
    ):
        # 检查是否存在条件 cond,如果 cond_dim 被指定,则必须传入 cond
        assert not (self.has_cond and not exists(cond)), 'cond must be passed in if cond_dim specified'
        # 获取输入 x 的 batch 大小和设备信息
        batch, device = x.shape[0], x.device

        # 如果未提供 focus_present_mask,则根据概率 prob_focus_present 创建一个概率掩码
        focus_present_mask = default(focus_present_mask, lambda: prob_mask_like((batch,), prob_focus_present, device = device))

        # 根据输入 x 的时间维度创建时间相对位置偏置
        time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device = x.device)

        # 对输入 x 进行初始卷积操作
        x = self.init_conv(x)

        # 对输入 x 进行初始时间注意力操作,使用时间相对位置偏置
        x = self.init_temporal_attn(x, pos_bias = time_rel_pos_bias)

        # 克隆输入 x,用于后续操作
        r = x.clone()

        # 如果存在时间多层感知机,则计算时间多层感知机的输出
        t = self.time_mlp(time) if exists(self.time_mlp) else None

        # 如果模型具有条件 cond,则进行条件处理
        if self.has_cond:
            # 重新获取 batch 和设备信息
            batch, device = x.shape[0], x.device
            # 根据 null_cond_prob 创建一个概率掩码
            mask = prob_mask_like((batch,), null_cond_prob, device = device)
            # 如果掩码为真,则使用 null_cond_emb 替换 cond
            cond = torch.where(rearrange(mask, 'b -> b 1'), self.null_cond_emb, cond)
            # 将 cond 与 t 连接在一起
            t = torch.cat((t, cond), dim = -1)

        # 初始化一个空列表 h,用于存储中间结果
        h = []

        # 遍历下采样模块,依次进行操作
        for block1, block2, spatial_attn, temporal_attn, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            x = spatial_attn(x)
            x = temporal_attn(x, pos_bias = time_rel_pos_bias, focus_present_mask = focus_present_mask)
            h.append(x)
            x = downsample(x)

        # 中间块1操作
        x = self.mid_block1(x, t)
        x = self.mid_spatial_attn(x)
        x = self.mid_temporal_attn(x, pos_bias = time_rel_pos_bias, focus_present_mask = focus_present_mask)
        x = self.mid_block2(x, t)

        # 遍历上采样模块,依次进行操作
        for block1, block2, spatial_attn, temporal_attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim = 1)
            x = block1(x, t)
            x = block2(x, t)
            x = spatial_attn(x)
            x = temporal_attn(x, pos_bias = time_rel_pos_bias, focus_present_mask = focus_present_mask)
            x = upsample(x)

        # 将最终输出与 r 进行连接,并返回最终卷积结果
        x = torch.cat((x, r), dim = 1)
        return self.final_conv(x)
# gaussian diffusion trainer class

# 从输入张量 a 中根据索引张量 t 提取对应元素,然后重塑形状
def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

# 根据给定的时间步数生成余弦调度表
def cosine_beta_schedule(timesteps, s = 0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.9999)

# 定义 GaussianDiffusion 类
class GaussianDiffusion(nn.Module):
    def __init__(
        self,
        denoise_fn,
        *,
        image_size,
        num_frames,
        text_use_bert_cls = False,
        channels = 3,
        timesteps = 1000,
        loss_type = 'l1',
        use_dynamic_thres = False, # from the Imagen paper
        dynamic_thres_percentile = 0.9
    ):
        super().__init__()
        self.channels = channels
        self.image_size = image_size
        self.num_frames = num_frames
        self.denoise_fn = denoise_fn

        # 生成余弦调度表
        betas = cosine_beta_schedule(timesteps)

        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, axis=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)

        timesteps, = betas.shape
        self.num_timesteps = int(timesteps)
        self.loss_type = loss_type

        # 注册缓冲区辅助函数,将 float64 类型转换为 float32 类型
        register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))

        register_buffer('betas', betas)
        register_buffer('alphas_cumprod', alphas_cumprod)
        register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

        # 计算扩散 q(x_t | x_{t-1}) 和其他参数

        register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
        register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
        register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
        register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))

        # 计算后验 q(x_{t-1} | x_t, x_0) 参数

        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

        # 上述:等于 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)

        register_buffer('posterior_variance', posterior_variance)

        # 下面:对后验方差进行对数计算,因为扩散链的开始处后验方差为 0

        register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
        register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))

        # 文本条件参数

        self.text_use_bert_cls = text_use_bert_cls

        # 在采样时使用动态阈值

        self.use_dynamic_thres = use_dynamic_thres
        self.dynamic_thres_percentile = dynamic_thres_percentile

    # 计算 q(x_t | x_{t-1}) 的均值和方差
    def q_mean_variance(self, x_start, t):
        mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
        variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
        log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
        return mean, variance, log_variance

    # 从噪声中预测起始点
    def predict_start_from_noise(self, x_t, t, noise):
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )
    # 计算后验分布的均值、方差和截断后的对数方差
    def q_posterior(self, x_start, x_t, t):
        # 计算后验分布的均值
        posterior_mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        # 计算后验分布的方差
        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
        # 获取截断后的对数方差
        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    # 计算模型的均值、方差和截断后的对数方差
    def p_mean_variance(self, x, t, clip_denoised: bool, cond = None, cond_scale = 1.):
        # 从噪声中预测起始值
        x_recon = self.predict_start_from_noise(x, t=t, noise = self.denoise_fn.forward_with_cond_scale(x, t, cond = cond, cond_scale = cond_scale))

        if clip_denoised:
            s = 1.
            if self.use_dynamic_thres:
                # 计算动态阈值
                s = torch.quantile(
                    rearrange(x_recon, 'b ... -> b (...)').abs(),
                    self.dynamic_thres_percentile,
                    dim = -1
                )

                s.clamp_(min = 1.)
                s = s.view(-1, *((1,) * (x_recon.ndim - 1)))

            # 根据阈值截断,取决于是静态还是动态
            x_recon = x_recon.clamp(-s, s) / s

        # 获取模型的均值、后验方差和后验对数方差
        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
        return model_mean, posterior_variance, posterior_log_variance

    # 生成样本
    @torch.inference_mode()
    def p_sample(self, x, t, cond = None, cond_scale = 1., clip_denoised = True):
        b, *_, device = *x.shape, x.device
        # 获取模型的均值、方差和对数方差
        model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, clip_denoised = clip_denoised, cond = cond, cond_scale = cond_scale)
        noise = torch.randn_like(x)
        # 当 t == 0 时不添加噪声
        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise

    # 循环生成样本
    @torch.inference_mode()
    def p_sample_loop(self, shape, cond = None, cond_scale = 1.):
        device = self.betas.device

        b = shape[0]
        img = torch.randn(shape, device=device)

        for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), cond = cond, cond_scale = cond_scale)

        return unnormalize_img(img)

    # 生成样本
    @torch.inference_mode()
    def sample(self, cond = None, cond_scale = 1., batch_size = 16):
        device = next(self.denoise_fn.parameters()).device

        if is_list_str(cond):
            cond = bert_embed(tokenize(cond)).to(device)

        batch_size = cond.shape[0] if exists(cond) else batch_size
        image_size = self.image_size
        channels = self.channels
        num_frames = self.num_frames
        return self.p_sample_loop((batch_size, channels, num_frames, image_size, image_size), cond = cond, cond_scale = cond_scale)

    # 插值
    @torch.inference_mode()
    def interpolate(self, x1, x2, t = None, lam = 0.5):
        b, *_, device = *x1.shape, x1.device
        t = default(t, self.num_timesteps - 1)

        assert x1.shape == x2.shape

        t_batched = torch.stack([torch.tensor(t, device=device)] * b)
        xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))

        img = (1 - lam) * xt1 + lam * xt2
        for i in tqdm(reversed(range(0, t)), desc='interpolation sample time step', total=t):
            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))

        return img

    # 从起始值生成样本
    def q_sample(self, x_start, t, noise = None):
        noise = default(noise, lambda: torch.randn_like(x_start))

        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )
    # 计算像素损失函数
    def p_losses(self, x_start, t, cond = None, noise = None, **kwargs):
        # 获取输入张量的形状和设备信息
        b, c, f, h, w, device = *x_start.shape, x_start.device
        # 如果没有提供噪声数据,则生成一个与输入张量相同形状的随机噪声张量
        noise = default(noise, lambda: torch.randn_like(x_start))

        # 生成带有噪声的输入张量
        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)

        # 如果条件是字符串列表,则将其转换为BERT嵌入表示,并根据需要返回CLS表示
        if is_list_str(cond):
            cond = bert_embed(tokenize(cond), return_cls_repr = self.text_use_bert_cls)
            cond = cond.to(device)

        # 使用去噪函数对带有噪声的输入张量进行去噪处理
        x_recon = self.denoise_fn(x_noisy, t, cond = cond, **kwargs)

        # 根据损失类型计算损失值
        if self.loss_type == 'l1':
            loss = F.l1_loss(noise, x_recon)
        elif self.loss_type == 'l2':
            loss = F.mse_loss(noise, x_recon)
        else:
            raise NotImplementedError()

        # 返回计算得到的损失值
        return loss

    # 前向传播函数
    def forward(self, x, *args, **kwargs):
        # 获取输入张量的形状信息、设备信息和图像大小
        b, device, img_size, = x.shape[0], x.device, self.image_size
        # 检查输入张量的形状是否符合要求
        check_shape(x, 'b c f h w', c = self.channels, f = self.num_frames, h = img_size, w = img_size)
        # 生成随机时间步长
        t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
        # 对输入图像进行归一化处理
        x = normalize_img(x)
        # 调用像素损失函数计算损失值并返回
        return self.p_losses(x, t, *args, **kwargs)
# trainer class

# 定义通道数与模式的映射关系
CHANNELS_TO_MODE = {
    1 : 'L',
    3 : 'RGB',
    4 : 'RGBA'
}

# 遍历图像的所有帧并转换为指定通道数的图像
def seek_all_images(img, channels = 3):
    # 检查通道数是否有效
    assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid'
    # 获取对应通道数的图像模式
    mode = CHANNELS_TO_MODE[channels]

    i = 0
    while True:
        try:
            # 尝试定位到第i帧图像并转换为指定通道数的图像
            img.seek(i)
            yield img.convert(mode)
        except EOFError:
            break
        i += 1

# 将张量转换为 GIF 图像并保存
def video_tensor_to_gif(tensor, path, duration = 120, loop = 0, optimize = True):
    # 将张量解绑定为图像列表
    images = map(T.ToPILImage(), tensor.unbind(dim = 1))
    first_img, *rest_imgs = images
    # 保存 GIF 图像
    first_img.save(path, save_all = True, append_images = rest_imgs, duration = duration, loop = loop, optimize = optimize)
    return images

# 将 GIF 图像转换为张量
def gif_to_tensor(path, channels = 3, transform = T.ToTensor()):
    # 打开 GIF 图像
    img = Image.open(path)
    # 对 GIF 图像的每一帧进行转换为张量
    tensors = tuple(map(transform, seek_all_images(img, channels = channels)))
    return torch.stack(tensors, dim = 1)

# 定义恒等函数
def identity(t, *args, **kwargs):
    return t

# 将图像张量归一化到[-1, 1]范围
def normalize_img(t):
    return t * 2 - 1

# 将归一化后的图像张量反归一化
def unnormalize_img(t):
    return (t + 1) * 0.5

# 调整张量的帧数
def cast_num_frames(t, *, frames):
    f = t.shape[1]

    if f == frames:
        return t

    if f > frames:
        return t[:, :frames]

    return F.pad(t, (0, 0, 0, 0, 0, frames - f))

# 数据集类
class Dataset(data.Dataset):
    def __init__(
        self,
        folder,
        image_size,
        channels = 3,
        num_frames = 16,
        horizontal_flip = False,
        force_num_frames = True,
        exts = ['gif']
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        self.channels = channels
        # 获取指定文件夹下所有指定扩展名的文件路径
        self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

        # 根据是否强制指定帧数选择相应的函数
        self.cast_num_frames_fn = partial(cast_num_frames, frames = num_frames) if force_num_frames else identity

        # 图像转换操作
        self.transform = T.Compose([
            T.Resize(image_size),
            T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        # 获取指定索引的文件路径并将其转换为张量
        path = self.paths[index]
        tensor = gif_to_tensor(path, self.channels, transform = self.transform)
        return self.cast_num_frames_fn(tensor)

# trainer class

# 训练器类
class Trainer(object):
    def __init__(
        self,
        diffusion_model,
        folder,
        *,
        ema_decay = 0.995,
        num_frames = 16,
        train_batch_size = 32,
        train_lr = 1e-4,
        train_num_steps = 100000,
        gradient_accumulate_every = 2,
        amp = False,
        step_start_ema = 2000,
        update_ema_every = 10,
        save_and_sample_every = 1000,
        results_folder = './results',
        num_sample_rows = 4,
        max_grad_norm = None
    # 初始化 Diffusion Trainer 类
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 设置模型
        self.model = diffusion_model
        # 创建指数移动平均对象
        self.ema = EMA(ema_decay)
        # 复制模型用于指数移动平均
        self.ema_model = copy.deepcopy(self.model)
        # 每隔一定步数更新指数移动平均
        self.update_ema_every = update_ema_every

        # 开始使用指数移动平均的步数
        self.step_start_ema = step_start_ema
        # 每隔一定步数保存模型和生成样本
        self.save_and_sample_every = save_and_sample_every

        # 训练批次大小
        self.batch_size = train_batch_size
        # 图像大小
        self.image_size = diffusion_model.image_size
        # 梯度累积步数
        self.gradient_accumulate_every = gradient_accumulate_every
        # 训练步数
        self.train_num_steps = train_num_steps

        # 获取图像大小、通道数和帧数
        image_size = diffusion_model.image_size
        channels = diffusion_model.channels
        num_frames = diffusion_model.num_frames

        # 创建数据集对象
        self.ds = Dataset(folder, image_size, channels = channels, num_frames = num_frames)

        # 打印数据集信息
        print(f'found {len(self.ds)} videos as gif files at {folder}')
        # 断言数据集长度大于0
        assert len(self.ds) > 0, 'need to have at least 1 video to start training (although 1 is not great, try 100k)'

        # 创建数据加载器
        self.dl = cycle(data.DataLoader(self.ds, batch_size = train_batch_size, shuffle=True, pin_memory=True))
        # 创建优化器
        self.opt = Adam(diffusion_model.parameters(), lr = train_lr)

        # 初始化步数
        self.step = 0

        # 是否使用混合精度训练
        self.amp = amp
        # 创建梯度缩放器
        self.scaler = GradScaler(enabled = amp)
        # 最大梯度范数
        self.max_grad_norm = max_grad_norm

        # 生成样本的行数
        self.num_sample_rows = num_sample_rows
        # 结果保存文件夹
        self.results_folder = Path(results_folder)
        # 创建结果保存文件夹
        self.results_folder.mkdir(exist_ok = True, parents = True)

        # 重置参数
        self.reset_parameters()

    # 重置参数
    def reset_parameters(self):
        # 加载模型参数到指数移动平均模型
        self.ema_model.load_state_dict(self.model.state_dict())

    # 更新指数移动平均模型
    def step_ema(self):
        # 若步数小于开始使用指数移动平均的步数,则重置参数
        if self.step < self.step_start_ema:
            self.reset_parameters()
            return
        # 更新指数移动平均模型
        self.ema.update_model_average(self.ema_model, self.model)

    # 保存模型
    def save(self, milestone):
        # 保存训练状态
        data = {
            'step': self.step,
            'model': self.model.state_dict(),
            'ema': self.ema_model.state_dict(),
            'scaler': self.scaler.state_dict()
        }
        # 将数据保存到文件
        torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))

    # 加载模型
    def load(self, milestone, **kwargs):
        # 若加载最新的检查点
        if milestone == -1:
            # 获取所有里程碑
            all_milestones = [int(p.stem.split('-')[-1]) for p in Path(self.results_folder).glob('**/*.pt')]
            # 断言至少有一个里程碑
            assert len(all_milestones) > 0, 'need to have at least one milestone to load from latest checkpoint (milestone == -1)'
            # 获取最大的里程碑
            milestone = max(all_milestones)

        # 加载模型数据
        data = torch.load(str(self.results_folder / f'model-{milestone}.pt'))

        # 更新步数和模型参数
        self.step = data['step']
        self.model.load_state_dict(data['model'], **kwargs)
        self.ema_model.load_state_dict(data['ema'], **kwargs)
        self.scaler.load_state_dict(data['scaler'])

    # 训练方法
    def train(
        self,
        prob_focus_present = 0.,
        focus_present_mask = None,
        log_fn = noop
        ):
        # 断言日志函数是可调用的
        assert callable(log_fn)

        # 当步数小于训练步数时,执行训练循环
        while self.step < self.train_num_steps:
            # 对于每个梯度累积周期
            for i in range(self.gradient_accumulate_every):
                # 从数据加载器中获取下一个数据批次并移至 GPU
                data = next(self.dl).cuda()

                # 使用自动混合精度计算
                with autocast(enabled = self.amp):
                    # 计算模型损失
                    loss = self.model(
                        data,
                        prob_focus_present = prob_focus_present,
                        focus_present_mask = focus_present_mask
                    )

                    # 反向传播并缩放损失
                    self.scaler.scale(loss / self.gradient_accumulate_every).backward()

                # 打印当前步数和损失值
                print(f'{self.step}: {loss.item()}')

            # 记录损失值
            log = {'loss': loss.item()}

            # 如果存在最大梯度范数,则对梯度进行裁剪
            if exists(self.max_grad_norm):
                self.scaler.unscale_(self.opt)
                nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

            # 更新模型参数
            self.scaler.step(self.opt)
            self.scaler.update()
            self.opt.zero_grad()

            # 每隔一定步数更新指数移动平均
            if self.step % self.update_ema_every == 0:
                self.step_ema()

            # 每隔一定步数保存模型并生成样本
            if self.step != 0 and self.step % self.save_and_sample_every == 0:
                milestone = self.step // self.save_and_sample_every
                num_samples = self.num_sample_rows ** 2
                batches = num_to_groups(num_samples, self.batch_size)

                # 生成所有样本视频
                all_videos_list = list(map(lambda n: self.ema_model.sample(batch_size=n), batches))
                all_videos_list = torch.cat(all_videos_list, dim = 0)

                # 对视频进行填充
                all_videos_list = F.pad(all_videos_list, (2, 2, 2, 2))

                # 重新排列视频帧以生成 GIF
                one_gif = rearrange(all_videos_list, '(i j) c f h w -> c f (i h) (j w)', i = self.num_sample_rows)
                video_path = str(self.results_folder / str(f'{milestone}.gif'))
                video_tensor_to_gif(one_gif, video_path)
                log = {**log, 'sample': video_path}
                self.save(milestone)

            # 记录日志
            log_fn(log)
            self.step += 1

        # 训练完成后打印消息
        print('training completed')

.\lucidrains\video-diffusion-pytorch\video_diffusion_pytorch\__init__.py

# 从 video_diffusion_pytorch 包中导入 Unet3D、GaussianDiffusion 和 Trainer 类
from video_diffusion_pytorch.video_diffusion_pytorch import Unet3D, GaussianDiffusion, Trainer

Table of Contents

Vision Transformer - Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in Yannic Kilcher's video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.

For a Pytorch implementation with pretrained models, please see Ross Wightman's repository here.

The official Jax repository is here.

A tensorflow2 translation also exists here, created by research scientist Junho Kim! 🙏

Flax translation by Enrico Shippole!

Install

$ pip install vit-pytorch

Usage

import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

Parameters

  • image_size: int.
    Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
  • patch_size: int.
    Size of patches. image_size must be divisible by patch_size.
    The number of patches is: n = (image_size // patch_size) ** 2 and n must be greater than 16.
  • num_classes: int.
    Number of classes to classify.
  • dim: int.
    Last dimension of output tensor after linear transformation nn.Linear(..., dim).
  • depth: int.
    Number of Transformer blocks.
  • heads: int.
    Number of heads in Multi-head Attention layer.
  • mlp_dim: int.
    Dimension of the MLP (FeedForward) layer.
  • channels: int, default 3.
    Number of image's channels.
  • dropout: float between [0, 1], default 0..
    Dropout rate.
  • emb_dropout: float between [0, 1], default 0.
    Embedding dropout rate.
  • pool: string, either cls token pooling or mean pooling

Simple ViT

An update from some of the same authors of the original paper proposes simplifications to ViT that allows it to train faster and better.

Among these simplifications include 2d sinusoidal positional embedding, global average pooling (no CLS token), no dropout, batch sizes of 1024 rather than 4096, and use of RandAugment and MixUp augmentations. They also show that a simple linear at the end is not significantly worse than the original MLP head

You can use it by importing the SimpleViT as shown below

import torch
from vit_pytorch import SimpleViT

v = SimpleViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

This paper proposes to leverage the flexibility of attention and masking for variable lengthed sequences to train images of multiple resolution, packed into a single batch. They demonstrate much faster training and improved accuracies, with the only cost being extra complexity in the architecture and dataloading. They use factorized 2d positional encodings, token dropping, as well as query-key normalization.

You can use it as follows

import torch
from vit_pytorch.na_vit import NaViT

v = NaViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1,
    token_dropout_prob = 0.1  # token dropout of 10% (keep 90% of tokens)
)

# 5 images of different resolutions - List[List[Tensor]]

# for now, you'll have to correctly place images in same batch element as to not exceed maximum allowed sequence length for self-attention w/ masking

images = [
    [torch.randn(3, 256, 256), torch.randn(3, 128, 128)],
    [torch.randn(3, 128, 256), torch.randn(3, 256, 128)],
    [torch.randn(3, 64, 256)]
]

preds = v(images) # (5, 1000) - 5, because 5 images of different resolution above

Or if you would rather that the framework auto group the images into variable lengthed sequences that do not exceed a certain max length

images = [
    torch.randn(3, 256, 256),
    torch.randn(3, 128, 128),
    torch.randn(3, 128, 256),
    torch.randn(3, 256, 128),
    torch.randn(3, 64, 256)
]

preds = v(
    images,
    group_images = True,
    group_max_seq_len = 64
) # (5, 1000)

Distillation

A recent paper has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.

ex. distilling from Resnet50 (or any teacher) to a vision transformer

import torch
from torchvision.models import resnet50

from vit_pytorch.distill import DistillableViT, DistillWrapper

teacher = resnet50(pretrained = True)

v = DistillableViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

distiller = DistillWrapper(
    student = v,
    teacher = teacher,
    temperature = 3,           # temperature of distillation
    alpha = 0.5,               # trade between main loss and distillation loss
    hard = False               # whether to use soft or hard distillation
)

img = torch.randn(2, 3, 256, 256)
labels = torch.randint(0, 1000, (2,))

loss = distiller(img, labels)
loss.backward()

# after lots of training above ...

pred = v(img) # (2, 1000)

The DistillableViT class is identical to ViT except for how the forward pass is handled, so you should be able to load the parameters back to ViT after you have completed distillation training.

You can also use the handy .to_vit method on the DistillableViT instance to get back a ViT instance.

v = v.to_vit()
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>

Deep ViT

This paper notes that ViT struggles to attend at greater depths (past 12 layers), and suggests mixing the attention of each head post-softmax as a solution, dubbed Re-attention. The results line up with the Talking Heads paper from NLP.

You can use it as follows

import torch
from vit_pytorch.deepvit import DeepViT

v = DeepViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

CaiT

This paper also notes difficulty in training vision transformers at greater depths and proposes two solutions. First it proposes to do per-channel multiplication of the output of the residual block. Second, it proposes to have the patches attend to one another, and only allow the CLS token to attend to the patches in the last few layers.

They also add Talking Heads, noting improvements

You can use this scheme as follows

import torch
from vit_pytorch.cait import CaiT

v = CaiT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 12,             # depth of transformer for patch to patch attention only
    cls_depth = 2,          # depth of cross attention of CLS tokens to patch
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1,
    layer_dropout = 0.05    # randomly dropout 5% of the layers
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

Token-to-Token ViT

This paper proposes that the first couple layers should downsample the image sequence by unfolding, leading to overlapping image data in each token as shown in the figure above. You can use this variant of the ViT as follows.

import torch
from vit_pytorch.t2t import T2TViT

v = T2TViT(
    dim = 512,
    image_size = 224,
    depth = 5,
    heads = 8,
    mlp_dim = 512,
    num_classes = 1000,
    t2t_layers = ((7, 4), (3, 2), (3, 2)) # tuples of the kernel size and stride of each consecutive layers of the initial token to token module
)

img = torch.randn(1, 3, 224, 224)

preds = v(img) # (1, 1000)

CCT

CCT proposes compact transformers
by using convolutions instead of patching and performing sequence pooling. This
allows for CCT to have high accuracy and a low number of parameters.

You can use this with two methods

import torch
from vit_pytorch.cct import CCT

cct = CCT(
    img_size = (224, 448),
    embedding_dim = 384,
    n_conv_layers = 2,
    kernel_size = 7,
    stride = 2,
    padding = 3,
    pooling_kernel_size = 3,
    pooling_stride = 2,
    pooling_padding = 1,
    num_layers = 14,
    num_heads = 6,
    mlp_ratio = 3.,
    num_classes = 1000,
    positional_embedding = 'learnable', # ['sine', 'learnable', 'none']
)

img = torch.randn(1, 3, 224, 448)
pred = cct(img) # (1, 1000)

Alternatively you can use one of several pre-defined models [2,4,6,7,8,14,16]
which pre-define the number of layers, number of attention heads, the mlp ratio,
and the embedding dimension.

import torch
from vit_pytorch.cct import cct_14

cct = cct_14(
    img_size = 224,
    n_conv_layers = 1,
    kernel_size = 7,
    stride = 2,
    padding = 3,
    pooling_kernel_size = 3,
    pooling_stride = 2,
    pooling_padding = 1,
    num_classes = 1000,
    positional_embedding = 'learnable', # ['sine', 'learnable', 'none']
)

Official
Repository
includes links to pretrained model checkpoints.

Cross ViT

This paper proposes to have two vision transformers processing the image at different scales, cross attending to one every so often. They show improvements on top of the base vision transformer.

import torch
from vit_pytorch.cross_vit import CrossViT

v = CrossViT(
    image_size = 256,
    num_classes = 1000,
    depth = 4,               # number of multi-scale encoding blocks
    sm_dim = 192,            # high res dimension
    sm_patch_size = 16,      # high res patch size (should be smaller than lg_patch_size)
    sm_enc_depth = 2,        # high res depth
    sm_enc_heads = 8,        # high res heads
    sm_enc_mlp_dim = 2048,   # high res feedforward dimension
    lg_dim = 384,            # low res dimension
    lg_patch_size = 64,      # low res patch size
    lg_enc_depth = 3,        # low res depth
    lg_enc_heads = 8,        # low res heads
    lg_enc_mlp_dim = 2048,   # low res feedforward dimensions
    cross_attn_depth = 2,    # cross attention rounds
    cross_attn_heads = 8,    # cross attention heads
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

pred = v(img) # (1, 1000)

PiT

This paper proposes to downsample the tokens through a pooling procedure using depth-wise convolutions.

import torch
from vit_pytorch.pit import PiT

v = PiT(
    image_size = 224,
    patch_size = 14,
    dim = 256,
    num_classes = 1000,
    depth = (3, 3, 3),     # list of depths, indicating the number of rounds of each stage before a downsample
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

# forward pass now returns predictions and the attention maps

img = torch.randn(1, 3, 224, 224)

preds = v(img) # (1, 1000)

LeViT

This paper proposes a number of changes, including (1) convolutional embedding instead of patch-wise projection (2) downsampling in stages (3) extra non-linearity in attention (4) 2d relative positional biases instead of initial absolute positional bias (5) batchnorm in place of layernorm.

Official repository

import torch
from vit_pytorch.levit import LeViT

levit = LeViT(
    image_size = 224,
    num_classes = 1000,
    stages = 3,             # number of stages
    dim = (256, 384, 512),  # dimensions at each stage
    depth = 4,              # transformer of depth 4 at each stage
    heads = (4, 6, 8),      # heads at each stage
    mlp_mult = 2,
    dropout = 0.1
)

img = torch.randn(1, 3, 224, 224)

levit(img) # (1, 1000)

CvT

This paper proposes mixing convolutions and attention. Specifically, convolutions are used to embed and downsample the image / feature map in three stages. Depthwise-convoltion is also used to project the queries, keys, and values for attention.

import torch
from vit_pytorch.cvt import CvT

v = CvT(
    num_classes = 1000,
    s1_emb_dim = 64,        # stage 1 - dimension
    s1_emb_kernel = 7,      # stage 1 - conv kernel
    s1_emb_stride = 4,      # stage 1 - conv stride
    s1_proj_kernel = 3,     # stage 1 - attention ds-conv kernel size
    s1_kv_proj_stride = 2,  # stage 1 - attention key / value projection stride
    s1_heads = 1,           # stage 1 - heads
    s1_depth = 1,           # stage 1 - depth
    s1_mlp_mult = 4,        # stage 1 - feedforward expansion factor
    s2_emb_dim = 192,       # stage 2 - (same as above)
    s2_emb_kernel = 3,
    s2_emb_stride = 2,
    s2_proj_kernel = 3,
    s2_kv_proj_stride = 2,
    s2_heads = 3,
    s2_depth = 2,
    s2_mlp_mult = 4,
    s3_emb_dim = 384,       # stage 3 - (same as above)
    s3_emb_kernel = 3,
    s3_emb_stride = 2,
    s3_proj_kernel = 3,
    s3_kv_proj_stride = 2,
    s3_heads = 4,
    s3_depth = 10,
    s3_mlp_mult = 4,
    dropout = 0.
)

img = torch.randn(1, 3, 224, 224)

pred = v(img) # (1, 1000)

Twins SVT

This paper proposes mixing local and global attention, along with position encoding generator (proposed in CPVT) and global average pooling, to achieve the same results as Swin, without the extra complexity of shifted windows, CLS tokens, nor positional embeddings.

import torch
from vit_pytorch.twins_svt import TwinsSVT

model = TwinsSVT(
    num_classes = 1000,       # number of output classes
    s1_emb_dim = 64,          # stage 1 - patch embedding projected dimension
    s1_patch_size = 4,        # stage 1 - patch size for patch embedding
    s1_local_patch_size = 7,  # stage 1 - patch size for local attention
    s1_global_k = 7,          # stage 1 - global attention key / value reduction factor, defaults to 7 as specified in paper
    s1_depth = 1,             # stage 1 - number of transformer blocks (local attn -> ff -> global attn -> ff)
    s2_emb_dim = 128,         # stage 2 (same as above)
    s2_patch_size = 2,
    s2_local_patch_size = 7,
    s2_global_k = 7,
    s2_depth = 1,
    s3_emb_dim = 256,         # stage 3 (same as above)
    s3_patch_size = 2,
    s3_local_patch_size = 7,
    s3_global_k = 7,
    s3_depth = 5,
    s4_emb_dim = 512,         # stage 4 (same as above)
    s4_patch_size = 2,
    s4_local_patch_size = 7,
    s4_global_k = 7,
    s4_depth = 4,
    peg_kernel_size = 3,      # positional encoding generator kernel size
    dropout = 0.              # dropout
)

img = torch.randn(1, 3, 224, 224)

pred = model(img) # (1, 1000)

RegionViT

This paper proposes to divide up the feature map into local regions, whereby the local tokens attend to each other. Each local region has its own regional token which then attends to all its local tokens, as well as other regional tokens.

You can use it as follows

import torch
from vit_pytorch.regionvit import RegionViT

model = RegionViT(
    dim = (64, 128, 256, 512),      # tuple of size 4, indicating dimension at each stage
    depth = (2, 2, 8, 2),           # depth of the region to local transformer at each stage
    window_size = 7,                # window size, which should be either 7 or 14
    num_classes = 1000,             # number of output classes
    tokenize_local_3_conv = False,  # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models
    use_peg = False,                # whether to use positional generating module. they used this for object detection for a boost in performance
)

img = torch.randn(1, 3, 224, 224)

pred = model(img) # (1, 1000)

CrossFormer

This paper beats PVT and Swin using alternating local and global attention. The global attention is done across the windowing dimension for reduced complexity, much like the scheme used for axial attention.

They also have cross-scale embedding layer, which they shown to be a generic layer that can improve all vision transformers. Dynamic relative positional bias was also formulated to allow the net to generalize to images of greater resolution.

import torch
from vit_pytorch.crossformer import CrossFormer

model = CrossFormer(
    num_classes = 1000,                # number of output classes
    dim = (64, 128, 256, 512),         # dimension at each stage
    depth = (2, 2, 8, 2),              # depth of transformer at each stage
    global_window_size = (8, 4, 2, 1), # global window sizes at each stage
    local_window_size = 7,             # local window size (can be customized for each stage, but in paper, held constant at 7 for all stages)
)

img = torch.randn(1, 3, 224, 224)

pred = model(img) # (1, 1000)

ScalableViT

This Bytedance AI paper proposes the Scalable Self Attention (SSA) and the Interactive Windowed Self Attention (IWSA) modules. The SSA alleviates the computation needed at earlier stages by reducing the key / value feature map by some factor (reduction_factor), while modulating the dimension of the queries and keys (ssa_dim_key). The IWSA performs self attention within local windows, similar to other vision transformer papers. However, they add a residual of the values, passed through a convolution of kernel size 3, which they named Local Interactive Module (LIM).

They make the claim in this paper that this scheme outperforms Swin Transformer, and also demonstrate competitive performance against Crossformer.

You can use it as follows (ex. ScalableViT-S)

import torch
from vit_pytorch.scalable_vit import ScalableViT

model = ScalableViT(
    num_classes = 1000,
    dim = 64,                               # starting model dimension. at every stage, dimension is doubled
    heads = (2, 4, 8, 16),                  # number of attention heads at each stage
    depth = (2, 2, 20, 2),                  # number of transformer blocks at each stage
    ssa_dim_key = (40, 40, 40, 32),         # the dimension of the attention keys (and queries) for SSA. in the paper, they represented this as a scale factor on the base dimension per key (ssa_dim_key / dim_key)
    reduction_factor = (8, 4, 2, 1),        # downsampling of the key / values in SSA. in the paper, this was represented as (reduction_factor ** -2)
    window_size = (64, 32, None, None),     # window size of the IWSA at each stage. None means no windowing needed
    dropout = 0.1,                          # attention and feedforward dropout
)

img = torch.randn(1, 3, 256, 256)

preds = model(img) # (1, 1000)

SepViT

Another Bytedance AI paper, it proposes a depthwise-pointwise self-attention layer that seems largely inspired by mobilenet's depthwise-separable convolution. The most interesting aspect is the reuse of the feature map from the depthwise self-attention stage as the values for the pointwise self-attention, as shown in the diagram above.

I have decided to include only the version of SepViT with this specific self-attention layer, as the grouped attention layers are not remarkable nor novel, and the authors were not clear on how they treated the window tokens for the group self-attention layer. Besides, it seems like with DSSA layer alone, they were able to beat Swin.

ex. SepViT-Lite

import torch
from vit_pytorch.sep_vit import SepViT

v = SepViT(
    num_classes = 1000,
    dim = 32,               # dimensions of first stage, which doubles every stage (32, 64, 128, 256) for SepViT-Lite
    dim_head = 32,          # attention head dimension
    heads = (1, 2, 4, 8),   # number of heads per stage
    depth = (1, 2, 6, 2),   # number of transformer blocks per stage
    window_size = 7,        # window size of DSS Attention block
    dropout = 0.1           # dropout
)

img = torch.randn(1, 3, 224, 224)

preds = v(img) # (1, 1000)

MaxViT

This paper proposes a hybrid convolutional / attention network, using MBConv from the convolution side, and then block / grid axial sparse attention.

They also claim this specific vision transformer is good for generative models (GANs).

ex. MaxViT-S

import torch
from vit_pytorch.max_vit import MaxViT

v = MaxViT(
    num_classes = 1000,
    dim_conv_stem = 64,               # dimension of the convolutional stem, would default to dimension of first layer if not specified
    dim = 96,                         # dimension of first layer, doubles every layer
    dim_head = 32,                    # dimension of attention heads, kept at 32 in paper
    depth = (2, 2, 5, 2),             # number of MaxViT blocks per stage, which consists of MBConv, block-like attention, grid-like attention
    window_size = 7,                  # window size for block and grids
    mbconv_expansion_rate = 4,        # expansion rate of MBConv
    mbconv_shrinkage_rate = 0.25,     # shrinkage rate of squeeze-excitation in MBConv
    dropout = 0.1                     # dropout
)

img = torch.randn(2, 3, 224, 224)

preds = v(img) # (2, 1000)

NesT

This paper decided to process the image in hierarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the hierarchy. The aggregation is done in the image plane, and contains a convolution and subsequent maxpool to allow it to pass information across the boundary.

You can use it with the following code (ex. NesT-T)

import torch
from vit_pytorch.nest import NesT

nest = NesT(
    image_size = 224,
    patch_size = 4,
    dim = 96,
    heads = 3,
    num_hierarchies = 3,        # number of hierarchies
    block_repeats = (2, 2, 8),  # the number of transformer blocks at each hierarchy, starting from the bottom
    num_classes = 1000
)

img = torch.randn(1, 3, 224, 224)

pred = nest(img) # (1, 1000)

MobileViT

This paper introduce MobileViT, a light-weight and general purpose vision transformer for mobile devices. MobileViT presents a different
perspective for the global processing of information with transformers.

You can use it with the following code (ex. mobilevit_xs)

import torch
from vit_pytorch.mobile_vit import MobileViT

mbvit_xs = MobileViT(
    image_size = (256, 256),
    dims = [96, 120, 144],
    channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
    num_classes = 1000
)

img = torch.randn(1, 3, 256, 256)

pred = mbvit_xs(img) # (1, 1000)

XCiT

This paper introduces the cross covariance attention (abbreviated XCA). One can think of it as doing attention across the features dimension rather than the spatial one (another perspective would be a dynamic 1x1 convolution, the kernel being attention map defined by spatial correlations).

Technically, this amounts to simply transposing the query, key, values before executing cosine similarity attention with learned temperature.

import torch
from vit_pytorch.xcit import XCiT

v = XCiT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 12,                     # depth of xcit transformer
    cls_depth = 2,                  # depth of cross attention of CLS tokens to patch, attention pool at end
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1,
    layer_dropout = 0.05,           # randomly dropout 5% of the layers
    local_patch_kernel_size = 3     # kernel size of the local patch interaction module (depthwise convs)
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

Simple Masked Image Modeling

This paper proposes a simple masked image modeling (SimMIM) scheme, using only a linear projection off the masked tokens into pixel space followed by an L1 loss with the pixel values of the masked patches. Results are competitive with other more complicated approaches.

You can use this as follows

import torch
from vit_pytorch import ViT
from vit_pytorch.simmim import SimMIM

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
)

mim = SimMIM(
    encoder = v,
    masking_ratio = 0.5  # they found 50% to yield the best results
)

images = torch.randn(8, 3, 256, 256)

loss = mim(images)
loss.backward()

# that's all!
# do the above in a for loop many times with a lot of images and your vision transformer will learn

torch.save(v.state_dict(), './trained-vit.pt')

Masked Autoencoder

A new Kaiming He paper proposes a simple autoencoder scheme where the vision transformer attends to a set of unmasked patches, and a smaller decoder tries to reconstruct the masked pixel values.

DeepReader quick paper review

AI Coffeebreak with Letitia

You can use it with the following code

import torch
from vit_pytorch import ViT, MAE

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
)

mae = MAE(
    encoder = v,
    masking_ratio = 0.75,   # the paper recommended 75% masked patches
    decoder_dim = 512,      # paper showed good results with just 512
    decoder_depth = 6       # anywhere from 1 to 8
)

images = torch.randn(8, 3, 256, 256)

loss = mae(images)
loss.backward()

# that's all!
# do the above in a for loop many times with a lot of images and your vision transformer will learn

# save your improved vision transformer
torch.save(v.state_dict(), './trained-vit.pt')

Masked Patch Prediction

Thanks to Zach, you can train using the original masked patch prediction task presented in the paper, with the following code.

import torch
from vit_pytorch import ViT
from vit_pytorch.mpp import MPP

model = ViT(
    image_size=256,
    patch_size=32,
    num_classes=1000,
    dim=1024,
    depth=6,
    heads=8,
    mlp_dim=2048,
    dropout=0.1,
    emb_dropout=0.1
)

mpp_trainer = MPP(
    transformer=model,
    patch_size=32,
    dim=1024,
    mask_prob=0.15,          # probability of using token in masked prediction task
    random_patch_prob=0.30,  # probability of randomly replacing a token being used for mpp
    replace_prob=0.50,       # probability of replacing a token being used for mpp with the mask token
)

opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4)

def sample_unlabelled_images():
    return torch.FloatTensor(20, 3, 256, 256).uniform_(0., 1.)

for _ in range(100):
    images = sample_unlabelled_images()
    loss = mpp_trainer(images)
    opt.zero_grad()
    loss.backward()
    opt.step()

# save your improved network
torch.save(model.state_dict(), './pretrained-net.pt')

Masked Position Prediction

New paper that introduces masked position prediction pre-training criteria. This strategy is more efficient than the Masked Autoencoder strategy and has comparable performance.

import torch
from vit_pytorch.mp3 import ViT, MP3

v = ViT(
    num_classes = 1000,
    image_size = 256,
    patch_size = 8,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    dropout = 0.1,
)

mp3 = MP3(
    vit = v,
    masking_ratio = 0.75
)

images = torch.randn(8, 3, 256, 256)

loss = mp3(images)
loss.backward()

# that's all!
# do the above in a for loop many times with a lot of images and your vision transformer will learn

# save your improved vision transformer
torch.save(v.state_dict(), './trained-vit.pt')

Adaptive Token Sampling

This paper proposes to use the CLS attention scores, re-weighed by the norms of the value heads, as means to discard unimportant tokens at different layers.

import torch
from vit_pytorch.ats_vit import ViT

v = ViT(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    max_tokens_per_depth = (256, 128, 64, 32, 16, 8), # a tuple that denotes the maximum number of tokens that any given layer should have. if the layer has greater than this amount, it will undergo adaptive token sampling
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(4, 3, 256, 256)

preds = v(img) # (4, 1000)

# you can also get a list of the final sampled patch ids
# a value of -1 denotes padding

preds, token_ids = v(img, return_sampled_token_ids = True) # (4, 1000), (4, <=8)

Patch Merger

This paper proposes a simple module (Patch Merger) for reducing the number of tokens at any layer of a vision transformer without sacrificing performance.

import torch
from vit_pytorch.vit_with_patch_merger import ViT

v = ViT(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 1024,
    depth = 12,
    heads = 8,
    patch_merge_layer = 6,        # at which transformer layer to do patch merging
    patch_merge_num_tokens = 8,   # the output number of tokens from the patch merge
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(4, 3, 256, 256)

preds = v(img) # (4, 1000)

One can also use the PatchMerger module by itself

import torch
from vit_pytorch.vit_with_patch_merger import PatchMerger

merger = PatchMerger(
    dim = 1024,
    num_tokens_out = 8   # output number of tokens
)

features = torch.randn(4, 256, 1024) # (batch, num tokens, dimension)

out = merger(features) # (4, 8, 1024)

Vision Transformer for Small Datasets

This paper proposes a new image to patch function that incorporates shifts of the image, before normalizing and dividing the image into patches. I have found shifting to be extremely helpful in some other transformers work, so decided to include this for further explorations. It also includes the LSA with the learned temperature and masking out of a token's attention to itself.

You can use as follows:

import torch
from vit_pytorch.vit_for_small_dataset import ViT

v = ViT(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(4, 3, 256, 256)

preds = v(img) # (1, 1000)

You can also use the SPT from this paper as a standalone module

import torch
from vit_pytorch.vit_for_small_dataset import SPT

spt = SPT(
    dim = 1024,
    patch_size = 16,
    channels = 3
)

img = torch.randn(4, 3, 256, 256)

tokens = spt(img) # (4, 256, 1024)

3D ViT

By popular request, I will start extending a few of the architectures in this repository to 3D ViTs, for use with video, medical imaging, etc.

You will need to pass in two additional hyperparameters: (1) the number of frames frames and (2) patch size along the frame dimension frame_patch_size

For starters, 3D ViT

import torch
from vit_pytorch.vit_3d import ViT

v = ViT(
    image_size = 128,          # image size
    frames = 16,               # number of frames
    image_patch_size = 16,     # image patch size
    frame_patch_size = 2,      # frame patch size
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)

preds = v(video) # (4, 1000)

3D Simple ViT

import torch
from vit_pytorch.simple_vit_3d import SimpleViT

v = SimpleViT(
    image_size = 128,          # image size
    frames = 16,               # number of frames
    image_patch_size = 16,     # image patch size
    frame_patch_size = 2,      # frame patch size
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
)

video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)

preds = v(video) # (4, 1000)

3D version of CCT

import torch
from vit_pytorch.cct_3d import CCT

cct = CCT(
    img_size = 224,
    num_frames = 8,
    embedding_dim = 384,
    n_conv_layers = 2,
    frame_kernel_size = 3,
    kernel_size = 7,
    stride = 2,
    padding = 3,
    pooling_kernel_size = 3,
    pooling_stride = 2,
    pooling_padding = 1,
    num_layers = 14,
    num_heads = 6,
    mlp_ratio = 3.,
    num_classes = 1000,
    positional_embedding = 'learnable'
)

video = torch.randn(1, 3, 8, 224, 224) # (batch, channels, frames, height, width)
pred = cct(video)

ViViT

This paper offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository will offer the first variant, which is a spatial transformer followed by a temporal one.

import torch
from vit_pytorch.vivit import ViT

v = ViT(
    image_size = 128,          # image size
    frames = 16,               # number of frames
    image_patch_size = 16,     # image patch size
    frame_patch_size = 2,      # frame patch size
    num_classes = 1000,
    dim = 1024,
    spatial_depth = 6,         # depth of the spatial transformer
    temporal_depth = 6,        # depth of the temporal transformer
    heads = 8,
    mlp_dim = 2048
)

video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)

preds = v(video) # (4, 1000)

Parallel ViT

This paper propose parallelizing multiple attention and feedforward blocks per layer (2 blocks), claiming that it is easier to train without loss of performance.

You can try this variant as follows

import torch
from vit_pytorch.parallel_vit import ViT

v = ViT(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    num_parallel_branches = 2,  # in paper, they claimed 2 was optimal
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(4, 3, 256, 256)

preds = v(img) # (4, 1000)

Learnable Memory ViT

This paper shows that adding learnable memory tokens at each layer of a vision transformer can greatly enhance fine-tuning results (in addition to learnable task specific CLS token and adapter head).

You can use this with a specially modified ViT as follows

import torch
from vit_pytorch.learnable_memory_vit import ViT, Adapter

# normal base ViT

v = ViT(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(4, 3, 256, 256)
logits = v(img) # (4, 1000)

# do your usual training with ViT
# ...


# then, to finetune, just pass the ViT into the Adapter class
# you can do this for multiple Adapters, as shown below

adapter1 = Adapter(
    vit = v,
    num_classes = 2,               # number of output classes for this specific task
    num_memories_per_layer = 5     # number of learnable memories per layer, 10 was sufficient in paper
)

logits1 = adapter1(img) # (4, 2) - predict 2 classes off frozen ViT backbone with learnable memories and task specific head

# yet another task to finetune on, this time with 4 classes

adapter2 = Adapter(
    vit = v,
    num_classes = 4,
    num_memories_per_layer = 10
)

logits2 = adapter2(img) # (4, 4) - predict 4 classes off frozen ViT backbone with learnable memories and task specific head

Dino

You can train ViT with the recent SOTA self-supervised learning technique, Dino, with the following code.

Yannic Kilcher video

import torch
from vit_pytorch import ViT, Dino

model = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
)

learner = Dino(
    model,
    image_size = 256,
    hidden_layer = 'to_latent',        # hidden layer name or index, from which to extract the embedding
    projection_hidden_size = 256,      # projector network hidden dimension
    projection_layers = 4,             # number of layers in projection network
    num_classes_K = 65336,             # output logits dimensions (referenced as K in paper)
    student_temp = 0.9,                # student temperature
    teacher_temp = 0.04,               # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
    local_upper_crop_scale = 0.4,      # upper bound for local crop - 0.4 was recommended in the paper 
    global_lower_crop_scale = 0.5,     # lower bound for global crop - 0.5 was recommended in the paper
    moving_average_decay = 0.9,        # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
    center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
)

opt = torch.optim.Adam(learner.parameters(), lr = 3e-4)

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()
    learner.update_moving_average() # update moving average of teacher encoder and teacher centers

# save your improved network
torch.save(model.state_dict(), './pretrained-net.pt')

EsViT

EsViT is a variant of Dino (from above) re-engineered to support efficient ViTs with patch merging / downsampling by taking into an account an extra regional loss between the augmented views. To quote the abstract, it outperforms its supervised counterpart on 17 out of 18 datasets at 3 times higher throughput.

Even though it is named as though it were a new ViT variant, it actually is just a strategy for training any multistage ViT (in the paper, they focused on Swin). The example below will show how to use it with CvT. You'll need to set the hidden_layer to the name of the layer within your efficient ViT that outputs the non-average pooled visual representations, just before the global pooling and projection to logits.

import torch
from vit_pytorch.cvt import CvT
from vit_pytorch.es_vit import EsViTTrainer

cvt = CvT(
    num_classes = 1000,
    s1_emb_dim = 64,
    s1_emb_kernel = 7,
    s1_emb_stride = 4,
    s1_proj_kernel = 3,
    s1_kv_proj_stride = 2,
    s1_heads = 1,
    s1_depth = 1,
    s1_mlp_mult = 4,
    s2_emb_dim = 192,
    s2_emb_kernel = 3,
    s2_emb_stride = 2,
    s2_proj_kernel = 3,
    s2_kv_proj_stride = 2,
    s2_heads = 3,
    s2_depth = 2,
    s2_mlp_mult = 4,
    s3_emb_dim = 384,
    s3_emb_kernel = 3,
    s3_emb_stride = 2,
    s3_proj_kernel = 3,
    s3_kv_proj_stride = 2,
    s3_heads = 4,
    s3_depth = 10,
    s3_mlp_mult = 4,
    dropout = 0.
)

learner = EsViTTrainer(
    cvt,
    image_size = 256,
    hidden_layer = 'layers',           # hidden layer name or index, from which to extract the embedding
    projection_hidden_size = 256,      # projector network hidden dimension
    projection_layers = 4,             # number of layers in projection network
    num_classes_K = 65336,             # output logits dimensions (referenced as K in paper)
    student_temp = 0.9,                # student temperature
    teacher_temp = 0.04,               # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
    local_upper_crop_scale = 0.4,      # upper bound for local crop - 0.4 was recommended in the paper
    global_lower_crop_scale = 0.5,     # lower bound for global crop - 0.5 was recommended in the paper
    moving_average_decay = 0.9,        # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
    center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
)

opt = torch.optim.AdamW(learner.parameters(), lr = 3e-4)

def sample_unlabelled_images():
    return torch.randn(8, 3, 256, 256)

for _ in range(1000):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()
    learner.update_moving_average() # update moving average of teacher encoder and teacher centers

# save your improved network
torch.save(cvt.state_dict(), './pretrained-net.pt')

Accessing Attention

If you would like to visualize the attention weights (post-softmax) for your research, just follow the procedure below

import torch
from vit_pytorch.vit import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

# import Recorder and wrap the ViT

from vit_pytorch.recorder import Recorder
v = Recorder(v)

# forward pass now returns predictions and the attention maps

img = torch.randn(1, 3, 256, 256)
preds, attns = v(img)

# there is one extra patch due to the CLS token

attns # (1, 6, 16, 65, 65) - (batch x layers x heads x patch x patch)

to cleanup the class and the hooks once you have collected enough data

v = v.eject()  # wrapper is discarded and original ViT instance is returned

Accessing Embeddings

You can similarly access the embeddings with the Extractor wrapper

import torch
from vit_pytorch.vit import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

# import Recorder and wrap the ViT

from vit_pytorch.extractor import Extractor
v = Extractor(v)

# forward pass now returns predictions and the attention maps

img = torch.randn(1, 3, 256, 256)
logits, embeddings = v(img)

# there is one extra token due to the CLS token

embeddings # (1, 65, 1024) - (batch x patches x model dim)

Or say for CrossViT, which has a multi-scale encoder that outputs two sets of embeddings for 'large' and 'small' scales

import torch
from vit_pytorch.cross_vit import CrossViT

v = CrossViT(
    image_size = 256,
    num_classes = 1000,
    depth = 4,
    sm_dim = 192,
    sm_patch_size = 16,
    sm_enc_depth = 2,
    sm_enc_heads = 8,
    sm_enc_mlp_dim = 2048,
    lg_dim = 384,
    lg_patch_size = 64,
    lg_enc_depth = 3,
    lg_enc_heads = 8,
    lg_enc_mlp_dim = 2048,
    cross_attn_depth = 2,
    cross_attn_heads = 8,
    dropout = 0.1,
    emb_dropout = 0.1
)

# wrap the CrossViT

from vit_pytorch.extractor import Extractor
v = Extractor(v, layer_name = 'multi_scale_encoder') # take embedding coming from the output of multi-scale-encoder

# forward pass now returns predictions and the attention maps

img = torch.randn(1, 3, 256, 256)
logits, embeddings = v(img)

# there is one extra token due to the CLS token

embeddings # ((1, 257, 192), (1, 17, 384)) - (batch x patches x dimension) <- large and small scales respectively

Research Ideas

Efficient Attention

There may be some coming from computer vision who think attention still suffers from quadratic costs. Fortunately, we have a lot of new techniques that may help. This repository offers a way for you to plugin your own sparse attention transformer.

An example with Nystromformer

$ pip install nystrom-attention
import torch
from vit_pytorch.efficient import ViT
from nystrom_attention import Nystromformer

efficient_transformer = Nystromformer(
    dim = 512,
    depth = 12,
    heads = 8,
    num_landmarks = 256
)

v = ViT(
    dim = 512,
    image_size = 2048,
    patch_size = 32,
    num_classes = 1000,
    transformer = efficient_transformer
)

img = torch.randn(1, 3, 2048, 2048) # your high resolution picture
v(img) # (1, 1000)

Other sparse attention frameworks I would highly recommend is Routing Transformer or Sinkhorn Transformer

Combining with other Transformer improvements

This paper purposely used the most vanilla of attention networks to make a statement. If you would like to use some of the latest improvements for attention nets, please use the Encoder from this repository.

ex.

$ pip install x-transformers
import torch
from vit_pytorch.efficient import ViT
from x_transformers import Encoder

v = ViT(
    dim = 512,
    image_size = 224,
    patch_size = 16,
    num_classes = 1000,
    transformer = Encoder(
        dim = 512,                  # set to be the same as the wrapper
        depth = 12,
        heads = 8,
        ff_glu = True,              # ex. feed forward GLU variant https://arxiv.org/abs/2002.05202
        residual_attn = True        # ex. residual attention https://arxiv.org/abs/2012.11747
    )
)

img = torch.randn(1, 3, 224, 224)
v(img) # (1, 1000)

FAQ

  • How do I pass in non-square images?

You can already pass in non-square images - you just have to make sure your height and width is less than or equal to the image_size, and both divisible by the patch_size

ex.

import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 128) # <-- not a square

preds = v(img) # (1, 1000)
  • How do I pass in non-square patches?
import torch
from vit_pytorch import ViT

v = ViT(
    num_classes = 1000,
    image_size = (256, 128),  # image size is a tuple of (height, width)
    patch_size = (32, 16),    # patch size is a tuple of (height, width)
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 128)

preds = v(img)

Resources

Coming from computer vision and new to transformers? Here are some resources that greatly accelerated my learning.

  1. Illustrated Transformer - Jay Alammar

  2. Transformers from Scratch - Peter Bloem

  3. The Annotated Transformer - Harvard NLP

Citations

@article{hassani2021escaping,
    title   = {Escaping the Big Data Paradigm with Compact Transformers},
    author  = {Ali Hassani and Steven Walton and Nikhil Shah and Abulikemu Abuduweili and Jiachen Li and Humphrey Shi},
    year    = 2021,
    url     = {https://arxiv.org/abs/2104.05704},
    eprint  = {2104.05704},
    archiveprefix = {arXiv},
    primaryclass = {cs.CV}
}
@misc{dosovitskiy2020image,
    title   = {An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
    author  = {Alexey Dosovitskiy and Lucas Beyer and Alexander Kolesnikov and Dirk Weissenborn and Xiaohua Zhai and Thomas Unterthiner and Mostafa Dehghani and Matthias Minderer and Georg Heigold and Sylvain Gelly and Jakob Uszkoreit and Neil Houlsby},
    year    = {2020},
    eprint  = {2010.11929},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{touvron2020training,
    title   = {Training data-efficient image transformers & distillation through attention}, 
    author  = {Hugo Touvron and Matthieu Cord and Matthijs Douze and Francisco Massa and Alexandre Sablayrolles and Hervé Jégou},
    year    = {2020},
    eprint  = {2012.12877},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{yuan2021tokenstotoken,
    title   = {Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet},
    author  = {Li Yuan and Yunpeng Chen and Tao Wang and Weihao Yu and Yujun Shi and Francis EH Tay and Jiashi Feng and Shuicheng Yan},
    year    = {2021},
    eprint  = {2101.11986},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{zhou2021deepvit,
    title   = {DeepViT: Towards Deeper Vision Transformer},
    author  = {Daquan Zhou and Bingyi Kang and Xiaojie Jin and Linjie Yang and Xiaochen Lian and Qibin Hou and Jiashi Feng},
    year    = {2021},
    eprint  = {2103.11886},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{touvron2021going,
    title   = {Going deeper with Image Transformers}, 
    author  = {Hugo Touvron and Matthieu Cord and Alexandre Sablayrolles and Gabriel Synnaeve and Hervé Jégou},
    year    = {2021},
    eprint  = {2103.17239},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{chen2021crossvit,
    title   = {CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification},
    author  = {Chun-Fu Chen and Quanfu Fan and Rameswar Panda},
    year    = {2021},
    eprint  = {2103.14899},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{wu2021cvt,
    title   = {CvT: Introducing Convolutions to Vision Transformers},
    author  = {Haiping Wu and Bin Xiao and Noel Codella and Mengchen Liu and Xiyang Dai and Lu Yuan and Lei Zhang},
    year    = {2021},
    eprint  = {2103.15808},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{heo2021rethinking,
    title   = {Rethinking Spatial Dimensions of Vision Transformers}, 
    author  = {Byeongho Heo and Sangdoo Yun and Dongyoon Han and Sanghyuk Chun and Junsuk Choe and Seong Joon Oh},
    year    = {2021},
    eprint  = {2103.16302},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{graham2021levit,
    title   = {LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference},
    author  = {Ben Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Hervé Jégou and Matthijs Douze},
    year    = {2021},
    eprint  = {2104.01136},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{li2021localvit,
    title   = {LocalViT: Bringing Locality to Vision Transformers},
    author  = {Yawei Li and Kai Zhang and Jiezhang Cao and Radu Timofte and Luc Van Gool},
    year    = {2021},
    eprint  = {2104.05707},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{chu2021twins,
    title   = {Twins: Revisiting Spatial Attention Design in Vision Transformers},
    author  = {Xiangxiang Chu and Zhi Tian and Yuqing Wang and Bo Zhang and Haibing Ren and Xiaolin Wei and Huaxia Xia and Chunhua Shen},
    year    = {2021},
    eprint  = {2104.13840},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding}, 
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{zhang2021aggregating,
    title   = {Aggregating Nested Transformers},
    author  = {Zizhao Zhang and Han Zhang and Long Zhao and Ting Chen and Tomas Pfister},
    year    = {2021},
    eprint  = {2105.12723},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{chen2021regionvit,
    title   = {RegionViT: Regional-to-Local Attention for Vision Transformers}, 
    author  = {Chun-Fu Chen and Rameswar Panda and Quanfu Fan},
    year    = {2021},
    eprint  = {2106.02689},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{wang2021crossformer,
    title   = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention}, 
    author  = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
    year    = {2021},
    eprint  = {2108.00154},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{caron2021emerging,
    title   = {Emerging Properties in Self-Supervised Vision Transformers},
    author  = {Mathilde Caron and Hugo Touvron and Ishan Misra and Hervé Jégou and Julien Mairal and Piotr Bojanowski and Armand Joulin},
    year    = {2021},
    eprint  = {2104.14294},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{he2021masked,
    title   = {Masked Autoencoders Are Scalable Vision Learners}, 
    author  = {Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and Piotr Dollár and Ross Girshick},
    year    = {2021},
    eprint  = {2111.06377},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{xie2021simmim,
    title   = {SimMIM: A Simple Framework for Masked Image Modeling}, 
    author  = {Zhenda Xie and Zheng Zhang and Yue Cao and Yutong Lin and Jianmin Bao and Zhuliang Yao and Qi Dai and Han Hu},
    year    = {2021},
    eprint  = {2111.09886},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{fayyaz2021ats,
    title   = {ATS: Adaptive Token Sampling For Efficient Vision Transformers},
    author  = {Mohsen Fayyaz and Soroush Abbasi Kouhpayegani and Farnoush Rezaei Jafari and Eric Sommerlade and Hamid Reza Vaezi Joze and Hamed Pirsiavash and Juergen Gall},
    year    = {2021},
    eprint  = {2111.15667},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{mehta2021mobilevit,
    title   = {MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer},
    author  = {Sachin Mehta and Mohammad Rastegari},
    year    = {2021},
    eprint  = {2110.02178},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{lee2021vision,
    title   = {Vision Transformer for Small-Size Datasets}, 
    author  = {Seung Hoon Lee and Seunghyun Lee and Byung Cheol Song},
    year    = {2021},
    eprint  = {2112.13492},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{renggli2022learning,
    title   = {Learning to Merge Tokens in Vision Transformers},
    author  = {Cedric Renggli and André Susano Pinto and Neil Houlsby and Basil Mustafa and Joan Puigcerver and Carlos Riquelme},
    year    = {2022},
    eprint  = {2202.12015},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{yang2022scalablevit,
    title   = {ScalableViT: Rethinking the Context-oriented Generalization of Vision Transformer}, 
    author  = {Rui Yang and Hailong Ma and Jie Wu and Yansong Tang and Xuefeng Xiao and Min Zheng and Xiu Li},
    year    = {2022},
    eprint  = {2203.10790},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@inproceedings{Touvron2022ThreeTE,
    title   = {Three things everyone should know about Vision Transformers},
    author  = {Hugo Touvron and Matthieu Cord and Alaaeldin El-Nouby and Jakob Verbeek and Herv'e J'egou},
    year    = {2022}
}
@inproceedings{Sandler2022FinetuningIT,
    title   = {Fine-tuning Image Transformers using Learnable Memory},
    author  = {Mark Sandler and Andrey Zhmoginov and Max Vladymyrov and Andrew Jackson},
    year    = {2022}
}
@inproceedings{Li2022SepViTSV,
    title   = {SepViT: Separable Vision Transformer},
    author  = {Wei Li and Xing Wang and Xin Xia and Jie Wu and Xuefeng Xiao and Minghang Zheng and Shiping Wen},
    year    = {2022}
}
@inproceedings{Tu2022MaxViTMV,
    title   = {MaxViT: Multi-Axis Vision Transformer},
    author  = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
    year    = {2022}
}
@article{Li2021EfficientSV,
    title   = {Efficient Self-supervised Vision Transformers for Representation Learning},
    author  = {Chunyuan Li and Jianwei Yang and Pengchuan Zhang and Mei Gao and Bin Xiao and Xiyang Dai and Lu Yuan and Jianfeng Gao},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2106.09785}
}
@misc{Beyer2022BetterPlainViT
    title     = {Better plain ViT baselines for ImageNet-1k},
    author    = {Beyer, Lucas and Zhai, Xiaohua and Kolesnikov, Alexander},
    publisher = {arXiv},
    year      = {2022}
}

@article{Arnab2021ViViTAV,
    title   = {ViViT: A Video Vision Transformer},
    author  = {Anurag Arnab and Mostafa Dehghani and Georg Heigold and Chen Sun and Mario Lucic and Cordelia Schmid},
    journal = {2021 IEEE/CVF International Conference on Computer Vision (ICCV)},
    year    = {2021},
    pages   = {6816-6826}
}
@article{Liu2022PatchDropoutEV,
    title   = {PatchDropout: Economizing Vision Transformers Using Patch Dropout},
    author  = {Yue Liu and Christos Matsoukas and Fredrik Strand and Hossein Azizpour and Kevin Smith},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2208.07220}
}
@misc{https://doi.org/10.48550/arxiv.2302.01327,
    doi     = {10.48550/ARXIV.2302.01327},
    url     = {https://arxiv.org/abs/2302.01327},
    author  = {Kumar, Manoj and Dehghani, Mostafa and Houlsby, Neil},
    title   = {Dual PatchNorm},
    publisher = {arXiv},
    year    = {2023},
    copyright = {Creative Commons Attribution 4.0 International}
}
@inproceedings{Dehghani2023PatchNP,
    title   = {Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution},
    author  = {Mostafa Dehghani and Basil Mustafa and Josip Djolonga and Jonathan Heek and Matthias Minderer and Mathilde Caron and Andreas Steiner and Joan Puigcerver and Robert Geirhos and Ibrahim M. Alabdulmohsin and Avital Oliver and Piotr Padlewski and Alexey A. Gritsenko and Mario Luvci'c and Neil Houlsby},
    year    = {2023}
}
@misc{vaswani2017attention,
    title   = {Attention Is All You Need},
    author  = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
    year    = {2017},
    eprint  = {1706.03762},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@inproceedings{Darcet2023VisionTN,
    title   = {Vision Transformers Need Registers},
    author  = {Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:263134283}
}
@inproceedings{ElNouby2021XCiTCI,
    title   = {XCiT: Cross-Covariance Image Transformers},
    author  = {Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{\'e} J{\'e}gou},
    booktitle = {Neural Information Processing Systems},
    year    = {2021},
    url     = {https://api.semanticscholar.org/CorpusID:235458262}
}

I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines. — Claude Shannon

posted @ 2024-06-28 14:12  绝不原创的飞龙  阅读(16)  评论(0编辑  收藏  举报