diffusers 源码解析(五十八)
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
创建一个 beta 调度,以离散化给定的 alpha_t_bar 函数,该函数定义了时间 t = [0,1] 上
(1-beta) 的累积乘积。
包含一个 alpha_bar 函数,该函数接受参数 t 并将其转换为扩散过程中该部分的 (1-beta) 的累积乘积。
参数:
num_diffusion_timesteps (`int`): 生成的 beta 数量。
max_beta (`float`): 使用的最大 beta 值;使用小于 1 的值以防止奇异性。
alpha_transform_type (`str`, *可选*, 默认值为 `cosine`): alpha_bar 的噪声调度类型。
可选择 `cosine` 或 `exp`
返回:
betas (`np.ndarray`): 调度器用于更新模型输出的 beta 值
"""
if alpha_transform_type == "cosine":
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
`DEISMultistepScheduler` 是一个快速高阶解算器,用于扩散常微分方程(ODEs)。
# 该模型继承自 [`SchedulerMixin`] 和 [`ConfigMixin`]。请查阅父类文档以获取库为所有调度程序实现的通用方法,例如加载和保存。
# 参数说明:
# num_train_timesteps (`int`, defaults to 1000):
# 用于训练模型的扩散步骤数量。
# beta_start (`float`, defaults to 0.0001):
# 推断的起始 `beta` 值。
# beta_end (`float`, defaults to 0.02):
# 最终的 `beta` 值。
# beta_schedule (`str`, defaults to `"linear"`):
# beta 计划,从 beta 范围映射到一系列用于模型步骤的 betas。可选择 `linear`、`scaled_linear` 或 `squaredcos_cap_v2`。
# trained_betas (`np.ndarray`, *optional*):
# 直接传递 beta 数组给构造函数,以绕过 `beta_start` 和 `beta_end`。
# solver_order (`int`, defaults to 2):
# DEIS 顺序,可以是 `1`、`2` 或 `3`。建议使用 `solver_order=2` 进行引导采样,使用 `solver_order=3` 进行无条件采样。
# prediction_type (`str`, defaults to `epsilon`):
# 调度程序函数的预测类型;可以是 `epsilon`(预测扩散过程的噪声)、`sample`(直接预测噪声样本)或 `v_prediction`(见 [Imagen Video](https://imagen.research.google/video/paper.pdf) 论文第 2.4 节)。
# thresholding (`bool`, defaults to `False`):
# 是否使用“动态阈值”方法。这对于如稳定扩散的潜空间扩散模型不适用。
# dynamic_thresholding_ratio (`float`, defaults to 0.995):
# 动态阈值方法的比率。仅在 `thresholding=True` 时有效。
# sample_max_value (`float`, defaults to 1.0):
# 动态阈值的阈值值。仅在 `thresholding=True` 时有效。
# algorithm_type (`str`, defaults to `deis`):
# 求解器的算法类型。
# lower_order_final (`bool`, defaults to `True`):
# 是否在最后步骤中使用低阶求解器。仅在推理步骤小于 15 时有效。
# use_karras_sigmas (`bool`, *optional*, defaults to `False`):
# 是否在采样过程中使用 Karras sigmas 作为噪声计划中的步长。如果为 `True`,则 sigmas 根据噪声水平序列 {σi} 确定。
# timestep_spacing (`str`, defaults to `"linspace"`):
# 时间步的缩放方式。有关更多信息,请参考 [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) 的表 2。
# steps_offset (`int`, defaults to 0):
# 添加到推理步骤的偏移量,根据某些模型系列的要求。
# 创建一个包含所有 KarrasDiffusionSchedulers 名称的列表
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
# 设置默认的求解器阶数为 1
order = 1
# 注册到配置中,定义初始化函数
@register_to_config
def __init__(
# 设置训练时间步的数量,默认值为 1000
num_train_timesteps: int = 1000,
# 设置 beta 的起始值,默认值为 0.0001
beta_start: float = 0.0001,
# 设置 beta 的结束值,默认值为 0.02
beta_end: float = 0.02,
# 设置 beta 的调度方式,默认值为 "linear"
beta_schedule: str = "linear",
# 可选参数,设置训练的 beta 数组,默认值为 None
trained_betas: Optional[np.ndarray] = None,
# 设置求解器的阶数,默认值为 2
solver_order: int = 2,
# 设置预测类型,默认值为 "epsilon"
prediction_type: str = "epsilon",
# 设置是否使用阈值处理,默认值为 False
thresholding: bool = False,
# 设置动态阈值比例,默认值为 0.995
dynamic_thresholding_ratio: float = 0.995,
# 设置样本的最大值,默认值为 1.0
sample_max_value: float = 1.0,
# 设置算法类型,默认值为 "deis"
algorithm_type: str = "deis",
# 设置求解器类型,默认值为 "logrho"
solver_type: str = "logrho",
# 设置是否在最后阶段使用较低的阶数,默认值为 True
lower_order_final: bool = True,
# 可选参数,设置是否使用 Karras sigma,默认值为 False
use_karras_sigmas: Optional[bool] = False,
# 设置时间步的间距类型,默认值为 "linspace"
timestep_spacing: str = "linspace",
# 设置步数偏移,默认值为 0
steps_offset: int = 0,
):
# 检查已训练的 beta 值是否为 None
if trained_betas is not None:
# 将训练的 beta 值转换为浮点型张量
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
# 检查 beta 调度类型是否为线性
elif beta_schedule == "linear":
# 生成从 beta_start 到 beta_end 的线性序列
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
# 检查 beta 调度类型是否为缩放线性
elif beta_schedule == "scaled_linear":
# 该调度特定于潜在扩散模型
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
# 检查 beta 调度类型是否为平方余弦 cap v2
elif beta_schedule == "squaredcos_cap_v2":
# Glide 余弦调度
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
# 如果不支持的调度类型,抛出未实现错误
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
# 计算 alphas,等于 1 减去 betas
self.alphas = 1.0 - self.betas
# 计算 alphas 的累积乘积
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# 当前仅支持 VP 类型噪声调度
self.alpha_t = torch.sqrt(self.alphas_cumprod)
# 计算 sigma_t,等于 1 减去 alphas_cumprod 的平方根
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
# 计算 lambda_t,等于 alpha_t 和 sigma_t 的对数差
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
# 计算 sigmas,等于 (1 - alphas_cumprod) 除以 alphas_cumprod 的平方根
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
# 设置初始噪声分布的标准差
self.init_noise_sigma = 1.0
# DEIS 设置
if algorithm_type not in ["deis"]:
# 如果算法类型是 dpmsolver 或 dpmsolver++
if algorithm_type in ["dpmsolver", "dpmsolver++"]:
# 注册算法类型到配置
self.register_to_config(algorithm_type="deis")
else:
# 抛出未实现错误
raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
# 检查求解器类型是否为 logrho
if solver_type not in ["logrho"]:
# 如果求解器类型是 midpoint, heun, bh1, bh2
if solver_type in ["midpoint", "heun", "bh1", "bh2"]:
# 注册求解器类型到配置
self.register_to_config(solver_type="logrho")
else:
# 抛出未实现错误
raise NotImplementedError(f"solver type {solver_type} is not implemented for {self.__class__}")
# 可设置的值
self.num_inference_steps = None
# 生成从 0 到 num_train_timesteps - 1 的时间步,反转顺序
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
# 将时间步转换为张量
self.timesteps = torch.from_numpy(timesteps)
# 初始化模型输出列表,长度为 solver_order
self.model_outputs = [None] * solver_order
# 记录低阶数
self.lower_order_nums = 0
# 初始化步索引
self._step_index = None
# 初始化开始索引
self._begin_index = None
# 将 sigmas 移到 CPU,以避免过多的 CPU/GPU 通信
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def step_index(self):
"""
当前时间步的索引计数器。每次调度器步骤后增加 1。
"""
return self._step_index
@property
def begin_index(self):
"""
第一个时间步的索引。应通过 `set_begin_index` 方法从管道中设置。
"""
return self._begin_index
# 从 diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index 复制
# 设置调度器的起始索引,默认值为0
def set_begin_index(self, begin_index: int = 0):
# 文档字符串,说明函数的用途和参数
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
# 将传入的起始索引值存储到实例变量中
self._begin_index = begin_index
# 从 diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample 复制的函数
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
# 文档字符串,描述动态阈值处理的原理和效果
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
# 获取输入样本的数值类型
dtype = sample.dtype
# 获取样本的批次大小、通道数及剩余维度
batch_size, channels, *remaining_dims = sample.shape
# 检查数据类型,如果不是浮点数,则转换为浮点数
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# 将样本扁平化以进行量化计算
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
# 计算样本的绝对值
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
# 计算每个图像的动态阈值
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
# 限制阈值在指定范围内
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
# 扩展维度以适应广播
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
# 将样本限制在[-s, s]范围内并归一化
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
# 恢复样本的原始形状
sample = sample.reshape(batch_size, channels, *remaining_dims)
# 将样本转换回原始数据类型
sample = sample.to(dtype)
# 返回处理后的样本
return sample
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t 复制的函数
def _sigma_to_t(self, sigma, log_sigmas):
# 计算对数sigma值,确保不小于1e-10
log_sigma = np.log(np.maximum(sigma, 1e-10))
# 计算对数sigma的分布
dists = log_sigma - log_sigmas[:, np.newaxis]
# 找到sigma的范围
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
# 获取低和高的对数sigma值
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
# 进行sigma的插值
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
# 将插值转换为时间范围
t = (1 - w) * low_idx + w * high_idx
# 重新调整形状以匹配sigma的形状
t = t.reshape(sigma.shape)
# 返回时间值
return t
# 从 diffusers.schedulers.scheduling_dpmsolver_multistep 导入的函数,用于将 sigma 转换为 alpha 和 sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
# 计算 alpha_t,公式为 1 / sqrt(sigma^2 + 1)
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
# 计算 sigma_t,公式为 sigma * alpha_t
sigma_t = sigma * alpha_t
# 返回计算得到的 alpha_t 和 sigma_t
return alpha_t, sigma_t
# 从 diffusers.schedulers.scheduling_euler_discrete 导入的函数,用于将输入 sigma 转换为 Karras 的格式
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""构建 Karras 等人 (2022) 的噪声调度。"""
# 确保其他调度器复制此函数时不会出错的黑客方案
# TODO: 将此逻辑添加到其他调度器中
if hasattr(self.config, "sigma_min"):
# 获取 sigma_min,如果配置中存在
sigma_min = self.config.sigma_min
else:
# 如果配置中不存在,设置为 None
sigma_min = None
if hasattr(self.config, "sigma_max"):
# 获取 sigma_max,如果配置中存在
sigma_max = self.config.sigma_max
else:
# 如果配置中不存在,设置为 None
sigma_max = None
# 设置 sigma_min 为输入 sigmas 的最后一个值,如果它是 None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
# 设置 sigma_max 为输入 sigmas 的第一个值,如果它是 None
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
# 定义 rho 的值为 7.0,引用文献中使用的值
rho = 7.0 # 7.0 is the value used in the paper
# 生成从 0 到 1 的 ramp 数组,长度为 num_inference_steps
ramp = np.linspace(0, 1, num_inference_steps)
# 计算 min_inv_rho 和 max_inv_rho
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
# 根据公式生成 sigmas 数组
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
# 返回生成的 sigmas
return sigmas
# 定义 convert_model_output 函数,用于处理模型输出
def convert_model_output(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
"""
将模型输出转换为 DEIS 算法所需的对应类型。
参数:
model_output (`torch.Tensor`):
来自学习的扩散模型的直接输出。
timestep (`int`):
当前扩散链中的离散时间步。
sample (`torch.Tensor`):
扩散过程中创建的当前样本实例。
返回:
`torch.Tensor`:
转换后的模型输出。
"""
# 从 args 中提取 timestep,如果没有则从 kwargs 中提取,默认为 None
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
# 如果 sample 为 None,尝试从 args 中提取
if sample is None:
if len(args) > 1:
sample = args[1]
else:
# 如果没有提供 sample,则抛出错误
raise ValueError("missing `sample` as a required keyward argument")
# 如果 timestep 不是 None,发出弃用警告
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
# 获取当前步的 sigma 值
sigma = self.sigmas[self.step_index]
# 将 sigma 转换为 alpha_t 和 sigma_t
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
# 根据配置类型进行不同的模型输出处理
if self.config.prediction_type == "epsilon":
# 计算基于 epsilon 的预测
x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample":
# 直接将模型输出作为预测
x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
# 计算基于 v 的预测
x0_pred = alpha_t * sample - sigma_t * model_output
else:
# 如果 prediction_type 不符合要求,抛出错误
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction` for the DEISMultistepScheduler."
)
# 如果开启阈值处理,则对预测值进行阈值处理
if self.config.thresholding:
x0_pred = self._threshold_sample(x0_pred)
# 如果算法类型为 deis,返回转换后的样本
if self.config.algorithm_type == "deis":
return (sample - alpha_t * x0_pred) / sigma_t
else:
# 抛出未实现错误,表明仅支持 log-rho multistep deis
raise NotImplementedError("only support log-rho multistep deis now")
# 定义 deis_first_order_update 函数,接受模型输出和可变参数
def deis_first_order_update(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
**kwargs,
) -> torch.Tensor: # 定义函数返回类型为 torch.Tensor
"""
One step for the first-order DEIS (equivalent to DDIM).
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.Tensor`:
The sample tensor at the previous timestep.
""" # 结束函数的文档字符串
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) # 获取当前时间步,如果没有则从关键字参数中提取
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) # 获取前一个时间步,如果没有则从关键字参数中提取
if sample is None: # 检查 sample 是否为 None
if len(args) > 2: # 如果 args 的长度大于 2
sample = args[2] # 从 args 中获取 sample
else: # 否则
raise ValueError(" missing `sample` as a required keyward argument") # 抛出缺少 sample 的异常
if timestep is not None: # 如果当前时间步不为 None
deprecate( # 调用 deprecate 函数以发出弃用警告
"timesteps", # 被弃用的参数名称
"1.0.0", # 版本号
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", # 弃用说明
)
if prev_timestep is not None: # 如果前一个时间步不为 None
deprecate( # 调用 deprecate 函数以发出弃用警告
"prev_timestep", # 被弃用的参数名称
"1.0.0", # 版本号
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", # 弃用说明
)
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] # 获取当前和前一个时间步的 sigma 值
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) # 将 sigma_t 转换为 alpha_t 和 sigma_t
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) # 将 sigma_s 转换为 alpha_s 和 sigma_s
lambda_t = torch.log(alpha_t) - torch.log(sigma_t) # 计算 lambda_t 为 alpha_t 和 sigma_t 的对数差
lambda_s = torch.log(alpha_s) - torch.log(sigma_s) # 计算 lambda_s 为 alpha_s 和 sigma_s 的对数差
h = lambda_t - lambda_s # 计算 h 为 lambda_t 和 lambda_s 的差
if self.config.algorithm_type == "deis": # 检查算法类型是否为 "deis"
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output # 计算当前样本 x_t
else: # 否则
raise NotImplementedError("only support log-rho multistep deis now") # 抛出不支持的算法类型异常
return x_t # 返回计算得到的样本 x_t
def multistep_deis_second_order_update( # 定义 multistep_deis_second_order_update 函数
self, # 类实例
model_output_list: List[torch.Tensor], # 参数 model_output_list,类型为 torch.Tensor 的列表
*args, # 可变位置参数
sample: torch.Tensor = None, # 参数 sample,默认为 None
**kwargs, # 可变关键字参数
# 定义一个函数,返回类型为 torch.Tensor
) -> torch.Tensor:
"""
第二阶多步 DEIS 的一步计算。
参数:
model_output_list (`List[torch.Tensor]`):
当前和后续时间步的学习扩散模型直接输出。
sample (`torch.Tensor`):
扩散过程生成的当前样本实例。
返回:
`torch.Tensor`:
上一时间步的样本张量。
"""
# 获取时间步列表,如果没有则从关键字参数中获取
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
# 获取前一个时间步,如果没有则从关键字参数中获取
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
# 如果样本为 None,则尝试从参数中获取样本
if sample is None:
if len(args) > 2:
sample = args[2]
else:
# 如果样本仍然为 None,则引发错误
raise ValueError(" missing `sample` as a required keyward argument")
# 如果时间步列表不为 None,则发出弃用警告
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
# 如果前一个时间步不为 None,则发出弃用警告
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
# 获取当前和前后时间步的 sigma 值
sigma_t, sigma_s0, sigma_s1 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
)
# 将 sigma 转换为 alpha 和 sigma_t
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
# 获取最后两个模型输出
m0, m1 = model_output_list[-1], model_output_list[-2]
# 计算 rho 值
rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1
# 检查算法类型是否为 "deis"
if self.config.algorithm_type == "deis":
# 定义积分函数
def ind_fn(t, b, c):
# Integrate[(log(t) - log(c)) / (log(b) - log(c)), {t}]
return t * (-np.log(c) + np.log(t) - 1) / (np.log(b) - np.log(c))
# 计算系数
coef1 = ind_fn(rho_t, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s0, rho_s1)
coef2 = ind_fn(rho_t, rho_s1, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s0)
# 计算 x_t
x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1)
# 返回计算结果
return x_t
else:
# 如果算法类型不支持,则引发未实现的错误
raise NotImplementedError("only support log-rho multistep deis now")
# 定义一个多步 DEIS 第三阶更新的函数
def multistep_deis_third_order_update(
self,
model_output_list: List[torch.Tensor],
*args,
# 当前样本实例,默认为 None
sample: torch.Tensor = None,
**kwargs,
# 从 diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep 复制
# 根据时间步初始化索引
def index_for_timestep(self, timestep, schedule_timesteps=None):
# 如果未提供时间调度步,则使用默认时间步
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
# 找到与当前时间步匹配的候选索引
index_candidates = (schedule_timesteps == timestep).nonzero()
# 如果没有找到匹配的候选索引
if len(index_candidates) == 0:
# 将步骤索引设置为时间步的最后一个索引
step_index = len(self.timesteps) - 1
# 如果找到多个候选索引
# 第一个步骤的 sigma 索引总是第二个索引(如果只有一个则是最后一个)
# 这样可以确保在去噪调度中不会意外跳过 sigma
elif len(index_candidates) > 1:
# 使用第二个候选索引作为步骤索引
step_index = index_candidates[1].item()
else:
# 否则,使用第一个候选索引作为步骤索引
step_index = index_candidates[0].item()
# 返回最终步骤索引
return step_index
# 从 diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index 中复制
def _init_step_index(self, timestep):
"""
初始化调度器的步骤索引计数器。
"""
# 如果开始索引为 None
if self.begin_index is None:
# 如果时间步是张量类型,则将其转移到相应设备
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
# 使用 index_for_timestep 方法初始化步骤索引
self._step_index = self.index_for_timestep(timestep)
else:
# 否则使用预设的开始索引
self._step_index = self._begin_index
# 执行一步计算
def step(
self,
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
从前一个时间步预测样本,通过反转 SDE。此函数使用多步 DEIS 传播样本。
参数:
model_output (`torch.Tensor`):
从学习的扩散模型直接输出的张量。
timestep (`int`):
扩散链中当前离散时间步。
sample (`torch.Tensor`):
通过扩散过程创建的当前样本实例。
return_dict (`bool`):
是否返回 [`~schedulers.scheduling_utils.SchedulerOutput`] 或 `tuple`。
返回:
[`~schedulers.scheduling_utils.SchedulerOutput`] 或 `tuple`:
如果 return_dict 为 `True`,则返回 [`~schedulers.scheduling_utils.SchedulerOutput`],否则返回一个元组,
其中第一个元素是样本张量。
"""
# 检查推理步骤数量是否为 None,若是则抛出异常
if self.num_inference_steps is None:
raise ValueError(
"推理步骤数量为 'None',您需要在创建调度器后运行 'set_timesteps'"
)
# 检查当前步骤索引是否为 None,若是则初始化步骤索引
if self.step_index is None:
self._init_step_index(timestep)
# 判断是否为较低阶最终更新的条件
lower_order_final = (
(self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
)
# 判断是否为较低阶第二更新的条件
lower_order_second = (
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
)
# 转换模型输出为适合当前样本的格式
model_output = self.convert_model_output(model_output, sample=sample)
# 更新模型输出缓存,将当前模型输出存储到最后一个位置
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output
# 根据配置选择合适的更新方法
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
# 使用一阶更新方法计算前一个样本
prev_sample = self.deis_first_order_update(model_output, sample=sample)
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
# 使用二阶更新方法计算前一个样本
prev_sample = self.multistep_deis_second_order_update(self.model_outputs, sample=sample)
else:
# 使用三阶更新方法计算前一个样本
prev_sample = self.multistep_deis_third_order_update(self.model_outputs, sample=sample)
# 更新较低阶次数计数器
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
# 完成后将步骤索引加一
self._step_index += 1
# 如果不返回字典,则返回包含前一个样本的元组
if not return_dict:
return (prev_sample,)
# 返回前一个样本的调度输出
return SchedulerOutput(prev_sample=prev_sample)
# 定义一个方法,用于根据当前时间步缩放去噪模型输入
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
确保与需要根据当前时间步缩放去噪模型输入的调度器之间的互换性。
Args:
sample (`torch.Tensor`):
输入样本。
Returns:
`torch.Tensor`:
缩放后的输入样本。
"""
# 返回未修改的输入样本
return sample
# 从 diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise 复制的代码
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
# 确保 sigmas 和 timesteps 与 original_samples 具有相同的设备和数据类型
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
# 检查设备类型,如果是 MPS 且 timesteps 是浮点型
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# MPS 不支持 float64 数据类型
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
# 将调度时间步转换为与原始样本相同的设备
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# 如果 begin_index 为 None,表示调度器用于训练或管道未实现 set_begin_index
if self.begin_index is None:
# 根据时间步获取步索引
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# 在第一次去噪步骤后调用 add_noise(用于修补)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# 在第一次去噪步骤之前调用 add_noise 以创建初始潜在图像(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]
# 根据步索引获取 sigma,并将其扁平化
sigma = sigmas[step_indices].flatten()
# 如果 sigma 的形状小于原始样本的形状,则在最后一个维度添加一个维度
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
# 将 sigma 转换为 alpha_t 和 sigma_t
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
# 生成带噪声的样本
noisy_samples = alpha_t * original_samples + sigma_t * noise
# 返回带噪声的样本
return noisy_samples
# 定义方法以返回训练时间步的数量
def __len__(self):
return self.config.num_train_timesteps
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
创建一个 beta 调度,它离散化给定的 alpha_t_bar 函数,定义时间 t = [0,1] 上 (1-beta) 的累积乘积。
包含一个 alpha_bar 函数,该函数接受参数 t 并将其转换为该部分扩散过程的 (1-beta) 的累积乘积。
参数:
num_diffusion_timesteps (`int`): 生成 beta 的数量。
max_beta (`float`): 使用的最大 beta 值;使用小于 1 的值以防止奇异性。
alpha_transform_type (`str`, *可选*, 默认为 `cosine`): alpha_bar 的噪声调度类型。
从 `cosine` 或 `exp` 中选择
返回:
betas (`np.ndarray`): 调度器用于更新模型输出的 betas
"""
if alpha_transform_type == "cosine":
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
def rescale_zero_terminal_snr(betas):
"""
根据 https://arxiv.org/pdf/2305.08891.pdf (算法 1) 重新调整 beta 以具有零终端 SNR
参数:
betas (`torch.Tensor`):
用于初始化调度器的 beta。
# 返回 rescaled betas,且终端信噪比为零
Returns:
`torch.Tensor`: rescaled betas with zero terminal SNR
"""
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
alphas_bar_sqrt -= alphas_bar_sqrt_T
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
alphas_bar = alphas_bar_sqrt**2
alphas = alphas_bar[1:] / alphas_bar[:-1]
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
`DPMSolverMultistepScheduler` 是一个快速的专用高阶求解器,用于扩散 ODE。
该模型继承自 [`SchedulerMixin`] 和 [`ConfigMixin`]。请查看父类文档以了解该库为所有调度器实现的通用方法,例如加载和保存。
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2,
prediction_type: str = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
use_lu_lambdas: Optional[bool] = False,
final_sigmas_type: Optional[str] = "zero",
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
@property
def step_index(self):
"""
当前时间步的索引计数器。每次调度器步骤后会增加 1。
"""
return self._step_index
@property
def begin_index(self):
"""
第一个时间步的索引。应通过 `set_begin_index` 方法从管道设置。
"""
return self._begin_index
def set_begin_index(self, begin_index: int = 0):
"""
设置调度器的起始索引。该函数应在推理前通过管道运行。
参数:
begin_index (`int`):
调度器的起始索引。
"""
self._begin_index = begin_index
def set_timesteps(
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
timesteps: Optional[List[int]] = None,
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
"动态阈值处理:在每个采样步骤中,我们将 s 设置为 xt0(在时间步 t 对 x_0 的预测)中的某个百分位绝对像素值,如果 s > 1,则将 xt0 阈值化到范围 [-s, s],然后除以 s。动态阈值处理将饱和像素(接近 -1 和 1 的像素)推入内部,从而在每一步主动防止像素饱和。我们发现动态阈值处理可以显著提高照片真实感以及图像与文本的对齐,尤其是在使用非常大的引导权重时。"
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float()
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs()
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
)
s = s.unsqueeze(1)
sample = torch.clamp(sample, -s, s) / s
sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
return sample
def _sigma_to_t(self, sigma, log_sigmas):
log_sigma = np.log(np.maximum(sigma, 1e-10))
dists = log_sigma - log_sigmas[:, np.newaxis]
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t
def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
sigma_t = sigma * alpha_t
return alpha_t, sigma_t
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""构建 Karras 等人 (2022) 的噪声调度。"""
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""构建 Lu 等人 (2022) 的噪声调度。"""
lambda_min: float = in_lambdas[-1].item()
lambda_max: float = in_lambdas[0].item()
rho = 1.0
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = lambda_min ** (1 / rho)
max_inv_rho = lambda_max ** (1 / rho)
lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return lambdas
def convert_model_output(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
**kwargs,
def dpm_solver_first_order_update(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
一步用于第一阶 DPMSolver(等效于 DDIM)。
参数:
model_output (`torch.Tensor`):
从学习的扩散模型直接输出的张量。
sample (`torch.Tensor`):
扩散过程中创建的当前样本实例。
返回:
`torch.Tensor`:
上一个时间步的样本张量。
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
h = lambda_t - lambda_s
if self.config.algorithm_type == "dpmsolver++":
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
elif self.config.algorithm_type == "dpmsolver":
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
x_t = (
(sigma_t / sigma_s * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None
x_t = (
(alpha_t / alpha_s) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
return x_t
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.Tensor],
*args,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
def multistep_dpm_solver_third_order_update(
self,
model_output_list: List[torch.Tensor],
*args,
sample: torch.Tensor = None,
**kwargs,
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
return step_index
def _init_step_index(self, timestep):
"""
初始化调度器的步索引计数器。
"""
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
确保与需要根据当前时间步缩放去噪模型输入的调度器互换。
参数:
sample (`torch.Tensor`):
输入样本。
返回:
`torch.Tensor`:
缩放后的输入样本。
"""
return sample
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
step_indices = [self.step_index] * timesteps.shape[0]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import flax
import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
add_noise_common,
)
@flax.struct.dataclass
class DPMSolverMultistepSchedulerState:
common: CommonSchedulerState
alpha_t: jnp.ndarray
sigma_t: jnp.ndarray
lambda_t: jnp.ndarray
init_noise_sigma: jnp.ndarray
timesteps: jnp.ndarray
num_inference_steps: Optional[int] = None
model_outputs: Optional[jnp.ndarray] = None
lower_order_nums: Optional[jnp.int32] = None
prev_timestep: Optional[jnp.int32] = None
cur_sample: Optional[jnp.ndarray] = None
@classmethod
def create(
cls,
common: CommonSchedulerState,
alpha_t: jnp.ndarray,
sigma_t: jnp.ndarray,
lambda_t: jnp.ndarray,
init_noise_sigma: jnp.ndarray,
timesteps: jnp.ndarray,
):
return cls(
common=common,
alpha_t=alpha_t,
sigma_t=sigma_t,
lambda_t=lambda_t,
init_noise_sigma=init_noise_sigma,
timesteps=timesteps,
)
@dataclass
class FlaxDPMSolverMultistepSchedulerOutput(FlaxSchedulerOutput):
state: DPMSolverMultistepSchedulerState
class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
DPM-Solver(以及改进版 DPM-Solver++)是一个快速的专用高阶求解器,用于扩散 ODE,并提供收敛阶数保证。
实证表明,使用 DPM-Solver 仅 20 步就能生成高质量样本,即使仅用 10 步也能生成相当不错的样本。
有关更多详细信息,请参见原始论文: https://arxiv.org/abs/2206.00927 和 https://arxiv.org/abs/2211.01095
目前,我们支持多步 DPM-Solver 适用于噪声预测模型和数据预测模型。
我们建议使用 `solver_order=2` 进行引导采样,使用 `solver_order=3` 进行无条件采样。
# 支持 Imagen 中的“动态阈值”方法,参考文献:https://arxiv.org/abs/2205.11487
We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487).
# 对于像素空间扩散模型,可以同时设置 `algorithm_type="dpmsolver++"` 和 `thresholding=True` 来使用动态阈值
For pixel-space diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic
# 注意,阈值方法不适合于潜空间扩散模型(如 stable-diffusion)
thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
stable-diffusion).
# `ConfigMixin` 负责存储在调度器的 `__init__` 函数中传递的所有配置属性,例如 `num_train_timesteps`
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
# 这些属性可以通过 `scheduler.config.num_train_timesteps` 访问
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
# `SchedulerMixin` 提供通用的加载和保存功能,通过 [`SchedulerMixin.save_pretrained`] 和 [`~SchedulerMixin.from_pretrained`] 函数
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
# 有关更多详细信息,请参见原始论文: https://arxiv.org/abs/2206.00927 和 https://arxiv.org/abs/2211.01095
For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
# 兼容的调度器列表,从 FlaxKarrasDiffusionSchedulers 中提取名称
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
# 数据类型变量
dtype: jnp.dtype
# 属性,返回是否有状态
@property
def has_state(self):
return True
# 注册到配置的初始化函数,定义多个参数的默认值
@register_to_config
def __init__(
# 训练时间步数,默认为 1000
num_train_timesteps: int = 1000,
# beta 的起始值,默认为 0.0001
beta_start: float = 0.0001,
# beta 的结束值,默认为 0.02
beta_end: float = 0.02,
# beta 的调度方式,默认为 "linear"
beta_schedule: str = "linear",
# 已训练的 beta 值,可选
trained_betas: Optional[jnp.ndarray] = None,
# 解算器阶数,默认为 2
solver_order: int = 2,
# 预测类型,默认为 "epsilon"
prediction_type: str = "epsilon",
# 是否启用阈值处理,默认为 False
thresholding: bool = False,
# 动态阈值比例,默认为 0.995
dynamic_thresholding_ratio: float = 0.995,
# 采样最大值,默认为 1.0
sample_max_value: float = 1.0,
# 算法类型,默认为 "dpmsolver++"
algorithm_type: str = "dpmsolver++",
# 解算器类型,默认为 "midpoint"
solver_type: str = "midpoint",
# 最后阶段是否降低阶数,默认为 True
lower_order_final: bool = True,
# 时间步的间隔类型,默认为 "linspace"
timestep_spacing: str = "linspace",
# 数据类型,默认为 jnp.float32
dtype: jnp.dtype = jnp.float32,
):
# 将数据类型赋值给实例变量
self.dtype = dtype
# 创建状态的方法,接受一个可选的公共调度状态参数,返回 DPM 求解器多步调度状态
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState:
# 如果没有提供公共调度状态,则创建一个新的实例
if common is None:
common = CommonSchedulerState.create(self)
# 当前仅支持 VP 类型的噪声调度
alpha_t = jnp.sqrt(common.alphas_cumprod) # 计算累积 alpha 的平方根
sigma_t = jnp.sqrt(1 - common.alphas_cumprod) # 计算 1 减去累积 alpha 的平方根
lambda_t = jnp.log(alpha_t) - jnp.log(sigma_t) # 计算 alpha_t 和 sigma_t 的对数差
# DPM 求解器的设置
if self.config.algorithm_type not in ["dpmsolver", "dpmsolver++"]:
# 如果算法类型不在支持的列表中,则抛出未实现异常
raise NotImplementedError(f"{self.config.algorithm_type} is not implemented for {self.__class__}")
if self.config.solver_type not in ["midpoint", "heun"]:
# 如果求解器类型不在支持的列表中,则抛出未实现异常
raise NotImplementedError(f"{self.config.solver_type} is not implemented for {self.__class__}")
# 初始化噪声分布的标准差
init_noise_sigma = jnp.array(1.0, dtype=self.dtype) # 创建一个值为 1.0 的数组,类型为实例的 dtype
# 生成时间步的数组,从 0 到 num_train_timesteps,取整后反转
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
# 创建并返回 DPM 求解器多步调度状态
return DPMSolverMultistepSchedulerState.create(
common=common, # 传入公共调度状态
alpha_t=alpha_t, # 传入计算得到的 alpha_t
sigma_t=sigma_t, # 传入计算得到的 sigma_t
lambda_t=lambda_t, # 传入计算得到的 lambda_t
init_noise_sigma=init_noise_sigma, # 传入初始化噪声的标准差
timesteps=timesteps, # 传入时间步数组
)
# 设置时间步的方法,接受当前状态、推理步骤数和形状作为参数
def set_timesteps(
self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple
) -> DPMSolverMultistepSchedulerState: # 定义返回类型为 DPMSolverMultistepSchedulerState
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
state (`DPMSolverMultistepSchedulerState`):
the `FlaxDPMSolverMultistepScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
shape (`Tuple`):
the shape of the samples to be generated.
""" # 文档字符串结束
last_timestep = self.config.num_train_timesteps # 获取训练时的最后时间步
if self.config.timestep_spacing == "linspace": # 检查时间步间距配置是否为线性空间
timesteps = ( # 生成线性空间的时间步
jnp.linspace(0, last_timestep - 1, num_inference_steps + 1) # 生成从0到最后时间步的线性间隔
.round()[::-1][:-1] # 取反并去掉最后一个元素
.astype(jnp.int32) # 转换为整型
)
elif self.config.timestep_spacing == "leading": # 检查时间步间距配置是否为前导
step_ratio = last_timestep // (num_inference_steps + 1) # 计算步骤比率
# creates integer timesteps by multiplying by ratio # 通过乘以比率创建整数时间步
# casting to int to avoid issues when num_inference_step is power of 3 # 强制转换为整数以避免在 num_inference_step 为 3 的幂时的问题
timesteps = ( # 生成前导时间步
(jnp.arange(0, num_inference_steps + 1) * step_ratio) # 创建范围并乘以步骤比率
.round()[::-1][:-1] # 取反并去掉最后一个元素
.copy().astype(jnp.int32) # 复制并转换为整型
)
timesteps += self.config.steps_offset # 加上步骤偏移量
elif self.config.timestep_spacing == "trailing": # 检查时间步间距配置是否为后置
step_ratio = self.config.num_train_timesteps / num_inference_steps # 计算步骤比率
# creates integer timesteps by multiplying by ratio # 通过乘以比率创建整数时间步
# casting to int to avoid issues when num_inference_step is power of 3 # 强制转换为整数以避免在 num_inference_step 为 3 的幂时的问题
timesteps = jnp.arange(last_timestep, 0, -step_ratio) # 从最后时间步到0生成时间步
.round().copy().astype(jnp.int32) # 四舍五入、复制并转换为整型
timesteps -= 1 # 时间步减去1
else: # 如果没有匹配的时间步间距配置
raise ValueError( # 抛出值错误
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." # 提示用户选择有效的时间步间距
)
# initial running values # 初始化运行值
model_outputs = jnp.zeros((self.config.solver_order,) + shape, dtype=self.dtype) # 创建模型输出数组,初始化为零
lower_order_nums = jnp.int32(0) # 初始化低阶数字为0
prev_timestep = jnp.int32(-1) # 初始化前一个时间步为-1
cur_sample = jnp.zeros(shape, dtype=self.dtype) # 创建当前样本数组,初始化为零
return state.replace( # 返回更新后的状态
num_inference_steps=num_inference_steps, # 更新推断步骤数量
timesteps=timesteps, # 更新时间步
model_outputs=model_outputs, # 更新模型输出
lower_order_nums=lower_order_nums, # 更新低阶数字
prev_timestep=prev_timestep, # 更新前一个时间步
cur_sample=cur_sample, # 更新当前样本
)
def convert_model_output( # 定义转换模型输出的函数
self, # 实例对象
state: DPMSolverMultistepSchedulerState, # 状态参数,类型为 DPMSolverMultistepSchedulerState
model_output: jnp.ndarray, # 模型输出参数,类型为 jnp.ndarray
timestep: int, # 当前时间步参数,类型为 int
sample: jnp.ndarray, # 样本参数,类型为 jnp.ndarray
def dpm_solver_first_order_update( # 定义一阶更新的扩散模型求解器函数
self, # 实例对象
state: DPMSolverMultistepSchedulerState, # 状态参数,类型为 DPMSolverMultistepSchedulerState
model_output: jnp.ndarray, # 模型输出参数,类型为 jnp.ndarray
timestep: int, # 当前时间步参数,类型为 int
prev_timestep: int, # 前一个时间步参数,类型为 int
sample: jnp.ndarray, # 样本参数,类型为 jnp.ndarray
# 函数返回一个一阶DPM求解器的步骤结果,等效于DDIM
) -> jnp.ndarray:
# 文档字符串,说明函数的用途及详细推导链接
"""
One step for the first-order DPM-Solver (equivalent to DDIM).
See https://arxiv.org/abs/2206.00927 for the detailed derivation.
Args:
model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
Returns:
`jnp.ndarray`: the sample tensor at the previous timestep.
"""
# 将前一个时间步和当前时间步赋值给变量
t, s0 = prev_timestep, timestep
# 获取模型输出
m0 = model_output
# 获取当前和前一个时间步的lambda值
lambda_t, lambda_s = state.lambda_t[t], state.lambda_t[s0]
# 获取当前和前一个时间步的alpha值
alpha_t, alpha_s = state.alpha_t[t], state.alpha_t[s0]
# 获取当前和前一个时间步的sigma值
sigma_t, sigma_s = state.sigma_t[t], state.sigma_t[s0]
# 计算h值,表示lambda_t与lambda_s的差异
h = lambda_t - lambda_s
# 根据配置的算法类型选择相应的计算公式
if self.config.algorithm_type == "dpmsolver++":
# 计算当前样本的更新值,使用dpmsolver++公式
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * m0
elif self.config.algorithm_type == "dpmsolver":
# 计算当前样本的更新值,使用dpmsolver公式
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * m0
# 返回更新后的样本
return x_t
# 定义一个多步骤DPM求解器的二阶更新函数
def multistep_dpm_solver_second_order_update(
# 接受当前状态作为参数
self,
state: DPMSolverMultistepSchedulerState,
# 接受模型输出列表作为参数
model_output_list: jnp.ndarray,
# 接受时间步列表作为参数
timestep_list: List[int],
# 接受前一个时间步作为参数
prev_timestep: int,
# 接受当前样本作为参数
sample: jnp.ndarray,
# 返回上一个时间步的样本张量
) -> jnp.ndarray:
# DPM-Solver的二阶多步一步
# 参数说明:
# model_output_list:当前和后续时间步的扩散模型直接输出的列表
# timestep:当前和后续离散时间步
# prev_timestep:前一个离散时间步
# sample:当前扩散过程中的样本实例
# 返回值为上一个时间步的样本张量
"""
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
m0, m1 = model_output_list[-1], model_output_list[-2]
lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1]
alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0]
sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0]
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
r0 = h_0 / h
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
if self.config.algorithm_type == "dpmsolver++":
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s0) * sample
- (alpha_t * (jnp.exp(-h) - 1.0)) * D0
- 0.5 * (alpha_t * (jnp.exp(-h) - 1.0)) * D1
)
elif self.config.solver_type == "heun":
x_t = (
(sigma_t / sigma_s0) * sample
- (alpha_t * (jnp.exp(-h) - 1.0)) * D0
+ (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1
)
elif self.config.algorithm_type == "dpmsolver":
if self.config.solver_type == "midpoint":
x_t = (
(alpha_t / alpha_s0) * sample
- (sigma_t * (jnp.exp(h) - 1.0)) * D0
- 0.5 * (sigma_t * (jnp.exp(h) - 1.0)) * D1
)
elif self.config.solver_type == "heun":
x_t = (
(alpha_t / alpha_s0) * sample
- (sigma_t * (jnp.exp(h) - 1.0)) * D0
- (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1
)
return x_t
def multistep_dpm_solver_third_order_update(
state: DPMSolverMultistepSchedulerState,
model_output_list: jnp.ndarray,
timestep_list: List[int],
prev_timestep: int,
sample: jnp.ndarray,
) -> jnp.ndarray:
""" # 开始文档字符串
One step for the third-order multistep DPM-Solver. # 描述该函数为三阶多步 DPM 求解器的一步
Args: # 开始参数说明
model_output_list (`List[jnp.ndarray]`): # 定义模型输出列表参数
direct outputs from learned diffusion model at current and latter timesteps. # 描述该参数为当前及后续时间步的扩散模型直接输出
timestep (`int`): # 定义当前时间步参数
current and latter discrete timestep in the diffusion chain. # 描述该参数为扩散链中当前及后续离散时间步
prev_timestep (`int`): # 定义前一个时间步参数
previous discrete timestep in the diffusion chain. # 描述该参数为扩散链中前一个离散时间步
sample (`jnp.ndarray`): # 定义样本参数
current instance of sample being created by diffusion process. # 描述该参数为当前通过扩散过程创建的样本实例
Returns: # 开始返回值说明
`jnp.ndarray`: the sample tensor at the previous timestep. # 描述返回值为前一个时间步的样本张量
"""
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
state.lambda_t[t],
state.lambda_t[s0],
state.lambda_t[s1],
state.lambda_t[s2],
)
alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0]
sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0]
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
r0, r1 = h_0 / h, h_1 / h
D0 = m0
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
if self.config.algorithm_type == "dpmsolver++":
x_t = (
(sigma_t / sigma_s0) * sample
- (alpha_t * (jnp.exp(-h) - 1.0)) * D0
+ (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1
- (alpha_t * ((jnp.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
)
elif self.config.algorithm_type == "dpmsolver":
x_t = (
(alpha_t / alpha_s0) * sample
- (sigma_t * (jnp.exp(h) - 1.0)) * D0
- (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1
- (sigma_t * ((jnp.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
)
return x_t
def step(
self,
state: DPMSolverMultistepSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
return_dict: bool = True,
def scale_model_input(
self,
state: DPMSolverMultistepSchedulerState,
sample: jnp.ndarray,
timestep: Optional[int] = None
) -> jnp.ndarray:
""" # 文档字符串,描述函数的作用及参数
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. # 确保与需要根据当前时间步缩放去噪模型输入的调度器的可互换性
Args: # 参数说明部分
state (`DPMSolverMultistepSchedulerState`): # state 参数,类型为 DPMSolverMultistepSchedulerState
the `FlaxDPMSolverMultistepScheduler` state data class instance. # FlaxDPMSolverMultistepScheduler 的状态数据类实例
sample (`jnp.ndarray`): input sample # sample 参数,类型为 jnp.ndarray,表示输入样本
timestep (`int`, optional): current timestep # timestep 参数,类型为 int,可选,表示当前时间步
Returns: # 返回值说明部分
`jnp.ndarray`: scaled input sample # 返回一个 jnp.ndarray,表示缩放后的输入样本
"""
return sample
def add_noise(
self,
state: DPMSolverMultistepSchedulerState,
original_samples: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
return add_noise_common(state.common, original_samples, noise, timesteps)
def __len__(self):
return self.config.num_train_timesteps
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
创建一个 beta 调度,离散化给定的 alpha_t_bar 函数,该函数定义了随时间变化的 (1-beta) 的累积乘积。
包含一个 alpha_bar 函数,该函数接受参数 t 并将其转换为扩散过程的累积乘积。
参数:
num_diffusion_timesteps (`int`): 生成的 beta 数量。
max_beta (`float`): 使用的最大 beta 值;使用低于 1 的值以防止奇异性。
alpha_transform_type (`str`, *可选*, 默认为 `cosine`): alpha_bar 的噪声调度类型。
可选值为 `cosine` 或 `exp`
返回:
betas (`np.ndarray`): 调度器用来更新模型输出的 betas。
"""
if alpha_transform_type == "cosine":
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
"""
`DPMSolverMultistepInverseScheduler` 是 [`DPMSolverMultistepScheduler`] 的反向调度器。
该模型继承自 [`SchedulerMixin`] 和 [`ConfigMixin`]。有关通用的信息,请查看父类文档。
# 文档字符串,描述库为所有调度程序实现的方法,例如加载和保存功能。
methods the library implements for all schedulers such as loading and saving.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2,
prediction_type: str = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
self.init_noise_sigma = 1.0
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
if algorithm_type == "deis":
self.register_to_config(algorithm_type="dpmsolver++")
else:
raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
if solver_type not in ["midpoint", "heun"]:
if solver_type in ["logrho", "bh1", "bh2"]:
self.register_to_config(solver_type="midpoint")
else:
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
self.num_inference_steps = None
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32).copy()
self.timesteps = torch.from_numpy(timesteps)
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self.sigmas = self.sigmas.to("cpu")
self.use_karras_sigmas = use_karras_sigmas
@property
def step_index(self):
"""
当前时间步的索引计数器。每次调度器步骤后增加1。
"""
return self._step_index
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
"动态阈值处理:在每个采样步骤中,我们将 s 设置为 xt0(在时间步 t 预测的 x_0)中的某个百分位绝对像素值,
如果 s > 1,则我们将 xt0 阈值处理到范围 [-s, s],然后除以 s。动态阈值处理将饱和像素(接近 -1 和 1 的像素)
向内推,以主动防止每一步的饱和。我们发现动态阈值处理显著提高了照片真实感以及图像-文本对齐,特别是在使用非常大的引导权重时。"
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float()
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs()
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
)
s = s.unsqueeze(1)
sample = torch.clamp(sample, -s, s) / s
sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
return sample
def _sigma_to_t(self, sigma, log_sigmas):
log_sigma = np.log(np.maximum(sigma, 1e-10))
dists = log_sigma - log_sigmas[:, np.newaxis]
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t
def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
sigma_t = sigma * alpha_t
return alpha_t, sigma_t
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""构建 Karras 等人(2022年)的噪声调度。"""
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def convert_model_output(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
**kwargs,
):
def dpm_solver_first_order_update(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
):
) -> torch.Tensor:
"""
对第一阶 DPMSolver 执行一步(相当于 DDIM)。
参数:
model_output (`torch.Tensor`):
从学习的扩散模型直接输出的张量。
sample (`torch.Tensor`):
扩散过程中生成的当前样本实例。
返回:
`torch.Tensor`:
前一时间步的样本张量。
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
h = lambda_t - lambda_s
if self.config.algorithm_type == "dpmsolver++":
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
elif self.config.algorithm_type == "dpmsolver":
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
x_t = (
(sigma_t / sigma_s * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None
x_t = (
(alpha_t / alpha_s) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
return x_t
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.Tensor],
*args,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
def multistep_dpm_solver_third_order_update(
self,
model_output_list: List[torch.Tensor],
*args,
sample: torch.Tensor = None,
**kwargs,
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
self._step_index = step_index
def step(
self,
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
确保与需要根据当前时间步缩放去噪模型输入的调度器互换性。
参数:
sample (`torch.Tensor`):
输入样本。
返回:
`torch.Tensor`:
一个缩放后的输入样本。
"""
return sample
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
step_indices = []
for timestep in timesteps:
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(schedule_timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torchsde
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
class BatchedBrownianTree:
"""封装 torchsde.BrownianTree 以支持批量熵的类。"""
def __init__(self, x, t0, t1, seed=None, **kwargs):
t0, t1, self.sign = self.sort(t0, t1)
w0 = kwargs.get("w0", torch.zeros_like(x))
if seed is None:
seed = torch.randint(0, 2**63 - 1, []).item()
self.batched = True
try:
assert len(seed) == x.shape[0]
w0 = w0[0]
except TypeError:
seed = [seed]
self.batched = False
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
@staticmethod
def sort(a, b):
return (a, b, 1) if a < b else (b, a, -1)
def __call__(self, t0, t1):
t0, t1, sign = self.sort(t0, t1)
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
return w if self.batched else w[0]
class BrownianTreeNoiseSampler:
"""基于 torchsde.BrownianTree 的噪声采样器。
参数:
x (Tensor): 用于生成随机样本的张量,其形状、设备和数据类型将被使用。
sigma_min (float): 有效区间的下限。
sigma_max (float): 有效区间的上限。
seed (int 或 List[int]): 随机种子。如果提供了种子列表而不是单个整数,
则噪声采样器将为每个批量项目使用一个 BrownianTree,每个都有自己的种子。
transform (callable): 一个函数,将 sigma 映射到采样器的内部时间步。
"""
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
self.transform = transform
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
self.tree = BatchedBrownianTree(x, t0, t1, seed)
def __call__(self, sigma, sigma_next):
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
创建一个 beta 调度,离散化给定的 alpha_t_bar 函数,该函数定义了时间 t = [0,1] 的
(1-beta) 的累积乘积。
包含一个 alpha_bar 函数,该函数接收 t 参数并将其转换为 (1-beta) 的累积乘积
直至扩散过程的该部分。
Args:
num_diffusion_timesteps (`int`): 生成的 beta 数量。
max_beta (`float`): 使用的最大 beta 值;使用低于 1 的值来防止奇点。
alpha_transform_type (`str`, *可选*, 默认为 `cosine`): alpha_bar 的噪声调度类型。
从 `cosine` 或 `exp` 中选择
Returns:
betas (`np.ndarray`): 调度程序用于步骤模型输出的 betas
"""
if alpha_transform_type == "cosine":
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
"""
DPMSolverSDEScheduler 实现了 [Elucidating the Design Space of Diffusion-Based
Generative Models](https://huggingface.co/papers/2206.00364) 论文中的随机采样器。
此模型继承自 [`SchedulerMixin`] 和 [`ConfigMixin`]。请查看超类文档以获取库为所有调度器实现的通用
方法,例如加载和保存。
# 定义初始化方法的参数及其默认值
Args:
num_train_timesteps (`int`, defaults to 1000): # 训练模型的扩散步骤数量,默认为1000
The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.00085): # 推理的起始 beta 值,默认为0.00085
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.012): # 推理的最终 beta 值,默认为0.012
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`): # beta 计划,定义 beta 范围到模型步骤的映射,默认为线性
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, *optional*): # 可选参数,直接传递 beta 数组以跳过 beta_start 和 beta_end
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
prediction_type (`str`, defaults to `epsilon`, *optional*): # 调度函数的预测类型,默认为预测扩散过程的噪声
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper).
use_karras_sigmas (`bool`, *optional*, defaults to `False`): # 是否在采样过程中使用 Karras sigmas 来调整噪声调度的步长,默认为 False
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
noise_sampler_seed (`int`, *optional*, defaults to `None`): # 噪声采样器使用的随机种子,默认为 None 时生成随机种子
The random seed to use for the noise sampler. If `None`, a random seed is generated.
timestep_spacing (`str`, defaults to `"linspace"`): # 定义时间步的缩放方式,默认为线性间隔
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0): # 推理步骤的偏移量,某些模型家族所需,默认为0
An offset added to the inference steps, as required by some model families.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 2
@register_to_config
def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
use_karras_sigmas: Optional[bool] = False,
noise_sampler_seed: Optional[int] = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
self.use_karras_sigmas = use_karras_sigmas
self.noise_sampler = None
self.noise_sampler_seed = noise_sampler_seed
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu")
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
@property
def init_noise_sigma(self):
if self.config.timestep_spacing in ["linspace", "trailing"]:
return self.sigmas.max()
return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
def step_index(self):
"""
当前时间步的索引计数器。每次调度步骤后增加 1。
"""
return self._step_index
@property
def begin_index(self):
"""
返回初始时间步索引,应该通过 `set_begin_index` 方法设置。
"""
return self._begin_index
def set_begin_index(self, begin_index: int = 0):
"""
设置调度器的初始时间步。此函数应在推断前从管道运行。
Args:
begin_index (`int`):
调度器的初始时间步。
"""
self._begin_index = begin_index
def scale_model_input(
self,
sample: torch.Tensor,
timestep: Union[float, torch.Tensor],
) -> torch.Tensor:
"""
确保与需要根据当前时间步缩放去噪模型输入的调度器的可互换性。
Args:
sample (`torch.Tensor`):
输入样本。
timestep (`int`, *optional*):
扩散链中的当前时间步。
Returns:
`torch.Tensor`:
缩放后的输入样本。
"""
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
sigma_input = sigma if self.state_in_first_order else self.mid_point_sigma
sample = sample / ((sigma_input**2 + 1) ** 0.5)
return sample
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
num_train_timesteps: Optional[int] = None,
def _second_order_timesteps(self, sigmas, log_sigmas):
def sigma_fn(_t):
return np.exp(-_t)
def t_fn(_sigma):
return -np.log(_sigma)
midpoint_ratio = 0.5
t = t_fn(sigmas)
delta_time = np.diff(t)
t_proposed = t[:-1] + delta_time * midpoint_ratio
sig_proposed = sigma_fn(t_proposed)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sig_proposed])
return timesteps
def _sigma_to_t(self, sigma, log_sigmas):
log_sigma = np.log(np.maximum(sigma, 1e-10))
dists = log_sigma - log_sigmas[:, np.newaxis]
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t
def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
"""构建 Karras 等人(2022)提出的噪声调度。"""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
rho = 7.0
ramp = np.linspace(0, 1, self.num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
@property
def state_in_first_order(self):
return self.sample is None
def step(
self,
model_output: Union[torch.Tensor, np.ndarray],
timestep: Union[float, torch.Tensor],
sample: Union[torch.Tensor, np.ndarray],
return_dict: bool = True,
s_noise: float = 1.0,
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
step_indices = [self.step_index] * timesteps.shape[0]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
noisy_samples = original_samples + noise * sigma
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 10年+ .NET Coder 心语 ── 封装的思维:从隐藏、稳定开始理解其本质意义
· 提示词工程——AI应用必不可少的技术
· 地球OL攻略 —— 某应届生求职总结
· 字符编码:从基础到乱码解决
· SpringCloud带你走进微服务的世界