从零开始搭建扩散模型
diffusers库介绍
!pip install -q diffusers
Diffusers 库简介
Hugging Face 的 Diffusers 是一个开源库,提供了对扩散模型的支持,包括:
文本生成
图像生成
图像修复(Inpainting)
超分辨率(Super-Resolution)
文本到图像生成(Text-to-Image Generation)
这些模型的核心是扩散过程,这是一种深度学习技术,主要用于生成式任务。
1.导入相关的库
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler,UNet2DModel
from matplotlib import pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device:{device}")
2.导入数据集
dataset = torchvision.datasets.MNIST(root="mnist/",train=True,download=True,
transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset,batch_size=8,shuffle=True)
x,y = next(iter(train_dataloader))
print("input shape:",x.shape)
print("Labels:",y)
plt.imshow(torchvision.utils.make_grid(x)[0],cmap='Greys')
3.图片添加噪声
def corrupt(x,amount):
noise = torch.rand_like(x)
amount = amount.view(-1,1,1,1)
return x*(1-amount) + noise*amount
fig,axs = plt.subplots(2,1,figsize=(12,5))
axs[0].set_title("Input Data")
axs[0].imshow(torchvision.utils.make_grid(x)[0],cmap="Greys")
amount = torch.linspace(0,1,x.shape[0])
noised_x = corrupt(x,amount)
axs[1].set_title("Corrupted Data")
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0],cmap="Greys")
简单的UNet
在进行训练之前,我们需要一个模型,要求它能够接收28×28像 素的噪声图像,并输出相同大小图片的预测结果。业界比较流行的选 择是UNet网络,UNet网络最初被发明用于完成医学图像中的分割任 务。UNet网络由一条“压缩路径”和一条“扩展路径”组成。“压缩 路径”会使通过该路径的数据维度被压缩,而“扩展路径”则会将数 据扩展回原始维度(类似于自动编码器)。
UNet网络中的残差连接允 许信息和梯度在不同层级之间流动。 有些UNet网络在每个阶段的设计中都包含复杂的模块,但在这 里,我们仅构建一个非常简单的示例,它能够接收一个单通道图像, 并使其通过下行路径的3个卷积层(参见图3-1与后面代码实现中的 down_layers)和上行路径的3个卷积层。下行层和上行层之间有残差 连接,我们使用最大池化层进行下采样,并使用nn.Upsample模块进 行上采样。某些更复杂的UNet网络还可能使用带有可学习参数的上采 样层和下采样层
4.使用一个简单的UNet网络对diffusion过程进行学习
class BasicUNet(nn.Module):
def __init__(self,in_channels=1,out_channels=1):
super().__init__()
self.down_layers = torch.nn.ModuleList([
nn.Conv2d(in_channels,32,kernel_size=5,padding=2),
nn.Conv2d(32,64,kernel_size=5,padding=2),
nn.Conv2d(64,64,kernel_size=5,padding=2),
])
self.up_layers = torch.nn.ModuleList([
nn.Conv2d(64,64,kernel_size=5,padding=2),
nn.Conv2d(64,32,kernel_size=5,padding=2),
nn.Conv2d(32,out_channels,kernel_size=5,padding=2),
])
self.act = nn.ReLU()
self.downscale = nn.MaxPool2d(2)
self.upscale = nn.Upsample(scale_factor=2)
def forward(self,x):
h = []
for i,l in enumerate(self.down_layers):
x = self.act(l(x))
if i < 2:
h.append(x)
x = self.downscale(x)
for i,l in enumerate(self.up_layers):
if i > 0:
x = self.upscale(x)
x += h.pop()
x = self.act(l(x))
return x
net = BasicUNet()
x = torch.rand(8,1,28,28)
net(x).shape
sum([p.numel() for p in net.parameters()])
batch_size = 128
train_dataloader = DataLoader(dataset,batch_size=batch_size,
shuffle=True)
n_epoches = 3
net = BasicUNet()
net.to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(),lr=1e-3)
losses = []
for epoch in range(n_epoches):
for x,y in train_dataloader:
x = x.to(device)
noised_amount = torch.rand(x.shape[0]).to(device)
noisy_x = corrupt(x,noised_amount)
pred = net(noisy_x)
loss = loss_fn(pred,x)
opt.zero_grad()
loss.backward()
opt.step()
losses.append(loss.item())
avg_loss = sum(losses[-len(train_dataloader):]) / len(train_dataloader)
print(f'Finished epoch {epoch}. Average loss for this epoch:{avg_loss:.5f}')
plt.plot(losses)
plt.ylim(0,0.1)
5.对得到的模型进行展示
x,y = next(iter(train_dataloader))
x = x[:8]
amount = torch.linspace(0,1,x.shape[0])
noised_x = corrupt(x,amount)
with torch.no_grad():
preds = net(noised_x.to(device)).detach().cpu()
fig,axs = plt.subplots(3,1,figsize=(12,7))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0,1),cmap='Greys')
axs[1].set_title('Corrupted data')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0,1),cmap='Greys')
axs[2].set_title('Network Predictions')
axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0,1),cmap='Greys')
n_steps = 5
x = torch.rand(8,1,28,28).to(device)
step_history = [x.detach().cpu()]
pred_output_history = []
for i in range(n_steps):
with torch.no_grad():
pred = net(x)
pred_output_history.append(pred.detach().cpu())
mix_factor = 1/(n_steps - i)
x = x*(1-mix_factor) + pred*mix_factor
step_history.append(x.detach().cpu())
fig,axs = plt.subplots(n_steps,2,figsize=(9,4),sharex=True)
axs[0,0].set_title('x (model input)')
axs[0,1].set_title('x (model predition)')
for i in range(n_steps):
axs[i,0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0,1),cmap='Greys')
axs[i,1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0,1),cmap='Greys')
左侧是每个阶段模型输入的可视化结果,右侧是 预测的“去噪”(即去除噪声)后的图像。注意,即使模型在第1步 就输出去除了一些噪声的图像,但也只是向最终目标前进了一点点。 如此重复几次后,图像的轮廓开始逐渐出现并得到改善,直到获得最 终结果为止。
n_steps = 40
x = torch.rand(64,1,28,28).to(device)
for i in range(n_steps):
noise_amount = torch.ones((x.shape[0],)).to(device)*(1-i/n_steps)
with torch.no_grad():
pred = net(x)
mix_factor = 1/(n_steps - i)
x = x*(1-mix_factor) + pred*mix_factor
fig,ax = plt.subplots(1,1,figsize=(12,12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu(),nrow=8)[0].clip(0,1),cmap="Greys")
6.使用精度更高的一个模型进行训练
训练目标不同,旨在预测噪声而不是“去噪”的图像。
UNet2DModel模型通过调节时间步来调节噪声量,t作为一个额外参数被传入前向过程。
net = UNet2DModel(
sample_size=28, # 将目标图像分辨率由28提升为64
in_channels=1, # 输入改为3通道(RGB)
out_channels=1, # 输出也改为3通道,以便生成RGB图像
layers_per_block=2, # 每个UNet块中仍使用2个ResNet层
block_out_channels=(32, 64, 64), # 增大通道数,以适应更高分辨率
down_block_types=(
"DownBlock2D", # 标准ResNet下采样模块
"AttnDownBlock2D", # 带有空间维度self-att的ResNet下采样模块
"AttnDownBlock2D", # 再增加一层带有self-att的下采样
),
up_block_types=(
"AttnUpBlock2D", # 带有空间维度self-att的ResNet上采样模块
"AttnUpBlock2D", # 再增加一层带有self-att的上采样
"UpBlock2D", # 标准ResNet上采样模块
),
)
net.to(device)
# 首先给定一个“带噪”(即加入了噪声)的输入noisy_x,扩散模型应
# 该输出其对原始输入x的最佳预测。我们需要通过均方误差对预测值与 真实值进行比较。
batch_size = 128
train_dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True)
n_epochs = 3
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(),lr=1e-3)
losses = []
for epoch in range(n_epochs):
for x,y in train_dataloader:
x = x.to(device)
noise_amount = torch.rand(x.shape[0], device=device)
noisy_x = corrupt(x, noise_amount)
# 随机选择时间步
t = torch.randint(0, 1000, (x.shape[0],), device=device).long()
# 模型前向传播需要 noisy_x 和 t
pred = net(noisy_x, t).sample
loss = loss_fn(pred, x)
opt.zero_grad()
loss.backward()
opt.step()
losses.append(loss.item())
n_steps = 40
x = torch.rand(64,1,28,28).to(device)
for i in range(n_steps):
noise_amount = torch.ones((x.shape[0],)).to(device) * (1 - i/n_steps)
with torch.no_grad():
# 定义时间步 t,可以是整型,也可以是与 batch 大小一致的张量
# 这里假设 i 是扩散时间步,范围为0到n_steps-1。
t = torch.full((x.shape[0],), i, device=device, dtype=torch.long)
# 调用net需要传入 x 和 t
pred = net(x, t).sample # net(...) 返回的是 UNet2DOutput,需要 .sample 获取结果张量
mix_factor = 1/(n_steps - i)
x = x*(1-mix_factor) + pred*mix_factor
fig,ax = plt.subplots(1,1,figsize=(12,12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu(),nrow=8)[0].clip(0,1),cmap="Greys")
plt.show()
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
plt.plot(noise_scheduler.alphas_cumprod.cpu()**0.5,
label=r"${\sqrt{\bar{\alpha}_t}}$")
plt.plot((1-noise_scheduler.alphas_cumprod.cpu())**0.5,
label=r"${\sqrt{1-\bar{\alpha}_t}}$")
plt.legend(fontsize="x-large")
fig,axs = plt.subplots(3,1,figsize=(16,10))
xb,yb = next(iter(train_dataloader))
xb = xb.to(device)[:8]
xb = xb*2 -1
print('X shape',xb.shape)
axs[0].imshow(torchvision.utils.make_grid(xb[:8])[0].detach().cpu(),cmap='Greys')
axs[0].set_title('Clean X')
timesteps = torch.linspace(0,999,8).long().to(device)
noise = torch.randn_like(xb)
noisy_xb = noise_scheduler.add_noise(xb,noise,timesteps)
print("Noisy X shape",noisy_xb.shape)
axs[1].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu().clip(-1,1),
cmap="Greys")
axs[1].set_title("Noisy X (clipped to (-1,1))")
axs[2].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu(),
cmap="Greys")
axs[2].set_title("Noisy X")
预测噪声(从中可以得出“去噪”图像的样子)等同 于直接预测“去噪”图像。但为什么要这么做呢?难道仅仅是为了数 学上的方便吗? 这里其实还有一些精妙之处。我们在训练过程中会计算不同(随 机选择)时间步的损失函数,不同任务目标计算得到的结果会根据损 失值向不同的“隐含权重”收敛,而“预测噪声”这个目标会使权重 更倾向于预测得到更低的噪声量。你可以通过选择更复杂的目标来改 变这种“隐性损失权重”,这样你所选择的噪声调度器就能够直接在 较高的噪声量下产生更多的样本。
你也可以把模型设计成预测“velocity”,我们将其定义为同时 受图像和噪声量影响的组合〔参见论文“Progressive Distillation for Fast Sampling of Diffusion Models”(扩散模型快速采样的渐 进蒸馏)〕。 你还可以将模型设计成预测噪声,但需要基于一些参数对损失进 行缩放。例如,一些研究指出,可以使用噪声量参数〔参见论文 “Perception Prioritized Training of Diffusion Models”(扩散模 型的感知优先训练)〕或者基于一些探索添加的最佳噪声量实验〔参 见 论 文 “Elucidating the Design Space of Diffusion-Based Generative Models”(基于扩散的生成模型的设计空间说明)〕。
本文作者:afengleafs
本文链接:https://www.cnblogs.com/afengleafs/p/18604853
版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步