我们登上并非我们所选择的舞台,演出并非我们所选择的剧本。|

afengleafs

园龄:2个月粉丝:0关注:0

从零开始搭建扩散模型

diffusers库介绍

!pip install -q diffusers
Diffusers 库简介
Hugging Face 的 Diffusers 是一个开源库,提供了对扩散模型的支持,包括:
文本生成
图像生成
图像修复(Inpainting)
超分辨率(Super-Resolution)
文本到图像生成(Text-to-Image Generation)
这些模型的核心是扩散过程,这是一种深度学习技术,主要用于生成式任务。

1.导入相关的库

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler,UNet2DModel
from matplotlib import pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device:{device}")

2.导入数据集

dataset = torchvision.datasets.MNIST(root="mnist/",train=True,download=True,
                                    transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset,batch_size=8,shuffle=True)
x,y = next(iter(train_dataloader))
print("input shape:",x.shape)
print("Labels:",y)
plt.imshow(torchvision.utils.make_grid(x)[0],cmap='Greys')

3.图片添加噪声

def corrupt(x,amount):
  noise = torch.rand_like(x)
  amount = amount.view(-1,1,1,1)
  return x*(1-amount) + noise*amount

fig,axs = plt.subplots(2,1,figsize=(12,5))
axs[0].set_title("Input Data")
axs[0].imshow(torchvision.utils.make_grid(x)[0],cmap="Greys")

amount = torch.linspace(0,1,x.shape[0])
noised_x = corrupt(x,amount)

axs[1].set_title("Corrupted Data")
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0],cmap="Greys")

简单的UNet

在进行训练之前,我们需要一个模型,要求它能够接收28×28像 素的噪声图像,并输出相同大小图片的预测结果。业界比较流行的选 择是UNet网络,UNet网络最初被发明用于完成医学图像中的分割任 务。UNet网络由一条“压缩路径”和一条“扩展路径”组成。“压缩 路径”会使通过该路径的数据维度被压缩,而“扩展路径”则会将数 据扩展回原始维度(类似于自动编码器)。

UNet网络中的残差连接允 许信息和梯度在不同层级之间流动。 有些UNet网络在每个阶段的设计中都包含复杂的模块,但在这 里,我们仅构建一个非常简单的示例,它能够接收一个单通道图像, 并使其通过下行路径的3个卷积层(参见图3-1与后面代码实现中的 down_layers)和上行路径的3个卷积层。下行层和上行层之间有残差 连接,我们使用最大池化层进行下采样,并使用nn.Upsample模块进 行上采样。某些更复杂的UNet网络还可能使用带有可学习参数的上采 样层和下采样层

4.使用一个简单的UNet网络对diffusion过程进行学习

class BasicUNet(nn.Module):
  def __init__(self,in_channels=1,out_channels=1):
    super().__init__()
    self.down_layers = torch.nn.ModuleList([
        nn.Conv2d(in_channels,32,kernel_size=5,padding=2),
        nn.Conv2d(32,64,kernel_size=5,padding=2),
        nn.Conv2d(64,64,kernel_size=5,padding=2),
    ])

    self.up_layers = torch.nn.ModuleList([
        nn.Conv2d(64,64,kernel_size=5,padding=2),
        nn.Conv2d(64,32,kernel_size=5,padding=2),
        nn.Conv2d(32,out_channels,kernel_size=5,padding=2),
    ])

    self.act = nn.ReLU()
    self.downscale = nn.MaxPool2d(2)
    self.upscale = nn.Upsample(scale_factor=2)

  def forward(self,x):
    h = []
    for i,l in enumerate(self.down_layers):
      x = self.act(l(x))
      if i < 2:
        h.append(x)
        x = self.downscale(x)

    for i,l in enumerate(self.up_layers):
      if i > 0:
        x = self.upscale(x)
        x += h.pop()
      x = self.act(l(x))
    return x
net = BasicUNet()
x = torch.rand(8,1,28,28)
net(x).shape
sum([p.numel() for p in net.parameters()])
batch_size = 128
train_dataloader = DataLoader(dataset,batch_size=batch_size,
                              shuffle=True)
n_epoches = 3
net = BasicUNet()
net.to(device)

loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(),lr=1e-3)
losses = []

for epoch in range(n_epoches):
  for x,y in train_dataloader:
    x = x.to(device)
    noised_amount = torch.rand(x.shape[0]).to(device)

    noisy_x = corrupt(x,noised_amount)

    pred = net(noisy_x)
    loss = loss_fn(pred,x)
    opt.zero_grad()
    loss.backward()
    opt.step()

    losses.append(loss.item())
  avg_loss = sum(losses[-len(train_dataloader):]) / len(train_dataloader)
  print(f'Finished epoch {epoch}. Average loss for this epoch:{avg_loss:.5f}')


plt.plot(losses)
plt.ylim(0,0.1)

5.对得到的模型进行展示

x,y = next(iter(train_dataloader))
x = x[:8]

amount = torch.linspace(0,1,x.shape[0])
noised_x = corrupt(x,amount)
with torch.no_grad():
  preds = net(noised_x.to(device)).detach().cpu()

fig,axs = plt.subplots(3,1,figsize=(12,7))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0,1),cmap='Greys')

axs[1].set_title('Corrupted data')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0,1),cmap='Greys')

axs[2].set_title('Network Predictions')
axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0,1),cmap='Greys')

n_steps = 5
x = torch.rand(8,1,28,28).to(device)
step_history = [x.detach().cpu()]
pred_output_history = []

for i in range(n_steps):
  with torch.no_grad():
    pred = net(x)
  pred_output_history.append(pred.detach().cpu())
  mix_factor = 1/(n_steps - i)
  x = x*(1-mix_factor) + pred*mix_factor
  step_history.append(x.detach().cpu())
fig,axs = plt.subplots(n_steps,2,figsize=(9,4),sharex=True)
axs[0,0].set_title('x (model input)')
axs[0,1].set_title('x (model predition)')
for i in range(n_steps):
  axs[i,0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0,1),cmap='Greys')
  axs[i,1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0,1),cmap='Greys')


左侧是每个阶段模型输入的可视化结果,右侧是 预测的“去噪”(即去除噪声)后的图像。注意,即使模型在第1步 就输出去除了一些噪声的图像,但也只是向最终目标前进了一点点。 如此重复几次后,图像的轮廓开始逐渐出现并得到改善,直到获得最 终结果为止。

n_steps = 40
x = torch.rand(64,1,28,28).to(device)
for i in range(n_steps):
  noise_amount = torch.ones((x.shape[0],)).to(device)*(1-i/n_steps)
  with torch.no_grad():
    pred = net(x)
  mix_factor = 1/(n_steps - i)
  x = x*(1-mix_factor) + pred*mix_factor
fig,ax = plt.subplots(1,1,figsize=(12,12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu(),nrow=8)[0].clip(0,1),cmap="Greys")

6.使用精度更高的一个模型进行训练

训练目标不同,旨在预测噪声而不是“去噪”的图像。

UNet2DModel模型通过调节时间步来调节噪声量,t作为一个额外参数被传入前向过程。

net = UNet2DModel(
    sample_size=28,           # 将目标图像分辨率由28提升为64
    in_channels=1,            # 输入改为3通道(RGB)
    out_channels=1,           # 输出也改为3通道,以便生成RGB图像
    layers_per_block=2,       # 每个UNet块中仍使用2个ResNet层
    block_out_channels=(32, 64, 64),  # 增大通道数,以适应更高分辨率
    down_block_types=(
        "DownBlock2D",        # 标准ResNet下采样模块
        "AttnDownBlock2D",    # 带有空间维度self-att的ResNet下采样模块
        "AttnDownBlock2D",    # 再增加一层带有self-att的下采样
    ),
    up_block_types=(
        "AttnUpBlock2D",      # 带有空间维度self-att的ResNet上采样模块
        "AttnUpBlock2D",      # 再增加一层带有self-att的上采样
        "UpBlock2D",          # 标准ResNet上采样模块
    ),
)

net.to(device)
# 首先给定一个“带噪”(即加入了噪声)的输入noisy_x,扩散模型应
# 该输出其对原始输入x的最佳预测。我们需要通过均方误差对预测值与 真实值进行比较。

batch_size = 128
train_dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True)
n_epochs = 3

loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(),lr=1e-3)
losses = []
for epoch in range(n_epochs):
  for x,y in train_dataloader:
      x = x.to(device)
      noise_amount = torch.rand(x.shape[0], device=device)
      noisy_x = corrupt(x, noise_amount)

      # 随机选择时间步
      t = torch.randint(0, 1000, (x.shape[0],), device=device).long()

      # 模型前向传播需要 noisy_x 和 t
      pred = net(noisy_x, t).sample

      loss = loss_fn(pred, x)
      opt.zero_grad()
      loss.backward()
      opt.step()
      losses.append(loss.item())


n_steps = 40
x = torch.rand(64,1,28,28).to(device)

for i in range(n_steps):
    noise_amount = torch.ones((x.shape[0],)).to(device) * (1 - i/n_steps)
    with torch.no_grad():
        # 定义时间步 t,可以是整型,也可以是与 batch 大小一致的张量
        # 这里假设 i 是扩散时间步,范围为0到n_steps-1。
        t = torch.full((x.shape[0],), i, device=device, dtype=torch.long)

        # 调用net需要传入 x 和 t
        pred = net(x, t).sample  # net(...) 返回的是 UNet2DOutput,需要 .sample 获取结果张量

    mix_factor = 1/(n_steps - i)
    x = x*(1-mix_factor) + pred*mix_factor

fig,ax = plt.subplots(1,1,figsize=(12,12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu(),nrow=8)[0].clip(0,1),cmap="Greys")
plt.show()

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
plt.plot(noise_scheduler.alphas_cumprod.cpu()**0.5,
         label=r"${\sqrt{\bar{\alpha}_t}}$")
plt.plot((1-noise_scheduler.alphas_cumprod.cpu())**0.5,
         label=r"${\sqrt{1-\bar{\alpha}_t}}$")
plt.legend(fontsize="x-large")

fig,axs = plt.subplots(3,1,figsize=(16,10))
xb,yb = next(iter(train_dataloader))
xb = xb.to(device)[:8]
xb = xb*2 -1
print('X shape',xb.shape)

axs[0].imshow(torchvision.utils.make_grid(xb[:8])[0].detach().cpu(),cmap='Greys')
axs[0].set_title('Clean X')

timesteps = torch.linspace(0,999,8).long().to(device)
noise = torch.randn_like(xb)
noisy_xb = noise_scheduler.add_noise(xb,noise,timesteps)
print("Noisy X shape",noisy_xb.shape)

axs[1].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu().clip(-1,1),
              cmap="Greys")
axs[1].set_title("Noisy X (clipped to (-1,1))")

axs[2].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu(),
              cmap="Greys")
axs[2].set_title("Noisy X")

预测噪声(从中可以得出“去噪”图像的样子)等同 于直接预测“去噪”图像。但为什么要这么做呢?难道仅仅是为了数 学上的方便吗? 这里其实还有一些精妙之处。我们在训练过程中会计算不同(随 机选择)时间步的损失函数,不同任务目标计算得到的结果会根据损 失值向不同的“隐含权重”收敛,而“预测噪声”这个目标会使权重 更倾向于预测得到更低的噪声量。你可以通过选择更复杂的目标来改 变这种“隐性损失权重”,这样你所选择的噪声调度器就能够直接在 较高的噪声量下产生更多的样本。

你也可以把模型设计成预测“velocity”,我们将其定义为同时 受图像和噪声量影响的组合〔参见论文“Progressive Distillation for Fast Sampling of Diffusion Models”(扩散模型快速采样的渐 进蒸馏)〕。 你还可以将模型设计成预测噪声,但需要基于一些参数对损失进 行缩放。例如,一些研究指出,可以使用噪声量参数〔参见论文 “Perception Prioritized Training of Diffusion Models”(扩散模 型的感知优先训练)〕或者基于一些探索添加的最佳噪声量实验〔参 见 论 文 “Elucidating the Design Space of Diffusion-Based Generative Models”(基于扩散的生成模型的设计空间说明)〕。

本文作者:afengleafs

本文链接:https://www.cnblogs.com/afengleafs/p/18604853

版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。

posted @   afengleafs  阅读(38)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示
评论
收藏
关注
推荐
深色
回顶
收起