Lucidrains-系列项目源码解析-二十-

Lucidrains 系列项目源码解析(二十)

.\lucidrains\imagen-pytorch\imagen_pytorch\imagen_pytorch.py

# 导入数学库
import math
# 从随机模块中导入随机函数
from random import random
# 从 beartype 库中导入 List 和 Union 类型
from beartype.typing import List, Union
# 从 beartype 库中导入 beartype 装饰器
from beartype import beartype
# 从 tqdm 库中导入 tqdm 函数
from tqdm.auto import tqdm
# 从 functools 库中导入 partial 和 wraps 函数
from functools import partial, wraps
# 从 contextlib 库中导入 contextmanager 和 nullcontext 函数
from contextlib import contextmanager, nullcontext
# 从 pathlib 库中导入 Path 类

from pathlib import Path

# 导入 torch 库
import torch
# 从 torch.nn.functional 模块中导入 F 函数
import torch.nn.functional as F
# 从 torch.nn.parallel 模块中导入 DistributedDataParallel 类
from torch.nn.parallel import DistributedDataParallel
# 从 torch 模块中导入 nn 和 einsum 函数
from torch import nn, einsum
# 从 torch.cuda.amp 模块中导入 autocast 函数
from torch.cuda.amp import autocast
# 从 torch.special 模块中导入 expm1 函数
from torch.special import expm1
# 从 torchvision.transforms 模块中导入 T 函数

import torchvision.transforms as T

# 从 kornia.augmentation 模块中导入 K 函数
import kornia.augmentation as K

# 从 einops 模块中导入 rearrange, repeat, reduce, pack, unpack 函数
from einops import rearrange, repeat, reduce, pack, unpack
# 从 einops.layers.torch 模块中导入 Rearrange 类
from einops.layers.torch import Rearrange

# 从 imagen_pytorch.t5 模块中导入 t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME 函数
from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME

# 从 imagen_pytorch.imagen_video 模块中导入 Unet3D, resize_video_to, scale_video_time 函数

from imagen_pytorch.imagen_video import Unet3D, resize_video_to, scale_video_time

# helper functions

# 判断值是否存在
def exists(val):
    return val is not None

# 返回输入值
def identity(t, *args, **kwargs):
    return t

# 判断一个数是否可以被另一个数整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# 返回列表的第一个元素,如果列表为空则返回默认值
def first(arr, d = None):
    if len(arr) == 0:
        return d
    return arr[0]

# 可能的装饰器
def maybe(fn):
    @wraps(fn)
    def inner(x):
        if not exists(x):
            return x
        return fn(x)
    return inner

# 仅执行一次的装饰器
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 仅打印一次的装饰器
print_once = once(print)

# 返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 将输入值转换为元组
def cast_tuple(val, length = None):
    if isinstance(val, list):
        val = tuple(val)

    output = val if isinstance(val, tuple) else ((val,) * default(length, 1))

    if exists(length):
        assert len(output) == length

    return output

# 压缩字典,去除值为 None 的键值对
def compact(input_dict):
    return {key: value for key, value in input_dict.items() if exists(value)}

# 对字典中指定键的值进行转换
def maybe_transform_dict_key(input_dict, key, fn):
    if key not in input_dict:
        return input_dict

    copied_dict = input_dict.copy()
    copied_dict[key] = fn(copied_dict[key])
    return copied_dict

# 将 uint8 类型的图像转换为 float 类型
def cast_uint8_images_to_float(images):
    if not images.dtype == torch.uint8:
        return images
    return images / 255

# 获取模块的设备信息
def module_device(module):
    return next(module.parameters()).device

# 初始化权重为零
def zero_init_(m):
    nn.init.zeros_(m.weight)
    if exists(m.bias):
        nn.init.zeros_(m.bias)

# 模型评估装饰器
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 将元组填充到指定长度
def pad_tuple_to_length(t, length, fillvalue = None):
    remain_length = length - len(t)
    if remain_length <= 0:
        return t
    return (*t, *((fillvalue,) * remain_length))

# helper classes

# 空操作模块
class Identity(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

    def forward(self, x, *args, **kwargs):
        return x

# tensor helpers

# 计算张量的对数
def log(t, eps: float = 1e-12):
    return torch.log(t.clamp(min = eps))

# 计算张量的 L2 范数
def l2norm(t):
    return F.normalize(t, dim = -1)

# 将一个张量的维度右侧填充到与另一个张量相同的维度
def right_pad_dims_to(x, t):
    padding_dims = x.ndim - t.ndim
    if padding_dims <= 0:
        return t
    return t.view(*t.shape, *((1,) * padding_dims))

# 计算带有掩码的张量均值
def masked_mean(t, *, dim, mask = None):
    if not exists(mask):
        return t.mean(dim = dim)

    denom = mask.sum(dim = dim, keepdim = True)
    mask = rearrange(mask, 'b n -> b n 1')
    masked_t = t.masked_fill(~mask, 0.)

    return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)

# 调整图像大小
def resize_image_to(
    image,
    target_image_size,
    clamp_range = None,
    mode = 'nearest'
):
    orig_image_size = image.shape[-1]

    if orig_image_size == target_image_size:
        return image

    out = F.interpolate(image, target_image_size, mode = mode)

    if exists(clamp_range):
        out = out.clamp(*clamp_range)

    return out

# 计算所有帧的维度
def calc_all_frame_dims(
    downsample_factors: List[int],
    frames
):
    # 如果frames不存在,则返回一个空元组的元组,长度为downsample_factors的长度
    if not exists(frames):
        return (tuple(),) * len(downsample_factors)

    # 存储所有帧的维度信息
    all_frame_dims = []

    # 遍历downsample_factors列表
    for divisor in downsample_factors:
        # 断言frames能够被divisor整除
        assert divisible_by(frames, divisor)
        # 将frames除以divisor得到的结果作为元组添加到all_frame_dims列表中
        all_frame_dims.append((frames // divisor,))

    # 返回所有帧的维度信息
    return all_frame_dims
# 安全获取元组中指定索引的值,如果索引超出范围则返回默认值
def safe_get_tuple_index(tup, index, default = None):
    if len(tup) <= index:
        return default
    return tup[index]

# 图像归一化函数
# ddpms 期望图像范围在 -1 到 1 之间

def normalize_neg_one_to_one(img):
    return img * 2 - 1

def unnormalize_zero_to_one(normed_img):
    return (normed_img + 1) * 0.5

# 无分类器指导函数

def prob_mask_like(shape, prob, device):
    if prob == 1:
        return torch.ones(shape, device = device, dtype = torch.bool)
    elif prob == 0:
        return torch.zeros(shape, device = device, dtype = torch.bool)
    else:
        return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

# 连续时间高斯扩散辅助函数和类
# 这部分很大程度上要感谢 @crowsonkb 在 https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py

@torch.jit.script
def beta_linear_log_snr(t):
    return -torch.log(expm1(1e-4 + 10 * (t ** 2)))

@torch.jit.script
def alpha_cosine_log_snr(t, s: float = 0.008):
    return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) # 不确定这是否考虑了在离散版本中 beta 被剪切为 0.999

def log_snr_to_alpha_sigma(log_snr):
    return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr))

class GaussianDiffusionContinuousTimes(nn.Module):
    def __init__(self, *, noise_schedule, timesteps = 1000):
        super().__init__()

        if noise_schedule == "linear":
            self.log_snr = beta_linear_log_snr
        elif noise_schedule == "cosine":
            self.log_snr = alpha_cosine_log_snr
        else:
            raise ValueError(f'invalid noise schedule {noise_schedule}')

        self.num_timesteps = timesteps

    def get_times(self, batch_size, noise_level, *, device):
        return torch.full((batch_size,), noise_level, device = device, dtype = torch.float32)

    def sample_random_times(self, batch_size, *, device):
        return torch.zeros((batch_size,), device = device).float().uniform_(0, 1)

    def get_condition(self, times):
        return maybe(self.log_snr)(times)

    def get_sampling_timesteps(self, batch, *, device):
        times = torch.linspace(1., 0., self.num_timesteps + 1, device = device)
        times = repeat(times, 't -> b t', b = batch)
        times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
        times = times.unbind(dim = -1)
        return times

    def q_posterior(self, x_start, x_t, t, *, t_next = None):
        t_next = default(t_next, lambda: (t - 1. / self.num_timesteps).clamp(min = 0.))

        """ https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """
        log_snr = self.log_snr(t)
        log_snr_next = self.log_snr(t_next)
        log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next))

        alpha, sigma = log_snr_to_alpha_sigma(log_snr)
        alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)

        # c - as defined near eq 33
        c = -expm1(log_snr - log_snr_next)
        posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start)

        # following (eq. 33)
        posterior_variance = (sigma_next ** 2) * c
        posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def q_sample(self, x_start, t, noise = None):
        dtype = x_start.dtype

        if isinstance(t, float):
            batch = x_start.shape[0]
            t = torch.full((batch,), t, device = x_start.device, dtype = dtype)

        noise = default(noise, lambda: torch.randn_like(x_start))
        log_snr = self.log_snr(t).type(dtype)
        log_snr_padded_dim = right_pad_dims_to(x_start, log_snr)
        alpha, sigma =  log_snr_to_alpha_sigma(log_snr_padded_dim)

        return alpha * x_start + sigma * noise, log_snr, alpha, sigma
    # 从输入的 x_from 中采样数据,从 from_t 到 to_t 时间范围内,添加噪声
    def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
        # 获取输入 x_from 的形状、设备和数据类型
        shape, device, dtype = x_from.shape, x_from.device, x_from.dtype
        batch = shape[0]

        # 如果 from_t 是浮点数,则将其转换为与 batch 大小相同的张量
        if isinstance(from_t, float):
            from_t = torch.full((batch,), from_t, device = device, dtype = dtype)

        # 如果 to_t 是浮点数,则将其转换为与 batch 大小相同的张量
        if isinstance(to_t, float):
            to_t = torch.full((batch,), to_t, device = device, dtype = dtype)

        # 如果未提供噪声,则生成一个与 x_from 相同形状的随机噪声张量
        noise = default(noise, lambda: torch.randn_like(x_from))

        # 计算 from_t 对应的 log_snr,并将其维度与 x_from 对齐
        log_snr = self.log_snr(from_t)
        log_snr_padded_dim = right_pad_dims_to(x_from, log_snr)
        # 根据 log_snr 计算 alpha 和 sigma
        alpha, sigma =  log_snr_to_alpha_sigma(log_snr_padded_dim)

        # 计算 to_t 对应的 log_snr,并将其维度与 x_from 对齐
        log_snr_to = self.log_snr(to_t)
        log_snr_padded_dim_to = right_pad_dims_to(x_from, log_snr_to)
        # 根据 log_snr_to 计算 alpha_to 和 sigma_to
        alpha_to, sigma_to =  log_snr_to_alpha_sigma(log_snr_padded_dim_to)

        # 返回根据公式计算得到的结果
        return x_from * (alpha_to / alpha) + noise * (sigma_to * alpha - sigma * alpha_to) / alpha

    # 根据给定的 x_t、t 和速度 v 预测起始值
    def predict_start_from_v(self, x_t, t, v):
        # 计算 t 对应的 log_snr,并将其维度与 x_t 对齐
        log_snr = self.log_snr(t)
        log_snr = right_pad_dims_to(x_t, log_snr)
        # 根据 log_snr 计算 alpha 和 sigma
        alpha, sigma = log_snr_to_alpha_sigma(log_snr)
        # 返回根据公式计算得到的结果
        return alpha * x_t - sigma * v

    # 根据给定的 x_t、t 和噪声 noise 预测起始值
    def predict_start_from_noise(self, x_t, t, noise):
        # 计算 t 对应的 log_snr,并将其维度与 x_t 对齐
        log_snr = self.log_snr(t)
        log_snr = right_pad_dims_to(x_t, log_snr)
        # 根据 log_snr 计算 alpha 和 sigma
        alpha, sigma = log_snr_to_alpha_sigma(log_snr)
        # 返回根据公式计算得到的结果
        return (x_t - sigma * noise) / alpha.clamp(min = 1e-8)
# 定义 LayerNorm 类,用于实现层归一化操作
class LayerNorm(nn.Module):
    # 初始化函数,接受特征数、是否稳定、维度作为参数
    def __init__(self, feats, stable = False, dim = -1):
        super().__init__()
        self.stable = stable
        self.dim = dim

        # 初始化可学习参数 g
        self.g = nn.Parameter(torch.ones(feats, *((1,) * (-dim - 1))))

    # 前向传播函数
    def forward(self, x):
        dtype, dim = x.dtype, self.dim

        # 如果设置了稳定性,对输入进行归一化处理
        if self.stable:
            x = x / x.amax(dim = dim, keepdim = True).detach()

        # 根据数据类型选择 eps 值
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        # 计算方差和均值
        var = torch.var(x, dim = dim, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = dim, keepdim = True)

        # 返回归一化后的结果
        return (x - mean) * (var + eps).rsqrt().type(dtype) * self.g.type(dtype)

# 定义 ChanLayerNorm 类,是 LayerNorm 的一个特例,维度为 -3
ChanLayerNorm = partial(LayerNorm, dim = -3)

# 定义 Always 类,用于返回固定值
class Always():
    def __init__(self, val):
        self.val = val

    def __call__(self, *args, **kwargs):
        return self.val

# 定义 Residual 类,实现残差连接
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

# 定义 Parallel 类,实现并行计算
class Parallel(nn.Module):
    def __init__(self, *fns):
        super().__init__()
        self.fns = nn.ModuleList(fns)

    def forward(self, x):
        outputs = [fn(x) for fn in self.fns]
        return sum(outputs)

# 定义 PerceiverAttention 类,实现注意力机制
class PerceiverAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        scale = 8
    ):
        super().__init__()
        self.scale = scale

        self.heads = heads
        inner_dim = dim_head * heads

        # 初始化层归一化操作和线性变换
        self.norm = nn.LayerNorm(dim)
        self.norm_latents = nn.LayerNorm(dim)
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        # 初始化缩放参数
        self.q_scale = nn.Parameter(torch.ones(dim_head))
        self.k_scale = nn.Parameter(torch.ones(dim_head))

        # 输出层
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            nn.LayerNorm(dim)
        )

    # 前向传播函数
    def forward(self, x, latents, mask = None):
        x = self.norm(x)
        latents = self.norm_latents(latents)

        b, h = x.shape[0], self.heads

        q = self.to_q(latents)

        # 拼接键值对
        kv_input = torch.cat((x, latents), dim = -2)
        k, v = self.to_kv(kv_input).chunk(2, dim = -1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 对 q 和 k 进行 L2 归一化
        q, k = map(l2norm, (q, k))
        q = q * self.q_scale
        k = k * self.k_scale

        # 计算相似度并进行掩码处理
        sim = einsum('... i d, ... j d  -> ... i j', q, k) * self.scale

        if exists(mask):
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = F.pad(mask, (0, latents.shape[-2]), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, max_neg_value)

        # 注意力计算
        attn = sim.softmax(dim = -1, dtype = torch.float32)
        attn = attn.to(sim.dtype)

        out = einsum('... i j, ... j d -> ... i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)', h = h)
        return self.to_out(out)

# 定义 PerceiverResampler 类,实现 Perceiver 模型的重采样
class PerceiverResampler(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        num_latents = 64,
        num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence
        max_seq_len = 512,
        ff_mult = 4
    # 初始化函数,继承父类的初始化方法
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 创建位置编码的嵌入层,用于将位置信息嵌入输入数据中
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        # 创建可学习的潜在变量,用于表示输入数据的潜在特征
        self.latents = nn.Parameter(torch.randn(num_latents, dim))

        # 初始化从平均池化序列到潜在变量的映射层
        self.to_latents_from_mean_pooled_seq = None

        # 如果平均池化的潜在变量数量大于0,则创建映射层
        if num_latents_mean_pooled > 0:
            self.to_latents_from_mean_pooled_seq = nn.Sequential(
                LayerNorm(dim),
                nn.Linear(dim, dim * num_latents_mean_pooled),
                Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled)
            )

        # 创建多层感知器的注意力和前馈网络层
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

    # 前向传播函数,接收输入数据 x 和掩码 mask
    def forward(self, x, mask = None):
        # 获取输入数据的长度和设备信息
        n, device = x.shape[1], x.device
        # 根据位置编码获取位置嵌入
        pos_emb = self.pos_emb(torch.arange(n, device = device))

        # 将输入数据与位置编码相加,融合位置信息
        x_with_pos = x + pos_emb

        # 重复潜在变量以匹配输入数据的维度
        latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])

        # 如果存在平均池化的潜在变量映射层,则将平均池化的潜在变量与原始潜在变量拼接
        if exists(self.to_latents_from_mean_pooled_seq):
            meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool))
            meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
            latents = torch.cat((meanpooled_latents, latents), dim = -2)

        # 遍历多层感知器的注意力和前馈网络层
        for attn, ff in self.layers:
            # 使用注意力层处理输入数据和潜在变量,然后与潜在变量相加
            latents = attn(x_with_pos, latents, mask = mask) + latents
            # 使用前馈网络层处理潜在变量,然后与潜在变量相加
            latents = ff(latents) + latents

        # 返回处理后的潜在变量
        return latents
# 定义注意力机制模块
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        dim_head = 64,
        heads = 8,
        context_dim = None,
        scale = 8
    ):
        super().__init__()
        self.scale = scale

        self.heads = heads
        inner_dim = dim_head * heads

        self.norm = LayerNorm(dim)

        self.null_kv = nn.Parameter(torch.randn(2, dim_head))
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)

        self.q_scale = nn.Parameter(torch.ones(dim_head))
        self.k_scale = nn.Parameter(torch.ones(dim_head))

        self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            LayerNorm(dim)
        )

    def forward(self, x, context = None, mask = None, attn_bias = None):
        b, n, device = *x.shape[:2], x.device

        x = self.norm(x)

        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))

        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)

        # add null key / value for classifier free guidance in prior net

        nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))
        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)

        # add text conditioning, if present

        if exists(context):
            assert exists(self.to_context)
            ck, cv = self.to_context(context).chunk(2, dim = -1)
            k = torch.cat((ck, k), dim = -2)
            v = torch.cat((cv, v), dim = -2)

        # qk rmsnorm

        q, k = map(l2norm, (q, k))
        q = q * self.q_scale
        k = k * self.k_scale

        # calculate query / key similarities

        sim = einsum('b h i d, b j d -> b h i j', q, k) * self.scale

        # relative positional encoding (T5 style)

        if exists(attn_bias):
            sim = sim + attn_bias

        # masking

        max_neg_value = -torch.finfo(sim.dtype).max

        if exists(mask):
            mask = F.pad(mask, (1, 0), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, max_neg_value)

        # attention

        attn = sim.softmax(dim = -1, dtype = torch.float32)
        attn = attn.to(sim.dtype)

        # aggregate values

        out = einsum('b h i j, b j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义上采样函数
def Upsample(dim, dim_out = None):
    dim_out = default(dim_out, dim)

    return nn.Sequential(
        nn.Upsample(scale_factor = 2, mode = 'nearest'),
        nn.Conv2d(dim, dim_out, 3, padding = 1)
    )

# 定义像素混洗上采样类
class PixelShuffleUpsample(nn.Module):
    """
    code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
    https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
    """
    def __init__(self, dim, dim_out = None):
        super().__init__()
        dim_out = default(dim_out, dim)
        conv = nn.Conv2d(dim, dim_out * 4, 1)

        self.net = nn.Sequential(
            conv,
            nn.SiLU(),
            nn.PixelShuffle(2)
        )

        self.init_conv_(conv)

    def init_conv_(self, conv):
        o, i, h, w = conv.weight.shape
        conv_weight = torch.empty(o // 4, i, h, w)
        nn.init.kaiming_uniform_(conv_weight)
        conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')

        conv.weight.data.copy_(conv_weight)
        nn.init.zeros_(conv.bias.data)

    def forward(self, x):
        return self.net(x)

# 定义下采样函数
def Downsample(dim, dim_out = None):
    # https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample
    # named SP-conv in the paper, but basically a pixel unshuffle
    dim_out = default(dim_out, dim)
    # 返回一个包含两个操作的序列:1. 重新排列输入张量的维度,将其转换为'b (c s1 s2) h w'的形式;2. 使用1x1卷积层将输入通道数从dim * 4降至dim_out
    return nn.Sequential(
        # 重新排列输入张量的维度,将其转换为'b (c s1 s2) h w'的形式,其中s1和s2分别为2
        Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2),
        # 使用1x1卷积层将输入通道数从dim * 4降至dim_out
        nn.Conv2d(dim * 4, dim_out, 1)
    )
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)  # 计算对数值
        emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)  # 计算指数值
        emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')  # 重排张量形状
        return torch.cat((emb.sin(), emb.cos()), dim = -1)  # 拼接正弦和余弦值

class LearnedSinusoidalPosEmb(nn.Module):
    """ following @crowsonkb 's lead with learned sinusoidal pos emb """
    """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """

    def __init__(self, dim):
        super().__init__()
        assert (dim % 2) == 0
        half_dim = dim // 2
        self.weights = nn.Parameter(torch.randn(half_dim))  # 初始化权重参数

    def forward(self, x):
        x = rearrange(x, 'b -> b 1')  # 重排张量形状
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi  # 计算频率
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)  # 拼接正弦和余弦值
        fouriered = torch.cat((x, fouriered), dim = -1)  # 拼接原始张量和傅立叶变换结果
        return fouriered

class Block(nn.Module):
    def __init__(
        self,
        dim,
        dim_out,
        groups = 8,
        norm = True
    ):
        super().__init__()
        self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity()  # 初始化分组归一化层
        self.activation = nn.SiLU()  # 激活函数
        self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)  # 卷积层

    def forward(self, x, scale_shift = None):
        x = self.groupnorm(x)  # 分组归一化

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift  # 缩放和平移

        x = self.activation(x)  # 激活函数
        return self.project(x)  # 卷积操作

class ResnetBlock(nn.Module):
    def __init__(
        self,
        dim,
        dim_out,
        *,
        cond_dim = None,
        time_cond_dim = None,
        groups = 8,
        linear_attn = False,
        use_gca = False,
        squeeze_excite = False,
        **attn_kwargs
    ):
        super().__init__()

        self.time_mlp = None

        if exists(time_cond_dim):
            self.time_mlp = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_cond_dim, dim_out * 2)
            )  # 时间条件的多层感��机

        self.cross_attn = None

        if exists(cond_dim):
            attn_klass = CrossAttention if not linear_attn else LinearCrossAttention

            self.cross_attn = attn_klass(
                dim = dim_out,
                context_dim = cond_dim,
                **attn_kwargs
            )  # 交叉注意力机制

        self.block1 = Block(dim, dim_out, groups = groups)  # 第一个块
        self.block2 = Block(dim_out, dim_out, groups = groups)  # 第二个块

        self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1)  # 全局上下文注意力

        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else Identity()  # 残差卷积

    def forward(self, x, time_emb = None, cond = None):

        scale_shift = None
        if exists(self.time_mlp) and exists(time_emb):
            time_emb = self.time_mlp(time_emb)
            time_emb = rearrange(time_emb, 'b c -> b c 1 1')
            scale_shift = time_emb.chunk(2, dim = 1)  # 分割时间嵌入

        h = self.block1(x)  # 第一个块操作

        if exists(self.cross_attn):
            assert exists(cond)
            h = rearrange(h, 'b c h w -> b h w c')
            h, ps = pack([h], 'b * c')
            h = self.cross_attn(h, context = cond) + h  # 交叉注意力机制
            h, = unpack(h, ps, 'b * c')
            h = rearrange(h, 'b h w c -> b c h w')

        h = self.block2(h, scale_shift = scale_shift)  # 第二个块操作

        h = h * self.gca(h)  # 全局上下文注意力

        return h + self.res_conv(x)  # 返回残差连接结果

class CrossAttention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        context_dim = None,
        dim_head = 64,
        heads = 8,
        norm_context = False,
        scale = 8
    # 初始化函数,设置缩放因子和头数
    def __init__(
        super().__init__()
        self.scale = scale

        self.heads = heads
        inner_dim = dim_head * heads

        # 设置上下文维度
        context_dim = default(context_dim, dim)

        # 初始化层归一化
        self.norm = LayerNorm(dim)
        self.norm_context = LayerNorm(context_dim) if norm_context else Identity()

        # 初始化空键值对
        self.null_kv = nn.Parameter(torch.randn(2, dim_head))
        # 线性变换,将输入转换为查询向量
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        # 线性变换,将上下文转换为键值对
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)

        # 初始化查询和键的缩放参数
        self.q_scale = nn.Parameter(torch.ones(dim_head))
        self.k_scale = nn.Parameter(torch.ones(dim_head))

        # 输出层
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            LayerNorm(dim)
        )

    # 前向传播函数
    def forward(self, x, context, mask = None):
        # 获取输入的形状和设备信息
        b, n, device = *x.shape[:2], x.device

        # 对输入和上下文进行层归一化
        x = self.norm(x)
        context = self.norm_context(context)

        # 获取查询、键、值
        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))

        # 重排查询、键、值的维度
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

        # 添加空键/值对,用于分类器在先验网络中的自由引导
        nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads,  b = b), self.null_kv.unbind(dim = -2))
        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)

        # 余弦相似度注意力
        q, k = map(l2norm, (q, k))
        q = q * self.q_scale
        k = k * self.k_scale

        # 计算相似度
        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        # 掩码
        max_neg_value = -torch.finfo(sim.dtype).max
        if exists(mask):
            mask = F.pad(mask, (1, 0), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, max_neg_value)

        # softmax计算注意力权重
        attn = sim.softmax(dim = -1, dtype = torch.float32)
        attn = attn.to(sim.dtype)

        # 加权求和得到输出
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
class LinearCrossAttention(CrossAttention):
    # 线性交叉注意力类,继承自CrossAttention类
    def forward(self, x, context, mask = None):
        # 前向传播函数,接收输入x、上下文context和掩码mask,默认为None
        b, n, device = *x.shape[:2], x.device

        x = self.norm(x)
        # 对输入x进行规范化处理
        context = self.norm_context(context)
        # 对上下文context进行规范化处理

        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        # 将输入x和上下文context转换为查询q、键k和值v

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads), (q, k, v))
        # 对查询q、键k和值v进行形状重排

        # add null key / value for classifier free guidance in prior net
        # 在先前网络中添加空键/值以用于分类器的自由引导

        nk, nv = map(lambda t: repeat(t, 'd -> (b h) 1 d', h = self.heads,  b = b), self.null_kv.unbind(dim = -2))

        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)

        # masking
        # 掩码处理

        max_neg_value = -torch.finfo(x.dtype).max

        if exists(mask):
            mask = F.pad(mask, (1, 0), value = True)
            mask = rearrange(mask, 'b n -> b n 1')
            k = k.masked_fill(~mask, max_neg_value)
            v = v.masked_fill(~mask, 0.)

        # linear attention
        # 线性注意力计算

        q = q.softmax(dim = -1)
        k = k.softmax(dim = -2)

        q = q * self.scale

        context = einsum('b n d, b n e -> b d e', k, v)
        out = einsum('b n d, b d e -> b n e', q, context)
        out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads)
        return self.to_out(out)

class LinearAttention(nn.Module):
    # 线性注意力类,继承自nn.Module类
    def __init__(
        self,
        dim,
        dim_head = 32,
        heads = 8,
        dropout = 0.05,
        context_dim = None,
        **kwargs
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads
        self.norm = ChanLayerNorm(dim)

        self.nonlin = nn.SiLU()

        self.to_q = nn.Sequential(
            nn.Dropout(dropout),
            nn.Conv2d(dim, inner_dim, 1, bias = False),
            nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
        )

        self.to_k = nn.Sequential(
            nn.Dropout(dropout),
            nn.Conv2d(dim, inner_dim, 1, bias = False),
            nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
        )

        self.to_v = nn.Sequential(
            nn.Dropout(dropout),
            nn.Conv2d(dim, inner_dim, 1, bias = False),
            nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
        )

        self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None

        self.to_out = nn.Sequential(
            nn.Conv2d(inner_dim, dim, 1, bias = False),
            ChanLayerNorm(dim)
        )

    def forward(self, fmap, context = None):
        # 前向传播函数,接收特征图fmap和上下文context,默认为None
        h, x, y = self.heads, *fmap.shape[-2:]

        fmap = self.norm(fmap)
        q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v))
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v))

        if exists(context):
            assert exists(self.to_context)
            ck, cv = self.to_context(context).chunk(2, dim = -1)
            ck, cv = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (ck, cv))
            k = torch.cat((k, ck), dim = -2)
            v = torch.cat((v, cv), dim = -2)

        q = q.softmax(dim = -1)
        k = k.softmax(dim = -2)

        q = q * self.scale

        context = einsum('b n d, b n e -> b d e', k, v)
        out = einsum('b n d, b d e -> b n e', q, context)
        out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)

        out = self.nonlin(out)
        return self.to_out(out)

class GlobalContext(nn.Module):
    # 全局上下文类
    """ basically a superior form of squeeze-excitation that is attention-esque """

    def __init__(
        self,
        *,
        dim_in,
        dim_out
    # 定义一个类,继承自 nn.Module
    class Attention(nn.Module):
        # 初始化函数
        def __init__(self, dim_in, dim_out):
            # 调用父类的初始化函数
            super().__init__()
            # 创建一个卷积层,输入维度为 dim_in,输出维度为 1,卷积核大小为 1
            self.to_k = nn.Conv2d(dim_in, 1, 1)
            # 计算隐藏层维度,取 dim_out 除以 2 和 3 中的较大值
            hidden_dim = max(3, dim_out // 2)
    
            # 创建一个神经网络序列
            self.net = nn.Sequential(
                # 第一层卷积层,输入维度为 dim_in,输出维度为 hidden_dim,卷积核大小为 1
                nn.Conv2d(dim_in, hidden_dim, 1),
                # 使用 SiLU 激活函数
                nn.SiLU(),
                # 第二层卷积层,输入维度为 hidden_dim,输出维度为 dim_out,卷积核大小为 1
                nn.Conv2d(hidden_dim, dim_out, 1),
                # 使用 Sigmoid 激活函数
                nn.Sigmoid()
            )
    
        # 前向传播函数
        def forward(self, x):
            # 将输入 x 通过 self.to_k 进行处理,得到 context
            context = self.to_k(x)
            # 对 x 和 context 进行维度重排,将 'b n ...' 转换为 'b n (...)'
            x, context = map(lambda t: rearrange(t, 'b n ... -> b n (...)'), (x, context))
            # 使用 einsum 进行张量乘法,计算注意力权重
            out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x)
            # 将输出 out 进行维度重排,将 '...' 转换为 '... 1'
            out = rearrange(out, '... -> ... 1')
            # 将处理后的 out 输入到神经网络 self.net 中
            return self.net(out)
# 定义一个前馈神经网络模块,包含层归一化、线性层、GELU激活函数和线性层
def FeedForward(dim, mult = 2):
    # 计算隐藏层维度
    hidden_dim = int(dim * mult)
    return nn.Sequential(
        LayerNorm(dim),  # 层归一化
        nn.Linear(dim, hidden_dim, bias = False),  # 线性层
        nn.GELU(),  # GELU激活函数
        LayerNorm(hidden_dim),  # 层归一化
        nn.Linear(hidden_dim, dim, bias = False)  # 线性层
    )

# 定义一个通道前馈神经网络模块,包含通道层归一化、卷积层、GELU激活函数和卷积层
def ChanFeedForward(dim, mult = 2):  # in paper, it seems for self attention layers they did feedforwards with twice channel width
    hidden_dim = int(dim * mult)
    return nn.Sequential(
        ChanLayerNorm(dim),  # 通道层归一化
        nn.Conv2d(dim, hidden_dim, 1, bias = False),  # 卷积层
        nn.GELU(),  # GELU激活函数
        ChanLayerNorm(hidden_dim),  # 通道层归一化
        nn.Conv2d(hidden_dim, dim, 1, bias = False)  # 卷积层
    )

# 定义一个Transformer块,包含多个自注意力层和前馈神经网络层
class TransformerBlock(nn.Module):
    def __init__(
        self,
        dim,
        *,
        depth = 1,
        heads = 8,
        dim_head = 32,
        ff_mult = 2,
        context_dim = None
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),  # 自注意力层
                FeedForward(dim = dim, mult = ff_mult)  # 前馈神经网络层
            ]))

    def forward(self, x, context = None):
        x = rearrange(x, 'b c h w -> b h w c')
        x, ps = pack([x], 'b * c')

        for attn, ff in self.layers:
            x = attn(x, context = context) + x
            x = ff(x) + x

        x, = unpack(x, ps, 'b * c')
        x = rearrange(x, 'b h w c -> b c h w')
        return x

# 定义一个线性注意力Transformer块,包含多个线性注意力层和通道前馈神经网络层
class LinearAttentionTransformerBlock(nn.Module):
    def __init__(
        self,
        dim,
        *,
        depth = 1,
        heads = 8,
        dim_head = 32,
        ff_mult = 2,
        context_dim = None,
        **kwargs
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),  # 线性注意力层
                ChanFeedForward(dim = dim, mult = ff_mult)  # 通道前馈神经网络层
            ]))

    def forward(self, x, context = None):
        for attn, ff in self.layers:
            x = attn(x, context = context) + x
            x = ff(x) + x
        return x

# 定义一个交叉嵌入层,包含多个卷积层
class CrossEmbedLayer(nn.Module):
    def __init__(
        self,
        dim_in,
        kernel_sizes,
        dim_out = None,
        stride = 2
    ):
        super().__init__()
        assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
        dim_out = default(dim_out, dim_in)

        kernel_sizes = sorted(kernel_sizes)
        num_scales = len(kernel_sizes)

        # 计算每个尺度的维度
        dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
        dim_scales = [*dim_scales, dim_out - sum(dim_scales)]

        self.convs = nn.ModuleList([])
        for kernel, dim_scale in zip(kernel_sizes, dim_scales):
            self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))

    def forward(self, x):
        fmaps = tuple(map(lambda conv: conv(x), self.convs))
        return torch.cat(fmaps, dim = 1)

# 定义一个上采样合并器,包含多个块
class UpsampleCombiner(nn.Module):
    def __init__(
        self,
        dim,
        *,
        enabled = False,
        dim_ins = tuple(),
        dim_outs = tuple()
    ):
        super().__init__()
        dim_outs = cast_tuple(dim_outs, len(dim_ins))
        assert len(dim_ins) == len(dim_outs)

        self.enabled = enabled

        if not self.enabled:
            self.dim_out = dim
            return

        self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
        self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
    # 定义一个前向传播函数,接受输入 x 和特征图列表 fmaps,默认为 None
    def forward(self, x, fmaps = None):
        # 获取输入 x 的最后一个维度大小作为目标大小
        target_size = x.shape[-1]

        # 如果未提供特征图列表,则使用空元组
        fmaps = default(fmaps, tuple())

        # 如果模块未启用,特征图列表为空,或者卷积层列表为空,则直接返回输入 x
        if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
            return x

        # 将特征图列表中的每个特征图调整大小为目标大小
        fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps]
        # 对每个调整大小后的特征图应用对应的卷积操作,得到输出列表
        outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
        # 在第一个维度上拼接输入 x 和所有输出,返回结果
        return torch.cat((x, *outs), dim = 1)
# 定义一个名为 Unet 的类,继承自 nn.Module
class Unet(nn.Module):
    # 初始化方法,设置类的属性
    def __init__(
        self,
        *,
        dim,
        text_embed_dim = get_encoded_dim(DEFAULT_T5_NAME),  # 默认文本嵌入维度
        num_resnet_blocks = 1,  # ResNet 块的数量
        cond_dim = None,  # 条件维度
        num_image_tokens = 4,  # 图像令牌数量
        num_time_tokens = 2,  # 时间令牌数量
        learned_sinu_pos_emb_dim = 16,  # 学习的正弦位置编码维度
        out_dim = None,  # 输出维度
        dim_mults=(1, 2, 4, 8),  # 维度倍增
        cond_images_channels = 0,  # 条件图像通道数
        channels = 3,  # 通道数
        channels_out = None,  # 输出通道数
        attn_dim_head = 64,  # 注意力头维度
        attn_heads = 8,  # 注意力头数量
        ff_mult = 2.,  # FeedForward 层倍增因子
        lowres_cond = False,  # 低分辨率条件
        layer_attns = True,  # 层间注意力
        layer_attns_depth = 1,  # 层间注意力深度
        layer_mid_attns_depth = 1,  # 中间层注意力深度
        layer_attns_add_text_cond = True,  # 是否使用文本嵌入来条件化自注意力块
        attend_at_middle = True,  # 是否在瓶颈处进行注意力
        layer_cross_attns = True,  # 层间交叉注意力
        use_linear_attn = False,  # 是否使用线性注意力
        use_linear_cross_attn = False,  # 是否使用线性交叉注意力
        cond_on_text = True,  # 是否在文本上进行条件化
        max_text_len = 256,  # 最大文本长度
        init_dim = None,  # 初始化维度
        resnet_groups = 8,  # ResNet 组数
        init_conv_kernel_size = 7,  # 初始卷积核大小
        init_cross_embed = True,  # 初始化交叉嵌入
        init_cross_embed_kernel_sizes = (3, 7, 15),  # 初始化交叉嵌入的卷积核大小
        cross_embed_downsample = False,  # 交叉嵌入下采样
        cross_embed_downsample_kernel_sizes = (2, 4),  # 交叉嵌入下采样的卷积核大小
        attn_pool_text = True,  # 注意力池化文本
        attn_pool_num_latents = 32,  # 注意力池化潜在数
        dropout = 0.,  # 丢弃率
        memory_efficient = False,  # 内存效率
        init_conv_to_final_conv_residual = False,  # 初始卷积到最终卷积的残差连接
        use_global_context_attn = True,  # 使用全局上下文注意力
        scale_skip_connection = True,  # 缩放跳跃连接
        final_resnet_block = True,  # 最终 ResNet 块
        final_conv_kernel_size = 3,  # 最终卷积核大小
        self_cond = False,  # 自条件
        resize_mode = 'nearest',  # 调整模式
        combine_upsample_fmaps = False,  # 合并所有上采样块的特征图
        pixel_shuffle_upsample = True,  # 像素混洗上采样
    # 如果当前 Unet 的设置不正确,重新使用正确的设置重新初始化 Unet
    def cast_model_parameters(
        self,
        *,
        lowres_cond,
        text_embed_dim,
        channels,
        channels_out,
        cond_on_text
    ):
        # 如果设置与当前 Unet 的设置相同,则返回当前 Unet
        if lowres_cond == self.lowres_cond and \
            channels == self.channels and \
            cond_on_text == self.cond_on_text and \
            text_embed_dim == self._locals['text_embed_dim'] and \
            channels_out == self.channels_out:
            return self

        # 更新参数
        updated_kwargs = dict(
            lowres_cond = lowres_cond,
            text_embed_dim = text_embed_dim,
            channels = channels,
            channels_out = channels_out,
            cond_on_text = cond_on_text
        )

        return self.__class__(**{**self._locals, **updated_kwargs})

    # 返回完整 Unet 配置及其参数状态字典的方法
    def to_config_and_state_dict(self):
        return self._locals, self.state_dict()

    # 从配置和状态字典中重新创建 Unet 的类方法
    @classmethod
    def from_config_and_state_dict(klass, config, state_dict):
        unet = klass(**config)
        unet.load_state_dict(state_dict)
        return unet

    # 将 Unet 持久化到磁盘的方法
    def persist_to_file(self, path):
        path = Path(path)
        path.parents[0].mkdir(exist_ok = True, parents = True)

        config, state_dict = self.to_config_and_state_dict()
        pkg = dict(config = config, state_dict = state_dict)
        torch.save(pkg, str(path))

    # 从使用 `persist_to_file` 保存的文件重新创建 Unet 的类方法
    @classmethod
    # 从文件中加载模型参数并返回实例化后的模型对象
    def hydrate_from_file(klass, path):
        # 将路径转换为 Path 对象
        path = Path(path)
        # 断言路径存在
        assert path.exists()
        # 使用 torch.load 加载模型参数
        pkg = torch.load(str(path))

        # 断言加载的参数中包含 'config' 和 'state_dict'
        assert 'config' in pkg and 'state_dict' in pkg
        # 分别获取配置和状态字典
        config, state_dict = pkg['config'], pkg['state_dict']

        # 使用配置和状态字典实例化 Unet 模型
        return Unet.from_config_and_state_dict(config, state_dict)

    # 使用分类器自由指导进行前向传播

    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 1.,
        **kwargs
    ):
        # 调用 forward 方法获取 logits
        logits = self.forward(*args, **kwargs)

        # 如果 cond_scale 为 1,则直接返回 logits
        if cond_scale == 1:
            return logits

        # 使用 cond_scale 进行加权计算
        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
        return null_logits + (logits - null_logits) * cond_scale

    # 普通的前向传播方法

    def forward(
        self,
        x,
        time,
        *,
        lowres_cond_img = None,
        lowres_noise_times = None,
        text_embeds = None,
        text_mask = None,
        cond_images = None,
        self_cond = None,
        cond_drop_prob = 0.
# 定义一个空的 Unet 类
class NullUnet(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.lowres_cond = False
        self.dummy_parameter = nn.Parameter(torch.tensor([0.]))

    # 将模型参数转换为自身
    def cast_model_parameters(self, *args, **kwargs):
        return self

    # 前向传播函数,直接返回输入
    def forward(self, x, *args, **kwargs):
        return x

# 预定义的 Unet 类,配置与论文附录中的超参数对应
class BaseUnet64(Unet):
    def __init__(self, *args, **kwargs):
        default_kwargs = dict(
            dim = 512,
            dim_mults = (1, 2, 3, 4),
            num_resnet_blocks = 3,
            layer_attns = (False, True, True, True),
            layer_cross_attns = (False, True, True, True),
            attn_heads = 8,
            ff_mult = 2.,
            memory_efficient = False
        )
        super().__init__(*args, **{**default_kwargs, **kwargs})

class SRUnet256(Unet):
    def __init__(self, *args, **kwargs):
        default_kwargs = dict(
            dim = 128,
            dim_mults = (1, 2, 4, 8),
            num_resnet_blocks = (2, 4, 8, 8),
            layer_attns = (False, False, False, True),
            layer_cross_attns = (False, False, False, True),
            attn_heads = 8,
            ff_mult = 2.,
            memory_efficient = True
        )
        super().__init__(*args, **{**default_kwargs, **kwargs})

class SRUnet1024(Unet):
    def __init__(self, *args, **kwargs):
        default_kwargs = dict(
            dim = 128,
            dim_mults = (1, 2, 4, 8),
            num_resnet_blocks = (2, 4, 8, 8),
            layer_attns = False,
            layer_cross_attns = (False, False, False, True),
            attn_heads = 8,
            ff_mult = 2.,
            memory_efficient = True
        )
        super().__init__(*args, **{**default_kwargs, **kwargs})

# 主要的 Imagen 类,是来自 Ho 等人的级联 DDPM
class Imagen(nn.Module):
    def __init__(
        self,
        unets,
        *,
        image_sizes,                                # 用于级联 ddpm,每个阶段的图像大小
        text_encoder_name = DEFAULT_T5_NAME,
        text_embed_dim = None,
        channels = 3,
        timesteps = 1000,
        cond_drop_prob = 0.1,
        loss_type = 'l2',
        noise_schedules = 'cosine',
        pred_objectives = 'noise',
        random_crop_sizes = None,
        lowres_noise_schedule = 'linear',
        lowres_sample_noise_level = 0.2,            # 论文中提到的一个新技巧,对低分辨率条件图像添加噪声,并在采样时将其固定到一定水平(0.1 或 0.3)- Unet 也被设计为在这��噪声水平上进行条件化
        per_sample_random_aug_noise_level = False,  # 不清楚在进行增强噪声水平条件化时,每个批次元素是否接收随机的增强噪声值-由于 @marunine 的发现,关闭此功能
        condition_on_text = True,
        auto_normalize_img = True,                  # 是否自动处理将图像从 [0, 1] 规范化为 [-1, 1] 并自动恢复-如果要自己从数据加载器传入 [-1, 1] 范围的图像,则可以关闭此功能
        dynamic_thresholding = True,
        dynamic_thresholding_percentile = 0.95,     # 通过查阅论文,不确定这是基于什么的
        only_train_unet_number = None,
        temporal_downsample_factor = 1,
        resize_cond_video_frames = True,
        resize_mode = 'nearest',
        min_snr_loss_weight = True,                 # https://arxiv.org/abs/2303.09556
        min_snr_gamma = 5
    def force_unconditional_(self):
        self.condition_on_text = False
        self.unconditional = True

        for unet in self.unets:
            unet.cond_on_text = False

    @property
    def device(self):
        return self._temp.device
    # 获取指定编号的 UNet 模型
    def get_unet(self, unet_number):
        # 确保编号在有效范围内
        assert 0 < unet_number <= len(self.unets)
        index = unet_number - 1

        # 如果 self.unets 是 nn.ModuleList 类型
        if isinstance(self.unets, nn.ModuleList):
            # 将 self.unets 转换为列表
            unets_list = [unet for unet in self.unets]
            # 删除原有的 self.unets 属性
            delattr(self, 'unets')
            # 将转换后的列表重新赋值给 self.unets
            self.unets = unets_list

        # 如果指定的编号不是当前正在训练的编号
        if index != self.unet_being_trained_index:
            # 遍历所有 UNet 模型
            for unet_index, unet in enumerate(self.unets):
                # 将当前 UNet 模型移到指定设备上,其他模型移到 CPU 上
                unet.to(self.device if unet_index == index else 'cpu')

        # 更新当前正在训练的 UNet 模型编号
        self.unet_being_trained_index = index
        # 返回指定编号的 UNet 模型
        return self.unets[index]

    # 将所有 UNet 模型重置到同一设备上
    def reset_unets_all_one_device(self, device = None):
        # 设置设备为默认设备或者指定设备
        device = default(device, self.device)
        # 将所有 UNet 模型转换为 nn.ModuleList 类型
        self.unets = nn.ModuleList([*self.unets])
        # 将所有 UNet 模型移到指定设备上
        self.unets.to(device)

        # 重置当前正在训练的 UNet 模型编号
        self.unet_being_trained_index = -1

    # 使用上下文管理器将指定编号的 UNet 模型移到 GPU 上
    @contextmanager
    def one_unet_in_gpu(self, unet_number = None, unet = None):
        # 确保只有一个参数是有效的
        assert exists(unet_number) ^ exists(unet)

        # 如果指定了编号,则获取对应的 UNet 模型
        if exists(unet_number):
            unet = self.unets[unet_number - 1]

        # 创建 CPU 设备
        cpu = torch.device('cpu')

        # 获取所有 UNet 模型的设备信息
        devices = [module_device(unet) for unet in self.unets]

        # 将所有 UNet 模型移到 CPU 上
        self.unets.to(cpu)
        # 将指定 UNet 模型移到当前设备上
        unet.to(self.device)

        yield

        # 将所有 UNet 模型还原到各自的设备上
        for unet, device in zip(self.unets, devices):
            unet.to(device)

    # 重写 state_dict 函数
    def state_dict(self, *args, **kwargs):
        # 重置所有 UNet 模型到同一设备上
        self.reset_unets_all_one_device()
        return super().state_dict(*args, **kwargs)

    # 重写 load_state_dict 函数
    def load_state_dict(self, *args, **kwargs):
        # 重置所有 UNet 模型到同一设备上
        self.reset_unets_all_one_device()
        return super().load_state_dict(*args, **kwargs)

    # 高斯扩散方法

    def p_mean_variance(
        self,
        unet,
        x,
        t,
        *,
        noise_scheduler,
        text_embeds = None,
        text_mask = None,
        cond_images = None,
        cond_video_frames = None,
        post_cond_video_frames = None,
        lowres_cond_img = None,
        self_cond = None,
        lowres_noise_times = None,
        cond_scale = 1.,
        model_output = None,
        t_next = None,
        pred_objective = 'noise',
        dynamic_threshold = True
    ):
        # 断言条件:如果条件为真,则抛出异常,说明不能使用分类器自由引导
        assert not (cond_scale != 1. and not self.can_classifier_guidance), 'imagen was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'

        # 初始化视频参数字典
        video_kwargs = dict()
        # 如果是视频模式,设置视频参数
        if self.is_video:
            video_kwargs = dict(
                cond_video_frames = cond_video_frames,
                post_cond_video_frames = post_cond_video_frames,
            )

        # 使用默认函数处理模型输出,获取预测结果
        pred = default(model_output, lambda: unet.forward_with_cond_scale(
            x,
            noise_scheduler.get_condition(t),
            text_embeds = text_embeds,
            text_mask = text_mask,
            cond_images = cond_images,
            cond_scale = cond_scale,
            lowres_cond_img = lowres_cond_img,
            self_cond = self_cond,
            lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_noise_times),
            **video_kwargs
        ))

        # 根据预测目标类型进行处理
        if pred_objective == 'noise':
            x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
        elif pred_objective == 'x_start':
            x_start = pred
        elif pred_objective == 'v':
            x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred)
        else:
            raise ValueError(f'unknown objective {pred_objective}')

        # 如果启用动态阈值
        if dynamic_threshold:
            # 根据重构样本的绝对值百分位数确定动态阈值
            s = torch.quantile(
                rearrange(x_start, 'b ... -> b (...)').abs(),
                self.dynamic_thresholding_percentile,
                dim = -1
            )

            s.clamp_(min = 1.)
            s = right_pad_dims_to(x_start, s)
            x_start = x_start.clamp(-s, s) / s
        else:
            x_start.clamp_(-1., 1.)

        # 计算均值和方差
        mean_and_variance = noise_scheduler.q_posterior(x_start = x_start, x_t = x, t = t, t_next = t_next)
        return mean_and_variance, x_start

    # 无梯度计算
    @torch.no_grad()
    def p_sample(
        self,
        unet,
        x,
        t,
        *,
        noise_scheduler,
        t_next = None,
        text_embeds = None,
        text_mask = None,
        cond_images = None,
        cond_video_frames = None,
        post_cond_video_frames = None,
        cond_scale = 1.,
        self_cond = None,
        lowres_cond_img = None,
        lowres_noise_times = None,
        pred_objective = 'noise',
        dynamic_threshold = True
    ):
        # 获取输入张量的形状和设备信息
        b, *_, device = *x.shape, x.device

        # 初始化视频参数字典
        video_kwargs = dict()
        # 如果是视频模式,设置视频参数
        if self.is_video:
            video_kwargs = dict(
                cond_video_frames = cond_video_frames,
                post_cond_video_frames = post_cond_video_frames,
            )

        # 获取均值、方差和起始值
        (model_mean, _, model_log_variance), x_start = self.p_mean_variance(
            unet,
            x = x,
            t = t,
            t_next = t_next,
            noise_scheduler = noise_scheduler,
            text_embeds = text_embeds,
            text_mask = text_mask,
            cond_images = cond_images,
            cond_scale = cond_scale,
            lowres_cond_img = lowres_cond_img,
            self_cond = self_cond,
            lowres_noise_times = lowres_noise_times,
            pred_objective = pred_objective,
            dynamic_threshold = dynamic_threshold,
            **video_kwargs
        )

        # 生成随机噪声
        noise = torch.randn_like(x)
        # 当 t == 0 时不添加噪声
        is_last_sampling_timestep = (t_next == 0) if isinstance(noise_scheduler, GaussianDiffusionContinuousTimes) else (t == 0)
        nonzero_mask = (1 - is_last_sampling_timestep.float()).reshape(b, *((1,) * (len(x.shape) - 1)))
        # 计算预测值
        pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
        return pred, x_start

    # 无梯度计算
    @torch.no_grad()
    # 定义一个函数 p_sample_loop,用于执行采样循环
    def p_sample_loop(
        self,
        unet,
        shape,
        *,
        noise_scheduler,
        lowres_cond_img = None,
        lowres_noise_times = None,
        text_embeds = None,
        text_mask = None,
        cond_images = None,
        cond_video_frames = None,
        post_cond_video_frames = None,
        inpaint_images = None,
        inpaint_videos = None,
        inpaint_masks = None,
        inpaint_resample_times = 5,
        init_images = None,
        skip_steps = None,
        cond_scale = 1,
        pred_objective = 'noise',
        dynamic_threshold = True,
        use_tqdm = True
    ):
        # 获取当前设备
        device = self.device

        # 获取批次大小
        batch = shape[0]
        # 生成指定形状的随机张量
        img = torch.randn(shape, device = device)

        # video

        # 判断是否为视频
        is_video = len(shape) == 5
        # 如果是视频,获取帧数
        frames = shape[-3] if is_video else None
        # 如果存在帧数,则传入目标帧数参数,否则传入空字典
        resize_kwargs = dict(target_frames = frames) if exists(frames) else dict()

        # for initialization with an image or video

        # 如果存在初始化图像
        if exists(init_images):
            # 将随机生成的图像与初始化图像相加
            img += init_images

        # keep track of x0, for self conditioning

        # 初始化 x0,用于自身条件
        x_start = None

        # prepare inpainting

        # 将 inpaint_videos 默认为 inpaint_images
        inpaint_images = default(inpaint_videos, inpaint_images)

        # 判断是否存在 inpaint_images 和 inpaint_masks
        has_inpainting = exists(inpaint_images) and exists(inpaint_masks)
        # 如果存在 inpaint_images 和 inpaint_masks,则重采样次数为 inpaint_resample_times,否则为 1
        resample_times = inpaint_resample_times if has_inpainting else 1

        # 如果存在 inpaint_images 和 inpaint_masks
        if has_inpainting:
            # 对 inpaint_images 进行归一化处理
            inpaint_images = self.normalize_img(inpaint_images)
            # 将 inpaint_images 调整大小为指定形状
            inpaint_images = self.resize_to(inpaint_images, shape[-1], **resize_kwargs)
            # 将 inpaint_masks 调整大小为指定形状,并转换为布尔类型
            inpaint_masks = self.resize_to(rearrange(inpaint_masks, 'b ... -> b 1 ...').float(), shape[-1], **resize_kwargs).bool()

        # time

        # 获取采样时间步长
        timesteps = noise_scheduler.get_sampling_timesteps(batch, device = device)

        # 是否跳过任何步骤

        # 设置默认跳过步数为 0
        skip_steps = default(skip_steps, 0)
        # 从指定步数开始采样
        timesteps = timesteps[skip_steps:]

        # video conditioning kwargs

        # 初始化视频条件参数字典
        video_kwargs = dict()
        # 如果是视频
        if self.is_video:
            # 设置视频条件参数
            video_kwargs = dict(
                cond_video_frames = cond_video_frames,
                post_cond_video_frames = post_cond_video_frames,
            )

        # 遍历时间步长
        for times, times_next in tqdm(timesteps, desc = 'sampling loop time step', total = len(timesteps), disable = not use_tqdm):
            # 判断是否为最后一个时间步长
            is_last_timestep = times_next == 0

            # 反向遍历重采样次数
            for r in reversed(range(resample_times)):
                # 判断是否为最后一个重采样步骤
                is_last_resample_step = r == 0

                # 如果存在 inpainting
                if has_inpainting:
                    # 从噪声调度器中采样噪声图像
                    noised_inpaint_images, *_ = noise_scheduler.q_sample(inpaint_images, t = times)
                    # 根据掩模进行图像修复
                    img = img * ~inpaint_masks + noised_inpaint_images * inpaint_masks

                # 如果 unet.self_cond 为真,则设置 self_cond 为 x_start,否则为 None
                self_cond = x_start if unet.self_cond else None

                # 生成图像
                img, x_start = self.p_sample(
                    unet,
                    img,
                    times,
                    t_next = times_next,
                    text_embeds = text_embeds,
                    text_mask = text_mask,
                    cond_images = cond_images,
                    cond_scale = cond_scale,
                    self_cond = self_cond,
                    lowres_cond_img = lowres_cond_img,
                    lowres_noise_times = lowres_noise_times,
                    noise_scheduler = noise_scheduler,
                    pred_objective = pred_objective,
                    dynamic_threshold = dynamic_threshold,
                    **video_kwargs
                )

                # 如果存在 inpainting 且不是最后一个重采样步骤或所有时间步骤都为最后一个
                if has_inpainting and not (is_last_resample_step or torch.all(is_last_timestep)):
                    # 从指定时间点到另一个时间点采样图像
                    renoised_img = noise_scheduler.q_sample_from_to(img, times_next, times)

                    # 根据条件选择图像
                    img = torch.where(
                        self.right_pad_dims_to_datatype(is_last_timestep),
                        img,
                        renoised_img
                    )

        # 限制图像像素值范围在 -1 到 1 之间
        img.clamp_(-1., 1.)

        # final inpainting

        # 如果存在 inpainting
        if has_inpainting:
            # 根据掩模进行最终图像修复
            img = img * ~inpaint_masks + inpaint_images * inpaint_masks

        # 反归一化图像
        unnormalize_img = self.unnormalize_img(img)
        # 返回反归一化后的图像
        return unnormalize_img

    # 禁用梯度计算
    @torch.no_grad()
    # 设置评估模式装饰器
    @eval_decorator
    # 设置类型检查装饰器
    @beartype
    # 定义一个方法用于生成样本
    def sample(
        self,
        texts: List[str] = None,  # 文本列表,默认为 None
        text_masks = None,  # 文本掩码,默认为 None
        text_embeds = None,  # 文本嵌入,默认为 None
        video_frames = None,  # 视频帧,默认为 None
        cond_images = None,  # 条件图像,默认为 None
        cond_video_frames = None,  # 条件视频帧,默认为 None
        post_cond_video_frames = None,  # 后置条件视频帧,默认为 None
        inpaint_videos = None,  # 修复视频,默认为 None
        inpaint_images = None,  # 修复图像,默认为 None
        inpaint_masks = None,  # 修复掩码,默认为 None
        inpaint_resample_times = 5,  # 修复重采样次数,默认为 5
        init_images = None,  # 初始图像,默认为 None
        skip_steps = None,  # 跳过步骤,默认为 None
        batch_size = 1,  # 批量大小,默认为 1
        cond_scale = 1.,  # 条件比例,默认为 1.0
        lowres_sample_noise_level = None,  # 低分辨率采样噪声级别,默认为 None
        start_at_unet_number = 1,  # 开始于 Unet 编号,默认为 1
        start_image_or_video = None,  # 开始图像或视频,默认为 None
        stop_at_unet_number = None,  # 停止于 Unet 编号,默认为 None
        return_all_unet_outputs = False,  # 返回所有 Unet 输出,默认为 False
        return_pil_images = False,  # 返回 PIL 图像,默认为 False
        device = None,  # 设备,默认为 None
        use_tqdm = True,  # 使用 tqdm,默认为 True
        use_one_unet_in_gpu = True  # 在 GPU 中使用一个 Unet,默认为 True
    # 定义一个方法用于计算损失
    @beartype
    def p_losses(
        self,
        unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel],  # Unet 对象,默认为 None
        x_start,  # 起始值
        times,  # 时间
        *,
        noise_scheduler,  # 噪声调度器
        lowres_cond_img = None,  # 低分辨率条件图像,默认为 None
        lowres_aug_times = None,  # 低分辨率增强次数,默认为 None
        text_embeds = None,  # 文本嵌入,默认为 None
        text_mask = None,  # 文本掩码,默认为 None
        cond_images = None,  # 条件图像,默认为 None
        noise = None,  # 噪声,默认为 None
        times_next = None,  # 下一个时间,默认为 None
        pred_objective = 'noise',  # 预测目标,默认为 'noise'
        min_snr_gamma = None,  # 最小信噪比伽马,默认为 None
        random_crop_size = None,  # ��机裁剪大小,默认为 None
        **kwargs  # 其他关键字参数
    # 定义一个方法用于前向传播
    @beartype
    def forward(
        self,
        images,  # 图像或视频
        unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel] = None,  # Unet 对象,默认为 None
        texts: List[str] = None,  # 文本列表,默认为 None
        text_embeds = None,  # 文本嵌入,默认为 None
        text_masks = None,  # 文本掩码,默认为 None
        unet_number = None,  # Unet 编号,默认为 None
        cond_images = None,  # 条件图像,默认为 None
        **kwargs  # 其他关键字参数

.\lucidrains\imagen-pytorch\imagen_pytorch\imagen_video.py

# 导入数学、操作符、函数工具等模块
import math
import operator
import functools
from tqdm.auto import tqdm
from functools import partial, wraps
from pathlib import Path

# 导入 PyTorch 相关模块
import torch
import torch.nn.functional as F
from torch import nn, einsum

# 导入 einops 相关模块
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange

# 导入自定义模块
from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 返回输入值
def identity(t, *args, **kwargs):
    return t

# 返回数组的第一个元素,如果数组为空则返回默认值
def first(arr, d = None):
    if len(arr) == 0:
        return d
    return arr[0]

# 检查一个数是否能被另一个数整除
def divisible_by(numer, denom):
    return (numer % denom) == 0

# 可能执行函数,如果输入值不存在则直接返回
def maybe(fn):
    @wraps(fn)
    def inner(x):
        if not exists(x):
            return x
        return fn(x)
    return inner

# 仅执行一次函数,用于打印信息
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

# 仅打印一次信息
print_once = once(print)

# 返回默认值或默认函数的值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 将输入值转换为元组
def cast_tuple(val, length = None):
    if isinstance(val, list):
        val = tuple(val)

    output = val if isinstance(val, tuple) else ((val,) * default(length, 1))

    if exists(length):
        assert len(output) == length

    return output

# 将 uint8 类型的图像转换为 float 类型
def cast_uint8_images_to_float(images):
    if not images.dtype == torch.uint8:
        return images
    return images / 255

# 获取模块的设备信息
def module_device(module):
    return next(module.parameters()).device

# 初始化权重为零
def zero_init_(m):
    nn.init.zeros_(m.weight)
    if exists(m.bias):
        nn.init.zeros_(m.bias)

# 模型评估装饰器
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 将元组填充到指定长度
def pad_tuple_to_length(t, length, fillvalue = None):
    remain_length = length - len(t)
    if remain_length <= 0:
        return t
    return (*t, *((fillvalue,) * remain_length))

# 辅助类

# 简单的返回输入值的模块
class Identity(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

    def forward(self, x, *args, **kwargs):
        return x

# 创建序列模块
def Sequential(*modules):
    return nn.Sequential(*filter(exists, modules))

# 张量辅助函数

# 对数函数
def log(t, eps: float = 1e-12):
    return torch.log(t.clamp(min = eps))

# L2 归一化
def l2norm(t):
    return F.normalize(t, dim = -1)

# 将右侧维度填充到相同维度
def right_pad_dims_to(x, t):
    padding_dims = x.ndim - t.ndim
    if padding_dims <= 0:
        return t
    return t.view(*t.shape, *((1,) * padding_dims))

# 带掩码的均值计算
def masked_mean(t, *, dim, mask = None):
    if not exists(mask):
        return t.mean(dim = dim)

    denom = mask.sum(dim = dim, keepdim = True)
    mask = rearrange(mask, 'b n -> b n 1')
    masked_t = t.masked_fill(~mask, 0.)

    return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)

# 调整视频大小
def resize_video_to(
    video,
    target_image_size,
    target_frames = None,
    clamp_range = None,
    mode = 'nearest'
):
    orig_video_size = video.shape[-1]

    frames = video.shape[2]
    target_frames = default(target_frames, frames)

    target_shape = (target_frames, target_image_size, target_image_size)

    if tuple(video.shape[-3:]) == target_shape:
        return video

    out = F.interpolate(video, target_shape, mode = mode)

    if exists(clamp_range):
        out = out.clamp(*clamp_range)
        
    return out

# 缩放视频时间
def scale_video_time(
    video,
    downsample_scale = 1,
    mode = 'nearest'
):
    if downsample_scale == 1:
        return video

    image_size, frames = video.shape[-1], video.shape[-3]
    assert divisible_by(frames, downsample_scale), f'trying to temporally downsample a conditioning video frames of length {frames} by {downsample_scale}, however it is not neatly divisible'

    target_frames = frames // downsample_scale
    # 调用 resize_video_to 函数,将视频调整大小为指定尺寸
    resized_video = resize_video_to(
        video,  # 原始视频
        image_size,  # 目标图像尺寸
        target_frames = target_frames,  # 目标帧数
        mode = mode  # 调整模式
    )

    # 返回调整大小后的视频
    return resized_video
# classifier free guidance functions

# 根据给定形状、概率和设备创建一个布尔类型的掩码
def prob_mask_like(shape, prob, device):
    if prob == 1:
        return torch.ones(shape, device=device, dtype=torch.bool)
    elif prob == 0:
        return torch.zeros(shape, device=device, dtype=torch.bool)
    else:
        return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob

# norms and residuals

# Layer normalization模块
class LayerNorm(nn.Module):
    def __init__(self, dim, stable=False):
        super().__init__()
        self.stable = stable
        self.g = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        if self.stable:
            x = x / x.amax(dim=-1, keepdim=True).detach()

        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        var = torch.var(x, dim=-1, unbiased=False, keepdim=True)
        mean = torch.mean(x, dim=-1, keepdim=True)
        return (x - mean) * (var + eps).rsqrt() * self.g

# 通道层规范化模块
class ChanLayerNorm(nn.Module):
    def __init__(self, dim, stable=False):
        super().__init__()
        self.stable = stable
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1, 1))

    def forward(self, x):
        if self.stable:
            x = x / x.amax(dim=1, keepdim=True).detach()

        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        var = torch.var(x, dim=1, unbiased=False, keepdim=True)
        mean = torch.mean(x, dim=1, keepdim=True)
        return (x - mean) * (var + eps).rsqrt() * self.g

# 始终返回相同值的类
class Always():
    def __init__(self, val):
        self.val = val

    def __call__(self, *args, **kwargs):
        return self.val

# 残差连接模块
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

# 并行执行多个函数模块
class Parallel(nn.Module):
    def __init__(self, *fns):
        super().__init__()
        self.fns = nn.ModuleList(fns)

    def forward(self, x):
        outputs = [fn(x) for fn in self.fns]
        return sum(outputs)

# rearranging

# 时间为中心的重排模块
class RearrangeTimeCentric(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        x = rearrange(x, 'b c f ... -> b ... f c')
        x, ps = pack([x], '* f c')

        x = self.fn(x)

        x, = unpack(x, ps, '* f c')
        x = rearrange(x, 'b ... f c -> b c f ...')
        return x

# attention pooling

# PerceiverAttention模块
class PerceiverAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head=64,
        heads=8,
        scale=8
    ):
        super().__init__()
        self.scale = scale

        self.heads = heads
        inner_dim = dim_head * heads

        self.norm = nn.LayerNorm(dim)
        self.norm_latents = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)

        self.q_scale = nn.Parameter(torch.ones(dim_head))
        self.k_scale = nn.Parameter(torch.ones(dim_head))

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias=False),
            nn.LayerNorm(dim)
        )
    # 前向传播函数,接收输入 x、潜在变量 latents 和可选的 mask
    def forward(self, x, latents, mask = None):
        # 对输入 x 进行归一化处理
        x = self.norm(x)
        # 对潜在变量 latents 进行归一化处理
        latents = self.norm_latents(latents)

        # 获取输入 x 的 batch 大小和头数
        b, h = x.shape[0], self.heads

        # 生成查询向量 q
        q = self.to_q(latents)

        # 将输入 x 和潜在变量 latents 连接起来,作为键值对的输入
        kv_input = torch.cat((x, latents), dim = -2)
        # 将连接后的输入转换为键和值
        k, v = self.to_kv(kv_input).chunk(2, dim = -1)

        # 对查询、键、值进行维度重排
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 对查询和键进行 L2 归一化
        q, k = map(l2norm, (q, k))
        # 对查询和键进行缩放
        q = q * self.q_scale
        k = k * self.k_scale

        # 计算相似度矩阵
        sim = einsum('... i d, ... j d  -> ... i j', q, k) * self.scale

        # 如果存在 mask,则进行填充和掩码处理
        if exists(mask):
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = F.pad(mask, (0, latents.shape[-2]), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, max_neg_value)

        # 计算注意力权重
        attn = sim.softmax(dim = -1)

        # 计算输出
        out = einsum('... i j, ... j d -> ... i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)', h = h)
        # 返回输出结果
        return self.to_out(out)
# 定义 PerceiverResampler 类,继承自 nn.Module
class PerceiverResampler(nn.Module):
    # 初始化函数
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        num_latents = 64,
        num_latents_mean_pooled = 4, # 从序列的均值池化表示派生的潜在变量数量
        max_seq_len = 512,
        ff_mult = 4
    ):
        super().__init__()
        # 创建位置嵌入层
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        # 初始化潜在变量
        self.latents = nn.Parameter(torch.randn(num_latents, dim))

        self.to_latents_from_mean_pooled_seq = None

        # 如果均值池化的潜在变量数量大于0,则创建相应的层
        if num_latents_mean_pooled > 0:
            self.to_latents_from_mean_pooled_seq = nn.Sequential(
                LayerNorm(dim),
                nn.Linear(dim, dim * num_latents_mean_pooled),
                Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled)
            )

        # 创建多层感知器
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

    # 前向传播函数
    def forward(self, x, mask = None):
        n, device = x.shape[1], x.device
        pos_emb = self.pos_emb(torch.arange(n, device = device))

        x_with_pos = x + pos_emb

        latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])

        # 如果存在均值池化的潜在变量,则将其与原始潜在变量拼接
        if exists(self.to_latents_from_mean_pooled_seq):
            meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool))
            meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
            latents = torch.cat((meanpooled_latents, latents), dim = -2)

        # 遍历每一层的注意力机制和前馈网络
        for attn, ff in self.layers:
            latents = attn(x_with_pos, latents, mask = mask) + latents
            latents = ff(latents) + latents

        return latents

# 定义 Conv3d 类,继承自 nn.Module
class Conv3d(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        dim_out = None,
        kernel_size = 3,
        *,
        temporal_kernel_size = None,
        **kwargs
    ):
        super().__init__()
        dim_out = default(dim_out, dim)
        temporal_kernel_size = default(temporal_kernel_size, kernel_size)

        # 创建空���卷积层
        self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size = kernel_size, padding = kernel_size // 2)
        # 创建时间卷积层(如果 kernel_size 大于1)
        self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size = temporal_kernel_size) if kernel_size > 1 else None
        self.kernel_size = kernel_size

        # 初始化时间卷积层的权重为单位矩阵
        if exists(self.temporal_conv):
            nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity
            nn.init.zeros_(self.temporal_conv.bias.data)

    # 前向传播函数
    def forward(
        self,
        x,
        ignore_time = False
    ):
        b, c, *_, h, w = x.shape

        is_video = x.ndim == 5
        ignore_time &= is_video

        if is_video:
            x = rearrange(x, 'b c f h w -> (b f) c h w')

        x = self.spatial_conv(x)

        if is_video:
            x = rearrange(x, '(b f) c h w -> b c f h w', b = b)

        if ignore_time or not exists(self.temporal_conv):
            return x

        x = rearrange(x, 'b c f h w -> (b h w) c f')

        # 因果时间卷积 - 时间在 imagen-video 中是因果的

        if self.kernel_size > 1:
            x = F.pad(x, (self.kernel_size - 1, 0))

        x = self.temporal_conv(x)

        x = rearrange(x, '(b h w) c f -> b c f h w', h = h, w = w)

        return x

# 定义 Attention 类,继承自 nn.Module
class Attention(nn.Module):
    # 初始化函数
    def __init__(
        self,
        dim,
        *,
        dim_head = 64,
        heads = 8,
        causal = False,
        context_dim = None,
        rel_pos_bias = False,
        rel_pos_bias_mlp_depth = 2,
        init_zero = False,
        scale = 8
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 设置缩放因子和是否因果的标志
        self.scale = scale
        self.causal = causal

        # 如果启用相对位置偏置,则创建动态位置偏置对象
        self.rel_pos_bias = DynamicPositionBias(dim = dim, heads = heads, depth = rel_pos_bias_mlp_depth) if rel_pos_bias else None

        # 初始化头数和内部维度
        self.heads = heads
        inner_dim = dim_head * heads

        # 初始化 LayerNorm
        self.norm = LayerNorm(dim)

        # 初始化空注意力偏置和空键值对
        self.null_attn_bias = nn.Parameter(torch.randn(heads))
        self.null_kv = nn.Parameter(torch.randn(2, dim_head))
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)

        # 初始化缩放参数
        self.q_scale = nn.Parameter(torch.ones(dim_head))
        self.k_scale = nn.Parameter(torch.ones(dim_head))

        # 如果存在上下文维度,则初始化上下文处理层
        self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None

        # 初始化输出层
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            LayerNorm(dim)
        )

        # 如果初始化为零,则将输出层的偏置初始化为零
        if init_zero:
            nn.init.zeros_(self.to_out[-1].g)

    def forward(
        self,
        x,
        context = None,
        mask = None,
        attn_bias = None
    ):
        # 获取输入张量的形状和设备信息
        b, n, device = *x.shape[:2], x.device

        # 对输入张量进行 LayerNorm 处理
        x = self.norm(x)
        # 分别计算查询、键、值
        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))

        # 将查询张量重排为多头形式
        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)

        # 添加空键/值以用于分类器的先验网络引导
        nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))
        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)

        # 如果存在上下文,则添加文本条件
        if exists(context):
            assert exists(self.to_context)
            ck, cv = self.to_context(context).chunk(2, dim = -1)
            k = torch.cat((ck, k), dim = -2)
            v = torch.cat((cv, v), dim = -2)

        # 对查询、键进行 L2 归一化
        q, k = map(l2norm, (q, k))
        q = q * self.q_scale
        k = k * self.k_scale

        # 计算查询/键的相似性
        sim = einsum('b h i d, b j d -> b h i j', q, k) * self.scale

        # 相对位置编码(T5 风格)
        if not exists(attn_bias) and exists(self.rel_pos_bias):
            attn_bias = self.rel_pos_bias(n, device = device, dtype = q.dtype)

        if exists(attn_bias):
            null_attn_bias = repeat(self.null_attn_bias, 'h -> h n 1', n = n)
            attn_bias = torch.cat((null_attn_bias, attn_bias), dim = -1)
            sim = sim + attn_bias

        # 掩码
        max_neg_value = -torch.finfo(sim.dtype).max

        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, max_neg_value)

        if exists(mask):
            mask = F.pad(mask, (1, 0), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, max_neg_value)

        # 注意力
        attn = sim.softmax(dim = -1)

        # 聚合值
        out = einsum('b h i j, b j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
# 定义一个伪 Conv2d 函数,使用 Conv3d 但在帧维度上使用大小为1的卷积核
def Conv2d(dim_in, dim_out, kernel, stride = 1, padding = 0, **kwargs):
    # 将 kernel 转换为元组
    kernel = cast_tuple(kernel, 2)
    # 将 stride 转换为元组
    stride = cast_tuple(stride, 2)
    # 将 padding 转换为元组
    padding = cast_tuple(padding, 2)

    # 如果 kernel 的长度为2,则在前面添加1
    if len(kernel) == 2:
        kernel = (1, *kernel)

    # 如果 stride 的长度为2,则在前面添加1
    if len(stride) == 2:
        stride = (1, *stride)

    # 如果 padding 的长度为2,则在前面添加0
    if len(padding) == 2:
        padding = (0, *padding)

    # 返回一个 Conv3d 对象
    return nn.Conv3d(dim_in, dim_out, kernel, stride = stride, padding = padding, **kwargs)

# 定义一个 Pad 类
class Pad(nn.Module):
    def __init__(self, padding, value = 0.):
        super().__init__()
        self.padding = padding
        self.value = value

    # 前向传播函数
    def forward(self, x):
        return F.pad(x, self.padding, value = self.value)

# 定义一个 Upsample 函数
def Upsample(dim, dim_out = None):
    dim_out = default(dim_out, dim)

    # 返回一个包含 Upsample 和 Conv2d 的序列
    return nn.Sequential(
        nn.Upsample(scale_factor = 2, mode = 'nearest'),
        Conv2d(dim, dim_out, 3, padding = 1)
    )

# 定义一个 PixelShuffleUpsample 类
class PixelShuffleUpsample(nn.Module):
    def __init__(self, dim, dim_out = None):
        super().__init__()
        dim_out = default(dim_out, dim)
        conv = Conv2d(dim, dim_out * 4, 1)

        self.net = nn.Sequential(
            conv,
            nn.SiLU()
        )

        self.pixel_shuffle = nn.PixelShuffle(2)

        self.init_conv_(conv)

    # 初始化卷积层的权重
    def init_conv_(self, conv):
        o, i, f, h, w = conv.weight.shape
        conv_weight = torch.empty(o // 4, i, f, h, w)
        nn.init.kaiming_uniform_(conv_weight)
        conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')

        conv.weight.data.copy_(conv_weight)
        nn.init.zeros_(conv.bias.data)

    # 前向传播函数
    def forward(self, x):
        out = self.net(x)
        frames = x.shape[2]
        out = rearrange(out, 'b c f h w -> (b f) c h w')
        out = self.pixel_shuffle(out)
        return rearrange(out, '(b f) c h w -> b c f h w', f = frames)

# 定义一个 Downsample 函数
def Downsample(dim, dim_out = None):
    dim_out = default(dim_out, dim)
    return nn.Sequential(
        Rearrange('b c f (h p1) (w p2) -> b (c p1 p2) f h w', p1 = 2, p2 = 2),
        Conv2d(dim * 4, dim_out, 1)
    )

# 定义一个 TemporalPixelShuffleUpsample 类
class TemporalPixelShuffleUpsample(nn.Module):
    def __init__(self, dim, dim_out = None, stride = 2):
        super().__init__()
        self.stride = stride
        dim_out = default(dim_out, dim)
        conv = nn.Conv1d(dim, dim_out * stride, 1)

        self.net = nn.Sequential(
            conv,
            nn.SiLU()
        )

        self.pixel_shuffle = Rearrange('b (c r) n -> b c (n r)', r = stride)

        self.init_conv_(conv)

    # 初始化卷积层的权重
    def init_conv_(self, conv):
        o, i, f = conv.weight.shape
        conv_weight = torch.empty(o // self.stride, i, f)
        nn.init.kaiming_uniform_(conv_weight)
        conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r = self.stride)

        conv.weight.data.copy_(conv_weight)
        nn.init.zeros_(conv.bias.data)

    # 前向传播函数
    def forward(self, x):
        b, c, f, h, w = x.shape
        x = rearrange(x, 'b c f h w -> (b h w) c f')
        out = self.net(x)
        out = self.pixel_shuffle(out)
        return rearrange(out, '(b h w) c f -> b c f h w', h = h, w = w)

# 定义一个 TemporalDownsample 函数
def TemporalDownsample(dim, dim_out = None, stride = 2):
    dim_out = default(dim_out, dim)
    return nn.Sequential(
        Rearrange('b c (f p) h w -> b (c p) f h w', p = stride),
        Conv2d(dim * stride, dim_out, 1)
    )

# 定义一个 SinusoidalPosEmb 类
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    # 前向传播函数
    def forward(self, x):
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
        emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
        return torch.cat((emb.sin(), emb.cos()), dim = -1)

# 定义一个 LearnedSinusoidalPosEmb 类
class LearnedSinusoidalPosEmb(nn.Module):
    # 初始化函数,接受维度参数
    def __init__(self, dim):
        # 调用父类的初始化函数
        super().__init__()
        # 断言维度为偶数
        assert (dim % 2) == 0
        # 计算维度的一半
        half_dim = dim // 2
        # 初始化权重参数为服从标准正态分布的张量
        self.weights = nn.Parameter(torch.randn(half_dim))

    # 前向传播函数,接受输入张量 x
    def forward(self, x):
        # 重新排列输入张量 x 的维度,增加一个维度
        x = rearrange(x, 'b -> b 1')
        # 计算频率,乘以权重参数和 2π
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
        # 将正弦和余弦值拼接在一起
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
        # 将输入张量 x 和频率值拼接在一起
        fouriered = torch.cat((x, fouriered), dim = -1)
        # 返回拼接后的张量
        return fouriered
class Block(nn.Module):
    # 定义一个块模块,包含归一化、激活函数和卷积操作
    def __init__(
        self,
        dim,
        dim_out,
        groups = 8,
        norm = True
    ):
        super().__init__()
        # 初始化 GroupNorm 归一化层,如果不需要归一化则使用 Identity 函数
        self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity()
        # 初始化激活函数为 SiLU
        self.activation = nn.SiLU()
        # 初始化卷积操作,输出维度为 dim_out,卷积核大小为 3,填充为 1
        self.project = Conv3d(dim, dim_out, 3, padding = 1)

    # 前向传播函数,对输入进行归一化、缩放平移、激活和卷积操作
    def forward(
        self,
        x,
        scale_shift = None,
        ignore_time = False
    ):
        # 对输入进行归一化
        x = self.groupnorm(x)

        # 如果有缩放平移参数,则对输入进行缩放平移操作
        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        # 对归一化后的输入进行激活函数操作
        x = self.activation(x)
        # 返回卷积操作后的结果
        return self.project(x, ignore_time = ignore_time)

class ResnetBlock(nn.Module):
    # 定义一个 ResNet 块模块,包含时间 MLP、交叉注意力、块模块和全局上下文注意力
    def __init__(
        self,
        dim,
        dim_out,
        *,
        cond_dim = None,
        time_cond_dim = None,
        groups = 8,
        linear_attn = False,
        use_gca = False,
        squeeze_excite = False,
        **attn_kwargs
    ):
        super().__init__()

        self.time_mlp = None

        # 如果存在时间条件维度,则初始化时间 MLP
        if exists(time_cond_dim):
            self.time_mlp = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_cond_dim, dim_out * 2)
            )

        self.cross_attn = None

        # 如果存在条件维度,则初始化交叉注意力模块
        if exists(cond_dim):
            attn_klass = CrossAttention if not linear_attn else LinearCrossAttention

            self.cross_attn = attn_klass(
                dim = dim_out,
                context_dim = cond_dim,
                **attn_kwargs
            )

        # 初始化两个块模块
        self.block1 = Block(dim, dim_out, groups = groups)
        self.block2 = Block(dim_out, dim_out, groups = groups)

        # 如果使用全局上下文注意力,则初始化全局上下文模块
        self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1)

        # 如果输入维度不等于输出维度,则初始化卷积操作
        self.res_conv = Conv2d(dim, dim_out, 1) if dim != dim_out else Identity()


    # 前向传播函数,包括时间 MLP、交叉注意力、块模块和全局上下文注意力的操作
    def forward(
        self,
        x,
        time_emb = None,
        cond = None,
        ignore_time = False
    ):

        scale_shift = None
        # 如果存在时间 MLP 和时间嵌入,则进行时间 MLP 操作
        if exists(self.time_mlp) and exists(time_emb):
            time_emb = self.time_mlp(time_emb)
            time_emb = rearrange(time_emb, 'b c -> b c 1 1 1')
            scale_shift = time_emb.chunk(2, dim = 1)

        # 第一个块模块操作
        h = self.block1(x, ignore_time = ignore_time)

        # 如果存在交叉注意力模块,则进行交叉注意力操作
        if exists(self.cross_attn):
            assert exists(cond)
            h = rearrange(h, 'b c ... -> b ... c')
            h, ps = pack([h], 'b * c')

            h = self.cross_attn(h, context = cond) + h

            h, = unpack(h, ps, 'b * c')
            h = rearrange(h, 'b ... c -> b c ...')

        # 第二个块模块操作
        h = self.block2(h, scale_shift = scale_shift, ignore_time = ignore_time)

        # 全局上下文注意力操作
        h = h * self.gca(h)

        # 返回结果加上残差连接
        return h + self.res_conv(x)

class CrossAttention(nn.Module):
    # 定义交叉注意力模块,包含查询、键值映射和输出映射
    def __init__(
        self,
        dim,
        *,
        context_dim = None,
        dim_head = 64,
        heads = 8,
        norm_context = False,
        scale = 8
    ):
        super().__init__()
        self.scale = scale

        self.heads = heads
        inner_dim = dim_head * heads

        context_dim = default(context_dim, dim)

        # 初始化 LayerNorm 归一化层
        self.norm = LayerNorm(dim)
        self.norm_context = LayerNorm(context_dim) if norm_context else Identity()

        # 初始化查询映射和键值映射
        self.null_kv = nn.Parameter(torch.randn(2, dim_head))
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)

        self.q_scale = nn.Parameter(torch.ones(dim_head))
        self.k_scale = nn.Parameter(torch.ones(dim_head))

        # 初始化输出映射
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            LayerNorm(dim)
        )
    # 定义前向传播函数,接受输入 x、上下文 context 和可选的掩码 mask
    def forward(self, x, context, mask = None):
        # 获取输入 x 的形状信息,包括 batch 大小 b、序列长度 n、设备信息 device
        b, n, device = *x.shape[:2], x.device

        # 对输入 x 和上下文 context 进行归一化处理
        x = self.norm(x)
        context = self.norm_context(context)

        # 将输入 x 转换为查询 q,上下文 context 转换为键 k 和值 v
        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))

        # 将查询 q、键 k 和值 v 重排为多头注意力的形式
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

        # 为先验网络添加空键/值,用于无分类器干预的指导
        nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads,  b = b), self.null_kv.unbind(dim = -2))
        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)

        # 对查询 q 和键 k 进行 L2 归一化处理
        q, k = map(l2norm, (q, k))
        q = q * self.q_scale
        k = k * self.k_scale

        # 计算相似度矩阵
        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        # 掩码处理
        max_neg_value = -torch.finfo(sim.dtype).max
        if exists(mask):
            mask = F.pad(mask, (1, 0), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, max_neg_value)

        # 对相似度矩阵进行 softmax 操作,得到注意力权重
        attn = sim.softmax(dim = -1, dtype = torch.float32)

        # 根据注意力权重计算输出
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        # 返回输出结果
        return self.to_out(out)
class LinearCrossAttention(CrossAttention):
    # 线性交叉注意力类,继承自CrossAttention类
    def forward(self, x, context, mask = None):
        # 前向传播函数,接受输入x、上下文context和掩码mask,默认为None
        b, n, device = *x.shape[:2], x.device

        x = self.norm(x)
        # 对输入x进行规范化
        context = self.norm_context(context)
        # 对上下文context进行规范化

        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        # 将输入x和上下文context转换为查询q、键k和值v

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads), (q, k, v))
        # 重排查询q、键k和值v的维度

        # add null key / value for classifier free guidance in prior net
        # 为先前网络中的无分类器自由指导添加空键/值

        nk, nv = map(lambda t: repeat(t, 'd -> (b h) 1 d', h = self.heads,  b = b), self.null_kv.unbind(dim = -2))

        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)

        # masking
        # 掩码处理

        max_neg_value = -torch.finfo(x.dtype).max

        if exists(mask):
            mask = F.pad(mask, (1, 0), value = True)
            mask = rearrange(mask, 'b n -> b n 1')
            k = k.masked_fill(~mask, max_neg_value)
            v = v.masked_fill(~mask, 0.)

        # linear attention
        # 线性注意力

        q = q.softmax(dim = -1)
        k = k.softmax(dim = -2)

        q = q * self.scale

        context = einsum('b n d, b n e -> b d e', k, v)
        out = einsum('b n d, b d e -> b n e', q, context)
        out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads)
        return self.to_out(out)

class LinearAttention(nn.Module):
    # 线性注意力类,继承自nn.Module类
    def __init__(
        self,
        dim,
        dim_head = 32,
        heads = 8,
        dropout = 0.05,
        context_dim = None,
        **kwargs
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        inner_dim = dim_head * heads
        self.norm = ChanLayerNorm(dim)

        self.nonlin = nn.SiLU()

        self.to_q = nn.Sequential(
            nn.Dropout(dropout),
            Conv2d(dim, inner_dim, 1, bias = False),
            Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
        )

        self.to_k = nn.Sequential(
            nn.Dropout(dropout),
            Conv2d(dim, inner_dim, 1, bias = False),
            Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
        )

        self.to_v = nn.Sequential(
            nn.Dropout(dropout),
            Conv2d(dim, inner_dim, 1, bias = False),
            Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
        )

        self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None

        self.to_out = nn.Sequential(
            Conv2d(inner_dim, dim, 1, bias = False),
            ChanLayerNorm(dim)
        )

    def forward(self, fmap, context = None):
        # 前向传播函数,接受特征图fmap和上下文context,默认为None
        h, x, y = self.heads, *fmap.shape[-2:]

        fmap = self.norm(fmap)
        q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v))
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v))

        if exists(context):
            assert exists(self.to_context)
            ck, cv = self.to_context(context).chunk(2, dim = -1)
            ck, cv = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (ck, cv))
            k = torch.cat((k, ck), dim = -2)
            v = torch.cat((v, cv), dim = -2)

        q = q.softmax(dim = -1)
        k = k.softmax(dim = -2)

        q = q * self.scale

        context = einsum('b n d, b n e -> b d e', k, v)
        out = einsum('b n d, b d e -> b n e', q, context)
        out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)

        out = self.nonlin(out)
        return self.to_out(out)

class GlobalContext(nn.Module):
    # 全局上下文类,继承自nn.Module类
    """ basically a superior form of squeeze-excitation that is attention-esque """
    # 基本上是一种类似于注意力的优越形式的挤压激励

    def __init__(
        self,
        *,
        dim_in,
        dim_out
        # 初始化函数,接受输入维度dim_in和输出维度dim_out
    # 定义一个继承自 nn.Module 的类,用于实现一个自定义的注意力机制模块
    ):
        # 调用父类的构造函数
        super().__init__()
        # 定义一个将输入特征维度转换为 K 维度的卷积层
        self.to_k = Conv2d(dim_in, 1, 1)
        # 计算隐藏层维度,取最大值为 3 或者输出维度的一半
        hidden_dim = max(3, dim_out // 2)

        # 定义一个神经网络序列,包含卷积层、激活函数和输出层
        self.net = nn.Sequential(
            Conv2d(dim_in, hidden_dim, 1),
            nn.SiLU(),  # 使用 SiLU 激活函数
            Conv2d(hidden_dim, dim_out, 1),
            nn.Sigmoid()  # 使用 Sigmoid 激活函数
        )

    # 定义前向传播函数
    def forward(self, x):
        # 将输入 x 经过 to_k 卷积层得到 context
        context = self.to_k(x)
        # 对输入 x 和 context 进行维度重排
        x, context = map(lambda t: rearrange(t, 'b n ... -> b n (...)'), (x, context))
        # 使用 einsum 计算注意力权重并与输入 x 相乘
        out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x)
        # 对输出 out 进行维度重排
        out = rearrange(out, '... -> ... 1 1')
        # 将处理后的 out 输入到神经网络序列中得到最终输出
        return self.net(out)
# 定义一个前馈神经网络模块,包含层归一化、线性层、GELU激活函数和线性层
def FeedForward(dim, mult = 2):
    # 计算隐藏层维度
    hidden_dim = int(dim * mult)
    return nn.Sequential(
        LayerNorm(dim),  # 层归一化
        nn.Linear(dim, hidden_dim, bias = False),  # 线性层
        nn.GELU(),  # GELU激活函数
        LayerNorm(hidden_dim),  # 层归一化
        nn.Linear(hidden_dim, dim, bias = False)  # 线性层
    )

# 定义一个时间标记位移模块
class TimeTokenShift(nn.Module):
    def forward(self, x):
        if x.ndim != 5:
            return x

        x, x_shift = x.chunk(2, dim = 1)  # 将输入张量按维度1分块
        x_shift = F.pad(x_shift, (0, 0, 0, 0, 1, -1), value = 0.)  # 对x_shift进行填充
        return torch.cat((x, x_shift), dim = 1)  # 在维度1上连接张量x和x_shift

# 定义一个通道前馈神经网络模块
def ChanFeedForward(dim, mult = 2, time_token_shift = True):
    # 计算隐藏层维度
    hidden_dim = int(dim * mult)
    return Sequential(
        ChanLayerNorm(dim),  # 通道层归一化
        Conv2d(dim, hidden_dim, 1, bias = False),  # 二维卷积层
        nn.GELU(),  # GELU激活函数
        TimeTokenShift() if time_token_shift else None,  # 时间标记位移模块
        ChanLayerNorm(hidden_dim),  # 通道层归一化
        Conv2d(hidden_dim, dim, 1, bias = False)  # 二维卷积层
    )

# 定义一个Transformer块模块
class TransformerBlock(nn.Module):
    def __init__(
        self,
        dim,
        *,
        depth = 1,
        heads = 8,
        dim_head = 32,
        ff_mult = 2,
        ff_time_token_shift = True,
        context_dim = None
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),  # 注意力机制
                ChanFeedForward(dim = dim, mult = ff_mult, time_token_shift = ff_time_token_shift)  # 通道前馈神经网络
            ]))

    def forward(self, x, context = None):
        for attn, ff in self.layers:
            x = rearrange(x, 'b c ... -> b ... c')  # 重新排列张量维度
            x, ps = pack([x], 'b * c')  # 打包张量

            x = attn(x, context = context) + x  # 注意力机制处理后与原始张量相加

            x, = unpack(x, ps, 'b * c')  # 解包张量
            x = rearrange(x, 'b ... c -> b c ...')  # 重新排列张量维度

            x = ff(x) + x  # 通道前馈神经网络处理后与原始张量相加
        return x

# 定义一个线性注意力Transformer块模块
class LinearAttentionTransformerBlock(nn.Module):
    def __init__(
        self,
        dim,
        *,
        depth = 1,
        heads = 8,
        dim_head = 32,
        ff_mult = 2,
        ff_time_token_shift = True,
        context_dim = None,
        **kwargs
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),  # 线性注意力机制
                ChanFeedForward(dim = dim, mult = ff_mult, time_token_shift = ff_time_token_shift)  # 通道前馈神经网络
            ]))

    def forward(self, x, context = None):
        for attn, ff in self.layers:
            x = attn(x, context = context) + x  # 线性注意力机制处理后与原始张量相加
            x = ff(x) + x  # 通道前馈神经网络处理后与原始张量相加
        return x

# 定义一个交叉嵌入层模块
class CrossEmbedLayer(nn.Module):
    def __init__(
        self,
        dim_in,
        kernel_sizes,
        dim_out = None,
        stride = 2
    ):
        super().__init__()
        assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
        dim_out = default(dim_out, dim_in)

        kernel_sizes = sorted(kernel_sizes)
        num_scales = len(kernel_sizes)

        # 计算每个尺度的维度
        dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
        dim_scales = [*dim_scales, dim_out - sum(dim_scales)]

        self.convs = nn.ModuleList([])
        for kernel, dim_scale in zip(kernel_sizes, dim_scales):
            self.convs.append(Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))

    def forward(self, x):
        fmaps = tuple(map(lambda conv: conv(x), self.convs))  # 对输入张量进行卷积操作
        return torch.cat(fmaps, dim = 1)  # 在维度1上连接卷积结果

# 定义一个上采样合并器模块
class UpsampleCombiner(nn.Module):
    def __init__(
        self,
        dim,
        *,
        enabled = False,
        dim_ins = tuple(),
        dim_outs = tuple()
    # 初始化函数,设置输出维度和是否启用
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 将输出维度转换为元组,长度与输入维度相同
        dim_outs = cast_tuple(dim_outs, len(dim_ins))
        # 断言输入维度和输出维度长度相同
        assert len(dim_ins) == len(dim_outs)

        # 设置是否启用标志
        self.enabled = enabled

        # 如果未启用,则直接设置输出维度并返回
        if not self.enabled:
            self.dim_out = dim
            return

        # 根据输入维度和输出维度创建模块列表
        self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
        # 计算最终输出维度
        self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)

    # 前向传播函数,处理输入数据和特征图
    def forward(self, x, fmaps = None):
        # 获取输入数据的目标尺寸
        target_size = x.shape[-1]

        # 设置特征图为默认值空元组
        fmaps = default(fmaps, tuple())

        # 如果未启用或特征图为空或卷积模块为空,则直接返回输入数据
        if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
            return x

        # 将特征图调整为目标尺寸
        fmaps = [resize_video_to(fmap, target_size) for fmap in fmaps]
        # 对每个特征图应用对应的卷积模块
        outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
        # 拼接输入数据和卷积结果,沿指定维度拼接
        return torch.cat((x, *outs), dim = 1)
# 定义一个动态位置偏置的神经网络模块
class DynamicPositionBias(nn.Module):
    def __init__(
        self,
        dim,
        *,
        heads,
        depth
    ):
        super().__init__()
        self.mlp = nn.ModuleList([])

        # 添加一个线性层、LayerNorm 和 SiLU 激活函数到 MLP 中
        self.mlp.append(nn.Sequential(
            nn.Linear(1, dim),
            LayerNorm(dim),
            nn.SiLU()
        ))

        # 根据深度添加多个线性层、LayerNorm 和 SiLU 激活函数到 MLP 中
        for _ in range(max(depth - 1, 0)):
            self.mlp.append(nn.Sequential(
                nn.Linear(dim, dim),
                LayerNorm(dim),
                nn.SiLU()
            ))

        # 添加一个线性层到 MLP 中
        self.mlp.append(nn.Linear(dim, heads)

    # 前向传播函数
    def forward(self, n, device, dtype):
        # 创建张量 i 和 j
        i = torch.arange(n, device = device)
        j = torch.arange(n, device = device)

        # 计算位置索引
        indices = rearrange(i, 'i -> i 1') - rearrange(j, 'j -> 1 j')
        indices += (n - 1)

        # 创建位置张量
        pos = torch.arange(-n + 1, n, device = device, dtype = dtype)
        pos = rearrange(pos, '... -> ... 1')

        # 遍历 MLP 中的每一层
        for layer in self.mlp:
            pos = layer(pos)

        # 计算位置偏置
        bias = pos[indices]
        bias = rearrange(bias, 'i j h -> h i j')
        return bias

# 定义一个 3D UNet 神经网络模块
class Unet3D(nn.Module):
    def __init__(
        self,
        *,
        dim,
        text_embed_dim = get_encoded_dim(DEFAULT_T5_NAME),
        num_resnet_blocks = 1,
        cond_dim = None,
        num_image_tokens = 4,
        num_time_tokens = 2,
        learned_sinu_pos_emb_dim = 16,
        out_dim = None,
        dim_mults = (1, 2, 4, 8),
        temporal_strides = 1,
        cond_images_channels = 0,
        channels = 3,
        channels_out = None,
        attn_dim_head = 64,
        attn_heads = 8,
        ff_mult = 2.,
        ff_time_token_shift = True,         # 在 feedforwards 的隐藏层中沿时间轴进行令牌移位
        lowres_cond = False,                # 用于级联扩散
        layer_attns = False,
        layer_attns_depth = 1,
        layer_attns_add_text_cond = True,   # 是否在自注意力块中加入文本嵌入
        attend_at_middle = True,            # 是否在瓶颈处进行一层注意力
        time_rel_pos_bias_depth = 2,
        time_causal_attn = True,
        layer_cross_attns = True,
        use_linear_attn = False,
        use_linear_cross_attn = False,
        cond_on_text = True,
        max_text_len = 256,
        init_dim = None,
        resnet_groups = 8,
        init_conv_kernel_size = 7,          # 初始卷积的内核大小
        init_cross_embed = True,
        init_cross_embed_kernel_sizes = (3, 7, 15),
        cross_embed_downsample = False,
        cross_embed_downsample_kernel_sizes = (2, 4),
        attn_pool_text = True,
        attn_pool_num_latents = 32,
        dropout = 0.,
        memory_efficient = False,
        init_conv_to_final_conv_residual = False,
        use_global_context_attn = True,
        scale_skip_connection = True,
        final_resnet_block = True,
        final_conv_kernel_size = 3,
        self_cond = False,
        combine_upsample_fmaps = False,      # 在所有上采样块中合并特征图
        pixel_shuffle_upsample = True,       # 可能解决棋盘伪影
        resize_mode = 'nearest'
    # 如果当前 UNet 的设置不正确,则重新初始化 UNet
    def cast_model_parameters(
        self,
        *,
        lowres_cond,
        text_embed_dim,
        channels,
        channels_out,
        cond_on_text
    # 如果当前对象的属性与传入参数相同,则直接返回当前对象
    ):
        if lowres_cond == self.lowres_cond and \
            channels == self.channels and \
            cond_on_text == self.cond_on_text and \
            text_embed_dim == self._locals['text_embed_dim'] and \
            channels_out == self.channels_out:
            return self

        # 更新参数字典
        updated_kwargs = dict(
            lowres_cond = lowres_cond,
            text_embed_dim = text_embed_dim,
            channels = channels,
            channels_out = channels_out,
            cond_on_text = cond_on_text
        )

        # 返回一个新的类实例,使用当前对象的属性和更新后的参数
        return self.__class__(**{**self._locals, **updated_kwargs})

    # 返回完整的unet配置及其参数状态字典的方法

    def to_config_and_state_dict(self):
        return self._locals, self.state_dict()

    # 从配置和状态字典中重新创建unet的类方法

    @classmethod
    def from_config_and_state_dict(klass, config, state_dict):
        unet = klass(**config)
        unet.load_state_dict(state_dict)
        return unet

    # 将unet持久化到磁盘的方法

    def persist_to_file(self, path):
        path = Path(path)
        path.parents[0].mkdir(exist_ok = True, parents = True)

        config, state_dict = self.to_config_and_state_dict()
        pkg = dict(config = config, state_dict = state_dict)
        torch.save(pkg, str(path))

    # 从使用`persist_to_file`保存的文件中重新创建unet的类方法

    @classmethod
    def hydrate_from_file(klass, path):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path))

        assert 'config' in pkg and 'state_dict' in pkg
        config, state_dict = pkg['config'], pkg['state_dict']

        return Unet.from_config_and_state_dict(config, state_dict)

    # 带有分类器自由引导的前向传播

    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 1.,
        **kwargs
    ):
        logits = self.forward(*args, **kwargs)

        if cond_scale == 1:
            return logits

        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
        return null_logits + (logits - null_logits) * cond_scale

    def forward(
        self,
        x,
        time,
        *,
        lowres_cond_img = None,
        lowres_noise_times = None,
        text_embeds = None,
        text_mask = None,
        cond_images = None,
        cond_video_frames = None,
        post_cond_video_frames = None,
        self_cond = None,
        cond_drop_prob = 0.,
        ignore_time = False

.\lucidrains\imagen-pytorch\imagen_pytorch\t5.py

# 导入 torch 库
import torch
# 导入 transformers 库
import transformers
# 导入 List 类型
from typing import List
# 从 transformers 库中导入 T5Tokenizer, T5EncoderModel, T5Config
from transformers import T5Tokenizer, T5EncoderModel, T5Config
# 从 einops 库中导入 rearrange 函数
from einops import rearrange

# 设置 transformers 库的日志级别为 error
transformers.logging.set_verbosity_error()

# 定义函数,判断值是否存在
def exists(val):
    return val is not None

# 定义函数,返回默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 配置

# 定义最大长度为 256
MAX_LENGTH = 256

# 默认的 T5 模型名称
DEFAULT_T5_NAME = 'google/t5-v1_1-base'

# T5 配置字典
T5_CONFIGS = {}

# 全局单例变量

# 获取指定名称的 tokenizer
def get_tokenizer(name):
    tokenizer = T5Tokenizer.from_pretrained(name, model_max_length=MAX_LENGTH)
    return tokenizer

# 获取指定名称的模型
def get_model(name):
    model = T5EncoderModel.from_pretrained(name)
    return model

# 获取指定名称的模型和 tokenizer
def get_model_and_tokenizer(name):
    global T5_CONFIGS

    if name not in T5_CONFIGS:
        T5_CONFIGS[name] = dict()
    if "model" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["model"] = get_model(name)
    if "tokenizer" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)

    return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']

# 获取编码维度
def get_encoded_dim(name):
    if name not in T5_CONFIGS:
        # 避免仅获取维度时加载模型
        config = T5Config.from_pretrained(name)
        T5_CONFIGS[name] = dict(config=config)
    elif "config" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["config"]
    elif "model" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["model"].config
    else:
        assert False
    return config.d_model

# 编码文本

# 对文本进行分词
def t5_tokenize(
    texts: List[str],
    name = DEFAULT_T5_NAME
):
    t5, tokenizer = get_model_and_tokenizer(name)

    if torch.cuda.is_available():
        t5 = t5.cuda()

    device = next(t5.parameters()).device

    encoded = tokenizer.batch_encode_plus(
        texts,
        return_tensors = "pt",
        padding = 'longest',
        max_length = MAX_LENGTH,
        truncation = True
    )

    input_ids = encoded.input_ids.to(device)
    attn_mask = encoded.attention_mask.to(device)
    return input_ids, attn_mask

# 对分词后的文本进行编码
def t5_encode_tokenized_text(
    token_ids,
    attn_mask = None,
    pad_id = None,
    name = DEFAULT_T5_NAME
):
    assert exists(attn_mask) or exists(pad_id)
    t5, _ = get_model_and_tokenizer(name)

    attn_mask = default(attn_mask, lambda: (token_ids != pad_id).long())

    t5.eval()

    with torch.no_grad():
        output = t5(input_ids = token_ids, attention_mask = attn_mask)
        encoded_text = output.last_hidden_state.detach()

    attn_mask = attn_mask.bool()

    encoded_text = encoded_text.masked_fill(~rearrange(attn_mask, '... -> ... 1'), 0.) # 强制所有填充的嵌入为 0
    return encoded_text

# 对文本进行编码
def t5_encode_text(
    texts: List[str],
    name = DEFAULT_T5_NAME,
    return_attn_mask = False
):
    token_ids, attn_mask = t5_tokenize(texts, name = name)
    encoded_text = t5_encode_tokenized_text(token_ids, attn_mask = attn_mask, name = name)

    if return_attn_mask:
        attn_mask = attn_mask.bool()
        return encoded_text, attn_mask

    return encoded_text

.\lucidrains\imagen-pytorch\imagen_pytorch\test\test_trainer.py

# 从 imagen_pytorch 包中导入 ImagenTrainer 类
# 从 imagen_pytorch 包中导入 ImagenConfig 类
# 从 imagen_pytorch 包中导入 t5_encode_text 函数
# 从 torch.utils.data 包中导入 Dataset 类
# 导入 torch 库
from imagen_pytorch.trainer import ImagenTrainer
from imagen_pytorch.configs import ImagenConfig
from imagen_pytorch.t5 import t5_encode_text
from torch.utils.data import Dataset
import torch

# 定义一个测试函数,用于测试 ImagenTrainer 类的实例化
def test_trainer_instantiation():
    # 定义 unet1 字典,包含模型的参数配置
    unet1 = dict(
        dim = 8,
        dim_mults = (1, 1, 1, 1),
        num_resnet_blocks = 1,
        layer_attns = False,
        layer_cross_attns = False,
        attn_heads = 2
    )

    # 创建 ImagenConfig 对象,传入 unet1 参数配置
    imagen = ImagenConfig(
        unets=(unet1,),
        image_sizes=(64,),
    ).create()

    # 实例化 ImagenTrainer 对象,传入 imagen 参数
    trainer = ImagenTrainer(
        imagen=imagen
    )

# 定义一个测试函数,用于测试训练步骤
def test_trainer_step():
    # 定义一个自定义的 Dataset 类,用于生成训练数据
    class TestDataset(Dataset):
        def __init__(self):
            super().__init__()
        def __len__(self):
            return 16
        def __getitem__(self, index):
            return (torch.zeros(3, 64, 64), torch.zeros(6, 768))
    
    # 定义 unet1 字典,包含模型的参数配置
    unet1 = dict(
        dim = 8,
        dim_mults = (1, 1, 1, 1),
        num_resnet_blocks = 1,
        layer_attns = False,
        layer_cross_attns = False,
        attn_heads = 2
    )

    # 创建 ImagenConfig 对象,传入 unet1 参数配置
    imagen = ImagenConfig(
        unets=(unet1,),
        image_sizes=(64,),
    ).create()

    # 实例化 ImagenTrainer 对象,传入 imagen 参数
    trainer = ImagenTrainer(
        imagen=imagen
    )

    # 创建 TestDataset 对象
    ds = TestDataset()
    # 将数据集添加到训练器中,设置批量大小为 8
    trainer.add_train_dataset(ds, batch_size=8)
    # 执行一次训练步骤
    trainer.train_step(1)
    # 断言训练步骤的数量为 1
    assert trainer.num_steps_taken(1) == 1

.\lucidrains\imagen-pytorch\imagen_pytorch\test\__init__.py

# 从 imagen_pytorch.test 模块中导入 test_trainer 函数
from imagen_pytorch.test import test_trainer

.\lucidrains\imagen-pytorch\imagen_pytorch\trainer.py

# 导入必要的库
import os
from math import ceil
from contextlib import contextmanager, nullcontext
from functools import partial, wraps
from collections.abc import Iterable

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from torch.cuda.amp import autocast, GradScaler

import pytorch_warmup as warmup

from imagen_pytorch.imagen_pytorch import Imagen, NullUnet
from imagen_pytorch.elucidated_imagen import ElucidatedImagen
from imagen_pytorch.data import cycle

from imagen_pytorch.version import __version__
from packaging import version

import numpy as np

from ema_pytorch import EMA

from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs

from fsspec.core import url_to_fs
from fsspec.implementations.local import LocalFileSystem

# 辅助函数

# 检查值是否存在
def exists(val):
    return val is not None

# 返回值或默认值
def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# 将值转换为元组
def cast_tuple(val, length = 1):
    if isinstance(val, list):
        val = tuple(val)

    return val if isinstance(val, tuple) else ((val,) * length)

# 查找第一个满足条件的元素的索引
def find_first(fn, arr):
    for ind, el in enumerate(arr):
        if fn(el):
            return ind
    return -1

# 选择并弹出指定键的值
def pick_and_pop(keys, d):
    values = list(map(lambda key: d.pop(key), keys))
    return dict(zip(keys, values))

# 根据键的条件分组字典
def group_dict_by_key(cond, d):
    return_val = [dict(),dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)

# 检查字符串是否以指定前缀开头
def string_begins_with(prefix, str):
    return str.startswith(prefix)

# 根据键的前缀分组字典
def group_by_key_prefix(prefix, d):
    return group_dict_by_key(partial(string_begins_with, prefix), d)

# 根据前缀分组字典并修剪键
def groupby_prefix_and_trim(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
    return kwargs_without_prefix, kwargs

# 将数字分成组
def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

# URL转换为文���系统、存储桶、路径 - 用于将检查点保存到云端

def url_to_bucket(url):
    if '://' not in url:
        return url

    _, suffix = url.split('://')

    if prefix in {'gs', 's3'}:
        return suffix.split('/')[0]
    else:
        raise ValueError(f'storage type prefix "{prefix}" is not supported yet')

# 装饰器

# 模型评估装饰器
def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# 转换为Torch张量装饰器
def cast_torch_tensor(fn, cast_fp16 = False):
    @wraps(fn)
    def inner(model, *args, **kwargs):
        device = kwargs.pop('_device', model.device)
        cast_device = kwargs.pop('_cast_device', True)

        should_cast_fp16 = cast_fp16 and model.cast_half_at_training

        kwargs_keys = kwargs.keys()
        all_args = (*args, *kwargs.values())
        split_kwargs_index = len(all_args) - len(kwargs_keys)
        all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))

        if cast_device:
            all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))

        if should_cast_fp16:
            all_args = tuple(map(lambda t: t.half() if exists(t) and isinstance(t, torch.Tensor) and t.dtype != torch.bool else t, all_args))

        args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
        kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))

        out = fn(model, *args, **kwargs)
        return out
    return inner
# 定义一个函数,将可迭代对象按照指定大小分割成子列表
def split_iterable(it, split_size):
    accum = []
    # 遍历可迭代对象,根据指定大小分割成子列表
    for ind in range(ceil(len(it) / split_size)):
        start_index = ind * split_size
        accum.append(it[start_index: (start_index + split_size)])
    return accum

# 定义一个函数,根据不同类型的输入进行分割操作
def split(t, split_size = None):
    # 如果未指定分割大小,则直接返回输入
    if not exists(split_size):
        return t

    # 如果输入是 torch.Tensor 类型,则按照指定大小在指定维度上进行分割
    if isinstance(t, torch.Tensor):
        return t.split(split_size, dim = 0)

    # 如果输入是可迭代对象,则调用 split_iterable 函数进行分割
    if isinstance(t, Iterable):
        return split_iterable(t, split_size)

    # 其他情况返回类型错误
    return TypeError

# 定义一个函数,查找满足条件的第一个元素
def find_first(cond, arr):
    # 遍历数组,找到满足条件的第一个元素并返回
    for el in arr:
        if cond(el):
            return el
    return None

# 定义一个函数,将参数和关键字参数按照指定大小分割成子列表
def split_args_and_kwargs(*args, split_size = None, **kwargs):
    # 将所有参数和关键字参数合并成一个列表
    all_args = (*args, *kwargs.values())
    len_all_args = len(all_args)
    # 找到第一个是 torch.Tensor 类型的参数
    first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
    assert exists(first_tensor)

    # 获取第一个 tensor 的大小作为 batch_size
    batch_size = len(first_tensor)
    split_size = default(split_size, batch_size)
    num_chunks = ceil(batch_size / split_size)

    dict_len = len(kwargs)
    dict_keys = kwargs.keys()
    split_kwargs_index = len_all_args - dict_len

    # 对所有参数和关键字参数进行分割操作
    split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
    chunk_sizes = num_to_groups(batch_size, split_size)

    # 遍历分割后的结果,生成分块大小比例和分块后的参数和关键字参数
    for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
        chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
        chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
        chunk_size_frac = chunk_size / batch_size
        yield chunk_size_frac, (chunked_args, chunked_kwargs)

# 定义一个装饰器函数,用于对输入的函数进行分块处理
def imagen_sample_in_chunks(fn):
    @wraps(fn)
    def inner(self, *args, max_batch_size = None, **kwargs):
        # 如果未指定最大批处理大小,则直接调用原函数
        if not exists(max_batch_size):
            return fn(self, *args, **kwargs)

        # 如果是无条件的训练,则根据最大批处理大小分块处理
        if self.imagen.unconditional:
            batch_size = kwargs.get('batch_size')
            batch_sizes = num_to_groups(batch_size, max_batch_size)
            outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
        else:
            # 否则根据参数和关键字参数进行分块处理
            outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]

        # 如果输出是 torch.Tensor 类型,则按照指定维��拼接
        if isinstance(outputs[0], torch.Tensor):
            return torch.cat(outputs, dim = 0)

        # 否则对输出进行拼接处理
        return list(map(lambda t: torch.cat(t, dim = 0), list(zip(*outputs))))

    return inner

# 定义一个函数,用于恢复模型的部分参数
def restore_parts(state_dict_target, state_dict_from):
    for name, param in state_dict_from.items():

        if name not in state_dict_target:
            continue

        if param.size() == state_dict_target[name].size():
            state_dict_target[name].copy_(param)
        else:
            print(f"layer {name}({param.size()} different than target: {state_dict_target[name].size()}")

    return state_dict_target

# 定义一个类,用于图像生成的训练
class ImagenTrainer(nn.Module):
    locked = False

    def __init__(
        self,
        imagen = None,
        imagen_checkpoint_path = None,
        use_ema = True,
        lr = 1e-4,
        eps = 1e-8,
        beta1 = 0.9,
        beta2 = 0.99,
        max_grad_norm = None,
        group_wd_params = True,
        warmup_steps = None,
        cosine_decay_max_steps = None,
        only_train_unet_number = None,
        fp16 = False,
        precision = None,
        split_batches = True,
        dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks', 'cond_images'),
        verbose = True,
        split_valid_fraction = 0.025,
        split_valid_from_train = False,
        split_random_seed = 42,
        checkpoint_path = None,
        checkpoint_every = None,
        checkpoint_fs = None,
        fs_kwargs: dict = None,
        max_checkpoints_keep = 20,
        **kwargs
    # 准备训练器,确保训练器尚未准备好,设置只训练的 UNet 编号,并将 prepared 标记为 True
    def prepare(self):
        assert not self.prepared, f'The trainer is allready prepared'
        self.validate_and_set_unet_being_trained(self.only_train_unet_number)
        self.prepared = True
    # 计算属性

    @property
    def device(self):
        return self.accelerator.device

    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    @property
    def is_main(self):
        return self.accelerator.is_main_process

    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    @property
    def unwrapped_unet(self):
        return self.accelerator.unwrap_model(self.unet_being_trained)

    # 优化器辅助函数

    def get_lr(self, unet_number):
        self.validate_unet_number(unet_number)
        unet_index = unet_number - 1

        optim = getattr(self, f'optim{unet_index}')

        return optim.param_groups[0]['lr']

    # 仅允许同时训练一个 UNet 的函数

    def validate_and_set_unet_being_trained(self, unet_number = None):
        if exists(unet_number):
            self.validate_unet_number(unet_number)

        assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet'

        self.only_train_unet_number = unet_number
        self.imagen.only_train_unet_number = unet_number

        if not exists(unet_number):
            return

        self.wrap_unet(unet_number)

    def wrap_unet(self, unet_number):
        if hasattr(self, 'one_unet_wrapped'):
            return

        unet = self.imagen.get_unet(unet_number)
        unet_index = unet_number - 1

        optimizer = getattr(self, f'optim{unet_index}')
        scheduler = getattr(self, f'scheduler{unet_index}')

        if self.train_dl:
            self.unet_being_trained, self.train_dl, optimizer = self.accelerator.prepare(unet, self.train_dl, optimizer)
        else:
            self.unet_being_trained, optimizer = self.accelerator.prepare(unet, optimizer)

        if exists(scheduler):
            scheduler = self.accelerator.prepare(scheduler)

        setattr(self, f'optim{unet_index}', optimizer)
        setattr(self, f'scheduler{unet_index}', scheduler)

        self.one_unet_wrapped = True

    # 由于没有每个优化器单独的 gradscaler,对 accelerator 进行修改

    def set_accelerator_scaler(self, unet_number):
        def patch_optimizer_step(accelerated_optimizer, method):
            def patched_step(*args, **kwargs):
                accelerated_optimizer._accelerate_step_called = True
                return method(*args, **kwargs)
            return patched_step

        unet_number = self.validate_unet_number(unet_number)
        scaler = getattr(self, f'scaler{unet_number - 1}')

        self.accelerator.scaler = scaler
        for optimizer in self.accelerator._optimizers:
            optimizer.scaler = scaler
            optimizer._accelerate_step_called = False
            optimizer._optimizer_original_step_method = optimizer.optimizer.step
            optimizer._optimizer_patched_step_method = patch_optimizer_step(optimizer, optimizer.optimizer.step)

    # 辅助打印函数

    def print(self, msg):
        if not self.is_main:
            return

        if not self.verbose:
            return

        return self.accelerator.print(msg)

    # 验证 UNet 编号

    def validate_unet_number(self, unet_number = None):
        if self.num_unets == 1:
            unet_number = default(unet_number, 1)

        assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}'
        return unet_number

    # 训练步骤数
    # 返回指定 U-Net 编号的训练步数
    def num_steps_taken(self, unet_number = None):
        # 如果只有一个 U-Net,则默认使用编号为 1
        if self.num_unets == 1:
            unet_number = default(unet_number, 1)

        # 返回指定 U-Net 的训练步数
        return self.steps[unet_number - 1].item()

    # 打印未训练的 U-Net
    def print_untrained_unets(self):
        print_final_error = False

        # 遍历训练步数和 U-Net 对象,检查是否未训练
        for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)):
            if steps > 0 or isinstance(unet, NullUnet):
                continue

            # 打印未训练的 U-Net 编号
            self.print(f'unet {ind + 1} has not been trained')
            print_final_error = True

        # 如果存在未训练的 U-Net,则打印提示信息
        if print_final_error:
            self.print('when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets')

    # 数据相关函数

    # 添加训练数据加载器
    def add_train_dataloader(self, dl = None):
        if not exists(dl):
            return

        # 确保训练数据加载器未添加过
        assert not exists(self.train_dl), 'training dataloader was already added'
        assert not self.prepared, f'You need to add the dataset before preperation'
        self.train_dl = dl

    # 添加验证数据加载器
    def add_valid_dataloader(self, dl):
        if not exists(dl):
            return

        # 确保验证数据加载器未添加过
        assert not exists(self.valid_dl), 'validation dataloader was already added'
        assert not self.prepared, f'You need to add the dataset before preperation'
        self.valid_dl = dl

    # 添加训练数据集
    def add_train_dataset(self, ds = None, *, batch_size, **dl_kwargs):
        if not exists(ds):
            return

        # 确保训练数据加载器未添加过
        assert not exists(self.train_dl), 'training dataloader was already added'

        # 如果需要从训练数据集中分割验证数据集
        valid_ds = None
        if self.split_valid_from_train:
            # 计算训练数据集和验证数据集的大小
            train_size = int((1 - self.split_valid_fraction) * len(ds)
            valid_size = len(ds) - train_size

            # 随机分割数据集
            ds, valid_ds = random_split(ds, [train_size, valid_size], generator = torch.Generator().manual_seed(self.split_random_seed))
            self.print(f'training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples')

        # 创建数据加载器并添加训练数据加载器
        dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs)
        self.add_train_dataloader(dl)

        # 如果不需要从训练数据集中分割验证数据集,则直接返回
        if not self.split_valid_from_train:
            return

        # 添加验证数据集
        self.add_valid_dataset(valid_ds, batch_size = batch_size, **dl_kwargs)

    # 添加验证数据集
    def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs):
        if not exists(ds):
            return

        # 确保验证数据加载器未添加过
        assert not exists(self.valid_dl), 'validation dataloader was already added'

        # 创建数据加载器并添加验证数据加载器
        dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs)
        self.add_valid_dataloader(dl)

    # 创建训练数据迭代器
    def create_train_iter(self):
        assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet'

        if exists(self.train_dl_iter):
            return

        self.train_dl_iter = cycle(self.train_dl)

    # 创建验证数据迭代器
    def create_valid_iter(self):
        assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet'

        if exists(self.valid_dl_iter):
            return

        self.valid_dl_iter = cycle(self.valid_dl)

    # 训练步骤
    def train_step(self, *, unet_number = None, **kwargs):
        if not self.prepared:
            self.prepare()
        self.create_train_iter()

        kwargs = {'unet_number': unet_number, **kwargs}
        loss = self.step_with_dl_iter(self.train_dl_iter, **kwargs)
        self.update(unet_number = unet_number)
        return loss

    # 验证步骤
    @torch.no_grad()
    @eval_decorator
    def valid_step(self, **kwargs):
        if not self.prepared:
            self.prepare()
        self.create_valid_iter()
        context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext
        with context():
            loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs)
        return loss
    # 使用 dl_iter 迭代器获取下一个数据元组
    def step_with_dl_iter(self, dl_iter, **kwargs):
        dl_tuple_output = cast_tuple(next(dl_iter))
        # 将数据元组转换为字典
        model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output)))
        # 调用 forward 方法计算损失
        loss = self.forward(**{**kwargs, **model_input})
        return loss

    # 检查点函数

    # 获取所有按照时间排序的检查点文件
    @property
    def all_checkpoints_sorted(self):
        glob_pattern = os.path.join(self.checkpoint_path, '*.pt')
        checkpoints = self.fs.glob(glob_pattern)
        sorted_checkpoints = sorted(checkpoints, key = lambda x: int(str(x).split('.')[-2]), reverse = True)
        return sorted_checkpoints

    # 从检查点文件夹加载模型
    def load_from_checkpoint_folder(self, last_total_steps = -1):
        if last_total_steps != -1:
            filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt')
            self.load(filepath)
            return

        sorted_checkpoints = self.all_checkpoints_sorted

        if len(sorted_checkpoints) == 0:
            self.print(f'no checkpoints found to load from at {self.checkpoint_path}')
            return

        last_checkpoint = sorted_checkpoints[0]
        self.load(last_checkpoint)

    # 保存到检查点文件夹
    def save_to_checkpoint_folder(self):
        self.accelerator.wait_for_everyone()

        if not self.can_checkpoint:
            return

        total_steps = int(self.steps.sum().item())
        filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt')

        self.save(filepath)

        if self.max_checkpoints_keep <= 0:
            return

        sorted_checkpoints = self.all_checkpoints_sorted
        checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:]

        for checkpoint in checkpoints_to_discard:
            self.fs.rm(checkpoint)

    # 保存和加载函数

    # 保存模型到指定路径
    def save(
        self,
        path,
        overwrite = True,
        without_optim_and_sched = False,
        **kwargs
    ):
        self.accelerator.wait_for_everyone()

        if not self.can_checkpoint:
            return

        fs = self.fs

        assert not (fs.exists(path) and not overwrite)

        self.reset_ema_unets_all_one_device()

        # 构建保存对象
        save_obj = dict(
            model = self.imagen.state_dict(),
            version = __version__,
            steps = self.steps.cpu(),
            **kwargs
        )

        save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple()

        # 保存优化器和调度器状态
        for ind in save_optim_and_sched_iter:
            scaler_key = f'scaler{ind}'
            optimizer_key = f'optim{ind}'
            scheduler_key = f'scheduler{ind}'
            warmup_scheduler_key = f'warmup{ind}'

            scaler = getattr(self, scaler_key)
            optimizer = getattr(self, optimizer_key)
            scheduler = getattr(self, scheduler_key)
            warmup_scheduler = getattr(self, warmup_scheduler_key)

            if exists(scheduler):
                save_obj = {**save_obj, scheduler_key: scheduler.state_dict()}

            if exists(warmup_scheduler):
                save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()}

            save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}

        if self.use_ema:
            save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}

        # 确定是否存在 imagen 配置
        if hasattr(self.imagen, '_config'):
            self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"<prompt>"')
            save_obj = {
                **save_obj,
                'imagen_type': 'elucidated' if self.is_elucidated else 'original',
                'imagen_params': self.imagen._config
            }

        # 保存到指定路径
        with fs.open(path, 'wb') as f:
            torch.save(save_obj, f)

        self.print(f'checkpoint saved to {path}')
    # 加载模型参数和优化器状态
    def load(self, path, only_model = False, strict = True, noop_if_not_exist = False):
        # 获取文件系统对象
        fs = self.fs

        # 如果文件不存在且设置了不执行操作,则打印消息并返回
        if noop_if_not_exist and not fs.exists(path):
            self.print(f'trainer checkpoint not found at {str(path)}')
            return

        # 断言文件存在,否则抛出异常
        assert fs.exists(path), f'{path} does not exist'

        # 重置所有 EMA 模型到同一设备上
        self.reset_ema_unets_all_one_device()

        # 避免在主进程中使用 Accelerate 时产生额外的 GPU 内存使用
        with fs.open(path) as f:
            # 加载模型参数和优化器状态
            loaded_obj = torch.load(f, map_location='cpu')

        # 检查加载的模型版本是否与当前包版本一致
        if version.parse(__version__) != version.parse(loaded_obj['version']):
            self.print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}')

        try:
            # 加载模型参数
            self.imagen.load_state_dict(loaded_obj['model'], strict = strict)
        except RuntimeError:
            print("Failed loading state dict. Trying partial load")
            # 尝试部分加载模型参数
            self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(),
                                                      loaded_obj['model']))

        # 如果只加载模型参数,则返回加载的对象
        if only_model:
            return loaded_obj

        # 复制加载的步数
        self.steps.copy_(loaded_obj['steps'])

        # 遍历所有 U-Net 模型
        for ind in range(0, self.num_unets):
            scaler_key = f'scaler{ind}'
            optimizer_key = f'optim{ind}'
            scheduler_key = f'scheduler{ind}'
            warmup_scheduler_key = f'warmup{ind}'

            # 获取对应的 scaler、optimizer、scheduler 和 warmup_scheduler
            scaler = getattr(self, scaler_key)
            optimizer = getattr(self, optimizer_key)
            scheduler = getattr(self, scheduler_key)
            warmup_scheduler = getattr(self, warmup_scheduler_key)

            # 如果 scheduler 存在且在加载对象中有对应的键,则加载其状态
            if exists(scheduler) and scheduler_key in loaded_obj:
                scheduler.load_state_dict(loaded_obj[scheduler_key])

            # 如果 warmup_scheduler 存在且在加载对象中���对应的键,则加载其状态
            if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj:
                warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key])

            # 如果 optimizer 存在,则尝试加载其状态
            if exists(optimizer):
                try:
                    optimizer.load_state_dict(loaded_obj[optimizer_key])
                    scaler.load_state_dict(loaded_obj[scaler_key])
                except:
                    self.print('could not load optimizer and scaler, possibly because you have turned on mixed precision training since the last run. resuming with new optimizer and scalers')

        # 如果使用 EMA,则加载 EMA 模型参数
        if self.use_ema:
            assert 'ema' in loaded_obj
            try:
                self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
            except RuntimeError:
                print("Failed loading state dict. Trying partial load")
                self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(),
                                                             loaded_obj['ema']))

        # 打印加载成功的消息,并返回加载的对象
        self.print(f'checkpoint loaded from {path}')
        return loaded_obj

    # 获取所有 EMA 模型
    @property
    def unets(self):
        return nn.ModuleList([ema.ema_model for ema in self.ema_unets])

    # 获取指定编号的 EMA 模型
    def get_ema_unet(self, unet_number = None):
        # 如果不使用 EMA,则返回
        if not self.use_ema:
            return

        # 验证并获取正确的 U-Net 编号
        unet_number = self.validate_unet_number(unet_number)
        index = unet_number - 1

        # 如果 unets 是 nn.ModuleList,则转换为列表并更新 ema_unets
        if isinstance(self.unets, nn.ModuleList):
            unets_list = [unet for unet in self.ema_unets]
            delattr(self, 'ema_unets')
            self.ema_unets = unets_list

        # 将当前训练的 EMA 模型移到指定设备上
        if index != self.ema_unet_being_trained_index:
            for unet_index, unet in enumerate(self.ema_unets):
                unet.to(self.device if unet_index == index else 'cpu')

        # 更新当前训练的 EMA 模型索引,并返回对应的 EMA 模型
        self.ema_unet_being_trained_index = index
        return self.ema_unets[index]

    # 重置所有 EMA 模型到指定设备上
    def reset_ema_unets_all_one_device(self, device = None):
        # 如果不使用 EMA,则返回
        if not self.use_ema:
            return

        # 获取默认设备
        device = default(device, self.device)
        # 将所有 EMA 模型转移到指定设备上
        self.ema_unets = nn.ModuleList([*self.ema_unets])
        self.ema_unets.to(device)

        # 重置当前训练的 EMA 模型索引
        self.ema_unet_being_trained_index = -1

    # 禁用梯度计算
    @torch.no_grad()
    # 定义一个上下文管理器,用于控制是否使用指数移动平均的 U-Net 模型
    @contextmanager
    def use_ema_unets(self):
        # 如果不使用指数移动平均模型,则直接返回输出
        if not self.use_ema:
            output = yield
            return output

        # 重置所有 U-Net 模型为同一设备上的指数移动平均模型
        self.reset_ema_unets_all_one_device()
        self.imagen.reset_unets_all_one_device()

        # 将 U-Net 模型设置为评估模式
        self.unets.eval()

        # 保存可训练的 U-Net 模型,然后将指数移动平均模型用于采样
        trainable_unets = self.imagen.unets
        self.imagen.unets = self.unets

        output = yield

        # 恢复原始的训练 U-Net 模型
        self.imagen.unets = trainable_unets

        # 将指数移动平均模型的 U-Net 恢复到原始设备
        for ema in self.ema_unets:
            ema.restore_ema_model_device()

        return output

    # 打印 U-Net 模型的设备信息
    def print_unet_devices(self):
        self.print('unet devices:')
        for i, unet in enumerate(self.imagen.unets):
            device = next(unet.parameters()).device
            self.print(f'\tunet {i}: {device}')

        # 如果不使用指数移动平均模型,则直接返回
        if not self.use_ema:
            return

        self.print('\nema unet devices:')
        for i, ema_unet in enumerate(self.ema_unets):
            device = next(ema_unet.parameters()).device
            self.print(f'\tema unet {i}: {device}')

    # 重写状态字典函数

    def state_dict(self, *args, **kwargs):
        # 重置所有 U-Net 模型为同一设备上的指数移动平均模型
        self.reset_ema_unets_all_one_device()
        return super().state_dict(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
        # 重置所有 U-Net 模型为同一设备上的指数移动平均模型
        self.reset_ema_unets_all_one_device()
        return super().load_state_dict(*args, **kwargs)

    # 编码文本函数

    def encode_text(self, text, **kwargs):
        return self.imagen.encode_text(text, **kwargs)

    # 前向传播函数和梯度更新步骤

    def update(self, unet_number = None):
        unet_number = self.validate_unet_number(unet_number)
        self.validate_and_set_unet_being_trained(unet_number)
        self.set_accelerator_scaler(unet_number)

        index = unet_number - 1
        unet = self.unet_being_trained

        optimizer = getattr(self, f'optim{index}')
        scaler = getattr(self, f'scaler{index}')
        scheduler = getattr(self, f'scheduler{index}')
        warmup_scheduler = getattr(self, f'warmup{index}')

        # 在加速器上设置梯度缩放器,因为我们每个 U-Net 管理一个

        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm)

        optimizer.step()
        optimizer.zero_grad()

        if self.use_ema:
            ema_unet = self.get_ema_unet(unet_number)
            ema_unet.update()

        # 调度器,如果需要

        maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening()

        with maybe_warmup_context:
            if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: # 推荐在文档中
                scheduler.step()

        self.steps += F.one_hot(torch.tensor(unet_number - 1, device = self.steps.device), num_classes = len(self.steps))

        if not exists(self.checkpoint_path):
            return

        total_steps = int(self.steps.sum().item())

        if total_steps % self.checkpoint_every:
            return

        self.save_to_checkpoint_folder()

    @torch.no_grad()
    @cast_torch_tensor
    @imagen_sample_in_chunks
    def sample(self, *args, **kwargs):
        context = nullcontext if  kwargs.pop('use_non_ema', False) else self.use_ema_unets

        self.print_untrained_unets()

        if not self.is_main:
            kwargs['use_tqdm'] = False

        with context():
            output = self.imagen.sample(*args, device = self.device, **kwargs)

        return output

    @partial(cast_torch_tensor, cast_fp16 = True)
    def forward(
        self,
        *args,
        unet_number = None,
        max_batch_size = None,
        **kwargs
        ):
        # 验证并修正 UNet 编号
        unet_number = self.validate_unet_number(unet_number)
        # 验证并设置正在训练的 UNet 编号
        self.validate_and_set_unet_being_trained(unet_number)
        # 设置加速器缩放器
        self.set_accelerator_scaler(unet_number)

        # 断言只有训练指定 UNet 编号或者没有指定 UNet 编号
        assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}'

        # 初始化总损失
        total_loss = 0.

        # 将参数和关键字参数按照最大批处理大小拆分
        for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
            # 使用加速器自动转换
            with self.accelerator.autocast():
                # 计算损失
                loss = self.imagen(*chunked_args, unet = self.unet_being_trained, unet_number = unet_number, **chunked_kwargs)
                # 损失乘以分块大小比例
                loss = loss * chunk_size_frac

            # 累加总损失
            total_loss += loss.item()

            # 如果处于训练状态,进行反向传播
            if self.training:
                self.accelerator.backward(loss)

        # 返回总损失
        return total_loss

.\lucidrains\imagen-pytorch\imagen_pytorch\utils.py

# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 functools 库中导入 reduce 函数
from functools import reduce
# 从 pathlib 库中导入 Path 类
from pathlib import Path

# 从 imagen_pytorch.configs 模块中导入 ImagenConfig 和 ElucidatedImagenConfig 类
from imagen_pytorch.configs import ImagenConfig, ElucidatedImagenConfig
# 从 ema_pytorch 模块中导入 EMA 类

from ema_pytorch import EMA

# 定义一个函数,用于检查变量是否存在
def exists(val):
    return val is not None

# 定义一个函数,用于安全获取字典中的值
def safeget(dictionary, keys, default = None):
    return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)

# 加载模型和配置信息
def load_imagen_from_checkpoint(
    checkpoint_path,
    load_weights = True,
    load_ema_if_available = False
):
    # 创建 Path 对象
    model_path = Path(checkpoint_path)
    # 获取完整的模型路径
    full_model_path = str(model_path.resolve())
    # 断言模型路径存在
    assert model_path.exists(), f'checkpoint not found at {full_model_path}'
    # 加载模型参数
    loaded = torch.load(str(model_path), map_location='cpu')

    # 获取 imagen 参数和类型
    imagen_params = safeget(loaded, 'imagen_params')
    imagen_type = safeget(loaded, 'imagen_type')

    # 根据 imagen 类型选择对应的配置类
    if imagen_type == 'original':
        imagen_klass = ImagenConfig
    elif imagen_type == 'elucidated':
        imagen_klass = ElucidatedImagenConfig
    else:
        raise ValueError(f'unknown imagen type {imagen_type} - you need to instantiate your Imagen with configurations, using classes ImagenConfig or ElucidatedImagenConfig')

    # 断言 imagen 参数和类型存在
    assert exists(imagen_params) and exists(imagen_type), 'imagen type and configuration not saved in this checkpoint'

    # 根据配置类和参数创建 imagen 对象
    imagen = imagen_klass(**imagen_params).create()

    # 如果不加载权重,则直接返回 imagen 对象
    if not load_weights:
        return imagen

    # 检查是否存在 EMA 模型
    has_ema = 'ema' in loaded
    should_load_ema = has_ema and load_ema_if_available

    # 加载模型参数
    imagen.load_state_dict(loaded['model'])

    # 如果不需要加载 EMA 模型,则直接返回 imagen 对象
    if not should_load_ema:
        print('loading non-EMA version of unets')
        return imagen

    # 创建 EMA 模型列表
    ema_unets = nn.ModuleList([])
    # 遍历 imagen.unets,为每个 unet 创建一个 EMA 模型
    for unet in imagen.unets:
        ema_unets.append(EMA(unet))

    # 加载 EMA 模型参数
    ema_unets.load_state_dict(loaded['ema'])

    # 将 EMA 模型参数加载到对应的 unet 模型中
    for unet, ema_unet in zip(imagen.unets, ema_unets):
        unet.load_state_dict(ema_unet.ema_model.state_dict())

    # 打印信息并返回 imagen 对象
    print('loaded EMA version of unets')
    return imagen

.\lucidrains\imagen-pytorch\imagen_pytorch\version.py

# 定义变量 __version__,赋值为字符串 '1.26.2'
__version__ = '1.26.2'

.\lucidrains\imagen-pytorch\imagen_pytorch\__init__.py

# 从 imagen_pytorch 模块中导入 Imagen 和 Unet 类
from imagen_pytorch.imagen_pytorch import Imagen, Unet
# 从 imagen_pytorch 模块中导入 NullUnet 类
from imagen_pytorch.imagen_pytorch import NullUnet
# 从 imagen_pytorch 模块中导入 BaseUnet64, SRUnet256, SRUnet1024 类
from imagen_pytorch.imagen_pytorch import BaseUnet64, SRUnet256, SRUnet1024
# 从 imagen_pytorch 模块中导入 ImagenTrainer 类
from imagen_pytorch.trainer import ImagenTrainer
# 从 imagen_pytorch 模块中导入 __version__ 变量
from imagen_pytorch.version import __version__

# 使用 Tero Karras 的新论文中阐述的 ddpm 创建 imagen

# 从 imagen_pytorch 模块中导入 ElucidatedImagen 类
from imagen_pytorch.elucidated_imagen import ElucidatedImagen

# 通过配置创建 imagen 实例

# 从 imagen_pytorch 模块中导入 UnetConfig, ImagenConfig, ElucidatedImagenConfig, ImagenTrainerConfig 类
from imagen_pytorch.configs import UnetConfig, ImagenConfig, ElucidatedImagenConfig, ImagenTrainerConfig

# 工具

# 从 imagen_pytorch 模块中导入 load_imagen_from_checkpoint 函数
from imagen_pytorch.utils import load_imagen_from_checkpoint

# 视频

# 从 imagen_pytorch 模块中导入 Unet3D 类
from imagen_pytorch.imagen_video import Unet3D

Imagen - Pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network that beats DALL-E2, in Pytorch. It is the new SOTA for text-to-image synthesis.

Architecturally, it is actually much simpler than DALL-E2. It consists of a cascading DDPM conditioned on text embeddings from a large pretrained T5 model (attention network). It also contains dynamic clipping for improved classifier free guidance, noise level conditioning, and a memory efficient unet design.

It appears neither CLIP nor prior network is needed after all. And so research continues.

AI Coffee Break with Letitia | Assembly AI | Yannic Kilcher

Please join Join us on Discord if you are interested in helping out with the replication with the LAION community

Shoutouts

  • StabilityAI for the generous sponsorship, as well as my other sponsors out there

  • 🤗 Huggingface for their amazing transformers library. The text encoder portion is pretty much taken care of because of them

  • Jonathan Ho for bringing about a revolution in generative artificial intelligence through his seminal paper

  • Sylvain and Zachary for the Accelerate library, which this repository uses for distributed training

  • Alex for einops, indispensable tool for tensor manipulation

  • Jorge Gomes for helping out with the T5 loading code and advice on the correct T5 version

  • Katherine Crowson, for her beautiful code, which helped me understand the continuous time version of gaussian diffusion

  • Marunine and Netruk44, for reviewing code, sharing experimental results, and help with debugging

  • Marunine for providing a potential solution for a color shifting issue in the memory efficient u-nets. Thanks to Jacob for sharing experimental comparisons between the base and memory-efficient unets

  • Marunine for finding numerous bugs, resolving an issue with resize right, and for sharing his experimental configurations and results

  • MalumaDev for proposing the use of pixel shuffle upsampler to fix checkboard artifacts

  • Valentin for pointing out insufficient skip connections in the unet, as well as the specific method of attention conditioning in the base-unet in the appendix

  • BIGJUN for catching a big bug with continuous time gaussian diffusion noise level conditioning at inference time

  • Bingbing for identifying a bug with sampling and order of normalizing and noising with low resolution conditioning image

  • Kay for contributing one line command training of Imagen!

  • Hadrien Reynaud for testing out text-to-video on a medical dataset, sharing his results, and identifying issues!

Install

$ pip install imagen-pytorch

Usage

import torch
from imagen_pytorch import Unet, Imagen

# unet for imagen

unet1 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
    layer_cross_attns = (False, True, True, True)
)

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unets = (unet1, unet2),
    image_sizes = (64, 256),
    timesteps = 1000,
    cond_drop_prob = 0.1
).cuda()

# mock images (get a lot of this) and text encodings from large T5

text_embeds = torch.randn(4, 256, 768).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

for i in (1, 2):
    loss = imagen(images, text_embeds = text_embeds, unet_number = i)
    loss.backward()

# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm

images = imagen.sample(texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles'
], cond_scale = 3.)

images.shape # (3, 3, 256, 256)

For simpler training, you can directly supply text strings instead of precomputing text encodings. (Although for scaling purposes, you will definitely want to precompute the textual embeddings + mask)

The number of textual captions must match the batch size of the images if you go this route.

# mock images and text (get a lot of this)

texts = [
    'a child screaming at finding a worm within a half-eaten apple',
    'lizard running across the desert on two feet',
    'waking up to a psychedelic landscape',
    'seashells sparkling in the shallow waters'
]

images = torch.randn(4, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

for i in (1, 2):
    loss = imagen(images, texts = texts, unet_number = i)
    loss.backward()

With the ImagenTrainer wrapper class, the exponential moving averages for all of the U-nets in the cascading DDPM will be automatically taken care of when calling update

import torch
from imagen_pytorch import Unet, Imagen, ImagenTrainer

# unet for imagen

unet1 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
)

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unets = (unet1, unet2),
    text_encoder_name = 't5-large',
    image_sizes = (64, 256),
    timesteps = 1000,
    cond_drop_prob = 0.1
).cuda()

# wrap imagen with the trainer class

trainer = ImagenTrainer(imagen)

# mock images (get a lot of this) and text encodings from large T5

text_embeds = torch.randn(64, 256, 1024).cuda()
images = torch.randn(64, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

loss = trainer(
    images,
    text_embeds = text_embeds,
    unet_number = 1,            # training on unet number 1 in this example, but you will have to also save checkpoints and then reload and continue training on unet number 2
    max_batch_size = 4          # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
)

trainer.update(unet_number = 1)

# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm

images = trainer.sample(texts = [
    'a puppy looking anxiously at a giant donut on the table',
    'the milky way galaxy in the style of monet'
], cond_scale = 3.)

images.shape # (2, 3, 256, 256)

You can also train Imagen without text (unconditional image generation) as follows

import torch
from imagen_pytorch import Unet, Imagen, SRUnet256, ImagenTrainer

# unets for unconditional imagen

unet1 = Unet(
    dim = 32,
    dim_mults = (1, 2, 4),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True),
    layer_cross_attns = False,
    use_linear_attn = True
)

unet2 = SRUnet256(
    dim = 32,
    dim_mults = (1, 2, 4),
    num_resnet_blocks = (2, 4, 8),
    layer_attns = (False, False, True),
    layer_cross_attns = False
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    condition_on_text = False,   # this must be set to False for unconditional Imagen
    unets = (unet1, unet2),
    image_sizes = (64, 128),
    timesteps = 1000
)

trainer = ImagenTrainer(imagen).cuda()

# now get a ton of images and feed it through the Imagen trainer

training_images = torch.randn(4, 3, 256, 256).cuda()

# train each unet separately
# in this example, only training on unet number 1

loss = trainer(training_images, unet_number = 1)
trainer.update(unet_number = 1)

# do the above for many many many many steps
# now you can sample images unconditionally from the cascading unet(s)

images = trainer.sample(batch_size = 16) # (16, 3, 128, 128)

Or train only super-resoluting unets

import torch
from imagen_pytorch import Unet, NullUnet, Imagen

# unet for imagen

unet1 = NullUnet()  # add a placeholder "null" unet for the base unet

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unets = (unet1, unet2),
    image_sizes = (64, 256),
    timesteps = 250,
    cond_drop_prob = 0.1
).cuda()

# mock images (get a lot of this) and text encodings from large T5

text_embeds = torch.randn(4, 256, 768).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

loss = imagen(images, text_embeds = text_embeds, unet_number = 2)
loss.backward()

# do the above for many many many many steps
# now you can sample an image based on the text embeddings as well as low resolution images

lowres_images = torch.randn(3, 3, 64, 64).cuda()  # starting un-resoluted images

images = imagen.sample(
    texts = [
        'a whale breaching from afar',
        'young girl blowing out candles on her birthday cake',
        'fireworks with blue and green sparkles'
    ],
    start_at_unet_number = 2,              # start at unet number 2
    start_image_or_video = lowres_images,  # pass in low resolution images to be resoluted
    cond_scale = 3.)

images.shape # (3, 3, 256, 256)

At any time you can save and load the trainer and all associated states with the save and load methods. It is recommended you use these methods instead of manually saving with a state_dict call, as there are some device memory management being done underneath the hood within the trainer.

ex.

trainer.save('./path/to/checkpoint.pt')

trainer.load('./path/to/checkpoint.pt')

trainer.steps # (2,) step number for each of the unets, in this case 2

Dataloader

You can also rely on the ImagenTrainer to automatically train off DataLoader instances. You simply have to craft your DataLoader to return either images (for unconditional case), or of ('images', 'text_embeds') for text-guided generation.

ex. unconditional training

from imagen_pytorch import Unet, Imagen, ImagenTrainer
from imagen_pytorch.data import Dataset

# unets for unconditional imagen

unet = Unet(
    dim = 32,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 1,
    layer_attns = (False, False, False, True),
    layer_cross_attns = False
)

# imagen, which contains the unet above

imagen = Imagen(
    condition_on_text = False,  # this must be set to False for unconditional Imagen
    unets = unet,
    image_sizes = 128,
    timesteps = 1000
)

trainer = ImagenTrainer(
    imagen = imagen,
    split_valid_from_train = True # whether to split the validation dataset from the training
).cuda()

# instantiate your dataloader, which returns the necessary inputs to the DDPM as tuple in the order of images, text embeddings, then text masks. in this case, only images is returned as it is unconditional training

dataset = Dataset('/path/to/training/images', image_size = 128)

trainer.add_train_dataset(dataset, batch_size = 16)

# working training loop

for i in range(200000):
    loss = trainer.train_step(unet_number = 1, max_batch_size = 4)
    print(f'loss: {loss}')

    if not (i % 50):
        valid_loss = trainer.valid_step(unet_number = 1, max_batch_size = 4)
        print(f'valid loss: {valid_loss}')

    if not (i % 100) and trainer.is_main: # is_main makes sure this can run in distributed
        images = trainer.sample(batch_size = 1, return_pil_images = True) # returns List[Image]
        images[0].save(f'./sample-{i // 100}.png')

Multi GPU

Thanks to 🤗 Accelerate, you can do multi GPU training easily with two steps.

First you need to invoke accelerate config in the same directory as your training script (say it is named train.py)

$ accelerate config

Next, instead of calling python train.py as you would for single GPU, you would use the accelerate CLI as so

$ accelerate launch train.py

That's it!

Command-line

Imagen can also be used via CLI directly.

Configuration

ex.

$ imagen config

or

$ imagen config --path ./configs/config.json

In the config you are able to change settings for the trainer, dataset and the imagen config.

The Imagen config parameters can be found here

The Elucidated Imagen config parameters can be found here

The Imagen Trainer config parameters can be found here

For the dataset parameters all dataloader parameters can be used.

Training

This command allows you to train or resume training your model

ex.

$ imagen train

or

$ imagen train --unet 2 --epoches 10

You can pass following arguments to the training command.

  • --config specify the config file to use for training [default: ./imagen_config.json]
  • --unet the index of the unet to train [default: 1]
  • --epoches how many epoches to train for [default: 50]

Sampling

Be aware when sampling your checkpoint should have trained all unets to get a usable result.

ex.

$ imagen sample --model ./path/to/model/checkpoint.pt "a squirrel raiding the birdfeeder"
# image is saved to ./a_squirrel_raiding_the_birdfeeder.png

You can pass following arguments to the sample command.

  • --model specify the model file to use for sampling
  • --cond_scale conditioning scale (classifier free guidance) in decoder
  • --load_ema load EMA version of unets if available

In order to use a saved checkpoint with this feature, you either must instantiate your Imagen instance using the config classes, ImagenConfig and ElucidatedImagenConfig or create a checkpoint via the CLI directly

For proper training, you'll likely want to setup config-driven training anyways.

ex.

import torch
from imagen_pytorch import ImagenConfig, ElucidatedImagenConfig, ImagenTrainer

# in this example, using elucidated imagen

imagen = ElucidatedImagenConfig(
    unets = [
        dict(dim = 32, dim_mults = (1, 2, 4, 8)),
        dict(dim = 32, dim_mults = (1, 2, 4, 8))
    ],
    image_sizes = (64, 128),
    cond_drop_prob = 0.5,
    num_sample_steps = 32
).create()

trainer = ImagenTrainer(imagen)

# do your training ...

# then save it

trainer.save('./checkpoint.pt')

# you should see a message informing you that ./checkpoint.pt is commandable from the terminal

It really should be as simple as that

You can also pass this checkpoint file around, and anyone can continue finetune on their own data

from imagen_pytorch import load_imagen_from_checkpoint, ImagenTrainer

imagen = load_imagen_from_checkpoint('./checkpoint.pt')

trainer = ImagenTrainer(imagen)

# continue training / fine-tuning

Inpainting

Inpainting follows the formulation laid out by the recent Repaint paper. Simply pass in inpaint_images and inpaint_masks to the sample function on either Imagen or ElucidatedImagen


inpaint_images = torch.randn(4, 3, 512, 512).cuda()      # (batch, channels, height, width)
inpaint_masks = torch.ones((4, 512, 512)).bool().cuda()  # (batch, height, width)

inpainted_images = trainer.sample(texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles',
    'dust motes swirling in the morning sunshine on the windowsill'
], inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, cond_scale = 5.)

inpainted_images # (4, 3, 512, 512)

For video, similarly pass in your videos to inpaint_videos keyword on .sample. Inpainting mask can either be the same across all frames (batch, height, width) or different (batch, frames, height, width)


inpaint_videos = torch.randn(4, 3, 8, 512, 512).cuda()   # (batch, channels, frames, height, width)
inpaint_masks = torch.ones((4, 8, 512, 512)).bool().cuda()  # (batch, frames, height, width)

inpainted_videos = trainer.sample(texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles',
    'dust motes swirling in the morning sunshine on the windowsill'
], inpaint_videos = inpaint_videos, inpaint_masks = inpaint_masks, cond_scale = 5.)

inpainted_videos # (4, 3, 8, 512, 512)

Experimental

Tero Karras of StyleGAN fame has written a new paper with results that have been corroborated by a number of independent researchers as well as on my own machine. I have decided to create a version of Imagen, the ElucidatedImagen, so that one can use the new elucidated DDPM for text-guided cascading generation.

Simply import ElucidatedImagen, and then instantiate the instance as you did before. The hyperparameters are different than the usual ones for discrete and continuous time gaussian diffusion, and can be individualized for each unet in the cascade.

Ex.

from imagen_pytorch import ElucidatedImagen

# instantiate your unets ...

imagen = ElucidatedImagen(
    unets = (unet1, unet2),
    image_sizes = (64, 128),
    cond_drop_prob = 0.1,
    num_sample_steps = (64, 32), # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are)
    sigma_min = 0.002,           # min noise level
    sigma_max = (80, 160),       # max noise level, @crowsonkb recommends double the max noise level for upsampler
    sigma_data = 0.5,            # standard deviation of data distribution
    rho = 7,                     # controls the sampling schedule
    P_mean = -1.2,               # mean of log-normal distribution from which noise is drawn for training
    P_std = 1.2,                 # standard deviation of log-normal distribution from which noise is drawn for training
    S_churn = 80,                # parameters for stochastic sampling - depends on dataset, Table 5 in apper
    S_tmin = 0.05,
    S_tmax = 50,
    S_noise = 1.003,
).cuda()

# rest is the same as above

Text to Video

This repository will also start accumulating new research around text guided video synthesis. For starters it will adopt the 3d unet architecture described by Jonathan Ho in Video Diffusion Models

Update: verified working by Hadrien Reynaud!

Ex.

import torch
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer

unet1 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda()

unet2 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda()

# elucidated imagen, which contains the unets above (base unet and super resoluting ones)

imagen = ElucidatedImagen(
    unets = (unet1, unet2),
    image_sizes = (16, 32),
    random_crop_sizes = (None, 16),
    temporal_downsample_factor = (2, 1),        # in this example, the first unet would receive the video temporally downsampled by 2x
    num_sample_steps = 10,
    cond_drop_prob = 0.1,
    sigma_min = 0.002,                          # min noise level
    sigma_max = (80, 160),                      # max noise level, double the max noise level for upsampler
    sigma_data = 0.5,                           # standard deviation of data distribution
    rho = 7,                                    # controls the sampling schedule
    P_mean = -1.2,                              # mean of log-normal distribution from which noise is drawn for training
    P_std = 1.2,                                # standard deviation of log-normal distribution from which noise is drawn for training
    S_churn = 80,                               # parameters for stochastic sampling - depends on dataset, Table 5 in apper
    S_tmin = 0.05,
    S_tmax = 50,
    S_noise = 1.003,
).cuda()

# mock videos (get a lot of this) and text encodings from large T5

texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles',
    'dust motes swirling in the morning sunshine on the windowsill'
]

videos = torch.randn(4, 3, 10, 32, 32).cuda() # (batch, channels, time / video frames, height, width)

# feed images into imagen, training each unet in the cascade
# for this example, only training unet 1

trainer = ImagenTrainer(imagen)

# you can also ignore time when training on video initially, shown to improve results in video-ddpm paper. eventually will make the 3d unet trainable with either images or video. research shows it is essential (with current data regimes) to train first on text-to-image. probably won't be true in another decade. all big data becomes small data

trainer(videos, texts = texts, unet_number = 1, ignore_time = False)
trainer.update(unet_number = 1)

videos = trainer.sample(texts = texts, video_frames = 20) # extrapolating to 20 frames from training on 10 frames

videos.shape # (4, 3, 20, 32, 32)

You can also train on text - image pairs first. The Unet3D will automatically convert it to single framed videos and learn without the temporal components (by automatically setting ignore_time = True), whether it be 1d convolutions or causal attention across time.

This is the current approach taken by all the big artificial intelligence labs (Brain, MetaAI, Bytedance)

FAQ

  • Why are my generated images not aligning well with the text?

Imagen uses an algorithm called Classifier Free Guidance. When sampling, you apply a scale to the conditioning (text in this case) of greater than 1.0.

Researcher Netruk44 have reported 5-10 to be optimal, but anything greater than 10 to break.

trainer.sample(texts = [
    'a cloud in the shape of a roman gladiator'
], cond_scale = 5.) # <-- cond_scale is the conditioning scale, needs to be greater than 1.0 to be better than average
  • Are there any pretrained models yet?

Not at the moment but one will likely be trained and open sourced within the year, if not sooner. If you would like to participate, you can join the community of artificial neural network trainers at Laion (discord link is in the Readme above) and start collaborating.

  • Will this technology take my job?

More the reason why you should start training your own model, starting today! The last thing we need is this technology being in the hands of an elite few. Hopefully this repository reduces the work to just finding the necessary compute, and augmenting with your own curated dataset.

  • What am I allowed to do with this repository?

Anything! It is MIT licensed. In other words, you can freely copy / paste for your own research, remixed for whatever modality you can think of. Go train amazing models for profit, for science, or simply to satiate your own personal pleasure at witnessing something divine unravel in front of you.

Cool Applications!

Todo

Citations

@inproceedings{Saharia2022PhotorealisticTD,
    title   = {Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding},
    author  = {Chitwan Saharia and William Chan and Saurabh Saxena and Lala Li and Jay Whang and Emily L. Denton and Seyed Kamyar Seyed Ghasemipour and Burcu Karagol Ayan and Seyedeh Sara Mahdavi and Raphael Gontijo Lopes and Tim Salimans and Jonathan Ho and David Fleet and Mohammad Norouzi},
    year    = {2022}
}
@article{Alayrac2022Flamingo,
    title   = {Flamingo: a Visual Language Model for Few-Shot Learning},
    author  = {Jean-Baptiste Alayrac et al},
    year    = {2022}
}
@inproceedings{Sankararaman2022BayesFormerTW,
    title   = {BayesFormer: Transformer with Uncertainty Estimation},
    author  = {Karthik Abinav Sankararaman and Sinong Wang and Han Fang},
    year    = {2022}
}
@article{So2021PrimerSF,
    title   = {Primer: Searching for Efficient Transformers for Language Modeling},
    author  = {David R. So and Wojciech Ma'nke and Hanxiao Liu and Zihang Dai and Noam M. Shazeer and Quoc V. Le},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2109.08668}
}
@misc{cao2020global,
    title   = {Global Context Networks},
    author  = {Yue Cao and Jiarui Xu and Stephen Lin and Fangyun Wei and Han Hu},
    year    = {2020},
    eprint  = {2012.13375},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@article{Karras2022ElucidatingTD,
    title   = {Elucidating the Design Space of Diffusion-Based Generative Models},
    author  = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2206.00364}
}
@inproceedings{NEURIPS2020_4c5bcfec,
    author      = {Ho, Jonathan and Jain, Ajay and Abbeel, Pieter},
    booktitle   = {Advances in Neural Information Processing Systems},
    editor      = {H. Larochelle and M. Ranzato and R. Hadsell and M.F. Balcan and H. Lin},
    pages       = {6840--6851},
    publisher   = {Curran Associates, Inc.},
    title       = {Denoising Diffusion Probabilistic Models},
    url         = {https://proceedings.neurips.cc/paper/2020/file/4c5bcfec8584af0d967f1ab10179ca4b-Paper.pdf},
    volume      = {33},
    year        = {2020}
}
@article{Lugmayr2022RePaintIU,
    title   = {RePaint: Inpainting using Denoising Diffusion Probabilistic Models},
    author  = {Andreas Lugmayr and Martin Danelljan and Andr{\'e}s Romero and Fisher Yu and Radu Timofte and Luc Van Gool},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2201.09865}
}
@misc{ho2022video,
    title   = {Video Diffusion Models},
    author  = {Jonathan Ho and Tim Salimans and Alexey Gritsenko and William Chan and Mohammad Norouzi and David J. Fleet},
    year    = {2022},
    eprint  = {2204.03458},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@inproceedings{rogozhnikov2022einops,
    title   = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
    author  = {Alex Rogozhnikov},
    booktitle = {International Conference on Learning Representations},
    year    = {2022},
    url     = {https://openreview.net/forum?id=oapKSVM2bcj}
}
@misc{chen2022analog,
    title   = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
    author  = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
    year    = {2022},
    eprint  = {2208.04202},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{Singer2022,
    author  = {Uriel Singer},
    url     = {https://makeavideo.studio/Make-A-Video.pdf}
}
@article{Sunkara2022NoMS,
    title   = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
    author  = {Raja Sunkara and Tie Luo},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2208.03641}
}
@article{Salimans2022ProgressiveDF,
    title   = {Progressive Distillation for Fast Sampling of Diffusion Models},
    author  = {Tim Salimans and Jonathan Ho},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2202.00512}
}
@article{Ho2022ImagenVH,
    title   = {Imagen Video: High Definition Video Generation with Diffusion Models},
    author  = {Jonathan Ho and William Chan and Chitwan Saharia and Jay Whang and Ruiqi Gao and Alexey A. Gritsenko and Diederik P. Kingma and Ben Poole and Mohammad Norouzi and David J. Fleet and Tim Salimans},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2210.02303}
}
@misc{gilmer2023intriguing
    title  = {Intriguing Properties of Transformer Training Instabilities},
    author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen},
    year   = {2023},
    status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams}
}
@inproceedings{Hang2023EfficientDT,
    title   = {Efficient Diffusion Training via Min-SNR Weighting Strategy},
    author  = {Tiankai Hang and Shuyang Gu and Chen Li and Jianmin Bao and Dong Chen and Han Hu and Xin Geng and Baining Guo},
    year    = {2023}
}
@article{Zhang2021TokenST,
    title   = {Token Shift Transformer for Video Classification},
    author  = {Hao Zhang and Y. Hao and Chong-Wah Ngo},
    journal = {Proceedings of the 29th ACM International Conference on Multimedia},
    year    = {2021}
}
@inproceedings{anonymous2022normformer,
    title   = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
    author  = {Anonymous},
    booktitle = {Submitted to The Tenth International Conference on Learning Representations },
    year    = {2022},
    url     = {https://openreview.net/forum?id=GMYWzWztDx5},
    note    = {under review}
}
posted @ 2024-06-28 14:01  绝不原创的飞龙  阅读(6)  评论(0编辑  收藏  举报