Lucidrains-系列项目源码解析-二十-
Lucidrains 系列项目源码解析(二十)
.\lucidrains\imagen-pytorch\imagen_pytorch\imagen_pytorch.py
# 导入数学库
import math
# 从随机模块中导入随机函数
from random import random
# 从 beartype 库中导入 List 和 Union 类型
from beartype.typing import List, Union
# 从 beartype 库中导入 beartype 装饰器
from beartype import beartype
# 从 tqdm 库中导入 tqdm 函数
from tqdm.auto import tqdm
# 从 functools 库中导入 partial 和 wraps 函数
from functools import partial, wraps
# 从 contextlib 库中导入 contextmanager 和 nullcontext 函数
from contextlib import contextmanager, nullcontext
# 从 pathlib 库中导入 Path 类
from pathlib import Path
# 导入 torch 库
import torch
# 从 torch.nn.functional 模块中导入 F 函数
import torch.nn.functional as F
# 从 torch.nn.parallel 模块中导入 DistributedDataParallel 类
from torch.nn.parallel import DistributedDataParallel
# 从 torch 模块中导入 nn 和 einsum 函数
from torch import nn, einsum
# 从 torch.cuda.amp 模块中导入 autocast 函数
from torch.cuda.amp import autocast
# 从 torch.special 模块中导入 expm1 函数
from torch.special import expm1
# 从 torchvision.transforms 模块中导入 T 函数
import torchvision.transforms as T
# 从 kornia.augmentation 模块中导入 K 函数
import kornia.augmentation as K
# 从 einops 模块中导入 rearrange, repeat, reduce, pack, unpack 函数
from einops import rearrange, repeat, reduce, pack, unpack
# 从 einops.layers.torch 模块中导入 Rearrange 类
from einops.layers.torch import Rearrange
# 从 imagen_pytorch.t5 模块中导入 t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME 函数
from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
# 从 imagen_pytorch.imagen_video 模块中导入 Unet3D, resize_video_to, scale_video_time 函数
from imagen_pytorch.imagen_video import Unet3D, resize_video_to, scale_video_time
# helper functions
# 判断值是否存在
def exists(val):
return val is not None
# 返回输入值
def identity(t, *args, **kwargs):
return t
# 判断一个数是否可以被另一个数整除
def divisible_by(numer, denom):
return (numer % denom) == 0
# 返回列表的第一个元素,如果列表为空则返回默认值
def first(arr, d = None):
if len(arr) == 0:
return d
return arr[0]
# 可能的装饰器
def maybe(fn):
@wraps(fn)
def inner(x):
if not exists(x):
return x
return fn(x)
return inner
# 仅执行一次的装饰器
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
# 仅打印一次的装饰器
print_once = once(print)
# 返回默认值
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
# 将输入值转换为元组
def cast_tuple(val, length = None):
if isinstance(val, list):
val = tuple(val)
output = val if isinstance(val, tuple) else ((val,) * default(length, 1))
if exists(length):
assert len(output) == length
return output
# 压缩字典,去除值为 None 的键值对
def compact(input_dict):
return {key: value for key, value in input_dict.items() if exists(value)}
# 对字典中指定键的值进行转换
def maybe_transform_dict_key(input_dict, key, fn):
if key not in input_dict:
return input_dict
copied_dict = input_dict.copy()
copied_dict[key] = fn(copied_dict[key])
return copied_dict
# 将 uint8 类型的图像转换为 float 类型
def cast_uint8_images_to_float(images):
if not images.dtype == torch.uint8:
return images
return images / 255
# 获取模块的设备信息
def module_device(module):
return next(module.parameters()).device
# 初始化权重为零
def zero_init_(m):
nn.init.zeros_(m.weight)
if exists(m.bias):
nn.init.zeros_(m.bias)
# 模型评估装饰器
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# 将元组填充到指定长度
def pad_tuple_to_length(t, length, fillvalue = None):
remain_length = length - len(t)
if remain_length <= 0:
return t
return (*t, *((fillvalue,) * remain_length))
# helper classes
# 空操作模块
class Identity(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
# tensor helpers
# 计算张量的对数
def log(t, eps: float = 1e-12):
return torch.log(t.clamp(min = eps))
# 计算张量的 L2 范数
def l2norm(t):
return F.normalize(t, dim = -1)
# 将一个张量的维度右侧填充到与另一个张量相同的维度
def right_pad_dims_to(x, t):
padding_dims = x.ndim - t.ndim
if padding_dims <= 0:
return t
return t.view(*t.shape, *((1,) * padding_dims))
# 计算带有掩码的张量均值
def masked_mean(t, *, dim, mask = None):
if not exists(mask):
return t.mean(dim = dim)
denom = mask.sum(dim = dim, keepdim = True)
mask = rearrange(mask, 'b n -> b n 1')
masked_t = t.masked_fill(~mask, 0.)
return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)
# 调整图像大小
def resize_image_to(
image,
target_image_size,
clamp_range = None,
mode = 'nearest'
):
orig_image_size = image.shape[-1]
if orig_image_size == target_image_size:
return image
out = F.interpolate(image, target_image_size, mode = mode)
if exists(clamp_range):
out = out.clamp(*clamp_range)
return out
# 计算所有帧的维度
def calc_all_frame_dims(
downsample_factors: List[int],
frames
):
# 如果frames不存在,则返回一个空元组的元组,长度为downsample_factors的长度
if not exists(frames):
return (tuple(),) * len(downsample_factors)
# 存储所有帧的维度信息
all_frame_dims = []
# 遍历downsample_factors列表
for divisor in downsample_factors:
# 断言frames能够被divisor整除
assert divisible_by(frames, divisor)
# 将frames除以divisor得到的结果作为元组添加到all_frame_dims列表中
all_frame_dims.append((frames // divisor,))
# 返回所有帧的维度信息
return all_frame_dims
# 安全获取元组中指定索引的值,如果索引超出范围则返回默认值
def safe_get_tuple_index(tup, index, default = None):
if len(tup) <= index:
return default
return tup[index]
# 图像归一化函数
# ddpms 期望图像范围在 -1 到 1 之间
def normalize_neg_one_to_one(img):
return img * 2 - 1
def unnormalize_zero_to_one(normed_img):
return (normed_img + 1) * 0.5
# 无分类器指导函数
def prob_mask_like(shape, prob, device):
if prob == 1:
return torch.ones(shape, device = device, dtype = torch.bool)
elif prob == 0:
return torch.zeros(shape, device = device, dtype = torch.bool)
else:
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
# 连续时间高斯扩散辅助函数和类
# 这部分很大程度上要感谢 @crowsonkb 在 https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py
@torch.jit.script
def beta_linear_log_snr(t):
return -torch.log(expm1(1e-4 + 10 * (t ** 2)))
@torch.jit.script
def alpha_cosine_log_snr(t, s: float = 0.008):
return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) # 不确定这是否考虑了在离散版本中 beta 被剪切为 0.999
def log_snr_to_alpha_sigma(log_snr):
return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr))
class GaussianDiffusionContinuousTimes(nn.Module):
def __init__(self, *, noise_schedule, timesteps = 1000):
super().__init__()
if noise_schedule == "linear":
self.log_snr = beta_linear_log_snr
elif noise_schedule == "cosine":
self.log_snr = alpha_cosine_log_snr
else:
raise ValueError(f'invalid noise schedule {noise_schedule}')
self.num_timesteps = timesteps
def get_times(self, batch_size, noise_level, *, device):
return torch.full((batch_size,), noise_level, device = device, dtype = torch.float32)
def sample_random_times(self, batch_size, *, device):
return torch.zeros((batch_size,), device = device).float().uniform_(0, 1)
def get_condition(self, times):
return maybe(self.log_snr)(times)
def get_sampling_timesteps(self, batch, *, device):
times = torch.linspace(1., 0., self.num_timesteps + 1, device = device)
times = repeat(times, 't -> b t', b = batch)
times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
times = times.unbind(dim = -1)
return times
def q_posterior(self, x_start, x_t, t, *, t_next = None):
t_next = default(t_next, lambda: (t - 1. / self.num_timesteps).clamp(min = 0.))
""" https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """
log_snr = self.log_snr(t)
log_snr_next = self.log_snr(t_next)
log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next))
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)
# c - as defined near eq 33
c = -expm1(log_snr - log_snr_next)
posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start)
# following (eq. 33)
posterior_variance = (sigma_next ** 2) * c
posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def q_sample(self, x_start, t, noise = None):
dtype = x_start.dtype
if isinstance(t, float):
batch = x_start.shape[0]
t = torch.full((batch,), t, device = x_start.device, dtype = dtype)
noise = default(noise, lambda: torch.randn_like(x_start))
log_snr = self.log_snr(t).type(dtype)
log_snr_padded_dim = right_pad_dims_to(x_start, log_snr)
alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim)
return alpha * x_start + sigma * noise, log_snr, alpha, sigma
# 从输入的 x_from 中采样数据,从 from_t 到 to_t 时间范围内,添加噪声
def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
# 获取输入 x_from 的形状、设备和数据类型
shape, device, dtype = x_from.shape, x_from.device, x_from.dtype
batch = shape[0]
# 如果 from_t 是浮点数,则将其转换为与 batch 大小相同的张量
if isinstance(from_t, float):
from_t = torch.full((batch,), from_t, device = device, dtype = dtype)
# 如果 to_t 是浮点数,则将其转换为与 batch 大小相同的张量
if isinstance(to_t, float):
to_t = torch.full((batch,), to_t, device = device, dtype = dtype)
# 如果未提供噪声,则生成一个与 x_from 相同形状的随机噪声张量
noise = default(noise, lambda: torch.randn_like(x_from))
# 计算 from_t 对应的 log_snr,并将其维度与 x_from 对齐
log_snr = self.log_snr(from_t)
log_snr_padded_dim = right_pad_dims_to(x_from, log_snr)
# 根据 log_snr 计算 alpha 和 sigma
alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim)
# 计算 to_t 对应的 log_snr,并将其维度与 x_from 对齐
log_snr_to = self.log_snr(to_t)
log_snr_padded_dim_to = right_pad_dims_to(x_from, log_snr_to)
# 根据 log_snr_to 计算 alpha_to 和 sigma_to
alpha_to, sigma_to = log_snr_to_alpha_sigma(log_snr_padded_dim_to)
# 返回根据公式计算得到的结果
return x_from * (alpha_to / alpha) + noise * (sigma_to * alpha - sigma * alpha_to) / alpha
# 根据给定的 x_t、t 和速度 v 预测起始值
def predict_start_from_v(self, x_t, t, v):
# 计算 t 对应的 log_snr,并将其维度与 x_t 对齐
log_snr = self.log_snr(t)
log_snr = right_pad_dims_to(x_t, log_snr)
# 根据 log_snr 计算 alpha 和 sigma
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
# 返回根据公式计算得到的结果
return alpha * x_t - sigma * v
# 根据给定的 x_t、t 和噪声 noise 预测起始值
def predict_start_from_noise(self, x_t, t, noise):
# 计算 t 对应的 log_snr,并将其维度与 x_t 对齐
log_snr = self.log_snr(t)
log_snr = right_pad_dims_to(x_t, log_snr)
# 根据 log_snr 计算 alpha 和 sigma
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
# 返回根据公式计算得到的结果
return (x_t - sigma * noise) / alpha.clamp(min = 1e-8)
# 定义 LayerNorm 类,用于实现层归一化操作
class LayerNorm(nn.Module):
# 初始化函数,接受特征数、是否稳定、维度作为参数
def __init__(self, feats, stable = False, dim = -1):
super().__init__()
self.stable = stable
self.dim = dim
# 初始化可学习参数 g
self.g = nn.Parameter(torch.ones(feats, *((1,) * (-dim - 1))))
# 前向传播函数
def forward(self, x):
dtype, dim = x.dtype, self.dim
# 如果设置了稳定性,对输入进行归一化处理
if self.stable:
x = x / x.amax(dim = dim, keepdim = True).detach()
# 根据数据类型选择 eps 值
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
# 计算方差和均值
var = torch.var(x, dim = dim, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = dim, keepdim = True)
# 返回归一化后的结果
return (x - mean) * (var + eps).rsqrt().type(dtype) * self.g.type(dtype)
# 定义 ChanLayerNorm 类,是 LayerNorm 的一个特例,维度为 -3
ChanLayerNorm = partial(LayerNorm, dim = -3)
# 定义 Always 类,用于返回固定值
class Always():
def __init__(self, val):
self.val = val
def __call__(self, *args, **kwargs):
return self.val
# 定义 Residual 类,实现残差连接
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
# 定义 Parallel 类,实现并行计算
class Parallel(nn.Module):
def __init__(self, *fns):
super().__init__()
self.fns = nn.ModuleList(fns)
def forward(self, x):
outputs = [fn(x) for fn in self.fns]
return sum(outputs)
# 定义 PerceiverAttention 类,实现注意力机制
class PerceiverAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
scale = 8
):
super().__init__()
self.scale = scale
self.heads = heads
inner_dim = dim_head * heads
# 初始化层归一化操作和线性变换
self.norm = nn.LayerNorm(dim)
self.norm_latents = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
# 初始化缩放参数
self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))
# 输出层
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
nn.LayerNorm(dim)
)
# 前向传播函数
def forward(self, x, latents, mask = None):
x = self.norm(x)
latents = self.norm_latents(latents)
b, h = x.shape[0], self.heads
q = self.to_q(latents)
# 拼接键值对
kv_input = torch.cat((x, latents), dim = -2)
k, v = self.to_kv(kv_input).chunk(2, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# 对 q 和 k 进行 L2 归一化
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
# 计算相似度并进行掩码处理
sim = einsum('... i d, ... j d -> ... i j', q, k) * self.scale
if exists(mask):
max_neg_value = -torch.finfo(sim.dtype).max
mask = F.pad(mask, (0, latents.shape[-2]), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
# 注意力计算
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = attn.to(sim.dtype)
out = einsum('... i j, ... j d -> ... i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
return self.to_out(out)
# 定义 PerceiverResampler 类,实现 Perceiver 模型的重采样
class PerceiverResampler(nn.Module):
def __init__(
self,
*,
dim,
depth,
dim_head = 64,
heads = 8,
num_latents = 64,
num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence
max_seq_len = 512,
ff_mult = 4
# 初始化函数,继承父类的初始化方法
):
# 调用父类的初始化方法
super().__init__()
# 创建位置编码的嵌入层,用于将位置信息嵌入输入数据中
self.pos_emb = nn.Embedding(max_seq_len, dim)
# 创建可学习的潜在变量,用于表示输入数据的潜在特征
self.latents = nn.Parameter(torch.randn(num_latents, dim))
# 初始化从平均池化序列到潜在变量的映射层
self.to_latents_from_mean_pooled_seq = None
# 如果平均池化的潜在变量数量大于0,则创建映射层
if num_latents_mean_pooled > 0:
self.to_latents_from_mean_pooled_seq = nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, dim * num_latents_mean_pooled),
Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled)
)
# 创建多层感知器的注意力和前馈网络层
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads),
FeedForward(dim = dim, mult = ff_mult)
]))
# 前向传播函数,接收输入数据 x 和掩码 mask
def forward(self, x, mask = None):
# 获取输入数据的长度和设备信息
n, device = x.shape[1], x.device
# 根据位置编码获取位置嵌入
pos_emb = self.pos_emb(torch.arange(n, device = device))
# 将输入数据与位置编码相加,融合位置信息
x_with_pos = x + pos_emb
# 重复潜在变量以匹配输入数据的维度
latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])
# 如果存在平均池化的潜在变量映射层,则将平均池化的潜在变量与原始潜在变量拼接
if exists(self.to_latents_from_mean_pooled_seq):
meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool))
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
latents = torch.cat((meanpooled_latents, latents), dim = -2)
# 遍历多层感知器的注意力和前馈网络层
for attn, ff in self.layers:
# 使用注意力层处理输入数据和潜在变量,然后与潜在变量相加
latents = attn(x_with_pos, latents, mask = mask) + latents
# 使用前馈网络层处理潜在变量,然后与潜在变量相加
latents = ff(latents) + latents
# 返回处理后的潜在变量
return latents
# 定义注意力机制模块
class Attention(nn.Module):
def __init__(
self,
dim,
*,
dim_head = 64,
heads = 8,
context_dim = None,
scale = 8
):
super().__init__()
self.scale = scale
self.heads = heads
inner_dim = dim_head * heads
self.norm = LayerNorm(dim)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))
self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
)
def forward(self, x, context = None, mask = None, attn_bias = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
# add null key / value for classifier free guidance in prior net
nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# add text conditioning, if present
if exists(context):
assert exists(self.to_context)
ck, cv = self.to_context(context).chunk(2, dim = -1)
k = torch.cat((ck, k), dim = -2)
v = torch.cat((cv, v), dim = -2)
# qk rmsnorm
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
# calculate query / key similarities
sim = einsum('b h i d, b j d -> b h i j', q, k) * self.scale
# relative positional encoding (T5 style)
if exists(attn_bias):
sim = sim + attn_bias
# masking
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
# attention
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = attn.to(sim.dtype)
# aggregate values
out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# 定义上采样函数
def Upsample(dim, dim_out = None):
dim_out = default(dim_out, dim)
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
nn.Conv2d(dim, dim_out, 3, padding = 1)
)
# 定义像素混洗上采样类
class PixelShuffleUpsample(nn.Module):
"""
code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
"""
def __init__(self, dim, dim_out = None):
super().__init__()
dim_out = default(dim_out, dim)
conv = nn.Conv2d(dim, dim_out * 4, 1)
self.net = nn.Sequential(
conv,
nn.SiLU(),
nn.PixelShuffle(2)
)
self.init_conv_(conv)
def init_conv_(self, conv):
o, i, h, w = conv.weight.shape
conv_weight = torch.empty(o // 4, i, h, w)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
return self.net(x)
# 定义下采样函数
def Downsample(dim, dim_out = None):
# https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample
# named SP-conv in the paper, but basically a pixel unshuffle
dim_out = default(dim_out, dim)
# 返回一个包含两个操作的序列:1. 重新排列输入张量的维度,将其转换为'b (c s1 s2) h w'的形式;2. 使用1x1卷积层将输入通道数从dim * 4降至dim_out
return nn.Sequential(
# 重新排列输入张量的维度,将其转换为'b (c s1 s2) h w'的形式,其中s1和s2分别为2
Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2),
# 使用1x1卷积层将输入通道数从dim * 4降至dim_out
nn.Conv2d(dim * 4, dim_out, 1)
)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1) # 计算对数值
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb) # 计算指数值
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j') # 重排张量形状
return torch.cat((emb.sin(), emb.cos()), dim = -1) # 拼接正弦和余弦值
class LearnedSinusoidalPosEmb(nn.Module):
""" following @crowsonkb 's lead with learned sinusoidal pos emb """
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
def __init__(self, dim):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim)) # 初始化权重参数
def forward(self, x):
x = rearrange(x, 'b -> b 1') # 重排张量形状
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi # 计算频率
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) # 拼接正弦和余弦值
fouriered = torch.cat((x, fouriered), dim = -1) # 拼接原始张量和傅立叶变换结果
return fouriered
class Block(nn.Module):
def __init__(
self,
dim,
dim_out,
groups = 8,
norm = True
):
super().__init__()
self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity() # 初始化分组归一化层
self.activation = nn.SiLU() # 激活函数
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1) # 卷积层
def forward(self, x, scale_shift = None):
x = self.groupnorm(x) # 分组归一化
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift # 缩放和平移
x = self.activation(x) # 激活函数
return self.project(x) # 卷积操作
class ResnetBlock(nn.Module):
def __init__(
self,
dim,
dim_out,
*,
cond_dim = None,
time_cond_dim = None,
groups = 8,
linear_attn = False,
use_gca = False,
squeeze_excite = False,
**attn_kwargs
):
super().__init__()
self.time_mlp = None
if exists(time_cond_dim):
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_cond_dim, dim_out * 2)
) # 时间条件的多层感��机
self.cross_attn = None
if exists(cond_dim):
attn_klass = CrossAttention if not linear_attn else LinearCrossAttention
self.cross_attn = attn_klass(
dim = dim_out,
context_dim = cond_dim,
**attn_kwargs
) # 交叉注意力机制
self.block1 = Block(dim, dim_out, groups = groups) # 第一个块
self.block2 = Block(dim_out, dim_out, groups = groups) # 第二个块
self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1) # 全局上下文注意力
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else Identity() # 残差卷积
def forward(self, x, time_emb = None, cond = None):
scale_shift = None
if exists(self.time_mlp) and exists(time_emb):
time_emb = self.time_mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
scale_shift = time_emb.chunk(2, dim = 1) # 分割时间嵌入
h = self.block1(x) # 第一个块操作
if exists(self.cross_attn):
assert exists(cond)
h = rearrange(h, 'b c h w -> b h w c')
h, ps = pack([h], 'b * c')
h = self.cross_attn(h, context = cond) + h # 交叉注意力机制
h, = unpack(h, ps, 'b * c')
h = rearrange(h, 'b h w c -> b c h w')
h = self.block2(h, scale_shift = scale_shift) # 第二个块操作
h = h * self.gca(h) # 全局上下文注意力
return h + self.res_conv(x) # 返回残差连接结果
class CrossAttention(nn.Module):
def __init__(
self,
dim,
*,
context_dim = None,
dim_head = 64,
heads = 8,
norm_context = False,
scale = 8
# 初始化函数,设置缩放因子和头数
def __init__(
super().__init__()
self.scale = scale
self.heads = heads
inner_dim = dim_head * heads
# 设置上下文维度
context_dim = default(context_dim, dim)
# 初始化层归一化
self.norm = LayerNorm(dim)
self.norm_context = LayerNorm(context_dim) if norm_context else Identity()
# 初始化空键值对
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
# 线性变换,将输入转换为查询向量
self.to_q = nn.Linear(dim, inner_dim, bias = False)
# 线性变换,将上下文转换为键值对
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
# 初始化查询和键的缩放参数
self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))
# 输出层
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
)
# 前向传播函数
def forward(self, x, context, mask = None):
# 获取输入的形状和设备信息
b, n, device = *x.shape[:2], x.device
# 对输入和上下文进行层归一化
x = self.norm(x)
context = self.norm_context(context)
# 获取查询、键、值
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
# 重排查询、键、值的维度
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
# 添加空键/值对,用于分类器在先验网络中的自由引导
nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# 余弦相似度注意力
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
# 计算相似度
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
# 掩码
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
# softmax计算注意力权重
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = attn.to(sim.dtype)
# 加权求和得到输出
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class LinearCrossAttention(CrossAttention):
# 线性交叉注意力类,继承自CrossAttention类
def forward(self, x, context, mask = None):
# 前向传播函数,接收输入x、上下文context和掩码mask,默认为None
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
# 对输入x进行规范化处理
context = self.norm_context(context)
# 对上下文context进行规范化处理
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
# 将输入x和上下文context转换为查询q、键k和值v
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads), (q, k, v))
# 对查询q、键k和值v进行形状重排
# add null key / value for classifier free guidance in prior net
# 在先前网络中添加空键/值以用于分类器的自由引导
nk, nv = map(lambda t: repeat(t, 'd -> (b h) 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# masking
# 掩码处理
max_neg_value = -torch.finfo(x.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b n -> b n 1')
k = k.masked_fill(~mask, max_neg_value)
v = v.masked_fill(~mask, 0.)
# linear attention
# 线性注意力计算
q = q.softmax(dim = -1)
k = k.softmax(dim = -2)
q = q * self.scale
context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context)
out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads)
return self.to_out(out)
class LinearAttention(nn.Module):
# 线性注意力类,继承自nn.Module类
def __init__(
self,
dim,
dim_head = 32,
heads = 8,
dropout = 0.05,
context_dim = None,
**kwargs
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm = ChanLayerNorm(dim)
self.nonlin = nn.SiLU()
self.to_q = nn.Sequential(
nn.Dropout(dropout),
nn.Conv2d(dim, inner_dim, 1, bias = False),
nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
)
self.to_k = nn.Sequential(
nn.Dropout(dropout),
nn.Conv2d(dim, inner_dim, 1, bias = False),
nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
)
self.to_v = nn.Sequential(
nn.Dropout(dropout),
nn.Conv2d(dim, inner_dim, 1, bias = False),
nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
)
self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1, bias = False),
ChanLayerNorm(dim)
)
def forward(self, fmap, context = None):
# 前向传播函数,接收特征图fmap和上下文context,默认为None
h, x, y = self.heads, *fmap.shape[-2:]
fmap = self.norm(fmap)
q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v))
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v))
if exists(context):
assert exists(self.to_context)
ck, cv = self.to_context(context).chunk(2, dim = -1)
ck, cv = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (ck, cv))
k = torch.cat((k, ck), dim = -2)
v = torch.cat((v, cv), dim = -2)
q = q.softmax(dim = -1)
k = k.softmax(dim = -2)
q = q * self.scale
context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
out = self.nonlin(out)
return self.to_out(out)
class GlobalContext(nn.Module):
# 全局上下文类
""" basically a superior form of squeeze-excitation that is attention-esque """
def __init__(
self,
*,
dim_in,
dim_out
# 定义一个类,继承自 nn.Module
class Attention(nn.Module):
# 初始化函数
def __init__(self, dim_in, dim_out):
# 调用父类的初始化函数
super().__init__()
# 创建一个卷积层,输入维度为 dim_in,输出维度为 1,卷积核大小为 1
self.to_k = nn.Conv2d(dim_in, 1, 1)
# 计算隐藏层维度,取 dim_out 除以 2 和 3 中的较大值
hidden_dim = max(3, dim_out // 2)
# 创建一个神经网络序列
self.net = nn.Sequential(
# 第一层卷积层,输入维度为 dim_in,输出维度为 hidden_dim,卷积核大小为 1
nn.Conv2d(dim_in, hidden_dim, 1),
# 使用 SiLU 激活函数
nn.SiLU(),
# 第二层卷积层,输入维度为 hidden_dim,输出维度为 dim_out,卷积核大小为 1
nn.Conv2d(hidden_dim, dim_out, 1),
# 使用 Sigmoid 激活函数
nn.Sigmoid()
)
# 前向传播函数
def forward(self, x):
# 将输入 x 通过 self.to_k 进行处理,得到 context
context = self.to_k(x)
# 对 x 和 context 进行维度重排,将 'b n ...' 转换为 'b n (...)'
x, context = map(lambda t: rearrange(t, 'b n ... -> b n (...)'), (x, context))
# 使用 einsum 进行张量乘法,计算注意力权重
out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x)
# 将输出 out 进行维度重排,将 '...' 转换为 '... 1'
out = rearrange(out, '... -> ... 1')
# 将处理后的 out 输入到神经网络 self.net 中
return self.net(out)
# 定义一个前馈神经网络模块,包含层归一化、线性层、GELU激活函数和线性层
def FeedForward(dim, mult = 2):
# 计算隐藏层维度
hidden_dim = int(dim * mult)
return nn.Sequential(
LayerNorm(dim), # 层归一化
nn.Linear(dim, hidden_dim, bias = False), # 线性层
nn.GELU(), # GELU激活函数
LayerNorm(hidden_dim), # 层归一化
nn.Linear(hidden_dim, dim, bias = False) # 线性层
)
# 定义一个通道前馈神经网络模块,包含通道层归一化、卷积层、GELU激活函数和卷积层
def ChanFeedForward(dim, mult = 2): # in paper, it seems for self attention layers they did feedforwards with twice channel width
hidden_dim = int(dim * mult)
return nn.Sequential(
ChanLayerNorm(dim), # 通道层归一化
nn.Conv2d(dim, hidden_dim, 1, bias = False), # 卷积层
nn.GELU(), # GELU激活函数
ChanLayerNorm(hidden_dim), # 通道层归一化
nn.Conv2d(hidden_dim, dim, 1, bias = False) # 卷积层
)
# 定义一个Transformer块,包含多个自注意力层和前馈神经网络层
class TransformerBlock(nn.Module):
def __init__(
self,
dim,
*,
depth = 1,
heads = 8,
dim_head = 32,
ff_mult = 2,
context_dim = None
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim), # 自注意力层
FeedForward(dim = dim, mult = ff_mult) # 前馈神经网络层
]))
def forward(self, x, context = None):
x = rearrange(x, 'b c h w -> b h w c')
x, ps = pack([x], 'b * c')
for attn, ff in self.layers:
x = attn(x, context = context) + x
x = ff(x) + x
x, = unpack(x, ps, 'b * c')
x = rearrange(x, 'b h w c -> b c h w')
return x
# 定义一个线性注意力Transformer块,包含多个线性注意力层和通道前馈神经网络层
class LinearAttentionTransformerBlock(nn.Module):
def __init__(
self,
dim,
*,
depth = 1,
heads = 8,
dim_head = 32,
ff_mult = 2,
context_dim = None,
**kwargs
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim), # 线性注意力层
ChanFeedForward(dim = dim, mult = ff_mult) # 通道前馈神经网络层
]))
def forward(self, x, context = None):
for attn, ff in self.layers:
x = attn(x, context = context) + x
x = ff(x) + x
return x
# 定义一个交叉嵌入层,包含多个卷积层
class CrossEmbedLayer(nn.Module):
def __init__(
self,
dim_in,
kernel_sizes,
dim_out = None,
stride = 2
):
super().__init__()
assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
dim_out = default(dim_out, dim_in)
kernel_sizes = sorted(kernel_sizes)
num_scales = len(kernel_sizes)
# 计算每个尺度的维度
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
self.convs = nn.ModuleList([])
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
def forward(self, x):
fmaps = tuple(map(lambda conv: conv(x), self.convs))
return torch.cat(fmaps, dim = 1)
# 定义一个上采样合并器,包含多个块
class UpsampleCombiner(nn.Module):
def __init__(
self,
dim,
*,
enabled = False,
dim_ins = tuple(),
dim_outs = tuple()
):
super().__init__()
dim_outs = cast_tuple(dim_outs, len(dim_ins))
assert len(dim_ins) == len(dim_outs)
self.enabled = enabled
if not self.enabled:
self.dim_out = dim
return
self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
# 定义一个前向传播函数,接受输入 x 和特征图列表 fmaps,默认为 None
def forward(self, x, fmaps = None):
# 获取输入 x 的最后一个维度大小作为目标大小
target_size = x.shape[-1]
# 如果未提供特征图列表,则使用空元组
fmaps = default(fmaps, tuple())
# 如果模块未启用,特征图列表为空,或者卷积层列表为空,则直接返回输入 x
if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
return x
# 将特征图列表中的每个特征图调整大小为目标大小
fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps]
# 对每个调整大小后的特征图应用对应的卷积操作,得到输出列表
outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
# 在第一个维度上拼接输入 x 和所有输出,返回结果
return torch.cat((x, *outs), dim = 1)
# 定义一个名为 Unet 的类,继承自 nn.Module
class Unet(nn.Module):
# 初始化方法,设置类的属性
def __init__(
self,
*,
dim,
text_embed_dim = get_encoded_dim(DEFAULT_T5_NAME), # 默认文本嵌入维度
num_resnet_blocks = 1, # ResNet 块的数量
cond_dim = None, # 条件维度
num_image_tokens = 4, # 图像令牌数量
num_time_tokens = 2, # 时间令牌数量
learned_sinu_pos_emb_dim = 16, # 学习的正弦位置编码维度
out_dim = None, # 输出维度
dim_mults=(1, 2, 4, 8), # 维度倍增
cond_images_channels = 0, # 条件图像通道数
channels = 3, # 通道数
channels_out = None, # 输出通道数
attn_dim_head = 64, # 注意力头维度
attn_heads = 8, # 注意力头数量
ff_mult = 2., # FeedForward 层倍增因子
lowres_cond = False, # 低分辨率条件
layer_attns = True, # 层间注意力
layer_attns_depth = 1, # 层间注意力深度
layer_mid_attns_depth = 1, # 中间层注意力深度
layer_attns_add_text_cond = True, # 是否使用文本嵌入来条件化自注意力块
attend_at_middle = True, # 是否在瓶颈处进行注意力
layer_cross_attns = True, # 层间交叉注意力
use_linear_attn = False, # 是否使用线性注意力
use_linear_cross_attn = False, # 是否使用线性交叉注意力
cond_on_text = True, # 是否在文本上进行条件化
max_text_len = 256, # 最大文本长度
init_dim = None, # 初始化维度
resnet_groups = 8, # ResNet 组数
init_conv_kernel_size = 7, # 初始卷积核大小
init_cross_embed = True, # 初始化交叉嵌入
init_cross_embed_kernel_sizes = (3, 7, 15), # 初始化交叉嵌入的卷积核大小
cross_embed_downsample = False, # 交叉嵌入下采样
cross_embed_downsample_kernel_sizes = (2, 4), # 交叉嵌入下采样的卷积核大小
attn_pool_text = True, # 注意力池化文本
attn_pool_num_latents = 32, # 注意力池化潜在数
dropout = 0., # 丢弃率
memory_efficient = False, # 内存效率
init_conv_to_final_conv_residual = False, # 初始卷积到最终卷积的残差连接
use_global_context_attn = True, # 使用全局上下文注意力
scale_skip_connection = True, # 缩放跳跃连接
final_resnet_block = True, # 最终 ResNet 块
final_conv_kernel_size = 3, # 最终卷积核大小
self_cond = False, # 自条件
resize_mode = 'nearest', # 调整模式
combine_upsample_fmaps = False, # 合并所有上采样块的特征图
pixel_shuffle_upsample = True, # 像素混洗上采样
# 如果当前 Unet 的设置不正确,重新使用正确的设置重新初始化 Unet
def cast_model_parameters(
self,
*,
lowres_cond,
text_embed_dim,
channels,
channels_out,
cond_on_text
):
# 如果设置与当前 Unet 的设置相同,则返回当前 Unet
if lowres_cond == self.lowres_cond and \
channels == self.channels and \
cond_on_text == self.cond_on_text and \
text_embed_dim == self._locals['text_embed_dim'] and \
channels_out == self.channels_out:
return self
# 更新参数
updated_kwargs = dict(
lowres_cond = lowres_cond,
text_embed_dim = text_embed_dim,
channels = channels,
channels_out = channels_out,
cond_on_text = cond_on_text
)
return self.__class__(**{**self._locals, **updated_kwargs})
# 返回完整 Unet 配置及其参数状态字典的方法
def to_config_and_state_dict(self):
return self._locals, self.state_dict()
# 从配置和状态字典中重新创建 Unet 的类方法
@classmethod
def from_config_and_state_dict(klass, config, state_dict):
unet = klass(**config)
unet.load_state_dict(state_dict)
return unet
# 将 Unet 持久化到磁盘的方法
def persist_to_file(self, path):
path = Path(path)
path.parents[0].mkdir(exist_ok = True, parents = True)
config, state_dict = self.to_config_and_state_dict()
pkg = dict(config = config, state_dict = state_dict)
torch.save(pkg, str(path))
# 从使用 `persist_to_file` 保存的文件重新创建 Unet 的类方法
@classmethod
# 从文件中加载模型参数并返回实例化后的模型对象
def hydrate_from_file(klass, path):
# 将路径转换为 Path 对象
path = Path(path)
# 断言路径存在
assert path.exists()
# 使用 torch.load 加载模型参数
pkg = torch.load(str(path))
# 断言加载的参数中包含 'config' 和 'state_dict'
assert 'config' in pkg and 'state_dict' in pkg
# 分别获取配置和状态字典
config, state_dict = pkg['config'], pkg['state_dict']
# 使用配置和状态字典实例化 Unet 模型
return Unet.from_config_and_state_dict(config, state_dict)
# 使用分类器自由指导进行前向传播
def forward_with_cond_scale(
self,
*args,
cond_scale = 1.,
**kwargs
):
# 调用 forward 方法获取 logits
logits = self.forward(*args, **kwargs)
# 如果 cond_scale 为 1,则直接返回 logits
if cond_scale == 1:
return logits
# 使用 cond_scale 进行加权计算
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
# 普通的前向传播方法
def forward(
self,
x,
time,
*,
lowres_cond_img = None,
lowres_noise_times = None,
text_embeds = None,
text_mask = None,
cond_images = None,
self_cond = None,
cond_drop_prob = 0.
# 定义一个空的 Unet 类
class NullUnet(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.lowres_cond = False
self.dummy_parameter = nn.Parameter(torch.tensor([0.]))
# 将模型参数转换为自身
def cast_model_parameters(self, *args, **kwargs):
return self
# 前向传播函数,直接返回输入
def forward(self, x, *args, **kwargs):
return x
# 预定义的 Unet 类,配置与论文附录中的超参数对应
class BaseUnet64(Unet):
def __init__(self, *args, **kwargs):
default_kwargs = dict(
dim = 512,
dim_mults = (1, 2, 3, 4),
num_resnet_blocks = 3,
layer_attns = (False, True, True, True),
layer_cross_attns = (False, True, True, True),
attn_heads = 8,
ff_mult = 2.,
memory_efficient = False
)
super().__init__(*args, **{**default_kwargs, **kwargs})
class SRUnet256(Unet):
def __init__(self, *args, **kwargs):
default_kwargs = dict(
dim = 128,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = (2, 4, 8, 8),
layer_attns = (False, False, False, True),
layer_cross_attns = (False, False, False, True),
attn_heads = 8,
ff_mult = 2.,
memory_efficient = True
)
super().__init__(*args, **{**default_kwargs, **kwargs})
class SRUnet1024(Unet):
def __init__(self, *args, **kwargs):
default_kwargs = dict(
dim = 128,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = (2, 4, 8, 8),
layer_attns = False,
layer_cross_attns = (False, False, False, True),
attn_heads = 8,
ff_mult = 2.,
memory_efficient = True
)
super().__init__(*args, **{**default_kwargs, **kwargs})
# 主要的 Imagen 类,是来自 Ho 等人的级联 DDPM
class Imagen(nn.Module):
def __init__(
self,
unets,
*,
image_sizes, # 用于级联 ddpm,每个阶段的图像大小
text_encoder_name = DEFAULT_T5_NAME,
text_embed_dim = None,
channels = 3,
timesteps = 1000,
cond_drop_prob = 0.1,
loss_type = 'l2',
noise_schedules = 'cosine',
pred_objectives = 'noise',
random_crop_sizes = None,
lowres_noise_schedule = 'linear',
lowres_sample_noise_level = 0.2, # 论文中提到的一个新技巧,对低分辨率条件图像添加噪声,并在采样时将其固定到一定水平(0.1 或 0.3)- Unet 也被设计为在这��噪声水平上进行条件化
per_sample_random_aug_noise_level = False, # 不清楚在进行增强噪声水平条件化时,每个批次元素是否接收随机的增强噪声值-由于 @marunine 的发现,关闭此功能
condition_on_text = True,
auto_normalize_img = True, # 是否自动处理将图像从 [0, 1] 规范化为 [-1, 1] 并自动恢复-如果要自己从数据加载器传入 [-1, 1] 范围的图像,则可以关闭此功能
dynamic_thresholding = True,
dynamic_thresholding_percentile = 0.95, # 通过查阅论文,不确定这是基于什么的
only_train_unet_number = None,
temporal_downsample_factor = 1,
resize_cond_video_frames = True,
resize_mode = 'nearest',
min_snr_loss_weight = True, # https://arxiv.org/abs/2303.09556
min_snr_gamma = 5
def force_unconditional_(self):
self.condition_on_text = False
self.unconditional = True
for unet in self.unets:
unet.cond_on_text = False
@property
def device(self):
return self._temp.device
# 获取指定编号的 UNet 模型
def get_unet(self, unet_number):
# 确保编号在有效范围内
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
# 如果 self.unets 是 nn.ModuleList 类型
if isinstance(self.unets, nn.ModuleList):
# 将 self.unets 转换为列表
unets_list = [unet for unet in self.unets]
# 删除原有的 self.unets 属性
delattr(self, 'unets')
# 将转换后的列表重新赋值给 self.unets
self.unets = unets_list
# 如果指定的编号不是当前正在训练的编号
if index != self.unet_being_trained_index:
# 遍历所有 UNet 模型
for unet_index, unet in enumerate(self.unets):
# 将当前 UNet 模型移到指定设备上,其他模型移到 CPU 上
unet.to(self.device if unet_index == index else 'cpu')
# 更新当前正在训练的 UNet 模型编号
self.unet_being_trained_index = index
# 返回指定编号的 UNet 模型
return self.unets[index]
# 将所有 UNet 模型重置到同一设备上
def reset_unets_all_one_device(self, device = None):
# 设置设备为默认设备或者指定设备
device = default(device, self.device)
# 将所有 UNet 模型转换为 nn.ModuleList 类型
self.unets = nn.ModuleList([*self.unets])
# 将所有 UNet 模型移到指定设备上
self.unets.to(device)
# 重置当前正在训练的 UNet 模型编号
self.unet_being_trained_index = -1
# 使用上下文管理器将指定编号的 UNet 模型移到 GPU 上
@contextmanager
def one_unet_in_gpu(self, unet_number = None, unet = None):
# 确保只有一个参数是有效的
assert exists(unet_number) ^ exists(unet)
# 如果指定了编号,则获取对应的 UNet 模型
if exists(unet_number):
unet = self.unets[unet_number - 1]
# 创建 CPU 设备
cpu = torch.device('cpu')
# 获取所有 UNet 模型的设备信息
devices = [module_device(unet) for unet in self.unets]
# 将所有 UNet 模型移到 CPU 上
self.unets.to(cpu)
# 将指定 UNet 模型移到当前设备上
unet.to(self.device)
yield
# 将所有 UNet 模型还原到各自的设备上
for unet, device in zip(self.unets, devices):
unet.to(device)
# 重写 state_dict 函数
def state_dict(self, *args, **kwargs):
# 重置所有 UNet 模型到同一设备上
self.reset_unets_all_one_device()
return super().state_dict(*args, **kwargs)
# 重写 load_state_dict 函数
def load_state_dict(self, *args, **kwargs):
# 重置所有 UNet 模型到同一设备上
self.reset_unets_all_one_device()
return super().load_state_dict(*args, **kwargs)
# 高斯扩散方法
def p_mean_variance(
self,
unet,
x,
t,
*,
noise_scheduler,
text_embeds = None,
text_mask = None,
cond_images = None,
cond_video_frames = None,
post_cond_video_frames = None,
lowres_cond_img = None,
self_cond = None,
lowres_noise_times = None,
cond_scale = 1.,
model_output = None,
t_next = None,
pred_objective = 'noise',
dynamic_threshold = True
):
# 断言条件:如果条件为真,则抛出异常,说明不能使用分类器自由引导
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'imagen was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
# 初始化视频参数字典
video_kwargs = dict()
# 如果是视频模式,设置视频参数
if self.is_video:
video_kwargs = dict(
cond_video_frames = cond_video_frames,
post_cond_video_frames = post_cond_video_frames,
)
# 使用默认函数处理模型输出,获取预测结果
pred = default(model_output, lambda: unet.forward_with_cond_scale(
x,
noise_scheduler.get_condition(t),
text_embeds = text_embeds,
text_mask = text_mask,
cond_images = cond_images,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
self_cond = self_cond,
lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_noise_times),
**video_kwargs
))
# 根据预测目标类型进行处理
if pred_objective == 'noise':
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
elif pred_objective == 'x_start':
x_start = pred
elif pred_objective == 'v':
x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred)
else:
raise ValueError(f'unknown objective {pred_objective}')
# 如果启用动态阈值
if dynamic_threshold:
# 根据重构样本的绝对值百分位数确定动态阈值
s = torch.quantile(
rearrange(x_start, 'b ... -> b (...)').abs(),
self.dynamic_thresholding_percentile,
dim = -1
)
s.clamp_(min = 1.)
s = right_pad_dims_to(x_start, s)
x_start = x_start.clamp(-s, s) / s
else:
x_start.clamp_(-1., 1.)
# 计算均值和方差
mean_and_variance = noise_scheduler.q_posterior(x_start = x_start, x_t = x, t = t, t_next = t_next)
return mean_and_variance, x_start
# 无梯度计算
@torch.no_grad()
def p_sample(
self,
unet,
x,
t,
*,
noise_scheduler,
t_next = None,
text_embeds = None,
text_mask = None,
cond_images = None,
cond_video_frames = None,
post_cond_video_frames = None,
cond_scale = 1.,
self_cond = None,
lowres_cond_img = None,
lowres_noise_times = None,
pred_objective = 'noise',
dynamic_threshold = True
):
# 获取输入张量的形状和设备信息
b, *_, device = *x.shape, x.device
# 初始化视频参数字典
video_kwargs = dict()
# 如果是视频模式,设置视频参数
if self.is_video:
video_kwargs = dict(
cond_video_frames = cond_video_frames,
post_cond_video_frames = post_cond_video_frames,
)
# 获取均值、方差和起始值
(model_mean, _, model_log_variance), x_start = self.p_mean_variance(
unet,
x = x,
t = t,
t_next = t_next,
noise_scheduler = noise_scheduler,
text_embeds = text_embeds,
text_mask = text_mask,
cond_images = cond_images,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
self_cond = self_cond,
lowres_noise_times = lowres_noise_times,
pred_objective = pred_objective,
dynamic_threshold = dynamic_threshold,
**video_kwargs
)
# 生成随机噪声
noise = torch.randn_like(x)
# 当 t == 0 时不添加噪声
is_last_sampling_timestep = (t_next == 0) if isinstance(noise_scheduler, GaussianDiffusionContinuousTimes) else (t == 0)
nonzero_mask = (1 - is_last_sampling_timestep.float()).reshape(b, *((1,) * (len(x.shape) - 1)))
# 计算预测值
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return pred, x_start
# 无梯度计算
@torch.no_grad()
# 定义一个函数 p_sample_loop,用于执行采样循环
def p_sample_loop(
self,
unet,
shape,
*,
noise_scheduler,
lowres_cond_img = None,
lowres_noise_times = None,
text_embeds = None,
text_mask = None,
cond_images = None,
cond_video_frames = None,
post_cond_video_frames = None,
inpaint_images = None,
inpaint_videos = None,
inpaint_masks = None,
inpaint_resample_times = 5,
init_images = None,
skip_steps = None,
cond_scale = 1,
pred_objective = 'noise',
dynamic_threshold = True,
use_tqdm = True
):
# 获取当前设备
device = self.device
# 获取批次大小
batch = shape[0]
# 生成指定形状的随机张量
img = torch.randn(shape, device = device)
# video
# 判断是否为视频
is_video = len(shape) == 5
# 如果是视频,获取帧数
frames = shape[-3] if is_video else None
# 如果存在帧数,则传入目标帧数参数,否则传入空字典
resize_kwargs = dict(target_frames = frames) if exists(frames) else dict()
# for initialization with an image or video
# 如果存在初始化图像
if exists(init_images):
# 将随机生成的图像与初始化图像相加
img += init_images
# keep track of x0, for self conditioning
# 初始化 x0,用于自身条件
x_start = None
# prepare inpainting
# 将 inpaint_videos 默认为 inpaint_images
inpaint_images = default(inpaint_videos, inpaint_images)
# 判断是否存在 inpaint_images 和 inpaint_masks
has_inpainting = exists(inpaint_images) and exists(inpaint_masks)
# 如果存在 inpaint_images 和 inpaint_masks,则重采样次数为 inpaint_resample_times,否则为 1
resample_times = inpaint_resample_times if has_inpainting else 1
# 如果存在 inpaint_images 和 inpaint_masks
if has_inpainting:
# 对 inpaint_images 进行归一化处理
inpaint_images = self.normalize_img(inpaint_images)
# 将 inpaint_images 调整大小为指定形状
inpaint_images = self.resize_to(inpaint_images, shape[-1], **resize_kwargs)
# 将 inpaint_masks 调整大小为指定形状,并转换为布尔类型
inpaint_masks = self.resize_to(rearrange(inpaint_masks, 'b ... -> b 1 ...').float(), shape[-1], **resize_kwargs).bool()
# time
# 获取采样时间步长
timesteps = noise_scheduler.get_sampling_timesteps(batch, device = device)
# 是否跳过任何步骤
# 设置默认跳过步数为 0
skip_steps = default(skip_steps, 0)
# 从指定步数开始采样
timesteps = timesteps[skip_steps:]
# video conditioning kwargs
# 初始化视频条件参数字典
video_kwargs = dict()
# 如果是视频
if self.is_video:
# 设置视频条件参数
video_kwargs = dict(
cond_video_frames = cond_video_frames,
post_cond_video_frames = post_cond_video_frames,
)
# 遍历时间步长
for times, times_next in tqdm(timesteps, desc = 'sampling loop time step', total = len(timesteps), disable = not use_tqdm):
# 判断是否为最后一个时间步长
is_last_timestep = times_next == 0
# 反向遍历重采样次数
for r in reversed(range(resample_times)):
# 判断是否为最后一个重采样步骤
is_last_resample_step = r == 0
# 如果存在 inpainting
if has_inpainting:
# 从噪声调度器中采样噪声图像
noised_inpaint_images, *_ = noise_scheduler.q_sample(inpaint_images, t = times)
# 根据掩模进行图像修复
img = img * ~inpaint_masks + noised_inpaint_images * inpaint_masks
# 如果 unet.self_cond 为真,则设置 self_cond 为 x_start,否则为 None
self_cond = x_start if unet.self_cond else None
# 生成图像
img, x_start = self.p_sample(
unet,
img,
times,
t_next = times_next,
text_embeds = text_embeds,
text_mask = text_mask,
cond_images = cond_images,
cond_scale = cond_scale,
self_cond = self_cond,
lowres_cond_img = lowres_cond_img,
lowres_noise_times = lowres_noise_times,
noise_scheduler = noise_scheduler,
pred_objective = pred_objective,
dynamic_threshold = dynamic_threshold,
**video_kwargs
)
# 如果存在 inpainting 且不是最后一个重采样步骤或所有时间步骤都为最后一个
if has_inpainting and not (is_last_resample_step or torch.all(is_last_timestep)):
# 从指定时间点到另一个时间点采样图像
renoised_img = noise_scheduler.q_sample_from_to(img, times_next, times)
# 根据条件选择图像
img = torch.where(
self.right_pad_dims_to_datatype(is_last_timestep),
img,
renoised_img
)
# 限制图像像素值范围在 -1 到 1 之间
img.clamp_(-1., 1.)
# final inpainting
# 如果存在 inpainting
if has_inpainting:
# 根据掩模进行最终图像修复
img = img * ~inpaint_masks + inpaint_images * inpaint_masks
# 反归一化图像
unnormalize_img = self.unnormalize_img(img)
# 返回反归一化后的图像
return unnormalize_img
# 禁用梯度计算
@torch.no_grad()
# 设置评估模式装饰器
@eval_decorator
# 设置类型检查装饰器
@beartype
# 定义一个方法用于生成样本
def sample(
self,
texts: List[str] = None, # 文本列表,默认为 None
text_masks = None, # 文本掩码,默认为 None
text_embeds = None, # 文本嵌入,默认为 None
video_frames = None, # 视频帧,默认为 None
cond_images = None, # 条件图像,默认为 None
cond_video_frames = None, # 条件视频帧,默认为 None
post_cond_video_frames = None, # 后置条件视频帧,默认为 None
inpaint_videos = None, # 修复视频,默认为 None
inpaint_images = None, # 修复图像,默认为 None
inpaint_masks = None, # 修复掩码,默认为 None
inpaint_resample_times = 5, # 修复重采样次数,默认为 5
init_images = None, # 初始图像,默认为 None
skip_steps = None, # 跳过步骤,默认为 None
batch_size = 1, # 批量大小,默认为 1
cond_scale = 1., # 条件比例,默认为 1.0
lowres_sample_noise_level = None, # 低分辨率采样噪声级别,默认为 None
start_at_unet_number = 1, # 开始于 Unet 编号,默认为 1
start_image_or_video = None, # 开始图像或视频,默认为 None
stop_at_unet_number = None, # 停止于 Unet 编号,默认为 None
return_all_unet_outputs = False, # 返回所有 Unet 输出,默认为 False
return_pil_images = False, # 返回 PIL 图像,默认为 False
device = None, # 设备,默认为 None
use_tqdm = True, # 使用 tqdm,默认为 True
use_one_unet_in_gpu = True # 在 GPU 中使用一个 Unet,默认为 True
# 定义一个方法用于计算损失
@beartype
def p_losses(
self,
unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel], # Unet 对象,默认为 None
x_start, # 起始值
times, # 时间
*,
noise_scheduler, # 噪声调度器
lowres_cond_img = None, # 低分辨率条件图像,默认为 None
lowres_aug_times = None, # 低分辨率增强次数,默认为 None
text_embeds = None, # 文本嵌入,默认为 None
text_mask = None, # 文本掩码,默认为 None
cond_images = None, # 条件图像,默认为 None
noise = None, # 噪声,默认为 None
times_next = None, # 下一个时间,默认为 None
pred_objective = 'noise', # 预测目标,默认为 'noise'
min_snr_gamma = None, # 最小信噪比伽马,默认为 None
random_crop_size = None, # ��机裁剪大小,默认为 None
**kwargs # 其他关键字参数
# 定义一个方法用于前向传播
@beartype
def forward(
self,
images, # 图像或视频
unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel] = None, # Unet 对象,默认为 None
texts: List[str] = None, # 文本列表,默认为 None
text_embeds = None, # 文本嵌入,默认为 None
text_masks = None, # 文本掩码,默认为 None
unet_number = None, # Unet 编号,默认为 None
cond_images = None, # 条件图像,默认为 None
**kwargs # 其他关键字参数
.\lucidrains\imagen-pytorch\imagen_pytorch\imagen_video.py
# 导入数学、操作符、函数工具等模块
import math
import operator
import functools
from tqdm.auto import tqdm
from functools import partial, wraps
from pathlib import Path
# 导入 PyTorch 相关模块
import torch
import torch.nn.functional as F
from torch import nn, einsum
# 导入 einops 相关模块
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
# 导入自定义模块
from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
# 辅助函数
# 检查值是否存在
def exists(val):
return val is not None
# 返回输入值
def identity(t, *args, **kwargs):
return t
# 返回数组的第一个元素,如果数组为空则返回默认值
def first(arr, d = None):
if len(arr) == 0:
return d
return arr[0]
# 检查一个数是否能被另一个数整除
def divisible_by(numer, denom):
return (numer % denom) == 0
# 可能执行函数,如果输入值不存在则直接返回
def maybe(fn):
@wraps(fn)
def inner(x):
if not exists(x):
return x
return fn(x)
return inner
# 仅执行一次函数,用于打印信息
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
# 仅打印一次信息
print_once = once(print)
# 返回默认值或默认函数的值
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
# 将输入值转换为元组
def cast_tuple(val, length = None):
if isinstance(val, list):
val = tuple(val)
output = val if isinstance(val, tuple) else ((val,) * default(length, 1))
if exists(length):
assert len(output) == length
return output
# 将 uint8 类型的图像转换为 float 类型
def cast_uint8_images_to_float(images):
if not images.dtype == torch.uint8:
return images
return images / 255
# 获取模块的设备信息
def module_device(module):
return next(module.parameters()).device
# 初始化权重为零
def zero_init_(m):
nn.init.zeros_(m.weight)
if exists(m.bias):
nn.init.zeros_(m.bias)
# 模型评估装饰器
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# 将元组填充到指定长度
def pad_tuple_to_length(t, length, fillvalue = None):
remain_length = length - len(t)
if remain_length <= 0:
return t
return (*t, *((fillvalue,) * remain_length))
# 辅助类
# 简单的返回输入值的模块
class Identity(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
# 创建序列模块
def Sequential(*modules):
return nn.Sequential(*filter(exists, modules))
# 张量辅助函数
# 对数函数
def log(t, eps: float = 1e-12):
return torch.log(t.clamp(min = eps))
# L2 归一化
def l2norm(t):
return F.normalize(t, dim = -1)
# 将右侧维度填充到相同维度
def right_pad_dims_to(x, t):
padding_dims = x.ndim - t.ndim
if padding_dims <= 0:
return t
return t.view(*t.shape, *((1,) * padding_dims))
# 带掩码的均值计算
def masked_mean(t, *, dim, mask = None):
if not exists(mask):
return t.mean(dim = dim)
denom = mask.sum(dim = dim, keepdim = True)
mask = rearrange(mask, 'b n -> b n 1')
masked_t = t.masked_fill(~mask, 0.)
return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)
# 调整视频大小
def resize_video_to(
video,
target_image_size,
target_frames = None,
clamp_range = None,
mode = 'nearest'
):
orig_video_size = video.shape[-1]
frames = video.shape[2]
target_frames = default(target_frames, frames)
target_shape = (target_frames, target_image_size, target_image_size)
if tuple(video.shape[-3:]) == target_shape:
return video
out = F.interpolate(video, target_shape, mode = mode)
if exists(clamp_range):
out = out.clamp(*clamp_range)
return out
# 缩放视频时间
def scale_video_time(
video,
downsample_scale = 1,
mode = 'nearest'
):
if downsample_scale == 1:
return video
image_size, frames = video.shape[-1], video.shape[-3]
assert divisible_by(frames, downsample_scale), f'trying to temporally downsample a conditioning video frames of length {frames} by {downsample_scale}, however it is not neatly divisible'
target_frames = frames // downsample_scale
# 调用 resize_video_to 函数,将视频调整大小为指定尺寸
resized_video = resize_video_to(
video, # 原始视频
image_size, # 目标图像尺寸
target_frames = target_frames, # 目标帧数
mode = mode # 调整模式
)
# 返回调整大小后的视频
return resized_video
# classifier free guidance functions
# 根据给定形状、概率和设备创建一个布尔类型的掩码
def prob_mask_like(shape, prob, device):
if prob == 1:
return torch.ones(shape, device=device, dtype=torch.bool)
elif prob == 0:
return torch.zeros(shape, device=device, dtype=torch.bool)
else:
return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
# norms and residuals
# Layer normalization模块
class LayerNorm(nn.Module):
def __init__(self, dim, stable=False):
super().__init__()
self.stable = stable
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
if self.stable:
x = x / x.amax(dim=-1, keepdim=True).detach()
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim=-1, unbiased=False, keepdim=True)
mean = torch.mean(x, dim=-1, keepdim=True)
return (x - mean) * (var + eps).rsqrt() * self.g
# 通道层规范化模块
class ChanLayerNorm(nn.Module):
def __init__(self, dim, stable=False):
super().__init__()
self.stable = stable
self.g = nn.Parameter(torch.ones(1, dim, 1, 1, 1))
def forward(self, x):
if self.stable:
x = x / x.amax(dim=1, keepdim=True).detach()
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) * (var + eps).rsqrt() * self.g
# 始终返回相同值的类
class Always():
def __init__(self, val):
self.val = val
def __call__(self, *args, **kwargs):
return self.val
# 残差连接模块
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
# 并行执行多个函数模块
class Parallel(nn.Module):
def __init__(self, *fns):
super().__init__()
self.fns = nn.ModuleList(fns)
def forward(self, x):
outputs = [fn(x) for fn in self.fns]
return sum(outputs)
# rearranging
# 时间为中心的重排模块
class RearrangeTimeCentric(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
x = rearrange(x, 'b c f ... -> b ... f c')
x, ps = pack([x], '* f c')
x = self.fn(x)
x, = unpack(x, ps, '* f c')
x = rearrange(x, 'b ... f c -> b c f ...')
return x
# attention pooling
# PerceiverAttention模块
class PerceiverAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_head=64,
heads=8,
scale=8
):
super().__init__()
self.scale = scale
self.heads = heads
inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim)
self.norm_latents = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias=False),
nn.LayerNorm(dim)
)
# 前向传播函数,接收输入 x、潜在变量 latents 和可选的 mask
def forward(self, x, latents, mask = None):
# 对输入 x 进行归一化处理
x = self.norm(x)
# 对潜在变量 latents 进行归一化处理
latents = self.norm_latents(latents)
# 获取输入 x 的 batch 大小和头数
b, h = x.shape[0], self.heads
# 生成查询向量 q
q = self.to_q(latents)
# 将输入 x 和潜在变量 latents 连接起来,作为键值对的输入
kv_input = torch.cat((x, latents), dim = -2)
# 将连接后的输入转换为键和值
k, v = self.to_kv(kv_input).chunk(2, dim = -1)
# 对查询、键、值进行维度重排
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# 对查询和键进行 L2 归一化
q, k = map(l2norm, (q, k))
# 对查询和键进行缩放
q = q * self.q_scale
k = k * self.k_scale
# 计算相似度矩阵
sim = einsum('... i d, ... j d -> ... i j', q, k) * self.scale
# 如果存在 mask,则进行填充和掩码处理
if exists(mask):
max_neg_value = -torch.finfo(sim.dtype).max
mask = F.pad(mask, (0, latents.shape[-2]), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
# 计算注意力权重
attn = sim.softmax(dim = -1)
# 计算输出
out = einsum('... i j, ... j d -> ... i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
# 返回输出结果
return self.to_out(out)
# 定义 PerceiverResampler 类,继承自 nn.Module
class PerceiverResampler(nn.Module):
# 初始化函数
def __init__(
self,
*,
dim,
depth,
dim_head = 64,
heads = 8,
num_latents = 64,
num_latents_mean_pooled = 4, # 从序列的均值池化表示派生的潜在变量数量
max_seq_len = 512,
ff_mult = 4
):
super().__init__()
# 创建位置嵌入层
self.pos_emb = nn.Embedding(max_seq_len, dim)
# 初始化潜在变量
self.latents = nn.Parameter(torch.randn(num_latents, dim))
self.to_latents_from_mean_pooled_seq = None
# 如果均值池化的潜在变量数量大于0,则创建相应的层
if num_latents_mean_pooled > 0:
self.to_latents_from_mean_pooled_seq = nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, dim * num_latents_mean_pooled),
Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled)
)
# 创建多层感知器
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads),
FeedForward(dim = dim, mult = ff_mult)
]))
# 前向传播函数
def forward(self, x, mask = None):
n, device = x.shape[1], x.device
pos_emb = self.pos_emb(torch.arange(n, device = device))
x_with_pos = x + pos_emb
latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])
# 如果存在均值池化的潜在变量,则将其与原始潜在变量拼接
if exists(self.to_latents_from_mean_pooled_seq):
meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool))
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
latents = torch.cat((meanpooled_latents, latents), dim = -2)
# 遍历每一层的注意力机制和前馈网络
for attn, ff in self.layers:
latents = attn(x_with_pos, latents, mask = mask) + latents
latents = ff(latents) + latents
return latents
# 定义 Conv3d 类,继承自 nn.Module
class Conv3d(nn.Module):
# 初始化函数
def __init__(
self,
dim,
dim_out = None,
kernel_size = 3,
*,
temporal_kernel_size = None,
**kwargs
):
super().__init__()
dim_out = default(dim_out, dim)
temporal_kernel_size = default(temporal_kernel_size, kernel_size)
# 创建空���卷积层
self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size = kernel_size, padding = kernel_size // 2)
# 创建时间卷积层(如果 kernel_size 大于1)
self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size = temporal_kernel_size) if kernel_size > 1 else None
self.kernel_size = kernel_size
# 初始化时间卷积层的权重为单位矩阵
if exists(self.temporal_conv):
nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity
nn.init.zeros_(self.temporal_conv.bias.data)
# 前向传播函数
def forward(
self,
x,
ignore_time = False
):
b, c, *_, h, w = x.shape
is_video = x.ndim == 5
ignore_time &= is_video
if is_video:
x = rearrange(x, 'b c f h w -> (b f) c h w')
x = self.spatial_conv(x)
if is_video:
x = rearrange(x, '(b f) c h w -> b c f h w', b = b)
if ignore_time or not exists(self.temporal_conv):
return x
x = rearrange(x, 'b c f h w -> (b h w) c f')
# 因果时间卷积 - 时间在 imagen-video 中是因果的
if self.kernel_size > 1:
x = F.pad(x, (self.kernel_size - 1, 0))
x = self.temporal_conv(x)
x = rearrange(x, '(b h w) c f -> b c f h w', h = h, w = w)
return x
# 定义 Attention 类,继承自 nn.Module
class Attention(nn.Module):
# 初始化函数
def __init__(
self,
dim,
*,
dim_head = 64,
heads = 8,
causal = False,
context_dim = None,
rel_pos_bias = False,
rel_pos_bias_mlp_depth = 2,
init_zero = False,
scale = 8
):
# 调用父类的初始化方法
super().__init__()
# 设置缩放因子和是否因果的标志
self.scale = scale
self.causal = causal
# 如果启用相对位置偏置,则创建动态位置偏置对象
self.rel_pos_bias = DynamicPositionBias(dim = dim, heads = heads, depth = rel_pos_bias_mlp_depth) if rel_pos_bias else None
# 初始化头数和内部维度
self.heads = heads
inner_dim = dim_head * heads
# 初始化 LayerNorm
self.norm = LayerNorm(dim)
# 初始化空注意力偏置和空键值对
self.null_attn_bias = nn.Parameter(torch.randn(heads))
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
# 初始化缩放参数
self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))
# 如果存在上下文维度,则初始化上下文处理层
self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None
# 初始化输出层
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
)
# 如果初始化为零,则将输出层的偏置初始化为零
if init_zero:
nn.init.zeros_(self.to_out[-1].g)
def forward(
self,
x,
context = None,
mask = None,
attn_bias = None
):
# 获取输入张量的形状和设备信息
b, n, device = *x.shape[:2], x.device
# 对输入张量进行 LayerNorm 处理
x = self.norm(x)
# 分别计算查询、键、值
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
# 将查询张量重排为多头形式
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
# 添加空键/值以用于分类器的先验网络引导
nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# 如果存在上下文,则添加文本条件
if exists(context):
assert exists(self.to_context)
ck, cv = self.to_context(context).chunk(2, dim = -1)
k = torch.cat((ck, k), dim = -2)
v = torch.cat((cv, v), dim = -2)
# 对查询、键进行 L2 归一化
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
# 计算查询/键的相似性
sim = einsum('b h i d, b j d -> b h i j', q, k) * self.scale
# 相对位置编码(T5 风格)
if not exists(attn_bias) and exists(self.rel_pos_bias):
attn_bias = self.rel_pos_bias(n, device = device, dtype = q.dtype)
if exists(attn_bias):
null_attn_bias = repeat(self.null_attn_bias, 'h -> h n 1', n = n)
attn_bias = torch.cat((null_attn_bias, attn_bias), dim = -1)
sim = sim + attn_bias
# 掩码
max_neg_value = -torch.finfo(sim.dtype).max
if self.causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, max_neg_value)
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
# 注意力
attn = sim.softmax(dim = -1)
# 聚合值
out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# 定义一个伪 Conv2d 函数,使用 Conv3d 但在帧维度上使用大小为1的卷积核
def Conv2d(dim_in, dim_out, kernel, stride = 1, padding = 0, **kwargs):
# 将 kernel 转换为元组
kernel = cast_tuple(kernel, 2)
# 将 stride 转换为元组
stride = cast_tuple(stride, 2)
# 将 padding 转换为元组
padding = cast_tuple(padding, 2)
# 如果 kernel 的长度为2,则在前面添加1
if len(kernel) == 2:
kernel = (1, *kernel)
# 如果 stride 的长度为2,则在前面添加1
if len(stride) == 2:
stride = (1, *stride)
# 如果 padding 的长度为2,则在前面添加0
if len(padding) == 2:
padding = (0, *padding)
# 返回一个 Conv3d 对象
return nn.Conv3d(dim_in, dim_out, kernel, stride = stride, padding = padding, **kwargs)
# 定义一个 Pad 类
class Pad(nn.Module):
def __init__(self, padding, value = 0.):
super().__init__()
self.padding = padding
self.value = value
# 前向传播函数
def forward(self, x):
return F.pad(x, self.padding, value = self.value)
# 定义一个 Upsample 函数
def Upsample(dim, dim_out = None):
dim_out = default(dim_out, dim)
# 返回一个包含 Upsample 和 Conv2d 的序列
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
Conv2d(dim, dim_out, 3, padding = 1)
)
# 定义一个 PixelShuffleUpsample 类
class PixelShuffleUpsample(nn.Module):
def __init__(self, dim, dim_out = None):
super().__init__()
dim_out = default(dim_out, dim)
conv = Conv2d(dim, dim_out * 4, 1)
self.net = nn.Sequential(
conv,
nn.SiLU()
)
self.pixel_shuffle = nn.PixelShuffle(2)
self.init_conv_(conv)
# 初始化卷积层的权重
def init_conv_(self, conv):
o, i, f, h, w = conv.weight.shape
conv_weight = torch.empty(o // 4, i, f, h, w)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
# 前向传播函数
def forward(self, x):
out = self.net(x)
frames = x.shape[2]
out = rearrange(out, 'b c f h w -> (b f) c h w')
out = self.pixel_shuffle(out)
return rearrange(out, '(b f) c h w -> b c f h w', f = frames)
# 定义一个 Downsample 函数
def Downsample(dim, dim_out = None):
dim_out = default(dim_out, dim)
return nn.Sequential(
Rearrange('b c f (h p1) (w p2) -> b (c p1 p2) f h w', p1 = 2, p2 = 2),
Conv2d(dim * 4, dim_out, 1)
)
# 定义一个 TemporalPixelShuffleUpsample 类
class TemporalPixelShuffleUpsample(nn.Module):
def __init__(self, dim, dim_out = None, stride = 2):
super().__init__()
self.stride = stride
dim_out = default(dim_out, dim)
conv = nn.Conv1d(dim, dim_out * stride, 1)
self.net = nn.Sequential(
conv,
nn.SiLU()
)
self.pixel_shuffle = Rearrange('b (c r) n -> b c (n r)', r = stride)
self.init_conv_(conv)
# 初始化卷积层的权重
def init_conv_(self, conv):
o, i, f = conv.weight.shape
conv_weight = torch.empty(o // self.stride, i, f)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r = self.stride)
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
# 前向传播函数
def forward(self, x):
b, c, f, h, w = x.shape
x = rearrange(x, 'b c f h w -> (b h w) c f')
out = self.net(x)
out = self.pixel_shuffle(out)
return rearrange(out, '(b h w) c f -> b c f h w', h = h, w = w)
# 定义一个 TemporalDownsample 函数
def TemporalDownsample(dim, dim_out = None, stride = 2):
dim_out = default(dim_out, dim)
return nn.Sequential(
Rearrange('b c (f p) h w -> b (c p) f h w', p = stride),
Conv2d(dim * stride, dim_out, 1)
)
# 定义一个 SinusoidalPosEmb 类
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
# 前向传播函数
def forward(self, x):
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1)
# 定义一个 LearnedSinusoidalPosEmb 类
class LearnedSinusoidalPosEmb(nn.Module):
# 初始化函数,接受维度参数
def __init__(self, dim):
# 调用父类的初始化函数
super().__init__()
# 断言维度为偶数
assert (dim % 2) == 0
# 计算维度的一半
half_dim = dim // 2
# 初始化权重参数为服从标准正态分布的张量
self.weights = nn.Parameter(torch.randn(half_dim))
# 前向传播函数,接受输入张量 x
def forward(self, x):
# 重新排列输入张量 x 的维度,增加一个维度
x = rearrange(x, 'b -> b 1')
# 计算频率,乘以权重参数和 2π
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
# 将正弦和余弦值拼接在一起
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
# 将输入张量 x 和频率值拼接在一起
fouriered = torch.cat((x, fouriered), dim = -1)
# 返回拼接后的张量
return fouriered
class Block(nn.Module):
# 定义一个块模块,包含归一化、激活函数和卷积操作
def __init__(
self,
dim,
dim_out,
groups = 8,
norm = True
):
super().__init__()
# 初始化 GroupNorm 归一化层,如果不需要归一化则使用 Identity 函数
self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity()
# 初始化激活函数为 SiLU
self.activation = nn.SiLU()
# 初始化卷积操作,输出维度为 dim_out,卷积核大小为 3,填充为 1
self.project = Conv3d(dim, dim_out, 3, padding = 1)
# 前向传播函数,对输入进行归一化、缩放平移、激活和卷积操作
def forward(
self,
x,
scale_shift = None,
ignore_time = False
):
# 对输入进行归一化
x = self.groupnorm(x)
# 如果有缩放平移参数,则对输入进行缩放平移操作
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
# 对归一化后的输入进行激活函数操作
x = self.activation(x)
# 返回卷积操作后的结果
return self.project(x, ignore_time = ignore_time)
class ResnetBlock(nn.Module):
# 定义一个 ResNet 块模块,包含时间 MLP、交叉注意力、块模块和全局上下文注意力
def __init__(
self,
dim,
dim_out,
*,
cond_dim = None,
time_cond_dim = None,
groups = 8,
linear_attn = False,
use_gca = False,
squeeze_excite = False,
**attn_kwargs
):
super().__init__()
self.time_mlp = None
# 如果存在时间条件维度,则初始化时间 MLP
if exists(time_cond_dim):
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_cond_dim, dim_out * 2)
)
self.cross_attn = None
# 如果存在条件维度,则初始化交叉注意力模块
if exists(cond_dim):
attn_klass = CrossAttention if not linear_attn else LinearCrossAttention
self.cross_attn = attn_klass(
dim = dim_out,
context_dim = cond_dim,
**attn_kwargs
)
# 初始化两个块模块
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
# 如果使用全局上下文注意力,则初始化全局上下文模块
self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1)
# 如果输入维度不等于输出维度,则初始化卷积操作
self.res_conv = Conv2d(dim, dim_out, 1) if dim != dim_out else Identity()
# 前向传播函数,包括时间 MLP、交叉注意力、块模块和全局上下文注意力的操作
def forward(
self,
x,
time_emb = None,
cond = None,
ignore_time = False
):
scale_shift = None
# 如果存在时间 MLP 和时间嵌入,则进行时间 MLP 操作
if exists(self.time_mlp) and exists(time_emb):
time_emb = self.time_mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1 1')
scale_shift = time_emb.chunk(2, dim = 1)
# 第一个块模块操作
h = self.block1(x, ignore_time = ignore_time)
# 如果存在交叉注意力模块,则进行交叉注意力操作
if exists(self.cross_attn):
assert exists(cond)
h = rearrange(h, 'b c ... -> b ... c')
h, ps = pack([h], 'b * c')
h = self.cross_attn(h, context = cond) + h
h, = unpack(h, ps, 'b * c')
h = rearrange(h, 'b ... c -> b c ...')
# 第二个块模块操作
h = self.block2(h, scale_shift = scale_shift, ignore_time = ignore_time)
# 全局上下文注意力操作
h = h * self.gca(h)
# 返回结果加上残差连接
return h + self.res_conv(x)
class CrossAttention(nn.Module):
# 定义交叉注意力模块,包含查询、键值映射和输出映射
def __init__(
self,
dim,
*,
context_dim = None,
dim_head = 64,
heads = 8,
norm_context = False,
scale = 8
):
super().__init__()
self.scale = scale
self.heads = heads
inner_dim = dim_head * heads
context_dim = default(context_dim, dim)
# 初始化 LayerNorm 归一化层
self.norm = LayerNorm(dim)
self.norm_context = LayerNorm(context_dim) if norm_context else Identity()
# 初始化查询映射和键值映射
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))
# 初始化输出映射
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
)
# 定义前向传播函数,接受输入 x、上下文 context 和可选的掩码 mask
def forward(self, x, context, mask = None):
# 获取输入 x 的形状信息,包括 batch 大小 b、序列长度 n、设备信息 device
b, n, device = *x.shape[:2], x.device
# 对输入 x 和上下文 context 进行归一化处理
x = self.norm(x)
context = self.norm_context(context)
# 将输入 x 转换为查询 q,上下文 context 转换为键 k 和值 v
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
# 将查询 q、键 k 和值 v 重排为多头注意力的形式
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
# 为先验网络添加空键/值,用于无分类器干预的指导
nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# 对查询 q 和键 k 进行 L2 归一化处理
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
# 计算相似度矩阵
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
# 掩码处理
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
# 对相似度矩阵进行 softmax 操作,得到注意力权重
attn = sim.softmax(dim = -1, dtype = torch.float32)
# 根据注意力权重计算输出
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
# 返回输出结果
return self.to_out(out)
class LinearCrossAttention(CrossAttention):
# 线性交叉注意力类,继承自CrossAttention类
def forward(self, x, context, mask = None):
# 前向传播函数,接受输入x、上下文context和掩码mask,默认为None
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
# 对输入x进行规范化
context = self.norm_context(context)
# 对上下文context进行规范化
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
# 将输入x和上下文context转换为查询q、键k和值v
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads), (q, k, v))
# 重排查询q、键k和值v的维度
# add null key / value for classifier free guidance in prior net
# 为先前网络中的无分类器自由指导添加空键/值
nk, nv = map(lambda t: repeat(t, 'd -> (b h) 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# masking
# 掩码处理
max_neg_value = -torch.finfo(x.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b n -> b n 1')
k = k.masked_fill(~mask, max_neg_value)
v = v.masked_fill(~mask, 0.)
# linear attention
# 线性注意力
q = q.softmax(dim = -1)
k = k.softmax(dim = -2)
q = q * self.scale
context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context)
out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads)
return self.to_out(out)
class LinearAttention(nn.Module):
# 线性注意力类,继承自nn.Module类
def __init__(
self,
dim,
dim_head = 32,
heads = 8,
dropout = 0.05,
context_dim = None,
**kwargs
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm = ChanLayerNorm(dim)
self.nonlin = nn.SiLU()
self.to_q = nn.Sequential(
nn.Dropout(dropout),
Conv2d(dim, inner_dim, 1, bias = False),
Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
)
self.to_k = nn.Sequential(
nn.Dropout(dropout),
Conv2d(dim, inner_dim, 1, bias = False),
Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
)
self.to_v = nn.Sequential(
nn.Dropout(dropout),
Conv2d(dim, inner_dim, 1, bias = False),
Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
)
self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None
self.to_out = nn.Sequential(
Conv2d(inner_dim, dim, 1, bias = False),
ChanLayerNorm(dim)
)
def forward(self, fmap, context = None):
# 前向传播函数,接受特征图fmap和上下文context,默认为None
h, x, y = self.heads, *fmap.shape[-2:]
fmap = self.norm(fmap)
q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v))
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v))
if exists(context):
assert exists(self.to_context)
ck, cv = self.to_context(context).chunk(2, dim = -1)
ck, cv = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (ck, cv))
k = torch.cat((k, ck), dim = -2)
v = torch.cat((v, cv), dim = -2)
q = q.softmax(dim = -1)
k = k.softmax(dim = -2)
q = q * self.scale
context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
out = self.nonlin(out)
return self.to_out(out)
class GlobalContext(nn.Module):
# 全局上下文类,继承自nn.Module类
""" basically a superior form of squeeze-excitation that is attention-esque """
# 基本上是一种类似于注意力的优越形式的挤压激励
def __init__(
self,
*,
dim_in,
dim_out
# 初始化函数,接受输入维度dim_in和输出维度dim_out
# 定义一个继承自 nn.Module 的类,用于实现一个自定义的注意力机制模块
):
# 调用父类的构造函数
super().__init__()
# 定义一个将输入特征维度转换为 K 维度的卷积层
self.to_k = Conv2d(dim_in, 1, 1)
# 计算隐藏层维度,取最大值为 3 或者输出维度的一半
hidden_dim = max(3, dim_out // 2)
# 定义一个神经网络序列,包含卷积层、激活函数和输出层
self.net = nn.Sequential(
Conv2d(dim_in, hidden_dim, 1),
nn.SiLU(), # 使用 SiLU 激活函数
Conv2d(hidden_dim, dim_out, 1),
nn.Sigmoid() # 使用 Sigmoid 激活函数
)
# 定义前向传播函数
def forward(self, x):
# 将输入 x 经过 to_k 卷积层得到 context
context = self.to_k(x)
# 对输入 x 和 context 进行维度重排
x, context = map(lambda t: rearrange(t, 'b n ... -> b n (...)'), (x, context))
# 使用 einsum 计算注意力权重并与输入 x 相乘
out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x)
# 对输出 out 进行维度重排
out = rearrange(out, '... -> ... 1 1')
# 将处理后的 out 输入到神经网络序列中得到最终输出
return self.net(out)
# 定义一个前馈神经网络模块,包含层归一化、线性层、GELU激活函数和线性层
def FeedForward(dim, mult = 2):
# 计算隐藏层维度
hidden_dim = int(dim * mult)
return nn.Sequential(
LayerNorm(dim), # 层归一化
nn.Linear(dim, hidden_dim, bias = False), # 线性层
nn.GELU(), # GELU激活函数
LayerNorm(hidden_dim), # 层归一化
nn.Linear(hidden_dim, dim, bias = False) # 线性层
)
# 定义一个时间标记位移模块
class TimeTokenShift(nn.Module):
def forward(self, x):
if x.ndim != 5:
return x
x, x_shift = x.chunk(2, dim = 1) # 将输入张量按维度1分块
x_shift = F.pad(x_shift, (0, 0, 0, 0, 1, -1), value = 0.) # 对x_shift进行填充
return torch.cat((x, x_shift), dim = 1) # 在维度1上连接张量x和x_shift
# 定义一个通道前馈神经网络模块
def ChanFeedForward(dim, mult = 2, time_token_shift = True):
# 计算隐藏层维度
hidden_dim = int(dim * mult)
return Sequential(
ChanLayerNorm(dim), # 通道层归一化
Conv2d(dim, hidden_dim, 1, bias = False), # 二维卷积层
nn.GELU(), # GELU激活函数
TimeTokenShift() if time_token_shift else None, # 时间标记位移模块
ChanLayerNorm(hidden_dim), # 通道层归一化
Conv2d(hidden_dim, dim, 1, bias = False) # 二维卷积层
)
# 定义一个Transformer块模块
class TransformerBlock(nn.Module):
def __init__(
self,
dim,
*,
depth = 1,
heads = 8,
dim_head = 32,
ff_mult = 2,
ff_time_token_shift = True,
context_dim = None
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim), # 注意力机制
ChanFeedForward(dim = dim, mult = ff_mult, time_token_shift = ff_time_token_shift) # 通道前馈神经网络
]))
def forward(self, x, context = None):
for attn, ff in self.layers:
x = rearrange(x, 'b c ... -> b ... c') # 重新排列张量维度
x, ps = pack([x], 'b * c') # 打包张量
x = attn(x, context = context) + x # 注意力机制处理后与原始张量相加
x, = unpack(x, ps, 'b * c') # 解包张量
x = rearrange(x, 'b ... c -> b c ...') # 重新排列张量维度
x = ff(x) + x # 通道前馈神经网络处理后与原始张量相加
return x
# 定义一个线性注意力Transformer块模块
class LinearAttentionTransformerBlock(nn.Module):
def __init__(
self,
dim,
*,
depth = 1,
heads = 8,
dim_head = 32,
ff_mult = 2,
ff_time_token_shift = True,
context_dim = None,
**kwargs
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim), # 线性注意力机制
ChanFeedForward(dim = dim, mult = ff_mult, time_token_shift = ff_time_token_shift) # 通道前馈神经网络
]))
def forward(self, x, context = None):
for attn, ff in self.layers:
x = attn(x, context = context) + x # 线性注意力机制处理后与原始张量相加
x = ff(x) + x # 通道前馈神经网络处理后与原始张量相加
return x
# 定义一个交叉嵌入层模块
class CrossEmbedLayer(nn.Module):
def __init__(
self,
dim_in,
kernel_sizes,
dim_out = None,
stride = 2
):
super().__init__()
assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
dim_out = default(dim_out, dim_in)
kernel_sizes = sorted(kernel_sizes)
num_scales = len(kernel_sizes)
# 计算每个尺度的维度
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
self.convs = nn.ModuleList([])
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
self.convs.append(Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
def forward(self, x):
fmaps = tuple(map(lambda conv: conv(x), self.convs)) # 对输入张量进行卷积操作
return torch.cat(fmaps, dim = 1) # 在维度1上连接卷积结果
# 定义一个上采样合并器模块
class UpsampleCombiner(nn.Module):
def __init__(
self,
dim,
*,
enabled = False,
dim_ins = tuple(),
dim_outs = tuple()
# 初始化函数,设置输出维度和是否启用
):
# 调用父类的初始化函数
super().__init__()
# 将输出维度转换为元组,长度与输入维度相同
dim_outs = cast_tuple(dim_outs, len(dim_ins))
# 断言输入维度和输出维度长度相同
assert len(dim_ins) == len(dim_outs)
# 设置是否启用标志
self.enabled = enabled
# 如果未启用,则直接设置输出维度并返回
if not self.enabled:
self.dim_out = dim
return
# 根据输入维度和输出维度创建模块列表
self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
# 计算最终输出维度
self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
# 前向传播函数,处理输入数据和特征图
def forward(self, x, fmaps = None):
# 获取输入数据的目标尺寸
target_size = x.shape[-1]
# 设置特征图为默认值空元组
fmaps = default(fmaps, tuple())
# 如果未启用或特征图为空或卷积模块为空,则直接返回输入数据
if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
return x
# 将特征图调整为目标尺寸
fmaps = [resize_video_to(fmap, target_size) for fmap in fmaps]
# 对每个特征图应用对应的卷积模块
outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
# 拼接输入数据和卷积结果,沿指定维度拼接
return torch.cat((x, *outs), dim = 1)
# 定义一个动态位置偏置的神经网络模块
class DynamicPositionBias(nn.Module):
def __init__(
self,
dim,
*,
heads,
depth
):
super().__init__()
self.mlp = nn.ModuleList([])
# 添加一个线性层、LayerNorm 和 SiLU 激活函数到 MLP 中
self.mlp.append(nn.Sequential(
nn.Linear(1, dim),
LayerNorm(dim),
nn.SiLU()
))
# 根据深度添加多个线性层、LayerNorm 和 SiLU 激活函数到 MLP 中
for _ in range(max(depth - 1, 0)):
self.mlp.append(nn.Sequential(
nn.Linear(dim, dim),
LayerNorm(dim),
nn.SiLU()
))
# 添加一个线性层到 MLP 中
self.mlp.append(nn.Linear(dim, heads)
# 前向传播函数
def forward(self, n, device, dtype):
# 创建张量 i 和 j
i = torch.arange(n, device = device)
j = torch.arange(n, device = device)
# 计算位置索引
indices = rearrange(i, 'i -> i 1') - rearrange(j, 'j -> 1 j')
indices += (n - 1)
# 创建位置张量
pos = torch.arange(-n + 1, n, device = device, dtype = dtype)
pos = rearrange(pos, '... -> ... 1')
# 遍历 MLP 中的每一层
for layer in self.mlp:
pos = layer(pos)
# 计算位置偏置
bias = pos[indices]
bias = rearrange(bias, 'i j h -> h i j')
return bias
# 定义一个 3D UNet 神经网络模块
class Unet3D(nn.Module):
def __init__(
self,
*,
dim,
text_embed_dim = get_encoded_dim(DEFAULT_T5_NAME),
num_resnet_blocks = 1,
cond_dim = None,
num_image_tokens = 4,
num_time_tokens = 2,
learned_sinu_pos_emb_dim = 16,
out_dim = None,
dim_mults = (1, 2, 4, 8),
temporal_strides = 1,
cond_images_channels = 0,
channels = 3,
channels_out = None,
attn_dim_head = 64,
attn_heads = 8,
ff_mult = 2.,
ff_time_token_shift = True, # 在 feedforwards 的隐藏层中沿时间轴进行令牌移位
lowres_cond = False, # 用于级联扩散
layer_attns = False,
layer_attns_depth = 1,
layer_attns_add_text_cond = True, # 是否在自注意力块中加入文本嵌入
attend_at_middle = True, # 是否在瓶颈处进行一层注意力
time_rel_pos_bias_depth = 2,
time_causal_attn = True,
layer_cross_attns = True,
use_linear_attn = False,
use_linear_cross_attn = False,
cond_on_text = True,
max_text_len = 256,
init_dim = None,
resnet_groups = 8,
init_conv_kernel_size = 7, # 初始卷积的内核大小
init_cross_embed = True,
init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),
attn_pool_text = True,
attn_pool_num_latents = 32,
dropout = 0.,
memory_efficient = False,
init_conv_to_final_conv_residual = False,
use_global_context_attn = True,
scale_skip_connection = True,
final_resnet_block = True,
final_conv_kernel_size = 3,
self_cond = False,
combine_upsample_fmaps = False, # 在所有上采样块中合并特征图
pixel_shuffle_upsample = True, # 可能解决棋盘伪影
resize_mode = 'nearest'
# 如果当前 UNet 的设置不正确,则重新初始化 UNet
def cast_model_parameters(
self,
*,
lowres_cond,
text_embed_dim,
channels,
channels_out,
cond_on_text
# 如果当前对象的属性与传入参数相同,则直接返回当前对象
):
if lowres_cond == self.lowres_cond and \
channels == self.channels and \
cond_on_text == self.cond_on_text and \
text_embed_dim == self._locals['text_embed_dim'] and \
channels_out == self.channels_out:
return self
# 更新参数字典
updated_kwargs = dict(
lowres_cond = lowres_cond,
text_embed_dim = text_embed_dim,
channels = channels,
channels_out = channels_out,
cond_on_text = cond_on_text
)
# 返回一个新的类实例,使用当前对象的属性和更新后的参数
return self.__class__(**{**self._locals, **updated_kwargs})
# 返回完整的unet配置及其参数状态字典的方法
def to_config_and_state_dict(self):
return self._locals, self.state_dict()
# 从配置和状态字典中重新创建unet的类方法
@classmethod
def from_config_and_state_dict(klass, config, state_dict):
unet = klass(**config)
unet.load_state_dict(state_dict)
return unet
# 将unet持久化到磁盘的方法
def persist_to_file(self, path):
path = Path(path)
path.parents[0].mkdir(exist_ok = True, parents = True)
config, state_dict = self.to_config_and_state_dict()
pkg = dict(config = config, state_dict = state_dict)
torch.save(pkg, str(path))
# 从使用`persist_to_file`保存的文件中重新创建unet的类方法
@classmethod
def hydrate_from_file(klass, path):
path = Path(path)
assert path.exists()
pkg = torch.load(str(path))
assert 'config' in pkg and 'state_dict' in pkg
config, state_dict = pkg['config'], pkg['state_dict']
return Unet.from_config_and_state_dict(config, state_dict)
# 带有分类器自由引导的前向传播
def forward_with_cond_scale(
self,
*args,
cond_scale = 1.,
**kwargs
):
logits = self.forward(*args, **kwargs)
if cond_scale == 1:
return logits
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
self,
x,
time,
*,
lowres_cond_img = None,
lowres_noise_times = None,
text_embeds = None,
text_mask = None,
cond_images = None,
cond_video_frames = None,
post_cond_video_frames = None,
self_cond = None,
cond_drop_prob = 0.,
ignore_time = False
.\lucidrains\imagen-pytorch\imagen_pytorch\t5.py
# 导入 torch 库
import torch
# 导入 transformers 库
import transformers
# 导入 List 类型
from typing import List
# 从 transformers 库中导入 T5Tokenizer, T5EncoderModel, T5Config
from transformers import T5Tokenizer, T5EncoderModel, T5Config
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 设置 transformers 库的日志级别为 error
transformers.logging.set_verbosity_error()
# 定义函数,判断值是否存在
def exists(val):
return val is not None
# 定义函数,返回默认值
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
# 配置
# 定义最大长度为 256
MAX_LENGTH = 256
# 默认的 T5 模型名称
DEFAULT_T5_NAME = 'google/t5-v1_1-base'
# T5 配置字典
T5_CONFIGS = {}
# 全局单例变量
# 获取指定名称的 tokenizer
def get_tokenizer(name):
tokenizer = T5Tokenizer.from_pretrained(name, model_max_length=MAX_LENGTH)
return tokenizer
# 获取指定名称的模型
def get_model(name):
model = T5EncoderModel.from_pretrained(name)
return model
# 获取指定名称的模型和 tokenizer
def get_model_and_tokenizer(name):
global T5_CONFIGS
if name not in T5_CONFIGS:
T5_CONFIGS[name] = dict()
if "model" not in T5_CONFIGS[name]:
T5_CONFIGS[name]["model"] = get_model(name)
if "tokenizer" not in T5_CONFIGS[name]:
T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)
return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']
# 获取编码维度
def get_encoded_dim(name):
if name not in T5_CONFIGS:
# 避免仅获取维度时加载模型
config = T5Config.from_pretrained(name)
T5_CONFIGS[name] = dict(config=config)
elif "config" in T5_CONFIGS[name]:
config = T5_CONFIGS[name]["config"]
elif "model" in T5_CONFIGS[name]:
config = T5_CONFIGS[name]["model"].config
else:
assert False
return config.d_model
# 编码文本
# 对文本进行分词
def t5_tokenize(
texts: List[str],
name = DEFAULT_T5_NAME
):
t5, tokenizer = get_model_and_tokenizer(name)
if torch.cuda.is_available():
t5 = t5.cuda()
device = next(t5.parameters()).device
encoded = tokenizer.batch_encode_plus(
texts,
return_tensors = "pt",
padding = 'longest',
max_length = MAX_LENGTH,
truncation = True
)
input_ids = encoded.input_ids.to(device)
attn_mask = encoded.attention_mask.to(device)
return input_ids, attn_mask
# 对分词后的文本进行编码
def t5_encode_tokenized_text(
token_ids,
attn_mask = None,
pad_id = None,
name = DEFAULT_T5_NAME
):
assert exists(attn_mask) or exists(pad_id)
t5, _ = get_model_and_tokenizer(name)
attn_mask = default(attn_mask, lambda: (token_ids != pad_id).long())
t5.eval()
with torch.no_grad():
output = t5(input_ids = token_ids, attention_mask = attn_mask)
encoded_text = output.last_hidden_state.detach()
attn_mask = attn_mask.bool()
encoded_text = encoded_text.masked_fill(~rearrange(attn_mask, '... -> ... 1'), 0.) # 强制所有填充的嵌入为 0
return encoded_text
# 对文本进行编码
def t5_encode_text(
texts: List[str],
name = DEFAULT_T5_NAME,
return_attn_mask = False
):
token_ids, attn_mask = t5_tokenize(texts, name = name)
encoded_text = t5_encode_tokenized_text(token_ids, attn_mask = attn_mask, name = name)
if return_attn_mask:
attn_mask = attn_mask.bool()
return encoded_text, attn_mask
return encoded_text
.\lucidrains\imagen-pytorch\imagen_pytorch\test\test_trainer.py
# 从 imagen_pytorch 包中导入 ImagenTrainer 类
# 从 imagen_pytorch 包中导入 ImagenConfig 类
# 从 imagen_pytorch 包中导入 t5_encode_text 函数
# 从 torch.utils.data 包中导入 Dataset 类
# 导入 torch 库
from imagen_pytorch.trainer import ImagenTrainer
from imagen_pytorch.configs import ImagenConfig
from imagen_pytorch.t5 import t5_encode_text
from torch.utils.data import Dataset
import torch
# 定义一个测试函数,用于测试 ImagenTrainer 类的实例化
def test_trainer_instantiation():
# 定义 unet1 字典,包含模型的参数配置
unet1 = dict(
dim = 8,
dim_mults = (1, 1, 1, 1),
num_resnet_blocks = 1,
layer_attns = False,
layer_cross_attns = False,
attn_heads = 2
)
# 创建 ImagenConfig 对象,传入 unet1 参数配置
imagen = ImagenConfig(
unets=(unet1,),
image_sizes=(64,),
).create()
# 实例化 ImagenTrainer 对象,传入 imagen 参数
trainer = ImagenTrainer(
imagen=imagen
)
# 定义一个测试函数,用于测试训练步骤
def test_trainer_step():
# 定义一个自定义的 Dataset 类,用于生成训练数据
class TestDataset(Dataset):
def __init__(self):
super().__init__()
def __len__(self):
return 16
def __getitem__(self, index):
return (torch.zeros(3, 64, 64), torch.zeros(6, 768))
# 定义 unet1 字典,包含模型的参数配置
unet1 = dict(
dim = 8,
dim_mults = (1, 1, 1, 1),
num_resnet_blocks = 1,
layer_attns = False,
layer_cross_attns = False,
attn_heads = 2
)
# 创建 ImagenConfig 对象,传入 unet1 参数配置
imagen = ImagenConfig(
unets=(unet1,),
image_sizes=(64,),
).create()
# 实例化 ImagenTrainer 对象,传入 imagen 参数
trainer = ImagenTrainer(
imagen=imagen
)
# 创建 TestDataset 对象
ds = TestDataset()
# 将数据集添加到训练器中,设置批量大小为 8
trainer.add_train_dataset(ds, batch_size=8)
# 执行一次训练步骤
trainer.train_step(1)
# 断言训练步骤的数量为 1
assert trainer.num_steps_taken(1) == 1
.\lucidrains\imagen-pytorch\imagen_pytorch\test\__init__.py
# 从 imagen_pytorch.test 模块中导入 test_trainer 函数
from imagen_pytorch.test import test_trainer
.\lucidrains\imagen-pytorch\imagen_pytorch\trainer.py
# 导入必要的库
import os
from math import ceil
from contextlib import contextmanager, nullcontext
from functools import partial, wraps
from collections.abc import Iterable
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from torch.cuda.amp import autocast, GradScaler
import pytorch_warmup as warmup
from imagen_pytorch.imagen_pytorch import Imagen, NullUnet
from imagen_pytorch.elucidated_imagen import ElucidatedImagen
from imagen_pytorch.data import cycle
from imagen_pytorch.version import __version__
from packaging import version
import numpy as np
from ema_pytorch import EMA
from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs
from fsspec.core import url_to_fs
from fsspec.implementations.local import LocalFileSystem
# 辅助函数
# 检查值是否存在
def exists(val):
return val is not None
# 返回值或默认值
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
# 将值转换为元组
def cast_tuple(val, length = 1):
if isinstance(val, list):
val = tuple(val)
return val if isinstance(val, tuple) else ((val,) * length)
# 查找第一个满足条件的元素的索引
def find_first(fn, arr):
for ind, el in enumerate(arr):
if fn(el):
return ind
return -1
# 选择并弹出指定键的值
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
# 根据键的条件分组字典
def group_dict_by_key(cond, d):
return_val = [dict(),dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
# 检查字符串是否以指定前缀开头
def string_begins_with(prefix, str):
return str.startswith(prefix)
# 根据键的前缀分组字典
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
# 根据前缀分组字典并修剪键
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
return kwargs_without_prefix, kwargs
# 将数字分成组
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
# URL转换为文���系统、存储桶、路径 - 用于将检查点保存到云端
def url_to_bucket(url):
if '://' not in url:
return url
_, suffix = url.split('://')
if prefix in {'gs', 's3'}:
return suffix.split('/')[0]
else:
raise ValueError(f'storage type prefix "{prefix}" is not supported yet')
# 装饰器
# 模型评估装饰器
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# 转换为Torch张量装饰器
def cast_torch_tensor(fn, cast_fp16 = False):
@wraps(fn)
def inner(model, *args, **kwargs):
device = kwargs.pop('_device', model.device)
cast_device = kwargs.pop('_cast_device', True)
should_cast_fp16 = cast_fp16 and model.cast_half_at_training
kwargs_keys = kwargs.keys()
all_args = (*args, *kwargs.values())
split_kwargs_index = len(all_args) - len(kwargs_keys)
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
if cast_device:
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
if should_cast_fp16:
all_args = tuple(map(lambda t: t.half() if exists(t) and isinstance(t, torch.Tensor) and t.dtype != torch.bool else t, all_args))
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
out = fn(model, *args, **kwargs)
return out
return inner
# 定义一个函数,将可迭代对象按照指定大小分割成子列表
def split_iterable(it, split_size):
accum = []
# 遍历可迭代对象,根据指定大小分割成子列表
for ind in range(ceil(len(it) / split_size)):
start_index = ind * split_size
accum.append(it[start_index: (start_index + split_size)])
return accum
# 定义一个函数,根据不同类型的输入进行分割操作
def split(t, split_size = None):
# 如果未指定分割大小,则直接返回输入
if not exists(split_size):
return t
# 如果输入是 torch.Tensor 类型,则按照指定大小在指定维度上进行分割
if isinstance(t, torch.Tensor):
return t.split(split_size, dim = 0)
# 如果输入是可迭代对象,则调用 split_iterable 函数进行分割
if isinstance(t, Iterable):
return split_iterable(t, split_size)
# 其他情况返回类型错误
return TypeError
# 定义一个函数,查找满足条件的第一个元素
def find_first(cond, arr):
# 遍历数组,找到满足条件的第一个元素并返回
for el in arr:
if cond(el):
return el
return None
# 定义一个函数,将参数和关键字参数按照指定大小分割成子列表
def split_args_and_kwargs(*args, split_size = None, **kwargs):
# 将所有参数和关键字参数合并成一个列表
all_args = (*args, *kwargs.values())
len_all_args = len(all_args)
# 找到第一个是 torch.Tensor 类型的参数
first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
assert exists(first_tensor)
# 获取第一个 tensor 的大小作为 batch_size
batch_size = len(first_tensor)
split_size = default(split_size, batch_size)
num_chunks = ceil(batch_size / split_size)
dict_len = len(kwargs)
dict_keys = kwargs.keys()
split_kwargs_index = len_all_args - dict_len
# 对所有参数和关键字参数进行分割操作
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
chunk_sizes = num_to_groups(batch_size, split_size)
# 遍历分割后的结果,生成分块大小比例和分块后的参数和关键字参数
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
chunk_size_frac = chunk_size / batch_size
yield chunk_size_frac, (chunked_args, chunked_kwargs)
# 定义一个装饰器函数,用于对输入的函数进行分块处理
def imagen_sample_in_chunks(fn):
@wraps(fn)
def inner(self, *args, max_batch_size = None, **kwargs):
# 如果未指定最大批处理大小,则直接调用原函数
if not exists(max_batch_size):
return fn(self, *args, **kwargs)
# 如果是无条件的训练,则根据最大批处理大小分块处理
if self.imagen.unconditional:
batch_size = kwargs.get('batch_size')
batch_sizes = num_to_groups(batch_size, max_batch_size)
outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
else:
# 否则根据参数和关键字参数进行分块处理
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
# 如果输出是 torch.Tensor 类型,则按照指定维��拼接
if isinstance(outputs[0], torch.Tensor):
return torch.cat(outputs, dim = 0)
# 否则对输出进行拼接处理
return list(map(lambda t: torch.cat(t, dim = 0), list(zip(*outputs))))
return inner
# 定义一个函数,用于恢复模型的部分参数
def restore_parts(state_dict_target, state_dict_from):
for name, param in state_dict_from.items():
if name not in state_dict_target:
continue
if param.size() == state_dict_target[name].size():
state_dict_target[name].copy_(param)
else:
print(f"layer {name}({param.size()} different than target: {state_dict_target[name].size()}")
return state_dict_target
# 定义一个类,用于图像生成的训练
class ImagenTrainer(nn.Module):
locked = False
def __init__(
self,
imagen = None,
imagen_checkpoint_path = None,
use_ema = True,
lr = 1e-4,
eps = 1e-8,
beta1 = 0.9,
beta2 = 0.99,
max_grad_norm = None,
group_wd_params = True,
warmup_steps = None,
cosine_decay_max_steps = None,
only_train_unet_number = None,
fp16 = False,
precision = None,
split_batches = True,
dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks', 'cond_images'),
verbose = True,
split_valid_fraction = 0.025,
split_valid_from_train = False,
split_random_seed = 42,
checkpoint_path = None,
checkpoint_every = None,
checkpoint_fs = None,
fs_kwargs: dict = None,
max_checkpoints_keep = 20,
**kwargs
# 准备训练器,确保训练器尚未准备好,设置只训练的 UNet 编号,并将 prepared 标记为 True
def prepare(self):
assert not self.prepared, f'The trainer is allready prepared'
self.validate_and_set_unet_being_trained(self.only_train_unet_number)
self.prepared = True
# 计算属性
@property
def device(self):
return self.accelerator.device
@property
def is_distributed(self):
return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
@property
def is_main(self):
return self.accelerator.is_main_process
@property
def is_local_main(self):
return self.accelerator.is_local_main_process
@property
def unwrapped_unet(self):
return self.accelerator.unwrap_model(self.unet_being_trained)
# 优化器辅助函数
def get_lr(self, unet_number):
self.validate_unet_number(unet_number)
unet_index = unet_number - 1
optim = getattr(self, f'optim{unet_index}')
return optim.param_groups[0]['lr']
# 仅允许同时训练一个 UNet 的函数
def validate_and_set_unet_being_trained(self, unet_number = None):
if exists(unet_number):
self.validate_unet_number(unet_number)
assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet'
self.only_train_unet_number = unet_number
self.imagen.only_train_unet_number = unet_number
if not exists(unet_number):
return
self.wrap_unet(unet_number)
def wrap_unet(self, unet_number):
if hasattr(self, 'one_unet_wrapped'):
return
unet = self.imagen.get_unet(unet_number)
unet_index = unet_number - 1
optimizer = getattr(self, f'optim{unet_index}')
scheduler = getattr(self, f'scheduler{unet_index}')
if self.train_dl:
self.unet_being_trained, self.train_dl, optimizer = self.accelerator.prepare(unet, self.train_dl, optimizer)
else:
self.unet_being_trained, optimizer = self.accelerator.prepare(unet, optimizer)
if exists(scheduler):
scheduler = self.accelerator.prepare(scheduler)
setattr(self, f'optim{unet_index}', optimizer)
setattr(self, f'scheduler{unet_index}', scheduler)
self.one_unet_wrapped = True
# 由于没有每个优化器单独的 gradscaler,对 accelerator 进行修改
def set_accelerator_scaler(self, unet_number):
def patch_optimizer_step(accelerated_optimizer, method):
def patched_step(*args, **kwargs):
accelerated_optimizer._accelerate_step_called = True
return method(*args, **kwargs)
return patched_step
unet_number = self.validate_unet_number(unet_number)
scaler = getattr(self, f'scaler{unet_number - 1}')
self.accelerator.scaler = scaler
for optimizer in self.accelerator._optimizers:
optimizer.scaler = scaler
optimizer._accelerate_step_called = False
optimizer._optimizer_original_step_method = optimizer.optimizer.step
optimizer._optimizer_patched_step_method = patch_optimizer_step(optimizer, optimizer.optimizer.step)
# 辅助打印函数
def print(self, msg):
if not self.is_main:
return
if not self.verbose:
return
return self.accelerator.print(msg)
# 验证 UNet 编号
def validate_unet_number(self, unet_number = None):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}'
return unet_number
# 训练步骤数
# 返回指定 U-Net 编号的训练步数
def num_steps_taken(self, unet_number = None):
# 如果只有一个 U-Net,则默认使用编号为 1
if self.num_unets == 1:
unet_number = default(unet_number, 1)
# 返回指定 U-Net 的训练步数
return self.steps[unet_number - 1].item()
# 打印未训练的 U-Net
def print_untrained_unets(self):
print_final_error = False
# 遍历训练步数和 U-Net 对象,检查是否未训练
for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)):
if steps > 0 or isinstance(unet, NullUnet):
continue
# 打印未训练的 U-Net 编号
self.print(f'unet {ind + 1} has not been trained')
print_final_error = True
# 如果存在未训练的 U-Net,则打印提示信息
if print_final_error:
self.print('when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets')
# 数据相关函数
# 添加训练数据加载器
def add_train_dataloader(self, dl = None):
if not exists(dl):
return
# 确保训练数据加载器未添加过
assert not exists(self.train_dl), 'training dataloader was already added'
assert not self.prepared, f'You need to add the dataset before preperation'
self.train_dl = dl
# 添加验证数据加载器
def add_valid_dataloader(self, dl):
if not exists(dl):
return
# 确保验证数据加载器未添加过
assert not exists(self.valid_dl), 'validation dataloader was already added'
assert not self.prepared, f'You need to add the dataset before preperation'
self.valid_dl = dl
# 添加训练数据集
def add_train_dataset(self, ds = None, *, batch_size, **dl_kwargs):
if not exists(ds):
return
# 确保训练数据加载器未添加过
assert not exists(self.train_dl), 'training dataloader was already added'
# 如果需要从训练数据集中分割验证数据集
valid_ds = None
if self.split_valid_from_train:
# 计算训练数据集和验证数据集的大小
train_size = int((1 - self.split_valid_fraction) * len(ds)
valid_size = len(ds) - train_size
# 随机分割数据集
ds, valid_ds = random_split(ds, [train_size, valid_size], generator = torch.Generator().manual_seed(self.split_random_seed))
self.print(f'training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples')
# 创建数据加载器并添加训练数据加载器
dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs)
self.add_train_dataloader(dl)
# 如果不需要从训练数据集中分割验证数据集,则直接返回
if not self.split_valid_from_train:
return
# 添加验证数据集
self.add_valid_dataset(valid_ds, batch_size = batch_size, **dl_kwargs)
# 添加验证数据集
def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs):
if not exists(ds):
return
# 确保验证数据加载器未添加过
assert not exists(self.valid_dl), 'validation dataloader was already added'
# 创建数据加载器并添加验证数据加载器
dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs)
self.add_valid_dataloader(dl)
# 创建训练数据迭代器
def create_train_iter(self):
assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet'
if exists(self.train_dl_iter):
return
self.train_dl_iter = cycle(self.train_dl)
# 创建验证数据迭代器
def create_valid_iter(self):
assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet'
if exists(self.valid_dl_iter):
return
self.valid_dl_iter = cycle(self.valid_dl)
# 训练步骤
def train_step(self, *, unet_number = None, **kwargs):
if not self.prepared:
self.prepare()
self.create_train_iter()
kwargs = {'unet_number': unet_number, **kwargs}
loss = self.step_with_dl_iter(self.train_dl_iter, **kwargs)
self.update(unet_number = unet_number)
return loss
# 验证步骤
@torch.no_grad()
@eval_decorator
def valid_step(self, **kwargs):
if not self.prepared:
self.prepare()
self.create_valid_iter()
context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext
with context():
loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs)
return loss
# 使用 dl_iter 迭代器获取下一个数据元组
def step_with_dl_iter(self, dl_iter, **kwargs):
dl_tuple_output = cast_tuple(next(dl_iter))
# 将数据元组转换为字典
model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output)))
# 调用 forward 方法计算损失
loss = self.forward(**{**kwargs, **model_input})
return loss
# 检查点函数
# 获取所有按照时间排序的检查点文件
@property
def all_checkpoints_sorted(self):
glob_pattern = os.path.join(self.checkpoint_path, '*.pt')
checkpoints = self.fs.glob(glob_pattern)
sorted_checkpoints = sorted(checkpoints, key = lambda x: int(str(x).split('.')[-2]), reverse = True)
return sorted_checkpoints
# 从检查点文件夹加载模型
def load_from_checkpoint_folder(self, last_total_steps = -1):
if last_total_steps != -1:
filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt')
self.load(filepath)
return
sorted_checkpoints = self.all_checkpoints_sorted
if len(sorted_checkpoints) == 0:
self.print(f'no checkpoints found to load from at {self.checkpoint_path}')
return
last_checkpoint = sorted_checkpoints[0]
self.load(last_checkpoint)
# 保存到检查点文件夹
def save_to_checkpoint_folder(self):
self.accelerator.wait_for_everyone()
if not self.can_checkpoint:
return
total_steps = int(self.steps.sum().item())
filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt')
self.save(filepath)
if self.max_checkpoints_keep <= 0:
return
sorted_checkpoints = self.all_checkpoints_sorted
checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:]
for checkpoint in checkpoints_to_discard:
self.fs.rm(checkpoint)
# 保存和加载函数
# 保存模型到指定路径
def save(
self,
path,
overwrite = True,
without_optim_and_sched = False,
**kwargs
):
self.accelerator.wait_for_everyone()
if not self.can_checkpoint:
return
fs = self.fs
assert not (fs.exists(path) and not overwrite)
self.reset_ema_unets_all_one_device()
# 构建保存对象
save_obj = dict(
model = self.imagen.state_dict(),
version = __version__,
steps = self.steps.cpu(),
**kwargs
)
save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple()
# 保存优化器和调度器状态
for ind in save_optim_and_sched_iter:
scaler_key = f'scaler{ind}'
optimizer_key = f'optim{ind}'
scheduler_key = f'scheduler{ind}'
warmup_scheduler_key = f'warmup{ind}'
scaler = getattr(self, scaler_key)
optimizer = getattr(self, optimizer_key)
scheduler = getattr(self, scheduler_key)
warmup_scheduler = getattr(self, warmup_scheduler_key)
if exists(scheduler):
save_obj = {**save_obj, scheduler_key: scheduler.state_dict()}
if exists(warmup_scheduler):
save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()}
save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
# 确定是否存在 imagen 配置
if hasattr(self.imagen, '_config'):
self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"<prompt>"')
save_obj = {
**save_obj,
'imagen_type': 'elucidated' if self.is_elucidated else 'original',
'imagen_params': self.imagen._config
}
# 保存到指定路径
with fs.open(path, 'wb') as f:
torch.save(save_obj, f)
self.print(f'checkpoint saved to {path}')
# 加载模型参数和优化器状态
def load(self, path, only_model = False, strict = True, noop_if_not_exist = False):
# 获取文件系统对象
fs = self.fs
# 如果文件不存在且设置了不执行操作,则打印消息并返回
if noop_if_not_exist and not fs.exists(path):
self.print(f'trainer checkpoint not found at {str(path)}')
return
# 断言文件存在,否则抛出异常
assert fs.exists(path), f'{path} does not exist'
# 重置所有 EMA 模型到同一设备上
self.reset_ema_unets_all_one_device()
# 避免在主进程中使用 Accelerate 时产生额外的 GPU 内存使用
with fs.open(path) as f:
# 加载模型参数和优化器状态
loaded_obj = torch.load(f, map_location='cpu')
# 检查加载的模型版本是否与当前包版本一致
if version.parse(__version__) != version.parse(loaded_obj['version']):
self.print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}')
try:
# 加载模型参数
self.imagen.load_state_dict(loaded_obj['model'], strict = strict)
except RuntimeError:
print("Failed loading state dict. Trying partial load")
# 尝试部分加载模型参数
self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(),
loaded_obj['model']))
# 如果只加载模型参数,则返回加载的对象
if only_model:
return loaded_obj
# 复制加载的步数
self.steps.copy_(loaded_obj['steps'])
# 遍历所有 U-Net 模型
for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
optimizer_key = f'optim{ind}'
scheduler_key = f'scheduler{ind}'
warmup_scheduler_key = f'warmup{ind}'
# 获取对应的 scaler、optimizer、scheduler 和 warmup_scheduler
scaler = getattr(self, scaler_key)
optimizer = getattr(self, optimizer_key)
scheduler = getattr(self, scheduler_key)
warmup_scheduler = getattr(self, warmup_scheduler_key)
# 如果 scheduler 存在且在加载对象中有对应的键,则加载其状态
if exists(scheduler) and scheduler_key in loaded_obj:
scheduler.load_state_dict(loaded_obj[scheduler_key])
# 如果 warmup_scheduler 存在且在加载对象中���对应的键,则加载其状态
if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj:
warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key])
# 如果 optimizer 存在,则尝试加载其状态
if exists(optimizer):
try:
optimizer.load_state_dict(loaded_obj[optimizer_key])
scaler.load_state_dict(loaded_obj[scaler_key])
except:
self.print('could not load optimizer and scaler, possibly because you have turned on mixed precision training since the last run. resuming with new optimizer and scalers')
# 如果使用 EMA,则加载 EMA 模型参数
if self.use_ema:
assert 'ema' in loaded_obj
try:
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
except RuntimeError:
print("Failed loading state dict. Trying partial load")
self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(),
loaded_obj['ema']))
# 打印加载成功的消息,并返回加载的对象
self.print(f'checkpoint loaded from {path}')
return loaded_obj
# 获取所有 EMA 模型
@property
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
# 获取指定编号的 EMA 模型
def get_ema_unet(self, unet_number = None):
# 如果不使用 EMA,则返回
if not self.use_ema:
return
# 验证并获取正确的 U-Net 编号
unet_number = self.validate_unet_number(unet_number)
index = unet_number - 1
# 如果 unets 是 nn.ModuleList,则转换为列表并更新 ema_unets
if isinstance(self.unets, nn.ModuleList):
unets_list = [unet for unet in self.ema_unets]
delattr(self, 'ema_unets')
self.ema_unets = unets_list
# 将当前训练的 EMA 模型移到指定设备上
if index != self.ema_unet_being_trained_index:
for unet_index, unet in enumerate(self.ema_unets):
unet.to(self.device if unet_index == index else 'cpu')
# 更新当前训练的 EMA 模型索引,并返回对应的 EMA 模型
self.ema_unet_being_trained_index = index
return self.ema_unets[index]
# 重置所有 EMA 模型到指定设备上
def reset_ema_unets_all_one_device(self, device = None):
# 如果不使用 EMA,则返回
if not self.use_ema:
return
# 获取默认设备
device = default(device, self.device)
# 将所有 EMA 模型转移到指定设备上
self.ema_unets = nn.ModuleList([*self.ema_unets])
self.ema_unets.to(device)
# 重置当前训练的 EMA 模型索引
self.ema_unet_being_trained_index = -1
# 禁用梯度计算
@torch.no_grad()
# 定义一个上下文管理器,用于控制是否使用指数移动平均的 U-Net 模型
@contextmanager
def use_ema_unets(self):
# 如果不使用指数移动平均模型,则直接返回输出
if not self.use_ema:
output = yield
return output
# 重置所有 U-Net 模型为同一设备上的指数移动平均模型
self.reset_ema_unets_all_one_device()
self.imagen.reset_unets_all_one_device()
# 将 U-Net 模型设置为评估模式
self.unets.eval()
# 保存可训练的 U-Net 模型,然后将指数移动平均模型用于采样
trainable_unets = self.imagen.unets
self.imagen.unets = self.unets
output = yield
# 恢复原始的训练 U-Net 模型
self.imagen.unets = trainable_unets
# 将指数移动平均模型的 U-Net 恢复到原始设备
for ema in self.ema_unets:
ema.restore_ema_model_device()
return output
# 打印 U-Net 模型的设备信息
def print_unet_devices(self):
self.print('unet devices:')
for i, unet in enumerate(self.imagen.unets):
device = next(unet.parameters()).device
self.print(f'\tunet {i}: {device}')
# 如果不使用指数移动平均模型,则直接返回
if not self.use_ema:
return
self.print('\nema unet devices:')
for i, ema_unet in enumerate(self.ema_unets):
device = next(ema_unet.parameters()).device
self.print(f'\tema unet {i}: {device}')
# 重写状态字典函数
def state_dict(self, *args, **kwargs):
# 重置所有 U-Net 模型为同一设备上的指数移动平均模型
self.reset_ema_unets_all_one_device()
return super().state_dict(*args, **kwargs)
def load_state_dict(self, *args, **kwargs):
# 重置所有 U-Net 模型为同一设备上的指数移动平均模型
self.reset_ema_unets_all_one_device()
return super().load_state_dict(*args, **kwargs)
# 编码文本函数
def encode_text(self, text, **kwargs):
return self.imagen.encode_text(text, **kwargs)
# 前向传播函数和梯度更新步骤
def update(self, unet_number = None):
unet_number = self.validate_unet_number(unet_number)
self.validate_and_set_unet_being_trained(unet_number)
self.set_accelerator_scaler(unet_number)
index = unet_number - 1
unet = self.unet_being_trained
optimizer = getattr(self, f'optim{index}')
scaler = getattr(self, f'scaler{index}')
scheduler = getattr(self, f'scheduler{index}')
warmup_scheduler = getattr(self, f'warmup{index}')
# 在加速器上设置梯度缩放器,因为我们每个 U-Net 管理一个
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
optimizer.step()
optimizer.zero_grad()
if self.use_ema:
ema_unet = self.get_ema_unet(unet_number)
ema_unet.update()
# 调度器,如果需要
maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening()
with maybe_warmup_context:
if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: # 推荐在文档中
scheduler.step()
self.steps += F.one_hot(torch.tensor(unet_number - 1, device = self.steps.device), num_classes = len(self.steps))
if not exists(self.checkpoint_path):
return
total_steps = int(self.steps.sum().item())
if total_steps % self.checkpoint_every:
return
self.save_to_checkpoint_folder()
@torch.no_grad()
@cast_torch_tensor
@imagen_sample_in_chunks
def sample(self, *args, **kwargs):
context = nullcontext if kwargs.pop('use_non_ema', False) else self.use_ema_unets
self.print_untrained_unets()
if not self.is_main:
kwargs['use_tqdm'] = False
with context():
output = self.imagen.sample(*args, device = self.device, **kwargs)
return output
@partial(cast_torch_tensor, cast_fp16 = True)
def forward(
self,
*args,
unet_number = None,
max_batch_size = None,
**kwargs
):
# 验证并修正 UNet 编号
unet_number = self.validate_unet_number(unet_number)
# 验证并设置正在训练的 UNet 编号
self.validate_and_set_unet_being_trained(unet_number)
# 设置加速器缩放器
self.set_accelerator_scaler(unet_number)
# 断言只有训练指定 UNet 编号或者没有指定 UNet 编号
assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}'
# 初始化总损失
total_loss = 0.
# 将参数和关键字参数按照最大批处理大小拆分
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
# 使用加速器自动转换
with self.accelerator.autocast():
# 计算损失
loss = self.imagen(*chunked_args, unet = self.unet_being_trained, unet_number = unet_number, **chunked_kwargs)
# 损失乘以分块大小比例
loss = loss * chunk_size_frac
# 累加总损失
total_loss += loss.item()
# 如果处于训练状态,进行反向传播
if self.training:
self.accelerator.backward(loss)
# 返回总损失
return total_loss
.\lucidrains\imagen-pytorch\imagen_pytorch\utils.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块
from torch import nn
# 从 functools 库中导入 reduce 函数
from functools import reduce
# 从 pathlib 库中导入 Path 类
from pathlib import Path
# 从 imagen_pytorch.configs 模块中导入 ImagenConfig 和 ElucidatedImagenConfig 类
from imagen_pytorch.configs import ImagenConfig, ElucidatedImagenConfig
# 从 ema_pytorch 模块中导入 EMA 类
from ema_pytorch import EMA
# 定义一个函数,用于检查变量是否存在
def exists(val):
return val is not None
# 定义一个函数,用于安全获取字典中的值
def safeget(dictionary, keys, default = None):
return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
# 加载模型和配置信息
def load_imagen_from_checkpoint(
checkpoint_path,
load_weights = True,
load_ema_if_available = False
):
# 创建 Path 对象
model_path = Path(checkpoint_path)
# 获取完整的模型路径
full_model_path = str(model_path.resolve())
# 断言模型路径存在
assert model_path.exists(), f'checkpoint not found at {full_model_path}'
# 加载模型参数
loaded = torch.load(str(model_path), map_location='cpu')
# 获取 imagen 参数和类型
imagen_params = safeget(loaded, 'imagen_params')
imagen_type = safeget(loaded, 'imagen_type')
# 根据 imagen 类型选择对应的配置类
if imagen_type == 'original':
imagen_klass = ImagenConfig
elif imagen_type == 'elucidated':
imagen_klass = ElucidatedImagenConfig
else:
raise ValueError(f'unknown imagen type {imagen_type} - you need to instantiate your Imagen with configurations, using classes ImagenConfig or ElucidatedImagenConfig')
# 断言 imagen 参数和类型存在
assert exists(imagen_params) and exists(imagen_type), 'imagen type and configuration not saved in this checkpoint'
# 根据配置类和参数创建 imagen 对象
imagen = imagen_klass(**imagen_params).create()
# 如果不加载权重,则直接返回 imagen 对象
if not load_weights:
return imagen
# 检查是否存在 EMA 模型
has_ema = 'ema' in loaded
should_load_ema = has_ema and load_ema_if_available
# 加载模型参数
imagen.load_state_dict(loaded['model'])
# 如果不需要加载 EMA 模型,则直接返回 imagen 对象
if not should_load_ema:
print('loading non-EMA version of unets')
return imagen
# 创建 EMA 模型列表
ema_unets = nn.ModuleList([])
# 遍历 imagen.unets,为每个 unet 创建一个 EMA 模型
for unet in imagen.unets:
ema_unets.append(EMA(unet))
# 加载 EMA 模型参数
ema_unets.load_state_dict(loaded['ema'])
# 将 EMA 模型参数加载到对应的 unet 模型中
for unet, ema_unet in zip(imagen.unets, ema_unets):
unet.load_state_dict(ema_unet.ema_model.state_dict())
# 打印信息并返回 imagen 对象
print('loaded EMA version of unets')
return imagen
.\lucidrains\imagen-pytorch\imagen_pytorch\version.py
# 定义变量 __version__,赋值为字符串 '1.26.2'
__version__ = '1.26.2'
.\lucidrains\imagen-pytorch\imagen_pytorch\__init__.py
# 从 imagen_pytorch 模块中导入 Imagen 和 Unet 类
from imagen_pytorch.imagen_pytorch import Imagen, Unet
# 从 imagen_pytorch 模块中导入 NullUnet 类
from imagen_pytorch.imagen_pytorch import NullUnet
# 从 imagen_pytorch 模块中导入 BaseUnet64, SRUnet256, SRUnet1024 类
from imagen_pytorch.imagen_pytorch import BaseUnet64, SRUnet256, SRUnet1024
# 从 imagen_pytorch 模块中导入 ImagenTrainer 类
from imagen_pytorch.trainer import ImagenTrainer
# 从 imagen_pytorch 模块中导入 __version__ 变量
from imagen_pytorch.version import __version__
# 使用 Tero Karras 的新论文中阐述的 ddpm 创建 imagen
# 从 imagen_pytorch 模块中导入 ElucidatedImagen 类
from imagen_pytorch.elucidated_imagen import ElucidatedImagen
# 通过配置创建 imagen 实例
# 从 imagen_pytorch 模块中导入 UnetConfig, ImagenConfig, ElucidatedImagenConfig, ImagenTrainerConfig 类
from imagen_pytorch.configs import UnetConfig, ImagenConfig, ElucidatedImagenConfig, ImagenTrainerConfig
# 工具
# 从 imagen_pytorch 模块中导入 load_imagen_from_checkpoint 函数
from imagen_pytorch.utils import load_imagen_from_checkpoint
# 视频
# 从 imagen_pytorch 模块中导入 Unet3D 类
from imagen_pytorch.imagen_video import Unet3D
Imagen - Pytorch
Implementation of Imagen, Google's Text-to-Image Neural Network that beats DALL-E2, in Pytorch. It is the new SOTA for text-to-image synthesis.
Architecturally, it is actually much simpler than DALL-E2. It consists of a cascading DDPM conditioned on text embeddings from a large pretrained T5 model (attention network). It also contains dynamic clipping for improved classifier free guidance, noise level conditioning, and a memory efficient unet design.
It appears neither CLIP nor prior network is needed after all. And so research continues.
AI Coffee Break with Letitia | Assembly AI | Yannic Kilcher
Please join if you are interested in helping out with the replication with the LAION community
Shoutouts
-
StabilityAI for the generous sponsorship, as well as my other sponsors out there
-
🤗 Huggingface for their amazing transformers library. The text encoder portion is pretty much taken care of because of them
-
Jonathan Ho for bringing about a revolution in generative artificial intelligence through his seminal paper
-
Sylvain and Zachary for the Accelerate library, which this repository uses for distributed training
-
Jorge Gomes for helping out with the T5 loading code and advice on the correct T5 version
-
Katherine Crowson, for her beautiful code, which helped me understand the continuous time version of gaussian diffusion
-
Marunine and Netruk44, for reviewing code, sharing experimental results, and help with debugging
-
Marunine for providing a potential solution for a color shifting issue in the memory efficient u-nets. Thanks to Jacob for sharing experimental comparisons between the base and memory-efficient unets
-
Marunine for finding numerous bugs, resolving an issue with resize right, and for sharing his experimental configurations and results
-
MalumaDev for proposing the use of pixel shuffle upsampler to fix checkboard artifacts
-
Valentin for pointing out insufficient skip connections in the unet, as well as the specific method of attention conditioning in the base-unet in the appendix
-
BIGJUN for catching a big bug with continuous time gaussian diffusion noise level conditioning at inference time
-
Bingbing for identifying a bug with sampling and order of normalizing and noising with low resolution conditioning image
-
Kay for contributing one line command training of Imagen!
-
Hadrien Reynaud for testing out text-to-video on a medical dataset, sharing his results, and identifying issues!
Install
$ pip install imagen-pytorch
Usage
import torch
from imagen_pytorch import Unet, Imagen
# unet for imagen
unet1 = Unet(
dim = 32,
cond_dim = 512,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = 3,
layer_attns = (False, True, True, True),
layer_cross_attns = (False, True, True, True)
)
unet2 = Unet(
dim = 32,
cond_dim = 512,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = (2, 4, 8, 8),
layer_attns = (False, False, False, True),
layer_cross_attns = (False, False, False, True)
)
# imagen, which contains the unets above (base unet and super resoluting ones)
imagen = Imagen(
unets = (unet1, unet2),
image_sizes = (64, 256),
timesteps = 1000,
cond_drop_prob = 0.1
).cuda()
# mock images (get a lot of this) and text encodings from large T5
text_embeds = torch.randn(4, 256, 768).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# feed images into imagen, training each unet in the cascade
for i in (1, 2):
loss = imagen(images, text_embeds = text_embeds, unet_number = i)
loss.backward()
# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm
images = imagen.sample(texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles'
], cond_scale = 3.)
images.shape # (3, 3, 256, 256)
For simpler training, you can directly supply text strings instead of precomputing text encodings. (Although for scaling purposes, you will definitely want to precompute the textual embeddings + mask)
The number of textual captions must match the batch size of the images if you go this route.
# mock images and text (get a lot of this)
texts = [
'a child screaming at finding a worm within a half-eaten apple',
'lizard running across the desert on two feet',
'waking up to a psychedelic landscape',
'seashells sparkling in the shallow waters'
]
images = torch.randn(4, 3, 256, 256).cuda()
# feed images into imagen, training each unet in the cascade
for i in (1, 2):
loss = imagen(images, texts = texts, unet_number = i)
loss.backward()
With the ImagenTrainer
wrapper class, the exponential moving averages for all of the U-nets in the cascading DDPM will be automatically taken care of when calling update
import torch
from imagen_pytorch import Unet, Imagen, ImagenTrainer
# unet for imagen
unet1 = Unet(
dim = 32,
cond_dim = 512,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = 3,
layer_attns = (False, True, True, True),
)
unet2 = Unet(
dim = 32,
cond_dim = 512,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = (2, 4, 8, 8),
layer_attns = (False, False, False, True),
layer_cross_attns = (False, False, False, True)
)
# imagen, which contains the unets above (base unet and super resoluting ones)
imagen = Imagen(
unets = (unet1, unet2),
text_encoder_name = 't5-large',
image_sizes = (64, 256),
timesteps = 1000,
cond_drop_prob = 0.1
).cuda()
# wrap imagen with the trainer class
trainer = ImagenTrainer(imagen)
# mock images (get a lot of this) and text encodings from large T5
text_embeds = torch.randn(64, 256, 1024).cuda()
images = torch.randn(64, 3, 256, 256).cuda()
# feed images into imagen, training each unet in the cascade
loss = trainer(
images,
text_embeds = text_embeds,
unet_number = 1, # training on unet number 1 in this example, but you will have to also save checkpoints and then reload and continue training on unet number 2
max_batch_size = 4 # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
)
trainer.update(unet_number = 1)
# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm
images = trainer.sample(texts = [
'a puppy looking anxiously at a giant donut on the table',
'the milky way galaxy in the style of monet'
], cond_scale = 3.)
images.shape # (2, 3, 256, 256)
You can also train Imagen without text (unconditional image generation) as follows
import torch
from imagen_pytorch import Unet, Imagen, SRUnet256, ImagenTrainer
# unets for unconditional imagen
unet1 = Unet(
dim = 32,
dim_mults = (1, 2, 4),
num_resnet_blocks = 3,
layer_attns = (False, True, True),
layer_cross_attns = False,
use_linear_attn = True
)
unet2 = SRUnet256(
dim = 32,
dim_mults = (1, 2, 4),
num_resnet_blocks = (2, 4, 8),
layer_attns = (False, False, True),
layer_cross_attns = False
)
# imagen, which contains the unets above (base unet and super resoluting ones)
imagen = Imagen(
condition_on_text = False, # this must be set to False for unconditional Imagen
unets = (unet1, unet2),
image_sizes = (64, 128),
timesteps = 1000
)
trainer = ImagenTrainer(imagen).cuda()
# now get a ton of images and feed it through the Imagen trainer
training_images = torch.randn(4, 3, 256, 256).cuda()
# train each unet separately
# in this example, only training on unet number 1
loss = trainer(training_images, unet_number = 1)
trainer.update(unet_number = 1)
# do the above for many many many many steps
# now you can sample images unconditionally from the cascading unet(s)
images = trainer.sample(batch_size = 16) # (16, 3, 128, 128)
Or train only super-resoluting unets
import torch
from imagen_pytorch import Unet, NullUnet, Imagen
# unet for imagen
unet1 = NullUnet() # add a placeholder "null" unet for the base unet
unet2 = Unet(
dim = 32,
cond_dim = 512,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = (2, 4, 8, 8),
layer_attns = (False, False, False, True),
layer_cross_attns = (False, False, False, True)
)
# imagen, which contains the unets above (base unet and super resoluting ones)
imagen = Imagen(
unets = (unet1, unet2),
image_sizes = (64, 256),
timesteps = 250,
cond_drop_prob = 0.1
).cuda()
# mock images (get a lot of this) and text encodings from large T5
text_embeds = torch.randn(4, 256, 768).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# feed images into imagen, training each unet in the cascade
loss = imagen(images, text_embeds = text_embeds, unet_number = 2)
loss.backward()
# do the above for many many many many steps
# now you can sample an image based on the text embeddings as well as low resolution images
lowres_images = torch.randn(3, 3, 64, 64).cuda() # starting un-resoluted images
images = imagen.sample(
texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles'
],
start_at_unet_number = 2, # start at unet number 2
start_image_or_video = lowres_images, # pass in low resolution images to be resoluted
cond_scale = 3.)
images.shape # (3, 3, 256, 256)
At any time you can save and load the trainer and all associated states with the save
and load
methods. It is recommended you use these methods instead of manually saving with a state_dict
call, as there are some device memory management being done underneath the hood within the trainer.
ex.
trainer.save('./path/to/checkpoint.pt')
trainer.load('./path/to/checkpoint.pt')
trainer.steps # (2,) step number for each of the unets, in this case 2
Dataloader
You can also rely on the ImagenTrainer
to automatically train off DataLoader
instances. You simply have to craft your DataLoader
to return either images
(for unconditional case), or of ('images', 'text_embeds')
for text-guided generation.
ex. unconditional training
from imagen_pytorch import Unet, Imagen, ImagenTrainer
from imagen_pytorch.data import Dataset
# unets for unconditional imagen
unet = Unet(
dim = 32,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = 1,
layer_attns = (False, False, False, True),
layer_cross_attns = False
)
# imagen, which contains the unet above
imagen = Imagen(
condition_on_text = False, # this must be set to False for unconditional Imagen
unets = unet,
image_sizes = 128,
timesteps = 1000
)
trainer = ImagenTrainer(
imagen = imagen,
split_valid_from_train = True # whether to split the validation dataset from the training
).cuda()
# instantiate your dataloader, which returns the necessary inputs to the DDPM as tuple in the order of images, text embeddings, then text masks. in this case, only images is returned as it is unconditional training
dataset = Dataset('/path/to/training/images', image_size = 128)
trainer.add_train_dataset(dataset, batch_size = 16)
# working training loop
for i in range(200000):
loss = trainer.train_step(unet_number = 1, max_batch_size = 4)
print(f'loss: {loss}')
if not (i % 50):
valid_loss = trainer.valid_step(unet_number = 1, max_batch_size = 4)
print(f'valid loss: {valid_loss}')
if not (i % 100) and trainer.is_main: # is_main makes sure this can run in distributed
images = trainer.sample(batch_size = 1, return_pil_images = True) # returns List[Image]
images[0].save(f'./sample-{i // 100}.png')
Multi GPU
Thanks to 🤗 Accelerate, you can do multi GPU training easily with two steps.
First you need to invoke accelerate config
in the same directory as your training script (say it is named train.py
)
$ accelerate config
Next, instead of calling python train.py
as you would for single GPU, you would use the accelerate CLI as so
$ accelerate launch train.py
That's it!
Command-line
Imagen can also be used via CLI directly.
Configuration
ex.
$ imagen config
or
$ imagen config --path ./configs/config.json
In the config you are able to change settings for the trainer, dataset and the imagen config.
The Imagen config parameters can be found here
The Elucidated Imagen config parameters can be found here
The Imagen Trainer config parameters can be found here
For the dataset parameters all dataloader parameters can be used.
Training
This command allows you to train or resume training your model
ex.
$ imagen train
or
$ imagen train --unet 2 --epoches 10
You can pass following arguments to the training command.
--config
specify the config file to use for training [default: ./imagen_config.json]--unet
the index of the unet to train [default: 1]--epoches
how many epoches to train for [default: 50]
Sampling
Be aware when sampling your checkpoint should have trained all unets to get a usable result.
ex.
$ imagen sample --model ./path/to/model/checkpoint.pt "a squirrel raiding the birdfeeder"
# image is saved to ./a_squirrel_raiding_the_birdfeeder.png
You can pass following arguments to the sample command.
--model
specify the model file to use for sampling--cond_scale
conditioning scale (classifier free guidance) in decoder--load_ema
load EMA version of unets if available
In order to use a saved checkpoint with this feature, you either must instantiate your Imagen instance using the config classes, ImagenConfig
and ElucidatedImagenConfig
or create a checkpoint via the CLI directly
For proper training, you'll likely want to setup config-driven training anyways.
ex.
import torch
from imagen_pytorch import ImagenConfig, ElucidatedImagenConfig, ImagenTrainer
# in this example, using elucidated imagen
imagen = ElucidatedImagenConfig(
unets = [
dict(dim = 32, dim_mults = (1, 2, 4, 8)),
dict(dim = 32, dim_mults = (1, 2, 4, 8))
],
image_sizes = (64, 128),
cond_drop_prob = 0.5,
num_sample_steps = 32
).create()
trainer = ImagenTrainer(imagen)
# do your training ...
# then save it
trainer.save('./checkpoint.pt')
# you should see a message informing you that ./checkpoint.pt is commandable from the terminal
It really should be as simple as that
You can also pass this checkpoint file around, and anyone can continue finetune on their own data
from imagen_pytorch import load_imagen_from_checkpoint, ImagenTrainer
imagen = load_imagen_from_checkpoint('./checkpoint.pt')
trainer = ImagenTrainer(imagen)
# continue training / fine-tuning
Inpainting
Inpainting follows the formulation laid out by the recent Repaint paper. Simply pass in inpaint_images
and inpaint_masks
to the sample
function on either Imagen
or ElucidatedImagen
inpaint_images = torch.randn(4, 3, 512, 512).cuda() # (batch, channels, height, width)
inpaint_masks = torch.ones((4, 512, 512)).bool().cuda() # (batch, height, width)
inpainted_images = trainer.sample(texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles',
'dust motes swirling in the morning sunshine on the windowsill'
], inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, cond_scale = 5.)
inpainted_images # (4, 3, 512, 512)
For video, similarly pass in your videos to inpaint_videos
keyword on .sample
. Inpainting mask can either be the same across all frames (batch, height, width)
or different (batch, frames, height, width)
inpaint_videos = torch.randn(4, 3, 8, 512, 512).cuda() # (batch, channels, frames, height, width)
inpaint_masks = torch.ones((4, 8, 512, 512)).bool().cuda() # (batch, frames, height, width)
inpainted_videos = trainer.sample(texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles',
'dust motes swirling in the morning sunshine on the windowsill'
], inpaint_videos = inpaint_videos, inpaint_masks = inpaint_masks, cond_scale = 5.)
inpainted_videos # (4, 3, 8, 512, 512)
Experimental
Tero Karras of StyleGAN fame has written a new paper with results that have been corroborated by a number of independent researchers as well as on my own machine. I have decided to create a version of Imagen
, the ElucidatedImagen
, so that one can use the new elucidated DDPM for text-guided cascading generation.
Simply import ElucidatedImagen
, and then instantiate the instance as you did before. The hyperparameters are different than the usual ones for discrete and continuous time gaussian diffusion, and can be individualized for each unet in the cascade.
Ex.
from imagen_pytorch import ElucidatedImagen
# instantiate your unets ...
imagen = ElucidatedImagen(
unets = (unet1, unet2),
image_sizes = (64, 128),
cond_drop_prob = 0.1,
num_sample_steps = (64, 32), # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are)
sigma_min = 0.002, # min noise level
sigma_max = (80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler
sigma_data = 0.5, # standard deviation of data distribution
rho = 7, # controls the sampling schedule
P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
S_tmin = 0.05,
S_tmax = 50,
S_noise = 1.003,
).cuda()
# rest is the same as above
Text to Video
This repository will also start accumulating new research around text guided video synthesis. For starters it will adopt the 3d unet architecture described by Jonathan Ho in Video Diffusion Models
Update: verified working by Hadrien Reynaud!
Ex.
import torch
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer
unet1 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda()
unet2 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda()
# elucidated imagen, which contains the unets above (base unet and super resoluting ones)
imagen = ElucidatedImagen(
unets = (unet1, unet2),
image_sizes = (16, 32),
random_crop_sizes = (None, 16),
temporal_downsample_factor = (2, 1), # in this example, the first unet would receive the video temporally downsampled by 2x
num_sample_steps = 10,
cond_drop_prob = 0.1,
sigma_min = 0.002, # min noise level
sigma_max = (80, 160), # max noise level, double the max noise level for upsampler
sigma_data = 0.5, # standard deviation of data distribution
rho = 7, # controls the sampling schedule
P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
S_tmin = 0.05,
S_tmax = 50,
S_noise = 1.003,
).cuda()
# mock videos (get a lot of this) and text encodings from large T5
texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles',
'dust motes swirling in the morning sunshine on the windowsill'
]
videos = torch.randn(4, 3, 10, 32, 32).cuda() # (batch, channels, time / video frames, height, width)
# feed images into imagen, training each unet in the cascade
# for this example, only training unet 1
trainer = ImagenTrainer(imagen)
# you can also ignore time when training on video initially, shown to improve results in video-ddpm paper. eventually will make the 3d unet trainable with either images or video. research shows it is essential (with current data regimes) to train first on text-to-image. probably won't be true in another decade. all big data becomes small data
trainer(videos, texts = texts, unet_number = 1, ignore_time = False)
trainer.update(unet_number = 1)
videos = trainer.sample(texts = texts, video_frames = 20) # extrapolating to 20 frames from training on 10 frames
videos.shape # (4, 3, 20, 32, 32)
You can also train on text - image pairs first. The Unet3D
will automatically convert it to single framed videos and learn without the temporal components (by automatically setting ignore_time = True
), whether it be 1d convolutions or causal attention across time.
This is the current approach taken by all the big artificial intelligence labs (Brain, MetaAI, Bytedance)
FAQ
- Why are my generated images not aligning well with the text?
Imagen uses an algorithm called Classifier Free Guidance. When sampling, you apply a scale to the conditioning (text in this case) of greater than 1.0
.
Researcher Netruk44 have reported 5-10
to be optimal, but anything greater than 10
to break.
trainer.sample(texts = [
'a cloud in the shape of a roman gladiator'
], cond_scale = 5.) # <-- cond_scale is the conditioning scale, needs to be greater than 1.0 to be better than average
- Are there any pretrained models yet?
Not at the moment but one will likely be trained and open sourced within the year, if not sooner. If you would like to participate, you can join the community of artificial neural network trainers at Laion (discord link is in the Readme above) and start collaborating.
- Will this technology take my job?
More the reason why you should start training your own model, starting today! The last thing we need is this technology being in the hands of an elite few. Hopefully this repository reduces the work to just finding the necessary compute, and augmenting with your own curated dataset.
- What am I allowed to do with this repository?
Anything! It is MIT licensed. In other words, you can freely copy / paste for your own research, remixed for whatever modality you can think of. Go train amazing models for profit, for science, or simply to satiate your own personal pleasure at witnessing something divine unravel in front of you.
Cool Applications!
Related Works
Todo
Citations
@inproceedings{Saharia2022PhotorealisticTD,
title = {Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding},
author = {Chitwan Saharia and William Chan and Saurabh Saxena and Lala Li and Jay Whang and Emily L. Denton and Seyed Kamyar Seyed Ghasemipour and Burcu Karagol Ayan and Seyedeh Sara Mahdavi and Raphael Gontijo Lopes and Tim Salimans and Jonathan Ho and David Fleet and Mohammad Norouzi},
year = {2022}
}
@article{Alayrac2022Flamingo,
title = {Flamingo: a Visual Language Model for Few-Shot Learning},
author = {Jean-Baptiste Alayrac et al},
year = {2022}
}
@inproceedings{Sankararaman2022BayesFormerTW,
title = {BayesFormer: Transformer with Uncertainty Estimation},
author = {Karthik Abinav Sankararaman and Sinong Wang and Han Fang},
year = {2022}
}
@article{So2021PrimerSF,
title = {Primer: Searching for Efficient Transformers for Language Modeling},
author = {David R. So and Wojciech Ma'nke and Hanxiao Liu and Zihang Dai and Noam M. Shazeer and Quoc V. Le},
journal = {ArXiv},
year = {2021},
volume = {abs/2109.08668}
}
@misc{cao2020global,
title = {Global Context Networks},
author = {Yue Cao and Jiarui Xu and Stephen Lin and Fangyun Wei and Han Hu},
year = {2020},
eprint = {2012.13375},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@article{Karras2022ElucidatingTD,
title = {Elucidating the Design Space of Diffusion-Based Generative Models},
author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine},
journal = {ArXiv},
year = {2022},
volume = {abs/2206.00364}
}
@inproceedings{NEURIPS2020_4c5bcfec,
author = {Ho, Jonathan and Jain, Ajay and Abbeel, Pieter},
booktitle = {Advances in Neural Information Processing Systems},
editor = {H. Larochelle and M. Ranzato and R. Hadsell and M.F. Balcan and H. Lin},
pages = {6840--6851},
publisher = {Curran Associates, Inc.},
title = {Denoising Diffusion Probabilistic Models},
url = {https://proceedings.neurips.cc/paper/2020/file/4c5bcfec8584af0d967f1ab10179ca4b-Paper.pdf},
volume = {33},
year = {2020}
}
@article{Lugmayr2022RePaintIU,
title = {RePaint: Inpainting using Denoising Diffusion Probabilistic Models},
author = {Andreas Lugmayr and Martin Danelljan and Andr{\'e}s Romero and Fisher Yu and Radu Timofte and Luc Van Gool},
journal = {ArXiv},
year = {2022},
volume = {abs/2201.09865}
}
@misc{ho2022video,
title = {Video Diffusion Models},
author = {Jonathan Ho and Tim Salimans and Alexey Gritsenko and William Chan and Mohammad Norouzi and David J. Fleet},
year = {2022},
eprint = {2204.03458},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@inproceedings{rogozhnikov2022einops,
title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
author = {Alex Rogozhnikov},
booktitle = {International Conference on Learning Representations},
year = {2022},
url = {https://openreview.net/forum?id=oapKSVM2bcj}
}
@misc{chen2022analog,
title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
year = {2022},
eprint = {2208.04202},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@misc{Singer2022,
author = {Uriel Singer},
url = {https://makeavideo.studio/Make-A-Video.pdf}
}
@article{Sunkara2022NoMS,
title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
author = {Raja Sunkara and Tie Luo},
journal = {ArXiv},
year = {2022},
volume = {abs/2208.03641}
}
@article{Salimans2022ProgressiveDF,
title = {Progressive Distillation for Fast Sampling of Diffusion Models},
author = {Tim Salimans and Jonathan Ho},
journal = {ArXiv},
year = {2022},
volume = {abs/2202.00512}
}
@article{Ho2022ImagenVH,
title = {Imagen Video: High Definition Video Generation with Diffusion Models},
author = {Jonathan Ho and William Chan and Chitwan Saharia and Jay Whang and Ruiqi Gao and Alexey A. Gritsenko and Diederik P. Kingma and Ben Poole and Mohammad Norouzi and David J. Fleet and Tim Salimans},
journal = {ArXiv},
year = {2022},
volume = {abs/2210.02303}
}
@misc{gilmer2023intriguing
title = {Intriguing Properties of Transformer Training Instabilities},
author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen},
year = {2023},
status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams}
}
@inproceedings{Hang2023EfficientDT,
title = {Efficient Diffusion Training via Min-SNR Weighting Strategy},
author = {Tiankai Hang and Shuyang Gu and Chen Li and Jianmin Bao and Dong Chen and Han Hu and Xin Geng and Baining Guo},
year = {2023}
}
@article{Zhang2021TokenST,
title = {Token Shift Transformer for Video Classification},
author = {Hao Zhang and Y. Hao and Chong-Wah Ngo},
journal = {Proceedings of the 29th ACM International Conference on Multimedia},
year = {2021}
}
@inproceedings{anonymous2022normformer,
title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
author = {Anonymous},
booktitle = {Submitted to The Tenth International Conference on Learning Representations },
year = {2022},
url = {https://openreview.net/forum?id=GMYWzWztDx5},
note = {under review}
}