diffusers 源码解析(六十一)
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
创建一个 beta 调度,以离散化给定的 alpha_t_bar 函数,该函数定义了时间 t = [0,1] 的
(1-beta) 的累积乘积。
包含一个 alpha_bar 函数,该函数接受一个参数 t,并将其转换为扩散过程的该部分
(1-beta) 的累积乘积。
参数:
num_diffusion_timesteps (`int`): 生成的 beta 数量。
max_beta (`float`): 使用的最大 beta 值;使用小于 1 的值以防止奇点。
alpha_transform_type (`str`, *可选*, 默认值为 `cosine`): alpha_bar 的噪声调度类型。
从 `cosine` 或 `exp` 中选择
返回:
betas (`np.ndarray`): 调度器用于更新模型输出的 betas
"""
if alpha_transform_type == "cosine":
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
KDPM2DiscreteScheduler 的灵感来自 DPMSolver2 和论文
[Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364)。
此模型继承自 [`SchedulerMixin`] 和 [`ConfigMixin`]。查看超类文档以了解库为所有调度器实现的通用方法,如加载和保存。
"""
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.00085):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.012):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper).
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps, as required by some model families.
""" # 参数说明文档结束
_compatibles = [e.name for e in KarrasDiffusionSchedulers] # 从 KarrasDiffusionSchedulers 中提取兼容的名称列表
order = 2 # 设置调度的顺序为2
@register_to_config # 将此方法注册到配置中
def __init__( # 初始化方法
self,
num_train_timesteps: int = 1000, # 默认训练步骤数为1000
beta_start: float = 0.00085, # sensible defaults
beta_end: float = 0.012, # 默认最终 beta 值
beta_schedule: str = "linear", # 默认调度方式为线性
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, # 可选的训练 beta 数组
use_karras_sigmas: Optional[bool] = False, # 默认不使用 Karras sigmas
prediction_type: str = "epsilon", # 默认预测类型为 epsilon
timestep_spacing: str = "linspace", # 默认时间步长缩放方式为线性空间
steps_offset: int = 0, # 默认步骤偏移量为0
):
# 检查是否有训练好的 beta 值
if trained_betas is not None:
# 将训练好的 beta 值转换为张量,数据类型为 float32
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
# 检查 beta 调度是否为线性
elif beta_schedule == "linear":
# 生成从 beta_start 到 beta_end 的线性序列,长度为 num_train_timesteps
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
# 检查 beta 调度是否为缩放线性
elif beta_schedule == "scaled_linear":
# 该调度特定于潜在扩散模型
# 生成从 beta_start 的平方根到 beta_end 的平方根的线性序列,再平方
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
# 检查 beta 调度是否为平方余弦
elif beta_schedule == "squaredcos_cap_v2":
# Glide 余弦调度
# 使用 betas_for_alpha_bar 函数生成 beta 值
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
# 如果 beta 调度不在已实现的范围内,抛出未实现错误
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
# 计算 alpha 值,等于 1 减去 beta 值
self.alphas = 1.0 - self.betas
# 计算 alpha 值的累积乘积
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# 设置所有时间步的值
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
# 初始化步骤索引和开始索引
self._step_index = None
self._begin_index = None
# 将 sigma 值移动到 CPU,避免过多的 CPU/GPU 通信
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def init_noise_sigma(self):
# 返回初始噪声分布的标准差
if self.config.timestep_spacing in ["linspace", "trailing"]:
# 返回 sigma 的最大值
return self.sigmas.max()
# 返回 sigma 最大值的平方加 1 的平方根
return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
def step_index(self):
"""
当前时间步的索引计数器。每次调度器步骤后增加 1。
"""
return self._step_index
@property
def begin_index(self):
"""
第一个时间步的索引。应该通过 `set_begin_index` 方法从管道设置。
"""
return self._begin_index
# 从 diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index 复制的
def set_begin_index(self, begin_index: int = 0):
"""
设置调度器的开始索引。此函数应在推理之前从管道运行。
参数:
begin_index (`int`):
调度器的开始索引。
"""
# 设置调度器的开始索引
self._begin_index = begin_index
def scale_model_input(
self,
# 输入的样本张量
sample: torch.Tensor,
# 当前时间步,可以是浮点数或张量
timestep: Union[float, torch.Tensor],
) -> torch.Tensor:
"""
确保与需要根据当前时间步调整去噪模型输入的调度器互换性。
参数:
sample (`torch.Tensor`):
输入样本。
timestep (`int`, *可选*):
当前扩散链中的时间步。
返回:
`torch.Tensor`:
一个经过缩放的输入样本。
"""
# 如果步骤索引尚未初始化,则根据时间步初始化它
if self.step_index is None:
self._init_step_index(timestep)
# 根据状态决定使用哪个 sigma 值
if self.state_in_first_order:
sigma = self.sigmas[self.step_index]
else:
sigma = self.sigmas_interpol[self.step_index]
# 将输入样本除以 sigma 的平方加一的平方根,进行缩放
sample = sample / ((sigma**2 + 1) ** 0.5)
# 返回经过缩放的样本
return sample
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
num_train_timesteps: Optional[int] = None,
@property
# 判断是否处于一阶状态,即样本是否为 None
def state_in_first_order(self):
return self.sample is None
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep 复制而来
def index_for_timestep(self, timestep, schedule_timesteps=None):
# 如果没有提供调度时间步,则使用默认时间步
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
# 找到与当前时间步相匹配的索引
indices = (schedule_timesteps == timestep).nonzero()
# 对于第一个 `step`,选择第二个索引(或只有一个时选择最后一个索引)
pos = 1 if len(indices) > 1 else 0
# 返回对应的索引值
return indices[pos].item()
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index 复制而来
def _init_step_index(self, timestep):
# 如果开始索引为 None,则初始化步骤索引
if self.begin_index is None:
# 如果时间步是张量,则转移到相同设备
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
# 根据时间步索引初始化步骤索引
self._step_index = self.index_for_timestep(timestep)
else:
# 否则将步骤索引设置为开始索引
self._step_index = self._begin_index
# 从 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t 复制而来
def _sigma_to_t(self, sigma, log_sigmas):
# 计算 sigma 的对数
log_sigma = np.log(np.maximum(sigma, 1e-10))
# 计算对数 sigma 与给定对数的差异
dists = log_sigma - log_sigmas[:, np.newaxis]
# 确定 sigma 的范围
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
# 获取低高对数 sigma 值
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
# 进行 sigma 的插值
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
# 将插值转换为时间范围
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
# 返回时间 t
return t
# 复制自 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""构造 Karras 等人(2022)提出的噪声调度。"""
# 确保其他调度器复制此函数时不会出现问题的黑客修复
# TODO: 将此逻辑添加到其他调度器中
if hasattr(self.config, "sigma_min"):
# 如果配置中有 sigma_min,则使用它
sigma_min = self.config.sigma_min
else:
# 否则设置为 None
sigma_min = None
if hasattr(self.config, "sigma_max"):
# 如果配置中有 sigma_max,则使用它
sigma_max = self.config.sigma_max
else:
# 否则设置为 None
sigma_max = None
# 如果 sigma_min 为 None,则使用输入信号的最后一个值
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
# 如果 sigma_max 为 None,则使用输入信号的第一个值
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 论文中使用的值 7.0
# 生成一个从 0 到 1 的线性 ramp,长度为 num_inference_steps
ramp = np.linspace(0, 1, num_inference_steps)
# 计算 sigma_min 的逆 rho 次方
min_inv_rho = sigma_min ** (1 / rho)
# 计算 sigma_max 的逆 rho 次方
max_inv_rho = sigma_max ** (1 / rho)
# 根据最大和最小的逆值以及 ramp 生成 sigma 序列
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
# 返回生成的 sigma 序列
return sigmas
def step(
self,
model_output: Union[torch.Tensor, np.ndarray],
timestep: Union[float, torch.Tensor],
sample: Union[torch.Tensor, np.ndarray],
return_dict: bool = True,
# 复制自 diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
# 确保 sigmas 和 timesteps 与 original_samples 具有相同的设备和数据类型
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
# 检查设备是否为 MPS 且 timesteps 是否为浮点数
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# MPS 不支持 float64 数据类型
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
# 将 timesteps 转换为相同设备和 float32 数据类型
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
# 将 schedule_timesteps 转换为 original_samples 设备
schedule_timesteps = self.timesteps.to(original_samples.device)
# 将 timesteps 转换为 original_samples 设备
timesteps = timesteps.to(original_samples.device)
# 当 scheduler 用于训练时,self.begin_index 为 None,或者管道未实现 set_begin_index
if self.begin_index is None:
# 根据 timesteps 计算步索引
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# 在第一个去噪步骤后调用 add_noise(用于图像修复)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# 在第一个去噪步骤之前调用 add_noise 以创建初始潜在图像(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]
# 根据步索引提取 sigma,并展平为一维
sigma = sigmas[step_indices].flatten()
# 如果 sigma 的维度少于 original_samples,则在最后一个维度添加维度
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
# 生成带噪声的样本,通过原始样本与噪声和 sigma 的乘积相加
noisy_samples = original_samples + noise * sigma
# 返回带噪声的样本
return noisy_samples
# 定义 __len__ 方法以返回训练时间步的数量
def __len__(self):
return self.config.num_train_timesteps
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import SchedulerMixin
logger = logging.get_logger(__name__)
@dataclass
class LCMSchedulerOutput(BaseOutput):
"""
调度器 `step` 函数输出的输出类。
参数:
prev_sample (`torch.Tensor`,形状为 `(batch_size, num_channels, height, width)` 的图像):
先前时间步的计算样本 `(x_{t-1})`。`prev_sample` 应作为下一个模型输入用于
去噪循环。
pred_original_sample (`torch.Tensor`,形状为 `(batch_size, num_channels, height, width)` 的图像):
基于当前时间步的模型输出的预测去噪样本 `(x_{0})`。
`pred_original_sample` 可用于预览进度或指导。
"""
prev_sample: torch.Tensor
denoised: Optional[torch.Tensor] = None
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
创建一个 beta 调度,离散化给定的 alpha_t_bar 函数,该函数定义了
(1-beta) 随时间的累积乘积,范围从 t = [0,1]。
包含一个 alpha_bar 函数,该函数接受参数 t,并将其转换为
扩散过程的 (1-beta) 累积乘积。
参数:
num_diffusion_timesteps (`int`): 要生成的 beta 数量。
max_beta (`float`): 使用的最大 beta 值;使用小于 1 的值以防止奇异性。
alpha_transform_type (`str`, *可选*, 默认为 `cosine`): alpha_bar 的噪声调度类型。
可选择 `cosine` 或 `exp`
返回:
betas (`np.ndarray`): 调度器用于步骤模型输出的 beta 值
"""
if alpha_transform_type == "cosine":
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
将 betas 重新缩放为零终端 SNR,基于 https://arxiv.org/pdf/2305.08891.pdf (算法 1)
参数:
betas (`torch.Tensor`):
初始化调度器时使用的 betas。
返回:
`torch.Tensor`: 具有零终端 SNR 的重新缩放的 betas
"""
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
alphas_bar_sqrt -= alphas_bar_sqrt_T
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
alphas_bar = alphas_bar_sqrt**2
alphas = alphas_bar[1:] / alphas_bar[:-1]
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
class LCMScheduler(SchedulerMixin, ConfigMixin):
"""
`LCMScheduler` 扩展了在去噪扩散概率模型 (DDPM) 中引入的去噪程序,并实现了
非马尔可夫引导。
此模型继承自 [`SchedulerMixin`] 和 [`ConfigMixin`]。[`~ConfigMixin`] 负责存储在调度器的
`__init__` 函数中传入的所有配置属性,例如 `num_train_timesteps`。它们可以通过
`scheduler.config.num_train_timesteps` 访问。[`SchedulerMixin`] 提供通用的加载和保存
功能,通过 [`SchedulerMixin.save_pretrained`] 和 [`~SchedulerMixin.from_pretrained`] 函数。
"""
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "scaled_linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
original_inference_steps: int = 50,
clip_sample: bool = False,
clip_sample_range: float = 1.0,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
timestep_spacing: str = "leading",
timestep_scaling: float = 10.0,
rescale_betas_zero_snr: bool = False,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
self.init_noise_sigma = 1.0
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
self.custom_timesteps = False
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step_index(self):
return self._step_index
@property
def begin_index(self):
"""
起始时间步的索引。应通过管道使用 `set_begin_index` 方法设置。
"""
return self._begin_index
def set_begin_index(self, begin_index: int = 0):
"""
设置调度器的起始索引。此函数应在推理前通过管道运行。
参数:
begin_index (`int`):
调度器的起始索引。
"""
self._begin_index = begin_index
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
"""
确保与需要根据当前时间步缩放去噪模型输入的调度器的互换性。
参数:
sample (`torch.Tensor`):
输入样本。
timestep (`int`, *可选*):
扩散链中的当前时间步。
返回:
`torch.Tensor`:
缩放后的输入样本。
"""
return sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
"动态阈值处理:在每个采样步骤中,我们将 s 设置为 xt0(在时间步 t 时 x_0 的预测)的某个百分位绝对像素值,
如果 s > 1,则将 xt0 阈值化到范围 [-s, s],然后除以 s。动态阈值处理将饱和像素(接近 -1 和 1 的像素)向内推,
从而主动防止每个步骤的像素饱和。我们发现动态阈值处理显著改善了照片真实感以及图像-文本对齐,特别是在使用非常大的
引导权重时。"
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float()
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs()
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
)
s = s.unsqueeze(1)
sample = torch.clamp(sample, -s, s) / s
sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
return sample
def set_timesteps(
self,
num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None,
original_inference_steps: Optional[int] = None,
timesteps: Optional[List[int]] = None,
strength: int = 1.0,
def get_scalings_for_boundary_condition_discrete(self, timestep):
self.sigma_data = 0.5
scaled_timestep = timestep * self.config.timestep_scaling
c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
return c_skip, c_out
def step(
self,
model_output: torch.Tensor,
timestep: int,
sample: torch.Tensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
timesteps = timesteps.to(sample.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
def __len__(self):
return self.config.num_train_timesteps
def previous_timestep(self, timestep):
if self.custom_timesteps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
else:
num_inference_steps = (
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
)
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
return prev_t
import math
import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
@dataclass
class LMSDiscreteSchedulerOutput(BaseOutput):
"""
调度器`step`函数输出的输出类。
参数:
prev_sample (`torch.Tensor`形状为`(batch_size, num_channels, height, width)`的图像):
上一时间步的计算样本`(x_{t-1})`。`prev_sample`应作为下一个模型输入用于去噪循环。
pred_original_sample (`torch.Tensor`形状为`(batch_size, num_channels, height, width)`的图像):
基于当前时间步模型输出的预测去噪样本`(x_{0})`。
`pred_original_sample`可用于预览进度或进行指导。
"""
prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
创建一个beta调度器,该调度器离散化给定的alpha_t_bar函数,
该函数定义了从t = [0,1]开始的(1-beta)的累积乘积。
包含一个alpha_bar函数,该函数接受一个参数t,并将其转换为在扩散过程的这一部分
中(1-beta)的累积乘积。
参数:
num_diffusion_timesteps (`int`): 要生成的beta数量。
max_beta (`float`): 使用的最大beta值;使用小于1的值以防止奇异性。
alpha_transform_type (`str`, *可选*, 默认为`cosine`): alpha_bar的噪声调度类型。
可选择`cosine`或`exp`
返回:
betas (`np.ndarray`): 调度器用于更新模型输出的beta值
"""
if alpha_transform_type == "cosine":
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
一个用于离散 beta 计划的线性多步调度器。
该模型继承自 [`SchedulerMixin`] 和 [`ConfigMixin`]。请查看超类文档以了解库为所有调度器实现的通用方法,如加载和保存。
参数:
num_train_timesteps (`int`, 默认值为 1000):
用于训练模型的扩散步骤数量。
beta_start (`float`, 默认值为 0.0001):
推断的起始 `beta` 值。
beta_end (`float`, 默认值为 0.02):
最终的 `beta` 值。
beta_schedule (`str`, 默认值为 `"linear"`):
beta 计划,将 beta 范围映射到一系列用于模型步进的 betas。可以选择 `linear` 或 `scaled_linear`。
trained_betas (`np.ndarray`, *可选*):
直接将 beta 数组传递给构造函数,以绕过 `beta_start` 和 `beta_end`。
use_karras_sigmas (`bool`, *可选*, 默认值为 `False`):
是否在采样过程中使用 Karras sigmas 作为噪声计划中的步长。如果为 `True`,则根据噪声水平序列 {σi} 确定 sigmas。
prediction_type (`str`, 默认值为 `epsilon`, *可选*):
调度器函数的预测类型;可以是 `epsilon`(预测扩散过程的噪声)、`sample`(直接预测带噪声的样本)或 `v_prediction`(参见 [Imagen Video](https://imagen.research.google/video/paper.pdf) 论文的第 2.4 节)。
timestep_spacing (`str`, 默认值为 `"linspace"`):
时间步的缩放方式。有关更多信息,请参阅 [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) 的表 2。
steps_offset (`int`, 默认值为 0):
添加到推断步骤的偏移量,某些模型系列需要该偏移量。
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
use_karras_sigmas: Optional[bool] = False,
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
self.num_inference_steps = None
self.use_karras_sigmas = use_karras_sigmas
self.set_timesteps(num_train_timesteps, None)
self.derivatives = []
self.is_scale_input_called = False
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu")
@property
def init_noise_sigma(self):
if self.config.timestep_spacing in ["linspace", "trailing"]:
return self.sigmas.max()
return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
def step_index(self):
"""
当前时间步的索引计数器。每次调度器步骤后增加 1。
"""
return self._step_index
@property
def begin_index(self):
"""
第一个时间步的索引。应该通过 `set_begin_index` 方法从管道设置。
"""
return self._begin_index
def set_begin_index(self, begin_index: int = 0):
"""
设置调度器的起始索引。此函数应该在推理之前从管道运行。
参数:
begin_index (`int`):
调度器的起始索引。
"""
self._begin_index = begin_index
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
"""
确保与需要根据当前时间步缩放去噪模型输入的调度器的互换性。
参数:
sample (`torch.Tensor`):
输入样本。
timestep (`float` or `torch.Tensor`):
扩散链中的当前时间步。
返回:
`torch.Tensor`:
缩放后的输入样本。
"""
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
self.is_scale_input_called = True
return sample
def get_lms_coefficient(self, order, t, current_order):
"""
计算线性多步系数。
参数:
order ():
t ():
current_order ():
"""
def lms_derivative(tau):
prod = 1.0
for k in range(order):
if current_order == k:
continue
prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k])
return prod
integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0]
return integrated_coeff
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
设置用于扩散链的离散时间步长(在推理之前运行)。
参数:
num_inference_steps (`int`):
生成样本时使用的扩散步骤数。
device (`str` 或 `torch.device`, *可选*):
要将时间步长移动到的设备。如果为 `None`,则不移动时间步长。
"""
self.num_inference_steps = num_inference_steps
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
::-1
].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu")
self.derivatives = []
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def _sigma_to_t(self, sigma, log_sigmas):
log_sigma = np.log(np.maximum(sigma, 1e-10))
dists = log_sigma - log_sigmas[:, np.newaxis]
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t
def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
"""构建 Karras 等人 (2022) 的噪声调度。"""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
rho = 7.0
ramp = np.linspace(0, 1, self.num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
order: int = 4,
return_dict: bool = True,
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
step_indices = [self.step_index] * timesteps.shape[0]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
noisy_samples = original_samples + noise * sigma
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import flax
import jax.numpy as jnp
from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
)
@flax.struct.dataclass
class LMSDiscreteSchedulerState:
common: CommonSchedulerState
init_noise_sigma: jnp.ndarray
timesteps: jnp.ndarray
sigmas: jnp.ndarray
num_inference_steps: Optional[int] = None
derivatives: Optional[jnp.ndarray] = None
@classmethod
def create(
cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
):
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
@dataclass
class FlaxLMSSchedulerOutput(FlaxSchedulerOutput):
state: LMSDiscreteSchedulerState
class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
线性多步调度器,用于离散beta调度。基于Katherine Crowson的原始k-diffusion实现:
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
[`~ConfigMixin`]负责存储传递给调度器`__init__`函数的所有配置属性,例如`num_train_timesteps`。
可以通过`scheduler.config.num_train_timesteps`访问。
[`SchedulerMixin`]提供通过[`SchedulerMixin.save_pretrained`]和
[`~SchedulerMixin.from_pretrained`]函数进行的通用加载和保存功能。
"""
Args:
num_train_timesteps (`int`): 训练模型时使用的扩散步骤数。
beta_start (`float`): 推理时的起始 `beta` 值。
beta_end (`float`): 最终 `beta` 值。
beta_schedule (`str`):
beta 调度,表示从 beta 范围到一系列 beta 的映射,用于模型的步进。可选择
`linear` 或 `scaled_linear`。
trained_betas (`jnp.ndarray`, optional):
直接传递 beta 数组到构造函数的选项,以绕过 `beta_start`、`beta_end` 等。
prediction_type (`str`, default `epsilon`, optional):
调度函数的预测类型,可能值有 `epsilon`(预测扩散过程的噪声)、`sample`(直接预测带噪声的样本)或 `v_prediction`(见第 2.4 节
https://imagen.research.google/video/paper.pdf)。
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
用于参数和计算的 `dtype` 类型。
"""
# 创建一个包含 FlaxKarrasDiffusionSchedulers 中每个调度器名称的列表
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
# 定义一个数据类型属性
dtype: jnp.dtype
# 定义属性,指示是否有状态
@property
def has_state(self):
return True
# 注册构造函数到配置
@register_to_config
def __init__(
# 初始化时的参数,设定默认值
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32,
):
# 将传入的数据类型参数赋值给实例变量
self.dtype = dtype
# 创建状态的方法,接受一个可选的公共调度器状态
def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState:
# 如果没有传入公共状态,则创建一个新的公共状态
if common is None:
common = CommonSchedulerState.create(self)
# 生成一个从 0 到 num_train_timesteps 的时间步数组,并反转
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
# 计算每个时间步的标准差,使用公式
sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5
# 初始噪声分布的标准差
init_noise_sigma = sigmas.max()
# 创建并返回一个 LMSDiscreteSchedulerState 实例,传入相关参数
return LMSDiscreteSchedulerState.create(
common=common,
init_noise_sigma=init_noise_sigma,
timesteps=timesteps,
sigmas=sigmas,
)
# 定义一个方法用于缩放模型输入以匹配 K-LMS 算法
def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray:
"""
通过 `(sigma**2 + 1) ** 0.5` 缩放去噪模型输入以匹配 K-LMS 算法。
参数:
state (`LMSDiscreteSchedulerState`):
`FlaxLMSDiscreteScheduler` 状态数据类实例。
sample (`jnp.ndarray`):
当前由扩散过程创建的样本实例。
timestep (`int`):
扩散链中的当前离散时间步。
返回:
`jnp.ndarray`: 缩放后的输入样本
"""
# 找到与当前时间步相等的索引
(step_index,) = jnp.where(state.timesteps == timestep, size=1)
# 获取索引的第一个值
step_index = step_index[0]
# 获取当前时间步对应的 sigma 值
sigma = state.sigmas[step_index]
# 将样本按缩放因子进行缩放
sample = sample / ((sigma**2 + 1) ** 0.5)
# 返回缩放后的样本
return sample
# 定义一个方法用于计算线性多步系数
def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order, t, current_order):
"""
计算线性多步系数。
参数:
order (TODO):
t (TODO):
current_order (TODO):
"""
# 定义一个内部函数用于计算 LMS 导数
def lms_derivative(tau):
prod = 1.0
# 遍历所有步长,计算导数的乘积
for k in range(order):
# 跳过当前的阶数
if current_order == k:
continue
# 计算导数乘积
prod *= (tau - state.sigmas[t - k]) / (state.sigmas[t - current_order] - state.sigmas[t - k])
# 返回导数值
return prod
# 使用数值积分计算集成系数
integrated_coeff = integrate.quad(lms_derivative, state.sigmas[t], state.sigmas[t + 1], epsrel=1e-4)[0]
# 返回集成系数
return integrated_coeff
# 定义一个方法用于设置扩散链使用的时间步
def set_timesteps(
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> LMSDiscreteSchedulerState:
"""
设置用于扩散链的时间步。在推理之前运行的辅助函数。
参数:
state (`LMSDiscreteSchedulerState`):
`FlaxLMSDiscreteScheduler` 状态数据类实例。
num_inference_steps (`int`):
在生成样本时使用的扩散步骤数。
"""
# 生成从最大训练时间步到 0 的线性时间步数组
timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype)
# 计算时间步的低索引和高索引
low_idx = jnp.floor(timesteps).astype(jnp.int32)
high_idx = jnp.ceil(timesteps).astype(jnp.int32)
# 计算时间步的分数部分
frac = jnp.mod(timesteps, 1.0)
# 计算 sigma 值
sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5
# 插值计算 sigma 值
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
# 在 sigma 数组末尾添加 0.0
sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)])
# 将时间步转换为整型
timesteps = timesteps.astype(jnp.int32)
# 初始化导数的值
derivatives = jnp.zeros((0,) + shape, dtype=self.dtype)
# 返回更新后的状态
return state.replace(
timesteps=timesteps,
sigmas=sigmas,
num_inference_steps=num_inference_steps,
derivatives=derivatives,
)
# 定义一个方法,用于在扩散过程中预测上一个时间步的样本
def step(
self,
state: LMSDiscreteSchedulerState, # 当前调度器状态实例
model_output: jnp.ndarray, # 从学习到的扩散模型得到的直接输出
timestep: int, # 当前扩散链中的离散时间步
sample: jnp.ndarray, # 当前正在通过扩散过程生成的样本实例
order: int = 4, # 多步推理的系数
return_dict: bool = True, # 是否返回元组而非 FlaxLMSSchedulerOutput 类
) -> Union[FlaxLMSSchedulerOutput, Tuple]:
"""
通过逆转 SDE 预测上一个时间步的样本。核心函数从学习到的模型输出(通常是预测噪声)传播扩散过程。
Args:
state (`LMSDiscreteSchedulerState`): FlaxLMSDiscreteScheduler 的状态数据类实例。
model_output (`jnp.ndarray`): 从学习到的扩散模型直接输出。
timestep (`int`): 当前离散时间步。
sample (`jnp.ndarray`):
当前通过扩散过程创建的样本实例。
order: 多步推理的系数。
return_dict (`bool`): 是否返回元组而非 FlaxLMSSchedulerOutput 类。
Returns:
[`FlaxLMSSchedulerOutput`] or `tuple`: 如果 `return_dict` 为 True,返回 [`FlaxLMSSchedulerOutput`],否则返回一个元组。当返回元组时,第一个元素是样本张量。
"""
# 检查推理步骤是否为 None,如果是则抛出错误
if state.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
# 获取当前时间步的 sigma 值
sigma = state.sigmas[timestep]
# 1. 从 sigma 缩放的预测噪声计算预测的原始样本 (x_0)
if self.config.prediction_type == "epsilon":
# 计算预测的原始样本
pred_original_sample = sample - sigma * model_output
elif self.config.prediction_type == "v_prediction":
# 使用 v 预测公式计算预测的原始样本
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
else:
# 如果 prediction_type 不符合预期,抛出错误
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
# 2. 转换为 ODE 导数
derivative = (sample - pred_original_sample) / sigma # 计算导数
# 将新的导数添加到状态中
state = state.replace(derivatives=jnp.append(state.derivatives, derivative))
# 如果导数长度超过了设定的 order,删除最早的导数
if len(state.derivatives) > order:
state = state.replace(derivatives=jnp.delete(state.derivatives, 0))
# 3. 计算线性多步系数
order = min(timestep + 1, order) # 确保 order 不超过当前时间步
# 生成多步系数
lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)]
# 4. 基于导数路径计算上一个样本
prev_sample = sample + sum(
coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(state.derivatives))
) # 计算上一个样本
# 如果不需要返回字典,返回元组
if not return_dict:
return (prev_sample, state)
# 返回 FlaxLMSSchedulerOutput 类实例
return FlaxLMSSchedulerOutput(prev_sample=prev_sample, state=state)
# 定义添加噪声的函数,接受调度状态、原始样本、噪声和时间步
def add_noise(
self,
state: LMSDiscreteSchedulerState,
original_samples: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
# 从调度状态中获取指定时间步的 sigma 值,并扁平化
sigma = state.sigmas[timesteps].flatten()
# 将 sigma 的形状广播到噪声的形状
sigma = broadcast_to_shape_from_left(sigma, noise.shape)
# 将噪声与原始样本结合,生成带噪声的样本
noisy_samples = original_samples + noise * sigma
# 返回带噪声的样本
return noisy_samples
# 定义获取对象长度的方法
def __len__(self):
# 返回训练时间步的数量
return self.config.num_train_timesteps
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
创建一个 beta 调度,离散化给定的 alpha_t_bar 函数,定义了时间 t=[0,1] 上 (1-beta) 的累积乘积。
包含一个 alpha_bar 函数,接受 t 参数并将其转换为扩散过程中的 (1-beta) 的累积乘积。
参数:
num_diffusion_timesteps (`int`): 要生成的 beta 数量。
max_beta (`float`): 要使用的最大 beta 值;使用小于 1 的值以
防止奇点。
alpha_transform_type (`str`, *可选*,默认为 `cosine`): alpha_bar 的噪声调度类型。
选择 `cosine` 或 `exp`
返回:
betas (`np.ndarray`): 调度器用于更新模型输出的 betas
"""
if alpha_transform_type == "cosine":
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
class PNDMScheduler(SchedulerMixin, ConfigMixin):
"""
`PNDMScheduler` 使用伪数值方法进行扩散模型的调度,如龙格-库塔和线性多步方法。
此模型继承自 [`SchedulerMixin`] 和 [`ConfigMixin`]。有关所有调度器的通用方法的库文档,请检查超类文档,如加载和保存。
# 参数说明
Args:
num_train_timesteps (`int`, defaults to 1000):
# 模型训练的扩散步骤数量,默认为 1000
The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.0001):
# 推理的起始 `beta` 值,默认为 0.0001
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
# 最终的 `beta` 值,默认为 0.02
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
# beta 调度策略,从 beta 范围到模型步进的 beta 序列的映射。可选值包括 `linear`、`scaled_linear` 或 `squaredcos_cap_v2`
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, *optional*):
# 直接将 beta 数组传递给构造函数,以绕过 `beta_start` 和 `beta_end`
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
skip_prk_steps (`bool`, defaults to `False`):
# 允许调度器跳过原始论文中定义的 Runge-Kutta 步骤,这些步骤在 PLMS 步骤之前是必需的
Allows the scheduler to skip the Runge-Kutta steps defined in the original paper as being required before
PLMS steps.
set_alpha_to_one (`bool`, defaults to `False`):
# 每个扩散步骤使用该步骤和前一步的 alpha 乘积值。对于最后一步没有前一个 alpha。当选项为 `True` 时,前一个 alpha 乘积固定为 1, 否则使用第 0 步的 alpha 值
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the alpha value at step 0.
prediction_type (`str`, defaults to `epsilon`, *optional*):
# 调度函数的预测类型;可以是 `epsilon`(预测扩散过程的噪声)或 `v_prediction`(参见 [Imagen Video](https://imagen.research.google/video/paper.pdf) 论文的 2.4 节)
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process)
or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf)
paper).
timestep_spacing (`str`, defaults to `"leading"`):
# 时间步的缩放方式。有关更多信息,请参见 [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) 的表 2
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0):
# 添加到推理步骤的偏移量,一些模型家族需要这个偏移
An offset added to the inference steps, as required by some model families.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
skip_prk_steps: bool = False,
set_alpha_to_one: bool = False,
prediction_type: str = "epsilon",
timestep_spacing: str = "leading",
steps_offset: int = 0,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
self.init_noise_sigma = 1.0
self.pndm_order = 4
self.cur_model_output = 0
self.counter = 0
self.cur_sample = None
self.ets = []
self.num_inference_steps = None
self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self.prk_timesteps = None
self.plms_timesteps = None
self.timesteps = None
def step(
model_output: torch.Tensor,
timestep: int,
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
预测前一个时间步的样本,通过逆向 SDE 进行。这一函数从学习模型的输出(通常是预测的噪声)中传播扩散过程,
并根据内部变量 `counter` 调用 [`~PNDMScheduler.step_prk`] 或 [`~PNDMScheduler.step_plms`]。
参数:
model_output (`torch.Tensor`):
来自学习扩散模型的直接输出。
timestep (`int`):
当前扩散链中的离散时间步。
sample (`torch.Tensor`):
当前通过扩散过程生成的样本实例。
return_dict (`bool`):
是否返回 [`~schedulers.scheduling_utils.SchedulerOutput`] 或 `tuple`。
返回:
[`~schedulers.scheduling_utils.SchedulerOutput`] 或 `tuple`:
如果 return_dict 为 `True`,返回 [`~schedulers.scheduling_utils.SchedulerOutput`],否则返回一个
元组,其中第一个元素是样本张量。
"""
if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
else:
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
def step_prk(
self,
model_output: torch.Tensor,
timestep: int,
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
通过逆向SDE预测前一个时间步的样本。该函数使用Runge-Kutta方法传播样本。
进行四次前向传递以逼近微分方程的解。
参数:
model_output (`torch.Tensor`):
来自学习的扩散模型的直接输出。
timestep (`int`):
扩散链中的当前离散时间步。
sample (`torch.Tensor`):
通过扩散过程创建的当前样本实例。
return_dict (`bool`):
是否返回一个[`~schedulers.scheduling_utils.SchedulerOutput`]或元组。
返回:
[`~schedulers.scheduling_utils.SchedulerOutput`]或`tuple`:
如果return_dict为`True`,返回[`~schedulers.scheduling_utils.SchedulerOutput`],否则返回一个
元组,其第一个元素是样本张量。
"""
if self.num_inference_steps is None:
raise ValueError(
"推断步骤数为'None',创建调度器后需要运行'set_timesteps'"
)
diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
prev_timestep = timestep - diff_to_prev
timestep = self.prk_timesteps[self.counter // 4 * 4]
if self.counter % 4 == 0:
self.cur_model_output += 1 / 6 * model_output
self.ets.append(model_output)
self.cur_sample = sample
elif (self.counter - 1) % 4 == 0:
self.cur_model_output += 1 / 3 * model_output
elif (self.counter - 2) % 4 == 0:
self.cur_model_output += 1 / 3 * model_output
elif (self.counter - 3) % 4 == 0:
model_output = self.cur_model_output + 1 / 6 * model_output
self.cur_model_output = 0
cur_sample = self.cur_sample if self.cur_sample is not None else sample
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
self.counter += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def step_plms(
self,
model_output: torch.Tensor,
timestep: int,
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
预测从上一个时间步生成的样本,通过逆转SDE。该函数使用线性多步法传播样本。
它多次执行一次前向传递以近似解决方案。
参数:
model_output (`torch.Tensor`):
学习的扩散模型的直接输出。
timestep (`int`):
当前扩散链中的离散时间步。
sample (`torch.Tensor`):
通过扩散过程生成的当前样本实例。
return_dict (`bool`):
是否返回 [`~schedulers.scheduling_utils.SchedulerOutput`] 或元组。
返回:
[`~schedulers.scheduling_utils.SchedulerOutput`] 或 `tuple`:
如果 return_dict 为 `True`,返回 [`~schedulers.scheduling_utils.SchedulerOutput`],否则返回一个元组,元组的第一个元素是样本张量。
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if not self.config.skip_prk_steps and len(self.ets) < 3:
raise ValueError(
f"{self.__class__} can only be run AFTER scheduler has been run "
"in 'prk' mode for at least 12 iterations "
"See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
"for more information."
)
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
if self.counter != 1:
self.ets = self.ets[-3:]
self.ets.append(model_output)
else:
prev_timestep = timestep
timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
if len(self.ets) == 1 and self.counter == 0:
model_output = model_output
self.cur_sample = sample
elif len(self.ets) == 1 and self.counter == 1:
model_output = (model_output + self.ets[-1]) / 2
sample = self.cur_sample
self.cur_sample = None
elif len(self.ets) == 2:
model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
elif len(self.ets) == 3:
model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
else:
model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
self.counter += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.Tensor`):
The input sample.
Returns:
`torch.Tensor`:
A scaled input sample.
"""
return sample
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
if self.config.prediction_type == "v_prediction":
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
elif self.config.prediction_type != "epsilon":
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`"
)
sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
alpha_prod_t * beta_prod_t * alpha_prod_t_prev
) ** (0.5)
prev_sample = (
sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
)
return prev_sample
def add_noise(
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import flax
import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
add_noise_common,
)
@flax.struct.dataclass
class PNDMSchedulerState:
common: CommonSchedulerState
final_alpha_cumprod: jnp.ndarray
init_noise_sigma: jnp.ndarray
timesteps: jnp.ndarray
num_inference_steps: Optional[int] = None
prk_timesteps: Optional[jnp.ndarray] = None
plms_timesteps: Optional[jnp.ndarray] = None
cur_model_output: Optional[jnp.ndarray] = None
counter: Optional[jnp.int32] = None
cur_sample: Optional[jnp.ndarray] = None
ets: Optional[jnp.ndarray] = None
@classmethod
def create(
cls,
common: CommonSchedulerState,
final_alpha_cumprod: jnp.ndarray,
init_noise_sigma: jnp.ndarray,
timesteps: jnp.ndarray,
):
return cls(
common=common,
final_alpha_cumprod=final_alpha_cumprod,
init_noise_sigma=init_noise_sigma,
timesteps=timesteps,
)
@dataclass
class FlaxPNDMSchedulerOutput(FlaxSchedulerOutput):
state: PNDMSchedulerState
class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
namely Runge-Kutta method and a linear multi-step method.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2202.09778
"""
Args:
num_train_timesteps (`int`): 训练模型所使用的扩散步骤数量。
beta_start (`float`): 推理的起始 `beta` 值。
beta_end (`float`): 最终的 `beta` 值。
beta_schedule (`str`):
beta 调度,表示从一个 beta 范围到一系列 beta 的映射,用于模型的步骤选择。可选值为
`linear`、`scaled_linear` 或 `squaredcos_cap_v2`。
trained_betas (`jnp.ndarray`, optional):
可选参数,直接将 beta 数组传递给构造函数,以跳过 `beta_start`、`beta_end` 等设置。
skip_prk_steps (`bool`):
允许调度器跳过原论文中定义的 Runge-Kutta 步骤,这些步骤在 plms 步骤之前是必要的;默认为 `False`。
set_alpha_to_one (`bool`, default `False`):
每个扩散步骤使用该步骤和前一个步骤的 alpha 乘积的值。对于最后一步没有前一个 alpha。当此选项为 `True` 时,前一个 alpha 乘积固定为 `1`,否则使用步骤 0 的 alpha 值。
steps_offset (`int`, default `0`):
添加到推理步骤的偏移量,某些模型系列需要此偏移。
prediction_type (`str`, default `epsilon`, optional):
调度函数的预测类型,选项包括 `epsilon`(预测扩散过程中的噪声)、`sample`(直接预测带噪声的样本)或 `v_prediction`(见文献 2.4 https://imagen.research.google/video/paper.pdf)。
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
用于参数和计算的 `dtype` 类型。
"""
# 获取 FlaxKarrasDiffusionSchedulers 中所有兼容的名称
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
# 定义数据类型
dtype: jnp.dtype
# 定义 PNDM 的阶数
pndm_order: int
# 定义属性以检查是否具有状态
@property
def has_state(self):
# 返回 True,表示该对象具有状态
return True
# 注册到配置中,并定义初始化函数
@register_to_config
def __init__(
# 设置训练时的扩散步骤数量,默认为 1000
num_train_timesteps: int = 1000,
# 设置推理的起始 beta 值,默认为 0.0001
beta_start: float = 0.0001,
# 设置最终的 beta 值,默认为 0.02
beta_end: float = 0.02,
# 设置 beta 调度类型,默认为 "linear"
beta_schedule: str = "linear",
# 可选参数,传递已训练的 beta 数组
trained_betas: Optional[jnp.ndarray] = None,
# 设置是否跳过 Runge-Kutta 步骤,默认为 False
skip_prk_steps: bool = False,
# 设置是否将 alpha 固定为 1,默认为 False
set_alpha_to_one: bool = False,
# 设置推理步骤的偏移量,默认为 0
steps_offset: int = 0,
# 设置预测类型,默认为 "epsilon"
prediction_type: str = "epsilon",
# 设置数据类型,默认为 jnp.float32
dtype: jnp.dtype = jnp.float32,
):
# 将数据类型赋值给实例变量
self.dtype = dtype
# 当前仅支持 F-PNDM,即 Runge-Kutta 方法
# 有关算法的更多信息,请参见论文:https://arxiv.org/pdf/2202.09778.pdf
# 主要查看公式 (9)、(12)、(13) 和算法 2。
# 将 PNDM 阶数设置为 4
self.pndm_order = 4
# 创建状态的方法,接受一个可选的 CommonSchedulerState 参数
def create_state(self, common: Optional[CommonSchedulerState] = None) -> PNDMSchedulerState:
# 如果 common 参数为 None,则创建一个新的 CommonSchedulerState 实例
if common is None:
common = CommonSchedulerState.create(self)
# 在每个 ddim 步骤中,我们查看前一个 alphas_cumprod
# 对于最后一步,由于我们已经处于 0,因此没有前一个 alphas_cumprod
# `set_alpha_to_one` 决定我们是否将该参数简单设置为 1,还是
# 使用“非前一个”的最终 alpha。
final_alpha_cumprod = (
jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0]
)
# 初始噪声分布的标准差
init_noise_sigma = jnp.array(1.0, dtype=self.dtype)
# 创建一个反向的时间步数组,从 num_train_timesteps 开始
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
# 返回一个新的 PNDMSchedulerState 实例,包含 common、final_alpha_cumprod、init_noise_sigma 和 timesteps
return PNDMSchedulerState.create(
common=common,
final_alpha_cumprod=final_alpha_cumprod,
init_noise_sigma=init_noise_sigma,
timesteps=timesteps,
)
# 设置用于扩散链的离散时间步,推理前运行的辅助函数
def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState:
"""
设置用于扩散链的离散时间步,推理前运行的辅助函数。
参数:
state (`PNDMSchedulerState`):
`FlaxPNDMScheduler` 状态数据类实例。
num_inference_steps (`int`):
生成样本时使用的扩散步骤数量。
shape (`Tuple`):
要生成的样本形状。
"""
# 计算每个推理步骤的步长比
step_ratio = self.config.num_train_timesteps // num_inference_steps
# 通过乘以比率生成整数时间步
# 四舍五入以避免 num_inference_step 为 3 的幂时出现问题
_timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + self.config.steps_offset
if self.config.skip_prk_steps:
# 对于某些模型(如稳定扩散),可以/应该跳过 prk 步骤以产生更好的结果。
# 使用 PNDM 时,如果配置跳过 prk 步骤,基于 crowsonkb 的 PLMS 采样实现
prk_timesteps = jnp.array([], dtype=jnp.int32)
# 生成 plms 时间步,将最后的时间步反转并添加到前面
plms_timesteps = jnp.concatenate([_timesteps[:-1], _timesteps[-2:-1], _timesteps[-1:]])[::-1]
else:
# 生成 prk 时间步,重复并添加偏移
prk_timesteps = _timesteps[-self.pndm_order :].repeat(2) + jnp.tile(
jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=jnp.int32),
self.pndm_order,
)
# 反转并去掉边界的 prk 时间步
prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1]
# 反转 plms 时间步
plms_timesteps = _timesteps[:-3][::-1]
# 合并 prk 和 plms 时间步
timesteps = jnp.concatenate([prk_timesteps, plms_timesteps])
# 初始化运行值
# 创建当前模型输出的零数组,形状为传入的 shape
cur_model_output = jnp.zeros(shape, dtype=self.dtype)
# 初始化计数器为 0
counter = jnp.int32(0)
# 创建当前样本的零数组,形状为传入的 shape
cur_sample = jnp.zeros(shape, dtype=self.dtype)
# 创建一个额外的数组,用于存储中间结果
ets = jnp.zeros((4,) + shape, dtype=self.dtype)
# 返回更新后的状态,包含新的时间步和运行值
return state.replace(
timesteps=timesteps,
num_inference_steps=num_inference_steps,
prk_timesteps=prk_timesteps,
plms_timesteps=plms_timesteps,
cur_model_output=cur_model_output,
counter=counter,
cur_sample=cur_sample,
ets=ets,
)
# 定义缩放模型输入的函数
def scale_model_input(
self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
) -> jnp.ndarray:
# 声明函数返回类型为 jnp.ndarray(JAX 的 ndarray)
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
sample (`jnp.ndarray`): input sample
timestep (`int`, optional): current timestep
Returns:
`jnp.ndarray`: scaled input sample
"""
return sample
# 返回输入样本,当前未进行任何处理
def step(
# 定义 step 方法
self,
state: PNDMSchedulerState,
# 参数 state,类型为 PNDMSchedulerState,表示调度器的状态数据类实例
model_output: jnp.ndarray,
# 参数 model_output,类型为 jnp.ndarray,表示模型的输出
timestep: int,
# 参数 timestep,类型为 int,表示当前时间步
sample: jnp.ndarray,
# 参数 sample,类型为 jnp.ndarray,表示输入样本
return_dict: bool = True,
# 参数 return_dict,类型为 bool,默认为 True,表示是否返回字典格式的结果
) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
"""
预测在上一个时间步的样本,通过反转 SDE。核心功能是从学习的模型输出传播扩散过程
(通常是预测的噪声)。
此函数根据内部变量 `counter` 调用 `step_prk()` 或 `step_plms()`。
Args:
state (`PNDMSchedulerState`): `FlaxPNDMScheduler` 状态数据类实例。
model_output (`jnp.ndarray`): 来自学习扩散模型的直接输出。
timestep (`int`): 当前扩散链中的离散时间步。
sample (`jnp.ndarray`):
正在通过扩散过程创建的当前样本实例。
return_dict (`bool`): 返回元组而不是 `FlaxPNDMSchedulerOutput` 类的选项。
Returns:
[`FlaxPNDMSchedulerOutput`] 或 `tuple`: 如果 `return_dict` 为 True,则返回 [`FlaxPNDMSchedulerOutput`],
否则返回 `tuple`。返回元组时,第一个元素是样本张量。
"""
# 检查推理步骤数量是否为 None,抛出错误以提醒用户
if state.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
# 如果配置跳过 PRK 步骤,调用 PLMS 步骤
if self.config.skip_prk_steps:
prev_sample, state = self.step_plms(state, model_output, timestep, sample)
else:
# 否则,首先执行 PRK 步骤
prk_prev_sample, prk_state = self.step_prk(state, model_output, timestep, sample)
# 然后执行 PLMS 步骤
plms_prev_sample, plms_state = self.step_plms(state, model_output, timestep, sample)
# 检查当前计数器是否小于 PRK 时间步的长度
cond = state.counter < len(state.prk_timesteps)
# 根据条件选择前一个样本
prev_sample = jax.lax.select(cond, prk_prev_sample, plms_prev_sample)
# 更新状态,选择相应的当前模型输出和其他状态变量
state = state.replace(
cur_model_output=jax.lax.select(cond, prk_state.cur_model_output, plms_state.cur_model_output),
ets=jax.lax.select(cond, prk_state.ets, plms_state.ets),
cur_sample=jax.lax.select(cond, prk_state.cur_sample, plms_state.cur_sample),
counter=jax.lax.select(cond, prk_state.counter, plms_state.counter),
)
# 如果不返回字典,则返回前一个样本和状态的元组
if not return_dict:
return (prev_sample, state)
# 否则返回 FlaxPNDMSchedulerOutput 对象
return FlaxPNDMSchedulerOutput(prev_sample=prev_sample, state=state)
# 定义 step_prk 方法,用于执行 PRK 步骤
def step_prk(
self,
state: PNDMSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
) -> Union[FlaxPNDMSchedulerOutput, Tuple]:
"""
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
solution to the differential equation.
Args:
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class
Returns:
[`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is the sample tensor.
"""
# 检查推理步骤数量是否为 None,如果是则抛出异常
if state.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
# 根据当前计数器决定与上一步的差值,计算上一步的时间步
diff_to_prev = jnp.where(
state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2
)
prev_timestep = timestep - diff_to_prev # 计算前一个时间步
timestep = state.prk_timesteps[state.counter // 4 * 4] # 更新当前时间步
# 选择当前模型输出,基于计数器的余数决定逻辑
model_output = jax.lax.select(
(state.counter % 4) != 3,
model_output, # 余数为 0, 1, 2
state.cur_model_output + 1 / 6 * model_output, # 余数为 3
)
# 更新状态,替换当前模型输出、ets 和当前样本
state = state.replace(
cur_model_output=jax.lax.select_n(
state.counter % 4,
state.cur_model_output + 1 / 6 * model_output, # 余数为 0
state.cur_model_output + 1 / 3 * model_output, # 余数为 1
state.cur_model_output + 1 / 3 * model_output, # 余数为 2
jnp.zeros_like(state.cur_model_output), # 余数为 3
),
ets=jax.lax.select(
(state.counter % 4) == 0,
state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # 余数为 0
state.ets, # 余数为 1, 2, 3
),
cur_sample=jax.lax.select(
(state.counter % 4) == 0,
sample, # 余数为 0
state.cur_sample, # 余数为 1, 2, 3
),
)
cur_sample = state.cur_sample # 获取当前样本
# 获取前一个样本,基于当前状态和模型输出
prev_sample = self._get_prev_sample(state, cur_sample, timestep, prev_timestep, model_output)
# 更新状态计数器
state = state.replace(counter=state.counter + 1)
# 返回前一个样本和更新后的状态
return (prev_sample, state)
# 定义 step_plms 函数,参数包括状态、模型输出、时间步和样本
def step_plms(
self,
state: PNDMSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
# 计算前一个样本,使用 PNDM 算法中的公式 (9)
def _get_prev_sample(self, state: PNDMSchedulerState, sample, timestep, prev_timestep, model_output):
# 查看 PNDM 论文中的公式 (9)
# 此函数使用公式 (9) 计算 x_(t−δ)
# 注意:需要将 x_t 加到方程的两边
# 符号约定 (<变量名> -> <论文中的名称>
# alpha_prod_t -> α_t
# alpha_prod_t_prev -> α_(t−δ)
# beta_prod_t -> (1 - α_t)
# beta_prod_t_prev -> (1 - α_(t−δ))
# sample -> x_t
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ)
# 获取当前时间步的累积 α 值
alpha_prod_t = state.common.alphas_cumprod[timestep]
# 如果 prev_timestep 大于等于 0,获取前一个时间步的累积 α 值,否则使用最终的累积 α 值
alpha_prod_t_prev = jnp.where(
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
)
# 计算当前时间步的 β 值
beta_prod_t = 1 - alpha_prod_t
# 计算前一个时间步的 β 值
beta_prod_t_prev = 1 - alpha_prod_t_prev
# 根据预测类型进行不同的处理
if self.config.prediction_type == "v_prediction":
# 使用公式调整模型输出
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
elif self.config.prediction_type != "epsilon":
# 如果预测类型不符合要求,则抛出异常
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`"
)
# 计算样本系数,对应公式 (9) 中的分母部分加 1
# 注意:公式简化后可得 (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
# sqrt(α_(t−δ)) / sqrt(α_t)
sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
# 计算模型输出的分母系数,对应公式 (9) 中 e_θ(x_t, t) 的分母
model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
alpha_prod_t * beta_prod_t * alpha_prod_t_prev
) ** (0.5)
# 根据公式 (9) 计算前一个样本
prev_sample = (
sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
)
# 返回计算得到的前一个样本
return prev_sample
# 添加噪声到样本中
def add_noise(
self,
state: PNDMSchedulerState,
original_samples: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
# 调用公共函数添加噪声
return add_noise_common(state.common, original_samples, noise, timesteps)
# 返回训练时间步的数量
def __len__(self):
return self.config.num_train_timesteps
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· AI 智能体引爆开源社区「GitHub 热点速览」
· 三行代码完成国际化适配,妙~啊~
· .NET Core 中如何实现缓存的预热?