Lucidrains-系列项目源码解析-十二-

Lucidrains 系列项目源码解析(十二)

.\lucidrains\denoising-diffusion-pytorch\denoising_diffusion_pytorch\karras_unet.py

"""
the magnitude-preserving unet proposed in https://arxiv.org/abs/2312.02696 by Karras et al.
"""

import math
from math import sqrt, ceil
from functools import partial

import torch
from torch import nn, einsum
from torch.nn import Module, ModuleList
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F

from einops import rearrange, repeat, pack, unpack

from denoising_diffusion_pytorch.attend import Attend

# helpers functions

# 检查变量是否存在
def exists(x):
    return x is not None

# 返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 逻辑异或操作
def xnor(x, y):
    return not (x ^ y)

# 在数组末尾添加元素
def append(arr, el):
    arr.append(el)

# 在数组开头添加元素
def prepend(arr, el):
    arr.insert(0, el)

# 将张量打包成指定模式
def pack_one(t, pattern):
    return pack([t], pattern)

# 将打包的张量解包成指定模式
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 将输入转换为元组
def cast_tuple(t, length = 1):
    if isinstance(t, tuple):
        return t
    return ((t,) * length)

# 判断是否可以整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# 计算 L2 范数
def l2norm(t, dim = -1, eps = 1e-12):
    return F.normalize(t, dim = dim, eps = eps)

# mp activations
# section 2.5

# MPSiLU 激活函数
class MPSiLU(Module):
    def forward(self, x):
        return F.silu(x) / 0.596

# gain - layer scaling

# 增益层
class Gain(Module):
    def __init__(self):
        super().__init__()
        self.gain = nn.Parameter(torch.tensor(0.))

    def forward(self, x):
        return x * self.gain

# magnitude preserving concat
# equation (103) - default to 0.5, which they recommended

# 保持幅度的拼接层
class MPCat(Module):
    def __init__(self, t = 0.5, dim = -1):
        super().__init__()
        self.t = t
        self.dim = dim

    def forward(self, a, b):
        dim, t = self.dim, self.t
        Na, Nb = a.shape[dim], b.shape[dim]

        C = sqrt((Na + Nb) / ((1. - t) ** 2 + t ** 2))

        a = a * (1. - t) / sqrt(Na)
        b = b * t / sqrt(Nb)

        return C * torch.cat((a, b), dim = dim)

# magnitude preserving sum
# equation (88)
# empirically, they found t=0.3 for encoder / decoder / attention residuals
# and for embedding, t=0.5

# 保持幅度的求和层
class MPAdd(Module):
    def __init__(self, t):
        super().__init__()
        self.t = t

    def forward(self, x, res):
        a, b, t = x, res, self.t
        num = a * (1. - t) + b * t
        den = sqrt((1 - t) ** 2 + t ** 2)
        return num / den

# pixelnorm
# equation (30)

# 像素归一化层
class PixelNorm(Module):
    def __init__(self, dim, eps = 1e-4):
        super().__init__()
        # high epsilon for the pixel norm in the paper
        self.dim = dim
        self.eps = eps

    def forward(self, x):
        dim = self.dim
        return l2norm(x, dim = dim, eps = self.eps) * sqrt(x.shape[dim])

# forced weight normed conv2d and linear
# algorithm 1 in paper

# 规范化权重
def normalize_weight(weight, eps = 1e-4):
    weight, ps = pack_one(weight, 'o *')
    normed_weight = l2norm(weight, eps = eps)
    normed_weight = normed_weight * sqrt(weight.numel() / weight.shape[0])
    return unpack_one(normed_weight, ps, 'o *')

# 卷积层
class Conv2d(Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        kernel_size,
        eps = 1e-4,
        concat_ones_to_input = False   # they use this in the input block to protect against loss of expressivity due to removal of all biases, even though they claim they observed none
    ):
        super().__init__()
        weight = torch.randn(dim_out, dim_in + int(concat_ones_to_input), kernel_size, kernel_size)
        self.weight = nn.Parameter(weight)

        self.eps = eps
        self.fan_in = dim_in * kernel_size ** 2
        self.concat_ones_to_input = concat_ones_to_input
    # 定义前向传播函数,接受输入 x
    def forward(self, x):
    
        # 如果处于训练模式
        if self.training:
            # 在不计算梯度的情况下,对权重进行归一化处理
            with torch.no_grad():
                normed_weight = normalize_weight(self.weight, eps = self.eps)
                # 将归一化后的权重复制给当前权重
                self.weight.copy_(normed_weight)

        # 对权重进行归一化处理,并除以输入特征的平方根
        weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)

        # 如果需要将输入特征的维度扩展为与权重相同
        if self.concat_ones_to_input:
            # 在输入特征的高度维度上添加一个维度,值为1
            x = F.pad(x, (0, 0, 0, 0, 1, 0), value = 1.)

        # 返回经过卷积操作后的结果
        return F.conv2d(x, weight, padding='same')
class Linear(Module):
    # 定义一个线性层模块,包含输入维度、输出维度和一个小的常数 eps
    def __init__(self, dim_in, dim_out, eps = 1e-4):
        super().__init__()
        # 用随机数初始化权重
        weight = torch.randn(dim_out, dim_in)
        self.weight = nn.Parameter(weight)
        self.eps = eps
        self.fan_in = dim_in

    # 前向传播函数
    def forward(self, x):
        # 如果处于训练状态
        if self.training:
            # 使用 torch.no_grad() 上下文管理器,不计算梯度
            with torch.no_grad():
                # 对权重进行归一化处理
                normed_weight = normalize_weight(self.weight, eps = self.eps)
                # 将归一化后的权重复制给原始权重
                self.weight.copy_(normed_weight)

        # 对权重进行归一化处理,并除以输入维度的平方根
        weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)
        # 返回线性变换后的结果
        return F.linear(x, weight)

# mp fourier embeds

class MPFourierEmbedding(Module):
    # 定义一个多项式傅里叶嵌入模块,包含维度信息
    def __init__(self, dim):
        super().__init__()
        # 断言维度能被 2 整除
        assert divisible_by(dim, 2)
        half_dim = dim // 2
        # 初始化权重参数,不需要梯度
        self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = False)

    # 前向传播函数
    def forward(self, x):
        # 对输入进行重新排列
        x = rearrange(x, 'b -> b 1')
        # 计算频率
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
        # 返回正弦和余弦函数的拼接结果
        return torch.cat((freqs.sin(), freqs.cos()), dim = -1) * sqrt(2)

# building block modules

class Encoder(Module):
    # 定义一个编码器模块,包含多个参数和子模块
    def __init__(
        self,
        dim,
        dim_out = None,
        *,
        emb_dim = None,
        dropout = 0.1,
        mp_add_t = 0.3,
        has_attn = False,
        attn_dim_head = 64,
        attn_res_mp_add_t = 0.3,
        attn_flash = False,
        downsample = False
    ):
        super().__init__()
        dim_out = default(dim_out, dim)

        self.downsample = downsample
        self.downsample_conv = None

        curr_dim = dim
        # 如果需要下采样,添加一个卷积层
        if downsample:
            self.downsample_conv = Conv2d(curr_dim, dim_out, 1)
            curr_dim = dim_out

        # ��素归一化
        self.pixel_norm = PixelNorm(dim = 1)

        self.to_emb = None
        # 如果存在嵌入维度,添加线性层和增益操作
        if exists(emb_dim):
            self.to_emb = nn.Sequential(
                Linear(emb_dim, dim_out),
                Gain()
            )

        # 第一个块
        self.block1 = nn.Sequential(
            MPSiLU(),
            Conv2d(curr_dim, dim_out, 3)
        )

        # 第二个块
        self.block2 = nn.Sequential(
            MPSiLU(),
            nn.Dropout(dropout),
            Conv2d(dim_out, dim_out, 3)
        )

        # MPAdd 操作
        self.res_mp_add = MPAdd(t = mp_add_t)

        self.attn = None
        # 如果有注意力机制,添加注意力模块
        if has_attn:
            self.attn = Attention(
                dim = dim_out,
                heads = max(ceil(dim_out / attn_dim_head), 2),
                dim_head = attn_dim_head,
                mp_add_t = attn_res_mp_add_t,
                flash = attn_flash
            )

    # 前向传播函数
    def forward(
        self,
        x,
        emb = None
    ):
        # 如果需要下采样,进行插值操作和卷积
        if self.downsample:
            h, w = x.shape[-2:]
            x = F.interpolate(x, (h // 2, w // 2), mode = 'bilinear')
            x = self.downsample_conv(x)

        # 像素归一化
        x = self.pixel_norm(x)

        res = x.clone()

        x = self.block1(x)

        # 如果存在嵌入信息,进行缩放操作
        if exists(emb):
            scale = self.to_emb(emb) + 1
            x = x * rearrange(scale, 'b c -> b c 1 1')

        x = self.block2(x)

        x = self.res_mp_add(x, res)

        # 如果存在注意力模块,应用注意力机制
        if exists(self.attn):
            x = self.attn(x)

        return x

class Decoder(Module):
    # 定义一个解码器模块,包含多个参数和子模块
    def __init__(
        self,
        dim,
        dim_out = None,
        *,
        emb_dim = None,
        dropout = 0.1,
        mp_add_t = 0.3,
        has_attn = False,
        attn_dim_head = 64,
        attn_res_mp_add_t = 0.3,
        attn_flash = False,
        upsample = False
    # 初始化函数,继承父类的初始化方法
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 如果输出维度未指定,则使用输入维度作为输出维度
        dim_out = default(dim_out, dim)

        # 设置上采样标志
        self.upsample = upsample
        # 判断是否需要跳跃连接
        self.needs_skip = not upsample

        # 初始化嵌入层
        self.to_emb = None
        # 如果嵌入维度存在,则创建嵌入层
        if exists(emb_dim):
            self.to_emb = nn.Sequential(
                Linear(emb_dim, dim_out),
                Gain()
            )

        # 第一个块
        self.block1 = nn.Sequential(
            MPSiLU(),
            Conv2d(dim, dim_out, 3)
        )

        # 第二个块
        self.block2 = nn.Sequential(
            MPSiLU(),
            nn.Dropout(dropout),
            Conv2d(dim_out, dim_out, 3)
        )

        # 残差连接的卷积层
        self.res_conv = Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

        # 残差连接的加法操作
        self.res_mp_add = MPAdd(t = mp_add_t)

        # 注意力机制
        self.attn = None
        # 如果需要注意力机制
        if has_attn:
            self.attn = Attention(
                dim = dim_out,
                heads = max(ceil(dim_out / attn_dim_head), 2),
                dim_head = attn_dim_head,
                mp_add_t = attn_res_mp_add_t,
                flash = attn_flash
            )

    # 前向传播函数
    def forward(
        self,
        x,
        emb = None
    ):
        # 如果需要上采样
        if self.upsample:
            h, w = x.shape[-2:]
            x = F.interpolate(x, (h * 2, w * 2), mode = 'bilinear')

        # 计算残差连接
        res = self.res_conv(x)

        # 第一个块的操作
        x = self.block1(x)

        # 如果嵌入存在,则对输入进行缩放
        if exists(emb):
            scale = self.to_emb(emb) + 1
            x = x * rearrange(scale, 'b c -> b c 1 1')

        # 第二个块的操作
        x = self.block2(x)

        # 残差连接的加法操作
        x = self.res_mp_add(x, res)

        # 如果存在注意力机制,则应用注意力机制
        if exists(self.attn):
            x = self.attn(x)

        # 返回结果
        return x
# 定义注意力机制模块
class Attention(Module):
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 64,
        num_mem_kv = 4,
        flash = False,
        mp_add_t = 0.3
    ):
        super().__init__()
        self.heads = heads
        hidden_dim = dim_head * heads

        # 像素归一化
        self.pixel_norm = PixelNorm(dim = -1)

        # 注意力机制
        self.attend = Attend(flash = flash)

        # 存储键值对的参数
        self.mem_kv = nn.Parameter(torch.randn(2, heads, num_mem_kv, dim_head))
        # 将输入转换为查询、键、值
        self.to_qkv = Conv2d(dim, hidden_dim * 3, 1)
        # 输出转换
        self.to_out = Conv2d(hidden_dim, dim, 1)

        # 多路加法
        self.mp_add = MPAdd(t = mp_add_t)

    def forward(self, x):
        res, b, c, h, w = x, *x.shape

        # 将输入转换为查询、键、值
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h (x y) c', h = self.heads), qkv)

        # 重复存储的键值对
        mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), self.mem_kv)
        k, v = map(partial(torch.cat, dim = -2), ((mk, k), (mv, v)))

        # 对查询、键、值进行像素归一化
        q, k, v = map(self.pixel_norm, (q, k, v))

        # 注意力机制
        out = self.attend(q, k, v)

        # 重排输出形状
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
        out = self.to_out(out)

        return self.mp_add(out, res)

# Karras 提出的 Unet 模型
# 无偏置、无组归一化、保持幅度的操作

class KarrasUnet(Module):
    """
    根据图 21 配置 G
    """

    def __init__(
        self,
        *,
        image_size,
        dim = 192,
        dim_max = 768,            # 通道数每次下采样会翻倍,最大值为此值
        num_classes = None,       # 论文中为了一个流行的基准测试,使用 1000 个类别
        channels = 4,             # 论文中为 4 个通道,可能是指 alpha 通道?
        num_downsamples = 3,
        num_blocks_per_stage = 4,
        attn_res = (16, 8),
        fourier_dim = 16,
        attn_dim_head = 64,
        attn_flash = False,
        mp_cat_t = 0.5,
        mp_add_emb_t = 0.5,
        attn_res_mp_add_t = 0.3,
        resnet_mp_add_t = 0.3,
        dropout = 0.1,
        self_condition = False
    ):
        # 调用父类的构造函数
        super().__init__()

        # 设置 self_condition 属性
        self.self_condition = self_condition

        # 确定维度

        # 设置通道数和图像大小
        self.channels = channels
        self.image_size = image_size
        input_channels = channels * (2 if self_condition else 1)

        # 输入和输出块

        # 创建输入块
        self.input_block = Conv2d(input_channels, dim, 3, concat_ones_to_input = True)

        # 创建输出块
        self.output_block = nn.Sequential(
            Conv2d(dim, channels, 3),
            Gain()
        )

        # 时间嵌入

        # 设置嵌入维度
        emb_dim = dim * 4

        # 创建时间嵌入
        self.to_time_emb = nn.Sequential(
            MPFourierEmbedding(fourier_dim),
            Linear(fourier_dim, emb_dim)
        )

        # 类别嵌入

        # 检查是否需要类别标签
        self.needs_class_labels = exists(num_classes)
        self.num_classes = num_classes

        if self.needs_class_labels:
            # 创建类别嵌入
            self.to_class_emb = Linear(num_classes, 4 * dim)
            self.add_class_emb = MPAdd(t = mp_add_emb_t)

        # 最终嵌入激活函数

        # 设置嵌入激活函数
        self.emb_activation = MPSiLU()

        # 下采样数量

        # 设置下采样数量
        self.num_downsamples = num_downsamples

        # 注意力

        # 设置注意力的分辨率
        attn_res = set(cast_tuple(attn_res))

        # ResNet 块

        # 设置 ResNet 块的参数
        block_kwargs = dict(
            dropout = dropout,
            emb_dim = emb_dim,
            attn_dim_head = attn_dim_head,
            attn_res_mp_add_t = attn_res_mp_add_t,
            attn_flash = attn_flash
        )

        # UNet 编码器和解码器

        # 初始化编码器和解码器列表
        self.downs = ModuleList([])
        self.ups = ModuleList([])

        curr_dim = dim
        curr_res = image_size

        # 处理初始输入块和前三个编码器块的跳跃连接
        self.skip_mp_cat = MPCat(t = mp_cat_t, dim = 1)

        prepend(self.ups, Decoder(dim * 2, dim, **block_kwargs))

        assert num_blocks_per_stage >= 1

        for _ in range(num_blocks_per_stage):
            enc = Encoder(curr_dim, curr_dim, **block_kwargs)
            dec = Decoder(curr_dim * 2, curr_dim, **block_kwargs)

            append(self.downs, enc)
            prepend(self.ups, dec)

        # 阶段

        for _ in range(self.num_downsamples):
            dim_out = min(dim_max, curr_dim * 2)
            upsample = Decoder(dim_out, curr_dim, has_attn = curr_res in attn_res, upsample = True, **block_kwargs)

            curr_res //= 2
            has_attn = curr_res in attn_res

            downsample = Encoder(curr_dim, dim_out, downsample = True, has_attn = has_attn, **block_kwargs)

            append(self.downs, downsample)
            prepend(self.ups, upsample)
            prepend(self.ups, Decoder(dim_out * 2, dim_out, has_attn = has_attn, **block_kwargs))

            for _ in range(num_blocks_per_stage):
                enc = Encoder(dim_out, dim_out, has_attn = has_attn, **block_kwargs)
                dec = Decoder(dim_out * 2, dim_out, has_attn = has_attn, **block_kwargs)

                append(self.downs, enc)
                prepend(self.ups, dec)

            curr_dim = dim_out

        # 处理两个中间解码器

        mid_has_attn = curr_res in attn_res

        self.mids = ModuleList([
            Decoder(curr_dim, curr_dim, has_attn = mid_has_attn, **block_kwargs),
            Decoder(curr_dim, curr_dim, has_attn = mid_has_attn, **block_kwargs),
        ])

        self.out_dim = channels

    @property
    def downsample_factor(self):
        # 返回下采样因子
        return 2 ** self.num_downsamples

    def forward(
        self,
        x,
        time,
        self_cond = None,
        class_labels = None
    ):
        # 验证图像形状是否符合预期

        assert x.shape[1:] == (self.channels, self.image_size, self.image_size)

        # 自身条件

        if self.self_condition:
            # 如果存在自身条件,则将其与输入数据拼接在一起
            self_cond = default(self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((self_cond, x), dim = 1)
        else:
            # 确保不存在自身条件
            assert not exists(self_cond)

        # 时间条件

        time_emb = self.to_time_emb(time)

        # 类别条件

        assert xnor(exists(class_labels), self.needs_class_labels)

        if self.needs_class_labels:
            if class_labels.dtype in (torch.int, torch.long):
                # 将类别标签转换为 one-hot 编码
                class_labels = F.one_hot(class_labels, self.num_classes)

            assert class_labels.shape[-1] == self.num_classes
            # 将类别标签转换为浮点数并乘以根号下类别数
            class_labels = class_labels.float() * sqrt(self.num_classes)

            class_emb = self.to_class_emb(class_labels)

            # 将类别嵌入加入到时间嵌入中
            time_emb = self.add_class_emb(time_emb, class_emb)

        # 最终的 mp-silu 嵌入

        emb = self.emb_activation(time_emb)

        # 跳跃连接

        skips = []

        # 输入块

        x = self.input_block(x)

        skips.append(x)

        # 下采样

        for encoder in self.downs:
            x = encoder(x, emb = emb)
            skips.append(x)

        # 中间层

        for decoder in self.mids:
            x = decoder(x, emb = emb)

        # 上采样

        for decoder in self.ups:
            if decoder.needs_skip:
                skip = skips.pop()
                x = self.skip_mp_cat(x, skip)

            x = decoder(x, emb = emb)

        # 输出块

        return self.output_block(x)
# 定义 MPFeedForward 类,用于实现多头感知器前馈网络
class MPFeedForward(Module):
    # 初始化函数
    def __init__(
        self,
        *,
        dim,  # 输入维度
        mult = 4,  # 内部维度倍数,默认为4
        mp_add_t = 0.3  # MPAdd 参数,默认为0.3
    ):
        super().__init__()
        dim_inner = int(dim * mult)  # 计算内部维度
        self.net = nn.Sequential(  # 定义网络结构
            PixelNorm(dim = 1),  # 像素归一化
            Conv2d(dim, dim_inner, 1),  # 1x1 卷积
            MPSiLU(),  # MPSiLU激活函数
            Conv2d(dim_inner, dim, 1)  # 1x1 卷积
        )

        self.mp_add = MPAdd(t = mp_add_t)  # 初始化 MPAdd 操作

    # 前向传播函数
    def forward(self, x):
        res = x  # 保存输入
        out = self.net(x)  # 网络前向传播
        return self.mp_add(out, res)  # 返回 MPAdd 操作结果

# 定义 MPImageTransformer 类,用于实现多头图像变换器
class MPImageTransformer(Module):
    # 初始化函数
    def __init__(
        self,
        *,
        dim,  # 输入维度
        depth,  # 深度
        dim_head = 64,  # 头维度,默认为64
        heads = 8,  # 头数,默认为8
        num_mem_kv = 4,  # 记忆键值对数,默认为4
        ff_mult = 4,  # 前馈网络内部维度倍数,默认为4
        attn_flash = False,  # 是否使用闪回,默认为False
        residual_mp_add_t = 0.3  # 残差 MPAdd 参数,默认为0.3
    ):
        super().__init__()
        self.layers = ModuleList([])  # 初始化层列表

        for _ in range(depth):  # 根据深度循环添加层
            self.layers.append(ModuleList([
                Attention(dim = dim, heads = heads, dim_head = dim_head, num_mem_kv = num_mem_kv, flash = attn_flash, mp_add_t = residual_mp_add_t),  # 添加注意力层
                MPFeedForward(dim = dim, mult = ff_mult, mp_add_t = residual_mp_add_t)  # 添加前馈网络层
            ]))

    # 前向传播函数
    def forward(self, x):
        for attn, ff in self.layers:  # 遍历层列表
            x = attn(x)  # 注意力层前向传播
            x = ff(x)  # 前馈网络层前向传播

        return x  # 返回结果

# 定义 InvSqrtDecayLRSched 函数,用于实现反平方根衰减学习率调度
def InvSqrtDecayLRSched(
    optimizer,  # 优化器
    t_ref = 70000,  # 参考时间,默认为70000
    sigma_ref = 0.01  # 参考 Sigma,默认为0.01
):
    """
    refer to equation 67 and Table1
    """
    def inv_sqrt_decay_fn(t: int):  # 定义反平方根衰减函数
        return sigma_ref / sqrt(max(t / t_ref, 1.))  # 返回学习率

    return LambdaLR(optimizer, lr_lambda = inv_sqrt_decay_fn)  # 返回学习率调度器

# 示例
if __name__ == '__main__':
    # 创建 KarrasUnet 实例
    unet = KarrasUnet(
        image_size = 64,
        dim = 192,
        dim_max = 768,
        num_classes = 1000,
    )

    images = torch.randn(2, 4, 64, 64)  # 创建随机输入图像

    # 输入图像进行去噪处理
    denoised_images = unet(
        images,
        time = torch.ones(2,),  # 时间参数
        class_labels = torch.randint(0, 1000, (2,))  # 类别标签
    )

    assert denoised_images.shape == images.shape  # 断言输出形状与输入形状相同

.\lucidrains\denoising-diffusion-pytorch\denoising_diffusion_pytorch\karras_unet_1d.py

"""
the magnitude-preserving unet proposed in https://arxiv.org/abs/2312.02696 by Karras et al.
"""

import math
from math import sqrt, ceil
from functools import partial

import torch
from torch import nn, einsum
from torch.nn import Module, ModuleList
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F

from einops import rearrange, repeat, pack, unpack

from denoising_diffusion_pytorch.attend import Attend

# helpers functions

# 检查变量是否存在
def exists(x):
    return x is not None

# 返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 逻辑异或操作
def xnor(x, y):
    return not (x ^ y)

# 在数组末尾添加元素
def append(arr, el):
    arr.append(el)

# 在数组开头添加元素
def prepend(arr, el):
    arr.insert(0, el)

# 将张量打包成指定模式
def pack_one(t, pattern):
    return pack([t], pattern)

# 将打包的张量解包成指定模式
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 将元素转换为元组
def cast_tuple(t, length = 1):
    if isinstance(t, tuple):
        return t
    return ((t,) * length)

# 判断是否可以整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# 计算 L2 范数
def l2norm(t, dim = -1, eps = 1e-12):
    return F.normalize(t, dim = dim, eps = eps)

# 在一维上插值
def interpolate_1d(x, length, mode = 'bilinear'):
    x = rearrange(x, 'b c t -> b c t 1')
    x = F.interpolate(x, (length, 1), mode = mode)
    return rearrange(x, 'b c t 1 -> b c t')

# mp activations
# section 2.5

# MPSiLU 激活函数
class MPSiLU(Module):
    def forward(self, x):
        return F.silu(x) / 0.596

# gain - layer scaling

# 增益层
class Gain(Module):
    def __init__(self):
        super().__init__()
        self.gain = nn.Parameter(torch.tensor(0.))

    def forward(self, x):
        return x * self.gain

# magnitude preserving concat
# equation (103) - default to 0.5, which they recommended

# 保持幅度的拼接层
class MPCat(Module):
    def __init__(self, t = 0.5, dim = -1):
        super().__init__()
        self.t = t
        self.dim = dim

    def forward(self, a, b):
        dim, t = self.dim, self.t
        Na, Nb = a.shape[dim], b.shape[dim]

        C = sqrt((Na + Nb) / ((1. - t) ** 2 + t ** 2))

        a = a * (1. - t) / sqrt(Na)
        b = b * t / sqrt(Nb)

        return C * torch.cat((a, b), dim = dim)

# magnitude preserving sum
# equation (88)
# empirically, they found t=0.3 for encoder / decoder / attention residuals
# and for embedding, t=0.5

# 保持幅度的求和层
class MPAdd(Module):
    def __init__(self, t):
        super().__init__()
        self.t = t

    def forward(self, x, res):
        a, b, t = x, res, self.t
        num = a * (1. - t) + b * t
        den = sqrt((1 - t) ** 2 + t ** 2)
        return num / den

# pixelnorm
# equation (30)

# 像素范数层
class PixelNorm(Module):
    def __init__(self, dim, eps = 1e-4):
        super().__init__()
        # high epsilon for the pixel norm in the paper
        self.dim = dim
        self.eps = eps

    def forward(self, x):
        dim = self.dim
        return l2norm(x, dim = dim, eps = self.eps) * sqrt(x.shape[dim])

# forced weight normed conv2d and linear
# algorithm 1 in paper

# 规范化权重
def normalize_weight(weight, eps = 1e-4):
    weight, ps = pack_one(weight, 'o *')
    normed_weight = l2norm(weight, eps = eps)
    normed_weight = normed_weight * sqrt(weight.numel() / weight.shape[0])
    return unpack_one(normed_weight, ps, 'o *')

# 一维卷积层
class Conv1d(Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        kernel_size,
        eps = 1e-4,
        init_dirac = False,
        concat_ones_to_input = False   # they use this in the input block to protect against loss of expressivity due to removal of all biases, even though they claim they observed none
    ):
        super().__init__()
        weight = torch.randn(dim_out, dim_in + int(concat_ones_to_input), kernel_size)
        self.weight = nn.Parameter(weight)

        if init_dirac:
            nn.init.dirac_(self.weight)

        self.eps = eps
        self.fan_in = dim_in * kernel_size
        self.concat_ones_to_input = concat_ones_to_input
    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 如果处于训练模式
        if self.training:
            # 在不计算梯度的情况下,对权重进行归一化处理
            with torch.no_grad():
                normed_weight = normalize_weight(self.weight, eps = self.eps)
                # 将归一化后的权重复制给当前权重
                self.weight.copy_(normed_weight)

        # 对权重进行归一化处理,并除以输入特征数的平方根
        weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)

        # 如果需要将输入的维度扩展为包含全为1的维度
        if self.concat_ones_to_input:
            # 在输入 x 上进行填充,使得维度增加一维,填充值为1
            x = F.pad(x, (0, 0, 1, 0), value = 1.)

        # 返回一维卷积操作的结果,使用权重 weight 进行卷积,padding 为 'same'
        return F.conv1d(x, weight, padding = 'same')
# 定义线性层模块,继承自 Module 类
class Linear(Module):
    # 初始化函数,接受输入维度、输出维度和 eps 参数
    def __init__(self, dim_in, dim_out, eps = 1e-4):
        # 调用父类的初始化函数
        super().__init__()
        # 生成随机权重矩阵
        weight = torch.randn(dim_out, dim_in)
        # 将权重矩阵设置为可训练参数
        self.weight = nn.Parameter(weight)
        # 设置 eps 属性
        self.eps = eps
        # 记录输入维度
        self.fan_in = dim_in

    # 前向传播函数
    def forward(self, x):
        # 如果处于训练模式
        if self.training:
            # 使用 torch.no_grad() 上下文管理器,不计算梯度
            with torch.no_grad():
                # 对权重矩阵进行归一化处理
                normed_weight = normalize_weight(self.weight, eps = self.eps)
                # 将归一化后的权重矩阵复制给 self.weight
                self.weight.copy_(normed_weight)

        # 对权重矩阵进行归一化处理,并除以输入维度的平方根
        weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)
        # 返回线性变换后的结果
        return F.linear(x, weight)

# MP Fourier Embedding 模块

class MPFourierEmbedding(Module):
    # 初始化函数,接受维度参数
    def __init__(self, dim):
        # 调用父类的初始化函数
        super().__init__()
        # 断言维度能被 2 整除
        assert divisible_by(dim, 2)
        # 计算维度的一半
        half_dim = dim // 2
        # 初始化权重参数,不需要梯度
        self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = False)

    # 前向传播函数
    def forward(self, x):
        # 对输入进行维度重排,增加一个维度
        x = rearrange(x, 'b -> b 1')
        # 计算频率
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
        # 返回正弦和余弦函数的拼接结果,乘以根号2
        return torch.cat((freqs.sin(), freqs.cos()), dim = -1) * sqrt(2)

# 构建基础模块

class Encoder(Module):
    # 初始化函数,接受维度、输出维度等参数
    def __init__(
        self,
        dim,
        dim_out = None,
        *,
        emb_dim = None,
        dropout = 0.1,
        mp_add_t = 0.3,
        has_attn = False,
        attn_dim_head = 64,
        attn_res_mp_add_t = 0.3,
        attn_flash = False,
        downsample = False
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 如果未指定输出维度,则设为输入维度
        dim_out = default(dim_out, dim)

        # 是否下���样
        self.downsample = downsample
        self.downsample_conv = None

        curr_dim = dim
        # 如果下采样为真
        if downsample:
            # 初始化下采样卷积层
            self.downsample_conv = Conv1d(curr_dim, dim_out, 1)
            curr_dim = dim_out

        # 像素归一化
        self.pixel_norm = PixelNorm(dim = 1)

        self.to_emb = None
        # 如果存在嵌入维度
        if exists(emb_dim):
            # 初始化嵌入层
            self.to_emb = nn.Sequential(
                Linear(emb_dim, dim_out),
                Gain()
            )

        # 第一个块
        self.block1 = nn.Sequential(
            MPSiLU(),
            Conv1d(curr_dim, dim_out, 3)
        )

        # 第二个块
        self.block2 = nn.Sequential(
            MPSiLU(),
            nn.Dropout(dropout),
            Conv1d(dim_out, dim_out, 3)
        )

        # MPAdd 模块
        self.res_mp_add = MPAdd(t = mp_add_t)

        self.attn = None
        # 如果有注意力机制
        if has_attn:
            # 初始化注意力层
            self.attn = Attention(
                dim = dim_out,
                heads = max(ceil(dim_out / attn_dim_head), 2),
                dim_head = attn_dim_head,
                mp_add_t = attn_res_mp_add_t,
                flash = attn_flash
            )

    # 前向传播函数
    def forward(
        self,
        x,
        emb = None
    ):
        # 如果下采样为真
        if self.downsample:
            # 对输入进行一维插值,减半长度
            x = interpolate_1d(x, x.shape[-1] // 2, mode = 'bilinear')
            x = self.downsample_conv(x)

        # 对输入进行像素归一化
        x = self.pixel_norm(x)

        # 复制输入作为残差
        res = x.clone()

        # 第一个块的前向传播
        x = self.block1(x)

        # 如果存在嵌入
        if exists(emb):
            # 计算缩放因子
            scale = self.to_emb(emb) + 1
            x = x * rearrange(scale, 'b c -> b c 1')

        # 第二个块的前向传播
        x = self.block2(x)

        # MPAdd 模块的前向传播
        x = self.res_mp_add(x, res)

        # 如果存在注意力层
        if exists(self.attn):
            x = self.attn(x)

        # 返回结果
        return x

# 解码器模块

class Decoder(Module):
    # 初始化函数,接受维度、输出维度等参数
    def __init__(
        self,
        dim,
        dim_out = None,
        *,
        emb_dim = None,
        dropout = 0.1,
        mp_add_t = 0.3,
        has_attn = False,
        attn_dim_head = 64,
        attn_res_mp_add_t = 0.3,
        attn_flash = False,
        upsample = False
    # 初始化函数,继承父类的初始化方法
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 如果输出维度未指定,则使用输入维度作为输出维度
        dim_out = default(dim_out, dim)

        # 设置上采样标志
        self.upsample = upsample
        # 判断是否需要跳跃连接
        self.needs_skip = not upsample

        # 初始化嵌入层
        self.to_emb = None
        # 如果嵌入维度存在,则创建嵌入层
        if exists(emb_dim):
            self.to_emb = nn.Sequential(
                Linear(emb_dim, dim_out),
                Gain()
            )

        # 第一个块
        self.block1 = nn.Sequential(
            MPSiLU(),
            Conv1d(dim, dim_out, 3)
        )

        # 第二个块
        self.block2 = nn.Sequential(
            MPSiLU(),
            nn.Dropout(dropout),
            Conv1d(dim_out, dim_out, 3)
        )

        # 残差连接的卷积层
        self.res_conv = Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

        # 残差连接的加法操作
        self.res_mp_add = MPAdd(t = mp_add_t)

        # 注意力机制
        self.attn = None
        # 如果需要注意力机制
        if has_attn:
            self.attn = Attention(
                dim = dim_out,
                heads = max(ceil(dim_out / attn_dim_head), 2),
                dim_head = attn_dim_head,
                mp_add_t = attn_res_mp_add_t,
                flash = attn_flash
            )

    # 前向传播函数
    def forward(
        self,
        x,
        emb = None
    ):
        # 如果需要上采样
        if self.upsample:
            # 对输入进行一维插值上采样
            x = interpolate_1d(x, x.shape[-1] * 2, mode = 'bilinear')

        # 计算残差连接
        res = self.res_conv(x)

        # 第一个块的操作
        x = self.block1(x)

        # 如果嵌入存在
        if exists(emb):
            # 计算缩放因子
            scale = self.to_emb(emb) + 1
            x = x * rearrange(scale, 'b c -> b c 1')

        # 第二个块的操作
        x = self.block2(x)

        # 执行残差连接的加法操作
        x = self.res_mp_add(x, res)

        # 如果存在注意力机制
        if exists(self.attn):
            # 执行注意力机制操作
            x = self.attn(x)

        # 返回结果
        return x
# 定义一个注意力机制的类,继承自 Module 类
class Attention(Module):
    # 初始化函数,设置注意力机制的参数
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 64,
        num_mem_kv = 4,
        flash = False,
        mp_add_t = 0.3
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 设置头数和隐藏维度
        self.heads = heads
        hidden_dim = dim_head * heads

        # 像素归一化
        self.pixel_norm = PixelNorm(dim = -1)

        # 注意力机制
        self.attend = Attend(flash = flash)

        # 记忆键值对
        self.mem_kv = nn.Parameter(torch.randn(2, heads, num_mem_kv, dim_head))
        self.to_qkv = Conv1d(dim, hidden_dim * 3, 1)
        self.to_out = Conv1d(hidden_dim, dim, 1)

        # 多路加法
        self.mp_add = MPAdd(t = mp_add_t)

    # 前向传播函数
    def forward(self, x):
        res, b, c, n = x, *x.shape

        # 将输入数据转换为查询、键、值
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) n -> b h n c', h = self.heads), qkv)

        # 扩展记忆键值对
        mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), self.mem_kv)
        k, v = map(partial(torch.cat, dim = -2), ((mk, k), (mv, v)))

        # 对查询、键、值进行像素归一化
        q, k, v = map(self.pixel_norm, (q, k, v))

        # 进行注意力计算
        out = self.attend(q, k, v)

        out = rearrange(out, 'b h n d -> b (h d) n')
        out = self.to_out(out)

        return self.mp_add(out, res)

# 定义一个基于 Karras 提出的 Unet 的 1D 版本
class KarrasUnet1D(Module):
    """
    going by figure 21. config G
    """

    # 初始化函数,设置 Unet 的参数
    def __init__(
        self,
        *,
        seq_len,
        dim = 192,
        dim_max = 768,            
        num_classes = None,       
        channels = 4,             
        num_downsamples = 3,
        num_blocks_per_stage = 4,
        attn_res = (16, 8),
        fourier_dim = 16,
        attn_dim_head = 64,
        attn_flash = False,
        mp_cat_t = 0.5,
        mp_add_emb_t = 0.5,
        attn_res_mp_add_t = 0.3,
        resnet_mp_add_t = 0.3,
        dropout = 0.1,
        self_condition = False
    # 初始化函数,继承父类的初始化方法
    ):
        super().__init__()

        # 设置 self_condition 属性
        self.self_condition = self_condition

        # 确定维度

        # 设置通道数和序列长度
        self.channels = channels
        self.seq_len = seq_len
        # 计算输入通道数
        input_channels = channels * (2 if self_condition else 1)

        # 输入和输出块

        # 创建输入块
        self.input_block = Conv1d(input_channels, dim, 3, concat_ones_to_input = True)

        # 创建输出块
        self.output_block = nn.Sequential(
            Conv1d(dim, channels, 3),
            Gain()
        )

        # 时间嵌入

        # 设置嵌入维度
        emb_dim = dim * 4

        # 创建时间嵌入层
        self.to_time_emb = nn.Sequential(
            MPFourierEmbedding(fourier_dim),
            Linear(fourier_dim, emb_dim)
        )

        # 类别嵌入

        # 判断是否需要类别标签
        self.needs_class_labels = exists(num_classes)
        self.num_classes = num_classes

        # 如果需要类别标签
        if self.needs_class_labels:
            # 创建类别嵌入层
            self.to_class_emb = Linear(num_classes, 4 * dim)
            self.add_class_emb = MPAdd(t = mp_add_emb_t)

        # 最终嵌入激活函数

        self.emb_activation = MPSiLU()

        # 下采样数量

        self.num_downsamples = num_downsamples

        # 注意力

        attn_res = set(cast_tuple(attn_res))

        # ResNet 块

        block_kwargs = dict(
            dropout = dropout,
            emb_dim = emb_dim,
            attn_dim_head = attn_dim_head,
            attn_res_mp_add_t = attn_res_mp_add_t,
            attn_flash = attn_flash
        )

        # UNet 编码器和解码器

        self.downs = ModuleList([])
        self.ups = ModuleList([])

        curr_dim = dim
        curr_res = seq_len

        self.skip_mp_cat = MPCat(t = mp_cat_t, dim = 1)

        # 处理初始输入块和前三个编码器块的跳跃连接

        prepend(self.ups, Decoder(dim * 2, dim, **block_kwargs))

        assert num_blocks_per_stage >= 1

        for _ in range(num_blocks_per_stage):
            enc = Encoder(curr_dim, curr_dim, **block_kwargs)
            dec = Decoder(curr_dim * 2, curr_dim, **block_kwargs)

            append(self.downs, enc)
            prepend(self.ups, dec)

        # 阶段

        for _ in range(self.num_downsamples):
            dim_out = min(dim_max, curr_dim * 2)
            upsample = Decoder(dim_out, curr_dim, has_attn = curr_res in attn_res, upsample = True, **block_kwargs)

            curr_res //= 2
            has_attn = curr_res in attn_res

            downsample = Encoder(curr_dim, dim_out, downsample = True, has_attn = has_attn, **block_kwargs)

            append(self.downs, downsample)
            prepend(self.ups, upsample)
            prepend(self.ups, Decoder(dim_out * 2, dim_out, has_attn = has_attn, **block_kwargs))

            for _ in range(num_blocks_per_stage):
                enc = Encoder(dim_out, dim_out, has_attn = has_attn, **block_kwargs)
                dec = Decoder(dim_out * 2, dim_out, has_attn = has_attn, **block_kwargs)

                append(self.downs, enc)
                prepend(self.ups, dec)

            curr_dim = dim_out

        # 处理两个中间解码器

        mid_has_attn = curr_res in attn_res

        self.mids = ModuleList([
            Decoder(curr_dim, curr_dim, has_attn = mid_has_attn, **block_kwargs),
            Decoder(curr_dim, curr_dim, has_attn = mid_has_attn, **block_kwargs),
        ])

        self.out_dim = channels

    @property
    def downsample_factor(self):
        return 2 ** self.num_downsamples

    def forward(
        self,
        x,
        time,
        self_cond = None,
        class_labels = None
    ):
        # 验证图像形状是否符合要求

        assert x.shape[1:] == (self.channels, self.seq_len)

        # 自身条件

        if self.self_condition:
            self_cond = default(self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((self_cond, x), dim = 1)
        else:
            assert not exists(self_cond)

        # 时间条件

        time_emb = self.to_time_emb(time)

        # 类别条件

        assert xnor(exists(class_labels), self.needs_class_labels)

        if self.needs_class_labels:
            if class_labels.dtype in (torch.int, torch.long):
                class_labels = F.one_hot(class_labels, self.num_classes)

            assert class_labels.shape[-1] == self.num_classes
            class_labels = class_labels.float() * sqrt(self.num_classes)

            class_emb = self.to_class_emb(class_labels)

            time_emb = self.add_class_emb(time_emb, class_emb)

        # 最终的 mp-silu 用于嵌入

        emb = self.emb_activation(time_emb)

        # 跳过连接

        skips = []

        # 输入块

        x = self.input_block(x)

        skips.append(x)

        # 下采样

        for encoder in self.downs:
            x = encoder(x, emb = emb)
            skips.append(x)

        # 中间层

        for decoder in self.mids:
            x = decoder(x, emb = emb)

        # 上采样

        for decoder in self.ups:
            if decoder.needs_skip:
                skip = skips.pop()
                x = self.skip_mp_cat(x, skip)

            x = decoder(x, emb = emb)

        # 输出块

        return self.output_block(x)
# 定义一个 MPFeedForward 类,用于实现多头感知器前馈网络
class MPFeedForward(Module):
    # 初始化函数,接收参数 dim(维度)、mult(倍数,默认为4)、mp_add_t(MPAdd 参数,默认为0.3)
    def __init__(
        self,
        *,
        dim,
        mult = 4,
        mp_add_t = 0.3
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 计算内部维度
        dim_inner = int(dim * mult)
        # 定义网络结构
        self.net = nn.Sequential(
            PixelNorm(dim = 1),  # 对输入进行像素归一化
            Conv2d(dim, dim_inner, 1),  # 1x1 卷积层
            MPSiLU(),  # MP SiLU 激活函数
            Conv2d(dim_inner, dim, 1)  # 1x1 卷积层
        )

        # 初始化 MPAdd 模块
        self.mp_add = MPAdd(t = mp_add_t)

    # 前向传播函数
    def forward(self, x):
        res = x
        out = self.net(x)  # 网络前向传播
        return self.mp_add(out, res)  # 返回 MPAdd 模块的输出结果

# 定义一个 MPImageTransformer 类,用于实现多头图像变换器
class MPImageTransformer(Module):
    # 初始化函数,接收参数 dim(维度)、depth(深度)、dim_head(头部维度,默认为64)、heads(头数,默认为8)、num_mem_kv(记忆键值对数,默认为4)、ff_mult(前馈网络倍数,默认为4)、attn_flash(是否使用闪回,默认为False)、residual_mp_add_t(MPAdd 参数,默认为0.3)
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        num_mem_kv = 4,
        ff_mult = 4,
        attn_flash = False,
        residual_mp_add_t = 0.3
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 初始化网络层列表
        self.layers = ModuleList([])

        # 根据深度循环添加注意力和前馈网络层
        for _ in range(depth):
            self.layers.append(ModuleList([
                Attention(dim = dim, heads = heads, dim_head = dim_head, num_mem_kv = num_mem_kv, flash = attn_flash, mp_add_t = residual_mp_add_t),  # 添加注意力层
                MPFeedForward(dim = dim, mult = ff_mult, mp_add_t = residual_mp_add_t)  # 添加前馈网络层
            ]))

    # 前向传播函数
    def forward(self, x):
        # 遍历网络层列表
        for attn, ff in self.layers:
            x = attn(x)  # 注意力层前向传播
            x = ff(x)  # 前馈网络层前向传播

        return x  # 返回输出结果

# 示例代码
if __name__ == '__main__':
    # 创建 KarrasUnet1D 实例
    unet = KarrasUnet1D(
        seq_len = 64,
        dim = 192,
        dim_max = 768,
        num_classes = 1000,
    )

    # 生成随机输入图像
    images = torch.randn(2, 4, 64)

    # 使用 unet 进行图像去噪
    denoised_images = unet(
        images,
        time = torch.ones(2,),
        class_labels = torch.randint(0, 1000, (2,))
    )

    # 断言去噪后的图像形状与原始图像形状相同
    assert denoised_images.shape == images.shape

.\lucidrains\denoising-diffusion-pytorch\denoising_diffusion_pytorch\karras_unet_3d.py

"""
the magnitude-preserving unet proposed in https://arxiv.org/abs/2312.02696 by Karras et al.
"""

import math
from math import sqrt, ceil
from functools import partial
from typing import Optional, Union, Tuple

import torch
from torch import nn, einsum
from torch.nn import Module, ModuleList
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F

from einops import rearrange, repeat, pack, unpack

from denoising_diffusion_pytorch.attend import Attend

# helpers functions

# 检查变量是否存在
def exists(x):
    return x is not None

# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 逻辑异或操作
def xnor(x, y):
    return not (x ^ y)

# 在数组末尾添加元素
def append(arr, el):
    arr.append(el)

# 在数组开头添加元素
def prepend(arr, el):
    arr.insert(0, el)

# 将张量打包成指定模式的形状
def pack_one(t, pattern):
    return pack([t], pattern)

# 将打包后的张量解包成原始形状
def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

# 将输入转换为元组
def cast_tuple(t, length = 1):
    if isinstance(t, tuple):
        return t
    return ((t,) * length)

# 判断一个数是否可以被另一个数整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# 在论文中,他们使用 eps 1e-4 作为像素归一化的值

# 计算 L2 范数
def l2norm(t, dim = -1, eps = 1e-12):
    return F.normalize(t, dim = dim, eps = eps)

# mp activations
# section 2.5

# MPSiLU 激活函数
class MPSiLU(Module):
    def forward(self, x):
        return F.silu(x) / 0.596

# gain - layer scaling

# 增益层
class Gain(Module):
    def __init__(self):
        super().__init__()
        self.gain = nn.Parameter(torch.tensor(0.))

    def forward(self, x):
        return x * self.gain

# magnitude preserving concat
# equation (103) - default to 0.5, which they recommended

# 保持幅度的拼接层
class MPCat(Module):
    def __init__(self, t = 0.5, dim = -1):
        super().__init__()
        self.t = t
        self.dim = dim

    def forward(self, a, b):
        dim, t = self.dim, self.t
        Na, Nb = a.shape[dim], b.shape[dim]

        C = sqrt((Na + Nb) / ((1. - t) ** 2 + t ** 2))

        a = a * (1. - t) / sqrt(Na)
        b = b * t / sqrt(Nb)

        return C * torch.cat((a, b), dim = dim)

# magnitude preserving sum
# equation (88)
# empirically, they found t=0.3 for encoder / decoder / attention residuals
# and for embedding, t=0.5

# 保持幅度的求和层
class MPAdd(Module):
    def __init__(self, t):
        super().__init__()
        self.t = t

    def forward(self, x, res):
        a, b, t = x, res, self.t
        num = a * (1. - t) + b * t
        den = sqrt((1 - t) ** 2 + t ** 2)
        return num / den

# pixelnorm
# equation (30)

# 像素归一化层
class PixelNorm(Module):
    def __init__(self, dim, eps = 1e-4):
        super().__init__()
        # 论文中像素归一化的高 epsilon 值
        self.dim = dim
        self.eps = eps

    def forward(self, x):
        dim = self.dim
        return l2norm(x, dim = dim, eps = self.eps) * sqrt(x.shape[dim])

# forced weight normed conv3d and linear
# algorithm 1 in paper

# 归一化权重的 Conv3d 和 Linear 层
def normalize_weight(weight, eps = 1e-4):
    weight, ps = pack_one(weight, 'o *')
    normed_weight = l2norm(weight, eps = eps)
    normed_weight = normed_weight * sqrt(weight.numel() / weight.shape[0])
    return unpack_one(normed_weight, ps, 'o *')

# 3D 卷积层
class Conv3d(Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        kernel_size,
        eps = 1e-4,
        concat_ones_to_input = False   # they use this in the input block to protect against loss of expressivity due to removal of all biases, even though they claim they observed none
    ):
        super().__init__()
        weight = torch.randn(dim_out, dim_in + int(concat_ones_to_input), kernel_size, kernel_size, kernel_size)
        self.weight = nn.Parameter(weight)

        self.eps = eps
        self.fan_in = dim_in * kernel_size ** 3
        self.concat_ones_to_input = concat_ones_to_input
    # 定义前向传播函数,接受输入 x
    def forward(self, x):

        # 如果处于训练模式
        if self.training:
            # 在不计算梯度的情况下,对权重进行归一化处理
            with torch.no_grad():
                normed_weight = normalize_weight(self.weight, eps = self.eps)
                # 将归一化后的权重复制给当前权重
                self.weight.copy_(normed_weight)

        # 对权重进行归一化处理,并除以输入特征的平方根
        weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)

        # 如果需要将输入与全为1的张量进行拼接
        if self.concat_ones_to_input:
            # 在输入张量的最后一维度上填充1
            x = F.pad(x, (0, 0, 0, 0, 0, 0, 1, 0), value = 1.)

        # 返回经过卷积操作后的结果
        return F.conv3d(x, weight, padding='same')
# 定义一个线性层模块,包含输入维度、输出维度和一个小的常数 eps
class Linear(Module):
    def __init__(self, dim_in, dim_out, eps = 1e-4):
        super().__init__()
        # 用随机数初始化权重矩阵
        weight = torch.randn(dim_out, dim_in)
        self.weight = nn.Parameter(weight)
        self.eps = eps
        self.fan_in = dim_in

    # 前向传播函数
    def forward(self, x):
        # 如果处于训练状态
        if self.training:
            # 使用 torch.no_grad() 上下文管理器,不计算梯度
            with torch.no_grad():
                # 对权重进行归一化处理
                normed_weight = normalize_weight(self.weight, eps = self.eps)
                # 将归一化后的权重复制给原始权重
                self.weight.copy_(normed_weight)

        # 对权重进行归一化处理,并除以输入维度的平方根
        weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in)
        # 返回线性变换后的结果
        return F.linear(x, weight)

# MP Fourier Embedding 模块

class MPFourierEmbedding(Module):
    def __init__(self, dim):
        super().__init__()
        # 断言维度必须是2的倍数
        assert divisible_by(dim, 2)
        half_dim = dim // 2
        # 初始化权重参数,不需要计算梯度
        self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = False)

    # 前向传播函数
    def forward(self, x):
        # 对输入进行维度重排
        x = rearrange(x, 'b -> b 1')
        # 计算频率
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
        # 返回正弦和余弦函数的拼接结果,并乘以根号2
        return torch.cat((freqs.sin(), freqs.cos()), dim = -1) * sqrt(2)

# 构建基本模块

class Encoder(Module):
    def __init__(
        self,
        dim,
        dim_out = None,
        *,
        emb_dim = None,
        dropout = 0.1,
        mp_add_t = 0.3,
        has_attn = False,
        attn_dim_head = 64,
        attn_res_mp_add_t = 0.3,
        attn_flash = False,
        factorize_space_time_attn = False,
        downsample = False,
        downsample_config: Tuple[bool, bool, bool] = (True, True, True)
    ):
        super().__init__()
        dim_out = default(dim_out, dim)

        self.downsample = downsample
        self.downsample_config = downsample_config

        self.downsample_conv = None

        curr_dim = dim
        # 如果需要下采样
        if downsample:
            # 使用 1x1 卷积进行下采样
            self.downsample_conv = Conv3d(curr_dim, dim_out, 1)
            curr_dim = dim_out

        # 像素归一化
        self.pixel_norm = PixelNorm(dim = 1)

        self.to_emb = None
        # 如果存在嵌入维度
        if exists(emb_dim):
            # 构建嵌入层
            self.to_emb = nn.Sequential(
                Linear(emb_dim, dim_out),
                Gain()
            )

        # 第一个基本模块
        self.block1 = nn.Sequential(
            MPSiLU(),
            Conv3d(curr_dim, dim_out, 3)
        )

        # 第二个基本模块
        self.block2 = nn.Sequential(
            MPSiLU(),
            nn.Dropout(dropout),
            Conv3d(dim_out, dim_out, 3)
        )

        # MPAdd 模块
        self.res_mp_add = MPAdd(t = mp_add_t)

        self.attn = None
        self.factorized_attn = factorize_space_time_attn

        # 如果有注意力机制
        if has_attn:
            attn_kwargs = dict(
                dim = dim_out,
                heads = max(ceil(dim_out / attn_dim_head), 2),
                dim_head = attn_dim_head,
                mp_add_t = attn_res_mp_add_t,
                flash = attn_flash
            )

            # 如果需要分解空间和时间的注意力机制
            if factorize_space_time_attn:
                self.attn = nn.ModuleList([
                    Attention(**attn_kwargs, only_space = True),
                    Attention(**attn_kwargs, only_time = True),
                ])
            else:
                self.attn = Attention(**attn_kwargs)

    # 前向传播函数
    def forward(
        self,
        x,
        emb = None
        ):
        # 如果存在下采样参数
        if self.downsample:
            # 获取输入张量的时间、高度、宽度
            t, h, w = x.shape[-3:]
            # 根据下采样配置计算缩放因子
            resize_factors = tuple((2 if downsample else 1) for downsample in self.downsample_config)
            # 计算插值后的形状
            interpolate_shape = tuple(shape // factor for shape, factor in zip((t, h, w), resize_factors))

            # 对输入张量进行三线性插值
            x = F.interpolate(x, interpolate_shape, mode='trilinear')
            # 使用下采样卷积层处理插值后的张量
            x = self.downsample_conv(x)

        # 对输入张量进行像素归一化
        x = self.pixel_norm(x)

        # 复制输入张量
        res = x.clone()

        # 使用第一个残差块处理输入张量
        x = self.block1(x)

        # 如果存在嵌入向量
        if exists(emb):
            # 计算缩放因子
            scale = self.to_emb(emb) + 1
            # 对输入张量进行缩放
            x = x * rearrange(scale, 'b c -> b c 1 1 1')

        # 使用第二个残差块处理输入张量
        x = self.block2(x)

        # 将残差块的输出与之前复制的张量相加
        x = self.res_mp_add(x, res)

        # 如果存在注意力机制
        if exists(self.attn):
            # 如果使用分解的注意力机制
            if self.factorized_attn:
                # 获取空间注意力和时间注意力
                attn_space, attn_time = self.attn
                # 先对空间进行注意力处理
                x = attn_space(x)
                # 再对时间进行注意力处理
                x = attn_time(x)

            else:
                # 使用整体的注意力机制处理输入张量
                x = self.attn(x)

        # 返回处理后的张量
        return x
# 定义一个名为 Decoder 的类,继承自 Module 类
class Decoder(Module):
    # 初始化方法
    def __init__(
        self,
        dim,
        dim_out = None,
        *,
        emb_dim = None,
        dropout = 0.1,
        mp_add_t = 0.3,
        has_attn = False,
        attn_dim_head = 64,
        attn_res_mp_add_t = 0.3,
        attn_flash = False,
        factorize_space_time_attn = False,
        upsample = False,
        upsample_config: Tuple[bool, bool, bool] = (True, True, True)
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 如果未指定 dim_out,则设为 dim
        dim_out = default(dim_out, dim)

        # 设置是否需要上采样和上采样配置
        self.upsample = upsample
        self.upsample_config = upsample_config

        # 如果不需要上采样,则需要跳跃连接
        self.needs_skip = not upsample

        # 如果存在 emb_dim,则创建线性层和增益层
        self.to_emb = None
        if exists(emb_dim):
            self.to_emb = nn.Sequential(
                Linear(emb_dim, dim_out),
                Gain()
            )

        # 第一个块包含 MPSiLU 和 3D 卷积层
        self.block1 = nn.Sequential(
            MPSiLU(),
            Conv3d(dim, dim_out, 3)
        )

        # 第二个块包含 MPSiLU、Dropout 和 3D 卷积层
        self.block2 = nn.Sequential(
            MPSiLU(),
            nn.Dropout(dropout),
            Conv3d(dim_out, dim_out, 3)
        )

        # 如果输入维度不等于输出维度,则使用 1x1 卷积层进行维度匹配
        self.res_conv = Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

        # 创建 MPAdd 模块
        self.res_mp_add = MPAdd(t = mp_add_t)

        # 初始化注意力机制相关参数
        self.attn = None
        self.factorized_attn = factorize_space_time_attn

        # 如果需要注意力机制
        if has_attn:
            attn_kwargs = dict(
                dim = dim_out,
                heads = max(ceil(dim_out / attn_dim_head), 2),
                dim_head = attn_dim_head,
                mp_add_t = attn_res_mp_add_t,
                flash = attn_flash
            )

            # 如果需要分解空间和时间的注意力机制
            if factorize_space_time_attn:
                self.attn = nn.ModuleList([
                    Attention(**attn_kwargs, only_space = True),
                    Attention(**attn_kwargs, only_time = True),
                ])
            else:
                self.attn = Attention(**attn_kwargs)

    # 前向传播方法
    def forward(
        self,
        x,
        emb = None
    ):
        # 如果需要上采样
        if self.upsample:
            t, h, w = x.shape[-3:]
            resize_factors = tuple((2 if upsample else 1) for upsample in self.upsample_config)
            interpolate_shape = tuple(shape * factor for shape, factor in zip((t, h, w), resize_factors))

            x = F.interpolate(x, interpolate_shape, mode = 'trilinear')

        # 计算残差连接
        res = self.res_conv(x)

        # 第一个块的操作
        x = self.block1(x)

        # 如果存在 emb,则进行缩放
        if exists(emb):
            scale = self.to_emb(emb) + 1
            x = x * rearrange(scale, 'b c -> b c 1 1 1')

        # 第二个块的操作
        x = self.block2(x)

        # 计算残差连接的 MPAdd
        x = self.res_mp_add(x, res)

        # 如果存在注意力机制
        if exists(self.attn):
            # 如果使用分解的注意力机制
            if self.factorized_attn:
                attn_space, attn_time = self.attn
                x = attn_space(x)
                x = attn_time(x)

            else:
                x = self.attn(x)

        return x

# 定义名为 Attention 的类,继承自 Module 类
class Attention(Module):
    # 初始化方法
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 64,
        num_mem_kv = 4,
        flash = False,
        mp_add_t = 0.3,
        only_space = False,
        only_time = False
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 确保只有空间或时间中的一个为 True
        assert (int(only_space) + int(only_time)) <= 1

        # 设置头数和隐藏维度
        self.heads = heads
        hidden_dim = dim_head * heads

        # 像素归一化
        self.pixel_norm = PixelNorm(dim = -1)

        # 注意力机制
        self.attend = Attend(flash = flash)

        # 记忆键值对
        self.mem_kv = nn.Parameter(torch.randn(2, heads, num_mem_kv, dim_head))
        self.to_qkv = Conv3d(dim, hidden_dim * 3, 1)
        self.to_out = Conv3d(hidden_dim, dim, 1)

        # MPAdd 模块
        self.mp_add = MPAdd(t = mp_add_t)

        # 是否只考虑空间或时间
        self.only_space = only_space
        self.only_time = only_time
    # 定义前向传播函数,接受输入 x
    def forward(self, x):
        # 保存输入 x 的原始形状
        res, orig_shape = x, x.shape
        b, c, t, h, w = orig_shape

        # 将输入 x 转换为查询、键、值
        qkv = self.to_qkv(x)

        # 根据 self.only_space 和 self.only_time 进行不同的重排操作
        if self.only_space:
            qkv = rearrange(qkv, 'b c t x y -> (b t) c x y')
        elif self.only_time:
            qkv = rearrange(qkv, 'b c t x y -> (b x y) c t')

        # 将查询、键、值分成三部分
        qkv = qkv.chunk(3, dim = 1)

        # 重排查询、键、值的形状
        q, k, v = map(lambda t: rearrange(t, 'b (h c) ... -> b h (...) c', h = self.heads), qkv)

        # 复制记忆键值对
        mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = k.shape[0]), self.mem_kv)

        # 拼接键和值
        k, v = map(partial(torch.cat, dim = -2), ((mk, k), (mv, v)))

        # 对查询、键、值进行像素归一化
        q, k, v = map(self.pixel_norm, (q, k, v))

        # 进行注意力计算
        out = self.attend(q, k, v)

        # 重排输出形状
        out = rearrange(out, 'b h n d -> b (h d) n')

        # 根据 self.only_space 和 self.only_time 进行不同的重排操作
        if self.only_space:
            out = rearrange(out, '(b t) c n -> b c (t n)', t = t)
        elif self.only_time:
            out = rearrange(out, '(b x y) c n -> b c (n x y)', x = h, y = w)

        # 恢复输出形状
        out = out.reshape(orig_shape)

        # 将输出转换为最终输出
        out = self.to_out(out)

        # 将最终输出与输入相加并返回
        return self.mp_add(out, res)
# 定义了一个名为KarrasUnet3D的类,代表Karras提出的3D U-Net模型
# 该模型没有偏置,没有组归一化,使用保持幅度的操作

class KarrasUnet3D(Module):
    """
    根据图21的配置G进行设计
    """

    def __init__(
        self,
        *,
        image_size,              # 图像大小
        frames,                  # 帧数
        dim = 192,               # 维度
        dim_max = 768,           # 通道数将在每次下采样时翻倍,并限制在这个值
        num_classes = None,      # 类别数,在论文中为一个流行的基准测试使用了1000个类别
        channels = 4,            # 为什么是4个通道,可能是指alpha通道?
        num_downsamples = 3,     # 下采样次数
        num_blocks_per_stage: Union[int, Tuple[int, ...]] = 4,  # 每个阶段的块数
        downsample_types: Optional[Tuple[str, ...]] = None,     # 下采样类型
        attn_res = (16, 8),      # 注意力机制的分辨率
        fourier_dim = 16,        # 傅立叶维度
        attn_dim_head = 64,      # 注意力机制的头数
        attn_flash = False,      # 是否使用闪光注意力
        mp_cat_t = 0.5,          # MP Cat阈值
        mp_add_emb_t = 0.5,      # MP Add Emb阈值
        attn_res_mp_add_t = 0.3, # 注意力机制MP Add阈值
        resnet_mp_add_t = 0.3,   # ResNet MP Add阈值
        dropout = 0.1,           # 丢弃率
        self_condition = False,  # 是否自我条件
        factorize_space_time_attn = False  # 是否分解空间时间注意力
    @property
    def downsample_factor(self):
        return 2 ** self.num_downsamples

    def forward(
        self,
        x,
        time,
        self_cond = None,
        class_labels = None
    ):
        # 验证图像形状

        assert x.shape[1:] == (self.channels, self.frames, self.image_size, self.image_size)

        # 自我条件

        if self.self_condition:
            self_cond = default(self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((self_cond, x), dim = 1)
        else:
            assert not exists(self_cond)

        # 时间条件

        time_emb = self.to_time_emb(time)

        # 类别条件

        assert xnor(exists(class_labels), self.needs_class_labels)

        if self.needs_class_labels:
            if class_labels.dtype in (torch.int, torch.long):
                class_labels = F.one_hot(class_labels, self.num_classes)

            assert class_labels.shape[-1] == self.num_classes
            class_labels = class_labels.float() * sqrt(self.num_classes)

            class_emb = self.to_class_emb(class_labels)

            time_emb = self.add_class_emb(time_emb, class_emb)

        # 最终的MP-SiLU用于嵌入

        emb = self.emb_activation(time_emb)

        # 跳跃连接

        skips = []

        # 输入块

        x = self.input_block(x)

        skips.append(x)

        # 下采样

        for encoder in self.downs:
            x = encoder(x, emb = emb)
            skips.append(x)

        # 中间

        for decoder in self.mids:
            x = decoder(x, emb = emb)

        # 上采样

        for decoder in self.ups:
            if decoder.needs_skip:
                skip = skips.pop()
                x = self.skip_mp_cat(x, skip)

            x = decoder(x, emb = emb)

        # 输出块

        return self.output_block(x)

# 改进的MP Transformer

class MPFeedForward(Module):
    def __init__(
        self,
        *,
        dim,
        mult = 4,
        mp_add_t = 0.3
    ):
        super().__init__()
        dim_inner = int(dim * mult)
        self.net = nn.Sequential(
            PixelNorm(dim = 1),
            Conv3d(dim, dim_inner, 1),
            MPSiLU(),
            Conv3d(dim_inner, dim, 1)
        )

        self.mp_add = MPAdd(t = mp_add_t)

    def forward(self, x):
        res = x
        out = self.net(x)
        return self.mp_add(out, res)

class MPImageTransformer(Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        num_mem_kv = 4,
        ff_mult = 4,
        attn_flash = False,
        residual_mp_add_t = 0.3
    # 定义一个继承自 nn.Module 的 Transformer 类
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化一个空的 ModuleList 用于存储 Transformer 的层
        self.layers = ModuleList([])

        # 根据指定的深度循环创建 Transformer 的每一层
        for _ in range(depth):
            # 在 layers 中添加一个包含 Attention 和 MPFeedForward 两个模块的 ModuleList
            self.layers.append(ModuleList([
                Attention(dim = dim, heads = heads, dim_head = dim_head, num_mem_kv = num_mem_kv, flash = attn_flash, mp_add_t = residual_mp_add_t),
                MPFeedForward(dim = dim, mult = ff_mult, mp_add_t = residual_mp_add_t)
            ]))

    # 定义 Transformer 类的前向传播函数
    def forward(self, x):

        # 遍历 Transformer 的每一层,依次进行 Attention 和 FeedForward 操作
        for attn, ff in self.layers:
            x = attn(x)
            x = ff(x)

        # 返回处理后的结果
        return x
# 如果当前脚本作为主程序运行
if __name__ == '__main__':

    # 创建一个 KarrasUnet3D 的实例
    unet = KarrasUnet3D(
        frames = 32,  # 视频帧数
        image_size = 64,  # 图像大小
        dim = 8,  # 维度
        dim_max = 768,  # 最大维度
        num_downsamples = 6,  # 下采样次数
        num_blocks_per_stage = (4, 3, 2, 2, 2, 2),  # 每个阶段的块数
        downsample_types = (
            'image',  # 图像下采样类型
            'frame',  # 帧下采样类型
            'image',  # 图像下采样类型
            'frame',  # 帧下采样类型
            'image',  # 图像下采样类型
            'frame',  # 帧下采样类型
        ),
        attn_dim_head = 8,  # 注意力机制的头数
        num_classes = 1000,  # 类别数
        factorize_space_time_attn = True  # 是否在空间和时间上分别进行注意力操作
    )

    # 创建一个形状为 (2, 4, 32, 64, 64) 的随机张量作为视频输入
    video = torch.randn(2, 4, 32, 64, 64)

    # 使用 unet 对视频进行去噪处理
    denoised_video = unet(
        video,  # 输入视频
        time = torch.ones(2,),  # 时间信息
        class_labels = torch.randint(0, 1000, (2,))  # 类别标签
    )

.\lucidrains\denoising-diffusion-pytorch\denoising_diffusion_pytorch\learned_gaussian_diffusion.py

import torch
from collections import namedtuple
from math import pi, sqrt, log as ln
from inspect import isfunction
from torch import nn, einsum
from einops import rearrange

from denoising_diffusion_pytorch.denoising_diffusion_pytorch import GaussianDiffusion, extract, unnormalize_to_zero_to_one

# 定义常量
NAT = 1. / ln(2)

# 定义命名元组
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start', 'pred_variance'])

# 辅助函数

# 判断变量是否存在
def exists(x):
    return x is not None

# 返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

# 张量辅助函数

# 计算张量的对数
def log(t, eps = 1e-15):
    return torch.log(t.clamp(min = eps))

# 求张量的平均值
def meanflat(x):
    return x.mean(dim = tuple(range(1, len(x.shape)))

# 计算两个正态分布之间的 KL 散度
def normal_kl(mean1, logvar1, mean2, logvar2):
    """
    KL divergence between normal distributions parameterized by mean and log-variance.
    """
    return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2))

# 近似标准正态分布的累积分布函数
def approx_standard_normal_cdf(x):
    return 0.5 * (1.0 + torch.tanh(sqrt(2.0 / pi) * (x + 0.044715 * (x ** 3)))

# 计算离散高斯分布的对数似然
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
    assert x.shape == means.shape == log_scales.shape

    centered_x = x - means
    inv_stdv = torch.exp(-log_scales)
    plus_in = inv_stdv * (centered_x + 1. / 255.)
    cdf_plus = approx_standard_normal_cdf(plus_in)
    min_in = inv_stdv * (centered_x - 1. / 255.)
    cdf_min = approx_standard_normal_cdf(min_in)
    log_cdf_plus = log(cdf_plus)
    log_one_minus_cdf_min = log(1. - cdf_min)
    cdf_delta = cdf_plus - cdf_min

    log_probs = torch.where(x < -thres,
        log_cdf_plus,
        torch.where(x > thres,
            log_one_minus_cdf_min,
            log(cdf_delta)))

    return log_probs

# https://arxiv.org/abs/2102.09672

# i thought the results were questionable, if one were to focus only on FID
# but may as well get this in here for others to try, as GLIDE is using it (and DALL-E2 first stage of cascade)
# gaussian diffusion for learned variance + hybrid eps simple + vb loss

# 继承 GaussianDiffusion 类,实现 LearnedGaussianDiffusion 类
class LearnedGaussianDiffusion(GaussianDiffusion):
    def __init__(
        self,
        model,
        vb_loss_weight = 0.001,  # lambda was 0.001 in the paper
        *args,
        **kwargs
    ):
        super().__init__(model, *args, **kwargs)
        assert model.out_dim == (model.channels * 2), 'dimension out of unet must be twice the number of channels for learned variance - you can also set the `learned_variance` keyword argument on the Unet to be `True`'
        assert not model.self_condition, 'not supported yet'

        self.vb_loss_weight = vb_loss_weight

    # 模型预测函数
    def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False):
        model_output = self.model(x, t)
        model_output, pred_variance = model_output.chunk(2, dim = 1)

        maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity

        if self.objective == 'pred_noise':
            pred_noise = model_output
            x_start = self.predict_start_from_noise(x, t, model_output)

        elif self.objective == 'pred_x0':
            pred_noise = self.predict_noise_from_start(x, t, model_output)
            x_start = model_output

        x_start = maybe_clip(x_start)

        return ModelPrediction(pred_noise, x_start, pred_variance)
    # 计算预测均值、方差和对数方差,根据输入的特征 x 和时间 t,以及是否裁剪去噪声
    def p_mean_variance(self, *, x, t, clip_denoised, model_output = None, **kwargs):
        # 如果未提供模型输出,则使用默认的模型输出函数计算模型输出
        model_output = default(model_output, lambda: self.model(x, t))
        # 将模型输出分成预测噪声和插值分数未归一化的方差
        pred_noise, var_interp_frac_unnormalized = model_output.chunk(2, dim = 1)

        # 提取后验对数方差的最小值和最大值
        min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
        max_log = extract(torch.log(self.betas), t, x.shape)
        # 将插值分数未归一化的方差归一化到 [0, 1] 区间
        var_interp_frac = unnormalize_to_zero_to_one(var_interp_frac_unnormalized)

        # 计算模型对数方差和方差
        model_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
        model_variance = model_log_variance.exp()

        # 根据预测噪声和时间 t 预测起始值 x_start
        x_start = self.predict_start_from_noise(x, t, pred_noise)

        # 如果需要裁剪去噪声,则将 x_start 裁剪到 [-1, 1] 区间
        if clip_denoised:
            x_start.clamp_(-1., 1.)

        # 计算模型均值和其他参数
        model_mean, _, _ = self.q_posterior(x_start, x, t)

        # 返回模型均值、方差、对数方差和起始值 x_start
        return model_mean, model_variance, model_log_variance, x_start

    # 计算损失函数,包括 KL 散度和简单损失
    def p_losses(self, x_start, t, noise = None, clip_denoised = False):
        # 如果未提供噪声,则使用默认的噪声函数生成噪声
        noise = default(noise, lambda: torch.randn_like(x_start))
        # 根据起始值 x_start、时间 t 和噪声生成 x_t
        x_t = self.q_sample(x_start = x_start, t = t, noise = noise)

        # 获取模型输出
        model_output = self.model(x_t, t)

        # 计算学习方差(插值)的 KL 散度
        true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_t, t = t)
        model_mean, _, model_log_variance, _ = self.p_mean_variance(x = x_t, t = t, clip_denoised = clip_denoised, model_output = model_output)

        # 为了稳定性,使用分离的模型预测均值计算 KL 散度
        detached_model_mean = model_mean.detach()

        kl = normal_kl(true_mean, true_log_variance_clipped, detached_model_mean, model_log_variance)
        kl = meanflat(kl) * NAT

        # 计算解码器负对数似然
        decoder_nll = -discretized_gaussian_log_likelihood(x_start, means = detached_model_mean, log_scales = 0.5 * model_log_variance)
        decoder_nll = meanflat(decoder_nll) * NAT

        # 在第一个时间步返回解码器 NLL,否则返回 KL 散度
        vb_losses = torch.where(t == 0, decoder_nll, kl)

        # 简单损失 - 预测噪声、x0 或 x_prev
        pred_noise, _ = model_output.chunk(2, dim = 1)
        simple_losses = F.mse_loss(pred_noise, noise)

        # 返回简单损失和 VB 损失的平均值乘以 VB 损失权重
        return simple_losses + vb_losses.mean() * self.vb_loss_weight

.\lucidrains\denoising-diffusion-pytorch\denoising_diffusion_pytorch\simple_diffusion.py

# 导入数学库
import math
# 导入 functools 模块中的 partial 和 wraps 函数
from functools import partial, wraps

# 导入 torch 库
import torch
# 从 torch 库中导入 sqrt 函数
from torch import sqrt
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch 库中导入 nn.functional 模块和 F 别名
import torch.nn.functional as F
# 从 torch.special 模块中导入 expm1 函数
from torch.special import expm1
# 从 torch.cuda.amp 模块中导入 autocast 函数

# 导入 tqdm 库
from tqdm import tqdm
# 从 einops 库中导入 rearrange、repeat、reduce、pack、unpack 函数
from einops import rearrange, repeat, reduce, pack, unpack
# 从 einops.layers.torch 模块中导入 Rearrange 类

# helpers

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 返回输入的函数
def identity(t):
    return t

# 判断是否为 lambda 函数的函数
def is_lambda(f):
    return callable(f) and f.__name__ == "<lambda>"

# 返回默认值的函数
def default(val, d):
    if exists(val):
        return val
    return d() if is_lambda(d) else d

# 将输入转换为元组的函数
def cast_tuple(t, l = 1):
    return ((t,) * l) if not isinstance(t, tuple) else t

# 在输入张量中添加维度的函数
def append_dims(t, dims):
    shape = t.shape
    return t.reshape(*shape, *((1,) * dims))

# 对输入张量进行 L2 归一化的函数
def l2norm(t):
    return F.normalize(t, dim = -1)

# u-vit 相关函数和模块

# 上采样模块
class Upsample(nn.Module):
    def __init__(
        self,
        dim,
        dim_out = None,
        factor = 2
    ):
        super().__init__()
        self.factor = factor
        self.factor_squared = factor ** 2

        dim_out = default(dim_out, dim)
        conv = nn.Conv2d(dim, dim_out * self.factor_squared, 1)

        self.net = nn.Sequential(
            conv,
            nn.SiLU(),
            nn.PixelShuffle(factor)
        )

        self.init_conv_(conv)

    # 初始化卷积层权重
    def init_conv_(self, conv):
        o, i, h, w = conv.weight.shape
        conv_weight = torch.empty(o // self.factor_squared, i, h, w)
        nn.init.kaiming_uniform_(conv_weight)
        conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r = self.factor_squared)

        conv.weight.data.copy_(conv_weight)
        nn.init.zeros_(conv.bias.data)

    def forward(self, x):
        return self.net(x)

# 下采样模块
def Downsample(
    dim,
    dim_out = None,
    factor = 2
):
    return nn.Sequential(
        Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = factor, p2 = factor),
        nn.Conv2d(dim * (factor ** 2), default(dim_out, dim), 1)
    )

# RMS 归一化模块
class RMSNorm(nn.Module):
    def __init__(self, dim, scale = True, normalize_dim = 2):
        super().__init__()
        self.g = nn.Parameter(torch.ones(dim)) if scale else 1

        self.scale = scale
        self.normalize_dim = normalize_dim

    def forward(self, x):
        normalize_dim = self.normalize_dim
        scale = append_dims(self.g, x.ndim - self.normalize_dim - 1) if self.scale else 1
        return F.normalize(x, dim = normalize_dim) * scale * (x.shape[normalize_dim] ** 0.5)

# 正弦位置嵌入模块
class LearnedSinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        assert (dim % 2) == 0
        half_dim = dim // 2
        self.weights = nn.Parameter(torch.randn(half_dim))

    def forward(self, x):
        x = rearrange(x, 'b -> b 1')
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
        fouriered = torch.cat((x, fouriered), dim = -1)
        return fouriered

# 基础模块
class Block(nn.Module):
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift = None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

class ResnetBlock(nn.Module):
    # 初始化函数,定义神经网络结构
    def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
        # 调用父类的初始化函数
        super().__init__()
        # 如果存在时间嵌入维度,则创建包含激活函数和线性层的序列模块
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, dim_out * 2)
        ) if exists(time_emb_dim) else None

        # 创建第一个块
        self.block1 = Block(dim, dim_out, groups = groups)
        # 创建第二个块
        self.block2 = Block(dim_out, dim_out, groups = groups)
        # 如果输入维度和输出维度不相等,则使用卷积层进行维度转换,否则使用恒等映射
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    # 前向传播函数
    def forward(self, x, time_emb = None):

        scale_shift = None
        # 如果存在时间嵌入模块和时间嵌入向量,则进行处理
        if exists(self.mlp) and exists(time_emb):
            # 对时间嵌入向量进行处理
            time_emb = self.mlp(time_emb)
            # 重新排列时间嵌入向量的维度
            time_emb = rearrange(time_emb, 'b c -> b c 1 1')
            # 将时间嵌入向量分成两部分,用于缩放和平移
            scale_shift = time_emb.chunk(2, dim = 1)

        # 使用第一个块处理输入数据
        h = self.block1(x, scale_shift = scale_shift)

        # 使用第二个块处理第一个块的输出
        h = self.block2(h)

        # 返回块处理后的结果与输入数据经过维度转换后的结果的和
        return h + self.res_conv(x)
class LinearAttention(nn.Module):
    # 初始化线性注意力模块
    def __init__(self, dim, heads = 4, dim_head = 32):
        super().__init__()
        # 缩放因子
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads

        # 归一化层
        self.norm = RMSNorm(dim, normalize_dim = 1)
        # 转换输入到查询、键、值
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)

        # 输出转换层
        self.to_out = nn.Sequential(
            nn.Conv2d(hidden_dim, dim, 1),
            RMSNorm(dim, normalize_dim = 1)
        )

    # 前向传播函数
    def forward(self, x):
        residual = x

        b, c, h, w = x.shape

        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)

        q = q.softmax(dim = -2)
        k = k.softmax(dim = -1)

        q = q * self.scale

        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)

        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
        out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)

        return self.to_out(out) + residual

class Attention(nn.Module):
    # 初始化注意力模块
    def __init__(self, dim, heads = 4, dim_head = 32, scale = 8, dropout = 0.):
        super().__init__()
        self.scale = scale
        self.heads = heads
        hidden_dim = dim_head * heads

        # 归一化层
        self.norm = RMSNorm(dim)

        self.attn_dropout = nn.Dropout(dropout)
        # 转换输入到查询、键、值
        self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias = False)

        self.q_scale = nn.Parameter(torch.ones(dim_head))
        self.k_scale = nn.Parameter(torch.ones(dim_head))

        # 输出转换层
        self.to_out = nn.Linear(hidden_dim, dim, bias = False)

    # 前向传播函数
    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        q, k = map(l2norm, (q, k))

        q = q * self.q_scale
        k = k * self.k_scale

        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = sim.softmax(dim = -1)
        attn = self.attn_dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class FeedForward(nn.Module):
    # 初始化前馈神经网络模块
    def __init__(
        self,
        dim,
        cond_dim,
        mult = 4,
        dropout = 0.
    ):
        super().__init__()
        # 归一化层
        self.norm = RMSNorm(dim, scale = False)
        dim_hidden = dim * mult

        # 缩放和偏移层
        self.to_scale_shift = nn.Sequential(
            nn.SiLU(),
            nn.Linear(cond_dim, dim_hidden * 2),
            Rearrange('b d -> b 1 d')
        )

        to_scale_shift_linear = self.to_scale_shift[-2]
        nn.init.zeros_(to_scale_shift_linear.weight)
        nn.init.zeros_(to_scale_shift_linear.bias)

        # 输入投影层
        self.proj_in = nn.Sequential(
            nn.Linear(dim, dim_hidden, bias = False),
            nn.SiLU()
        )

        # 输出投影层
        self.proj_out = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(dim_hidden, dim, bias = False)
        )

    # 前向传播函数
    def forward(self, x, t):
        x = self.norm(x)
        x = self.proj_in(x)

        scale, shift = self.to_scale_shift(t).chunk(2, dim = -1)
        x = x * (scale + 1) + shift

        return self.proj_out(x)

# vit

class Transformer(nn.Module):
    # 初始化Transformer模块
    def __init__(
        self,
        dim,
        time_cond_dim,
        depth,
        dim_head = 32,
        heads = 4,
        ff_mult = 4,
        dropout = 0.,
    ):
        super().__init__()

        self.layers = nn.ModuleList([])
        # 创建多层Transformer
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout),
                FeedForward(dim = dim, mult = ff_mult, cond_dim = time_cond_dim, dropout = dropout)
            ]))

    # 前向传播函数
    def forward(self, x, t):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x, t) + x

        return x
# 定义 UViT 类,继承自 nn.Module
class UViT(nn.Module):
    # 初始化函数,接受多个参数
    def __init__(
        self,
        dim,  # 特征维度
        init_dim = None,  # 初始维度,默认为 None
        out_dim = None,  # 输出维度,默认为 None
        dim_mults = (1, 2, 4, 8),  # 维度倍增因子,默认为 (1, 2, 4, 8)
        downsample_factor = 2,  # 下采样因子,默认为 2
        channels = 3,  # 通道数,默认为 3
        vit_depth = 6,  # ViT 深度,默认为 6
        vit_dropout = 0.2,  # ViT dropout 概率,默认为 0.2
        attn_dim_head = 32,  # 注意力头维度,默认为 32
        attn_heads = 4,  # 注意力头数,默认为 4
        ff_mult = 4,  # FeedForward 层倍增因子,默认为 4
        resnet_block_groups = 8,  # ResNet 块组数,默认为 8
        learned_sinusoidal_dim = 16,  # 学习的正弦维度,默认为 16
        init_img_transform: callable = None,  # 初始图像变换函数,默认为 None
        final_img_itransform: callable = None,  # 最终图像逆变换函数,默认为 None
        patch_size = 1,  # 补丁大小,默认为 1
        dual_patchnorm = False  # 双补丁规范化,默认为 False
        ):
        # 调用父类的构造函数
        super().__init__()

        # 用于初始 DWT 变换(或者研究者想要尝试的其他变换)

        if exists(init_img_transform) and exists(final_img_itransform):
            # 初始化形状为 1x1x32x32 的张量
            init_shape = torch.Size(1, 1, 32, 32)
            mock_tensor = torch.randn(init_shape)
            # 确保经过 final_img_itransform 和 init_img_transform 变换后的形状与初始形状相同
            assert final_img_itransform(init_img_transform(mock_tensor)).shape == init_shape

        # 设置初始图像变换和最终图像逆变换
        self.init_img_transform = default(init_img_transform, identity)
        self.final_img_itransform = default(final_img_itransform, identity)

        input_channels = channels

        init_dim = default(init_dim, dim)
        # 初始化卷积层,输入通道数为 input_channels,输出通道数为 init_dim,卷积核大小为 7x7,填充为 3
        self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)

        # 是否进行初始补丁处理,作为 DWT 的替代方案
        self.unpatchify = identity

        input_channels = channels * (patch_size ** 2)
        needs_patch = patch_size > 1

        if needs_patch:
            if not dual_patchnorm:
                # 如果不使用双补丁规范化,则初始化卷积层
                self.init_conv = nn.Conv2d(channels, init_dim, patch_size, stride = patch_size)
            else:
                # 使用双补丁规范化
                self.init_conv = nn.Sequential(
                    Rearrange('b c (h p1) (w p2) -> b h w (c p1 p2)', p1 = patch_size, p2 = patch_size),
                    nn.LayerNorm(input_channels),
                    nn.Linear(input_channels, init_dim),
                    nn.LayerNorm(init_dim),
                    Rearrange('b h w c -> b c h w')
                )

            # 反卷积层,用于将补丁还原为原始图像
            self.unpatchify = nn.ConvTranspose2d(input_channels, channels, patch_size, stride = patch_size)

        # 确定维度
        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        # 部分 ResNet 块
        resnet_block = partial(ResnetBlock, groups = resnet_block_groups)

        # 时间嵌入
        time_dim = dim * 4

        sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim)
        fourier_dim = learned_sinusoidal_dim + 1

        # 时间 MLP
        self.time_mlp = nn.Sequential(
            sinu_pos_emb,
            nn.Linear(fourier_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )

        # 下采样因子
        downsample_factor = cast_tuple(downsample_factor, len(dim_mults)
        assert len(downsample_factor) == len(dim_mults)

        # 层
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, ((dim_in, dim_out), factor) in enumerate(zip(in_out, downsample_factor)):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(nn.ModuleList([
                resnet_block(dim_in, dim_in, time_emb_dim = time_dim),
                resnet_block(dim_in, dim_in, time_emb_dim = time_dim),
                LinearAttention(dim_in),
                Downsample(dim_in, dim_out, factor = factor)
            ]))

        mid_dim = dims[-1]

        # ViT 模型
        self.vit = Transformer(
            dim = mid_dim,
            time_cond_dim = time_dim,
            depth = vit_depth,
            dim_head = attn_dim_head,
            heads = attn_heads,
            ff_mult = ff_mult,
            dropout = vit_dropout
        )

        for ind, ((dim_in, dim_out), factor) in enumerate(zip(reversed(in_out), reversed(downsample_factor))):
            is_last = ind == (len(in_out) - 1)

            self.ups.append(nn.ModuleList([
                Upsample(dim_out, dim_in, factor = factor),
                resnet_block(dim_in * 2, dim_in, time_emb_dim = time_dim),
                resnet_block(dim_in * 2, dim_in, time_emb_dim = time_dim),
                LinearAttention(dim_in),
            ]))

        default_out_dim = input_channels
        self.out_dim = default(out_dim, default_out_dim)

        # 最终 ResNet 块和卷积层
        self.final_res_block = resnet_block(dim * 2, dim, time_emb_dim = time_dim)
        self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
    # 定义前向传播函数,接受输入 x 和时间信息 time
    def forward(self, x, time):
        # 对输入图像进行初始化转换
        x = self.init_img_transform(x)

        # 初始卷积操作
        x = self.init_conv(x)
        # 保存初始特征图
        r = x.clone()

        # 时间信息通过 MLP 网络处理
        t = self.time_mlp(time)

        # 存储中间特征图的列表
        h = []

        # 下采样模块
        for block1, block2, attn, downsample in self.downs:
            # 第一个块处理
            x = block1(x, t)
            h.append(x)

            # 第二个块处理
            x = block2(x, t)
            # 注意力机制处理
            x = attn(x)
            h.append(x)

            # 下采样操作
            x = downsample(x)

        # 重新排列特征图维度
        x = rearrange(x, 'b c h w -> b h w c')
        # 打包特征图
        x, ps = pack([x], 'b * c')

        # Vision Transformer 处理
        x = self.vit(x, t)

        # 解包特征图
        x, = unpack(x, ps, 'b * c')
        # 重新排列特征图维度
        x = rearrange(x, 'b h w c -> b c h w')

        # 上采样模块
        for upsample, block1, block2, attn in self.ups:
            # 上采样操作
            x = upsample(x)

            # 拼接特征图
            x = torch.cat((x, h.pop()), dim = 1)
            x = block1(x, t)

            # 拼接特征图
            x = torch.cat((x, h.pop()), dim = 1)
            x = block2(x, t)
            x = attn(x)

        # 拼接初始特征图
        x = torch.cat((x, r), dim = 1)

        # 最终残差块处理
        x = self.final_res_block(x, t)
        # 最终卷积操作
        x = self.final_conv(x)

        # 反向解除图像补丁
        x = self.unpatchify(x)
        # 返回最终图像
        return self.final_img_itransform(x)
# normalization functions

# 将图像数据归一化到 [-1, 1] 范围
def normalize_to_neg_one_to_one(img):
    return img * 2 - 1

# 将归一化后的数据反归一化到 [0, 1] 范围
def unnormalize_to_zero_to_one(t):
    return (t + 1) * 0.5

# diffusion helpers

# 将 t 张量的维度右侧填充到与 x 张量相同维度
def right_pad_dims_to(x, t):
    padding_dims = x.ndim - t.ndim
    if padding_dims <= 0:
        return t
    return t.view(*t.shape, *((1,) * padding_dims))

# logsnr schedules and shifting / interpolating decorators
# only cosine for now

# 计算张量 t 的对数,避免 t 小于 eps 时取对数出错
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 计算 logsnr 的余弦调度
def logsnr_schedule_cosine(t, logsnr_min = -15, logsnr_max = 15):
    t_min = math.atan(math.exp(-0.5 * logsnr_max))
    t_max = math.atan(math.exp(-0.5 * logsnr_min))
    return -2 * log(torch.tan(t_min + t * (t_max - t_min)))

# 对 logsnr_schedule_cosine 进行偏移
def logsnr_schedule_shifted(fn, image_d, noise_d):
    shift = 2 * math.log(noise_d / image_d)
    @wraps(fn)
    def inner(*args, **kwargs):
        nonlocal shift
        return fn(*args, **kwargs) + shift
    return inner

# 对 logsnr_schedule_cosine 进行插值
def logsnr_schedule_interpolated(fn, image_d, noise_d_low, noise_d_high):
    logsnr_low_fn = logsnr_schedule_shifted(fn, image_d, noise_d_low)
    logsnr_high_fn = logsnr_schedule_shifted(fn, image_d, noise_d_high)

    @wraps(fn)
    def inner(t, *args, **kwargs):
        nonlocal logsnr_low_fn
        nonlocal logsnr_high_fn
        return t * logsnr_low_fn(t, *args, **kwargs) + (1 - t) * logsnr_high_fn(t, *args, **kwargs)

    return inner

# main gaussian diffusion class

# 高斯扩散类
class GaussianDiffusion(nn.Module):
    def __init__(
        self,
        model: UViT,
        *,
        image_size,
        channels = 3,
        pred_objective = 'v',
        noise_schedule = logsnr_schedule_cosine,
        noise_d = None,
        noise_d_low = None,
        noise_d_high = None,
        num_sample_steps = 500,
        clip_sample_denoised = True,
        min_snr_loss_weight = True,
        min_snr_gamma = 5
    ):
        super().__init__()
        assert pred_objective in {'v', 'eps'}, 'whether to predict v-space (progressive distillation paper) or noise'

        self.model = model

        # image dimensions

        self.channels = channels
        self.image_size = image_size

        # training objective

        self.pred_objective = pred_objective

        # noise schedule

        assert not all([*map(exists, (noise_d, noise_d_low, noise_d_high))]), 'you must either set noise_d for shifted schedule, or noise_d_low and noise_d_high for shifted and interpolated schedule'

        # determine shifting or interpolated schedules

        self.log_snr = noise_schedule

        if exists(noise_d):
            self.log_snr = logsnr_schedule_shifted(self.log_snr, image_size, noise_d)

        if exists(noise_d_low) or exists(noise_d_high):
            assert exists(noise_d_low) and exists(noise_d_high), 'both noise_d_low and noise_d_high must be set'

            self.log_snr = logsnr_schedule_interpolated(self.log_snr, image_size, noise_d_low, noise_d_high)

        # sampling

        self.num_sample_steps = num_sample_steps
        self.clip_sample_denoised = clip_sample_denoised

        # loss weight

        self.min_snr_loss_weight = min_snr_loss_weight
        self.min_snr_gamma = min_snr_gamma

    @property
    def device(self):
        return next(self.model.parameters()).device
    # 计算均值和方差
    def p_mean_variance(self, x, time, time_next):
        
        # 计算当前时间点和下一个时间点的对数信噪比
        log_snr = self.log_snr(time)
        log_snr_next = self.log_snr(time_next)
        # 计算 c 值
        c = -expm1(log_snr - log_snr_next)

        # 计算 alpha 和 sigma
        squared_alpha, squared_alpha_next = log_snr.sigmoid(), log_snr_next.sigmoid()
        squared_sigma, squared_sigma_next = (-log_snr).sigmoid(), (-log_snr_next).sigmoid()
        alpha, sigma, alpha_next = map(sqrt, (squared_alpha, squared_sigma, squared_alpha_next))

        # 重复 log_snr 以匹配 x 的形状
        batch_log_snr = repeat(log_snr, ' -> b', b = x.shape[0])
        # 使用模型预测
        pred = self.model(x, batch_log_snr)

        # 根据预测目标选择不同的计算方式
        if self.pred_objective == 'v':
            x_start = alpha * x - sigma * pred
        elif self.pred_objective == 'eps':
            x_start = (x - sigma * pred) / alpha

        # 将 x_start 限制在 -1 到 1 之间
        x_start.clamp_(-1., 1.)

        # 计算模型均值和后验方差
        model_mean = alpha_next * (x * (1 - c) / alpha + c * x_start)
        posterior_variance = squared_sigma_next * c

        return model_mean, posterior_variance

    # 采样相关函数

    @torch.no_grad()
    def p_sample(self, x, time, time_next):
        batch, *_, device = *x.shape, x.device

        # 计算模型均值和方差
        model_mean, model_variance = self.p_mean_variance(x = x, time = time, time_next = time_next)

        # 如果是最后一个时间点,则直接返回模型均值
        if time_next == 0:
            return model_mean

        # 生成噪声并返回采样结果
        noise = torch.randn_like(x)
        return model_mean + sqrt(model_variance) * noise

    @torch.no_grad()
    def p_sample_loop(self, shape):
        batch = shape[0]

        # 生成随机初始图像
        img = torch.randn(shape, device = self.device)
        steps = torch.linspace(1., 0., self.num_sample_steps + 1, device = self.device)

        # 循环进行采样
        for i in tqdm(range(self.num_sample_steps), desc = 'sampling loop time step', total = self.num_sample_steps):
            times = steps[i]
            times_next = steps[i + 1]
            img = self.p_sample(img, times, times_next)

        # 将图像限制在 -1 到 1 之间,并反归一化到 [0, 1] 范围
        img.clamp_(-1., 1.)
        img = unnormalize_to_zero_to_one(img)
        return img

    @torch.no_grad()
    def sample(self, batch_size = 16):
        return self.p_sample_loop((batch_size, self.channels, self.image_size, self.image_size))

    # 训练相关函数 - 噪声预测

    @autocast(enabled = False)
    def q_sample(self, x_start, times, noise = None):
        noise = default(noise, lambda: torch.randn_like(x_start))

        # 计算 alpha 和 sigma,生成带噪声的图像
        log_snr = self.log_snr(times)
        log_snr_padded = right_pad_dims_to(x_start, log_snr)
        alpha, sigma = sqrt(log_snr_padded.sigmoid()), sqrt((-log_snr_padded).sigmoid())
        x_noised =  x_start * alpha + noise * sigma

        return x_noised, log_snr

    # 计算损失函数
    def p_losses(self, x_start, times, noise = None):
        noise = default(noise, lambda: torch.randn_like(x_start))

        # 生成带噪声的图像并计算模型输出
        x, log_snr = self.q_sample(x_start = x_start, times = times, noise = noise)
        model_out = self.model(x, log_snr)

        # 根据预测目标选择不同的计算方式
        if self.pred_objective == 'v':
            padded_log_snr = right_pad_dims_to(x, log_snr)
            alpha, sigma = padded_log_snr.sigmoid().sqrt(), (-padded_log_snr).sigmoid().sqrt()
            target = alpha * noise - sigma * x_start
        elif self.pred_objective == 'eps':
            target = noise

        # 计算均方误差损失
        loss = F.mse_loss(model_out, target, reduction = 'none')
        loss = reduce(loss, 'b ... -> b', 'mean')

        snr = log_snr.exp()

        maybe_clip_snr = snr.clone()
        if self.min_snr_loss_weight:
            maybe_clip_snr.clamp_(max = self.min_snr_gamma)

        # 根据预测目标选择不同的损失权重计算方式
        if self.pred_objective == 'v':
            loss_weight = maybe_clip_snr / (snr + 1)
        elif self.pred_objective == 'eps':
            loss_weight = maybe_clip_snr / snr

        return (loss * loss_weight).mean()
    # 定义一个前向传播函数,接受图像和其他参数
    def forward(self, img, *args, **kwargs):
        # 解包图像的形状信息,包括通道数、高度、宽度等
        b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
        # 断言图像的高度和宽度必须等于指定的图像大小
        assert h == img_size and w == img_size, f'height and width of image must be {img_size}'

        # 将图像数据归一化到 -1 到 1 之间
        img = normalize_to_neg_one_to_one(img)
        # 创建一个与图像数量相同的随机时间数组
        times = torch.zeros((img.shape[0],), device = self.device).float().uniform_(0, 1)

        # 调用损失函数计算函数,传入图像、时间和其他参数
        return self.p_losses(img, times, *args, **kwargs)

.\lucidrains\denoising-diffusion-pytorch\denoising_diffusion_pytorch\version.py

# 定义变量 __version__,赋值为 '1.11.0'
__version__ = '1.11.0'

.\lucidrains\denoising-diffusion-pytorch\denoising_diffusion_pytorch\v_param_continuous_time_gaussian_diffusion.py

# 导入数学库和PyTorch库
import math
import torch
# 从torch库中导入sqrt函数
from torch import sqrt
# 从torch库中导入nn、einsum模块
from torch import nn, einsum
# 从torch库中导入F模块
import torch.nn.functional as F
# 从torch.special库中导入expm1函数
from torch.special import expm1
# 从torch.cuda.amp库中导入autocast函数

# 从tqdm库中导入tqdm函数
from tqdm import tqdm
# 从einops库中导入rearrange、repeat、reduce函数
from einops import rearrange, repeat, reduce
# 从einops.layers.torch库中导入Rearrange类

# helpers

# 判断变量是否存在的函数
def exists(val):
    return val is not None

# 如果变量存在则返回该变量,否则返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# normalization functions

# 将图像归一化到[-1, 1]范围内
def normalize_to_neg_one_to_one(img):
    return img * 2 - 1

# 将张量反归一化到[0, 1]范围内
def unnormalize_to_zero_to_one(t):
    return (t + 1) * 0.5

# diffusion helpers

# 将t张量的维度右侧填充到与x张量相同维度
def right_pad_dims_to(x, t):
    padding_dims = x.ndim - t.ndim
    if padding_dims <= 0:
        return t
    return t.view(*t.shape, *((1,) * padding_dims))

# continuous schedules
# log(snr) that approximates the original linear schedule

# 计算t的对数,避免t小于eps时取对数出错
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 计算alpha_cosine_log_snr函数
def alpha_cosine_log_snr(t, s = 0.008):
    return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5)

# 定义VParamContinuousTimeGaussianDiffusion类
class VParamContinuousTimeGaussianDiffusion(nn.Module):
    """
    a new type of parameterization in v-space proposed in https://arxiv.org/abs/2202.00512 that
    (1) allows for improved distillation over noise prediction objective and
    (2) noted in imagen-video to improve upsampling unets by removing the color shifting artifacts
    """

    # 初始化函数
    def __init__(
        self,
        model,
        *,
        image_size,
        channels = 3,
        num_sample_steps = 500,
        clip_sample_denoised = True,
    ):
        super().__init__()
        assert model.random_or_learned_sinusoidal_cond
        assert not model.self_condition, 'not supported yet'

        self.model = model

        # image dimensions

        self.channels = channels
        self.image_size = image_size

        # continuous noise schedule related stuff

        self.log_snr = alpha_cosine_log_snr

        # sampling

        self.num_sample_steps = num_sample_steps
        self.clip_sample_denoised = clip_sample_denoised        

    # 获取设备信息
    @property
    def device(self):
        return next(self.model.parameters()).device

    # 计算p_mean_variance函数
    def p_mean_variance(self, x, time, time_next):
        # reviewer found an error in the equation in the paper (missing sigma)
        # following - https://openreview.net/forum?id=2LdBqxc1Yv&noteId=rIQgH0zKsRt

        log_snr = self.log_snr(time)
        log_snr_next = self.log_snr(time_next)
        c = -expm1(log_snr - log_snr_next)

        squared_alpha, squared_alpha_next = log_snr.sigmoid(), log_snr_next.sigmoid()
        squared_sigma, squared_sigma_next = (-log_snr).sigmoid(), (-log_snr_next).sigmoid()

        alpha, sigma, alpha_next = map(sqrt, (squared_alpha, squared_sigma, squared_alpha_next))

        batch_log_snr = repeat(log_snr, ' -> b', b = x.shape[0])

        pred_v = self.model(x, batch_log_snr)

        # shown in Appendix D in the paper
        x_start = alpha * x - sigma * pred_v

        if self.clip_sample_denoised:
            x_start.clamp_(-1., 1.)

        model_mean = alpha_next * (x * (1 - c) / alpha + c * x_start)

        posterior_variance = squared_sigma_next * c

        return model_mean, posterior_variance

    # 与采样相关的函数

    @torch.no_grad()
    def p_sample(self, x, time, time_next):
        batch, *_, device = *x.shape, x.device

        model_mean, model_variance = self.p_mean_variance(x = x, time = time, time_next = time_next)

        if time_next == 0:
            return model_mean

        noise = torch.randn_like(x)
        return model_mean + sqrt(model_variance) * noise

    @torch.no_grad()
    # 定义一个函数,用于执行采样循环
    def p_sample_loop(self, shape):
        # 获取批次大小
        batch = shape[0]

        # 生成指定形状的随机张量
        img = torch.randn(shape, device = self.device)
        # 在指定范围内生成一系列步长
        steps = torch.linspace(1., 0., self.num_sample_steps + 1, device = self.device)

        # 循环执行采样步骤
        for i in tqdm(range(self.num_sample_steps), desc = 'sampling loop time step', total = self.num_sample_steps):
            times = steps[i]
            times_next = steps[i + 1]
            img = self.p_sample(img, times, times_next)

        # 将张量值限制在指定范围内
        img.clamp_(-1., 1.)
        # 将张量值从[-1, 1]范围转换为[0, 1]范围
        img = unnormalize_to_zero_to_one(img)
        return img

    # 无梯度计算的采样函数
    @torch.no_grad()
    def sample(self, batch_size = 16):
        return self.p_sample_loop((batch_size, self.channels, self.image_size, self.image_size))

    # 训练相关函数 - 噪声预测

    # 生成采样数据
    @autocast(enabled = False)
    def q_sample(self, x_start, times, noise = None):
        # 生成随机噪声
        noise = default(noise, lambda: torch.randn_like(x_start))

        # 计算信噪比
        log_snr = self.log_snr(times)

        # 对信噪比进行填充
        log_snr_padded = right_pad_dims_to(x_start, log_snr)
        alpha, sigma = sqrt(log_snr_padded.sigmoid()), sqrt((-log_snr_padded).sigmoid())
        x_noised =  x_start * alpha + noise * sigma

        return x_noised, log_snr, alpha, sigma

    # 生成随机时间
    def random_times(self, batch_size):
        return torch.zeros((batch_size,), device = self.device).float().uniform_(0, 1)

    # 计算损失函数
    def p_losses(self, x_start, times, noise = None):
        noise = default(noise, lambda: torch.randn_like(x_start))

        x, log_snr, alpha, sigma = self.q_sample(x_start = x_start, times = times, noise = noise)

        # 描述在第4节中作为预测目标,附录D中有推导
        v = alpha * noise - sigma * x_start

        model_out = self.model(x, log_snr)

        return F.mse_loss(model_out, v)

    # 前向传播函数
    def forward(self, img, *args, **kwargs):
        # 获取输入图像的形状和设备信息
        b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
        # 断言图像的高度和宽度必须为指定大小
        assert h == img_size and w == img_size, f'height and width of image must be {img_size}'

        # 生成随机时间
        times = self.random_times(b)
        # 将图像值从[-1, 1]范围转换���[-1, 1]范围
        img = normalize_to_neg_one_to_one(img)
        return self.p_losses(img, times, *args, **kwargs)

.\lucidrains\denoising-diffusion-pytorch\denoising_diffusion_pytorch\weighted_objective_gaussian_diffusion.py

# 导入 torch 库
import torch
# 从 inspect 库中导入 isfunction 函数
from inspect import isfunction
# 从 torch 库中导入 nn 和 einsum 模块
from torch import nn, einsum
# 从 einops 库中导入 rearrange 函数
from einops import rearrange

# 从 denoising_diffusion_pytorch.denoising_diffusion_pytorch 模块中导入 GaussianDiffusion 类

# 辅助函数

# 判断变量是否存在的函数
def exists(x):
    return x is not None

# 默认值函数,如果值存在则返回该值,否则返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

# 在我的改进中
# 模型学习预测噪声和 x0,并根据时间步长学习加权和

# WeightedObjectiveGaussianDiffusion 类继承自 GaussianDiffusion 类
class WeightedObjectiveGaussianDiffusion(GaussianDiffusion):
    def __init__(
        self,
        model,
        *args,
        pred_noise_loss_weight = 0.1,
        pred_x_start_loss_weight = 0.1,
        **kwargs
    ):
        # 调用父类的构造函数
        super().__init__(model, *args, **kwargs)
        channels = model.channels
        # 断言模型输出维度必须是通道数的两倍加上 2
        assert model.out_dim == (channels * 2 + 2), 'dimension out (out_dim) of unet must be twice the number of channels + 2 (for the softmax weighted sum) - for channels of 3, this should be (3 * 2) + 2 = 8'
        assert not model.self_condition, 'not supported yet'
        assert not self.is_ddim_sampling, 'ddim sampling cannot be used'

        self.split_dims = (channels, channels, 2)
        self.pred_noise_loss_weight = pred_noise_loss_weight
        self.pred_x_start_loss_weight = pred_x_start_loss_weight

    # 计算均值和方差
    def p_mean_variance(self, *, x, t, clip_denoised, model_output = None):
        # 调用模型得到输出
        model_output = self.model(x, t)

        # 将模型输出拆分为预测的噪声、预测的 x_start 和权重
        pred_noise, pred_x_start, weights = model_output.split(self.split_dims, dim = 1)
        normalized_weights = weights.softmax(dim = 1)

        # 从预测的噪声中预测 x_start
        x_start_from_noise = self.predict_start_from_noise(x, t = t, noise = pred_noise)
        
        x_starts = torch.stack((x_start_from_noise, pred_x_start), dim = 1)
        weighted_x_start = einsum('b j h w, b j c h w -> b c h w', normalized_weights, x_starts)

        if clip_denoised:
            weighted_x_start.clamp_(-1., 1.)

        # 计算模型的均值、方差和对数方差
        model_mean, model_variance, model_log_variance = self.q_posterior(weighted_x_start, x, t)

        return model_mean, model_variance, model_log_variance

    # 计算损失
    def p_losses(self, x_start, t, noise = None, clip_denoised = False):
        noise = default(noise, lambda: torch.randn_like(x_start))
        x_t = self.q_sample(x_start = x_start, t = t, noise = noise)

        model_output = self.model(x_t, t)
        pred_noise, pred_x_start, weights = model_output.split(self.split_dims, dim = 1)

        # 计算预测噪声和 x_start 的损失,并乘以初始化时给定的损失权重
        noise_loss = F.mse_loss(noise, pred_noise) * self.pred_noise_loss_weight
        x_start_loss = F.mse_loss(x_start, pred_x_start) * self.pred_x_start_loss_weight

        # 从预测的噪声中计算 x_start,然后对 x_start 预测和模型预测的权重进行加权和
        x_start_from_pred_noise = self.predict_start_from_noise(x_t, t, pred_noise)
        x_start_from_pred_noise = x_start_from_pred_noise.clamp(-2., 2.)
        weighted_x_start = einsum('b j h w, b j c h w -> b c h w', weights.softmax(dim = 1), torch.stack((x_start_from_pred_noise, pred_x_start), dim = 1))

        # 主要损失为 x_start 与加权 x_start 的均方误差
        weighted_x_start_loss = F.mse_loss(x_start, weighted_x_start)
        return weighted_x_start_loss + x_start_loss + noise_loss

.\lucidrains\denoising-diffusion-pytorch\denoising_diffusion_pytorch\__init__.py

# 导入 denoising_diffusion_pytorch 库中的 GaussianDiffusion、Unet、Trainer 类
from denoising_diffusion_pytorch.denoising_diffusion_pytorch import GaussianDiffusion, Unet, Trainer

# 导入 denoising_diffusion_pytorch 库中的 LearnedGaussianDiffusion 类
from denoising_diffusion_pytorch.learned_gaussian_diffusion import LearnedGaussianDiffusion

# 导入 denoising_diffusion_pytorch 库中的 ContinuousTimeGaussianDiffusion 类
from denoising_diffusion_pytorch.continuous_time_gaussian_diffusion import ContinuousTimeGaussianDiffusion

# 导入 denoising_diffusion_pytorch 库中的 WeightedObjectiveGaussianDiffusion 类
from denoising_diffusion_pytorch.weighted_objective_gaussian_diffusion import WeightedObjectiveGaussianDiffusion

# 导入 denoising_diffusion_pytorch 库中的 ElucidatedDiffusion 类
from denoising_diffusion_pytorch.elucidated_diffusion import ElucidatedDiffusion

# 导入 denoising_diffusion_pytorch 库中的 VParamContinuousTimeGaussianDiffusion 类
from denoising_diffusion_pytorch.v_param_continuous_time_gaussian_diffusion import VParamContinuousTimeGaussianDiffusion

# 导入 denoising_diffusion_pytorch 库中的 GaussianDiffusion1D、Unet1D、Trainer1D、Dataset1D 类
from denoising_diffusion_pytorch.denoising_diffusion_pytorch_1d import GaussianDiffusion1D, Unet1D, Trainer1D, Dataset1D

# 导入 denoising_diffusion_pytorch 库中的 KarrasUnet、InvSqrtDecayLRSched 类
from denoising_diffusion_pytorch.karras_unet import KarrasUnet, InvSqrtDecayLRSched

# 导入 denoising_diffusion_pytorch 库中的 KarrasUnet1D 类
from denoising_diffusion_pytorch.karras_unet_1d import KarrasUnet1D

# 导入 denoising_diffusion_pytorch 库中的 KarrasUnet3D 类
from denoising_diffusion_pytorch.karras_unet_3d import KarrasUnet3D

Denoising Diffusion Probabilistic Model, in Pytorch

Implementation of Denoising Diffusion Probabilistic Model in Pytorch. It is a new approach to generative modeling that may have the potential to rival GANs. It uses denoising score matching to estimate the gradient of the data distribution, followed by Langevin sampling to sample from the true distribution.

This implementation was inspired by the official Tensorflow version here

Youtube AI Educators - Yannic Kilcher | AI Coffeebreak with Letitia | Outlier

Flax implementation from YiYi Xu

Annotated code by Research Scientists / Engineers from 🤗 Huggingface

Update: Turns out none of the technicalities really matters at all | "Cold Diffusion" paper | Muse

PyPI version

Install

$ pip install denoising_diffusion_pytorch

Usage

import torch
from denoising_diffusion_pytorch import Unet, GaussianDiffusion

model = Unet(
    dim = 64,
    dim_mults = (1, 2, 4, 8),
    flash_attn = True
)

diffusion = GaussianDiffusion(
    model,
    image_size = 128,
    timesteps = 1000    # number of steps
)

training_images = torch.rand(8, 3, 128, 128) # images are normalized from 0 to 1
loss = diffusion(training_images)
loss.backward()

# after a lot of training

sampled_images = diffusion.sample(batch_size = 4)
sampled_images.shape # (4, 3, 128, 128)

Or, if you simply want to pass in a folder name and the desired image dimensions, you can use the Trainer class to easily train a model.

from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer

model = Unet(
    dim = 64,
    dim_mults = (1, 2, 4, 8),
    flash_attn = True
)

diffusion = GaussianDiffusion(
    model,
    image_size = 128,
    timesteps = 1000,           # number of steps
    sampling_timesteps = 250    # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
)

trainer = Trainer(
    diffusion,
    'path/to/your/images',
    train_batch_size = 32,
    train_lr = 8e-5,
    train_num_steps = 700000,         # total training steps
    gradient_accumulate_every = 2,    # gradient accumulation steps
    ema_decay = 0.995,                # exponential moving average decay
    amp = True,                       # turn on mixed precision
    calculate_fid = True              # whether to calculate fid during training
)

trainer.train()

Samples and model checkpoints will be logged to ./results periodically

Multi-GPU Training

The Trainer class is now equipped with 🤗 Accelerator. You can easily do multi-gpu training in two steps using their accelerate CLI

At the project root directory, where the training script is, run

$ accelerate config

Then, in the same directory

$ accelerate launch train.py

Miscellaneous

1D Sequence

By popular request, a 1D Unet + Gaussian Diffusion implementation.

import torch
from denoising_diffusion_pytorch import Unet1D, GaussianDiffusion1D, Trainer1D, Dataset1D

model = Unet1D(
    dim = 64,
    dim_mults = (1, 2, 4, 8),
    channels = 32
)

diffusion = GaussianDiffusion1D(
    model,
    seq_length = 128,
    timesteps = 1000,
    objective = 'pred_v'
)

training_seq = torch.rand(64, 32, 128) # features are normalized from 0 to 1
dataset = Dataset1D(training_seq)  # this is just an example, but you can formulate your own Dataset and pass it into the `Trainer1D` below

loss = diffusion(training_seq)
loss.backward()

# Or using trainer

trainer = Trainer1D(
    diffusion,
    dataset = dataset,
    train_batch_size = 32,
    train_lr = 8e-5,
    train_num_steps = 700000,         # total training steps
    gradient_accumulate_every = 2,    # gradient accumulation steps
    ema_decay = 0.995,                # exponential moving average decay
    amp = True,                       # turn on mixed precision
)
trainer.train()

# after a lot of training

sampled_seq = diffusion.sample(batch_size = 4)
sampled_seq.shape # (4, 32, 128)

Trainer1D does not evaluate the generated samples in any way since the type of data is not known.

You could consider adding a suitable metric to the training loop yourself after doing an editable install of this package
pip install -e ..

Citations

@inproceedings{NEURIPS2020_4c5bcfec,
    author      = {Ho, Jonathan and Jain, Ajay and Abbeel, Pieter},
    booktitle   = {Advances in Neural Information Processing Systems},
    editor      = {H. Larochelle and M. Ranzato and R. Hadsell and M.F. Balcan and H. Lin},
    pages       = {6840--6851},
    publisher   = {Curran Associates, Inc.},
    title       = {Denoising Diffusion Probabilistic Models},
    url         = {https://proceedings.neurips.cc/paper/2020/file/4c5bcfec8584af0d967f1ab10179ca4b-Paper.pdf},
    volume      = {33},
    year        = {2020}
}
@InProceedings{pmlr-v139-nichol21a,
    title       = {Improved Denoising Diffusion Probabilistic Models},
    author      = {Nichol, Alexander Quinn and Dhariwal, Prafulla},
    booktitle   = {Proceedings of the 38th International Conference on Machine Learning},
    pages       = {8162--8171},
    year        = {2021},
    editor      = {Meila, Marina and Zhang, Tong},
    volume      = {139},
    series      = {Proceedings of Machine Learning Research},
    month       = {18--24 Jul},
    publisher   = {PMLR},
    pdf         = {http://proceedings.mlr.press/v139/nichol21a/nichol21a.pdf},
    url         = {https://proceedings.mlr.press/v139/nichol21a.html},
}
@inproceedings{kingma2021on,
    title       = {On Density Estimation with Diffusion Models},
    author      = {Diederik P Kingma and Tim Salimans and Ben Poole and Jonathan Ho},
    booktitle   = {Advances in Neural Information Processing Systems},
    editor      = {A. Beygelzimer and Y. Dauphin and P. Liang and J. Wortman Vaughan},
    year        = {2021},
    url         = {https://openreview.net/forum?id=2LdBqxc1Yv}
}
@article{Karras2022ElucidatingTD,
    title   = {Elucidating the Design Space of Diffusion-Based Generative Models},
    author  = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2206.00364}
}
@article{Song2021DenoisingDI,
    title   = {Denoising Diffusion Implicit Models},
    author  = {Jiaming Song and Chenlin Meng and Stefano Ermon},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2010.02502}
}
@misc{chen2022analog,
    title   = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
    author  = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
    year    = {2022},
    eprint  = {2208.04202},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@article{Salimans2022ProgressiveDF,
    title   = {Progressive Distillation for Fast Sampling of Diffusion Models},
    author  = {Tim Salimans and Jonathan Ho},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2202.00512}
}
@article{Ho2022ClassifierFreeDG,
    title   = {Classifier-Free Diffusion Guidance},
    author  = {Jonathan Ho},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2207.12598}
}
@article{Sunkara2022NoMS,
    title   = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
    author  = {Raja Sunkara and Tie Luo},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2208.03641}
}
@inproceedings{Jabri2022ScalableAC,
    title   = {Scalable Adaptive Computation for Iterative Generation},
    author  = {A. Jabri and David J. Fleet and Ting Chen},
    year    = {2022}
}
@article{Cheng2022DPMSolverPlusPlus,
    title   = {DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models},
    author  = {Cheng Lu and Yuhao Zhou and Fan Bao and Jianfei Chen and Chongxuan Li and Jun Zhu},
    journal = {NeuRips 2022 Oral},
    year    = {2022},
    volume  = {abs/2211.01095}
}
@inproceedings{Hoogeboom2023simpleDE,
    title   = {simple diffusion: End-to-end diffusion for high resolution images},
    author  = {Emiel Hoogeboom and Jonathan Heek and Tim Salimans},
    year    = {2023}
}
@misc{https://doi.org/10.48550/arxiv.2302.01327,
    doi     = {10.48550/ARXIV.2302.01327},
    url     = {https://arxiv.org/abs/2302.01327},
    author  = {Kumar, Manoj and Dehghani, Mostafa and Houlsby, Neil},
    title   = {Dual PatchNorm},
    publisher = {arXiv},
    year    = {2023},
    copyright = {Creative Commons Attribution 4.0 International}
}
@inproceedings{Hang2023EfficientDT,
    title   = {Efficient Diffusion Training via Min-SNR Weighting Strategy},
    author  = {Tiankai Hang and Shuyang Gu and Chen Li and Jianmin Bao and Dong Chen and Han Hu and Xin Geng and Baining Guo},
    year    = {2023}
}
@misc{Guttenberg2023,
    author  = {Nicholas Guttenberg},
    url     = {https://www.crosslabs.org/blog/diffusion-with-offset-noise}
}
@inproceedings{Lin2023CommonDN,
    title   = {Common Diffusion Noise Schedules and Sample Steps are Flawed},
    author  = {Shanchuan Lin and Bingchen Liu and Jiashi Li and Xiao Yang},
    year    = {2023}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@article{Bondarenko2023QuantizableTR,
    title   = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
    author  = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2306.12929},
    url     = {https://api.semanticscholar.org/CorpusID:259224568}
}
@article{Karras2023AnalyzingAI,
    title   = {Analyzing and Improving the Training Dynamics of Diffusion Models},
    author  = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2312.02696},
    url     = {https://api.semanticscholar.org/CorpusID:265659032}
}

.\lucidrains\denoising-diffusion-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 执行版本文件中的代码,将版本信息导入当前环境
exec(open('denoising_diffusion_pytorch/version.py').read())

# 设置包的元数据
setup(
  name = 'denoising-diffusion-pytorch', # 包名
  packages = find_packages(), # 查找所有包
  version = __version__, # 使用之前导入的版本信息
  license='MIT', # 许可证
  description = 'Denoising Diffusion Probabilistic Models - Pytorch', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  url = 'https://github.com/lucidrains/denoising-diffusion-pytorch', # 项目链接
  long_description_content_type = 'text/markdown', # 长描述内容类型
  keywords = [
    'artificial intelligence', # 关键词
    'generative models'
  ],
  install_requires=[ # 安装依赖
    'accelerate',
    'einops',
    'ema-pytorch>=0.4.2',
    'numpy',
    'pillow',
    'pytorch-fid',
    'torch',
    'torchvision',
    'tqdm'
  ],
  classifiers=[ # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Differentiable Signed Distance Function Rendering - Pytorch (wip)

Implementation of Differentiable Sign-Distance Function Rendering from EPFL - in Pytorch.

Citations

@article{Vicini2022sdf,
    author = {Delio Vicini and Sébastien Speierer and Wenzel Jakob},
    title = {Differentiable Signed Distance Function Rendering},
    journal = {Transactions on Graphics (Proceedings of SIGGRAPH)},
    volume = {41},
    number = {4},
    pages = {125:1--125:18},
    year = {2022},
    month = jul,
    doi = {10.1145/3528223.3530139}
}

.\lucidrains\diffusion-policy\diffusion_policy\diffusion_policy.py

import math
from pathlib import Path
from random import random
from functools import partial
from multiprocessing import cpu_count

import torch
import torch.nn.functional as F
from torch.nn import Module, ModuleList
from torch import nn, einsum, Tensor
from torch.special import expm1
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms as T, utils

from beartype import beartype

from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange

from PIL import Image
from tqdm.auto import tqdm

from ema_pytorch import EMA

from accelerate import (
    Accelerator,
    DistributedDataParallelKwargs
)

# helpers functions

# 检查变量是否存在
def exists(x):
    return x is not None

# 返回输入值
def identity(x):
    return x

# 如果值存在则返回该值,否则返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 检查一个数是否可以被另一个数整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# 安全地进行除法运算
def safe_div(numer, denom, eps = 1e-10):
    return numer / denom.clamp(min = eps)

# 生成数据集的循环迭代器
def cycle(dl):
    while True:
        for data in dl:
            yield data

# 检查一个数是否有整数平方根
def has_int_squareroot(num):
    num_sqrt = math.sqrt(num)
    return int(num_sqrt) == num_sqrt

# 将一个数分成若干组
def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

# 将图像转换为指定类型
def convert_image_to(img_type, image):
    if image.mode != img_type:
        return image.convert(img_type)
    return image

# 创建序列模块
def Sequential(*mods):
    return nn.Sequential(*filter(exists, mods))

# normalize and unnormalize image

# 归一化图像
def normalize_img(x):
    return x * 2 - 1

# 反归一化图像
def unnormalize_img(x):
    return (x + 1) * 0.5

# 标准化有噪声图像的方差(如果比例不为1)
def normalize_img_variance(x, eps = 1e-5):
    std = reduce(x, 'b c h w -> b 1 1 1', partial(torch.std, unbiased = False))
    return x / std.clamp(min = eps)

# helper functions

# 计算对数
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

# 将右侧维度填充到与左侧相同的维度
def right_pad_dims_to(x, t):
    padding_dims = x.ndim - t.ndim
    if padding_dims <= 0:
        return t
    return t.view(*t.shape, *((1,) * padding_dims))

# noise schedules

# 简单线性调度
def simple_linear_schedule(t, clip_min = 1e-9):
    return (1 - t).clamp(min = clip_min)

# 余弦调度
def cosine_schedule(t, start = 0, end = 1, tau = 1, clip_min = 1e-9):
    power = 2 * tau
    v_start = math.cos(start * math.pi / 2) ** power
    v_end = math.cos(end * math.pi / 2) ** power
    output = math.cos((t * (end - start) + start) * math.pi / 2) ** power
    output = (v_end - output) / (v_end - v_start)
    return output.clamp(min = clip_min)

# sigmoid调度
def sigmoid_schedule(t, start = -3, end = 3, tau = 1, clamp_min = 1e-9):
    v_start = torch.tensor(start / tau).sigmoid()
    v_end = torch.tensor(end / tau).sigmoid()
    gamma = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
    return gamma.clamp_(min = clamp_min, max = 1.)

# 将gamma转换为alpha、sigma或logsnr

# 将gamma转换为alpha和sigma
def gamma_to_alpha_sigma(gamma, scale = 1):
    return torch.sqrt(gamma) * scale, torch.sqrt(1 - gamma)

# 将gamma转换为logsnr
def gamma_to_log_snr(gamma, scale = 1, eps = 1e-5):
    return log(gamma * (scale ** 2) / (1 - gamma), eps = eps)

# gaussian diffusion

# 扩散策略类
class DiffusionPolicy(Module):
    @beartype
    def __init__(
        self,
        model: Module,
        *,
        timesteps = 1000,
        use_ddim = True,
        noise_schedule = 'sigmoid',
        objective = 'v',
        schedule_kwargs: dict = dict(),
        time_difference = 0.,
        min_snr_loss_weight = True,
        min_snr_gamma = 5,
        train_prob_self_cond = 0.9,
        scale = 1.                      # this will be set to < 1. for better convergence when training on higher resolution images
        # 调用父类的构造函数
        super().__init__()
        # 设置模型和通道数
        self.model = model
        self.channels = self.model.channels

        # 确保目标是预测 x0 或者噪声
        assert objective in {'x0', 'eps', 'v'}, 'objective must be either predict x0 or noise'
        self.objective = objective

        # 设置图像大小
        self.image_size = model.image_size

        # 根据噪声调度设置不同的 gamma 调度函数
        if noise_schedule == "linear":
            self.gamma_schedule = simple_linear_schedule
        elif noise_schedule == "cosine":
            self.gamma_schedule = cosine_schedule
        elif noise_schedule == "sigmoid":
            self.gamma_schedule = sigmoid_schedule
        else:
            raise ValueError(f'invalid noise schedule {noise_schedule}')

        # 根据图像尺寸调整噪声大小
        assert scale <= 1, 'scale must be less than or equal to 1'
        self.scale = scale
        self.maybe_normalize_img_variance = normalize_img_variance if scale < 1 else identity

        # 设置 gamma 调度函数的参数
        self.gamma_schedule = partial(self.gamma_schedule, **schedule_kwargs)

        # 设置采样时间步数和是否使用 DDIM
        self.timesteps = timesteps
        self.use_ddim = use_ddim

        # 根据论文提出的方法,将时间差加到 time_next 上,以修复自条件不足和在采样时间步数 < 400 时降低 FID 的问题
        self.time_difference = time_difference

        # 训练过程中自条件的概率
        self.train_prob_self_cond = train_prob_self_cond

        # 最小 SNR 损失权重
        self.min_snr_loss_weight = min_snr_loss_weight
        self.min_snr_gamma = min_snr_gamma

    # 返回设备信息
    @property
    def device(self):
        return next(self.model.parameters()).device

    # 获取采样时间步数
    def get_sampling_timesteps(self, batch, *, device):
        # 在设备上生成时间步数
        times = torch.linspace(1., 0., self.timesteps + 1, device=device)
        times = repeat(times, 't -> b t', b=batch)
        times = torch.stack((times[:, :-1], times[:, 1:]), dim=0)
        times = times.unbind(dim=-1)
        return times

    # 禁用梯度计算
    @torch.no_grad()
    # 从 DDPM 模型中采样生成图像
    def ddpm_sample(self, shape, time_difference = None):
        # 获取批次大小和设备信息
        batch, device = shape[0], self.device

        # 设置时间差,默认为 None
        time_difference = default(time_difference, self.time_difference)

        # 获取采样时间步骤对
        time_pairs = self.get_sampling_timesteps(batch, device = device)

        # 生成随机噪声图像
        img = torch.randn(shape, device=device)

        x_start = None
        last_latents = None

        # 遍历时间步骤对
        for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', total = self.timesteps):

            # 添加时间延迟
            time_next = (time_next - self.time_difference).clamp(min = 0.)

            noise_cond = time

            # 获取预测的 x0
            maybe_normalized_img = self.maybe_normalize_img_variance(img)
            model_output, last_latents = self.model(maybe_normalized_img, noise_cond, x_start, last_latents, return_latents = True)

            # 获取 log(snr)
            gamma = self.gamma_schedule(time)
            gamma_next = self.gamma_schedule(time_next)
            gamma, gamma_next = map(partial(right_pad_dims_to, img), (gamma, gamma_next))

            # 获取 alpha 和 sigma
            alpha, sigma = gamma_to_alpha_sigma(gamma, self.scale)
            alpha_next, sigma_next = gamma_to_alpha_sigma(gamma_next, self.scale)

            # 计算 x0 和噪声
            if self.objective == 'x0':
                x_start = model_output
            elif self.objective == 'eps':
                x_start = safe_div(img - sigma * model_output, alpha)
            elif self.objective == 'v':
                x_start = alpha * img - sigma * model_output

            # 限制 x0 的取值范围
            x_start.clamp_(-1., 1.)

            # 推导后验均值和方差
            log_snr, log_snr_next = map(gamma_to_log_snr, (gamma, gamma_next))
            c = -expm1(log_snr - log_snr_next)
            mean = alpha_next * (img * (1 - c) / alpha + c * x_start)
            variance = (sigma_next ** 2) * c
            log_variance = log(variance)

            # 生成噪声
            noise = torch.where(
                rearrange(time_next > 0, 'b -> b 1 1 1'),
                torch.randn_like(img),
                torch.zeros_like(img)
            )

            # 更新图像
            img = mean + (0.5 * log_variance).exp() * noise

        # 返回未归一化的图像
        return unnormalize_img(img)

    # 禁用梯度计算
    @torch.no_grad()
    # 从给定形状中获取批次和设备信息
    def ddim_sample(self, shape, time_difference = None):
        batch, device = shape[0], self.device

        # 设置时间差,默认为None
        time_difference = default(time_difference, self.time_difference)

        # 获取采样时间步骤
        time_pairs = self.get_sampling_timesteps(batch, device = device)

        # 生成符合正态分布的随机张量
        img = torch.randn(shape, device = device)

        x_start = None
        last_latents = None

        # 遍历时间对
        for times, times_next in tqdm(time_pairs, desc = 'sampling loop time step'):

            # 获取时间和噪声水平
            gamma = self.gamma_schedule(times)
            gamma_next = self.gamma_schedule(times_next)

            # 对gamma进行填充
            padded_gamma, padded_gamma_next = map(partial(right_pad_dims_to, img), (gamma, gamma_next))

            # 将gamma转换为alpha和sigma
            alpha, sigma = gamma_to_alpha_sigma(padded_gamma, self.scale)
            alpha_next, sigma_next = gamma_to_alpha_sigma(padded_gamma_next, self.scale)

            # 添加时间延迟
            times_next = (times_next - time_difference).clamp(min = 0.)

            # 预测x0
            maybe_normalized_img = self.maybe_normalize_img_variance(img)
            model_output, last_latents = self.model(maybe_normalized_img, times, x_start, last_latents, return_latents = True)

            # 计算x0和噪声
            if self.objective == 'x0':
                x_start = model_output
            elif self.objective == 'eps':
                x_start = safe_div(img - sigma * model_output, alpha)
            elif self.objective == 'v':
                x_start = alpha * img - sigma * model_output

            # 限制x0的范围在[-1, 1]之间
            x_start.clamp_(-1., 1.)

            # 获取预测的噪声
            pred_noise = safe_div(img - alpha * x_start, sigma)

            # 计算下一个x
            img = x_start * alpha_next + pred_noise * sigma_next

        # 返回未标准化的图像
        return unnormalize_img(img)

    # 无需梯度计算
    @torch.no_grad()
    def sample(self, batch_size = 16):
        image_size, channels = self.image_size, self.channels
        # 根据是否使用DDIM选择采样函数
        sample_fn = self.ddpm_sample if not self.use_ddim else self.ddim_sample
        return sample_fn((batch_size, channels, image_size, image_size))
    # 定义一个前向传播函数,接受图像和其他参数
    def forward(self, img, *args, **kwargs):
        # 解包图像的形状和设备信息
        batch, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
        # 断言图像的高度和宽度必须为指定的图像大小
        assert h == img_size and w == img_size, f'height and width of image must be {img_size}'

        # 生成随机时间采样
        times = torch.zeros((batch,), device=device).float().uniform_(0, 1.)

        # 将图像转换为比特表示
        img = normalize_img(img)

        # 生成噪声样本
        noise = torch.randn_like(img)

        # 计算 gamma 值
        gamma = self.gamma_schedule(times)
        padded_gamma = right_pad_dims_to(img, gamma)
        alpha, sigma = gamma_to_alpha_sigma(padded_gamma, self.scale)

        # 添加噪声到图像
        noised_img = alpha * img + sigma * noise

        # 可能对图像进行归一化处理
        noised_img = self.maybe_normalize_img_variance(noised_img)

        # 在论文中,他们必须使用非常高的概率进行潜在的自我条件,高达 90% 的时间
        # 稍微有点缺点
        self_cond = self_latents = None

        if random() < self.train_prob_self_cond:
            with torch.no_grad():
                model_output, self_latents = self.model(noised_img, times, return_latents=True)
                self_latents = self_latents.detach()

                if self.objective == 'x0':
                    self_cond = model_output

                elif self.objective == 'eps':
                    self_cond = safe_div(noised_img - sigma * model_output, alpha)

                elif self.objective == 'v':
                    self_cond = alpha * noised_img - sigma * model_output

                self_cond.clamp_(-1., 1.)
                self_cond = self_cond.detach()

        # 预测并进行梯度下降步骤
        pred = self.model(noised_img, times, self_cond, self_latents)

        if self.objective == 'eps':
            target = noise

        elif self.objective == 'x0':
            target = img

        elif self.objective == 'v':
            target = alpha * noise - sigma * img

        # 计算损失
        loss = F.mse_loss(pred, target, reduction='none')
        loss = reduce(loss, 'b ... -> b', 'mean')

        # 最小信噪比损失权重
        snr = (alpha * alpha) / (sigma * sigma)
        maybe_clipped_snr = snr.clone()

        if self.min_snr_loss_weight:
            maybe_clipped_snr.clamp_(max=self.min_snr_gamma)

        if self.objective == 'eps':
            loss_weight = maybe_clipped_snr / snr

        elif self.objective == 'x0':
            loss_weight = maybe_clipped_snr

        elif self.objective == 'v':
            loss_weight = maybe_clipped_snr / (snr + 1)

        return (loss * loss_weight).mean()
# dataset classes

# 定义 Dataset 类,继承自 torch.utils.data.Dataset
class Dataset(Dataset):
    # 初始化函数
    def __init__(
        self,
        folder,  # 数据集文件夹路径
        image_size,  # 图像大小
        exts = ['jpg', 'jpeg', 'png', 'tiff'],  # 图像文件扩展名列表
        augment_horizontal_flip = False,  # 是否进行水平翻转增强
        convert_image_to = None  # 图像转换函数
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        # 获取文件夹中指定扩展名的所有文件路径
        self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

        # 部分应用转换函数
        maybe_convert_fn = partial(convert_image_to, convert_image_to) if exists(convert_image_to) else nn.Identity()

        # 图像转换操作序列
        self.transform = T.Compose([
            T.Lambda(maybe_convert_fn),
            T.Resize(image_size),
            T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

    # 返回数据集长度
    def __len__(self):
        return len(self.paths)

    # 获取指定索引处的数据
    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)

# trainer class

# 定义 Trainer 类
@beartype
class Trainer(object):
    # 初始化函数
    def __init__(
        self,
        diffusion_model: GaussianDiffusion,  # 扩散模型
        folder,  # 数据集文件夹路径
        *,
        train_batch_size = 16,  # 训练批量大小
        gradient_accumulate_every = 1,  # 梯度累积步数
        augment_horizontal_flip = True,  # 是否进行水平翻转增强
        train_lr = 1e-4,  # 训练学习率
        train_num_steps = 100000,  # 训练步数
        max_grad_norm = 1.,  # 梯度裁剪阈值
        ema_update_every = 10,  # EMA 更新频率
        ema_decay = 0.995,  # EMA 衰减率
        betas = (0.9, 0.99),  # Adam 优化器的 beta 参数
        save_and_sample_every = 1000,  # 保存和采样频率
        num_samples = 25,  # 采样数量
        results_folder = './results',  # 结果保存文件夹路径
        amp = False,  # 是否使用混合精度训练
        mixed_precision_type = 'fp16',  # 混合精度类型
        split_batches = True,  # 是否拆分批次
        convert_image_to = None  # 图像转换函数
    ):
        super().__init__()

        # 初始化加速器
        self.accelerator = Accelerator(
            split_batches = split_batches,
            mixed_precision = mixed_precision_type if amp else 'no',
            kwargs_handlers = [DistributedDataParallelKwargs(find_unused_parameters=True)]
        )

        # 设置扩散模型
        self.model = diffusion_model

        # 检查采样数量是否有整数平方根
        assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'
        self.num_samples = num_samples
        self.save_and_sample_every = save_and_sample_every

        self.batch_size = train_batch_size
        self.gradient_accumulate_every = gradient_accumulate_every
        self.max_grad_norm = max_grad_norm

        self.train_num_steps = train_num_steps
        self.image_size = diffusion_model.image_size

        # 数据集和数据加载器

        # 创建数据集对象
        self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip, convert_image_to = convert_image_to)
        # 创建数据加载器
        dl = DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())

        # 准备数据加载器
        dl = self.accelerator.prepare(dl)
        self.dl = cycle(dl)

        # 优化器

        # 创建 Adam 优化器
        self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = betas)

        # 定期记录结果到文件夹

        self.results_folder = Path(results_folder)

        if self.accelerator.is_local_main_process:
            self.results_folder.mkdir(exist_ok = True)

        if self.accelerator.is_main_process:
            self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every)

        # 步数计数器状态

        self.step = 0

        # 准备模型、数据加载器、优化器与加速器

        self.model, self.opt = self.accelerator.prepare(self.model, self.opt)

    # 保存模型
    def save(self, milestone):
        if not self.accelerator.is_local_main_process:
            return

        data = {
            'step': self.step + 1,
            'model': self.accelerator.get_state_dict(self.model),
            'opt': self.opt.state_dict(),
            'ema': self.ema.state_dict(),
            'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None
        }

        torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))
    # 加载指定里程碑的模型数据
    def load(self, milestone):
        # 从文件中加载模型数据
        data = torch.load(str(self.results_folder / f'model-{milestone}.pt'))

        # 获取未加速的模型对象
        model = self.accelerator.unwrap_model(self.model)
        # 加载模型的状态字典
        model.load_state_dict(data['model'])

        # 设置当前训练步数
        self.step = data['step']
        # 加载优化器的状态字典
        self.opt.load_state_dict(data['opt'])

        # 如果是主进程,则加载指数移动平均模型的状态字典
        if self.accelerator.is_main_process:
            self.ema.load_state_dict(data['ema'])

        # 如果加速器和数据中都存在缩放器状态字典,则加载缩放器的状态字典
        if exists(self.accelerator.scaler) and exists(data['scaler']):
            self.accelerator.scaler.load_state_dict(data['scaler'])

    # 训练模型
    def train(self):
        # 获取加速器和设备
        accelerator = self.accelerator
        device = accelerator.device

        # 使用 tqdm 显示训练进度
        with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar:

            # 在未达到训练步数上限前循环训练
            while self.step < self.train_num_steps:

                total_loss = 0.

                # 根据梯度累积次数循环执行训练步骤
                for _ in range(self.gradient_accumulate_every):
                    # 获取下一个数据批次并发送到设备
                    data = next(self.dl).to(device)

                    # 使用自动混合精度计算模型损失
                    with accelerator.autocast():
                        loss = self.model(data)
                        loss = loss / self.gradient_accumulate_every
                        total_loss += loss.item()

                    # 反向传播计算梯度
                    accelerator.backward(loss)

                # 更新进度条显示当前损失值
                pbar.set_description(f'loss: {total_loss:.4f}')

                # 等待所有进程完成当前步骤
                accelerator.wait_for_everyone()
                # 对模型参数进行梯度裁剪
                accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

                # 执行优化器的一步更新
                self.opt.step()
                # 清空梯度
                self.opt.zero_grad()

                # 等待所有进程完成当前步骤
                accelerator.wait_for_everyone()

                # 在每个本地主进程上保存里程碑,仅在全局主进程上采样
                if accelerator.is_local_main_process:
                    milestone = self.step // self.save_and_sample_every
                    save_and_sample = self.step != 0 and self.step % self.save_and_sample_every == 0
                    
                    if accelerator.is_main_process:
                        # 将指数移动平均模型发送到设备
                        self.ema.to(device)
                        # 更新指数移动平均模型
                        self.ema.update()

                        if save_and_sample:
                            # 将指数移动平均模型设置为评估模式
                            self.ema.ema_model.eval()

                            with torch.no_grad():
                                # 将样本数量分组并生成样本图像
                                batches = num_to_groups(self.num_samples, self.batch_size)
                                all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches))

                            all_images = torch.cat(all_images_list, dim = 0)
                            # 保存生成的样本图像
                            utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = int(math.sqrt(self.num_samples)))

                    if save_and_sample:
                        # 保存当前里程碑的模型数据
                        self.save(milestone)

                # 更新训练步数并更新进度条
                self.step += 1
                pbar.update(1)

        # 打印训练完成信息
        accelerator.print('training complete')

.\lucidrains\diffusion-policy\diffusion_policy\__init__.py

# 从diffusion_policy.diffusion_policy模块中导入DiffusionPolicy类
from diffusion_policy.diffusion_policy import DiffusionPolicy

Diffusion Policy (wip)

Implementation of Diffusion Policy, Toyota Research's supposed breakthrough in leveraging DDPMs for learning policies for real-world Robotics

What seemed to have happened is that a research group at Columbia adapted the popular SOTA text-to-image models (complete with denoising diffusion with cross attention conditioning) to policy generation (predicting robot actions conditioned on observations). Toyota research then validated this at a certain scale for imitation learning with real world robotic demonstrations. It is hard to know how much of a breakthrough this is given corporate press is prone to exaggerations, but let me try to get a clean implementation out, just in the case that it is.

The great thing is, if this really works, all the advances being made in text-to-image space can translate to robotics. Yes, this includes stuff like dreambooth.

Todo

Citations

@article{Chi2023DiffusionPV,
    title   = {Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
    author  = {Cheng Chi and Siyuan Feng and Yilun Du and Zhenjia Xu and Eric A. Cousineau and Benjamin Burchfiel and Shuran Song},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2303.04137},
    url     = {https://api.semanticscholar.org/CorpusID:257378658}
}
@article{Sauer2023AdversarialDD,
    title   = {Adversarial Diffusion Distillation},
    author  = {Axel Sauer and Dominik Lorenz and A. Blattmann and Robin Rombach},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2311.17042},
    url     = {https://api.semanticscholar.org/CorpusID:265466173}
}
posted @ 2024-06-28 14:04  绝不原创的飞龙  阅读(18)  评论(0编辑  收藏  举报