CogVideo & CogVideoX 微调代码源码解析(十)
.\cogvideo-finetune\sat\sgm\modules\diffusionmodules\sampling.py
"""
# 从类型提示模块导入字典和联合类型
from typing import Dict, Union
# 导入 PyTorch 库
import torch
# 从 OmegaConf 导入列表配置和 OmegaConf 类
from omegaconf import ListConfig, OmegaConf
# 从 tqdm 导入进度条功能
from tqdm import tqdm
# 从自定义模块导入必要的函数
from ...modules.diffusionmodules.sampling_utils import (
get_ancestral_step, # 获取祖先步骤的函数
linear_multistep_coeff, # 线性多步骤系数的函数
to_d, # 转换到 d 的函数
to_neg_log_sigma, # 转换为负对数 sigma 的函数
to_sigma, # 转换为 sigma 的函数
)
# 从 util 模块导入辅助函数
from ...util import append_dims, default, instantiate_from_config
from ...util import SeededNoise # 导入带种子的噪声生成器
# 从 guiders 模块导入动态 CFG 类
from .guiders import DynamicCFG
# 默认引导器配置
DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
# 定义基础扩散采样器类
class BaseDiffusionSampler:
# 初始化方法,设置相关配置
def __init__(
self,
discretization_config: Union[Dict, ListConfig, OmegaConf], # 离散化配置
num_steps: Union[int, None] = None, # 步骤数
guider_config: Union[Dict, ListConfig, OmegaConf, None] = None, # 引导器配置
verbose: bool = False, # 是否详细输出
device: str = "cuda", # 使用的设备
):
self.num_steps = num_steps # 保存步骤数
# 根据配置实例化离散化对象
self.discretization = instantiate_from_config(discretization_config)
# 根据配置实例化引导器对象,使用默认值如果未提供
self.guider = instantiate_from_config(
default(
guider_config,
DEFAULT_GUIDER,
)
)
self.verbose = verbose # 设置详细输出标志
self.device = device # 设置设备
# 准备采样循环的方法
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
# 计算 sigma 值
sigmas = self.discretization(self.num_steps if num_steps is None else num_steps, device=self.device)
# 默认情况下使用条件输入
uc = default(uc, cond)
# 对输入进行缩放
x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
num_sigmas = len(sigmas) # 获取 sigma 数量
# 创建与输入样本数相同的全 1 张量
s_in = x.new_ones([x.shape[0]]).float()
# 返回准备好的参数
return x, s_in, sigmas, num_sigmas, cond, uc
# 去噪声的方法
def denoise(self, x, denoiser, sigma, cond, uc):
# 准备输入并进行去噪声处理
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc))
# 使用引导器进一步处理去噪声结果
denoised = self.guider(denoised, sigma)
# 返回去噪声后的结果
return denoised
# 获取 sigma 生成器的方法
def get_sigma_gen(self, num_sigmas):
sigma_generator = range(num_sigmas - 1) # 创建 sigma 生成器范围
if self.verbose: # 如果启用详细输出
print("#" * 30, " Sampling setting ", "#" * 30) # 输出分隔符
print(f"Sampler: {self.__class__.__name__}") # 输出采样器类名
print(f"Discretization: {self.discretization.__class__.__name__}") # 输出离散化类名
print(f"Guider: {self.guider.__class__.__name__}") # 输出引导器类名
# 包装 sigma 生成器为 tqdm 对象,以便显示进度条
sigma_generator = tqdm(
sigma_generator,
total=num_sigmas,
desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps",
)
# 返回 sigma 生成器
return sigma_generator
# 定义单步扩散采样器类,继承自基础扩散采样器
class SingleStepDiffusionSampler(BaseDiffusionSampler):
# 定义采样步骤的抽象方法
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):
raise NotImplementedError # 抛出未实现错误
# 欧拉步骤的方法
def euler_step(self, x, d, dt):
# 计算下一个状态
return x + dt * d
# 定义 EDM 采样器类,继承自单步扩散采样器
class EDMSampler(SingleStepDiffusionSampler):
# 初始化方法,设置相关参数
def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs):
super().__init__(*args, **kwargs) # 调用父类初始化
self.s_churn = s_churn # 设置 churn 参数
self.s_tmin = s_tmin # 设置最小时间
self.s_tmax = s_tmax # 设置最大时间
self.s_noise = s_noise # 设置噪声参数
# 定义sampler_step函数,用于执行采样步骤
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
# 计算sigma_hat,用于计算噪声
sigma_hat = sigma * (gamma + 1.0)
# 如果gamma大于0,生成服从标准正态分布的随机数eps,并将其乘以s_noise,再加到x上
if gamma > 0:
eps = torch.randn_like(x) * self.s_noise
x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
# 使用denoiser对x进行去噪,得到denoised
denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
# 计算d,用于后续步骤
d = to_d(x, sigma_hat, denoised)
# 计算dt,用于后续步骤
dt = append_dims(next_sigma - sigma_hat, x.ndim)
# 使用欧拉步骤更新x
euler_step = self.euler_step(x, d, dt)
# 使用可能的修正步骤更新x
x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
# 返回更新后的x
return x
# 定义__call__函数,用于执行整个采样过程
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
# 准备采样循环所需的变量
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
# 遍历sigmas,执行采样步骤
for i in self.get_sigma_gen(num_sigmas):
# 计算gamma,用于控制噪声的大小
gamma = (
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0
)
# 执行采样步骤
x = self.sampler_step(
s_in * sigmas[i],
s_in * sigmas[i + 1],
denoiser,
x,
cond,
uc,
gamma,
)
# 返回最终的x
return x
# 定义 DDIMSampler 类,继承自 SingleStepDiffusionSampler 类
class DDIMSampler(SingleStepDiffusionSampler):
# 初始化方法,接受 s_noise 参数,默认值为 0.1
def __init__(self, s_noise=0.1, *args, **kwargs):
# 调用父类的初始化方法
super().__init__(*args, **kwargs)
# 设置实例属性 s_noise 为传入的 s_noise 参数值
self.s_noise = s_noise
# sampler_step 方法,接受 sigma、next_sigma、denoiser、x、cond、uc 和 s_noise 参数
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0):
# 使用 denoiser 对 x 进行去噪,得到 denoised
denoised = self.denoise(x, denoiser, sigma, cond, uc)
# 计算 d,使用 to_d 函数
d = to_d(x, sigma, denoised)
# 计算 dt
dt = append_dims(next_sigma * (1 - s_noise**2) ** 0.5 - sigma, x.ndim)
# 计算 euler_step
euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x)
# 调用 possible_correction_step 方法,得到 x
x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
# 返回 x
return x
# __call__ 方法,接受 denoiser、x、cond、uc 和 num_steps 参数
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
# 调用 prepare_sampling_loop 方法,得到 x、s_in、sigmas、num_sigmas、cond 和 uc
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
# 遍历 sigmas
for i in self.get_sigma_gen(num_sigmas):
# 调用 sampler_step 方法,得到 x
x = self.sampler_step(
s_in * sigmas[i],
s_in * sigmas[i + 1],
denoiser,
x,
cond,
uc,
self.s_noise,
)
# 返回 x
return x
# 定义 AncestralSampler 类,继承自 SingleStepDiffusionSampler 类
class AncestralSampler(SingleStepDiffusionSampler):
# 初始化方法,接受 eta 和 s_noise 参数,默认值分别为 1.0 和 1.0
def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):
# 调用父类的初始化方法
super().__init__(*args, **kwargs)
# 设置实例属性 eta 为传入的 eta 参数值
self.eta = eta
# 设置实例属性 s_noise 为传入的 s_noise 参数值
self.s_noise = s_noise
# 设置实例属性 noise_sampler 为一个 lambda 函数,用于生成噪声
self.noise_sampler = lambda x: torch.randn_like(x)
# ancestral_euler_step 方法,接受 x、denoised、sigma 和 sigma_down 参数
def ancestral_euler_step(self, x, denoised, sigma, sigma_down):
# 计算 d,使用 to_d 函数
d = to_d(x, sigma, denoised)
# 计算 dt
dt = append_dims(sigma_down - sigma, x.ndim)
# 调用 euler_step 方法,得到结果
return self.euler_step(x, d, dt)
# ancestral_step 方法,接受 x、sigma、next_sigma 和 sigma_up 参数
def ancestral_step(self, x, sigma, next_sigma, sigma_up):
# 根据条件进行赋值操作
x = torch.where(
append_dims(next_sigma, x.ndim) > 0.0,
x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim),
x,
)
# 返回结果
return x
# __call__ 方法,接受 denoiser、x、cond、uc 和 num_steps 参数
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
# 调用 prepare_sampling_loop 方法,得到 x、s_in、sigmas、num_sigmas、cond 和 uc
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
# 遍历 sigmas
for i in self.get_sigma_gen(num_sigmas):
# 调用 sampler_step 方法,得到 x
x = self.sampler_step(
s_in * sigmas[i],
s_in * sigmas[i + 1],
denoiser,
x,
cond,
uc,
)
# 返回 x
return x
# 定义 LinearMultistepSampler 类,继承自 BaseDiffusionSampler 类
class LinearMultistepSampler(BaseDiffusionSampler):
# 初始化方法,接受 order 参数,默认值为 4
def __init__(
self,
order=4,
*args,
**kwargs,
):
# 调用父类的初始化方法
super().__init__(*args, **kwargs)
# 设置实例属性 order 为传入的 order 参数值
self.order = order
# 定义可调用方法,用于执行去噪操作
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
# 准备采样循环,返回处理后的输入、sigma、以及条件信息
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
# 初始化去噪结果列表
ds = []
# 将 sigma 从计算图中分离并转移到 CPU,然后转换为 NumPy 数组
sigmas_cpu = sigmas.detach().cpu().numpy()
# 遍历生成的 sigma 值
for i in self.get_sigma_gen(num_sigmas):
# 计算当前 sigma
sigma = s_in * sigmas[i]
# 使用去噪器处理输入数据,获取去噪结果
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs)
# 进一步处理去噪结果
denoised = self.guider(denoised, sigma)
# 将当前输入、sigma 和去噪结果转换为目标格式
d = to_d(x, sigma, denoised)
# 将去噪结果添加到结果列表
ds.append(d)
# 如果结果列表超过预设顺序,则移除最旧的结果
if len(ds) > self.order:
ds.pop(0)
# 计算当前顺序
cur_order = min(i + 1, self.order)
# 计算线性多步法的系数
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
# 更新输入 x,使用加权和的方式结合去噪结果
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
# 返回最终的去噪结果
return x
# 定义一个基于 Euler 方法的 EDM 采样器类,继承自 EDMSampler
class EulerEDMSampler(EDMSampler):
# 可能的修正步骤,接受多个参数
def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
# 直接返回 euler_step,没有进行任何修正
return euler_step
# 定义一个基于 Heun 方法的 EDM 采样器类,继承自 EDMSampler
class HeunEDMSampler(EDMSampler):
# 可能的修正步骤,接受多个参数
def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
# 检查 next_sigma 的总和是否小于 1e-14
if torch.sum(next_sigma) < 1e-14:
# 如果所有噪声水平为 0,则返回 euler_step,避免网络评估
return euler_step
else:
# 使用 denoiser 去噪 euler_step
denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc)
# 将 euler_step 和去噪结果转换为新的 d 值
d_new = to_d(euler_step, next_sigma, denoised)
# 计算 d 的新值,取 d 和 d_new 的平均
d_prime = (d + d_new) / 2.0
# 如果噪声水平不为 0,则应用修正
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step)
# 返回修正后的 x
return x
# 定义一个基于 Euler 的祖先采样器类,继承自 AncestralSampler
class EulerAncestralSampler(AncestralSampler):
# 进行采样步骤,接受多个参数
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc):
# 获取祖先步骤的上下界
sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
# 去噪 x
denoised = self.denoise(x, denoiser, sigma, cond, uc)
# 进行 Euler 采样步骤
x = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
# 进行祖先步骤
x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
# 返回最终的 x
return x
# 定义一个基于 DPM++ 的祖先采样器类,继承自 AncestralSampler
class DPMPP2SAncestralSampler(AncestralSampler):
# 获取变量,接受 sigma 和 sigma_down
def get_variables(self, sigma, sigma_down):
# 将 sigma 和 sigma_down 转换为负对数
t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)]
# 计算时间差 h
h = t_next - t
# 计算 s 值
s = t + 0.5 * h
# 返回 h, s, t, t_next
return h, s, t, t_next
# 计算多重值,接受 h, s, t, t_next
def get_mult(self, h, s, t, t_next):
# 计算多个乘数
mult1 = to_sigma(s) / to_sigma(t)
mult2 = (-0.5 * h).expm1()
mult3 = to_sigma(t_next) / to_sigma(t)
mult4 = (-h).expm1()
# 返回所有计算出的乘数
return mult1, mult2, mult3, mult4
# 进行采样步骤,接受多个参数
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs):
# 获取祖先步骤的上下界
sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
# 去噪 x
denoised = self.denoise(x, denoiser, sigma, cond, uc)
# 进行 Euler 采样步骤
x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
# 检查 sigma_down 的总和是否小于 1e-14
if torch.sum(sigma_down) < 1e-14:
# 如果所有噪声水平为 0,则返回 x_euler,避免网络评估
x = x_euler
else:
# 获取变量
h, s, t, t_next = self.get_variables(sigma, sigma_down)
# 计算多重值并调整维度
mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)]
# 计算新的 x2
x2 = mult[0] * x - mult[1] * denoised
# 对 x2 进行去噪
denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
# 计算最终的 x_dpmpp2s
x_dpmpp2s = mult[2] * x - mult[3] * denoised2
# 如果噪声水平不为 0,则应用修正
x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler)
# 进行祖先步骤
x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
# 返回最终的 x
return x
# 定义一个基于 DPM++ 的采样器类,继承自 BaseDiffusionSampler
class DPMPP2MSampler(BaseDiffusionSampler):
# 定义一个获取变量的函数,接受当前和下一个噪声级别,以及可选的上一个噪声级别
def get_variables(self, sigma, next_sigma, previous_sigma=None):
# 将 sigma 和 next_sigma 转换为负对数形式
t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
# 计算两个时间点之间的差值
h = t_next - t
# 如果上一个噪声级别存在
if previous_sigma is not None:
# 计算当前时间与上一个时间的差值
h_last = t - to_neg_log_sigma(previous_sigma)
# 计算当前和上一个时间差的比值
r = h_last / h
# 返回差值 h、比值 r、当前和下一个时间
return h, r, t, t_next
else:
# 返回差值 h 和当前、下一个时间,但不返回比值 r
return h, None, t, t_next
# 定义一个获取乘数的函数,接受多个参数
def get_mult(self, h, r, t, t_next, previous_sigma):
# 计算当前和下一个时间的 sigma 乘数
mult1 = to_sigma(t_next) / to_sigma(t)
# 计算 h 的负值的指数减一
mult2 = (-h).expm1()
# 如果上一个噪声级别存在
if previous_sigma is not None:
# 计算与 r 相关的乘数
mult3 = 1 + 1 / (2 * r)
mult4 = 1 / (2 * r)
# 返回所有乘数
return mult1, mult2, mult3, mult4
else:
# 返回前两个乘数
return mult1, mult2
# 定义采样步骤函数,接受多个参数
def sampler_step(
self,
old_denoised,
previous_sigma,
sigma,
next_sigma,
denoiser,
x,
cond,
uc=None,
):
# 使用去噪器对输入 x 进行去噪
denoised = self.denoise(x, denoiser, sigma, cond, uc)
# 获取变量 h、r、t、t_next
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
# 获取乘数,并将维度调整以匹配 x 的维度
mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)]
# 计算标准化后的 x
x_standard = mult[0] * x - mult[1] * denoised
# 如果没有旧的去噪结果或下一个噪声级别接近零
if old_denoised is None or torch.sum(next_sigma) < 1e-14:
# 保存网络评估,如果所有噪声级别为 0 或为第一步
return x_standard, denoised
else:
# 计算去噪后的旧结果
denoised_d = mult[2] * denoised - mult[3] * old_denoised
# 计算高级 x
x_advanced = mult[0] * x - mult[1] * denoised_d
# 如果噪声级别不为 0 且不是第一步,则应用修正
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard)
# 返回最终的 x 和去噪结果
return x, denoised
# 定义调用函数,接受去噪器、输入 x、条件和其他可选参数
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
# 准备采样循环的参数,包括对输入 x、条件、噪声级别等的处理
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
# 初始化旧去噪结果
old_denoised = None
# 遍历生成的噪声级别
for i in self.get_sigma_gen(num_sigmas):
# 执行采样步骤,更新 x 和旧去噪结果
x, old_denoised = self.sampler_step(
old_denoised,
None if i == 0 else s_in * sigmas[i - 1],
s_in * sigmas[i],
s_in * sigmas[i + 1],
denoiser,
x,
cond,
uc=uc,
)
# 返回最终的 x
return x
# 定义 SDEDPMPP2MSampler 类,继承自 BaseDiffusionSampler
class SDEDPMPP2MSampler(BaseDiffusionSampler):
# 获取变量 h、r 和时间参数 t、t_next
def get_variables(self, sigma, next_sigma, previous_sigma=None):
# 将 sigma 和 next_sigma 转换为负对数形式
t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
# 计算 h 为 t_next 和 t 的差值
h = t_next - t
# 如果 previous_sigma 不为 None
if previous_sigma is not None:
# 计算上一个 sigma 的负对数值
h_last = t - to_neg_log_sigma(previous_sigma)
# 计算 r 为 h_last 和 h 的比值
r = h_last / h
# 返回 h、r、t 和 t_next
return h, r, t, t_next
else:
# 返回 h 和 None(无 r),以及 t 和 t_next
return h, None, t, t_next
# 计算乘数值
def get_mult(self, h, r, t, t_next, previous_sigma):
# 计算 mult1 为 t_next 和 t 的 sigma 比值乘以 h 的负指数
mult1 = to_sigma(t_next) / to_sigma(t) * (-h).exp()
# 计算 mult2 为 (-2*h) 的 expm1 值
mult2 = (-2 * h).expm1()
# 如果 previous_sigma 不为 None
if previous_sigma is not None:
# 计算 mult3 为 1 + 1/(2*r)
mult3 = 1 + 1 / (2 * r)
# 计算 mult4 为 1/(2*r)
mult4 = 1 / (2 * r)
# 返回 mult1、mult2、mult3 和 mult4
return mult1, mult2, mult3, mult4
else:
# 返回 mult1 和 mult2
return mult1, mult2
# 执行采样步骤
def sampler_step(
self,
old_denoised,
previous_sigma,
sigma,
next_sigma,
denoiser,
x,
cond,
uc=None,
):
# 使用 denoiser 对 x 进行去噪处理
denoised = self.denoise(x, denoiser, sigma, cond, uc)
# 获取 h、r、t 和 t_next
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
# 计算乘数,并调整维度
mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)]
# 计算噪声乘数并调整维度
mult_noise = append_dims(next_sigma * (1 - (-2 * h).exp()) ** 0.5, x.ndim)
# 计算标准化后的 x
x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
# 如果 old_denoised 为 None 或 next_sigma 的和小于 1e-14
if old_denoised is None or torch.sum(next_sigma) < 1e-14:
# 返回标准化后的 x 和去噪后的结果
return x_standard, denoised
else:
# 计算去噪的差异
denoised_d = mult[2] * denoised - mult[3] * old_denoised
# 计算高级 x
x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x)
# 如果噪声水平不为 0 且不是第一步,应用修正
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard)
# 返回最终的 x 和去噪后的结果
return x, denoised
# 调用采样器,执行采样循环
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs):
# 准备采样循环,初始化输入和 sigma
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
# 初始化 old_denoised 为 None
old_denoised = None
# 遍历 sigma 生成器
for i in self.get_sigma_gen(num_sigmas):
# 执行采样步骤并更新 x 和 old_denoised
x, old_denoised = self.sampler_step(
old_denoised,
None if i == 0 else s_in * sigmas[i - 1],
s_in * sigmas[i],
s_in * sigmas[i + 1],
denoiser,
x,
cond,
uc=uc,
)
# 返回最终的 x
return x
# 定义 SdeditEDMSampler 类,继承自 EulerEDMSampler
class SdeditEDMSampler(EulerEDMSampler):
# 初始化函数,设置编辑比例
def __init__(self, edit_ratio=0.5, *args, **kwargs):
# 调用父类的初始化函数
super().__init__(*args, **kwargs)
# 设置编辑比例
self.edit_ratio = edit_ratio
# 定义一个可调用的方法,接受多个参数进行图像去噪
def __call__(self, denoiser, image, randn, cond, uc=None, num_steps=None, edit_ratio=None):
# 克隆 randn,创建 randn_unit 用于后续计算
randn_unit = randn.clone()
# 准备采样循环,处理 randn、条件、未条件和步骤数
randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(randn, cond, uc, num_steps)
# 如果未指定 num_steps,则使用对象的默认步骤数
if num_steps is None:
num_steps = self.num_steps
# 如果未指定 edit_ratio,则使用对象的默认编辑比例
if edit_ratio is None:
edit_ratio = self.edit_ratio
# 初始化 x 为 None,用于后续存储结果
x = None
# 遍历 sigma 生成器,获取每个 sigma 的值
for i in self.get_sigma_gen(num_sigmas):
# 如果当前步骤比例小于 edit_ratio,则跳过此次循环
if i / num_steps < edit_ratio:
continue
# 如果 x 为 None,则初始化 x 为图像与噪声的组合
if x is None:
x = image + randn_unit * append_dims(s_in * sigmas[i], len(randn_unit.shape))
# 计算 gamma 值,依据条件限制调整
gamma = (
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0
)
# 进行一次采样步骤,更新 x 的值
x = self.sampler_step(
s_in * sigmas[i], # 当前 sigma 的输入
s_in * sigmas[i + 1], # 下一个 sigma 的输入
denoiser, # 去噪器
x, # 当前图像
cond, # 条件信息
uc, # 未条件信息
gamma, # gamma 值
)
# 返回最终处理后的图像
return x
# 定义一个名为 VideoDDIMSampler 的类,继承自 BaseDiffusionSampler
class VideoDDIMSampler(BaseDiffusionSampler):
# 初始化函数,接受固定帧数和 sdedit 标志,及其他参数
def __init__(self, fixed_frames=0, sdedit=False, **kwargs):
# 调用父类的初始化函数
super().__init__(**kwargs)
# 设置固定帧数
self.fixed_frames = fixed_frames
# 设置 sdedit 标志
self.sdedit = sdedit
# 准备采样循环,接受输入数据和条件
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
# 进行离散化,计算 alpha 的平方根累积乘积和时间步
alpha_cumprod_sqrt, timesteps = self.discretization(
self.num_steps if num_steps is None else num_steps, # 使用给定的步数或默认步数
device=self.device, # 指定设备
return_idx=True, # 返回索引
do_append_zero=False, # 不追加零
)
# 在 alpha_cumprod_sqrt 末尾添加一个值为 1 的新张量
alpha_cumprod_sqrt = torch.cat([alpha_cumprod_sqrt, alpha_cumprod_sqrt.new_ones([1])])
# 创建一个新的时间步张量,并在开头添加一个值为 -1 的零张量
timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1]) - 1, torch.tensor(list(timesteps))])
# 如果 uc 为空,使用 cond 作为默认值
uc = default(uc, cond)
# 计算 alpha_cumprod_sqrt 的元素数量
num_sigmas = len(alpha_cumprod_sqrt)
# 创建一个新的张量 s_in,初始值为 1
s_in = x.new_ones([x.shape[0]])
# 返回多个变量
return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps
# 去噪函数,接受多个输入参数
def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None):
# 初始化额外模型输入的字典
additional_model_inputs = {}
# 检查 scale 是否为张量且不为 1
if isinstance(scale, torch.Tensor) == False and scale == 1:
# 为额外模型输入添加当前时间步的索引
additional_model_inputs["idx"] = x.new_ones([x.shape[0]]) * timestep
# 如果 scale_emb 不为 None,添加到额外输入
if scale_emb is not None:
additional_model_inputs["scale_emb"] = scale_emb
# 调用去噪器进行去噪,并转换为 float32 类型
denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(torch.float32)
else:
# 创建一个新的索引张量,包含当前时间步的重复值
additional_model_inputs["idx"] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
# 调用去噪器进行去噪,准备输入并转换为 float32 类型
denoised = denoiser(
*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs
).to(torch.float32)
# 如果 guider 是 DynamicCFG 的实例,进行动态调整
if isinstance(self.guider, DynamicCFG):
denoised = self.guider(
denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, step_index=self.num_steps - timestep, scale=scale
)
else:
# 否则,进行普通的调整
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, scale=scale)
# 返回去噪后的结果
return denoised
# 采样步骤函数,接受多个输入参数
def sampler_step(
self,
alpha_cumprod_sqrt,
next_alpha_cumprod_sqrt,
denoiser,
x,
cond,
uc=None,
idx=None,
timestep=None,
scale=None,
scale_emb=None,
):
# 调用 denoise 方法获取去噪结果
denoised = self.denoise(
x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb
).to(torch.float32)
# 计算 a_t 和 b_t 值
a_t = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
# 更新 x 的值,通过加权当前值和去噪后的值
x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
# 返回更新后的 x
return x
# 定义一个可调用的方法,用于处理去噪过程
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None):
# 准备采样循环的输入,返回处理后的数据和相关参数
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
# 遍历生成的 sigma 值
for i in self.get_sigma_gen(num_sigmas):
# 执行采样步骤,更新输入数据
x = self.sampler_step(
s_in * alpha_cumprod_sqrt[i], # 当前 sigma 的缩放输入
s_in * alpha_cumprod_sqrt[i + 1], # 下一个 sigma 的缩放输入
denoiser, # 去噪器对象
x, # 当前输入
cond, # 条件输入
uc, # 可选的额外条件
idx=self.num_steps - i, # 当前步骤索引
timestep=timesteps[-(i + 1)], # 当前时间步
scale=scale, # 缩放因子
scale_emb=scale_emb, # 嵌入的缩放因子
)
# 返回处理后的结果
return x
# 定义 VPSDEDPMPP2M 采样器类,继承自 VideoDDIMSampler
class VPSDEDPMPP2MSampler(VideoDDIMSampler):
# 获取变量,计算多个参数
def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None):
# 计算 alpha 的累积乘积
alpha_cumprod = alpha_cumprod_sqrt**2
# 计算 lamb 的对数值
lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log()
# 计算下一个 alpha 的累积乘积
next_alpha_cumprod = next_alpha_cumprod_sqrt**2
# 计算下一个 lamb 的对数值
lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log()
# 计算 h 值
h = lamb_next - lamb
# 如果存在前一个 alpha 的累积乘积
if previous_alpha_cumprod_sqrt is not None:
# 计算前一个 alpha 的累积乘积
previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2
# 计算前一个 lamb 的对数值
lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log()
# 计算 h_last 值
h_last = lamb - lamb_previous
# 计算 r 值
r = h_last / h
# 返回 h、r、lamb 和 lamb_next
return h, r, lamb, lamb_next
else:
# 返回 h、None、lamb 和 lamb_next
return h, None, lamb, lamb_next
# 计算乘数
def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt):
# 计算第一个乘数
mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 * (-h).exp()
# 计算第二个乘数
mult2 = (-2 * h).expm1() * next_alpha_cumprod_sqrt
# 如果存在前一个 alpha 的累积乘积
if previous_alpha_cumprod_sqrt is not None:
# 计算第三个乘数
mult3 = 1 + 1 / (2 * r)
# 计算第四个乘数
mult4 = 1 / (2 * r)
# 返回所有乘数
return mult1, mult2, mult3, mult4
else:
# 返回前两个乘数
return mult1, mult2
# 执行采样步骤
def sampler_step(
self,
old_denoised,
previous_alpha_cumprod_sqrt,
alpha_cumprod_sqrt,
next_alpha_cumprod_sqrt,
denoiser,
x,
cond,
uc=None,
idx=None,
timestep=None,
scale=None,
scale_emb=None,
):
# 使用去噪器处理输入,得到去噪后的结果
denoised = self.denoise(
x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb
).to(torch.float32)
# 如果索引为 1,返回去噪结果
if idx == 1:
return denoised, denoised
# 获取相关变量
h, r, lamb, lamb_next = self.get_variables(
alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
)
# 获取乘数
mult = [
append_dims(mult, x.ndim)
for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
]
# 计算噪声乘数
mult_noise = append_dims((1 - next_alpha_cumprod_sqrt**2) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5, x.ndim)
# 计算标准化 x
x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
# 如果 old_denoised 为 None 或者下一个 alpha 的累积乘积小于阈值
if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14:
# 返回标准化的 x 和去噪后的结果
return x_standard, denoised
else:
# 计算去噪后的差异
denoised_d = mult[2] * denoised - mult[3] * old_denoised
# 计算高级 x
x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x)
# 更新 x
x = x_advanced
# 返回最终的 x 和去噪结果
return x, denoised
# 定义可调用方法,接收去噪器、输入数据、条件、上采样因子及其它参数
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None):
# 准备采样循环所需的输入及参数
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
# 如果固定帧数大于0,提取前固定帧数的图像
if self.fixed_frames > 0:
prefix_frames = x[:, : self.fixed_frames]
# 初始化去噪后的图像为 None
old_denoised = None
# 遍历生成的 sigma 值
for i in self.get_sigma_gen(num_sigmas):
# 如果固定帧数大于0,进行处理
if self.fixed_frames > 0:
# 如果启用 SD 编辑模式
if self.sdedit:
# 生成与前缀帧同形状的随机噪声
rd = torch.randn_like(prefix_frames)
# 计算带噪声的前缀帧
noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims(
s_in * (1 - alpha_cumprod_sqrt[i] ** 2) ** 0.5, len(prefix_frames.shape)
)
# 将带噪声的前缀帧与剩余帧连接
x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames :]], dim=1)
else:
# 直接将前缀帧与剩余帧连接
x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1)
# 执行去噪步骤
x, old_denoised = self.sampler_step(
old_denoised,
None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1],
s_in * alpha_cumprod_sqrt[i],
s_in * alpha_cumprod_sqrt[i + 1],
denoiser,
x,
cond,
uc=uc,
idx=self.num_steps - i,
timestep=timesteps[-(i + 1)],
scale=scale,
scale_emb=scale_emb,
)
# 如果固定帧数大于0,重构最终输出
if self.fixed_frames > 0:
x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1)
# 返回最终的去噪结果
return x
# 定义 VPODEDPMPP2MSampler 类,继承自 VideoDDIMSampler
class VPODEDPMPP2MSampler(VideoDDIMSampler):
# 获取变量,计算当前和下一个 alpha 的平方根
def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None):
# 计算 alpha 的平方
alpha_cumprod = alpha_cumprod_sqrt**2
# 计算 lambda 值并取其对数
lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log()
# 计算下一个 alpha 的平方
next_alpha_cumprod = next_alpha_cumprod_sqrt**2
# 计算下一个 lambda 值并取其对数
lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log()
# 计算 h 值
h = lamb_next - lamb
# 如果提供了上一个 alpha 的平方根
if previous_alpha_cumprod_sqrt is not None:
# 计算上一个 alpha 的平方
previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2
# 计算上一个 lambda 值并取其对数
lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log()
# 计算上一个 h 值
h_last = lamb - lamb_previous
# 计算 r 值
r = h_last / h
# 返回 h, r, lamb, lamb_next
return h, r, lamb, lamb_next
else:
# 如果没有上一个 alpha,返回 h 和其他计算值
return h, None, lamb, lamb_next
# 获取乘数
def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt):
# 计算第一个乘数
mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
# 计算第二个乘数
mult2 = (-h).expm1() * next_alpha_cumprod_sqrt
# 如果提供了上一个 alpha 的平方根
if previous_alpha_cumprod_sqrt is not None:
# 计算第三个乘数
mult3 = 1 + 1 / (2 * r)
# 计算第四个乘数
mult4 = 1 / (2 * r)
# 返回所有乘数
return mult1, mult2, mult3, mult4
else:
# 返回前两个乘数
return mult1, mult2
# 采样步骤
def sampler_step(
self,
old_denoised,
previous_alpha_cumprod_sqrt,
alpha_cumprod_sqrt,
next_alpha_cumprod_sqrt,
denoiser,
x,
cond,
uc=None,
idx=None,
timestep=None,
):
# 使用去噪器对输入 x 进行去噪处理
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(torch.float32)
# 如果索引为 1,返回去噪结果
if idx == 1:
return denoised, denoised
# 获取变量
h, r, lamb, lamb_next = self.get_variables(
alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
)
# 获取乘数并调整维度
mult = [
append_dims(mult, x.ndim)
for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
]
# 计算标准化的 x
x_standard = mult[0] * x - mult[1] * denoised
# 如果没有旧的去噪结果或下一个 alpha 的平方根总和接近 0
if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14:
# 返回标准化的 x 和去噪结果
return x_standard, denoised
else:
# 计算去噪后的结果
denoised_d = mult[2] * denoised - mult[3] * old_denoised
# 计算高级的 x
x_advanced = mult[0] * x - mult[1] * denoised_d
# 更新 x
x = x_advanced
# 返回最终的 x 和去噪结果
return x, denoised
# 定义可调用对象,接受去噪器、输入数据、条件及其他参数
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs):
# 准备采样循环所需的输入数据和参数
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
# 初始化旧去噪结果为 None
old_denoised = None
# 遍历生成的 sigma 值,进行采样步骤
for i in self.get_sigma_gen(num_sigmas):
# 执行单步采样,并更新当前输入和旧去噪结果
x, old_denoised = self.sampler_step(
old_denoised,
None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1], # 第一步不使用旧去噪
s_in * alpha_cumprod_sqrt[i], # 当前 sigma 的值
s_in * alpha_cumprod_sqrt[i + 1], # 下一个 sigma 的值
denoiser, # 去噪器
x, # 当前输入数据
cond, # 条件输入
uc=uc, # 额外条件(可选)
idx=self.num_steps - i, # 当前步骤索引
timestep=timesteps[-(i + 1)], # 当前时间步
)
# 返回最终生成的输入数据
return x
.\cogvideo-finetune\sat\sgm\modules\diffusionmodules\sampling_utils.py
import torch
from scipy import integrate
from ...util import append_dims
from einops import rearrange
class NoDynamicThresholding:
def __call__(self, uncond, cond, scale):
scale = append_dims(scale, cond.ndim) if isinstance(scale, torch.Tensor) else scale
return uncond + scale * (cond - uncond)
class StaticThresholding:
def __call__(self, uncond, cond, scale):
result = uncond + scale * (cond - uncond)
result = torch.clamp(result, min=-1.0, max=1.0)
return result
def dynamic_threshold(x, p=0.95):
N, T, C, H, W = x.shape
x = rearrange(x, "n t c h w -> n c (t h w)")
l, r = x.quantile(q=torch.tensor([1 - p, p], device=x.device), dim=-1, keepdim=True)
s = torch.maximum(-l, r)
threshold_mask = (s > 1).expand(-1, -1, H * W * T)
if threshold_mask.any():
x = torch.where(threshold_mask, x.clamp(min=-1 * s, max=s), x)
x = rearrange(x, "n c (t h w) -> n t c h w", t=T, h=H, w=W)
return x
def dynamic_thresholding2(x0):
p = 0.995
origin_dtype = x0.dtype
x0 = x0.to(torch.float32)
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = append_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), x0.dim())
x0 = torch.clamp(x0, -s, s)
return x0.to(origin_dtype)
def latent_dynamic_thresholding(x0):
p = 0.9995
origin_dtype = x0.dtype
x0 = x0.to(torch.float32)
s = torch.quantile(torch.abs(x0), p, dim=2)
s = append_dims(s, x0.dim())
x0 = torch.clamp(x0, -s, s) / s
return x0.to(origin_dtype)
def dynamic_thresholding3(x0):
p = 0.995
origin_dtype = x0.dtype
x0 = x0.to(torch.float32)
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = append_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), x0.dim())
x0 = torch.clamp(x0, -s, s)
return x0.to(origin_dtype)
class DynamicThresholding:
def __call__(self, uncond, cond, scale):
mean = uncond.mean()
std = uncond.std()
result = uncond + scale * (cond - uncond)
result_mean, result_std = result.mean(), result.std()
result = (result - result_mean) / result_std * std
return result
class DynamicThresholdingV1:
def __init__(self, scale_factor):
self.scale_factor = scale_factor
def __call__(self, uncond, cond, scale):
result = uncond + scale * (cond - uncond)
unscaled_result = result / self.scale_factor
B, T, C, H, W = unscaled_result.shape
flattened = rearrange(unscaled_result, "b t c h w -> b c (t h w)")
means = flattened.mean(dim=2).unsqueeze(2)
recentered = flattened - means
magnitudes = recentered.abs().max()
normalized = recentered / magnitudes
thresholded = latent_dynamic_thresholding(normalized)
denormalized = thresholded * magnitudes
uncentered = denormalized + means
unflattened = rearrange(uncentered, "b c (t h w) -> b t c h w", t=T, h=H, w=W)
scaled_result = unflattened * self.scale_factor
return scaled_result
class DynamicThresholdingV2:
def __call__(self, uncond, cond, scale):
B, T, C, H, W = uncond.shape
diff = cond - uncond
mim_target = uncond + diff * 4.0
cfg_target = uncond + diff * 8.0
mim_flattened = rearrange(mim_target, "b t c h w -> b c (t h w)")
cfg_flattened = rearrange(cfg_target, "b t c h w -> b c (t h w)")
mim_means = mim_flattened.mean(dim=2).unsqueeze(2)
cfg_means = cfg_flattened.mean(dim=2).unsqueeze(2)
mim_centered = mim_flattened - mim_means
cfg_centered = cfg_flattened - cfg_means
mim_scaleref = mim_centered.std(dim=2).unsqueeze(2)
cfg_scaleref = cfg_centered.std(dim=2).unsqueeze(2)
cfg_renormalized = cfg_centered / cfg_scaleref * mim_scaleref
result = cfg_renormalized + cfg_means
unflattened = rearrange(result, "b c (t h w) -> b t c h w", t=T, h=H, w=W)
return unflattened
def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
if order - 1 > i:
raise ValueError(f"Order {order} too high for step {i}")
def fn(tau):
prod = 1.0
for k in range(order):
if j == k:
continue
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
return prod
return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]
def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
if not eta:
return sigma_to, 0.0
sigma_up = torch.minimum(
sigma_to,
eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
)
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
return sigma_down, sigma_up
def to_d(x, sigma, denoised):
return (x - denoised) / append_dims(sigma, x.ndim)
def to_neg_log_sigma(sigma):
return sigma.log().neg()
def to_sigma(neg_log_sigma):
return neg_log_sigma.neg().exp()
.\cogvideo-finetune\sat\sgm\modules\diffusionmodules\sigma_sampling.py
import torch
import torch.distributed
from sat import mpu
from ...util import default, instantiate_from_config
class EDMSampling:
def __init__(self, p_mean=-1.2, p_std=1.2):
self.p_mean = p_mean
self.p_std = p_std
def __call__(self, n_samples, rand=None):
log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
return log_sigma.exp()
class DiscreteSampling:
def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False):
self.num_idx = num_idx
self.sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip)
world_size = mpu.get_data_parallel_world_size()
self.uniform_sampling = uniform_sampling
if self.uniform_sampling:
i = 1
while True:
if world_size % i != 0 or num_idx % (world_size // i) != 0:
i += 1
else:
self.group_num = world_size // i
break
assert self.group_num > 0
assert world_size % self.group_num == 0
self.group_width = world_size // self.group_num
self.sigma_interval = self.num_idx // self.group_num
def idx_to_sigma(self, idx):
return self.sigmas[idx]
def __call__(self, n_samples, rand=None, return_idx=False):
if self.uniform_sampling:
rank = mpu.get_data_parallel_rank()
group_index = rank // self.group_width
idx = default(
rand,
torch.randint(
group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,)
),
)
else:
idx = default(
rand,
torch.randint(0, self.num_idx, (n_samples,)),
)
if return_idx:
return self.idx_to_sigma(idx), idx
else:
return self.idx_to_sigma(idx)
class PartialDiscreteSampling:
def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True):
self.total_num_idx = total_num_idx
self.partial_num_idx = partial_num_idx
self.sigmas = instantiate_from_config(discretization_config)(
total_num_idx, do_append_zero=do_append_zero, flip=flip
)
def idx_to_sigma(self, idx):
return self.sigmas[idx]
def __call__(self, n_samples, rand=None):
idx = default(
rand,
torch.randint(0, self.partial_num_idx, (n_samples,)),
)
return self.idx_to_sigma(idx)
.\cogvideo-finetune\sat\sgm\modules\diffusionmodules\util.py
"""
adopted from
https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
and
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
and
https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
thanks!
"""
import math
from typing import Optional
import torch
import torch.nn as nn
from einops import rearrange, repeat
def make_beta_schedule(
schedule,
n_timestep,
linear_start=1e-4,
linear_end=2e-2,
):
if schedule == "linear":
betas = torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2
return betas.numpy()
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def mixed_checkpoint(func, inputs: dict, params, flag):
"""
在不缓存中间激活的情况下评估函数,减少内存消耗,但会增加反向传播的计算量。
该实现允许非张量输入。
:param func: 要评估的函数。
:param inputs: 传递给 `func` 的参数字典。
:param params: func 依赖但不作为参数的参数序列。
:param flag: 如果为 False,禁用梯度检查点。
"""
if flag:
tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]
tensor_inputs = [inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)]
non_tensor_keys = [key for key in inputs if not isinstance(inputs[key], torch.Tensor)]
non_tensor_inputs = [inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)]
args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)
return MixedCheckpointFunction.apply(
func,
len(tensor_inputs),
len(non_tensor_inputs),
tensor_keys,
non_tensor_keys,
*args,
)
else:
return func(**inputs)
class MixedCheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
run_function,
length_tensors,
length_non_tensors,
tensor_keys,
non_tensor_keys,
*args,
):
ctx.end_tensors = length_tensors
ctx.end_non_tensors = length_tensors + length_non_tensors
ctx.gpu_autocast_kwargs = {
"enabled": torch.is_autocast_enabled(),
"dtype": torch.get_autocast_gpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled(),
}
assert len(tensor_keys) == length_tensors and len(non_tensor_keys) == length_non_tensors
ctx.input_tensors = {key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))}
ctx.input_non_tensors = {
key: val for (key, val) in zip(non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors]))
}
ctx.run_function = run_function
ctx.input_params = list(args[ctx.end_non_tensors :])
with torch.no_grad():
output_tensors = ctx.run_function(**ctx.input_tensors, **ctx.input_non_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = {key: ctx.input_tensors[key].detach().requires_grad_(True) for key in ctx.input_tensors}
with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
shallow_copies = {key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) for key in ctx.input_tensors}
output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)
input_grads = torch.autograd.grad(
output_tensors,
list(ctx.input_tensors.values()) + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (
(None, None, None, None, None)
+ input_grads[: ctx.end_tensors]
+ (None,) * (ctx.end_non_tensors - ctx.end_tensors)
+ input_grads[ctx.end_tensors :]
)
def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
ctx.gpu_autocast_kwargs = {
"enabled": torch.is_autocast_enabled(),
"dtype": torch.get_autocast_gpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled(),
}
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, dtype=torch.float32):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=timesteps.device
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, "b -> b d", d=dim)
return embedding.to(dtype)
def zero_module(module):
"""
将模块的参数置为零,并返回该模块。
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
将模块的参数进行缩放,并返回该模块。
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def mean_flat(tensor):
"""
对所有非批次维度进行求均值。
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels):
"""
创建一个标准归一化层。
:param channels: 输入通道的数量。
:return: 一个用于归一化的 nn.Module。
"""
return GroupNorm32(32, channels)
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
创建一个 1D、2D 或 3D 卷积模块。
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
创建一个线性模块。
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
创建一个 1D、2D 或 3D 平均池化模块。
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class AlphaBlender(nn.Module):
strategies = ["learned", "fixed", "learned_with_images"]
def __init__(
self,
alpha: float,
merge_strategy: str = "learned_with_images",
rearrange_pattern: str = "b t -> (b t) 1 1",
):
super().__init__()
self.merge_strategy = merge_strategy
self.rearrange_pattern = rearrange_pattern
assert merge_strategy in self.strategies, f"merge_strategy needs to be in {self.strategies}"
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images":
self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
if self.merge_strategy == "fixed":
alpha = self.mix_factor
elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor)
elif self.merge_strategy == "learned_with_images":
assert image_only_indicator is not None, "need image_only_indicator ..."
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
)
alpha = rearrange(alpha, self.rearrange_pattern)
else:
raise NotImplementedError
return alpha
def forward(
self,
x_spatial: torch.Tensor,
x_temporal: torch.Tensor,
image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.Tensor:
alpha = self.get_alpha(image_only_indicator)
x = alpha.to(x_spatial.dtype) * x_spatial + (1.0 - alpha).to(x_spatial.dtype) * x_temporal
return x
.\cogvideo-finetune\sat\sgm\modules\diffusionmodules\wrappers.py
import torch
import torch.nn as nn
from packaging import version
OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
class IdentityWrapper(nn.Module):
def __init__(self, diffusion_model, compile_model: bool = False, dtype: torch.dtype = torch.float32):
super().__init__()
compile = (
torch.compile
if (version.parse(torch.__version__) >= version.parse("2.0.0")) and compile_model
else lambda x: x
)
self.diffusion_model = compile(diffusion_model)
self.dtype = dtype
def forward(self, *args, **kwargs):
return self.diffusion_model(*args, **kwargs)
class OpenAIWrapper(IdentityWrapper):
def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs) -> torch.Tensor:
for key in c:
c[key] = c[key].to(self.dtype)
if x.dim() == 4:
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
elif x.dim() == 5:
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=2)
else:
raise ValueError("Input tensor must be 4D or 5D")
return self.diffusion_model(
x,
timesteps=t,
context=c.get("crossattn", None),
y=c.get("vector", None),
**kwargs,
)
.\cogvideo-finetune\sat\sgm\modules\diffusionmodules\__init__.py
from .denoiser import Denoiser
from .discretizer import Discretization
from .model import Decoder, Encoder, Model
from .openaimodel import UNetModel
from .sampling import BaseDiffusionSampler
from .wrappers import OpenAIWrapper
.\cogvideo-finetune\sat\sgm\modules\distributions\distributions.py
import numpy as np
import torch
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn_like(self.mean).to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
计算两个高斯分布之间的 KL 散度。
形状会自动广播,支持批量比较和标量等用例。
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
return 0.5 * (
-1.0 + logvar2 - logvar1 +
torch.exp(logvar1 - logvar2) +
((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)
.\cogvideo-finetune\sat\sgm\modules\distributions\__init__.py
请提供需要注释的代码。
.\cogvideo-finetune\sat\sgm\modules\ema.py
import torch
from torch import nn
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError("Decay must be between 0 and 1")
self.m_name2s_name = {}
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
self.register_buffer(
"num_updates",
torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int),
)
for name, p in model.named_parameters():
if p.requires_grad:
s_name = name.replace(".", "")
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def reset_num_updates(self):
del self.num_updates
self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
else:
assert not key in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
def store(self, parameters):
"""
保存当前参数以便稍后恢复。
参数:
parameters: 可迭代的 `torch.nn.Parameter`;要临时存储的参数。
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
.\cogvideo-finetune\sat\sgm\modules\encoders\modules.py
import math
from contextlib import nullcontext
from functools import partial
from typing import Dict, List, Optional, Tuple, Union
import kornia
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange, repeat
from omegaconf import ListConfig
from torch.utils.checkpoint import checkpoint
from transformers import (
T5EncoderModel,
T5Tokenizer,
)
from ...util import (
append_dims,
autocast,
count_params,
default,
disabled_train,
expand_dims_like,
instantiate_from_config,
)
class AbstractEmbModel(nn.Module):
def __init__(self):
super().__init__()
self._is_trainable = None
self._ucg_rate = None
self._input_key = None
@property
def is_trainable(self) -> bool:
return self._is_trainable
@property
def ucg_rate(self) -> Union[float, torch.Tensor]:
return self._ucg_rate
@property
def input_key(self) -> str:
return self._input_key
@is_trainable.setter
def is_trainable(self, value: bool):
self._is_trainable = value
@ucg_rate.setter
def ucg_rate(self, value: Union[float, torch.Tensor]):
self._ucg_rate = value
@input_key.setter
def input_key(self, value: str):
self._input_key = value
@is_trainable.deleter
def is_trainable(self):
del self._is_trainable
@ucg_rate.deleter
def ucg_rate(self):
del self._ucg_rate
@input_key.deleter
def input_key(self):
del self._input_key
class GeneralConditioner(nn.Module):
OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
def __init__(self, emb_models: Union[List, ListConfig], cor_embs=[], cor_p=[]):
super().__init__()
embedders = []
for n, embconfig in enumerate(emb_models):
embedder = instantiate_from_config(embconfig)
assert isinstance(
embedder, AbstractEmbModel
), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
embedder.is_trainable = embconfig.get("is_trainable", False)
embedder.ucg_rate = embconfig.get("ucg_rate", 0.0)
if not embedder.is_trainable:
embedder.train = disabled_train
for param in embedder.parameters():
param.requires_grad = False
embedder.eval()
print(
f"Initialized embedder #{n}: {embedder.__class__.__name__} "
f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
)
if "input_key" in embconfig:
embedder.input_key = embconfig["input_key"]
elif "input_keys" in embconfig:
embedder.input_keys = embconfig["input_keys"]
else:
raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}")
embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
if embedder.legacy_ucg_val is not None:
embedder.ucg_prng = np.random.RandomState()
embedders.append(embedder)
self.embedders = nn.ModuleList(embedders)
if len(cor_embs) > 0:
assert len(cor_p) == 2 ** len(cor_embs)
self.cor_embs = cor_embs
self.cor_p = cor_p
def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict:
assert embedder.legacy_ucg_val is not None
p = embedder.ucg_rate
val = embedder.legacy_ucg_val
for i in range(len(batch[embedder.input_key])):
if embedder.ucg_prng.choice(2, p=[1 - p, p]):
batch[embedder.input_key][i] = val
return batch
def surely_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict, cond_or_not) -> Dict:
assert embedder.legacy_ucg_val is not None
val = embedder.legacy_ucg_val
for i in range(len(batch[embedder.input_key])):
if cond_or_not[i]:
batch[embedder.input_key][i] = val
return batch
def get_single_embedding(
self,
embedder,
batch,
output,
cond_or_not: Optional[np.ndarray] = None,
force_zero_embeddings: Optional[List] = None,
):
embedding_context = nullcontext if embedder.is_trainable else torch.no_grad
with embedding_context():
if hasattr(embedder, "input_key") and (embedder.input_key is not None):
if embedder.legacy_ucg_val is not None:
if cond_or_not is None:
batch = self.possibly_get_ucg_val(embedder, batch)
else:
batch = self.surely_get_ucg_val(embedder, batch, cond_or_not)
emb_out = embedder(batch[embedder.input_key])
elif hasattr(embedder, "input_keys"):
emb_out = embedder(*[batch[k] for k in embedder.input_keys])
assert isinstance(
emb_out, (torch.Tensor, list, tuple)
), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}"
if not isinstance(emb_out, (list, tuple)):
emb_out = [emb_out]
for emb in emb_out:
out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
if cond_or_not is None:
emb = (
expand_dims_like(
torch.bernoulli((1.0 - embedder.ucg_rate) * torch.ones(emb.shape[0], device=emb.device)),
emb,
)
* emb
)
else:
emb = (
expand_dims_like(
torch.tensor(1 - cond_or_not, dtype=emb.dtype, device=emb.device),
emb,
)
* emb
)
if hasattr(embedder, "input_key") and embedder.input_key in force_zero_embeddings:
emb = torch.zeros_like(emb)
if out_key in output:
output[out_key] = torch.cat((output[out_key], emb), self.KEY2CATDIM[out_key])
else:
output[out_key] = emb
return output
def forward(self, batch: Dict, force_zero_embeddings: Optional[List] = None) -> Dict:
output = dict()
if force_zero_embeddings is None:
force_zero_embeddings = []
if len(self.cor_embs) > 0:
batch_size = len(batch[list(batch.keys())[0]])
rand_idx = np.random.choice(len(self.cor_p), size=(batch_size,), p=self.cor_p)
for emb_idx in self.cor_embs:
cond_or_not = rand_idx % 2
rand_idx //= 2
output = self.get_single_embedding(
self.embedders[emb_idx],
batch,
output=output,
cond_or_not=cond_or_not,
force_zero_embeddings=force_zero_embeddings,
)
for i, embedder in enumerate(self.embedders):
if i in self.cor_embs:
continue
output = self.get_single_embedding(
embedder, batch, output=output, force_zero_embeddings=force_zero_embeddings
)
return output
def get_unconditional_conditioning(self, batch_c, batch_uc=None, force_uc_zero_embeddings=None):
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
ucg_rates = list()
for embedder in self.embedders:
ucg_rates.append(embedder.ucg_rate)
embedder.ucg_rate = 0.0
cor_embs = self.cor_embs
cor_p = self.cor_p
self.cor_embs = []
self.cor_p = []
c = self(batch_c)
uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings)
for embedder, rate in zip(self.embedders, ucg_rates):
embedder.ucg_rate = rate
self.cor_embs = cor_embs
self.cor_p = cor_p
return c, uc
class FrozenT5Embedder(AbstractEmbModel):
"""使用 T5 变换器编码器处理文本"""
def __init__(
self,
model_dir="google/t5-v1_1-xxl",
device="cuda",
max_length=77,
freeze=True,
cache_dir=None,
):
super().__init__()
if model_dir is not "google/t5-v1_1-xxl":
self.tokenizer = T5Tokenizer.from_pretrained(model_dir)
self.transformer = T5EncoderModel.from_pretrained(model_dir)
else:
self.tokenizer = T5Tokenizer.from_pretrained(model_dir, cache_dir=cache_dir)
self.transformer = T5EncoderModel.from_pretrained(model_dir, cache_dir=cache_dir)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
with torch.autocast("cuda", enabled=False):
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
2022-10-23 【公告】布客社区公告 2022.10